/*
 * Decompiled with CFR 0.152.
 */
package net.imglib2.algorithm.metrics.segmentation;

import java.util.ArrayList;
import java.util.LinkedHashMap;
import net.imglib2.Cursor;
import net.imglib2.Localizable;
import net.imglib2.RandomAccess;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.type.numeric.IntegerType;
import net.imglib2.view.Views;

public class ConfusionMatrix<I extends IntegerType<I>, J extends IntegerType<J>> {
    private final ArrayList<Integer> gtCMHistogram;
    private final ArrayList<Integer> predCMHistogram;
    private final int[][] confusionMatrix;

    public ConfusionMatrix(RandomAccessibleInterval<I> groundTruth, RandomAccessibleInterval<J> prediction) {
        LinkedHashMap<Integer, Integer> gtHistogram = new LinkedHashMap<Integer, Integer>();
        LinkedHashMap<Integer, Integer> predHistogram = new LinkedHashMap<Integer, Integer>();
        Cursor cGT = Views.iterable(groundTruth).localizingCursor();
        RandomAccess cPD = prediction.randomAccess();
        while (cGT.hasNext()) {
            gtHistogram.compute(((IntegerType)cGT.next()).getInteger(), (k, v) -> v == null ? 1 : v + 1);
            cPD.setPosition((Localizable)cGT);
            predHistogram.compute(((IntegerType)cPD.get()).getInteger(), (k, v) -> v == null ? 1 : v + 1);
        }
        gtHistogram.remove(0);
        predHistogram.remove(0);
        this.confusionMatrix = new int[gtHistogram.size()][predHistogram.size()];
        LinkedHashMap groundTruthLUT = new LinkedHashMap();
        gtHistogram.keySet().forEach(key -> groundTruthLUT.put(key, groundTruthLUT.size()));
        LinkedHashMap predictionLUT = new LinkedHashMap();
        predHistogram.keySet().forEach(key -> predictionLUT.put(key, predictionLUT.size()));
        this.gtCMHistogram = new ArrayList(gtHistogram.values());
        this.predCMHistogram = new ArrayList(predHistogram.values());
        cGT.reset();
        while (cGT.hasNext()) {
            cGT.next();
            cPD.setPosition((Localizable)cGT);
            int gtLabel = ((IntegerType)cGT.get()).getInteger();
            int predLabel = ((IntegerType)cPD.get()).getInteger();
            if (gtLabel <= 0 || predLabel <= 0) continue;
            int i = (Integer)groundTruthLUT.get(gtLabel);
            int j = (Integer)predictionLUT.get(predLabel);
            int[] nArray = this.confusionMatrix[i];
            int n = j;
            nArray[n] = nArray[n] + 1;
        }
    }

    public int getGroundTruthLabelSize(int labelIndex) {
        if (labelIndex < 0 || labelIndex >= this.gtCMHistogram.size()) {
            return -1;
        }
        return this.gtCMHistogram.get(labelIndex);
    }

    public int getPredictionLabelSize(int labelIndex) {
        if (labelIndex < 0 || labelIndex >= this.predCMHistogram.size()) {
            return -1;
        }
        return this.predCMHistogram.get(labelIndex);
    }

    public int getIntersection(int gtLabelIndex, int predLabelIndex) {
        if (this.getNumberGroundTruthLabels() == 0 || this.getNumberPredictionLabels() == 0) {
            return 0;
        }
        return this.confusionMatrix[gtLabelIndex][predLabelIndex];
    }

    public int getNumberGroundTruthLabels() {
        return this.gtCMHistogram.size();
    }

    public int getNumberPredictionLabels() {
        return this.predCMHistogram.size();
    }
}

