/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.utils;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.hep.HepRelVertex;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Calc;
import org.apache.calcite.rel.core.Exchange;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinInfo;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rel.core.Union;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexProgram;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.flink.api.dag.Transformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.streaming.api.transformations.SourceTransformation;
import org.apache.flink.table.catalog.CatalogTable;
import org.apache.flink.table.catalog.ContextResolvedTable;
import org.apache.flink.table.catalog.ObjectIdentifier;
import org.apache.flink.table.connector.source.DataStreamScanProvider;
import org.apache.flink.table.connector.source.DynamicTableSource;
import org.apache.flink.table.connector.source.ScanTableSource;
import org.apache.flink.table.connector.source.SourceProvider;
import org.apache.flink.table.connector.source.abilities.SupportsDynamicFiltering;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
import org.apache.flink.table.planner.connectors.TransformationScanProvider;
import org.apache.flink.table.planner.plan.abilities.source.AggregatePushDownSpec;
import org.apache.flink.table.planner.plan.abilities.source.FilterPushDownSpec;
import org.apache.flink.table.planner.plan.abilities.source.SourceAbilitySpec;
import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalDynamicFilteringDataCollector;
import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalDynamicFilteringTableSourceScan;
import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalGroupAggregateBase;
import org.apache.flink.table.planner.plan.schema.TableSourceTable;
import org.apache.flink.table.runtime.connector.source.ScanRuntimeProviderContext;

public class DynamicPartitionPruningUtils {
    public static boolean isDppDimSide(RelNode rel) {
        DppDimSideChecker dimSideChecker = new DppDimSideChecker(rel);
        return dimSideChecker.isDppDimSide();
    }

    public static Tuple2<Boolean, RelNode> canConvertAndConvertDppFactSide(RelNode rel, ImmutableIntList joinKeys, RelNode dimSide, ImmutableIntList dimSideJoinKey) {
        DppFactSideChecker dppFactSideChecker = new DppFactSideChecker(rel, joinKeys, dimSide, dimSideJoinKey);
        return dppFactSideChecker.canConvertAndConvertDppFactSide();
    }

    public static boolean isSuitableJoin(Join join) {
        if (join.getJoinType() != JoinRelType.INNER && join.getJoinType() != JoinRelType.SEMI && join.getJoinType() != JoinRelType.LEFT && join.getJoinType() != JoinRelType.RIGHT) {
            return false;
        }
        JoinInfo joinInfo = join.analyzeCondition();
        return !joinInfo.leftKeys.isEmpty();
    }

    private static class DppDimSideChecker {
        private final RelNode relNode;
        private boolean hasFilter;
        private boolean hasPartitionedScan;
        private final Map<ObjectIdentifier, ContextResolvedTable> tables = new HashMap<ObjectIdentifier, ContextResolvedTable>();

        public DppDimSideChecker(RelNode relNode) {
            this.relNode = relNode;
        }

        public boolean isDppDimSide() {
            this.visitDimSide(this.relNode);
            return this.hasFilter && !this.hasPartitionedScan && this.tables.size() == 1;
        }

