package org.apache.mahout.cf.taste.hadoop.als;

import com.google.common.base.Preconditions;
import com.google.common.io.Closeables;
import com.ibm.icu.text.DateFormat;
import java.io.IOException;
import java.util.Iterator;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.filecache.DistributedCache;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
import org.apache.hadoop.mapreduce.lib.map.MultithreadedMapper;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils;
import org.apache.mahout.cf.taste.hadoop.preparation.PreparePreferenceMatrixJob;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.RandomWrapper;
import org.apache.mahout.common.iterator.sequencefile.PathFilters;
import org.apache.mahout.common.mapreduce.MergeVectorsCombiner;
import org.apache.mahout.common.mapreduce.MergeVectorsReducer;
import org.apache.mahout.common.mapreduce.TransposeMapper;
import org.apache.mahout.common.mapreduce.VectorSumCombiner;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.VarIntWritable;
import org.apache.mahout.math.VarLongWritable;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.hadoop.similarity.cooccurrence.Vectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob.class */
public class ParallelALSFactorizationJob extends AbstractJob {
    private static final Logger log = LoggerFactory.getLogger(ParallelALSFactorizationJob.class);
    static final String NUM_FEATURES = ParallelALSFactorizationJob.class.getName() + ".numFeatures";
    static final String LAMBDA = ParallelALSFactorizationJob.class.getName() + ".lambda";
    static final String ALPHA = ParallelALSFactorizationJob.class.getName() + ".alpha";
    static final String NUM_ENTITIES = ParallelALSFactorizationJob.class.getName() + ".numEntities";
    static final String USES_LONG_IDS = ParallelALSFactorizationJob.class.getName() + ".usesLongIDs";
    static final String TOKEN_POS = ParallelALSFactorizationJob.class.getName() + ".tokenPos";
    private boolean implicitFeedback;
    private int numIterations;
    private int numFeatures;
    private double lambda;
    private double alpha;
    private int numThreadsPerSolver;
    private boolean usesLongIDs;
    private int numItems;
    private int numUsers;

    /* loaded from: input_file:org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob$AverageRatingMapper.class */
    static class AverageRatingMapper extends Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable> {
        private final IntWritable firstIndex = new IntWritable(0);
        private final Vector featureVector = new RandomAccessSparseVector(Integer.MAX_VALUE, 1);
        private final VectorWritable featureVectorWritable = new VectorWritable();

        AverageRatingMapper() {
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // org.apache.hadoop.mapreduce.Mapper
        public void map(IntWritable intWritable, VectorWritable vectorWritable, Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable>.Context context) throws IOException, InterruptedException {
            FullRunningAverage fullRunningAverage = new FullRunningAverage();
            Iterator<Vector.Element> it = vectorWritable.get().nonZeroes().iterator();
            while (it.hasNext()) {
                fullRunningAverage.addDatum(it.next().get());
            }
            this.featureVector.setQuick(intWritable.get(), fullRunningAverage.getAverage());
            this.featureVectorWritable.set(this.featureVector);
            context.write(this.firstIndex, this.featureVectorWritable);
            this.featureVector.setQuick(intWritable.get(), 0.0d);
        }
    }

    /* loaded from: input_file:org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob$IDMapReducer.class */
    static class IDMapReducer extends Reducer<VarIntWritable, VarLongWritable, VarIntWritable, VarLongWritable> {
        IDMapReducer() {
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // org.apache.hadoop.mapreduce.Reducer
        public void reduce(VarIntWritable varIntWritable, Iterable<VarLongWritable> iterable, Reducer<VarIntWritable, VarLongWritable, VarIntWritable, VarLongWritable>.Context context) throws IOException, InterruptedException {
            context.write(varIntWritable, iterable.iterator().next());
        }
    }

    /* loaded from: input_file:org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob$ItemRatingVectorsMapper.class */
    static class ItemRatingVectorsMapper extends Mapper<LongWritable, Text, IntWritable, VectorWritable> {
        private final IntWritable itemIDWritable = new IntWritable();
        private final VectorWritable ratingsWritable = new VectorWritable(true);
        private final Vector ratings = new RandomAccessSparseVector(Integer.MAX_VALUE, 1);
        private boolean usesLongIDs;

