package org.bigml.mimir.math.gpu;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.jocl.CL;
import org.jocl.Pointer;
import org.jocl.cl_kernel;
import org.jocl.cl_mem;

/* loaded from: input_file:org/bigml/mimir/math/gpu/LocalKernel.class */
public class LocalKernel extends ThreadLocal<cl_kernel> {
    protected final List<KernelArg> kernelArgs = new ArrayList();
    protected final ThreadMemory allocatedMemory = new ThreadMemory();
    protected final Program _program;
    protected final String _fn;

    /* loaded from: input_file:org/bigml/mimir/math/gpu/LocalKernel$FloatArg.class */
    private static class FloatArg extends KernelArg {
        private float[] _arg;

        private FloatArg(int i, float[] fArr) {
            this._index = i;
            this._arg = fArr;
        }

        @Override // org.bigml.mimir.math.gpu.LocalKernel.KernelArg
        public cl_mem setArgument(Program program, cl_kernel cl_kernelVar) {
            return program.setArg(cl_kernelVar, this._index, Pointer.to(this._arg), 4 * this._arg.length);
        }
    }

    /* loaded from: input_file:org/bigml/mimir/math/gpu/LocalKernel$IntegerArg.class */
    private static class IntegerArg extends KernelArg {
        private int[] _arg;

        private IntegerArg(int i, int[] iArr) {
            this._index = i;
            this._arg = iArr;
        }

        @Override // org.bigml.mimir.math.gpu.LocalKernel.KernelArg
        public cl_mem setArgument(Program program, cl_kernel cl_kernelVar) {
            return program.setArg(cl_kernelVar, this._index, Pointer.to(this._arg), 4 * this._arg.length);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/bigml/mimir/math/gpu/LocalKernel$KernelArg.class */
    public static abstract class KernelArg {
        protected int _index;

        private KernelArg() {
        }

        public abstract cl_mem setArgument(Program program, cl_kernel cl_kernelVar);
    }

    /* loaded from: input_file:org/bigml/mimir/math/gpu/LocalKernel$SpaceArg.class */
    private static class SpaceArg extends KernelArg {
        private long _size;

        private SpaceArg(int i, long j) {
            this._index = i;
            this._size = j;
        }

        @Override // org.bigml.mimir.math.gpu.LocalKernel.KernelArg
        public cl_mem setArgument(Program program, cl_kernel cl_kernelVar) {
            CL.clSetKernelArg(cl_kernelVar, this._index, 4 * this._size, (Pointer) null);
            return null;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/bigml/mimir/math/gpu/LocalKernel$ThreadMemory.class */
    public class ThreadMemory extends ThreadLocal<List<cl_mem>> {
        private ThreadMemory() {
        }

        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.lang.ThreadLocal
        public List<cl_mem> initialValue() {
            return new ArrayList();
        }
    }

    public LocalKernel(Program program, String str) {
        this._program = program;
        this._fn = str;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Can't rename method to resolve collision */
    @Override // java.lang.ThreadLocal
    public cl_kernel initialValue() {
        cl_kernel clCreateKernel = CL.clCreateKernel(this._program.getProgram(), this._fn, (int[]) null);
        List<cl_mem> list = this.allocatedMemory.get();
        Iterator<KernelArg> it = this.kernelArgs.iterator();
        while (it.hasNext()) {
            list.add(it.next().setArgument(this._program, clCreateKernel));
        }
        return clCreateKernel;
    }

    public void addArg(int i, float[] fArr) {
        this.kernelArgs.add(new FloatArg(i, fArr));
    }

    public void addArg(int i, int[] iArr) {
        this.kernelArgs.add(new IntegerArg(i, iArr));
    }

    public void addArg(int i, long j) {
        this.kernelArgs.add(new SpaceArg(i, j));
    }

    public cl_mem setReadArg(int i, float[] fArr) {
        return this._program.setArg(get(), i, Pointer.to(fArr), 4 * fArr.length);
    }

    public cl_mem setReadArg(int i, int[] iArr) {
        return this._program.setArg(get(), i, Pointer.to(iArr), 4 * iArr.length);
    }

    public cl_mem setWriteArg(int i, float[] fArr) {
        return this._program.setArg(get(), i, null, 4 * fArr.length);
    }

    public void cleanUp() {
        Iterator<cl_mem> it = this.allocatedMemory.get().iterator();
        while (it.hasNext()) {
            CL.clReleaseMemObject(it.next());
        }
        CL.clReleaseKernel(get());
    }
}
