/*
 * Decompiled with CFR 0.152.
 */
package sc.fiji.labkit.pixel_classification.pixel_feature.filter.gradient;

import java.util.Collections;
import java.util.List;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.loops.LoopBuilder;
import net.imglib2.type.numeric.real.DoubleType;
import net.imglib2.type.numeric.real.FloatType;
import org.scijava.plugin.Parameter;
import org.scijava.plugin.Plugin;
import sc.fiji.labkit.pixel_classification.gpu.GpuFeatureInput;
import sc.fiji.labkit.pixel_classification.gpu.api.GpuApi;
import sc.fiji.labkit.pixel_classification.gpu.api.GpuPixelWiseOperation;
import sc.fiji.labkit.pixel_classification.gpu.api.GpuView;
import sc.fiji.labkit.pixel_classification.pixel_feature.filter.AbstractFeatureOp;
import sc.fiji.labkit.pixel_classification.pixel_feature.filter.FeatureInput;
import sc.fiji.labkit.pixel_classification.pixel_feature.filter.FeatureOp;

@Plugin(type=FeatureOp.class, label="gaussian gradient magnitude")
public class SingleGaussianGradientMagnitudeFeature
extends AbstractFeatureOp {
    @Parameter
    private double sigma = 1.0;

    @Override
    public int count() {
        return 1;
    }

    @Override
    public List<String> attributeLabels() {
        return Collections.singletonList("gaussian gradient magnitude sigma=" + this.sigma);
    }

    @Override
    public void apply(FeatureInput input, List<RandomAccessibleInterval<FloatType>> output) {
        int n = this.globalSettings().numDimensions();
        if (n == 2) {
            this.apply2d(input, output.get(0));
        } else if (n == 3) {
            this.apply3d(input, output.get(0));
        } else {
            throw new AssertionError();
        }
    }

    private void apply3d(FeatureInput input, RandomAccessibleInterval<FloatType> output) {
        RandomAccessibleInterval<DoubleType> dx = this.derive(input, 0);
        RandomAccessibleInterval<DoubleType> dy = this.derive(input, 1);
        RandomAccessibleInterval<DoubleType> dz = this.derive(input, 2);
        LoopBuilder.setImages(dx, dy, dz, output).forEachPixel((x, y, z, o) -> o.setReal(SingleGaussianGradientMagnitudeFeature.magnitude(x.getRealDouble(), y.getRealDouble(), z.getRealDouble())));
    }

    private void apply2d(FeatureInput input, RandomAccessibleInterval<FloatType> output) {
        RandomAccessibleInterval<DoubleType> dx = this.derive(input, 0);
        RandomAccessibleInterval<DoubleType> dy = this.derive(input, 1);
        LoopBuilder.setImages(dx, dy, output).forEachPixel((x, y, o) -> o.setReal(SingleGaussianGradientMagnitudeFeature.magnitude(x.getRealDouble(), y.getRealDouble())));
    }

    private static double magnitude(double x, double y) {
        return Math.sqrt(SingleGaussianGradientMagnitudeFeature.square(x) + SingleGaussianGradientMagnitudeFeature.square(y));
    }

    private static double magnitude(double x, double y, double z) {
        return Math.sqrt(SingleGaussianGradientMagnitudeFeature.square(x) + SingleGaussianGradientMagnitudeFeature.square(y) + SingleGaussianGradientMagnitudeFeature.square(z));
    }

    private static double square(double x) {
        return x * x;
    }

    private RandomAccessibleInterval<DoubleType> derive(FeatureInput input, int d) {
        int[] orders = new int[this.globalSettings().numDimensions()];
        orders[d] = 1;
        return input.derivedGauss(this.sigma, orders);
    }

    @Override
    public void prefetch(GpuFeatureInput input) {
        for (int d = 0; d < this.globalSettings().numDimensions(); ++d) {
            input.prefetchDerivative(this.sigma, d, input.targetInterval());
        }
    }

    @Override
    public void apply(GpuFeatureInput input, List<GpuView> output) {
        boolean is3d = this.globalSettings().numDimensions() == 3;
        GpuApi gpu = input.gpuApi();
        GpuPixelWiseOperation loopBuilder = GpuPixelWiseOperation.gpu(gpu);
        loopBuilder.addInput("dx", input.derivative(this.sigma, 0, input.targetInterval()));
        loopBuilder.addInput("dy", input.derivative(this.sigma, 1, input.targetInterval()));
        if (is3d) {
            loopBuilder.addInput("dz", input.derivative(this.sigma, 2, input.targetInterval()));
        }
        loopBuilder.addOutput("output", output.get(0));
        String operation = is3d ? "output = sqrt(dx * dx + dy * dy + dz * dz)" : "output = sqrt(dx * dx + dy * dy)";
        loopBuilder.forEachPixel(operation);
    }
}

