package quickdt;

import com.google.common.base.Predicate;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.uprizer.sensearray.freetools.stats.ReservoirSampler;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Map;
import java.util.Set;
import org.javatuples.Pair;
import org.jdesktop.swingx.JXLabel;
import quickdt.scorers.Scorer1;

/* loaded from: input_file:lib/palladian.jar:quickdt/TreeBuilder.class */
public final class TreeBuilder implements PredictiveModelBuilder<Tree> {
    public static final int ORDINAL_TEST_SPLITS = 5;
    private int maxDepth;
    private double minProbability;
    private int attributeExcludeDepth;
    private double ignoreAttributeAtNodeProbability;
    private int minNominalAttributeValueOccurances;
    private Set<String> excludeAttributes;
    Scorer scorer;

    public TreeBuilder maxDepth(int i) {
        this.maxDepth = i;
        return this;
    }

    public TreeBuilder minProbability(double d) {
        this.minProbability = d;
        return this;
    }

    public TreeBuilder attributeExcludeDepth(int i) {
        this.attributeExcludeDepth = i;
        return this;
    }

    public TreeBuilder excludeAttributes(Set<String> set) {
        this.excludeAttributes = set;
        return this;
    }

    public TreeBuilder ignoreAttributeAtNodeProbability(double d) {
        this.ignoreAttributeAtNodeProbability = d;
        return this;
    }

    public TreeBuilder minNominalAttributeValueOccurances(int i) {
        this.minNominalAttributeValueOccurances = i;
        return this;
    }

    public TreeBuilder() {
        this(new Scorer1());
    }

    public TreeBuilder(Scorer scorer) {
        this.maxDepth = Integer.MAX_VALUE;
        this.minProbability = 1.0d;
        this.attributeExcludeDepth = 1;
        this.ignoreAttributeAtNodeProbability = JXLabel.NORMAL;
        this.minNominalAttributeValueOccurances = 5;
        this.excludeAttributes = Collections.emptySet();
        this.scorer = scorer;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // quickdt.PredictiveModelBuilder
    public Tree buildPredictiveModel(Iterable<? extends AbstractInstance> iterable) {
        return new Tree(buildTree(null, iterable, 0, createOrdinalSplits(iterable)));
    }

    private double[] createOrdinalSplit(Iterable<? extends AbstractInstance> iterable, String str) {
        ReservoirSampler reservoirSampler = new ReservoirSampler(1000);
        Iterator<? extends AbstractInstance> it = iterable.iterator();
        while (it.hasNext()) {
            reservoirSampler.addSample(Double.valueOf(((Number) it.next().getAttributes().get(str)).doubleValue()));
        }
        ArrayList newArrayList = Lists.newArrayList();
        Iterator it2 = reservoirSampler.getSamples().iterator();
        while (it2.hasNext()) {
            newArrayList.add((Double) it2.next());
        }
        Collections.sort(newArrayList);
        double[] dArr = new double[4];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = ((Double) newArrayList.get(((i + 1) * newArrayList.size()) / (dArr.length + 1))).doubleValue();
        }
        return dArr;
    }

    private Map<String, double[]> createOrdinalSplits(Iterable<? extends AbstractInstance> iterable) {
        HashMap newHashMap = Maps.newHashMap();
        Iterator<? extends AbstractInstance> it = iterable.iterator();
        while (it.hasNext()) {
            for (Map.Entry<String, Serializable> entry : it.next().getAttributes().entrySet()) {
                if (entry.getValue() instanceof Number) {
                    ReservoirSampler reservoirSampler = (ReservoirSampler) newHashMap.get(entry.getKey());
                    if (reservoirSampler == null) {
                        reservoirSampler = new ReservoirSampler(1000);
                        newHashMap.put(entry.getKey(), reservoirSampler);
                    }
                    reservoirSampler.addSample(Double.valueOf(((Number) entry.getValue()).doubleValue()));
                }
            }
        }
        HashMap newHashMap2 = Maps.newHashMap();
        for (Map.Entry entry2 : newHashMap.entrySet()) {
            ArrayList newArrayList = Lists.newArrayList();
            Iterator it2 = ((ReservoirSampler) entry2.getValue()).getSamples().iterator();
            while (it2.hasNext()) {
                newArrayList.add((Double) it2.next());
            }
            Collections.sort(newArrayList);
            double[] dArr = new double[4];
            for (int i = 0; i < dArr.length; i++) {
                dArr[i] = ((Double) newArrayList.get(((i + 1) * newArrayList.size()) / (dArr.length + 2))).doubleValue();
            }
            newHashMap2.put(entry2.getKey(), dArr);
        }
        return newHashMap2;
    }

