package org.apache.mahout.cf.taste.impl.recommender.svd;

import com.carrotsearch.randomizedtesting.annotations.ThreadLeakLingering;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import org.apache.mahout.cf.taste.impl.TasteTestCase;
import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
import org.apache.mahout.cf.taste.impl.model.GenericDataModel;
import org.apache.mahout.cf.taste.impl.model.GenericPreference;
import org.apache.mahout.cf.taste.impl.model.GenericUserPreferenceArray;
import org.apache.mahout.cf.taste.impl.recommender.svd.ParallelSGDFactorizer;
import org.apache.mahout.cf.taste.model.DataModel;
import org.apache.mahout.cf.taste.model.Preference;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.RandomWrapper;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.DoubleFunction;
import org.apache.mahout.math.function.VectorFunction;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/mahout/cf/taste/impl/recommender/svd/ParallelSGDFactorizerTest.class */
public class ParallelSGDFactorizerTest extends TasteTestCase {
    protected DataModel dataModel;
    protected int rank;
    protected double lambda;
    protected int numIterations;
    private RandomWrapper random = RandomUtils.getRandom();
    protected Factorizer factorizer;
    protected SVDRecommender svdRecommender;
    private static final Logger logger = LoggerFactory.getLogger(ParallelSGDFactorizerTest.class);

    private Matrix randomMatrix(int i, int i2, double d) {
        double[][] dArr = new double[i][i2];
        for (int i3 = 0; i3 < i; i3++) {
            for (int i4 = 0; i4 < i2; i4++) {
                dArr[i3][i4] = this.random.nextDouble() * d;
            }
        }
        return new DenseMatrix(dArr);
    }

    private void normalize(Matrix matrix, final double d) {
        final double maxValue = matrix.aggregateColumns(new VectorFunction() { // from class: org.apache.mahout.cf.taste.impl.recommender.svd.ParallelSGDFactorizerTest.1
            public double apply(Vector vector) {
                return vector.maxValue();
            }
        }).maxValue();
        final double minValue = matrix.aggregateColumns(new VectorFunction() { // from class: org.apache.mahout.cf.taste.impl.recommender.svd.ParallelSGDFactorizerTest.2
            public double apply(Vector vector) {
                return vector.minValue();
            }
        }).minValue();
        matrix.assign(new DoubleFunction() { // from class: org.apache.mahout.cf.taste.impl.recommender.svd.ParallelSGDFactorizerTest.3
            public double apply(double d2) {
                return ((d2 - minValue) * d) / (maxValue - minValue);
            }
        });
    }

    public void setUpSyntheticData() throws Exception {
        this.rank = 20;
        this.lambda = 1.0E-9d;
        this.numIterations = 100;
        Matrix times = randomMatrix(2000, this.rank, 1.0d).times(randomMatrix(this.rank, 1000, 1.0d));
        normalize(times, 5.0d);
        FastByIDMap fastByIDMap = new FastByIDMap();
        for (int i = 0; i < 2000; i++) {
            ArrayList newArrayList = Lists.newArrayList();
            for (int i2 = 0; i2 < 1000; i2++) {
                if (this.random.nextDouble() <= 0.5d) {
                    newArrayList.add(new GenericPreference(i, i2, (float) times.get(i, i2)));
                }
            }
            fastByIDMap.put(i, new GenericUserPreferenceArray(newArrayList));
        }
        this.dataModel = new GenericDataModel(fastByIDMap);
    }

    public void setUpToyData() throws Exception {
        this.rank = 3;
        this.lambda = 0.01d;
        this.numIterations = 1000;
        FastByIDMap fastByIDMap = new FastByIDMap();
        fastByIDMap.put(1L, new GenericUserPreferenceArray(Arrays.asList(new GenericPreference(1L, 1L, 5.0f), new GenericPreference(1L, 2L, 5.0f), new GenericPreference(1L, 3L, 2.0f))));
        fastByIDMap.put(2L, new GenericUserPreferenceArray(Arrays.asList(new GenericPreference(2L, 1L, 2.0f), new GenericPreference(2L, 3L, 3.0f), new GenericPreference(2L, 4L, 5.0f))));
        fastByIDMap.put(3L, new GenericUserPreferenceArray(Arrays.asList(new GenericPreference(3L, 2L, 5.0f), new GenericPreference(3L, 4L, 3.0f))));
        fastByIDMap.put(4L, new GenericUserPreferenceArray(Arrays.asList(new GenericPreference(4L, 1L, 3.0f), new GenericPreference(4L, 4L, 5.0f))));
        this.dataModel = new GenericDataModel(fastByIDMap);
    }

