/*
 * Decompiled with CFR 0.152.
 */
package sc.fiji.labkit.pixel_classification.random_forest;

import hr.irb.fastRandomForest.FastRandomForest;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import sc.fiji.labkit.pixel_classification.random_forest.TransparentRandomForest;
import sc.fiji.labkit.pixel_classification.random_forest.TransparentRandomTree;
import sc.fiji.labkit.pixel_classification.utils.ArrayUtils;

class CpuRandomForestCore {
    private static final int MAX_ARRAY_SIZE = 0x7FFFFFF7;
    private static final int COMPACT_STORAGE_MIN_HEIGHT = 4;
    private final int[] attributes;
    private final float[] thresholds;
    private final float[] probabilities;
    private final int numClasses;
    private final int[] numTreesOfHeight;
    private final float[] prior;

    public CpuRandomForestCore(FastRandomForest classifier) {
        this(TransparentRandomForest.forFastRandomForest(classifier));
    }

    public CpuRandomForestCore(TransparentRandomForest forest) {
        this.numClasses = forest.numberOfClasses();
        this.prior = new float[this.numClasses];
        HashMap treesByHeight = new HashMap();
        for (TransparentRandomTree tree : forest.trees()) {
            treesByHeight.computeIfAbsent(tree.height(), ArrayList::new).add(tree);
        }
        int[] heights = treesByHeight.keySet().stream().mapToInt(Integer::intValue).sorted().toArray();
        int maxHeight = heights.length == 0 ? 0 : heights[heights.length - 1];
        this.numTreesOfHeight = new int[maxHeight + 1];
        long attributesSize = 0L;
        long thresholdsSize = 0L;
        long probabilitiesSize = 0L;
        for (int height = 1; height < this.numTreesOfHeight.length; ++height) {
            List trees = treesByHeight.getOrDefault(height, Collections.emptyList());
            this.numTreesOfHeight[height] = trees.size();
            if (height < 4) {
                long numTrees = trees.size();
                int numLeafs = 1 << height;
                int numNonLeafs = numLeafs - 1;
                attributesSize += (long)numNonLeafs * numTrees;
                thresholdsSize += (long)numNonLeafs * numTrees;
                probabilitiesSize += (long)(numLeafs * this.numClasses) * numTrees;
                continue;
            }
            for (TransparentRandomTree tree : trees) {
                int numNodes = tree.numberOfNodes();
                int numLeafs = tree.numberOfLeafs();
                int numNonLeafs = numNodes - numLeafs;
                attributesSize += (long)(2 + 3 * numNonLeafs);
                thresholdsSize += (long)numNonLeafs;
                probabilitiesSize += (long)(numLeafs * this.numClasses);
            }
        }
        if (attributesSize > 0x7FFFFFF7L || probabilitiesSize > 0x7FFFFFF7L) {
            throw new IllegalArgumentException("forest is too big to represent in " + CpuRandomForestCore.class.getSimpleName());
        }
        this.attributes = new int[(int)attributesSize];
        this.thresholds = new float[(int)thresholdsSize];
        this.probabilities = new float[(int)probabilitiesSize];
        int attributesBase = 0;
        int thresholdsBase = 0;
        int probabilitiesBase = 0;
        int[] j = new int[1];
        for (int height : heights) {
            List trees = (List)treesByHeight.get(height);
            if (height == 0) {
                for (TransparentRandomTree tree : trees) {
                    for (int i = 0; i < this.numClasses; ++i) {
                        int n = i;
                        this.prior[n] = this.prior[n] + (float)tree.classProbabilities()[i];
                    }
                }
                continue;
            }
            if (height < 4) {
                int numNonLeafs;
                int numLeafs = 1 << height;
                int dataSize = numNonLeafs = numLeafs - 1;
                int probSize = numLeafs * this.numClasses;
                for (TransparentRandomTree tree : trees) {
                    this.write(tree, 0, 0, 0, height, attributesBase, probabilitiesBase);
                    attributesBase += dataSize;
                    thresholdsBase += dataSize;
                    probabilitiesBase += probSize;
                }
                continue;
            }
            for (TransparentRandomTree tree : trees) {
                int size;
                j[0] = 0;
                this.attributes[attributesBase] = size = this.write_compact(tree, 0, j, attributesBase + 2, thresholdsBase, probabilitiesBase);
                this.attributes[attributesBase + 1] = j[0];
                attributesBase += 2 + size * 3;
                thresholdsBase += size;
                probabilitiesBase += j[0];
            }
        }
    }

