package org.elasticsearch.compute.aggregation.blockhash;

import java.io.IOException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.BitArray;
import org.elasticsearch.common.util.BytesRefHash;
import org.elasticsearch.compute.aggregation.AggregatorMode;
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
import org.elasticsearch.compute.aggregation.SeenGroupIds;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BytesRefBlock;
import org.elasticsearch.compute.data.BytesRefVector;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.ReleasableIterator;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.index.analysis.AnalysisRegistry;
import org.elasticsearch.xpack.core.ml.job.config.CategorizationAnalyzerConfig;
import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash;
import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationPartOfSpeechDictionary;
import org.elasticsearch.xpack.ml.aggs.categorization.SerializableTokenListCategory;
import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer;
import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategory;
import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer;

/* loaded from: input_file:org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHash.class */
public class CategorizeBlockHash extends BlockHash {
    private static final CategorizationAnalyzerConfig ANALYZER_CONFIG = CategorizationAnalyzerConfig.buildStandardCategorizationAnalyzer(List.of());
    private static final int NULL_ORD = 0;
    private final int channel;
    private final AggregatorMode aggregatorMode;
    private final TokenListCategorizer.CloseableTokenListCategorizer categorizer;
    private final CategorizeEvaluator evaluator;
    private boolean seenNull;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHash$CategorizeEvaluator.class */
    public final class CategorizeEvaluator implements Releasable {
        private final CategorizationAnalyzer analyzer;

        CategorizeEvaluator(CategorizationAnalyzer categorizationAnalyzer) {
            this.analyzer = categorizationAnalyzer;
        }

        Block eval(BytesRefBlock bytesRefBlock) {
            BytesRefVector asVector = bytesRefBlock.asVector();
            return asVector == null ? eval(bytesRefBlock.getPositionCount(), bytesRefBlock) : eval(bytesRefBlock.getPositionCount(), asVector).asBlock();
        }

