package org.nd4j.linalg.api.ops;

import java.util.Collections;
import java.util.List;
import java.util.Map;
import onnx.OnnxProto3;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Broadcast;
import org.nd4j.linalg.util.ArrayUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

/* loaded from: input_file:org/nd4j/linalg/api/ops/BaseBroadcastOp.class */
public abstract class BaseBroadcastOp extends BaseOp implements BroadcastOp {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) BaseBroadcastOp.class);
    protected int[] dimension;

    public BaseBroadcastOp(SameDiff sameDiff, SDVariable sDVariable, SDVariable sDVariable2, int[] iArr) {
        this(sameDiff, sDVariable, sDVariable2, false, iArr);
    }

    public BaseBroadcastOp(SameDiff sameDiff, SDVariable sDVariable, SDVariable sDVariable2, boolean z, int[] iArr) {
        super(sameDiff, z, new Object[]{sDVariable2});
        if (sDVariable == null || sDVariable2 == null) {
            throw new IllegalArgumentException("Input not null variables.");
        }
        f().validateDifferentialFunctionsameDiff(sDVariable);
        f().validateDifferentialFunctionsameDiff(sDVariable2);
        this.sameDiff = sameDiff;
        this.inPlace = z;
        this.dimension = iArr;
        if (Shape.isPlaceholderShape(sDVariable.getShape())) {
            sameDiff.addPropertyToResolve(this, sDVariable.getVarName());
        }
        if (Shape.isPlaceholderShape(sDVariable2.getShape())) {
            sameDiff.addPropertyToResolve(this, sDVariable2.getVarName());
        }
        sameDiff.addArgsFor(new SDVariable[]{sDVariable, sDVariable2}, this);
    }

    public BaseBroadcastOp(SameDiff sameDiff) {
        this.sameDiff = sameDiff;
    }

    public BaseBroadcastOp(SameDiff sameDiff, SDVariable sDVariable, SDVariable sDVariable2, int[] iArr, Object[] objArr) {
        super(sameDiff, objArr);
        this.dimension = iArr;
        if (sDVariable == null || sDVariable2 == null) {
            throw new IllegalArgumentException("Input not null variables.");
        }
        f().validateDifferentialFunctionsameDiff(sDVariable);
        f().validateDifferentialFunctionsameDiff(sDVariable2);
        this.sameDiff = sameDiff;
        sameDiff.addArgsFor(new SDVariable[]{sDVariable, sDVariable2}, this);
    }

    public BaseBroadcastOp(SameDiff sameDiff, SDVariable sDVariable, int[] iArr, boolean z) {
        this(sameDiff, sDVariable, sDVariable.getShape(), z, iArr, (Object[]) null);
    }

    public BaseBroadcastOp(SameDiff sameDiff, SDVariable sDVariable, int[] iArr, boolean z, int[] iArr2, Object[] objArr) {
        this(sameDiff, sDVariable, ArrayUtil.toLongArray(iArr), z, iArr2, objArr);
    }

    public BaseBroadcastOp(SameDiff sameDiff, SDVariable sDVariable, long[] jArr, boolean z, int[] iArr, Object[] objArr) {
        super(sameDiff, z, objArr);
        this.dimension = iArr;
        if (sDVariable == null) {
            throw new IllegalArgumentException("Input not null variable.");
        }
        f().validateDifferentialFunctionsameDiff(sDVariable);
        sameDiff.addArgsFor(new SDVariable[]{sDVariable}, this);
    }

    public BaseBroadcastOp(SameDiff sameDiff, SDVariable sDVariable, int[] iArr, Object[] objArr) {
        this(sameDiff, sDVariable, sDVariable.getShape(), false, iArr, objArr);
    }

    public BaseBroadcastOp(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, int... iArr) {
        super(iNDArray, iNDArray2, iNDArray3);
        Broadcast.validateBroadcastDims(iNDArray, iNDArray2, iNDArray3, iArr);
        this.dimension = iArr;
        defineDimensions(iArr);
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public Op.Type opType() {
        return Op.Type.BROADCAST;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public List<LongShapeDescriptor> calculateOutputShape() {
        return (this.x == null || this.y == null) ? Collections.emptyList() : Collections.singletonList(LongShapeDescriptor.fromShape(Shape.broadcastOutputShape(this.x.shape(), this.y.shape()), Shape.pickPairwiseDataType(this.x.dataType(), this.y.dataType())));
    }

    @Override // org.nd4j.linalg.api.ops.BroadcastOp
    public int[] getDimension() {
        if (this.dimension == null) {
            this.dimension = Shape.getBroadcastDimensions(larg().getShape(), rarg().getShape());
        }
        return this.dimension;
    }

    @Override // org.nd4j.linalg.api.ops.BroadcastOp
    public void setDimension(int... iArr) {
        this.dimension = iArr;
    }

    @Override // org.nd4j.linalg.api.ops.BaseOp, org.nd4j.autodiff.functions.DifferentialFunction
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff sameDiff, Map<String, AttrValue> map, GraphDef graphDef) {
    }

    @Override // org.nd4j.linalg.api.ops.BaseOp, org.nd4j.autodiff.functions.DifferentialFunction
    public void initFromOnnx(OnnxProto3.NodeProto nodeProto, SameDiff sameDiff, Map<String, OnnxProto3.AttributeProto> map, OnnxProto3.GraphProto graphProto) {
    }

    @Override // org.nd4j.linalg.api.ops.BroadcastOp
    public boolean validateDataTypes(boolean z) {
        int opNum = opNum();
        if (y() != null && z() != null) {
            Preconditions.checkArgument(y().dataType() == z().dataType() || x().dataType() == z().dataType(), "Op.Z type must be either Op.X or Op.Y: x.dataType=%s, y.dataType=%s, z.dataType=%s, op=%s", this.x.dataType(), this.y.dataType(), this.z.dataType(), getClass().getName());
        }
        if (!z) {
            Preconditions.checkArgument(this.x.dataType() == this.y.dataType() || this.y.dataType() == DataType.BOOL, "Op.X must have same data type as Op.Y: X.datatype=%s, Y.datatype=%s", this.x.dataType(), this.y.dataType());
        }
        if (y() == null) {
            if (!x().isR()) {
                return true;
            }
            Preconditions.checkArgument(z().isR(), "Op.Z must have floating point type, since one of operands is floating point: x.dataType=%s, z.dataType=%s, op=%s", this.x.dataType(), this.z.dataType(), getClass().getName());
            return true;
        }
        if (opNum == 1) {
            return true;
        }
        if (!y().isR() && !x().isR()) {
            return true;
        }
        Preconditions.checkArgument(z().isR(), "Op.Z must have floating point type, since one of operands is floating point: x.dataType=%s, y.dataType=%s, z.dataType=%s, op=%s", this.x.dataType(), this.y.dataType(), this.z.dataType(), getClass().getName());
        return true;
    }

    @Override // org.nd4j.linalg.api.ops.BroadcastOp
    public Op.Type getOpType() {
        return Op.Type.BROADCAST;
    }

    public BaseBroadcastOp() {
    }
}
