package org.nd4j.samediff.frameworkimport.tensorflow.ir;

import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.List;
import java.util.NoSuchElementException;
import kotlin.Metadata;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.Intrinsics;
import kotlin.text.Regex;
import kotlin.text.StringsKt;
import org.apache.commons.io.IOUtils;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper;
import org.nd4j.imports.graphmapper.tf.tensors.TFTensorMappers;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.samediff.frameworkimport.ir.IRAttribute;
import org.nd4j.samediff.frameworkimport.rule.attribute.AttributeValueType;
import org.nd4j.shade.protobuf.Message;
import org.nd4j.shade.protobuf.TextFormat;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.NodeDef;
import org.tensorflow.framework.OpDef;
import org.tensorflow.framework.OpList;
import org.tensorflow.framework.TensorProto;
import org.tensorflow.framework.TensorShapeProto;

/* compiled from: TensorflowIR.kt */
@Metadata(mv = {1, 4, 2}, bv = {1, 0, 3}, k = 2, d1 = {"��`\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000b\n��\n\u0002\u0010\u000e\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0010\u0016\n��\n\u0002\u0018\u0002\n\u0002\b\u0006\u001a\u001e\u0010��\u001a\u001a\u0012\u0004\u0012\u00020\u0002\u0012\u0004\u0012\u00020\u0003\u0012\u0004\u0012\u00020\u0004\u0012\u0004\u0012\u00020\u00050\u0001\u001a\u000e\u0010\u0006\u001a\u00020\u00072\u0006\u0010\b\u001a\u00020\u0002\u001a\u000e\u0010\t\u001a\u00020\u00052\u0006\u0010\n\u001a\u00020\u000b\u001a\u0010\u0010\f\u001a\u00020\u000b2\b\u0010\r\u001a\u0004\u0018\u00010\u0005\u001a\u0010\u0010\u000e\u001a\u0004\u0018\u00010\u000f2\u0006\u0010\u0010\u001a\u00020\u0011\u001a\u000e\u0010\u0012\u001a\u00020\u00132\u0006\u0010\u0014\u001a\u00020\u0015\u001a\u0016\u0010\u0016\u001a\u00020\u00132\u0006\u0010\u0014\u001a\u00020\u00152\u0006\u0010\u0017\u001a\u00020\u0018\u001a\u0016\u0010\u0019\u001a\u00020\u00132\u0006\u0010\u0014\u001a\u00020\u00152\u0006\u0010\u0017\u001a\u00020\u0018\u001a\u0006\u0010\u001a\u001a\u00020\u001b\u001a\u000e\u0010\u001c\u001a\u00020\u000f2\u0006\u0010\u001d\u001a\u00020\u0004\u001a\u0012\u0010\u001e\u001a\u0004\u0018\u00010\u001f2\u0006\u0010 \u001a\u00020!H\u0002\u001a\u000e\u0010\"\u001a\u00020\u00152\u0006\u0010\u0014\u001a\u00020\u0015\u001a\u000e\u0010#\u001a\u00020\u00152\u0006\u0010$\u001a\u00020\u0015\u001a\u0016\u0010%\u001a\u00020\u00072\u0006\u0010&\u001a\u00020\u00152\u0006\u0010\u0017\u001a\u00020\u0018¨\u0006'"}, d2 = {"attrDefaultValue", "Lorg/nd4j/samediff/frameworkimport/ir/IRAttribute;", "Lorg/tensorflow/framework/OpDef$AttrDef;", "Lorg/tensorflow/framework/AttrValue;", "Lorg/tensorflow/framework/TensorProto;", "Lorg/tensorflow/framework/DataType;", "attributeValueTypeForTensorflowAttribute", "Lorg/nd4j/samediff/frameworkimport/rule/attribute/AttributeValueType;", "attributeDef", "convertToDataType", "dataType", "Lorg/nd4j/linalg/api/buffer/DataType;", "convertType", "tfType", "getNDArrayFromTensor", "Lorg/nd4j/linalg/api/ndarray/INDArray;", "node", "Lorg/tensorflow/framework/NodeDef;", "isControlDep", "", "name", "", "isTensorflowAttributeName", "opDef", "Lorg/tensorflow/framework/OpDef;", "isTensorflowTensorName", "loadTensorflowOps", "Lorg/tensorflow/framework/OpList;", "mapTensorProto", "tfTensor", "shapeFromShapeProto", "", "tensorShapeProto", "Lorg/tensorflow/framework/TensorShapeProto;", "stripControl", "stripVarSuffix", "varName", "tensorflowAttributeValueTypeFor", "attributeName", "samediff-import-tensorflow"})
/* loaded from: input_file:org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIRKt.class */
public final class TensorflowIRKt {

