/*
 * Decompiled with CFR 0.152.
 */
package spim.process.fusion.deconvolution;

import bdv.util.ConstantRandomAccessible;
import ij.ImagePlus;
import java.io.File;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Vector;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import mpicbg.spim.data.sequence.Angle;
import mpicbg.spim.data.sequence.Channel;
import mpicbg.spim.data.sequence.Illumination;
import mpicbg.spim.data.sequence.ImgLoader;
import mpicbg.spim.data.sequence.SequenceDescription;
import mpicbg.spim.data.sequence.TimePoint;
import mpicbg.spim.data.sequence.ViewDescription;
import mpicbg.spim.data.sequence.ViewId;
import mpicbg.spim.data.sequence.ViewSetup;
import mpicbg.spim.io.IOFunctions;
import net.imglib2.Cursor;
import net.imglib2.FinalInterval;
import net.imglib2.Interval;
import net.imglib2.RandomAccess;
import net.imglib2.RandomAccessible;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.img.Img;
import net.imglib2.img.ImgFactory;
import net.imglib2.img.display.imagej.ImageJFunctions;
import net.imglib2.realtransform.AffineTransform3D;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Pair;
import net.imglib2.view.Views;
import spim.Threads;
import spim.fiji.spimdata.SpimData2;
import spim.fiji.spimdata.ViewSetupUtils;
import spim.fiji.spimdata.imgloaders.LegacyStackImgLoaderIJ;
import spim.fiji.spimdata.interestpoints.CorrespondingInterestPoints;
import spim.fiji.spimdata.interestpoints.InterestPoint;
import spim.fiji.spimdata.interestpoints.InterestPointList;
import spim.process.fusion.FusionHelper;
import spim.process.fusion.ImagePortion;
import spim.process.fusion.boundingbox.BoundingBoxGUI;
import spim.process.fusion.deconvolution.ChannelPSF;
import spim.process.fusion.deconvolution.ExtractPSF;
import spim.process.fusion.deconvolution.TransformInput;
import spim.process.fusion.deconvolution.TransformInputAndWeights;
import spim.process.fusion.deconvolution.TransformWeights;
import spim.process.fusion.deconvolution.WeightNormalizer;
import spim.process.fusion.export.DisplayImage;
import spim.process.fusion.weightedavg.ProcessFusion;
import spim.process.fusion.weights.Blending;
import spim.process.fusion.weights.NormalizingRandomAccessibleInterval;
import spim.process.fusion.weights.TransformedRealRandomAccessibleInterval;

public class ProcessForDeconvolution {
    protected final SpimData2 spimData;
    protected final List<ViewId> viewIdsToProcess;
    final BoundingBoxGUI bb;
    final int[] blendingBorder;
    final int[] blendingRange;
    int minOverlappingViews;
    double avgOverlappingViews;
    ArrayList<ViewDescription> viewDescriptions;
    HashMap<ViewId, RandomAccessibleInterval<FloatType>> imgs;
    HashMap<ViewId, RandomAccessibleInterval<FloatType>> weights;
    ExtractPSF<FloatType> ePSF;
    public static String[] files;
    public static boolean debugImport;

    public ProcessForDeconvolution(SpimData2 spimData, List<ViewId> viewIdsToProcess, BoundingBoxGUI bb, int[] blendingBorder, int[] blendingRange) {
        this.spimData = spimData;
        this.viewIdsToProcess = viewIdsToProcess;
        this.bb = bb;
        this.blendingBorder = blendingBorder;
        this.blendingRange = blendingRange;
    }

    public ExtractPSF<FloatType> getExtractPSF() {
        return this.ePSF;
    }

    public HashMap<ViewId, RandomAccessibleInterval<FloatType>> getTransformedImgs() {
        return this.imgs;
    }

    public HashMap<ViewId, RandomAccessibleInterval<FloatType>> getTransformedWeights() {
        return this.weights;
    }

    public ArrayList<ViewDescription> getViewDescriptions() {
        return this.viewDescriptions;
    }

    public int getMinOverlappingViews() {
        return this.minOverlappingViews;
    }

    public double getAvgOverlappingViews() {
        return this.avgOverlappingViews;
    }

