/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.math.random;

import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.MahoutTestCase;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.random.MultiNormal;
import org.apache.mahout.math.stats.OnlineSummarizer;
import org.junit.Before;
import org.junit.Test;

public class MultiNormalTest
extends MahoutTestCase {
    @Override
    @Before
    public void setUp() {
        RandomUtils.useTestSeed();
    }

    @Test
    public void testDiagonal() {
        DenseVector offset = new DenseVector(new double[]{6.0, 3.0, 0.0});
        MultiNormal n = new MultiNormal((Vector)new DenseVector(new double[]{1.0, 2.0, 5.0}), (Vector)offset);
        OnlineSummarizer[] s = new OnlineSummarizer[]{new OnlineSummarizer(), new OnlineSummarizer(), new OnlineSummarizer()};
        OnlineSummarizer[] cross = new OnlineSummarizer[]{new OnlineSummarizer(), new OnlineSummarizer(), new OnlineSummarizer()};
        for (int i = 0; i < 10000; ++i) {
            Vector v = n.sample();
            for (int j = 0; j < 3; ++j) {
                s[j].add(v.get(j) - offset.get(j));
                int k1 = j % 2;
                int k2 = (j + 1) / 2 + 1;
                cross[j].add((v.get(k1) - offset.get(k1)) * (v.get(k2) - offset.get(k2)));
            }
        }
        for (int j = 0; j < 3; ++j) {
            MultiNormalTest.assertEquals((double)0.0, (double)(s[j].getMean() / s[j].getSD()), (double)0.04);
            MultiNormalTest.assertEquals((double)0.0, (double)(cross[j].getMean() / cross[j].getSD()), (double)0.04);
        }
    }

    @Test
    public void testRadius() {
        MultiNormal gen = new MultiNormal(0.1, (Vector)new DenseVector(10));
        OnlineSummarizer s = new OnlineSummarizer();
        for (int i = 0; i < 10000; ++i) {
            double x = gen.sample().norm(2.0) / Math.sqrt(10.0);
            s.add(x);
        }
        MultiNormalTest.assertEquals((double)0.1, (double)s.getMean(), (double)0.01);
    }
}

