package org.deeplearning4j.nn.conf.layers;

import java.util.Map;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.samediff.SDLayerParams;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.deeplearning4j.util.CapsuleUtils;
import org.deeplearning4j.util.ValidationUtils;
import org.nd4j.autodiff.samediff.SDIndex;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.Distribution;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/nn/conf/layers/CapsuleLayer.class */
public class CapsuleLayer extends SameDiffLayer {
    private static final String WEIGHT_PARAM = "weight";
    private static final String BIAS_PARAM = "bias";
    private boolean hasBias;
    private long inputCapsules;
    private long inputCapsuleDimensions;
    private int capsules;
    private int capsuleDimensions;
    private int routings;

    /* loaded from: input_file:org/deeplearning4j/nn/conf/layers/CapsuleLayer$Builder.class */
    public static class Builder extends SameDiffLayer.Builder<Builder> {
        private int capsules;
        private int capsuleDimensions;
        private int routings;
        private boolean hasBias;
        private int inputCapsules;
        private int inputCapsuleDimensions;

        public Builder(int i, int i2) {
            this(i, i2, 3);
        }

        public Builder(int i, int i2, int i3) {
            this.routings = 3;
            this.hasBias = false;
            this.inputCapsules = 0;
            this.inputCapsuleDimensions = 0;
            setCapsules(i);
            setCapsuleDimensions(i2);
            setRoutings(i3);
        }

        @Override // org.deeplearning4j.nn.conf.layers.Layer.Builder
        public <E extends Layer> E build() {
            return new CapsuleLayer(this);
        }

        public Builder capsules(int i) {
            setCapsules(i);
            return this;
        }

        public Builder capsuleDimensions(int i) {
            setCapsuleDimensions(i);
            return this;
        }

        public Builder routings(int i) {
            setRoutings(i);
            return this;
        }

        public Builder inputCapsules(int i) {
            setInputCapsules(i);
            return this;
        }

        public Builder inputCapsuleDimensions(int i) {
            setInputCapsuleDimensions(i);
            return this;
        }

        public Builder inputShape(int... iArr) {
            int[] validate2NonNegative = ValidationUtils.validate2NonNegative(iArr, false, "inputShape");
            setInputCapsules(validate2NonNegative[0]);
            setInputCapsuleDimensions(validate2NonNegative[1]);
            return this;
        }

        public Builder hasBias(boolean z) {
            setHasBias(z);
            return this;
        }

        public int getCapsules() {
            return this.capsules;
        }

        public int getCapsuleDimensions() {
            return this.capsuleDimensions;
        }

        public int getRoutings() {
            return this.routings;
        }

        public boolean isHasBias() {
            return this.hasBias;
        }

        public int getInputCapsules() {
            return this.inputCapsules;
        }

        public int getInputCapsuleDimensions() {
            return this.inputCapsuleDimensions;
        }

        public void setCapsules(int i) {
            this.capsules = i;
        }

        public void setCapsuleDimensions(int i) {
            this.capsuleDimensions = i;
        }

        public void setRoutings(int i) {
            this.routings = i;
        }

        public void setHasBias(boolean z) {
            this.hasBias = z;
        }

        public void setInputCapsules(int i) {
            this.inputCapsules = i;
        }

        public void setInputCapsuleDimensions(int i) {
            this.inputCapsuleDimensions = i;
        }
    }

