/*
 * Decompiled with CFR 0.152.
 */
package sc.fiji.labkit.pixel_classification.pixel_feature.filter.structure;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.StringJoiner;
import java.util.function.BiConsumer;
import java.util.stream.Collectors;
import java.util.stream.DoubleStream;
import net.haesleinhuepf.clij.coremem.enums.NativeTypeEnum;
import net.imglib2.Dimensions;
import net.imglib2.FinalInterval;
import net.imglib2.Interval;
import net.imglib2.RandomAccessible;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.algorithm.convolution.Convolution;
import net.imglib2.algorithm.convolution.kernel.Kernel1D;
import net.imglib2.algorithm.convolution.kernel.SeparableKernelConvolution;
import net.imglib2.algorithm.gauss3.Gauss3;
import net.imglib2.algorithm.linalg.eigen.EigenValues;
import net.imglib2.loops.LoopBuilder;
import net.imglib2.type.numeric.NumericType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.real.DoubleType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Intervals;
import net.imglib2.view.IntervalView;
import net.imglib2.view.Views;
import net.imglib2.view.composite.Composite;
import org.scijava.plugin.Parameter;
import org.scijava.plugin.Plugin;
import sc.fiji.labkit.pixel_classification.RevampUtils;
import sc.fiji.labkit.pixel_classification.gpu.GpuFeatureInput;
import sc.fiji.labkit.pixel_classification.gpu.algorithms.GpuEigenvalues;
import sc.fiji.labkit.pixel_classification.gpu.algorithms.GpuGauss;
import sc.fiji.labkit.pixel_classification.gpu.algorithms.GpuNeighborhoodOperation;
import sc.fiji.labkit.pixel_classification.gpu.api.GpuApi;
import sc.fiji.labkit.pixel_classification.gpu.api.GpuImage;
import sc.fiji.labkit.pixel_classification.gpu.api.GpuPixelWiseOperation;
import sc.fiji.labkit.pixel_classification.gpu.api.GpuView;
import sc.fiji.labkit.pixel_classification.gpu.api.GpuViews;
import sc.fiji.labkit.pixel_classification.pixel_feature.filter.AbstractFeatureOp;
import sc.fiji.labkit.pixel_classification.pixel_feature.filter.FeatureInput;
import sc.fiji.labkit.pixel_classification.pixel_feature.filter.FeatureOp;
import sc.fiji.labkit.pixel_classification.pixel_feature.filter.hessian.EigenValuesSymmetric3D;
import sc.fiji.labkit.pixel_classification.utils.views.FastViews;

