package org.apache.mahout.vectorizer;

import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope;
import com.google.common.collect.Lists;
import com.google.common.io.Closeables;
import java.io.IOException;
import java.util.Comparator;
import java.util.Iterator;
import java.util.LinkedList;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.iterator.sequencefile.PathFilters;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.junit.Test;

@ThreadLeakScope(ThreadLeakScope.Scope.NONE)
/* loaded from: input_file:org/apache/mahout/vectorizer/SparseVectorsFromSequenceFilesTest.class */
public class SparseVectorsFromSequenceFilesTest extends MahoutTestCase {
    private static final int NUM_DOCS = 100;
    private Configuration conf;
    private Path inputPath;

    private void setupDocs() throws IOException {
        this.conf = getConfiguration();
        this.inputPath = getTestTempFilePath("documents/docs.file");
        SequenceFile.Writer writer = new SequenceFile.Writer(FileSystem.get(this.inputPath.toUri(), this.conf), this.conf, this.inputPath, Text.class, Text.class);
        RandomDocumentGenerator randomDocumentGenerator = new RandomDocumentGenerator();
        for (int i = 0; i < NUM_DOCS; i++) {
            try {
                writer.append(new Text("Document::ID::" + i), new Text(randomDocumentGenerator.getRandomDocument()));
            } finally {
                Closeables.close(writer, false);
            }
        }
    }

    @Test
    public void testCreateTermFrequencyVectors() throws Exception {
        setupDocs();
        runTest(false, false, false, -1.0d, NUM_DOCS);
    }

    @Test
    public void testCreateTermFrequencyVectorsNam() throws Exception {
        setupDocs();
        runTest(false, false, true, -1.0d, NUM_DOCS);
    }

    @Test
    public void testCreateTermFrequencyVectorsSeq() throws Exception {
        setupDocs();
        runTest(false, true, false, -1.0d, NUM_DOCS);
    }

    @Test
    public void testCreateTermFrequencyVectorsSeqNam() throws Exception {
        setupDocs();
        runTest(false, true, true, -1.0d, NUM_DOCS);
    }

    @Test
    public void testPruning() throws Exception {
        this.conf = getConfiguration();
        this.inputPath = getTestTempFilePath("documents/docs.file");
        SequenceFile.Writer writer = new SequenceFile.Writer(FileSystem.get(this.inputPath.toUri(), this.conf), this.conf, this.inputPath, Text.class, Text.class);
        String[] strArr = {"a b c", "a a a a a b", "a a a a a c"};
        for (int i = 0; i < strArr.length; i++) {
            try {
                writer.append(new Text("Document::ID::" + i), new Text(strArr[i]));
            } finally {
                Closeables.close(writer, false);
            }
        }
        Path path = new Path(runTest(false, false, false, 2.0d, strArr.length), "tfidf-vectors");
        int i2 = 0;
        Vector[] vectorArr = new Vector[strArr.length];
        Iterator it = new SequenceFileDirValueIterable(path, PathType.LIST, PathFilters.partFilter(), (Comparator) null, true, this.conf).iterator();
        while (it.hasNext()) {
            Vector vector = ((VectorWritable) it.next()).get();
            System.out.println(vector);
            assertEquals(2L, vector.size());
            vectorArr[i2] = vector;
            i2++;
        }
        assertEquals(strArr.length, i2);
        assertEquals(2L, vectorArr[0].getNumNondefaultElements());
        assertEquals(1L, vectorArr[1].getNumNondefaultElements());
        assertEquals(1L, vectorArr[2].getNumNondefaultElements());
    }

    @Test
    public void testPruningTF() throws Exception {
        this.conf = getConfiguration();
        FileSystem fileSystem = FileSystem.get(this.conf);
        this.inputPath = getTestTempFilePath("documents/docs.file");
        SequenceFile.Writer writer = new SequenceFile.Writer(fileSystem, this.conf, this.inputPath, Text.class, Text.class);
        String[] strArr = {"a b c", "a a a a a b", "a a a a a c"};
        for (int i = 0; i < strArr.length; i++) {
            try {
                writer.append(new Text("Document::ID::" + i), new Text(strArr[i]));
            } finally {
                Closeables.close(writer, false);
            }
        }
        Path path = new Path(runTest(true, false, false, 2.0d, strArr.length), "tf-vectors");
        int i2 = 0;
        Vector[] vectorArr = new Vector[strArr.length];
        Iterator it = new SequenceFileDirValueIterable(path, PathType.LIST, PathFilters.partFilter(), (Comparator) null, true, this.conf).iterator();
        while (it.hasNext()) {
            Vector vector = ((VectorWritable) it.next()).get();
            System.out.println(vector);
            assertEquals(2L, vector.size());
            vectorArr[i2] = vector;
            i2++;
        }
        assertEquals(strArr.length, i2);
        assertEquals(2L, vectorArr[0].getNumNondefaultElements());
        assertEquals(1L, vectorArr[1].getNumNondefaultElements());
        assertEquals(1L, vectorArr[2].getNumNondefaultElements());
    }

    private Path runTest(boolean z, boolean z2, boolean z3, double d, int i) throws Exception {
        Path testTempFilePath = getTestTempFilePath("output");
        LinkedList newLinkedList = Lists.newLinkedList();
        newLinkedList.add("-i");
        newLinkedList.add(this.inputPath.toString());
        newLinkedList.add("-o");
        newLinkedList.add(testTempFilePath.toString());
        if (z2) {
            newLinkedList.add("-seq");
        }
        if (z3) {
            newLinkedList.add("-nv");
        }
        if (d >= 0.0d) {
            newLinkedList.add("--maxDFSigma");
            newLinkedList.add(String.valueOf(d));
        }
        if (z) {
            newLinkedList.add("--weight");
            newLinkedList.add("tf");
        }
        ToolRunner.run(getConfiguration(), new SparseVectorsFromSequenceFiles(), (String[]) newLinkedList.toArray(new String[newLinkedList.size()]));
        Path path = new Path(testTempFilePath, "tf-vectors");
        Path path2 = new Path(testTempFilePath, "tfidf-vectors");
        DictionaryVectorizerTest.validateVectors(this.conf, i, path, z2, z3);
        if (!z) {
            DictionaryVectorizerTest.validateVectors(this.conf, i, path2, z2, z3);
        }
        return testTempFilePath;
    }
}
