/*
 * Decompiled with CFR 0.152.
 */
package sc.fiji.labkit.pixel_classification.classification;

import com.google.gson.Gson;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.reflect.TypeToken;
import hr.irb.fastRandomForest.FastRandomForest;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import net.haesleinhuepf.clij.coremem.enums.NativeTypeEnum;
import net.imglib2.Dimensions;
import net.imglib2.Interval;
import net.imglib2.RandomAccessible;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.img.array.ArrayImg;
import net.imglib2.img.array.ArrayImgFactory;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.IntegerType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.integer.UnsignedByteType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Cast;
import net.imglib2.util.Intervals;
import net.imglib2.view.Views;
import net.imglib2.view.composite.Composite;
import org.scijava.Context;
import sc.fiji.labkit.pixel_classification.RevampUtils;
import sc.fiji.labkit.pixel_classification.classification.ClassifierSerialization;
import sc.fiji.labkit.pixel_classification.classification.Training;
import sc.fiji.labkit.pixel_classification.gpu.api.GpuApi;
import sc.fiji.labkit.pixel_classification.gpu.api.GpuCopy;
import sc.fiji.labkit.pixel_classification.gpu.api.GpuImage;
import sc.fiji.labkit.pixel_classification.gpu.api.GpuPool;
import sc.fiji.labkit.pixel_classification.gpu.random_forest.GpuRandomForestPrediction;
import sc.fiji.labkit.pixel_classification.pixel_feature.calculator.FeatureCalculator;
import sc.fiji.labkit.pixel_classification.pixel_feature.settings.FeatureSettings;
import sc.fiji.labkit.pixel_classification.random_forest.CpuRandomForestPrediction;
import weka.classifiers.Classifier;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;

public class Segmenter {
    private final FeatureCalculator features;
    private final List<String> classNames;
    private final Classifier classifier;
    private GpuRandomForestPrediction gpuPrediction;
    private CpuRandomForestPrediction cpuPrediction;
    private boolean useGpu = false;

    private Segmenter(List<String> classNames, FeatureCalculator features, Classifier classifier) {
        this.classNames = Collections.unmodifiableList(classNames);
        this.features = Objects.requireNonNull(features);
        this.classifier = Objects.requireNonNull(classifier);
        this.updatePrecacheRandomForests();
    }

    private void updatePrecacheRandomForests() {
        this.gpuPrediction = new GpuRandomForestPrediction((FastRandomForest)Cast.unchecked((Object)this.classifier), this.features.count());
        this.cpuPrediction = new CpuRandomForestPrediction((FastRandomForest)Cast.unchecked((Object)this.classifier), this.features.count());
    }

    public Segmenter(Context context, List<String> classNames, FeatureSettings features, Classifier classifier) {
        this(classNames, new FeatureCalculator(context, features), classifier);
    }

    public Classifier getClassifier() {
        return this.classifier;
    }

    public void setUseGpu(boolean useGpu) {
        this.useGpu = useGpu;
        this.features.setUseGpu(useGpu);
    }

    public FeatureCalculator features() {
        return this.features;
    }

    public FeatureSettings settings() {
        return this.features.settings();
    }

    public RandomAccessibleInterval<UnsignedByteType> segment(RandomAccessibleInterval<?> image) {
        return this.segment(image, new UnsignedByteType());
    }

    private <T extends NativeType<T>> RandomAccessibleInterval<T> createImage(T type, Interval interval) {
        long[] size = Intervals.dimensionsAsLongArray((Dimensions)interval);
        long[] min = Intervals.minAsLongArray((Interval)interval);
        ArrayImg img = new ArrayImgFactory(type).create(size);
        return Views.translate((RandomAccessibleInterval)img, (long[])min);
    }

    public <T extends IntegerType<T> & NativeType<T>> RandomAccessibleInterval<T> segment(RandomAccessibleInterval<?> image, T type) {
        Interval interval = this.features.outputIntervalFromInput(image);
        RandomAccessibleInterval<NativeType<T>> rai = this.createImage(type, interval);
        this.segment((RandomAccessibleInterval<? extends IntegerType<?>>)rai, (RandomAccessible<?>)Views.extendBorder(image));
        return rai;
    }

    public void segment(RandomAccessibleInterval<? extends IntegerType<?>> out, RandomAccessible<?> image) {
        Objects.requireNonNull(out);
        Objects.requireNonNull(image);
        if (this.useGpu) {
            this.segmentGpu(image, out);
        } else {
            this.segmentCpu(image, out);
        }
    }

