package org.nd4j.linalg.jcublas;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import org.apache.commons.math3.util.Pair;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.LongPointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerPointer;
import org.bytedeco.javacpp.ShortPointer;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.utils.AllocationUtils;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.complex.IComplexDouble;
import org.nd4j.linalg.api.complex.IComplexFloat;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.complex.IComplexNumber;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.cache.TADManager;
import org.nd4j.linalg.compression.CompressedDataBuffer;
import org.nd4j.linalg.compression.CompressionDescriptor;
import org.nd4j.linalg.compression.CompressionType;
import org.nd4j.linalg.factory.BaseNDArrayFactory;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.blas.JcublasLapack;
import org.nd4j.linalg.jcublas.blas.JcublasLevel1;
import org.nd4j.linalg.jcublas.blas.JcublasLevel2;
import org.nd4j.linalg.jcublas.blas.JcublasLevel3;
import org.nd4j.linalg.jcublas.buffer.AddressRetriever;
import org.nd4j.linalg.jcublas.buffer.CudaDoubleDataBuffer;
import org.nd4j.linalg.jcublas.buffer.CudaIntDataBuffer;
import org.nd4j.linalg.jcublas.complex.ComplexDouble;
import org.nd4j.linalg.jcublas.complex.ComplexFloat;
import org.nd4j.linalg.jcublas.complex.JCublasComplexNDArray;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.linalg.jcublas.ops.executioner.CudaExecutioner;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/jcublas/JCublasNDArrayFactory.class */
public class JCublasNDArrayFactory extends BaseNDArrayFactory {
    private NativeOps nativeOps;
    private static Logger log = LoggerFactory.getLogger(JCublasNDArrayFactory.class);

    public JCublasNDArrayFactory() {
        this.nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
    }

    public JCublasNDArrayFactory(DataBuffer.Type type, Character ch) {
        super(type, ch);
        this.nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
    }

    public JCublasNDArrayFactory(DataBuffer.Type type, char c) {
        super(type, c);
        this.nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
    }

    public void createLevel1() {
        this.level1 = new JcublasLevel1();
    }

    public void createLevel2() {
        this.level2 = new JcublasLevel2();
    }

    public void createLevel3() {
        this.level3 = new JcublasLevel3();
    }

    public void createLapack() {
        this.lapack = new JcublasLapack();
    }

    public INDArray create(int[] iArr, DataBuffer dataBuffer) {
        return new JCublasNDArray(iArr, dataBuffer);
    }

    public IComplexFloat createFloat(float f, float f2) {
        return new ComplexFloat(f, f2);
    }

    public IComplexDouble createDouble(double d, double d2) {
        return new ComplexDouble(d, d2);
    }

    public INDArray create(double[][] dArr) {
        return new JCublasNDArray(dArr);
    }

    public INDArray create(double[][] dArr, char c) {
        return new JCublasNDArray(dArr, c);
    }

    public IComplexNDArray createComplex(INDArray iNDArray) {
        return new JCublasComplexNDArray(iNDArray);
    }

    public IComplexNDArray createComplex(IComplexNumber[] iComplexNumberArr, int[] iArr) {
        return new JCublasComplexNDArray(iComplexNumberArr, iArr, Nd4j.getComplexStrides(iArr, Nd4j.order().charValue()));
    }

    public IComplexNDArray createComplex(List<IComplexNDArray> list, int[] iArr) {
        return new JCublasComplexNDArray(list, iArr);
    }

    public INDArray create(DataBuffer dataBuffer) {
        return new JCublasNDArray(dataBuffer);
    }

    public IComplexNDArray createComplex(DataBuffer dataBuffer) {
        return new JCublasComplexNDArray(dataBuffer);
    }

    public IComplexNDArray createComplex(DataBuffer dataBuffer, int i, int i2, int[] iArr, int i3) {
        return new JCublasComplexNDArray(dataBuffer, new int[]{i, i2}, iArr, i3);
    }

    public INDArray create(DataBuffer dataBuffer, int i, int i2, int[] iArr, int i3) {
        return new JCublasNDArray(dataBuffer, new int[]{i, i2}, iArr, i3);
    }

    public IComplexNDArray createComplex(DataBuffer dataBuffer, int[] iArr, int[] iArr2, int i) {
        return new JCublasComplexNDArray(dataBuffer, iArr, iArr2, i);
    }

    public IComplexNDArray createComplex(float[] fArr, int[] iArr, int[] iArr2, int i) {
        return new JCublasComplexNDArray(fArr, iArr, iArr2, i);
    }

    public INDArray create(int[] iArr, char c) {
        return new JCublasNDArray(iArr, c);
    }

    public INDArray createUninitialized(int[] iArr, char c) {
        return new JCublasNDArray(iArr, Nd4j.getStrides(iArr, c), 0, c, false);
    }

