/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.mllib.classification;

import java.util.List;
import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.mllib.classification.SVMModel;
import org.apache.spark.mllib.classification.SVMSuite;
import org.apache.spark.mllib.classification.SVMWithSGD;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.rdd.RDD;
import org.junit.Assert;
import org.junit.Test;

public class JavaSVMSuite
extends SharedSparkSession {
    int validatePrediction(List<LabeledPoint> validationData, SVMModel model) {
        int numAccurate = 0;
        for (LabeledPoint point : validationData) {
            Double prediction = model.predict(point.features());
            if (prediction.doubleValue() != point.label()) continue;
            ++numAccurate;
        }
        return numAccurate;
    }

    @Test
    public void runSVMUsingConstructor() {
        int nPoints = 10000;
        double A = 2.0;
        double[] weights = new double[]{-1.5, 1.0};
        JavaRDD testRDD = this.jsc.parallelize(SVMSuite.generateSVMInputAsList(A, weights, nPoints, 42), 2).cache();
        List<LabeledPoint> validationData = SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17);
        SVMWithSGD svmSGDImpl = new SVMWithSGD();
        svmSGDImpl.setIntercept(true);
        svmSGDImpl.optimizer().setStepSize(1.0).setRegParam(1.0).setNumIterations(100);
        SVMModel model = (SVMModel)svmSGDImpl.run(testRDD.rdd());
        int numAccurate = this.validatePrediction(validationData, model);
        Assert.assertTrue(((double)numAccurate > (double)nPoints * 4.0 / 5.0 ? 1 : 0) != 0);
    }

    @Test
    public void runSVMUsingStaticMethods() {
        SVMModel model;
        int nPoints = 10000;
        double A = 0.0;
        double[] weights = new double[]{-1.5, 1.0};
        JavaRDD testRDD = this.jsc.parallelize(SVMSuite.generateSVMInputAsList(A, weights, nPoints, 42), 2).cache();
        List<LabeledPoint> validationData = SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17);
        int numAccurate = this.validatePrediction(validationData, model = SVMWithSGD.train((RDD)testRDD.rdd(), (int)100, (double)1.0, (double)1.0, (double)1.0));
        Assert.assertTrue(((double)numAccurate > (double)nPoints * 4.0 / 5.0 ? 1 : 0) != 0);
    }
}