    private void write(TransparentRandomTree node, int nodeIndex, int branchBits, int depth, int height, int treeDataBase, int treeProbBase) {
        if (node.isLeaf()) {
            int b;
            int o;
            if (depth < height) {
                o = treeDataBase + nodeIndex;
                this.attributes[o] = -1;
                b = branchBits << height - depth;
            } else {
                b = branchBits;
            }
            o = treeProbBase + b * this.numClasses;
            for (int i = 0; i < this.numClasses; ++i) {
                this.probabilities[o + i] = (float)node.classProbabilities()[i];
            }
        } else {
            int o = treeDataBase + nodeIndex;
            this.attributes[o] = node.attributeIndex();
            this.thresholds[o] = (float)node.threshold();
            this.write(node.smallerChild(), 2 * nodeIndex + 1, branchBits << 1, depth + 1, height, treeDataBase, treeProbBase);
            this.write(node.biggerChild(), 2 * nodeIndex + 2, (branchBits << 1) + 1, depth + 1, height, treeDataBase, treeProbBase);
        }
    }

    private int write_compact(TransparentRandomTree node, int i, int[] j, int attributesBase, int thresholdsBase, int probabilitiesBase) {
        int rsize;
        int lsize;
        this.attributes[attributesBase + 3 * i] = node.attributeIndex();
        this.thresholds[thresholdsBase + i] = (float)node.threshold();
        TransparentRandomTree smaller = node.smallerChild();
        if (smaller.isLeaf()) {
            this.attributes[attributesBase + 3 * i + 1] = j[0] + Integer.MIN_VALUE;
            for (int c = 0; c < this.numClasses; ++c) {
                int n = j[0];
                j[0] = n + 1;
                this.probabilities[probabilitiesBase + n] = (float)smaller.classProbabilities()[c];
            }
            lsize = 0;
        } else {
            this.attributes[attributesBase + 3 * i + 1] = i + 1;
            lsize = this.write_compact(smaller, i + 1, j, attributesBase, thresholdsBase, probabilitiesBase);
        }
        TransparentRandomTree bigger = node.biggerChild();
        if (bigger.isLeaf()) {
            this.attributes[attributesBase + 3 * i + 2] = j[0] + Integer.MIN_VALUE;
            for (int c = 0; c < this.numClasses; ++c) {
                int n = j[0];
                j[0] = n + 1;
                this.probabilities[probabilitiesBase + n] = (float)bigger.classProbabilities()[c];
            }
            rsize = 0;
        } else {
            this.attributes[attributesBase + 3 * i + 2] = i + lsize + 1;
            rsize = this.write_compact(bigger, i + lsize + 1, j, attributesBase, thresholdsBase, probabilitiesBase);
        }
        return 1 + lsize + rsize;
    }

    void distributionForInstance(float[] instance, float[] distribution) {
        switch (this.numClasses) {
            case 2: {
                this.distributionForInstance_c2(instance, distribution);
                break;
            }
            case 3: {
                this.distributionForInstance_c3(instance, distribution);
                break;
            }
            default: {
                this.distributionForInstance_ck(instance, distribution);
            }
        }
    }