    public INDArray create(DataBuffer dataBuffer, int[] iArr, int[] iArr2, int i, char c) {
        return new JCublasNDArray(dataBuffer, iArr, iArr2, i, c);
    }

    public IComplexNDArray createComplex(DataBuffer dataBuffer, int[] iArr, int[] iArr2, int i, char c) {
        return new JCublasComplexNDArray(dataBuffer, iArr, iArr2, i, c);
    }

    public IComplexNDArray createComplex(float[] fArr, Character ch) {
        return new JCublasComplexNDArray(fArr, ch);
    }

    public INDArray create(float[] fArr, int[] iArr, int i, Character ch) {
        return new JCublasNDArray(fArr, iArr, i, ch.charValue());
    }

    public INDArray create(float[] fArr, int i, int i2, int[] iArr, int i3, char c) {
        return new JCublasNDArray(fArr, new int[]{i, i2}, iArr, i3, c);
    }

    public INDArray create(double[] dArr, int[] iArr, char c) {
        return new JCublasNDArray(dArr, iArr, c);
    }

    public INDArray create(List<INDArray> list, int[] iArr, char c) {
        return new JCublasNDArray(list, iArr, c);
    }

    public INDArray create(double[] dArr, int[] iArr, int i) {
        return new JCublasNDArray(dArr, iArr, (char) i);
    }

    public INDArray create(double[] dArr, int[] iArr, int[] iArr2, int i, char c) {
        return new JCublasNDArray(dArr, iArr, iArr2, i, c);
    }

    public IComplexNDArray createComplex(IComplexNumber[] iComplexNumberArr, int[] iArr, int[] iArr2, int i) {
        return new JCublasComplexNDArray(iComplexNumberArr, iArr, iArr2, i);
    }

    public IComplexNDArray createComplex(IComplexNumber[] iComplexNumberArr, int[] iArr, int[] iArr2, int i, char c) {
        return new JCublasComplexNDArray(iComplexNumberArr, iArr, iArr2, i, c);
    }

    public IComplexNDArray createComplex(IComplexNumber[] iComplexNumberArr, int[] iArr, int[] iArr2, char c) {
        return new JCublasComplexNDArray(iComplexNumberArr, iArr, iArr2, c);
    }

    public IComplexNDArray createComplex(IComplexNumber[] iComplexNumberArr, int[] iArr, int i, char c) {
        return new JCublasComplexNDArray(iComplexNumberArr, iArr, i, c);
    }

    public IComplexNDArray createComplex(IComplexNumber[] iComplexNumberArr, int[] iArr, char c) {
        return new JCublasComplexNDArray(iComplexNumberArr, iArr, c);
    }

    public INDArray create(float[] fArr, int[] iArr, int[] iArr2, int i) {
        return new JCublasNDArray(fArr, iArr, iArr2, i);
    }

    public IComplexNDArray createComplex(double[] dArr, int[] iArr, int[] iArr2, int i) {
        return new JCublasComplexNDArray(ArrayUtil.floatCopyOf(dArr), iArr, iArr2, i);
    }

    public INDArray create(double[] dArr, int[] iArr, int[] iArr2, int i) {
        return new JCublasNDArray(dArr, iArr, iArr2, i);
    }

    public INDArray create(DataBuffer dataBuffer, int[] iArr) {
        return new JCublasNDArray(dataBuffer, iArr);
    }

    public IComplexNDArray createComplex(DataBuffer dataBuffer, int[] iArr) {
        return new JCublasComplexNDArray(dataBuffer, iArr);
    }

    public IComplexNDArray createComplex(DataBuffer dataBuffer, int[] iArr, int[] iArr2) {
        return new JCublasComplexNDArray(dataBuffer, iArr, iArr2);
    }

    public INDArray create(DataBuffer dataBuffer, int[] iArr, int[] iArr2, int i) {
        return new JCublasNDArray(dataBuffer, iArr, iArr2, i);
    }

    public INDArray create(List<INDArray> list, int[] iArr) {
        return this.order == 'f' ? new JCublasNDArray(list, iArr, ArrayUtil.calcStridesFortran(iArr)) : new JCublasNDArray(list, iArr);
    }

    public IComplexNDArray createComplex(double[] dArr, int[] iArr, int[] iArr2, int i, char c) {
        return new JCublasComplexNDArray(ArrayUtil.floatCopyOf(dArr), iArr, iArr2, i, c);
    }

    public IComplexNDArray createComplex(double[] dArr, int[] iArr, int i, char c) {
        return new JCublasComplexNDArray(ArrayUtil.floatCopyOf(dArr), iArr, i, c);
    }

    public IComplexNDArray createComplex(DataBuffer dataBuffer, int[] iArr, int i, char c) {
        return new JCublasComplexNDArray(dataBuffer, iArr, i, c);
    }

