package org.apache.hadoop.hive.ql.exec.tez;

import hive.com.google.common.annotations.VisibleForTesting;
import hive.com.google.common.base.Preconditions;
import java.io.DataInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import jodd.util.StringPool;
import org.apache.commons.lang3.mutable.MutableInt;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hive.ql.exec.ExprNodeEvaluator;
import org.apache.hadoop.hive.ql.exec.ExprNodeEvaluatorFactory;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.plan.MapWork;
import org.apache.hadoop.hive.ql.plan.TableDesc;
import org.apache.hadoop.hive.serde2.Deserializer;
import org.apache.hadoop.hive.serde2.SerDeException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.StandardStructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.AbstractPrimitiveWritableObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;
import org.apache.hadoop.io.BytesWritable;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.util.ReflectionUtils;
import org.apache.tez.dag.api.event.VertexState;
import org.apache.tez.runtime.api.InputInitializerContext;
import org.apache.tez.runtime.api.events.InputInitializerEvent;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/hadoop/hive/ql/exec/tez/DynamicPartitionPruner.class */
public class DynamicPartitionPruner {
    private static final Logger LOG = LoggerFactory.getLogger(DynamicPartitionPruner.class);
    private final InputInitializerContext context;
    private final MapWork work;
    private final JobConf jobConf;
    private final Map<String, List<SourceInfo>> sourceInfoMap = new HashMap();
    private final BytesWritable writable = new BytesWritable();
    private final BlockingQueue<Object> queue = new LinkedBlockingQueue();
    private final Set<String> sourcesWaitingForEvents = new HashSet();
    private final Map<String, MutableInt> numExpectedEventsPerSource = new HashMap();
    private final Map<String, MutableInt> numEventsSeenPerSource = new HashMap();
    private int sourceInfoCount = 0;
    private final Object endOfEvents = new Object();
    private int totalEventCount = 0;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/hadoop/hive/ql/exec/tez/DynamicPartitionPruner$ByteBufferBackedInputStream.class */
    public static class ByteBufferBackedInputStream extends InputStream {
        ByteBuffer buf;

        public ByteBufferBackedInputStream(ByteBuffer byteBuffer) {
            this.buf = byteBuffer;
        }

        @Override // java.io.InputStream
        public int read() throws IOException {
            if (this.buf.hasRemaining()) {
                return this.buf.get() & 255;
            }
            return -1;
        }

