package org.apache.mahout.math.random;

import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.MahoutTestCase;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.stats.OnlineSummarizer;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/math/random/MultiNormalTest.class */
public class MultiNormalTest extends MahoutTestCase {
    @Override // org.apache.mahout.math.MahoutTestCase
    @Before
    public void setUp() {
        RandomUtils.useTestSeed();
    }

    @Test
    public void testDiagonal() {
        DenseVector denseVector = new DenseVector(new double[]{6.0d, 3.0d, 0.0d});
        MultiNormal multiNormal = new MultiNormal(new DenseVector(new double[]{1.0d, 2.0d, 5.0d}), denseVector);
        OnlineSummarizer[] onlineSummarizerArr = {new OnlineSummarizer(), new OnlineSummarizer(), new OnlineSummarizer()};
        OnlineSummarizer[] onlineSummarizerArr2 = {new OnlineSummarizer(), new OnlineSummarizer(), new OnlineSummarizer()};
        for (int i = 0; i < 10000; i++) {
            Vector sample = multiNormal.sample();
            for (int i2 = 0; i2 < 3; i2++) {
                onlineSummarizerArr[i2].add(sample.get(i2) - denseVector.get(i2));
                int i3 = i2 % 2;
                int i4 = ((i2 + 1) / 2) + 1;
                onlineSummarizerArr2[i2].add((sample.get(i3) - denseVector.get(i3)) * (sample.get(i4) - denseVector.get(i4)));
            }
        }
        for (int i5 = 0; i5 < 3; i5++) {
            assertEquals(0.0d, onlineSummarizerArr[i5].getMean() / onlineSummarizerArr[i5].getSD(), 0.04d);
            assertEquals(0.0d, onlineSummarizerArr2[i5].getMean() / onlineSummarizerArr2[i5].getSD(), 0.04d);
        }
    }

    @Test
    public void testRadius() {
        MultiNormal multiNormal = new MultiNormal(0.1d, new DenseVector(10));
        OnlineSummarizer onlineSummarizer = new OnlineSummarizer();
        for (int i = 0; i < 10000; i++) {
            onlineSummarizer.add(multiNormal.sample().norm(2.0d) / Math.sqrt(10.0d));
        }
        assertEquals(0.1d, onlineSummarizer.getMean(), 0.01d);
    }
}
