/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hive.druid.org.apache.calcite.rel.metadata;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.hive.druid.com.google.common.base.Function;
import org.apache.hive.druid.com.google.common.collect.HashMultimap;
import org.apache.hive.druid.com.google.common.collect.Iterables;
import org.apache.hive.druid.com.google.common.collect.Lists;
import org.apache.hive.druid.com.google.common.collect.Sets;
import org.apache.hive.druid.org.apache.calcite.plan.RelOptUtil;
import org.apache.hive.druid.org.apache.calcite.plan.hep.HepRelVertex;
import org.apache.hive.druid.org.apache.calcite.plan.volcano.RelSubset;
import org.apache.hive.druid.org.apache.calcite.rel.RelNode;
import org.apache.hive.druid.org.apache.calcite.rel.core.Aggregate;
import org.apache.hive.druid.org.apache.calcite.rel.core.Exchange;
import org.apache.hive.druid.org.apache.calcite.rel.core.Filter;
import org.apache.hive.druid.org.apache.calcite.rel.core.Join;
import org.apache.hive.druid.org.apache.calcite.rel.core.JoinRelType;
import org.apache.hive.druid.org.apache.calcite.rel.core.Project;
import org.apache.hive.druid.org.apache.calcite.rel.core.Sort;
import org.apache.hive.druid.org.apache.calcite.rel.core.TableScan;
import org.apache.hive.druid.org.apache.calcite.rel.core.Union;
import org.apache.hive.druid.org.apache.calcite.rel.metadata.BuiltInMetadata;
import org.apache.hive.druid.org.apache.calcite.rel.metadata.MetadataDef;
import org.apache.hive.druid.org.apache.calcite.rel.metadata.MetadataHandler;
import org.apache.hive.druid.org.apache.calcite.rel.metadata.ReflectiveRelMetadataProvider;
import org.apache.hive.druid.org.apache.calcite.rel.metadata.RelMetadataProvider;
import org.apache.hive.druid.org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.hive.druid.org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.hive.druid.org.apache.calcite.rex.RexBuilder;
import org.apache.hive.druid.org.apache.calcite.rex.RexInputRef;
import org.apache.hive.druid.org.apache.calcite.rex.RexNode;
import org.apache.hive.druid.org.apache.calcite.rex.RexShuttle;
import org.apache.hive.druid.org.apache.calcite.rex.RexTableInputRef;
import org.apache.hive.druid.org.apache.calcite.rex.RexUtil;
import org.apache.hive.druid.org.apache.calcite.util.BuiltInMethod;
import org.apache.hive.druid.org.apache.calcite.util.ImmutableBitSet;
import org.apache.hive.druid.org.apache.calcite.util.Util;

