/*
 * Decompiled with CFR 0.152.
 */
package net.imglib2.algorithm.metrics.imagequality;

import java.util.Arrays;
import net.imglib2.FinalInterval;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.algorithm.convolution.fast_gauss.FastGauss;
import net.imglib2.algorithm.gauss3.Gauss3;
import net.imglib2.converter.Converters;
import net.imglib2.img.Img;
import net.imglib2.loops.LoopBuilder;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.real.DoubleType;
import net.imglib2.util.Intervals;
import net.imglib2.util.RealSum;
import net.imglib2.util.Util;
import net.imglib2.view.ExtendedRandomAccessibleInterval;
import net.imglib2.view.Views;

public class SSIM {
    public static final double K1 = 0.01;
    public static final double K2 = 0.03;

    public static <T extends RealType<T>> double computeMetrics(RandomAccessibleInterval<T> reference, RandomAccessibleInterval<T> processed, double sigma) {
        if (!Intervals.equalDimensions(reference, processed)) {
            throw new IllegalArgumentException("Image dimensions must match.");
        }
        if (reference.numDimensions() > 3) {
            throw new IllegalArgumentException("2D or 3D images expected.");
        }
        double range = ((RealType)reference.randomAccess().get()).getMaxValue() - ((RealType)reference.randomAccess().get()).getMinValue();
        Filter filter = new Filter(sigma);
        for (long d : reference.dimensionsAsLongArray()) {
            if (d > (long)(2 * filter.padding)) continue;
            throw new IllegalArgumentException("Sigma = " + filter.sigma + " is not compatible with dimension of depth " + d + ", minimum size: " + (2 * filter.padding + 1) + ".");
        }
        RandomAccessibleInterval<DoubleType> refIm = Converters.convert(reference, (i, o) -> o.set(i.getRealDouble()), new DoubleType());
        RandomAccessibleInterval<DoubleType> procIm = Converters.convert(processed, (i, o) -> o.set(i.getRealDouble()), new DoubleType());
        RandomAccessibleInterval<DoubleType> ux = SSIM.computeWeightedMean(filter, refIm);
        RandomAccessibleInterval<DoubleType> uy = SSIM.computeWeightedMean(filter, procIm);
        RandomAccessibleInterval<DoubleType> vx = SSIM.computeWeightedVariance(filter, refIm, ux);
        RandomAccessibleInterval<DoubleType> vy = SSIM.computeWeightedVariance(filter, procIm, uy);
        RandomAccessibleInterval<DoubleType> vxy = SSIM.computeWeightedCovariance(filter, refIm, procIm, ux, uy);
        RandomAccessibleInterval<DoubleType> S = SSIM.computeSSIM(ux, uy, vx, vy, vxy, range);
        return SSIM.computeMean(S);
    }

    private static RandomAccessibleInterval<DoubleType> createRAI(RandomAccessibleInterval<DoubleType> input, int padding) {
        FinalInterval cropped = Intervals.expand(input, (long)(-padding));
        Img<DoubleType> output = Util.getSuitableImgFactory(input, new DoubleType()).create(cropped);
        long[] translation = new long[input.numDimensions()];
        Arrays.fill(translation, (long)padding);
        return Views.translate(output, translation);
    }

    private static void filter(Filter filter, ExtendedRandomAccessibleInterval<DoubleType, RandomAccessibleInterval<DoubleType>> input, RandomAccessibleInterval<DoubleType> output) {
        if (FilteringAlgorithm.FASTGAUSS.equals((Object)filter.algorithm)) {
            FastGauss.convolve(filter.sigma, input, output);
        } else {
            Gauss3.gauss(filter.sigma, input, output);
        }
    }

    private static RandomAccessibleInterval<DoubleType> computeWeightedMean(Filter filter, RandomAccessibleInterval<DoubleType> input) {
        RandomAccessibleInterval<DoubleType> output = SSIM.createRAI(input, filter.padding);
        SSIM.filter(filter, Views.extendMirrorDouble(input), output);
        return output;
    }

