/*
 * Decompiled with CFR 0.152.
 */
package sc.fiji.labkit.pixel_classification.gpu.compute_cache;

import java.util.Objects;
import net.haesleinhuepf.clij.coremem.enums.NativeTypeEnum;
import net.imglib2.Dimensions;
import net.imglib2.FinalInterval;
import net.imglib2.Interval;
import net.imglib2.util.Intervals;
import sc.fiji.labkit.pixel_classification.gpu.api.GpuApi;
import sc.fiji.labkit.pixel_classification.gpu.api.GpuImage;
import sc.fiji.labkit.pixel_classification.gpu.api.GpuPixelWiseOperation;
import sc.fiji.labkit.pixel_classification.gpu.api.GpuView;
import sc.fiji.labkit.pixel_classification.gpu.api.GpuViews;
import sc.fiji.labkit.pixel_classification.gpu.compute_cache.GpuComputeCache;

public class GpuDerivativeContent
implements GpuComputeCache.Content {
    private final GpuComputeCache cache;
    private final GpuComputeCache.Content input;
    private final int d;

    public GpuDerivativeContent(GpuComputeCache cache, GpuComputeCache.Content input, int d) {
        this.cache = cache;
        this.input = input;
        this.d = d;
    }

    public int hashCode() {
        return Objects.hash(this.input, this.d);
    }

    public boolean equals(Object obj) {
        return obj instanceof GpuDerivativeContent && this.input.equals(((GpuDerivativeContent)obj).input) && this.d == ((GpuDerivativeContent)obj).d;
    }

    @Override
    public void request(Interval interval) {
        this.cache.request(this.input, (Interval)this.requiredInput(interval));
    }

    private FinalInterval requiredInput(Interval interval) {
        long[] border = new long[interval.numDimensions()];
        border[this.d] = 1L;
        return Intervals.expand((Interval)interval, (long[])border);
    }

    private FinalInterval shrink(Interval interval) {
        long[] border = new long[interval.numDimensions()];
        border[this.d] = -1L;
        return Intervals.expand((Interval)interval, (long[])border);
    }

    @Override
    public GpuImage load(Interval interval) {
        GpuApi gpu = this.cache.gpuApi();
        double[] pixelSize = this.cache.pixelSize();
        GpuView source = this.cache.get(this.input, (Interval)this.requiredInput(interval));
        FinalInterval center = this.shrink((Interval)new FinalInterval(source.dimensions()));
        GpuView front = GpuViews.crop(source, (Interval)Intervals.translate((Interval)center, (long)1L, (int)this.d));
        GpuView back = GpuViews.crop(source, (Interval)Intervals.translate((Interval)center, (long)-1L, (int)this.d));
        GpuImage result = gpu.create(Intervals.dimensionsAsLongArray((Dimensions)center), NativeTypeEnum.Float);
        GpuPixelWiseOperation.gpu(gpu).addInput("f", front).addInput("b", back).addInput("factor", 0.5 / pixelSize[this.d]).addOutput("r", result).forEachPixel("r = (f - b) * factor");
        return result;
    }
}