    @Metadata(mv = {1, 4, 2}, bv = {1, 0, 3}, k = 3)
    /* loaded from: input_file:org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIRKt$WhenMappings.class */
    public final /* synthetic */ class WhenMappings {
        public static final /* synthetic */ int[] $EnumSwitchMapping$0 = new int[DataType.values().length];
        public static final /* synthetic */ int[] $EnumSwitchMapping$1;

        static {
            $EnumSwitchMapping$0[DataType.UINT16.ordinal()] = 1;
            $EnumSwitchMapping$0[DataType.UINT32.ordinal()] = 2;
            $EnumSwitchMapping$0[DataType.UINT64.ordinal()] = 3;
            $EnumSwitchMapping$0[DataType.BOOL.ordinal()] = 4;
            $EnumSwitchMapping$0[DataType.BFLOAT16.ordinal()] = 5;
            $EnumSwitchMapping$0[DataType.FLOAT.ordinal()] = 6;
            $EnumSwitchMapping$0[DataType.INT.ordinal()] = 7;
            $EnumSwitchMapping$0[DataType.LONG.ordinal()] = 8;
            $EnumSwitchMapping$0[DataType.BYTE.ordinal()] = 9;
            $EnumSwitchMapping$0[DataType.SHORT.ordinal()] = 10;
            $EnumSwitchMapping$0[DataType.DOUBLE.ordinal()] = 11;
            $EnumSwitchMapping$0[DataType.UBYTE.ordinal()] = 12;
            $EnumSwitchMapping$0[DataType.HALF.ordinal()] = 13;
            $EnumSwitchMapping$0[DataType.UTF8.ordinal()] = 14;
            $EnumSwitchMapping$1 = new int[org.tensorflow.framework.DataType.values().length];
            $EnumSwitchMapping$1[org.tensorflow.framework.DataType.DT_DOUBLE.ordinal()] = 1;
            $EnumSwitchMapping$1[org.tensorflow.framework.DataType.DT_FLOAT.ordinal()] = 2;
            $EnumSwitchMapping$1[org.tensorflow.framework.DataType.DT_HALF.ordinal()] = 3;
            $EnumSwitchMapping$1[org.tensorflow.framework.DataType.DT_BFLOAT16.ordinal()] = 4;
            $EnumSwitchMapping$1[org.tensorflow.framework.DataType.DT_INT8.ordinal()] = 5;
            $EnumSwitchMapping$1[org.tensorflow.framework.DataType.DT_INT16.ordinal()] = 6;
            $EnumSwitchMapping$1[org.tensorflow.framework.DataType.DT_INT32.ordinal()] = 7;
            $EnumSwitchMapping$1[org.tensorflow.framework.DataType.DT_INT64.ordinal()] = 8;
            $EnumSwitchMapping$1[org.tensorflow.framework.DataType.DT_UINT8.ordinal()] = 9;
            $EnumSwitchMapping$1[org.tensorflow.framework.DataType.DT_STRING.ordinal()] = 10;
            $EnumSwitchMapping$1[org.tensorflow.framework.DataType.DT_BOOL.ordinal()] = 11;
        }
    }

    @NotNull
    public static final OpList loadTensorflowOps() {
        String iOUtils = IOUtils.toString(new ClassPathResource("ops.proto").getInputStream(), Charset.defaultCharset());
        Message.Builder newBuilder = OpList.newBuilder();
        TextFormat.merge(iOUtils, newBuilder);
        OpList build = newBuilder.build();
        Intrinsics.checkNotNullExpressionValue(build, "tfListBuilder.build()");
        return build;
    }

    @NotNull
    public static final IRAttribute<OpDef.AttrDef, AttrValue, TensorProto, org.tensorflow.framework.DataType> attrDefaultValue() {
        OpDef.AttrDef defaultInstance = OpDef.AttrDef.getDefaultInstance();
        Intrinsics.checkNotNullExpressionValue(defaultInstance, "AttrDef.getDefaultInstance()");
        AttrValue defaultInstance2 = AttrValue.getDefaultInstance();
        Intrinsics.checkNotNullExpressionValue(defaultInstance2, "AttrValue.getDefaultInstance()");
        return new TensorflowIRAttr(defaultInstance, defaultInstance2);
    }