    public boolean fuseStacksAndGetPSFs(TimePoint timepoint, Channel channel, ImgFactory<FloatType> imgFactory, int osemIndex, double osemspeedup, WeightType weightType, HashMap<Channel, ChannelPSF> extractPSFLabels, long[] psfSize, HashMap<Channel, ArrayList<Pair<Pair<Angle, Illumination>, String>>> psfFiles, boolean transformLoadedPSFs) {
        boolean loadPSFs;
        if (files != null) {
            weightType = WeightType.LOAD_WEIGHTS;
            IOFunctions.println("WARNING: LOADING WEIGHTS FROM IMAGES, files.length()=" + files.length);
        }
        this.viewDescriptions = FusionHelper.assembleInputData(this.spimData, timepoint, channel, this.viewIdsToProcess);
        if (this.viewDescriptions.size() == 0) {
            return false;
        }
        this.imgs = new HashMap();
        this.weights = new HashMap();
        Img overlapImg = weightType == WeightType.WEIGHTS_ONLY ? imgFactory.create(this.bb.getDimensions(), (Object)new FloatType()) : null;
        boolean extractPSFs = extractPSFLabels != null && extractPSFLabels.get(channel).getLabel() != null;
        boolean bl = loadPSFs = psfFiles != null;
        this.ePSF = extractPSFs ? new ExtractPSF() : (loadPSFs ? this.loadPSFs(channel, this.viewDescriptions, psfFiles, transformLoadedPSFs) : this.assignOtherChannel(channel, extractPSFLabels));
        if (this.ePSF == null) {
            return false;
        }
        extractPSFLabels.get(channel).setExtractPSFInstance(this.ePSF);
        for (int i = 0; i < this.viewDescriptions.size(); ++i) {
            Object weightImg;
            RandomAccessibleInterval<FloatType> img;
            ViewDescription vd = this.viewDescriptions.get(i);
            IOFunctions.println("Transforming view " + i + " of " + (this.viewDescriptions.size() - 1) + " (viewsetup=" + vd.getViewSetupId() + ", tp=" + vd.getTimePointId() + ")");
            IOFunctions.println("(" + new Date(System.currentTimeMillis()) + "): Reserving memory for transformed & weight image.");
            Img transformedImg = weightType == WeightType.WEIGHTS_ONLY ? overlapImg : imgFactory.create(this.bb.getDimensions(), (Object)new FloatType());
            IOFunctions.println("(" + new Date(System.currentTimeMillis()) + "): Transformed image factory: " + imgFactory.getClass().getSimpleName());
            if (weightType == WeightType.WEIGHTS_ONLY && !extractPSFs) {
                img = null;
            } else {
                IOFunctions.println("(" + new Date(System.currentTimeMillis()) + "): Loading image.");
                img = ProcessFusion.getImage(new FloatType(), this.spimData, (ViewId)vd, true);
                if (Img.class.isInstance(img)) {
                    IOFunctions.println("(" + new Date(System.currentTimeMillis()) + "): Input image factory: " + ((Img)img).factory().getClass().getSimpleName());
                }
            }
            IOFunctions.println("(" + new Date(System.currentTimeMillis()) + "): Initializing transformation & weights: " + weightType.name());
            this.spimData.getViewRegistrations().getViewRegistration((ViewId)vd).updateModel();
            AffineTransform3D transform = this.spimData.getViewRegistrations().getViewRegistration((ViewId)vd).getModel();
            long[] offset = new long[]{this.bb.min(0), this.bb.min(1), this.bb.min(2)};
            if (weightType == WeightType.PRECOMPUTED_WEIGHTS || weightType == WeightType.WEIGHTS_ONLY) {
                weightImg = imgFactory.create(this.bb.getDimensions(), (Object)new FloatType());
            } else if (weightType == WeightType.NO_WEIGHTS) {
                weightImg = Views.interval((RandomAccessible)new ConstantRandomAccessible((Object)new FloatType(1.0f), transformedImg.numDimensions()), (Interval)transformedImg);
            } else if (weightType == WeightType.VIRTUAL_WEIGHTS) {
                Blending blending = this.getBlending((Interval)img, this.blendingBorder, this.blendingRange, vd);
                weightImg = new TransformedRealRandomAccessibleInterval<FloatType>(blending, new FloatType(), (Interval)transformedImg, transform, offset);
            } else {
                IOFunctions.println("WARNING: LOADING WEIGHTS FROM: '" + new File(files[i]) + "'");
                ImagePlus imp = LegacyStackImgLoaderIJ.open(new File(files[i]));
                weightImg = imgFactory.create(this.bb.getDimensions(), (Object)new FloatType());
                LegacyStackImgLoaderIJ.imagePlus2ImgLib2Img(imp, (Img<FloatType>)weightImg, false);
                imp.close();
                if (debugImport) {
                    imp = ImageJFunctions.show((RandomAccessibleInterval)weightImg);
                    imp.setTitle("ViewSetup " + vd.getViewSetupId() + " Timepoint " + vd.getTimePointId());
                }
            }
            Vector<ImagePortion> portions = FusionHelper.divideIntoPortions(Views.iterable((RandomAccessibleInterval)transformedImg).size(), Threads.numThreads() * 4);
            ExecutorService taskExecutor = Executors.newFixedThreadPool(Threads.numThreads());
            ArrayList<Callable<String>> tasks = new ArrayList<Callable<String>>();
            IOFunctions.println("(" + new Date(System.currentTimeMillis()) + "): Transforming image & computing weights.");
            for (ImagePortion portion : portions) {
                if (weightType == WeightType.WEIGHTS_ONLY) {
                    FinalInterval imgInterval = new FinalInterval(ViewSetupUtils.getSizeOrLoad(vd.getViewSetup(), vd.getTimePoint(), (ImgLoader)((SequenceDescription)this.spimData.getSequenceDescription()).getImgLoader()));
                    Blending blending = this.getBlending((Interval)imgInterval, this.blendingBorder, this.blendingRange, vd);
                    tasks.add(new TransformWeights(portion, (Interval)imgInterval, blending, transform, (RandomAccessibleInterval<FloatType>)overlapImg, (RandomAccessibleInterval<FloatType>)weightImg, offset));
                    continue;
                }
                if (weightType == WeightType.PRECOMPUTED_WEIGHTS) {
                    Blending blending = this.getBlending((Interval)img, this.blendingBorder, this.blendingRange, vd);
                    tasks.add(new TransformInputAndWeights(portion, img, blending, transform, (RandomAccessibleInterval<FloatType>)transformedImg, (RandomAccessibleInterval<FloatType>)weightImg, offset));
                    continue;
                }
                if (weightType == WeightType.NO_WEIGHTS || weightType == WeightType.VIRTUAL_WEIGHTS || weightType == WeightType.LOAD_WEIGHTS) {
                    tasks.add(new TransformInput(portion, img, transform, (RandomAccessibleInterval<FloatType>)transformedImg, offset));
                    continue;
                }
                throw new RuntimeException(weightType.name() + " not implemented yet.");
            }
            try {
                taskExecutor.invokeAll(tasks);
            }
            catch (InterruptedException e) {
                IOFunctions.println("Failed to compute fusion: " + e);
                e.printStackTrace();
                return false;
            }
            taskExecutor.shutdown();
            if (extractPSFs) {
                ArrayList<double[]> llist = this.getLocationsOfCorrespondingBeads(timepoint, vd, extractPSFLabels.get(channel).getLabel());
                IOFunctions.println("(" + new Date(System.currentTimeMillis()) + "): Extracting PSF for viewsetup " + vd.getViewSetupId() + " using label '" + extractPSFLabels.get(channel).getLabel() + "' (" + llist.size() + " corresponding detections available)");
                this.ePSF.extractNextImg(img, (ViewId)vd, transform, llist, psfSize);
            }
            if (weightType != WeightType.WEIGHTS_ONLY) {
                this.imgs.put((ViewId)vd, (RandomAccessibleInterval<FloatType>)transformedImg);
            }
            this.weights.put((ViewId)vd, (RandomAccessibleInterval<FloatType>)weightImg);
            tasks.clear();
            System.gc();
        }
        ArrayList<RandomAccessibleInterval<FloatType>> weightsSorted = new ArrayList<RandomAccessibleInterval<FloatType>>();
        for (ViewDescription vd : this.viewDescriptions) {
            weightsSorted.add(this.weights.get(vd));
        }
        IOFunctions.println("(" + new Date(System.currentTimeMillis()) + "): Computing weight normalization for deconvolution.");
        WeightNormalizer wn = weightType == WeightType.WEIGHTS_ONLY || weightType == WeightType.PRECOMPUTED_WEIGHTS || weightType == WeightType.LOAD_WEIGHTS ? new WeightNormalizer(weightsSorted) : (weightType == WeightType.VIRTUAL_WEIGHTS ? new WeightNormalizer(weightsSorted, imgFactory) : null);
        if (wn != null && !wn.process()) {
            return false;
        }
        for (int i = 0; i < this.viewDescriptions.size(); ++i) {
            this.weights.put((ViewId)this.viewDescriptions.get(i), weightsSorted.get(i));
        }
        if (wn != null) {
            this.minOverlappingViews = wn.getMinOverlappingViews();
            this.avgOverlappingViews = wn.getAvgOverlappingViews();
            this.minOverlappingViews = Math.max(1, this.minOverlappingViews);
            IOFunctions.println("(" + new Date(System.currentTimeMillis()) + "): Minimal number of overlapping views: " + this.getMinOverlappingViews() + ", using " + this.minOverlappingViews);
            this.avgOverlappingViews = Math.max(1.0, this.avgOverlappingViews);
            IOFunctions.println("(" + new Date(System.currentTimeMillis()) + "): Average number of overlapping views: " + this.getAvgOverlappingViews() + ", using " + this.avgOverlappingViews);
        }
        if (osemIndex == 1) {
            osemspeedup = this.getMinOverlappingViews();
        } else if (osemIndex == 2) {
            osemspeedup = this.getAvgOverlappingViews();
        }
        IOFunctions.println("(" + new Date(System.currentTimeMillis()) + "): Adjusting for OSEM speedup = " + osemspeedup);
        if (weightType == WeightType.WEIGHTS_ONLY) {
            this.displayWeights(osemspeedup, weightsSorted, (RandomAccessibleInterval<FloatType>)overlapImg, imgFactory);
        } else {
            ProcessForDeconvolution.adjustForOSEM(this.weights, weightType, osemspeedup);
        }
        IOFunctions.println("(" + new Date(System.currentTimeMillis()) + "): Finished precomputations for deconvolution.");
        return true;
    }

