/*
 * Decompiled with CFR 0.152.
 */
package spim.process.fusion.deconvolution;

import ij.CompositeImage;
import ij.ImagePlus;
import ij.ImageStack;
import java.io.File;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.Vector;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import mpicbg.spim.io.IOFunctions;
import net.imglib2.Cursor;
import net.imglib2.Dimensions;
import net.imglib2.IterableInterval;
import net.imglib2.Localizable;
import net.imglib2.RandomAccess;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.exception.IncompatibleTypeException;
import net.imglib2.img.Img;
import net.imglib2.img.ImgFactory;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Pair;
import net.imglib2.util.RealSum;
import net.imglib2.util.Util;
import net.imglib2.view.Views;
import spim.Threads;
import spim.fiji.spimdata.imgloaders.LegacyStackImgLoaderIJ;
import spim.process.fusion.FusionHelper;
import spim.process.fusion.ImagePortion;
import spim.process.fusion.deconvolution.FirstIteration;
import spim.process.fusion.deconvolution.MVDeconFFT;
import spim.process.fusion.deconvolution.MVDeconInput;
import spim.process.fusion.export.DisplayImage;

public class MVDeconvolution {
    public static String initialImage = null;
    public static boolean checkNumbers = true;
    public static boolean debug = true;
    public static int debugInterval = 1;
    public static boolean setBackgroundToAvg = true;
    static final float minValue = 1.0E-4f;
    final int numViews;
    final int numDimensions;
    final double lambda;
    ImageStack stack;
    CompositeImage ci;
    boolean collectStatistics = true;
    int i = 0;
    Img<FloatType> psi = null;
    final Img<FloatType> tmp1;
    final Img<FloatType> tmp2;
    final MVDeconInput views;
    ArrayList<MVDeconFFT> data;
    String name;

    public MVDeconvolution(MVDeconInput views, MVDeconFFT.PSFTYPE iterationType, int numIterations, double lambda, double osemspeedup, int osemspeedupindex, String name) throws IncompatibleTypeException {
        this.name = name;
        this.data = views.getViews();
        this.views = views;
        this.numViews = this.data.size();
        this.numDimensions = this.data.get(0).getImage().numDimensions();
        this.lambda = lambda;
        IOFunctions.println("(" + new Date(System.currentTimeMillis()) + "): Deconvolved & temporary image factory: " + views.imgFactory().getClass().getSimpleName());
        views.init(iterationType);
        if (initialImage != null) {
            IOFunctions.println("(" + new Date(System.currentTimeMillis()) + "): Loading intial image '" + initialImage + "'");
            this.psi = MVDeconvolution.loadInitialImage(initialImage, checkNumbers, 1.0E-4f, this.data.get(0).getImage(), views.imgFactory());
        } else {
            IOFunctions.println("(" + new Date(System.currentTimeMillis()) + "): Fusing image for first iteration");
            this.psi = views.imgFactory().create(this.data.get(0).getImage(), (Object)new FloatType());
            double avg = MVDeconvolution.fuseFirstIteration(this.psi, views.getViews());
            IOFunctions.println("(" + new Date(System.currentTimeMillis()) + "): Average intensity in overlapping area: " + avg);
            if (Double.isNaN(avg)) {
                avg = 0.5;
                IOFunctions.println("(" + new Date(System.currentTimeMillis()) + "): ERROR! Computing average FAILED, is NaN, setting it to: " + avg);
            }
            IOFunctions.println("(" + new Date(System.currentTimeMillis()) + "): Setting image to average intensity: " + avg);
            for (FloatType t : this.psi) {
                t.set((float)avg);
            }
        }
        this.tmp1 = views.imgFactory().create(this.psi, (Object)new FloatType());
        this.tmp2 = views.imgFactory().create(this.psi, (Object)new FloatType());
        while (this.i < numIterations) {
            if (debug && (this.i - 1) % debugInterval == 0) {
                ImagePlus tmp = DisplayImage.getImagePlusInstance(this.psi, true, "Psi", 0.0, 1.0).duplicate();
                if (this.stack == null) {
                    this.stack = tmp.getImageStack();
                    int i = 0;
                    while ((long)i < this.psi.dimension(2)) {
                        this.stack.setSliceLabel("Iteration 1", i + 1);
                        ++i;
                    }
                    tmp.setTitle("debug view");
                    this.ci = new CompositeImage(tmp, 1);
                    this.ci.setDimensions(1, (int)this.psi.dimension(2), 1);
                    this.ci.show();
                } else if ((long)this.stack.getSize() == this.psi.dimension(2)) {
                    ImageStack t = tmp.getImageStack();
                    int i = 0;
                    while ((long)i < this.psi.dimension(2)) {
                        this.stack.addSlice("Iteration 2", t.getProcessor(i + 1));
                        ++i;
                    }
                    this.ci.hide();
                    this.ci = new CompositeImage(new ImagePlus("debug view", this.stack), 1);
                    this.ci.setDimensions(1, (int)this.psi.dimension(2), 2);
                    this.ci.show();
                } else {
                    ImageStack t = tmp.getImageStack();
                    int i = 0;
                    while ((long)i < this.psi.dimension(2)) {
                        this.stack.addSlice("Iteration " + i, t.getProcessor(i + 1));
                        ++i;
                    }
                    this.ci.setStack(this.stack, 1, (int)this.psi.dimension(2), this.stack.getSize() / (int)this.psi.dimension(2));
                }
            }
            this.runIteration();
        }
        IOFunctions.println("Masking never updated pixels.");
        MVDeconvolution.fuseFirstIteration(this.tmp1, views.getViews());
        Cursor tmp1c = this.tmp1.cursor();
        for (FloatType t : this.psi) {
            if (((FloatType)tmp1c.next()).get() != 0.0f) continue;
            t.set(0.0f);
        }
        IOFunctions.println("DONE (" + new Date(System.currentTimeMillis()) + ").");
    }