    private void distributionForInstance_ck(float[] instance, float[] distribution) {
        int height;
        int numClasses = this.numClasses;
        for (int i = 0; i < numClasses; ++i) {
            distribution[i] = this.prior[i];
        }
        int attributesBase = 0;
        int probabilitiesBase = 0;
        for (height = 1; height < this.numTreesOfHeight.length && height < 4; ++height) {
            int branchBits;
            int tree;
            int probSize;
            int dataSize;
            int nh = this.numTreesOfHeight[height];
            if (nh == 0) continue;
            if (height == 1) {
                dataSize = 1;
                probSize = 2 * numClasses;
                for (tree = 0; tree < nh; ++tree) {
                    branchBits = this.evaluateTree_h1(instance, attributesBase);
                    this.acc(distribution, numClasses, probabilitiesBase, branchBits * numClasses);
                    ++attributesBase;
                    probabilitiesBase += probSize;
                }
                continue;
            }
            if (height == 2) {
                dataSize = 3;
                probSize = 4 * numClasses;
                for (tree = 0; tree < nh; ++tree) {
                    branchBits = this.evaluateTree_h2(instance, attributesBase);
                    this.acc(distribution, numClasses, probabilitiesBase, branchBits * numClasses);
                    attributesBase += 3;
                    probabilitiesBase += probSize;
                }
                continue;
            }
            if (height == 3) {
                dataSize = 7;
                probSize = 8 * numClasses;
                for (tree = 0; tree < nh; ++tree) {
                    branchBits = this.evaluateTree_h3(instance, attributesBase);
                    this.acc(distribution, numClasses, probabilitiesBase, branchBits * numClasses);
                    attributesBase += 7;
                    probabilitiesBase += probSize;
                }
                continue;
            }
            int numLeafs = 1 << height;
            int dataSize2 = numLeafs - 1;
            int probSize2 = numLeafs * numClasses;
            for (int tree2 = 0; tree2 < nh; ++tree2) {
                int branchBits2 = this.evaluateTree(instance, attributesBase, height);
                this.acc(distribution, numClasses, probabilitiesBase, branchBits2 * numClasses);
                attributesBase += dataSize2;
                probabilitiesBase += probSize2;
            }
        }
        int thresholdsBase = attributesBase;
        while (height < this.numTreesOfHeight.length) {
            int nh = this.numTreesOfHeight[height];
            if (nh != 0) {
                for (int tree = 0; tree < nh; ++tree) {
                    int attrSize = this.attributes[attributesBase];
                    int probSize = this.attributes[attributesBase + 1];
                    int node = 0;
                    while (node >= 0) {
                        int attributeIndex = this.attributes[attributesBase + 2 + 3 * node];
                        float attributeValue = instance[attributeIndex];
                        float threshold = this.thresholds[thresholdsBase + node];
                        node = attributeValue < threshold ? this.attributes[attributesBase + 2 + 3 * node + 1] : this.attributes[attributesBase + 2 + 3 * node + 2];
                    }
                    int j = node - Integer.MIN_VALUE;
                    this.acc(distribution, numClasses, probabilitiesBase, j);
                    attributesBase += 2 + 3 * attrSize;
                    thresholdsBase += attrSize;
                    probabilitiesBase += probSize;
                }
            }
            ++height;
        }
        ArrayUtils.normalize(distribution);
    }

    private void acc(float[] distribution, int numClasses, int probBase, int offset) {
        for (int k = 0; k < numClasses; ++k) {
            int n = k;
            distribution[n] = distribution[n] + this.probabilities[probBase + offset + k];
        }
    }

