/*
 * Decompiled with CFR 0.152.
 */
package bigwarp.scripts;

import bigwarp.transforms.NgffTransformations;
import java.util.Arrays;
import java.util.concurrent.Callable;
import java.util.stream.IntStream;
import net.imagej.Dataset;
import net.imagej.axis.CalibratedAxis;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.converter.Converter;
import net.imglib2.converter.Converters;
import net.imglib2.realtransform.AffineGet;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.IntegerType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.integer.ByteType;
import net.imglib2.type.numeric.integer.ShortType;
import net.imglib2.type.numeric.real.DoubleType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.view.Views;
import org.janelia.saalfeldlab.n5.Compression;
import org.janelia.saalfeldlab.n5.N5Writer;
import org.janelia.saalfeldlab.n5.ij.N5ScalePyramidExporter;
import org.janelia.saalfeldlab.n5.imglib2.N5DisplacementField;
import org.janelia.saalfeldlab.n5.universe.N5Factory;
import org.janelia.saalfeldlab.n5.universe.metadata.ome.ngff.v05.transformations.DisplacementFieldCoordinateTransform;
import org.scijava.command.Command;
import org.scijava.log.LogService;
import org.scijava.plugin.Parameter;
import org.scijava.plugin.Plugin;
import org.scijava.ui.UIService;

@Plugin(type=Command.class, menuPath="Plugins>Transform>Write Displacement Field")
public class WriteDisplacementField
implements Callable<Void>,
Command {
    public static final String INT8 = "INT8";
    public static final String INT16 = "INT16";
    public static final String FLOAT32 = "FLOAT32";
    public static final String FLOAT64 = "FLOAT64";
    @Parameter
    private UIService ui;
    @Parameter
    private LogService log;
    @Parameter
    private String n5Root;
    @Parameter
    private String n5Dataset;
    @Parameter
    private Dataset dataset;
    @Parameter(label="Chunk size", description="The size of chunks. Comma separated, for example: \"64,32,16\".\n ImageJ's axis order is X,Y,C,Z,T. The chunk size must be specified in this order.\nYou must skip any axis whose size is 1, e.g. a 2D time-series without channels\nmay have a chunk size of 1024,1024,1 (X,Y,T).\nYou may provide fewer values than the data dimension. In that case, the size will\nbe expanded to necessary size with the last value, for example \"64\", will expand\nto \"64,64,64\" for 3D data.")
    private String chunkSizeArg;
    @Parameter(label="Compression", style="listBox", choices={"gzip", "raw", "lz4", "xz", "blosc", "zstd"})
    private String compressionArg = "gzip";
    @Parameter(label="Output type", style="listBox", choices={"FLOAT64", "FLOAT32", "INT16", "INT8"})
    private String outputType;
    @Parameter(label="Output format", style="listBox", choices={"NGFF", "N5", "TPS"})
    private String format = "NGFF";
    @Parameter(label="Thread count", required=true, min="1", max="999")
    private int nThreads = 1;
    @Parameter(label="Quantization Error", required=true, min="0")
    private double quantizationError = 0.01;
    private int nd = -1;
    private int vectorDim = -1;
    private int vectorSize = -1;

    public <T extends RealType<T> & NativeType<T>, S extends RealType<S> & NativeType<S>, Q extends NativeType<Q> & IntegerType<Q>> void process() {
        AffineGet affine = null;
        Compression compression = N5ScalePyramidExporter.getCompression((String)this.compressionArg);
        this.nd = this.dataset.numDimensions() - 1;
        long[] spatialDims = new long[this.nd];
        double[] offset = new double[this.nd];
        double[] spacing = new double[this.nd];
        Arrays.fill(spacing, 1.0);
        String unit = ((CalibratedAxis)this.dataset.axis(0)).unit();
        int j = 0;
        for (int i = 0; i < this.dataset.numDimensions(); ++i) {
            if (((CalibratedAxis)this.dataset.axis(i)).type().isSpatial()) {
                spatialDims[j] = this.dataset.dimension(i);
                offset[j] = ((CalibratedAxis)this.dataset.axis(i)).calibratedValue(0.0);
                spacing[j++] = this.dataset.averageScale(i);
                continue;
            }
            this.vectorDim = i;
            this.vectorSize = (int)this.dataset.dimension(i);
        }
        this.validateAndWarn();
        int[] chunkSizeSpatial = N5ScalePyramidExporter.parseBlockSize((String)this.chunkSizeArg, (long[])spatialDims);
        int[] chunkSize = IntStream.concat(IntStream.of(this.vectorSize), Arrays.stream(chunkSizeSpatial)).toArray();
        RandomAccessibleInterval vectorAxisFirst = Views.moveAxis((RandomAccessibleInterval)this.dataset, (int)this.vectorDim, (int)0);
        try (N5Writer n5 = new N5Factory().openWriter(this.n5Root);){
            if (this.format.equals("N5")) {
                if (this.outputType.equals(FLOAT32) || this.outputType.equals(FLOAT64)) {
                    RandomAccessibleInterval<T> converted = this.convertIfNecessary((RandomAccessibleInterval<T>)vectorAxisFirst, (S)this.getTargetType());
                    N5DisplacementField.save((N5Writer)n5, (String)this.n5Dataset, affine, converted, (double[])spacing, (double[])offset, (int[])chunkSize, (Compression)compression);
                } else {
                    NativeType quantizedType = (NativeType)this.getTargetType();
                    N5DisplacementField.save((N5Writer)n5, (String)this.n5Dataset, affine, (RandomAccessibleInterval)vectorAxisFirst, (double[])spacing, (double[])offset, (int[])chunkSize, (Compression)compression, (NativeType)quantizedType, (double)this.quantizationError);
                }
            } else if (this.format.equals("NGFF")) {
                DisplacementFieldCoordinateTransform<?> dfieldTform = NgffTransformations.save(n5, this.n5Dataset, vectorAxisFirst, "input", "output", spacing, offset, unit, chunkSize, compression, this.nThreads);
                NgffTransformations.addCoordinateTransformations(n5, "/", dfieldTform);
            }
        }
        catch (Exception e) {
            System.err.println("Failed to write displacement field at " + this.n5Root);
            e.printStackTrace();
        }
    }

    private <T extends RealType<T> & NativeType<T>> T getTargetType() {
        switch (this.outputType) {
            case "FLOAT32": {
                return (T)new FloatType();
            }
            case "FLOAT64": {
                return (T)new DoubleType();
            }
            case "INT16": {
                return (T)new ShortType();
            }
            case "INT8": {
                return (T)new ByteType();
            }
        }
        return null;
    }

    private <T extends RealType<T> & NativeType<T>, S extends RealType<S> & NativeType<S>> RandomAccessibleInterval<S> convertIfNecessary(RandomAccessibleInterval<T> dfield, S targetType) {
        if (((RealType)dfield.getType()).getClass().equals(targetType.getClass())) {
            return dfield;
        }
        Converter conv = new Converter<T, S>(){

            public void convert(T input, S output) {
                output.setReal(input.getRealDouble());
            }
        };
        return Converters.convertRAI(dfield, (Converter)conv, targetType);
    }

    private void validateAndWarn() {
        if (this.vectorSize != this.nd) {
            this.ui.showDialog(String.format("Error: channel dimension size (%d) must match dimensionality (%d). Exiting.", this.vectorSize, this.nd));
            return;
        }
    }

    public void run() {
        this.call();
    }

    @Override
    public Void call() {
        this.process();
        return null;
    }
}

