/*
 * Decompiled with CFR 0.152.
 */
package mpicbg.imglib.algorithm.fft;

import java.util.ArrayList;
import java.util.Collections;
import java.util.concurrent.atomic.AtomicInteger;
import mpicbg.imglib.algorithm.Algorithm;
import mpicbg.imglib.algorithm.Benchmark;
import mpicbg.imglib.algorithm.MultiThreaded;
import mpicbg.imglib.algorithm.fft.FourierTransform;
import mpicbg.imglib.algorithm.fft.InverseFourierTransform;
import mpicbg.imglib.algorithm.fft.PhaseCorrelationPeak;
import mpicbg.imglib.cursor.Cursor;
import mpicbg.imglib.cursor.LocalizableByDimCursor;
import mpicbg.imglib.cursor.special.LocalNeighborhoodCursor;
import mpicbg.imglib.cursor.special.RegionOfInterestCursor;
import mpicbg.imglib.image.Image;
import mpicbg.imglib.multithreading.SimpleMultiThreading;
import mpicbg.imglib.outofbounds.OutOfBoundsStrategyPeriodicFactory;
import mpicbg.imglib.type.numeric.RealType;
import mpicbg.imglib.type.numeric.complex.ComplexFloatType;
import mpicbg.imglib.type.numeric.real.FloatType;
import mpicbg.imglib.util.Util;

