package org.apache.mahout.classifier.sgd;

import com.google.common.base.Preconditions;
import com.google.common.io.Closeables;
import java.io.DataInput;
import java.io.DataInputStream;
import java.io.DataOutput;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.hadoop.io.Writable;

/* loaded from: input_file:org/apache/mahout/classifier/sgd/LogisticModelParameters.class */
public class LogisticModelParameters implements Writable {
    private String targetVariable;
    private Map<String, String> typeMap;
    private int numFeatures;
    private boolean useBias;
    private int maxTargetCategories;
    private List<String> targetCategories;
    private double lambda;
    private double learningRate;
    private CsvRecordFactory csv;
    private OnlineLogisticRegression lr;

    public CsvRecordFactory getCsvRecordFactory() {
        if (this.csv == null) {
            this.csv = new CsvRecordFactory(getTargetVariable(), getTypeMap()).maxTargetValue(getMaxTargetCategories()).includeBiasTerm(useBias());
            if (this.targetCategories != null) {
                this.csv.defineTargetCategories(this.targetCategories);
            }
        }
        return this.csv;
    }

    public OnlineLogisticRegression createRegression() {
        if (this.lr == null) {
            this.lr = new OnlineLogisticRegression(getMaxTargetCategories(), getNumFeatures(), new L1()).lambda(getLambda()).learningRate(getLearningRate()).alpha(0.999d);
        }
        return this.lr;
    }

    public void saveTo(OutputStream outputStream) throws IOException {
        Closeables.close(this.lr, false);
        this.targetCategories = getCsvRecordFactory().getTargetCategories();
        write(new DataOutputStream(outputStream));
    }

    public static LogisticModelParameters loadFrom(InputStream inputStream) throws IOException {
        LogisticModelParameters logisticModelParameters = new LogisticModelParameters();
        logisticModelParameters.readFields(new DataInputStream(inputStream));
        return logisticModelParameters;
    }

    public static LogisticModelParameters loadFrom(File file) throws IOException {
        FileInputStream fileInputStream = new FileInputStream(file);
        Throwable th = null;
        try {
            LogisticModelParameters loadFrom = loadFrom(fileInputStream);
            if (fileInputStream != null) {
                if (0 != 0) {
                    try {
                        fileInputStream.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                } else {
                    fileInputStream.close();
                }
            }
            return loadFrom;
        } catch (Throwable th3) {
            if (fileInputStream != null) {
                if (0 != 0) {
                    try {
                        fileInputStream.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    fileInputStream.close();
                }
            }
            throw th3;
        }
    }

    public void write(DataOutput dataOutput) throws IOException {
        dataOutput.writeUTF(this.targetVariable);
        dataOutput.writeInt(this.typeMap.size());
        for (Map.Entry<String, String> entry : this.typeMap.entrySet()) {
            dataOutput.writeUTF(entry.getKey());
            dataOutput.writeUTF(entry.getValue());
        }
        dataOutput.writeInt(this.numFeatures);
        dataOutput.writeBoolean(this.useBias);
        dataOutput.writeInt(this.maxTargetCategories);
        if (this.targetCategories == null) {
            dataOutput.writeInt(0);
        } else {
            dataOutput.writeInt(this.targetCategories.size());
            Iterator<String> it = this.targetCategories.iterator();
            while (it.hasNext()) {
                dataOutput.writeUTF(it.next());
            }
        }
        dataOutput.writeDouble(this.lambda);
        dataOutput.writeDouble(this.learningRate);
        this.lr.write(dataOutput);
    }

    public void readFields(DataInput dataInput) throws IOException {
        this.targetVariable = dataInput.readUTF();
        int readInt = dataInput.readInt();
        this.typeMap = new HashMap(readInt);
        for (int i = 0; i < readInt; i++) {
            this.typeMap.put(dataInput.readUTF(), dataInput.readUTF());
        }
        this.numFeatures = dataInput.readInt();
        this.useBias = dataInput.readBoolean();
        this.maxTargetCategories = dataInput.readInt();
        int readInt2 = dataInput.readInt();
        this.targetCategories = new ArrayList(readInt2);
        for (int i2 = 0; i2 < readInt2; i2++) {
            this.targetCategories.add(dataInput.readUTF());
        }
        this.lambda = dataInput.readDouble();
        this.learningRate = dataInput.readDouble();
        this.csv = null;
        this.lr = new OnlineLogisticRegression();
        this.lr.readFields(dataInput);
    }

    public void setTypeMap(Iterable<String> iterable, List<String> list) {
        Preconditions.checkArgument(!list.isEmpty(), "Must have at least one type specifier");
        this.typeMap = new HashMap();
        Iterator<String> it = list.iterator();
        String str = null;
        for (String str2 : iterable) {
            if (it.hasNext()) {
                str = it.next();
            }
            this.typeMap.put(str2.toString(), str);
        }
    }

    public void setTargetVariable(String str) {
        this.targetVariable = str;
    }

    public void setMaxTargetCategories(int i) {
        this.maxTargetCategories = i;
    }

    public void setNumFeatures(int i) {
        this.numFeatures = i;
    }

    public void setTargetCategories(List<String> list) {
        this.targetCategories = list;
        this.maxTargetCategories = list.size();
    }

    public List<String> getTargetCategories() {
        return this.targetCategories;
    }

    public void setUseBias(boolean z) {
        this.useBias = z;
    }

    public boolean useBias() {
        return this.useBias;
    }

    public String getTargetVariable() {
        return this.targetVariable;
    }

    public Map<String, String> getTypeMap() {
        return this.typeMap;
    }

    public void setTypeMap(Map<String, String> map) {
        this.typeMap = map;
    }

    public int getNumFeatures() {
        return this.numFeatures;
    }

    public int getMaxTargetCategories() {
        return this.maxTargetCategories;
    }

    public double getLambda() {
        return this.lambda;
    }

    public void setLambda(double d) {
        this.lambda = d;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public void setLearningRate(double d) {
        this.learningRate = d;
    }
}
