package org.deeplearning4j.models.featuredetectors.autoencoder;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.deeplearning4j.models.featuredetectors.autoencoder.AutoEncoder;
import org.deeplearning4j.nn.BaseMultiLayerNetwork;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.layers.OutputLayer;
import org.nd4j.linalg.api.activation.ActivationFunction;
import org.nd4j.linalg.api.activation.Activations;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.transformation.MatrixTransform;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/models/featuredetectors/autoencoder/SemanticHashing.class */
public class SemanticHashing extends BaseMultiLayerNetwork {
    private static final long serialVersionUID = -3571832097247806784L;
    private BaseMultiLayerNetwork encoder;
    private static Logger log = LoggerFactory.getLogger(SemanticHashing.class);

    /* loaded from: input_file:org/deeplearning4j/models/featuredetectors/autoencoder/SemanticHashing$Builder.class */
    public static class Builder extends BaseMultiLayerNetwork.Builder<SemanticHashing> {
        private BaseMultiLayerNetwork encoder;

        public Builder() {
            this.clazz = SemanticHashing.class;
        }

        public Builder withEncoder(BaseMultiLayerNetwork baseMultiLayerNetwork) {
            this.encoder = baseMultiLayerNetwork;
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: useGaussNewtonVectorProductBackProp, reason: merged with bridge method [inline-methods] */
        public BaseMultiLayerNetwork.Builder<SemanticHashing> useGaussNewtonVectorProductBackProp2(boolean z) {
            super.useGaussNewtonVectorProductBackProp2(z);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: useDropConnection, reason: merged with bridge method [inline-methods] */
        public BaseMultiLayerNetwork.Builder<SemanticHashing> useDropConnection2(boolean z) {
            super.useDropConnection2(z);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        public BaseMultiLayerNetwork.Builder<SemanticHashing> withVisibleBiasTransforms(Map<Integer, MatrixTransform> map) {
            super.withVisibleBiasTransforms(map);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        public BaseMultiLayerNetwork.Builder<SemanticHashing> withHiddenBiasTransforms(Map<Integer, MatrixTransform> map) {
            super.withHiddenBiasTransforms(map);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: forceIterations, reason: merged with bridge method [inline-methods] */
        public BaseMultiLayerNetwork.Builder<SemanticHashing> forceIterations2() {
            super.forceIterations2();
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: disableBackProp, reason: merged with bridge method [inline-methods] */
        public BaseMultiLayerNetwork.Builder<SemanticHashing> disableBackProp2() {
            super.disableBackProp2();
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: transformWeightsAt, reason: merged with bridge method [inline-methods] */
        public BaseMultiLayerNetwork.Builder<SemanticHashing> transformWeightsAt2(int i, MatrixTransform matrixTransform) {
            super.transformWeightsAt2(i, matrixTransform);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        public BaseMultiLayerNetwork.Builder<SemanticHashing> transformWeightsAt(Map<Integer, MatrixTransform> map) {
            super.transformWeightsAt(map);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: hiddenLayerSizes, reason: merged with bridge method [inline-methods] */
        public BaseMultiLayerNetwork.Builder<SemanticHashing> hiddenLayerSizes2(Integer[] numArr) {
            super.hiddenLayerSizes2(numArr);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: hiddenLayerSizes, reason: merged with bridge method [inline-methods] */
        public BaseMultiLayerNetwork.Builder<SemanticHashing> hiddenLayerSizes2(int[] iArr) {
            super.hiddenLayerSizes2(iArr);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        public BaseMultiLayerNetwork.Builder<SemanticHashing> layerWiseConfiguration(List<NeuralNetConfiguration> list) {
            super.layerWiseConfiguration(list);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: withInput, reason: merged with bridge method [inline-methods] */
        public BaseMultiLayerNetwork.Builder<SemanticHashing> withInput2(INDArray iNDArray) {
            super.withInput2(iNDArray);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: withLabels, reason: merged with bridge method [inline-methods] */
        public BaseMultiLayerNetwork.Builder<SemanticHashing> withLabels2(INDArray iNDArray) {
            super.withLabels2(iNDArray);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        public BaseMultiLayerNetwork.Builder<SemanticHashing> withClazz(Class<? extends BaseMultiLayerNetwork> cls) {
            super.withClazz(cls);
            return this;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        public SemanticHashing buildEmpty() {
            return (SemanticHashing) super.buildEmpty();
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        public SemanticHashing build() {
            int length = this.encoder.getNeuralNets().length - 1;
            NeuralNetwork[] neuralNetworkArr = new NeuralNetwork[(this.encoder.getNeuralNets().length * 2) - 1];
            Layer[] layerArr = new Layer[neuralNetworkArr.length + 1];
            for (int i = 0; i < neuralNetworkArr.length; i++) {
                if (i < this.encoder.getNeuralNets().length) {
                    AutoEncoder build = new AutoEncoder.Builder().configure(this.encoder.getNeuralNets()[i].conf().m21clone()).withVisibleBias2(this.encoder.getNeuralNets()[i].getvBias().dup()).withHBias2(this.encoder.getNeuralNets()[i].gethBias().dup()).build();
                    int rows = build.getW().rows();
                    int columns = build.getW().columns();
                    Layer mo26clone = this.encoder.getLayers()[i].mo26clone();
                    mo26clone.setConfiguration(build.conf());
                    layerArr[i] = mo26clone;
                    neuralNetworkArr[i] = build;
                    layerArr[i].setB(build.gethBias());
                    layerArr[i].setW(build.getW());
                    layerArr[i].conf().setnIn(rows);
                    layerArr[i].conf().setnOut(columns);
                    neuralNetworkArr[i].conf().setnIn(rows);
                    neuralNetworkArr[i].conf().setnOut(columns);
                    if (i == this.encoder.getNeuralNets().length - 1) {
                        build.conf().setActivationFunction(Activations.linear());
                    }
                } else {
                    NeuralNetConfiguration m21clone = this.encoder.getNeuralNets()[length].conf().m21clone();
                    AutoEncoder build2 = new AutoEncoder.Builder().configure(m21clone).withWeights2(this.encoder.getNeuralNets()[length].getW().transpose()).withVisibleBias2(this.encoder.getNeuralNets()[length].gethBias().dup()).withHBias2(this.encoder.getNeuralNets()[length].getvBias().dup()).build();
                    int rows2 = build2.getW().rows();
                    int columns2 = build2.getW().columns();
                    m21clone.setnIn(rows2);
                    m21clone.setnOut(columns2);
                    neuralNetworkArr[i] = build2;
                    layerArr[i] = this.encoder.getLayers()[length].transpose();
                    layerArr[i].setConfiguration(m21clone);
                    layerArr[i].setB(build2.gethBias());
                    layerArr[i].setW(build2.getW());
                    length--;
                }
            }
            OutputLayer build3 = new OutputLayer.Builder().configure(this.encoder.getNeuralNets()[0].conf()).withBias(this.encoder.getNeuralNets()[0].getvBias()).withWeights(this.encoder.getNeuralNets()[0].getW().transpose()).build();
            build3.conf().setLossFunction(this.encoder.getOutputLayer().conf().getLossFunction());
            build3.conf().setActivationType(NeuralNetConfiguration.ActivationType.HIDDEN_LAYER_ACTIVATION);
            build3.conf().setnIn(build3.getW().rows());
            build3.conf().setnOut(build3.getW().columns());
            layerArr[layerArr.length - 1] = build3;
            SemanticHashing semanticHashing = new SemanticHashing();
            semanticHashing.setLayers(layerArr);
            semanticHashing.setNeuralNets(neuralNetworkArr);
            semanticHashing.setDefaultConfiguration(this.conf);
            semanticHashing.setUseDropConnect(this.encoder.isUseDropConnect());
            semanticHashing.setUseGaussNewtonVectorProductBackProp(this.encoder.isUseGaussNewtonVectorProductBackProp());
            semanticHashing.setSampleFromHiddenActivations(this.encoder.isSampleFromHiddenActivations());
            semanticHashing.setForceNumEpochs(this.shouldForceEpochs);
            ArrayList arrayList = new ArrayList();
            for (int i2 = 0; i2 < semanticHashing.layers.length; i2++) {
                arrayList.add(semanticHashing.layers[i2].conf());
            }
            semanticHashing.setLayerWiseConfigurations(arrayList);
            semanticHashing.setDefaultConfiguration(arrayList.get(0));
            semanticHashing.dimensionCheck();
            return semanticHashing;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: withClazz, reason: avoid collision after fix types in other method */
        public /* bridge */ /* synthetic */ BaseMultiLayerNetwork.Builder<SemanticHashing> withClazz2(Class cls) {
            return withClazz((Class<? extends BaseMultiLayerNetwork>) cls);
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: transformWeightsAt, reason: avoid collision after fix types in other method */
        public /* bridge */ /* synthetic */ BaseMultiLayerNetwork.Builder<SemanticHashing> transformWeightsAt2(Map map) {
            return transformWeightsAt((Map<Integer, MatrixTransform>) map);
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: withHiddenBiasTransforms, reason: avoid collision after fix types in other method */
        public /* bridge */ /* synthetic */ BaseMultiLayerNetwork.Builder<SemanticHashing> withHiddenBiasTransforms2(Map map) {
            return withHiddenBiasTransforms((Map<Integer, MatrixTransform>) map);
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: withVisibleBiasTransforms, reason: avoid collision after fix types in other method */
        public /* bridge */ /* synthetic */ BaseMultiLayerNetwork.Builder<SemanticHashing> withVisibleBiasTransforms2(Map map) {
            return withVisibleBiasTransforms((Map<Integer, MatrixTransform>) map);
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: layerWiseConfiguration, reason: avoid collision after fix types in other method */
        public /* bridge */ /* synthetic */ BaseMultiLayerNetwork.Builder<SemanticHashing> layerWiseConfiguration2(List list) {
            return layerWiseConfiguration((List<NeuralNetConfiguration>) list);
        }
    }

    @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork
    public void pretrain(INDArray iNDArray, Object[] objArr) {
        throw new IllegalStateException("Not implemented");
    }

    @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork
    public NeuralNetwork createLayer(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, int i) {
        throw new IllegalStateException("Not implemented");
    }

    @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork
    public List<INDArray> computeDeltasR(INDArray iNDArray) {
        ArrayList arrayList = new ArrayList();
        INDArray[] iNDArrayArr = new INDArray[getnLayers() + 1];
        List<INDArray> feedForward = feedForward();
        List<INDArray> feedForwardR = feedForwardR(feedForward, iNDArray);
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        ArrayList arrayList4 = new ArrayList();
        for (int i = 0; i < getNeuralNets().length; i++) {
            arrayList2.add(getNeuralNets()[i].getW());
            arrayList3.add(getNeuralNets()[i].gethBias());
            arrayList4.add(((AutoEncoder) getNeuralNets()[i]).conf().getActivationFunction());
        }
        arrayList2.add(getOutputLayer().getW());
        arrayList3.add(getOutputLayer().getB());
        arrayList4.add(getOutputLayer().conf().getActivationFunction());
        INDArray div = feedForwardR.get(feedForwardR.size() - 1).div(Integer.valueOf(this.input.rows()));
        for (int i2 = getnLayers(); i2 >= 0; i2--) {
            iNDArrayArr[i2] = feedForward.get(i2).transpose().mmul(div);
            applyDropConnectIfNecessary(iNDArrayArr[i2]);
            if (i2 > 0) {
                div = div.mmul(((INDArray) arrayList2.get(i2)).addRowVector((INDArray) arrayList3.get(i2)).transpose()).muli(((ActivationFunction) arrayList4.get(i2 - 1)).applyDerivative(feedForward.get(i2)));
            }
        }
        for (int i3 = 0; i3 < iNDArrayArr.length; i3++) {
            if (this.layerWiseConfigurations.get(i3).isConstrainGradientToUnitNorm()) {
                arrayList.add(iNDArrayArr[i3].div(iNDArrayArr[i3].norm2(Integer.MAX_VALUE)));
            } else {
                arrayList.add(iNDArrayArr[i3]);
            }
        }
        return arrayList;
    }

    @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork
    public List<INDArray> feedForwardR(List<INDArray> list, INDArray iNDArray) {
        ArrayList arrayList = new ArrayList();
        arrayList.add(Nd4j.zeros(this.input.rows(), this.input.columns()));
        List<Pair<INDArray, INDArray>> unPack = unPack(iNDArray);
        List<INDArray> weightMatrices = weightMatrices();
        for (int i = 0; i < this.neuralNets.length; i++) {
            arrayList.add(((INDArray) arrayList.get(i)).mmul(weightMatrices.get(i)).add(list.get(i).mmul(unPack.get(i).getFirst().addRowVector(unPack.get(i).getSecond())).add(1)).mul(((AutoEncoder) getNeuralNets()[i]).conf().getActivationFunction().applyDerivative(list.get(i + 1))));
        }
        arrayList.add(((INDArray) arrayList.get(arrayList.size() - 1)).mmul(weightMatrices.get(weightMatrices.size() - 1)).add(list.get(list.size() - 2).mmul(unPack.get(unPack.size() - 1).getFirst().addRowVector(unPack.get(unPack.size() - 1).getSecond()))).mul(getOutputLayer().conf().getActivationFunction().applyDerivative(list.get(list.size() - 1))));
        return arrayList;
    }

    @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork
    public void pretrain(DataSetIterator dataSetIterator, Object[] objArr) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork
    public NeuralNetwork[] createNetworkLayers(int i) {
        return new NeuralNetwork[i];
    }

    @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork
    public void finetune(INDArray iNDArray) {
        this.input = iNDArray;
        setInput(iNDArray);
        setLabels(iNDArray);
        super.finetune(iNDArray);
    }

    @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork
    public List<Pair<INDArray, INDArray>> computeDeltas2() {
        ArrayList arrayList = new ArrayList();
        List<INDArray> feedForward = feedForward();
        INDArray[] iNDArrayArr = new INDArray[feedForward.size() - 1];
        INDArray[] iNDArrayArr2 = new INDArray[feedForward.size() - 1];
        INDArray divi = feedForward.get(feedForward.size() - 1).sub(this.labels).divi(Integer.valueOf(this.labels.rows()));
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        ArrayList arrayList4 = new ArrayList();
        for (int i = 0; i < getNeuralNets().length; i++) {
            arrayList2.add(getNeuralNets()[i].getW());
            arrayList3.add(getNeuralNets()[i].gethBias());
            arrayList4.add(((AutoEncoder) getNeuralNets()[i]).conf().getActivationFunction());
        }
        arrayList3.add(getOutputLayer().getB());
        arrayList2.add(getOutputLayer().getW());
        arrayList4.add(getOutputLayer().conf().getActivationFunction());
        for (int size = arrayList2.size() - 1; size >= 0; size--) {
            iNDArrayArr[size] = feedForward.get(size).transpose().mmul(divi);
            iNDArrayArr2[size] = Transforms.pow(feedForward.get(size).transpose(), 2).mmul(Transforms.pow(divi.dup(), 2)).muli(Integer.valueOf(this.labels.rows()));
            applyDropConnectIfNecessary(iNDArrayArr[size]);
            if (size > 0) {
                divi = divi.mmul(((INDArray) arrayList2.get(size)).transpose()).muli(((ActivationFunction) arrayList4.get(size - 1)).applyDerivative(feedForward.get(size)));
            }
        }
        for (int i2 = 0; i2 < iNDArrayArr.length; i2++) {
            if (this.layerWiseConfigurations.get(i2).isConstrainGradientToUnitNorm()) {
                arrayList.add(new Pair(iNDArrayArr[i2].divi(iNDArrayArr[i2].norm2(Integer.MAX_VALUE)), iNDArrayArr2[i2]));
            } else {
                arrayList.add(new Pair(iNDArrayArr[i2], iNDArrayArr2[i2]));
            }
        }
        return arrayList;
    }

    public BaseMultiLayerNetwork getEncoder() {
        return this.encoder;
    }

    public void setEncoder(BaseMultiLayerNetwork baseMultiLayerNetwork) {
        this.encoder = baseMultiLayerNetwork;
    }

    @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork, org.deeplearning4j.nn.api.Classifier
    public double score(DataSet dataSet) {
        return 0.0d;
    }

    @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork, org.deeplearning4j.nn.api.Classifier
    public int numLabels() {
        return 0;
    }

    @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork, org.deeplearning4j.nn.api.Classifier
    public INDArray labelProbabilities(INDArray iNDArray) {
        return null;
    }

    @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork, org.deeplearning4j.nn.api.Classifier
    public void fit(INDArray iNDArray, INDArray iNDArray2) {
        fit((DataSet) new org.nd4j.linalg.dataset.DataSet(iNDArray, iNDArray2));
    }

    @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork, org.deeplearning4j.nn.api.Classifier
    public void fit(DataSet dataSet) {
        this.input = dataSet.getFeatureMatrix();
        finetune(dataSet.getLabels());
    }

    @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork, org.deeplearning4j.nn.api.Classifier
    public void fit(INDArray iNDArray, int[] iArr) {
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void fit(INDArray iNDArray, Object[] objArr) {
    }
}
