package org.apache.mahout.flinkbindings.blas;

import com.google.common.collect.Lists;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.scala.JoinDataSet;
import org.apache.flink.util.Collector;
import org.apache.mahout.flinkbindings.drm.BlockifiedFlinkDrm;
import org.apache.mahout.flinkbindings.drm.FlinkDrm;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.drm.logical.OpAtB;
import org.apache.mahout.math.scalabindings.RLikeOps$;
import scala.Array$;
import scala.MatchError;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.Tuple2;
import scala.collection.JavaConverters$;
import scala.collection.TraversableOnce;
import scala.collection.mutable.Buffer;
import scala.collection.mutable.Buffer$;
import scala.reflect.ClassTag$;

/* compiled from: FlinkOpAtB.scala */
/* loaded from: input_file:org/apache/mahout/flinkbindings/blas/FlinkOpAtB$.class */
public final class FlinkOpAtB$ {
    public static final FlinkOpAtB$ MODULE$ = null;

    static {
        new FlinkOpAtB$();
    }

    public <A> FlinkDrm<Object> notZippable(OpAtB<A> opAtB, FlinkDrm<A> flinkDrm, FlinkDrm<A> flinkDrm2) {
        JoinDataSet joinDataSet = (JoinDataSet) flinkDrm.asRowWise().ds().join(flinkDrm2.asRowWise().ds()).where(Predef$.MODULE$.wrapIntArray(new int[]{0})).equalTo(Predef$.MODULE$.wrapIntArray(new int[]{0}));
        return new BlockifiedFlinkDrm(joinDataSet.flatMap(new FlinkOpAtB$$anon$5((int) opAtB.nrow(), 10, org.apache.mahout.math.drm.package$.MODULE$.safeToNonNegInt(((r0 - 1) / 10) + 1)), new FlinkOpAtB$$anon$3(), ClassTag$.MODULE$.apply(Tuple2.class)).groupBy(Predef$.MODULE$.wrapIntArray(new int[]{0})).reduceGroup(new GroupReduceFunction<Tuple2<Object, Matrix>, Tuple2<int[], Matrix>>(10) { // from class: org.apache.mahout.flinkbindings.blas.FlinkOpAtB$$anon$6
            private final int blockHeight$1;

            public void reduce(Iterable<Tuple2<Object, Matrix>> iterable, Collector<Tuple2<int[], Matrix>> collector) {
                Buffer buffer = (Buffer) JavaConverters$.MODULE$.asScalaBufferConverter(Lists.newArrayList(iterable)).asScala();
                Tuple2 tuple2 = (Tuple2) buffer.head();
                if (tuple2 == null) {
                    throw new MatchError(tuple2);
                }
                int _1$mcI$sp = tuple2._1$mcI$sp();
                Matrix matrix = (Matrix) ((TraversableOnce) buffer.map(new FlinkOpAtB$$anon$6$$anonfun$2(this), Buffer$.MODULE$.canBuildFrom())).reduce(new FlinkOpAtB$$anon$6$$anonfun$3(this));
                collector.collect(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.any2ArrowAssoc((int[]) Array$.MODULE$.tabulate(RLikeOps$.MODULE$.m2mOps(matrix).nrow(), new FlinkOpAtB$$anon$6$$anonfun$1(this, _1$mcI$sp * this.blockHeight$1), ClassTag$.MODULE$.Int())), matrix));
            }

            {
                this.blockHeight$1 = r4;
            }
        }, new FlinkOpAtB$$anon$4(), ClassTag$.MODULE$.apply(Tuple2.class)), opAtB.ncol(), BasicTypeInfo.getInfoFor(Integer.TYPE), ClassTag$.MODULE$.Int());
    }

    private FlinkOpAtB$() {
        MODULE$ = this;
    }
}
