package weka.classifiers.bayes.net.search.local;

import java.util.Enumeration;
import java.util.Vector;
import org.xmlcml.cml.element.CMLBond;
import weka.classifiers.bayes.BayesNet;
import weka.classifiers.bayes.net.ParentSet;
import weka.classifiers.bayes.net.search.SearchAlgorithm;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Statistics;
import weka.core.Tag;
import weka.core.Utils;

/* loaded from: input_file:lib/ches-mapper_lib/weka-3-7-6/weka.jar:weka/classifiers/bayes/net/search/local/LocalScoreSearchAlgorithm.class */
public class LocalScoreSearchAlgorithm extends SearchAlgorithm {
    static final long serialVersionUID = 3325995552474190374L;
    BayesNet m_BayesNet;
    public static final Tag[] TAGS_SCORE_TYPE = {new Tag(0, "BAYES"), new Tag(1, "BDeu"), new Tag(2, "MDL"), new Tag(3, "ENTROPY"), new Tag(4, "AIC")};
    double m_fAlpha = 0.5d;
    int m_nScoreType = 0;

    public LocalScoreSearchAlgorithm() {
    }

    public LocalScoreSearchAlgorithm(BayesNet bayesNet, Instances instances) {
        this.m_BayesNet = bayesNet;
    }

    public double logScore(int i) {
        if (this.m_BayesNet.m_Distributions == null) {
            return 0.0d;
        }
        if (i < 0) {
            i = this.m_nScoreType;
        }
        double d = 0.0d;
        Instances instances = this.m_BayesNet.m_Instances;
        for (int i2 = 0; i2 < instances.numAttributes(); i2++) {
            int cardinalityOfParents = this.m_BayesNet.getParentSet(i2).getCardinalityOfParents();
            for (int i3 = 0; i3 < cardinalityOfParents; i3++) {
                d += ((Scoreable) this.m_BayesNet.m_Distributions[i2][i3]).logScore(i, cardinalityOfParents);
            }
            switch (i) {
                case 2:
                    d -= ((0.5d * this.m_BayesNet.getParentSet(i2).getCardinalityOfParents()) * (instances.attribute(i2).numValues() - 1)) * Math.log(instances.numInstances());
                    break;
                case 4:
                    d -= this.m_BayesNet.getParentSet(i2).getCardinalityOfParents() * (instances.attribute(i2).numValues() - 1);
                    break;
            }
        }
        return d;
    }

    @Override // weka.classifiers.bayes.net.search.SearchAlgorithm
    public void buildStructure(BayesNet bayesNet, Instances instances) throws Exception {
        this.m_BayesNet = bayesNet;
        super.buildStructure(bayesNet, instances);
    }

    public double calcNodeScore(int i) {
        return (!this.m_BayesNet.getUseADTree() || this.m_BayesNet.getADTree() == null) ? calcNodeScorePlain(i) : calcNodeScoreADTree(i);
    }

    private double calcNodeScoreADTree(int i) {
        Instances instances = this.m_BayesNet.m_Instances;
        ParentSet parentSet = this.m_BayesNet.getParentSet(i);
        int nrOfParents = parentSet.getNrOfParents();
        int[] iArr = new int[nrOfParents + 1];
        for (int i2 = 0; i2 < nrOfParents; i2++) {
            iArr[i2] = parentSet.getParent(i2);
        }
        iArr[nrOfParents] = i;
        int[] iArr2 = new int[nrOfParents + 1];
        iArr2[nrOfParents] = 1;
        int numValues = 1 * instances.attribute(i).numValues();
        for (int i3 = nrOfParents - 1; i3 >= 0; i3--) {
            iArr2[i3] = numValues;
            numValues *= instances.attribute(iArr[i3]).numValues();
        }
        for (int i4 = 1; i4 < iArr.length; i4++) {
            for (int i5 = i4; i5 > 0 && iArr[i5] < iArr[i5 - 1]; i5--) {
                int i6 = iArr[i5];
                iArr[i5] = iArr[i5 - 1];
                iArr[i5 - 1] = i6;
                int i7 = iArr2[i5];
                iArr2[i5] = iArr2[i5 - 1];
                iArr2[i5 - 1] = i7;
            }
        }
        int cardinalityOfParents = parentSet.getCardinalityOfParents();
        int numValues2 = instances.attribute(i).numValues();
        int[] iArr3 = new int[cardinalityOfParents * numValues2];
        this.m_BayesNet.getADTree().getCounts(iArr3, iArr, iArr2, 0, 0, false);
        return calcScoreOfCounts(iArr3, cardinalityOfParents, numValues2, instances);
    }