    private void segmentCpu(RandomAccessible<?> image, RandomAccessibleInterval<? extends IntegerType<?>> out) {
        RandomAccessibleInterval<FloatType> featureValues = this.features.apply(image, (Interval)out);
        this.cpuPrediction.segment(featureValues, out);
    }

    private void segmentGpu(RandomAccessible<?> image, RandomAccessibleInterval<? extends IntegerType<?>> out) {
        try (GpuApi scope = GpuPool.borrowGpu();){
            GpuImage featureStack = this.features.applyUseGpu(scope, image, (Interval)out);
            GpuImage segmentationBuffer = this.gpuPrediction.segment(scope, featureStack);
            GpuCopy.copyFromTo(segmentationBuffer, out);
        }
    }

    public RandomAccessibleInterval<? extends RealType<?>> predict(RandomAccessibleInterval<?> image) {
        Objects.requireNonNull(image);
        Interval outputInterval = this.features.outputIntervalFromInput(image);
        RandomAccessibleInterval<FloatType> result = RevampUtils.createImage(RevampUtils.appendDimensionToInterval(outputInterval, 0L, this.classNames.size() - 1), new FloatType());
        this.predict((RandomAccessibleInterval<? extends RealType<?>>)result, (RandomAccessible<?>)Views.extendBorder(image));
        return result;
    }

    public void predict(RandomAccessibleInterval<? extends RealType<?>> out, RandomAccessible<?> image) {
        Objects.requireNonNull(out);
        Objects.requireNonNull(image);
        if (this.useGpu) {
            this.predictGpu(out, image);
        } else {
            this.predictCpu(out, image);
        }
    }

    private void predictCpu(RandomAccessibleInterval<? extends RealType<?>> out, RandomAccessible<?> image) {
        Interval interval = RevampUtils.removeLastDimension(out);
        RandomAccessibleInterval<FloatType> featureValues = this.features.apply(image, interval);
        this.cpuPrediction.distribution(featureValues, out);
    }

    private void predictGpu(RandomAccessibleInterval<? extends RealType<?>> out, RandomAccessible<?> image) {
        Interval interval = RevampUtils.removeLastDimension(out);
        try (GpuApi scope = GpuPool.borrowGpu();){
            GpuImage featureStack = this.features.applyUseGpu(scope, image, interval);
            GpuImage distribution = scope.create(featureStack.getDimensions(), this.classNames.size(), NativeTypeEnum.Float);
            this.gpuPrediction.distribution(scope, featureStack, distribution);
            GpuCopy.copyFromTo(distribution, out);
        }
    }

    public List<String> classNames() {
        return this.classNames;
    }

    public Training training() {
        return new MyTrainingData();
    }

    public JsonElement toJsonTree() {
        JsonObject json = new JsonObject();
        json.add("features", this.features.settings().toJson());
        json.add("classNames", new Gson().toJsonTree(this.classNames));
        json.add("classifier", ClassifierSerialization.wekaToJson(this.classifier));
        return json;
    }

    public static Segmenter fromJson(Context context, JsonElement json) {
        JsonObject object = json.getAsJsonObject();
        return new Segmenter(context, (List)new Gson().fromJson(object.get("classNames"), new TypeToken<List<String>>(){}.getType()), FeatureSettings.fromJson(object.get("features")), ClassifierSerialization.jsonToWeka(object.get("classifier")));
    }

    private List<Attribute> attributes() {
        Stream<Attribute> featureAttributes = this.features.attributeLabels().stream().map(Attribute::new);
        Stream<Attribute> classAttribute = Stream.of(new Attribute("class", this.classNames));
        return Stream.concat(featureAttributes, classAttribute).collect(Collectors.toList());
    }

    private class MyTrainingData
    implements Training {
        final Instances instances;
        final int featureCount;

        MyTrainingData() {
            this.instances = new Instances("segment", new ArrayList(Segmenter.this.attributes()), 1);
            this.featureCount = Segmenter.this.features.count();
            this.instances.setClassIndex(this.featureCount);
        }

        @Override
        public void add(Composite<? extends RealType<?>> featureVector, int classIndex) {
            this.instances.add((Instance)RevampUtils.getInstance(this.featureCount, classIndex, featureVector));
        }

        @Override
        public void train() {
            RevampUtils.wrapException(() -> Segmenter.this.classifier.buildClassifier(this.instances));
            Segmenter.this.updatePrecacheRandomForests();
        }
    }
}

