package org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode;
import org.apache.iotdb.db.queryengine.plan.relational.metadata.ResolvedFunction;
import org.apache.iotdb.db.queryengine.plan.relational.planner.Assignments;
import org.apache.iotdb.db.queryengine.plan.relational.planner.PlannerContext;
import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol;
import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.Rule;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.ApplyNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.CorrelatedJoinNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.JoinNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.LimitNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.ProjectNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations.PlanNodeDecorrelator;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.BooleanLiteral;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Cast;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.CoalesceExpression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ComparisonExpression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.LongLiteral;
import org.apache.iotdb.db.queryengine.plan.relational.type.TypeSignatureTranslator;
import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Captures;
import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Pattern;
import org.apache.iotdb.db.utils.constant.SqlConstant;
import org.apache.tsfile.read.common.type.BooleanType;
import org.apache.tsfile.read.common.type.LongType;
import org.apache.tsfile.read.common.type.Type;

/* loaded from: input_file:org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/TransformExistsApplyToCorrelatedJoin.class */
public class TransformExistsApplyToCorrelatedJoin implements Rule<ApplyNode> {
    private static final Pattern<ApplyNode> PATTERN = Patterns.applyNode();
    private final PlannerContext plannerContext;

    public TransformExistsApplyToCorrelatedJoin(PlannerContext plannerContext) {
        this.plannerContext = (PlannerContext) Objects.requireNonNull(plannerContext, "plannerContext is null");
    }

    @Override // org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.Rule
    public Pattern<ApplyNode> getPattern() {
        return PATTERN;
    }

    @Override // org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.Rule
    public Rule.Result apply(ApplyNode applyNode, Captures captures, Rule.Context context) {
        if (applyNode.getSubqueryAssignments().size() == 1 && (((ApplyNode.SetExpression) Iterables.getOnlyElement(applyNode.getSubqueryAssignments().values())) instanceof ApplyNode.Exists)) {
            return applyNode.getCorrelation().isEmpty() ? Rule.Result.ofPlanNode(rewriteToDefaultAggregation(applyNode, context)) : (Rule.Result) rewriteToNonDefaultAggregation(applyNode, context).map(Rule.Result::ofPlanNode).orElseGet(() -> {
                return Rule.Result.ofPlanNode(rewriteToDefaultAggregation(applyNode, context));
            });
        }
        return Rule.Result.empty();
    }

    private Optional<PlanNode> rewriteToNonDefaultAggregation(ApplyNode applyNode, Rule.Context context) {
        Preconditions.checkState(applyNode.getSubquery().getOutputSymbols().isEmpty(), "Expected subquery output symbols to be pruned");
        Symbol newSymbol = context.getSymbolAllocator().newSymbol("subqueryTrue", (Type) BooleanType.BOOLEAN);
        ProjectNode projectNode = new ProjectNode(context.getIdAllocator().genPlanNodeId(), new LimitNode(context.getIdAllocator().genPlanNodeId(), applyNode.getSubquery(), 1L, Optional.empty()), Assignments.of(newSymbol, BooleanLiteral.TRUE_LITERAL));
        if (!new PlanNodeDecorrelator(this.plannerContext, context.getSymbolAllocator(), context.getLookup()).decorrelateFilters(projectNode, applyNode.getCorrelation()).isPresent()) {
            return Optional.empty();
        }
        return Optional.of(new ProjectNode(context.getIdAllocator().genPlanNodeId(), new CorrelatedJoinNode(applyNode.getPlanNodeId(), applyNode.getInput(), projectNode, applyNode.getCorrelation(), JoinNode.JoinType.LEFT, BooleanLiteral.TRUE_LITERAL, applyNode.getOriginSubquery()), Assignments.builder().putIdentities(applyNode.getInput().getOutputSymbols()).put((Symbol) Iterables.getOnlyElement(applyNode.getSubqueryAssignments().keySet()), new CoalesceExpression((List<Expression>) ImmutableList.of(newSymbol.toSymbolReference(), BooleanLiteral.FALSE_LITERAL))).build()));
    }

    private PlanNode rewriteToDefaultAggregation(ApplyNode applyNode, Rule.Context context) {
        ResolvedFunction resolvedBuiltInAggregateFunction = org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations.Util.getResolvedBuiltInAggregateFunction(this.plannerContext.getMetadata(), SqlConstant.COUNT, ImmutableList.of());
        Symbol newSymbol = context.getSymbolAllocator().newSymbol(SqlConstant.COUNT, (Type) LongType.getInstance());
        return new CorrelatedJoinNode(applyNode.getPlanNodeId(), applyNode.getInput(), new ProjectNode(context.getIdAllocator().genPlanNodeId(), AggregationNode.singleAggregation(context.getIdAllocator().genPlanNodeId(), applyNode.getSubquery(), ImmutableMap.of(newSymbol, new AggregationNode.Aggregation(resolvedBuiltInAggregateFunction, ImmutableList.of(), false, Optional.empty(), Optional.empty(), Optional.empty())), AggregationNode.globalAggregation()), Assignments.of((Symbol) Iterables.getOnlyElement(applyNode.getSubqueryAssignments().keySet()), new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, newSymbol.toSymbolReference(), new Cast(new LongLiteral("0"), TypeSignatureTranslator.toSqlType(LongType.getInstance()))))), applyNode.getCorrelation(), JoinNode.JoinType.INNER, BooleanLiteral.TRUE_LITERAL, applyNode.getOriginSubquery());
    }
}
