/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.classifier.sgd;

import com.google.common.collect.Lists;
import com.google.common.io.Closeables;
import java.io.DataInput;
import java.io.DataInputStream;
import java.io.DataOutput;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression;
import org.apache.mahout.classifier.sgd.ElasticBandPrior;
import org.apache.mahout.classifier.sgd.L1;
import org.apache.mahout.classifier.sgd.L2;
import org.apache.mahout.classifier.sgd.LogisticModelParameters;
import org.apache.mahout.classifier.sgd.PriorFunction;
import org.apache.mahout.classifier.sgd.TPrior;
import org.apache.mahout.classifier.sgd.UniformPrior;
import org.apache.mahout.math.stats.GlobalOnlineAuc;
import org.apache.mahout.math.stats.GroupedOnlineAuc;
import org.apache.mahout.math.stats.OnlineAuc;

public class AdaptiveLogisticModelParameters
extends LogisticModelParameters {
    private AdaptiveLogisticRegression alr;
    private int interval = 800;
    private int averageWindow = 500;
    private int threads = 4;
    private String prior = "L1";
    private double priorOption = Double.NaN;
    private String auc = null;

    public AdaptiveLogisticRegression createAdaptiveLogisticRegression() {
        if (this.alr == null) {
            this.alr = new AdaptiveLogisticRegression(this.getMaxTargetCategories(), this.getNumFeatures(), AdaptiveLogisticModelParameters.createPrior(this.prior, this.priorOption));
            this.alr.setInterval(this.interval);
            this.alr.setAveragingWindow(this.averageWindow);
            this.alr.setThreadCount(this.threads);
            this.alr.setAucEvaluator(AdaptiveLogisticModelParameters.createAUC(this.auc));
        }
        return this.alr;
    }

    public void checkParameters() {
        String priorUppercase;
        if (this.prior != null && ("TP".equals(priorUppercase = this.prior.toUpperCase(Locale.ENGLISH).trim()) || "EBP".equals(priorUppercase)) && Double.isNaN(this.priorOption)) {
            throw new IllegalArgumentException("You must specify a double value for TPrior and ElasticBandPrior.");
        }
    }

    private static PriorFunction createPrior(String cmd, double priorOption) {
        if (cmd == null) {
            return null;
        }
        if ("L1".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) {
            return new L1();
        }
        if ("L2".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) {
            return new L2();
        }
        if ("UP".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) {
            return new UniformPrior();
        }
        if ("TP".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) {
            return new TPrior(priorOption);
        }
        if ("EBP".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) {
            return new ElasticBandPrior(priorOption);
        }
        return null;
    }

    private static OnlineAuc createAUC(String cmd) {
        if (cmd == null) {
            return null;
        }
        if ("GLOBAL".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) {
            return new GlobalOnlineAuc();
        }
        if ("GROUPED".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) {
            return new GroupedOnlineAuc();
        }
        return null;
    }

    @Override
    public void saveTo(OutputStream out) throws IOException {
        if (this.alr != null) {
            this.alr.close();
        }
        this.setTargetCategories(this.getCsvRecordFactory().getTargetCategories());
        this.write(new DataOutputStream(out));
    }

    @Override
    public void write(DataOutput out) throws IOException {
        out.writeUTF(this.getTargetVariable());
        out.writeInt(this.getTypeMap().size());
        for (Map.Entry<String, String> entry : this.getTypeMap().entrySet()) {
            out.writeUTF(entry.getKey());
            out.writeUTF(entry.getValue());
        }
        out.writeInt(this.getNumFeatures());
        out.writeInt(this.getMaxTargetCategories());
        out.writeInt(this.getTargetCategories().size());
        for (String category : this.getTargetCategories()) {
            out.writeUTF(category);
        }
        out.writeInt(this.interval);
        out.writeInt(this.averageWindow);
        out.writeInt(this.threads);
        out.writeUTF(this.prior);
        out.writeDouble(this.priorOption);
        out.writeUTF(this.auc);
        this.alr.write(out);
    }

    @Override
    public void readFields(DataInput in) throws IOException {
        this.setTargetVariable(in.readUTF());
        int typeMapSize = in.readInt();
        HashMap<String, String> typeMap = new HashMap<String, String>(typeMapSize);
        for (int i = 0; i < typeMapSize; ++i) {
            String key = in.readUTF();
            String value = in.readUTF();
            typeMap.put(key, value);
        }
        this.setTypeMap(typeMap);
        this.setNumFeatures(in.readInt());
        this.setMaxTargetCategories(in.readInt());
        int targetCategoriesSize = in.readInt();
        ArrayList<String> targetCategories = Lists.newArrayListWithCapacity(targetCategoriesSize);
        for (int i = 0; i < targetCategoriesSize; ++i) {
            targetCategories.add(in.readUTF());
        }
        this.setTargetCategories(targetCategories);
        this.interval = in.readInt();
        this.averageWindow = in.readInt();
        this.threads = in.readInt();
        this.prior = in.readUTF();
        this.priorOption = in.readDouble();
        this.auc = in.readUTF();
        this.alr = new AdaptiveLogisticRegression();
        this.alr.readFields(in);
    }

    private static AdaptiveLogisticModelParameters loadFromStream(InputStream in) throws IOException {
        AdaptiveLogisticModelParameters result = new AdaptiveLogisticModelParameters();
        result.readFields(new DataInputStream(in));
        return result;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public static AdaptiveLogisticModelParameters loadFromFile(File in) throws IOException {
        FileInputStream input = new FileInputStream(in);
        try {
            AdaptiveLogisticModelParameters adaptiveLogisticModelParameters = AdaptiveLogisticModelParameters.loadFromStream(input);
            return adaptiveLogisticModelParameters;
        }
        finally {
            Closeables.close(input, true);
        }
    }

    public int getInterval() {
        return this.interval;
    }

    public void setInterval(int interval) {
        this.interval = interval;
    }

    public int getAverageWindow() {
        return this.averageWindow;
    }

    public void setAverageWindow(int averageWindow) {
        this.averageWindow = averageWindow;
    }

    public int getThreads() {
        return this.threads;
    }

    public void setThreads(int threads) {
        this.threads = threads;
    }

    public String getPrior() {
        return this.prior;
    }

    public void setPrior(String prior) {
        this.prior = prior;
    }

    public String getAuc() {
        return this.auc;
    }

    public void setAuc(String auc) {
        this.auc = auc;
    }

    public double getPriorOption() {
        return this.priorOption;
    }

    public void setPriorOption(double priorOption) {
        this.priorOption = priorOption;
    }
}

