package org.apache.hadoop.hive.ql.optimizer.optiq.translator;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hbase.util.Addressing;
import org.apache.hadoop.hive.metastore.api.FieldSchema;
import org.apache.hadoop.hive.ql.optimizer.optiq.HiveOptiqUtil;
import org.apache.hadoop.hive.ql.optimizer.optiq.OptiqSemanticException;
import org.apache.hadoop.hive.ql.optimizer.optiq.reloperators.HiveAggregateRel;
import org.apache.hadoop.hive.ql.optimizer.optiq.reloperators.HiveProjectRel;
import org.apache.hadoop.hive.ql.optimizer.optiq.reloperators.HiveSortRel;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;
import org.eigenbase.rel.AggregateCall;
import org.eigenbase.rel.AggregateRelBase;
import org.eigenbase.rel.EmptyRel;
import org.eigenbase.rel.FilterRelBase;
import org.eigenbase.rel.JoinRelBase;
import org.eigenbase.rel.OneRowRelBase;
import org.eigenbase.rel.ProjectRelBase;
import org.eigenbase.rel.RelCollationImpl;
import org.eigenbase.rel.RelNode;
import org.eigenbase.rel.SetOpRel;
import org.eigenbase.rel.SingleRel;
import org.eigenbase.rel.SortRel;
import org.eigenbase.rel.rules.MultiJoinRel;
import org.eigenbase.relopt.RelOptUtil;
import org.eigenbase.relopt.hep.HepRelVertex;
import org.eigenbase.relopt.volcano.RelSubset;
import org.eigenbase.reltype.RelDataType;
import org.eigenbase.reltype.RelDataTypeFactory;
import org.eigenbase.rex.RexNode;
import org.eigenbase.sql.SqlKind;
import org.eigenbase.util.Pair;

/* loaded from: input_file:org/apache/hadoop/hive/ql/optimizer/optiq/translator/PlanModifierForASTConv.class */
public class PlanModifierForASTConv {
    private static final Log LOG = LogFactory.getLog(PlanModifierForASTConv.class);

    public static RelNode convertOpTree(RelNode relNode, List<FieldSchema> list) throws OptiqSemanticException {
        RelNode relNode2 = relNode;
        if (LOG.isDebugEnabled()) {
            LOG.debug("Original plan for PlanModifier\n " + RelOptUtil.toString(relNode2));
        }
        if (!(relNode2 instanceof ProjectRelBase) && !(relNode2 instanceof SortRel)) {
            relNode2 = introduceDerivedTable(relNode2);
            if (LOG.isDebugEnabled()) {
                LOG.debug("Plan after top-level introduceDerivedTable\n " + RelOptUtil.toString(relNode2));
            }
        }
        convertOpTree(relNode2, (RelNode) null);
        if (LOG.isDebugEnabled()) {
            LOG.debug("Plan after nested convertOpTree\n " + RelOptUtil.toString(relNode2));
        }
        fixTopOBSchema(relNode2, HiveOptiqUtil.getTopLevelSelect(relNode2), list);
        if (LOG.isDebugEnabled()) {
            LOG.debug("Plan after fixTopOBSchema\n " + RelOptUtil.toString(relNode2));
        }
        RelNode renameTopLevelSelectInResultSchema = renameTopLevelSelectInResultSchema(relNode2, HiveOptiqUtil.getTopLevelSelect(relNode2), list);
        if (LOG.isDebugEnabled()) {
            LOG.debug("Final plan after modifier\n " + RelOptUtil.toString(renameTopLevelSelectInResultSchema));
        }
        return renameTopLevelSelectInResultSchema;
    }

