/*
 * Decompiled with CFR 0.152.
 */
package fiji.plugin.trackmate.gui.editor.labkit.model;

import fiji.plugin.trackmate.Logger;
import fiji.plugin.trackmate.Model;
import fiji.plugin.trackmate.Spot;
import fiji.plugin.trackmate.detection.MaskUtils;
import fiji.plugin.trackmate.gui.editor.labkit.model.TMLabKitUtils;
import ij.IJ;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import net.imglib2.Interval;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.roi.IterableRegion;
import net.imglib2.roi.labeling.LabelingMapping;
import net.imglib2.type.numeric.integer.UnsignedIntType;
import net.imglib2.view.IntervalView;
import net.imglib2.view.Views;
import org.jgrapht.graph.DefaultWeightedEdge;
import sc.fiji.labkit.ui.labeling.Label;
import sc.fiji.labkit.ui.labeling.Labeling;

public class LabkitImporter {
    private static final boolean DEBUG = true;
    private final Model model;
    private final double[] calibration;
    private final double dt;
    private final Labeling labeling;
    private final RandomAccessibleInterval<UnsignedIntType> previousIndexImg;
    private final Map<Label, Spot> initialMapping;
    private final int targetTimePoint;
    private final boolean simplifyContours;

    public static Builder create() {
        return new Builder();
    }

    private LabkitImporter(Model model, Labeling labeling, RandomAccessibleInterval<UnsignedIntType> initialIndexImg, Map<Label, Spot> initialMapping, int targetTimePoint, boolean simplifyContours, double[] calibration, double dt) {
        this.model = model;
        this.labeling = labeling;
        this.previousIndexImg = initialIndexImg;
        this.initialMapping = initialMapping;
        this.targetTimePoint = targetTimePoint;
        this.simplifyContours = simplifyContours;
        this.calibration = calibration;
        this.dt = dt;
    }

    public boolean hasChanges() {
        RandomAccessibleInterval indexImg = this.labeling.getIndexImg();
        return TMLabKitUtils.isDifferent(this.previousIndexImg, (RandomAccessibleInterval<UnsignedIntType>)indexImg);
    }

