package org.elasticsearch.compute.aggregation;

import java.util.List;
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
import org.elasticsearch.compute.aggregation.SeenGroupIds;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BooleanBlock;
import org.elasticsearch.compute.data.BooleanVector;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.compute.data.LongVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.DriverContext;

/* loaded from: input_file:org/elasticsearch/compute/aggregation/CountGroupingAggregatorFunction.class */
public class CountGroupingAggregatorFunction implements GroupingAggregatorFunction {
    private static final List<IntermediateStateDesc> INTERMEDIATE_STATE_DESC;
    private final LongArrayState state;
    private final List<Integer> channels;
    private final DriverContext driverContext;
    private final boolean countAll;
    static final /* synthetic */ boolean $assertionsDisabled;

    public static CountGroupingAggregatorFunction create(DriverContext driverContext, List<Integer> list) {
        return new CountGroupingAggregatorFunction(list, new LongArrayState(driverContext.bigArrays(), 0L), driverContext);
    }

    public static List<IntermediateStateDesc> intermediateStateDesc() {
        return INTERMEDIATE_STATE_DESC;
    }

    private CountGroupingAggregatorFunction(List<Integer> list, LongArrayState longArrayState, DriverContext driverContext) {
        this.channels = list;
        this.state = longArrayState;
        this.driverContext = driverContext;
        this.countAll = list.isEmpty();
    }

    private int blockIndex() {
        if (this.countAll) {
            return 0;
        }
        return this.channels.get(0).intValue();
    }

    @Override // org.elasticsearch.compute.aggregation.GroupingAggregatorFunction
    public int intermediateBlockCount() {
        return intermediateStateDesc().size();
    }