    protected static final double fuseFirstIteration(Img<FloatType> psi, ArrayList<MVDeconFFT> views) {
        int nThreads = Threads.numThreads();
        int nPortions = nThreads * 2;
        Vector<ImagePortion> portions = FusionHelper.divideIntoPortions(psi.size(), nPortions);
        ArrayList<FirstIteration> tasks = new ArrayList<FirstIteration>();
        ExecutorService taskExecutor = Executors.newFixedThreadPool(nThreads);
        ArrayList<RandomAccessibleInterval<FloatType>> imgs = new ArrayList<RandomAccessibleInterval<FloatType>>();
        for (MVDeconFFT mvdecon : views) {
            imgs.add(mvdecon.getImage());
        }
        for (ImagePortion portion : portions) {
            tasks.add(new FirstIteration(portion, (RandomAccessibleInterval<FloatType>)psi, imgs));
        }
        RealSum s = new RealSum();
        long count = 0L;
        try {
            List imgIntensities = taskExecutor.invokeAll(tasks);
            for (Future future : imgIntensities) {
                s.add(((RealSum)((Pair)future.get()).getA()).getSum());
                count += ((Long)((Pair)future.get()).getB()).longValue();
            }
        }
        catch (Exception e) {
            IOFunctions.println("Failed to fuse initial iteration: " + e);
            e.printStackTrace();
            return -1.0;
        }
        taskExecutor.shutdown();
        return s.getSum() / (double)count;
    }

