package quickdt.bagging;

import com.google.common.base.Preconditions;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.Lists;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import quickdt.AbstractInstance;
import quickdt.Attributes;
import quickdt.Misc;
import quickdt.Tree;
import quickdt.TreeBuilder;

/* loaded from: input_file:lib/palladian.jar:quickdt/bagging/BaggedTree.class */
public class BaggedTree implements Serializable {
    private static final long serialVersionUID = 8996197519632788949L;
    public static final int DEFAULT_TREE_COUNT = 10;
    private final List<Tree> trees;

    private BaggedTree(List<Tree> list) {
        this.trees = list;
    }

    public static BaggedTree build(TreeBuilder treeBuilder, int i, Iterable<? extends AbstractInstance> iterable) {
        Preconditions.checkNotNull(treeBuilder);
        Preconditions.checkArgument(i > 0, "numTrees must be greater than zero");
        Preconditions.checkNotNull(iterable);
        ArrayList newArrayList = Lists.newArrayList();
        for (int i2 = 0; i2 < i; i2++) {
            newArrayList.add(treeBuilder.buildPredictiveModel((Iterable<? extends AbstractInstance>) getBootstrapSampling(iterable)));
        }
        return new BaggedTree(newArrayList);
    }

    public static BaggedTree build(TreeBuilder treeBuilder, Iterable<? extends AbstractInstance> iterable) {
        return build(treeBuilder, 10, iterable);
    }

    public BaggingResult predict(Attributes attributes) {
        Preconditions.checkNotNull(attributes);
        HashMultiset create = HashMultiset.create();
        Iterator<Tree> it = this.trees.iterator();
        while (it.hasNext()) {
            create.add(it.next().node.getLeaf(attributes).getBestClassification());
        }
        return new BaggingResult(create);
    }

    public double getProbability(Attributes attributes, Serializable serializable) {
        int i = 0;
        double d = 0.0d;
        Iterator<Tree> it = this.trees.iterator();
        while (it.hasNext()) {
            d += it.next().node.getLeaf(attributes).getProbability(serializable);
            i++;
        }
        return d / i;
    }

    private static List<AbstractInstance> getBootstrapSampling(Iterable<? extends AbstractInstance> iterable) {
        ArrayList newArrayList = Lists.newArrayList(iterable);
        ArrayList newArrayList2 = Lists.newArrayList();
        for (int i = 0; i < newArrayList.size(); i++) {
            newArrayList2.add(newArrayList.get(Misc.random.nextInt(newArrayList.size())));
        }
        return newArrayList2;
    }

    public List<Tree> getTrees() {
        return this.trees;
    }
}
