/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.cf.taste.impl.recommender.svd;

import com.carrotsearch.randomizedtesting.annotations.ThreadLeakLingering;
import java.util.Arrays;
import java.util.Iterator;
import org.apache.mahout.cf.taste.impl.TasteTestCase;
import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
import org.apache.mahout.cf.taste.impl.model.GenericDataModel;
import org.apache.mahout.cf.taste.impl.model.GenericPreference;
import org.apache.mahout.cf.taste.impl.model.GenericUserPreferenceArray;
import org.apache.mahout.cf.taste.impl.recommender.svd.ALSWRFactorizer;
import org.apache.mahout.cf.taste.impl.recommender.svd.Factorizer;
import org.apache.mahout.cf.taste.impl.recommender.svd.SVDRecommender;
import org.apache.mahout.cf.taste.model.DataModel;
import org.apache.mahout.cf.taste.model.Preference;
import org.apache.mahout.cf.taste.model.PreferenceArray;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.MatrixSlice;
import org.apache.mahout.math.SparseRowMatrix;
import org.apache.mahout.math.Vector;
import org.junit.Before;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ALSWRFactorizerTest
extends TasteTestCase {
    private ALSWRFactorizer factorizer;
    private DataModel dataModel;
    private static final Logger log = LoggerFactory.getLogger(ALSWRFactorizerTest.class);

    @Override
    @Before
    public void setUp() throws Exception {
        super.setUp();
        FastByIDMap userData = new FastByIDMap();
        userData.put(1L, (Object)new GenericUserPreferenceArray(Arrays.asList(new GenericPreference(1L, 1L, 5.0f), new GenericPreference(1L, 2L, 5.0f), new GenericPreference(1L, 3L, 2.0f))));
        userData.put(2L, (Object)new GenericUserPreferenceArray(Arrays.asList(new GenericPreference(2L, 1L, 2.0f), new GenericPreference(2L, 3L, 3.0f), new GenericPreference(2L, 4L, 5.0f))));
        userData.put(3L, (Object)new GenericUserPreferenceArray(Arrays.asList(new GenericPreference(3L, 2L, 5.0f), new GenericPreference(3L, 4L, 3.0f))));
        userData.put(4L, (Object)new GenericUserPreferenceArray(Arrays.asList(new GenericPreference(4L, 1L, 3.0f), new GenericPreference(4L, 4L, 5.0f))));
        this.dataModel = new GenericDataModel(userData);
        this.factorizer = new ALSWRFactorizer(this.dataModel, 3, 0.065, 10);
    }

    @Test
    public void setFeatureColumn() throws Exception {
        ALSWRFactorizer.Features features = new ALSWRFactorizer.Features(this.factorizer);
        DenseVector vector = new DenseVector(new double[]{0.5, 2.0, 1.5});
        int index = 1;
        features.setFeatureColumnInM(index, (Vector)vector);
        double[][] matrix = features.getM();
        ALSWRFactorizerTest.assertEquals((double)vector.get(0), (double)matrix[index][0], (double)1.0E-6);
        ALSWRFactorizerTest.assertEquals((double)vector.get(1), (double)matrix[index][1], (double)1.0E-6);
        ALSWRFactorizerTest.assertEquals((double)vector.get(2), (double)matrix[index][2], (double)1.0E-6);
    }

    @Test
    public void ratingVector() throws Exception {
        PreferenceArray prefs = this.dataModel.getPreferencesFromUser(1L);
        Vector ratingVector = ALSWRFactorizer.ratingVector((PreferenceArray)prefs);
        ALSWRFactorizerTest.assertEquals((long)prefs.length(), (long)ratingVector.getNumNondefaultElements());
        ALSWRFactorizerTest.assertEquals((double)prefs.get(0).getValue(), (double)ratingVector.get(0), (double)1.0E-6);
        ALSWRFactorizerTest.assertEquals((double)prefs.get(1).getValue(), (double)ratingVector.get(1), (double)1.0E-6);
        ALSWRFactorizerTest.assertEquals((double)prefs.get(2).getValue(), (double)ratingVector.get(2), (double)1.0E-6);
    }

    @Test
    public void averageRating() throws Exception {
        ALSWRFactorizer.Features features = new ALSWRFactorizer.Features(this.factorizer);
        ALSWRFactorizerTest.assertEquals((double)2.5, (double)features.averateRating(3L), (double)1.0E-6);
    }

    @Test
    public void initializeM() throws Exception {
        ALSWRFactorizer.Features features = new ALSWRFactorizer.Features(this.factorizer);
        double[][] M = features.getM();
        ALSWRFactorizerTest.assertEquals((double)3.333333333, (double)M[0][0], (double)1.0E-6);
        ALSWRFactorizerTest.assertEquals((double)5.0, (double)M[1][0], (double)1.0E-6);
        ALSWRFactorizerTest.assertEquals((double)2.5, (double)M[2][0], (double)1.0E-6);
        ALSWRFactorizerTest.assertEquals((double)4.333333333, (double)M[3][0], (double)1.0E-6);
        for (int itemIndex = 0; itemIndex < this.dataModel.getNumItems(); ++itemIndex) {
            for (int feature = 1; feature < 3; ++feature) {
                ALSWRFactorizerTest.assertTrue((M[itemIndex][feature] >= 0.0 ? 1 : 0) != 0);
                ALSWRFactorizerTest.assertTrue((M[itemIndex][feature] <= 0.1 ? 1 : 0) != 0);
            }
        }
    }

    @ThreadLeakLingering(linger=10)
    @Test
    public void toyExample() throws Exception {
        SVDRecommender svdRecommender = new SVDRecommender(this.dataModel, (Factorizer)this.factorizer);
        FullRunningAverage avg = new FullRunningAverage();
        LongPrimitiveIterator userIDs = this.dataModel.getUserIDs();
        while (userIDs.hasNext()) {
            long userID = userIDs.nextLong();
            for (Preference pref : this.dataModel.getPreferencesFromUser(userID)) {
                double rating = pref.getValue();
                double estimate = svdRecommender.estimatePreference(userID, pref.getItemID());
                double err = rating - estimate;
                avg.addDatum(err * err);
            }
        }
        double rmse = Math.sqrt(avg.getAverage());
        ALSWRFactorizerTest.assertTrue((rmse < 0.2 ? 1 : 0) != 0);
    }

    @Test
    public void toyExampleImplicit() throws Exception {
        SparseRowMatrix observations = new SparseRowMatrix(4, 4, new Vector[]{new DenseVector(new double[]{5.0, 5.0, 2.0, 0.0}), new DenseVector(new double[]{2.0, 0.0, 3.0, 5.0}), new DenseVector(new double[]{0.0, 5.0, 0.0, 3.0}), new DenseVector(new double[]{3.0, 0.0, 0.0, 5.0})});
        SparseRowMatrix preferences = new SparseRowMatrix(4, 4, new Vector[]{new DenseVector(new double[]{1.0, 1.0, 1.0, 0.0}), new DenseVector(new double[]{1.0, 0.0, 1.0, 1.0}), new DenseVector(new double[]{0.0, 1.0, 0.0, 1.0}), new DenseVector(new double[]{1.0, 0.0, 0.0, 1.0})});
        double alpha = 20.0;
        ALSWRFactorizer factorizer = new ALSWRFactorizer(this.dataModel, 3, 0.065, 5, true, alpha);
        SVDRecommender svdRecommender = new SVDRecommender(this.dataModel, (Factorizer)factorizer);
        FullRunningAverage avg = new FullRunningAverage();
        Iterator sliceIterator = preferences.iterateAll();
        while (sliceIterator.hasNext()) {
            MatrixSlice slice = (MatrixSlice)sliceIterator.next();
            for (Vector.Element e : slice.vector().all()) {
                long userID = slice.index() + 1;
                long itemID = e.index() + 1;
                if (Double.isNaN(e.get())) continue;
                double pref = e.get();
                double estimate = svdRecommender.estimatePreference(userID, itemID);
                double confidence = 1.0 + alpha * observations.getQuick(slice.index(), e.index());
                double err = confidence * (pref - estimate) * (pref - estimate);
                avg.addDatum(err);
                log.info("Comparing preference of user [{}] towards item [{}], was [{}] with confidence [{}] estimate is [{}]", new Object[]{slice.index(), e.index(), pref, confidence, estimate});
            }
        }
        double rmse = Math.sqrt(avg.getAverage());
        log.info("RMSE: {}", (Object)rmse);
        ALSWRFactorizerTest.assertTrue((rmse < 0.4 ? 1 : 0) != 0);
    }
}

