package org.nd4j.linalg.api.ops.aggregates.impl;

import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.aggregates.BaseAggregate;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/nd4j/linalg/api/ops/aggregates/impl/AggregateDot.class */
public class AggregateDot extends BaseAggregate {
    private int vectorLength;

    public AggregateDot(@NonNull INDArray iNDArray, @NonNull INDArray iNDArray2) {
        if (iNDArray == null) {
            throw new NullPointerException("x");
        }
        if (iNDArray2 == null) {
            throw new NullPointerException("y");
        }
        this.arguments.add(iNDArray);
        this.arguments.add(iNDArray2);
        this.indexingArguments.add(Integer.valueOf(iNDArray.length()));
        this.vectorLength = iNDArray.length();
    }

    @Override // org.nd4j.linalg.api.ops.aggregates.Aggregate
    public int getSharedMemorySize() {
        return (getThreadsPerInstance() * Nd4j.sizeOfDataType()) + 512;
    }

    @Override // org.nd4j.linalg.api.ops.aggregates.Aggregate
    public int getThreadsPerInstance() {
        if (this.vectorLength > 768) {
            return 768;
        }
        return this.vectorLength;
    }

    @Override // org.nd4j.linalg.api.ops.aggregates.Aggregate
    public String name() {
        return "aggregate_dot";
    }

    @Override // org.nd4j.linalg.api.ops.aggregates.Aggregate
    public int opNum() {
        return 1;
    }

    @Override // org.nd4j.linalg.api.ops.aggregates.Aggregate
    public int maxArguments() {
        return 2;
    }

    @Override // org.nd4j.linalg.api.ops.aggregates.Aggregate
    public int maxShapes() {
        return 0;
    }

    @Override // org.nd4j.linalg.api.ops.aggregates.Aggregate
    public int maxIntArrays() {
        return 0;
    }

    @Override // org.nd4j.linalg.api.ops.aggregates.Aggregate
    public int maxIntArraySize() {
        return 0;
    }

    @Override // org.nd4j.linalg.api.ops.aggregates.Aggregate
    public int maxIndexArguments() {
        return 1;
    }

    @Override // org.nd4j.linalg.api.ops.aggregates.Aggregate
    public int maxRealArguments() {
        return 0;
    }
}
