package org.apache.mahout.math.solver;

import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.Functions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/mahout/math/solver/LSMR.class */
public final class LSMR {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) LSMR.class);
    private int localPointer;
    private Vector[] localV;
    private double residualNorm;
    private double normalEquationResidual;
    private double xNorm;
    private int iteration;
    private double normA;
    private double condA;
    private final double lambda = 0.0d;
    private double aTolerance = 1.0E-6d;
    private double bTolerance = 1.0E-6d;
    private double conditionLimit = 1.0E8d;
    private int iterationLimit = -1;
    private int localSize = 0;

    /* loaded from: input_file:org/apache/mahout/math/solver/LSMR$StopCode.class */
    private enum StopCode {
        CONTINUE("Not done"),
        TRIVIAL("The exact solution is  x = 0"),
        CONVERGED("Ax - b is small enough, given atol, btol"),
        LEAST_SQUARE_CONVERGED("The least-squares solution is good enough, given atol"),
        CONDITION("The estimate of cond(Abar) has exceeded condition limit"),
        CONVERGED_MACHINE_TOLERANCE("Ax - b is small enough for this machine"),
        LEAST_SQUARE_CONVERGED_MACHINE_TOLERANCE("The least-squares solution is good enough for this machine"),
        CONDITION_MACHINE_TOLERANCE("Cond(Abar) seems to be too large for this machine"),
        ITERATION_LIMIT("The iteration limit has been reached");

        private final String message;

        StopCode(String str) {
            this.message = str;
        }

        public String getMessage() {
            return this.message;
        }
    }

    public int getIterationCount() {
        return this.iteration;
    }

    public double getResidualNorm() {
        return this.residualNorm;
    }

    public double getNormalEquationResidual() {
        return this.normalEquationResidual;
    }

    public double getANorm() {
        return this.normA;
    }

    public double getCondition() {
        return this.condA;
    }

    public double getXNorm() {
        return this.xNorm;
    }

    public Vector solve(Matrix matrix, Vector vector) {
        log.debug("   itn         x(1)     norm r   norm A'r");
        log.debug("   compatible   LS      norm A   cond A");
        Matrix transpose = matrix.transpose();
        Vector vector2 = vector;
        double norm = vector2.norm(2.0d);
        if (norm > 0.0d) {
            vector2 = vector2.divide(norm);
        }
        Vector times = transpose.times(vector2);
        int numRows = matrix.numRows();
        int numCols = matrix.numCols();
        int min = Math.min(numRows, numCols);
        if (this.iterationLimit == -1) {
            this.iterationLimit = min;
        }
        if (log.isDebugEnabled()) {
            log.debug("LSMR - Least-squares solution of  Ax = b, based on Matlab Version 1.02, 14 Apr 2010, Mahout version {}", getClass().getPackage().getImplementationVersion());
            log.debug(String.format("The matrix A has %d rows  and %d cols, lambda = %.4g, atol = %g, btol = %g", Integer.valueOf(numRows), Integer.valueOf(numCols), Double.valueOf(this.lambda), Double.valueOf(this.aTolerance), Double.valueOf(this.bTolerance)));
        }
        double norm2 = times.norm(2.0d);
        if (norm2 > 0.0d) {
            times.assign(Functions.div(norm2));
        }
        this.localPointer = 0;
        this.localV = new Vector[Math.min(this.localSize, min)];
        boolean z = false;
        if (this.localSize > 0) {
            z = true;
            this.localV[0] = times;
        }
        this.iteration = 0;
        double d = norm2 * norm;
        double d2 = norm2;
        Vector vector3 = times;
        Vector zeros = zeros(numCols);
        Vector zeros2 = zeros(numCols);
        double d3 = norm;
        double d4 = norm2 * norm2;
        double d5 = 0.0d;
        if (this.conditionLimit > 0.0d) {
            d5 = 1.0d / this.conditionLimit;
        }
        this.residualNorm = norm;
        this.normalEquationResidual = norm2 * norm;
        if (this.normalEquationResidual == 0.0d) {
            return zeros2;
        }
        if (log.isDebugEnabled()) {
            log.debug("{} {}", Integer.valueOf(this.iteration), Double.valueOf(zeros2.get(0)));
            log.debug("{} {}", Double.valueOf(this.residualNorm), Double.valueOf(this.normalEquationResidual));
            log.debug("{} {}", Double.valueOf(1.0d), Double.valueOf(norm2 / norm));
        }
        double d6 = 1.0d;
        double d7 = 1.0d;
        double d8 = 1.0d;
        double d9 = 0.0d;
        double d10 = 0.0d;
        double d11 = 1.0d;
        double d12 = 0.0d;
        double d13 = 0.0d;
        double d14 = 0.0d;
        double d15 = 0.0d;
        double d16 = 0.0d;
        double d17 = 1.0E100d;
        StopCode stopCode = StopCode.CONTINUE;
        while (this.iteration <= this.iterationLimit && stopCode == StopCode.CONTINUE) {
            this.iteration++;
            vector2 = matrix.times(times).minus(vector2.times(norm2));
            double norm3 = vector2.norm(2.0d);
            if (norm3 > 0.0d) {
                vector2.assign(Functions.div(norm3));
                if (z) {
                    localVEnqueue(times);
                }
                times = transpose.times(vector2).minus(times.times(norm3));
                if (z) {
                    times = localVOrtho(times);
                }
                norm2 = times.norm(2.0d);
                if (norm2 > 0.0d) {
                    times.assign(Functions.div(norm2));
                }
            }
            double hypot = Math.hypot(d2, this.lambda);
            double d18 = d2 / hypot;
            double d19 = this.lambda / hypot;
            double d20 = d6;
            d6 = Math.hypot(hypot, norm3);
            double d21 = hypot / d6;
            double d22 = norm3 / d6;
            double d23 = d22 * norm2;
            d2 = d21 * norm2;
            double d24 = d7;
            double d25 = d14;
            double d26 = d9 * d6;
            double d27 = d8 * d6;
            d7 = Math.hypot(d8 * d6, d23);
            d8 = (d8 * d6) / d7;
            d9 = d23 / d7;
            d14 = d8 * d;
            d = (-d9) * d;
            zeros = vector3.minus(zeros.times((d26 * d6) / (d20 * d24)));
            zeros2.assign(zeros.times(d14 / (d6 * d7)), Functions.PLUS);
            vector3 = times.minus(vector3.times(d23 / d6));
            double d28 = d18 * d3;
            double d29 = (-d19) * d3;
            double d30 = d21 * d28;
            d3 = (-d22) * d28;
            double d31 = d13;
            double hypot2 = Math.hypot(d11, d26);
            double d32 = d11 / hypot2;
            double d33 = d26 / hypot2;
            d13 = d33 * d7;
            d11 = d32 * d7;
            d10 = ((-d33) * d10) + (d32 * d30);
            d12 = (d25 - (d31 * d12)) / hypot2;
            double d34 = (d14 - (d13 * d12)) / d11;
            d15 += d29 * d29;
            this.residualNorm = Math.sqrt(d15 + ((d10 - d34) * (d10 - d34)) + (d3 * d3));
            double d35 = d4 + (norm3 * norm3);
            this.normA = Math.sqrt(d35);
            d4 = d35 + (norm2 * norm2);
            d16 = Math.max(d16, d24);
            if (this.iteration > 1) {
                d17 = Math.min(d17, d24);
            }
            this.condA = Math.max(d16, d27) / Math.min(d17, d27);
            this.normalEquationResidual = Math.abs(d);
            this.xNorm = zeros2.norm(2.0d);
            double d36 = this.residualNorm / norm;
            double d37 = this.normalEquationResidual / (this.normA * this.residualNorm);
            double d38 = 1.0d / this.condA;
            double d39 = d36 / (1.0d + ((this.normA * this.xNorm) / norm));
            double d40 = this.bTolerance + (((this.aTolerance * this.normA) * this.xNorm) / norm);
            if (this.iteration > this.iterationLimit) {
                stopCode = StopCode.ITERATION_LIMIT;
            }
            if (1.0d + d38 <= 1.0d) {
                stopCode = StopCode.CONDITION_MACHINE_TOLERANCE;
            }
            if (1.0d + d37 <= 1.0d) {
                stopCode = StopCode.LEAST_SQUARE_CONVERGED_MACHINE_TOLERANCE;
            }
            if (1.0d + d39 <= 1.0d) {
                stopCode = StopCode.CONVERGED_MACHINE_TOLERANCE;
            }
            if (d38 <= d5) {
                stopCode = StopCode.CONDITION;
            }
            if (d37 <= this.aTolerance) {
                stopCode = StopCode.CONVERGED;
            }
            if (d36 <= d40) {
                stopCode = StopCode.TRIVIAL;
            }
            if (log.isDebugEnabled() && (numCols <= 40 || this.iteration <= 10 || this.iteration >= this.iterationLimit - 10 || this.iteration % 10 == 0 || d38 <= 1.1d * d5 || d37 <= 1.1d * this.aTolerance || d36 <= 1.1d * d40 || stopCode != StopCode.CONTINUE)) {
                statusDump(zeros2, this.normA, this.condA, d36, d37);
            }
        }
        log.debug("Finished: {}", stopCode.getMessage());
        return zeros2;
    }

    private void statusDump(Vector vector, double d, double d2, double d3, double d4) {
        log.debug("{} {}", Double.valueOf(this.residualNorm), Double.valueOf(this.normalEquationResidual));
        log.debug("{} {}", Integer.valueOf(this.iteration), Double.valueOf(vector.get(0)));
        log.debug("{} {}", Double.valueOf(d3), Double.valueOf(d4));
        log.debug("{} {}", Double.valueOf(d), Double.valueOf(d2));
    }

    private static Vector zeros(int i) {
        return new DenseVector(i);
    }

    private void localVEnqueue(Vector vector) {
        if (this.localV.length > 0) {
            this.localV[this.localPointer] = vector;
            this.localPointer = (this.localPointer + 1) % this.localV.length;
        }
    }

    private Vector localVOrtho(Vector vector) {
        for (Vector vector2 : this.localV) {
            if (vector2 != null) {
                vector = vector.minus(vector2.times(vector.dot(vector2)));
            }
        }
        return vector;
    }

    public void setAtolerance(double d) {
        this.aTolerance = d;
    }

    public void setBtolerance(double d) {
        this.bTolerance = d;
    }

    public void setConditionLimit(double d) {
        this.conditionLimit = d;
    }

    public void setIterationLimit(int i) {
        this.iterationLimit = i;
    }

    public void setLocalSize(int i) {
        this.localSize = i;
    }

    public double getLambda() {
        return this.lambda;
    }

    public double getAtolerance() {
        return this.aTolerance;
    }

    public double getBtolerance() {
        return this.bTolerance;
    }
}
