package org.nd4j.linalg.api.ops.impl.layers.convolution;

import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
import org.nd4j.imports.descriptors.properties.AttributeAdapter;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.imports.descriptors.properties.adapters.IntArrayIntIndexAdpater;
import org.nd4j.imports.descriptors.properties.adapters.NDArrayShapeAdapter;
import org.nd4j.imports.descriptors.properties.adapters.StringNotEqualsAdapter;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

/* loaded from: input_file:org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3D.class */
public class Conv3D extends DynamicCustomOp {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) Conv3D.class);
    protected Conv3DConfig config;
    private static final String INVALID_CONFIGURATION = "Invalid Conv3D configuration : sW = %s pH = %s dW = %s ";

    /* loaded from: input_file:org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3D$Conv3DBuilder.class */
    public static class Conv3DBuilder {
        private SameDiff sameDiff;
        private SDVariable[] inputFunctions;
        private Conv3DConfig config;

        Conv3DBuilder() {
        }

        public Conv3DBuilder sameDiff(SameDiff sameDiff) {
            this.sameDiff = sameDiff;
            return this;
        }

        public Conv3DBuilder inputFunctions(SDVariable[] sDVariableArr) {
            this.inputFunctions = sDVariableArr;
            return this;
        }

        public Conv3DBuilder config(Conv3DConfig conv3DConfig) {
            this.config = conv3DConfig;
            return this;
        }

        public Conv3D build() {
            return new Conv3D(this.sameDiff, this.inputFunctions, this.config);
        }

        public String toString() {
            return "Conv3D.Conv3DBuilder(sameDiff=" + this.sameDiff + ", inputFunctions=" + Arrays.deepToString(this.inputFunctions) + ", config=" + this.config + ")";
        }
    }

    public Conv3D() {
    }

    public Conv3D(SameDiff sameDiff, SDVariable[] sDVariableArr, Conv3DConfig conv3DConfig) {
        super(sameDiff, sDVariableArr);
        initConfig(conv3DConfig);
    }

    public Conv3D(INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2, Conv3DConfig conv3DConfig) {
        super(iNDArrayArr, iNDArrayArr2);
        initConfig(conv3DConfig);
    }

    public Conv3D(@NonNull INDArray iNDArray, @NonNull INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, @NonNull Conv3DConfig conv3DConfig) {
        this((INDArray[]) wrapFilterNull(iNDArray, iNDArray2, iNDArray3), wrapOrNull(iNDArray4), conv3DConfig);
        if (iNDArray == null) {
            throw new NullPointerException("input is marked @NonNull but is null");
        }
        if (iNDArray2 == null) {
            throw new NullPointerException("weights is marked @NonNull but is null");
        }
        if (conv3DConfig == null) {
            throw new NullPointerException("config is marked @NonNull but is null");
        }
    }

    private void initConfig(Conv3DConfig conv3DConfig) {
        this.config = conv3DConfig;
        Preconditions.checkState(conv3DConfig.getSW() >= 1 && conv3DConfig.getPH() >= 0 && conv3DConfig.getDW() >= 1, INVALID_CONFIGURATION, conv3DConfig.getSW(), conv3DConfig.getPH(), conv3DConfig.getDW());
        addArgs();
    }

    private void addArgs() {
        long[] jArr = new long[14];
        jArr[0] = getConfig().getKD();
        jArr[1] = getConfig().getKH();
        jArr[2] = getConfig().getKW();
        jArr[3] = getConfig().getSD();
        jArr[4] = getConfig().getSH();
        jArr[5] = getConfig().getSW();
        jArr[6] = getConfig().getPD();
        jArr[7] = getConfig().getPH();
        jArr[8] = getConfig().getPW();
        jArr[9] = getConfig().getDD();
        jArr[10] = getConfig().getDH();
        jArr[11] = getConfig().getDW();
        jArr[12] = getConfig().isSameMode() ? 1L : 0L;
        jArr[13] = getConfig().isNCDHW() ? 0L : 1L;
        addIArgument(jArr);
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public Object getValue(Field field) {
        if (this.config == null && !this.iArguments.isEmpty()) {
            this.config = Conv3DConfig.builder().kD(this.iArguments.get(0).longValue()).kH(this.iArguments.get(1).longValue()).kW(this.iArguments.get(2).longValue()).sD(this.iArguments.get(3).longValue()).sH(this.iArguments.get(4).longValue()).sW(this.iArguments.get(5).longValue()).pD(this.iArguments.get(6).longValue()).pH(this.iArguments.get(7).longValue()).pW(this.iArguments.get(8).longValue()).dD(this.iArguments.get(9).longValue()).dH(this.iArguments.get(10).longValue()).dW(this.iArguments.get(11).longValue()).isSameMode(this.iArguments.get(12).longValue() == 1).dataFormat(this.iArguments.get(13).longValue() == 1 ? "NCDHW" : "NDHWC").build();
        }
        return this.config.getValue(field);
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.linalg.api.ops.CustomOp
    public long[] iArgs() {
        if (this.iArguments.size() == 0) {
            addArgs();
        }
        return super.iArgs();
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public Map<String, Map<String, AttributeAdapter>> attributeAdaptersForFunction() {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        LinkedHashMap linkedHashMap2 = new LinkedHashMap();
        DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(this);
        linkedHashMap2.put("kD", new NDArrayShapeAdapter(0));
        linkedHashMap2.put("kH", new NDArrayShapeAdapter(1));
        linkedHashMap2.put("kW", new NDArrayShapeAdapter(2));
        linkedHashMap2.put("sD", new IntArrayIntIndexAdpater(1));
        linkedHashMap2.put("sH", new IntArrayIntIndexAdpater(2));
        linkedHashMap2.put("sW", new IntArrayIntIndexAdpater(3));
        linkedHashMap2.put("pD", new IntArrayIntIndexAdpater(1));
        linkedHashMap2.put("pH", new IntArrayIntIndexAdpater(2));
        linkedHashMap2.put("pW", new IntArrayIntIndexAdpater(3));
        linkedHashMap2.put("isSameMode", new StringNotEqualsAdapter("VALID"));
        linkedHashMap.put(tensorflowName(), linkedHashMap2);
        return linkedHashMap;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public Map<String, Object> propertiesForFunction() {
        return this.config == null ? Collections.emptyMap() : this.config.toProperties();
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction, org.nd4j.linalg.api.ops.CustomOp
    public String opName() {
        return "conv3dnew";
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        for (PropertyMapping propertyMapping : new PropertyMapping[]{PropertyMapping.builder().propertyNames(new String[]{"kD", "kW", "kH"}).tfInputPosition(1).onnxAttrName("kernel_shape").build(), PropertyMapping.builder().tfAttrName("strides").onnxAttrName("strides").propertyNames(new String[]{"sD", "sW", "sH"}).build(), PropertyMapping.builder().onnxAttrName("dilations").propertyNames(new String[]{"dD", "dH", "dW"}).tfAttrName("rates").build(), PropertyMapping.builder().onnxAttrName("auto_pad").propertyNames(new String[]{"isSameMode"}).tfAttrName("padding").build(), PropertyMapping.builder().onnxAttrName("padding").propertyNames(new String[]{"pD", "pW", "pH"}).build(), PropertyMapping.builder().onnxAttrName("data_format").tfAttrName("data_format").propertyNames(new String[]{"dataFormat"}).build(), PropertyMapping.builder().propertyNames(new String[]{"aD", "aH", "aW"}).build(), PropertyMapping.builder().propertyNames(new String[]{"biasUsed"}).build()}) {
            for (String str : propertyMapping.getPropertyNames()) {
                hashMap2.put(str, propertyMapping);
            }
        }
        hashMap.put(tensorflowName(), hashMap2);
        return hashMap;
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff sameDiff, Map<String, AttrValue> map, GraphDef graphDef) {
        TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, map, nodeDef, graphDef);
        addArgs();
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public List<SDVariable> doDiff(List<SDVariable> list) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        arrayList2.addAll(Arrays.asList(args()));
        arrayList2.add(list.get(0));
        arrayList.addAll(Arrays.asList(Conv3DDerivative.derivativeBuilder().conv3DConfig(this.config).inputFunctions((SDVariable[]) arrayList2.toArray(new SDVariable[arrayList2.size()])).sameDiff(this.sameDiff).build().outputVariables()));
        return arrayList;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public boolean isConfigProperties() {
        return true;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public String configFieldName() {
        return "config";
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public String onnxName() {
        throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName());
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public String tensorflowName() {
        return "Conv3D";
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public List<DataType> calculateOutputDataTypes(List<DataType> list) {
        int length = args().length;
        Preconditions.checkState(list != null && list.size() == length, "Expected %s input data types for %s, got %s", Integer.valueOf(length), getClass(), list);
        return Collections.singletonList(list.get(0));
    }

    public static Conv3DBuilder sameDiffBuilder() {
        return new Conv3DBuilder();
    }

    public Conv3DConfig getConfig() {
        return this.config;
    }
}
