/*
 * Decompiled with CFR 0.152.
 */
package edu.mines.jtk.bench;

import edu.mines.jtk.util.ArrayMath;
import edu.mines.jtk.util.Check;
import edu.mines.jtk.util.Parallel;
import edu.mines.jtk.util.Stopwatch;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicInteger;

public class MtMatMulBench {
    public static final int NTHREAD = Runtime.getRuntime().availableProcessors();

    public static void main(String[] args) {
        int m = 1001;
        int n = 1002;
        float[][] a = ArrayMath.randfloat(n, m);
        float[][] b = ArrayMath.randfloat(m, n);
        float[][] c1 = ArrayMath.zerofloat(m, m);
        float[][] c2 = ArrayMath.zerofloat(m, m);
        float[][] c3 = ArrayMath.zerofloat(m, m);
        float[][] c4 = ArrayMath.zerofloat(m, m);
        float[][] c5 = ArrayMath.zerofloat(m, m);
        Stopwatch s = new Stopwatch();
        double mflops = 2.0E-6 * (double)m * (double)m * (double)n;
        double maxtime = 5.0;
        System.out.println("Matrix multiply benchmark");
        System.out.println("m=" + m + " n=" + n + " nthread=" + NTHREAD);
        System.out.println("Methods:");
        System.out.println("mul1 = single-threaded");
        System.out.println("mul2 = multi-threaded (equal chunks)");
        System.out.println("mul3 = multi-threaded (atomic-integer)");
        System.out.println("mul4 = multi-threaded (thread-pool)");
        System.out.println("mul5 = multi-threaded (fork-join)");
        for (int ntrial = 0; ntrial < 3; ++ntrial) {
            System.out.println();
            s.restart();
            int nmul = 0;
            while (s.time() < maxtime) {
                MtMatMulBench.mul1(a, b, c1);
                ++nmul;
            }
            s.stop();
            System.out.println("mul1: rate=" + (int)((double)nmul * mflops / s.time()) + " mflops");
            s.restart();
            nmul = 0;
            while (s.time() < maxtime) {
                MtMatMulBench.mul2(a, b, c2);
                ++nmul;
            }
            s.stop();
            System.out.println("mul2: rate=" + (int)((double)nmul * mflops / s.time()) + " mflops");
            s.restart();
            nmul = 0;
            while (s.time() < maxtime) {
                MtMatMulBench.mul3(a, b, c3);
                ++nmul;
            }
            s.stop();
            System.out.println("mul3: rate=" + (int)((double)nmul * mflops / s.time()) + " mflops");
            s.restart();
            nmul = 0;
            while (s.time() < maxtime) {
                MtMatMulBench.mul4(a, b, c4);
                ++nmul;
            }
            s.stop();
            System.out.println("mul4: rate=" + (int)((double)nmul * mflops / s.time()) + " mflops");
            s.restart();
            nmul = 0;
            while (s.time() < maxtime) {
                MtMatMulBench.mul5(a, b, c5);
                ++nmul;
            }
            s.stop();
            System.out.println("mul5: rate=" + (int)((double)nmul * mflops / s.time()) + " mflops");
            MtMatMulBench.assertEquals(c1, c2);
            MtMatMulBench.assertEquals(c1, c3);
            MtMatMulBench.assertEquals(c1, c4);
            MtMatMulBench.assertEquals(c1, c5);
        }
    }

    private static void computeColumn(int j, float[][] a, float[][] b, float[][] c) {
        int ni = c.length;
        int nk = b.length;
        float[] bj = new float[nk];
        for (int k = 0; k < nk; ++k) {
            bj[k] = b[k][j];
        }
        for (int i = 0; i < ni; ++i) {
            int k;
            float[] ai = a[i];
            float cij = 0.0f;
            int mk = nk % 4;
            for (k = 0; k < mk; ++k) {
                cij += ai[k] * bj[k];
            }
            for (k = mk; k < nk; k += 4) {
                cij += ai[k] * bj[k];
                cij += ai[k + 1] * bj[k + 1];
                cij += ai[k + 2] * bj[k + 2];
                cij += ai[k + 3] * bj[k + 3];
            }
            c[i][j] = cij;
        }
    }

