/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.clustering.lda.cvb;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.Random;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import org.apache.hadoop.conf.Configurable;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.DistributedRowMatrixWriter;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.MatrixSlice;
import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.function.Functions;
import org.apache.mahout.math.stats.Sampler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TopicModel
implements Configurable,
Iterable<MatrixSlice> {
    private static final Logger log = LoggerFactory.getLogger(TopicModel.class);
    private final String[] dictionary;
    private final Matrix topicTermCounts;
    private final Vector topicSums;
    private final int numTopics;
    private final int numTerms;
    private final double eta;
    private final double alpha;
    private Configuration conf;
    private final Sampler sampler;
    private final int numThreads;
    private ThreadPoolExecutor threadPool;
    private Updater[] updaters;

    public int getNumTerms() {
        return this.numTerms;
    }

    public int getNumTopics() {
        return this.numTopics;
    }

    public TopicModel(int numTopics, int numTerms, double eta, double alpha, String[] dictionary, double modelWeight) {
        this(numTopics, numTerms, eta, alpha, null, dictionary, 1, modelWeight);
    }

    public TopicModel(Configuration conf, double eta, double alpha, String[] dictionary, int numThreads, double modelWeight, Path ... modelpath) throws IOException {
        this(TopicModel.loadModel(conf, modelpath), eta, alpha, dictionary, numThreads, modelWeight);
    }

    public TopicModel(int numTopics, int numTerms, double eta, double alpha, String[] dictionary, int numThreads, double modelWeight) {
        this((Matrix)new DenseMatrix(numTopics, numTerms), (Vector)new DenseVector(numTopics), eta, alpha, dictionary, numThreads, modelWeight);
    }

    public TopicModel(int numTopics, int numTerms, double eta, double alpha, Random random, String[] dictionary, int numThreads, double modelWeight) {
        this(TopicModel.randomMatrix(numTopics, numTerms, random), eta, alpha, dictionary, numThreads, modelWeight);
    }

    private TopicModel(Pair<Matrix, Vector> model, double eta, double alpha, String[] dict, int numThreads, double modelWeight) {
        this(model.getFirst(), model.getSecond(), eta, alpha, dict, numThreads, modelWeight);
    }

    public TopicModel(Matrix topicTermCounts, Vector topicSums, double eta, double alpha, String[] dictionary, double modelWeight) {
        this(topicTermCounts, topicSums, eta, alpha, dictionary, 1, modelWeight);
    }

    public TopicModel(Matrix topicTermCounts, double eta, double alpha, String[] dictionary, int numThreads, double modelWeight) {
        this(topicTermCounts, TopicModel.viewRowSums(topicTermCounts), eta, alpha, dictionary, numThreads, modelWeight);
    }

    public TopicModel(Matrix topicTermCounts, Vector topicSums, double eta, double alpha, String[] dictionary, int numThreads, double modelWeight) {
        this.dictionary = dictionary;
        this.topicTermCounts = topicTermCounts;
        this.topicSums = topicSums;
        this.numTopics = topicSums.size();
        this.numTerms = topicTermCounts.numCols();
        this.eta = eta;
        this.alpha = alpha;
        this.sampler = new Sampler((Random)RandomUtils.getRandom());
        this.numThreads = numThreads;
        if (modelWeight != 1.0) {
            topicSums.assign(Functions.mult((double)modelWeight));
            for (int x = 0; x < this.numTopics; ++x) {
                topicTermCounts.viewRow(x).assign(Functions.mult((double)modelWeight));
            }
        }
        this.initializeThreadPool();
    }

    private static Vector viewRowSums(Matrix m) {
        DenseVector v = new DenseVector(m.numRows());
        for (MatrixSlice slice : m) {
            v.set(slice.index(), slice.vector().norm(1.0));
        }
        return v;
    }

    private synchronized void initializeThreadPool() {
        if (this.threadPool != null) {
            this.threadPool.shutdown();
            try {
                this.threadPool.awaitTermination(100L, TimeUnit.SECONDS);
            }
            catch (InterruptedException e) {
                log.error("Could not terminate all threads for TopicModel in time.", (Throwable)e);
            }
        }
        this.threadPool = new ThreadPoolExecutor(this.numThreads, this.numThreads, 0L, TimeUnit.SECONDS, new ArrayBlockingQueue<Runnable>(this.numThreads * 10));
        this.threadPool.allowCoreThreadTimeOut(false);
        this.updaters = new Updater[this.numThreads];
        for (int i = 0; i < this.numThreads; ++i) {
            this.updaters[i] = new Updater();
            this.threadPool.submit(this.updaters[i]);
        }
    }

    Matrix topicTermCounts() {
        return this.topicTermCounts;
    }

    @Override
    public Iterator<MatrixSlice> iterator() {
        return this.topicTermCounts.iterateAll();
    }

    public Vector topicSums() {
        return this.topicSums;
    }

    private static Pair<Matrix, Vector> randomMatrix(int numTopics, int numTerms, Random random) {
        int x;
        DenseMatrix topicTermCounts = new DenseMatrix(numTopics, numTerms);
        DenseVector topicSums = new DenseVector(numTopics);
        if (random != null) {
            for (x = 0; x < numTopics; ++x) {
                for (int term = 0; term < numTerms; ++term) {
                    topicTermCounts.viewRow(x).set(term, random.nextDouble());
                }
            }
        }
        for (x = 0; x < numTopics; ++x) {
            topicSums.set(x, random == null ? 1.0 : topicTermCounts.viewRow(x).norm(1.0));
        }
        return Pair.of(topicTermCounts, topicSums);
    }

    public static Pair<Matrix, Vector> loadModel(Configuration conf, Path ... modelPaths) throws IOException {
        int numTopics = -1;
        int numTerms = -1;
        ArrayList<Pair<Integer, Vector>> rows = new ArrayList<Pair<Integer, Vector>>();
        for (Path path : modelPaths) {
            for (Pair row : new SequenceFileIterable(path, true, conf)) {
                rows.add(Pair.of(((IntWritable)row.getFirst()).get(), ((VectorWritable)row.getSecond()).get()));
                numTopics = Math.max(numTopics, ((IntWritable)row.getFirst()).get());
                if (numTerms >= 0) continue;
                numTerms = ((VectorWritable)row.getSecond()).get().size();
            }
        }
        if (rows.isEmpty()) {
            throw new IOException(Arrays.toString(modelPaths) + " have no vectors in it");
        }
        DenseMatrix model = new DenseMatrix(++numTopics, numTerms);
        DenseVector topicSums = new DenseVector(numTopics);
        for (Pair pair : rows) {
            model.viewRow(((Integer)pair.getFirst()).intValue()).assign((Vector)pair.getSecond());
            topicSums.set(((Integer)pair.getFirst()).intValue(), ((Vector)pair.getSecond()).norm(1.0));
        }
        return Pair.of(model, topicSums);
    }

    public String toString() {
        StringBuilder buf = new StringBuilder();
        for (int x = 0; x < this.numTopics; ++x) {
            String v = this.dictionary != null ? TopicModel.vectorToSortedString(this.topicTermCounts.viewRow(x).normalize(1.0), this.dictionary) : this.topicTermCounts.viewRow(x).asFormatString();
            buf.append(v).append('\n');
        }
        return buf.toString();
    }

    public int sampleTerm(Vector topicDistribution) {
        return this.sampler.sample(this.topicTermCounts.viewRow(this.sampler.sample(topicDistribution)));
    }

    public int sampleTerm(int topic) {
        return this.sampler.sample(this.topicTermCounts.viewRow(topic));
    }

    public synchronized void reset() {
        for (int x = 0; x < this.numTopics; ++x) {
            this.topicTermCounts.assignRow(x, (Vector)new SequentialAccessSparseVector(this.numTerms));
        }
        this.topicSums.assign(1.0);
        if (this.threadPool.isTerminated()) {
            this.initializeThreadPool();
        }
    }

    public synchronized void stop() {
        for (Updater updater : this.updaters) {
            updater.shutdown();
        }
        this.threadPool.shutdown();
        try {
            if (!this.threadPool.awaitTermination(60L, TimeUnit.SECONDS)) {
                log.warn("Threadpool timed out on await termination - jobs still running!");
            }
        }
        catch (InterruptedException e) {
            log.error("Interrupted shutting down!", (Throwable)e);
        }
    }

    public void renormalize() {
        for (int x = 0; x < this.numTopics; ++x) {
            this.topicTermCounts.assignRow(x, this.topicTermCounts.viewRow(x).normalize(1.0));
            this.topicSums.assign(1.0);
        }
    }

    public void trainDocTopicModel(Vector original, Vector topics, Matrix docTopicModel) {
        this.pTopicGivenTerm(original, topics, docTopicModel);
        this.normalizeByTopic(docTopicModel);
        for (Vector.Element e : original.nonZeroes()) {
            for (int x = 0; x < this.numTopics; ++x) {
                Vector docTopicModelRow = docTopicModel.viewRow(x);
                docTopicModelRow.setQuick(e.index(), docTopicModelRow.getQuick(e.index()) * e.get());
            }
        }
        topics.assign(0.0);
        for (int x = 0; x < this.numTopics; ++x) {
            topics.set(x, docTopicModel.viewRow(x).norm(1.0));
        }
        topics.assign(Functions.mult((double)(1.0 / topics.norm(1.0))));
    }

    public Vector infer(Vector original, Vector docTopics) {
        Vector pTerm = original.like();
        for (Vector.Element e : original.nonZeroes()) {
            int term = e.index();
            double pA = 0.0;
            for (int x = 0; x < this.numTopics; ++x) {
                pA += this.topicTermCounts.viewRow(x).get(term) / this.topicSums.get(x) * docTopics.get(x);
            }
            pTerm.set(term, pA);
        }
        return pTerm;
    }

    public void update(Matrix docTopicCounts) {
        for (int x = 0; x < this.numTopics; ++x) {
            this.updaters[x % this.updaters.length].update(x, docTopicCounts.viewRow(x));
        }
    }

    public void updateTopic(int topic, Vector docTopicCounts) {
        this.topicTermCounts.viewRow(topic).assign(docTopicCounts, Functions.PLUS);
        this.topicSums.set(topic, this.topicSums.get(topic) + docTopicCounts.norm(1.0));
    }

    public void update(int termId, Vector topicCounts) {
        for (int x = 0; x < this.numTopics; ++x) {
            Vector v = this.topicTermCounts.viewRow(x);
            v.set(termId, v.get(termId) + topicCounts.get(x));
        }
        this.topicSums.assign(topicCounts, Functions.PLUS);
    }

    public void persist(Path outputDir, boolean overwrite) throws IOException {
        FileSystem fs = outputDir.getFileSystem(this.conf);
        if (overwrite) {
            fs.delete(outputDir, true);
        }
        DistributedRowMatrixWriter.write(outputDir, this.conf, (Iterable<MatrixSlice>)this.topicTermCounts);
    }

    private void pTopicGivenTerm(Vector document, Vector docTopics, Matrix termTopicDist) {
        for (int x = 0; x < this.numTopics; ++x) {
            double topicWeight = docTopics == null ? 1.0 : docTopics.get(x);
            Vector topicTermRow = this.topicTermCounts.viewRow(x);
            double topicSum = this.topicSums.get(x);
            Vector termTopicRow = termTopicDist.viewRow(x);
            for (Vector.Element e : document.nonZeroes()) {
                int termIndex = e.index();
                double termTopicLikelihood = (topicTermRow.get(termIndex) + this.eta) * (topicWeight + this.alpha) / (topicSum + this.eta * (double)this.numTerms);
                termTopicRow.set(termIndex, termTopicLikelihood);
            }
        }
    }

    public double perplexity(Vector document, Vector docTopics) {
        double perplexity = 0.0;
        double norm = docTopics.norm(1.0) + (double)docTopics.size() * this.alpha;
        for (Vector.Element e : document.nonZeroes()) {
            int term = e.index();
            double prob = 0.0;
            for (int x = 0; x < this.numTopics; ++x) {
                double d = (docTopics.get(x) + this.alpha) / norm;
                double p = d * (this.topicTermCounts.viewRow(x).get(term) + this.eta) / (this.topicSums.get(x) + this.eta * (double)this.numTerms);
                prob += p;
            }
            perplexity += e.get() * Math.log(prob);
        }
        return -perplexity;
    }

    private void normalizeByTopic(Matrix perTopicSparseDistributions) {
        for (Vector.Element e : perTopicSparseDistributions.viewRow(0).nonZeroes()) {
            int x;
            int a = e.index();
            double sum = 0.0;
            for (x = 0; x < this.numTopics; ++x) {
                sum += perTopicSparseDistributions.viewRow(x).get(a);
            }
            for (x = 0; x < this.numTopics; ++x) {
                perTopicSparseDistributions.viewRow(x).set(a, perTopicSparseDistributions.viewRow(x).get(a) / sum);
            }
        }
    }

    public static String vectorToSortedString(Vector vector, String[] dictionary) {
        ArrayList<Pair<String, Double>> vectorValues = new ArrayList<Pair<String, Double>>(vector.getNumNondefaultElements());
        for (Vector.Element e : vector.nonZeroes()) {
            vectorValues.add(Pair.of(dictionary != null ? dictionary[e.index()] : String.valueOf(e.index()), e.get()));
        }
        Collections.sort(vectorValues, new Comparator<Pair<String, Double>>(){

            @Override
            public int compare(Pair<String, Double> x, Pair<String, Double> y) {
                return y.getSecond().compareTo(x.getSecond());
            }
        });
        Iterator listIt = vectorValues.iterator();
        StringBuilder bldr = new StringBuilder(2048);
        bldr.append('{');
        for (int i = 0; listIt.hasNext() && i < 25; ++i) {
            Pair p = (Pair)listIt.next();
            bldr.append((String)p.getFirst());
            bldr.append(':');
            bldr.append(p.getSecond());
            bldr.append(',');
        }
        if (bldr.length() > 1) {
            bldr.setCharAt(bldr.length() - 1, '}');
        }
        return bldr.toString();
    }

    public void setConf(Configuration configuration) {
        this.conf = configuration;
    }

    public Configuration getConf() {
        return this.conf;
    }

    private final class Updater
    implements Runnable {
        private final ArrayBlockingQueue<Pair<Integer, Vector>> queue = new ArrayBlockingQueue(100);
        private boolean shutdown = false;
        private boolean shutdownComplete = false;

        private Updater() {
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public void shutdown() {
            try {
                Updater updater = this;
                synchronized (updater) {
                    while (!this.shutdownComplete) {
                        this.shutdown = true;
                        this.wait(10000L);
                    }
                }
            }
            catch (InterruptedException e) {
                log.warn("Interrupted waiting to shutdown() : ", (Throwable)e);
            }
        }

        public boolean update(int topic, Vector v) {
            if (this.shutdown) {
                throw new IllegalStateException("In SHUTDOWN state: cannot submit tasks");
            }
            while (true) {
                try {
                    this.queue.put(Pair.of(topic, v));
                    return true;
                }
                catch (InterruptedException e) {
                    log.warn("Interrupted trying to queue update:", (Throwable)e);
                    continue;
                }
                break;
            }
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public void run() {
            while (!this.shutdown) {
                try {
                    Pair<Integer, Vector> pair = this.queue.poll(1L, TimeUnit.SECONDS);
                    if (pair == null) continue;
                    TopicModel.this.updateTopic(pair.getFirst(), pair.getSecond());
                }
                catch (InterruptedException e) {
                    log.warn("Interrupted waiting to poll for update", (Throwable)e);
                }
            }
            for (Pair<Integer, Vector> pair : this.queue) {
                TopicModel.this.updateTopic(pair.getFirst(), pair.getSecond());
            }
            Updater updater = this;
            synchronized (updater) {
                this.shutdownComplete = true;
                this.notifyAll();
            }
        }
    }
}