    public CapsuleLayer(Builder builder) {
        super(builder);
        this.hasBias = false;
        this.inputCapsules = 0L;
        this.inputCapsuleDimensions = 0L;
        this.routings = 3;
        this.hasBias = builder.hasBias;
        this.inputCapsules = builder.inputCapsules;
        this.inputCapsuleDimensions = builder.inputCapsuleDimensions;
        this.capsules = builder.capsules;
        this.capsuleDimensions = builder.capsuleDimensions;
        this.routings = builder.routings;
        if (this.capsules <= 0 || this.capsuleDimensions <= 0 || this.routings <= 0) {
            throw new IllegalArgumentException("Invalid configuration for Capsule Layer (layer name = \"" + this.layerName + "\"): capsules, capsuleDimensions, and routings must be > 0.  Got: " + this.capsules + ", " + this.capsuleDimensions + ", " + this.routings);
        }
        if (this.inputCapsules < 0 || this.inputCapsuleDimensions < 0) {
            throw new IllegalArgumentException("Invalid configuration for Capsule Layer (layer name = \"" + this.layerName + "\"): inputCapsules and inputCapsuleDimensions must be >= 0 if set.  Got: " + this.inputCapsules + ", " + this.inputCapsuleDimensions);
        }
    }

    @Override // org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer, org.deeplearning4j.nn.conf.layers.Layer
    public void setNIn(InputType inputType, boolean z) {
        if (inputType == null || inputType.getType() != InputType.Type.RNN) {
            throw new IllegalStateException("Invalid input for Capsule layer (layer name = \"" + this.layerName + "\"): expect RNN input.  Got: " + inputType);
        }
        if (this.inputCapsules <= 0 || this.inputCapsuleDimensions <= 0) {
            InputType.InputTypeRecurrent inputTypeRecurrent = (InputType.InputTypeRecurrent) inputType;
            this.inputCapsules = inputTypeRecurrent.getSize();
            this.inputCapsuleDimensions = inputTypeRecurrent.getTimeSeriesLength();
        }
    }

    @Override // org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer
    public SDVariable defineLayer(SameDiff sameDiff, SDVariable sDVariable, Map<String, SDVariable> map, SDVariable sDVariable2) {
        SDVariable reshape = map.get(WEIGHT_PARAM).times(sameDiff.tile(sameDiff.expandDims(sameDiff.expandDims(sDVariable, 2), 4), 1, 1, this.capsules * this.capsuleDimensions, 1, 1)).sum(true, 3).reshape(-1, this.inputCapsules, this.capsules, this.capsuleDimensions, 1);
        SDVariable sDVariable3 = sameDiff.zerosLike(reshape).get(SDIndex.all(), SDIndex.all(), SDIndex.all(), SDIndex.interval((Integer) 0, (Integer) 1), SDIndex.interval((Integer) 0, (Integer) 1));
        for (int i = 0; i < this.routings; i++) {
            SDVariable sum = CapsuleUtils.softmax(sameDiff, sDVariable3, 2, 5).times(reshape).sum(true, 1);
            if (this.hasBias) {
                sum = sum.plus(map.get(BIAS_PARAM));
            }
            SDVariable squash = CapsuleUtils.squash(sameDiff, sum, 3);
            if (i == this.routings - 1) {
                return sameDiff.squeeze(sameDiff.squeeze(squash, 1), 3);
            }
            sDVariable3 = sDVariable3.plus(reshape.times(sameDiff.tile(squash, 1, (int) this.inputCapsules, 1, 1, 1)).sum(true, 3));
        }
        return null;
    }

    @Override // org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer
    public void defineParameters(SDLayerParams sDLayerParams) {
        sDLayerParams.clear();
        sDLayerParams.addWeightParam(WEIGHT_PARAM, 1, this.inputCapsules, this.capsules * this.capsuleDimensions, this.inputCapsuleDimensions, 1);
        if (this.hasBias) {
            sDLayerParams.addBiasParam(BIAS_PARAM, 1, 1, this.capsules, this.capsuleDimensions, 1);
        }
    }

