/*
 * Decompiled with CFR 0.152.
 */
package sc.fiji.coloc.algorithms;

import ij.IJ;
import java.util.Arrays;
import net.imglib2.Cursor;
import net.imglib2.PairIterator;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.TwinCursor;
import net.imglib2.type.logic.BitType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.view.Views;
import sc.fiji.coloc.algorithms.Algorithm;
import sc.fiji.coloc.algorithms.IntArraySorter;
import sc.fiji.coloc.algorithms.IntComparator;
import sc.fiji.coloc.algorithms.MissingPreconditionException;
import sc.fiji.coloc.gadgets.DataContainer;
import sc.fiji.coloc.results.ResultHandler;

public class KendallTauRankCorrelation<T extends RealType<T>>
extends Algorithm<T> {
    private double tau;

    public KendallTauRankCorrelation() {
        super("Kendall's Tau-b Rank Correlation");
    }

    @Override
    public void execute(DataContainer<T> container) throws MissingPreconditionException {
        RandomAccessibleInterval<T> img1 = container.getSourceImage1();
        RandomAccessibleInterval<T> img2 = container.getSourceImage2();
        RandomAccessibleInterval<BitType> mask = container.getMask();
        TwinCursor cursor = new TwinCursor(img1.randomAccess(), img2.randomAccess(), (Cursor<BitType>)Views.iterable(mask).localizingCursor());
        this.tau = KendallTauRankCorrelation.calculateMergeSort(cursor);
    }

    public static <T extends RealType<T>> double calculateNaive(PairIterator<T> iterator) {
        if (!iterator.hasNext()) {
            return Double.NaN;
        }
        int n = 0;
        int max1 = 0;
        int max2 = 0;
        int max = 255;
        int[][] histogram = new int[max + 1][max + 1];
        while (iterator.hasNext()) {
            iterator.fwd();
            RealType type1 = (RealType)iterator.getFirst();
            RealType type2 = (RealType)iterator.getSecond();
            double ch1 = type1.getRealDouble();
            double ch2 = type2.getRealDouble();
            if (ch1 < 0.0 || ch2 < 0.0 || ch1 > (double)max || ch2 > (double)max) {
                IJ.log((String)"Error: The current Kendall Tau implementation is limited to 8-bit data");
                return Double.NaN;
            }
            ++n;
            int ch1Int = (int)Math.round(ch1);
            int ch2Int = (int)Math.round(ch2);
            int[] nArray = histogram[ch1Int];
            int n2 = ch2Int;
            nArray[n2] = nArray[n2] + 1;
            if (max1 < ch1Int) {
                max1 = ch1Int;
            }
            if (max2 >= ch2Int) continue;
            max2 = ch2Int;
        }
        long n0 = (long)n * (long)(n - 1) / 2L;
        long n1 = 0L;
        long n2 = 0L;
        long nc = 0L;
        long nd = 0L;
        for (int i1 = 0; i1 <= max1; ++i1) {
            IJ.log((String)("" + i1 + "/" + max1));
            int ch1 = 0;
            for (int i2 = 0; i2 <= max2; ++i2) {
                int j2;
                int j1;
                ch1 += histogram[i1][i2];
                int count = histogram[i1][i2];
                for (j1 = 0; j1 < i1; ++j1) {
                    for (j2 = 0; j2 < i2; ++j2) {
                        nc += (long)(count * histogram[j1][j2]);
                    }
                }
                for (j1 = 0; j1 < i1; ++j1) {
                    for (j2 = i2 + 1; j2 <= max2; ++j2) {
                        nd += (long)(count * histogram[j1][j2]);
                    }
                }
            }
            n1 += (long)ch1 * (long)(ch1 - 1) / 2L;
        }
        for (int i2 = 0; i2 <= max2; ++i2) {
            int ch2 = 0;
            for (int i1 = 0; i1 <= max1; ++i1) {
                ch2 += histogram[i1][i2];
            }
            n2 += (long)ch2 * (long)(ch2 - 1) / 2L;
        }
        return (double)(nc - nd) / Math.sqrt((double)(n0 - n1) * (double)(n0 - n2));
    }

    private static <T extends RealType<T>> double[][] getPairs(PairIterator<T> iterator) {
        int capacity = 0;
        while (iterator.hasNext()) {
            iterator.fwd();
            ++capacity;
        }
        double[] values1 = new double[capacity];
        double[] values2 = new double[capacity];
        iterator.reset();
        int count = 0;
        while (iterator.hasNext()) {
            iterator.fwd();
            values1[count] = ((RealType)iterator.getFirst()).getRealDouble();
            values2[count] = ((RealType)iterator.getSecond()).getRealDouble();
            ++count;
        }
        if (count < capacity) {
            values1 = Arrays.copyOf(values1, count);
            values2 = Arrays.copyOf(values2, count);
        }
        return new double[][]{values1, values2};
    }

    public static <T extends RealType<T>> double calculateMergeSort(PairIterator<T> iterator) {
        double[][] pairs = KendallTauRankCorrelation.getPairs(iterator);
        final double[] x = pairs[0];
        final double[] y = pairs[1];
        int n = x.length;
        int[] index = new int[n];
        for (int i = 0; i < n; ++i) {
            index[i] = i;
        }
        IntArraySorter.sort(index, new IntComparator(){

            @Override
            public int compare(int a, int b) {
                double xa = x[a];
                double ya = y[a];
                double xb = x[b];
                double yb = y[b];
                int result = Double.compare(xa, xb);
                return result != 0 ? result : Double.compare(ya, yb);
            }
        });
        long n0 = (long)n * (long)(n - 1) / 2L;
        long n1 = 0L;
        long n3 = 0L;
        for (int i = 1; i < n; ++i) {
            double x0 = x[index[i - 1]];
            if (x[index[i]] != x0) continue;
            double y0 = y[index[i - 1]];
            int i1 = i;
            do {
                double y1;
                if ((y1 = y[index[i1++]]) == y0) {
                    int i2 = i1;
                    while (i1 < n && x[index[i1]] == x0 && y[index[i1]] == y0) {
                        ++i1;
                    }
                    n3 += (long)(i1 - i2 + 2) * (long)(i1 - i2 + 1) / 2L;
                }
                y0 = y1;
            } while (i1 < n && x[index[i1]] == x0);
            n1 += (long)(i1 - i + 1) * (long)(i1 - i) / 2L;
            i = i1;
        }
        MergeSort mergeSort = new MergeSort(index, new IntComparator(){

            @Override
            public int compare(int a, int b) {
                double ya = y[a];
                double yb = y[b];
                return Double.compare(ya, yb);
            }
        });
        long S = mergeSort.sort();
        index = mergeSort.getSorted();
        long n2 = 0L;
        for (int i = 1; i < n; ++i) {
            int i1;
            double y0 = y[index[i - 1]];
            if (y[index[i]] != y0) continue;
            for (i1 = i + 1; i1 < n && y[index[i1]] == y0; ++i1) {
            }
            n2 += (long)(i1 - i + 1) * (long)(i1 - i) / 2L;
            i = i1;
        }
        return (double)(n0 - n1 - n2 + n3 - 2L * S) / Math.sqrt((double)(n0 - n1) * (double)(n0 - n2));
    }

    @Override
    public void processResults(ResultHandler<T> handler) {
        super.processResults(handler);
        handler.handleValue("Kendall's Tau-b rank correlation value", this.tau, 4);
    }

    private static final class MergeSort {
        private int[] index;
        private final IntComparator comparator;

        public MergeSort(int[] index, IntComparator comparator) {
            this.index = index;
            this.comparator = comparator;
        }

        public int[] getSorted() {
            return this.index;
        }

        public long sort() {
            long swaps = 0L;
            int n = this.index.length;
            int[] index2 = new int[n];
            for (int step = 1; step < n; step <<= 1) {
                int begin = 0;
                int k = 0;
                while (true) {
                    int begin2;
                    int end;
                    if ((end = (begin2 = begin + step) + step) >= n) {
                        if (begin2 >= n) break;
                        end = n;
                    }
                    int i = begin;
                    int j = begin2;
                    while (i < begin2 && j < end) {
                        int compare = this.comparator.compare(this.index[i], this.index[j]);
                        if (compare > 0) {
                            swaps += (long)(begin2 - i);
                            index2[k++] = this.index[j++];
                            continue;
                        }
                        index2[k++] = this.index[i++];
                    }
                    if (i < begin2) {
                        do {
                            index2[k++] = this.index[i++];
                        } while (i < begin2);
                    } else {
                        while (j < end) {
                            index2[k++] = this.index[j++];
                        }
                    }
                    begin = end;
                }
                if (k < n) {
                    System.arraycopy(this.index, k, index2, k, n - k);
                }
                int[] swapIndex = index2;
                index2 = this.index;
                this.index = swapIndex;
            }
            return swaps;
        }
    }
}