    void distributionForInstance_c2(float[] instance, float[] distribution) {
        int height;
        float c0 = this.prior[0];
        float c1 = this.prior[1];
        int attributesBase = 0;
        int probabilitiesBase = 0;
        for (height = 1; height < this.numTreesOfHeight.length && height < 4; ++height) {
            int branchBits;
            int tree;
            int probSize;
            int dataSize;
            int nh = this.numTreesOfHeight[height];
            if (nh == 0) continue;
            if (height == 1) {
                dataSize = 1;
                probSize = 2 * this.numClasses;
                for (tree = 0; tree < nh; ++tree) {
                    branchBits = this.evaluateTree_h1(instance, attributesBase);
                    c0 += this.probabilities[probabilitiesBase + branchBits * this.numClasses];
                    c1 += this.probabilities[probabilitiesBase + branchBits * this.numClasses + 1];
                    ++attributesBase;
                    probabilitiesBase += probSize;
                }
                continue;
            }
            if (height == 2) {
                dataSize = 3;
                probSize = 4 * this.numClasses;
                for (tree = 0; tree < nh; ++tree) {
                    branchBits = this.evaluateTree_h2(instance, attributesBase);
                    c0 += this.probabilities[probabilitiesBase + branchBits * this.numClasses];
                    c1 += this.probabilities[probabilitiesBase + branchBits * this.numClasses + 1];
                    attributesBase += 3;
                    probabilitiesBase += probSize;
                }
                continue;
            }
            if (height == 3) {
                dataSize = 7;
                probSize = 8 * this.numClasses;
                for (tree = 0; tree < nh; ++tree) {
                    branchBits = this.evaluateTree_h3(instance, attributesBase);
                    c0 += this.probabilities[probabilitiesBase + branchBits * this.numClasses];
                    c1 += this.probabilities[probabilitiesBase + branchBits * this.numClasses + 1];
                    attributesBase += 7;
                    probabilitiesBase += probSize;
                }
                continue;
            }
            int numLeafs = 1 << height;
            int dataSize2 = numLeafs - 1;
            int probSize2 = numLeafs * this.numClasses;
            for (int tree2 = 0; tree2 < nh; ++tree2) {
                int branchBits2 = this.evaluateTree(instance, attributesBase, height);
                c0 += this.probabilities[probabilitiesBase + branchBits2 * this.numClasses];
                c1 += this.probabilities[probabilitiesBase + branchBits2 * this.numClasses + 1];
                attributesBase += dataSize2;
                probabilitiesBase += probSize2;
            }
        }
        int thresholdsBase = attributesBase;
        while (height < this.numTreesOfHeight.length) {
            int nh = this.numTreesOfHeight[height];
            if (nh != 0) {
                for (int tree = 0; tree < nh; ++tree) {
                    int attrSize = this.attributes[attributesBase];
                    int probSize = this.attributes[attributesBase + 1];
                    int node = 0;
                    while (node >= 0) {
                        int attributeIndex = this.attributes[attributesBase + 2 + 3 * node];
                        float attributeValue = instance[attributeIndex];
                        float threshold = this.thresholds[thresholdsBase + node];
                        node = attributeValue < threshold ? this.attributes[attributesBase + 2 + 3 * node + 1] : this.attributes[attributesBase + 2 + 3 * node + 2];
                    }
                    int j = node - Integer.MIN_VALUE;
                    c0 += this.probabilities[probabilitiesBase + j];
                    c1 += this.probabilities[probabilitiesBase + j + 1];
                    attributesBase += 2 + 3 * attrSize;
                    thresholdsBase += attrSize;
                    probabilitiesBase += probSize;
                }
            }
            ++height;
        }
        float invsum = 1.0f / (c0 + c1);
        distribution[0] = c0 * invsum;
        distribution[1] = c1 * invsum;
    }