    public void run() {
        Logger log = Logger.IJ_LOGGER;
        log.setStatus("Re-importing from Labkit...");
        RandomAccessibleInterval indexImg = this.labeling.getIndexImg();
        Set<Integer> modifiedIndices = TMLabKitUtils.getModifiedIndices((RandomAccessibleInterval<UnsignedIntType>)indexImg, this.previousIndexImg);
        int nModified = modifiedIndices.size();
        if (nModified == 0) {
            return;
        }
        LabelingMapping mapping = this.labeling.getType().getMapping();
        HashSet modifiedLabels = new HashSet();
        for (Integer id : modifiedIndices) {
            modifiedLabels.addAll(mapping.labelsAtIndex(id.intValue()));
        }
        IJ.log((String)"\nRe-importing from Labkit");
        IJ.log((String)("Modified indices: " + modifiedIndices));
        IJ.log((String)"Corresponding modified labels & initial spot (if any):");
        modifiedLabels.forEach(l -> IJ.log((String)(" - " + l.name() + " -> " + this.initialMapping.get(l))));
        IJ.log((String)"Re-insertion in the model:");
        double threshold = 0.5;
        boolean numThreads = true;
        Map regions = this.labeling.iterableRegions();
        int timeAxis = TMLabKitUtils.timeAxis(this.labeling);
        TreeMap allModifiedSpots = new TreeMap();
        if (timeAxis < 0) {
            int lt = Math.max(0, this.targetTimePoint);
            HashMap<Label, List<Spot>> map = new HashMap<Label, List<Spot>>();
            for (Label label : modifiedLabels) {
                IterableRegion region = (IterableRegion)regions.get(label);
                List<Spot> spots = MaskUtils.fromThresholdWithROI(region, (Interval)region, this.calibration, 0.5, this.simplifyContours, 1, null);
                map.put(label, spots);
            }
            allModifiedSpots.put(lt, map);
        } else {
            for (Label label : modifiedLabels) {
                IterableRegion region = (IterableRegion)regions.get(label);
                long minT = region.min(timeAxis);
                long maxT = region.max(timeAxis);
                for (long t = minT; t <= maxT; ++t) {
                    IntervalView slice = Views.hyperSlice((RandomAccessibleInterval)region, (int)timeAxis, (long)t);
                    List<Spot> spots = MaskUtils.fromThresholdWithROI(slice, (Interval)slice, this.calibration, 0.5, this.simplifyContours, 1, null);
                    HashMap<Label, List<Spot>> map = (HashMap<Label, List<Spot>>)allModifiedSpots.get((int)t);
                    if (map == null) {
                        map = new HashMap<Label, List<Spot>>();
                        allModifiedSpots.put((int)t, map);
                    }
                    map.put(label, spots);
                }
            }
        }
        int nTimepoints = allModifiedSpots.size();
        Set timepoints = allModifiedSpots.keySet();
        int progress = 0;
        for (Integer timepoint : timepoints) {
            Map modifiedSpots = (Map)allModifiedSpots.get(timepoint);
            HashMap<Label, Spot> previousSpots = new HashMap<Label, Spot>();
            for (Label label : this.initialMapping.keySet()) {
                Spot spot = this.initialMapping.get(label);
                if (spot.getFeature("FRAME").intValue() != timepoint.intValue()) continue;
                previousSpots.put(label, spot);
            }
            this.reimport(previousSpots, modifiedSpots, timepoint, this.simplifyContours);
            log.setProgress((double)(++progress) / (double)nTimepoints);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void reimport(Map<Label, Spot> initialSpots, Map<Label, List<Spot>> modifiedSpots, int currentTimePoint, boolean simplifyContours) {
        this.model.beginUpdate();
        try {
            for (Label label : modifiedSpots.keySet()) {
                Spot previousSpot = initialSpots.get(label);
                List<Spot> novelSpotList = modifiedSpots.get(label);
                if (previousSpot == null) {
                    if (novelSpotList == null) continue;
                    this.addNewSpot(novelSpotList, currentTimePoint);
                    continue;
                }
                if (novelSpotList == null || novelSpotList.isEmpty()) {
                    IJ.log((String)(" - Removed spot " + LabkitImporter.str(previousSpot)));
                    this.model.removeSpot(previousSpot);
                    continue;
                }
                this.modifySpot(novelSpotList, previousSpot, currentTimePoint);
            }
        }
        finally {
            this.model.endUpdate();
        }
    }

    private void modifySpot(List<Spot> novelSpotList, Spot previousSpot, int currentTimePoint) {
        Spot mainNovelSpot;
        if (novelSpotList.size() == 1) {
            mainNovelSpot = novelSpotList.get(0);
        } else {
            Spot closest = null;
            double minD2 = Double.POSITIVE_INFINITY;
            for (Spot s : novelSpotList) {
                double d2 = s.squareDistanceTo(previousSpot);
                if (!(d2 < minD2)) continue;
                minD2 = d2;
                closest = s;
            }
            mainNovelSpot = closest;
        }
        mainNovelSpot.setName(previousSpot.getName());
        mainNovelSpot.putFeature("POSITION_T", (double)currentTimePoint * this.dt);
        mainNovelSpot.putFeature("QUALITY", -1.0);
        this.model.addSpotTo(mainNovelSpot, currentTimePoint);
        Set<DefaultWeightedEdge> edges = this.model.getTrackModel().edgesOf(previousSpot);
        for (DefaultWeightedEdge e : edges) {
            double weight = this.model.getTrackModel().getEdgeWeight(e);
            Spot source = this.model.getTrackModel().getEdgeSource(e);
            Spot target = this.model.getTrackModel().getEdgeTarget(e);
            if (source == previousSpot) {
                this.model.addEdge(mainNovelSpot, target, weight);
                continue;
            }
            if (target == previousSpot) {
                this.model.addEdge(source, mainNovelSpot, weight);
                continue;
            }
            throw new IllegalArgumentException("The edge of a spot does not have the spot as source or target?!?");
        }
        IJ.log((String)(" - Modified spot " + LabkitImporter.str(previousSpot) + " -> " + LabkitImporter.str(mainNovelSpot)));
        this.model.removeSpot(previousSpot);
        HashSet<Spot> extraSpots = new HashSet<Spot>(novelSpotList);
        extraSpots.remove(mainNovelSpot);
        int i = 1;
        for (Spot s : extraSpots) {
            s.setName(previousSpot.getName() + "_" + i++);
            s.putFeature("POSITION_T", (double)currentTimePoint * this.dt);
            s.putFeature("QUALITY", -1.0);
            this.model.addSpotTo(s, currentTimePoint);
            IJ.log((String)(" - Added spot " + LabkitImporter.str(s)));
        }
    }

    private void addNewSpot(Iterable<Spot> novelSpotList, int currentTimePoint) {
        for (Spot spot : novelSpotList) {
            spot.putFeature("POSITION_T", (double)currentTimePoint * this.dt);
            spot.putFeature("QUALITY", -1.0);
            this.model.addSpotTo(spot, currentTimePoint);
            IJ.log((String)(" - Added spot " + LabkitImporter.str(spot)));
        }
    }

    private static final String str(Spot spot) {
        return spot.ID() + " (" + LabkitImporter.roundToN(spot.getDoublePosition(0), 1) + ", " + LabkitImporter.roundToN(spot.getDoublePosition(1), 1) + ", " + LabkitImporter.roundToN(spot.getDoublePosition(2), 1) + ") @ t=" + spot.getFeature("FRAME");
    }

    private static double roundToN(double num, int n) {
        double scale = Math.pow(10.0, n);
        return (double)Math.round(num * scale) / scale;
    }

    public static class Builder {
        private Model model;
        private double[] calibration;
        private double dt;
        private Labeling labeling;
        private RandomAccessibleInterval<UnsignedIntType> initialIndexImg;
        private int targetTimePoint;
        private Map<Label, Spot> initialMapping;
        private boolean simplifyContours;

        public LabkitImporter get() {
            return new LabkitImporter(this.model, this.labeling, this.initialIndexImg, this.initialMapping, this.targetTimePoint, this.simplifyContours, this.calibration, this.dt);
        }

        public Builder initialMapping(Map<Label, Spot> initialMapping) {
            this.initialMapping = initialMapping;
            return this;
        }

        public Builder initialIndexImg(RandomAccessibleInterval<UnsignedIntType> previousIndexImg) {
            this.initialIndexImg = previousIndexImg;
            return this;
        }

        public Builder targetTimePoint(int targetTimePoint) {
            this.targetTimePoint = targetTimePoint;
            return this;
        }

        public Builder labeling(Labeling labeling) {
            this.labeling = labeling;
            return this;
        }

        public Builder trackmateModel(Model model) {
            this.model = model;
            return this;
        }

        public Builder calibration(double[] calibration) {
            this.calibration = calibration;
            return this;
        }

        public Builder frameInterval(double dt) {
            this.dt = dt;
            return this;
        }

        public Builder simplifyContours(boolean simplifyContours) {
            this.simplifyContours = simplifyContours;
            return this;
        }
    }
}

