/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.plan.utils;

import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.calcite.plan.hep.HepRelVertex;
import org.apache.calcite.plan.volcano.RelSubset;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexFieldAccess;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlKind;
import org.apache.flink.table.functions.DeclarativeAggregateFunction;
import org.apache.flink.table.functions.FunctionDefinition;
import org.apache.flink.table.functions.python.PythonFunction;
import org.apache.flink.table.functions.python.PythonFunctionKind;
import org.apache.flink.table.planner.functions.bridging.BridgingSqlAggFunction;
import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction;
import org.apache.flink.table.planner.functions.utils.AggSqlFunction;
import org.apache.flink.table.planner.functions.utils.ScalarSqlFunction;
import org.apache.flink.table.planner.functions.utils.TableSqlFunction;
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCalc;
import org.apache.flink.table.planner.plan.utils.RexDefaultVisitor;
import org.apache.flink.table.runtime.functions.aggregate.BuiltInAggregateFunction;

public class PythonUtil {
    public static boolean containsPythonCall(RexNode node, PythonFunctionKind pythonFunctionKind) {
        FunctionFinder functionFinder = new FunctionFinder(true, Optional.ofNullable(pythonFunctionKind), true);
        return node.accept(functionFinder);
    }

    public static boolean containsPythonCall(RexNode node) {
        return PythonUtil.containsPythonCall(node, null);
    }

    public static boolean containsNonPythonCall(RexNode node) {
        FunctionFinder functionFinder = new FunctionFinder(false, Optional.empty(), true);
        return node.accept(functionFinder);
    }

    public static boolean isPythonCall(RexNode node, PythonFunctionKind pythonFunctionKind) {
        FunctionFinder functionFinder = new FunctionFinder(true, Optional.ofNullable(pythonFunctionKind), false);
        return node.accept(functionFinder);
    }

    public static boolean isPythonCall(RexNode node) {
        return PythonUtil.isPythonCall(node, null);
    }

    public static boolean isNonPythonCall(RexNode node) {
        FunctionFinder functionFinder = new FunctionFinder(false, Optional.empty(), false);
        return node.accept(functionFinder);
    }

    public static boolean isPythonAggregate(AggregateCall call) {
        return PythonUtil.isPythonAggregate(call, null);
    }

    public static boolean isPythonAggregate(AggregateCall call, PythonFunctionKind pythonFunctionKind) {
        SqlAggFunction aggregation = call.getAggregation();
        if (aggregation instanceof AggSqlFunction) {
            return PythonUtil.isPythonFunction(((AggSqlFunction)aggregation).aggregateFunction(), pythonFunctionKind);
        }
        if (aggregation instanceof BridgingSqlAggFunction) {
            return PythonUtil.isPythonFunction(((BridgingSqlAggFunction)aggregation).getDefinition(), pythonFunctionKind);
        }
        return false;
    }

    public static boolean isBuiltInAggregate(AggregateCall call) {
        SqlAggFunction aggregation = call.getAggregation();
        if (aggregation instanceof AggSqlFunction) {
            AggSqlFunction aggSqlFunction = (AggSqlFunction)aggregation;
            return aggSqlFunction.aggregateFunction() instanceof BuiltInAggregateFunction;
        }
        if (aggregation instanceof BridgingSqlAggFunction) {
            BridgingSqlAggFunction bridgingSqlAggFunction = (BridgingSqlAggFunction)aggregation;
            return bridgingSqlAggFunction.getDefinition() instanceof DeclarativeAggregateFunction;
        }
        return true;
    }

    public static boolean takesRowAsInput(RexCall call) {
        if (call.getOperator() instanceof ScalarSqlFunction) {
            ScalarSqlFunction sfc = (ScalarSqlFunction)call.getOperator();
            return ((PythonFunction)sfc.scalarFunction()).takesRowAsInput();
        }
        if (call.getOperator() instanceof TableSqlFunction) {
            TableSqlFunction tfc = (TableSqlFunction)call.getOperator();
            return ((PythonFunction)tfc.udtf()).takesRowAsInput();
        }
        if (call.getOperator() instanceof BridgingSqlFunction) {
            BridgingSqlFunction bsf = (BridgingSqlFunction)call.getOperator();
            return ((PythonFunction)bsf.getDefinition()).takesRowAsInput();
        }
        return false;
    }

