/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.crypto.key.kms.server;

import java.io.IOException;
import java.io.PrintStream;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.crypto.key.KeyProvider;
import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension;
import org.apache.hadoop.util.ExitUtil;
import org.apache.hadoop.util.GenericOptionsParser;
import org.apache.hadoop.util.KMSUtil;
import org.apache.hadoop.util.StringUtils;
import org.apache.hadoop.util.Time;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KMSBenchmark
implements Tool {
    private static final Logger LOG = LoggerFactory.getLogger(KMSBenchmark.class);
    private static final String GENERAL_OPTIONS_USAGE = "[-threads int] | [-numops int] | [{-warmup (true|false)}]";
    private static Configuration config;
    private KeyProviderCryptoExtension kp;
    private KeyProviderCryptoExtension.EncryptedKeyVersion eek = null;
    private String encryptionKeyName = "systest";
    private boolean createEncryptionKey = false;
    private boolean warmupKey = false;
    private List<String> keys = new ArrayList<String>();

    KMSBenchmark(Configuration conf, String[] args) throws IOException {
        config = conf;
        this.kp = KMSBenchmark.createKeyProviderCryptoExtension(config);
        try {
            this.eek = this.kp.generateEncryptedKey(this.encryptionKeyName);
        }
        catch (GeneralSecurityException e) {
            LOG.warn("failed to generate key", (Throwable)e);
        }
        for (int i = 2; i < args.length; ++i) {
            if (args[i].equals("-warmup")) {
                this.warmupKey = Boolean.parseBoolean(args[++i]);
                continue;
            }
            if (!args[i].equals("-createkey")) continue;
            this.encryptionKeyName = args[++i];
        }
        try {
            if (this.createEncryptionKey) {
                this.keys = this.kp.getKeys();
                if (!this.keys.contains(this.encryptionKeyName)) {
                    this.kp.createKey(this.encryptionKeyName, KeyProvider.options((Configuration)conf));
                } else {
                    LOG.warn("encryption key already exists: {}", (Object)this.encryptionKeyName);
                }
            }
            if (this.warmupKey) {
                this.kp.warmUpEncryptedKeys(new String[]{this.encryptionKeyName});
            }
        }
        catch (GeneralSecurityException e) {
            LOG.warn(" failed to create or warmup encryption key", (Throwable)e);
        }
    }

    static void printUsage() {
        System.err.println("Usage: KMSBenchmark\n\t-op all <other ops options> | \n\t-op encrypt [-threads T -numops N -warmup F] | \n\t-op decrypt [-threads T -numops N -warmup F] | \n\t[-threads int] | [-numops int] | [{-warmup (true|false)}]");
        System.err.println();
        GenericOptionsParser.printGenericCommandUsage((PrintStream)System.err);
        ExitUtil.terminate((int)-1);
    }

    public static KeyProviderCryptoExtension createKeyProviderCryptoExtension(Configuration conf) throws IOException {
        KeyProvider keyProvider = KMSUtil.createKeyProvider((Configuration)conf, (String)"hadoop.security.key.provider.path");
        if (keyProvider == null) {
            throw new IOException("Key provider was not configured.");
        }
        return KeyProviderCryptoExtension.createKeyProviderCryptoExtension((KeyProvider)keyProvider);
    }

    public static void runBenchmark(Configuration conf, String[] args) throws Exception {
        KMSBenchmark bench = null;
        try {
            bench = new KMSBenchmark(conf, args);
            ToolRunner.run((Tool)bench, (String[])args);
        }
        finally {
            LOG.info("runBenchmark finished.");
        }
    }

    public int run(String[] aArgs) throws Exception {
        ArrayList<String> args = new ArrayList<String>(Arrays.asList(aArgs));
        if (args.size() < 2 || !((String)args.get(0)).startsWith("-op")) {
            KMSBenchmark.printUsage();
        }
        String type = (String)args.get(1);
        boolean runAll = "all".equals(type);
        ArrayList<EncryptKeyStats> ops = new ArrayList<EncryptKeyStats>();
        OperationStatsBase opStat = null;
        try {
            if (runAll || "encrypt".equals(type)) {
                opStat = new EncryptKeyStats(args);
                ops.add((EncryptKeyStats)opStat);
            }
            if (runAll || "decrypt".equals(type)) {
                opStat = new DecryptKeyStats(args);
                ops.add((EncryptKeyStats)opStat);
            }
            if (ops.isEmpty()) {
                KMSBenchmark.printUsage();
            }
            for (OperationStatsBase operationStatsBase : ops) {
                LOG.info("Starting benchmark: " + operationStatsBase.getOpName());
                operationStatsBase.benchmark();
                operationStatsBase.cleanUp();
            }
            for (OperationStatsBase operationStatsBase : ops) {
                LOG.info("");
                operationStatsBase.printResults();
            }
        }
        catch (Exception e) {
            LOG.error("failed to run benchmarks", (Throwable)e);
            throw e;
        }
        return 0;
    }

    public static void main(String[] args) throws Exception {
        KMSBenchmark.runBenchmark(new Configuration(), args);
    }

    public void setConf(Configuration conf) {
        config = conf;
    }

    public Configuration getConf() {
        return config;
    }

    class DecryptKeyStats
    extends OperationStatsBase {
        static final String OP_DECRYPT_KEY = "decrypt";
        static final String OP_DECRYPT_USAGE = "-op decrypt [-threads T -numops N -warmup F]";

        DecryptKeyStats(List<String> args) {
            this.parseArguments(args);
        }

        @Override
        String getOpName() {
            return OP_DECRYPT_KEY;
        }

        @Override
        void parseArguments(List<String> args) {
            this.verifyOpArgument(args);
            for (int i = 2; i < args.size(); ++i) {
                if (args.get(i).equals("-threads")) {
                    if (i + 1 == args.size()) {
                        KMSBenchmark.printUsage();
                    }
                    this.setNumThreads(Integer.parseInt(args.get(++i)));
                    continue;
                }
                if (!args.get(i).equals("-numops")) continue;
                this.setNumOpsRequired(Integer.parseInt(args.get(++i)));
            }
        }

        @Override
        String getExecutionArgument(int daemonId) {
            return this.getClientName(daemonId);
        }

        @Override
        long executeOp(int daemonId, int inputIdx, String clientName) throws IOException {
            long start = Time.now();
            try {
                KMSBenchmark.this.kp.decryptEncryptedKey(KMSBenchmark.this.eek);
            }
            catch (GeneralSecurityException e) {
                LOG.warn("failed to generate and/or decrypt key", (Throwable)e);
            }
            long end = Time.now();
            return end - start;
        }

        @Override
        void printResults() {
            LOG.info("--- " + this.getOpName() + " inputs ---");
            LOG.info("nrOps = " + this.getNumOpsRequired());
            LOG.info("nrThreads = " + this.getNumThreads());
            this.printStats();
        }
    }

    class EncryptKeyStats
    extends OperationStatsBase {
        static final String OP_ENCRYPT_KEY = "encrypt";
        static final String OP_ENCRYPT_USAGE = "-op encrypt [-threads T -numops N -warmup F]";

        EncryptKeyStats(List<String> args) {
            this.parseArguments(args);
        }

        @Override
        String getOpName() {
            return OP_ENCRYPT_KEY;
        }

        @Override
        void parseArguments(List<String> args) {
            this.verifyOpArgument(args);
            for (int i = 2; i < args.size(); ++i) {
                if (args.get(i).equals("-threads")) {
                    if (i + 1 == args.size()) {
                        KMSBenchmark.printUsage();
                    }
                    this.setNumThreads(Integer.parseInt(args.get(++i)));
                    continue;
                }
                if (!args.get(i).equals("-numops")) continue;
                this.setNumOpsRequired(Integer.parseInt(args.get(++i)));
            }
        }

        @Override
        String getExecutionArgument(int daemonId) {
            return this.getClientName(daemonId);
        }

        @Override
        long executeOp(int daemonId, int inputIdx, String clientName) throws IOException {
            long start = Time.now();
            try {
                KMSBenchmark.this.eek = KMSBenchmark.this.kp.generateEncryptedKey(KMSBenchmark.this.encryptionKeyName);
            }
            catch (GeneralSecurityException e) {
                LOG.warn("failed to generate encrypted key", (Throwable)e);
            }
            long end = Time.now();
            return end - start;
        }

        @Override
        void printResults() {
            LOG.info("--- " + this.getOpName() + " inputs ---");
            LOG.info("nOps = " + this.getNumOpsRequired());
            LOG.info("nThreads = " + this.getNumThreads());
            this.printStats();
        }
    }

    private class StatsDaemon
    extends Thread {
        private final int daemonId;
        private int opsPerThread;
        private String arg1;
        private volatile int localNumOpsExecuted = 0;
        private volatile long localCumulativeTime = 0L;
        private final OperationStatsBase statsOp;

        StatsDaemon(int daemonId, int nOps, OperationStatsBase op) {
            this.daemonId = daemonId;
            this.opsPerThread = nOps;
            this.statsOp = op;
            this.setName(this.toString());
        }

        @Override
        public void run() {
            this.localNumOpsExecuted = 0;
            this.localCumulativeTime = 0L;
            this.arg1 = this.statsOp.getExecutionArgument(this.daemonId);
            try {
                this.benchmarkOne();
            }
            catch (IOException ex) {
                LOG.error("StatsDaemon " + this.daemonId + " failed: \n" + StringUtils.stringifyException((Throwable)ex));
            }
        }

        @Override
        public String toString() {
            return "StatsDaemon-" + this.daemonId;
        }

        void benchmarkOne() throws IOException {
            for (int idx = 0; idx < this.opsPerThread; ++idx) {
                long stat = this.statsOp.executeOp(this.daemonId, idx, this.arg1);
                ++this.localNumOpsExecuted;
                this.localCumulativeTime += stat;
            }
        }

        boolean isInProgress() {
            return this.localNumOpsExecuted < this.opsPerThread;
        }

        void terminate() {
            this.opsPerThread = this.localNumOpsExecuted;
        }
    }

    abstract class OperationStatsBase {
        protected static final String OP_ALL_NAME = "all";
        protected static final String OP_ALL_USAGE = "-op all <other ops options>";
        private int numThreads = 3;
        private int numOpsRequired = 10000;
        private int numOpsExecuted = 0;
        private long cumulativeTime = 0L;
        private long elapsedTime = 0L;
        private List<StatsDaemon> daemons;

        abstract String getOpName();

        abstract void parseArguments(List<String> var1) throws IOException;

        abstract String getExecutionArgument(int var1);

        abstract long executeOp(int var1, int var2, String var3) throws IOException;

        abstract void printResults();

        OperationStatsBase() {
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        void benchmark() throws IOException {
            this.daemons = new ArrayList<StatsDaemon>();
            long start = 0L;
            try {
                this.numOpsExecuted = 0;
                this.cumulativeTime = 0L;
                if (this.numThreads < 1) {
                    return;
                }
                int tIdx = 0;
                int[] opsPerThread = new int[this.numThreads];
                for (int opsScheduled = 0; opsScheduled < this.numOpsRequired; opsScheduled += opsPerThread[tIdx++]) {
                    opsPerThread[tIdx] = (this.numOpsRequired - opsScheduled) / (this.numThreads - tIdx);
                    if (opsPerThread[tIdx] != 0) continue;
                    opsPerThread[tIdx] = 1;
                }
                while (tIdx < this.numThreads) {
                    opsPerThread[tIdx] = 0;
                    ++tIdx;
                }
                for (tIdx = 0; tIdx < this.numThreads; ++tIdx) {
                    this.daemons.add(new StatsDaemon(tIdx, opsPerThread[tIdx], this));
                }
                start = Time.now();
                LOG.info("Starting " + this.numOpsRequired + " " + this.getOpName() + "(s).");
                for (StatsDaemon d : this.daemons) {
                    d.start();
                }
            }
            finally {
                while (this.isInProgress()) {
                    try {
                        Thread.sleep(500L);
                    }
                    catch (InterruptedException interruptedException) {}
                }
                this.elapsedTime = Time.now() - start;
                for (StatsDaemon d : this.daemons) {
                    this.incrementStats(d.localNumOpsExecuted, d.localCumulativeTime);
                    System.out.println(d.toString() + ": ops Exec = " + d.localNumOpsExecuted);
                }
            }
        }

        private boolean isInProgress() {
            for (StatsDaemon d : this.daemons) {
                if (!d.isInProgress()) continue;
                return true;
            }
            return false;
        }

        void cleanUp() throws IOException {
        }

        int getNumOpsExecuted() {
            return this.numOpsExecuted;
        }

        long getCumulativeTime() {
            return this.cumulativeTime;
        }

        long getElapsedTime() {
            return this.elapsedTime;
        }

        long getAverageTime() {
            LOG.info("getAverageTime, cumulativeTime = " + this.cumulativeTime);
            LOG.info("getAverageTime, numOpsExecuted = " + this.numOpsExecuted);
            return this.numOpsExecuted == 0 ? 0L : this.cumulativeTime / (long)this.numOpsExecuted;
        }

        double getOpsPerSecond() {
            return this.elapsedTime == 0L ? 0.0 : 1000.0 * (double)this.numOpsExecuted / (double)this.elapsedTime;
        }

        String getClientName(int idx) {
            return this.getOpName() + "-client-" + idx;
        }

        void incrementStats(int ops, long time) {
            this.numOpsExecuted += ops;
            this.cumulativeTime += time;
        }

        int getNumThreads() {
            return this.numThreads;
        }

        void setNumThreads(int num) {
            this.numThreads = num;
        }

        int getNumOpsRequired() {
            return this.numOpsRequired;
        }

        void setNumOpsRequired(int num) {
            this.numOpsRequired = num;
        }

        protected boolean verifyOpArgument(List<String> args) {
            String type;
            if (args.size() < 2 || !args.get(0).startsWith("-op")) {
                KMSBenchmark.printUsage();
            }
            if (OP_ALL_NAME.equals(type = args.get(1))) {
                type = this.getOpName();
                return true;
            }
            if (!this.getOpName().equals(type)) {
                KMSBenchmark.printUsage();
            }
            return false;
        }

        void printStats() {
            LOG.info("--- " + this.getOpName() + " stats  ---");
            LOG.info("# operations: " + this.getNumOpsExecuted());
            LOG.info("Elapsed Time: " + this.getElapsedTime());
            LOG.info(" Ops per sec: " + this.getOpsPerSecond());
            LOG.info("Average Time: " + this.getAverageTime());
        }
    }
}