    @Override // org.elasticsearch.compute.aggregation.GroupingAggregatorFunction
    public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, Page page) {
        final Block block = page.getBlock(blockIndex());
        if (this.countAll || block.asVector() != null) {
            return new GroupingAggregatorFunction.AddInput() { // from class: org.elasticsearch.compute.aggregation.CountGroupingAggregatorFunction.2
                @Override // org.elasticsearch.compute.aggregation.GroupingAggregatorFunction.AddInput
                public void add(int i, IntBlock intBlock) {
                    CountGroupingAggregatorFunction.this.addRawInput(intBlock);
                }

                @Override // org.elasticsearch.compute.aggregation.GroupingAggregatorFunction.AddInput
                public void add(int i, IntVector intVector) {
                    CountGroupingAggregatorFunction.this.addRawInput(intVector);
                }

                public void close() {
                }
            };
        }
        if (block.mayHaveNulls()) {
            this.state.enableGroupIdTracking(seenGroupIds);
        }
        return new GroupingAggregatorFunction.AddInput() { // from class: org.elasticsearch.compute.aggregation.CountGroupingAggregatorFunction.1
            @Override // org.elasticsearch.compute.aggregation.GroupingAggregatorFunction.AddInput
            public void add(int i, IntBlock intBlock) {
                CountGroupingAggregatorFunction.this.addRawInput(i, intBlock, block);
            }

            @Override // org.elasticsearch.compute.aggregation.GroupingAggregatorFunction.AddInput
            public void add(int i, IntVector intVector) {
                CountGroupingAggregatorFunction.this.addRawInput(i, intVector, block);
            }

            public void close() {
            }
        };
    }

    private void addRawInput(int i, IntVector intVector, Block block) {
        int i2 = i;
        int i3 = 0;
        while (i3 < intVector.getPositionCount()) {
            int intExact = Math.toIntExact(intVector.getInt(i3));
            if (!block.isNull(i2)) {
                this.state.increment(intExact, block.getValueCount(i2));
            }
            i3++;
            i2++;
        }
    }

    private void addRawInput(int i, IntBlock intBlock, Block block) {
        int i2 = i;
        int i3 = 0;
        while (i3 < intBlock.getPositionCount()) {
            if (!intBlock.isNull(i3)) {
                int firstValueIndex = intBlock.getFirstValueIndex(i3);
                int valueCount = firstValueIndex + intBlock.getValueCount(i3);
                for (int i4 = firstValueIndex; i4 < valueCount; i4++) {
                    int intExact = Math.toIntExact(intBlock.getInt(i4));
                    if (!block.isNull(i2)) {
                        this.state.increment(intExact, block.getValueCount(i2));
                    }
                }
            }
            i3++;
            i2++;
        }
    }

    private void addRawInput(IntVector intVector) {
        for (int i = 0; i < intVector.getPositionCount(); i++) {
            this.state.increment(Math.toIntExact(intVector.getInt(i)), 1L);
        }
    }

    private void addRawInput(IntBlock intBlock) {
        for (int i = 0; i < intBlock.getPositionCount(); i++) {
            if (!intBlock.isNull(i)) {
                int firstValueIndex = intBlock.getFirstValueIndex(i);
                int valueCount = firstValueIndex + intBlock.getValueCount(i);
                for (int i2 = firstValueIndex; i2 < valueCount; i2++) {
                    this.state.increment(Math.toIntExact(intBlock.getInt(i2)), 1L);
                }
            }
        }
    }

    @Override // org.elasticsearch.compute.aggregation.GroupingAggregatorFunction
    public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) {
        this.state.enableGroupIdTracking(seenGroupIds);
    }

    @Override // org.elasticsearch.compute.aggregation.GroupingAggregatorFunction
    public void addIntermediateInput(int i, IntVector intVector, Page page) {
        if (!$assertionsDisabled && this.channels.size() != intermediateBlockCount()) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && page.getBlockCount() < blockIndex() + intermediateStateDesc().size()) {
            throw new AssertionError();
        }
        this.state.enableGroupIdTracking(new SeenGroupIds.Empty());
        LongVector asVector = ((LongBlock) page.getBlock(this.channels.get(0).intValue())).asVector();
        BooleanVector asVector2 = ((BooleanBlock) page.getBlock(this.channels.get(1).intValue())).asVector();
        if (!$assertionsDisabled && asVector.getPositionCount() != asVector2.getPositionCount()) {
            throw new AssertionError();
        }
        for (int i2 = 0; i2 < intVector.getPositionCount(); i2++) {
            this.state.increment(Math.toIntExact(intVector.getInt(i2)), asVector.getLong(i2 + i));
        }
    }

    @Override // org.elasticsearch.compute.aggregation.GroupingAggregatorFunction
    public void addIntermediateRowInput(int i, GroupingAggregatorFunction groupingAggregatorFunction, int i2) {
        if (groupingAggregatorFunction.getClass() != getClass()) {
            throw new IllegalArgumentException("expected " + String.valueOf(getClass()) + "; got " + String.valueOf(groupingAggregatorFunction.getClass()));
        }
        LongArrayState longArrayState = ((CountGroupingAggregatorFunction) groupingAggregatorFunction).state;
        this.state.enableGroupIdTracking(new SeenGroupIds.Empty());
        if (longArrayState.hasValue(i2)) {
            this.state.increment(i, longArrayState.get(i2));
        }
    }

    @Override // org.elasticsearch.compute.aggregation.GroupingAggregatorFunction
    public void evaluateIntermediate(Block[] blockArr, int i, IntVector intVector) {
        this.state.toIntermediate(blockArr, i, intVector, this.driverContext);
    }

    @Override // org.elasticsearch.compute.aggregation.GroupingAggregatorFunction
    public void evaluateFinal(Block[] blockArr, int i, IntVector intVector, DriverContext driverContext) {
        LongVector.FixedBuilder newLongVectorFixedBuilder = driverContext.blockFactory().newLongVectorFixedBuilder(intVector.getPositionCount());
        for (int i2 = 0; i2 < intVector.getPositionCount(); i2++) {
            try {
                int i3 = intVector.getInt(i2);
                newLongVectorFixedBuilder.appendLong(this.state.hasValue(i3) ? this.state.get(i3) : 0L);
            } catch (Throwable th) {
                if (newLongVectorFixedBuilder != null) {
                    try {
                        newLongVectorFixedBuilder.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
                throw th;
            }
        }
        blockArr[i] = newLongVectorFixedBuilder.build().asBlock();
        if (newLongVectorFixedBuilder != null) {
            newLongVectorFixedBuilder.close();
        }
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(getClass().getSimpleName()).append("[");
        sb.append("channels=").append(this.channels);
        sb.append("]");
        return sb.toString();
    }

    public void close() {
        this.state.close();
    }

    static {
        $assertionsDisabled = !CountGroupingAggregatorFunction.class.desiredAssertionStatus();
        INTERMEDIATE_STATE_DESC = List.of(new IntermediateStateDesc("count", ElementType.LONG), new IntermediateStateDesc("seen", ElementType.BOOLEAN));
    }
}
