package org.apache.mahout.math.hadoop.stochasticsvd;

import com.google.common.collect.Lists;
import com.ibm.icu.text.DateFormat;
import java.io.Closeable;
import java.io.IOException;
import java.text.NumberFormat;
import java.util.ArrayDeque;
import java.util.Arrays;
import java.util.Deque;
import java.util.Iterator;
import java.util.regex.Matcher;
import org.apache.commons.lang3.Validate;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.filecache.DistributedCache;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.OutputCollector;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.common.IOUtils;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterator;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.function.Functions;
import org.apache.mahout.math.hadoop.stochasticsvd.qr.QRFirstStep;

/* loaded from: input_file:org/apache/mahout/math/hadoop/stochasticsvd/ABtDenseOutJob.class */
public final class ABtDenseOutJob {
    public static final String PROP_BT_PATH = "ssvd.Bt.path";
    public static final String PROP_BT_BROADCAST = "ssvd.Bt.broadcast";
    public static final String PROP_SB_PATH = "ssvdpca.sb.path";
    public static final String PROP_SQ_PATH = "ssvdpca.sq.path";
    public static final String PROP_XI_PATH = "ssvdpca.xi.path";

    /* loaded from: input_file:org/apache/mahout/math/hadoop/stochasticsvd/ABtDenseOutJob$ABtMapper.class */
    public static class ABtMapper extends Mapper<Writable, VectorWritable, SplitPartitionedWritable, DenseBlockWritable> {
        private SplitPartitionedWritable outKey;
        private final Deque<Closeable> closeables = new ArrayDeque();
        private SequenceFileDirIterator<IntWritable, VectorWritable> btInput;
        private Vector[] aCols;
        private double[][] yiCols;
        private int aRowCount;
        private int kp;
        private int blockHeight;
        private boolean distributedBt;
        private Path[] btLocalPath;
        private Configuration localFsConfig;
        protected Vector xi;
        protected Vector sq;

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // org.apache.hadoop.mapreduce.Mapper
        public void map(Writable writable, VectorWritable vectorWritable, Mapper<Writable, VectorWritable, SplitPartitionedWritable, DenseBlockWritable>.Context context) throws IOException, InterruptedException {
            Vector vector = vectorWritable.get();
            int size = vector.size();
            if (this.aCols == null) {
                this.aCols = new Vector[size];
            } else if (this.aCols.length < size) {
                this.aCols = (Vector[]) Arrays.copyOf(this.aCols, size);
            }
            if (vector.isDense()) {
                for (int i = 0; i < size; i++) {
                    extendAColIfNeeded(i, this.aRowCount + 1);
                    this.aCols[i].setQuick(this.aRowCount, vector.getQuick(i));
                }
            } else if (vector.size() > 0) {
                for (Vector.Element element : vector.nonZeroes()) {
                    int index = element.index();
                    extendAColIfNeeded(index, this.aRowCount + 1);
                    this.aCols[index].setQuick(this.aRowCount, element.get());
                }
            }
            this.aRowCount++;
        }