    private static void convertOpTree(RelNode relNode, RelNode relNode2) {
        if (relNode instanceof EmptyRel) {
            throw new RuntimeException("Found Empty Rel");
        }
        if (relNode instanceof HepRelVertex) {
            throw new RuntimeException("Found HepRelVertex");
        }
        if (!(relNode instanceof JoinRelBase)) {
            if (relNode instanceof MultiJoinRel) {
                throw new RuntimeException("Found MultiJoinRel");
            }
            if (relNode instanceof OneRowRelBase) {
                throw new RuntimeException("Found OneRowRelBase");
            }
            if (relNode instanceof RelSubset) {
                throw new RuntimeException("Found RelSubset");
            }
            if (relNode instanceof SetOpRel) {
                if (!validSetopParent(relNode, relNode2)) {
                    introduceDerivedTable(relNode, relNode2);
                }
                SetOpRel setOpRel = (SetOpRel) relNode;
                for (RelNode relNode3 : setOpRel.getInputs()) {
                    if (!validSetopChild(relNode3)) {
                        introduceDerivedTable(relNode3, setOpRel);
                    }
                }
            } else if (relNode instanceof SingleRel) {
                if (relNode instanceof FilterRelBase) {
                    if (!validFilterParent(relNode, relNode2)) {
                        introduceDerivedTable(relNode, relNode2);
                    }
                } else if (relNode instanceof HiveSortRel) {
                    if (!validSortParent(relNode, relNode2)) {
                        introduceDerivedTable(relNode, relNode2);
                    }
                    if (!validSortChild((HiveSortRel) relNode)) {
                        introduceDerivedTable(((HiveSortRel) relNode).getChild(), relNode);
                    }
                } else if (relNode instanceof HiveAggregateRel) {
                    RelNode relNode4 = relNode2;
                    if (!validGBParent(relNode, relNode2)) {
                        relNode4 = introduceDerivedTable(relNode, relNode2);
                    }
                    if (isEmptyGrpAggr(relNode)) {
                        replaceEmptyGroupAggr(relNode, relNode4);
                    }
                }
            }
        } else if (!validJoinParent(relNode, relNode2)) {
            introduceDerivedTable(relNode, relNode2);
        }
        List<RelNode> inputs = relNode.getInputs();
        if (inputs != null) {
            Iterator<RelNode> it = inputs.iterator();
            while (it.hasNext()) {
                convertOpTree(it.next(), relNode);
            }
        }
    }

    private static void fixTopOBSchema(RelNode relNode, Pair<RelNode, RelNode> pair, List<FieldSchema> list) throws OptiqSemanticException {
        if ((pair.getKey() instanceof SortRel) && HiveOptiqUtil.orderRelNode(pair.getKey())) {
            HiveSortRel hiveSortRel = (HiveSortRel) pair.getKey();
            ProjectRelBase projectRelBase = (ProjectRelBase) pair.getValue();
            if (projectRelBase.getRowType().getFieldCount() <= list.size()) {
                return;
            }
            RelDataType rowType = projectRelBase.getRowType();
            HashSet hashSet = new HashSet(RelCollationImpl.ordinals(hiveSortRel.getCollation()));
            ImmutableMap.Builder builder = ImmutableMap.builder();
            for (int size = list.size(); size < rowType.getFieldCount(); size++) {
                if (hashSet.contains(Integer.valueOf(size))) {
                    builder.put(Integer.valueOf(size), projectRelBase.getChildExps().get(size));
                }
            }
            ImmutableMap<Integer, RexNode> build = builder.build();
            if (projectRelBase.getRowType().getFieldCount() - build.size() != list.size()) {
                LOG.error(generateInvalidSchemaMessage(projectRelBase, list, build.size()));
                throw new OptiqSemanticException("Result Schema didn't match Optimized Op Tree Schema");
            }
            hiveSortRel.replaceInput(0, HiveProjectRel.create(projectRelBase.getChild(), projectRelBase.getChildExps().subList(0, list.size()), projectRelBase.getRowType().getFieldNames().subList(0, list.size())));
            hiveSortRel.setInputRefToCallMap(build);
        }
    }

    private static String generateInvalidSchemaMessage(ProjectRelBase projectRelBase, List<FieldSchema> list, int i) {
        String str = "Result Schema didn't match Optiq Optimized Op Tree; schema: ";
        for (FieldSchema fieldSchema : list) {
            str = str + "[" + fieldSchema.getName() + Addressing.HOSTNAME_PORT_SEPARATOR + fieldSchema.getType() + "], ";
        }
        String str2 = str + " projection fields: ";
        for (RexNode rexNode : projectRelBase.getChildExps()) {
            str2 = str2 + "[" + rexNode.toString() + Addressing.HOSTNAME_PORT_SEPARATOR + rexNode.getType() + "], ";
        }
        if (i != 0) {
            str2 = str2 + i + " fields removed due to ORDER BY  ";
        }
        return str2.substring(0, str2.length() - 2);
    }

