/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.classifier.sgd;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import org.apache.hadoop.io.Writable;
import org.apache.mahout.classifier.AbstractVectorClassifier;
import org.apache.mahout.classifier.OnlineLearner;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.MatrixWritable;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.Functions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class PassiveAggressive
extends AbstractVectorClassifier
implements OnlineLearner,
Writable {
    private static final Logger log = LoggerFactory.getLogger(PassiveAggressive.class);
    public static final int WRITABLE_VERSION = 1;
    private double learningRate = 0.1;
    private int lossCount = 0;
    private double lossSum = 0.0;
    private Matrix weights;
    private int numCategories;

    public PassiveAggressive(int numCategories, int numFeatures) {
        this.numCategories = numCategories;
        this.weights = new DenseMatrix(numCategories, numFeatures);
        this.weights.assign(0.0);
    }

    public PassiveAggressive learningRate(double learningRate) {
        this.learningRate = learningRate;
        return this;
    }

    public void copyFrom(PassiveAggressive other) {
        this.learningRate = other.learningRate;
        this.numCategories = other.numCategories;
        this.weights = other.weights;
    }

    @Override
    public int numCategories() {
        return this.numCategories;
    }

    @Override
    public Vector classify(Vector instance) {
        Vector result = this.classifyNoLink(instance);
        double max = result.maxValue();
        result.assign(Functions.minus((double)max)).assign(Functions.EXP);
        result = result.divide(result.norm(1.0));
        return result.viewPart(1, result.size() - 1);
    }

    @Override
    public Vector classifyNoLink(Vector instance) {
        DenseVector result = new DenseVector(this.weights.numRows());
        result.assign(0.0);
        for (int i = 0; i < this.weights.numRows(); ++i) {
            result.setQuick(i, this.weights.viewRow(i).dot(instance));
        }
        return result;
    }

    @Override
    public double classifyScalar(Vector instance) {
        double v1 = this.weights.viewRow(0).dot(instance);
        double v2 = this.weights.viewRow(1).dot(instance);
        v1 = Math.exp(v1);
        v2 = Math.exp(v2);
        return v2 / (v1 + v2);
    }

    public int numFeatures() {
        return this.weights.numCols();
    }

    public PassiveAggressive copy() {
        this.close();
        PassiveAggressive r = new PassiveAggressive(this.numCategories(), this.numFeatures());
        r.copyFrom(this);
        return r;
    }

    public void write(DataOutput out) throws IOException {
        out.writeInt(1);
        out.writeDouble(this.learningRate);
        out.writeInt(this.numCategories);
        MatrixWritable.writeMatrix((DataOutput)out, (Matrix)this.weights);
    }

    public void readFields(DataInput in) throws IOException {
        int version = in.readInt();
        if (version != 1) {
            throw new IOException("Incorrect object version, wanted 1 got " + version);
        }
        this.learningRate = in.readDouble();
        this.numCategories = in.readInt();
        this.weights = MatrixWritable.readMatrix((DataInput)in);
    }

    @Override
    public void close() {
    }

    @Override
    public void train(long trackingKey, String groupKey, int actual, Vector instance) {
        if (this.lossCount > 1000) {
            log.info("Avg. Loss = {}", (Object)(this.lossSum / (double)this.lossCount));
            this.lossCount = 0;
            this.lossSum = 0.0;
        }
        Vector result = this.classifyNoLink(instance);
        double myScore = result.get(actual);
        int otherIndex = result.maxValueIndex();
        double otherValue = result.get(otherIndex);
        if (otherIndex == actual) {
            result.setQuick(otherIndex, Double.NEGATIVE_INFINITY);
            otherIndex = result.maxValueIndex();
            otherValue = result.get(otherIndex);
        }
        double loss = 1.0 - myScore + otherValue;
        ++this.lossCount;
        if (loss >= 0.0) {
            this.lossSum += loss;
            double tau = loss / (instance.dot(instance) + 0.5 / this.learningRate);
            Vector delta = instance.clone();
            delta.assign(Functions.mult((double)tau));
            this.weights.viewRow(actual).assign(delta, Functions.PLUS);
            delta.assign(Functions.mult((double)-1.0));
            this.weights.viewRow(otherIndex).assign(delta, Functions.PLUS);
        }
    }

    @Override
    public void train(long trackingKey, int actual, Vector instance) {
        this.train(trackingKey, null, actual, instance);
    }

    @Override
    public void train(int actual, Vector instance) {
        this.train(0L, null, actual, instance);
    }
}

