package org.apache.spark.sql.catalyst.expressions;

import java.util.Random;
import org.apache.spark.SparkConf;
import org.apache.spark.internal.config.package$;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.memory.TestMemoryManager;
import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.unsafe.KVIterator;
import org.apache.spark.unsafe.types.UTF8String;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.class */
public class RowBasedKeyValueBatchSuite {
    private TestMemoryManager memoryManager;
    private TaskMemoryManager taskMemoryManager;
    private final Random rand = new Random(42);
    private StructType keySchema = new StructType().add("k1", DataTypes.LongType).add("k2", DataTypes.StringType);
    private StructType fixedKeySchema = new StructType().add("k1", DataTypes.LongType).add("k2", DataTypes.LongType);
    private StructType valueSchema = new StructType().add("count", DataTypes.LongType).add("sum", DataTypes.LongType);
    private int DEFAULT_CAPACITY = 65536;

    private String getRandomString(int i) {
        Assert.assertTrue(i >= 0);
        byte[] bArr = new byte[i];
        this.rand.nextBytes(bArr);
        return new String(bArr);
    }

    private UnsafeRow makeKeyRow(long j, String str) {
        UnsafeRowWriter unsafeRowWriter = new UnsafeRowWriter(2);
        unsafeRowWriter.reset();
        unsafeRowWriter.write(0, j);
        unsafeRowWriter.write(1, UTF8String.fromString(str));
        return unsafeRowWriter.getRow();
    }

    private UnsafeRow makeKeyRow(long j, long j2) {
        UnsafeRowWriter unsafeRowWriter = new UnsafeRowWriter(2);
        unsafeRowWriter.reset();
        unsafeRowWriter.write(0, j);
        unsafeRowWriter.write(1, j2);
        return unsafeRowWriter.getRow();
    }

    private UnsafeRow makeValueRow(long j, long j2) {
        UnsafeRowWriter unsafeRowWriter = new UnsafeRowWriter(2);
        unsafeRowWriter.reset();
        unsafeRowWriter.write(0, j);
        unsafeRowWriter.write(1, j2);
        return unsafeRowWriter.getRow();
    }

    private UnsafeRow appendRow(RowBasedKeyValueBatch rowBasedKeyValueBatch, UnsafeRow unsafeRow, UnsafeRow unsafeRow2) {
        return rowBasedKeyValueBatch.appendRow(unsafeRow.getBaseObject(), unsafeRow.getBaseOffset(), unsafeRow.getSizeInBytes(), unsafeRow2.getBaseObject(), unsafeRow2.getBaseOffset(), unsafeRow2.getSizeInBytes());
    }

    private void updateValueRow(UnsafeRow unsafeRow, long j, long j2) {
        unsafeRow.setLong(0, j);
        unsafeRow.setLong(1, j2);
    }

    private boolean checkKey(UnsafeRow unsafeRow, long j, String str) {
        return unsafeRow.getLong(0) == j && unsafeRow.getUTF8String(1).equals(UTF8String.fromString(str));
    }

    private boolean checkKey(UnsafeRow unsafeRow, long j, long j2) {
        return unsafeRow.getLong(0) == j && unsafeRow.getLong(1) == j2;
    }

    private boolean checkValue(UnsafeRow unsafeRow, long j, long j2) {
        return unsafeRow.getLong(0) == j && unsafeRow.getLong(1) == j2;
    }

    @Before
    public void setup() {
        this.memoryManager = new TestMemoryManager(new SparkConf().set(package$.MODULE$.MEMORY_OFFHEAP_ENABLED(), false).set(package$.MODULE$.SHUFFLE_SPILL_COMPRESS(), false).set(package$.MODULE$.SHUFFLE_COMPRESS(), false));
        this.taskMemoryManager = new TaskMemoryManager(this.memoryManager, 0L);
    }

    @After
    public void tearDown() {
        if (this.taskMemoryManager != null) {
            Assert.assertEquals(0L, this.taskMemoryManager.cleanUpAllAllocatedMemory());
            long memoryConsumptionForThisTask = this.taskMemoryManager.getMemoryConsumptionForThisTask();
            this.taskMemoryManager = null;
            Assert.assertEquals(0L, memoryConsumptionForThisTask);
        }
    }

