/*
 * Decompiled with CFR 0.152.
 */
package ai;

import ai.AttributeClassPair;
import ai.SplitFunction;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Random;
import weka.core.Instance;
import weka.core.Instances;

public class GiniFunction
extends SplitFunction {
    private static final long serialVersionUID = 9707184791345L;
    private int index;
    private double threshold;
    private boolean allSame;
    private int numOfFeatures;
    private final Random random;

    public GiniFunction(int numOfFeatures, Random random) {
        this.numOfFeatures = numOfFeatures;
        this.random = random;
    }

    @Override
    public void init(Instances data, ArrayList<Integer> indices) {
        if (indices.isEmpty()) {
            this.index = 0;
            this.threshold = 0.0;
            this.allSame = true;
            return;
        }
        int len = data.numAttributes();
        int numElements = indices.size();
        int numClasses = data.numClasses();
        int classIndex = data.classIndex();
        Comparator<AttributeClassPair> comp = new Comparator<AttributeClassPair>(){

            @Override
            public int compare(AttributeClassPair o1, AttributeClassPair o2) {
                double diff = o2.attributeValue - o1.attributeValue;
                if (diff < 0.0) {
                    return 1;
                }
                if (diff == 0.0) {
                    return 0;
                }
                return -1;
            }

            @Override
            public boolean equals(Object o) {
                return false;
            }
        };
        ArrayList<Integer> allIndices = new ArrayList<Integer>();
        for (int i = 0; i < len; ++i) {
            if (i == classIndex) continue;
            allIndices.add(i);
        }
        double minimumGini = Double.MAX_VALUE;
        for (int i = 0; i < this.numOfFeatures; ++i) {
            int index = this.random.nextInt(allIndices.size());
            int featureToUse = (Integer)allIndices.get(index);
            allIndices.remove(index);
            ArrayList<AttributeClassPair> list = new ArrayList<AttributeClassPair>();
            for (int j = 0; j < numElements; ++j) {
                Instance ins = data.get(indices.get(j).intValue());
                list.add(new AttributeClassPair(ins.value(featureToUse), (int)ins.value(classIndex)));
            }
            Collections.sort(list, comp);
            double[] probLeft = new double[numClasses];
            double[] probRight = new double[numClasses];
            for (int n = 0; n < list.size(); ++n) {
                int n2 = ((AttributeClassPair)list.get((int)n)).classValue;
                probRight[n2] = probRight[n2] + 1.0;
            }
            for (int splitPoint = 0; splitPoint < numElements; ++splitPoint) {
                double giniLeft = 0.0;
                double giniRight = 0.0;
                int rightNumElements = numElements - splitPoint;
                for (int nClass = 0; nClass < numClasses; ++nClass) {
                    double prob = probLeft[nClass];
                    if (splitPoint != 0) {
                        prob /= (double)splitPoint;
                    }
                    giniLeft += prob * prob;
                    prob = probRight[nClass];
                    if (rightNumElements != 0) {
                        prob /= (double)rightNumElements;
                    }
                    giniRight += prob * prob;
                }
                double gini = ((1.0 - giniLeft) * (double)splitPoint + (1.0 - giniRight) * (double)rightNumElements) / (double)numElements;
                if (gini < minimumGini) {
                    minimumGini = gini;
                    this.index = featureToUse;
                    this.threshold = ((AttributeClassPair)list.get((int)splitPoint)).attributeValue;
                }
                int n = ((AttributeClassPair)list.get((int)splitPoint)).classValue;
                probLeft[n] = probLeft[n] + 1.0;
                int n3 = ((AttributeClassPair)list.get((int)splitPoint)).classValue;
                probRight[n3] = probRight[n3] - 1.0;
            }
        }
    }

    @Override
    public boolean evaluate(Instance instance) {
        if (this.allSame) {
            return true;
        }
        return instance.value(this.index) < this.threshold;
    }

    @Override
    public SplitFunction newInstance() {
        return new GiniFunction(this.numOfFeatures, this.random);
    }
}

