/*
 * Decompiled with CFR 0.152.
 */
package net.imglib2.algorithm.metrics.segmentation;

import net.imglib2.RandomAccessibleInterval;
import net.imglib2.algorithm.metrics.segmentation.ConfusionMatrix;
import net.imglib2.roi.labeling.ImgLabeling;
import net.imglib2.roi.labeling.Labelings;
import net.imglib2.type.numeric.IntegerType;
import net.imglib2.util.Intervals;
import net.imglib2.util.Pair;
import net.imglib2.util.ValuePair;
import net.imglib2.view.IntervalView;
import net.imglib2.view.Views;

public class SEGMetrics {
    private static final int T_AXIS = 3;

    public static <T, I extends IntegerType<I>, U, J extends IntegerType<J>> double computeMetrics(ImgLabeling<T, I> groundTruth, ImgLabeling<U, J> prediction) {
        if (Labelings.hasIntersectingLabels(groundTruth) || Labelings.hasIntersectingLabels(prediction)) {
            throw new UnsupportedOperationException("ImgLabeling with intersecting labels are not supported.");
        }
        return SEGMetrics.computeMetrics(groundTruth.getIndexImg(), prediction.getIndexImg());
    }

    public static <I extends IntegerType<I>, J extends IntegerType<J>> double computeMetrics(RandomAccessibleInterval<I> groundTruth, RandomAccessibleInterval<J> prediction) {
        if (!Intervals.equalDimensions(groundTruth, prediction)) {
            throw new IllegalArgumentException("Image dimensions must match.");
        }
        boolean timeLapse = false;
        if (groundTruth.numDimensions() > 3) {
            boolean bl = timeLapse = groundTruth.dimension(3) > 1L;
        }
        if (timeLapse) {
            return SEGMetrics.runAverageOverTime(groundTruth, prediction);
        }
        Pair<Integer, Double> result = SEGMetrics.runSingle(groundTruth, prediction);
        return (Integer)result.getA() > 0 ? (Double)result.getB() / (double)((Integer)result.getA()).intValue() : Double.NaN;
    }

    private static <I extends IntegerType<I>, J extends IntegerType<J>> double runAverageOverTime(RandomAccessibleInterval<I> groundTruth, RandomAccessibleInterval<J> prediction) {
        int nFrames = (int)groundTruth.dimension(3);
        double sumScores = 0.0;
        double nGT = 0.0;
        for (int i = 0; i < nFrames; ++i) {
            IntervalView predFrame;
            IntervalView gtFrame = Views.hyperSlice(groundTruth, (int)3, (long)i);
            Pair<Integer, Double> result = SEGMetrics.runSingle(gtFrame, predFrame = Views.hyperSlice(prediction, (int)3, (long)i));
            if (Double.compare((Double)result.getB(), Double.NaN) == 0) continue;
            nGT += (double)((Integer)result.getA()).intValue();
            sumScores += ((Double)result.getB()).doubleValue();
        }
        return nGT > 0.0 ? sumScores / nGT : Double.NaN;
    }

    protected static <I extends IntegerType<I>, J extends IntegerType<J>> Pair<Integer, Double> runSingle(RandomAccessibleInterval<I> groundTruth, RandomAccessibleInterval<J> prediction) {
        ConfusionMatrix<I, J> confusionMatrix = new ConfusionMatrix<I, J>(groundTruth, prediction);
        int n = confusionMatrix.getNumberGroundTruthLabels();
        double[][] costMatrix = SEGMetrics.computeCostMatrix(confusionMatrix);
        return new ValuePair((Object)n, (Object)SEGMetrics.computeFinalScore(costMatrix));
    }

    protected static double[][] computeCostMatrix(ConfusionMatrix cM) {
        int M = cM.getNumberGroundTruthLabels();
        int N = cM.getNumberPredictionLabels();
        double[][] costMatrix = new double[M][N];
        for (int i = 0; i < M; ++i) {
            for (int j = 0; j < N; ++j) {
                costMatrix[i][j] = SEGMetrics.getLocalIoUScore(cM, i, j);
            }
        }
        return costMatrix;
    }

    protected static double getLocalIoUScore(ConfusionMatrix cM, int i, int j) {
        double gtSize;
        double intersection = cM.getIntersection(i, j);
        double overlap = intersection / (gtSize = (double)cM.getGroundTruthLabelSize(i));
        if (overlap > 0.5) {
            double predSize = cM.getPredictionLabelSize(j);
            return intersection / (gtSize + predSize - intersection);
        }
        return 0.0;
    }

    private static double computeFinalScore(double[][] costMatrix) {
        if (costMatrix.length != 0 && costMatrix[0].length != 0) {
            int M = costMatrix.length;
            int N = costMatrix[0].length;
            double precision = 0.0;
            for (int i = 0; i < M; ++i) {
                for (int j = 0; j < N; ++j) {
                    precision += costMatrix[i][j];
                }
            }
            return precision;
        }
        return 0.0;
    }
}

