package org.nd4j.linalg.jcublas.util;

import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.Multimap;
import java.util.Arrays;
import org.apache.camel.util.URISupport;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.IndexAccumulation;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.impl.accum.distances.CosineSimilarity;
import org.nd4j.linalg.api.ops.impl.accum.distances.EuclideanDistance;
import org.nd4j.linalg.api.ops.impl.accum.distances.ManhattanDistance;
import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.AddOp;
import org.nd4j.linalg.jcublas.CublasPointer;
import org.nd4j.linalg.jcublas.buffer.JCudaBuffer;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.springframework.beans.factory.xml.BeanDefinitionParserDelegate;

/* loaded from: input_file:org/nd4j/linalg/jcublas/util/CudaArgs.class */
public class CudaArgs {

    /* loaded from: input_file:org/nd4j/linalg/jcublas/util/CudaArgs$ArgsAndReferences.class */
    public static class ArgsAndReferences {
        private Object[] args;
        private Multimap<INDArray, CublasPointer> arrayToPointer;

        public Object[] getArgs() {
            return this.args;
        }

        public Multimap<INDArray, CublasPointer> getArrayToPointer() {
            return this.arrayToPointer;
        }

        public void setArgs(Object[] objArr) {
            this.args = objArr;
        }

        public void setArrayToPointer(Multimap<INDArray, CublasPointer> multimap) {
            this.arrayToPointer = multimap;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof ArgsAndReferences)) {
                return false;
            }
            ArgsAndReferences argsAndReferences = (ArgsAndReferences) obj;
            if (!argsAndReferences.canEqual(this) || !Arrays.deepEquals(getArgs(), argsAndReferences.getArgs())) {
                return false;
            }
            Multimap<INDArray, CublasPointer> arrayToPointer = getArrayToPointer();
            Multimap<INDArray, CublasPointer> arrayToPointer2 = argsAndReferences.getArrayToPointer();
            return arrayToPointer == null ? arrayToPointer2 == null : arrayToPointer.equals(arrayToPointer2);
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof ArgsAndReferences;
        }

        public int hashCode() {
            int deepHashCode = (1 * 59) + Arrays.deepHashCode(getArgs());
            Multimap<INDArray, CublasPointer> arrayToPointer = getArrayToPointer();
            return (deepHashCode * 59) + (arrayToPointer == null ? 43 : arrayToPointer.hashCode());
        }

        public String toString() {
            return "CudaArgs.ArgsAndReferences(args=" + Arrays.deepToString(getArgs()) + ", arrayToPointer=" + getArrayToPointer() + URISupport.RAW_TOKEN_END;
        }

        public ArgsAndReferences(Object[] objArr, Multimap<INDArray, CublasPointer> multimap) {
            this.args = objArr;
            this.arrayToPointer = multimap;
        }
    }

    private CudaArgs() {
    }

    public static String getModuleNameFor(Op op) {
        String str = null;
        if (op instanceof Accumulation) {
            str = "reduce";
            if (op.opName().equals(CosineSimilarity.OP_NAME)) {
                str = "reduce3";
            } else if (op.opName().equals(EuclideanDistance.OP_NAME)) {
                str = "reduce3";
            } else if (op.opName().equals(ManhattanDistance.OP_NAME)) {
                str = "reduce3";
            }
        } else if (op instanceof TransformOp) {
            str = op.opName().equals(AddOp.OP_NAME) ? "pairWiseTransform" : op.opName().equals("copy") ? "pairWiseTransform" : op.opName().equals("div") ? "pairWiseTransform" : op.opName().equals("mul") ? "pairWiseTransform" : op.opName().equals("rdiv") ? "pairWiseTransform" : op.opName().equals("rsub") ? "pairWiseTransform" : op.opName().equals("sub") ? "pairWiseTransform" : "transform";
        } else if (op instanceof ScalarOp) {
            str = "scalar";
        } else if (op instanceof BroadcastOp) {
            str = "broadcast";
        } else if (op instanceof IndexAccumulation) {
            str = "indexReduce";
        }
        return str;
    }

    public static int getOpCode(Op op) {
        int i = -1;
        String opName = op.opName();
        if (op instanceof Accumulation) {
            if (opName.equals("mean")) {
                i = 0;
            } else if (opName.equals("sum")) {
                i = 1;
            } else if (opName.equals("bias")) {
                i = 2;
            } else if (opName.equals("max")) {
                i = 3;
            } else if (opName.equals("min")) {
                i = 4;
            } else if (opName.equals("norm1")) {
                i = 5;
            } else if (opName.equals("norm2")) {
                i = 6;
            } else if (opName.equals("normmax")) {
                i = 7;
            } else if (opName.equals("prod")) {
                i = 8;
            } else if (opName.equals("std")) {
                i = 9;
            } else if (opName.equals("var")) {
                i = 10;
            } else if (opName.equals(ManhattanDistance.OP_NAME)) {
                i = 0;
            } else if (opName.equals(EuclideanDistance.OP_NAME)) {
                i = 1;
            } else if (opName.equals(CosineSimilarity.OP_NAME)) {
                i = 2;
            }
        } else if (op instanceof TransformOp) {
            if (opName.equals("abs")) {
                i = 0;
            } else if (opName.equals("ceil")) {
                i = 1;
            } else if (opName.equals("cos")) {
                i = 2;
            } else if (opName.equals("exp")) {
                i = 3;
            } else if (opName.equals("floor")) {
                i = 4;
            } else if (opName.equals("log")) {
                i = 5;
            } else if (opName.equals("neg")) {
                i = 6;
            } else if (opName.equals("pow")) {
                i = 7;
            } else if (opName.equals("round")) {
                i = 8;
            } else if (opName.equals("setrange")) {
                i = 9;
            } else if (opName.equals("sigmoid")) {
                i = 10;
            } else if (opName.equals("sign")) {
                i = 11;
            } else if (opName.equals("sin")) {
                i = 12;
            } else if (opName.equals("softplus")) {
                i = 13;
            } else if (opName.equals("sqrt")) {
                i = 14;
            } else if (opName.equals("tanh")) {
                i = 15;
            } else if (opName.equals("acos")) {
                i = 16;
            } else if (opName.equals("asin")) {
                i = 17;
            } else if (opName.equals("atan")) {
                i = 18;
            } else if (opName.equals(AddOp.OP_NAME)) {
                i = 0;
            } else if (opName.equals("copy")) {
                i = 1;
            } else if (opName.equals("div")) {
                i = 2;
            } else if (opName.equals("eq")) {
                i = 3;
            } else if (opName.equals("gt")) {
                i = 4;
            } else if (opName.equals("lt")) {
                i = 5;
            } else if (opName.equals("mul")) {
                i = 6;
            } else if (opName.equals("rdiv")) {
                i = 7;
            } else if (opName.equals("rsub")) {
                i = 8;
            } else if (opName.equals("sub")) {
                i = 9;
            } else if (opName.equals("eps")) {
                i = 10;
            } else if (opName.equals("gte")) {
                i = 11;
            } else if (opName.equals("lte")) {
                i = 12;
            } else if (opName.equals("max")) {
                i = 13;
            } else if (opName.equals("min")) {
                i = 14;
            } else if (opName.equals("neq")) {
                i = 15;
            }
        } else if (op instanceof ScalarOp) {
            if (opName.startsWith(AddOp.OP_NAME)) {
                i = 0;
            } else if (opName.startsWith("sub")) {
                i = 1;
            } else if (opName.startsWith("mul")) {
                i = 2;
            } else if (opName.startsWith("div")) {
                i = 3;
            } else if (opName.startsWith("rdiv")) {
                i = 4;
            } else if (opName.startsWith("rsub")) {
                i = 5;
            } else if (opName.startsWith("max")) {
                i = 6;
            } else if (opName.startsWith("lessthan")) {
                i = 7;
            } else if (opName.startsWith("greaterthan")) {
                i = 8;
            } else if (opName.startsWith("eq")) {
                i = 9;
            } else if (opName.startsWith("lte")) {
                i = 10;
            } else if (opName.startsWith("neq")) {
                i = 11;
            } else if (opName.startsWith("min")) {
                i = 12;
            } else if (opName.startsWith(BeanDefinitionParserDelegate.SET_ELEMENT)) {
                i = 13;
            }
        } else if (op instanceof BroadcastOp) {
            if (opName.equals("broadcastadd")) {
                i = 0;
            } else if (opName.equals("broadcastsub")) {
                i = 1;
            } else if (opName.equals("broadcastmul")) {
                i = 2;
            } else if (opName.equals("broadcastdiv")) {
                i = 3;
            } else if (opName.equals("broadcastrdiv")) {
                i = 4;
            } else if (opName.equals("broadcastrsub")) {
                i = 5;
            } else if (opName.equals("broadcastcopy")) {
                i = 6;
            }
        } else if (op instanceof IndexAccumulation) {
            if (opName.equals("imax")) {
                i = 0;
            } else if (opName.equals("imin")) {
                i = 1;
            }
        }
        return i;
    }

    public static int convertMPtoCores(int i, int i2, int i3) {
        if (i == 1) {
            return 8;
        }
        if (i == 2 && i2 == 1) {
            return 48;
        }
        if (i == 2) {
            return 32;
        }
        if (i == 3) {
            return 192;
        }
        return i == 5 ? 128 : -1;
    }

    public static ArgsAndReferences argsAndReference(CudaContext cudaContext, Object... objArr) {
        Object[] objArr2 = new Object[objArr.length];
        ArrayListMultimap create = ArrayListMultimap.create();
        for (int i = 0; i < objArr.length; i++) {
            Object obj = objArr[i];
            if (obj instanceof JCudaBuffer) {
                objArr2[i] = new CublasPointer((JCudaBuffer) obj, cudaContext).getDevicePointer();
            } else if (obj instanceof INDArray) {
                INDArray iNDArray = (INDArray) obj;
                CublasPointer cublasPointer = new CublasPointer(iNDArray, cudaContext);
                objArr2[i] = cublasPointer.getDevicePointer();
                create.put(iNDArray, cublasPointer);
            } else {
                objArr2[i] = obj;
            }
        }
        return new ArgsAndReferences(objArr2, create);
    }
}
