/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.classifier.naivebayes.training;

import com.google.common.base.Preconditions;
import org.apache.mahout.classifier.naivebayes.ComplementaryNaiveBayesClassifier;
import org.apache.mahout.math.Vector;

public class ComplementaryThetaTrainer {
    private final Vector weightsPerFeature;
    private final Vector weightsPerLabel;
    private final Vector perLabelThetaNormalizer;
    private final double alphaI;
    private final double totalWeightSum;
    private final double numFeatures;

    public ComplementaryThetaTrainer(Vector weightsPerFeature, Vector weightsPerLabel, double alphaI) {
        Preconditions.checkNotNull((Object)weightsPerFeature);
        Preconditions.checkNotNull((Object)weightsPerLabel);
        this.weightsPerFeature = weightsPerFeature;
        this.weightsPerLabel = weightsPerLabel;
        this.alphaI = alphaI;
        this.perLabelThetaNormalizer = weightsPerLabel.like();
        this.totalWeightSum = weightsPerLabel.zSum();
        this.numFeatures = weightsPerFeature.getNumNondefaultElements();
    }

    public void train(int label, Vector perLabelWeight) {
        double labelWeight = this.labelWeight(label);
        for (int i = 0; i < perLabelWeight.size(); ++i) {
            Vector.Element perLabelWeightElement = perLabelWeight.getElement(i);
            this.updatePerLabelThetaNormalizer(label, ComplementaryNaiveBayesClassifier.computeWeight(this.featureWeight(perLabelWeightElement.index()), perLabelWeightElement.get(), this.totalWeightSum(), labelWeight, this.alphaI(), this.numFeatures()));
        }
    }

    protected double alphaI() {
        return this.alphaI;
    }

    protected double numFeatures() {
        return this.numFeatures;
    }

    protected double labelWeight(int label) {
        return this.weightsPerLabel.get(label);
    }

    protected double totalWeightSum() {
        return this.totalWeightSum;
    }

    protected double featureWeight(int feature) {
        return this.weightsPerFeature.get(feature);
    }

    protected void updatePerLabelThetaNormalizer(int label, double weight) {
        this.perLabelThetaNormalizer.set(label, this.perLabelThetaNormalizer.get(label) + Math.abs(weight));
    }

    public Vector retrievePerLabelThetaNormalizer() {
        return this.perLabelThetaNormalizer.clone();
    }
}