    private void distributionForInstance_c3(float[] instance, float[] distribution) {
        int height;
        float c0 = this.prior[0];
        float c1 = this.prior[1];
        float c2 = this.prior[2];
        int numClasses = 3;
        int attributesBase = 0;
        int probabilitiesBase = 0;
        for (height = 1; height < this.numTreesOfHeight.length && height < 4; ++height) {
            int branchBits;
            int tree;
            int probSize;
            int dataSize;
            int nh = this.numTreesOfHeight[height];
            if (nh == 0) continue;
            if (height == 1) {
                dataSize = 1;
                probSize = 6;
                for (tree = 0; tree < nh; ++tree) {
                    branchBits = this.evaluateTree_h1(instance, attributesBase);
                    c0 += this.probabilities[probabilitiesBase + branchBits * 3];
                    c1 += this.probabilities[probabilitiesBase + branchBits * 3 + 1];
                    c2 += this.probabilities[probabilitiesBase + branchBits * 3 + 2];
                    ++attributesBase;
                    probabilitiesBase += 6;
                }
                continue;
            }
            if (height == 2) {
                dataSize = 3;
                probSize = 12;
                for (tree = 0; tree < nh; ++tree) {
                    branchBits = this.evaluateTree_h2(instance, attributesBase);
                    c0 += this.probabilities[probabilitiesBase + branchBits * 3];
                    c1 += this.probabilities[probabilitiesBase + branchBits * 3 + 1];
                    c2 += this.probabilities[probabilitiesBase + branchBits * 3 + 2];
                    attributesBase += 3;
                    probabilitiesBase += 12;
                }
                continue;
            }
            if (height == 3) {
                dataSize = 7;
                probSize = 24;
                for (tree = 0; tree < nh; ++tree) {
                    branchBits = this.evaluateTree_h3(instance, attributesBase);
                    c0 += this.probabilities[probabilitiesBase + branchBits * 3];
                    c1 += this.probabilities[probabilitiesBase + branchBits * 3 + 1];
                    c2 += this.probabilities[probabilitiesBase + branchBits * 3 + 2];
                    attributesBase += 7;
                    probabilitiesBase += 24;
                }
                continue;
            }
            int numLeafs = 1 << height;
            int dataSize2 = numLeafs - 1;
            int probSize2 = numLeafs * 3;
            for (int tree2 = 0; tree2 < nh; ++tree2) {
                int branchBits2 = this.evaluateTree(instance, attributesBase, height);
                c0 += this.probabilities[probabilitiesBase + branchBits2 * 3];
                c1 += this.probabilities[probabilitiesBase + branchBits2 * 3 + 1];
                c2 += this.probabilities[probabilitiesBase + branchBits2 * 3 + 2];
                attributesBase += dataSize2;
                probabilitiesBase += probSize2;
            }
        }
        int thresholdsBase = attributesBase;
        while (height < this.numTreesOfHeight.length) {
            int nh = this.numTreesOfHeight[height];
            if (nh != 0) {
                for (int tree = 0; tree < nh; ++tree) {
                    int attrSize = this.attributes[attributesBase];
                    int probSize = this.attributes[attributesBase + 1];
                    int node = 0;
                    while (node >= 0) {
                        int attributeIndex = this.attributes[attributesBase + 2 + 3 * node];
                        float attributeValue = instance[attributeIndex];
                        float threshold = this.thresholds[thresholdsBase + node];
                        node = attributeValue < threshold ? this.attributes[attributesBase + 2 + 3 * node + 1] : this.attributes[attributesBase + 2 + 3 * node + 2];
                    }
                    int j = node - Integer.MIN_VALUE;
                    c0 += this.probabilities[probabilitiesBase + j];
                    c1 += this.probabilities[probabilitiesBase + j + 1];
                    c2 += this.probabilities[probabilitiesBase + j + 2];
                    attributesBase += 2 + 3 * attrSize;
                    thresholdsBase += attrSize;
                    probabilitiesBase += probSize;
                }
            }
            ++height;
        }
        float invsum = 1.0f / (c0 + c1 + c2);
        distribution[0] = c0 * invsum;
        distribution[1] = c1 * invsum;
        distribution[2] = c2 * invsum;
    }

    private int evaluateTree(float[] instance, int dataBase, int height) {
        int branchBits = 0;
        int nodeIndex = 0;
        for (int depth = 0; depth < height; ++depth) {
            int o = dataBase + nodeIndex;
            int attributeIndex = this.attributes[o];
            if (attributeIndex < 0) {
                branchBits <<= height - depth;
                break;
            }
            float attributeValue = instance[attributeIndex];
            float threshold = this.thresholds[o];
            int branch = attributeValue < threshold ? 0 : 1;
            nodeIndex = (nodeIndex << 1) + branch + 1;
            branchBits = (branchBits << 1) + branch;
        }
        return branchBits;
    }

    private int evaluateTree_h1(float[] instance, int dataBase) {
        int attributeIndex = this.attributes[dataBase];
        float attributeValue = instance[attributeIndex];
        float threshold = this.thresholds[dataBase];
        int branchBits = attributeValue < threshold ? 0 : 1;
        return branchBits;
    }

