package org.nd4j.autodiff.execution;

import java.io.File;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Map;
import org.bytedeco.javacpp.BytePointer;
import org.nd4j.autodiff.execution.GraphExecutioner;
import org.nd4j.autodiff.execution.conf.ExecutionMode;
import org.nd4j.autodiff.execution.conf.ExecutorConfiguration;
import org.nd4j.autodiff.execution.conf.OutputMode;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.graph.FlatResult;
import org.nd4j.graph.FlatVariable;
import org.nd4j.linalg.api.memory.pointers.PagedPointer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.nd4j.nativeblas.ResultWrapperAbstraction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/autodiff/execution/NativeGraphExecutioner.class */
public class NativeGraphExecutioner implements GraphExecutioner {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) NativeGraphExecutioner.class);

    @Override // org.nd4j.autodiff.execution.GraphExecutioner
    public GraphExecutioner.Type getExecutionerType() {
        return GraphExecutioner.Type.LOCAL;
    }

    @Override // org.nd4j.autodiff.execution.GraphExecutioner
    public INDArray[] executeGraph(SameDiff sameDiff) {
        return executeGraph(sameDiff, ExecutorConfiguration.builder().outputMode(OutputMode.IMPLICIT).executionMode(ExecutionMode.SEQUENTIAL).profilingMode(OpExecutioner.ProfilingMode.DISABLED).build());
    }

    @Override // org.nd4j.autodiff.execution.GraphExecutioner
    public INDArray[] reuseGraph(SameDiff sameDiff, Map<Integer, INDArray> map) {
        throw new UnsupportedOperationException();
    }

    public ByteBuffer convertToFlatBuffers(SameDiff sameDiff, ExecutorConfiguration executorConfiguration, Map<Integer, Node> map) {
        log.info("Configuration: {}", executorConfiguration);
        return sameDiff.asFlatBuffers(executorConfiguration);
    }

    @Override // org.nd4j.autodiff.execution.GraphExecutioner
    public ByteBuffer convertToFlatBuffers(SameDiff sameDiff, ExecutorConfiguration executorConfiguration) {
        return convertToFlatBuffers(sameDiff, executorConfiguration, new HashMap());
    }

    @Override // org.nd4j.autodiff.execution.GraphExecutioner
    public INDArray[] executeGraph(SameDiff sameDiff, ExecutorConfiguration executorConfiguration) {
        ByteBuffer convertToFlatBuffers = convertToFlatBuffers(sameDiff, executorConfiguration, new HashMap());
        BytePointer bytePointer = new BytePointer(convertToFlatBuffers);
        log.info("Buffer length: {}", Integer.valueOf(convertToFlatBuffers.limit()));
        ResultWrapperAbstraction executeFlatGraphFloat = NativeOpsHolder.getInstance().getDeviceNativeOps().executeFlatGraphFloat(null, bytePointer);
        if (executeFlatGraphFloat == null) {
            throw new ND4JIllegalStateException("Graph execution failed");
        }
        FlatResult rootAsFlatResult = FlatResult.getRootAsFlatResult(new PagedPointer(executeFlatGraphFloat.pointer(), executeFlatGraphFloat.size()).asBytePointer().asByteBuffer());
        log.info("VarMap: {}", sameDiff.variableMap());
        INDArray[] iNDArrayArr = new INDArray[rootAsFlatResult.variablesLength()];
        for (int i = 0; i < rootAsFlatResult.variablesLength(); i++) {
            FlatVariable variables = rootAsFlatResult.variables(i);
            INDArray createFromFlatArray = Nd4j.createFromFlatArray(variables.ndarray());
            iNDArrayArr[i] = createFromFlatArray;
            if (variables.name() != null && sameDiff.variableMap().containsKey(variables.name())) {
                sameDiff.associateArrayWithVariable(createFromFlatArray, sameDiff.variableMap().get(variables.name()));
            } else if (sameDiff.variableMap().get(variables.name()) != null) {
                sameDiff.associateArrayWithVariable(createFromFlatArray, sameDiff.getVariable(variables.name()));
            } else {
                log.warn("Unknown variable received: [{}]", variables.name());
            }
        }
        NativeOpsHolder.getInstance().getDeviceNativeOps().deleteResultWrapper(executeFlatGraphFloat);
        return iNDArrayArr;
    }

    public static long getOpNum(String str, Op.Type type) {
        return type == Op.Type.CUSTOM ? Nd4j.getExecutioner().getCustomOperations().get(str.toLowerCase()).getHash() : Nd4j.getOpFactory().getOpNumByName(str);
    }

    public static byte getFlatOpType(Op.Type type) {
        switch (type) {
            case SCALAR:
                return (byte) 3;
            case BROADCAST:
                return (byte) 4;
            case TRANSFORM:
                return (byte) 0;
            case REDUCE:
                return (byte) 1;
            case INDEXREDUCE:
                return (byte) 2;
            case CUSTOM:
                return (byte) 11;
            default:
                throw new UnsupportedOperationException("Unknown op type passed in: " + type);
        }
    }

    @Override // org.nd4j.autodiff.execution.GraphExecutioner
    public INDArray[] executeGraph(int i, SDVariable... sDVariableArr) {
        return new INDArray[0];
    }

    @Override // org.nd4j.autodiff.execution.GraphExecutioner
    public int registerGraph(SameDiff sameDiff) {
        return 0;
    }

    @Override // org.nd4j.autodiff.execution.GraphExecutioner
    public INDArray[] importProto(File file) {
        throw new UnsupportedOperationException("Not implemented yet");
    }
}
