/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.clustering.lda.cvb;

import java.io.IOException;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.mahout.clustering.lda.cvb.CVB0Driver;
import org.apache.mahout.clustering.lda.cvb.ModelTrainer;
import org.apache.mahout.clustering.lda.cvb.TopicModel;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.MatrixSlice;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CachingCVB0Mapper
extends Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable> {
    private static final Logger log = LoggerFactory.getLogger(CachingCVB0Mapper.class);
    private ModelTrainer modelTrainer;
    private TopicModel readModel;
    private TopicModel writeModel;
    private int maxIters;
    private int numTopics;

    protected ModelTrainer getModelTrainer() {
        return this.modelTrainer;
    }

    protected int getMaxIters() {
        return this.maxIters;
    }

    protected int getNumTopics() {
        return this.numTopics;
    }

    @Override
    protected void setup(Mapper.Context context) throws IOException, InterruptedException {
        log.info("Retrieving configuration");
        Configuration conf = context.getConfiguration();
        float eta = conf.getFloat("term_topic_smoothing", Float.NaN);
        float alpha = conf.getFloat("doc_topic_smoothing", Float.NaN);
        long seed = conf.getLong("random_seed", 1234L);
        this.numTopics = conf.getInt("num_topics", -1);
        int numTerms = conf.getInt("num_terms", -1);
        int numUpdateThreads = conf.getInt("num_update_threads", 1);
        int numTrainThreads = conf.getInt("num_train_threads", 4);
        this.maxIters = conf.getInt("max_doc_topic_iters", 10);
        float modelWeight = conf.getFloat("prev_iter_mult", 1.0f);
        log.info("Initializing read model");
        Path[] modelPaths = CVB0Driver.getModelPaths(conf);
        if (modelPaths != null && modelPaths.length > 0) {
            this.readModel = new TopicModel(conf, eta, (double)alpha, null, numUpdateThreads, modelWeight, modelPaths);
        } else {
            log.info("No model files found");
            this.readModel = new TopicModel(this.numTopics, numTerms, (double)eta, alpha, RandomUtils.getRandom(seed), null, numTrainThreads, modelWeight);
        }
        log.info("Initializing write model");
        this.writeModel = modelWeight == 1.0f ? new TopicModel(this.numTopics, numTerms, (double)eta, (double)alpha, null, (double)numUpdateThreads) : this.readModel;
        log.info("Initializing model trainer");
        this.modelTrainer = new ModelTrainer(this.readModel, this.writeModel, numTrainThreads, this.numTopics, numTerms);
        this.modelTrainer.start();
    }

    @Override
    public void map(IntWritable docId, VectorWritable document, Mapper.Context context) throws IOException, InterruptedException {
        Vector topicVector = new DenseVector(this.numTopics).assign(1.0 / (double)this.numTopics);
        this.modelTrainer.train(document.get(), topicVector, true, this.maxIters);
    }

    @Override
    protected void cleanup(Mapper.Context context) throws IOException, InterruptedException {
        log.info("Stopping model trainer");
        this.modelTrainer.stop();
        log.info("Writing model");
        TopicModel readFrom = this.modelTrainer.getReadModel();
        for (MatrixSlice topic : readFrom) {
            context.write(new IntWritable(topic.index()), new VectorWritable(topic.vector()));
        }
        this.readModel.stop();
        this.writeModel.stop();
    }
}

