package org.nd4j.linalg.api.ops.impl.reduce;

import java.lang.reflect.Field;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.linalg.api.blas.params.MMulTranspose;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.util.ArrayUtil;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

/* loaded from: input_file:org/nd4j/linalg/api/ops/impl/reduce/Mmul.class */
public class Mmul extends DynamicCustomOp {
    protected MMulTranspose mt;

    public Mmul(SameDiff sameDiff, SDVariable sDVariable, SDVariable sDVariable2, MMulTranspose mMulTranspose) {
        super((String) null, sameDiff, new SDVariable[]{sDVariable, sDVariable2});
        this.mt = mMulTranspose;
        addIArgument(ArrayUtil.fromBoolean(mMulTranspose.isTransposeA()), ArrayUtil.fromBoolean(mMulTranspose.isTransposeB()), ArrayUtil.fromBoolean(mMulTranspose.isTransposeResult()));
    }

    public Mmul(SameDiff sameDiff, SDVariable sDVariable, SDVariable sDVariable2) {
        this(sameDiff, sDVariable, sDVariable2, MMulTranspose.allFalse());
    }

    public Mmul(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, MMulTranspose mMulTranspose) {
        super((String) null, new INDArray[]{iNDArray, iNDArray2}, iNDArray3 == null ? null : new INDArray[]{iNDArray3});
        if (mMulTranspose != null) {
            this.mt = mMulTranspose;
            addIArgument(ArrayUtil.fromBoolean(mMulTranspose.isTransposeA()), ArrayUtil.fromBoolean(mMulTranspose.isTransposeB()), ArrayUtil.fromBoolean(mMulTranspose.isTransposeResult()));
        }
    }

    public Mmul() {
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public Object getValue(Field field) {
        if (this.mt == null) {
            this.mt = MMulTranspose.builder().build();
        }
        return this.mt.getValue(field);
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public Map<String, Object> propertiesForFunction() {
        return this.mt.toProperties();
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public boolean isConfigProperties() {
        return true;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public String configFieldName() {
        return "mt";
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public void setPropertiesForFunction(Map<String, Object> map) {
        if (this.mt == null) {
            this.mt = MMulTranspose.builder().build();
        }
        this.mt.setProperties(map);
    }

    public long[] transposeShapeArray(long[] jArr) {
        if (jArr.length == 2) {
            return ArrayUtil.reverseCopy(jArr);
        }
        if (jArr.length == 3) {
            return new long[]{jArr[0], jArr[2], jArr[1]};
        }
        throw new IllegalArgumentException("Matrix input has to be of length 2 or 3, got: " + jArr.length);
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public String onnxName() {
        return "MatMul";
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public String[] tensorflowNames() {
        return new String[]{"MatMul", "BatchMatMul", "BatchMatMulV2"};
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction, org.nd4j.linalg.api.ops.CustomOp
    public String opName() {
        return "mmul";
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff sameDiff, Map<String, AttrValue> map, GraphDef graphDef) {
        boolean b;
        boolean b2;
        super.initFromTensorFlow(nodeDef, sameDiff, map, graphDef);
        if (nodeDef.getOp().equalsIgnoreCase("MatMul")) {
            b = map.get("transpose_a").getB();
            b2 = map.get("transpose_b").getB();
        } else {
            b = map.containsKey("transpose_a") ? map.get("transpose_a").getB() : map.get("adj_x").getB();
            b2 = map.containsKey("transpose_b") ? map.get("transpose_b").getB() : map.get("adj_y").getB();
        }
        this.mt = MMulTranspose.builder().transposeA(b).transposeB(b2).build();
        this.iArguments.clear();
        addIArgument(ArrayUtil.fromBoolean(this.mt.isTransposeA()), ArrayUtil.fromBoolean(this.mt.isTransposeB()));
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public void initFromOnnx(Onnx.NodeProto nodeProto, SameDiff sameDiff, Map<String, Onnx.AttributeProto> map, Onnx.GraphProto graphProto) {
        this.mt = MMulTranspose.builder().transposeA(!map.containsKey("transA") ? false : map.get("transA").getI() > 0).transposeB(!map.containsKey("transB") ? false : map.get("transB").getI() > 0).build();
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public List<SDVariable> doDiff(List<SDVariable> list) {
        return this.sameDiff.f().mmulBp(larg(), rarg(), list.get(0), this.mt);
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        PropertyMapping build = PropertyMapping.builder().onnxAttrName("transA").tfAttrName("transpose_a").propertyNames(new String[]{"transposeA"}).build();
        PropertyMapping build2 = PropertyMapping.builder().onnxAttrName("transB").tfAttrName("transpose_b").propertyNames(new String[]{"transposeB"}).build();
        hashMap2.put("transposeA", build);
        hashMap2.put("transposeB", build2);
        for (String str : tensorflowNames()) {
            hashMap.put(str, hashMap2);
        }
        hashMap.put(onnxName(), hashMap2);
        return hashMap;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public List<DataType> calculateOutputDataTypes(List<DataType> list) {
        Preconditions.checkState(list != null && list.size() == 2, "Expected exactly 2 inputs to mmul op, got %s", list);
        Preconditions.checkState(list.get(0).isFPType() && list.get(1).isFPType(), "Inputs to mmul op must both be a floatingpoint type: got %s", list);
        return Collections.singletonList(list.get(0));
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof Mmul)) {
            return false;
        }
        Mmul mmul = (Mmul) obj;
        if (!mmul.canEqual(this)) {
            return false;
        }
        MMulTranspose mMulTranspose = this.mt;
        MMulTranspose mMulTranspose2 = mmul.mt;
        return mMulTranspose == null ? mMulTranspose2 == null : mMulTranspose.equals(mMulTranspose2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof Mmul;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public int hashCode() {
        MMulTranspose mMulTranspose = this.mt;
        return (1 * 59) + (mMulTranspose == null ? 43 : mMulTranspose.hashCode());
    }
}