        private void extendAColIfNeeded(int i, int i2) {
            if (this.aCols[i] == null) {
                this.aCols[i] = new SequentialAccessSparseVector(i2 < this.blockHeight ? this.blockHeight : i2, 1);
            } else if (this.aCols[i].size() < i2) {
                SequentialAccessSparseVector sequentialAccessSparseVector = new SequentialAccessSparseVector(i2 + this.blockHeight, this.aCols[i].getNumNondefaultElements() << 1);
                sequentialAccessSparseVector.viewPart(0, this.aCols[i].size()).assign(this.aCols[i]);
                this.aCols[i] = sequentialAccessSparseVector;
            }
        }

        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Type inference failed for: r1v2, types: [double[], double[][]] */
        @Override // org.apache.hadoop.mapreduce.Mapper
        public void cleanup(Mapper<Writable, VectorWritable, SplitPartitionedWritable, DenseBlockWritable>.Context context) throws IOException, InterruptedException {
            Vector vector;
            try {
                this.yiCols = new double[this.kp];
                for (int i = 0; i < this.kp; i++) {
                    this.yiCols[i] = new double[Math.min(this.aRowCount, this.blockHeight)];
                }
                int i2 = ((this.aRowCount - 1) / this.blockHeight) + 1;
                String str = context.getConfiguration().get("ssvd.Bt.path");
                Validate.notNull(str, "Bt input is not set", new Object[0]);
                Path path = new Path(str);
                DenseBlockWritable denseBlockWritable = new DenseBlockWritable();
                int i3 = -1;
                for (int i4 = 0; i4 < i2; i4++) {
                    if (this.distributedBt) {
                        this.btInput = new SequenceFileDirIterator<>(this.btLocalPath, true, this.localFsConfig);
                    } else {
                        this.btInput = new SequenceFileDirIterator<>(path, PathType.GLOB, null, null, true, context.getConfiguration());
                    }
                    this.closeables.addFirst(this.btInput);
                    Validate.isTrue(this.btInput.hasNext(), "Empty B' input!", new Object[0]);
                    int i5 = i4 * this.blockHeight;
                    int min = Math.min(this.blockHeight, this.aRowCount - i5);
                    if (i4 > 0) {
                        if (min == this.blockHeight) {
                            for (int i6 = 0; i6 < this.kp; i6++) {
                                Arrays.fill(this.yiCols[i6], 0.0d);
                            }
                        } else {
                            for (int i7 = 0; i7 < this.kp; i7++) {
                                this.yiCols[i7] = null;
                            }
                            for (int i8 = 0; i8 < this.kp; i8++) {
                                this.yiCols[i8] = new double[min];
                            }
                        }
                    }
                    while (this.btInput.hasNext()) {
                        Pair next = this.btInput.next();
                        int i9 = ((IntWritable) next.getFirst()).get();
                        Vector vector2 = ((VectorWritable) next.getSecond()).get();
                        if (i9 <= this.aCols.length && (vector = this.aCols[i9]) != null && vector.size() != 0) {
                            int i10 = -1;
                            for (Vector.Element element : vector.nonZeroes()) {
                                i10 = element.index();
                                if (i10 >= i5) {
                                    if (i10 >= i5 + min) {
                                        break;
                                    }
                                    if (this.xi != null) {
                                        for (int i11 = 0; i11 < this.kp; i11++) {
                                            double d = this.xi.size() > i9 ? this.xi.get(i9) : 0.0d;
                                            double[] dArr = this.yiCols[i11];
                                            int i12 = i10 - i5;
                                            dArr[i12] = dArr[i12] + (element.get() * (vector2.getQuick(i11) - (d * this.sq.get(i11))));
                                        }
                                    } else {
                                        for (int i13 = 0; i13 < this.kp; i13++) {
                                            double[] dArr2 = this.yiCols[i13];
                                            int i14 = i10 - i5;
                                            dArr2[i14] = dArr2[i14] + (element.get() * vector2.getQuick(i13));
                                        }
                                    }
                                }
                            }
                            if (i3 < i10) {
                                i3 = i10;
                            }
                        }
                    }
                    denseBlockWritable.setBlock(this.yiCols);
                    this.outKey.setTaskItemOrdinal(i4);
                    context.write(this.outKey, denseBlockWritable);
                    this.closeables.remove(this.btInput);
                    this.btInput.close();
                }
            } finally {
                IOUtils.close(this.closeables);
            }
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // org.apache.hadoop.mapreduce.Mapper
        public void setup(Mapper<Writable, VectorWritable, SplitPartitionedWritable, DenseBlockWritable>.Context context) throws IOException, InterruptedException {
            Configuration configuration = context.getConfiguration();
            this.kp = Integer.parseInt(configuration.get("ssvd.k")) + Integer.parseInt(configuration.get("ssvd.p"));
            this.outKey = new SplitPartitionedWritable(context);
            this.blockHeight = configuration.getInt(BtJob.PROP_OUTER_PROD_BLOCK_HEIGHT, -1);
            this.distributedBt = configuration.get("ssvd.Bt.broadcast") != null;
            if (this.distributedBt) {
                this.btLocalPath = HadoopUtil.getCachedFiles(configuration);
                this.localFsConfig = new Configuration();
                this.localFsConfig.set("fs.default.name", "file:///");
            }
            String str = configuration.get("ssvdpca.xi.path");
            if (str != null) {
                this.xi = SSVDHelper.loadAndSumUpVectors(new Path(str), configuration);
                this.sq = SSVDHelper.loadAndSumUpVectors(new Path(configuration.get("ssvdpca.sq.path")), configuration);
            }
        }
    }

