package org.apache.mahout.math.random;

import com.google.common.collect.HashMultiset;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.RandomWrapper;
import org.apache.mahout.math.MahoutTestCase;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/math/random/MultinomialTest.class */
public class MultinomialTest extends MahoutTestCase {
    @Override // org.apache.mahout.math.MahoutTestCase
    @Before
    public void setUp() {
        RandomUtils.useTestSeed();
    }

    @Test(expected = IllegalArgumentException.class)
    public void testNoValues() {
        new Multinomial(HashMultiset.create());
    }

    @Test
    public void testSingleton() {
        HashMultiset create = HashMultiset.create();
        create.add("one");
        Multinomial multinomial = new Multinomial(create);
        assertEquals("one", multinomial.sample(0.0d));
        assertEquals("one", multinomial.sample(0.1d));
        assertEquals("one", multinomial.sample(1.0d));
    }

    @Test
    public void testEvenSplit() {
        HashMultiset create = HashMultiset.create();
        for (int i = 0; i < 5; i++) {
            create.add(String.valueOf(i));
        }
        Multinomial multinomial = new Multinomial(create);
        HashMultiset create2 = HashMultiset.create();
        for (int i2 = 0; i2 < 5; i2++) {
            create2.add(multinomial.sample(i2 * 0.2d));
            create2.add(multinomial.sample((i2 * 0.2d) + 1.0E-15d));
            create2.add(multinomial.sample(((i2 + 1) * 0.2d) - 1.0E-15d));
        }
        assertEquals(5L, create2.elementSet().size());
        Iterator it = create2.elementSet().iterator();
        while (it.hasNext()) {
            assertEquals(3.0d, create2.count((String) it.next()), 1.01d);
        }
        assertTrue(create2.contains(multinomial.sample(1.0d)));
        assertEquals(multinomial.sample(1.0d - 1.0E-15d), multinomial.sample(1.0d));
    }

    @Test
    public void testPrime() {
        ArrayList newArrayList = Lists.newArrayList();
        for (int i = 0; i < 17; i++) {
            String str = (i & 1) != 0 ? "1" : "0";
            if ((i & 2) != 0) {
                str = "2";
            }
            if ((i & 4) != 0) {
                str = "3";
            }
            if ((i & 8) != 0) {
                str = "4";
            }
            newArrayList.add(str);
        }
        HashMultiset create = HashMultiset.create();
        Iterator it = newArrayList.iterator();
        while (it.hasNext()) {
            create.add((String) it.next());
        }
        Multinomial multinomial = new Multinomial(create);
        Multinomial multinomial2 = new Multinomial(create);
        Multinomial multinomial3 = new Multinomial(create);
        HashMultiset create2 = HashMultiset.create();
        for (int i2 = 0; i2 < 50; i2++) {
            double d = i2 * 0.02d;
            double d2 = (i2 + 1) * 0.02d;
            create2.add(multinomial.sample(d));
            create2.add(multinomial.sample(d + 1.0E-15d));
            create2.add(multinomial.sample(d2 - 1.0E-15d));
            assertEquals(multinomial.sample(d), multinomial2.sample(d));
            assertEquals(multinomial.sample(d + 1.0E-15d), multinomial2.sample(d + 1.0E-15d));
            assertEquals(multinomial.sample(d2 - 1.0E-15d), multinomial2.sample(d2 - 1.0E-15d));
            assertEquals(multinomial.sample(d), multinomial3.sample(d));
            assertEquals(multinomial.sample(d + 1.0E-15d), multinomial3.sample(d + 1.0E-15d));
            assertEquals(multinomial.sample(d2 - 1.0E-15d), multinomial3.sample(d2 - 1.0E-15d));
        }
        assertEquals(multinomial.sample(0.0d), multinomial2.sample(0.0d));
        assertEquals(multinomial.sample(0.0d + 1.0E-15d), multinomial2.sample(0.0d + 1.0E-15d));
        assertEquals(multinomial.sample(1.0d - 1.0E-15d), multinomial2.sample(1.0d - 1.0E-15d));
        assertEquals(multinomial.sample(1.0d), multinomial2.sample(1.0d));
        assertEquals(multinomial.sample(0.0d), multinomial3.sample(0.0d));
        assertEquals(multinomial.sample(0.0d + 1.0E-15d), multinomial3.sample(0.0d + 1.0E-15d));
        assertEquals(multinomial.sample(1.0d - 1.0E-15d), multinomial3.sample(1.0d - 1.0E-15d));
        assertEquals(multinomial.sample(1.0d), multinomial3.sample(1.0d));
        assertEquals(5L, create2.elementSet().size());
        ImmutableMap of = ImmutableMap.of("3", 35, "2", 18, "1", 9, "0", 16, "4", 72);
        for (String str2 : create2.elementSet()) {
            assertTrue(Math.abs(((Integer) of.get(str2)).intValue() - create2.count(str2)) <= 2);
        }
        assertTrue(create2.contains(multinomial.sample(1.0d)));
        assertEquals(multinomial.sample(1.0d - 1.0E-15d), multinomial.sample(1.0d));
    }