    protected static Img<FloatType> loadInitialImage(String fileName, boolean checkNumbers, float minValue, Dimensions dimensions, ImgFactory<FloatType> imageFactory) {
        long[] lArray;
        IOFunctions.println("Loading image '" + fileName + "' as start for iteration.");
        ImagePlus impPSI = LegacyStackImgLoaderIJ.open(new File(fileName));
        if (impPSI == null) {
            IOFunctions.println("Could not load image '" + fileName + "'.");
            return null;
        }
        if (impPSI.getStack().getSize() == 1) {
            long[] lArray2 = new long[2];
            lArray2[0] = impPSI.getWidth();
            lArray = lArray2;
            lArray2[1] = impPSI.getHeight();
        } else {
            long[] lArray3 = new long[3];
            lArray3[0] = impPSI.getWidth();
            lArray3[1] = impPSI.getHeight();
            lArray = lArray3;
            lArray3[2] = impPSI.getStack().getSize();
        }
        long[] dimPsi = lArray;
        Img psi = imageFactory.create(dimPsi, (Object)new FloatType());
        LegacyStackImgLoaderIJ.imagePlus2ImgLib2Img(impPSI, (Img<FloatType>)psi, false);
        if (psi == null) {
            IOFunctions.println("Could not load image '" + fileName + "'.");
            return null;
        }
        boolean dimensionsMatch = true;
        long[] dim = new long[dimensions.numDimensions()];
        for (int d = 0; d < psi.numDimensions(); ++d) {
            if (psi.dimension(d) != dimensions.dimension(d)) {
                dimensionsMatch = false;
            }
            dim[d] = dimensions.dimension(d);
        }
        if (!dimensionsMatch) {
            IOFunctions.println("Dimensions of '" + fileName + "' do not match: " + Util.printCoordinates((long[])dimPsi) + " != " + Util.printCoordinates((long[])dim));
            return null;
        }
        if (checkNumbers) {
            IOFunctions.println("Checking values of '" + fileName + "' you can disable this check by setting spim.process.fusion.deconvolution.MVDeconvolution.checkNumbers = false;");
            boolean smaller = false;
            boolean hasZerosOrNeg = false;
            for (FloatType v : psi) {
                if (v.get() < minValue) {
                    smaller = true;
                }
                if (!(v.get() <= 0.0f)) continue;
                hasZerosOrNeg = true;
                v.set(minValue);
            }
            if (smaller) {
                IOFunctions.println("Some values '" + fileName + "' are smaller than the minimal value of " + minValue + ", this can lead to instabilities.");
            }
            if (hasZerosOrNeg) {
                IOFunctions.println("Some values '" + fileName + "' were smaller or equal to zero,they have been replaced with the min value of " + minValue);
            }
        }
        return psi;
    }

    public MVDeconInput getData() {
        return this.views;
    }

    public String getName() {
        return this.name;
    }

    public Img<FloatType> getPsi() {
        return this.psi;
    }

    public int getCurrentIteration() {
        return this.i;
    }

    public void runIteration() {
        MVDeconvolution.runIteration(this.psi, this.tmp1, this.tmp2, this.data, this.lambda, 1.0E-4f, this.collectStatistics, this.i++);
    }

    private static final void runIteration(final Img<FloatType> psi, final Img<FloatType> tmp1, final Img<FloatType> tmp2, ArrayList<MVDeconFFT> data, final double lambda, float minValue, boolean collectStatistic, int iteration) {
        IOFunctions.println("iteration: " + iteration + " (" + new Date(System.currentTimeMillis()) + ")");
        int numViews = data.size();
        int nThreads = Threads.numThreads();
        int nPortions = nThreads * 2;
        Vector<ImagePortion> portions = FusionHelper.divideIntoPortions(psi.size(), nPortions);
        ArrayList<Callable<Void>> tasks = new ArrayList<Callable<Void>>();
        for (int view = 0; view < numViews; ++view) {
            final MVDeconFFT processingData = data.get(view);
            processingData.convolve1(psi, tmp1);
            tasks.clear();
            for (final ImagePortion portion : portions) {
                tasks.add(new Callable<Void>(){

                    @Override
                    public Void call() throws Exception {
                        MVDeconvolution.computeQuotient(portion.getStartPosition(), portion.getLoopSize(), (RandomAccessibleInterval<FloatType>)((RandomAccessibleInterval)tmp1), (RandomAccessibleInterval<FloatType>)processingData.getImage());
                        return null;
                    }
                });
            }
            MVDeconvolution.execTasks(tasks, nThreads, "compute quotient");
            processingData.convolve2(tmp1, tmp2);
            final double[][] sumMax = new double[nPortions][2];
            tasks.clear();
            int i = 0;
            while (i < portions.size()) {
                final ImagePortion portion = portions.get(i);
                final int portionId = i++;
                tasks.add(new Callable<Void>(){

                    @Override
                    public Void call() throws Exception {
                        MVDeconvolution.computeFinalValues(portion.getStartPosition(), portion.getLoopSize(), (RandomAccessibleInterval<FloatType>)((RandomAccessibleInterval)psi), (RandomAccessibleInterval<FloatType>)((RandomAccessibleInterval)tmp2), (RandomAccessibleInterval<FloatType>)processingData.getWeight(), lambda, sumMax[portionId]);
                        return null;
                    }
                });
            }
            MVDeconvolution.execTasks(tasks, nThreads, "compute final values");
            double sumChange = 0.0;
            double maxChange = -1.0;
            for (int i2 = 0; i2 < nPortions; ++i2) {
                sumChange += sumMax[i2][0];
                maxChange = Math.max(maxChange, sumMax[i2][1]);
            }
            IOFunctions.println("iteration: " + iteration + ", view: " + view + " --- sum change: " + sumChange + " --- max change per pixel: " + maxChange);
        }
    }

