/*
 * Decompiled with CFR 0.152.
 */
package sc.fiji.labkit.pixel_classification.gpu.random_forest;

import sc.fiji.labkit.pixel_classification.random_forest.TransparentRandomTree;
import weka.core.Instance;

public class GpuRandomTreePrediction {
    final int numberOfNodes;
    final int numberOfLeafs;
    private int nodeCount = 0;
    private int leafCount = 0;
    final int[] attributeIndicies;
    final double[] threshold;
    final int[] smallerChild;
    final int[] biggerChild;
    final double[][] classProbabilities;

    public GpuRandomTreePrediction(TransparentRandomTree tree) {
        if (tree.isLeaf()) {
            this.numberOfLeafs = 1;
            this.numberOfNodes = 1;
            this.attributeIndicies = new int[]{0};
            this.threshold = new double[]{0.0};
            this.smallerChild = new int[]{Short.MIN_VALUE};
            this.biggerChild = new int[]{Short.MIN_VALUE};
            this.classProbabilities = new double[][]{tree.classProbabilities()};
        } else {
            this.numberOfLeafs = this.countLeafs(tree);
            this.numberOfNodes = this.countNodes(tree);
            this.attributeIndicies = new int[this.numberOfNodes];
            this.threshold = new double[this.numberOfNodes];
            this.smallerChild = new int[this.numberOfNodes];
            this.biggerChild = new int[this.numberOfNodes];
            this.classProbabilities = new double[this.numberOfLeafs][];
            this.addTree(tree);
        }
    }

    private int countNodes(TransparentRandomTree node) {
        return node.isLeaf() ? 0 : 1 + this.countNodes(node.smallerChild()) + this.countNodes(node.biggerChild());
    }

    private int countLeafs(TransparentRandomTree node) {
        return node.isLeaf() ? 1 : this.countLeafs(node.smallerChild()) + this.countLeafs(node.biggerChild());
    }

    int addTree(TransparentRandomTree node) {
        return node.isLeaf() ? this.addLeaf(node) : this.addNode(node);
    }

    private int addNode(TransparentRandomTree node) {
        int i = this.nodeCount++;
        this.attributeIndicies[i] = node.attributeIndex();
        this.threshold[i] = node.threshold();
        this.smallerChild[i] = this.addTree(node.smallerChild());
        this.biggerChild[i] = this.addTree(node.biggerChild());
        return i;
    }

    private int addLeaf(TransparentRandomTree node) {
        int i;
        if ((i = this.leafCount++) >= this.classProbabilities.length) {
            throw new AssertionError();
        }
        this.classProbabilities[i] = node.classProbabilities();
        return i + Short.MIN_VALUE;
    }

    public double[] distributionForInstance(Instance instance) {
        int nodeIndex = 0;
        while (nodeIndex >= 0) {
            int attributeIndex = this.attributeIndicies[nodeIndex];
            double attributeValue = instance.value(attributeIndex);
            nodeIndex = attributeValue < this.threshold[nodeIndex] ? this.smallerChild[nodeIndex] : this.biggerChild[nodeIndex];
        }
        int leafIndex = nodeIndex - Short.MIN_VALUE;
        return this.classProbabilities[leafIndex];
    }
}

