package org.apache.mahout.math.decomposer.hebbian;

import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.decomposer.AsyncEigenVerifier;
import org.apache.mahout.math.decomposer.SolverTest;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/math/decomposer/hebbian/TestHebbianSolver.class */
public final class TestHebbianSolver extends SolverTest {
    public static long timeSolver(Matrix matrix, double d, int i, TrainingState trainingState) {
        return timeSolver(matrix, d, i, 10, trainingState);
    }

    public static long timeSolver(Matrix matrix, double d, int i, int i2, TrainingState trainingState) {
        HebbianSolver hebbianSolver = new HebbianSolver(new HebbianUpdater(), new AsyncEigenVerifier(), d, i);
        long nanoTime = System.nanoTime();
        TrainingState solve = hebbianSolver.solve(matrix, i2);
        assertNotNull(solve);
        trainingState.setCurrentEigens(solve.getCurrentEigens());
        trainingState.setCurrentEigenValues(solve.getCurrentEigenValues());
        long nanoTime2 = 0 + (System.nanoTime() - nanoTime);
        assertEquals(trainingState.getCurrentEigens().numRows(), i2);
        return nanoTime2 / 1000000;
    }

    public static long timeSolver(Matrix matrix, TrainingState trainingState) {
        return timeSolver(matrix, trainingState, 10);
    }

    public static long timeSolver(Matrix matrix, TrainingState trainingState, int i) {
        return timeSolver(matrix, 0.01d, 20, i, trainingState);
    }

    @Test
    public void testHebbianSolver() {
        Matrix randomSequentialAccessSparseMatrix = randomSequentialAccessSparseMatrix(1000, 900, 800, 30, 1.0d);
        TrainingState trainingState = new TrainingState(new DenseMatrix(50, 800), (Matrix) null);
        long timeSolver = timeSolver(randomSequentialAccessSparseMatrix, 1.0E-5d, 5, 50, trainingState);
        Matrix currentEigens = trainingState.getCurrentEigens();
        assertEigen(currentEigens, randomSequentialAccessSparseMatrix, 0.05d, false);
        assertOrthonormal(currentEigens, 1.0E-6d);
        System.out.println("Avg solving (Hebbian) time in ms: " + timeSolver);
    }
}