    private double calcNodeScorePlain(int i) {
        Instances instances = this.m_BayesNet.m_Instances;
        ParentSet parentSet = this.m_BayesNet.getParentSet(i);
        int cardinalityOfParents = parentSet.getCardinalityOfParents();
        int numValues = instances.attribute(i).numValues();
        int[] iArr = new int[cardinalityOfParents * numValues];
        for (int i2 = 0; i2 < cardinalityOfParents * numValues; i2++) {
            iArr[i2] = 0;
        }
        Enumeration enumerateInstances = instances.enumerateInstances();
        while (enumerateInstances.hasMoreElements()) {
            Instance instance = (Instance) enumerateInstances.nextElement();
            double d = 0.0d;
            for (int i3 = 0; i3 < parentSet.getNrOfParents(); i3++) {
                d = (d * instances.attribute(r0).numValues()) + instance.value(parentSet.getParent(i3));
            }
            int value = (numValues * ((int) d)) + ((int) instance.value(i));
            iArr[value] = iArr[value] + 1;
        }
        return calcScoreOfCounts(iArr, cardinalityOfParents, numValues, instances);
    }

    protected double calcScoreOfCounts(int[] iArr, int i, int i2, Instances instances) {
        double d = 0.0d;
        for (int i3 = 0; i3 < i; i3++) {
            switch (this.m_nScoreType) {
                case 0:
                    double d2 = 0.0d;
                    for (int i4 = 0; i4 < i2; i4++) {
                        if (this.m_fAlpha + iArr[(i3 * i2) + i4] != 0.0d) {
                            d += Statistics.lnGamma(this.m_fAlpha + iArr[(i3 * i2) + i4]);
                            d2 += this.m_fAlpha + iArr[(i3 * i2) + i4];
                        }
                    }
                    if (d2 != 0.0d) {
                        d -= Statistics.lnGamma(d2);
                    }
                    if (this.m_fAlpha != 0.0d) {
                        d = (d - (i2 * Statistics.lnGamma(this.m_fAlpha))) + Statistics.lnGamma(i2 * this.m_fAlpha);
                        break;
                    } else {
                        break;
                    }
                case 1:
                    double d3 = 0.0d;
                    for (int i5 = 0; i5 < i2; i5++) {
                        if (this.m_fAlpha + iArr[(i3 * i2) + i5] != 0.0d) {
                            d += Statistics.lnGamma((1.0d / (i2 * i)) + iArr[(i3 * i2) + i5]);
                            d3 += (1.0d / (i2 * i)) + iArr[(i3 * i2) + i5];
                        }
                    }
                    d = ((d - Statistics.lnGamma(d3)) - (i2 * Statistics.lnGamma(1.0d / (i2 * i)))) + Statistics.lnGamma(1.0d / i);
                    break;
                case 2:
                case 3:
                case 4:
                    double d4 = 0.0d;
                    for (int i6 = 0; i6 < i2; i6++) {
                        d4 += iArr[(i3 * i2) + i6];
                    }
                    for (int i7 = 0; i7 < i2; i7++) {
                        if (iArr[(i3 * i2) + i7] > 0) {
                            d += iArr[(i3 * i2) + i7] * Math.log(iArr[(i3 * i2) + i7] / d4);
                        }
                    }
                    break;
            }
        }
        switch (this.m_nScoreType) {
            case 2:
                d -= ((0.5d * i) * (i2 - 1)) * Math.log(instances.numInstances());
                break;
            case 4:
                d -= i * (i2 - 1);
                break;
        }
        return d;
    }