    private static RelNode renameTopLevelSelectInResultSchema(RelNode relNode, Pair<RelNode, RelNode> pair, List<FieldSchema> list) throws OptiqSemanticException {
        RelNode key = pair.getKey();
        HiveProjectRel hiveProjectRel = (HiveProjectRel) pair.getValue();
        List<RexNode> childExps = hiveProjectRel.getChildExps();
        if (list.size() != childExps.size()) {
            LOG.error(generateInvalidSchemaMessage(hiveProjectRel, list, 0));
            throw new OptiqSemanticException("Result Schema didn't match Optimized Op Tree Schema");
        }
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < childExps.size(); i++) {
            String name = list.get(i).getName();
            if (name.startsWith("_")) {
                name = name.substring(1);
            }
            arrayList.add(name);
        }
        HiveProjectRel create = HiveProjectRel.create(hiveProjectRel.getChild(), hiveProjectRel.getChildExps(), arrayList);
        if (relNode == hiveProjectRel) {
            return create;
        }
        key.replaceInput(0, create);
        return relNode;
    }

    private static RelNode introduceDerivedTable(RelNode relNode) {
        return HiveProjectRel.create(relNode.getCluster(), relNode, HiveOptiqUtil.getProjsFromBelowAsInputRef(relNode), relNode.getRowType(), relNode.getCollationList());
    }

    private static RelNode introduceDerivedTable(RelNode relNode, RelNode relNode2) {
        int i = 0;
        int i2 = -1;
        Iterator<RelNode> it = relNode2.getInputs().iterator();
        while (true) {
            if (!it.hasNext()) {
                break;
            }
            if (it.next() == relNode) {
                i2 = i;
                break;
            }
            i++;
        }
        if (i2 == -1) {
            throw new RuntimeException("Couldn't find child node in parent's inputs");
        }
        RelNode introduceDerivedTable = introduceDerivedTable(relNode);
        relNode2.replaceInput(i2, introduceDerivedTable);
        return introduceDerivedTable;
    }

    private static boolean validJoinParent(RelNode relNode, RelNode relNode2) {
        boolean z = true;
        if (relNode2 instanceof JoinRelBase) {
            if (((JoinRelBase) relNode2).getRight() == relNode) {
                z = false;
            }
        } else if (relNode2 instanceof SetOpRel) {
            z = false;
        }
        return z;
    }

    private static boolean validFilterParent(RelNode relNode, RelNode relNode2) {
        boolean z = true;
        if ((relNode2 instanceof FilterRelBase) || (relNode2 instanceof JoinRelBase) || (relNode2 instanceof SetOpRel)) {
            z = false;
        }
        return z;
    }

    private static boolean validGBParent(RelNode relNode, RelNode relNode2) {
        boolean z = true;
        if ((relNode2 instanceof JoinRelBase) || (relNode2 instanceof SetOpRel) || (relNode2 instanceof AggregateRelBase) || ((relNode2 instanceof FilterRelBase) && ((AggregateRelBase) relNode).getGroupSet().isEmpty())) {
            z = false;
        }
        return z;
    }

    private static boolean validSortParent(RelNode relNode, RelNode relNode2) {
        boolean z = true;
        if (relNode2 != null && !(relNode2 instanceof ProjectRelBase) && !(relNode2 instanceof SortRel) && !HiveOptiqUtil.orderRelNode(relNode2)) {
            z = false;
        }
        return z;
    }

    private static boolean validSortChild(HiveSortRel hiveSortRel) {
        boolean z = true;
        RelNode child = hiveSortRel.getChild();
        if ((!HiveOptiqUtil.limitRelNode(hiveSortRel) || !HiveOptiqUtil.orderRelNode(child)) && !(child instanceof ProjectRelBase)) {
            z = false;
        }
        return z;
    }

    private static boolean validSetopParent(RelNode relNode, RelNode relNode2) {
        boolean z = true;
        if (relNode2 != null && !(relNode2 instanceof ProjectRelBase)) {
            z = false;
        }
        return z;
    }

    private static boolean validSetopChild(RelNode relNode) {
        boolean z = true;
        if (!(relNode instanceof ProjectRelBase)) {
            z = false;
        }
        return z;
    }

    private static boolean isEmptyGrpAggr(RelNode relNode) {
        AggregateRelBase aggregateRelBase = (AggregateRelBase) relNode;
        return aggregateRelBase.getGroupSet().isEmpty() && aggregateRelBase.getAggCallList().isEmpty();
    }

    private static void replaceEmptyGroupAggr(RelNode relNode, RelNode relNode2) {
        for (RexNode rexNode : relNode2.getChildExps()) {
            if (rexNode.getKind() != SqlKind.LITERAL) {
                throw new RuntimeException("We expect " + relNode2.toString() + " to contain only constants. However, " + rexNode.toString() + " is " + rexNode.getKind());
            }
        }
        HiveAggregateRel hiveAggregateRel = (HiveAggregateRel) relNode;
        RelDataTypeFactory typeFactory = hiveAggregateRel.getCluster().getTypeFactory();
        RelDataType convert = TypeConverter.convert(TypeInfoFactory.longTypeInfo, typeFactory);
        relNode2.replaceInput(0, introduceDerivedTable(hiveAggregateRel.copy(hiveAggregateRel.getTraitSet(), hiveAggregateRel.getChild(), hiveAggregateRel.getGroupSet(), ImmutableList.of(new AggregateCall(SqlFunctionConverter.getOptiqAggFn("count", ImmutableList.of(TypeConverter.convert(TypeInfoFactory.intTypeInfo, typeFactory)), convert), false, ImmutableList.of(0), convert, null)))));
    }
}
