package org.apache.spark.ml.feature;

import org.apache.spark.ml.Model;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.ml.util.MLTestingUtils$;
import org.apache.spark.ml.util.SchemaUtils$;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.DataTypes;
import scala.Predef$;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.Seq;
import scala.runtime.BoxesRunTime;

/* compiled from: LSHTest.scala */
/* loaded from: input_file:org/apache/spark/ml/feature/LSHTest$.class */
public final class LSHTest$ {
    public static final LSHTest$ MODULE$ = null;

    static {
        new LSHTest$();
    }

    public <T extends LSHModel<T>> Tuple2<Object, Object> calculateLSHProperty(Dataset<?> dataset, LSH<T> lsh, double d, double d2) {
        Model<?> fit = lsh.fit(dataset);
        String inputCol = fit.getInputCol();
        String outputCol = fit.getOutputCol();
        Dataset transform = fit.transform(dataset);
        MLTestingUtils$.MODULE$.checkCopyAndUids(lsh, fit);
        SchemaUtils$.MODULE$.checkColumnType(transform.schema(), fit.getOutputCol(), DataTypes.createArrayType(new VectorUDT()), SchemaUtils$.MODULE$.checkColumnType$default$4());
        Predef$.MODULE$.assert(((Seq) ((Row) transform.select(outputCol, Predef$.MODULE$.wrapRefArray(new String[0])).head()).get(0)).length() == fit.getNumHashTables());
        Dataset withColumn = transform.as("a").crossJoin(transform.as("b")).withColumn("same_bucket", functions$.MODULE$.udf(new LSHTest$$anonfun$2(fit), DataTypes.BooleanType).apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"a.", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{outputCol}))), functions$.MODULE$.col(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"b.", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{outputCol})))}))).withColumn("distance", functions$.MODULE$.udf(new LSHTest$$anonfun$1(fit), DataTypes.DoubleType).apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"a.", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{inputCol}))), functions$.MODULE$.col(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"b.", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{inputCol})))})));
        Dataset filter = withColumn.filter(functions$.MODULE$.col("same_bucket"));
        Dataset filter2 = withColumn.filter(functions$.MODULE$.col("same_bucket").unary_$bang());
        return new Tuple2.mcDD.sp(filter.filter(functions$.MODULE$.col("distance").$greater(BoxesRunTime.boxToDouble(d))).count() / filter.count(), filter2.filter(functions$.MODULE$.col("distance").$less(BoxesRunTime.boxToDouble(d2))).count() / filter2.count());
    }

    public <T extends LSHModel<T>> Tuple2<Object, Object> calculateApproxNearestNeighbors(LSH<T> lsh, Dataset<?> dataset, Vector vector, int i, boolean z) {
        LSHModel fit = lsh.fit(dataset);
        Dataset limit = dataset.sort(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.udf(new LSHTest$$anonfun$3(vector, fit), DataTypes.DoubleType).apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col(fit.getInputCol())}))})).limit(i);
        Dataset approxNearestNeighbors = fit.approxNearestNeighbors(dataset, vector, i, z, "distCol");
        Predef$.MODULE$.assert(approxNearestNeighbors.schema().sameType(fit.transformSchema(dataset.schema()).add("distCol", DataTypes.DoubleType)));
        if (!z) {
            Predef$.MODULE$.assert(approxNearestNeighbors.count() == ((long) i));
        }
        double count = limit.join(approxNearestNeighbors, fit.getInputCol()).count();
        return new Tuple2.mcDD.sp(count / approxNearestNeighbors.count(), count / limit.count());
    }

    public <T extends LSHModel<T>> Tuple2<Object, Object> calculateApproxSimilarityJoin(LSH<T> lsh, Dataset<?> dataset, Dataset<?> dataset2, double d) {
        LSHModel fit = lsh.fit(dataset);
        String inputCol = fit.getInputCol();
        Dataset filter = dataset.as("a").crossJoin(dataset2.as("b")).filter(functions$.MODULE$.udf(new LSHTest$$anonfun$4(fit), DataTypes.DoubleType).apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"a.", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{inputCol}))), functions$.MODULE$.col(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"b.", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{inputCol})))})).$less(BoxesRunTime.boxToDouble(d)));
        Dataset approxSimilarityJoin = fit.approxSimilarityJoin(dataset, dataset2, d);
        SchemaUtils$.MODULE$.checkColumnType(approxSimilarityJoin.schema(), "distCol", DataTypes.DoubleType, SchemaUtils$.MODULE$.checkColumnType$default$4());
        Predef$.MODULE$.assert(approxSimilarityJoin.schema().apply("datasetA").dataType().sameType(fit.transformSchema(dataset.schema())));
        Predef$.MODULE$.assert(approxSimilarityJoin.schema().apply("datasetB").dataType().sameType(fit.transformSchema(dataset2.schema())));
        double count = approxSimilarityJoin.filter(functions$.MODULE$.col("distCol").$less(BoxesRunTime.boxToDouble(d))).count();
        return new Tuple2.mcDD.sp(count / approxSimilarityJoin.count(), count / filter.count());
    }

    private LSHTest$() {
        MODULE$ = this;
    }
}