    @Test
    public void testInsert() {
        RandomWrapper random = RandomUtils.getRandom();
        Multinomial multinomial = new Multinomial();
        double[] dArr = new double[10];
        for (int i = 0; i < 10; i++) {
            dArr[i] = random.nextDouble();
            multinomial.add(Integer.valueOf(i), dArr[i]);
        }
        checkSelfConsistent(multinomial);
        for (int i2 = 0; i2 < 10; i2++) {
            assertEquals(dArr[i2], multinomial.getWeight(Integer.valueOf(i2)), 0.0d);
        }
    }

    @Test
    public void testSetZeroWhileIterating() {
        Multinomial multinomial = new Multinomial();
        for (int i = 0; i < 10000; i++) {
            multinomial.add(Integer.valueOf(i), i);
        }
        Iterator it = multinomial.iterator();
        while (it.hasNext()) {
            multinomial.set((Integer) it.next(), 0.0d);
        }
    }

    @Test(expected = NullPointerException.class)
    public void testNoNullValuesAllowed() {
        new Multinomial().add((Object) null, 1.0d);
    }

    @Test
    public void testDeleteAndUpdate() {
        RandomWrapper random = RandomUtils.getRandom();
        Multinomial multinomial = new Multinomial();
        assertEquals(0.0d, multinomial.getWeight(), 1.0E-9d);
        double d = 0.0d;
        double[] dArr = new double[10];
        for (int i = 0; i < 10; i++) {
            dArr[i] = random.nextDouble();
            multinomial.add(Integer.valueOf(i), dArr[i]);
            d += dArr[i];
            assertEquals(d, multinomial.getWeight(), 1.0E-9d);
        }
        assertEquals(d, multinomial.getWeight(), 1.0E-9d);
        checkSelfConsistent(multinomial);
        double d2 = dArr[7] + dArr[8];
        multinomial.delete(7);
        dArr[7] = 0.0d;
        multinomial.set(8, 0.0d);
        dArr[8] = 0.0d;
        double d3 = d - d2;
        checkSelfConsistent(multinomial);
        assertEquals(d3, multinomial.getWeight(), 1.0E-9d);
        for (int i2 = 0; i2 < 10; i2++) {
            assertEquals(dArr[i2], multinomial.getWeight(Integer.valueOf(i2)), 0.0d);
            assertEquals(dArr[i2] / d3, multinomial.getProbability(Integer.valueOf(i2)), 1.0E-10d);
        }
        multinomial.set(9, 5.1d);
        double d4 = d3 - dArr[9];
        dArr[9] = 5.1d;
        double d5 = d4 + 5.1d;
        assertEquals(d5, multinomial.getWeight(), 1.0E-9d);
        for (int i3 = 0; i3 < 10; i3++) {
            assertEquals(dArr[i3], multinomial.getWeight(Integer.valueOf(i3)), 0.0d);
            assertEquals(dArr[i3] / d5, multinomial.getProbability(Integer.valueOf(i3)), 1.0E-10d);
        }
        checkSelfConsistent(multinomial);
        for (int i4 = 0; i4 < 10; i4++) {
            assertEquals(dArr[i4], multinomial.getWeight(Integer.valueOf(i4)), 0.0d);
        }
    }

    private static void checkSelfConsistent(Multinomial<Integer> multinomial) {
        List weights = multinomial.getWeights();
        double weight = multinomial.getWeight();
        double d = 0.0d;
        int[] iArr = new int[weights.size()];
        Iterator it = weights.iterator();
        while (it.hasNext()) {
            double doubleValue = ((Double) it.next()).doubleValue();
            if (doubleValue > 0.0d) {
                if (d > 0.0d) {
                    int intValue = ((Integer) multinomial.sample(d - 1.0E-9d)).intValue();
                    iArr[intValue] = iArr[intValue] + 1;
                }
                int intValue2 = ((Integer) multinomial.sample(d + 1.0E-9d)).intValue();
                iArr[intValue2] = iArr[intValue2] + 1;
            }
            d += doubleValue / weight;
        }
        int intValue3 = ((Integer) multinomial.sample(d - 1.0E-9d)).intValue();
        iArr[intValue3] = iArr[intValue3] + 1;
        assertEquals(1.0d, d, 1.0E-9d);
        for (int i = 0; i < weights.size(); i++) {
            if (multinomial.getWeight(Integer.valueOf(i)) > 0.0d) {
                assertEquals(2L, iArr[i]);
            } else {
                assertEquals(0L, iArr[i]);
            }
        }
    }
}
