package org.apache.drill.exec.physical.impl.join;

import java.util.Iterator;
import java.util.Map;
import org.apache.drill.exec.physical.impl.join.BatchSizePredictor;
import org.apache.drill.exec.record.RecordBatch;
import org.apache.drill.exec.record.RecordBatchSizer;
import org.apache.drill.shaded.guava.com.google.common.base.Preconditions;

/* loaded from: input_file:org/apache/drill/exec/physical/impl/join/BatchSizePredictorImpl.class */
public class BatchSizePredictorImpl implements BatchSizePredictor {
    private RecordBatch batch;
    private double fragmentationFactor;
    private double safetyFactor;
    private long batchSize;
    private int numRecords;
    private boolean updatedStats;
    private boolean hasData;

    /* loaded from: input_file:org/apache/drill/exec/physical/impl/join/BatchSizePredictorImpl$Factory.class */
    public static class Factory implements BatchSizePredictor.Factory {
        public static final Factory INSTANCE = new Factory();

        private Factory() {
        }

        @Override // org.apache.drill.exec.physical.impl.join.BatchSizePredictor.Factory
        public BatchSizePredictor create(RecordBatch recordBatch, double d, double d2) {
            return new BatchSizePredictorImpl(recordBatch, d, d2);
        }
    }

    public BatchSizePredictorImpl(RecordBatch recordBatch, double d, double d2) {
        this.batch = (RecordBatch) Preconditions.checkNotNull(recordBatch);
        this.fragmentationFactor = d;
        this.safetyFactor = d2;
    }

    @Override // org.apache.drill.exec.physical.impl.join.BatchSizePredictor
    public long getBatchSize() {
        Preconditions.checkState(this.updatedStats);
        if (this.hasData) {
            return this.batchSize;
        }
        return 0L;
    }

    @Override // org.apache.drill.exec.physical.impl.join.BatchSizePredictor
    public int getNumRecords() {
        Preconditions.checkState(this.updatedStats);
        if (this.hasData) {
            return this.numRecords;
        }
        return 0;
    }

    @Override // org.apache.drill.exec.physical.impl.join.BatchSizePredictor
    public boolean hadDataLastTime() {
        return this.hasData;
    }

    @Override // org.apache.drill.exec.physical.impl.join.BatchSizePredictor
    public void updateStats() {
        this.numRecords = new RecordBatchSizer(this.batch).rowCount();
        this.updatedStats = true;
        this.hasData = this.numRecords > 0;
        if (this.hasData) {
            this.batchSize = getBatchSizeEstimate(this.batch);
        }
    }

    @Override // org.apache.drill.exec.physical.impl.join.BatchSizePredictor
    public long predictBatchSize(int i, boolean z) {
        Preconditions.checkState(this.hasData);
        return computeMaxBatchSize(this.batchSize, this.numRecords, i, this.fragmentationFactor, this.safetyFactor, z);
    }

    public static long computeValueVectorSize(long j, long j2) {
        return roundUpToPowerOf2(j * j2);
    }

    public static long computeValueVectorSize(long j, long j2, double d) {
        return roundUpToPowerOf2(RecordBatchSizer.multiplyByFactor(j * j2, d));
    }

    public static long roundUpToPowerOf2(long j) {
        Preconditions.checkArgument(j >= 1);
        if (j == 1) {
            return 1L;
        }
        return Long.highestOneBit(j - 1) << 1;
    }

    public static long computeMaxBatchSizeNoHash(long j, int i, int i2, double d, double d2) {
        return RecordBatchSizer.multiplyByFactors(computePartitionBatchSize(j, i, i2), d, d2);
    }

    public static long computeMaxBatchSize(long j, int i, int i2, double d, double d2, boolean z) {
        long computeMaxBatchSizeNoHash = computeMaxBatchSizeNoHash(j, i, i2, d, d2);
        return !z ? computeMaxBatchSizeNoHash : computeMaxBatchSizeNoHash + RecordBatchSizer.multiplyByFactors(i2 * 4, d);
    }

    public static long computePartitionBatchSize(long j, int i, int i2) {
        return (long) Math.ceil((j / i) * i2);
    }

    public static long getBatchSizeEstimate(RecordBatch recordBatch) {
        long j = 0;
        Iterator<Map.Entry<String, RecordBatchSizer.ColumnSize>> it = new RecordBatchSizer(recordBatch).columns().entrySet().iterator();
        while (it.hasNext()) {
            j += computeValueVectorSize(recordBatch.getRecordCount(), it.next().getValue().getStdNetOrNetSizePerEntry());
        }
        return j;
    }
}
