/*
 * Decompiled with CFR 0.152.
 */
package sc.fiji.labkit.pixel_classification.gpu.random_forest;

import hr.irb.fastRandomForest.FastRandomForest;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import net.haesleinhuepf.clij.coremem.enums.NativeTypeEnum;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.img.array.ArrayImg;
import net.imglib2.img.array.ArrayImgs;
import net.imglib2.type.numeric.RealType;
import sc.fiji.labkit.pixel_classification.gpu.api.GpuApi;
import sc.fiji.labkit.pixel_classification.gpu.api.GpuImage;
import sc.fiji.labkit.pixel_classification.gpu.random_forest.GpuRandomForestKernel;
import sc.fiji.labkit.pixel_classification.gpu.random_forest.GpuRandomTreePrediction;
import sc.fiji.labkit.pixel_classification.random_forest.TransparentRandomForest;
import sc.fiji.labkit.pixel_classification.utils.ArrayUtils;

public class GpuRandomForestPrediction {
    private final int numberOfClasses;
    private final int numberOfFeatures;
    private final int numberOfTrees;
    private final int numberOfNodes;
    private final int numberOfLeafs;
    private final short[] nodeIndices;
    private final float[] nodeThresholds;
    private final float[] leafProbabilities;

    public GpuRandomForestPrediction(FastRandomForest classifier, int numberOfFeatures) {
        TransparentRandomForest forest = TransparentRandomForest.forFastRandomForest(classifier);
        List trees = forest.trees().stream().map(GpuRandomTreePrediction::new).collect(Collectors.toList());
        this.numberOfClasses = forest.numberOfClasses();
        this.numberOfFeatures = numberOfFeatures;
        this.numberOfTrees = trees.size();
        this.numberOfNodes = trees.stream().mapToInt(x -> x.numberOfNodes).max().orElse(0);
        this.numberOfLeafs = trees.stream().mapToInt(x -> x.numberOfLeafs).max().orElse(0);
        this.nodeIndices = new short[this.numberOfTrees * this.numberOfNodes * 3];
        this.nodeThresholds = new float[this.numberOfTrees * this.numberOfNodes];
        this.leafProbabilities = new float[this.numberOfTrees * this.numberOfLeafs * this.numberOfClasses];
        for (int j = 0; j < this.numberOfTrees; ++j) {
            int i;
            GpuRandomTreePrediction tree = (GpuRandomTreePrediction)trees.get(j);
            for (i = 0; i < tree.numberOfNodes; ++i) {
                this.nodeIndices[(j * this.numberOfNodes + i) * 3] = (short)tree.attributeIndicies[i];
                this.nodeIndices[(j * this.numberOfNodes + i) * 3 + 1] = (short)tree.smallerChild[i];
                this.nodeIndices[(j * this.numberOfNodes + i) * 3 + 2] = (short)tree.biggerChild[i];
                this.nodeThresholds[j * this.numberOfNodes + i] = (float)tree.threshold[i];
            }
            for (i = 0; i < tree.numberOfLeafs; ++i) {
                for (int k = 0; k < this.numberOfClasses; ++k) {
                    this.leafProbabilities[(j * this.numberOfLeafs + i) * this.numberOfClasses + k] = (float)tree.classProbabilities[i][k];
                }
            }
        }
    }

    public int numberOfClasses() {
        return this.numberOfClasses;
    }

    public int numberOfFeatures() {
        return this.numberOfFeatures;
    }

    public void distribution(GpuApi gpu, GpuImage featureStack, GpuImage distribution) {
        try (GpuApi scope = gpu.subScope();){
            ArrayImg indices = ArrayImgs.unsignedShorts((short[])this.nodeIndices, (long[])new long[]{3L, this.numberOfNodes, this.numberOfTrees});
            ArrayImg thresholds = ArrayImgs.floats((float[])this.nodeThresholds, (long[])new long[]{1L, this.numberOfNodes, this.numberOfTrees});
            ArrayImg probabilities = ArrayImgs.floats((float[])this.leafProbabilities, (long[])new long[]{this.numberOfClasses, this.numberOfLeafs, this.numberOfTrees});
            GpuImage thresholdsClBuffer = scope.push((RandomAccessibleInterval<? extends RealType<?>>)thresholds);
            GpuImage probabilitiesClBuffer = scope.push((RandomAccessibleInterval<? extends RealType<?>>)probabilities);
            GpuImage indicesClBuffer = scope.push((RandomAccessibleInterval<? extends RealType<?>>)indices);
            GpuRandomForestKernel.randomForest(scope, distribution, featureStack, thresholdsClBuffer, probabilitiesClBuffer, indicesClBuffer, this.numberOfFeatures);
        }
    }

    public GpuImage segment(GpuApi gpu, GpuImage featureStack) {
        try (GpuApi scope = gpu.subScope();){
            GpuImage distribution = scope.create(featureStack.getDimensions(), this.numberOfClasses, NativeTypeEnum.Float);
            this.distribution(scope, featureStack, distribution);
            GpuImage output = gpu.create(distribution.getDimensions(), NativeTypeEnum.UnsignedShort);
            GpuRandomForestKernel.findMax(scope, distribution, output);
            GpuImage gpuImage = output;
            return gpuImage;
        }
    }

    private void distributionForInstance(float[] instance, float[] distribution) {
        Arrays.fill(distribution, 0.0f);
        for (int tree = 0; tree < this.numberOfTrees; ++tree) {
            this.addDistributionForTree(instance, tree, distribution);
        }
        ArrayUtils.normalize(distribution);
    }

    private void addDistributionForTree(float[] instance, int tree, float[] distribution) {
        int node = 0;
        while (node >= 0) {
            int nodeOffset = tree * this.numberOfNodes + node;
            short attributeIndex = this.nodeIndices[nodeOffset * 3];
            float attributeValue = instance[attributeIndex];
            int b = attributeValue < this.nodeThresholds[nodeOffset] ? 1 : 2;
            node = this.nodeIndices[nodeOffset * 3 + b];
        }
        int leaf = node - Short.MIN_VALUE;
        int leafOffset = (tree * this.numberOfLeafs + leaf) * this.numberOfClasses;
        for (int k = 0; k < this.numberOfClasses; ++k) {
            int n = k;
            distribution[n] = distribution[n] + this.leafProbabilities[leafOffset + k];
        }
    }
}

