/*
 * Decompiled with CFR 0.152.
 */
package sc.fiji.labkit.pixel_classification.pixel_feature.calculator;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.IntPredicate;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import net.haesleinhuepf.clij.coremem.enums.NativeTypeEnum;
import net.imglib2.Dimensions;
import net.imglib2.FinalInterval;
import net.imglib2.Interval;
import net.imglib2.RandomAccessible;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.img.array.ArrayImg;
import net.imglib2.img.array.ArrayImgs;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Intervals;
import net.imglib2.view.IntervalView;
import net.imglib2.view.Views;
import org.scijava.Context;
import sc.fiji.labkit.pixel_classification.RevampUtils;
import sc.fiji.labkit.pixel_classification.gpu.GpuFeatureInput;
import sc.fiji.labkit.pixel_classification.gpu.api.GpuApi;
import sc.fiji.labkit.pixel_classification.gpu.api.GpuCopy;
import sc.fiji.labkit.pixel_classification.gpu.api.GpuImage;
import sc.fiji.labkit.pixel_classification.gpu.api.GpuPool;
import sc.fiji.labkit.pixel_classification.gpu.api.GpuView;
import sc.fiji.labkit.pixel_classification.gpu.api.GpuViews;
import sc.fiji.labkit.pixel_classification.pixel_feature.calculator.ColorInputPreprocessor;
import sc.fiji.labkit.pixel_classification.pixel_feature.calculator.DeprecatedColorInputPreprocessor;
import sc.fiji.labkit.pixel_classification.pixel_feature.calculator.GrayInputPreprocessor;
import sc.fiji.labkit.pixel_classification.pixel_feature.calculator.InputPreprocessor;
import sc.fiji.labkit.pixel_classification.pixel_feature.calculator.MultiChannelInputPreprocessor;
import sc.fiji.labkit.pixel_classification.pixel_feature.filter.FeatureInput;
import sc.fiji.labkit.pixel_classification.pixel_feature.filter.FeatureJoiner;
import sc.fiji.labkit.pixel_classification.pixel_feature.filter.FeatureOp;
import sc.fiji.labkit.pixel_classification.pixel_feature.settings.ChannelSetting;
import sc.fiji.labkit.pixel_classification.pixel_feature.settings.FeatureSetting;
import sc.fiji.labkit.pixel_classification.pixel_feature.settings.FeatureSettings;
import sc.fiji.labkit.pixel_classification.pixel_feature.settings.GlobalSettings;
import sc.fiji.labkit.pixel_classification.utils.SingletonContext;

public class FeatureCalculator {
    private final FeatureJoiner joiner;
    private final FeatureSettings settings;
    private final InputPreprocessor preprocessor;
    private boolean useGpu = false;

    public FeatureCalculator(Context context, FeatureSettings settings) {
        this.settings = settings;
        List<FeatureOp> featureOps = settings.features().stream().map(x -> x.newInstance(context, settings.globals())).collect(Collectors.toList());
        this.joiner = new FeatureJoiner(featureOps);
        this.preprocessor = this.initPreprocessor(settings.globals().channelSetting());
    }

    public static Builder default2d() {
        return (Builder)new Builder().dimensions(2);
    }

    private InputPreprocessor initPreprocessor(ChannelSetting channelSetting) {
        if (ChannelSetting.RGB.equals(channelSetting)) {
            return new ColorInputPreprocessor(this.settings.globals());
        }
        if (ChannelSetting.DEPRECATED_RGB.equals(channelSetting)) {
            return new DeprecatedColorInputPreprocessor(this.settings.globals());
        }
        if (ChannelSetting.SINGLE.equals(channelSetting)) {
            return new GrayInputPreprocessor(this.settings.globals());
        }
        if (channelSetting.isMultiple()) {
            return new MultiChannelInputPreprocessor(this.settings.globals());
        }
        throw new UnsupportedOperationException("Unsupported channel setting: " + this.settings().globals().channelSetting());
    }

    public FeatureSettings settings() {
        return this.settings;
    }

    public List<FeatureOp> features() {
        return this.joiner.features();
    }

    public int count() {
        return this.joiner.count() * this.channelCount();
    }

    public List<String> attributeLabels() {
        return FeatureCalculator.prepend(this.settings.globals().channelSetting().channels(), this.joiner.attributeLabels());
    }

    public void setUseGpu(boolean useGpu) {
        this.useGpu = useGpu;
    }

    public void apply(RandomAccessible<?> input, RandomAccessibleInterval<FloatType> output) {
        if (this.useGpu) {
            Interval interval = RevampUtils.removeLastDimension(output);
            try (GpuApi scope = GpuPool.borrowGpu();){
                GpuImage result = this.applyUseGpu(scope, input, interval);
                GpuCopy.copyFromTo(result, output);
            }
        } else {
            this.applyUseCpu(input, output);
        }
    }