    private static void mul1(float[][] a, float[][] b, float[][] c) {
        MtMatMulBench.checkDimensions(a, b, c);
        int nj = c[0].length;
        for (int j = 0; j < nj; ++j) {
            MtMatMulBench.computeColumn(j, a, b, c);
        }
    }

    private static void mul2(final float[][] a, final float[][] b, final float[][] c) {
        MtMatMulBench.checkDimensions(a, b, c);
        int nj = c[0].length;
        int mj = 1 + nj / NTHREAD;
        Thread[] threads = new Thread[NTHREAD];
        for (int ithread = 0; ithread < NTHREAD; ++ithread) {
            final int jfirst = ithread * mj;
            final int jlast = Math.min(jfirst + mj, nj);
            threads[ithread] = new Thread(new Runnable(){

                @Override
                public void run() {
                    for (int j = jfirst; j < jlast; ++j) {
                        MtMatMulBench.computeColumn(j, a, b, c);
                    }
                }
            });
        }
        MtMatMulBench.startAndJoin(threads);
    }

    private static void mul3(final float[][] a, final float[][] b, final float[][] c) {
        MtMatMulBench.checkDimensions(a, b, c);
        final int nj = c[0].length;
        final AtomicInteger aj = new AtomicInteger();
        Thread[] threads = new Thread[NTHREAD];
        for (int ithread = 0; ithread < threads.length; ++ithread) {
            threads[ithread] = new Thread(new Runnable(){

                @Override
                public void run() {
                    int j = aj.getAndIncrement();
                    while (j < nj) {
                        MtMatMulBench.computeColumn(j, a, b, c);
                        j = aj.getAndIncrement();
                    }
                }
            });
        }
        MtMatMulBench.startAndJoin(threads);
    }

    private static void mul4(final float[][] a, final float[][] b, final float[][] c) {
        MtMatMulBench.checkDimensions(a, b, c);
        int nj = c[0].length;
        ExecutorService es = Executors.newFixedThreadPool(NTHREAD);
        ExecutorCompletionService cs = new ExecutorCompletionService(es);
        int j = 0;
        while (j < nj) {
            final int jj = j++;
            cs.submit(new Runnable(){

                @Override
                public void run() {
                    MtMatMulBench.computeColumn(jj, a, b, c);
                }
            }, null);
        }
        try {
            for (j = 0; j < nj; ++j) {
                cs.take();
            }
        }
        catch (InterruptedException ie) {
            throw new RuntimeException(ie);
        }
        es.shutdown();
    }

    private static void mul5(final float[][] a, final float[][] b, final float[][] c) {
        MtMatMulBench.checkDimensions(a, b, c);
        int nj = c[0].length;
        Parallel.loop(nj, new Parallel.LoopInt(){

            @Override
            public void compute(int j) {
                MtMatMulBench.computeColumn(j, a, b, c);
            }
        });
    }

    private static void startAndJoin(Thread[] threads) {
        for (Thread thread : threads) {
            thread.start();
        }
        try {
            for (Thread thread : threads) {
                thread.join();
            }
        }
        catch (InterruptedException ie) {
            throw new RuntimeException(ie);
        }
    }

    private static void assertEquals(float[][] a, float[][] b) {
        Check.state(a.length == b.length, "same dimensions");
        Check.state(a[0].length == b[0].length, "same dimensions");
        int m = a[0].length;
        int n = a.length;
        for (int i = 0; i < m; ++i) {
            for (int j = 0; j < n; ++j) {
                Check.state(a[i][j] == b[i][j], "same elements");
            }
        }
    }

    private static void checkDimensions(float[][] a, float[][] b, float[][] c) {
        int ma = a.length;
        int na = a[0].length;
        int mb = b.length;
        int nb = b[0].length;
        int mc = c.length;
        int nc = c[0].length;
        Check.argument(na == mb, "number of columns in A = number of rows in B");
        Check.argument(ma == mc, "number of rows in A = number of rows in C");
        Check.argument(nb == nc, "number of columns in B = number of columns in C");
    }
}

