/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.math.random;

import com.google.common.collect.HashMultiset;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Collections;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.MahoutTestCase;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.QRDecomposition;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.random.ChineseRestaurant;
import org.junit.Test;

public final class ChineseRestaurantTest
extends MahoutTestCase {
    @Test
    public void testDepth() {
        ArrayList totals = Lists.newArrayList();
        for (int i = 0; i < 1000; ++i) {
            ChineseRestaurant x = new ChineseRestaurant(10.0);
            HashMultiset counts = HashMultiset.create();
            for (int j = 0; j < 100; ++j) {
                counts.add((Object)x.sample());
            }
            ArrayList tmp = Lists.newArrayList();
            for (Integer k : counts.elementSet()) {
                tmp.add(counts.count((Object)k));
            }
            Collections.sort(tmp, Collections.reverseOrder());
            while (totals.size() < tmp.size()) {
                totals.add(0);
            }
            int j = 0;
            for (Integer k : tmp) {
                totals.set(j, (Integer)totals.get(j) + k);
                ++j;
            }
        }
        ChineseRestaurantTest.assertEquals((double)25000.0, (double)((Integer)totals.get(0)).intValue(), (double)1000.0);
        ChineseRestaurantTest.assertEquals((double)24000.0, (double)((Integer)totals.get(1)).intValue(), (double)1000.0);
        ChineseRestaurantTest.assertEquals((double)8000.0, (double)((Integer)totals.get(2)).intValue(), (double)200.0);
        ChineseRestaurantTest.assertEquals((double)1000.0, (double)((Integer)totals.get(15)).intValue(), (double)50.0);
        ChineseRestaurantTest.assertEquals((double)1000.0, (double)((Integer)totals.get(20)).intValue(), (double)40.0);
    }

    @Test
    public void testExtremeDiscount() {
        int i;
        ChineseRestaurant x = new ChineseRestaurant(100.0, 1.0);
        HashMultiset counts = HashMultiset.create();
        for (i = 0; i < 10000; ++i) {
            counts.add((Object)x.sample());
        }
        ChineseRestaurantTest.assertEquals((long)10000L, (long)x.size());
        for (i = 0; i < 10000; ++i) {
            ChineseRestaurantTest.assertEquals((long)1L, (long)x.count(i));
        }
    }

    @Test
    public void testGrowth() {
        ChineseRestaurant s0 = new ChineseRestaurant(10.0, 0.0);
        ChineseRestaurant s5 = new ChineseRestaurant(10.0, 0.5);
        ChineseRestaurant s9 = new ChineseRestaurant(10.0, 0.9);
        ImmutableSet splits = ImmutableSet.of((Object)1.0, (Object)1.5, (Object)2.0, (Object)3.0, (Object)5.0, (Object)8.0, (Object[])new Double[0]);
        double offset0 = 0.0;
        int k = 0;
        DenseMatrix m5 = new DenseMatrix(20, 3);
        DenseMatrix m9 = new DenseMatrix(20, 3);
        for (int i = 0; i <= 200000; ++i) {
            double n = (double)i / Math.pow(10.0, Math.floor(Math.log10(i)));
            if (splits.contains(n)) {
                if (i > 900) {
                    double predict5 = ChineseRestaurantTest.predictSize(m5.viewPart(0, k, 0, 3), i, 0.5);
                    ChineseRestaurantTest.assertEquals((double)predict5, (double)Math.log(s5.size()), (double)1.0);
                    double predict9 = ChineseRestaurantTest.predictSize(m9.viewPart(0, k, 0, 3), i, 0.9);
                    ChineseRestaurantTest.assertEquals((double)predict9, (double)Math.log(s9.size()), (double)1.0);
                } else if (i > 50) {
                    double x = 10.5 * Math.log(i) - (double)s0.size();
                    m5.viewRow(k).assign(new double[]{Math.log(s5.size()), Math.log(i), 1.0});
                    m9.viewRow(k).assign(new double[]{Math.log(s9.size()), Math.log(i), 1.0});
                    offset0 += (x - offset0) / (double)(++k);
                }
                if (i > 10000) {
                    ChineseRestaurantTest.assertEquals((double)0.0, (double)((double)ChineseRestaurantTest.hapaxCount(s0) / (double)s0.size()), (double)0.25);
                    ChineseRestaurantTest.assertEquals((double)0.5, (double)((double)ChineseRestaurantTest.hapaxCount(s5) / (double)s5.size()), (double)0.1);
                    ChineseRestaurantTest.assertEquals((double)0.9, (double)((double)ChineseRestaurantTest.hapaxCount(s9) / (double)s9.size()), (double)0.05);
                }
            }
            s0.sample();
            s5.sample();
            s9.sample();
        }
    }

    private static double predictSize(Matrix m, int currentIndex, double expectedCoefficient) {
        int rows = m.rowSize();
        Matrix a = m.viewPart(0, rows, 1, 2);
        Matrix b = m.viewPart(0, rows, 0, 1);
        Matrix ata = a.transpose().times(a);
        Matrix atb = a.transpose().times(b);
        QRDecomposition s = new QRDecomposition(ata);
        Matrix r = s.solve(atb).transpose();
        ChineseRestaurantTest.assertEquals((double)expectedCoefficient, (double)r.get(0, 0), (double)0.2);
        return r.times((Vector)new DenseVector(new double[]{Math.log(currentIndex), 1.0})).get(0);
    }

    private static int hapaxCount(ChineseRestaurant s) {
        int r = 0;
        for (int i = 0; i < s.size(); ++i) {
            if (s.count(i) != 1) continue;
            ++r;
        }
        return r;
    }
}