    private int evaluateTree_h2(float[] instance, int dataBase) {
        int o;
        int branchBits;
        int attributeIndex0 = this.attributes[dataBase];
        float attributeValue0 = instance[attributeIndex0];
        float threshold0 = this.thresholds[dataBase];
        if (attributeValue0 < threshold0) {
            branchBits = 0;
            o = 1;
        } else {
            branchBits = 2;
            o = 2;
        }
        int dataBase1 = dataBase + o;
        int attributeIndex1 = this.attributes[dataBase1];
        if (attributeIndex1 < 0) {
            return branchBits;
        }
        float attributeValue1 = instance[attributeIndex1];
        float threshold1 = this.thresholds[dataBase1];
        if (attributeValue1 >= threshold1) {
            ++branchBits;
        }
        return branchBits;
    }

    private int evaluateTree_h3(float[] instance, int dataBase) {
        int o;
        int branchBits;
        int attributeIndex0 = this.attributes[dataBase];
        float attributeValue0 = instance[attributeIndex0];
        float threshold0 = this.thresholds[dataBase];
        if (attributeValue0 < threshold0) {
            branchBits = 0;
            o = 1;
        } else {
            branchBits = 4;
            o = 2;
        }
        int dataBase1 = dataBase + o;
        int attributeIndex1 = this.attributes[dataBase1];
        if (attributeIndex1 < 0) {
            return branchBits;
        }
        float attributeValue1 = instance[attributeIndex1];
        float threshold1 = this.thresholds[dataBase1];
        if (attributeValue1 < threshold1) {
            o = o * 2 + 1;
        } else {
            o = o * 2 + 2;
            branchBits += 2;
        }
        int dataBase2 = dataBase + o;
        int attributeIndex2 = this.attributes[dataBase2];
        if (attributeIndex2 < 0) {
            return branchBits;
        }
        float attributeValue2 = instance[attributeIndex2];
        float threshold2 = this.thresholds[dataBase2];
        if (attributeValue2 >= threshold2) {
            ++branchBits;
        }
        return branchBits;
    }

    private void generic_distributionForInstance(float[] instance, float[] distribution) {
        int height;
        int numClasses = this.numClasses;
        for (int i = 0; i < numClasses; ++i) {
            distribution[i] = this.prior[i];
        }
        int attributesBase = 0;
        int probabilitiesBase = 0;
        for (height = 1; height < this.numTreesOfHeight.length && height < 4; ++height) {
            int nh = this.numTreesOfHeight[height];
            if (nh == 0) continue;
            int numLeafs = 1 << height;
            int dataSize = numLeafs - 1;
            int probSize = numLeafs * numClasses;
            for (int tree = 0; tree < nh; ++tree) {
                int branchBits = this.evaluateTree(instance, attributesBase, height);
                for (int k = 0; k < numClasses; ++k) {
                    int n = k;
                    distribution[n] = distribution[n] + this.probabilities[probabilitiesBase + branchBits * numClasses + k];
                }
                attributesBase += dataSize;
                probabilitiesBase += probSize;
            }
        }
        int thresholdsBase = attributesBase;
        while (height < this.numTreesOfHeight.length) {
            int nh = this.numTreesOfHeight[height];
            if (nh != 0) {
                for (int tree = 0; tree < nh; ++tree) {
                    int attrSize = this.attributes[attributesBase];
                    int probSize = this.attributes[attributesBase + 1];
                    int node = 0;
                    while (node >= 0) {
                        int attributeIndex = this.attributes[attributesBase + 2 + 3 * node];
                        float attributeValue = instance[attributeIndex];
                        float threshold = this.thresholds[thresholdsBase + node];
                        node = attributeValue < threshold ? this.attributes[attributesBase + 2 + 3 * node + 1] : this.attributes[attributesBase + 2 + 3 * node + 2];
                    }
                    int j = node - Integer.MIN_VALUE;
                    for (int k = 0; k < numClasses; ++k) {
                        int n = k;
                        distribution[n] = distribution[n] + this.probabilities[probabilitiesBase + j + k];
                    }
                    attributesBase += 2 + 3 * attrSize;
                    thresholdsBase += attrSize;
                    probabilitiesBase += probSize;
                }
            }
            ++height;
        }
        ArrayUtils.normalize(distribution);
    }

    public int numberOfClasses() {
        return this.numClasses;
    }
}