public class PhaseCorrelation<T extends RealType<T>, S extends RealType<S>>
implements MultiThreaded,
Algorithm,
Benchmark {
    final int numDimensions;
    boolean computeFFTinParalell = true;
    boolean keepPCM = false;
    Image<T> image1;
    Image<S> image2;
    Image<FloatType> invPCM;
    int numPeaks;
    int[] minOverlapPx;
    float normalizationThreshold;
    boolean verifyWithCrossCorrelation;
    ArrayList<PhaseCorrelationPeak> phaseCorrelationPeaks;
    String errorMessage = "";
    int numThreads;
    long processingTime;

    public PhaseCorrelation(Image<T> image1, Image<S> image2, int numPeaks, boolean verifyWithCrossCorrelation) {
        this.image1 = image1;
        this.image2 = image2;
        this.numPeaks = numPeaks;
        this.verifyWithCrossCorrelation = verifyWithCrossCorrelation;
        this.numDimensions = image1.getNumDimensions();
        this.normalizationThreshold = 1.0E-5f;
        this.minOverlapPx = new int[this.numDimensions];
        this.setMinimalPixelOverlap(3);
        this.setNumThreads();
        this.processingTime = -1L;
    }

    public PhaseCorrelation(Image<T> image1, Image<S> image2) {
        this(image1, image2, 5, true);
    }

    public void setComputeFFTinParalell(boolean computeFFTinParalell) {
        this.computeFFTinParalell = computeFFTinParalell;
    }

    public void setInvestigateNumPeaks(int numPeaks) {
        this.numPeaks = numPeaks;
    }

    public void setKeepPhaseCorrelationMatrix(boolean keepPCM) {
        this.keepPCM = keepPCM;
    }

    public void setNormalizationThreshold(int normalizationThreshold) {
        this.normalizationThreshold = normalizationThreshold;
    }

    public void setVerifyWithCrossCorrelation(boolean verifyWithCrossCorrelation) {
        this.verifyWithCrossCorrelation = verifyWithCrossCorrelation;
    }

    public void setMinimalPixelOverlap(int[] minOverlapPx) {
        this.minOverlapPx = (int[])minOverlapPx.clone();
    }

    public void setMinimalPixelOverlap(int minOverlapPx) {
        for (int d = 0; d < this.numDimensions; ++d) {
            this.minOverlapPx[d] = minOverlapPx;
        }
    }

    public boolean getComputeFFTinParalell() {
        return this.computeFFTinParalell;
    }

    public int getInvestigateNumPeaks() {
        return this.numPeaks;
    }

    public boolean getKeepPhaseCorrelationMatrix() {
        return this.keepPCM;
    }

    public float getNormalizationThreshold() {
        return this.normalizationThreshold;
    }

    public boolean getVerifyWithCrossCorrelation() {
        return this.verifyWithCrossCorrelation;
    }

    public int[] getMinimalPixelOverlap() {
        return (int[])this.minOverlapPx.clone();
    }

    public Image<FloatType> getPhaseCorrelationMatrix() {
        return this.invPCM;
    }

    public PhaseCorrelationPeak getShift() {
        return this.phaseCorrelationPeaks.get(this.phaseCorrelationPeaks.size() - 1);
    }

    public ArrayList<PhaseCorrelationPeak> getAllShifts() {
        return this.phaseCorrelationPeaks;
    }

    @Override
    public boolean process() {
        int[] maxDim = PhaseCorrelation.getMaxDim(this.image1, this.image2);
        FourierTransform<T, ComplexFloatType> fft1 = new FourierTransform<T, ComplexFloatType>(this.image1, new ComplexFloatType());
        FourierTransform<S, ComplexFloatType> fft2 = new FourierTransform<S, ComplexFloatType>(this.image2, new ComplexFloatType());
        fft1.setRelativeImageExtension(0.1f);
        fft2.setRelativeImageExtension(0.1f);
        fft1.setRelativeFadeOutDistance(0.1f);
        fft2.setRelativeFadeOutDistance(0.1f);
        fft1.setRearrangement(FourierTransform.Rearrangement.UNCHANGED);
        fft2.setRearrangement(FourierTransform.Rearrangement.UNCHANGED);
        boolean sizeFound = false;
        do {
            sizeFound = true;
            fft1.setExtendedOriginalImageSize(maxDim);
            fft2.setExtendedOriginalImageSize(maxDim);
            for (int d = 0; d < this.numDimensions; ++d) {
                int diff = Math.abs(fft1.getExtendedSize()[d] - fft2.getExtendedSize()[d]);
                if (diff <= 0) continue;
                int n = d;
                maxDim[n] = maxDim[n] + diff;
                sizeFound = false;
            }
        } while (!sizeFound);
        if (!fft1.checkInput()) {
            this.errorMessage = "Fourier Transform of first image failed: " + fft1.getErrorMessage();
            return false;
        }
        if (!fft2.checkInput()) {
            this.errorMessage = "Fourier Transform of second image failed: " + fft2.getErrorMessage();
            return false;
        }
        if (!this.computeFFT(fft1, fft2)) {
            this.errorMessage = "Fourier Transform of failed: fft1=" + fft1.getErrorMessage() + " fft2=" + fft2.getErrorMessage();
            return false;
        }
        Image<ComplexFloatType> fftImage1 = fft1.getResult();
        Image<ComplexFloatType> fftImage2 = fft2.getResult();
        this.normalizeAndConjugate(fftImage1, fftImage2);
        this.multiplyInPlace(fftImage1, fftImage2);
        InverseFourierTransform<FloatType, ComplexFloatType> invFFT = new InverseFourierTransform<FloatType, ComplexFloatType>(fftImage1, fft1, new FloatType());
        invFFT.setInPlaceTransform(true);
        invFFT.setCropBackToOriginalSize(false);
        if (!invFFT.checkInput() || !invFFT.process()) {
            this.errorMessage = "Inverse Fourier Transform of failed: " + invFFT.getErrorMessage();
            return false;
        }
        fftImage1.close();
        fftImage2.close();
        this.invPCM = invFFT.getResult();
        this.phaseCorrelationPeaks = this.extractPhaseCorrelationPeaks(this.invPCM, this.numPeaks, fft1, fft2);
        if (!this.verifyWithCrossCorrelation) {
            return true;
        }
        this.verifyWithCrossCorrelation(this.phaseCorrelationPeaks, this.invPCM.getDimensions(), this.image1, this.image2);
        if (!this.keepPCM) {
            this.invPCM.close();
        }
        return true;
    }

    protected void verifyWithCrossCorrelation(ArrayList<PhaseCorrelationPeak> peakList, int[] dimInvPCM, final Image<T> image1, final Image<S> image2) {
        boolean[][] coordinates = Util.getRecursiveCoordinates(this.numDimensions);
        final ArrayList<PhaseCorrelationPeak> newPeakList = new ArrayList<PhaseCorrelationPeak>();
        for (PhaseCorrelationPeak peak : peakList) {
            for (int i = 0; i < coordinates.length; ++i) {
                boolean[] currentPossiblity = coordinates[i];
                int[] peakPosition = peak.getPosition();
                for (int d = 0; d < currentPossiblity.length; ++d) {
                    if (!currentPossiblity[d]) continue;
                    if (peakPosition[d] < 0) {
                        int n = d;
                        peakPosition[n] = peakPosition[n] + dimInvPCM[d];
                        continue;
                    }
                    int n = d;
                    peakPosition[n] = peakPosition[n] - dimInvPCM[d];
                }
                PhaseCorrelationPeak newPeak = new PhaseCorrelationPeak(peakPosition, peak.getPhaseCorrelationPeak());
                newPeak.setOriginalInvPCMPosition(peak.getOriginalInvPCMPosition());
                newPeakList.add(newPeak);
            }
        }
        final AtomicInteger ai = new AtomicInteger(0);
        Thread[] threads = SimpleMultiThreading.newThreads(this.getNumThreads());
        final int numThreads = threads.length;
        for (int ithread = 0; ithread < threads.length; ++ithread) {
            threads[ithread] = new Thread(new Runnable(){

                @Override
                public void run() {
                    int myNumber = ai.getAndIncrement();
                    for (int i = 0; i < newPeakList.size(); ++i) {
                        if (i % numThreads != myNumber) continue;
                        PhaseCorrelationPeak peak = (PhaseCorrelationPeak)newPeakList.get(i);
                        long[] numPixels = new long[1];
                        peak.setCrossCorrelationPeak((float)PhaseCorrelation.testCrossCorrelation(peak.getPosition(), image1, image2, PhaseCorrelation.this.minOverlapPx, numPixels));
                        peak.setNumPixels(numPixels[0]);
                        peak.setSortPhaseCorrelation(false);
                    }
                }
            });
        }
        SimpleMultiThreading.startAndJoin(threads);
        peakList.clear();
        peakList.addAll(newPeakList);
        Collections.sort(peakList);
    }

    public static <T extends RealType<T>, S extends RealType<S>> double testCrossCorrelation(int[] shift, Image<T> image1, Image<S> image2) {
        return PhaseCorrelation.testCrossCorrelation(shift, image1, image2, 5);
    }

    public static <T extends RealType<T>, S extends RealType<S>> double testCrossCorrelation(int[] shift, Image<T> image1, Image<S> image2, int minOverlapPx) {
        return PhaseCorrelation.testCrossCorrelation(shift, image1, image2, minOverlapPx, null);
    }

    public static <T extends RealType<T>, S extends RealType<S>> double testCrossCorrelation(int[] shift, Image<T> image1, Image<S> image2, int minOverlapPx, long[] numPixels) {
        return PhaseCorrelation.testCrossCorrelation(shift, image1, image2, Util.getArrayFromValue(minOverlapPx, image1.getNumDimensions()), numPixels);
    }

    public static <T extends RealType<T>, S extends RealType<S>> double testCrossCorrelation(int[] shift, Image<T> image1, Image<S> image2, int[] minOverlapPx) {
        return PhaseCorrelation.testCrossCorrelation(shift, image1, image2, minOverlapPx, null);
    }

    public static <T extends RealType<T>, S extends RealType<S>> double testCrossCorrelation(int[] shift, Image<T> image1, Image<S> image2, int[] minOverlapPx, long[] numPixels) {
        int numDimensions = image1.getNumDimensions();
        double correlationCoefficient = 0.0;
        int[] overlapSize = new int[numDimensions];
        int[] offsetImage1 = new int[numDimensions];
        int[] offsetImage2 = new int[numDimensions];
        long numPx = 1L;
        for (int d = 0; d < numDimensions; ++d) {
            if (shift[d] >= 0) {
                if (shift[d] >= image1.getDimension(d)) {
                    if (numPixels != null && numPixels.length > 0) {
                        numPixels[0] = 0L;
                    }
                    return 0.0;
                }
                offsetImage1[d] = shift[d];
                offsetImage2[d] = 0;
                overlapSize[d] = Math.min(image1.getDimension(d) - shift[d], image2.getDimension(d));
            } else {
                if (shift[d] >= image2.getDimension(d)) {
                    if (numPixels != null && numPixels.length > 0) {
                        numPixels[0] = 0L;
                    }
                    return 0.0;
                }
                offsetImage1[d] = 0;
                offsetImage2[d] = -shift[d];
                overlapSize[d] = Math.min(image2.getDimension(d) + shift[d], image1.getDimension(d));
            }
            numPx *= (long)overlapSize[d];
            if (overlapSize[d] >= minOverlapPx[d]) continue;
            if (numPixels != null && numPixels.length > 0) {
                numPixels[0] = 0L;
            }
            return 0.0;
        }
        if (numPixels != null && numPixels.length > 0) {
            numPixels[0] = numPx;
        }
        LocalizableByDimCursor<T> cursor1 = image1.createLocalizableByDimCursor();
        LocalizableByDimCursor<S> cursor2 = image2.createLocalizableByDimCursor();
        RegionOfInterestCursor<T> roiCursor1 = cursor1.createRegionOfInterestCursor(offsetImage1, overlapSize);
        RegionOfInterestCursor<S> roiCursor2 = cursor2.createRegionOfInterestCursor(offsetImage2, overlapSize);
        double avg1 = 0.0;
        double avg2 = 0.0;
        while (roiCursor1.hasNext()) {
            roiCursor1.fwd();
            roiCursor2.fwd();
            avg1 += (double)((RealType)cursor1.getType()).getRealFloat();
            avg2 += (double)((RealType)cursor2.getType()).getRealFloat();
        }
        avg1 /= (double)numPx;
        avg2 /= (double)numPx;
        roiCursor1.reset();
        roiCursor2.reset();
        double var1 = 0.0;
        double var2 = 0.0;
        double coVar = 0.0;
        while (roiCursor1.hasNext()) {
            roiCursor1.fwd();
            roiCursor2.fwd();
            float pixel1 = ((RealType)cursor1.getType()).getRealFloat();
            float pixel2 = ((RealType)cursor2.getType()).getRealFloat();
            double dist1 = (double)pixel1 - avg1;
            double dist2 = (double)pixel2 - avg2;
            coVar += dist1 * dist2;
            var1 += dist1 * dist1;
            var2 += dist2 * dist2;
        }
        coVar /= (double)numPx;
        double stDev1 = Math.sqrt(var1 /= (double)numPx);
        double stDev2 = Math.sqrt(var2 /= (double)numPx);
        if (stDev1 == 0.0 || stDev2 == 0.0) {
            if (stDev1 == stDev2 && avg1 == avg2) {
                return 1.0;
            }
            return 0.0;
        }
        correlationCoefficient = coVar / (stDev1 * stDev2);
        roiCursor1.close();
        roiCursor2.close();
        cursor1.close();
        cursor2.close();
        return correlationCoefficient;
    }

    protected ArrayList<PhaseCorrelationPeak> extractPhaseCorrelationPeaks(Image<FloatType> invPCM, int numPeaks, FourierTransform<?, ?> fft1, FourierTransform<?, ?> fft2) {
        ArrayList<PhaseCorrelationPeak> peakList = new ArrayList<PhaseCorrelationPeak>();
        for (int i = 0; i < numPeaks; ++i) {
            peakList.add(new PhaseCorrelationPeak(new int[this.numDimensions], -3.4028235E38f));
        }
        LocalizableByDimCursor<FloatType> cursor = invPCM.createLocalizableByDimCursor(new OutOfBoundsStrategyPeriodicFactory());
        LocalNeighborhoodCursor<FloatType> localCursor = cursor.createLocalNeighborhoodCursor();
        int[] originalOffset1 = fft1.getOriginalOffset();
        int[] originalOffset2 = fft2.getOriginalOffset();
        int[] offset = new int[this.numDimensions];
        for (int d = 0; d < this.numDimensions; ++d) {
            offset[d] = originalOffset2[d] - originalOffset1[d];
        }
        int[] imgSize = invPCM.getDimensions();
        while (cursor.hasNext()) {
            cursor.fwd();
            localCursor.update();
            float value = ((FloatType)cursor.getType()).get();
            boolean isMax = true;
            while (localCursor.hasNext() && isMax) {
                localCursor.fwd();
                isMax = ((FloatType)cursor.getType()).get() <= value;
            }
            localCursor.reset();
            if (!isMax) continue;
            float lowestValue = Float.MAX_VALUE;
            int lowestValueIndex = -1;
            for (int i = 0; i < numPeaks; ++i) {
                float v = peakList.get(i).getPhaseCorrelationPeak();
                if (!(v < lowestValue)) continue;
                lowestValue = v;
                lowestValueIndex = i;
            }
            if (!(value > lowestValue)) continue;
            peakList.remove(lowestValueIndex);
            int[] position = cursor.getPosition();
            for (int d = 0; d < this.numDimensions; ++d) {
                position[d] = (position[d] + offset[d]) % imgSize[d];
                if (position[d] <= imgSize[d] / 2) continue;
                position[d] = position[d] - imgSize[d];
            }
            PhaseCorrelationPeak pcp = new PhaseCorrelationPeak(position, value);
            pcp.setOriginalInvPCMPosition(cursor.getPosition());
            peakList.add(pcp);
        }
        Collections.sort(peakList);
        return peakList;
    }

    protected static int[] getMaxDim(Image<?> image1, Image<?> image2) {
        int[] maxDim = new int[image1.getNumDimensions()];
        for (int d = 0; d < image1.getNumDimensions(); ++d) {
            maxDim[d] = Math.max(image1.getDimension(d), image2.getDimension(d));
        }
        return maxDim;
    }

    protected void multiplyInPlace(Image<ComplexFloatType> fftImage1, Image<ComplexFloatType> fftImage2) {
        Cursor<ComplexFloatType> cursor1 = fftImage1.createCursor();
        Cursor<ComplexFloatType> cursor2 = fftImage2.createCursor();
        while (cursor1.hasNext()) {
            cursor1.fwd();
            cursor2.fwd();
            cursor1.getType().mul(cursor2.getType());
        }
        cursor1.close();
        cursor2.close();
    }

    protected void normalizeAndConjugate(final Image<ComplexFloatType> fftImage1, final Image<ComplexFloatType> fftImage2) {
        final AtomicInteger ai = new AtomicInteger(0);
        Thread[] threads = SimpleMultiThreading.newThreads(Math.min(2, this.numThreads));
        final int numThreads = threads.length;
        for (int ithread = 0; ithread < threads.length; ++ithread) {
            threads[ithread] = new Thread(new Runnable(){

                @Override
                public void run() {
                    int myNumber = ai.getAndIncrement();
                    if (numThreads == 1) {
                        PhaseCorrelation.normalizeComplexImage(fftImage1, PhaseCorrelation.this.normalizationThreshold);
                        PhaseCorrelation.normalizeAndConjugateComplexImage(fftImage2, PhaseCorrelation.this.normalizationThreshold);
                    } else if (myNumber == 0) {
                        PhaseCorrelation.normalizeComplexImage(fftImage1, PhaseCorrelation.this.normalizationThreshold);
                    } else {
                        PhaseCorrelation.normalizeAndConjugateComplexImage(fftImage2, PhaseCorrelation.this.normalizationThreshold);
                    }
                }
            });
        }
        SimpleMultiThreading.startAndJoin(threads);
    }

    private static final void normalizeComplexImage(Image<ComplexFloatType> fftImage, float normalizationThreshold) {
        Cursor<ComplexFloatType> cursor = fftImage.createCursor();
        while (cursor.hasNext()) {
            cursor.fwd();
            PhaseCorrelation.normalizeLength(cursor.getType(), normalizationThreshold);
        }
        cursor.close();
    }

    private static final void normalizeAndConjugateComplexImage(Image<ComplexFloatType> fftImage, float normalizationThreshold) {
        Cursor<ComplexFloatType> cursor = fftImage.createCursor();
        while (cursor.hasNext()) {
            cursor.fwd();
            PhaseCorrelation.normalizeLength(cursor.getType(), normalizationThreshold);
            cursor.getType().complexConjugate();
        }
        cursor.close();
    }

    private static void normalizeLength(ComplexFloatType type, float threshold) {
        float complex;
        float real = type.getRealFloat();
        float length = (float)Math.sqrt(real * real + (complex = type.getComplexFloat()) * complex);
        if (length < threshold) {
            type.setReal(0.0f);
            type.setComplex(0.0f);
        } else {
            type.setReal(real / length);
            type.setComplex(complex / length);
        }
    }

    protected boolean computeFFT(final FourierTransform<T, ComplexFloatType> fft1, final FourierTransform<S, ComplexFloatType> fft2) {
        int minThreads = this.computeFFTinParalell ? 2 : 1;
        final AtomicInteger ai = new AtomicInteger(0);
        Thread[] threads = SimpleMultiThreading.newThreads(Math.min(minThreads, this.numThreads));
        final int numThreads = threads.length;
        final boolean[] sucess = new boolean[2];
        for (int ithread = 0; ithread < threads.length; ++ithread) {
            threads[ithread] = new Thread(new Runnable(){

                @Override
                public void run() {
                    int myNumber = ai.getAndIncrement();
                    if (numThreads == 1) {
                        fft1.setNumThreads(PhaseCorrelation.this.getNumThreads());
                        fft2.setNumThreads(PhaseCorrelation.this.getNumThreads());
                        sucess[0] = fft1.process();
                        sucess[1] = fft2.process();
                    } else if (myNumber == 0) {
                        fft1.setNumThreads(PhaseCorrelation.this.getNumThreads() / 2);
                        sucess[0] = fft1.process();
                    } else {
                        fft2.setNumThreads(PhaseCorrelation.this.getNumThreads() / 2);
                        sucess[1] = fft2.process();
                    }
                }
            });
        }
        SimpleMultiThreading.startAndJoin(threads);
        return sucess[0] && sucess[1];
    }

    @Override
    public long getProcessingTime() {
        return this.processingTime;
    }

    @Override
    public void setNumThreads() {
        this.numThreads = Runtime.getRuntime().availableProcessors();
    }

    @Override
    public void setNumThreads(int numThreads) {
        this.numThreads = numThreads;
    }

    @Override
    public int getNumThreads() {
        return this.numThreads;
    }

    @Override
    public boolean checkInput() {
        if (this.errorMessage.length() > 0) {
            return false;
        }
        if (this.image1 == null || this.image2 == null) {
            this.errorMessage = "One of the input images is null";
            return false;
        }
        if (this.image1.getNumDimensions() != this.image2.getNumDimensions()) {
            this.errorMessage = "Dimensionality of images is not the same";
            return false;
        }
        return true;
    }

    @Override
    public String getErrorMessage() {
        return this.errorMessage;
    }
}