    private static RandomAccessibleInterval<DoubleType> computeWeightedVariance(Filter filter, RandomAccessibleInterval<DoubleType> img, RandomAccessibleInterval<DoubleType> weightedMean) {
        RandomAccessibleInterval<DoubleType> square = SSIM.createRAI(img, 0);
        LoopBuilder.setImages(img, square).forEachPixel((i, o) -> o.set(i.get() * i.get()));
        RandomAccessibleInterval<DoubleType> meanSquare = SSIM.createRAI(img, filter.padding);
        SSIM.filter(filter, Views.extendMirrorDouble(square), meanSquare);
        LoopBuilder.setImages(meanSquare, weightedMean).forEachPixel((v, u) -> v.set(v.get() - u.get() * u.get()));
        return meanSquare;
    }

    private static RandomAccessibleInterval<DoubleType> computeWeightedCovariance(Filter filter, RandomAccessibleInterval<DoubleType> im1, RandomAccessibleInterval<DoubleType> im2, RandomAccessibleInterval<DoubleType> weightedMean1, RandomAccessibleInterval<DoubleType> weightedMean2) {
        RandomAccessibleInterval<DoubleType> product = SSIM.createRAI(im1, 0);
        LoopBuilder.setImages(im1, im2, product).forEachPixel((i1, i2, o) -> o.set(i1.get() * i2.get()));
        RandomAccessibleInterval<DoubleType> meanProduct = SSIM.createRAI(im1, filter.padding);
        SSIM.filter(filter, Views.extendMirrorDouble(product), meanProduct);
        LoopBuilder.setImages(meanProduct, weightedMean1, weightedMean2).forEachPixel((v, u1, u2) -> v.set(v.get() - u1.get() * u2.get()));
        return meanProduct;
    }

    private static RandomAccessibleInterval<DoubleType> computeSSIM(RandomAccessibleInterval<DoubleType> ux, RandomAccessibleInterval<DoubleType> uy, RandomAccessibleInterval<DoubleType> vx, RandomAccessibleInterval<DoubleType> vy, RandomAccessibleInterval<DoubleType> vxy, double range) {
        double C1 = 1.0E-4 * range * range;
        double C2 = 9.0E-4 * range * range;
        RandomAccessibleInterval<DoubleType> A1 = SSIM.createRAI(ux, 0);
        LoopBuilder.setImages(ux, uy, A1).forEachPixel((u1, u2, o) -> o.set(2.0 * u1.get() * u2.get() + C1));
        RandomAccessibleInterval<DoubleType> A2 = SSIM.createRAI(ux, 0);
        LoopBuilder.setImages(vxy, A2).forEachPixel((v, o) -> o.set(2.0 * v.get() + C2));
        RandomAccessibleInterval<DoubleType> B1 = SSIM.createRAI(ux, 0);
        LoopBuilder.setImages(ux, uy, B1).forEachPixel((u1, u2, o) -> o.set(u1.get() * u1.get() + u2.get() * u2.get() + C1));
        RandomAccessibleInterval<DoubleType> B2 = SSIM.createRAI(ux, 0);
        LoopBuilder.setImages(vx, vy, B2).forEachPixel((v1, v2, o) -> o.set(v1.get() + v2.get() + C2));
        RandomAccessibleInterval<DoubleType> S = SSIM.createRAI(ux, 0);
        LoopBuilder.setImages(A1, A2, B1, B2, S).forEachPixel((a1, a2, b1, b2, s) -> s.set(a1.get() * a2.get() / b1.get() / b2.get()));
        return S;
    }

    private static Double computeMean(RandomAccessibleInterval<DoubleType> img) {
        int counter = 0;
        RealSum sum = new RealSum();
        for (DoubleType d : Views.iterable(img)) {
            sum.add(d.get());
            ++counter;
        }
        return sum.getSum() / (double)counter;
    }

    private static class Filter {
        private final FilteringAlgorithm algorithm;
        private final double sigma;
        private final int padding;

        public Filter(double sigma) {
            this.sigma = sigma;
            if (sigma > 2.0) {
                this.algorithm = FilteringAlgorithm.FASTGAUSS;
                this.padding = Math.max(2, (int)(3.721 * sigma + 0.20157 + 0.5));
            } else {
                this.algorithm = FilteringAlgorithm.GAUSS;
                this.padding = Math.max(2, (int)(3.0 * sigma + 0.5));
            }
        }
    }

    private static enum FilteringAlgorithm {
        GAUSS,
        FASTGAUSS;

    }
}

