/*
 * Decompiled with CFR 0.152.
 */
package net.imglib2.algorithm.metrics.segmentation;

import java.util.Arrays;
import java.util.HashMap;
import java.util.stream.Stream;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.algorithm.metrics.segmentation.ConfusionMatrix;
import net.imglib2.algorithm.metrics.segmentation.assignment.MunkresKuhnAlgorithm;
import net.imglib2.roi.labeling.ImgLabeling;
import net.imglib2.roi.labeling.Labelings;
import net.imglib2.type.numeric.IntegerType;
import net.imglib2.util.Intervals;
import net.imglib2.view.IntervalView;
import net.imglib2.view.Views;

public class MultiMetrics {
    private static final int T_AXIS = 3;

    public static <T, I extends IntegerType<I>, U, J extends IntegerType<J>> HashMap<Metrics, Double> computeMetrics(ImgLabeling<T, I> groundTruth, ImgLabeling<U, J> prediction, double threshold) {
        if (Labelings.hasIntersectingLabels(groundTruth) || Labelings.hasIntersectingLabels(prediction)) {
            throw new UnsupportedOperationException("ImgLabeling with intersecting labels are not supported.");
        }
        return MultiMetrics.computeMetrics(groundTruth.getIndexImg(), prediction.getIndexImg(), threshold);
    }

    public static <I extends IntegerType<I>, J extends IntegerType<J>> HashMap<Metrics, Double> computeMetrics(RandomAccessibleInterval<I> groundTruth, RandomAccessibleInterval<J> prediction, double threshold) {
        if (!Intervals.equalDimensions(groundTruth, prediction)) {
            throw new IllegalArgumentException("Image dimensions must match.");
        }
        boolean timeLapse = false;
        if (groundTruth.numDimensions() > 3) {
            boolean bl = timeLapse = groundTruth.dimension(3) > 1L;
        }
        if (timeLapse) {
            return MultiMetrics.runAverageOverTime(groundTruth, prediction, threshold).getScores();
        }
        return MultiMetrics.runSingle(groundTruth, prediction, threshold).getScores();
    }

    protected static <I extends IntegerType<I>, J extends IntegerType<J>> MetricsSummary runAverageOverTime(RandomAccessibleInterval<I> groundTruth, RandomAccessibleInterval<J> prediction, double threshold) {
        int nFrames = (int)groundTruth.dimension(3);
        MetricsSummary metrics = new MetricsSummary();
        for (int i = 0; i < nFrames; ++i) {
            IntervalView gtFrame = Views.hyperSlice(groundTruth, (int)3, (long)i);
            IntervalView predFrame = Views.hyperSlice(prediction, (int)3, (long)i);
            MetricsSummary result = MultiMetrics.runSingle(gtFrame, predFrame, threshold);
            metrics.addPoint(result);
        }
        return metrics;
    }

    protected static <I extends IntegerType<I>, J extends IntegerType<J>> MetricsSummary runSingle(RandomAccessibleInterval<I> groundTruth, RandomAccessibleInterval<J> prediction, double threshold) {
        ConfusionMatrix<I, J> confusionMatrix = new ConfusionMatrix<I, J>(groundTruth, prediction);
        double[][] costMatrix = MultiMetrics.computeCostMatrix(confusionMatrix, threshold);
        return MultiMetrics.computeFinalScores(confusionMatrix, costMatrix, threshold);
    }

    protected static double[][] computeCostMatrix(ConfusionMatrix cM, double threshold) {
        int M = cM.getNumberGroundTruthLabels();
        int N = cM.getNumberPredictionLabels();
        double[][] costMatrix = new double[M][Math.max(M + 1, N)];
        for (int i = 0; i < M; ++i) {
            for (int j = 0; j < N; ++j) {
                costMatrix[i][j] = -MultiMetrics.getLocalIoUScore(cM, i, j, threshold);
            }
        }
        return costMatrix;
    }

    protected static double getLocalIoUScore(ConfusionMatrix cM, int iGT, int jPred, double threshold) {
        double fn;
        double iou;
        double tp = cM.getIntersection(iGT, jPred);
        int sumI = cM.getGroundTruthLabelSize(iGT);
        int sumJ = cM.getPredictionLabelSize(jPred);
        double fp = (double)sumJ - tp;
        double d = iou = tp + fp + (fn = (double)sumI - tp) > 0.0 ? tp / (tp + fp + fn) : 0.0;
        if (iou < threshold) {
            iou = 0.0;
        }
        return iou;
    }

