package org.apache.mahout.math.als;

import java.util.Arrays;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.MahoutTestCase;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.SparseMatrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.map.OpenIntObjectHashMap;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/math/als/AlternatingLeastSquaresSolverTest.class */
public class AlternatingLeastSquaresSolverTest extends MahoutTestCase {
    /* JADX WARN: Multi-variable type inference failed */
    @Test
    public void testYtY() {
        for (double[][] dArr : new double[][]{new double[]{new double[]{1.0d, 2.0d, 3.0d, 4.0d, 5.0d}, new double[]{1.0d, 2.0d, 3.0d, 4.0d, 5.0d}, new double[]{1.0d, 2.0d, 3.0d, 4.0d, 5.0d}, new double[]{1.0d, 2.0d, 3.0d, 4.0d, 5.0d}, new double[]{1.0d, 2.0d, 3.0d, 4.0d, 5.0d}}, new double[]{new double[]{1.0d, 2.0d, 3.0d, 4.0d, 5.0d, 6.0d}, new double[]{5.0d, 4.0d, 3.0d, 2.0d, 1.0d, 7.0d}, new double[]{1.0d, 2.0d, 3.0d, 4.0d, 5.0d, 8.0d}, new double[]{1.0d, 2.0d, 3.0d, 4.0d, 5.0d, 8.0d}, new double[]{11.0d, 12.0d, 13.0d, 20.0d, 27.0d, 8.0d}}}) {
            DenseMatrix denseMatrix = new DenseMatrix(dArr);
            for (int i = 0; i < 100; i++) {
                validateYtY(denseMatrix, 4);
            }
            validateYtY(denseMatrix, 1);
        }
    }

    private void validateYtY(Matrix matrix, int i) {
        OpenIntObjectHashMap<Vector> asRowVectors = asRowVectors(matrix);
        ImplicitFeedbackAlternatingLeastSquaresSolver implicitFeedbackAlternatingLeastSquaresSolver = new ImplicitFeedbackAlternatingLeastSquaresSolver(matrix.columnSize(), 1.0d, 1.0d, asRowVectors, i);
        Matrix times = matrix.transpose().times(matrix);
        Matrix ytransposeY = implicitFeedbackAlternatingLeastSquaresSolver.getYtransposeY(asRowVectors);
        for (int i2 = 0; i2 < times.rowSize(); i2++) {
            for (int i3 = 0; i3 < times.columnSize(); i3++) {
                assertEquals(times.getQuick(i2, i3), ytransposeY.getQuick(i2, i3), 0.0d);
            }
        }
    }

    private OpenIntObjectHashMap<Vector> asRowVectors(Matrix matrix) {
        OpenIntObjectHashMap<Vector> openIntObjectHashMap = new OpenIntObjectHashMap<>();
        for (int i = 0; i < matrix.numRows(); i++) {
            openIntObjectHashMap.put(i, matrix.viewRow(i).clone());
        }
        return openIntObjectHashMap;
    }

    @Test
    public void addLambdaTimesNuiTimesE() {
        SparseMatrix sparseMatrix = new SparseMatrix(5, 5);
        AlternatingLeastSquaresSolver.addLambdaTimesNuiTimesE(sparseMatrix, 0.2d, 5);
        for (int i = 0; i < 5; i++) {
            assertEquals(1.0d, sparseMatrix.getQuick(i, i), 1.0E-6d);
        }
    }

    @Test
    public void createMiIi() {
        Matrix createMiIi = AlternatingLeastSquaresSolver.createMiIi(Arrays.asList(new DenseVector(new double[]{1.0d, 2.0d, 3.0d}), new DenseVector(new double[]{4.0d, 5.0d, 6.0d})), 3);
        assertEquals(1.0d, createMiIi.getQuick(0, 0), 1.0E-6d);
        assertEquals(2.0d, createMiIi.getQuick(1, 0), 1.0E-6d);
        assertEquals(3.0d, createMiIi.getQuick(2, 0), 1.0E-6d);
        assertEquals(4.0d, createMiIi.getQuick(0, 1), 1.0E-6d);
        assertEquals(5.0d, createMiIi.getQuick(1, 1), 1.0E-6d);
        assertEquals(6.0d, createMiIi.getQuick(2, 1), 1.0E-6d);
    }

    @Test
    public void createRiIiMaybeTransposed() {
        SequentialAccessSparseVector sequentialAccessSparseVector = new SequentialAccessSparseVector(3);
        sequentialAccessSparseVector.setQuick(1, 1.0d);
        sequentialAccessSparseVector.setQuick(3, 3.0d);
        sequentialAccessSparseVector.setQuick(5, 5.0d);
        Matrix createRiIiMaybeTransposed = AlternatingLeastSquaresSolver.createRiIiMaybeTransposed(sequentialAccessSparseVector);
        assertEquals(1.0f, createRiIiMaybeTransposed.numCols(), 1.0f);
        assertEquals(3.0f, createRiIiMaybeTransposed.numRows(), 3.0f);
        assertEquals(1.0d, createRiIiMaybeTransposed.getQuick(0, 0), 1.0E-6d);
        assertEquals(3.0d, createRiIiMaybeTransposed.getQuick(1, 0), 1.0E-6d);
        assertEquals(5.0d, createRiIiMaybeTransposed.getQuick(2, 0), 1.0E-6d);
    }

    @Test
    public void createRiIiMaybeTransposedExceptionOnNonSequentialVector() {
        RandomAccessSparseVector randomAccessSparseVector = new RandomAccessSparseVector(3);
        randomAccessSparseVector.setQuick(1, 1.0d);
        randomAccessSparseVector.setQuick(3, 3.0d);
        randomAccessSparseVector.setQuick(5, 5.0d);
        try {
            AlternatingLeastSquaresSolver.createRiIiMaybeTransposed(randomAccessSparseVector);
            fail();
        } catch (IllegalArgumentException e) {
        }
    }
}
