/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.classifier.sgd;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.apache.hadoop.io.Writable;
import org.apache.mahout.classifier.AbstractVectorClassifier;
import org.apache.mahout.classifier.OnlineLearner;
import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
import org.apache.mahout.classifier.sgd.PolymorphicWritable;
import org.apache.mahout.classifier.sgd.PriorFunction;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.DoubleDoubleFunction;
import org.apache.mahout.math.function.Functions;
import org.apache.mahout.math.stats.GlobalOnlineAuc;
import org.apache.mahout.math.stats.OnlineAuc;

public class CrossFoldLearner
extends AbstractVectorClassifier
implements OnlineLearner,
Writable {
    private int record;
    private static final double MIN_SCORE = 1.0E-50;
    private OnlineAuc auc = new GlobalOnlineAuc();
    private double logLikelihood;
    private final List<OnlineLogisticRegression> models = new ArrayList<OnlineLogisticRegression>();
    private double[] parameters = new double[4];
    private int numFeatures;
    private PriorFunction prior;
    private double percentCorrect;
    private int windowSize = Integer.MAX_VALUE;

    public CrossFoldLearner() {
    }

    public CrossFoldLearner(int folds, int numCategories, int numFeatures, PriorFunction prior) {
        this.numFeatures = numFeatures;
        this.prior = prior;
        for (int i = 0; i < folds; ++i) {
            OnlineLogisticRegression model = new OnlineLogisticRegression(numCategories, numFeatures, prior);
            model.alpha(1.0).stepOffset(0).decayExponent(0.0);
            this.models.add(model);
        }
    }

    public CrossFoldLearner lambda(double v) {
        for (OnlineLogisticRegression model : this.models) {
            model.lambda(v);
        }
        return this;
    }

    public CrossFoldLearner learningRate(double x) {
        for (OnlineLogisticRegression model : this.models) {
            model.learningRate(x);
        }
        return this;
    }

    public CrossFoldLearner stepOffset(int x) {
        for (OnlineLogisticRegression model : this.models) {
            model.stepOffset(x);
        }
        return this;
    }

    public CrossFoldLearner decayExponent(double x) {
        for (OnlineLogisticRegression model : this.models) {
            model.decayExponent(x);
        }
        return this;
    }

    public CrossFoldLearner alpha(double alpha) {
        for (OnlineLogisticRegression model : this.models) {
            model.alpha(alpha);
        }
        return this;
    }

    @Override
    public void train(int actual, Vector instance) {
        this.train(this.record, null, actual, instance);
    }

    @Override
    public void train(long trackingKey, int actual, Vector instance) {
        this.train(trackingKey, null, actual, instance);
    }

    @Override
    public void train(long trackingKey, String groupKey, int actual, Vector instance) {
        ++this.record;
        int k = 0;
        for (OnlineLogisticRegression model : this.models) {
            if ((long)k == CrossFoldLearner.mod(trackingKey, this.models.size())) {
                Vector v = model.classifyFull(instance);
                double score = Math.max(v.get(actual), 1.0E-50);
                this.logLikelihood += (Math.log(score) - this.logLikelihood) / (double)Math.min(this.record, this.windowSize);
                boolean correct = v.maxValueIndex() == actual;
                this.percentCorrect += ((double)correct - this.percentCorrect) / (double)Math.min(this.record, this.windowSize);
                if (this.numCategories() == 2) {
                    this.auc.addSample(actual, groupKey, v.get(1));
                }
            } else {
                model.train(trackingKey, groupKey, actual, instance);
            }
            ++k;
        }
    }

    private static long mod(long x, int y) {
        long r = x % (long)y;
        return r < 0L ? r + (long)y : r;
    }

    @Override
    public void close() {
        for (OnlineLogisticRegression m : this.models) {
            m.close();
        }
    }

    public void resetLineCounter() {
        this.record = 0;
    }

    public boolean validModel() {
        boolean r = true;
        for (OnlineLogisticRegression model : this.models) {
            r &= model.validModel();
        }
        return r;
    }

    @Override
    public Vector classify(Vector instance) {
        DenseVector r = new DenseVector(this.numCategories() - 1);
        DoubleDoubleFunction scale = Functions.plusMult((double)(1.0 / (double)this.models.size()));
        for (OnlineLogisticRegression model : this.models) {
            r.assign(model.classify(instance), scale);
        }
        return r;
    }

    @Override
    public Vector classifyNoLink(Vector instance) {
        DenseVector r = new DenseVector(this.numCategories() - 1);
        DoubleDoubleFunction scale = Functions.plusMult((double)(1.0 / (double)this.models.size()));
        for (OnlineLogisticRegression model : this.models) {
            r.assign(model.classifyNoLink(instance), scale);
        }
        return r;
    }

    @Override
    public double classifyScalar(Vector instance) {
        double r = 0.0;
        int n = 0;
        for (OnlineLogisticRegression model : this.models) {
            ++n;
            r += model.classifyScalar(instance);
        }
        return r / (double)n;
    }

    @Override
    public int numCategories() {
        return this.models.get(0).numCategories();
    }

    public double auc() {
        return this.auc.auc();
    }

    public double logLikelihood() {
        return this.logLikelihood;
    }

    public double percentCorrect() {
        return this.percentCorrect;
    }

    public CrossFoldLearner copy() {
        CrossFoldLearner r = new CrossFoldLearner(this.models.size(), this.numCategories(), this.numFeatures, this.prior);
        r.models.clear();
        for (OnlineLogisticRegression model : this.models) {
            model.close();
            OnlineLogisticRegression newModel = new OnlineLogisticRegression(model.numCategories(), model.numFeatures(), model.prior);
            newModel.copyFrom(model);
            r.models.add(newModel);
        }
        return r;
    }

    public int getRecord() {
        return this.record;
    }

    public void setRecord(int record) {
        this.record = record;
    }

    public OnlineAuc getAucEvaluator() {
        return this.auc;
    }

    public void setAucEvaluator(OnlineAuc auc) {
        this.auc = auc;
    }

    public double getLogLikelihood() {
        return this.logLikelihood;
    }

    public void setLogLikelihood(double logLikelihood) {
        this.logLikelihood = logLikelihood;
    }

    public List<OnlineLogisticRegression> getModels() {
        return this.models;
    }

    public void addModel(OnlineLogisticRegression model) {
        this.models.add(model);
    }

    public double[] getParameters() {
        return this.parameters;
    }

    public void setParameters(double[] parameters) {
        this.parameters = parameters;
    }

    public int getNumFeatures() {
        return this.numFeatures;
    }

    public void setNumFeatures(int numFeatures) {
        this.numFeatures = numFeatures;
    }

    public void setWindowSize(int windowSize) {
        this.windowSize = windowSize;
        this.auc.setWindowSize(windowSize);
    }

    public PriorFunction getPrior() {
        return this.prior;
    }

    public void setPrior(PriorFunction prior) {
        this.prior = prior;
    }

    public void write(DataOutput out) throws IOException {
        out.writeInt(this.record);
        PolymorphicWritable.write(out, this.auc);
        out.writeDouble(this.logLikelihood);
        out.writeInt(this.models.size());
        for (OnlineLogisticRegression model : this.models) {
            model.write(out);
        }
        for (double x : this.parameters) {
            out.writeDouble(x);
        }
        out.writeInt(this.numFeatures);
        PolymorphicWritable.write(out, this.prior);
        out.writeDouble(this.percentCorrect);
        out.writeInt(this.windowSize);
    }

    public void readFields(DataInput in) throws IOException {
        int i;
        this.record = in.readInt();
        this.auc = PolymorphicWritable.read(in, OnlineAuc.class);
        this.logLikelihood = in.readDouble();
        int n = in.readInt();
        for (i = 0; i < n; ++i) {
            OnlineLogisticRegression olr = new OnlineLogisticRegression();
            olr.readFields(in);
            this.models.add(olr);
        }
        this.parameters = new double[4];
        for (i = 0; i < 4; ++i) {
            this.parameters[i] = in.readDouble();
        }
        this.numFeatures = in.readInt();
        this.prior = PolymorphicWritable.read(in, PriorFunction.class);
        this.percentCorrect = in.readDouble();
        this.windowSize = in.readInt();
    }
}