    protected static MetricsSummary computeFinalScores(ConfusionMatrix confusionMatrix, double[][] costMatrix, double threshold) {
        MetricsSummary summary = new MetricsSummary();
        int[][] assignment = new MunkresKuhnAlgorithm().computeAssignments(costMatrix);
        if (assignment.length != 0 && assignment[0].length != 0) {
            int tp = 0;
            double sumIoU = 0.0;
            for (int i = 0; i < assignment.length; ++i) {
                if (!(-costMatrix[assignment[i][0]][assignment[i][1]] >= threshold)) continue;
                ++tp;
                sumIoU += -costMatrix[assignment[i][0]][assignment[i][1]];
            }
            int fn = confusionMatrix.getNumberGroundTruthLabels() - tp;
            int fp = confusionMatrix.getNumberPredictionLabels() - tp;
            summary.addPoint(tp, fp, fn, sumIoU);
        } else {
            int fn = confusionMatrix.getNumberGroundTruthLabels();
            int fp = confusionMatrix.getNumberPredictionLabels();
            summary.addPoint(0, fp, fn, 0.0);
        }
        return summary;
    }

    protected static class MetricsSummary {
        private int tp = 0;
        private int fp = 0;
        private int fn = 0;
        private double sumIoU = 0.0;

        protected MetricsSummary() {
        }

        public void addPoint(MetricsSummary otherMetrics) {
            this.tp += otherMetrics.tp;
            this.fp += otherMetrics.fp;
            this.fn += otherMetrics.fn;
            this.sumIoU += otherMetrics.sumIoU;
        }

        public void addPoint(int tp, int fp, int fn, double sumIoU) {
            this.tp += tp;
            this.fp += fp;
            this.fn += fn;
            this.sumIoU += sumIoU;
        }

        public int getTP() {
            return this.tp;
        }

        public int getFN() {
            return this.fn;
        }

        public int getFP() {
            return this.fp;
        }

        public double getIoU() {
            return this.sumIoU;
        }

        protected double meanMatchedIoU(double tp, double sumIoU) {
            return tp > 0.0 ? sumIoU / tp : Double.NaN;
        }

        protected double meanTrueIoU(double tp, double fn, double sumIoU) {
            return tp + fn > 0.0 ? sumIoU / (tp + fn) : Double.NaN;
        }

        protected double precision(double tp, double fp) {
            return tp + fp > 0.0 ? tp / (tp + fp) : Double.NaN;
        }

        protected double recall(double tp, double fn) {
            return tp + fn > 0.0 ? tp / (tp + fn) : Double.NaN;
        }

        protected double f1(double precision, double recall) {
            return precision + recall > 0.0 ? 2.0 * precision * recall / (precision + recall) : Double.NaN;
        }

        protected double accuracy(double tp, double fp, double fn) {
            return tp + fn + fp > 0.0 ? tp / (tp + fn + fp) : Double.NaN;
        }

        public HashMap<Metrics, Double> getScores() {
            HashMap<Metrics, Double> metrics = new HashMap<Metrics, Double>();
            double meanMatched = this.meanMatchedIoU(this.tp, this.sumIoU);
            double meanTrue = this.meanTrueIoU(this.tp, this.fn, this.sumIoU);
            double precision = this.precision(this.tp, this.fp);
            double recall = this.recall(this.tp, this.fn);
            double f1 = this.f1(precision, recall);
            double accuracy = this.accuracy(this.tp, this.fp, this.fn);
            metrics.put(Metrics.TP, Double.valueOf(this.tp));
            metrics.put(Metrics.FP, Double.valueOf(this.fp));
            metrics.put(Metrics.FN, Double.valueOf(this.fn));
            metrics.put(Metrics.MEAN_MATCHED_IOU, meanMatched);
            metrics.put(Metrics.MEAN_TRUE_IOU, meanTrue);
            metrics.put(Metrics.PRECISION, precision);
            metrics.put(Metrics.RECALL, recall);
            metrics.put(Metrics.F1, f1);
            metrics.put(Metrics.ACCURACY, accuracy);
            return metrics;
        }
    }

    public static enum Metrics {
        ACCURACY("Accuracy"),
        MEAN_MATCHED_IOU("Mean matched IoU"),
        MEAN_TRUE_IOU("Mean true IoU"),
        TP("True positives"),
        FP("False positives"),
        FN("False negatives"),
        PRECISION("Precision"),
        RECALL("Recall"),
        F1("F1");

        private final String name;

        private Metrics(String name) {
            this.name = name;
        }

        public String getName() {
            return this.name;
        }

        public static Stream<Metrics> stream() {
            return Arrays.stream(Metrics.values());
        }
    }
}