    private static final void execTasks(ArrayList<Callable<Void>> tasks, int nThreads, String jobDescription) {
        ExecutorService taskExecutor = Executors.newFixedThreadPool(nThreads);
        try {
            taskExecutor.invokeAll(tasks);
        }
        catch (InterruptedException e) {
            IOFunctions.println("Failed to " + jobDescription + ": " + e);
            e.printStackTrace();
            return;
        }
        taskExecutor.shutdown();
    }

    private static final void computeQuotient(long start, long loopSize, RandomAccessibleInterval<FloatType> psiBlurred, RandomAccessibleInterval<FloatType> observedImg) {
        IterableInterval psiBlurredIterable = Views.iterable(psiBlurred);
        IterableInterval observedImgIterable = Views.iterable(observedImg);
        if (psiBlurredIterable.iterationOrder().equals(observedImgIterable.iterationOrder())) {
            Cursor cursorPsiBlurred = psiBlurredIterable.cursor();
            Cursor cursorImg = observedImgIterable.cursor();
            cursorPsiBlurred.jumpFwd(start);
            cursorImg.jumpFwd(start);
            for (long l = 0L; l < loopSize; ++l) {
                cursorPsiBlurred.fwd();
                cursorImg.fwd();
                float psiBlurredValue = ((FloatType)cursorPsiBlurred.get()).get();
                float imgValue = ((FloatType)cursorImg.get()).get();
                if (imgValue > 0.0f) {
                    ((FloatType)cursorPsiBlurred.get()).set(imgValue / psiBlurredValue);
                    continue;
                }
                ((FloatType)cursorPsiBlurred.get()).set(1.0f);
            }
        } else {
            RandomAccess raPsiBlurred = psiBlurred.randomAccess();
            Cursor cursorImg = observedImgIterable.localizingCursor();
            cursorImg.jumpFwd(start);
            for (long l = 0L; l < loopSize; ++l) {
                cursorImg.fwd();
                raPsiBlurred.setPosition((Localizable)cursorImg);
                float psiBlurredValue = ((FloatType)raPsiBlurred.get()).get();
                float imgValue = ((FloatType)cursorImg.get()).get();
                if (imgValue > 0.0f) {
                    ((FloatType)raPsiBlurred.get()).set(imgValue / psiBlurredValue);
                    continue;
                }
                ((FloatType)raPsiBlurred.get()).set(1.0f);
            }
        }
    }

    public static final void copyImg(long start, long loopSize, RandomAccessibleInterval<FloatType> source, RandomAccessibleInterval<FloatType> target) {
        IterableInterval sourceIterable = Views.iterable(source);
        IterableInterval targetIterable = Views.iterable(target);
        if (sourceIterable.iterationOrder().equals(sourceIterable.iterationOrder())) {
            Cursor cursorSource = sourceIterable.cursor();
            Cursor cursorTarget = targetIterable.cursor();
            cursorSource.jumpFwd(start);
            cursorTarget.jumpFwd(start);
            for (long l = 0L; l < loopSize; ++l) {
                ((FloatType)cursorTarget.next()).set((FloatType)cursorSource.next());
            }
        } else {
            RandomAccess raSource = source.randomAccess();
            Cursor cursorTarget = targetIterable.localizingCursor();
            cursorTarget.jumpFwd(start);
            for (long l = 0L; l < loopSize; ++l) {
                cursorTarget.fwd();
                raSource.setPosition((Localizable)cursorTarget);
                ((FloatType)cursorTarget.get()).set((FloatType)raSource.get());
            }
        }
    }

