/*
 * Decompiled with CFR 0.152.
 */
package sc.fiji.labkit.pixel_classification.classification;

import hr.irb.fastRandomForest.FastRandomForest;
import ij.Prefs;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import net.imglib2.Cursor;
import net.imglib2.Localizable;
import net.imglib2.RandomAccess;
import net.imglib2.RandomAccessible;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.roi.labeling.LabelRegion;
import net.imglib2.roi.labeling.LabelRegions;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.view.composite.Composite;
import org.scijava.Context;
import sc.fiji.labkit.pixel_classification.classification.Segmenter;
import sc.fiji.labkit.pixel_classification.classification.Training;
import sc.fiji.labkit.pixel_classification.pixel_feature.calculator.FeatureCalculator;
import sc.fiji.labkit.pixel_classification.pixel_feature.settings.FeatureSettings;
import sc.fiji.labkit.pixel_classification.utils.views.FastViews;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;

public class Trainer {
    private final FeatureCalculator features;
    private final List<String> classNames;
    private final Training training;
    private boolean autoFinish = true;
    private boolean finished = false;

    private Trainer(Segmenter segmenter) {
        this.features = segmenter.features();
        this.training = segmenter.training();
        this.classNames = segmenter.classNames();
    }

    public static Trainer of(Segmenter segmenter) {
        return new Trainer(segmenter);
    }

    public void start() {
        this.autoFinish = false;
    }

    public void finish() {
        if (this.finished) {
            throw new IllegalStateException();
        }
        this.finished = true;
        this.training.train();
    }

    public void trainLabeledImage(RandomAccessibleInterval<?> image, LabelRegions<?> labeling) {
        RandomAccessibleInterval<Composite<FloatType>> featureStack = FastViews.collapse(this.features.apply(image));
        this.trainLabeledFeatures((RandomAccessible<? extends Composite<? extends RealType<?>>>)featureStack, (LabelRegions)labeling);
    }

    public <L> void trainLabeledFeatures(RandomAccessible<? extends Composite<? extends RealType<?>>> features, LabelRegions<L> regions) {
        RandomAccess ra = features.randomAccess();
        Map<String, L> kayMap = this.createKeyMap(regions);
        for (int classIndex = 0; classIndex < this.classNames.size(); ++classIndex) {
            L label = kayMap.get(this.classNames.get(classIndex));
            if (label == null) continue;
            LabelRegion region = regions.getLabelRegion(label);
            Cursor cursor = region.inside().cursor();
            while (cursor.hasNext()) {
                cursor.next();
                ra.setPosition((Localizable)cursor);
                this.training.add((Composite)ra.get(), classIndex);
            }
        }
        if (this.autoFinish) {
            this.finish();
        }
    }

    private <L> Map<String, L> createKeyMap(LabelRegions<L> regions) {
        HashMap map = new HashMap();
        regions.getExistingLabels().forEach(label -> map.put(label.toString(), label));
        return map;
    }

    public static Segmenter train(Context context, RandomAccessibleInterval<?> image, LabelRegions<?> labeling, FeatureSettings features) {
        return Trainer.train(context, image, labeling, features, (Classifier)Trainer.initRandomForest());
    }

    public static Segmenter train(Context context, RandomAccessibleInterval<?> image, LabelRegions<?> labeling, FeatureSettings features, Classifier initialWekaClassifier) {
        List<String> classNames = labeling.getExistingLabels().stream().map(Object::toString).collect(Collectors.toList());
        Segmenter segmenter = new Segmenter(context, classNames, features, initialWekaClassifier);
        Trainer.of(segmenter).trainLabeledImage(image, labeling);
        return segmenter;
    }

    public static AbstractClassifier initRandomForest() {
        FastRandomForest rf = new FastRandomForest();
        int numOfTrees = 200;
        rf.setNumTrees(numOfTrees);
        int randomFeatures = 2;
        rf.setNumFeatures(randomFeatures);
        rf.setSeed(1);
        rf.setNumThreads(Prefs.getThreads());
        return rf;
    }
}

