package org.apache.spark.mllib.recommendation;

import java.util.ArrayList;
import java.util.List;
import org.apache.spark.SharedSparkSession;
import org.junit.Assert;
import org.junit.Test;
import scala.Tuple2;
import scala.Tuple3;

/* loaded from: input_file:org/apache/spark/mllib/recommendation/JavaALSSuite.class */
public class JavaALSSuite extends SharedSparkSession {
    private void validatePrediction(MatrixFactorizationModel matrixFactorizationModel, int i, int i2, double[] dArr, double d, boolean z, double[] dArr2) {
        ArrayList arrayList = new ArrayList(i * i2);
        for (int i3 = 0; i3 < i; i3++) {
            for (int i4 = 0; i4 < i2; i4++) {
                arrayList.add(new Tuple2(Integer.valueOf(i3), Integer.valueOf(i4)));
            }
        }
        List<Rating> collect = matrixFactorizationModel.predict(this.jsc.parallelizePairs(arrayList)).collect();
        Assert.assertEquals(i * i2, collect.size());
        if (!z) {
            for (Rating rating : collect) {
                double rating2 = rating.rating();
                Assert.assertTrue(String.format("Prediction=%2.4f not below match threshold of %2.2f", Double.valueOf(rating2), Double.valueOf(d)), Math.abs(rating2 - dArr[(rating.product() * i) + rating.user()]) < d);
            }
            return;
        }
        double d2 = 0.0d;
        double d3 = 0.0d;
        for (Rating rating3 : collect) {
            double rating4 = rating3.rating();
            double d4 = dArr2[(rating3.product() * i) + rating3.user()];
            double abs = 1.0d + Math.abs(dArr[(rating3.product() * i) + rating3.user()]);
            d2 += abs * (d4 - rating4) * (d4 - rating4);
            d3 += abs;
        }
        double sqrt = Math.sqrt(d2 / d3);
        Assert.assertTrue(String.format("Confidence-weighted RMSE=%2.4f above threshold of %2.2f", Double.valueOf(sqrt), Double.valueOf(d)), sqrt < d);
    }

    @Test
    public void runALSUsingStaticMethods() {
        Tuple3<List<Rating>, double[], double[]> generateRatingsAsJava = ALSSuite.generateRatingsAsJava(50, 100, 1, 0.7d, false, false);
        validatePrediction(ALS.train(this.jsc.parallelize((List) generateRatingsAsJava._1()).rdd(), 1, 15), 50, 100, (double[]) generateRatingsAsJava._2(), 0.3d, false, (double[]) generateRatingsAsJava._3());
    }

    @Test
    public void runALSUsingConstructor() {
        Tuple3<List<Rating>, double[], double[]> generateRatingsAsJava = ALSSuite.generateRatingsAsJava(100, 200, 2, 0.7d, false, false);
        validatePrediction(new ALS().setRank(2).setIterations(15).run(this.jsc.parallelize((List) generateRatingsAsJava._1())), 100, 200, (double[]) generateRatingsAsJava._2(), 0.3d, false, (double[]) generateRatingsAsJava._3());
    }

    @Test
    public void runImplicitALSUsingStaticMethods() {
        Tuple3<List<Rating>, double[], double[]> generateRatingsAsJava = ALSSuite.generateRatingsAsJava(80, 160, 1, 0.7d, true, false);
        validatePrediction(ALS.trainImplicit(this.jsc.parallelize((List) generateRatingsAsJava._1()).rdd(), 1, 15), 80, 160, (double[]) generateRatingsAsJava._2(), 0.4d, true, (double[]) generateRatingsAsJava._3());
    }

    @Test
    public void runImplicitALSUsingConstructor() {
        Tuple3<List<Rating>, double[], double[]> generateRatingsAsJava = ALSSuite.generateRatingsAsJava(100, 200, 2, 0.7d, true, false);
        validatePrediction(new ALS().setRank(2).setIterations(15).setImplicitPrefs(true).run(this.jsc.parallelize((List) generateRatingsAsJava._1()).rdd()), 100, 200, (double[]) generateRatingsAsJava._2(), 0.4d, true, (double[]) generateRatingsAsJava._3());
    }

    @Test
    public void runImplicitALSWithNegativeWeight() {
        Tuple3<List<Rating>, double[], double[]> generateRatingsAsJava = ALSSuite.generateRatingsAsJava(80, 160, 2, 0.7d, true, true);
        validatePrediction(new ALS().setRank(2).setIterations(15).setImplicitPrefs(true).setSeed(8675309L).run(this.jsc.parallelize((List) generateRatingsAsJava._1()).rdd()), 80, 160, (double[]) generateRatingsAsJava._2(), 0.4d, true, (double[]) generateRatingsAsJava._3());
    }

    @Test
    public void runRecommend() {
        MatrixFactorizationModel run = new ALS().setRank(5).setIterations(10).setImplicitPrefs(true).setSeed(8675309L).run(this.jsc.parallelize((List) ALSSuite.generateRatingsAsJava(200, 50, 5, 0.7d, true, false)._1()).rdd());
        validateRecommendations(run.recommendProducts(1, 10), 10);
        validateRecommendations(run.recommendUsers(1, 20), 20);
    }

    private static void validateRecommendations(Rating[] ratingArr, int i) {
        Assert.assertEquals(i, ratingArr.length);
        for (int i2 = 1; i2 < ratingArr.length; i2++) {
            Assert.assertTrue(ratingArr[i2 - 1].rating() >= ratingArr[i2].rating());
        }
        Assert.assertTrue(ratingArr[0].rating() > 0.7d);
    }
}
