/*
 * Decompiled with CFR 0.152.
 */
package mpicbg.spim.postprocessing.deconvolution2;

import ij.CompositeImage;
import ij.IJ;
import ij.ImagePlus;
import ij.ImageStack;
import java.util.ArrayList;
import java.util.Date;
import java.util.Vector;
import java.util.concurrent.atomic.AtomicInteger;
import mpicbg.imglib.cursor.Cursor;
import mpicbg.imglib.image.Image;
import mpicbg.imglib.image.ImageFactory;
import mpicbg.imglib.image.display.imagej.ImageJFunctions;
import mpicbg.imglib.io.LOCI;
import mpicbg.imglib.multithreading.Chunk;
import mpicbg.imglib.multithreading.SimpleMultiThreading;
import mpicbg.imglib.type.numeric.real.FloatType;
import mpicbg.spim.io.IOFunctions;
import mpicbg.spim.postprocessing.deconvolution2.AdjustInput;
import mpicbg.spim.postprocessing.deconvolution2.Deconvolver;
import mpicbg.spim.postprocessing.deconvolution2.LRFFT;
import mpicbg.spim.postprocessing.deconvolution2.LRInput;
import net.imglib2.util.Util;
import spim.Threads;

public class BayesMVDeconvolution
implements Deconvolver {
    public static String initialImage = null;
    public static boolean checkNumbers = true;
    public static boolean debug = true;
    public static int debugInterval = 1;
    static final float minValue = 1.0E-4f;
    final int numViews;
    final int numDimensions;
    final float avg;
    final double lambda;
    ImageStack stack;
    CompositeImage ci;
    boolean collectStatistics = true;
    int i = 0;
    Image<FloatType> psi;
    final LRInput views;
    ArrayList<LRFFT> data;
    String name;

    public BayesMVDeconvolution(LRInput views, LRFFT.PSFTYPE iterationType, int numIterations, double lambda, double osemspeedup, int osemspeedupindex, String name) {
        this.name = name;
        this.data = views.getViews();
        this.views = views;
        this.numViews = this.data.size();
        this.numDimensions = this.data.get(0).getImage().getNumDimensions();
        this.lambda = lambda;
        if (initialImage != null) {
            this.psi = BayesMVDeconvolution.loadInitialImage(initialImage, checkNumbers, 1.0E-4f, this.data.get(0).getImage().getDimensions(), (ImageFactory<FloatType>)this.data.get(0).getImage().getImageFactory());
        }
        double[] result = AdjustInput.normAllImages(this.data);
        this.avg = (float)result[0];
        if (osemspeedupindex == 1) {
            osemspeedup = Math.max(1.0, result[1]);
        } else if (osemspeedupindex == 2) {
            osemspeedup = Math.max(1.0, result[2]);
        }
        this.adjustOSEMspeedup(views, osemspeedup);
        IJ.log((String)("Average intensity in overlapping area: " + this.avg));
        IJ.log((String)("OSEM acceleration: " + osemspeedup));
        views.init(iterationType);
        if (this.psi == null) {
            this.psi = this.data.get(0).getImage().createNewImage("psi (deconvolved image)");
            for (FloatType f : this.psi) {
                f.set(this.avg);
            }
        }
        IOFunctions.println("Deconvolved image container: " + this.psi.getImageFactory().getContainerFactory().getClass().getSimpleName());
        while (this.i < numIterations) {
            int i;
            this.runIteration();
            if (!debug || (this.i - 1) % debugInterval != 0) continue;
            this.psi.getDisplay().setMinMax(0.0, 1.0);
            ImagePlus tmp = ImageJFunctions.copyToImagePlus(this.psi);
            if (this.stack == null) {
                this.stack = tmp.getImageStack();
                for (int i2 = 0; i2 < this.psi.getDimension(2); ++i2) {
                    this.stack.setSliceLabel("Iteration 1", i2 + 1);
                }
                tmp.setTitle("debug view");
                this.ci = new CompositeImage(tmp, 1);
                this.ci.setDimensions(1, this.psi.getDimension(2), 1);
                this.ci.show();
                continue;
            }
            if (this.stack.getSize() == this.psi.getDimension(2)) {
                IJ.log((String)("Stack size = " + this.stack.getSize()));
                ImageStack t = tmp.getImageStack();
                for (i = 0; i < this.psi.getDimension(2); ++i) {
                    this.stack.addSlice("Iteration 2", t.getProcessor(i + 1));
                }
                IJ.log((String)("Stack size = " + this.stack.getSize()));
                this.ci.hide();
                IJ.log((String)("Stack size = " + this.stack.getSize()));
                this.ci = new CompositeImage(new ImagePlus("debug view", this.stack), 1);
                this.ci.setDimensions(1, this.psi.getDimension(2), 2);
                this.ci.show();
                continue;
            }
            ImageStack t = tmp.getImageStack();
            for (i = 0; i < this.psi.getDimension(2); ++i) {
                this.stack.addSlice("Iteration " + i, t.getProcessor(i + 1));
            }
            this.ci.setStack(this.stack, 1, this.psi.getDimension(2), this.stack.getSize() / this.psi.getDimension(2));
        }
        IJ.log((String)("DONE (" + new Date(System.currentTimeMillis()) + ")."));
    }

    private void adjustOSEMspeedup(LRInput views, double osemspeedup) {
        if (osemspeedup == 1.0) {
            return;
        }
        for (LRFFT view : views.getViews()) {
            for (FloatType f : view.getWeight()) {
                f.set(Math.min(1.0f, f.get() * (float)osemspeedup));
            }
        }
    }

    protected static Image<FloatType> loadInitialImage(String fileName, boolean checkNumbers, float minValue, int[] dimensions, ImageFactory<FloatType> imageFactory) {
        IOFunctions.println("Loading image '" + fileName + "' as start for iteration.");
        Image psi = LOCI.openLOCIFloatType((String)fileName, imageFactory);
        if (psi == null) {
            IOFunctions.println("Could not load image '" + fileName + "'.");
            return null;
        }
        boolean dimensionsMatch = true;
        for (int d = 0; d < psi.getNumDimensions(); ++d) {
            if (psi.getDimension(d) == dimensions[d]) continue;
            dimensionsMatch = false;
        }
        if (!dimensionsMatch) {
            IOFunctions.println("Dimensions of '" + fileName + "' do not match: " + Util.printCoordinates((int[])psi.getDimensions()) + " != " + Util.printCoordinates((int[])dimensions));
            psi.close();
            return null;
        }
        if (checkNumbers) {
            IOFunctions.println("Checking values of '" + fileName + "' you can disable this check by setting mpicbg.spim.postprocessing.deconvolution2.BayesMVDeconvolution.checkNumbers = false;");
            boolean smaller = false;
            boolean hasZerosOrNeg = false;
            for (FloatType v : psi) {
                if (v.get() < minValue) {
                    smaller = true;
                }
                if (!(v.get() <= 0.0f)) continue;
                hasZerosOrNeg = true;
                v.set(minValue);
            }
            if (smaller) {
                IOFunctions.println("Some values '" + fileName + "' are smaller than the minimal value of " + minValue + ", this can lead to instabilities.");
            }
            if (hasZerosOrNeg) {
                IOFunctions.println("Some values '" + fileName + "' were smaller or equal to zero, they have been replaced with the min value of " + minValue);
            }
        }
        return psi;
    }

    @Override
    public LRInput getData() {
        return this.views;
    }

    @Override
    public String getName() {
        return this.name;
    }

    @Override
    public double getAvg() {
        return this.avg;
    }

    @Override
    public Image<FloatType> getPsi() {
        return this.psi;
    }

    public int getCurrentIteration() {
        return this.i;
    }

    @Override
    public void runIteration() {
        BayesMVDeconvolution.runIteration(this.psi, this.data, this.lambda, 1.0E-4f, this.collectStatistics, this.i++);
    }

    private static final void runIteration(final Image<FloatType> psi, ArrayList<LRFFT> data, final double lambda, float minValue, boolean collectStatistic, int iteration) {
        IJ.log((String)("iteration: " + iteration + " (" + new Date(System.currentTimeMillis()) + ")"));
        int numViews = data.size();
        final Vector threadChunks = SimpleMultiThreading.divideIntoChunks((long)psi.getNumPixels(), (int)Threads.numThreads());
        int numThreads = threadChunks.size();
        final Image lastIteration = collectStatistic ? psi.clone() : null;
        for (int view = 0; view < numViews; ++view) {
            final LRFFT processingData = data.get(view);
            long time = System.currentTimeMillis();
            final Image<FloatType> psiBlurred = processingData.convolve1(psi);
            System.out.println(view + " a: " + (time - System.currentTimeMillis()) + " ms.");
            final AtomicInteger ai = new AtomicInteger(0);
            Thread[] threads = SimpleMultiThreading.newThreads((int)numThreads);
            for (int ithread = 0; ithread < threads.length; ++ithread) {
                threads[ithread] = new Thread(new Runnable(){

                    @Override
                    public void run() {
                        int myNumber = ai.getAndIncrement();
                        Chunk myChunk = (Chunk)threadChunks.get(myNumber);
                        BayesMVDeconvolution.computeQuotient(myChunk.getStartPosition(), myChunk.getLoopSize(), (Image<FloatType>)psiBlurred, processingData);
                    }
                });
            }
            SimpleMultiThreading.startAndJoin((Thread[])threads);
            time = System.currentTimeMillis();
            final Image<FloatType> integral = processingData.convolve2(psiBlurred);
            System.out.println(view + " b: " + (time - System.currentTimeMillis()) + " ms.");
            ai.set(0);
            for (int ithread = 0; ithread < threads.length; ++ithread) {
                threads[ithread] = new Thread(new Runnable(){

                    @Override
                    public void run() {
                        int myNumber = ai.getAndIncrement();
                        Chunk myChunk = (Chunk)threadChunks.get(myNumber);
                        BayesMVDeconvolution.computeFinalValues(myChunk.getStartPosition(), myChunk.getLoopSize(), (Image<FloatType>)psi, (Image<FloatType>)integral, (Image<FloatType>)processingData.getWeight(), lambda);
                    }
                });
            }
            SimpleMultiThreading.startAndJoin((Thread[])threads);
        }
        if (collectStatistic) {
            final AtomicInteger ai = new AtomicInteger(0);
            Thread[] threads = SimpleMultiThreading.newThreads((int)numThreads);
            final double[][] sumMax = new double[numThreads][2];
            for (int ithread = 0; ithread < threads.length; ++ithread) {
                threads[ithread] = new Thread(new Runnable(){

                    @Override
                    public void run() {
                        int myNumber = ai.getAndIncrement();
                        Chunk myChunk = (Chunk)threadChunks.get(myNumber);
                        BayesMVDeconvolution.collectStatistics(myChunk.getStartPosition(), myChunk.getLoopSize(), (Image<FloatType>)psi, (Image<FloatType>)lastIteration, sumMax[myNumber]);
                    }
                });
            }
            SimpleMultiThreading.startAndJoin((Thread[])threads);
            double sumChange = 0.0;
            double maxChange = -1.0;
            for (int i = 0; i < numThreads; ++i) {
                sumChange += sumMax[i][0];
                maxChange = Math.max(maxChange, sumMax[i][1]);
            }
            IJ.log((String)("iteration: " + iteration + " --- sum change: " + sumChange + " --- max change per pixel: " + maxChange));
        }
    }

    private static final void collectStatistics(long start, long loopSize, Image<FloatType> psi, Image<FloatType> lastIteration, double[] sumMax) {
        double sumChange = 0.0;
        double maxChange = -1.0;
        Cursor cursorPsi = psi.createCursor();
        Cursor cursorLast = lastIteration.createCursor();
        cursorPsi.fwd(start);
        cursorLast.fwd(start);
        for (long l = 0L; l < loopSize; ++l) {
            float last = ((FloatType)cursorLast.next()).get();
            float next = ((FloatType)cursorPsi.next()).get();
            float change = Math.abs(next - last);
            sumChange += (double)change;
            maxChange = Math.max(maxChange, (double)change);
        }
        sumMax[0] = sumChange;
        sumMax[1] = maxChange;
    }

    private static final void computeQuotient(long start, long loopSize, Image<FloatType> psiBlurred, LRFFT processingData) {
        Cursor cursorImg = processingData.getImage().createCursor();
        Cursor cursorPsiBlurred = psiBlurred.createCursor();
        cursorImg.fwd(start);
        cursorPsiBlurred.fwd(start);
        for (long l = 0L; l < loopSize; ++l) {
            cursorImg.fwd();
            cursorPsiBlurred.fwd();
            float imgValue = ((FloatType)cursorImg.getType()).get();
            float psiBlurredValue = ((FloatType)cursorPsiBlurred.getType()).get();
            ((FloatType)cursorPsiBlurred.getType()).set(imgValue / psiBlurredValue);
        }
        cursorImg.close();
        cursorPsiBlurred.close();
    }

    private static final void computeFinalValues(long start, long loopSize, Image<FloatType> psi, Image<FloatType> integral, Image<FloatType> weight, double lambda) {
        Cursor cursorPsi = psi.createCursor();
        Cursor cursorIntegral = integral.createCursor();
        Cursor cursorWeight = weight.createCursor();
        cursorPsi.fwd(start);
        cursorIntegral.fwd(start);
        cursorWeight.fwd(start);
        for (long l = 0L; l < loopSize; ++l) {
            cursorPsi.fwd();
            cursorIntegral.fwd();
            cursorWeight.fwd();
            float lastPsiValue = ((FloatType)cursorPsi.getType()).get();
            float value = lastPsiValue * ((FloatType)cursorIntegral.getType()).get();
            if (value > 0.0f) {
                if (lambda > 0.0) {
                    value = (float)((Math.sqrt(1.0 + 2.0 * lambda * (double)value) - 1.0) / lambda);
                }
            } else {
                value = 1.0E-4f;
            }
            float nextPsiValue = Double.isNaN(value) ? 1.0E-4f : Math.max(1.0E-4f, value);
            float change = nextPsiValue - lastPsiValue;
            nextPsiValue = lastPsiValue + (change *= ((FloatType)cursorWeight.getType()).get());
            ((FloatType)cursorPsi.getType()).set(nextPsiValue);
        }
    }
}