    private static void adjustForOSEM(HashMap<ViewId, RandomAccessibleInterval<FloatType>> weights, WeightType weightType, double osemspeedup) {
        if (osemspeedup == 1.0) {
            return;
        }
        if (weightType == WeightType.PRECOMPUTED_WEIGHTS || weightType == WeightType.WEIGHTS_ONLY || weightType == WeightType.LOAD_WEIGHTS) {
            for (RandomAccessibleInterval<FloatType> w : weights.values()) {
                for (FloatType f : Views.iterable(w)) {
                    f.set(Math.min(1.0f, f.get() * (float)osemspeedup));
                }
            }
        } else if (weightType == WeightType.NO_WEIGHTS) {
            for (RandomAccessibleInterval<FloatType> w : weights.values()) {
                RandomAccess r = w.randomAccess();
                long[] min = new long[w.numDimensions()];
                w.min(min);
                r.setPosition(min);
                ((FloatType)r.get()).set(Math.min(1.0f, ((FloatType)r.get()).get() * (float)osemspeedup));
            }
        } else if (weightType == WeightType.VIRTUAL_WEIGHTS) {
            for (RandomAccessibleInterval<FloatType> w : weights.values()) {
                ((NormalizingRandomAccessibleInterval)w).setOSEMspeedup(osemspeedup);
            }
        } else {
            throw new RuntimeException("Weight Type: " + weightType.name() + " not supported in ProcessForDeconvolution.adjustForOSEM()");
        }
    }

