package org.nd4j.linalg.jcublas.ops.executioner;

import java.util.Arrays;
import org.apache.commons.math3.util.Pair;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerPointer;
import org.bytedeco.javacpp.ShortPointer;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.tad.DeviceTADManager;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.IndexAccumulation;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.impl.accum.Variance;
import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.CopyOp;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.cache.TADManager;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.buffer.AddressRetriever;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.class */
public class CudaExecutioner extends DefaultOpExecutioner {
    protected static NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
    private static Logger log = LoggerFactory.getLogger(CudaExecutioner.class);
    protected static TADManager tadManager = new DeviceTADManager();
    protected ThreadLocal<PointerPointer> extraz = new ThreadLocal<>();

    public NativeOps getNativeOps() {
        return nativeOps;
    }

    public INDArray exec(BroadcastOp broadcastOp, int... iArr) {
        checkForCompression(broadcastOp);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        Arrays.sort(iArr);
        CudaContext prepareAction = AtomicAllocator.getInstance().getFlowController().prepareAction(broadcastOp.z(), broadcastOp.x(), broadcastOp.y());
        Pointer retrieveHostPointer = broadcastOp.y() == null ? null : AddressRetriever.retrieveHostPointer(broadcastOp.y().shapeInfoDataBuffer());
        Pointer retrieveHostPointer2 = broadcastOp.z() == null ? null : AddressRetriever.retrieveHostPointer(broadcastOp.z().shapeInfoDataBuffer());
        DoublePointer pointer = AtomicAllocator.getInstance().getPointer(broadcastOp.x(), prepareAction);
        DoublePointer pointer2 = AtomicAllocator.getInstance().getPointer(broadcastOp.y(), prepareAction);
        DoublePointer pointer3 = AtomicAllocator.getInstance().getPointer(broadcastOp.z(), prepareAction);
        IntPointer pointer4 = AtomicAllocator.getInstance().getPointer(broadcastOp.x().shapeInfoDataBuffer(), prepareAction);
        Pair tADOnlyShapeInfo = tadManager.getTADOnlyShapeInfo(broadcastOp.x(), iArr);
        Pointer retrieveHostPointer3 = AddressRetriever.retrieveHostPointer((DataBuffer) tADOnlyShapeInfo.getFirst());
        Pointer pointer5 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), prepareAction);
        Pointer pointer6 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getSecond(), prepareAction);
        Pair tADOnlyShapeInfo2 = tadManager.getTADOnlyShapeInfo(broadcastOp.z(), iArr);
        PointerPointer put = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(broadcastOp.x().shapeInfoDataBuffer()), prepareAction.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), prepareAction.getBufferAllocation(), prepareAction.getBufferReduction(), prepareAction.getBufferScalar(), prepareAction.getBufferSpecial(), retrieveHostPointer, retrieveHostPointer2, retrieveHostPointer3, pointer5, pointer6, AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo2.getFirst(), prepareAction), AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo2.getSecond(), prepareAction)});
        IntPointer pointer7 = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(iArr), prepareAction);
        if (broadcastOp.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            nativeOps.execBroadcastDouble(put, broadcastOp.opNum(), pointer, pointer4, pointer2, AtomicAllocator.getInstance().getPointer(broadcastOp.y().shapeInfoDataBuffer(), prepareAction), pointer3, AtomicAllocator.getInstance().getPointer(broadcastOp.z().shapeInfoDataBuffer(), prepareAction), pointer7, iArr.length);
        } else if (broadcastOp.x().data().dataType() == DataBuffer.Type.FLOAT) {
            nativeOps.execBroadcastFloat(put, broadcastOp.opNum(), (FloatPointer) pointer, pointer4, (FloatPointer) pointer2, AtomicAllocator.getInstance().getPointer(broadcastOp.y().shapeInfoDataBuffer(), prepareAction), (FloatPointer) pointer3, AtomicAllocator.getInstance().getPointer(broadcastOp.z().shapeInfoDataBuffer(), prepareAction), pointer7, iArr.length);
        } else {
            nativeOps.execBroadcastHalf(put, broadcastOp.opNum(), (ShortPointer) pointer, pointer4, (ShortPointer) pointer2, AtomicAllocator.getInstance().getPointer(broadcastOp.y().shapeInfoDataBuffer(), prepareAction), (ShortPointer) pointer3, AtomicAllocator.getInstance().getPointer(broadcastOp.z().shapeInfoDataBuffer(), prepareAction), pointer7, iArr.length);
        }
        AtomicAllocator.getInstance().registerAction(prepareAction, broadcastOp.z(), broadcastOp.x(), broadcastOp.y());
        return broadcastOp.z();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public INDArray naiveExec(Accumulation accumulation, int... iArr) {
        INDArray z = accumulation.z();
        CudaContext prepareAction = AtomicAllocator.getInstance().getFlowController().prepareAction(accumulation.z(), accumulation.x(), accumulation.y());
        Pointer retrieveHostPointer = accumulation.y() == null ? null : AddressRetriever.retrieveHostPointer(accumulation.y().shapeInfoDataBuffer());
        Pointer retrieveHostPointer2 = accumulation.z() == null ? null : AddressRetriever.retrieveHostPointer(accumulation.z().shapeInfoDataBuffer());
        Pair tADOnlyShapeInfo = tadManager.getTADOnlyShapeInfo(accumulation.x(), iArr);
        Pointer retrieveHostPointer3 = AddressRetriever.retrieveHostPointer((DataBuffer) tADOnlyShapeInfo.getFirst());
        Pointer pointer = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), prepareAction);
        DataBuffer dataBuffer = (DataBuffer) tADOnlyShapeInfo.getSecond();
        Pointer pointer2 = dataBuffer == null ? null : AtomicAllocator.getInstance().getPointer(dataBuffer, prepareAction);
        DoublePointer pointer3 = AtomicAllocator.getInstance().getPointer(accumulation.x(), prepareAction);
        IntPointer pointer4 = AtomicAllocator.getInstance().getPointer(accumulation.x().shapeInfoDataBuffer(), prepareAction);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        PointerPointer put = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(accumulation.x().shapeInfoDataBuffer()), prepareAction.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), prepareAction.getBufferAllocation(), prepareAction.getBufferReduction(), prepareAction.getBufferScalar(), prepareAction.getBufferSpecial(), retrieveHostPointer, retrieveHostPointer2, retrieveHostPointer3, pointer, pointer2});
        Pointer pointer5 = accumulation.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(accumulation.extraArgsDataBuff(), prepareAction) : null;
        IntPointer pointer6 = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(iArr), prepareAction);
        if (accumulation.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            if (accumulation instanceof Variance) {
                if (z.isScalar()) {
                    AtomicAllocator.getInstance().tickHostWrite(z);
                    z.putScalar(0, nativeOps.execSummaryStatsScalarDouble(put, accumulation.opNum(), pointer3, pointer4, (DoublePointer) pointer5, true));
                    accumulation.setFinalResult(Double.valueOf(z.getDouble(0)));
                } else {
                    nativeOps.execSummaryStatsDouble(put, accumulation.opNum(), pointer3, pointer4, (DoublePointer) pointer5, AtomicAllocator.getInstance().getPointer(accumulation.z(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.z().shapeInfoDataBuffer(), prepareAction), pointer6, iArr.length, ((Variance) accumulation).isBiasCorrected());
                    AtomicAllocator.getInstance().registerAction(prepareAction, accumulation.z(), accumulation.x(), accumulation.y());
                }
            } else if (accumulation.y() != null) {
                if (z.isScalar()) {
                    AtomicAllocator.getInstance().tickHostWrite(z);
                    z.putScalar(0, nativeOps.execReduce3ScalarDouble(put, accumulation.opNum(), pointer3, pointer4, (DoublePointer) pointer5, AtomicAllocator.getInstance().getPointer(accumulation.y(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.y().shapeInfoDataBuffer(), prepareAction)));
                    accumulation.setFinalResult(Double.valueOf(z.getDouble(0)));
                } else {
                    nativeOps.execReduce3Double(put, accumulation.opNum(), pointer3, pointer4, (DoublePointer) pointer5, AtomicAllocator.getInstance().getPointer(accumulation.y(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.y().shapeInfoDataBuffer(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.z(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.z().shapeInfoDataBuffer(), prepareAction), pointer6, iArr.length);
                    AtomicAllocator.getInstance().registerAction(prepareAction, accumulation.z(), accumulation.x(), accumulation.y());
                }
            } else if (z.isScalar()) {
                AtomicAllocator.getInstance().tickHostWrite(z);
                z.putScalar(0, nativeOps.execReduceScalarDouble(put, accumulation.opNum(), pointer3, pointer4, (DoublePointer) pointer5));
                accumulation.setFinalResult(Double.valueOf(z.getDouble(0)));
            } else {
                nativeOps.execReduceDouble(put, accumulation.opNum(), pointer3, pointer4, (DoublePointer) pointer5, AtomicAllocator.getInstance().getPointer(accumulation.z(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.z().shapeInfoDataBuffer(), prepareAction), pointer6, iArr.length);
                AtomicAllocator.getInstance().registerAction(prepareAction, accumulation.z(), accumulation.x(), accumulation.y());
            }
        } else if (accumulation.x().data().dataType() == DataBuffer.Type.FLOAT) {
            if (accumulation instanceof Variance) {
                if (z.isScalar()) {
                    AtomicAllocator.getInstance().tickHostWrite(z);
                    z.putScalar(0, nativeOps.execSummaryStatsScalarFloat(put, accumulation.opNum(), (FloatPointer) pointer3, pointer4, (FloatPointer) pointer5, true));
                    accumulation.setFinalResult(Float.valueOf(z.getFloat(0)));
                } else {
                    nativeOps.execSummaryStatsFloat(put, accumulation.opNum(), (FloatPointer) pointer3, pointer4, (FloatPointer) pointer5, AtomicAllocator.getInstance().getPointer(accumulation.z(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.z().shapeInfoDataBuffer(), prepareAction), pointer6, iArr.length, ((Variance) accumulation).isBiasCorrected());
                    AtomicAllocator.getInstance().registerAction(prepareAction, accumulation.z(), accumulation.x(), accumulation.y());
                }
            } else if (accumulation.y() != null) {
                if (z.isScalar()) {
                    AtomicAllocator.getInstance().tickHostWrite(z);
                    z.putScalar(0, nativeOps.execReduce3ScalarFloat(put, accumulation.opNum(), (FloatPointer) pointer3, pointer4, (FloatPointer) pointer5, AtomicAllocator.getInstance().getPointer(accumulation.y(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.y().shapeInfoDataBuffer(), prepareAction)));
                    accumulation.setFinalResult(Float.valueOf(z.getFloat(0)));
                } else {
                    nativeOps.execReduce3Float(put, accumulation.opNum(), (FloatPointer) pointer3, pointer4, (FloatPointer) pointer5, AtomicAllocator.getInstance().getPointer(accumulation.y(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.y().shapeInfoDataBuffer(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.z(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.z().shapeInfoDataBuffer(), prepareAction), pointer6, iArr.length);
                    AtomicAllocator.getInstance().registerAction(prepareAction, accumulation.z(), accumulation.x(), accumulation.y());
                }
            } else if (z.isScalar()) {
                AtomicAllocator.getInstance().tickHostWrite(z);
                z.putScalar(0, nativeOps.execReduceScalarFloat(put, accumulation.opNum(), (FloatPointer) pointer3, pointer4, (FloatPointer) pointer5));
                accumulation.setFinalResult(Float.valueOf(z.getFloat(0)));
            } else {
                nativeOps.execReduceFloat(put, accumulation.opNum(), (FloatPointer) pointer3, pointer4, (FloatPointer) pointer5, AtomicAllocator.getInstance().getPointer(accumulation.z(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.z().shapeInfoDataBuffer(), prepareAction), pointer6, iArr.length);
                AtomicAllocator.getInstance().registerAction(prepareAction, accumulation.z(), accumulation.x(), accumulation.y());
            }
        } else if (accumulation instanceof Variance) {
            if (z.isScalar()) {
                AtomicAllocator.getInstance().tickHostWrite(z);
                z.putScalar(0, nativeOps.execSummaryStatsScalarHalf(put, accumulation.opNum(), (ShortPointer) pointer3, pointer4, (ShortPointer) pointer5, true));
                accumulation.setFinalResult(Float.valueOf(z.getFloat(0)));
            } else {
                nativeOps.execSummaryStatsHalf(put, accumulation.opNum(), (ShortPointer) pointer3, pointer4, (ShortPointer) pointer5, AtomicAllocator.getInstance().getPointer(accumulation.z(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.z().shapeInfoDataBuffer(), prepareAction), pointer6, iArr.length, ((Variance) accumulation).isBiasCorrected());
                AtomicAllocator.getInstance().registerAction(prepareAction, accumulation.z(), accumulation.x(), accumulation.y());
            }
        } else if (accumulation.y() != null) {
            if (z.isScalar()) {
                AtomicAllocator.getInstance().tickHostWrite(z);
                z.putScalar(0, nativeOps.execReduce3ScalarHalf(put, accumulation.opNum(), (ShortPointer) pointer3, pointer4, (ShortPointer) pointer5, AtomicAllocator.getInstance().getPointer(accumulation.y(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.y().shapeInfoDataBuffer(), prepareAction)));
                accumulation.setFinalResult(Float.valueOf(z.getFloat(0)));
            } else {
                nativeOps.execReduce3Half(put, accumulation.opNum(), (ShortPointer) pointer3, pointer4, (ShortPointer) pointer5, AtomicAllocator.getInstance().getPointer(accumulation.y(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.y().shapeInfoDataBuffer(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.z(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.z().shapeInfoDataBuffer(), prepareAction), pointer6, iArr.length);
                AtomicAllocator.getInstance().registerAction(prepareAction, accumulation.z(), accumulation.x(), accumulation.y());
            }
        } else if (z.isScalar()) {
            AtomicAllocator.getInstance().tickHostWrite(z);
            z.putScalar(0, nativeOps.execReduceScalarHalf(put, accumulation.opNum(), (ShortPointer) pointer3, pointer4, (ShortPointer) pointer5));
            accumulation.setFinalResult(Float.valueOf(z.getFloat(0)));
        } else {
            nativeOps.execReduceHalf(put, accumulation.opNum(), (ShortPointer) pointer3, pointer4, (ShortPointer) pointer5, AtomicAllocator.getInstance().getPointer(accumulation.z(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.z().shapeInfoDataBuffer(), prepareAction), pointer6, iArr.length);
            AtomicAllocator.getInstance().registerAction(prepareAction, accumulation.z(), accumulation.x(), accumulation.y());
        }
        return accumulation.z();
    }

    public INDArray exec(Accumulation accumulation, int... iArr) {
        checkForCompression(accumulation);
        Arrays.sort(iArr);
        for (int i = 0; i < iArr.length; i++) {
            if (iArr[i] < 0) {
                int i2 = i;
                iArr[i2] = iArr[i2] + accumulation.x().rank();
            }
        }
        if (iArr.length == accumulation.x().rank()) {
            iArr = new int[]{Integer.MAX_VALUE};
        }
        int[] removeIndex = Shape.wholeArrayDimension(iArr) ? new int[]{1, 1} : ArrayUtil.removeIndex(accumulation.x().shape(), iArr);
        if (removeIndex.length == 1) {
            removeIndex = iArr[0] == 0 ? new int[]{1, removeIndex[0]} : new int[]{removeIndex[0], 1};
        } else if (removeIndex.length == 0) {
            removeIndex = new int[]{1, 1};
        }
        if (accumulation.x().isVector() && accumulation.x().length() == ArrayUtil.prod(removeIndex)) {
            return accumulation.noOp();
        }
        accumulation.setZ((accumulation.zeroDouble() <= -0.009999999776482582d || accumulation.zeroDouble() >= 0.009999999776482582d) ? Nd4j.valueArrayOf(removeIndex, accumulation.zeroDouble()) : Nd4j.zeros(removeIndex));
        naiveExec(accumulation, iArr);
        return accumulation.z();
    }

    public INDArray exec(IndexAccumulation indexAccumulation, int... iArr) {
        checkForCompression(indexAccumulation);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        Arrays.sort(iArr);
        for (int i = 0; i < iArr.length; i++) {
            if (iArr[i] < 0) {
                int i2 = i;
                iArr[i2] = iArr[i2] + indexAccumulation.x().rank();
            }
        }
        if (iArr.length == indexAccumulation.x().rank()) {
            iArr = new int[]{Integer.MAX_VALUE};
        }
        int[] removeIndex = Shape.wholeArrayDimension(iArr) ? new int[]{1, 1} : ArrayUtil.removeIndex(indexAccumulation.x().shape(), iArr);
        if (indexAccumulation.x().isVector() && indexAccumulation.x().length() == ArrayUtil.prod(removeIndex)) {
            return indexAccumulation.x();
        }
        if (removeIndex.length == 1) {
            removeIndex = iArr[0] == 0 ? new int[]{1, removeIndex[0]} : new int[]{removeIndex[0], 1};
        } else if (removeIndex.length == 0) {
            removeIndex = new int[]{1, 1};
        }
        indexAccumulation.setZ((indexAccumulation.zeroDouble() <= -0.009999999776482582d || indexAccumulation.zeroDouble() >= 0.009999999776482582d) ? Nd4j.valueArrayOf(removeIndex, indexAccumulation.zeroDouble()) : Nd4j.zeros(removeIndex));
        if (iArr.length == indexAccumulation.x().rank()) {
            iArr = new int[]{Integer.MAX_VALUE};
        }
        CudaContext prepareAction = AtomicAllocator.getInstance().getFlowController().prepareAction(indexAccumulation.z(), indexAccumulation.x(), indexAccumulation.y());
        Pointer retrieveHostPointer = indexAccumulation.y() == null ? null : AddressRetriever.retrieveHostPointer(indexAccumulation.y().shapeInfoDataBuffer());
        Pointer retrieveHostPointer2 = indexAccumulation.z() == null ? null : AddressRetriever.retrieveHostPointer(indexAccumulation.z().shapeInfoDataBuffer());
        DoublePointer pointer = AtomicAllocator.getInstance().getPointer(indexAccumulation.x(), prepareAction);
        IntPointer pointer2 = AtomicAllocator.getInstance().getPointer(indexAccumulation.x().shapeInfoDataBuffer(), prepareAction);
        DoublePointer pointer3 = AtomicAllocator.getInstance().getPointer(indexAccumulation.z(), prepareAction);
        IntPointer pointer4 = AtomicAllocator.getInstance().getPointer(indexAccumulation.z().shapeInfoDataBuffer(), prepareAction);
        Pair tADOnlyShapeInfo = tadManager.getTADOnlyShapeInfo(indexAccumulation.x(), iArr);
        Pointer retrieveHostPointer3 = AddressRetriever.retrieveHostPointer((DataBuffer) tADOnlyShapeInfo.getFirst());
        Pointer pointer5 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), prepareAction);
        DataBuffer dataBuffer = (DataBuffer) tADOnlyShapeInfo.getSecond();
        PointerPointer put = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(indexAccumulation.x().shapeInfoDataBuffer()), prepareAction.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), prepareAction.getBufferAllocation(), prepareAction.getBufferReduction(), prepareAction.getBufferScalar(), prepareAction.getBufferSpecial(), retrieveHostPointer, retrieveHostPointer2, retrieveHostPointer3, pointer5, dataBuffer == null ? null : AtomicAllocator.getInstance().getPointer(dataBuffer, prepareAction)});
        Pointer pointer6 = indexAccumulation.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(indexAccumulation.extraArgsDataBuff(), prepareAction) : null;
        IntPointer pointer7 = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(iArr), prepareAction);
        if (indexAccumulation.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            nativeOps.execIndexReduceDouble(put, indexAccumulation.opNum(), pointer, pointer2, (DoublePointer) pointer6, pointer3, pointer4, pointer7, iArr.length);
        } else if (indexAccumulation.x().data().dataType() == DataBuffer.Type.FLOAT) {
            nativeOps.execIndexReduceFloat(put, indexAccumulation.opNum(), (FloatPointer) pointer, pointer2, (FloatPointer) pointer6, (FloatPointer) pointer3, pointer4, pointer7, iArr.length);
        } else {
            nativeOps.execIndexReduceHalf(put, indexAccumulation.opNum(), (ShortPointer) pointer, pointer2, (ShortPointer) pointer6, (ShortPointer) pointer3, pointer4, pointer7, iArr.length);
        }
        AtomicAllocator.getInstance().registerAction(prepareAction, indexAccumulation.z(), indexAccumulation.x(), indexAccumulation.y());
        return indexAccumulation.z();
    }

    public Op exec(Op op, int... iArr) {
        checkForCompression(op);
        Arrays.sort(iArr);
        return super.exec(op, iArr);
    }

    public Op exec(Op op) {
        checkForCompression(op);
        if ((op.x() instanceof IComplexNDArray) || executionMode() == OpExecutioner.ExecutionMode.JAVA || (op instanceof CopyOp)) {
            if (op.x() != null) {
                AtomicAllocator.getInstance().synchronizeHostData(op.x());
            }
            if (op.y() != null) {
                AtomicAllocator.getInstance().synchronizeHostData(op.y());
            }
            super.exec(op);
            if (op.z() == null) {
                return null;
            }
            AtomicAllocator.getInstance().tickHostWrite(op.z());
            return null;
        }
        if (op instanceof TransformOp) {
            invoke((TransformOp) op);
        } else if (op instanceof Accumulation) {
            invoke((Accumulation) op, (int[]) null);
        } else if (op instanceof ScalarOp) {
            invoke((ScalarOp) op);
        } else if (op instanceof BroadcastOp) {
            invoke((BroadcastOp) op);
        } else if (op instanceof IndexAccumulation) {
            invoke((IndexAccumulation) op, (int[]) null);
        }
        return op;
    }

    public INDArray execAndReturn(TransformOp transformOp) {
        checkForCompression(transformOp);
        invoke(transformOp);
        return transformOp.z();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public CudaContext invoke(BroadcastOp broadcastOp) {
        checkForCompression(broadcastOp);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        CudaContext prepareAction = AtomicAllocator.getInstance().getFlowController().prepareAction(broadcastOp.z(), broadcastOp.x(), broadcastOp.y());
        DoublePointer pointer = AtomicAllocator.getInstance().getPointer(broadcastOp.x(), prepareAction);
        IntPointer pointer2 = AtomicAllocator.getInstance().getPointer(broadcastOp.x().shapeInfoDataBuffer(), prepareAction);
        Pointer retrieveHostPointer = broadcastOp.y() == null ? null : AddressRetriever.retrieveHostPointer(broadcastOp.y().shapeInfoDataBuffer());
        Pointer retrieveHostPointer2 = broadcastOp.z() == null ? null : AddressRetriever.retrieveHostPointer(broadcastOp.z().shapeInfoDataBuffer());
        Pair tADOnlyShapeInfo = tadManager.getTADOnlyShapeInfo(broadcastOp.x(), broadcastOp.getDimension());
        Pointer retrieveHostPointer3 = AddressRetriever.retrieveHostPointer((DataBuffer) tADOnlyShapeInfo.getFirst());
        Pointer pointer3 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), prepareAction);
        Pointer pointer4 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getSecond(), prepareAction);
        Pair tADOnlyShapeInfo2 = tadManager.getTADOnlyShapeInfo(broadcastOp.z(), broadcastOp.getDimension());
        PointerPointer put = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(broadcastOp.x().shapeInfoDataBuffer()), prepareAction.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), prepareAction.getBufferAllocation(), prepareAction.getBufferReduction(), prepareAction.getBufferScalar(), prepareAction.getBufferSpecial(), retrieveHostPointer, retrieveHostPointer2, retrieveHostPointer3, pointer3, pointer4, AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo2.getFirst(), prepareAction), AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo2.getSecond(), prepareAction)});
        DoublePointer pointer5 = AtomicAllocator.getInstance().getPointer(broadcastOp.y(), prepareAction);
        IntPointer pointer6 = AtomicAllocator.getInstance().getPointer(broadcastOp.y().shapeInfoDataBuffer(), prepareAction);
        DoublePointer pointer7 = AtomicAllocator.getInstance().getPointer(broadcastOp.z(), prepareAction);
        IntPointer pointer8 = AtomicAllocator.getInstance().getPointer(broadcastOp.z().shapeInfoDataBuffer(), prepareAction);
        IntPointer pointer9 = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(broadcastOp.getDimension()), prepareAction);
        if (broadcastOp.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            nativeOps.execBroadcastDouble(put, broadcastOp.opNum(), pointer, pointer2, pointer5, pointer6, pointer7, pointer8, pointer9, broadcastOp.getDimension().length);
        } else if (broadcastOp.x().data().dataType() == DataBuffer.Type.FLOAT) {
            nativeOps.execBroadcastFloat(put, broadcastOp.opNum(), (FloatPointer) pointer, pointer2, (FloatPointer) pointer5, pointer6, (FloatPointer) pointer7, pointer8, pointer9, broadcastOp.getDimension().length);
        } else {
            nativeOps.execBroadcastHalf(put, broadcastOp.opNum(), (ShortPointer) pointer, pointer2, (ShortPointer) pointer5, pointer6, (ShortPointer) pointer7, pointer8, pointer9, broadcastOp.getDimension().length);
        }
        AtomicAllocator.getInstance().registerAction(prepareAction, broadcastOp.z(), broadcastOp.x(), broadcastOp.y());
        return null;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public CudaContext invoke(IndexAccumulation indexAccumulation, int[] iArr) {
        checkForCompression(indexAccumulation);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        CudaContext prepareAction = AtomicAllocator.getInstance().getFlowController().prepareAction(indexAccumulation.z(), indexAccumulation.x(), indexAccumulation.y());
        DoublePointer pointer = AtomicAllocator.getInstance().getPointer(indexAccumulation.x(), prepareAction);
        IntPointer pointer2 = AtomicAllocator.getInstance().getPointer(indexAccumulation.x().shapeInfoDataBuffer(), prepareAction);
        Pointer pointer3 = indexAccumulation.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(indexAccumulation.extraArgsDataBuff(), prepareAction) : null;
        Pointer retrieveHostPointer = indexAccumulation.y() == null ? null : AddressRetriever.retrieveHostPointer(indexAccumulation.y().shapeInfoDataBuffer());
        Pointer retrieveHostPointer2 = indexAccumulation.z() == null ? null : AddressRetriever.retrieveHostPointer(indexAccumulation.z().shapeInfoDataBuffer());
        int[] iArr2 = iArr;
        if (iArr2 == null) {
            iArr2 = new int[]{0};
        }
        Pair tADOnlyShapeInfo = tadManager.getTADOnlyShapeInfo(indexAccumulation.x(), iArr2);
        Pointer retrieveHostPointer3 = AddressRetriever.retrieveHostPointer((DataBuffer) tADOnlyShapeInfo.getFirst());
        Pointer pointer4 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), prepareAction);
        DataBuffer dataBuffer = (DataBuffer) tADOnlyShapeInfo.getSecond();
        PointerPointer put = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(indexAccumulation.x().shapeInfoDataBuffer()), prepareAction.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), prepareAction.getBufferAllocation(), prepareAction.getBufferReduction(), prepareAction.getBufferScalar(), prepareAction.getBufferSpecial(), retrieveHostPointer, retrieveHostPointer2, retrieveHostPointer3, pointer4, dataBuffer == null ? null : AtomicAllocator.getInstance().getPointer(dataBuffer, prepareAction)});
        if (!indexAccumulation.z().isScalar() && iArr != null && iArr[0] != Integer.MAX_VALUE) {
            Arrays.sort(iArr);
            DoublePointer pointer5 = AtomicAllocator.getInstance().getPointer(indexAccumulation.z(), prepareAction);
            IntPointer pointer6 = AtomicAllocator.getInstance().getPointer(indexAccumulation.z().shapeInfoDataBuffer(), prepareAction);
            IntPointer pointer7 = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(iArr), prepareAction);
            if (indexAccumulation.x().data().dataType() == DataBuffer.Type.DOUBLE) {
                nativeOps.execIndexReduceDouble(put, indexAccumulation.opNum(), pointer, pointer2, (DoublePointer) pointer3, pointer5, pointer6, pointer7, iArr.length);
            } else if (indexAccumulation.x().data().dataType() == DataBuffer.Type.FLOAT) {
                nativeOps.execIndexReduceFloat(put, indexAccumulation.opNum(), (FloatPointer) pointer, pointer2, (FloatPointer) pointer3, (FloatPointer) pointer5, pointer6, pointer7, iArr.length);
            } else {
                nativeOps.execIndexReduceHalf(put, indexAccumulation.opNum(), (ShortPointer) pointer, pointer2, (ShortPointer) pointer3, (ShortPointer) pointer5, pointer6, pointer7, iArr.length);
            }
        } else if (indexAccumulation.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            indexAccumulation.setFinalResult((int) nativeOps.execIndexReduceScalarDouble(put, indexAccumulation.opNum(), pointer, pointer2, (DoublePointer) pointer3));
        } else if (indexAccumulation.x().data().dataType() == DataBuffer.Type.FLOAT) {
            indexAccumulation.setFinalResult((int) nativeOps.execIndexReduceScalarFloat(put, indexAccumulation.opNum(), (FloatPointer) pointer, pointer2, (FloatPointer) pointer3));
        } else {
            indexAccumulation.setFinalResult((int) nativeOps.execIndexReduceScalarHalf(put, indexAccumulation.opNum(), (ShortPointer) pointer, pointer2, (ShortPointer) pointer3));
        }
        AtomicAllocator.getInstance().registerAction(prepareAction, indexAccumulation.z(), indexAccumulation.x(), indexAccumulation.y());
        return null;
    }

    protected CudaContext invoke(Accumulation accumulation, int[] iArr) {
        checkForCompression(accumulation);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        if (iArr == null) {
            iArr = new int[]{Integer.MAX_VALUE};
        }
        Arrays.sort(iArr);
        CudaContext prepareAction = AtomicAllocator.getInstance().getFlowController().prepareAction(accumulation.z(), accumulation.x(), accumulation.y());
        Pointer retrieveHostPointer = accumulation.y() == null ? null : AddressRetriever.retrieveHostPointer(accumulation.y().shapeInfoDataBuffer());
        Pointer retrieveHostPointer2 = accumulation.z() == null ? null : AddressRetriever.retrieveHostPointer(accumulation.z().shapeInfoDataBuffer());
        Pair tADOnlyShapeInfo = tadManager.getTADOnlyShapeInfo(accumulation.x(), iArr);
        Pointer retrieveHostPointer3 = AddressRetriever.retrieveHostPointer((DataBuffer) tADOnlyShapeInfo.getFirst());
        Pointer pointer = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), prepareAction);
        DataBuffer dataBuffer = (DataBuffer) tADOnlyShapeInfo.getSecond();
        PointerPointer put = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(accumulation.x().shapeInfoDataBuffer()), prepareAction.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), prepareAction.getBufferAllocation(), prepareAction.getBufferReduction(), prepareAction.getBufferScalar(), prepareAction.getBufferSpecial(), retrieveHostPointer, retrieveHostPointer2, retrieveHostPointer3, pointer, dataBuffer == null ? null : AtomicAllocator.getInstance().getPointer(dataBuffer, prepareAction)});
        DoublePointer pointer2 = AtomicAllocator.getInstance().getPointer(accumulation.x(), prepareAction);
        IntPointer pointer3 = AtomicAllocator.getInstance().getPointer(accumulation.x().shapeInfoDataBuffer(), prepareAction);
        Pointer pointer4 = accumulation.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(accumulation.extraArgsDataBuff(), prepareAction) : null;
        int[] removeIndex = Shape.wholeArrayDimension(iArr) ? new int[]{1, 1} : ArrayUtil.removeIndex(accumulation.x().shape(), iArr);
        if (removeIndex.length == 1) {
            removeIndex = iArr[0] == 0 ? new int[]{1, removeIndex[0]} : new int[]{removeIndex[0], 1};
        } else if (removeIndex.length == 0) {
            removeIndex = new int[]{1, 1};
        }
        if (accumulation.x().isVector() && accumulation.x().length() == ArrayUtil.prod(removeIndex)) {
            return null;
        }
        accumulation.setZ((accumulation.zeroDouble() <= -0.009999999776482582d || accumulation.zeroDouble() >= 0.009999999776482582d) ? Nd4j.valueArrayOf(removeIndex, accumulation.zeroDouble()) : Nd4j.zeros(removeIndex));
        if (!accumulation.z().isScalar()) {
            DoublePointer pointer5 = AtomicAllocator.getInstance().getPointer(accumulation.z(), prepareAction);
            IntPointer pointer6 = AtomicAllocator.getInstance().getPointer(accumulation.z().shapeInfoDataBuffer(), prepareAction);
            IntPointer pointer7 = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(iArr), prepareAction);
            if (accumulation.x().data().dataType() == DataBuffer.Type.DOUBLE) {
                if (accumulation.y() != null) {
                    nativeOps.execReduce3Double(put, accumulation.opNum(), pointer2, pointer3, (DoublePointer) pointer4, AtomicAllocator.getInstance().getPointer(accumulation.y(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.y().shapeInfoDataBuffer(), prepareAction), pointer5, pointer6, pointer7, iArr.length);
                } else if (accumulation instanceof Variance) {
                    nativeOps.execSummaryStatsDouble(put, accumulation.opNum(), pointer2, pointer3, (DoublePointer) pointer4, pointer5, pointer6, pointer7, iArr.length, ((Variance) accumulation).isBiasCorrected());
                } else {
                    nativeOps.execReduceDouble(put, accumulation.opNum(), pointer2, pointer3, (DoublePointer) pointer4, pointer5, pointer6, pointer7, iArr.length);
                }
            } else if (accumulation.x().data().dataType() == DataBuffer.Type.FLOAT) {
                if (accumulation.y() != null) {
                    nativeOps.execReduce3Float(put, accumulation.opNum(), (FloatPointer) pointer2, pointer3, (FloatPointer) pointer4, AtomicAllocator.getInstance().getPointer(accumulation.y(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.y().shapeInfoDataBuffer(), prepareAction), (FloatPointer) pointer5, pointer6, pointer7, iArr.length);
                } else if (accumulation instanceof Variance) {
                    nativeOps.execSummaryStatsFloat(put, accumulation.opNum(), (FloatPointer) pointer2, pointer3, (FloatPointer) pointer4, (FloatPointer) pointer5, pointer6, pointer7, iArr.length, ((Variance) accumulation).isBiasCorrected());
                } else {
                    nativeOps.execReduceFloat(put, accumulation.opNum(), (FloatPointer) pointer2, pointer3, (FloatPointer) pointer4, (FloatPointer) pointer5, pointer6, pointer7, iArr.length);
                }
            } else if (accumulation.y() != null) {
                nativeOps.execReduce3Half(put, accumulation.opNum(), (ShortPointer) pointer2, pointer3, (ShortPointer) pointer4, AtomicAllocator.getInstance().getPointer(accumulation.y(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.y().shapeInfoDataBuffer(), prepareAction), (ShortPointer) pointer5, pointer6, pointer7, iArr.length);
            } else if (accumulation instanceof Variance) {
                nativeOps.execSummaryStatsHalf(put, accumulation.opNum(), (ShortPointer) pointer2, pointer3, (ShortPointer) pointer4, (ShortPointer) pointer5, pointer6, pointer7, iArr.length, ((Variance) accumulation).isBiasCorrected());
            } else {
                nativeOps.execReduceHalf(put, accumulation.opNum(), (ShortPointer) pointer2, pointer3, (ShortPointer) pointer4, (ShortPointer) pointer5, pointer6, pointer7, iArr.length);
            }
        } else if (accumulation.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            if (accumulation instanceof Variance) {
                accumulation.setFinalResult(Double.valueOf(nativeOps.execSummaryStatsScalarDouble(put, accumulation.opNum(), pointer2, pointer3, (DoublePointer) pointer4, true)));
            } else if (accumulation.y() != null) {
                accumulation.setFinalResult(Double.valueOf(nativeOps.execReduce3ScalarDouble(put, accumulation.opNum(), pointer2, pointer3, (DoublePointer) pointer4, AtomicAllocator.getInstance().getPointer(accumulation.y(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.y().shapeInfoDataBuffer(), prepareAction))));
            } else {
                accumulation.setFinalResult(Double.valueOf(nativeOps.execReduceScalarDouble(put, accumulation.opNum(), pointer2, pointer3, (DoublePointer) pointer4)));
            }
        } else if (accumulation.x().data().dataType() == DataBuffer.Type.FLOAT) {
            if (accumulation instanceof Variance) {
                accumulation.setFinalResult(Float.valueOf(nativeOps.execSummaryStatsScalarFloat(put, accumulation.opNum(), (FloatPointer) pointer2, pointer3, (FloatPointer) pointer4, true)));
            } else if (accumulation.y() != null) {
                accumulation.setFinalResult(Float.valueOf(nativeOps.execReduce3ScalarFloat(put, accumulation.opNum(), (FloatPointer) pointer2, pointer3, (FloatPointer) pointer4, AtomicAllocator.getInstance().getPointer(accumulation.y(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.y().shapeInfoDataBuffer(), prepareAction))));
            } else {
                accumulation.setFinalResult(Float.valueOf(nativeOps.execReduceScalarFloat(put, accumulation.opNum(), (FloatPointer) pointer2, pointer3, (FloatPointer) pointer4)));
            }
        } else if (accumulation instanceof Variance) {
            accumulation.setFinalResult(Float.valueOf(nativeOps.execSummaryStatsScalarHalf(put, accumulation.opNum(), (ShortPointer) pointer2, pointer3, (ShortPointer) pointer4, true)));
        } else if (accumulation.y() != null) {
            accumulation.setFinalResult(Float.valueOf(nativeOps.execReduce3ScalarHalf(put, accumulation.opNum(), (ShortPointer) pointer2, pointer3, (ShortPointer) pointer4, AtomicAllocator.getInstance().getPointer(accumulation.y(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.y().shapeInfoDataBuffer(), prepareAction))));
        } else {
            accumulation.setFinalResult(Float.valueOf(nativeOps.execReduceScalarHalf(put, accumulation.opNum(), (ShortPointer) pointer2, pointer3, (ShortPointer) pointer4)));
        }
        AtomicAllocator.getInstance().registerAction(prepareAction, accumulation.z(), accumulation.x(), accumulation.y());
        return prepareAction;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public CudaContext invoke(ScalarOp scalarOp) {
        checkForCompression(scalarOp);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        CudaContext prepareAction = AtomicAllocator.getInstance().getFlowController().prepareAction(scalarOp.z(), scalarOp.x(), scalarOp.y());
        Pointer retrieveHostPointer = scalarOp.y() == null ? null : AddressRetriever.retrieveHostPointer(scalarOp.y().shapeInfoDataBuffer());
        Pointer retrieveHostPointer2 = scalarOp.z() == null ? null : AddressRetriever.retrieveHostPointer(scalarOp.z().shapeInfoDataBuffer());
        DoublePointer pointer = AtomicAllocator.getInstance().getPointer(scalarOp.x(), prepareAction);
        IntPointer pointer2 = AtomicAllocator.getInstance().getPointer(scalarOp.x().shapeInfoDataBuffer(), prepareAction);
        Pointer pointer3 = scalarOp.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(scalarOp.extraArgsDataBuff(), prepareAction) : null;
        DoublePointer pointer4 = AtomicAllocator.getInstance().getPointer(scalarOp.z(), prepareAction);
        IntPointer pointer5 = AtomicAllocator.getInstance().getPointer(scalarOp.z().shapeInfoDataBuffer(), prepareAction);
        PointerPointer put = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(scalarOp.x().shapeInfoDataBuffer()), prepareAction.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), prepareAction.getBufferAllocation(), prepareAction.getBufferReduction(), prepareAction.getBufferScalar(), prepareAction.getBufferSpecial(), retrieveHostPointer, retrieveHostPointer2, null, null});
        if (scalarOp.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            nativeOps.execScalarDouble(put, scalarOp.opNum(), pointer, pointer2, pointer4, pointer5, scalarOp.scalar().doubleValue(), (DoublePointer) pointer3);
        } else if (scalarOp.x().data().dataType() == DataBuffer.Type.FLOAT) {
            nativeOps.execScalarFloat(put, scalarOp.opNum(), (FloatPointer) pointer, pointer2, (FloatPointer) pointer4, pointer5, scalarOp.scalar().floatValue(), (FloatPointer) pointer3);
        } else {
            nativeOps.execScalarHalf(put, scalarOp.opNum(), (ShortPointer) pointer, pointer2, (ShortPointer) pointer4, pointer5, scalarOp.scalar().floatValue(), (ShortPointer) pointer3);
        }
        AtomicAllocator.getInstance().registerAction(prepareAction, scalarOp.z(), scalarOp.x(), scalarOp.y());
        return null;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public CudaContext invoke(TransformOp transformOp) {
        checkForCompression(transformOp);
        AtomicAllocator.getInstance();
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        CudaContext prepareAction = AtomicAllocator.getInstance().getFlowController().prepareAction(transformOp.z(), transformOp.x(), transformOp.y());
        DoublePointer pointer = AtomicAllocator.getInstance().getPointer(transformOp.x(), prepareAction);
        IntPointer pointer2 = AtomicAllocator.getInstance().getPointer(transformOp.x().shapeInfoDataBuffer(), prepareAction);
        Pointer pointer3 = transformOp.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(transformOp.extraArgsDataBuff(), prepareAction) : null;
        Pointer retrieveHostPointer = transformOp.y() == null ? null : AddressRetriever.retrieveHostPointer(transformOp.y().shapeInfoDataBuffer());
        Pointer retrieveHostPointer2 = transformOp.z() == null ? null : AddressRetriever.retrieveHostPointer(transformOp.z().shapeInfoDataBuffer());
        Pointer pointer4 = null;
        Pointer pointer5 = null;
        int[] iArr = null;
        if (transformOp.opNum() == 41 && transformOp.extraArgs() != null) {
            iArr = new int[]{((Integer) transformOp.extraArgs()[1]).intValue()};
            for (int i = 0; i < iArr.length; i++) {
                if (iArr[i] < 0) {
                    int i2 = i;
                    iArr[i2] = iArr[i2] + transformOp.x().rank();
                }
            }
            if (iArr.length == transformOp.x().rank()) {
                iArr = new int[]{Integer.MAX_VALUE};
            }
            int[] removeIndex = Shape.wholeArrayDimension(iArr) ? new int[]{1, 1} : ArrayUtil.removeIndex(transformOp.x().shape(), iArr);
            if (removeIndex.length == 1) {
                removeIndex = iArr[0] == 0 ? new int[]{1, removeIndex[0]} : new int[]{removeIndex[0], 1};
            } else if (removeIndex.length == 0) {
                removeIndex = new int[]{1, 1};
            }
            retrieveHostPointer = AtomicAllocator.getInstance().getPointer(Nd4j.zeros(removeIndex).shapeInfoDataBuffer(), prepareAction);
            DataBuffer constantBuffer = AtomicAllocator.getInstance().getConstantBuffer(iArr);
            pointer4 = AtomicAllocator.getInstance().getPointer(constantBuffer, prepareAction);
            pointer5 = AtomicAllocator.getInstance().getHostPointer(constantBuffer);
        }
        Pointer pointer6 = null;
        Pointer pointer7 = null;
        Pointer pointer8 = null;
        Pointer pointer9 = null;
        Pointer pointer10 = null;
        Pointer pointer11 = null;
        if (transformOp.opNum() >= 38 && transformOp.opNum() <= 41) {
            if (transformOp.opNum() != 41) {
                Pair tADOnlyShapeInfo = tadManager.getTADOnlyShapeInfo(transformOp.x(), new int[]{0});
                Pair tADOnlyShapeInfo2 = tadManager.getTADOnlyShapeInfo(transformOp.x(), new int[]{1});
                pointer6 = AddressRetriever.retrieveHostPointer((DataBuffer) tADOnlyShapeInfo.getFirst());
                pointer7 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), prepareAction);
                pointer8 = AddressRetriever.retrieveHostPointer((DataBuffer) tADOnlyShapeInfo2.getFirst());
                pointer9 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo2.getFirst(), prepareAction);
                DataBuffer dataBuffer = (DataBuffer) tADOnlyShapeInfo.getSecond();
                pointer10 = dataBuffer == null ? null : AtomicAllocator.getInstance().getPointer(dataBuffer, prepareAction);
                DataBuffer dataBuffer2 = (DataBuffer) tADOnlyShapeInfo2.getSecond();
                pointer11 = dataBuffer2 == null ? null : AtomicAllocator.getInstance().getPointer(dataBuffer2, prepareAction);
            } else {
                Pair tADOnlyShapeInfo3 = tadManager.getTADOnlyShapeInfo(transformOp.z(), iArr);
                pointer6 = AddressRetriever.retrieveHostPointer((DataBuffer) tADOnlyShapeInfo3.getFirst());
                pointer7 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo3.getFirst(), prepareAction);
                DataBuffer dataBuffer3 = (DataBuffer) tADOnlyShapeInfo3.getSecond();
                pointer10 = dataBuffer3 == null ? null : AtomicAllocator.getInstance().getPointer(dataBuffer3, prepareAction);
            }
        }
        DoublePointer pointer12 = AtomicAllocator.getInstance().getPointer(transformOp.z(), prepareAction);
        IntPointer pointer13 = AtomicAllocator.getInstance().getPointer(transformOp.z().shapeInfoDataBuffer(), prepareAction);
        PointerPointer put = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(transformOp.x().shapeInfoDataBuffer()), prepareAction.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), prepareAction.getBufferAllocation(), prepareAction.getBufferReduction(), prepareAction.getBufferScalar(), prepareAction.getBufferSpecial(), retrieveHostPointer, retrieveHostPointer2, pointer6, pointer7, pointer10, pointer8, pointer9, pointer11, pointer4, pointer5});
        if (transformOp.y() != null) {
            DoublePointer pointer14 = AtomicAllocator.getInstance().getPointer(transformOp.y(), prepareAction);
            IntPointer pointer15 = AtomicAllocator.getInstance().getPointer(transformOp.y().shapeInfoDataBuffer(), prepareAction);
            if (transformOp.x().data().dataType() == DataBuffer.Type.DOUBLE) {
                if (transformOp.x().elementWiseStride() < 1 || transformOp.y().elementWiseStride() < 1 || transformOp.isExecSpecial() || transformOp.x().ordering() != transformOp.y().ordering() || transformOp.x().ordering() != transformOp.z().ordering()) {
                    nativeOps.execPairwiseTransformDouble(put, transformOp.opNum(), pointer, pointer2, pointer14, pointer15, pointer12, pointer13, (DoublePointer) pointer3);
                } else {
                    nativeOps.execPairwiseTransformDouble(put, transformOp.opNum(), pointer, transformOp.x().elementWiseStride(), pointer14, transformOp.y().elementWiseStride(), pointer12, transformOp.z().elementWiseStride(), (DoublePointer) pointer3, transformOp.n());
                }
            } else if (transformOp.x().data().dataType() == DataBuffer.Type.FLOAT) {
                if (transformOp.x().elementWiseStride() < 1 || transformOp.y().elementWiseStride() < 1 || transformOp.x().elementWiseStride() != transformOp.y().elementWiseStride() || transformOp.isExecSpecial() || transformOp.x().ordering() != transformOp.y().ordering() || transformOp.x().ordering() != transformOp.z().ordering()) {
                    nativeOps.execPairwiseTransformFloat(put, transformOp.opNum(), (FloatPointer) pointer, pointer2, (FloatPointer) pointer14, pointer15, (FloatPointer) pointer12, pointer13, (FloatPointer) pointer3);
                } else {
                    nativeOps.execPairwiseTransformFloat(put, transformOp.opNum(), (FloatPointer) pointer, transformOp.x().elementWiseStride(), (FloatPointer) pointer14, transformOp.y().elementWiseStride(), (FloatPointer) pointer12, transformOp.z().elementWiseStride(), (FloatPointer) pointer3, transformOp.n());
                }
            } else if (transformOp.x().elementWiseStride() < 1 || transformOp.y().elementWiseStride() < 1 || transformOp.x().elementWiseStride() != transformOp.y().elementWiseStride() || transformOp.isExecSpecial() || transformOp.x().ordering() != transformOp.y().ordering() || transformOp.x().ordering() != transformOp.z().ordering()) {
                nativeOps.execPairwiseTransformHalf(put, transformOp.opNum(), (ShortPointer) pointer, pointer2, (ShortPointer) pointer14, pointer15, (ShortPointer) pointer12, pointer13, (ShortPointer) pointer3);
            } else {
                nativeOps.execPairwiseTransformHalf(put, transformOp.opNum(), (ShortPointer) pointer, transformOp.x().elementWiseStride(), (ShortPointer) pointer14, transformOp.y().elementWiseStride(), (ShortPointer) pointer12, transformOp.z().elementWiseStride(), (ShortPointer) pointer3, transformOp.n());
            }
        } else if (transformOp.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            if (transformOp.x().elementWiseStride() < 1 || transformOp.isExecSpecial() || transformOp.z().ordering() != transformOp.x().ordering()) {
                nativeOps.execTransformDouble(put, transformOp.opNum(), pointer, pointer2, pointer12, pointer13, (DoublePointer) pointer3);
            } else {
                nativeOps.execTransformDouble(put, transformOp.opNum(), pointer, transformOp.x().elementWiseStride(), pointer12, transformOp.z().elementWiseStride(), (DoublePointer) pointer3, transformOp.n());
            }
        } else if (transformOp.x().data().dataType() == DataBuffer.Type.FLOAT) {
            if (transformOp.x().elementWiseStride() < 1 || transformOp.isExecSpecial() || transformOp.z().ordering() != transformOp.x().ordering()) {
                nativeOps.execTransformFloat(put, transformOp.opNum(), (FloatPointer) pointer, pointer2, (FloatPointer) pointer12, pointer13, (FloatPointer) pointer3);
            } else {
                nativeOps.execTransformFloat(put, transformOp.opNum(), (FloatPointer) pointer, transformOp.x().elementWiseStride(), (FloatPointer) pointer12, transformOp.z().elementWiseStride(), (FloatPointer) pointer3, transformOp.n());
            }
        } else if (transformOp.x().elementWiseStride() < 1 || transformOp.isExecSpecial() || transformOp.z().ordering() != transformOp.x().ordering()) {
            nativeOps.execTransformHalf(put, transformOp.opNum(), (ShortPointer) pointer, pointer2, (ShortPointer) pointer12, pointer13, (ShortPointer) pointer3);
        } else {
            nativeOps.execTransformHalf(put, transformOp.opNum(), (ShortPointer) pointer, transformOp.x().elementWiseStride(), (ShortPointer) pointer12, transformOp.z().elementWiseStride(), (ShortPointer) pointer3, transformOp.n());
        }
        AtomicAllocator.getInstance().registerAction(prepareAction, transformOp.z(), transformOp.x(), transformOp.y());
        if (pointer3 == null) {
            return null;
        }
        pointer3.address();
        return null;
    }

    public static TADManager getTadManager() {
        return tadManager;
    }
}