    private static final void computeFinalValues(long start, long loopSize, RandomAccessibleInterval<FloatType> psi, RandomAccessibleInterval<FloatType> integral, RandomAccessibleInterval<FloatType> weight, double lambda, double[] sumMax) {
        double sumChange = 0.0;
        double maxChange = -1.0;
        IterableInterval psiIterable = Views.iterable(psi);
        IterableInterval integralIterable = Views.iterable(integral);
        IterableInterval weightIterable = Views.iterable(weight);
        if (psiIterable.iterationOrder().equals(integralIterable.iterationOrder()) && psiIterable.iterationOrder().equals(weightIterable.iterationOrder())) {
            Cursor cursorPsi = psiIterable.cursor();
            Cursor cursorIntegral = integralIterable.cursor();
            Cursor cursorWeight = weightIterable.cursor();
            cursorPsi.jumpFwd(start);
            cursorIntegral.jumpFwd(start);
            cursorWeight.jumpFwd(start);
            for (long l = 0L; l < loopSize; ++l) {
                cursorPsi.fwd();
                cursorIntegral.fwd();
                cursorWeight.fwd();
                float lastPsiValue = ((FloatType)cursorPsi.get()).get();
                float nextPsiValue = MVDeconvolution.computeNextValue(lastPsiValue, ((FloatType)cursorIntegral.get()).get(), ((FloatType)cursorWeight.get()).get(), lambda);
                ((FloatType)cursorPsi.get()).set(nextPsiValue);
                float change = MVDeconvolution.change(lastPsiValue, nextPsiValue);
                sumChange += (double)change;
                maxChange = Math.max(maxChange, (double)change);
            }
        } else {
            Cursor cursorPsi = psiIterable.localizingCursor();
            RandomAccess raIntegral = integral.randomAccess();
            RandomAccess raWeight = weight.randomAccess();
            cursorPsi.jumpFwd(start);
            for (long l = 0L; l < loopSize; ++l) {
                cursorPsi.fwd();
                raIntegral.setPosition((Localizable)cursorPsi);
                raWeight.setPosition((Localizable)cursorPsi);
                float lastPsiValue = ((FloatType)cursorPsi.get()).get();
                float nextPsiValue = MVDeconvolution.computeNextValue(lastPsiValue, ((FloatType)raIntegral.get()).get(), ((FloatType)raWeight.get()).get(), lambda);
                ((FloatType)cursorPsi.get()).set(nextPsiValue);
                float change = MVDeconvolution.change(lastPsiValue, nextPsiValue);
                sumChange += (double)change;
                maxChange = Math.max(maxChange, (double)change);
            }
        }
        sumMax[0] = sumChange;
        sumMax[1] = maxChange;
    }

    private static final float change(float lastPsiValue, float nextPsiValue) {
        return Math.abs(nextPsiValue - lastPsiValue);
    }

    private static final float computeNextValue(float lastPsiValue, float integralValue, float weight, double lambda) {
        float value = lastPsiValue * integralValue;
        float adjustedValue = value > 0.0f ? (lambda > 0.0 ? (float)MVDeconvolution.tikhonov(value, lambda) : value) : 1.0E-4f;
        float nextPsiValue = Double.isNaN(adjustedValue) ? 1.0E-4f : Math.max(1.0E-4f, adjustedValue);
        return lastPsiValue + (nextPsiValue - lastPsiValue) * weight;
    }

    private static final double tikhonov(double value, double lambda) {
        return (Math.sqrt(1.0 + 2.0 * lambda * value) - 1.0) / lambda;
    }

    public static void main(String[] args) {
        for (double d = 0.0; d < 10.0; d += 0.1) {
            System.out.println(d * 10000.0 + ": " + MVDeconvolution.tikhonov(d * 10000.0, 6.0E-4));
            System.out.println(d + ": " + MVDeconvolution.tikhonov(d, 6.0E-4));
        }
    }
}