    public IComplexNDArray createComplex(double[] dArr, int[] iArr, int i) {
        return new JCublasComplexNDArray(ArrayUtil.floatCopyOf(dArr), iArr, i);
    }

    public IComplexNDArray createComplex(DataBuffer dataBuffer, int[] iArr, int i) {
        return new JCublasComplexNDArray(dataBuffer, iArr, i);
    }

    public INDArray create(float[] fArr, int[] iArr, int i) {
        return new JCublasNDArray(fArr, iArr, i);
    }

    public IComplexNDArray createComplex(float[] fArr, int[] iArr, int i, char c) {
        return new JCublasComplexNDArray(fArr, iArr, Nd4j.getComplexStrides(iArr, c), i, c);
    }

    public IComplexNDArray createComplex(float[] fArr, int[] iArr, int i) {
        return new JCublasComplexNDArray(fArr, iArr, i);
    }

    public IComplexNDArray createComplex(float[] fArr, int[] iArr, int[] iArr2, int i, char c) {
        return new JCublasComplexNDArray(fArr, iArr, iArr2, i, c);
    }

    public INDArray create(float[][] fArr) {
        return new JCublasNDArray(fArr);
    }

    public INDArray create(float[][] fArr, char c) {
        return new JCublasNDArray(fArr, c);
    }

    public IComplexNDArray createComplex(float[] fArr) {
        if (fArr.length % 2 != 0) {
            throw new IllegalArgumentException("Complex nd array buffers must have an even number of elements");
        }
        IComplexNDArray createComplex = Nd4j.createComplex(fArr.length / 2);
        int i = 0;
        for (int i2 = 0; i2 < fArr.length - 1; i2 += 2) {
            int i3 = i;
            i++;
            createComplex.putScalar(i3, Nd4j.createDouble(fArr[i2], fArr[i2 + 1]));
        }
        return createComplex;
    }

    public INDArray create(float[] fArr, int[] iArr, int[] iArr2, int i, char c) {
        return new JCublasNDArray(fArr, iArr, iArr2, i, c);
    }

    public INDArray create(DataBuffer dataBuffer, int[] iArr, int i) {
        return new JCublasNDArray(dataBuffer, iArr, i);
    }

    public INDArray toFlattened(Collection<INDArray> collection) {
        return toFlattened(order(), collection);
    }

