package org.apache.mahout.math.ssvd;

import com.google.common.collect.Lists;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FilenameFilter;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.DiagonalMatrix;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.MatrixWritable;
import org.apache.mahout.math.RandomTrinaryMatrix;
import org.apache.mahout.math.SingularValueDecomposition;
import org.apache.mahout.math.function.Functions;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/math/ssvd/SequentialOutOfCoreSvdTest.class */
public final class SequentialOutOfCoreSvdTest extends MahoutTestCase {
    private File tmpDir;

    @Override // org.apache.mahout.common.MahoutTestCase
    @Before
    public void setUp() throws Exception {
        super.setUp();
        this.tmpDir = getTestTempDir("matrix");
    }

    @Test
    public void testSingularValues() throws IOException {
        Matrix lowRankMatrix = lowRankMatrix(this.tmpDir, "A", 200, 970, 1020);
        List reverse = Lists.reverse(Arrays.asList(this.tmpDir.listFiles(new FilenameFilter() { // from class: org.apache.mahout.math.ssvd.SequentialOutOfCoreSvdTest.1
            @Override // java.io.FilenameFilter
            public boolean accept(File file, String str) {
                return str.matches("A-.*");
            }
        })));
        SequentialOutOfCoreSvd sequentialOutOfCoreSvd = new SequentialOutOfCoreSvd(reverse, this.tmpDir, 100, 210);
        assertEquals(0.0d, new DenseVector(new SequentialBigSvd(lowRankMatrix, 100).getSingularValues()).viewPart(0, 6).minus(sequentialOutOfCoreSvd.getSingularValues().viewPart(0, 6)).maxValue(), 1.0E-9d);
        sequentialOutOfCoreSvd.computeU(reverse, this.tmpDir);
        Matrix readBlockMatrix = readBlockMatrix(Arrays.asList(this.tmpDir.listFiles(new FilenameFilter() { // from class: org.apache.mahout.math.ssvd.SequentialOutOfCoreSvdTest.2
            @Override // java.io.FilenameFilter
            public boolean accept(File file, String str) {
                return str.matches("U-.*");
            }
        })));
        sequentialOutOfCoreSvd.computeV(this.tmpDir, lowRankMatrix.columnSize());
        assertEquals(0.0d, lowRankMatrix.minus(readBlockMatrix.times(new DiagonalMatrix(sequentialOutOfCoreSvd.getSingularValues())).times(readBlockMatrix(Arrays.asList(this.tmpDir.listFiles(new FilenameFilter() { // from class: org.apache.mahout.math.ssvd.SequentialOutOfCoreSvdTest.3
            @Override // java.io.FilenameFilter
            public boolean accept(File file, String str) {
                return str.matches("V-.*");
            }
        }))).transpose())).aggregate(Functions.PLUS, Functions.ABS), 1.0E-7d);
    }

    private static Matrix readBlockMatrix(List<File> list) throws IOException {
        Collections.sort(list);
        int i = -1;
        int i2 = -1;
        DenseMatrix denseMatrix = null;
        MatrixWritable matrixWritable = new MatrixWritable();
        int i3 = 0;
        Iterator<File> it = list.iterator();
        while (it.hasNext()) {
            DataInputStream dataInputStream = new DataInputStream(new FileInputStream(it.next()));
            matrixWritable.readFields(dataInputStream);
            dataInputStream.close();
            if (i == -1) {
                i = matrixWritable.get().rowSize() * list.size();
                i2 = matrixWritable.get().columnSize();
                denseMatrix = new DenseMatrix(i, i2);
            }
            denseMatrix.viewPart(i3, matrixWritable.get().rowSize(), 0, denseMatrix.columnSize()).assign(matrixWritable.get());
            i3 += matrixWritable.get().rowSize();
        }
        if (i3 != i && denseMatrix != null) {
            denseMatrix = denseMatrix.viewPart(0, i3, 0, i2);
        }
        return denseMatrix;
    }

    @Test
    public void testLeftVectors() throws IOException {
        Matrix lowRankMatrixInMemory = lowRankMatrixInMemory(20, 20);
        assertEquals(new SingularValueDecomposition(lowRankMatrixInMemory).getU().viewPart(0, 20, 0, 3).assign(Functions.ABS), new SequentialBigSvd(lowRankMatrixInMemory, 6).getU().viewPart(0, 20, 0, 3).assign(Functions.ABS));
    }

    private static Matrix lowRankMatrixInMemory(int i, int i2) throws IOException {
        return lowRankMatrix(null, null, 0, i, i2);
    }

    private static void assertEquals(Matrix matrix, Matrix matrix2) {
        assertEquals(0.0d, matrix.minus(matrix2).aggregate(Functions.MAX, Functions.ABS), 1.0E-10d);
    }

    @Test
    public void testRightVectors() throws IOException {
        Matrix lowRankMatrixInMemory = lowRankMatrixInMemory(20, 20);
        assertEquals(new SingularValueDecomposition(lowRankMatrixInMemory).getV().viewPart(0, 20, 0, 3).assign(Functions.ABS), new SequentialBigSvd(lowRankMatrixInMemory, 6).getV().viewPart(0, 20, 0, 3).assign(Functions.ABS));
    }

    private static Matrix lowRankMatrix(File file, String str, int i, int i2, int i3) throws IOException {
        RandomTrinaryMatrix randomTrinaryMatrix = new RandomTrinaryMatrix(1, i2, 10, false);
        DenseMatrix denseMatrix = new DenseMatrix(10, 10);
        denseMatrix.set(0, 0, 5.0d);
        denseMatrix.set(1, 1, 3.0d);
        denseMatrix.set(2, 2, 1.0d);
        denseMatrix.set(3, 3, 0.5d);
        Matrix times = randomTrinaryMatrix.times(denseMatrix).times(new RandomTrinaryMatrix(2, i3, 10, false).transpose());
        if (file != null) {
            int i4 = 0;
            while (true) {
                int i5 = i4;
                if (i5 >= times.rowSize()) {
                    break;
                }
                MatrixWritable matrixWritable = new MatrixWritable(times.viewPart(i5, Math.min(times.rowSize() - i5, i), 0, times.columnSize()));
                DataOutputStream dataOutputStream = new DataOutputStream(new FileOutputStream(new File(file, String.format("%s-%09d", str, Integer.valueOf(i5)))));
                try {
                    matrixWritable.write(dataOutputStream);
                    dataOutputStream.close();
                    i4 = i5 + i;
                } catch (Throwable th) {
                    dataOutputStream.close();
                    throw th;
                }
            }
        }
        return times;
    }
}
