package org.apache.mahout.math.solver;

import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.MahoutTestCase;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.SingularValueDecomposition;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.Functions;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/math/solver/LSMRTest.class */
public final class LSMRTest extends MahoutTestCase {
    @Test
    public void basics() {
        Matrix hilbert = hilbert(5);
        assertEquals(1.0d, hilbert.get(0, 0), 0.0d);
        assertEquals(0.5d, hilbert.get(0, 1), 0.0d);
        assertEquals(0.16666666666666666d, hilbert.get(2, 3), 1.0E-9d);
        DenseVector denseVector = new DenseVector(new double[]{5.0d, -120.0d, 630.0d, -1120.0d, 630.0d});
        DenseVector denseVector2 = new DenseVector(5);
        denseVector2.assign(1.0d);
        assertEquals(0.0d, hilbert.times(denseVector).minus(denseVector2).norm(2.0d), 1.0E-9d);
        LSMR lsmr = new LSMR();
        Vector solve = lsmr.solve(hilbert, denseVector2);
        assertEquals(0.0d, hilbert.times(solve).minus(denseVector2).norm(2.0d), 0.01d);
        assertEquals(0.0d, hilbert.transpose().times(hilbert).times(solve).minus(hilbert.transpose().times(denseVector2)).norm(2.0d), 1.0E-7d);
        assertEquals(hilbert.times(solve).minus(denseVector2).norm(2.0d), lsmr.getResidualNorm(), 1.0E-5d);
        assertEquals(hilbert.transpose().times(hilbert).times(solve).minus(hilbert.transpose().times(denseVector2)).norm(2.0d), lsmr.getNormalEquationResidual(), 1.0E-9d);
    }

    @Test
    public void random() {
        Matrix assign = new DenseMatrix(200, 30).assign(Functions.random());
        Vector assign2 = new DenseVector(200).assign(1.0d);
        LSMR lsmr = new LSMR();
        Vector solve = lsmr.solve(assign, assign2);
        double norm = new SingularValueDecomposition(assign).getS().viewDiagonal().norm(2.0d);
        double norm2 = assign.transpose().times(assign).times(solve).minus(assign.transpose().times(assign2)).norm(2.0d);
        System.out.printf("%.4f\n", Double.valueOf((norm2 / norm) * 1000000.0d));
        assertEquals(0.0d, norm2, norm * 1.0E-5d);
        assertEquals(assign.times(solve).minus(assign2).norm(2.0d), lsmr.getResidualNorm(), 1.0E-5d);
        assertEquals(norm2, lsmr.getNormalEquationResidual(), 1.0E-9d);
    }

    private static Matrix hilbert(int i) {
        DenseMatrix denseMatrix = new DenseMatrix(i, i);
        for (int i2 = 0; i2 < i; i2++) {
            for (int i3 = 0; i3 < i; i3++) {
                denseMatrix.set(i2, i3, 1.0d / ((i2 + i3) + 1));
            }
        }
        return denseMatrix;
    }
}
