/*
 * Decompiled with CFR 0.152.
 */
package org.eigenbase.rel.rules;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.eigenbase.rel.JoinRel;
import org.eigenbase.rel.RelNode;
import org.eigenbase.rel.rules.SemiJoinRel;
import org.eigenbase.relopt.RelOptRule;
import org.eigenbase.relopt.RelOptRuleCall;
import org.eigenbase.relopt.RelOptRuleOperand;
import org.eigenbase.relopt.RelOptUtil;
import org.eigenbase.reltype.RelDataTypeField;
import org.eigenbase.rex.RexNode;

public class PushSemiJoinPastJoinRule
extends RelOptRule {
    public static final PushSemiJoinPastJoinRule INSTANCE = new PushSemiJoinPastJoinRule();

    private PushSemiJoinPastJoinRule() {
        super(PushSemiJoinPastJoinRule.operand(SemiJoinRel.class, PushSemiJoinPastJoinRule.some(PushSemiJoinPastJoinRule.operand(JoinRel.class, PushSemiJoinPastJoinRule.any()), new RelOptRuleOperand[0])));
    }

    public void onMatch(RelOptRuleCall call) {
        RelNode rightJoinRel;
        RelNode leftJoinRel;
        List<Integer> newLeftKeys;
        RexNode newSemiJoinFilter;
        SemiJoinRel semiJoin = (SemiJoinRel)call.rel(0);
        JoinRel joinRel = (JoinRel)call.rel(1);
        List<Integer> leftKeys = semiJoin.getLeftKeys();
        List<Integer> rightKeys = semiJoin.getRightKeys();
        int nFieldsX = joinRel.getLeft().getRowType().getFieldList().size();
        int nFieldsY = joinRel.getRight().getRowType().getFieldList().size();
        int nFieldsZ = semiJoin.getRight().getRowType().getFieldList().size();
        int nTotalFields = nFieldsX + nFieldsY + nFieldsZ;
        ArrayList<RelDataTypeField> fields = new ArrayList<RelDataTypeField>();
        List<RelDataTypeField> joinFields = semiJoin.getRowType().getFieldList();
        int i = 0;
        while (i < nFieldsX + nFieldsY) {
            fields.add(joinFields.get(i));
            ++i;
        }
        joinFields = semiJoin.getRight().getRowType().getFieldList();
        i = 0;
        while (i < nFieldsZ) {
            fields.add(joinFields.get(i));
            ++i;
        }
        int nKeysFromX = 0;
        for (int leftKey : leftKeys) {
            if (leftKey >= nFieldsX) continue;
            ++nKeysFromX;
        }
        assert (nKeysFromX == 0 || nKeysFromX == leftKeys.size());
        int[] adjustments = new int[nTotalFields];
        if (nKeysFromX > 0) {
            this.setJoinAdjustments(adjustments, nFieldsX, nFieldsY, nFieldsZ, 0, -nFieldsY);
            newSemiJoinFilter = semiJoin.getCondition().accept(new RelOptUtil.RexInputConverter(semiJoin.getCluster().getRexBuilder(), fields, adjustments));
            newLeftKeys = leftKeys;
        } else {
            this.setJoinAdjustments(adjustments, nFieldsX, nFieldsY, nFieldsZ, -nFieldsX, -nFieldsX);
            newSemiJoinFilter = semiJoin.getCondition().accept(new RelOptUtil.RexInputConverter(semiJoin.getCluster().getRexBuilder(), fields, adjustments));
            newLeftKeys = RelOptUtil.adjustKeys(leftKeys, -nFieldsX);
        }
        RelNode leftSemiJoinOp = nKeysFromX > 0 ? joinRel.getLeft() : joinRel.getRight();
        SemiJoinRel newSemiJoin = new SemiJoinRel(semiJoin.getCluster(), leftSemiJoinOp, semiJoin.getRight(), newSemiJoinFilter, newLeftKeys, rightKeys);
        if (nKeysFromX > 0) {
            leftJoinRel = newSemiJoin;
            rightJoinRel = joinRel.getRight();
        } else {
            leftJoinRel = joinRel.getLeft();
            rightJoinRel = newSemiJoin;
        }
        JoinRel newJoinRel = new JoinRel(joinRel.getCluster(), leftJoinRel, rightJoinRel, joinRel.getCondition(), joinRel.getJoinType(), Collections.<String>emptySet(), joinRel.isSemiJoinDone(), joinRel.getSystemFieldList());
        call.transformTo(newJoinRel);
    }

    private void setJoinAdjustments(int[] adjustments, int nFieldsX, int nFieldsY, int nFieldsZ, int adjustY, int adjustZ) {
        int i = 0;
        while (i < nFieldsX) {
            adjustments[i] = 0;
            ++i;
        }
        i = nFieldsX;
        while (i < nFieldsX + nFieldsY) {
            adjustments[i] = adjustY;
            ++i;
        }
        i = nFieldsX + nFieldsY;
        while (i < nFieldsX + nFieldsY + nFieldsZ) {
            adjustments[i] = adjustZ;
            ++i;
        }
    }
}