    @Override // org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer
    public void initializeParameters(Map<String, INDArray> map) {
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            try {
                for (Map.Entry<String, INDArray> entry : map.entrySet()) {
                    if (BIAS_PARAM.equals(entry.getKey())) {
                        entry.getValue().assign((Number) 0);
                    } else if (WEIGHT_PARAM.equals(entry.getKey())) {
                        WeightInitUtil.initWeights(this.inputCapsules * this.inputCapsuleDimensions, this.capsules * this.capsuleDimensions, new long[]{1, this.inputCapsules, this.capsules * this.capsuleDimensions, this.inputCapsuleDimensions, 1}, this.weightInit, (Distribution) null, 'c', entry.getValue());
                    }
                }
                if (scopeOutOfWorkspaces != null) {
                    if (0 == 0) {
                        scopeOutOfWorkspaces.close();
                        return;
                    }
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
            } catch (Throwable th3) {
                th = th3;
                throw th3;
            }
        } catch (Throwable th4) {
            if (scopeOutOfWorkspaces != null) {
                if (th != null) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th5) {
                        th.addSuppressed(th5);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
            throw th4;
        }
    }

    @Override // org.deeplearning4j.nn.conf.layers.Layer
    public InputType getOutputType(int i, InputType inputType) {
        return InputType.recurrent(this.capsules, this.capsuleDimensions);
    }

    public boolean isHasBias() {
        return this.hasBias;
    }

    public long getInputCapsules() {
        return this.inputCapsules;
    }

    public long getInputCapsuleDimensions() {
        return this.inputCapsuleDimensions;
    }

    public int getCapsules() {
        return this.capsules;
    }

    public int getCapsuleDimensions() {
        return this.capsuleDimensions;
    }

    public int getRoutings() {
        return this.routings;
    }

    public void setHasBias(boolean z) {
        this.hasBias = z;
    }

    public void setInputCapsules(long j) {
        this.inputCapsules = j;
    }

    public void setInputCapsuleDimensions(long j) {
        this.inputCapsuleDimensions = j;
    }

    public void setCapsules(int i) {
        this.capsules = i;
    }

    public void setCapsuleDimensions(int i) {
        this.capsuleDimensions = i;
    }

    public void setRoutings(int i) {
        this.routings = i;
    }

    @Override // org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer, org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer, org.deeplearning4j.nn.conf.layers.Layer
    public String toString() {
        return "CapsuleLayer(hasBias=" + isHasBias() + ", inputCapsules=" + getInputCapsules() + ", inputCapsuleDimensions=" + getInputCapsuleDimensions() + ", capsules=" + getCapsules() + ", capsuleDimensions=" + getCapsuleDimensions() + ", routings=" + getRoutings() + ")";
    }

    public CapsuleLayer() {
        this.hasBias = false;
        this.inputCapsules = 0L;
        this.inputCapsuleDimensions = 0L;
        this.routings = 3;
    }

    @Override // org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer, org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer, org.deeplearning4j.nn.conf.layers.Layer
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof CapsuleLayer)) {
            return false;
        }
        CapsuleLayer capsuleLayer = (CapsuleLayer) obj;
        return capsuleLayer.canEqual(this) && super.equals(obj) && isHasBias() == capsuleLayer.isHasBias() && getInputCapsules() == capsuleLayer.getInputCapsules() && getInputCapsuleDimensions() == capsuleLayer.getInputCapsuleDimensions() && getCapsules() == capsuleLayer.getCapsules() && getCapsuleDimensions() == capsuleLayer.getCapsuleDimensions() && getRoutings() == capsuleLayer.getRoutings();
    }

    @Override // org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer, org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer, org.deeplearning4j.nn.conf.layers.Layer
    protected boolean canEqual(Object obj) {
        return obj instanceof CapsuleLayer;
    }

    @Override // org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer, org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer, org.deeplearning4j.nn.conf.layers.Layer
    public int hashCode() {
        int hashCode = (super.hashCode() * 59) + (isHasBias() ? 79 : 97);
        long inputCapsules = getInputCapsules();
        int i = (hashCode * 59) + ((int) ((inputCapsules >>> 32) ^ inputCapsules));
        long inputCapsuleDimensions = getInputCapsuleDimensions();
        return (((((((i * 59) + ((int) ((inputCapsuleDimensions >>> 32) ^ inputCapsuleDimensions))) * 59) + getCapsules()) * 59) + getCapsuleDimensions()) * 59) + getRoutings();
    }
}