    @Test
    public void testPreferenceShufflerWithSyntheticData() throws Exception {
        setUpSyntheticData();
        ParallelSGDFactorizer.PreferenceShuffler preferenceShuffler = new ParallelSGDFactorizer.PreferenceShuffler(this.dataModel);
        preferenceShuffler.shuffle();
        preferenceShuffler.stage();
        FastByIDMap fastByIDMap = new FastByIDMap();
        for (int i = 0; i < preferenceShuffler.size(); i++) {
            Preference preference = preferenceShuffler.get(i);
            assertEquals(preference.getValue(), this.dataModel.getPreferenceValue(preference.getUserID(), preference.getItemID()).floatValue(), 0.0d);
            if (!fastByIDMap.containsKey(preference.getUserID())) {
                fastByIDMap.put(preference.getUserID(), new FastByIDMap());
            }
            assertNull(((FastByIDMap) fastByIDMap.get(preference.getUserID())).get(preference.getItemID()));
            ((FastByIDMap) fastByIDMap.get(preference.getUserID())).put(preference.getItemID(), true);
        }
        LongPrimitiveIterator userIDs = this.dataModel.getUserIDs();
        int i2 = 0;
        while (userIDs.hasNext()) {
            for (Preference preference2 : this.dataModel.getPreferencesFromUser(userIDs.nextLong())) {
                assertTrue(((Boolean) ((FastByIDMap) fastByIDMap.get(preference2.getUserID())).get(preference2.getItemID())).booleanValue());
                i2++;
            }
        }
        assertEquals(i2, preferenceShuffler.size());
    }

    @Test
    @ThreadLeakLingering(linger = 1000)
    public void testFactorizerWithToyData() throws Exception {
        setUpToyData();
        long currentTimeMillis = System.currentTimeMillis();
        this.factorizer = new ParallelSGDFactorizer(this.dataModel, this.rank, this.lambda, this.numIterations, 0.01d, 1.0d, 0, 0.0d);
        Factorization factorize = this.factorizer.factorize();
        long currentTimeMillis2 = System.currentTimeMillis() - currentTimeMillis;
        FullRunningAverage fullRunningAverage = new FullRunningAverage();
        LongPrimitiveIterator userIDs = this.dataModel.getUserIDs();
        while (userIDs.hasNext()) {
            long nextLong = userIDs.nextLong();
            Iterator it = this.dataModel.getPreferencesFromUser(nextLong).iterator();
            while (it.hasNext()) {
                double value = r0.getValue() - new DenseVector(factorize.getUserFeatures(nextLong)).dot(new DenseVector(factorize.getItemFeatures(((Preference) it.next()).getItemID())));
                fullRunningAverage.addDatum(value * value);
            }
        }
        double d = 0.0d;
        LongPrimitiveIterator userIDs2 = this.dataModel.getUserIDs();
        while (userIDs2.hasNext()) {
            DenseVector denseVector = new DenseVector(factorize.getUserFeatures(userIDs2.nextLong()));
            d += denseVector.dot(denseVector);
        }
        LongPrimitiveIterator itemIDs = this.dataModel.getItemIDs();
        while (itemIDs.hasNext()) {
            DenseVector denseVector2 = new DenseVector(factorize.getUserFeatures(itemIDs.nextLong()));
            d += denseVector2.dot(denseVector2);
        }
        double sqrt = Math.sqrt(fullRunningAverage.getAverage());
        logger.info("RMSE: " + sqrt + ";\tLoss: " + ((fullRunningAverage.getAverage() / 2.0d) + ((this.lambda / 2.0d) * d)) + ";\tTime Used: " + currentTimeMillis2);
        assertTrue(sqrt < 0.2d);
    }