        private void visitDimSide(RelNode rel) {
            if (rel instanceof TableScan) {
                CatalogTable catalogTable;
                TableScan scan = (TableScan)rel;
                TableSourceTable table = scan.getTable().unwrap(TableSourceTable.class);
                if (table == null) {
                    return;
                }
                if (!this.hasFilter && table.abilitySpecs() != null && table.abilitySpecs().length != 0) {
                    for (SourceAbilitySpec spec : table.abilitySpecs()) {
                        if (!(spec instanceof FilterPushDownSpec)) continue;
                        List<RexNode> predicates = ((FilterPushDownSpec)spec).getPredicates();
                        for (RexNode predicate : predicates) {
                            if (!DppDimSideChecker.isSuitableFilter(predicate)) continue;
                            this.hasFilter = true;
                        }
                    }
                }
                if ((catalogTable = (CatalogTable)table.contextResolvedTable().getResolvedTable()).isPartitioned()) {
                    this.hasPartitionedScan = true;
                    return;
                }
                this.setTables(table.contextResolvedTable());
            } else if (rel instanceof HepRelVertex) {
                this.visitDimSide(((HepRelVertex)rel).getCurrentRel());
            } else if (rel instanceof Exchange || rel instanceof Project) {
                this.visitDimSide(rel.getInput(0));
            } else if (rel instanceof Calc) {
                RexProgram origProgram = ((Calc)rel).getProgram();
                if (origProgram.getCondition() != null && DppDimSideChecker.isSuitableFilter(origProgram.expandLocalRef(origProgram.getCondition()))) {
                    this.hasFilter = true;
                }
                this.visitDimSide(rel.getInput(0));
            } else if (rel instanceof Filter) {
                if (DppDimSideChecker.isSuitableFilter(((Filter)rel).getCondition())) {
                    this.hasFilter = true;
                }
                this.visitDimSide(rel.getInput(0));
            } else if (rel instanceof Join) {
                Join join = (Join)rel;
                this.visitDimSide(join.getLeft());
                this.visitDimSide(join.getRight());
            } else if (rel instanceof BatchPhysicalGroupAggregateBase) {
                this.visitDimSide(((BatchPhysicalGroupAggregateBase)rel).getInput());
            } else if (rel instanceof Union) {
                Union union = (Union)rel;
                for (RelNode input : union.getInputs()) {
                    this.visitDimSide(input);
                }
            }
        }

        private static boolean isSuitableFilter(RexNode filterCondition) {
            switch (filterCondition.getKind()) {
                case AND: {
                    List<RexNode> conjunctions = RelOptUtil.conjunctions(filterCondition);
                    return DppDimSideChecker.isSuitableFilter(conjunctions.get(0)) || DppDimSideChecker.isSuitableFilter(conjunctions.get(1));
                }
                case OR: {
                    List<RexNode> disjunctions = RelOptUtil.disjunctions(filterCondition);
                    return DppDimSideChecker.isSuitableFilter(disjunctions.get(0)) && DppDimSideChecker.isSuitableFilter(disjunctions.get(1));
                }
                case NOT: {
                    return DppDimSideChecker.isSuitableFilter((RexNode)((RexCall)filterCondition).operands.get(0));
                }
                case EQUALS: 
                case GREATER_THAN: 
                case GREATER_THAN_OR_EQUAL: 
                case LESS_THAN: 
                case LESS_THAN_OR_EQUAL: 
                case NOT_EQUALS: 
                case IN: 
                case LIKE: 
                case CONTAINS: 
                case SEARCH: 
                case IS_FALSE: 
                case IS_NOT_FALSE: 
                case IS_NOT_TRUE: 
                case IS_TRUE: {
                    return true;
                }
            }
            return false;
        }

        private void setTables(ContextResolvedTable catalogTable) {
            ObjectIdentifier identifier = catalogTable.getIdentifier();
            this.tables.putIfAbsent(identifier, catalogTable);
        }
    }

    private static class DppFactSideChecker {
        private final RelNode relNode;
        private final ImmutableIntList joinKeys;
        private final RelNode dimSide;
        private final ImmutableIntList dimSideJoinKey;
        private boolean isChanged;

        public DppFactSideChecker(RelNode relNode, ImmutableIntList joinKeys, RelNode dimSide, ImmutableIntList dimSideJoinKey) {
            this.relNode = relNode;
            this.joinKeys = joinKeys;
            this.dimSide = dimSide;
            this.dimSideJoinKey = dimSideJoinKey;
        }

        public Tuple2<Boolean, RelNode> canConvertAndConvertDppFactSide() {
            return Tuple2.of((Object)this.isChanged, (Object)this.convertDppFactSide(this.relNode, this.joinKeys, this.dimSide, this.dimSideJoinKey));
        }