    @Test
    public void emptyBatch() throws Exception {
        RowBasedKeyValueBatch allocate = RowBasedKeyValueBatch.allocate(this.keySchema, this.valueSchema, this.taskMemoryManager, this.DEFAULT_CAPACITY);
        try {
            Assert.assertEquals(0L, allocate.numRows());
            Assert.assertThrows(AssertionError.class, () -> {
                allocate.getKeyRow(-1);
            });
            Assert.assertThrows(AssertionError.class, () -> {
                allocate.getValueRow(-1);
            });
            Assert.assertThrows(AssertionError.class, () -> {
                allocate.getKeyRow(0);
            });
            Assert.assertThrows(AssertionError.class, () -> {
                allocate.getValueRow(0);
            });
            Assert.assertFalse(allocate.rowIterator().next());
            if (allocate != null) {
                allocate.close();
            }
        } catch (Throwable th) {
            if (allocate != null) {
                try {
                    allocate.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void batchType() {
        RowBasedKeyValueBatch allocate = RowBasedKeyValueBatch.allocate(this.keySchema, this.valueSchema, this.taskMemoryManager, this.DEFAULT_CAPACITY);
        try {
            RowBasedKeyValueBatch allocate2 = RowBasedKeyValueBatch.allocate(this.fixedKeySchema, this.valueSchema, this.taskMemoryManager, this.DEFAULT_CAPACITY);
            try {
                Assert.assertEquals(VariableLengthRowBasedKeyValueBatch.class, allocate.getClass());
                Assert.assertEquals(FixedLengthRowBasedKeyValueBatch.class, allocate2.getClass());
                if (allocate2 != null) {
                    allocate2.close();
                }
                if (allocate != null) {
                    allocate.close();
                }
            } catch (Throwable th) {
                if (allocate2 != null) {
                    try {
                        allocate2.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
                throw th;
            }
        } catch (Throwable th3) {
            if (allocate != null) {
                try {
                    allocate.close();
                } catch (Throwable th4) {
                    th3.addSuppressed(th4);
                }
            }
            throw th3;
        }
    }

    @Test
    public void setAndRetrieve() {
        RowBasedKeyValueBatch allocate = RowBasedKeyValueBatch.allocate(this.keySchema, this.valueSchema, this.taskMemoryManager, this.DEFAULT_CAPACITY);
        try {
            Assert.assertTrue(checkValue(appendRow(allocate, makeKeyRow(1L, "A"), makeValueRow(1L, 1L)), 1L, 1L));
            Assert.assertTrue(checkValue(appendRow(allocate, makeKeyRow(2L, "B"), makeValueRow(2L, 2L)), 2L, 2L));
            Assert.assertTrue(checkValue(appendRow(allocate, makeKeyRow(3L, "C"), makeValueRow(3L, 3L)), 3L, 3L));
            Assert.assertEquals(3L, allocate.numRows());
            Assert.assertTrue(checkKey(allocate.getKeyRow(0), 1L, "A"));
            Assert.assertTrue(checkKey(allocate.getKeyRow(1), 2L, "B"));
            Assert.assertTrue(checkValue(allocate.getValueRow(1), 2L, 2L));
            Assert.assertTrue(checkValue(allocate.getValueRow(2), 3L, 3L));
            Assert.assertThrows(AssertionError.class, () -> {
                allocate.getKeyRow(3);
            });
            Assert.assertThrows(AssertionError.class, () -> {
                allocate.getValueRow(3);
            });
            if (allocate != null) {
                allocate.close();
            }
        } catch (Throwable th) {
            if (allocate != null) {
                try {
                    allocate.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void setUpdateAndRetrieve() {
        RowBasedKeyValueBatch allocate = RowBasedKeyValueBatch.allocate(this.keySchema, this.valueSchema, this.taskMemoryManager, this.DEFAULT_CAPACITY);
        try {
            appendRow(allocate, makeKeyRow(1L, "A"), makeValueRow(1L, 1L));
            Assert.assertEquals(1L, allocate.numRows());
            updateValueRow(allocate.getValueRow(0), 2L, 2L);
            Assert.assertTrue(checkValue(allocate.getValueRow(0), 2L, 2L));
            if (allocate != null) {
                allocate.close();
            }
        } catch (Throwable th) {
            if (allocate != null) {
                try {
                    allocate.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void iteratorTest() throws Exception {
        RowBasedKeyValueBatch allocate = RowBasedKeyValueBatch.allocate(this.keySchema, this.valueSchema, this.taskMemoryManager, this.DEFAULT_CAPACITY);
        try {
            appendRow(allocate, makeKeyRow(1L, "A"), makeValueRow(1L, 1L));
            appendRow(allocate, makeKeyRow(2L, "B"), makeValueRow(2L, 2L));
            appendRow(allocate, makeKeyRow(3L, "C"), makeValueRow(3L, 3L));
            Assert.assertEquals(3L, allocate.numRows());
            KVIterator rowIterator = allocate.rowIterator();
            Assert.assertTrue(rowIterator.next());
            UnsafeRow unsafeRow = (UnsafeRow) rowIterator.getKey();
            UnsafeRow unsafeRow2 = (UnsafeRow) rowIterator.getValue();
            Assert.assertTrue(checkKey(unsafeRow, 1L, "A"));
            Assert.assertTrue(checkValue(unsafeRow2, 1L, 1L));
            Assert.assertTrue(rowIterator.next());
            UnsafeRow unsafeRow3 = (UnsafeRow) rowIterator.getKey();
            UnsafeRow unsafeRow4 = (UnsafeRow) rowIterator.getValue();
            Assert.assertTrue(checkKey(unsafeRow3, 2L, "B"));
            Assert.assertTrue(checkValue(unsafeRow4, 2L, 2L));
            Assert.assertTrue(rowIterator.next());
            UnsafeRow unsafeRow5 = (UnsafeRow) rowIterator.getKey();
            UnsafeRow unsafeRow6 = (UnsafeRow) rowIterator.getValue();
            Assert.assertTrue(checkKey(unsafeRow5, 3L, "C"));
            Assert.assertTrue(checkValue(unsafeRow6, 3L, 3L));
            Assert.assertFalse(rowIterator.next());
            if (allocate != null) {
                allocate.close();
            }
        } catch (Throwable th) {
            if (allocate != null) {
                try {
                    allocate.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void fixedLengthTest() throws Exception {
        RowBasedKeyValueBatch allocate = RowBasedKeyValueBatch.allocate(this.fixedKeySchema, this.valueSchema, this.taskMemoryManager, this.DEFAULT_CAPACITY);
        try {
            appendRow(allocate, makeKeyRow(11L, 11L), makeValueRow(1L, 1L));
            appendRow(allocate, makeKeyRow(22L, 22L), makeValueRow(2L, 2L));
            appendRow(allocate, makeKeyRow(33L, 33L), makeValueRow(3L, 3L));
            Assert.assertTrue(checkKey(allocate.getKeyRow(0), 11L, 11L));
            Assert.assertTrue(checkKey(allocate.getKeyRow(1), 22L, 22L));
            Assert.assertTrue(checkValue(allocate.getValueRow(1), 2L, 2L));
            Assert.assertTrue(checkValue(allocate.getValueRow(2), 3L, 3L));
            Assert.assertEquals(3L, allocate.numRows());
            KVIterator rowIterator = allocate.rowIterator();
            Assert.assertTrue(rowIterator.next());
            UnsafeRow unsafeRow = (UnsafeRow) rowIterator.getKey();
            UnsafeRow unsafeRow2 = (UnsafeRow) rowIterator.getValue();
            Assert.assertTrue(checkKey(unsafeRow, 11L, 11L));
            Assert.assertTrue(checkValue(unsafeRow2, 1L, 1L));
            Assert.assertTrue(rowIterator.next());
            UnsafeRow unsafeRow3 = (UnsafeRow) rowIterator.getKey();
            UnsafeRow unsafeRow4 = (UnsafeRow) rowIterator.getValue();
            Assert.assertTrue(checkKey(unsafeRow3, 22L, 22L));
            Assert.assertTrue(checkValue(unsafeRow4, 2L, 2L));
            Assert.assertTrue(rowIterator.next());
            UnsafeRow unsafeRow5 = (UnsafeRow) rowIterator.getKey();
            UnsafeRow unsafeRow6 = (UnsafeRow) rowIterator.getValue();
            Assert.assertTrue(checkKey(unsafeRow5, 33L, 33L));
            Assert.assertTrue(checkValue(unsafeRow6, 3L, 3L));
            Assert.assertFalse(rowIterator.next());
            if (allocate != null) {
                allocate.close();
            }
        } catch (Throwable th) {
            if (allocate != null) {
                try {
                    allocate.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void appendRowUntilExceedingCapacity() throws Exception {
        RowBasedKeyValueBatch allocate = RowBasedKeyValueBatch.allocate(this.keySchema, this.valueSchema, this.taskMemoryManager, 10);
        try {
            UnsafeRow makeKeyRow = makeKeyRow(1L, "A");
            UnsafeRow makeValueRow = makeValueRow(1L, 1L);
            for (int i = 0; i < 10; i++) {
                appendRow(allocate, makeKeyRow, makeValueRow);
            }
            UnsafeRow appendRow = appendRow(allocate, makeKeyRow, makeValueRow);
            Assert.assertEquals(10L, allocate.numRows());
            Assert.assertNull(appendRow);
            KVIterator rowIterator = allocate.rowIterator();
            for (int i2 = 0; i2 < 10; i2++) {
                Assert.assertTrue(rowIterator.next());
                UnsafeRow unsafeRow = (UnsafeRow) rowIterator.getKey();
                UnsafeRow unsafeRow2 = (UnsafeRow) rowIterator.getValue();
                Assert.assertTrue(checkKey(unsafeRow, 1L, "A"));
                Assert.assertTrue(checkValue(unsafeRow2, 1L, 1L));
            }
            Assert.assertFalse(rowIterator.next());
            if (allocate != null) {
                allocate.close();
            }
        } catch (Throwable th) {
            if (allocate != null) {
                try {
                    allocate.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void appendRowUntilExceedingPageSize() throws Exception {
        int pageSizeBytes = (int) this.memoryManager.pageSizeBytes();
        RowBasedKeyValueBatch allocate = RowBasedKeyValueBatch.allocate(this.keySchema, this.valueSchema, this.taskMemoryManager, pageSizeBytes);
        try {
            UnsafeRow makeKeyRow = makeKeyRow(1L, "A");
            UnsafeRow makeValueRow = makeValueRow(1L, 1L);
            int sizeInBytes = 8 + makeKeyRow.getSizeInBytes() + makeValueRow.getSizeInBytes() + 8;
            int i = 4;
            int i2 = 0;
            while (i + sizeInBytes < pageSizeBytes) {
                appendRow(allocate, makeKeyRow, makeValueRow);
                i += sizeInBytes;
                i2++;
            }
            UnsafeRow appendRow = appendRow(allocate, makeKeyRow, makeValueRow);
            Assert.assertEquals(i2, allocate.numRows());
            Assert.assertNull(appendRow);
            KVIterator rowIterator = allocate.rowIterator();
            for (int i3 = 0; i3 < i2; i3++) {
                Assert.assertTrue(rowIterator.next());
                UnsafeRow unsafeRow = (UnsafeRow) rowIterator.getKey();
                UnsafeRow unsafeRow2 = (UnsafeRow) rowIterator.getValue();
                Assert.assertTrue(checkKey(unsafeRow, 1L, "A"));
                Assert.assertTrue(checkValue(unsafeRow2, 1L, 1L));
            }
            Assert.assertFalse(rowIterator.next());
            if (allocate != null) {
                allocate.close();
            }
        } catch (Throwable th) {
            if (allocate != null) {
                try {
                    allocate.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void failureToAllocateFirstPage() throws Exception {
        this.memoryManager.limit(1024L);
        RowBasedKeyValueBatch allocate = RowBasedKeyValueBatch.allocate(this.keySchema, this.valueSchema, this.taskMemoryManager, this.DEFAULT_CAPACITY);
        try {
            Assert.assertNull(appendRow(allocate, makeKeyRow(1L, "A"), makeValueRow(11L, 11L)));
            Assert.assertFalse(allocate.rowIterator().next());
            if (allocate != null) {
                allocate.close();
            }
        } catch (Throwable th) {
            if (allocate != null) {
                try {
                    allocate.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void randomizedTest() {
        RowBasedKeyValueBatch allocate = RowBasedKeyValueBatch.allocate(this.keySchema, this.valueSchema, this.taskMemoryManager, this.DEFAULT_CAPACITY);
        try {
            long[] jArr = new long[100];
            String[] strArr = new String[100];
            long[] jArr2 = new long[100];
            long[] jArr3 = new long[100];
            for (int i = 0; i < 100; i++) {
                long nextLong = this.rand.nextLong();
                String randomString = getRandomString(this.rand.nextInt(256));
                long nextLong2 = this.rand.nextLong();
                long nextLong3 = this.rand.nextLong();
                appendRow(allocate, makeKeyRow(nextLong, randomString), makeValueRow(nextLong2, nextLong3));
                jArr[i] = nextLong;
                strArr[i] = randomString;
                jArr2[i] = nextLong2;
                jArr3[i] = nextLong3;
            }
            for (int i2 = 0; i2 < 10000; i2++) {
                int nextInt = this.rand.nextInt(100);
                if (this.rand.nextBoolean()) {
                    Assert.assertTrue(checkKey(allocate.getKeyRow(nextInt), jArr[nextInt], strArr[nextInt]));
                }
                if (this.rand.nextBoolean()) {
                    Assert.assertTrue(checkValue(allocate.getValueRow(nextInt), jArr2[nextInt], jArr3[nextInt]));
                }
            }
            if (allocate != null) {
                allocate.close();
            }
        } catch (Throwable th) {
            if (allocate != null) {
                try {
                    allocate.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }
}