    @Test
    @ThreadLeakLingering(linger = 1000)
    public void testRecommenderWithToyData() throws Exception {
        setUpToyData();
        this.factorizer = new ParallelSGDFactorizer(this.dataModel, this.rank, this.lambda, this.numIterations, 0.01d, 1.0d, 0, 0.0d);
        this.svdRecommender = new SVDRecommender(this.dataModel, this.factorizer);
        FullRunningAverage fullRunningAverage = new FullRunningAverage();
        LongPrimitiveIterator userIDs = this.dataModel.getUserIDs();
        while (userIDs.hasNext()) {
            for (Preference preference : this.dataModel.getPreferencesFromUser(userIDs.nextLong())) {
                double value = preference.getValue() - this.svdRecommender.estimatePreference(r0, preference.getItemID());
                fullRunningAverage.addDatum(value * value);
            }
        }
        double sqrt = Math.sqrt(fullRunningAverage.getAverage());
        logger.info("rmse: " + sqrt);
        assertTrue(sqrt < 0.2d);
    }

    @Test
    public void testFactorizerWithWithSyntheticData() throws Exception {
        setUpSyntheticData();
        long currentTimeMillis = System.currentTimeMillis();
        this.factorizer = new ParallelSGDFactorizer(this.dataModel, this.rank, this.lambda, this.numIterations, 0.01d, 1.0d, 0, 0.0d);
        Factorization factorize = this.factorizer.factorize();
        long currentTimeMillis2 = System.currentTimeMillis() - currentTimeMillis;
        FullRunningAverage fullRunningAverage = new FullRunningAverage();
        LongPrimitiveIterator userIDs = this.dataModel.getUserIDs();
        while (userIDs.hasNext()) {
            long nextLong = userIDs.nextLong();
            Iterator it = this.dataModel.getPreferencesFromUser(nextLong).iterator();
            while (it.hasNext()) {
                double value = r0.getValue() - new DenseVector(factorize.getUserFeatures(nextLong)).dot(new DenseVector(factorize.getItemFeatures(((Preference) it.next()).getItemID())));
                fullRunningAverage.addDatum(value * value);
            }
        }
        double d = 0.0d;
        LongPrimitiveIterator userIDs2 = this.dataModel.getUserIDs();
        while (userIDs2.hasNext()) {
            DenseVector denseVector = new DenseVector(factorize.getUserFeatures(userIDs2.nextLong()));
            d += denseVector.dot(denseVector);
        }
        LongPrimitiveIterator itemIDs = this.dataModel.getItemIDs();
        while (itemIDs.hasNext()) {
            DenseVector denseVector2 = new DenseVector(factorize.getUserFeatures(itemIDs.nextLong()));
            d += denseVector2.dot(denseVector2);
        }
        double sqrt = Math.sqrt(fullRunningAverage.getAverage());
        logger.info("RMSE: " + sqrt + ";\tLoss: " + ((fullRunningAverage.getAverage() / 2.0d) + ((this.lambda / 2.0d) * d)) + ";\tTime Used: " + currentTimeMillis2 + "ms");
        assertTrue(sqrt < 0.2d);
    }

    @Test
    public void testRecommenderWithSyntheticData() throws Exception {
        setUpSyntheticData();
        this.factorizer = new ParallelSGDFactorizer(this.dataModel, this.rank, this.lambda, this.numIterations, 0.01d, 1.0d, 0, 0.0d);
        this.svdRecommender = new SVDRecommender(this.dataModel, this.factorizer);
        FullRunningAverage fullRunningAverage = new FullRunningAverage();
        LongPrimitiveIterator userIDs = this.dataModel.getUserIDs();
        while (userIDs.hasNext()) {
            for (Preference preference : this.dataModel.getPreferencesFromUser(userIDs.nextLong())) {
                double value = preference.getValue() - this.svdRecommender.estimatePreference(r0, preference.getItemID());
                fullRunningAverage.addDatum(value * value);
            }
        }
        double sqrt = Math.sqrt(fullRunningAverage.getAverage());
        logger.info("rmse: " + sqrt);
        assertTrue(sqrt < 0.2d);
    }
}
