package org.apache.lucene.codecs.hnsw;

import java.io.IOException;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
import org.apache.lucene.util.quantization.ScalarQuantizedVectorSimilarity;
import org.apache.lucene.util.quantization.ScalarQuantizer;

/* loaded from: input_file:META-INF/bundled-dependencies/lucene-core-9.11.1.jar:org/apache/lucene/codecs/hnsw/ScalarQuantizedVectorScorer.class */
public class ScalarQuantizedVectorScorer implements FlatVectorsScorer {
    private final FlatVectorsScorer nonQuantizedDelegate;

    /* loaded from: input_file:META-INF/bundled-dependencies/lucene-core-9.11.1.jar:org/apache/lucene/codecs/hnsw/ScalarQuantizedVectorScorer$ScalarQuantizedRandomVectorScorerSupplier.class */
    public static class ScalarQuantizedRandomVectorScorerSupplier implements RandomVectorScorerSupplier {
        private final RandomAccessQuantizedByteVectorValues values;
        private final ScalarQuantizedVectorSimilarity similarity;
        private final VectorSimilarityFunction vectorSimilarityFunction;

        public ScalarQuantizedRandomVectorScorerSupplier(VectorSimilarityFunction vectorSimilarityFunction, ScalarQuantizer scalarQuantizer, RandomAccessQuantizedByteVectorValues randomAccessQuantizedByteVectorValues) {
            this.similarity = ScalarQuantizedVectorSimilarity.fromVectorSimilarity(vectorSimilarityFunction, scalarQuantizer.getConstantMultiplier(), scalarQuantizer.getBits());
            this.values = randomAccessQuantizedByteVectorValues;
            this.vectorSimilarityFunction = vectorSimilarityFunction;
        }

        private ScalarQuantizedRandomVectorScorerSupplier(ScalarQuantizedVectorSimilarity scalarQuantizedVectorSimilarity, VectorSimilarityFunction vectorSimilarityFunction, RandomAccessQuantizedByteVectorValues randomAccessQuantizedByteVectorValues) {
            this.similarity = scalarQuantizedVectorSimilarity;
            this.values = randomAccessQuantizedByteVectorValues;
            this.vectorSimilarityFunction = vectorSimilarityFunction;
        }

        @Override // org.apache.lucene.util.hnsw.RandomVectorScorerSupplier
        public RandomVectorScorer scorer(int i) throws IOException {
            final RandomAccessQuantizedByteVectorValues copy = this.values.copy();
            final byte[] vectorValue = this.values.vectorValue(i);
            final float scoreCorrectionConstant = this.values.getScoreCorrectionConstant(i);
            return new RandomVectorScorer.AbstractRandomVectorScorer(copy) { // from class: org.apache.lucene.codecs.hnsw.ScalarQuantizedVectorScorer.ScalarQuantizedRandomVectorScorerSupplier.1
                @Override // org.apache.lucene.util.hnsw.RandomVectorScorer
                public float score(int i2) throws IOException {
                    return ScalarQuantizedRandomVectorScorerSupplier.this.similarity.score(vectorValue, scoreCorrectionConstant, copy.vectorValue(i2), copy.getScoreCorrectionConstant(i2));
                }
            };
        }

        @Override // org.apache.lucene.util.hnsw.RandomVectorScorerSupplier
        public RandomVectorScorerSupplier copy() throws IOException {
            return new ScalarQuantizedRandomVectorScorerSupplier(this.similarity, this.vectorSimilarityFunction, this.values.copy());
        }
    }

    public static float quantizeQuery(float[] fArr, byte[] bArr, VectorSimilarityFunction vectorSimilarityFunction, ScalarQuantizer scalarQuantizer) {
        float[] fArr2;
        switch (vectorSimilarityFunction) {
            case EUCLIDEAN:
            case DOT_PRODUCT:
            case MAXIMUM_INNER_PRODUCT:
                fArr2 = fArr;
                break;
            case COSINE:
                float[] copyArray = ArrayUtil.copyArray(fArr);
                VectorUtil.l2normalize(copyArray);
                fArr2 = copyArray;
                break;
            default:
                throw new IllegalArgumentException("Unsupported similarity function: " + vectorSimilarityFunction);
        }
        return scalarQuantizer.quantize(fArr2, bArr, vectorSimilarityFunction);
    }

