package org.apache.mahout.cf.taste.impl.recommender.svd;

import java.util.Iterator;
import org.apache.mahout.cf.taste.common.TasteException;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
import org.apache.mahout.cf.taste.model.DataModel;
import org.apache.mahout.cf.taste.model.Preference;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.RandomWrapper;

/* loaded from: input_file:org/apache/mahout/cf/taste/impl/recommender/svd/RatingSGDFactorizer.class */
public class RatingSGDFactorizer extends AbstractFactorizer {
    protected static final int FEATURE_OFFSET = 3;
    protected final double learningRateDecay;
    protected final double learningRate;
    protected final double preventOverfitting;
    protected final int numFeatures;
    private final int numIterations;
    protected final double randomNoise;
    protected double[][] userVectors;
    protected double[][] itemVectors;
    protected final DataModel dataModel;
    private long[] cachedUserIDs;
    private long[] cachedItemIDs;
    protected double biasLearningRate;
    protected double biasReg;
    protected static final int USER_BIAS_INDEX = 1;
    protected static final int ITEM_BIAS_INDEX = 2;

    public RatingSGDFactorizer(DataModel dataModel, int i, int i2) throws TasteException {
        this(dataModel, i, 0.01d, 0.1d, 0.01d, i2, 1.0d);
    }

    public RatingSGDFactorizer(DataModel dataModel, int i, double d, double d2, double d3, int i2, double d4) throws TasteException {
        super(dataModel);
        this.biasLearningRate = 0.5d;
        this.biasReg = 0.1d;
        this.dataModel = dataModel;
        this.numFeatures = i + FEATURE_OFFSET;
        this.numIterations = i2;
        this.learningRate = d;
        this.learningRateDecay = d4;
        this.preventOverfitting = d2;
        this.randomNoise = d3;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void prepareTraining() throws TasteException {
        RandomWrapper random = RandomUtils.getRandom();
        this.userVectors = new double[this.dataModel.getNumUsers()][this.numFeatures];
        this.itemVectors = new double[this.dataModel.getNumItems()][this.numFeatures];
        double averagePreference = getAveragePreference();
        for (int i = 0; i < this.userVectors.length; i++) {
            this.userVectors[i][0] = averagePreference;
            this.userVectors[i][1] = 0.0d;
            this.userVectors[i][2] = 1.0d;
            for (int i2 = FEATURE_OFFSET; i2 < this.numFeatures; i2++) {
                this.userVectors[i][i2] = random.nextGaussian() * this.randomNoise;
            }
        }
        for (int i3 = 0; i3 < this.itemVectors.length; i3++) {
            this.itemVectors[i3][0] = 1.0d;
            this.itemVectors[i3][1] = 1.0d;
            this.itemVectors[i3][2] = 0.0d;
            for (int i4 = FEATURE_OFFSET; i4 < this.numFeatures; i4++) {
                this.itemVectors[i3][i4] = random.nextGaussian() * this.randomNoise;
            }
        }
        cachePreferences();
        shufflePreferences();
    }

    private int countPreferences() throws TasteException {
        int i = 0;
        LongPrimitiveIterator userIDs = this.dataModel.getUserIDs();
        while (userIDs.hasNext()) {
            i += this.dataModel.getPreferencesFromUser(userIDs.nextLong()).length();
        }
        return i;
    }

    private void cachePreferences() throws TasteException {
        int countPreferences = countPreferences();
        this.cachedUserIDs = new long[countPreferences];
        this.cachedItemIDs = new long[countPreferences];
        LongPrimitiveIterator userIDs = this.dataModel.getUserIDs();
        int i = 0;
        while (userIDs.hasNext()) {
            long nextLong = userIDs.nextLong();
            for (Preference preference : this.dataModel.getPreferencesFromUser(nextLong)) {
                this.cachedUserIDs[i] = nextLong;
                this.cachedItemIDs[i] = preference.getItemID();
                i++;
            }
        }
    }

    protected void shufflePreferences() {
        RandomWrapper random = RandomUtils.getRandom();
        for (int length = this.cachedUserIDs.length - 1; length > 0; length--) {
            swapCachedPreferences(length, random.nextInt(length + 1));
        }
    }

    private void swapCachedPreferences(int i, int i2) {
        long j = this.cachedUserIDs[i];
        long j2 = this.cachedItemIDs[i];
        this.cachedUserIDs[i] = this.cachedUserIDs[i2];
        this.cachedItemIDs[i] = this.cachedItemIDs[i2];
        this.cachedUserIDs[i2] = j;
        this.cachedItemIDs[i2] = j2;
    }

    @Override // org.apache.mahout.cf.taste.impl.recommender.svd.Factorizer
    public Factorization factorize() throws TasteException {
        prepareTraining();
        double d = this.learningRate;
        for (int i = 0; i < this.numIterations; i++) {
            for (int i2 = 0; i2 < this.cachedUserIDs.length; i2++) {
                long j = this.cachedUserIDs[i2];
                long j2 = this.cachedItemIDs[i2];
                updateParameters(j, j2, this.dataModel.getPreferenceValue(j, j2).floatValue(), d);
            }
            d *= this.learningRateDecay;
        }
        return createFactorization(this.userVectors, this.itemVectors);
    }

    double getAveragePreference() throws TasteException {
        FullRunningAverage fullRunningAverage = new FullRunningAverage();
        LongPrimitiveIterator userIDs = this.dataModel.getUserIDs();
        while (userIDs.hasNext()) {
            Iterator<Preference> it = this.dataModel.getPreferencesFromUser(userIDs.nextLong()).iterator();
            while (it.hasNext()) {
                fullRunningAverage.addDatum(it.next().getValue());
            }
        }
        return fullRunningAverage.getAverage();
    }

    protected void updateParameters(long j, long j2, float f, double d) {
        int intValue = userIndex(j).intValue();
        int intValue2 = itemIndex(j2).intValue();
        double[] dArr = this.userVectors[intValue];
        double[] dArr2 = this.itemVectors[intValue2];
        double predictRating = f - predictRating(intValue, intValue2);
        dArr[1] = dArr[1] + (this.biasLearningRate * d * (predictRating - ((this.biasReg * this.preventOverfitting) * dArr[1])));
        dArr2[2] = dArr2[2] + (this.biasLearningRate * d * (predictRating - ((this.biasReg * this.preventOverfitting) * dArr2[2])));
        for (int i = FEATURE_OFFSET; i < this.numFeatures; i++) {
            double d2 = dArr[i];
            double d3 = dArr2[i];
            int i2 = i;
            dArr[i2] = dArr[i2] + (d * ((predictRating * d3) - (this.preventOverfitting * d2)));
            int i3 = i;
            dArr2[i3] = dArr2[i3] + (d * ((predictRating * d2) - (this.preventOverfitting * d3)));
        }
    }

    private double predictRating(int i, int i2) {
        double d = 0.0d;
        for (int i3 = 0; i3 < this.numFeatures; i3++) {
            d += this.userVectors[i][i3] * this.itemVectors[i2][i3];
        }
        return d;
    }
}
