/*
 * Decompiled with CFR 0.152.
 */
package sc.fiji.labkit.pixel_classification.random_forest;

import hr.irb.fastRandomForest.FastRandomForest;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import sc.fiji.labkit.pixel_classification.random_forest.ReflectionUtils;
import sc.fiji.labkit.pixel_classification.random_forest.TransparentRandomTree;
import sc.fiji.labkit.pixel_classification.utils.ArrayUtils;
import weka.core.Instance;

public class TransparentRandomForest {
    private final List<TransparentRandomTree> trees;

    public TransparentRandomForest(List<TransparentRandomTree> trees) {
        this.trees = trees;
    }

    public static TransparentRandomForest forFastRandomForest(FastRandomForest original) {
        return new TransparentRandomForest(TransparentRandomForest.initTrees(original));
    }

    private static List<TransparentRandomTree> initTrees(FastRandomForest original) {
        Object bagger = ReflectionUtils.getPrivateField(original, "m_bagger", Object.class);
        if (bagger == null) {
            return Collections.emptyList();
        }
        Object[] trees = ReflectionUtils.getPrivateField(bagger, "m_Classifiers", Object[].class);
        return Collections.unmodifiableList(Stream.of(trees).map(TransparentRandomTree::forFastRandomTree).collect(Collectors.toList()));
    }

    public List<TransparentRandomTree> trees() {
        return this.trees;
    }

    public int numberOfClasses() {
        return this.trees.isEmpty() ? 0 : this.trees.get(0).numberOfClasses();
    }

    public double[] distributionForInstance(Instance instance, int numberOfClasses) {
        double[] result = new double[numberOfClasses];
        for (TransparentRandomTree tree : this.trees) {
            ArrayUtils.add(tree.distributionForInstance(instance), result);
        }
        return ArrayUtils.normalize(result);
    }
}