    protected Node buildTree(Node node, Iterable<? extends AbstractInstance> iterable, int i, Map<String, double[]> map) {
        Leaf leaf = new Leaf(node, iterable, i);
        if (i == this.maxDepth || leaf.getBestClassificationProbability() >= this.minProbability) {
            return leaf;
        }
        AbstractInstance abstractInstance = (AbstractInstance) Iterables.get(iterable, 0);
        boolean z = true;
        int i2 = 0;
        Iterator<? extends AbstractInstance> it = iterable.iterator();
        while (true) {
            if (!it.hasNext()) {
                break;
            }
            it.next();
            i2++;
            if (i2 > 10) {
                z = false;
                break;
            }
        }
        Branch branch = null;
        double d = 0.0d;
        for (Map.Entry<String, Serializable> entry : abstractInstance.getAttributes().entrySet()) {
            if (i > this.attributeExcludeDepth || !this.excludeAttributes.contains(entry.getKey())) {
                if (this.ignoreAttributeAtNodeProbability <= JXLabel.NORMAL || Misc.random.nextDouble() >= this.ignoreAttributeAtNodeProbability) {
                    Pair<? extends Branch, Double> pair = null;
                    if (!z && (entry.getValue() instanceof Number)) {
                        pair = createOrdinalNode(node, entry.getKey(), iterable, map.get(entry.getKey()));
                    }
                    if (pair == null || pair.getValue1().doubleValue() == JXLabel.NORMAL) {
                        pair = createNominalNode(node, entry.getKey(), iterable);
                    }
                    if (pair.getValue1().doubleValue() > d) {
                        d = pair.getValue1().doubleValue();
                        branch = pair.getValue0();
                    }
                }
            }
        }
        if (branch == null) {
            return leaf;
        }
        double[] dArr = null;
        LinkedList newLinkedList = Lists.newLinkedList(Iterables.filter(iterable, branch.getInPredicate()));
        LinkedList newLinkedList2 = Lists.newLinkedList(Iterables.filter(iterable, branch.getOutPredicate()));
        if (branch instanceof OrdinalBranch) {
            OrdinalBranch ordinalBranch = (OrdinalBranch) branch;
            dArr = map.get(ordinalBranch.attribute);
            map.put(ordinalBranch.attribute, createOrdinalSplit(newLinkedList, ordinalBranch.attribute));
        }
        branch.trueChild = buildTree(branch, newLinkedList, i + 1, map);
        if (branch instanceof OrdinalBranch) {
            OrdinalBranch ordinalBranch2 = (OrdinalBranch) branch;
            map.put(ordinalBranch2.attribute, createOrdinalSplit(newLinkedList2, ordinalBranch2.attribute));
        }
        branch.falseChild = buildTree(branch, newLinkedList2, i + 1, map);
        if (branch instanceof OrdinalBranch) {
            map.put(((OrdinalBranch) branch).attribute, dArr);
        }
        return branch;
    }

