/*
 * Decompiled with CFR 0.152.
 */
package net.imglib2.algorithm.metrics.segmentation;

import java.util.HashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.algorithm.metrics.segmentation.MultiMetrics;
import net.imglib2.roi.labeling.ImgLabeling;
import net.imglib2.roi.labeling.Labelings;
import net.imglib2.type.numeric.IntegerType;
import net.imglib2.util.Intervals;

public class LazyMultiMetrics {
    private AtomicInteger aTP = new AtomicInteger(0);
    private AtomicInteger aFP = new AtomicInteger(0);
    private AtomicInteger aFN = new AtomicInteger(0);
    private AtomicLong aIoU = new AtomicLong(0L);
    private final double threshold;

    public LazyMultiMetrics() {
        this.threshold = 0.5;
    }

    public LazyMultiMetrics(double threshold) {
        this.threshold = threshold;
    }

    public <T, I extends IntegerType<I>, U, J extends IntegerType<J>> void addTimePoint(ImgLabeling<T, I> groundTruth, ImgLabeling<U, J> prediction) {
        if (Labelings.hasIntersectingLabels(groundTruth) || Labelings.hasIntersectingLabels(prediction)) {
            throw new UnsupportedOperationException("ImgLabeling with intersecting labels are not supported.");
        }
        this.addTimePoint(groundTruth.getIndexImg(), prediction.getIndexImg());
    }

    public <I extends IntegerType<I>, J extends IntegerType<J>> void addTimePoint(RandomAccessibleInterval<I> groundTruth, RandomAccessibleInterval<J> prediction) {
        if (!Intervals.equalDimensions(groundTruth, prediction)) {
            throw new IllegalArgumentException("Image dimensions must match.");
        }
        MultiMetrics.MetricsSummary result = MultiMetrics.runSingle(groundTruth, prediction, this.threshold);
        this.addPoint(result);
    }

    public HashMap<MultiMetrics.Metrics, Double> computeScore() {
        MultiMetrics.MetricsSummary summary = new MultiMetrics.MetricsSummary();
        int tp = this.aTP.get();
        int fp = this.aFP.get();
        int fn = this.aFN.get();
        double sumIoU = this.atomicLongToDouble(this.aIoU);
        summary.addPoint(tp, fp, fn, sumIoU);
        return summary.getScores();
    }

    protected void addPoint(MultiMetrics.MetricsSummary newResult) {
        this.aTP.addAndGet(newResult.getTP());
        this.aFP.addAndGet(newResult.getFP());
        this.aFN.addAndGet(newResult.getFN());
        this.addToAtomicLong(this.aIoU, newResult.getIoU());
    }

    private void addToAtomicLong(AtomicLong a, double b) {
        a.set(Double.doubleToRawLongBits(Double.longBitsToDouble(a.get()) + b));
    }

    private double atomicLongToDouble(AtomicLong a) {
        return Double.longBitsToDouble(a.get());
    }
}

