/*
 * Decompiled with CFR 0.152.
 */
package edu.mines.jtk.util;

import edu.mines.jtk.util.ArrayMath;
import edu.mines.jtk.util.Parallel;
import edu.mines.jtk.util.Stopwatch;
import java.util.Random;
import junit.framework.Assert;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;
import junit.textui.TestRunner;

public class ParallelTest
extends TestCase {
    public static void main(String[] args) {
        if (args.length > 0 && args[0].equals("bench")) {
            boolean parallel = args.length <= 1 || !args[1].equals("serial");
            ParallelTest.bench(parallel);
        }
        TestSuite suite = new TestSuite(ParallelTest.class);
        TestRunner.run((Test)suite);
    }

    public void testRandom() {
        for (int ntest = 0; ntest < 1000; ++ntest) {
            this.oneRandomTest();
        }
    }

    private void oneRandomTest() {
        Random r = new Random();
        int n = 100 + r.nextInt(100);
        int begin = r.nextInt(n);
        int end = begin + 1 + r.nextInt(n - begin);
        int step = 1 + r.nextInt(6);
        int chunk = 1 + r.nextInt(4);
        float[] a = ArrayMath.randfloat(n);
        float[] bs = ArrayMath.zerofloat(n);
        float[] bp = ArrayMath.zerofloat(n);
        this.sqrS(begin, end, step, a, bs);
        this.sqrP(begin, end, step, chunk, a, bp);
        ParallelTest.assertEquals(bs, bp, 0.0f);
        float ss = this.sumS(begin, end, step, a);
        float sp = this.sumP(begin, end, step, chunk, a);
        ParallelTest.assertEquals((float)ss, (float)sp, (float)(1.0E-4f * ArrayMath.max(ss, sp)));
    }

    private void sqrS(int begin, int end, int step, float[] a, float[] b) {
        for (int i = begin; i < end; i += step) {
            b[i] = a[i] * a[i];
        }
    }

    private void sqrP(int begin, int end, int step, int chunk, final float[] a, final float[] b) {
        Parallel.loop(begin, end, step, chunk, new Parallel.LoopInt(){

            @Override
            public void compute(int i) {
                b[i] = a[i] * a[i];
            }
        });
    }

    private float sumS(int begin, int end, int step, float[] a) {
        float s = 0.0f;
        for (int i = begin; i < end; i += step) {
            s += a[i];
        }
        return s;
    }

    private float sumP(int begin, int end, int step, int chunk, final float[] a) {
        return Parallel.reduce(begin, end, step, chunk, new Parallel.ReduceInt<Float>(){

            @Override
            public Float compute(int i) {
                return Float.valueOf(a[i]);
            }

            @Override
            public Float combine(Float s1, Float s2) {
                return Float.valueOf(s1.floatValue() + s2.floatValue());
            }
        }).floatValue();
    }

    public void testUnsafe() {
        final Parallel.Unsafe nts = new Parallel.Unsafe();
        Parallel.loop(20, new Parallel.LoopInt(){

            @Override
            public void compute(int i) {
                Worker w = (Worker)nts.get();
                if (w == null) {
                    w = new Worker();
                    nts.set(w);
                }
                w.work();
            }
        });
    }

    private static void assertEquals(float[] e, float[] a, float t) {
        int n = e.length;
        for (int i = 0; i < n; ++i) {
            ParallelTest.assertEquals((float)e[i], (float)a[i], (float)t);
        }
    }

    private static void trace(String s) {
        System.out.println(s);
    }

    private static void benchArraySqr() {
        int n1 = 501;
        int n2 = 502;
        int n3 = 503;
        System.out.println("Array sqr: n1=" + n1 + " n2=" + n2 + " n3=" + n3);
        double maxtime = 5.0;
        double mflop2 = 1.0E-6 * (double)n1 * (double)n2;
        double mflop3 = 1.0E-6 * (double)n1 * (double)n2 * (double)n3;
        Stopwatch sw = new Stopwatch();
        float[][][] a = ArrayMath.sub(ArrayMath.randfloat(n1, n2, n3), 0.5f);
        float[][][] bs = ArrayMath.copy(a);
        float[][][] bp = ArrayMath.copy(a);
        for (int ntest = 0; ntest < 3; ++ntest) {
            sw.restart();
            int niter = 0;
            while (sw.time() < maxtime) {
                ParallelTest.sqrS(a[0], bs[0]);
                ++niter;
            }
            sw.stop();
            int rate = (int)((double)niter * mflop2 / sw.time());
            System.out.println("2D S: rate = " + rate);
            sw.restart();
            niter = 0;
            while (sw.time() < maxtime) {
                ParallelTest.sqrP(a[0], bp[0]);
                ++niter;
            }
            sw.stop();
            rate = (int)((double)niter * mflop2 / sw.time());
            System.out.println("2D P: rate = " + rate);
            sw.restart();
            niter = 0;
            while (sw.time() < maxtime) {
                ParallelTest.sqrS(a, bs);
                ++niter;
            }
            sw.stop();
            rate = (int)((double)niter * mflop3 / sw.time());
            System.out.println("3D S: rate = " + rate);
            sw.restart();
            niter = 0;
            while (sw.time() < maxtime) {
                ParallelTest.sqrP(a, bp);
                ++niter;
            }
            sw.stop();
            rate = (int)((double)niter * mflop3 / sw.time());
            System.out.println("3D P: rate = " + rate);
        }
    }

    private static void sqr(float[] a, float[] b) {
        int n = a.length;
        for (int i = 0; i < n; ++i) {
            b[i] = a[i] * a[i];
        }
    }

    private static void sqrS(float[][] a, float[][] b) {
        int n = a.length;
        for (int i = 0; i < n; ++i) {
            ParallelTest.sqr(a[i], b[i]);
        }
    }

    private static void sqrS(float[][][] a, float[][][] b) {
        int n = a.length;
        for (int i = 0; i < n; ++i) {
            ParallelTest.sqrS(a[i], b[i]);
        }
    }

    private static void sqrP(final float[][] a, final float[][] b) {
        int n = a.length;
        int chunk = ArrayMath.max(1, 10000 / a[0].length);
        Parallel.loop(0, n, 1, chunk, new Parallel.LoopInt(){

            @Override
            public void compute(int i) {
                ParallelTest.sqr(a[i], b[i]);
            }
        });
    }

    private static void sqrP(final float[][][] a, final float[][][] b) {
        int n = a.length;
        Parallel.loop(n, new Parallel.LoopInt(){

            @Override
            public void compute(int i) {
                ParallelTest.sqrP(a[i], b[i]);
            }
        });
    }

    private static void benchArraySum() {
        int n1 = 501;
        int n2 = 502;
        int n3 = 503;
        System.out.println("Array sum: n1=" + n1 + " n2=" + n2 + " n3=" + n3);
        double maxtime = 5.0;
        double mflop2 = 1.0E-6 * (double)n1 * (double)n2;
        double mflop3 = 1.0E-6 * (double)n1 * (double)n2 * (double)n3;
        Stopwatch sw = new Stopwatch();
        float[][][] a = ArrayMath.sub(ArrayMath.randfloat(n1, n2, n3), 0.5f);
        for (int ntest = 0; ntest < 3; ++ntest) {
            float ss = 0.0f;
            float sp = 0.0f;
            sw.restart();
            int niter = 0;
            while (sw.time() < maxtime) {
                ss = ParallelTest.sumS(a[0]);
                ++niter;
            }
            sw.stop();
            int rate = (int)((double)niter * mflop2 / sw.time());
            System.out.println("2D S: rate = " + rate);
            sw.restart();
            niter = 0;
            while (sw.time() < maxtime) {
                sp = ParallelTest.sumP(a[0]);
                ++niter;
            }
            sw.stop();
            rate = (int)((double)niter * mflop2 / sw.time());
            System.out.println("2D P: rate = " + rate);
            ss = 0.0f;
            sp = 0.0f;
            sw.restart();
            niter = 0;
            while (sw.time() < maxtime) {
                ss = ParallelTest.sumS(a);
                ++niter;
            }
            sw.stop();
            rate = (int)((double)niter * mflop3 / sw.time());
            System.out.println("3D S: rate = " + rate);
            sw.restart();
            niter = 0;
            while (sw.time() < maxtime) {
                sp = ParallelTest.sumP(a);
                ++niter;
            }
            sw.stop();
            rate = (int)((double)niter * mflop3 / sw.time());
            System.out.println("3D P: rate = " + rate);
        }
    }

    private static float sum(float[] a) {
        int n = a.length;
        float s = 0.0f;
        for (int i = 0; i < n; ++i) {
            s += a[i];
        }
        return s;
    }

    private static float sumS(float[][] a) {
        int n = a.length;
        float s = 0.0f;
        for (int i = 0; i < n; ++i) {
            s += ParallelTest.sum(a[i]);
        }
        return s;
    }

    private static float sumS(float[][][] a) {
        int n = a.length;
        float s = 0.0f;
        for (int i = 0; i < n; ++i) {
            s += ParallelTest.sumS(a[i]);
        }
        return s;
    }

    private static float sumP(final float[][] a) {
        int n = a.length;
        int chunk = ArrayMath.max(1, 10000 / a[0].length);
        return Parallel.reduce(0, n, 1, chunk, new Parallel.ReduceInt<Float>(){

            @Override
            public Float compute(int i) {
                return Float.valueOf(ParallelTest.sum(a[i]));
            }

            @Override
            public Float combine(Float s1, Float s2) {
                return Float.valueOf(s1.floatValue() + s2.floatValue());
            }
        }).floatValue();
    }

    private static float sumP(final float[][][] a) {
        int n = a.length;
        return Parallel.reduce(n, new Parallel.ReduceInt<Float>(){

            @Override
            public Float compute(int i) {
                return Float.valueOf(ParallelTest.sumP(a[i]));
            }

            @Override
            public Float combine(Float s1, Float s2) {
                return Float.valueOf(s1.floatValue() + s2.floatValue());
            }
        }).floatValue();
    }

    private static void benchMatrixMultiply() {
        int m = 1001;
        int n = 1002;
        System.out.println("Matrix multiply for m=" + m + " n=" + n);
        float[][] a = ArrayMath.randfloat(n, m);
        float[][] b = ArrayMath.randfloat(m, n);
        float[][] cs = ArrayMath.zerofloat(m, m);
        float[][] cp = ArrayMath.zerofloat(m, m);
        double maxtime = 5.0;
        double mflop = 2.0E-6 * (double)m * (double)m * (double)n;
        Stopwatch sw = new Stopwatch();
        for (int ntest = 0; ntest < 3; ++ntest) {
            sw.restart();
            int niter = 0;
            while (sw.time() < maxtime) {
                ParallelTest.matrixMultiplySerial(a, b, cs);
                ++niter;
            }
            sw.stop();
            int rate = (int)((double)niter * mflop / sw.time());
            System.out.println("S: rate = " + rate + " mflops");
            sw.restart();
            niter = 0;
            while (sw.time() < maxtime) {
                ParallelTest.matrixMultiplyParallel(a, b, cp);
                ++niter;
            }
            sw.stop();
            rate = (int)((double)niter * mflop / sw.time());
            System.out.println("P: rate = " + rate + " mflops");
        }
    }

    private static void matrixMultiplySerial(float[][] a, float[][] b, float[][] c) {
        int nj = c[0].length;
        for (int j = 0; j < nj; ++j) {
            ParallelTest.computeColumn(j, a, b, c);
        }
    }

    private static void matrixMultiplyParallel(final float[][] a, final float[][] b, final float[][] c) {
        int nj = c[0].length;
        Parallel.loop(nj, new Parallel.LoopInt(){

            @Override
            public void compute(int j) {
                ParallelTest.computeColumn(j, a, b, c);
            }
        });
    }

    private static void computeColumn(int j, float[][] a, float[][] b, float[][] c) {
        int ni = c.length;
        int nk = b.length;
        float[] bj = new float[nk];
        for (int k = 0; k < nk; ++k) {
            bj[k] = b[k][j];
        }
        for (int i = 0; i < ni; ++i) {
            int k;
            float[] ai = a[i];
            float cij = 0.0f;
            int mk = nk % 4;
            for (k = 0; k < mk; ++k) {
                cij += ai[k] * bj[k];
            }
            for (k = mk; k < nk; k += 4) {
                cij += ai[k] * bj[k];
                cij += ai[k + 1] * bj[k + 1];
                cij += ai[k + 2] * bj[k + 2];
                cij += ai[k + 3] * bj[k + 3];
            }
            c[i][j] = cij;
        }
    }

    private static float emax(float[] a, float[] b) {
        int n = a.length;
        float emax = 0.0f;
        for (int i = 0; i < n; ++i) {
            emax = ArrayMath.max(emax, ArrayMath.abs(b[i] - a[i]));
        }
        return emax;
    }

    private static float emax(float[][] a, float[][] b) {
        int n = a.length;
        float emax = 0.0f;
        for (int i = 0; i < n; ++i) {
            emax = ParallelTest.emax(a[i], b[i]);
        }
        return emax;
    }

    private static float emax(float[][][] a, float[][][] b) {
        int n = a.length;
        float emax = 0.0f;
        for (int i = 0; i < n; ++i) {
            emax = ParallelTest.emax(a[i], b[i]);
        }
        return emax;
    }

    private static void bench(boolean parallel) {
        Parallel.setParallel(parallel);
        ParallelTest.benchArraySqr();
        ParallelTest.benchArraySum();
        ParallelTest.benchMatrixMultiply();
    }

    private static class Worker {
        private boolean _working;

        private Worker() {
        }

        public void work() {
            Assert.assertTrue((!this._working ? 1 : 0) != 0);
            this._working = true;
            try {
                Thread.sleep(10L);
            }
            catch (InterruptedException e) {
                throw new RuntimeException(e);
            }
            this._working = false;
        }
    }
}