        ItemRatingVectorsMapper() {
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // org.apache.hadoop.mapreduce.Mapper
        public void setup(Mapper<LongWritable, Text, IntWritable, VectorWritable>.Context context) throws IOException, InterruptedException {
            this.usesLongIDs = context.getConfiguration().getBoolean(ParallelALSFactorizationJob.USES_LONG_IDS, false);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // org.apache.hadoop.mapreduce.Mapper
        public void map(LongWritable longWritable, Text text, Mapper<LongWritable, Text, IntWritable, VectorWritable>.Context context) throws IOException, InterruptedException {
            String[] splitPrefTokens = TasteHadoopUtils.splitPrefTokens(text.toString());
            int readID = TasteHadoopUtils.readID(splitPrefTokens[0], this.usesLongIDs);
            int readID2 = TasteHadoopUtils.readID(splitPrefTokens[1], this.usesLongIDs);
            this.ratings.setQuick(readID, Float.parseFloat(splitPrefTokens[2]));
            this.itemIDWritable.set(readID2);
            this.ratingsWritable.set(this.ratings);
            context.write(this.itemIDWritable, this.ratingsWritable);
            this.ratings.setQuick(readID, 0.0d);
        }
    }

    /* loaded from: input_file:org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob$MapLongIDsMapper.class */
    static class MapLongIDsMapper extends Mapper<LongWritable, Text, VarIntWritable, VarLongWritable> {
        private int tokenPos;
        private final VarIntWritable index = new VarIntWritable();
        private final VarLongWritable idWritable = new VarLongWritable();

        MapLongIDsMapper() {
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // org.apache.hadoop.mapreduce.Mapper
        public void setup(Mapper<LongWritable, Text, VarIntWritable, VarLongWritable>.Context context) throws IOException, InterruptedException {
            this.tokenPos = context.getConfiguration().getInt(ParallelALSFactorizationJob.TOKEN_POS, -1);
            Preconditions.checkState(this.tokenPos >= 0);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // org.apache.hadoop.mapreduce.Mapper
        public void map(LongWritable longWritable, Text text, Mapper<LongWritable, Text, VarIntWritable, VarLongWritable>.Context context) throws IOException, InterruptedException {
            long parseLong = Long.parseLong(TasteHadoopUtils.splitPrefTokens(text.toString())[this.tokenPos]);
            this.index.set(TasteHadoopUtils.idToIndex(parseLong));
            this.idWritable.set(parseLong);
            context.write(this.index, this.idWritable);
        }
    }

    /* loaded from: input_file:org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob$MergeUserVectorsReducer.class */
    static class MergeUserVectorsReducer extends Reducer<WritableComparable<?>, VectorWritable, WritableComparable<?>, VectorWritable> {
        private final VectorWritable result = new VectorWritable();

        MergeUserVectorsReducer() {
        }

        @Override // org.apache.hadoop.mapreduce.Reducer
        public void reduce(WritableComparable<?> writableComparable, Iterable<VectorWritable> iterable, Reducer<WritableComparable<?>, VectorWritable, WritableComparable<?>, VectorWritable>.Context context) throws IOException, InterruptedException {
            this.result.set(new SequentialAccessSparseVector(VectorWritable.merge(iterable.iterator()).get()));
            context.write(writableComparable, this.result);
            context.getCounter(Stats.NUM_USERS).increment(1L);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob$Stats.class */
    public enum Stats {
        NUM_USERS
    }

    /* loaded from: input_file:org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob$VectorSumReducer.class */
    static class VectorSumReducer extends Reducer<WritableComparable<?>, VectorWritable, WritableComparable<?>, VectorWritable> {
        private final VectorWritable result = new VectorWritable();

        VectorSumReducer() {
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // org.apache.hadoop.mapreduce.Reducer
        public void reduce(WritableComparable<?> writableComparable, Iterable<VectorWritable> iterable, Reducer<WritableComparable<?>, VectorWritable, WritableComparable<?>, VectorWritable>.Context context) throws IOException, InterruptedException {
            this.result.set(new SequentialAccessSparseVector(Vectors.sum(iterable.iterator())));
            context.write(writableComparable, this.result);
        }
    }

    public static void main(String[] strArr) throws Exception {
        ToolRunner.run(new ParallelALSFactorizationJob(), strArr);
    }

    @Override // org.apache.hadoop.util.Tool
    public int run(String[] strArr) throws Exception {
        addInputOption();
        addOutputOption();
        addOption("lambda", (String) null, "regularization parameter", true);
        addOption("implicitFeedback", (String) null, "data consists of implicit feedback?", String.valueOf(false));
        addOption("alpha", (String) null, "confidence parameter (only used on implicit feedback)", String.valueOf(40));
        addOption("numFeatures", (String) null, "dimension of the feature space", true);
        addOption("numIterations", (String) null, "number of iterations", true);
        addOption("numThreadsPerSolver", (String) null, "threads per solver mapper", String.valueOf(1));
        addOption("usesLongIDs", null, "input contains long IDs that need to be translated");
        if (parseArguments(strArr) == null) {
            return -1;
        }
        this.numFeatures = Integer.parseInt(getOption("numFeatures"));
        this.numIterations = Integer.parseInt(getOption("numIterations"));
        this.lambda = Double.parseDouble(getOption("lambda"));
        this.alpha = Double.parseDouble(getOption("alpha"));
        this.implicitFeedback = Boolean.parseBoolean(getOption("implicitFeedback"));
        this.numThreadsPerSolver = Integer.parseInt(getOption("numThreadsPerSolver"));
        this.usesLongIDs = Boolean.parseBoolean(getOption("usesLongIDs", String.valueOf(false)));
        if (this.usesLongIDs) {
            Job prepareJob = prepareJob(getInputPath(), getOutputPath("userIDIndex"), TextInputFormat.class, MapLongIDsMapper.class, VarIntWritable.class, VarLongWritable.class, IDMapReducer.class, VarIntWritable.class, VarLongWritable.class, SequenceFileOutputFormat.class);
            prepareJob.getConfiguration().set(TOKEN_POS, String.valueOf(0));
            prepareJob.waitForCompletion(true);
            Job prepareJob2 = prepareJob(getInputPath(), getOutputPath(PreparePreferenceMatrixJob.ITEMID_INDEX), TextInputFormat.class, MapLongIDsMapper.class, VarIntWritable.class, VarLongWritable.class, IDMapReducer.class, VarIntWritable.class, VarLongWritable.class, SequenceFileOutputFormat.class);
            prepareJob2.getConfiguration().set(TOKEN_POS, String.valueOf(1));
            prepareJob2.waitForCompletion(true);
        }
        Job prepareJob3 = prepareJob(getInputPath(), pathToItemRatings(), TextInputFormat.class, ItemRatingVectorsMapper.class, IntWritable.class, VectorWritable.class, VectorSumReducer.class, IntWritable.class, VectorWritable.class, SequenceFileOutputFormat.class);
        prepareJob3.setCombinerClass(VectorSumCombiner.class);
        prepareJob3.getConfiguration().set(USES_LONG_IDS, String.valueOf(this.usesLongIDs));
        if (!prepareJob3.waitForCompletion(true)) {
            return -1;
        }
        Job prepareJob4 = prepareJob(pathToItemRatings(), pathToUserRatings(), TransposeMapper.class, IntWritable.class, VectorWritable.class, MergeUserVectorsReducer.class, IntWritable.class, VectorWritable.class);
        prepareJob4.setCombinerClass(MergeVectorsCombiner.class);
        if (!prepareJob4.waitForCompletion(true)) {
            return -1;
        }
        Job prepareJob5 = prepareJob(pathToItemRatings(), getTempPath("averageRatings"), AverageRatingMapper.class, IntWritable.class, VectorWritable.class, MergeVectorsReducer.class, IntWritable.class, VectorWritable.class);
        prepareJob5.setCombinerClass(MergeVectorsCombiner.class);
        if (!prepareJob5.waitForCompletion(true)) {
            return -1;
        }
        Vector readFirstRow = ALS.readFirstRow(getTempPath("averageRatings"), getConf());
        this.numItems = readFirstRow.getNumNondefaultElements();
        this.numUsers = (int) prepareJob4.getCounters().findCounter(Stats.NUM_USERS).getValue();
        log.info("Found {} users and {} items", Integer.valueOf(this.numUsers), Integer.valueOf(this.numItems));
        initializeM(readFirstRow);
        for (int i = 0; i < this.numIterations; i++) {
            log.info("Recomputing U (iteration {}/{})", Integer.valueOf(i), Integer.valueOf(this.numIterations));
            runSolver(pathToUserRatings(), pathToU(i), pathToM(i - 1), i, "U", this.numItems);
            log.info("Recomputing M (iteration {}/{})", Integer.valueOf(i), Integer.valueOf(this.numIterations));
            runSolver(pathToItemRatings(), pathToM(i), pathToU(i), i, DateFormat.NUM_MONTH, this.numUsers);
        }
        return 0;
    }

    private void initializeM(Vector vector) throws IOException {
        RandomWrapper random = RandomUtils.getRandom();
        SequenceFile.Writer writer = null;
        try {
            writer = new SequenceFile.Writer(FileSystem.get(pathToM(-1).toUri(), getConf()), getConf(), new Path(pathToM(-1), "part-m-00000"), IntWritable.class, VectorWritable.class);
            IntWritable intWritable = new IntWritable();
            VectorWritable vectorWritable = new VectorWritable();
            for (Vector.Element element : vector.nonZeroes()) {
                DenseVector denseVector = new DenseVector(this.numFeatures);
                denseVector.setQuick(0, element.get());
                for (int i = 1; i < this.numFeatures; i++) {
                    denseVector.setQuick(i, random.nextDouble());
                }
                intWritable.set(element.index());
                vectorWritable.set(denseVector);
                writer.append((Writable) intWritable, (Writable) vectorWritable);
            }
            Closeables.close(writer, false);
        } catch (Throwable th) {
            Closeables.close(writer, false);
            throw th;
        }
    }

    private void runSolver(Path path, Path path2, Path path3, int i, String str, int i2) throws ClassNotFoundException, IOException, InterruptedException {
        Class cls;
        String str2;
        SharingMapper.reset();
        if (this.implicitFeedback) {
            cls = SolveImplicitFeedbackMapper.class;
            str2 = "Recompute " + str + ", iteration (" + i + '/' + this.numIterations + "), (" + this.numThreadsPerSolver + " threads, " + this.numFeatures + " features, implicit feedback)";
        } else {
            cls = SolveExplicitFeedbackMapper.class;
            str2 = "Recompute " + str + ", iteration (" + i + '/' + this.numIterations + "), (" + this.numThreadsPerSolver + " threads, " + this.numFeatures + " features, explicit feedback)";
        }
        Job prepareJob = prepareJob(path, path2, SequenceFileInputFormat.class, MultithreadedSharingMapper.class, IntWritable.class, VectorWritable.class, SequenceFileOutputFormat.class, str2);
        Configuration configuration = prepareJob.getConfiguration();
        configuration.set(LAMBDA, String.valueOf(this.lambda));
        configuration.set(ALPHA, String.valueOf(this.alpha));
        configuration.setInt(NUM_FEATURES, this.numFeatures);
        configuration.set(NUM_ENTITIES, String.valueOf(i2));
        for (FileStatus fileStatus : FileSystem.get(path3.toUri(), configuration).listStatus(path3, PathFilters.partFilter())) {
            if (log.isDebugEnabled()) {
                log.debug("Adding {} to distributed cache", fileStatus.getPath().toString());
            }
            DistributedCache.addCacheFile(fileStatus.getPath().toUri(), configuration);
        }
        MultithreadedMapper.setMapperClass(prepareJob, cls);
        MultithreadedMapper.setNumberOfThreads(prepareJob, this.numThreadsPerSolver);
        if (!prepareJob.waitForCompletion(true)) {
            throw new IllegalStateException("Job failed!");
        }
    }

    private Path pathToM(int i) {
        return i == this.numIterations - 1 ? getOutputPath(DateFormat.NUM_MONTH) : getTempPath("M-" + i);
    }

    private Path pathToU(int i) {
        return i == this.numIterations - 1 ? getOutputPath("U") : getTempPath("U-" + i);
    }

    private Path pathToItemRatings() {
        return getTempPath("itemRatings");
    }

    private Path pathToUserRatings() {
        return getOutputPath("userRatings");
    }
}
