package quickdt.randomForest;

import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import quickdt.AbstractInstance;
import quickdt.Branch;
import quickdt.Misc;
import quickdt.PredictiveModelBuilder;
import quickdt.Tree;
import quickdt.TreeBuilder;

/* loaded from: input_file:lib/palladian.jar:quickdt/randomForest/RandomForestBuilder.class */
public class RandomForestBuilder implements PredictiveModelBuilder<RandomForest> {
    private final TreeBuilder treeBuilder;
    private int numTrees;
    private boolean useBagging;
    private int attributesPerTree;

    public RandomForestBuilder() {
        this(new TreeBuilder());
    }

    public RandomForestBuilder(TreeBuilder treeBuilder) {
        this.numTrees = 8;
        this.useBagging = false;
        this.attributesPerTree = 0;
        this.treeBuilder = treeBuilder;
    }

    public RandomForestBuilder numTrees(int i) {
        this.numTrees = i;
        return this;
    }

    public RandomForestBuilder useBagging(boolean z) {
        this.useBagging = z;
        return this;
    }

    public RandomForestBuilder attributesPerTree(int i) {
        this.attributesPerTree = i;
        return this;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // quickdt.PredictiveModelBuilder
    public RandomForest buildPredictiveModel(Iterable<? extends AbstractInstance> iterable) {
        ArrayList newArrayListWithCapacity = Lists.newArrayListWithCapacity(this.numTrees);
        Object[] array = ((AbstractInstance) Iterables.get(iterable, 0)).getAttributes().keySet().toArray();
        HashSet newHashSet = Sets.newHashSet();
        for (int i = 0; i < this.numTrees; i++) {
            if (this.attributesPerTree > 0) {
                newHashSet.clear();
                while (newHashSet.size() < array.length - this.attributesPerTree) {
                    newHashSet.add((String) array[Misc.random.nextInt(array.length)]);
                }
            }
            this.treeBuilder.excludeAttributes(newHashSet);
            Tree buildPredictiveModel = this.useBagging ? this.treeBuilder.buildPredictiveModel((Iterable<? extends AbstractInstance>) getBootstrapSampling(iterable)) : this.treeBuilder.buildPredictiveModel(iterable);
            if (this.attributesPerTree == 0 && (buildPredictiveModel.node instanceof Branch)) {
                newHashSet.add(((Branch) buildPredictiveModel.node).attribute);
            }
            newArrayListWithCapacity.add(buildPredictiveModel);
        }
        return new RandomForest(newArrayListWithCapacity);
    }

    private static List<AbstractInstance> getBootstrapSampling(Iterable<? extends AbstractInstance> iterable) {
        ArrayList newArrayList = Lists.newArrayList(iterable);
        ArrayList newArrayList2 = Lists.newArrayList();
        for (int i = 0; i < newArrayList.size(); i++) {
            newArrayList2.add(newArrayList.get(Misc.random.nextInt(newArrayList.size())));
        }
        return newArrayList2;
    }

    @Override // quickdt.PredictiveModelBuilder
    public /* bridge */ /* synthetic */ RandomForest buildPredictiveModel(Iterable iterable) {
        return buildPredictiveModel((Iterable<? extends AbstractInstance>) iterable);
    }
}