    private static boolean isPythonFunction(FunctionDefinition function, PythonFunctionKind pythonFunctionKind) {
        if (function instanceof PythonFunction) {
            PythonFunction pythonFunction = (PythonFunction)function;
            return pythonFunctionKind == null || pythonFunction.getPythonFunctionKind() == pythonFunctionKind;
        }
        return false;
    }

    public static boolean isFlattenCalc(FlinkLogicalCalc calc) {
        RelNode child = calc.getInput();
        if (child instanceof RelSubset) {
            child = ((RelSubset)child).getOriginal();
        } else if (child instanceof HepRelVertex) {
            child = ((HepRelVertex)child).getCurrentRel();
        } else {
            return false;
        }
        if (!(child instanceof FlinkLogicalCalc)) {
            return false;
        }
        if (calc.getProgram().getCondition() != null) {
            return false;
        }
        List<RelDataTypeField> inputFields = calc.getProgram().getInputRowType().getFieldList();
        if (inputFields.size() != 1 || !inputFields.get(0).getType().isStruct()) {
            return false;
        }
        List projects = calc.getProgram().getProjectList().stream().map(calc.getProgram()::expandLocalRef).collect(Collectors.toList());
        if (inputFields.get(0).getType().getFieldCount() != projects.size()) {
            return false;
        }
        return IntStream.range(0, projects.size()).allMatch(idx -> ((RexNode)projects.get(idx)).accept(new FieldReferenceDetector(idx)));
    }

    private static class FunctionFinder
    extends RexDefaultVisitor<Boolean> {
        private final boolean findPythonFunction;
        private final Optional<PythonFunctionKind> pythonFunctionKind;
        private final boolean recursive;

        public FunctionFinder(boolean findPythonFunction, Optional<PythonFunctionKind> pythonFunctionKind, boolean recursive) {
            this.findPythonFunction = findPythonFunction;
            this.pythonFunctionKind = pythonFunctionKind;
            this.recursive = recursive;
        }

        private boolean isPythonRexCall(RexCall rexCall) {
            if (rexCall.getOperator() instanceof ScalarSqlFunction) {
                ScalarSqlFunction sfc = (ScalarSqlFunction)rexCall.getOperator();
                return this.isPythonFunction((FunctionDefinition)sfc.scalarFunction());
            }
            if (rexCall.getOperator() instanceof TableSqlFunction) {
                TableSqlFunction tfc = (TableSqlFunction)rexCall.getOperator();
                return this.isPythonFunction((FunctionDefinition)tfc.udtf());
            }
            if (rexCall.getOperator() instanceof BridgingSqlFunction) {
                BridgingSqlFunction bsf = (BridgingSqlFunction)rexCall.getOperator();
                return this.isPythonFunction(bsf.getDefinition());
            }
            return false;
        }

        private boolean isPythonFunction(FunctionDefinition functionDefinition) {
            if (functionDefinition instanceof PythonFunction) {
                PythonFunction pythonFunction = (PythonFunction)functionDefinition;
                return !this.pythonFunctionKind.isPresent() || pythonFunction.getPythonFunctionKind() == this.pythonFunctionKind.get();
            }
            return false;
        }

        @Override
        public Boolean visitCall(RexCall call) {
            return this.findPythonFunction == this.isPythonRexCall(call) || this.recursive && call.getOperands().stream().anyMatch(operand -> operand.accept(this));
        }

        @Override
        public Boolean visitFieldAccess(RexFieldAccess fieldAccess) {
            return fieldAccess.getReferenceExpr().accept(this);
        }

        @Override
        public Boolean visitNode(RexNode rexNode) {
            return false;
        }
    }

    private static class FieldReferenceDetector
    extends RexDefaultVisitor<Boolean> {
        private final int idx;

        public FieldReferenceDetector(int idx) {
            this.idx = idx;
        }

        @Override
        public Boolean visitNode(RexNode rexNode) {
            return false;
        }

        @Override
        public Boolean visitFieldAccess(RexFieldAccess fieldAccess) {
            if (fieldAccess.getField().getIndex() != this.idx) {
                return false;
            }
            RexNode expr = fieldAccess.getReferenceExpr();
            if (expr instanceof RexInputRef) {
                return ((RexInputRef)expr).getIndex() == 0;
            }
            return false;
        }

        @Override
        public Boolean visitCall(RexCall call) {
            if (call.getKind() == SqlKind.AS) {
                return call.getOperands().get(0).accept(this);
            }
            return false;
        }
    }
}