    /* loaded from: input_file:org/apache/mahout/math/hadoop/stochasticsvd/ABtDenseOutJob$QRReducer.class */
    public static class QRReducer extends Reducer<SplitPartitionedWritable, DenseBlockWritable, SplitPartitionedWritable, VectorWritable> {
        private static final NumberFormat NUMBER_FORMAT = NumberFormat.getInstance();
        protected int blockHeight;
        protected int accumSize;
        protected OutputCollector<Writable, DenseBlockWritable> qhatCollector;
        protected OutputCollector<Writable, VectorWritable> rhatCollector;
        protected QRFirstStep qr;
        protected Vector yiRow;
        protected Vector sb;
        private final Deque<Closeable> closeables = Lists.newLinkedList();
        protected int lastTaskId = -1;

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // org.apache.hadoop.mapreduce.Reducer
        public void setup(Reducer<SplitPartitionedWritable, DenseBlockWritable, SplitPartitionedWritable, VectorWritable>.Context context) throws IOException, InterruptedException {
            Configuration configuration = context.getConfiguration();
            this.blockHeight = configuration.getInt(BtJob.PROP_OUTER_PROD_BLOCK_HEIGHT, -1);
            String str = configuration.get("ssvdpca.sb.path");
            if (str != null) {
                this.sb = SSVDHelper.loadAndSumUpVectors(new Path(str), configuration);
            }
        }

