/*
 * Decompiled with CFR 0.152.
 */
package sc.fiji.labkit.ui.segmentation;

import java.util.Arrays;
import net.imagej.ImgPlus;
import net.imagej.axis.Axes;
import net.imglib2.Dimensions;
import net.imglib2.FinalInterval;
import net.imglib2.Interval;
import net.imglib2.RandomAccessible;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.img.Img;
import net.imglib2.img.cell.CellGrid;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.IntegerType;
import net.imglib2.type.numeric.NumericType;
import net.imglib2.type.numeric.integer.UnsignedByteType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Intervals;
import net.imglib2.view.Views;
import org.apache.commons.lang3.ArrayUtils;
import sc.fiji.labkit.ui.inputimage.ImgPlusViewsOld;
import sc.fiji.labkit.ui.models.CachedImageFactory;
import sc.fiji.labkit.ui.models.DefaultCachedImageFactory;
import sc.fiji.labkit.ui.segmentation.Segmenter;

public class SegmentationUtils {
    private SegmentationUtils() {
    }

    public static Img<FloatType> createCachedProbabilityMap(Segmenter segmenter, ImgPlus<?> image, CachedImageFactory cachedImageFactory) {
        if (cachedImageFactory == null) {
            cachedImageFactory = DefaultCachedImageFactory.getInstance();
        }
        int[] cellSize = segmenter.suggestCellSize(image);
        Interval interval = SegmentationUtils.intervalNoChannels(image);
        int count = segmenter.classNames().size();
        CellGrid grid = new CellGrid(Intervals.dimensionsAsLongArray((Dimensions)interval), cellSize);
        CellGrid gridWithChannel = SegmentationUtils.addDimensionToGrid(count, grid);
        int[] cellSizeWithChannel = SegmentationUtils.getCellDimensions(gridWithChannel);
        return cachedImageFactory.setupCachedImage(segmenter, target -> segmenter.predict(image, SegmentationUtils.ensureCellSize(segmenter, cellSizeWithChannel, target)), gridWithChannel, new FloatType());
    }

    public static Img<UnsignedByteType> createCachedSegmentation(Segmenter segmenter, ImgPlus<?> image, CachedImageFactory cachedImageFactory) {
        return SegmentationUtils.createCachedSegmentation(segmenter, image, cachedImageFactory, new UnsignedByteType());
    }

    public static <T extends IntegerType<T> & NativeType<T>> Img<T> createCachedSegmentation(Segmenter segmenter, ImgPlus<?> image, CachedImageFactory cachedImageFactory, T type) {
        if (cachedImageFactory == null) {
            cachedImageFactory = DefaultCachedImageFactory.getInstance();
        }
        int[] cellSize = segmenter.suggestCellSize(image);
        Interval interval = SegmentationUtils.intervalNoChannels(image);
        CellGrid grid = new CellGrid(Intervals.dimensionsAsLongArray((Dimensions)interval), cellSize);
        return cachedImageFactory.setupCachedImage(segmenter, target -> segmenter.segment(image, SegmentationUtils.ensureCellSize(segmenter, cellSize, target)), grid, type);
    }

    private static CellGrid addDimensionToGrid(int size, CellGrid grid) {
        long[] dimensions = ArrayUtils.add((long[])grid.getImgDimensions(), (long)size);
        int[] cellDimensions = ArrayUtils.add((int[])SegmentationUtils.getCellDimensions(grid), (int)size);
        return new CellGrid(dimensions, cellDimensions);
    }

    private static <T extends NativeType<T> & NumericType<T>> RandomAccessibleInterval<T> ensureCellSize(Segmenter segmenter, int[] cellSize, RandomAccessibleInterval<T> target) {
        int[] targetSize;
        if (segmenter.requiresFixedCellSize() && !Arrays.equals(cellSize, targetSize = Intervals.dimensionsAsIntArray(target))) {
            long[] min = Intervals.minAsLongArray(target);
            long[] max = new long[min.length];
            Arrays.setAll(max, d -> min[d] + (long)cellSize[d] - 1L);
            return Views.interval((RandomAccessible)Views.extendZero(target), (long[])min, (long[])max);
        }
        return target;
    }

    public static Interval intervalNoChannels(ImgPlus<?> image) {
        return new FinalInterval(ImgPlusViewsOld.hasAxis(image, Axes.CHANNEL) ? ImgPlusViewsOld.hyperSlice(image, Axes.CHANNEL, 0L) : image);
    }

    private static int[] getCellDimensions(CellGrid grid) {
        int[] cellDimensions = new int[grid.numDimensions()];
        grid.cellDimensions(cellDimensions);
        return cellDimensions;
    }
}