@Plugin(type=FeatureOp.class, label="structure tensor eigenvalues")
public class SingleStructureTensorEigenvaluesFeature
extends AbstractFeatureOp {
    @Parameter
    double sigma = 1.0;
    @Parameter
    double integrationScale = 1.0;

    @Override
    public int count() {
        return this.globalSettings().numDimensions();
    }

    @Override
    public List<String> attributeLabels() {
        List<String> prefix = this.getPrefix();
        return prefix.stream().map(s -> "structure tensor - " + s + " eigenvalue sigma=" + this.sigma + " integrationScale=" + this.integrationScale).collect(Collectors.toList());
    }

    @Override
    public void apply(FeatureInput input, List<RandomAccessibleInterval<FloatType>> output) {
        Interval targetInterval = (Interval)output.get(0);
        Convolution<NumericType<?>> convolution = this.gaussConvolution();
        Interval derivativeInterval = convolution.requiredSourceInterval(targetInterval);
        RandomAccessibleInterval<DoubleType> derivatives = this.derivatives(input, derivativeInterval);
        RandomAccessibleInterval<DoubleType> products = this.products(derivatives);
        IntervalView blurredProducts = Views.interval(products, (Interval)Intervals.addDimension((Interval)targetInterval, (long)products.min(products.numDimensions() - 1), (long)products.max(products.numDimensions() - 1)));
        convolution.process(products, (RandomAccessibleInterval)blurredProducts);
        Object eigenvalueComputer = this.globalSettings().numDimensions() == 3 ? new EigenValuesSymmetric3D() : EigenValues.symmetric2D();
        LoopBuilder.setImages(FastViews.collapse(blurredProducts), RevampUtils.vectorizeStack(output)).forEachPixel((arg_0, arg_1) -> ((EigenValues)eigenvalueComputer).compute(arg_0, arg_1));
    }

    private Convolution<NumericType<?>> gaussConvolution() {
        Kernel1D[] gauss = (Kernel1D[])this.globalSettings().pixelSize().stream().map(pixelSize -> this.gaussKernel(this.integrationScale / pixelSize)).toArray(Kernel1D[]::new);
        return SeparableKernelConvolution.convolution((Kernel1D[])gauss);
    }

    private Kernel1D gaussKernel(double v) {
        return Kernel1D.symmetric((double[])Gauss3.halfkernels((double[])new double[]{v})[0]);
    }

    private RandomAccessibleInterval<DoubleType> products(RandomAccessibleInterval<DoubleType> derivatives) {
        Interval interval = RevampUtils.removeLastDimension(derivatives);
        FinalInterval outputInterval = Intervals.addDimension((Interval)interval, (long)0L, (long)(this.getNumberOfProducts() - 1));
        RandomAccessibleInterval<DoubleType> output = RevampUtils.createImage((Interval)outputInterval, new DoubleType());
        LoopBuilder.setImages(FastViews.collapse(derivatives), FastViews.collapse(output)).forEachPixel(this.getProductPerPixelAction());
        return output;
    }

    private RandomAccessibleInterval<DoubleType> derivatives(FeatureInput input, Interval derivativeInterval) {
        RandomAccessibleInterval<DoubleType> gauss = RevampUtils.createImage((Interval)Intervals.expand((Interval)derivativeInterval, (long)1L), new DoubleType());
        List<Double> pixelSize = this.globalSettings().pixelSize();
        double[] sigmas = pixelSize.stream().mapToDouble(p -> this.sigma / p).toArray();
        RandomAccessible<FloatType> original = input.original();
        Gauss3.gauss((double[])sigmas, original, gauss);
        int n = derivativeInterval.numDimensions();
        RandomAccessibleInterval<DoubleType> tmp = RevampUtils.createImage(RevampUtils.appendDimensionToInterval(derivativeInterval, 0L, n - 1), new DoubleType());
        for (int i = 0; i < n; ++i) {
            this.derive((RandomAccessible<? extends RealType<?>>)gauss, (RandomAccessibleInterval<? extends RealType<?>>)Views.hyperSlice(tmp, (int)n, (long)i), i, pixelSize.get(i));
        }
        return tmp;
    }

    private void derive(RandomAccessible<? extends RealType<?>> input, RandomAccessibleInterval<? extends RealType<?>> tmp, int d, double pixelSize) {
        IntervalView back = Views.interval(input, (Interval)Intervals.translate(tmp, (long)-1L, (int)d));
        IntervalView front = Views.interval(input, (Interval)Intervals.translate(tmp, (long)1L, (int)d));
        double factor = 0.5 / pixelSize;
        LoopBuilder.setImages(tmp, (RandomAccessibleInterval)back, (RandomAccessibleInterval)front).forEachPixel((r, b, f) -> r.setReal((f.getRealDouble() - b.getRealDouble()) * factor));
    }

    private List<String> getPrefix() {
        return this.globalSettings().numDimensions() == 3 ? Arrays.asList("largest", "middle", "smallest") : Arrays.asList("largest", "smallest");
    }

    private int getNumberOfProducts() {
        return this.globalSettings().numDimensions() == 3 ? 6 : 3;
    }

    private BiConsumer<Composite<DoubleType>, Composite<DoubleType>> getProductPerPixelAction() {
        return this.globalSettings().numDimensions() == 3 ? SingleStructureTensorEigenvaluesFeature::productPerPixel3d : SingleStructureTensorEigenvaluesFeature::productPerPixel2d;
    }

    private static void productPerPixel3d(Composite<DoubleType> i, Composite<DoubleType> o) {
        double x = ((DoubleType)i.get(0L)).getRealDouble();
        double y = ((DoubleType)i.get(1L)).getRealDouble();
        double z = ((DoubleType)i.get(2L)).getRealDouble();
        ((DoubleType)o.get(0L)).setReal(x * x);
        ((DoubleType)o.get(1L)).setReal(x * y);
        ((DoubleType)o.get(2L)).setReal(x * z);
        ((DoubleType)o.get(3L)).setReal(y * y);
        ((DoubleType)o.get(4L)).setReal(y * z);
        ((DoubleType)o.get(5L)).setReal(z * z);
    }

    private static void productPerPixel2d(Composite<DoubleType> i, Composite<DoubleType> o) {
        double x = ((DoubleType)i.get(0L)).getRealDouble();
        double y = ((DoubleType)i.get(1L)).getRealDouble();
        ((DoubleType)o.get(0L)).setReal(x * x);
        ((DoubleType)o.get(1L)).setReal(x * y);
        ((DoubleType)o.get(2L)).setReal(y * y);
    }

    @Override
    public void prefetch(GpuFeatureInput input) {
        double[] integrationSigma = this.globalSettings().pixelSize().stream().mapToDouble(p -> this.integrationScale / p).toArray();
        double[] gaussSigma = this.globalSettings().pixelSize().stream().mapToDouble(p -> this.sigma / p).toArray();
        long[] border = DoubleStream.of(integrationSigma).mapToLong(sigma -> (long)(4.0 * sigma)).toArray();
        FinalInterval derivativeInterval = Intervals.expand((Interval)input.targetInterval(), (long[])border);
        for (int d = 0; d < this.globalSettings().numDimensions(); ++d) {
            input.prefetchDerivative(gaussSigma[d], d, (Interval)derivativeInterval);
        }
    }

    @Override
    public void apply(GpuFeatureInput input, List<GpuView> output) {
        try (GpuApi scope = input.gpuApi().subScope();){
            double[] integrationSigma = this.globalSettings().pixelSize().stream().mapToDouble(p -> this.integrationScale / p).toArray();
            GpuNeighborhoodOperation integrationGauss = GpuGauss.gauss(scope, integrationSigma);
            Interval border = integrationGauss.getRequiredInputInterval(input.targetInterval());
            List<GpuView> derivatives = this.derivatives(input, border);
            GpuImage products = this.products(scope, derivatives);
            GpuImage blurredProducts = this.blur(scope, GpuViews.channels(products), integrationGauss, input.targetInterval());
            GpuEigenvalues.symmetric(scope, GpuViews.channels(blurredProducts), output);
        }
    }

    private List<GpuView> derivatives(GpuFeatureInput input, Interval derivativeInterval) {
        double[] gaussSigma = this.globalSettings().pixelSize().stream().mapToDouble(p -> this.sigma / p).toArray();
        int n = this.globalSettings().numDimensions();
        ArrayList<GpuView> derivatives = new ArrayList<GpuView>(3);
        for (int d = 0; d < n; ++d) {
            derivatives.add(input.derivative(gaussSigma[d], d, derivativeInterval));
        }
        return derivatives;
    }

    private GpuImage products(GpuApi gpu, List<GpuView> derivatives) {
        int n = derivatives.size();
        int numProducts = n * (n + 1) / 2;
        long[] dimensions = Intervals.dimensionsAsLongArray((Dimensions)derivatives.get(0).dimensions());
        GpuPixelWiseOperation loopBuilder = GpuPixelWiseOperation.gpu(gpu);
        StringJoiner operation = new StringJoiner("; ");
        for (int i = 0; i < derivatives.size(); ++i) {
            loopBuilder.addInput("derivative" + i, derivatives.get(i));
        }
        GpuImage products = gpu.create(dimensions, numProducts, NativeTypeEnum.Float);
        Iterator<GpuView> iterator = GpuViews.channels(products).iterator();
        for (int i = 0; i < derivatives.size(); ++i) {
            for (int j = i; j < derivatives.size(); ++j) {
                loopBuilder.addOutput("product" + i + j, iterator.next());
                operation.add("product" + i + j + " = derivative" + i + " * derivative" + j);
            }
        }
        loopBuilder.forEachPixel(operation.toString());
        return products;
    }

    private GpuImage blur(GpuApi gpu, List<GpuView> products, GpuNeighborhoodOperation integrationGauss, Interval targertInteval) {
        long[] dimensions = Intervals.dimensionsAsLongArray((Dimensions)targertInteval);
        GpuImage blurred = gpu.create(dimensions, products.size(), NativeTypeEnum.Float);
        List<GpuView> blurredChannels = GpuViews.channels(blurred);
        for (int i = 0; i < products.size(); ++i) {
            integrationGauss.apply(products.get(i), blurredChannels.get(i));
        }
        return blurred;
    }
}