        protected void setupBlock(Reducer<SplitPartitionedWritable, DenseBlockWritable, SplitPartitionedWritable, VectorWritable>.Context context, SplitPartitionedWritable splitPartitionedWritable) throws InterruptedException, IOException {
            IOUtils.close(this.closeables);
            this.qhatCollector = createOutputCollector(QJob.OUTPUT_QHAT, splitPartitionedWritable, context, DenseBlockWritable.class);
            this.rhatCollector = createOutputCollector(QJob.OUTPUT_RHAT, splitPartitionedWritable, context, VectorWritable.class);
            this.qr = new QRFirstStep(context.getConfiguration(), this.qhatCollector, this.rhatCollector);
            this.closeables.addFirst(this.qr);
            this.lastTaskId = splitPartitionedWritable.getTaskId();
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // org.apache.hadoop.mapreduce.Reducer
        public void reduce(SplitPartitionedWritable splitPartitionedWritable, Iterable<DenseBlockWritable> iterable, Reducer<SplitPartitionedWritable, DenseBlockWritable, SplitPartitionedWritable, VectorWritable>.Context context) throws IOException, InterruptedException {
            if (splitPartitionedWritable.getTaskId() != this.lastTaskId) {
                setupBlock(context, splitPartitionedWritable);
            }
            Iterator<DenseBlockWritable> it = iterable.iterator();
            double[][] block = it.next().getBlock();
            if (it.hasNext()) {
                throw new IOException("Unexpected extra Y_i block in reducer input.");
            }
            long taskItemOrdinal = splitPartitionedWritable.getTaskItemOrdinal() * this.blockHeight;
            int length = block[0].length;
            if (this.yiRow == null) {
                this.yiRow = new DenseVector(block.length);
            }
            for (int i = 0; i < length; i++) {
                for (int i2 = 0; i2 < block.length; i2++) {
                    this.yiRow.setQuick(i2, block[i2][i]);
                }
                splitPartitionedWritable.setTaskItemOrdinal(taskItemOrdinal + i);
                if (this.sb != null) {
                    this.yiRow.assign(this.sb, Functions.MINUS);
                }
                this.qr.collect((Writable) splitPartitionedWritable, this.yiRow);
            }
        }

        private Path getSplitFilePath(String str, SplitPartitionedWritable splitPartitionedWritable, Reducer<SplitPartitionedWritable, DenseBlockWritable, SplitPartitionedWritable, VectorWritable>.Context context) throws InterruptedException, IOException {
            return new Path(FileOutputFormat.getWorkOutputPath(context), FileOutputFormat.getUniqueFile(context, str, "").replaceFirst("-r-", "-m-").replaceFirst("\\d+$", Matcher.quoteReplacement(NUMBER_FORMAT.format(splitPartitionedWritable.getTaskId()))));
        }

        private <K, V> OutputCollector<K, V> createOutputCollector(String str, final SplitPartitionedWritable splitPartitionedWritable, Reducer<SplitPartitionedWritable, DenseBlockWritable, SplitPartitionedWritable, VectorWritable>.Context context, Class<V> cls) throws IOException, InterruptedException {
            Path splitFilePath = getSplitFilePath(str, splitPartitionedWritable, context);
            final SequenceFile.Writer createWriter = SequenceFile.createWriter(FileSystem.get(splitFilePath.toUri(), context.getConfiguration()), context.getConfiguration(), splitFilePath, SplitPartitionedWritable.class, cls);
            this.closeables.addFirst(createWriter);
            return new OutputCollector<K, V>() { // from class: org.apache.mahout.math.hadoop.stochasticsvd.ABtDenseOutJob.QRReducer.1
                @Override // org.apache.hadoop.mapred.OutputCollector
                public void collect(K k, V v) throws IOException {
                    createWriter.append(splitPartitionedWritable, v);
                }
            };
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // org.apache.hadoop.mapreduce.Reducer
        public void cleanup(Reducer<SplitPartitionedWritable, DenseBlockWritable, SplitPartitionedWritable, VectorWritable>.Context context) throws IOException, InterruptedException {
            IOUtils.close(this.closeables);
        }

        static {
            NUMBER_FORMAT.setMinimumIntegerDigits(5);
            NUMBER_FORMAT.setGroupingUsed(false);
        }
    }

    private ABtDenseOutJob() {
    }

    public static void run(Configuration configuration, Path[] pathArr, Path path, Path path2, Path path3, Path path4, Path path5, int i, int i2, int i3, int i4, int i5, int i6, boolean z) throws ClassNotFoundException, InterruptedException, IOException {
        Job job = new Job((Configuration) new JobConf(configuration));
        job.setJobName("ABt-job");
        job.setJarByClass(ABtDenseOutJob.class);
        job.setInputFormatClass(SequenceFileInputFormat.class);
        FileInputFormat.setInputPaths(job, pathArr);
        if (i2 > 0) {
            FileInputFormat.setMinInputSplitSize(job, i2);
        }
        FileOutputFormat.setOutputPath(job, path5);
        SequenceFileOutputFormat.setOutputCompressionType(job, SequenceFile.CompressionType.BLOCK);
        job.setMapOutputKeyClass(SplitPartitionedWritable.class);
        job.setMapOutputValueClass(DenseBlockWritable.class);
        job.setOutputKeyClass(SplitPartitionedWritable.class);
        job.setOutputValueClass(VectorWritable.class);
        job.setMapperClass(ABtMapper.class);
        job.setReducerClass(QRReducer.class);
        job.getConfiguration().setInt("ssvd.arowblock.size", i);
        job.getConfiguration().setInt(BtJob.PROP_OUTER_PROD_BLOCK_HEIGHT, i5);
        job.getConfiguration().setInt("ssvd.k", i3);
        job.getConfiguration().setInt("ssvd.p", i4);
        job.getConfiguration().set("ssvd.Bt.path", path.toString());
        if (path2 != null) {
            job.getConfiguration().set("ssvdpca.xi.path", path2.toString());
            job.getConfiguration().set("ssvdpca.sb.path", path4.toString());
            job.getConfiguration().set("ssvdpca.sq.path", path3.toString());
        }
        job.setNumReduceTasks(i6);
        if (z) {
            job.getConfiguration().set("ssvd.Bt.broadcast", DateFormat.YEAR);
            FileStatus[] globStatus = FileSystem.get(path.toUri(), configuration).globStatus(path);
            if (globStatus != null) {
                for (FileStatus fileStatus : globStatus) {
                    DistributedCache.addCacheFile(fileStatus.getPath().toUri(), job.getConfiguration());
                }
            }
        }
        job.submit();
        job.waitForCompletion(false);
        if (!job.isSuccessful()) {
            throw new IOException("ABt job unsuccessful.");
        }
    }
}