    public ScalarQuantizedVectorScorer(FlatVectorsScorer flatVectorsScorer) {
        this.nonQuantizedDelegate = flatVectorsScorer;
    }

    @Override // org.apache.lucene.codecs.hnsw.FlatVectorsScorer
    public RandomVectorScorerSupplier getRandomVectorScorerSupplier(VectorSimilarityFunction vectorSimilarityFunction, RandomAccessVectorValues randomAccessVectorValues) throws IOException {
        if (!(randomAccessVectorValues instanceof RandomAccessQuantizedByteVectorValues)) {
            return this.nonQuantizedDelegate.getRandomVectorScorerSupplier(vectorSimilarityFunction, randomAccessVectorValues);
        }
        RandomAccessQuantizedByteVectorValues randomAccessQuantizedByteVectorValues = (RandomAccessQuantizedByteVectorValues) randomAccessVectorValues;
        return new ScalarQuantizedRandomVectorScorerSupplier(vectorSimilarityFunction, randomAccessQuantizedByteVectorValues.getScalarQuantizer(), randomAccessQuantizedByteVectorValues);
    }

    @Override // org.apache.lucene.codecs.hnsw.FlatVectorsScorer
    public RandomVectorScorer getRandomVectorScorer(VectorSimilarityFunction vectorSimilarityFunction, RandomAccessVectorValues randomAccessVectorValues, float[] fArr) throws IOException {
        if (!(randomAccessVectorValues instanceof RandomAccessQuantizedByteVectorValues)) {
            return this.nonQuantizedDelegate.getRandomVectorScorer(vectorSimilarityFunction, randomAccessVectorValues, fArr);
        }
        final RandomAccessQuantizedByteVectorValues randomAccessQuantizedByteVectorValues = (RandomAccessQuantizedByteVectorValues) randomAccessVectorValues;
        ScalarQuantizer scalarQuantizer = randomAccessQuantizedByteVectorValues.getScalarQuantizer();
        final byte[] bArr = new byte[fArr.length];
        final float quantizeQuery = quantizeQuery(fArr, bArr, vectorSimilarityFunction, scalarQuantizer);
        final ScalarQuantizedVectorSimilarity fromVectorSimilarity = ScalarQuantizedVectorSimilarity.fromVectorSimilarity(vectorSimilarityFunction, scalarQuantizer.getConstantMultiplier(), scalarQuantizer.getBits());
        return new RandomVectorScorer.AbstractRandomVectorScorer(randomAccessQuantizedByteVectorValues) { // from class: org.apache.lucene.codecs.hnsw.ScalarQuantizedVectorScorer.1
            @Override // org.apache.lucene.util.hnsw.RandomVectorScorer
            public float score(int i) throws IOException {
                return fromVectorSimilarity.score(bArr, quantizeQuery, randomAccessQuantizedByteVectorValues.vectorValue(i), randomAccessQuantizedByteVectorValues.getScoreCorrectionConstant(i));
            }
        };
    }

    @Override // org.apache.lucene.codecs.hnsw.FlatVectorsScorer
    public RandomVectorScorer getRandomVectorScorer(VectorSimilarityFunction vectorSimilarityFunction, RandomAccessVectorValues randomAccessVectorValues, byte[] bArr) throws IOException {
        return this.nonQuantizedDelegate.getRandomVectorScorer(vectorSimilarityFunction, randomAccessVectorValues, bArr);
    }

    public String toString() {
        return "ScalarQuantizedVectorScorer(nonQuantizedDelegate=" + this.nonQuantizedDelegate + ")";
    }
}