        @Override // java.io.InputStream
        public int read(byte[] bArr, int i, int i2) throws IOException {
            if (!this.buf.hasRemaining()) {
                return -1;
            }
            int min = Math.min(i2, this.buf.remaining());
            this.buf.get(bArr, i, min);
            return min;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @VisibleForTesting
    /* loaded from: input_file:org/apache/hadoop/hive/ql/exec/tez/DynamicPartitionPruner$SourceInfo.class */
    public static class SourceInfo {
        public final ExprNodeDesc partKey;
        public final Deserializer deserializer;
        public final StructObjectInspector soi;
        public final StructField field;
        public final ObjectInspector fieldInspector;
        public Set<Object> values;
        public AtomicBoolean skipPruning;
        public final String columnName;
        public final String columnType;

        @VisibleForTesting
        SourceInfo(TableDesc tableDesc, ExprNodeDesc exprNodeDesc, String str, String str2, JobConf jobConf, Object obj) {
            this.values = new HashSet();
            this.skipPruning = new AtomicBoolean();
            this.partKey = exprNodeDesc;
            this.columnName = str;
            this.columnType = str2;
            this.deserializer = null;
            this.soi = null;
            this.field = null;
            this.fieldInspector = null;
        }

        public SourceInfo(TableDesc tableDesc, ExprNodeDesc exprNodeDesc, String str, String str2, JobConf jobConf) throws SerDeException {
            this.values = new HashSet();
            this.skipPruning = new AtomicBoolean();
            this.skipPruning.set(false);
            this.partKey = exprNodeDesc;
            this.columnName = str;
            this.columnType = str2;
            this.deserializer = (Deserializer) ReflectionUtils.newInstance(tableDesc.getDeserializerClass(), (Configuration) null);
            this.deserializer.initialize(jobConf, tableDesc.getProperties());
            ObjectInspector objectInspector = this.deserializer.getObjectInspector();
            DynamicPartitionPruner.LOG.debug("Type of obj insp: " + objectInspector.getTypeName());
            this.soi = (StructObjectInspector) objectInspector;
            List<? extends StructField> allStructFieldRefs = this.soi.getAllStructFieldRefs();
            if (allStructFieldRefs.size() > 1) {
                DynamicPartitionPruner.LOG.error("expecting single field in input");
            }
            this.field = allStructFieldRefs.get(0);
            this.fieldInspector = ObjectInspectorUtils.getStandardObjectInspector(this.field.getFieldObjectInspector());
        }
    }

    public DynamicPartitionPruner(InputInitializerContext inputInitializerContext, MapWork mapWork, JobConf jobConf) throws SerDeException {
        this.context = inputInitializerContext;
        this.work = mapWork;
        this.jobConf = jobConf;
        synchronized (this) {
            initialize();
        }
    }

    public void prune() throws SerDeException, IOException, InterruptedException, HiveException {
        synchronized (this.sourcesWaitingForEvents) {
            if (this.sourcesWaitingForEvents.isEmpty()) {
                return;
            }
            Set singleton = Collections.singleton(VertexState.SUCCEEDED);
            Iterator<String> it = this.sourcesWaitingForEvents.iterator();
            while (it.hasNext()) {
                this.context.registerForVertexStateUpdates(it.next(), singleton);
            }
            LOG.info("Waiting for events (" + this.sourceInfoCount + " sources) ...");
            processEvents();
            prunePartitions();
            LOG.info("Ok to proceed.");
        }
    }

    public BlockingQueue<Object> getQueue() {
        return this.queue;
    }

    private void clear() {
        this.sourceInfoMap.clear();
        this.sourceInfoCount = 0;
    }

    private void initialize() throws SerDeException {
        clear();
        HashMap hashMap = new HashMap();
        Set<String> keySet = this.work.getEventSourceTableDescMap().keySet();
        this.sourcesWaitingForEvents.addAll(keySet);
        for (String str : keySet) {
            this.numExpectedEventsPerSource.put(str, new MutableInt(0));
            this.numEventsSeenPerSource.put(str, new MutableInt(0));
            List<TableDesc> list = this.work.getEventSourceTableDescMap().get(str);
            List<String> list2 = this.work.getEventSourceColumnNameMap().get(str);
            List<String> list3 = this.work.getEventSourceColumnTypeMap().get(str);
            List<ExprNodeDesc> list4 = this.work.getEventSourcePartKeyExprMap().get(str);
            Iterator<String> it = list2.iterator();
            Iterator<String> it2 = list3.iterator();
            Iterator<ExprNodeDesc> it3 = list4.iterator();
            for (TableDesc tableDesc : list) {
                this.numExpectedEventsPerSource.get(str).decrement();
                this.sourceInfoCount++;
                String next = it.next();
                SourceInfo createSourceInfo = createSourceInfo(tableDesc, it3.next(), next, it2.next(), this.jobConf);
                if (!this.sourceInfoMap.containsKey(str)) {
                    this.sourceInfoMap.put(str, new ArrayList());
                }
                this.sourceInfoMap.get(str).add(createSourceInfo);
                if (hashMap.containsKey(next)) {
                    createSourceInfo.values = ((SourceInfo) hashMap.get(next)).values;
                    createSourceInfo.skipPruning = ((SourceInfo) hashMap.get(next)).skipPruning;
                }
                hashMap.put(next, createSourceInfo);
            }
        }
    }

    private void prunePartitions() throws HiveException {
        int i = 0;
        for (Map.Entry<String, List<SourceInfo>> entry : this.sourceInfoMap.entrySet()) {
            String key = entry.getKey();
            for (SourceInfo sourceInfo : entry.getValue()) {
                int vertexNumTasks = this.context.getVertexNumTasks(key);
                LOG.info("Expecting " + vertexNumTasks + " events for vertex " + key + ", for column " + sourceInfo.columnName);
                i += vertexNumTasks;
                prunePartitionSingleSource(key, sourceInfo);
            }
        }
        if (i != this.totalEventCount) {
            LOG.error("Expecting: " + i + ", received: " + this.totalEventCount);
            throw new HiveException("Incorrect event count in dynamic partition pruning");
        }
    }

    @VisibleForTesting
    protected void prunePartitionSingleSource(String str, SourceInfo sourceInfo) throws HiveException {
        if (sourceInfo.skipPruning.get()) {
            LOG.info("Skip pruning on " + str + ", column " + sourceInfo.columnName);
            return;
        }
        Set<Object> set = sourceInfo.values;
        String str2 = sourceInfo.columnName;
        if (LOG.isDebugEnabled()) {
            StringBuilder sb = new StringBuilder("Pruning ");
            sb.append(str2);
            sb.append(" with ");
            Iterator<Object> it = set.iterator();
            while (it.hasNext()) {
                Object next = it.next();
                sb.append(next == null ? null : next.toString());
                sb.append(", ");
            }
            LOG.debug(sb.toString());
        }
        AbstractPrimitiveWritableObjectInspector primitiveWritableObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(TypeInfoFactory.getPrimitiveTypeInfo(sourceInfo.columnType));
        ObjectInspectorConverters.Converter converter = ObjectInspectorConverters.getConverter((ObjectInspector) PrimitiveObjectInspectorFactory.javaStringObjectInspector, (ObjectInspector) primitiveWritableObjectInspector);
        StandardStructObjectInspector standardStructObjectInspector = ObjectInspectorFactory.getStandardStructObjectInspector(Collections.singletonList(str2), Collections.singletonList(primitiveWritableObjectInspector));
        ExprNodeEvaluator exprNodeEvaluator = ExprNodeEvaluatorFactory.get(sourceInfo.partKey);
        exprNodeEvaluator.initialize(standardStructObjectInspector);
        applyFilterToPartitions(converter, exprNodeEvaluator, str2, set);
    }

    private void applyFilterToPartitions(ObjectInspectorConverters.Converter converter, ExprNodeEvaluator exprNodeEvaluator, String str, Set<Object> set) throws HiveException {
        Object[] objArr = new Object[1];
        Iterator<Path> it = this.work.getPathToPartitionInfo().keySet().iterator();
        while (it.hasNext()) {
            Path next = it.next();
            LinkedHashMap<String, String> partSpec = this.work.getPathToPartitionInfo().get(next).getPartSpec();
            if (partSpec == null) {
                throw new IllegalStateException("No partition spec found in dynamic pruning");
            }
            String str2 = partSpec.get(str);
            if (str2 == null) {
                throw new IllegalStateException("Could not find partition value for column: " + str);
            }
            Object convert = converter.convert(str2);
            if (LOG.isDebugEnabled()) {
                LOG.debug("Converted partition value: " + convert + " original (" + str2 + StringPool.RIGHT_BRACKET);
            }
            objArr[0] = convert;
            Object evaluate = exprNodeEvaluator.evaluate(objArr);
            if (LOG.isDebugEnabled()) {
                LOG.debug("part key expr applied: " + evaluate);
            }
            if (!set.contains(evaluate)) {
                LOG.info("Pruning path: " + next);
                it.remove();
                this.work.removePathToAlias(next);
            }
        }
    }

    @VisibleForTesting
    protected SourceInfo createSourceInfo(TableDesc tableDesc, ExprNodeDesc exprNodeDesc, String str, String str2, JobConf jobConf) throws SerDeException {
        return new SourceInfo(tableDesc, exprNodeDesc, str, str2, jobConf);
    }

    private void processEvents() throws SerDeException, IOException, InterruptedException {
        int i = 0;
        while (true) {
            Object take = this.queue.take();
            if (take == this.endOfEvents) {
                LOG.info("Received events: " + i);
                return;
            }
            InputInitializerEvent inputInitializerEvent = (InputInitializerEvent) take;
            LOG.info("Input event: " + inputInitializerEvent.getTargetInputName() + ", " + inputInitializerEvent.getTargetVertexName() + ", " + (inputInitializerEvent.getUserPayload().limit() - inputInitializerEvent.getUserPayload().position()));
            processPayload(inputInitializerEvent.getUserPayload(), inputInitializerEvent.getSourceVertexName());
            i++;
        }
    }

    @VisibleForTesting
    protected String processPayload(ByteBuffer byteBuffer, String str) throws SerDeException, IOException {
        DataInputStream dataInputStream = new DataInputStream(new ByteBufferBackedInputStream(byteBuffer));
        try {
            String readUTF = dataInputStream.readUTF();
            LOG.info("Source of event: " + str);
            List<SourceInfo> list = this.sourceInfoMap.get(str);
            if (list == null) {
                throw new IllegalStateException("no source info for event source: " + str);
            }
            SourceInfo sourceInfo = null;
            Iterator<SourceInfo> it = list.iterator();
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                SourceInfo next = it.next();
                if (readUTF.equals(next.columnName)) {
                    sourceInfo = next;
                    break;
                }
            }
            if (sourceInfo == null) {
                throw new IllegalStateException("no source info for column: " + readUTF);
            }
            if (!sourceInfo.skipPruning.get()) {
                if (dataInputStream.readBoolean()) {
                    sourceInfo.skipPruning.set(true);
                } else {
                    while (byteBuffer.hasRemaining()) {
                        this.writable.readFields(dataInputStream);
                        Object copyToStandardObject = ObjectInspectorUtils.copyToStandardObject(sourceInfo.soi.getStructFieldData(sourceInfo.deserializer.deserialize(this.writable), sourceInfo.field), sourceInfo.fieldInspector);
                        if (LOG.isDebugEnabled()) {
                            LOG.debug("Adding: " + copyToStandardObject + " to list of required partitions");
                        }
                        sourceInfo.values.add(copyToStandardObject);
                    }
                }
            }
            return str;
        } finally {
            if (dataInputStream != null) {
                dataInputStream.close();
            }
        }
    }