    public INDArray toFlattened(char c, Collection<INDArray> collection) {
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            Nd4j.getExecutioner().flushQueue();
        }
        int i = 0;
        Iterator<INDArray> it = collection.iterator();
        while (it.hasNext()) {
            i += it.next().length();
        }
        INDArray create = Nd4j.create(new int[]{1, i}, c);
        int i2 = 0;
        AtomicAllocator atomicAllocator = AtomicAllocator.getInstance();
        for (INDArray iNDArray : collection) {
            CudaContext prepareAction = atomicAllocator.getFlowController().prepareAction(create, iNDArray);
            if (iNDArray.ordering() == c && create.elementWiseStride() == iNDArray.elementWiseStride() && create.elementWiseStride() == 1) {
                atomicAllocator.memcpyAsync(create.data(), new CudaPointer(atomicAllocator.getHostPointer(iNDArray).address()), AllocationUtils.getRequiredMemory(AllocationUtils.buildAllocationShape(iNDArray)), i2 * (iNDArray.data().dataType() == DataBuffer.Type.DOUBLE ? 8 : iNDArray.data().dataType() == DataBuffer.Type.FLOAT ? 4 : 2));
                i2 += iNDArray.length();
            } else {
                PointerPointer pointerPointer = new PointerPointer(new Pointer[]{AddressRetriever.retrieveHostPointer(create.shapeInfoDataBuffer()), prepareAction.getOldStream(), atomicAllocator.getDeviceIdPointer(), prepareAction.getBufferAllocation(), prepareAction.getBufferReduction(), prepareAction.getBufferScalar(), prepareAction.getBufferSpecial(), AddressRetriever.retrieveHostPointer(iNDArray.shapeInfoDataBuffer()), AddressRetriever.retrieveHostPointer(create.shapeInfoDataBuffer())});
                if (iNDArray.data().dataType() == DataBuffer.Type.DOUBLE) {
                    this.nativeOps.flattenDouble(pointerPointer, i2, c, atomicAllocator.getPointer(create, prepareAction), atomicAllocator.getPointer(create.shapeInfoDataBuffer(), prepareAction), atomicAllocator.getPointer(iNDArray, prepareAction), atomicAllocator.getPointer(iNDArray.shapeInfoDataBuffer(), prepareAction));
                } else if (iNDArray.data().dataType() == DataBuffer.Type.FLOAT) {
                    this.nativeOps.flattenFloat(pointerPointer, i2, c, atomicAllocator.getPointer(create, prepareAction), atomicAllocator.getPointer(create.shapeInfoDataBuffer(), prepareAction), atomicAllocator.getPointer(iNDArray, prepareAction), atomicAllocator.getPointer(iNDArray.shapeInfoDataBuffer(), prepareAction));
                } else {
                    this.nativeOps.flattenHalf(pointerPointer, i2, c, atomicAllocator.getPointer(create, prepareAction), atomicAllocator.getPointer(create.shapeInfoDataBuffer(), prepareAction), atomicAllocator.getPointer(iNDArray, prepareAction), atomicAllocator.getPointer(iNDArray.shapeInfoDataBuffer(), prepareAction));
                }
                i2 += iNDArray.length();
            }
            if (create != null) {
                atomicAllocator.registerAction(prepareAction, create, iNDArray);
            }
        }
        return create;
    }

    public INDArray concat(int i, INDArray... iNDArrayArr) {
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            Nd4j.getExecutioner().flushQueue();
        }
        if (iNDArrayArr.length == 1) {
            return iNDArrayArr[0];
        }
        int i2 = 0;
        for (int i3 = 0; i3 < iNDArrayArr.length; i3++) {
            if (iNDArrayArr[i3].isCompressed()) {
                Nd4j.getCompressor().decompressi(iNDArrayArr[i3]);
            }
            i2 += iNDArrayArr[i3].size(i);
        }
        int[] copy = ArrayUtil.copy(iNDArrayArr[0].shape());
        copy[i] = i2;
        INDArray createUninitialized = Nd4j.createUninitialized(copy, Nd4j.order().charValue());
        AtomicAllocator atomicAllocator = AtomicAllocator.getInstance();
        CudaContext prepareAction = atomicAllocator.getFlowController().prepareAction(createUninitialized, iNDArrayArr);
        long[] jArr = new long[iNDArrayArr.length];
        long[] jArr2 = new long[iNDArrayArr.length];
        long[] jArr3 = new long[iNDArrayArr.length];
        long[] jArr4 = new long[iNDArrayArr.length];
        long[] jArr5 = new long[iNDArrayArr.length];
        Nd4j.getExecutioner();
        TADManager tadManager = CudaExecutioner.getTadManager();
        for (int i4 = 0; i4 < iNDArrayArr.length; i4++) {
            jArr[i4] = AddressRetriever.retrieveDeviceAddress(iNDArrayArr[i4].shapeInfoDataBuffer(), prepareAction);
            jArr2[i4] = AtomicAllocator.getInstance().getPointer(iNDArrayArr[i4], prepareAction).address();
            jArr5[i4] = AtomicAllocator.getInstance().getHostPointer(iNDArrayArr[i4].shapeInfoDataBuffer()).address();
            i2 += iNDArrayArr[i4].size(i);
            for (int i5 = 0; i5 < iNDArrayArr[i4].rank(); i5++) {
                if (i5 != i && iNDArrayArr[i4].size(i5) != copy[i5]) {
                    throw new IllegalArgumentException("Illegal concatneation at array " + i4 + " and shape element " + i5);
                }
            }
            Pair tADOnlyShapeInfo = tadManager.getTADOnlyShapeInfo(iNDArrayArr[i4], new int[]{i});
            long address = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), prepareAction).address();
            long address2 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getSecond(), prepareAction).address();
            jArr3[i4] = address;
            jArr4[i4] = address2;
        }
        DoublePointer pointer = AtomicAllocator.getInstance().getPointer(createUninitialized, prepareAction);
        IntPointer retrieveDevicePointer = AddressRetriever.retrieveDevicePointer(createUninitialized.shapeInfoDataBuffer(), prepareAction);
        CudaDoubleDataBuffer cudaDoubleDataBuffer = new CudaDoubleDataBuffer(iNDArrayArr.length);
        CudaDoubleDataBuffer cudaDoubleDataBuffer2 = new CudaDoubleDataBuffer(iNDArrayArr.length);
        CudaDoubleDataBuffer cudaDoubleDataBuffer3 = new CudaDoubleDataBuffer(iNDArrayArr.length);
        CudaDoubleDataBuffer cudaDoubleDataBuffer4 = new CudaDoubleDataBuffer(iNDArrayArr.length);
        AtomicAllocator.getInstance().memcpyBlocking(cudaDoubleDataBuffer, new LongPointer(jArr2), jArr2.length * 8, 0L);
        AtomicAllocator.getInstance().memcpyBlocking(cudaDoubleDataBuffer2, new LongPointer(jArr), jArr.length * 8, 0L);
        AtomicAllocator.getInstance().memcpyBlocking(cudaDoubleDataBuffer3, new LongPointer(jArr3), jArr3.length * 8, 0L);
        AtomicAllocator.getInstance().memcpyBlocking(cudaDoubleDataBuffer4, new LongPointer(jArr4), jArr4.length * 8, 0L);
        Pointer pointer2 = AtomicAllocator.getInstance().getPointer(cudaDoubleDataBuffer, prepareAction);
        Pointer pointer3 = AtomicAllocator.getInstance().getPointer(cudaDoubleDataBuffer2, prepareAction);
        Pointer pointer4 = AtomicAllocator.getInstance().getPointer(cudaDoubleDataBuffer3, prepareAction);
        Pointer pointer5 = AtomicAllocator.getInstance().getPointer(cudaDoubleDataBuffer4, prepareAction);
        PointerPointer pointerPointer = new PointerPointer(new Pointer[]{AddressRetriever.retrieveHostPointer(createUninitialized.shapeInfoDataBuffer()), prepareAction.getOldStream(), atomicAllocator.getDeviceIdPointer(), prepareAction.getBufferAllocation(), prepareAction.getBufferReduction(), prepareAction.getBufferScalar(), prepareAction.getBufferSpecial(), AddressRetriever.retrieveHostPointer(iNDArrayArr[0].shapeInfoDataBuffer()), AddressRetriever.retrieveHostPointer(createUninitialized.shapeInfoDataBuffer()), new LongPointer(jArr5)});
        if (createUninitialized.data().dataType() == DataBuffer.Type.DOUBLE) {
            this.nativeOps.concatDouble(pointerPointer, i, iNDArrayArr.length, new PointerPointer(new Pointer[]{pointer2}), new PointerPointer(new Pointer[]{pointer3}), pointer, retrieveDevicePointer, new PointerPointer(new Pointer[]{pointer4}), new PointerPointer(new Pointer[]{pointer5}));
        } else if (createUninitialized.data().dataType() == DataBuffer.Type.FLOAT) {
            this.nativeOps.concatFloat(pointerPointer, i, iNDArrayArr.length, new PointerPointer(new Pointer[]{pointer2}), new PointerPointer(new Pointer[]{pointer3}), (FloatPointer) pointer, retrieveDevicePointer, new PointerPointer(new Pointer[]{pointer4}), new PointerPointer(new Pointer[]{pointer5}));
        } else {
            this.nativeOps.concatHalf(pointerPointer, i, iNDArrayArr.length, new PointerPointer(new Pointer[]{pointer2}), new PointerPointer(new Pointer[]{pointer3}), (ShortPointer) pointer, retrieveDevicePointer, new PointerPointer(new Pointer[]{pointer4}), new PointerPointer(new Pointer[]{pointer5}));
        }
        atomicAllocator.registerAction(prepareAction, createUninitialized, iNDArrayArr);
        return createUninitialized;
    }

    public INDArray pullRows(INDArray iNDArray, int i, int[] iArr) {
        return pullRows(iNDArray, i, iArr, Nd4j.order().charValue());
    }

    public INDArray pullRows(INDArray iNDArray, int i, int[] iArr, char c) {
        int[] iArr2;
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            Nd4j.getExecutioner().flushQueue();
        }
        if (iArr == null || iArr.length < 1) {
            throw new IllegalStateException("Indexes can't be null or zero-length");
        }
        if (i == 1) {
            iArr2 = new int[]{iArr.length, iNDArray.shape()[i]};
        } else {
            if (i != 0) {
                throw new UnsupportedOperationException("2D input is expected");
            }
            iArr2 = new int[]{iNDArray.shape()[i], iArr.length};
        }
        INDArray createUninitialized = Nd4j.createUninitialized(iArr2, c);
        AtomicAllocator atomicAllocator = AtomicAllocator.getInstance();
        CudaContext prepareAction = atomicAllocator.getFlowController().prepareAction(createUninitialized, iNDArray);
        DoublePointer pointer = AtomicAllocator.getInstance().getPointer(iNDArray, prepareAction);
        IntPointer pointer2 = AtomicAllocator.getInstance().getPointer(iNDArray.shapeInfoDataBuffer(), prepareAction);
        DoublePointer pointer3 = AtomicAllocator.getInstance().getPointer(createUninitialized, prepareAction);
        IntPointer pointer4 = AtomicAllocator.getInstance().getPointer(createUninitialized.shapeInfoDataBuffer(), prepareAction);
        PointerPointer pointerPointer = new PointerPointer(new Pointer[]{AddressRetriever.retrieveHostPointer(createUninitialized.shapeInfoDataBuffer()), prepareAction.getOldStream(), atomicAllocator.getDeviceIdPointer()});
        CudaIntDataBuffer cudaIntDataBuffer = new CudaIntDataBuffer(iArr.length);
        AtomicAllocator.getInstance().memcpyBlocking(cudaIntDataBuffer, new IntPointer(iArr), iArr.length * 4, 0L);
        IntPointer pointer5 = AtomicAllocator.getInstance().getPointer(cudaIntDataBuffer, prepareAction);
        Nd4j.getExecutioner();
        TADManager tadManager = CudaExecutioner.getTadManager();
        Pair tADOnlyShapeInfo = tadManager.getTADOnlyShapeInfo(iNDArray, new int[]{i});
        Pair tADOnlyShapeInfo2 = tadManager.getTADOnlyShapeInfo(createUninitialized, new int[]{i});
        IntPointer pointer6 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), prepareAction);
        IntPointer pointer7 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo2.getFirst(), prepareAction);
        IntPointer pointer8 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getSecond(), prepareAction);
        IntPointer pointer9 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo2.getSecond(), prepareAction);
        if (createUninitialized.data().dataType() == DataBuffer.Type.DOUBLE) {
            this.nativeOps.pullRowsDouble(pointerPointer, pointer, pointer2, pointer3, pointer4, iArr.length, pointer5, pointer6, pointer8, pointer7, pointer9);
        } else if (createUninitialized.data().dataType() == DataBuffer.Type.FLOAT) {
            this.nativeOps.pullRowsFloat(pointerPointer, (FloatPointer) pointer, pointer2, (FloatPointer) pointer3, pointer4, iArr.length, pointer5, pointer6, pointer8, pointer7, pointer9);
        } else {
            this.nativeOps.pullRowsHalf(pointerPointer, (ShortPointer) pointer, pointer2, (ShortPointer) pointer3, pointer4, iArr.length, pointer5, pointer6, pointer8, pointer7, pointer9);
        }
        atomicAllocator.registerAction(prepareAction, createUninitialized, iNDArray);
        return createUninitialized;
    }

    public INDArray average(INDArray iNDArray, INDArray[] iNDArrayArr) {
        if (iNDArrayArr == null || iNDArrayArr.length == 0) {
            throw new RuntimeException("Input arrays are missing");
        }
        if (iNDArrayArr.length == 1) {
            return iNDArray.assign(iNDArrayArr[0]);
        }
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            Nd4j.getExecutioner().flushQueue();
        }
        long lengthLong = iNDArray.lengthLong();
        AtomicAllocator atomicAllocator = AtomicAllocator.getInstance();
        CudaContext prepareAction = atomicAllocator.getFlowController().prepareAction(iNDArray, iNDArrayArr);
        PointerPointer pointerPointer = new PointerPointer(new Pointer[]{null, prepareAction.getOldStream(), atomicAllocator.getDeviceIdPointer()});
        DoublePointer pointer = AtomicAllocator.getInstance().getPointer(iNDArray, prepareAction);
        long[] jArr = new long[iNDArrayArr.length];
        for (int i = 0; i < iNDArrayArr.length; i++) {
            if (iNDArrayArr[i].lengthLong() != lengthLong) {
                throw new RuntimeException("All arrays should have equal length for averaging");
            }
            AllocationPoint allocationPoint = atomicAllocator.getAllocationPoint(iNDArrayArr[i]);
            jArr[i] = allocationPoint.getPointers().getDevicePointer().address();
            allocationPoint.tickDeviceWrite();
        }
        CudaDoubleDataBuffer cudaDoubleDataBuffer = new CudaDoubleDataBuffer(iNDArrayArr.length);
        atomicAllocator.memcpyBlocking(cudaDoubleDataBuffer, new LongPointer(jArr), jArr.length * 8, 0L);
        PointerPointer pointerPointer2 = new PointerPointer(AtomicAllocator.getInstance().getPointer(cudaDoubleDataBuffer, prepareAction));
        if (iNDArray.data().dataType() == DataBuffer.Type.DOUBLE) {
            this.nativeOps.averageDouble(pointerPointer, pointerPointer2, pointer, iNDArrayArr.length, lengthLong, true);
        } else if (iNDArray.data().dataType() == DataBuffer.Type.FLOAT) {
            this.nativeOps.averageFloat(pointerPointer, pointerPointer2, (FloatPointer) pointer, iNDArrayArr.length, lengthLong, true);
        } else {
            this.nativeOps.averageHalf(pointerPointer, pointerPointer2, (ShortPointer) pointer, iNDArrayArr.length, lengthLong, true);
        }
        atomicAllocator.getFlowController().registerAction(prepareAction, iNDArray, iNDArrayArr);
        cudaDoubleDataBuffer.address();
        return iNDArray;
    }

    public INDArray average(Collection<INDArray> collection) {
        return average((INDArray[]) collection.toArray(new INDArray[0]));
    }

    public INDArray average(INDArray[] iNDArrayArr) {
        if (iNDArrayArr == null || iNDArrayArr.length == 0) {
            throw new RuntimeException("Input arrays are missing");
        }
        return average(Nd4j.createUninitialized(iNDArrayArr[0].shape(), iNDArrayArr[0].ordering()), iNDArrayArr);
    }

    public INDArray average(INDArray iNDArray, Collection<INDArray> collection) {
        return average(iNDArray, (INDArray[]) collection.toArray(new INDArray[0]));
    }

    public void shuffle(INDArray iNDArray, Random random, int... iArr) {
        shuffle(Collections.singletonList(iNDArray), random, iArr);
    }

    public void shuffle(List<INDArray> list, Random random, List<int[]> list2) {
        if (list2 == null || list2.size() == 0) {
            throw new RuntimeException("Dimension can't be null or 0-length");
        }
        if (list == null || list.size() == 0) {
            throw new RuntimeException("No input arrays provided");
        }
        if (list2.size() > 1 && list.size() != list2.size()) {
            throw new IllegalStateException("Number of dimensions do not match number of arrays to shuffle");
        }
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            Nd4j.getExecutioner().flushQueue();
        }
        AtomicAllocator atomicAllocator = AtomicAllocator.getInstance();
        CudaContext cudaContext = null;
        for (int i = 0; i < list.size(); i++) {
            cudaContext = atomicAllocator.getFlowController().prepareAction(list.get(i), new INDArray[0]);
        }
        int i2 = 1;
        for (int i3 = 0; i3 < list2.get(0).length; i3++) {
            i2 *= list.get(0).shape()[list2.get(0)[i3]];
        }
        CudaIntDataBuffer cudaIntDataBuffer = new CudaIntDataBuffer(ArrayUtil.buildInterleavedVector(random, list.get(0).length() / i2));
        Pointer pointer = atomicAllocator.getPointer(cudaIntDataBuffer, cudaContext);
        PointerPointer pointerPointer = new PointerPointer(new Pointer[]{null, cudaContext.getOldStream(), atomicAllocator.getDeviceIdPointer()});
        long[] jArr = new long[list.size()];
        long[] jArr2 = new long[list.size()];
        long[] jArr3 = new long[list.size()];
        long[] jArr4 = new long[list.size()];
        for (int i4 = 0; i4 < list.size(); i4++) {
            INDArray iNDArray = list.get(i4);
            Pointer pointer2 = AtomicAllocator.getInstance().getPointer(iNDArray, cudaContext);
            Pointer pointer3 = AtomicAllocator.getInstance().getPointer(iNDArray.shapeInfoDataBuffer(), cudaContext);
            Nd4j.getExecutioner();
            Pair tADOnlyShapeInfo = CudaExecutioner.getTadManager().getTADOnlyShapeInfo(iNDArray, list2.size() > 1 ? list2.get(i4) : list2.get(0));
            Pointer pointer4 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), cudaContext);
            Pointer pointer5 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getSecond(), cudaContext);
            jArr[i4] = pointer2.address();
            jArr2[i4] = pointer3.address();
            jArr3[i4] = pointer4.address();
            jArr4[i4] = pointer5.address();
        }
        CudaDoubleDataBuffer cudaDoubleDataBuffer = new CudaDoubleDataBuffer(list.size());
        CudaDoubleDataBuffer cudaDoubleDataBuffer2 = new CudaDoubleDataBuffer(list.size());
        CudaDoubleDataBuffer cudaDoubleDataBuffer3 = new CudaDoubleDataBuffer(list.size());
        CudaDoubleDataBuffer cudaDoubleDataBuffer4 = new CudaDoubleDataBuffer(list.size());
        AtomicAllocator.getInstance().memcpyBlocking(cudaDoubleDataBuffer, new LongPointer(jArr), jArr.length * 8, 0L);
        AtomicAllocator.getInstance().memcpyBlocking(cudaDoubleDataBuffer2, new LongPointer(jArr2), jArr.length * 8, 0L);
        AtomicAllocator.getInstance().memcpyBlocking(cudaDoubleDataBuffer3, new LongPointer(jArr3), jArr.length * 8, 0L);
        AtomicAllocator.getInstance().memcpyBlocking(cudaDoubleDataBuffer4, new LongPointer(jArr4), jArr.length * 8, 0L);
        if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
            this.nativeOps.shuffleDouble(pointerPointer, new PointerPointer(atomicAllocator.getPointer(cudaDoubleDataBuffer, cudaContext)), new PointerPointer(atomicAllocator.getPointer(cudaDoubleDataBuffer2, cudaContext)), new PointerPointer(atomicAllocator.getPointer(cudaDoubleDataBuffer, cudaContext)), new PointerPointer(atomicAllocator.getPointer(cudaDoubleDataBuffer2, cudaContext)), list.size(), (IntPointer) pointer, new PointerPointer(atomicAllocator.getPointer(cudaDoubleDataBuffer3, cudaContext)), new PointerPointer(atomicAllocator.getPointer(cudaDoubleDataBuffer4, cudaContext)));
        } else if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
            this.nativeOps.shuffleFloat(pointerPointer, new PointerPointer(atomicAllocator.getPointer(cudaDoubleDataBuffer, cudaContext)), new PointerPointer(atomicAllocator.getPointer(cudaDoubleDataBuffer2, cudaContext)), new PointerPointer(atomicAllocator.getPointer(cudaDoubleDataBuffer, cudaContext)), new PointerPointer(atomicAllocator.getPointer(cudaDoubleDataBuffer2, cudaContext)), list.size(), (IntPointer) pointer, new PointerPointer(atomicAllocator.getPointer(cudaDoubleDataBuffer3, cudaContext)), new PointerPointer(atomicAllocator.getPointer(cudaDoubleDataBuffer4, cudaContext)));
        } else {
            this.nativeOps.shuffleHalf(pointerPointer, new PointerPointer(atomicAllocator.getPointer(cudaDoubleDataBuffer, cudaContext)), new PointerPointer(atomicAllocator.getPointer(cudaDoubleDataBuffer2, cudaContext)), new PointerPointer(atomicAllocator.getPointer(cudaDoubleDataBuffer, cudaContext)), new PointerPointer(atomicAllocator.getPointer(cudaDoubleDataBuffer2, cudaContext)), list.size(), (IntPointer) pointer, new PointerPointer(atomicAllocator.getPointer(cudaDoubleDataBuffer3, cudaContext)), new PointerPointer(atomicAllocator.getPointer(cudaDoubleDataBuffer4, cudaContext)));
        }
        for (int i5 = 0; i5 < list.size(); i5++) {
            atomicAllocator.getFlowController().registerAction(cudaContext, list.get(i5), new INDArray[0]);
        }
        cudaIntDataBuffer.address();
        cudaDoubleDataBuffer.dataType();
        cudaDoubleDataBuffer2.dataType();
        cudaDoubleDataBuffer4.dataType();
        cudaDoubleDataBuffer3.dataType();
    }

    public void shuffle(Collection<INDArray> collection, Random random, int... iArr) {
        shuffle(new ArrayList(collection), random, Collections.singletonList(iArr));
    }

    public INDArray convertDataEx(DataBuffer.TypeEx typeEx, INDArray iNDArray, DataBuffer.TypeEx typeEx2) {
        if (iNDArray.isView()) {
            throw new UnsupportedOperationException("Impossible to compress View. Consider using dup() before. ");
        }
        DataBuffer convertDataEx = convertDataEx(typeEx, iNDArray.data(), typeEx2);
        iNDArray.setData(convertDataEx);
        if (convertDataEx instanceof CompressedDataBuffer) {
            iNDArray.markAsCompressed(true);
        } else {
            iNDArray.markAsCompressed(false);
        }
        return iNDArray;
    }

    public void convertDataEx(DataBuffer.TypeEx typeEx, Pointer pointer, DataBuffer.TypeEx typeEx2, Pointer pointer2, long j) {
        this.nativeOps.convertTypes((PointerPointer) null, typeEx.ordinal(), pointer, j, typeEx2.ordinal(), pointer2);
    }

    public void convertDataEx(DataBuffer.TypeEx typeEx, DataBuffer dataBuffer, DataBuffer.TypeEx typeEx2, DataBuffer dataBuffer2) {
        convertDataEx(typeEx, dataBuffer.addressPointer(), typeEx2, dataBuffer2.addressPointer(), dataBuffer2.length());
    }

    public DataBuffer convertDataEx(DataBuffer.TypeEx typeEx, DataBuffer dataBuffer, DataBuffer.TypeEx typeEx2) {
        int i;
        DataBuffer createBuffer;
        if (typeEx2.ordinal() <= 2) {
            i = 1;
        } else if (typeEx2.ordinal() <= 5) {
            i = 2;
        } else if (typeEx2.ordinal() == 6) {
            i = 4;
        } else {
            if (typeEx2.ordinal() != 7) {
                throw new UnsupportedOperationException("Unknown target TypeEx: " + typeEx2.name());
            }
            i = 8;
        }
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            Nd4j.getExecutioner().flushQueueBlocking();
        }
        if (!(dataBuffer instanceof CompressedDataBuffer)) {
            AtomicAllocator.getInstance().synchronizeHostData(dataBuffer);
        }
        if (typeEx2.ordinal() < 6) {
            BytePointer bytePointer = new BytePointer(dataBuffer.length() * i);
            CompressionDescriptor compressionDescriptor = new CompressionDescriptor(dataBuffer, typeEx2.name());
            compressionDescriptor.setCompressionType(CompressionType.LOSSY);
            compressionDescriptor.setCompressedLength(dataBuffer.length() * i);
            createBuffer = new CompressedDataBuffer(bytePointer, compressionDescriptor);
        } else {
            createBuffer = Nd4j.createBuffer(dataBuffer.length(), false);
            AtomicAllocator.getInstance().getAllocationPoint(createBuffer).tickHostWrite();
        }
        convertDataEx(typeEx, dataBuffer, typeEx2, createBuffer);
        return createBuffer;
    }
}
