package org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.ProcessNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.Assignments;
import org.apache.iotdb.db.queryengine.plan.relational.planner.PlannerContext;
import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol;
import org.apache.iotdb.db.queryengine.plan.relational.planner.ir.IrUtils;
import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.Rule;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.AssignUniqueId;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.CorrelatedJoinNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.JoinNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.ProjectNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations.PlanNodeDecorrelator;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.BooleanLiteral;
import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Capture;
import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Captures;
import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Pattern;
import org.apache.tsfile.read.common.type.BooleanType;
import org.apache.tsfile.read.common.type.LongType;
import org.apache.tsfile.read.common.type.Type;

/* loaded from: input_file:org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithoutProjection.class */
public class TransformCorrelatedGlobalAggregationWithoutProjection implements Rule<CorrelatedJoinNode> {
    private static final Capture<AggregationNode> AGGREGATION = Capture.newCapture();
    private static final Capture<PlanNode> SOURCE = Capture.newCapture();
    private static final Pattern<CorrelatedJoinNode> PATTERN = Patterns.correlatedJoin().with(Pattern.nonEmpty(Patterns.CorrelatedJoin.correlation())).with(Patterns.CorrelatedJoin.filter().equalTo(BooleanLiteral.TRUE_LITERAL)).with(Patterns.CorrelatedJoin.subquery().matching(Patterns.aggregation().with(Pattern.empty(Patterns.Aggregation.groupingColumns())).with(Patterns.source().capturedAs(SOURCE)).capturedAs(AGGREGATION)));
    private final PlannerContext plannerContext;

    public TransformCorrelatedGlobalAggregationWithoutProjection(PlannerContext plannerContext) {
        this.plannerContext = (PlannerContext) Objects.requireNonNull(plannerContext, "plannerContext is null");
    }

    @Override // org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.Rule
    public Pattern<CorrelatedJoinNode> getPattern() {
        return PATTERN;
    }

    @Override // org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.Rule
    public Rule.Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Rule.Context context) {
        Preconditions.checkArgument(correlatedJoinNode.getJoinType() == JoinNode.JoinType.INNER || correlatedJoinNode.getJoinType() == JoinNode.JoinType.LEFT, "unexpected correlated join type: %s", correlatedJoinNode.getJoinType());
        PlanNode planNode = (PlanNode) captures.get(SOURCE);
        AggregationNode aggregationNode = null;
        PlanNodeDecorrelator planNodeDecorrelator = new PlanNodeDecorrelator(this.plannerContext, context.getSymbolAllocator(), context.getLookup());
        Optional<PlanNodeDecorrelator.DecorrelatedNode> decorrelateFilters = planNodeDecorrelator.decorrelateFilters(planNode, correlatedJoinNode.getCorrelation());
        if (!decorrelateFilters.isPresent()) {
            if (AggregationDecorrelation.isDistinctOperator(planNode)) {
                aggregationNode = (AggregationNode) planNode;
                decorrelateFilters = planNodeDecorrelator.decorrelateFilters(aggregationNode.getChild(), correlatedJoinNode.getCorrelation());
            }
            if (!decorrelateFilters.isPresent()) {
                return Rule.Result.empty();
            }
        }
        PlanNode node = decorrelateFilters.get().getNode();
        Symbol newSymbol = context.getSymbolAllocator().newSymbol("non_null", (Type) BooleanType.getInstance());
        ProjectNode projectNode = new ProjectNode(context.getIdAllocator().genPlanNodeId(), node, Assignments.builder().putIdentities(node.getOutputSymbols()).put(newSymbol, BooleanLiteral.TRUE_LITERAL).build());
        AssignUniqueId assignUniqueId = new AssignUniqueId(context.getIdAllocator().genPlanNodeId(), correlatedJoinNode.getInput(), context.getSymbolAllocator().newSymbol("unique", (Type) LongType.getInstance()));
        JoinNode joinNode = new JoinNode(context.getIdAllocator().genPlanNodeId(), JoinNode.JoinType.LEFT, assignUniqueId, projectNode, ImmutableList.of(), assignUniqueId.getOutputSymbols(), projectNode.getOutputSymbols(), decorrelateFilters.get().getCorrelatedPredicates(), Optional.empty());
        ProcessNode processNode = joinNode;
        if (aggregationNode != null) {
            processNode = AggregationDecorrelation.restoreDistinctAggregation(aggregationNode, joinNode, ImmutableList.builder().addAll(joinNode.getLeftOutputSymbols()).add(newSymbol).addAll(aggregationNode.getGroupingKeys()).build());
        }
        AggregationNode aggregationNode2 = (AggregationNode) captures.get(AGGREGATION);
        ImmutableMap.Builder builder = ImmutableMap.builder();
        Assignments.Builder builder2 = Assignments.builder();
        for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : aggregationNode2.getAggregations().entrySet()) {
            AggregationNode.Aggregation value = entry.getValue();
            if (value.getMask().isPresent()) {
                Symbol newSymbol2 = context.getSymbolAllocator().newSymbol("mask", (Type) BooleanType.getInstance());
                builder2.put(newSymbol2, IrUtils.and(value.getMask().get().toSymbolReference(), newSymbol.toSymbolReference()));
                builder.put(entry.getKey(), newSymbol2);
            } else {
                builder.put(entry.getKey(), newSymbol);
            }
        }
        Assignments build = builder2.build();
        if (!build.isEmpty()) {
            processNode = new ProjectNode(context.getIdAllocator().genPlanNodeId(), processNode, Assignments.builder().putIdentities(processNode.getOutputSymbols()).putAll(build).build());
        }
        AggregationNode aggregationNode3 = new AggregationNode(aggregationNode2.getPlanNodeId(), processNode, AggregationDecorrelation.rewriteWithMasks(aggregationNode2.getAggregations(), builder.buildOrThrow()), AggregationNode.singleGroupingSet(ImmutableList.builder().addAll(joinNode.getLeftOutputSymbols()).addAll(aggregationNode2.getGroupingKeys()).build()), ImmutableList.of(), aggregationNode2.getStep(), Optional.empty(), Optional.empty());
        return Rule.Result.ofPlanNode(Util.restrictOutputs(context.getIdAllocator(), aggregationNode3, ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols())).orElse(aggregationNode3));
    }
}