        private RelNode convertDppFactSide(RelNode rel, ImmutableIntList joinKeys, RelNode dimSide, ImmutableIntList dimSideJoinKey) {
            if (rel instanceof TableScan) {
                TableScan scan = (TableScan)rel;
                if (scan instanceof BatchPhysicalDynamicFilteringTableSourceScan) {
                    return rel;
                }
                TableSourceTable tableSourceTable = scan.getTable().unwrap(TableSourceTable.class);
                if (tableSourceTable == null) {
                    return rel;
                }
                CatalogTable catalogTable = (CatalogTable)tableSourceTable.contextResolvedTable().getResolvedTable();
                List partitionKeys = catalogTable.getPartitionKeys();
                if (partitionKeys.isEmpty()) {
                    return rel;
                }
                DynamicTableSource tableSource = tableSourceTable.tableSource();
                if (!(tableSource instanceof SupportsDynamicFiltering) || !(tableSource instanceof ScanTableSource)) {
                    return rel;
                }
                if (Arrays.stream(tableSourceTable.abilitySpecs()).anyMatch(spec -> spec instanceof AggregatePushDownSpec)) {
                    return rel;
                }
                if (!DppFactSideChecker.isNewSource((ScanTableSource)tableSource)) {
                    return rel;
                }
                List<String> candidateFields = joinKeys.stream().map(i -> scan.getRowType().getFieldNames().get((int)i)).collect(Collectors.toList());
                if (candidateFields.isEmpty()) {
                    return rel;
                }
                List<String> acceptedFilterFields = DppFactSideChecker.getSuitableDynamicFilteringFieldsInFactSide(tableSource, candidateFields);
                if (acceptedFilterFields.size() == 0) {
                    return rel;
                }
                ((SupportsDynamicFiltering)tableSource).applyDynamicFiltering(acceptedFilterFields);
                List<Integer> acceptedFieldIndices = acceptedFilterFields.stream().map(f -> scan.getRowType().getFieldNames().indexOf(f)).collect(Collectors.toList());
                ArrayList<Integer> dynamicFilteringFieldIndices = new ArrayList<Integer>();
                for (int i2 = 0; i2 < joinKeys.size(); ++i2) {
                    if (!acceptedFieldIndices.contains(joinKeys.get(i2))) continue;
                    dynamicFilteringFieldIndices.add(dimSideJoinKey.get(i2));
                }
                BatchPhysicalDynamicFilteringDataCollector dynamicFilteringDataCollector = DppFactSideChecker.createDynamicFilteringConnector(dimSide, dynamicFilteringFieldIndices);
                this.isChanged = true;
                return new BatchPhysicalDynamicFilteringTableSourceScan(scan.getCluster(), scan.getTraitSet(), scan.getHints(), tableSourceTable, dynamicFilteringDataCollector, acceptedFieldIndices);
            }
            if (rel instanceof Exchange || rel instanceof Filter) {
                return rel.copy(rel.getTraitSet(), Collections.singletonList(this.convertDppFactSide(rel.getInput(0), joinKeys, dimSide, dimSideJoinKey)));
            }
            if (rel instanceof Project) {
                List<RexNode> projects = ((Project)rel).getProjects();
                ImmutableIntList inputJoinKeys = DppFactSideChecker.getInputIndices(projects, joinKeys);
                if (inputJoinKeys.isEmpty()) {
                    return rel;
                }
                return rel.copy(rel.getTraitSet(), Collections.singletonList(this.convertDppFactSide(rel.getInput(0), inputJoinKeys, dimSide, dimSideJoinKey)));
            }
            if (rel instanceof Calc) {
                Calc calc = (Calc)rel;
                RexProgram program = calc.getProgram();
                List<RexNode> projects = program.getProjectList().stream().map(program::expandLocalRef).collect(Collectors.toList());
                ImmutableIntList inputJoinKeys = DppFactSideChecker.getInputIndices(projects, joinKeys);
                if (inputJoinKeys.isEmpty()) {
                    return rel;
                }
                return rel.copy(rel.getTraitSet(), Collections.singletonList(this.convertDppFactSide(rel.getInput(0), inputJoinKeys, dimSide, dimSideJoinKey)));
            }
            if (rel instanceof Join) {
                Join currentJoin = (Join)rel;
                return currentJoin.copy(currentJoin.getTraitSet(), (List)Arrays.asList(this.convertDppFactSide(currentJoin.getLeft(), DppFactSideChecker.getInputIndices(currentJoin, joinKeys, true), dimSide, dimSideJoinKey), this.convertDppFactSide(currentJoin.getRight(), DppFactSideChecker.getInputIndices(currentJoin, joinKeys, false), dimSide, dimSideJoinKey)));
            }
            if (rel instanceof Union) {
                Union union = (Union)rel;
                ArrayList<RelNode> newInputs = new ArrayList<RelNode>();
                for (RelNode input : union.getInputs()) {
                    newInputs.add(this.convertDppFactSide(input, joinKeys, dimSide, dimSideJoinKey));
                }
                return union.copy(union.getTraitSet(), newInputs, union.all);
            }
            if (rel instanceof BatchPhysicalGroupAggregateBase) {
                BatchPhysicalGroupAggregateBase agg = (BatchPhysicalGroupAggregateBase)rel;
                RelNode input = agg.getInput();
                int[] grouping = agg.grouping();
                for (int k : joinKeys) {
                    if (k < grouping.length) continue;
                    return rel;
                }
                RelNode convertedRel = this.convertDppFactSide(input, ImmutableIntList.copyOf(joinKeys.stream().map(joinKey -> agg.grouping()[joinKey]).collect(Collectors.toList())), dimSide, dimSideJoinKey);
                return agg.copy(agg.getTraitSet(), Collections.singletonList(convertedRel));
            }
            return rel;
        }