    public void addEvent(InputInitializerEvent inputInitializerEvent) {
        synchronized (this.sourcesWaitingForEvents) {
            if (this.sourcesWaitingForEvents.contains(inputInitializerEvent.getSourceVertexName())) {
                this.totalEventCount++;
                this.numEventsSeenPerSource.get(inputInitializerEvent.getSourceVertexName()).increment();
                if (!this.queue.offer(inputInitializerEvent)) {
                    throw new IllegalStateException("Queue full");
                }
                checkForSourceCompletion(inputInitializerEvent.getSourceVertexName());
            }
        }
    }

    public void processVertex(String str) {
        LOG.info("Vertex succeeded: " + str);
        synchronized (this.sourcesWaitingForEvents) {
            MutableInt mutableInt = this.numExpectedEventsPerSource.get(str);
            int intValue = mutableInt.intValue();
            Preconditions.checkState(intValue < 0, "Invalid value for numExpectedEvents for source: " + str + ", oldVal=" + intValue);
            mutableInt.setValue((-1) * intValue * this.context.getVertexNumTasks(str));
            checkForSourceCompletion(str);
        }
    }

    private void checkForSourceCompletion(String str) {
        int intValue = this.numExpectedEventsPerSource.get(str).getValue2().intValue();
        if (intValue < 0) {
            return;
        }
        int intValue2 = this.numEventsSeenPerSource.get(str).getValue2().intValue();
        if (intValue2 != intValue) {
            if (intValue2 > intValue) {
                throw new IllegalStateException("Received too many events for " + str + ", Expected=" + intValue + ", Received=" + intValue2);
            }
            return;
        }
        this.sourcesWaitingForEvents.remove(str);
        if (!this.sourcesWaitingForEvents.isEmpty()) {
            LOG.info("Waiting for " + this.sourcesWaitingForEvents.size() + " sources.");
        } else if (!this.queue.offer(this.endOfEvents)) {
            throw new IllegalStateException("Queue full");
        }
    }
}