    private ExtractPSF<FloatType> loadPSFs(Channel ch, ArrayList<ViewDescription> allInputData, HashMap<Channel, ArrayList<Pair<Pair<Angle, Illumination>, String>>> psfFiles, boolean transformLoadedPSFs) {
        HashMap<ViewId, AffineTransform3D> models;
        if (transformLoadedPSFs) {
            models = new HashMap<ViewId, AffineTransform3D>();
            for (ViewDescription viewDesc : allInputData) {
                models.put((ViewId)viewDesc, this.spimData.getViewRegistrations().getViewRegistration((ViewId)viewDesc).getModel());
            }
        } else {
            models = null;
        }
        return ExtractPSF.loadAndTransformPSFs(psfFiles.get(ch), allInputData, new FloatType(), models);
    }

    protected ExtractPSF<FloatType> assignOtherChannel(Channel channel, HashMap<Channel, ChannelPSF> extractPSFLabels) {
        ChannelPSF thisChannelPSF = extractPSFLabels.get(channel);
        ChannelPSF otherChannelPSF = extractPSFLabels.get(thisChannelPSF.getOtherChannel());
        Channel otherChannel = thisChannelPSF.getOtherChannel();
        for (int i = 0; i < this.viewDescriptions.size(); ++i) {
            ViewDescription sourceVD = this.viewDescriptions.get(i);
            for (ViewId viewId : this.viewIdsToProcess) {
                ViewDescription otherVD = ((SequenceDescription)this.spimData.getSequenceDescription()).getViewDescription(viewId);
                if (((ViewSetup)otherVD.getViewSetup()).getAngle().getId() != ((ViewSetup)sourceVD.getViewSetup()).getAngle().getId() || ((ViewSetup)otherVD.getViewSetup()).getIllumination().getId() != ((ViewSetup)sourceVD.getViewSetup()).getIllumination().getId() || otherVD.getTimePointId() != sourceVD.getTimePointId() || ((ViewSetup)otherVD.getViewSetup()).getChannel().getId() != otherChannel.getId()) continue;
                this.ePSF.getViewIdMapping().put((ViewId)sourceVD, (ViewId)otherVD);
                IOFunctions.println("ViewID=" + sourceVD.getViewSetupId() + ", TPID=" + sourceVD.getTimePointId() + " takes the PSF from ViewID=" + otherVD.getViewSetupId() + ", TPID=" + otherVD.getTimePointId());
            }
        }
        return otherChannelPSF.getExtractPSFInstance();
    }