public class RelMdExpressionLineage
implements MetadataHandler<BuiltInMetadata.ExpressionLineage> {
    public static final RelMetadataProvider SOURCE = ReflectiveRelMetadataProvider.reflectiveSource(BuiltInMethod.EXPRESSION_LINEAGE.method, new RelMdExpressionLineage());

    private RelMdExpressionLineage() {
    }

    @Override
    public MetadataDef<BuiltInMetadata.ExpressionLineage> getDef() {
        return BuiltInMetadata.ExpressionLineage.DEF;
    }

    public Set<RexNode> getExpressionLineage(RelNode rel, RelMetadataQuery mq, RexNode outputExpression) {
        return null;
    }

    public Set<RexNode> getExpressionLineage(HepRelVertex rel, RelMetadataQuery mq, RexNode outputExpression) {
        return mq.getExpressionLineage(rel.getCurrentRel(), outputExpression);
    }

    public Set<RexNode> getExpressionLineage(RelSubset rel, RelMetadataQuery mq, RexNode outputExpression) {
        return mq.getExpressionLineage(Util.first(rel.getBest(), rel.getOriginal()), outputExpression);
    }

    public Set<RexNode> getExpressionLineage(TableScan rel, RelMetadataQuery mq, RexNode outputExpression) {
        RexBuilder rexBuilder = rel.getCluster().getRexBuilder();
        LinkedHashSet<RelDataTypeField> inputExtraFields = new LinkedHashSet<RelDataTypeField>();
        RelOptUtil.InputFinder inputFinder = new RelOptUtil.InputFinder(inputExtraFields);
        outputExpression.accept(inputFinder);
        ImmutableBitSet inputFieldsUsed = inputFinder.inputBitSet.build();
        LinkedHashMap<RexInputRef, Set<RexNode>> mapping = new LinkedHashMap<RexInputRef, Set<RexNode>>();
        for (int idx : inputFieldsUsed) {
            RexTableInputRef inputRef = RexTableInputRef.of(RexTableInputRef.RelTableRef.of(rel.getTable(), 0), RexInputRef.of(idx, rel.getRowType().getFieldList()));
            HashSet<RexNode> originalExprs = Sets.newHashSet(inputRef);
            RexInputRef ref = RexInputRef.of(idx, rel.getRowType().getFieldList());
            mapping.put(ref, originalExprs);
        }
        return RelMdExpressionLineage.createAllPossibleExpressions(rexBuilder, outputExpression, mapping);
    }

    public Set<RexNode> getExpressionLineage(Aggregate rel, RelMetadataQuery mq, RexNode outputExpression) {
        RelNode input = rel.getInput();
        RexBuilder rexBuilder = rel.getCluster().getRexBuilder();
        LinkedHashSet<RelDataTypeField> inputExtraFields = new LinkedHashSet<RelDataTypeField>();
        RelOptUtil.InputFinder inputFinder = new RelOptUtil.InputFinder(inputExtraFields);
        outputExpression.accept(inputFinder);
        ImmutableBitSet inputFieldsUsed = inputFinder.inputBitSet.build();
        for (int idx : inputFieldsUsed) {
            if (idx < rel.getGroupCount()) continue;
            return null;
        }
        LinkedHashMap<RexInputRef, Set<RexNode>> mapping = new LinkedHashMap<RexInputRef, Set<RexNode>>();
        for (int idx : inputFieldsUsed) {
            RexInputRef inputRef = RexInputRef.of(rel.getGroupSet().nth(idx), input.getRowType().getFieldList());
            Set<RexNode> originalExprs = mq.getExpressionLineage(input, inputRef);
            if (originalExprs == null) {
                return null;
            }
            RexInputRef ref = RexInputRef.of(idx, rel.getRowType().getFieldList());
            mapping.put(ref, originalExprs);
        }
        return RelMdExpressionLineage.createAllPossibleExpressions(rexBuilder, outputExpression, mapping);
    }

    public Set<RexNode> getExpressionLineage(Join rel, RelMetadataQuery mq, RexNode outputExpression) {
        if (rel.getJoinType() != JoinRelType.INNER) {
            return null;
        }
        final RexBuilder rexBuilder = rel.getCluster().getRexBuilder();
        RelNode leftInput = rel.getLeft();
        RelNode rightInput = rel.getRight();
        int nLeftColumns = leftInput.getRowType().getFieldList().size();
        HashMultimap<List<String>, RexTableInputRef.RelTableRef> qualifiedNamesToRefs = HashMultimap.create();
        final HashMap<RexTableInputRef.RelTableRef, RexTableInputRef.RelTableRef> currentTablesMapping = new HashMap<RexTableInputRef.RelTableRef, RexTableInputRef.RelTableRef>();
        LinkedHashMap<RexInputRef, Set<RexNode>> mapping = new LinkedHashMap<RexInputRef, Set<RexNode>>();
        for (int idx = 0; idx < rel.getRowType().getFieldList().size(); ++idx) {
            Set<RexTableInputRef.RelTableRef> tableRefs;
            Set<RexNode> originalExprs;
            RexInputRef inputRef;
            if (idx < nLeftColumns) {
                inputRef = RexInputRef.of(idx, leftInput.getRowType().getFieldList());
                originalExprs = mq.getExpressionLineage(leftInput, inputRef);
                if (originalExprs == null) {
                    return null;
                }
                tableRefs = RexUtil.gatherTableReferences(Lists.newArrayList(originalExprs));
                for (RexTableInputRef.RelTableRef leftRef : tableRefs) {
                    qualifiedNamesToRefs.put(leftRef.getQualifiedName(), leftRef);
                }
                mapping.put(RexInputRef.of(idx, rel.getRowType().getFieldList()), originalExprs);
                continue;
            }
            inputRef = RexInputRef.of(idx - nLeftColumns, rightInput.getRowType().getFieldList());
            originalExprs = mq.getExpressionLineage(rightInput, inputRef);
            if (originalExprs == null) {
                return null;
            }
            tableRefs = RexUtil.gatherTableReferences(Lists.newArrayList(originalExprs));
            for (RexTableInputRef.RelTableRef rightRef : tableRefs) {
                int shift = 0;
                Collection lRefs = qualifiedNamesToRefs.get(rightRef.getQualifiedName());
                if (lRefs != null) {
                    shift = lRefs.size();
                }
                currentTablesMapping.put(rightRef, RexTableInputRef.RelTableRef.of(rightRef.getTable(), shift + rightRef.getEntityNumber()));
            }
            HashSet<RexNode> updatedExprs = Sets.newHashSet(Iterables.transform(originalExprs, new Function<RexNode, RexNode>(){

                @Override
                public RexNode apply(RexNode e) {
                    return RexUtil.swapTableReferences(rexBuilder, e, currentTablesMapping);
                }
            }));
            mapping.put(RexInputRef.of(idx, rel.getRowType().getFieldList()), updatedExprs);
        }
        return RelMdExpressionLineage.createAllPossibleExpressions(rexBuilder, outputExpression, mapping);
    }

    public Set<RexNode> getExpressionLineage(Union rel, RelMetadataQuery mq, RexNode outputExpression) {
        final RexBuilder rexBuilder = rel.getCluster().getRexBuilder();
        HashMultimap<List<String>, RexTableInputRef.RelTableRef> qualifiedNamesToRefs = HashMultimap.create();
        LinkedHashMap<RexInputRef, Set<RexNode>> mapping = new LinkedHashMap<RexInputRef, Set<RexNode>>();
        for (RelNode input : rel.getInputs()) {
            final HashMap<RexTableInputRef.RelTableRef, RexTableInputRef.RelTableRef> currentTablesMapping = new HashMap<RexTableInputRef.RelTableRef, RexTableInputRef.RelTableRef>();
            for (int idx = 0; idx < input.getRowType().getFieldList().size(); ++idx) {
                RexInputRef inputRef = RexInputRef.of(idx, input.getRowType().getFieldList());
                Set<RexNode> originalExprs = mq.getExpressionLineage(input, inputRef);
                if (originalExprs == null) {
                    return null;
                }
                RexInputRef ref = RexInputRef.of(idx, rel.getRowType().getFieldList());
                Set<RexTableInputRef.RelTableRef> tableRefs = RexUtil.gatherTableReferences(Lists.newArrayList(originalExprs));
                for (RexTableInputRef.RelTableRef tableRef : tableRefs) {
                    int shift = 0;
                    Collection lRefs = qualifiedNamesToRefs.get(tableRef.getQualifiedName());
                    if (lRefs != null) {
                        shift = lRefs.size();
                    }
                    currentTablesMapping.put(tableRef, RexTableInputRef.RelTableRef.of(tableRef.getTable(), shift + tableRef.getEntityNumber()));
                }
                HashSet<RexNode> updatedExprs = Sets.newHashSet(Iterables.transform(originalExprs, new Function<RexNode, RexNode>(){

                    @Override
                    public RexNode apply(RexNode e) {
                        return RexUtil.swapTableReferences(rexBuilder, e, currentTablesMapping);
                    }
                }));
                Set set = (Set)mapping.get(ref);
                if (set != null) {
                    set.addAll(updatedExprs);
                    continue;
                }
                mapping.put(ref, updatedExprs);
            }
            for (RexTableInputRef.RelTableRef newRef : currentTablesMapping.values()) {
                qualifiedNamesToRefs.put(newRef.getQualifiedName(), newRef);
            }
        }
        return RelMdExpressionLineage.createAllPossibleExpressions(rexBuilder, outputExpression, mapping);
    }

    public Set<RexNode> getExpressionLineage(Project rel, RelMetadataQuery mq, RexNode outputExpression) {
        RelNode input = rel.getInput();
        RexBuilder rexBuilder = rel.getCluster().getRexBuilder();
        LinkedHashSet<RelDataTypeField> inputExtraFields = new LinkedHashSet<RelDataTypeField>();
        RelOptUtil.InputFinder inputFinder = new RelOptUtil.InputFinder(inputExtraFields);
        outputExpression.accept(inputFinder);
        ImmutableBitSet inputFieldsUsed = inputFinder.inputBitSet.build();
        LinkedHashMap<RexInputRef, Set<RexNode>> mapping = new LinkedHashMap<RexInputRef, Set<RexNode>>();
        for (int idx : inputFieldsUsed) {
            RexNode inputExpr = rel.getChildExps().get(idx);
            Set<RexNode> originalExprs = mq.getExpressionLineage(input, inputExpr);
            if (originalExprs == null) {
                return null;
            }
            RexInputRef ref = RexInputRef.of(idx, rel.getRowType().getFieldList());
            mapping.put(ref, originalExprs);
        }
        return RelMdExpressionLineage.createAllPossibleExpressions(rexBuilder, outputExpression, mapping);
    }

    public Set<RexNode> getExpressionLineage(Filter rel, RelMetadataQuery mq, RexNode outputExpression) {
        return mq.getExpressionLineage(rel.getInput(), outputExpression);
    }

    public Set<RexNode> getExpressionLineage(Sort rel, RelMetadataQuery mq, RexNode outputExpression) {
        return mq.getExpressionLineage(rel.getInput(), outputExpression);
    }

    public Set<RexNode> getExpressionLineage(Exchange rel, RelMetadataQuery mq, RexNode outputExpression) {
        return mq.getExpressionLineage(rel.getInput(), outputExpression);
    }

    protected static Set<RexNode> createAllPossibleExpressions(RexBuilder rexBuilder, RexNode expr, Map<RexInputRef, Set<RexNode>> mapping) {
        LinkedHashSet<RelDataTypeField> inputExtraFields = new LinkedHashSet<RelDataTypeField>();
        RelOptUtil.InputFinder inputFinder = new RelOptUtil.InputFinder(inputExtraFields);
        expr.accept(inputFinder);
        ImmutableBitSet predFieldsUsed = inputFinder.inputBitSet.build();
        if (predFieldsUsed.isEmpty()) {
            return Sets.newHashSet(expr);
        }
        return RelMdExpressionLineage.createAllPossibleExpressions(rexBuilder, expr, predFieldsUsed, mapping, new HashMap<RexInputRef, RexNode>());
    }

    private static Set<RexNode> createAllPossibleExpressions(RexBuilder rexBuilder, RexNode expr, ImmutableBitSet predFieldsUsed, Map<RexInputRef, Set<RexNode>> mapping, Map<RexInputRef, RexNode> singleMapping) {
        RexInputRef inputRef = mapping.keySet().iterator().next();
        Set<RexNode> replacements = mapping.remove(inputRef);
        HashSet<RexNode> result = new HashSet<RexNode>();
        assert (!replacements.isEmpty());
        if (predFieldsUsed.indexOf(inputRef.getIndex()) != -1) {
            for (RexNode replacement : replacements) {
                singleMapping.put(inputRef, replacement);
                RelMdExpressionLineage.createExpressions(rexBuilder, expr, predFieldsUsed, mapping, singleMapping, result);
                singleMapping.remove(inputRef);
            }
        } else {
            RelMdExpressionLineage.createExpressions(rexBuilder, expr, predFieldsUsed, mapping, singleMapping, result);
        }
        mapping.put(inputRef, replacements);
        return result;
    }

    private static void createExpressions(RexBuilder rexBuilder, RexNode expr, ImmutableBitSet predFieldsUsed, Map<RexInputRef, Set<RexNode>> mapping, Map<RexInputRef, RexNode> singleMapping, Set<RexNode> result) {
        if (mapping.isEmpty()) {
            RexReplacer replacer = new RexReplacer(singleMapping);
            ArrayList<RexNode> updatedPreds = new ArrayList<RexNode>(RelOptUtil.conjunctions(rexBuilder.copy(expr)));
            replacer.mutate(updatedPreds);
            result.addAll(updatedPreds);
        } else {
            result.addAll(RelMdExpressionLineage.createAllPossibleExpressions(rexBuilder, expr, predFieldsUsed, mapping, singleMapping));
        }
    }

    private static class RexReplacer
    extends RexShuttle {
        private final Map<RexInputRef, RexNode> replacementValues;

        RexReplacer(Map<RexInputRef, RexNode> replacementValues) {
            this.replacementValues = replacementValues;
        }

        @Override
        public RexNode visitInputRef(RexInputRef inputRef) {
            return this.replacementValues.get(inputRef);
        }
    }
}