    @NotNull
    public static final org.tensorflow.framework.DataType convertToDataType(@NotNull DataType dataType) {
        Intrinsics.checkNotNullParameter(dataType, "dataType");
        switch (WhenMappings.$EnumSwitchMapping$0[dataType.ordinal()]) {
            case 1:
                return org.tensorflow.framework.DataType.DT_UINT16;
            case 2:
                return org.tensorflow.framework.DataType.DT_UINT32;
            case 3:
                return org.tensorflow.framework.DataType.DT_UINT64;
            case 4:
                return org.tensorflow.framework.DataType.DT_BOOL;
            case 5:
                return org.tensorflow.framework.DataType.DT_BFLOAT16;
            case 6:
                return org.tensorflow.framework.DataType.DT_FLOAT;
            case 7:
                return org.tensorflow.framework.DataType.DT_INT32;
            case 8:
                return org.tensorflow.framework.DataType.DT_INT64;
            case 9:
                return org.tensorflow.framework.DataType.DT_INT8;
            case 10:
                return org.tensorflow.framework.DataType.DT_INT16;
            case 11:
                return org.tensorflow.framework.DataType.DT_DOUBLE;
            case 12:
                return org.tensorflow.framework.DataType.DT_UINT8;
            case 13:
                return org.tensorflow.framework.DataType.DT_HALF;
            case 14:
                return org.tensorflow.framework.DataType.DT_STRING;
            default:
                throw new UnsupportedOperationException("Unknown TF data type: [" + dataType.name() + "]");
        }
    }