    protected double calcScoreOfCounts2(int[][] iArr, int i, int i2, Instances instances) {
        double d = 0.0d;
        for (int i3 = 0; i3 < i; i3++) {
            switch (this.m_nScoreType) {
                case 0:
                    double d2 = 0.0d;
                    for (int i4 = 0; i4 < i2; i4++) {
                        if (this.m_fAlpha + iArr[i3][i4] != 0.0d) {
                            d += Statistics.lnGamma(this.m_fAlpha + iArr[i3][i4]);
                            d2 += this.m_fAlpha + iArr[i3][i4];
                        }
                    }
                    if (d2 != 0.0d) {
                        d -= Statistics.lnGamma(d2);
                    }
                    if (this.m_fAlpha != 0.0d) {
                        d = (d - (i2 * Statistics.lnGamma(this.m_fAlpha))) + Statistics.lnGamma(i2 * this.m_fAlpha);
                        break;
                    } else {
                        break;
                    }
                case 1:
                    double d3 = 0.0d;
                    for (int i5 = 0; i5 < i2; i5++) {
                        if (this.m_fAlpha + iArr[i3][i5] != 0.0d) {
                            d += Statistics.lnGamma((1.0d / (i2 * i)) + iArr[i3][i5]);
                            d3 += (1.0d / (i2 * i)) + iArr[i3][i5];
                        }
                    }
                    d = ((d - Statistics.lnGamma(d3)) - (i2 * Statistics.lnGamma(1.0d / (i * i2)))) + Statistics.lnGamma(1.0d / i);
                    break;
                case 2:
                case 3:
                case 4:
                    double d4 = 0.0d;
                    for (int i6 = 0; i6 < i2; i6++) {
                        d4 += iArr[i3][i6];
                    }
                    for (int i7 = 0; i7 < i2; i7++) {
                        if (iArr[i3][i7] > 0) {
                            d += iArr[i3][i7] * Math.log(iArr[i3][i7] / d4);
                        }
                    }
                    break;
            }
        }
        switch (this.m_nScoreType) {
            case 2:
                d -= ((0.5d * i) * (i2 - 1)) * Math.log(instances.numInstances());
                break;
            case 4:
                d -= i * (i2 - 1);
                break;
        }
        return d;
    }

    public double calcScoreWithExtraParent(int i, int i2) {
        ParentSet parentSet = this.m_BayesNet.getParentSet(i);
        if (parentSet.contains(i2)) {
            return -1.0E100d;
        }
        parentSet.addParent(i2, this.m_BayesNet.m_Instances);
        double calcNodeScore = calcNodeScore(i);
        parentSet.deleteLastParent(this.m_BayesNet.m_Instances);
        return calcNodeScore;
    }

    public double calcScoreWithMissingParent(int i, int i2) {
        ParentSet parentSet = this.m_BayesNet.getParentSet(i);
        if (!parentSet.contains(i2)) {
            return -1.0E100d;
        }
        int deleteParent = parentSet.deleteParent(i2, this.m_BayesNet.m_Instances);
        double calcNodeScore = calcNodeScore(i);
        parentSet.addParent(i2, deleteParent, this.m_BayesNet.m_Instances);
        return calcNodeScore;
    }

    public void setScoreType(SelectedTag selectedTag) {
        if (selectedTag.getTags() == TAGS_SCORE_TYPE) {
            this.m_nScoreType = selectedTag.getSelectedTag().getID();
        }
    }

    public SelectedTag getScoreType() {
        return new SelectedTag(this.m_nScoreType, TAGS_SCORE_TYPE);
    }

    @Override // weka.classifiers.bayes.net.search.SearchAlgorithm
    public void setMarkovBlanketClassifier(boolean z) {
        super.setMarkovBlanketClassifier(z);
    }

