package org.nd4j.linalg.api.ops.aggregates.impl;

import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.aggregates.BaseAggregate;
import org.nd4j.linalg.factory.Nd4j;

@Deprecated
/* loaded from: input_file:org/nd4j/linalg/api/ops/aggregates/impl/AggregateCBOW.class */
public class AggregateCBOW extends BaseAggregate {
    private int vectorLength;

    public AggregateCBOW(@NonNull INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, @NonNull INDArray iNDArray4, INDArray iNDArray5, int i, int[] iArr, int[] iArr2, int[] iArr3, int i2, int i3, int i4, double d, long j, int i5, int i6, boolean z, INDArray iNDArray6) {
        this(iNDArray, iNDArray2, iNDArray3, iNDArray4, iNDArray5, i, iArr, iArr2, iArr3, i2, i3, i4, d, j, i5);
        if (iNDArray == null) {
            throw new NullPointerException("syn0 is marked @NonNull but is null");
        }
        if (iNDArray4 == null) {
            throw new NullPointerException("expTable is marked @NonNull but is null");
        }
        this.indexingArguments.set(9, Integer.valueOf(i6));
        this.indexingArguments.set(10, Integer.valueOf(z ? 1 : 0));
        this.indexingArguments.set(11, Integer.valueOf(iNDArray6 == null ? 0 : 1));
        this.arguments.set(5, iNDArray6);
    }

    public AggregateCBOW(@NonNull INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, @NonNull INDArray iNDArray4, INDArray iNDArray5, int i, int[] iArr, int[] iArr2, int[] iArr3, int i2, int i3, int i4, double d, long j, int i5) {
        if (iNDArray == null) {
            throw new NullPointerException("syn0 is marked @NonNull but is null");
        }
        if (iNDArray4 == null) {
            throw new NullPointerException("expTable is marked @NonNull but is null");
        }
        this.indexingArguments.add(Integer.valueOf(i4));
        this.indexingArguments.add(Integer.valueOf(iArr2.length));
        this.indexingArguments.add(Integer.valueOf(i2));
        this.indexingArguments.add(Integer.valueOf((int) iNDArray4.length()));
        this.indexingArguments.add(Integer.valueOf(i5));
        this.indexingArguments.add(Integer.valueOf(i3));
        this.indexingArguments.add(Integer.valueOf(iNDArray5 == null ? 0 : (int) iNDArray5.length()));
        this.indexingArguments.add(Integer.valueOf(iArr.length));
        this.indexingArguments.add(Integer.valueOf(i));
        this.indexingArguments.add(0);
        this.indexingArguments.add(1);
        this.indexingArguments.add(0);
        this.arguments.add(iNDArray);
        this.arguments.add(iNDArray2);
        this.arguments.add(iNDArray4);
        this.arguments.add(iNDArray3);
        this.arguments.add(iNDArray5);
        this.arguments.add(null);
        this.intArrayArguments.add(iArr);
        this.intArrayArguments.add(iArr2);
        this.intArrayArguments.add(iArr3);
        this.realArguments.add(Double.valueOf(d));
        this.realArguments.add(Double.valueOf(j));
        this.vectorLength = i4;
    }

    @Override // org.nd4j.linalg.api.ops.aggregates.Aggregate
    public String name() {
        return "aggregate_cbow";
    }

    @Override // org.nd4j.linalg.api.ops.aggregates.Aggregate
    public int opNum() {
        return 4;
    }

    @Override // org.nd4j.linalg.api.ops.aggregates.Aggregate
    public int maxArguments() {
        return 6;
    }

    @Override // org.nd4j.linalg.api.ops.aggregates.Aggregate
    public int maxShapes() {
        return 0;
    }

    @Override // org.nd4j.linalg.api.ops.aggregates.Aggregate
    public int maxIntArrays() {
        return 3;
    }

    @Override // org.nd4j.linalg.api.ops.aggregates.Aggregate
    public int maxIntArraySize() {
        return 40;
    }

    @Override // org.nd4j.linalg.api.ops.aggregates.Aggregate
    public int maxIndexArguments() {
        return 12;
    }

    @Override // org.nd4j.linalg.api.ops.aggregates.Aggregate
    public int maxRealArguments() {
        return 2;
    }

    @Override // org.nd4j.linalg.api.ops.aggregates.Aggregate
    public int getSharedMemorySize() {
        return (this.vectorLength * Nd4j.sizeOfDataType() * 2) + 512;
    }

    @Override // org.nd4j.linalg.api.ops.aggregates.Aggregate
    public int getThreadsPerInstance() {
        if (this.vectorLength > 768) {
            return 768;
        }
        return this.vectorLength;
    }
}
