/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.cf.taste.impl.recommender.svd;

import org.apache.mahout.cf.taste.common.TasteException;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
import org.apache.mahout.cf.taste.impl.recommender.svd.AbstractFactorizer;
import org.apache.mahout.cf.taste.impl.recommender.svd.Factorization;
import org.apache.mahout.cf.taste.model.DataModel;
import org.apache.mahout.cf.taste.model.Preference;
import org.apache.mahout.cf.taste.model.PreferenceArray;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.RandomWrapper;

public class RatingSGDFactorizer
extends AbstractFactorizer {
    protected static final int FEATURE_OFFSET = 3;
    protected final double learningRateDecay;
    protected final double learningRate;
    protected final double preventOverfitting;
    protected final int numFeatures;
    private final int numIterations;
    protected final double randomNoise;
    protected double[][] userVectors;
    protected double[][] itemVectors;
    protected final DataModel dataModel;
    private long[] cachedUserIDs;
    private long[] cachedItemIDs;
    protected double biasLearningRate = 0.5;
    protected double biasReg = 0.1;
    protected static final int USER_BIAS_INDEX = 1;
    protected static final int ITEM_BIAS_INDEX = 2;

    public RatingSGDFactorizer(DataModel dataModel, int numFeatures, int numIterations) throws TasteException {
        this(dataModel, numFeatures, 0.01, 0.1, 0.01, numIterations, 1.0);
    }

    public RatingSGDFactorizer(DataModel dataModel, int numFeatures, double learningRate, double preventOverfitting, double randomNoise, int numIterations, double learningRateDecay) throws TasteException {
        super(dataModel);
        this.dataModel = dataModel;
        this.numFeatures = numFeatures + 3;
        this.numIterations = numIterations;
        this.learningRate = learningRate;
        this.learningRateDecay = learningRateDecay;
        this.preventOverfitting = preventOverfitting;
        this.randomNoise = randomNoise;
    }

    protected void prepareTraining() throws TasteException {
        int feature;
        RandomWrapper random = RandomUtils.getRandom();
        this.userVectors = new double[this.dataModel.getNumUsers()][this.numFeatures];
        this.itemVectors = new double[this.dataModel.getNumItems()][this.numFeatures];
        double globalAverage = this.getAveragePreference();
        for (int userIndex = 0; userIndex < this.userVectors.length; ++userIndex) {
            this.userVectors[userIndex][0] = globalAverage;
            this.userVectors[userIndex][1] = 0.0;
            this.userVectors[userIndex][2] = 1.0;
            for (feature = 3; feature < this.numFeatures; ++feature) {
                this.userVectors[userIndex][feature] = random.nextGaussian() * this.randomNoise;
            }
        }
        for (int itemIndex = 0; itemIndex < this.itemVectors.length; ++itemIndex) {
            this.itemVectors[itemIndex][0] = 1.0;
            this.itemVectors[itemIndex][1] = 1.0;
            this.itemVectors[itemIndex][2] = 0.0;
            for (feature = 3; feature < this.numFeatures; ++feature) {
                this.itemVectors[itemIndex][feature] = random.nextGaussian() * this.randomNoise;
            }
        }
        this.cachePreferences();
        this.shufflePreferences();
    }

    private int countPreferences() throws TasteException {
        int numPreferences = 0;
        LongPrimitiveIterator userIDs = this.dataModel.getUserIDs();
        while (userIDs.hasNext()) {
            PreferenceArray preferencesFromUser = this.dataModel.getPreferencesFromUser(userIDs.nextLong());
            numPreferences += preferencesFromUser.length();
        }
        return numPreferences;
    }

    private void cachePreferences() throws TasteException {
        int numPreferences = this.countPreferences();
        this.cachedUserIDs = new long[numPreferences];
        this.cachedItemIDs = new long[numPreferences];
        LongPrimitiveIterator userIDs = this.dataModel.getUserIDs();
        int index = 0;
        while (userIDs.hasNext()) {
            long userID = userIDs.nextLong();
            PreferenceArray preferencesFromUser = this.dataModel.getPreferencesFromUser(userID);
            for (Preference preference : preferencesFromUser) {
                this.cachedUserIDs[index] = userID;
                this.cachedItemIDs[index] = preference.getItemID();
                ++index;
            }
        }
    }

    protected void shufflePreferences() {
        RandomWrapper random = RandomUtils.getRandom();
        for (int currentPos = this.cachedUserIDs.length - 1; currentPos > 0; --currentPos) {
            int swapPos = random.nextInt(currentPos + 1);
            this.swapCachedPreferences(currentPos, swapPos);
        }
    }

    private void swapCachedPreferences(int posA, int posB) {
        long tmpUserIndex = this.cachedUserIDs[posA];
        long tmpItemIndex = this.cachedItemIDs[posA];
        this.cachedUserIDs[posA] = this.cachedUserIDs[posB];
        this.cachedItemIDs[posA] = this.cachedItemIDs[posB];
        this.cachedUserIDs[posB] = tmpUserIndex;
        this.cachedItemIDs[posB] = tmpItemIndex;
    }

    @Override
    public Factorization factorize() throws TasteException {
        this.prepareTraining();
        double currentLearningRate = this.learningRate;
        for (int it = 0; it < this.numIterations; ++it) {
            for (int index = 0; index < this.cachedUserIDs.length; ++index) {
                long userId = this.cachedUserIDs[index];
                long itemId = this.cachedItemIDs[index];
                float rating = this.dataModel.getPreferenceValue(userId, itemId).floatValue();
                this.updateParameters(userId, itemId, rating, currentLearningRate);
            }
            currentLearningRate *= this.learningRateDecay;
        }
        return this.createFactorization(this.userVectors, this.itemVectors);
    }

    double getAveragePreference() throws TasteException {
        FullRunningAverage average = new FullRunningAverage();
        LongPrimitiveIterator it = this.dataModel.getUserIDs();
        while (it.hasNext()) {
            for (Preference pref : this.dataModel.getPreferencesFromUser(it.nextLong())) {
                average.addDatum(pref.getValue());
            }
        }
        return average.getAverage();
    }

    protected void updateParameters(long userID, long itemID, float rating, double currentLearningRate) {
        int userIndex = this.userIndex(userID);
        int itemIndex = this.itemIndex(itemID);
        double[] userVector = this.userVectors[userIndex];
        double[] itemVector = this.itemVectors[itemIndex];
        double prediction = this.predictRating(userIndex, itemIndex);
        double err = (double)rating - prediction;
        userVector[1] = userVector[1] + this.biasLearningRate * currentLearningRate * (err - this.biasReg * this.preventOverfitting * userVector[1]);
        itemVector[2] = itemVector[2] + this.biasLearningRate * currentLearningRate * (err - this.biasReg * this.preventOverfitting * itemVector[2]);
        int feature = 3;
        while (feature < this.numFeatures) {
            double userFeature = userVector[feature];
            double itemFeature = itemVector[feature];
            double deltaUserFeature = err * itemFeature - this.preventOverfitting * userFeature;
            int n = feature;
            userVector[n] = userVector[n] + currentLearningRate * deltaUserFeature;
            double deltaItemFeature = err * userFeature - this.preventOverfitting * itemFeature;
            int n2 = feature++;
            itemVector[n2] = itemVector[n2] + currentLearningRate * deltaItemFeature;
        }
    }

    private double predictRating(int userID, int itemID) {
        double sum = 0.0;
        for (int feature = 0; feature < this.numFeatures; ++feature) {
            sum += this.userVectors[userID][feature] * this.itemVectors[itemID][feature];
        }
        return sum;
    }
}