    @Override // weka.classifiers.bayes.net.search.SearchAlgorithm
    public boolean getMarkovBlanketClassifier() {
        return super.getMarkovBlanketClassifier();
    }

    @Override // weka.classifiers.bayes.net.search.SearchAlgorithm, weka.core.OptionHandler
    public Enumeration listOptions() {
        Vector vector = new Vector();
        vector.addElement(new Option("\tApplies a Markov Blanket correction to the network structure, \n\tafter a network structure is learned. This ensures that all \n\tnodes in the network are part of the Markov blanket of the \n\tclassifier node.", "mbc", 0, "-mbc"));
        vector.addElement(new Option("\tScore type (BAYES, BDeu, MDL, ENTROPY and AIC)", CMLBond.SINGLE_S, 1, "-S [BAYES|MDL|ENTROPY|AIC|CROSS_CLASSIC|CROSS_BAYES]"));
        return vector.elements();
    }

    @Override // weka.classifiers.bayes.net.search.SearchAlgorithm, weka.core.OptionHandler
    public void setOptions(String[] strArr) throws Exception {
        setMarkovBlanketClassifier(Utils.getFlag("mbc", strArr));
        String option = Utils.getOption('S', strArr);
        if (option.compareTo("BAYES") == 0) {
            setScoreType(new SelectedTag(0, TAGS_SCORE_TYPE));
        }
        if (option.compareTo("BDeu") == 0) {
            setScoreType(new SelectedTag(1, TAGS_SCORE_TYPE));
        }
        if (option.compareTo("MDL") == 0) {
            setScoreType(new SelectedTag(2, TAGS_SCORE_TYPE));
        }
        if (option.compareTo("ENTROPY") == 0) {
            setScoreType(new SelectedTag(3, TAGS_SCORE_TYPE));
        }
        if (option.compareTo("AIC") == 0) {
            setScoreType(new SelectedTag(4, TAGS_SCORE_TYPE));
        }
    }

    @Override // weka.classifiers.bayes.net.search.SearchAlgorithm, weka.core.OptionHandler
    public String[] getOptions() {
        String[] options = super.getOptions();
        String[] strArr = new String[3 + options.length];
        int i = 0;
        if (getMarkovBlanketClassifier()) {
            i = 0 + 1;
            strArr[0] = "-mbc";
        }
        int i2 = i;
        int i3 = i + 1;
        strArr[i2] = "-S";
        switch (this.m_nScoreType) {
            case 0:
                i3++;
                strArr[i3] = "BAYES";
                break;
            case 1:
                i3++;
                strArr[i3] = "BDeu";
                break;
            case 2:
                i3++;
                strArr[i3] = "MDL";
                break;
            case 3:
                i3++;
                strArr[i3] = "ENTROPY";
                break;
            case 4:
                i3++;
                strArr[i3] = "AIC";
                break;
        }
        for (String str : options) {
            int i4 = i3;
            i3++;
            strArr[i4] = str;
        }
        while (i3 < strArr.length) {
            int i5 = i3;
            i3++;
            strArr[i5] = "";
        }
        return strArr;
    }

    public String scoreTypeTipText() {
        return "The score type determines the measure used to judge the quality of a network structure. It can be one of Bayes, BDeu, Minimum Description Length (MDL), Akaike Information Criterion (AIC), and Entropy.";
    }

    @Override // weka.classifiers.bayes.net.search.SearchAlgorithm
    public String markovBlanketClassifierTipText() {
        return super.markovBlanketClassifierTipText();
    }

    public String globalInfo() {
        return "The ScoreBasedSearchAlgorithm class supports Bayes net structure search algorithms that are based on maximizing scores (as opposed to for example conditional independence based search algorithms).";
    }

    @Override // weka.classifiers.bayes.net.search.SearchAlgorithm, weka.core.RevisionHandler
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 8034 $");
    }
}
