/*
 * Decompiled with CFR 0.152.
 */
package org.eigenbase.rel.rules;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Set;
import net.hydromatic.optiq.util.BitSets;
import org.eigenbase.rel.RelNode;
import org.eigenbase.rel.metadata.RelColumnOrigin;
import org.eigenbase.rel.metadata.RelMdUtil;
import org.eigenbase.rel.metadata.RelMetadataQuery;
import org.eigenbase.rel.rules.LoptMultiJoin;
import org.eigenbase.rel.rules.SemiJoinRel;
import org.eigenbase.relopt.RelOptCost;
import org.eigenbase.relopt.RelOptTable;
import org.eigenbase.relopt.RelOptUtil;
import org.eigenbase.rex.RexBuilder;
import org.eigenbase.rex.RexCall;
import org.eigenbase.rex.RexInputRef;
import org.eigenbase.rex.RexNode;
import org.eigenbase.rex.RexUtil;
import org.eigenbase.sql.SqlKind;
import org.eigenbase.sql.SqlOperator;
import org.eigenbase.sql.fun.SqlStdOperatorTable;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class LoptSemiJoinOptimizer {
    private static final int THRESHOLD_SCORE = 10;
    private RexBuilder rexBuilder;
    private RelNode[] chosenSemiJoins;
    private Map<Integer, Map<Integer, SemiJoinRel>> possibleSemiJoins;
    private final Comparator<Integer> factorCostComparator = new FactorCostComparator();

    public LoptSemiJoinOptimizer(LoptMultiJoin multiJoin, RexBuilder rexBuilder) {
        int nJoinFactors = multiJoin.getNumJoinFactors();
        this.chosenSemiJoins = new RelNode[nJoinFactors];
        int i = 0;
        while (i < nJoinFactors) {
            this.chosenSemiJoins[i] = multiJoin.getJoinFactor(i);
            ++i;
        }
        this.rexBuilder = rexBuilder;
    }

    public void makePossibleSemiJoins(LoptMultiJoin multiJoin) {
        this.possibleSemiJoins = new HashMap<Integer, Map<Integer, SemiJoinRel>>();
        if (multiJoin.getMultiJoinRel().isFullOuterJoin()) {
            return;
        }
        int nJoinFactors = multiJoin.getNumJoinFactors();
        int factIdx = 0;
        while (factIdx < nJoinFactors) {
            HashMap<Integer, ArrayList<RexNode>> dimFilters = new HashMap<Integer, ArrayList<RexNode>>();
            HashMap<Integer, SemiJoinRel> semiJoinMap = new HashMap<Integer, SemiJoinRel>();
            for (RexNode joinFilter : multiJoin.getJoinFilters()) {
                int dimIdx = this.isSuitableFilter(multiJoin, joinFilter, factIdx);
                if (dimIdx == -1 || multiJoin.isNullGenerating(factIdx) || multiJoin.isNullGenerating(dimIdx)) continue;
                ArrayList<RexNode> currDimFilters = (ArrayList<RexNode>)dimFilters.get(dimIdx);
                if (currDimFilters == null) {
                    currDimFilters = new ArrayList<RexNode>();
                }
                currDimFilters.add(joinFilter);
                dimFilters.put(dimIdx, currDimFilters);
            }
            Set dimIdxes = dimFilters.keySet();
            for (Integer dimIdx : dimIdxes) {
                SemiJoinRel semiJoin;
                List joinFilters = (List)dimFilters.get(dimIdx);
                if (joinFilters == null || (semiJoin = this.findSemiJoinIndexByCost(multiJoin, joinFilters, factIdx, dimIdx)) == null) continue;
                semiJoinMap.put(dimIdx, semiJoin);
                this.possibleSemiJoins.put(factIdx, semiJoinMap);
            }
            ++factIdx;
        }
    }

    private int isSuitableFilter(LoptMultiJoin multiJoin, RexNode joinFilter, int factIdx) {
        switch (joinFilter.getKind()) {
            case EQUALS: {
                break;
            }
            default: {
                return -1;
            }
        }
        List<RexNode> operands = ((RexCall)joinFilter).getOperands();
        if (!(operands.get(0) instanceof RexInputRef) || !(operands.get(1) instanceof RexInputRef)) {
            return -1;
        }
        BitSet joinRefs = multiJoin.getFactorsRefByJoinFilter(joinFilter);
        assert (joinRefs.cardinality() == 2);
        int factor1 = joinRefs.nextSetBit(0);
        int factor2 = joinRefs.nextSetBit(factor1 + 1);
        if (factor1 == factIdx) {
            return factor2;
        }
        if (factor2 == factIdx) {
            return factor1;
        }
        return -1;
    }

    private SemiJoinRel findSemiJoinIndexByCost(LoptMultiJoin multiJoin, List<RexNode> joinFilters, int factIdx, int dimIdx) {
        ArrayList<Integer> truncatedRightKeys;
        ArrayList<Integer> truncatedLeftKeys;
        RexNode semiJoinCondition = RexUtil.composeConjunction(this.rexBuilder, joinFilters, true);
        int leftAdjustment = 0;
        int i = 0;
        while (i < factIdx) {
            leftAdjustment -= multiJoin.getNumFieldsInJoinFactor(i);
            ++i;
        }
        semiJoinCondition = this.adjustSemiJoinCondition(multiJoin, leftAdjustment, semiJoinCondition, factIdx, dimIdx);
        ArrayList<Integer> leftKeys = new ArrayList<Integer>();
        ArrayList<Integer> rightKeys = new ArrayList<Integer>();
        RelNode factRel = multiJoin.getJoinFactor(factIdx);
        RelNode dimRel = multiJoin.getJoinFactor(dimIdx);
        RelOptUtil.splitJoinCondition(factRel, dimRel, semiJoinCondition, leftKeys, rightKeys);
        assert (leftKeys.size() > 0);
        ArrayList<Integer> actualLeftKeys = new ArrayList<Integer>();
        LcsTable factTable = this.validateKeys(factRel, leftKeys, rightKeys, actualLeftKeys);
        if (factTable == null) {
            return null;
        }
        ArrayList<Integer> bestKeyOrder = new ArrayList<Integer>();
        LcsRowScanRel tmpFactRel = (LcsRowScanRel)((Object)factTable.toRel(RelOptUtil.getContext(factRel.getCluster())));
        LcsIndexOptimizer indexOptimizer = new LcsIndexOptimizer(tmpFactRel);
        FemLocalIndex bestIndex = indexOptimizer.findSemiJoinIndexByCost(dimRel, actualLeftKeys, rightKeys, bestKeyOrder);
        if (bestIndex == null) {
            return null;
        }
        if (actualLeftKeys.size() == bestKeyOrder.size()) {
            truncatedLeftKeys = leftKeys;
            truncatedRightKeys = rightKeys;
        } else {
            truncatedLeftKeys = new ArrayList();
            truncatedRightKeys = new ArrayList();
            Iterator iterator = bestKeyOrder.iterator();
            while (iterator.hasNext()) {
                int key = (Integer)iterator.next();
                truncatedLeftKeys.add((Integer)leftKeys.get(key));
                truncatedRightKeys.add((Integer)rightKeys.get(key));
            }
            semiJoinCondition = this.removeExtraFilters(truncatedLeftKeys, multiJoin.getNumFieldsInJoinFactor(factIdx), semiJoinCondition);
        }
        SemiJoinRel semiJoin = new SemiJoinRel(factRel.getCluster(), factRel, dimRel, semiJoinCondition, truncatedLeftKeys, truncatedRightKeys);
        return semiJoin;
    }

    private RexNode adjustSemiJoinCondition(LoptMultiJoin multiJoin, int leftAdjustment, RexNode semiJoinCondition, int leftIdx, int rightIdx) {
        int rightAdjustment = 0;
        int i = 0;
        while (i < rightIdx) {
            rightAdjustment -= multiJoin.getNumFieldsInJoinFactor(i);
            ++i;
        }
        int rightStart = -rightAdjustment;
        int numFieldsLeftIdx = multiJoin.getNumFieldsInJoinFactor(leftIdx);
        int numFieldsRightIdx = multiJoin.getNumFieldsInJoinFactor(rightIdx);
        if (leftAdjustment != 0 || (rightAdjustment += numFieldsLeftIdx) != 0) {
            int i2;
            int[] adjustments = new int[multiJoin.getNumTotalFields()];
            if (leftAdjustment != 0) {
                i2 = -leftAdjustment;
                while (i2 < -leftAdjustment + numFieldsLeftIdx) {
                    adjustments[i2] = leftAdjustment;
                    ++i2;
                }
            }
            if (rightAdjustment != 0) {
                i2 = rightStart;
                while (i2 < rightStart + numFieldsRightIdx) {
                    adjustments[i2] = rightAdjustment;
                    ++i2;
                }
            }
            return semiJoinCondition.accept(new RelOptUtil.RexInputConverter(this.rexBuilder, multiJoin.getMultiJoinFields(), adjustments));
        }
        return semiJoinCondition;
    }

    private LcsTable validateKeys(RelNode factRel, List<Integer> leftKeys, List<Integer> rightKeys, List<Integer> actualLeftKeys) {
        int keyIdx = 0;
        RelOptTable theTable = null;
        ListIterator<Integer> keyIter = leftKeys.listIterator();
        while (keyIter.hasNext()) {
            boolean removeKey = false;
            RelColumnOrigin colOrigin = RelMetadataQuery.getColumnOrigin((RelNode)factRel, (int)keyIter.next());
            if (colOrigin == null || LucidDbSpecialOperators.isLcsRidColumnId(colOrigin.getOriginColumnOrdinal())) {
                removeKey = true;
            } else {
                RelOptTable table = colOrigin.getOriginTable();
                if (theTable == null) {
                    if (!(table instanceof LcsTable)) {
                        removeKey = true;
                    } else {
                        theTable = table;
                    }
                } else assert (table == theTable);
            }
            if (!removeKey) {
                actualLeftKeys.add(colOrigin.getOriginColumnOrdinal());
                ++keyIdx;
                continue;
            }
            keyIter.remove();
            rightKeys.remove(keyIdx);
        }
        if (actualLeftKeys.isEmpty()) {
            return null;
        }
        return (LcsTable)theTable;
    }

    private RexNode removeExtraFilters(List<Integer> keys, int nFields, RexNode condition) {
        assert (condition instanceof RexCall);
        RexCall call = (RexCall)condition;
        if (condition.isA(SqlKind.AND)) {
            List<RexNode> operands = call.getOperands();
            RexNode left = this.removeExtraFilters(keys, nFields, operands.get(0));
            RexNode right = this.removeExtraFilters(keys, nFields, operands.get(1));
            if (left == null) {
                return right;
            }
            if (right == null) {
                return left;
            }
            return this.rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.AND, left, right);
        }
        assert (call.getOperator() == SqlStdOperatorTable.EQUALS);
        List<RexNode> operands = call.getOperands();
        assert (operands.get(0) instanceof RexInputRef);
        assert (operands.get(1) instanceof RexInputRef);
        int idx = ((RexInputRef)operands.get(0)).getIndex();
        if (idx < nFields ? !keys.contains(idx) : !keys.contains(idx = ((RexInputRef)operands.get(1)).getIndex())) {
            return null;
        }
        return condition;
    }

    public boolean chooseBestSemiJoin(LoptMultiJoin multiJoin) {
        int nJoinFactors = multiJoin.getNumJoinFactors();
        Integer[] sortedFactors = new Integer[nJoinFactors];
        int i = 0;
        while (i < nJoinFactors) {
            sortedFactors[i] = i;
            ++i;
        }
        Arrays.sort(sortedFactors, this.factorCostComparator);
        i = 0;
        while (i < nJoinFactors) {
            Integer factIdx = sortedFactors[i];
            RelNode factRel = this.chosenSemiJoins[factIdx];
            Map<Integer, SemiJoinRel> possibleDimensions = this.possibleSemiJoins.get(factIdx);
            if (possibleDimensions != null) {
                double bestScore = 0.0;
                int bestDimIdx = -1;
                Set<Integer> dimIdxes = possibleDimensions.keySet();
                for (Integer dimIdx : dimIdxes) {
                    double score;
                    SemiJoinRel semiJoin = possibleDimensions.get(dimIdx);
                    if (semiJoin == null || !((score = this.computeScore(factRel, this.chosenSemiJoins[dimIdx], semiJoin)) > 10.0) || !(score > bestScore)) continue;
                    bestDimIdx = dimIdx;
                    bestScore = score;
                }
                if (bestDimIdx != -1) {
                    SemiJoinRel semiJoin = possibleDimensions.get(bestDimIdx);
                    SemiJoinRel chosenSemiJoin = new SemiJoinRel(factRel.getCluster(), factRel, this.chosenSemiJoins[bestDimIdx], semiJoin.getCondition(), semiJoin.getLeftKeys(), semiJoin.getRightKeys());
                    this.chosenSemiJoins[factIdx.intValue()] = chosenSemiJoin;
                    this.removeJoin(multiJoin, chosenSemiJoin, factIdx, bestDimIdx);
                    this.removePossibleSemiJoin(possibleDimensions, factIdx, bestDimIdx);
                    this.removePossibleSemiJoin(this.possibleSemiJoins.get(bestDimIdx), bestDimIdx, factIdx);
                    return true;
                }
            }
            ++i;
        }
        return false;
    }

    private double computeScore(RelNode factRel, RelNode dimRel, SemiJoinRel semiJoin) {
        BitSet dimCols = new BitSet();
        for (int dimCol : semiJoin.getRightKeys()) {
            dimCols.set(dimCol);
        }
        double selectivity = RelMdUtil.computeSemiJoinSelectivity(factRel, dimRel, semiJoin);
        if (selectivity > 0.5) {
            return 0.0;
        }
        RelOptCost factCost = RelMetadataQuery.getCumulativeCost(factRel);
        if (factCost == null) {
            return 0.0;
        }
        double savings = (1.0 - Math.sqrt(selectivity)) * Math.max(1.0, factCost.getRows());
        boolean uniq = RelMdUtil.areColumnsDefinitelyUniqueWhenNullsFiltered(dimRel, dimCols);
        if (uniq) {
            savings *= 2.0;
        }
        Double dimSortCost = RelMetadataQuery.getRowCount(dimRel);
        Double dupRemCost = uniq ? 0.0 : dimSortCost;
        RelOptCost dimCost = RelMetadataQuery.getCumulativeCost(dimRel);
        if (dimSortCost == null || dupRemCost == null || dimCost == null) {
            return 0.0;
        }
        Double dimRows = dimCost.getRows();
        if (dimRows < 1.0) {
            dimRows = 1.0;
        }
        return savings / dimRows;
    }

    private void removeJoin(LoptMultiJoin multiJoin, SemiJoinRel semiJoin, int factIdx, int dimIdx) {
        if (multiJoin.getJoinRemovalFactor(dimIdx) != null) {
            return;
        }
        BitSet dimKeys = new BitSet();
        for (Integer key : semiJoin.getRightKeys()) {
            dimKeys.set(key);
        }
        RelNode dimRel = multiJoin.getJoinFactor(dimIdx);
        if (!RelMdUtil.areColumnsDefinitelyUniqueWhenNullsFiltered(dimRel, dimKeys)) {
            return;
        }
        BitSet dimProjRefs = multiJoin.getProjFields(dimIdx);
        if (dimProjRefs == null) {
            int nDimFields = multiJoin.getNumFieldsInJoinFactor(dimIdx);
            dimProjRefs = BitSets.range(0, nDimFields);
        }
        if (!BitSets.contains(dimKeys, dimProjRefs)) {
            return;
        }
        int[] dimJoinRefCounts = multiJoin.getJoinFieldRefCounts(dimIdx);
        int i = 0;
        while (i < dimJoinRefCounts.length) {
            if (dimJoinRefCounts[i] > 0 && !dimKeys.get(i)) {
                return;
            }
            ++i;
        }
        multiJoin.setJoinRemovalFactor(dimIdx, factIdx);
        multiJoin.setJoinRemovalSemiJoin(dimIdx, semiJoin);
        if (dimProjRefs.cardinality() != 0) {
            return;
        }
        i = 0;
        while (i < dimJoinRefCounts.length) {
            if (dimJoinRefCounts[i] > 1) {
                return;
            }
            if (dimJoinRefCounts[i] == 1 && !dimKeys.get(i)) {
                return;
            }
            ++i;
        }
        int[] factJoinRefCounts = multiJoin.getJoinFieldRefCounts(factIdx);
        for (Integer key : semiJoin.getLeftKeys()) {
            int n = key;
            factJoinRefCounts[n] = factJoinRefCounts[n] - 1;
        }
    }

    private void removePossibleSemiJoin(Map<Integer, SemiJoinRel> possibleDimensions, Integer factIdx, Integer dimIdx) {
        if (possibleDimensions == null) {
            return;
        }
        possibleDimensions.remove(dimIdx);
        if (possibleDimensions.isEmpty()) {
            this.possibleSemiJoins.remove(factIdx);
        } else {
            this.possibleSemiJoins.put(factIdx, possibleDimensions);
        }
    }

    public RelNode getChosenSemiJoin(int factIdx) {
        return this.chosenSemiJoins[factIdx];
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    private class FactorCostComparator
    implements Comparator<Integer> {
        private FactorCostComparator() {
        }

        @Override
        public int compare(Integer rel1Idx, Integer rel2Idx) {
            RelOptCost c1 = RelMetadataQuery.getCumulativeCost(LoptSemiJoinOptimizer.this.chosenSemiJoins[rel1Idx]);
            RelOptCost c2 = RelMetadataQuery.getCumulativeCost(LoptSemiJoinOptimizer.this.chosenSemiJoins[rel2Idx]);
            if (c1 == null || c2 == null) {
                return -1;
            }
            return c1.isLt(c2) ? -1 : (c1.equals(c2) ? 0 : 1);
        }
    }

    private static class FemLocalIndex {
        private FemLocalIndex() {
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    private static class LcsIndexOptimizer {
        public LcsIndexOptimizer(LcsRowScanRel rel) {
        }

        public FemLocalIndex findSemiJoinIndexByCost(RelNode dimRel, List<Integer> actualLeftKeys, List<Integer> rightKeys, List<Integer> bestKeyOrder) {
            return null;
        }
    }

    private static class LcsRowScanRel {
        private LcsRowScanRel() {
        }
    }

    private static abstract class LcsTable
    implements RelOptTable {
        private LcsTable() {
        }
    }

    private static class LucidDbSpecialOperators {
        private LucidDbSpecialOperators() {
        }

        public static boolean isLcsRidColumnId(int originColumnOrdinal) {
            return false;
        }
    }
}

