/*
 * Decompiled with CFR 0.152.
 */
package sc.fiji.labkit.pixel_classification.pixel_feature.filter.stats;

import net.haesleinhuepf.clij.coremem.enums.NativeTypeEnum;
import net.imglib2.Dimensions;
import net.imglib2.RandomAccessible;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.algorithm.convolution.Convolution;
import net.imglib2.converter.Converters;
import net.imglib2.img.array.ArrayImgs;
import net.imglib2.loops.LoopBuilder;
import net.imglib2.type.Type;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Intervals;
import net.imglib2.view.IntervalView;
import net.imglib2.view.Views;
import org.scijava.plugin.Plugin;
import sc.fiji.labkit.pixel_classification.gpu.algorithms.GpuNeighborhoodOperations;
import sc.fiji.labkit.pixel_classification.gpu.api.GpuApi;
import sc.fiji.labkit.pixel_classification.gpu.api.GpuImage;
import sc.fiji.labkit.pixel_classification.gpu.api.GpuPixelWiseOperation;
import sc.fiji.labkit.pixel_classification.gpu.api.GpuView;
import sc.fiji.labkit.pixel_classification.gpu.api.GpuViews;
import sc.fiji.labkit.pixel_classification.pixel_feature.filter.FeatureOp;
import sc.fiji.labkit.pixel_classification.pixel_feature.filter.stats.AbstractSingleStatisticFeature;
import sc.fiji.labkit.pixel_classification.pixel_feature.filter.stats.SumFilter;

@Plugin(type=FeatureOp.class, label="variance filter")
public class SingleVarianceFeature
extends AbstractSingleStatisticFeature {
    @Override
    protected String filterName() {
        return "variance";
    }

    @Override
    protected void apply(int[] windowSize, RandomAccessible<FloatType> input, RandomAccessibleInterval<FloatType> output) {
        long n = Intervals.numElements((int[])windowSize);
        if (n <= 1L) {
            LoopBuilder.setImages(output).forEachPixel(FloatType::setZero);
            return;
        }
        IntervalView mean = Views.translate((RandomAccessibleInterval)ArrayImgs.floats((long[])Intervals.dimensionsAsLongArray(output)), (long[])Intervals.minAsLongArray(output));
        SumFilter.convolution(windowSize).process(input, (RandomAccessibleInterval)mean);
        double factor = 1.0 / (double)n;
        LoopBuilder.setImages((RandomAccessibleInterval)mean).forEachPixel(pixel -> pixel.mul(factor));
        RandomAccessible squared = Converters.convert(input, (i, o) -> o.set(this.square(i.getRealFloat())), (Type)new FloatType());
        Convolution sumFilter = SumFilter.convolution(windowSize);
        sumFilter.process(squared, output);
        float a = 1.0f / (float)(n - 1L);
        float b = (float)n / (float)(n - 1L);
        LoopBuilder.setImages(output, (RandomAccessibleInterval)mean).forEachPixel((o, m) -> o.setReal(o.getRealFloat() * a - this.square(m.getRealFloat()) * b));
    }

    private float square(float x) {
        return x * x;
    }

    @Override
    protected void apply(GpuApi gpu, int[] windowSize, GpuView input, GpuView output) {
        long[] dimensions = Intervals.dimensionsAsLongArray((Dimensions)output.dimensions());
        GpuImage mean = gpu.create(dimensions, NativeTypeEnum.Float);
        GpuImage meanOfSquared = gpu.create(dimensions, NativeTypeEnum.Float);
        GpuImage squared = gpu.create(Intervals.dimensionsAsLongArray((Dimensions)input.dimensions()), NativeTypeEnum.Float);
        GpuNeighborhoodOperations.mean(gpu, windowSize).apply(input, GpuViews.wrap(mean));
        long n = Intervals.numElements((int[])windowSize);
        if (n <= 1L) {
            GpuPixelWiseOperation.gpu(gpu).addOutput("variance", output).forEachPixel("variance = 0");
        } else {
            this.square(gpu, input, squared);
            GpuNeighborhoodOperations.mean(gpu, windowSize).apply(GpuViews.wrap(squared), GpuViews.wrap(meanOfSquared));
            GpuPixelWiseOperation.gpu(gpu).addInput("mean", mean).addInput("mean_of_squared", meanOfSquared).addInput("factor", (float)n / (float)(n - 1L)).addOutput("variance", output).forEachPixel("variance = (mean_of_squared - mean * mean) * factor");
        }
    }

    private void square(GpuApi gpu, GpuView inputBuffer, GpuImage tmp2) {
        GpuPixelWiseOperation.gpu(gpu).addInput("a", inputBuffer).addOutput("b", tmp2).forEachPixel("b = a * a");
    }
}