        IntBlock eval(int i, BytesRefBlock bytesRefBlock) {
            IntBlock.Builder newIntBlockBuilder = CategorizeBlockHash.this.blockFactory.newIntBlockBuilder(i);
            try {
                BytesRef bytesRef = new BytesRef();
                for (int i2 = 0; i2 < i; i2++) {
                    if (bytesRefBlock.isNull(i2)) {
                        CategorizeBlockHash.this.seenNull = true;
                        newIntBlockBuilder.mo263appendInt(0);
                    } else {
                        int firstValueIndex = bytesRefBlock.getFirstValueIndex(i2);
                        int valueCount = bytesRefBlock.getValueCount(i2);
                        if (valueCount == 1) {
                            newIntBlockBuilder.mo263appendInt(process(bytesRefBlock.getBytesRef(firstValueIndex, bytesRef)));
                        } else {
                            int i3 = firstValueIndex + valueCount;
                            newIntBlockBuilder.mo191beginPositionEntry();
                            for (int i4 = firstValueIndex; i4 < i3; i4++) {
                                newIntBlockBuilder.mo263appendInt(process(bytesRefBlock.getBytesRef(i4, bytesRef)));
                            }
                            newIntBlockBuilder.mo190endPositionEntry();
                        }
                    }
                }
                IntBlock mo193build = newIntBlockBuilder.mo193build();
                if (newIntBlockBuilder != null) {
                    newIntBlockBuilder.close();
                }
                return mo193build;
            } catch (Throwable th) {
                if (newIntBlockBuilder != null) {
                    try {
                        newIntBlockBuilder.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
                throw th;
            }
        }

        IntVector eval(int i, BytesRefVector bytesRefVector) {
            IntVector.FixedBuilder newIntVectorFixedBuilder = CategorizeBlockHash.this.blockFactory.newIntVectorFixedBuilder(i);
            try {
                BytesRef bytesRef = new BytesRef();
                for (int i2 = 0; i2 < i; i2++) {
                    newIntVectorFixedBuilder.appendInt(i2, process(bytesRefVector.getBytesRef(i2, bytesRef)));
                }
                IntVector build = newIntVectorFixedBuilder.build();
                if (newIntVectorFixedBuilder != null) {
                    newIntVectorFixedBuilder.close();
                }
                return build;
            } catch (Throwable th) {
                if (newIntVectorFixedBuilder != null) {
                    try {
                        newIntVectorFixedBuilder.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
                throw th;
            }
        }

        int process(BytesRef bytesRef) {
            TokenListCategory computeCategory = CategorizeBlockHash.this.categorizer.computeCategory(bytesRef.utf8ToString(), this.analyzer);
            if (computeCategory != null) {
                return computeCategory.getId() + 1;
            }
            CategorizeBlockHash.this.seenNull = true;
            return 0;
        }

        public void close() {
            this.analyzer.close();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public CategorizeBlockHash(BlockFactory blockFactory, int i, AggregatorMode aggregatorMode, AnalysisRegistry analysisRegistry) {
        super(blockFactory);
        this.seenNull = false;
        this.channel = i;
        this.aggregatorMode = aggregatorMode;
        this.categorizer = new TokenListCategorizer.CloseableTokenListCategorizer(new CategorizationBytesRefHash(new BytesRefHash(2048L, blockFactory.bigArrays())), CategorizationPartOfSpeechDictionary.getInstance(), 0.7f);
        if (aggregatorMode.isInputPartial()) {
            this.evaluator = null;
            return;
        }
        try {
            Objects.requireNonNull(analysisRegistry);
            this.evaluator = new CategorizeEvaluator(new CategorizationAnalyzer(analysisRegistry, ANALYZER_CONFIG));
        } catch (Exception e) {
            this.categorizer.close();
            throw new RuntimeException(e);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public boolean seenNull() {
        return this.seenNull;
    }

    @Override // org.elasticsearch.compute.aggregation.blockhash.BlockHash
    public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
        IntBlock add = add(page);
        try {
            addInput.add(0, add);
            if (add != null) {
                add.close();
            }
        } catch (Throwable th) {
            if (add != null) {
                try {
                    add.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // org.elasticsearch.compute.aggregation.blockhash.BlockHash
    public Block[] getKeys() {
        Block[] blockArr = new Block[1];
        blockArr[0] = this.aggregatorMode.isOutputPartial() ? buildIntermediateBlock() : buildFinalBlock();
        return blockArr;
    }

    @Override // org.elasticsearch.compute.aggregation.blockhash.BlockHash
    public IntVector nonEmpty() {
        return IntVector.range(this.seenNull ? 0 : 1, this.categorizer.getCategoryCount() + 1, this.blockFactory);
    }

    @Override // org.elasticsearch.compute.aggregation.blockhash.BlockHash, org.elasticsearch.compute.aggregation.SeenGroupIds
    public BitArray seenGroupIds(BigArrays bigArrays) {
        return new SeenGroupIds.Range(this.seenNull ? 0 : 1, Math.toIntExact(this.categorizer.getCategoryCount() + 1)).seenGroupIds(bigArrays);
    }

    @Override // org.elasticsearch.compute.aggregation.blockhash.BlockHash
    public final ReleasableIterator<IntBlock> lookup(Page page, ByteSizeValue byteSizeValue) {
        throw new UnsupportedOperationException();
    }

    public void close() {
        Releasables.close(new Releasable[]{this.evaluator, this.categorizer});
    }

    private IntBlock add(Page page) {
        return !this.aggregatorMode.isInputPartial() ? addInitial(page) : addIntermediate(page);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public IntBlock addInitial(Page page) {
        return (IntBlock) this.evaluator.eval((BytesRefBlock) page.getBlock(this.channel));
    }

    private IntBlock addIntermediate(Page page) {
        if (page.getPositionCount() == 0) {
            return null;
        }
        BytesRefBlock bytesRefBlock = (BytesRefBlock) page.getBlock(this.channel);
        if (!bytesRefBlock.areAllValuesNull()) {
            return recategorize(bytesRefBlock.getBytesRef(0, new BytesRef()), null).asBlock();
        }
        this.seenNull = true;
        return this.blockFactory.newConstantIntBlockWith(0, 1);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public IntVector recategorize(BytesRef bytesRef, IntVector intVector) {
        HashMap hashMap = new HashMap();
        try {
            StreamInput streamInput = new BytesArray(bytesRef).streamInput();
            try {
                if (streamInput.readBoolean()) {
                    this.seenNull = true;
                    hashMap.put(0, 0);
                }
                int readVInt = streamInput.readVInt();
                for (int i = 0; i < readVInt; i++) {
                    hashMap.put(Integer.valueOf(i + 1), Integer.valueOf(this.categorizer.mergeWireCategory(new SerializableTokenListCategory(streamInput)).getId() + 1));
                }
                if (streamInput != null) {
                    streamInput.close();
                }
                IntVector.Builder newIntVectorBuilder = this.blockFactory.newIntVectorBuilder(hashMap.size());
                try {
                    if (intVector == null) {
                        int i2 = hashMap.containsKey(0) ? 0 : 1;
                        for (int i3 = 0; i3 < hashMap.size(); i3++) {
                            newIntVectorBuilder.appendInt(((Integer) hashMap.get(Integer.valueOf(i3 + i2))).intValue());
                        }
                    } else {
                        for (int i4 = 0; i4 < intVector.getPositionCount(); i4++) {
                            newIntVectorBuilder.appendInt(((Integer) hashMap.get(Integer.valueOf(intVector.getInt(i4)))).intValue());
                        }
                    }
                    IntVector build = newIntVectorBuilder.build();
                    if (newIntVectorBuilder != null) {
                        newIntVectorBuilder.close();
                    }
                    return build;
                } catch (Throwable th) {
                    if (newIntVectorBuilder != null) {
                        try {
                            newIntVectorBuilder.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } finally {
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private Block buildIntermediateBlock() {
        if (this.categorizer.getCategoryCount() == 0) {
            return this.blockFactory.newConstantNullBlock(this.seenNull ? 1 : 0);
        }
        return this.blockFactory.newConstantBytesRefBlockWith(serializeCategorizer(), this.categorizer.getCategoryCount() + (this.seenNull ? 1 : 0));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public BytesRef serializeCategorizer() {
        try {
            BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
            try {
                bytesStreamOutput.writeBoolean(this.seenNull);
                bytesStreamOutput.writeVInt(this.categorizer.getCategoryCount());
                Iterator it = this.categorizer.toCategoriesById().iterator();
                while (it.hasNext()) {
                    ((SerializableTokenListCategory) it.next()).writeTo(bytesStreamOutput);
                }
                BytesRef bytesRef = bytesStreamOutput.bytes().toBytesRef();
                bytesStreamOutput.close();
                return bytesRef;
            } finally {
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private Block buildFinalBlock() {
        BytesRefBuilder bytesRefBuilder = new BytesRefBuilder();
        if (!this.seenNull) {
            BytesRefVector.Builder newBytesRefVectorBuilder = this.blockFactory.newBytesRefVectorBuilder(this.categorizer.getCategoryCount());
            try {
                Iterator it = this.categorizer.toCategoriesById().iterator();
                while (it.hasNext()) {
                    bytesRefBuilder.copyChars(((SerializableTokenListCategory) it.next()).getRegex());
                    newBytesRefVectorBuilder.appendBytesRef(bytesRefBuilder.get());
                    bytesRefBuilder.clear();
                }
                BytesRefBlock asBlock = newBytesRefVectorBuilder.build().asBlock();
                if (newBytesRefVectorBuilder != null) {
                    newBytesRefVectorBuilder.close();
                }
                return asBlock;
            } catch (Throwable th) {
                if (newBytesRefVectorBuilder != null) {
                    try {
                        newBytesRefVectorBuilder.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
                throw th;
            }
        }
        BytesRefBlock.Builder newBytesRefBlockBuilder = this.blockFactory.newBytesRefBlockBuilder(this.categorizer.getCategoryCount());
        try {
            newBytesRefBlockBuilder.mo192appendNull();
            Iterator it2 = this.categorizer.toCategoriesById().iterator();
            while (it2.hasNext()) {
                bytesRefBuilder.copyChars(((SerializableTokenListCategory) it2.next()).getRegex());
                newBytesRefBlockBuilder.mo217appendBytesRef(bytesRefBuilder.get());
                bytesRefBuilder.clear();
            }
            BytesRefBlock mo193build = newBytesRefBlockBuilder.mo193build();
            if (newBytesRefBlockBuilder != null) {
                newBytesRefBlockBuilder.close();
            }
            return mo193build;
        } catch (Throwable th3) {
            if (newBytesRefBlockBuilder != null) {
                try {
                    newBytesRefBlockBuilder.close();
                } catch (Throwable th4) {
                    th3.addSuppressed(th4);
                }
            }
            throw th3;
        }
    }
}