    protected ArrayList<double[]> getLocationsOfCorrespondingBeads(TimePoint tp, ViewDescription inputData, String label) {
        InterestPointList iplist = this.spimData.getViewInterestPoints().getViewInterestPointLists((ViewId)inputData).getInterestPointList(label);
        HashSet<Integer> ipWithCorrespondences = new HashSet<Integer>();
        for (CorrespondingInterestPoints cip : iplist.getCorrespondingInterestPoints()) {
            ipWithCorrespondences.add(cip.getDetectionId());
        }
        ArrayList<double[]> llist = new ArrayList<double[]>();
        for (InterestPoint ip : iplist.getInterestPoints()) {
            if (!ipWithCorrespondences.contains(ip.getId())) continue;
            llist.add((double[])ip.getL().clone());
        }
        return llist;
    }

    protected void displayWeights(final double osemspeedup, final ArrayList<RandomAccessibleInterval<FloatType>> weights, RandomAccessibleInterval<FloatType> overlapImg, ImgFactory<FloatType> imgFactory) {
        DisplayImage d = new DisplayImage();
        d.exportImage(overlapImg, this.bb, "Number of views per pixel");
        final Img w = imgFactory.create(overlapImg, (Object)new FloatType());
        final Img wosem = imgFactory.create(overlapImg, (Object)new FloatType());
        Vector<ImagePortion> portions = FusionHelper.divideIntoPortions(Views.iterable(weights.get(0)).size(), Threads.numThreads() * 2);
        ExecutorService taskExecutor = Executors.newFixedThreadPool(Threads.numThreads());
        ArrayList<1> tasks = new ArrayList<1>();
        for (final ImagePortion portion : portions) {
            tasks.add(new Callable<String>(){

                @Override
                public String call() throws Exception {
                    ArrayList<Cursor> cursors = new ArrayList<Cursor>();
                    Cursor sum = w.cursor();
                    Cursor sumOsem = wosem.cursor();
                    for (RandomAccessibleInterval imgW : weights) {
                        Cursor c = Views.iterable((RandomAccessibleInterval)imgW).cursor();
                        c.jumpFwd(portion.getStartPosition());
                        cursors.add(c);
                    }
                    sum.jumpFwd(portion.getStartPosition());
                    sumOsem.jumpFwd(portion.getStartPosition());
                    for (long j = 0L; j < portion.getLoopSize(); ++j) {
                        double sumW = 0.0;
                        double sumOsemW = 0.0;
                        for (Cursor c : cursors) {
                            float w2 = ((FloatType)c.next()).get();
                            sumW += (double)w2;
                            sumOsemW += Math.min(1.0, (double)w2 * osemspeedup);
                        }
                        ((FloatType)sum.next()).set((float)sumW);
                        ((FloatType)sumOsem.next()).set((float)sumOsemW);
                    }
                    return "done.";
                }
            });
        }
        try {
            taskExecutor.invokeAll(tasks);
        }
        catch (Exception e) {
            IOFunctions.println("Failed to compute weight normalization for deconvolution: " + e);
            e.printStackTrace();
            return;
        }
        taskExecutor.shutdown();
        d.exportImage(w, this.bb, "Sum of weights per pixel");
        d.exportImage(wosem, this.bb, "OSEM=" + osemspeedup + ", sum of weights per pixel");
    }

    protected Blending getBlending(Interval interval, int[] blendingBorder, int[] blendingRange, ViewDescription desc) {
        float[] blending = new float[3];
        float[] border = new float[3];
        blending[0] = blendingRange[0];
        blending[1] = blendingRange[1];
        blending[2] = blendingRange[2];
        border[0] = blendingBorder[0];
        border[1] = blendingBorder[1];
        border[2] = blendingBorder[2];
        return new Blending(interval, border, blending);
    }

    static {
        debugImport = false;
    }

    public static enum WeightType {
        WEIGHTS_ONLY,
        NO_WEIGHTS,
        VIRTUAL_WEIGHTS,
        PRECOMPUTED_WEIGHTS,
        LOAD_WEIGHTS;

    }
}

