package org.nd4j.linalg.jcublas.util;

import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.Multimap;
import java.util.ArrayList;
import java.util.List;
import jcuda.Pointer;
import jcuda.runtime.JCuda;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.jcublas.CublasPointer;
import org.nd4j.linalg.jcublas.buffer.JCudaBuffer;
import org.nd4j.linalg.jcublas.context.CudaContext;

/* loaded from: input_file:org/nd4j/linalg/jcublas/util/KernelParamsWrapper.class */
public class KernelParamsWrapper implements AutoCloseable {
    private boolean closeInvoked;
    private boolean closeContext;
    private CudaContext context;
    public final Object[] kernelParameters;
    final List<CublasPointer> pointersToFree;
    final List<CublasPointer> resultPointers;
    private Op resultOp;
    private Multimap<INDArray, CublasPointer> arrayToPointer;
    private int resultLength;

    public Object[] getKernelParameters() {
        return this.kernelParameters;
    }

    public KernelParamsWrapper setResultArray(INDArray iNDArray) {
        CublasPointer cublasPointer = (CublasPointer) this.arrayToPointer.get(iNDArray).iterator().next();
        cublasPointer.setResultPointer(true);
        if (cublasPointer == null) {
            throw new RuntimeException("Results array must be supplied as a kernel parameter");
        }
        this.resultPointers.add(cublasPointer);
        return this;
    }

    public KernelParamsWrapper setResultOp(Accumulation accumulation, INDArray iNDArray) {
        this.resultOp = accumulation;
        setResultArray(iNDArray);
        return this;
    }

    public KernelParamsWrapper(Op op, Object... objArr) {
        this(op, false, objArr);
    }

    public KernelParamsWrapper(Op op, boolean z, Object... objArr) {
        this.closeInvoked = false;
        this.resultLength = 1;
        this.kernelParameters = new Object[objArr.length];
        this.arrayToPointer = ArrayListMultimap.create();
        this.pointersToFree = new ArrayList();
        this.resultPointers = new ArrayList();
        this.context = new CudaContext(z);
        this.context.initOldStream();
        this.context.initStream();
        this.closeContext = z;
        for (int i = 0; i < objArr.length; i++) {
            Object obj = objArr[i];
            if (obj instanceof JCudaBuffer) {
                CublasPointer cublasPointer = new CublasPointer((JCudaBuffer) obj, this.context);
                this.kernelParameters[i] = cublasPointer.getDevicePointer();
                this.pointersToFree.add(cublasPointer);
            } else if (obj instanceof INDArray) {
                INDArray iNDArray = (INDArray) obj;
                CublasPointer cublasPointer2 = new CublasPointer(iNDArray, this.context);
                this.kernelParameters[i] = cublasPointer2.getDevicePointer();
                this.pointersToFree.add(cublasPointer2);
                this.arrayToPointer.put(iNDArray, cublasPointer2);
            } else {
                this.kernelParameters[i] = obj;
            }
        }
    }

    @Override // java.lang.AutoCloseable
    public void close() throws Exception {
        if (this.closeInvoked) {
            return;
        }
        for (CublasPointer cublasPointer : this.pointersToFree) {
            if (this.resultPointers.contains(cublasPointer)) {
                if (this.closeContext) {
                    if (cublasPointer.getBuffer().length() == 1 && (this.resultOp instanceof Accumulation)) {
                        setResultForOp(this.resultOp, cublasPointer);
                    } else {
                        cublasPointer.copyToHost();
                    }
                    cublasPointer.close();
                } else {
                    this.context.setResultPointer(cublasPointer);
                }
            }
        }
        if (this.closeContext) {
            this.context.destroy();
        }
        this.closeInvoked = true;
    }

    private void setResultForOp(Op op, CublasPointer cublasPointer) {
        if (cublasPointer.getBuffer().dataType() == DataBuffer.Type.DOUBLE) {
            double[] dArr = new double[this.resultLength];
            JCuda.cudaMemcpyAsync(Pointer.to(dArr), cublasPointer.getDevicePointer(), this.resultLength * 8, 2, this.context.getOldStream());
            this.context.syncOldStream();
            if (op instanceof Accumulation) {
                ((Accumulation) op).setFinalResult(Double.valueOf(dArr[0]));
                return;
            }
            return;
        }
        if (cublasPointer.getBuffer().dataType() == DataBuffer.Type.FLOAT) {
            float[] fArr = new float[this.resultLength];
            JCuda.cudaMemcpyAsync(Pointer.to(fArr), cublasPointer.getDevicePointer(), this.resultLength * 4, 2, this.context.getOldStream());
            this.context.syncOldStream();
            if (op instanceof Accumulation) {
                ((Accumulation) op).setFinalResult(Float.valueOf(fArr[0]));
            }
        }
    }

    public CudaContext getContext() {
        return this.context;
    }

    public void sync() {
        this.context.syncStream();
        JCuda.cudaDeviceSynchronize();
    }
}
