package org.apache.lucene.benchmark.jmh;

import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import org.apache.lucene.util.VectorUtil;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Level;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;

@Warmup(iterations = 4, time = 1)
@State(Scope.Benchmark)
@Measurement(iterations = 5, time = 1)
@OutputTimeUnit(TimeUnit.MICROSECONDS)
@Fork(value = 3, jvmArgsAppend = {"-Xmx2g", "-Xms2g", "-XX:+AlwaysPreTouch"})
@BenchmarkMode({Mode.Throughput})
/* loaded from: input_file:org/apache/lucene/benchmark/jmh/VectorUtilBenchmark.class */
public class VectorUtilBenchmark {
    private byte[] bytesA;
    private byte[] bytesB;
    private byte[] halfBytesA;
    private byte[] halfBytesB;
    private byte[] halfBytesBPacked;
    private float[] floatsA;
    private float[] floatsB;
    private int expectedhalfByteDotProduct;

    @Param({"1", "128", "207", "256", "300", "512", "702", "1024"})
    int size;

    static void compressBytes(byte[] bArr, byte[] bArr2) {
        for (int i = 0; i < bArr2.length; i++) {
            bArr2[i] = (byte) ((bArr[i] << 4) | bArr[bArr2.length + i]);
        }
    }

    @Setup(Level.Iteration)
    public void init() {
        ThreadLocalRandom current = ThreadLocalRandom.current();
        this.bytesA = new byte[this.size];
        this.bytesB = new byte[this.size];
        current.nextBytes(this.bytesA);
        current.nextBytes(this.bytesB);
        this.expectedhalfByteDotProduct = 0;
        this.halfBytesA = new byte[this.size];
        this.halfBytesB = new byte[this.size];
        for (int i = 0; i < this.size; i++) {
            this.halfBytesA[i] = (byte) current.nextInt(16);
            this.halfBytesB[i] = (byte) current.nextInt(16);
            this.expectedhalfByteDotProduct += this.halfBytesA[i] * this.halfBytesB[i];
        }
        if (this.size % 2 == 0) {
            this.halfBytesBPacked = new byte[(this.size + 1) >> 1];
            compressBytes(this.halfBytesB, this.halfBytesBPacked);
        }
        this.floatsA = new float[this.size];
        this.floatsB = new float[this.size];
        for (int i2 = 0; i2 < this.size; i2++) {
            this.floatsA[i2] = current.nextFloat();
            this.floatsB[i2] = current.nextFloat();
        }
    }

    @Benchmark
    public float binaryCosineScalar() {
        return VectorUtil.cosine(this.bytesA, this.bytesB);
    }

    @Benchmark
    @Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
    public float binaryCosineVector() {
        return VectorUtil.cosine(this.bytesA, this.bytesB);
    }

    @Benchmark
    public int binaryDotProductScalar() {
        return VectorUtil.dotProduct(this.bytesA, this.bytesB);
    }

    @Benchmark
    @Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
    public int binaryDotProductVector() {
        return VectorUtil.dotProduct(this.bytesA, this.bytesB);
    }

    @Benchmark
    public int binarySquareScalar() {
        return VectorUtil.squareDistance(this.bytesA, this.bytesB);
    }

    @Benchmark
    @Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
    public int binarySquareVector() {
        return VectorUtil.squareDistance(this.bytesA, this.bytesB);
    }

    @Benchmark
    public int binaryHalfByteScalar() {
        return VectorUtil.int4DotProduct(this.halfBytesA, this.halfBytesB);
    }

    @Benchmark
    @Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
    public int binaryHalfByteVector() {
        return VectorUtil.int4DotProduct(this.halfBytesA, this.halfBytesB);
    }

    @Benchmark
    public int binaryHalfByteScalarPacked() {
        if (this.size % 2 != 0) {
            throw new RuntimeException("Size must be even for this benchmark");
        }
        int int4DotProductPacked = VectorUtil.int4DotProductPacked(this.halfBytesA, this.halfBytesBPacked);
        if (int4DotProductPacked != this.expectedhalfByteDotProduct) {
            throw new RuntimeException("Expected " + this.expectedhalfByteDotProduct + " but got " + int4DotProductPacked);
        }
        return int4DotProductPacked;
    }

    @Benchmark
    @Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
    public int binaryHalfByteVectorPacked() {
        if (this.size % 2 != 0) {
            throw new RuntimeException("Size must be even for this benchmark");
        }
        int int4DotProductPacked = VectorUtil.int4DotProductPacked(this.halfBytesA, this.halfBytesBPacked);
        if (int4DotProductPacked != this.expectedhalfByteDotProduct) {
            throw new RuntimeException("Expected " + this.expectedhalfByteDotProduct + " but got " + int4DotProductPacked);
        }
        return int4DotProductPacked;
    }

    @Benchmark
    public float floatCosineScalar() {
        return VectorUtil.cosine(this.floatsA, this.floatsB);
    }

    @Benchmark
    @Fork(value = 15, jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
    public float floatCosineVector() {
        return VectorUtil.cosine(this.floatsA, this.floatsB);
    }

    @Benchmark
    public float floatDotProductScalar() {
        return VectorUtil.dotProduct(this.floatsA, this.floatsB);
    }

    @Benchmark
    @Fork(value = 15, jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
    public float floatDotProductVector() {
        return VectorUtil.dotProduct(this.floatsA, this.floatsB);
    }

    @Benchmark
    public float floatSquareScalar() {
        return VectorUtil.squareDistance(this.floatsA, this.floatsB);
    }

    @Benchmark
    @Fork(value = 15, jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
    public float floatSquareVector() {
        return VectorUtil.squareDistance(this.floatsA, this.floatsB);
    }
}