    @NotNull
    public static final AttributeValueType tensorflowAttributeValueTypeFor(@NotNull String str, @NotNull OpDef opDef) {
        Intrinsics.checkNotNullParameter(str, "attributeName");
        Intrinsics.checkNotNullParameter(opDef, "opDef");
        List attrList = opDef.getAttrList();
        Intrinsics.checkNotNullExpressionValue(attrList, "opDef.attrList");
        List<OpDef.AttrDef> list = attrList;
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(list, 10));
        for (OpDef.AttrDef attrDef : list) {
            Intrinsics.checkNotNullExpressionValue(attrDef, "attrDef");
            arrayList.add(attrDef.getName());
        }
        if (!arrayList.contains(str) && !isTensorflowTensorName(str, opDef)) {
            throw new IllegalArgumentException("Tensorflow op " + opDef.getName() + " does not have attribute name " + str);
        }
        if (isTensorflowTensorName(str, opDef)) {
            return AttributeValueType.TENSOR;
        }
        List attrList2 = opDef.getAttrList();
        Intrinsics.checkNotNullExpressionValue(attrList2, "opDef.attrList");
        for (Object obj : attrList2) {
            OpDef.AttrDef attrDef2 = (OpDef.AttrDef) obj;
            Intrinsics.checkNotNullExpressionValue(attrDef2, "attrDef");
            if (Intrinsics.areEqual(attrDef2.getName(), str)) {
                OpDef.AttrDef attrDef3 = (OpDef.AttrDef) obj;
                Intrinsics.checkNotNullExpressionValue(attrDef3, "attrDef");
                AttrValue defaultInstance = AttrValue.getDefaultInstance();
                Intrinsics.checkNotNullExpressionValue(defaultInstance, "AttrValue.getDefaultInstance()");
                return new TensorflowIRAttr(attrDef3, defaultInstance).attributeValueType();
            }
        }
        throw new NoSuchElementException("Collection contains no element matching the predicate.");
    }

    public static final boolean isTensorflowTensorName(@NotNull String str, @NotNull OpDef opDef) {
        Intrinsics.checkNotNullParameter(str, "name");
        Intrinsics.checkNotNullParameter(opDef, "opDef");
        List inputArgList = opDef.getInputArgList();
        Intrinsics.checkNotNullExpressionValue(inputArgList, "opDef.inputArgList");
        List<OpDef.ArgDef> list = inputArgList;
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(list, 10));
        for (OpDef.ArgDef argDef : list) {
            Intrinsics.checkNotNullExpressionValue(argDef, "inputDef");
            arrayList.add(argDef.getName());
        }
        return arrayList.contains(str);
    }

    public static final boolean isTensorflowAttributeName(@NotNull String str, @NotNull OpDef opDef) {
        Intrinsics.checkNotNullParameter(str, "name");
        Intrinsics.checkNotNullParameter(opDef, "opDef");
        List attrList = opDef.getAttrList();
        Intrinsics.checkNotNullExpressionValue(attrList, "opDef.attrList");
        List<OpDef.AttrDef> list = attrList;
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(list, 10));
        for (OpDef.AttrDef attrDef : list) {
            Intrinsics.checkNotNullExpressionValue(attrDef, "attrDef");
            arrayList.add(attrDef.getName());
        }
        return arrayList.contains(str);
    }

    private static final long[] shapeFromShapeProto(TensorShapeProto tensorShapeProto) {
        long[] jArr = new long[tensorShapeProto.getDimList().size()];
        int length = jArr.length;
        for (int i = 0; i < length; i++) {
            TensorShapeProto.Dim dim = tensorShapeProto.getDim(i);
            Intrinsics.checkNotNullExpressionValue(dim, "tensorShapeProto.getDim(i)");
            jArr[i] = dim.getSize();
        }
        return jArr;
    }

    @NotNull
    public static final DataType convertType(@Nullable org.tensorflow.framework.DataType dataType) {
        if (dataType != null) {
            switch (WhenMappings.$EnumSwitchMapping$1[dataType.ordinal()]) {
                case 1:
                    return DataType.DOUBLE;
                case 2:
                    return DataType.FLOAT;
                case 3:
                    return DataType.HALF;
                case 4:
                    return DataType.BFLOAT16;
                case 5:
                    return DataType.BYTE;
                case 6:
                    return DataType.SHORT;
                case 7:
                    return DataType.INT;
                case 8:
                    return DataType.LONG;
                case 9:
                    return DataType.UBYTE;
                case 10:
                    return DataType.UTF8;
                case 11:
                    return DataType.BOOL;
            }
        }
        return DataType.UNKNOWN;
    }

    public static final boolean isControlDep(@NotNull String str) {
        Intrinsics.checkNotNullParameter(str, "name");
        return StringsKt.startsWith$default(str, "^", false, 2, (Object) null);
    }

    @NotNull
    public static final String stripControl(@NotNull String str) {
        Intrinsics.checkNotNullParameter(str, "name");
        if (!StringsKt.startsWith$default(str, "^", false, 2, (Object) null)) {
            return str;
        }
        String substring = str.substring(1);
        Intrinsics.checkNotNullExpressionValue(substring, "(this as java.lang.String).substring(startIndex)");
        return substring;
    }

    @NotNull
    public static final String stripVarSuffix(@NotNull String str) {
        Intrinsics.checkNotNullParameter(str, "varName");
        if (!new Regex(".*:\\d+").matches(str)) {
            return str;
        }
        String substring = str.substring(0, StringsKt.lastIndexOf$default(str, ':', 0, false, 6, (Object) null));
        Intrinsics.checkNotNullExpressionValue(substring, "(this as java.lang.Strin…ing(startIndex, endIndex)");
        return substring;
    }

    @Nullable
    public static final INDArray getNDArrayFromTensor(@NotNull NodeDef nodeDef) {
        Intrinsics.checkNotNullParameter(nodeDef, "node");
        if (!nodeDef.getAttrMap().containsKey("value")) {
            return null;
        }
        AttrValue attrOrThrow = nodeDef.getAttrOrThrow("value");
        Intrinsics.checkNotNullExpressionValue(attrOrThrow, "node.getAttrOrThrow(\"value\")");
        TensorProto tensor = attrOrThrow.getTensor();
        Intrinsics.checkNotNullExpressionValue(tensor, "tfTensor");
        return mapTensorProto(tensor);
    }

    @NotNull
    public static final INDArray mapTensorProto(@NotNull TensorProto tensorProto) {
        Intrinsics.checkNotNullParameter(tensorProto, "tfTensor");
        TFTensorMapper newMapper = TFTensorMappers.newMapper(tensorProto);
        if (newMapper == null) {
            throw new RuntimeException("Not implemented datatype: " + tensorProto.getDtype());
        }
        INDArray nDArray = newMapper.toNDArray();
        Intrinsics.checkNotNullExpressionValue(nDArray, "m.toNDArray()");
        return nDArray;
    }

    @NotNull
    public static final AttributeValueType attributeValueTypeForTensorflowAttribute(@NotNull OpDef.AttrDef attrDef) {
        Intrinsics.checkNotNullParameter(attrDef, "attributeDef");
        String type = attrDef.getType();
        if (type != null) {
            switch (type.hashCode()) {
                case -1483838185:
                    if (type.equals("list(float)")) {
                        return AttributeValueType.LIST_FLOAT;
                    }
                    break;
                case -1275194268:
                    if (type.equals("list(int)")) {
                        return AttributeValueType.LIST_INT;
                    }
                    break;
                case -891985903:
                    if (type.equals("string")) {
                        return AttributeValueType.STRING;
                    }
                    break;
                case -882754187:
                    if (type.equals("list(bool)")) {
                        return AttributeValueType.LIST_BOOL;
                    }
                    break;
                case -877319079:
                    if (type.equals("tensor")) {
                        return AttributeValueType.TENSOR;
                    }
                    break;
                case 104431:
                    if (type.equals("int")) {
                        return AttributeValueType.INT;
                    }
                    break;
                case 3029738:
                    if (type.equals("bool")) {
                        return AttributeValueType.BOOL;
                    }
                    break;
                case 3575610:
                    if (type.equals("type")) {
                        return AttributeValueType.DATA_TYPE;
                    }
                    break;
                case 97526364:
                    if (type.equals("float")) {
                        return AttributeValueType.FLOAT;
                    }
                    break;
                case 130340782:
                    if (type.equals("list(string)")) {
                        return AttributeValueType.LIST_STRING;
                    }
                    break;
                case 585012326:
                    if (type.equals("list(tensor)")) {
                        return AttributeValueType.LIST_TENSOR;
                    }
                    break;
            }
        }
        return AttributeValueType.INVALID;
    }
}
