package org.apache.mahout.classifier.sgd;

import com.google.common.base.CharMatcher;
import com.google.common.base.Charsets;
import com.google.common.base.Splitter;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.io.CharStreams;
import com.google.common.io.Resources;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import org.apache.mahout.classifier.AbstractVectorClassifier;
import org.apache.mahout.classifier.OnlineLearner;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.Functions;

/* loaded from: input_file:org/apache/mahout/classifier/sgd/OnlineBaseTest.class */
public abstract class OnlineBaseTest extends MahoutTestCase {
    private Matrix input;

    /* JADX INFO: Access modifiers changed from: package-private */
    public Matrix getInput() {
        return this.input;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Vector readStandardData() throws IOException {
        this.input = readCsv("sgd.csv");
        DenseVector denseVector = new DenseVector(60);
        denseVector.assign(0.0d);
        denseVector.viewPart(30, 30).assign(1.0d);
        return denseVector;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void train(Matrix matrix, Vector vector, OnlineLearner onlineLearner) {
        RandomUtils.useTestSeed();
        for (int i : permute(RandomUtils.getRandom(), 60)) {
            onlineLearner.train((int) vector.get(i), matrix.viewRow(i));
        }
        onlineLearner.close();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void test(Matrix matrix, Vector vector, AbstractVectorClassifier abstractVectorClassifier, double d, double d2) {
        Matrix classify = abstractVectorClassifier.classify(matrix);
        double aggregate = classify.viewColumn(0).minus(vector).aggregate(Functions.PLUS, Functions.ABS) / 60.0d;
        double aggregate2 = classify.viewColumn(0).minus(vector).aggregate(Functions.MAX, Functions.ABS);
        System.out.printf("mAE = %.4f, maxAE = %.4f\n", Double.valueOf(aggregate), Double.valueOf(aggregate2));
        assertEquals(0.0d, aggregate, d);
        assertEquals(0.0d, aggregate2, d2);
        assertEquals(0.0d, abstractVectorClassifier.classifyScalar(matrix).minus(classify.viewColumn(0)).norm(1.0d), 1.0E-5d);
        assertEquals(0.0d, abstractVectorClassifier.classifyFull(matrix).viewColumn(1).minus(classify.viewColumn(0)).norm(1.0d), 1.0E-4d);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static int[] permute(Random random, int i) {
        int[] iArr = new int[i];
        iArr[0] = 0;
        for (int i2 = 1; i2 < i; i2++) {
            int nextInt = random.nextInt(i2 + 1);
            if (nextInt == i2) {
                iArr[i2] = i2;
            } else {
                iArr[i2] = iArr[nextInt];
                iArr[nextInt] = i2;
            }
        }
        return iArr;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Matrix readCsv(String str) throws IOException {
        Splitter trimResults = Splitter.on(',').trimResults(CharMatcher.anyOf(" \""));
        List readLines = CharStreams.readLines(new InputStreamReader(Resources.getResource(str).openStream(), Charsets.UTF_8));
        String str2 = (String) readLines.get(0);
        List subList = readLines.subList(1, readLines.size());
        ArrayList newArrayList = Lists.newArrayList(trimResults.split(str2));
        DenseMatrix denseMatrix = new DenseMatrix(subList.size(), newArrayList.size());
        int i = 0;
        HashMap newHashMap = Maps.newHashMap();
        Iterator it = newArrayList.iterator();
        while (it.hasNext()) {
            newHashMap.put((String) it.next(), Integer.valueOf(i));
            i++;
        }
        denseMatrix.setColumnLabelBindings(newHashMap);
        int i2 = 0;
        Iterator it2 = subList.iterator();
        while (it2.hasNext()) {
            int i3 = 0;
            Iterator it3 = Lists.newArrayList(trimResults.split((String) it2.next())).iterator();
            while (it3.hasNext()) {
                denseMatrix.set(i2, i3, Double.parseDouble((String) it3.next()));
                i3++;
            }
            i2++;
        }
        return denseMatrix;
    }
}