    protected Pair<? extends Branch, Double> createNominalNode(Node node, String str, Iterable<? extends AbstractInstance> iterable) {
        HashSet<Serializable> newHashSet = Sets.newHashSet();
        Iterator<? extends AbstractInstance> it = iterable.iterator();
        while (it.hasNext()) {
            newHashSet.add(it.next().getAttributes().get(str));
        }
        double d = 0.0d;
        HashSet newHashSet2 = Sets.newHashSet();
        ClassificationCounter classificationCounter = new ClassificationCounter();
        Pair<ClassificationCounter, Map<Serializable, ClassificationCounter>> countAllByAttributeValues = ClassificationCounter.countAllByAttributeValues(iterable, str);
        ClassificationCounter value0 = countAllByAttributeValues.getValue0();
        Map<Serializable, ClassificationCounter> value1 = countAllByAttributeValues.getValue1();
        while (true) {
            double d2 = 0.0d;
            Serializable serializable = null;
            for (Serializable serializable2 : newHashSet) {
                ClassificationCounter classificationCounter2 = value1.get(serializable2);
                if (classificationCounter2 != null && (this.minNominalAttributeValueOccurances <= 0 || !shouldWeIgnoreThisValue(classificationCounter2))) {
                    double scoreSplit = this.scorer.scoreSplit(classificationCounter.add(classificationCounter2), value0.subtract(classificationCounter2));
                    if (scoreSplit > d2) {
                        d2 = scoreSplit;
                        serializable = serializable2;
                    }
                }
            }
            if (d2 <= d) {
                return Pair.with(new NominalBranch(node, str, newHashSet2), Double.valueOf(d));
            }
            d = d2;
            newHashSet2.add(serializable);
            newHashSet.remove(serializable);
            ClassificationCounter classificationCounter3 = value1.get(serializable);
            classificationCounter = classificationCounter.add(classificationCounter3);
            value0 = value0.subtract(classificationCounter3);
        }
    }

    private boolean shouldWeIgnoreThisValue(ClassificationCounter classificationCounter) {
        double d = Double.MAX_VALUE;
        Iterator<Double> it = classificationCounter.getCounts().values().iterator();
        while (it.hasNext()) {
            double doubleValue = it.next().doubleValue();
            if (doubleValue < d) {
                d = doubleValue;
            }
        }
        return d < ((double) this.minNominalAttributeValueOccurances);
    }

    protected Pair<? extends Branch, Double> createOrdinalNode(Node node, final String str, Iterable<? extends AbstractInstance> iterable, double[] dArr) {
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = Double.MIN_VALUE;
        for (final double d4 : dArr) {
            if (d4 != d3) {
                d3 = d4;
                double scoreSplit = this.scorer.scoreSplit(ClassificationCounter.countAll(Iterables.filter(iterable, new Predicate<AbstractInstance>() { // from class: quickdt.TreeBuilder.1
                    @Override // com.google.common.base.Predicate
                    public boolean apply(AbstractInstance abstractInstance) {
                        try {
                            return ((Number) abstractInstance.getAttributes().get(str)).doubleValue() > d4;
                        } catch (ClassCastException e) {
                            return false;
                        }
                    }
                })), ClassificationCounter.countAll(Iterables.filter(iterable, new Predicate<AbstractInstance>() { // from class: quickdt.TreeBuilder.2
                    @Override // com.google.common.base.Predicate
                    public boolean apply(AbstractInstance abstractInstance) {
                        try {
                            return ((Number) abstractInstance.getAttributes().get(str)).doubleValue() <= d4;
                        } catch (ClassCastException e) {
                            return false;
                        }
                    }
                })));
                if (scoreSplit > d) {
                    d = scoreSplit;
                    d2 = d4;
                }
            }
        }
        return Pair.with(new OrdinalBranch(node, str, d2), Double.valueOf(d));
    }

    @Override // quickdt.PredictiveModelBuilder
    public /* bridge */ /* synthetic */ Tree buildPredictiveModel(Iterable iterable) {
        return buildPredictiveModel((Iterable<? extends AbstractInstance>) iterable);
    }
}
