/*
 * Decompiled with CFR 0.152.
 */
package sc.fiji.labkit.pixel_classification.gpu.algorithms;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import net.imglib2.Dimensions;
import net.imglib2.FinalInterval;
import net.imglib2.Interval;
import net.imglib2.util.Intervals;
import org.apache.commons.lang3.ArrayUtils;
import sc.fiji.labkit.pixel_classification.gpu.algorithms.GpuKernelConvolution;
import sc.fiji.labkit.pixel_classification.gpu.algorithms.GpuNeighborhoodOperation;
import sc.fiji.labkit.pixel_classification.gpu.api.GpuApi;
import sc.fiji.labkit.pixel_classification.gpu.api.GpuImage;
import sc.fiji.labkit.pixel_classification.gpu.api.GpuView;

class GpuSeparableOperation
implements GpuNeighborhoodOperation {
    private final GpuApi gpu;
    private final Operation operation;
    private final int windowSize;
    private final int d;

    GpuSeparableOperation(GpuApi gpu, Operation operation, int windowSize, int d) {
        this.gpu = gpu;
        this.operation = operation;
        this.windowSize = windowSize;
        this.d = d;
    }

    @Override
    public Interval getRequiredInputInterval(Interval targetInterval) {
        long[] min = Intervals.minAsLongArray((Interval)targetInterval);
        long[] max = Intervals.maxAsLongArray((Interval)targetInterval);
        int n = this.d;
        min[n] = min[n] - (long)(this.windowSize / 2);
        int n2 = this.d;
        max[n2] = max[n2] + (long)((this.windowSize - 1) / 2);
        return new FinalInterval(min, max);
    }

    @Override
    public void apply(GpuView input, GpuView output) {
        GpuSeparableOperation.run(this.gpu, this.operation.name, this.windowSize, new HashMap<String, Object>(), input, output, this.d);
    }

    static void run(GpuApi gpu, String kernelFile, long windowSize, HashMap<String, Object> parameters, GpuView input, GpuView output, int d) {
        parameters.put("input", input.source());
        parameters.put("output", output.source());
        HashMap<String, Object> defines = new HashMap<String, Object>();
        long[] localSizes = new long[3];
        Arrays.fill(localSizes, 1L);
        localSizes[0] = output.dimensions().dimension(d);
        defines.put("KERNEL_LENGTH", windowSize);
        defines.put("BLOCK_SIZE", localSizes[0]);
        GpuSeparableOperation.setSkips(defines, "INPUT", input, d);
        GpuSeparableOperation.setSkips(defines, "OUTPUT", output, d);
        defines.put("OUTPUT_IMAGE_PARAMETER", "__global float* output");
        defines.put("INPUT_IMAGE_PARAMETER", "__global float* input");
        defines.put("OUTPUT_WRITE_PIXEL(x,y,z,v)", "output[OUTPUT_OFFSET + OUTPUT_X_SKIP * (x) + OUTPUT_Y_SKIP * (y) + OUTPUT_Z_SKIP * (z)] = v;");
        defines.put("INPUT_READ_PIXEL(x,y,z)", "input[INPUT_OFFSET + INPUT_X_SKIP * (x) + INPUT_Y_SKIP * (y) + INPUT_Z_SKIP * (z)]");
        long[] globalSizes = GpuSeparableOperation.getDimensions(output.dimensions());
        ArrayUtils.swap((long[])globalSizes, (int)0, (int)d);
        gpu.execute(GpuKernelConvolution.class, kernelFile, "separable_operation", globalSizes, localSizes, parameters, defines);
    }

    public static void convolve(GpuApi gpu, GpuImage kernel, GpuImage input, int kernel_center, GpuImage output, int d) {
        HashMap<String, Object> parameters = new HashMap<String, Object>();
        parameters.put("input", input);
        parameters.put("kernelValues", kernel);
        parameters.put("output", output);
        HashMap<String, Object> defines = new HashMap<String, Object>();
        long[] localSizes = new long[3];
        Arrays.fill(localSizes, 1L);
        localSizes[0] = output.getDimensions()[d];
        defines.put("KERNEL_LENGTH", kernel.getWidth());
        defines.put("BLOCK_SIZE", localSizes[0]);
        defines.put("OUTPUT_IMAGE_PARAMETER", "IMAGE_output_TYPE output");
        defines.put("INPUT_IMAGE_PARAMETER", "IMAGE_input_TYPE input");
        defines.put("OUTPUT_WRITE_PIXEL(cx,cy,cz,v)", "WRITE_output_IMAGE(output, POS_output_INSTANCE(" + GpuSeparableOperation.position(d, 0) + ",0), v)");
        defines.put("INPUT_READ_PIXEL(cx,cy,cz)", "READ_input_IMAGE(input, sampler, POS_input_INSTANCE(" + GpuSeparableOperation.position(d, kernel_center) + ",0)).x");
        long[] globalSizes = output.getDimensions();
        ArrayUtils.swap((long[])globalSizes, (int)0, (int)d);
        gpu.execute(GpuKernelConvolution.class, "convolve1d.cl", "separable_operation", globalSizes, localSizes, parameters, defines);
    }

    private static String position(int d, int kernel_center) {
        ArrayList<String> list = new ArrayList<String>(Arrays.asList("(cx) - " + kernel_center, "(cy)", "(cz)"));
        Collections.swap(list, 0, d);
        return "(" + (String)list.get(0) + "),(" + (String)list.get(1) + "),(" + (String)list.get(2) + ")";
    }

    private static void setSkips(HashMap<String, Object> defines, String prefix, GpuView view, int d) {
        GpuImage buffer = view.source();
        long[] skip = new long[]{1L, buffer.getWidth(), buffer.getWidth() * buffer.getHeight()};
        defines.put(prefix + "_OFFSET", view.offset());
        ArrayUtils.swap((long[])skip, (int)0, (int)d);
        defines.put(prefix + "_X_SKIP", skip[0]);
        defines.put(prefix + "_Y_SKIP", skip[1]);
        defines.put(prefix + "_Z_SKIP", skip[2]);
    }

    private static long[] getDimensions(Dimensions dimensions) {
        return new long[]{GpuSeparableOperation.getDimension(dimensions, 0), GpuSeparableOperation.getDimension(dimensions, 1), GpuSeparableOperation.getDimension(dimensions, 2)};
    }

    private static long getDimension(Dimensions dimensions, int d) {
        return d < dimensions.numDimensions() ? dimensions.dimension(d) : 1L;
    }

    static enum Operation {
        MIN("min1d.cl"),
        MAX("max1d.cl"),
        MEAN("mean1d.cl");

        private final String name;

        private Operation(String name) {
            this.name = name;
        }
    }
}

