/*
 * Decompiled with CFR 0.152.
 */
package ai;

import ai.BalancedRandomTree;
import ai.GiniFunction;
import ai.Splitter;
import ai.VotesCollector;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Randomizable;
import weka.core.TechnicalInformation;
import weka.core.Utils;

public class BalancedRandomForest
extends AbstractClassifier
implements Randomizable {
    private static final long serialVersionUID = "BalancedRandomForest".hashCode();
    private int seed = 1;
    private int numTrees = 10;
    private int numFeatures = 0;
    private BalancedRandomTree[] tree = null;
    private double outOfBagError = 0.0;

    public String globalInfo() {
        return "Class for constructing a balanced forest of random trees.\n\nFor more information see: \n\n" + this.getTechnicalInformation().toString();
    }

    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.ARTICLE);
        result.setValue(TechnicalInformation.Field.AUTHOR, "Leo Breiman");
        result.setValue(TechnicalInformation.Field.YEAR, "2001");
        result.setValue(TechnicalInformation.Field.TITLE, "Random Forests");
        result.setValue(TechnicalInformation.Field.JOURNAL, "Machine Learning");
        result.setValue(TechnicalInformation.Field.VOLUME, "45");
        result.setValue(TechnicalInformation.Field.NUMBER, "1");
        result.setValue(TechnicalInformation.Field.PAGES, "5-32");
        return result;
    }

    public String numTreesTipText() {
        return "The number of trees to be generated.";
    }

    public String numFeaturesTipText() {
        return "The number of attributes to be used in random selection of each node.";
    }

    public String seedTipText() {
        return "The random number seed to be used.";
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void buildClassifier(final Instances data) throws Exception {
        int i;
        if (this.numFeatures < 1) {
            this.numFeatures = (int)Utils.log2((double)data.numAttributes()) + 1;
        }
        if (this.numFeatures >= data.numAttributes()) {
            this.numFeatures = data.numAttributes() - 1;
        }
        this.tree = new BalancedRandomTree[this.numTrees];
        int numInstances = data.numInstances();
        int numClasses = data.numClasses();
        ArrayList[] indexSample = new ArrayList[numClasses];
        for (i = 0; i < numClasses; ++i) {
            indexSample[i] = new ArrayList();
        }
        for (i = 0; i < numInstances; ++i) {
            indexSample[(int)data.get(i).classValue()].add(i);
        }
        Random random = new Random(this.seed);
        ExecutorService exe = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
        ArrayList<Future<BalancedRandomTree>> futures = new ArrayList<Future<BalancedRandomTree>>(this.numTrees);
        boolean[][] inBag = new boolean[this.numTrees][numInstances];
        try {
            for (int i2 = 0; i2 < this.numTrees; ++i2) {
                final ArrayList bagIndices = new ArrayList();
                for (int j = 0; j < numInstances; ++j) {
                    int randomClass = random.nextInt(numClasses);
                    int randomSample = random.nextInt(indexSample[randomClass].size());
                    bagIndices.add(indexSample[randomClass].get(randomSample));
                    inBag[i2][((Integer)indexSample[randomClass].get((int)randomSample)).intValue()] = true;
                }
                final Splitter splitter = new Splitter(new GiniFunction(this.numFeatures, data.getRandomNumberGenerator((long)random.nextInt())));
                futures.add(exe.submit(new Callable<BalancedRandomTree>(){

                    @Override
                    public BalancedRandomTree call() {
                        return new BalancedRandomTree(data, bagIndices, splitter);
                    }
                }));
            }
            for (int treeIdx = 0; treeIdx < this.numTrees; ++treeIdx) {
                this.tree[treeIdx] = (BalancedRandomTree)((Future)futures.get(treeIdx)).get();
            }
            boolean numeric = data.classAttribute().isNumeric();
            ArrayList<Future<Double>> votes = new ArrayList<Future<Double>>(data.numInstances());
            for (int i3 = 0; i3 < data.numInstances(); ++i3) {
                VotesCollector aCollector = new VotesCollector(this.tree, i3, data, inBag);
                votes.add(exe.submit(aCollector));
            }
            double outOfBagCount = 0.0;
            double errorSum = 0.0;
            for (int i4 = 0; i4 < data.numInstances(); ++i4) {
                double vote = (Double)((Future)votes.get(i4)).get();
                outOfBagCount += data.instance(i4).weight();
                if (numeric) {
                    errorSum += StrictMath.abs(vote - data.instance(i4).classValue()) * data.instance(i4).weight();
                    continue;
                }
                if (vote == data.instance(i4).classValue()) continue;
                errorSum += data.instance(i4).weight();
            }
            this.outOfBagError = errorSum / outOfBagCount;
        }
        catch (Exception ex) {
            ex.printStackTrace();
        }
        finally {
            exe.shutdownNow();
        }
    }

    public double[] distributionForInstance(Instance instance) {
        double[] sums = new double[instance.numClasses()];
        for (int i = 0; i < this.numTrees; ++i) {
            double[] newProbs = this.tree[i].evaluate(instance);
            for (int j = 0; j < newProbs.length; ++j) {
                int n = j;
                sums[n] = sums[n] + newProbs[j];
            }
        }
        int j = 0;
        while (j < sums.length) {
            int n = j++;
            sums[n] = sums[n] / (double)this.numTrees;
        }
        return sums;
    }

    public String[] getOptions() {
        Vector<String> result = new Vector<String>();
        result.add("-I");
        result.add(String.valueOf(this.getNumTrees()));
        result.add("-K");
        result.add(String.valueOf(this.getNumFeatures()));
        result.add("-S");
        result.add(String.valueOf(this.getSeed()));
        String[] options = super.getOptions();
        for (int i = 0; i < options.length; ++i) {
            result.add(options[i]);
        }
        return result.toArray(new String[result.size()]);
    }

    public void setOptions(String[] options) throws Exception {
        String tmpStr = Utils.getOption((char)'I', (String[])options);
        this.numTrees = tmpStr.length() != 0 ? Integer.parseInt(tmpStr) : 100;
        tmpStr = Utils.getOption((char)'K', (String[])options);
        this.numFeatures = tmpStr.length() != 0 ? Integer.parseInt(tmpStr) : 0;
        tmpStr = Utils.getOption((char)'S', (String[])options);
        if (tmpStr.length() != 0) {
            this.setSeed(Integer.parseInt(tmpStr));
        } else {
            this.setSeed(1);
        }
        super.setOptions(options);
        Utils.checkForRemainingOptions((String[])options);
    }

    public int getNumFeatures() {
        return this.numFeatures;
    }

    public int getNumTrees() {
        return this.numTrees;
    }

    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        result.enable(Capabilities.Capability.NOMINAL_CLASS);
        return result;
    }

    public int getSeed() {
        return this.seed;
    }

    public void setSeed(int seed) {
        this.seed = seed;
    }

    public void setNumTrees(int numTrees) {
        this.numTrees = numTrees;
    }

    public void setNumFeatures(int numFeatures) {
        this.numFeatures = numFeatures;
    }

    public double measureOutOfBagError() {
        return this.outOfBagError;
    }

    public Enumeration enumerateMeasures() {
        Vector<String> newVector = new Vector<String>(1);
        newVector.addElement("measureOutOfBagError");
        return newVector.elements();
    }

    public double getMeasure(String additionalMeasureName) {
        if (additionalMeasureName.equalsIgnoreCase("measureOutOfBagError")) {
            return this.measureOutOfBagError();
        }
        throw new IllegalArgumentException(additionalMeasureName + " not supported (Bagging)");
    }

    public String toString() {
        if (this.tree == null) {
            return "Balanced random forest not built yet";
        }
        return "Balanced random forest of " + this.numTrees + " trees, each constructed while considering " + this.numFeatures + " random feature" + (this.numFeatures == 1 ? "" : "s") + "\nOut of bag error: " + Utils.doubleToString((double)this.measureOutOfBagError(), (int)4) + ".\n";
    }

    public static void main(String[] argv) {
        BalancedRandomForest.runClassifier((Classifier)new BalancedRandomForest(), (String[])argv);
    }
}

