/*
 * Decompiled with CFR 0.152.
 */
package sc.fiji.labkit.pixel_classification.gpu.random_forest;

import java.util.HashMap;
import net.imglib2.util.Intervals;
import sc.fiji.labkit.pixel_classification.gpu.api.GpuApi;
import sc.fiji.labkit.pixel_classification.gpu.api.GpuImage;

public class GpuRandomForestKernel {
    private static final long ASSUMED_CONSTANT_MEMORY_SIZE = 65536L;

    public static void randomForest(GpuApi gpu, GpuImage distributions, GpuImage src, GpuImage thresholds, GpuImage probabilities, GpuImage indices, int numberOfFeatures) {
        long[] globalSizes = (long[])src.getDimensions().clone();
        HashMap<String, Object> parameters = new HashMap<String, Object>();
        parameters.put("src", src);
        parameters.put("dst", distributions);
        parameters.put("thresholds", thresholds);
        parameters.put("probabilities", probabilities);
        parameters.put("indices", indices);
        HashMap<String, Object> constants = new HashMap<String, Object>();
        constants.put("NUMBER_OF_CLASSES", probabilities.getWidth());
        constants.put("NUMBER_OF_FEATURES", numberOfFeatures);
        constants.put("INDICES_SIZE", Intervals.numElements((long[])indices.getDimensions()));
        constants.put("CONSTANT_OR_GLOBAL", GpuRandomForestKernel.appropriateMemory(thresholds, indices));
        gpu.execute(GpuRandomForestKernel.class, "random_forest.cl", "random_forest", globalSizes, null, parameters, constants);
    }

    private static String appropriateMemory(GpuImage thresholds, GpuImage indices) {
        long requiredConstantMemory = thresholds.clearCLBuffer().getSizeInBytes() + indices.clearCLBuffer().getSizeInBytes();
        boolean fitsConstantMemory = requiredConstantMemory < 65536L;
        return fitsConstantMemory ? "__constant" : "__global";
    }

    public static void findMax(GpuApi gpu, GpuImage distributions, GpuImage dst) {
        long[] globalSizes = new long[]{dst.getWidth(), dst.getHeight(), dst.getDepth()};
        HashMap<String, Object> parameters = new HashMap<String, Object>();
        parameters.put("dst", dst);
        parameters.put("src", distributions);
        parameters.put("num_classes", (int)distributions.getNumberOfChannels());
        gpu.execute(GpuRandomForestKernel.class, "find_max.cl", "find_max", globalSizes, null, parameters, null);
    }
}