        private static List<String> getSuitableDynamicFilteringFieldsInFactSide(DynamicTableSource tableSource, List<String> candidateFields) {
            List acceptedFilterFields = ((SupportsDynamicFiltering)tableSource).listAcceptedFilterFields();
            if (acceptedFilterFields == null || acceptedFilterFields.isEmpty()) {
                return new ArrayList<String>();
            }
            ArrayList<String> suitableFields = new ArrayList<String>();
            for (String candidateField : candidateFields) {
                if (!acceptedFilterFields.contains(candidateField)) continue;
                suitableFields.add(candidateField);
            }
            return suitableFields;
        }

        private static BatchPhysicalDynamicFilteringDataCollector createDynamicFilteringConnector(RelNode dimSide, List<Integer> dynamicFilteringFieldIndices) {
            RelDataType outputType = ((FlinkTypeFactory)dimSide.getCluster().getTypeFactory()).projectStructType(dimSide.getRowType(), dynamicFilteringFieldIndices.stream().mapToInt(i -> i).toArray());
            return new BatchPhysicalDynamicFilteringDataCollector(dimSide.getCluster(), dimSide.getTraitSet(), DppFactSideChecker.ignoreExchange(dimSide), outputType, dynamicFilteringFieldIndices.stream().mapToInt(i -> i).toArray());
        }

        private static RelNode ignoreExchange(RelNode dimSide) {
            if (dimSide instanceof Exchange) {
                return dimSide.getInput(0);
            }
            return dimSide;
        }

        private static boolean isNewSource(ScanTableSource scanTableSource) {
            ScanTableSource.ScanRuntimeProvider provider = scanTableSource.getScanRuntimeProvider((ScanTableSource.ScanContext)ScanRuntimeProviderContext.INSTANCE);
            if (provider instanceof SourceProvider) {
                return true;
            }
            if (provider instanceof TransformationScanProvider) {
                Transformation<RowData> transformation = ((TransformationScanProvider)provider).createTransformation(name -> Optional.empty());
                return transformation instanceof SourceTransformation;
            }
            return provider instanceof DataStreamScanProvider;
        }

        private static ImmutableIntList getInputIndices(List<RexNode> projects, ImmutableIntList joinKeys) {
            ArrayList<Integer> indices = new ArrayList<Integer>();
            for (int k : joinKeys) {
                RexNode rexNode = projects.get(k);
                if (!(rexNode instanceof RexInputRef)) continue;
                indices.add(((RexInputRef)rexNode).getIndex());
            }
            return ImmutableIntList.copyOf(indices);
        }

        private static ImmutableIntList getInputIndices(Join join, ImmutableIntList joinKeys, boolean isLeft) {
            ArrayList<Integer> indices = new ArrayList<Integer>();
            RelNode left = join.getLeft();
            int leftSize = left.getRowType().getFieldCount();
            for (int k : joinKeys) {
                if (isLeft) {
                    if (k >= leftSize) continue;
                    indices.add(k);
                    continue;
                }
                if (k < leftSize) continue;
                indices.add(k - leftSize);
            }
            return ImmutableIntList.copyOf(indices);
        }
    }
}

