/*
 * Decompiled with CFR 0.152.
 */
package sc.fiji.labkit.pixel_classification.random_forest;

import hr.irb.fastRandomForest.FastRandomForest;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.loops.LoopBuilder;
import net.imglib2.type.numeric.IntegerType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.view.composite.Composite;
import sc.fiji.labkit.pixel_classification.random_forest.CpuRandomForestCore;
import sc.fiji.labkit.pixel_classification.utils.ArrayUtils;
import sc.fiji.labkit.pixel_classification.utils.views.FastViews;

public class CpuRandomForestPrediction {
    private final CpuRandomForestCore core;
    private final int numberOfFeatures;

    public CpuRandomForestPrediction(FastRandomForest forest, int numberOfFeatures) {
        this.numberOfFeatures = numberOfFeatures;
        this.core = new CpuRandomForestCore(forest);
    }

    public void segment(RandomAccessibleInterval<FloatType> featureStack, RandomAccessibleInterval<? extends IntegerType<?>> out) {
        LoopBuilder.setImages(FastViews.collapse(featureStack), out).forEachChunk(chunk -> {
            float[] features = new float[this.numberOfFeatures];
            float[] probabilities = new float[this.numberOfClasses()];
            chunk.forEachPixel((featureVector, classIndex) -> {
                CpuRandomForestPrediction.copyFromTo((Composite<FloatType>)featureVector, features);
                this.core.distributionForInstance(features, probabilities);
                classIndex.setInteger(ArrayUtils.findMax(probabilities));
            });
            return null;
        });
    }

    public void distribution(RandomAccessibleInterval<FloatType> featureStack, RandomAccessibleInterval<? extends RealType<?>> out) {
        LoopBuilder.setImages(FastViews.collapse(featureStack), FastViews.collapse(out)).forEachChunk(chunk -> {
            float[] features = new float[this.numberOfFeatures];
            float[] probabilities = new float[this.numberOfClasses()];
            chunk.forEachPixel((featureVector, probabilityVector) -> {
                CpuRandomForestPrediction.copyFromTo((Composite<FloatType>)featureVector, features);
                this.core.distributionForInstance(features, probabilities);
                CpuRandomForestPrediction.copyFromTo(probabilities, probabilityVector);
            });
            return null;
        });
    }

    private static void copyFromTo(Composite<FloatType> input, float[] output) {
        int len = output.length;
        for (int i = 0; i < len; ++i) {
            output[i] = ((FloatType)input.get((long)i)).getRealFloat();
        }
    }

    private static void copyFromTo(float[] input, Composite<? extends RealType<?>> output) {
        int len = input.length;
        for (int i = 0; i < len; ++i) {
            ((RealType)output.get((long)i)).setReal(input[i]);
        }
    }

    public int numberOfFeatures() {
        return this.numberOfFeatures;
    }

    public int numberOfClasses() {
        return this.core.numberOfClasses();
    }
}

