package org.nd4j.linalg.api.ops.impl.layers.convolution;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.NonNull;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
import org.nd4j.linalg.util.ArrayUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

/* loaded from: input_file:org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling2D.class */
public class MaxPooling2D extends DynamicCustomOp {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) MaxPooling2D.class);
    protected Pooling2DConfig config;

    /* loaded from: input_file:org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling2D$MaxPooling2DBuilder.class */
    public static class MaxPooling2DBuilder {
        private SameDiff sameDiff;
        private SDVariable input;
        private Pooling2DConfig config;

        MaxPooling2DBuilder() {
        }

        public MaxPooling2DBuilder sameDiff(SameDiff sameDiff) {
            this.sameDiff = sameDiff;
            return this;
        }

        public MaxPooling2DBuilder input(SDVariable sDVariable) {
            this.input = sDVariable;
            return this;
        }

        public MaxPooling2DBuilder config(Pooling2DConfig pooling2DConfig) {
            this.config = pooling2DConfig;
            return this;
        }

        public MaxPooling2D build() {
            return new MaxPooling2D(this.sameDiff, this.input, this.config);
        }

        public String toString() {
            return "MaxPooling2D.MaxPooling2DBuilder(sameDiff=" + this.sameDiff + ", input=" + this.input + ", config=" + this.config + ")";
        }
    }

    public MaxPooling2D() {
    }

    public MaxPooling2D(SameDiff sameDiff, SDVariable sDVariable, Pooling2DConfig pooling2DConfig) {
        super(null, sameDiff, new SDVariable[]{sDVariable}, false);
        pooling2DConfig.setType(Pooling2D.Pooling2DType.MAX);
        this.config = pooling2DConfig;
        addArgs();
    }

    public MaxPooling2D(INDArray iNDArray, INDArray iNDArray2, @NonNull Pooling2DConfig pooling2DConfig) {
        super((String) null, new INDArray[]{iNDArray}, wrapOrNull(iNDArray2));
        if (pooling2DConfig == null) {
            throw new NullPointerException("config is marked @NonNull but is null");
        }
        pooling2DConfig.setType(Pooling2D.Pooling2DType.MAX);
        this.config = pooling2DConfig;
        addArgs();
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public boolean isConfigProperties() {
        return true;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public String configFieldName() {
        return "config";
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public Map<String, Object> propertiesForFunction() {
        if (this.config == null && this.iArguments.size() > 0) {
            this.config = Pooling2DConfig.builder().kH(this.iArguments.get(0).longValue()).kW(this.iArguments.get(1).longValue()).sH(this.iArguments.get(2).longValue()).sW(this.iArguments.get(3).longValue()).pH(this.iArguments.get(4).longValue()).pW(this.iArguments.get(5).longValue()).dH(this.iArguments.get(6).longValue()).dW(this.iArguments.get(7).longValue()).isSameMode(this.iArguments.get(8).longValue() == 1).extra(this.iArguments.get(9).longValue()).isNHWC(this.iArguments.get(10).longValue() == 1).type(Pooling2D.Pooling2DType.MAX).build();
        }
        return this.config.toProperties();
    }

    private void addArgs() {
        addIArgument(this.config.getKH(), this.config.getKW(), this.config.getSH(), this.config.getSW(), this.config.getPH(), this.config.getPW(), this.config.getDH(), this.config.getDW(), ArrayUtil.fromBoolean(this.config.isSameMode()), (int) this.config.getExtra(), ArrayUtil.fromBoolean(this.config.isNHWC()));
    }

    public String getPoolingPrefix() {
        return "max";
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction, org.nd4j.linalg.api.ops.CustomOp
    public String opName() {
        return getPoolingPrefix() + "pool2d";
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public List<SDVariable> doDiff(List<SDVariable> list) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        arrayList2.addAll(Arrays.asList(args()));
        arrayList2.add(list.get(0));
        arrayList.addAll(Arrays.asList(Pooling2DDerivative.derivativeBuilder().inputs((SDVariable[]) arrayList2.toArray(new SDVariable[arrayList2.size()])).sameDiff(this.sameDiff).config(this.config).build().outputVariables()));
        return arrayList;
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff sameDiff, Map<String, AttrValue> map, GraphDef graphDef) {
        int intValue;
        int intValue2;
        int intValue3;
        int intValue4;
        int intValue5;
        int intValue6;
        List<Long> iList = nodeDef.getAttrOrThrow("strides").getList().getIList();
        List<Long> iList2 = nodeDef.getAttrOrThrow("ksize").getList().getIList();
        AttrValue attrOrThrow = nodeDef.getAttrOrThrow("padding");
        List<Long> iList3 = attrOrThrow.getList().getIList();
        boolean equalsIgnoreCase = attrOrThrow.getS().toStringUtf8().replaceAll("\"", "").equalsIgnoreCase("SAME");
        String lowerCase = nodeDef.containsAttr("data_format") ? nodeDef.getAttrOrThrow("data_format").getS().toStringUtf8().toLowerCase() : "nhwc";
        if (lowerCase.equalsIgnoreCase("nhwc")) {
            intValue = iList.get(1).intValue();
            intValue2 = iList.get(2).intValue();
            intValue3 = iList2.get(1).intValue();
            intValue4 = iList2.get(2).intValue();
            intValue5 = iList3.size() > 0 ? iList3.get(1).intValue() : 0;
            intValue6 = iList3.size() > 0 ? iList3.get(2).intValue() : 0;
        } else {
            intValue = iList.get(2).intValue();
            intValue2 = iList.get(3).intValue();
            intValue3 = iList2.get(2).intValue();
            intValue4 = iList2.get(3).intValue();
            intValue5 = iList3.size() > 0 ? iList3.get(2).intValue() : 0;
            intValue6 = iList3.size() > 0 ? iList3.get(3).intValue() : 0;
        }
        this.config = Pooling2DConfig.builder().sH(intValue).sW(intValue2).type(Pooling2D.Pooling2DType.MAX).isSameMode(equalsIgnoreCase).kH(intValue3).kW(intValue4).pH(intValue5).pW(intValue6).isNHWC(lowerCase.equalsIgnoreCase("nhwc")).extra(1.0d).build();
        addArgs();
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public void initFromOnnx(Onnx.NodeProto nodeProto, SameDiff sameDiff, Map<String, Onnx.AttributeProto> map, Onnx.GraphProto graphProto) {
        boolean equals = (!map.containsKey("auto_pad") ? "VALID" : map.get("auto_pad").getS().toStringUtf8()).equals("SAME");
        List<Long> intsList = map.get("kernel_shape").getIntsList();
        List<Long> intsList2 = map.get("pads").getIntsList();
        this.config = Pooling2DConfig.builder().sH(r0.get(0).intValue()).sW(map.get("strides").getIntsList().size() < 2 ? r0.get(0).intValue() : r0.get(1).intValue()).type(Pooling2D.Pooling2DType.MAX).isSameMode(equals).kH(intsList.get(0).intValue()).kW(intsList.size() < 2 ? intsList.get(0).intValue() : intsList.get(1).intValue()).pH(intsList2.get(0).intValue()).pW(intsList2.size() < 2 ? intsList2.get(0).intValue() : intsList2.get(1).intValue()).build();
        addArgs();
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        PropertyMapping build = PropertyMapping.builder().tfAttrName("strides").onnxAttrName("strides").propertyNames(new String[]{"sW", "sH"}).build();
        PropertyMapping build2 = PropertyMapping.builder().onnxAttrName("padding").tfAttrName("padding").propertyNames(new String[]{"pH", "pW"}).build();
        PropertyMapping build3 = PropertyMapping.builder().propertyNames(new String[]{"kH", "kW"}).tfInputPosition(1).onnxAttrName("ksize").build();
        PropertyMapping build4 = PropertyMapping.builder().onnxAttrName("dilations").propertyNames(new String[]{"dW", "dH"}).tfAttrName("rates").build();
        PropertyMapping build5 = PropertyMapping.builder().propertyNames(new String[]{"isNHWC"}).tfAttrName("data_format").build();
        hashMap2.put("sW", build);
        hashMap2.put("sH", build);
        hashMap2.put("kH", build3);
        hashMap2.put("kW", build3);
        hashMap2.put("dW", build4);
        hashMap2.put("dH", build4);
        hashMap2.put("pH", build2);
        hashMap2.put("pW", build2);
        hashMap2.put("isNHWC", build5);
        hashMap.put(onnxName(), hashMap2);
        hashMap.put(tensorflowName(), hashMap2);
        return hashMap;
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public String onnxName() {
        return "MaxPool";
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public String[] tensorflowNames() {
        return new String[]{"MaxPool", "MaxPoolV2"};
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public List<DataType> calculateOutputDataTypes(List<DataType> list) {
        Preconditions.checkState(list != null && list.size() == 1, "Expected 1 input data type for %s, got %s", getClass(), list);
        return Collections.singletonList(list.get(0));
    }

    public static MaxPooling2DBuilder sameDiffBuilder() {
        return new MaxPooling2DBuilder();
    }

    public Pooling2DConfig getConfig() {
        return this.config;
    }
}