    public RandomAccessibleInterval<FloatType> apply(RandomAccessibleInterval<?> image) {
        return this.apply((RandomAccessible<?>)Views.extendBorder(image), this.preprocessor.outputIntervalFromInput(image));
    }

    public RandomAccessibleInterval<FloatType> apply(RandomAccessible<?> extendedImage, Interval interval) {
        FinalInterval fullInterval = Intervals.addDimension((Interval)interval, (long)0L, (long)(this.count() - 1));
        if (this.useGpu) {
            try (GpuApi scope = GpuPool.borrowGpu();){
                GpuImage featureStack = this.applyUseGpu(scope, extendedImage, interval);
                IntervalView intervalView = Views.translate(scope.pullRAIMultiChannel(featureStack), (long[])Intervals.minAsLongArray((Interval)fullInterval));
                return intervalView;
            }
        }
        ArrayImg image = ArrayImgs.floats((long[])Intervals.dimensionsAsLongArray((Dimensions)fullInterval));
        IntervalView rai = Views.translate((RandomAccessibleInterval)image, (long[])Intervals.minAsLongArray((Interval)fullInterval));
        this.applyUseCpu(extendedImage, (RandomAccessibleInterval<FloatType>)rai);
        return rai;
    }

    private void applyUseCpu(RandomAccessible<?> input, RandomAccessibleInterval<FloatType> output) {
        List<RandomAccessible<FloatType>> channels = this.preprocessor.getChannels(input);
        List<List<RandomAccessibleInterval<FloatType>>> outputs = FeatureCalculator.split(RevampUtils.slices(output), channels.size());
        double[] pixelSize = this.settings.globals().pixelSizeAsDoubleArray();
        for (int i = 0; i < channels.size(); ++i) {
            FeatureInput in = new FeatureInput(channels.get(i), (Interval)outputs.get(i).get(0), pixelSize);
            this.joiner.apply(in, outputs.get(i));
        }
    }

    public GpuImage applyUseGpu(GpuApi gpu, RandomAccessible<?> input, Interval interval) {
        if (interval.numDimensions() != this.settings().globals().numDimensions()) {
            throw new IllegalArgumentException("Wrong dimension of the output interval.");
        }
        double[] pixelSize = this.settings.globals().pixelSizeAsDoubleArray();
        List<RandomAccessible<FloatType>> channels = this.preprocessor.getChannels(input);
        GpuImage featureStack = gpu.create(Intervals.dimensionsAsLongArray((Dimensions)interval), this.count(), NativeTypeEnum.Float);
        List<List<GpuView>> outputs = FeatureCalculator.split(GpuViews.channels(featureStack), channels.size());
        for (int i = 0; i < channels.size(); ++i) {
            try (GpuApi scope = gpu.subScope();){
                GpuFeatureInput in = new GpuFeatureInput(scope, channels.get(i), interval, pixelSize);
                this.joiner.prefetch(in);
                this.joiner.apply(in, outputs.get(i));
                continue;
            }
        }
        return featureStack;
    }

    public Interval outputIntervalFromInput(RandomAccessibleInterval<?> image) {
        return this.preprocessor.outputIntervalFromInput(image);
    }

    private int channelCount() {
        return this.settings.globals().channelSetting().channels().size();
    }

    private static List<String> prepend(List<String> prepend, List<String> labels) {
        return labels.stream().flatMap(label -> prepend.stream().map(pre -> pre.isEmpty() ? label : pre + "_" + label)).collect(Collectors.toList());
    }

    private static <T> List<List<T>> split(List<T> input, int count) {
        return IntStream.range(0, count).mapToObj(i -> FeatureCalculator.filterByIndexPredicate(input, index -> index % count == i)).collect(Collectors.toList());
    }

    private static <T> List<T> filterByIndexPredicate(List<T> in, IntPredicate predicate) {
        return IntStream.range(0, in.size()).filter(predicate).mapToObj(in::get).collect(Collectors.toList());
    }

    public static class Builder
    extends GlobalSettings.AbstractBuilder<Builder> {
        private Context context;
        private final List<FeatureSetting> features = new ArrayList<FeatureSetting>();

        private Builder() {
        }

        public Builder context(Context context) {
            this.context = context;
            return this;
        }

        public Builder addFeatures(FeatureSetting ... features) {
            this.features.addAll(Arrays.asList(features));
            return this;
        }

        public Builder addFeature(Class<? extends FeatureOp> clazz, Object ... parameters) {
            this.features.add(new FeatureSetting(clazz, parameters));
            return this;
        }

        public Builder addFeature(FeatureSetting featureSetting) {
            this.addFeatures(featureSetting);
            return this;
        }

        public FeatureCalculator build() {
            if (this.context == null) {
                this.context = SingletonContext.getInstance();
            }
            GlobalSettings globalSettings = this.buildGlobalSettings();
            FeatureSettings featureSettings = new FeatureSettings(globalSettings, this.features);
            return new FeatureCalculator(this.context, featureSettings);
        }
    }
}

