package org.deeplearning4j.nn.conf.graph;

import org.deeplearning4j.clustering.kdtree.KDTree;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.CnnToRnnPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToCnnPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.RnnToCnnPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/nn/conf/graph/PreprocessorVertex.class */
public class PreprocessorVertex extends GraphVertex {
    private InputPreProcessor preProcessor;
    private InputType outputType;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.deeplearning4j.nn.conf.graph.PreprocessorVertex$1, reason: invalid class name */
    /* loaded from: input_file:org/deeplearning4j/nn/conf/graph/PreprocessorVertex$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$deeplearning4j$nn$conf$inputs$InputType$Type = new int[InputType.Type.values().length];

        static {
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$inputs$InputType$Type[InputType.Type.FF.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$inputs$InputType$Type[InputType.Type.RNN.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$inputs$InputType$Type[InputType.Type.CNN.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    public PreprocessorVertex(InputPreProcessor inputPreProcessor) {
        this(inputPreProcessor, null);
    }

    public PreprocessorVertex(InputPreProcessor inputPreProcessor, InputType inputType) {
        this.preProcessor = inputPreProcessor;
        this.outputType = inputType;
    }

    @Override // org.deeplearning4j.nn.conf.graph.GraphVertex
    /* renamed from: clone */
    public GraphVertex mo42clone() {
        return new PreprocessorVertex(this.preProcessor.m57clone());
    }

    @Override // org.deeplearning4j.nn.conf.graph.GraphVertex
    public boolean equals(Object obj) {
        if (obj instanceof PreprocessorVertex) {
            return ((PreprocessorVertex) obj).preProcessor.equals(this.preProcessor);
        }
        return false;
    }

    @Override // org.deeplearning4j.nn.conf.graph.GraphVertex
    public int hashCode() {
        return this.preProcessor.hashCode();
    }

    @Override // org.deeplearning4j.nn.conf.graph.GraphVertex
    public int numParams(boolean z) {
        return 0;
    }

    @Override // org.deeplearning4j.nn.conf.graph.GraphVertex
    public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph computationGraph, String str, int i, INDArray iNDArray) {
        return new org.deeplearning4j.nn.graph.vertex.impl.PreprocessorVertex(computationGraph, str, i, this.preProcessor);
    }

    @Override // org.deeplearning4j.nn.conf.graph.GraphVertex
    public InputType getOutputType(InputType... inputTypeArr) throws InvalidInputTypeException {
        if (inputTypeArr.length != 1) {
            throw new InvalidInputTypeException("Invalid input: Preprocessor vertex expects exactly one input");
        }
        if (this.outputType != null) {
            return this.outputType;
        }
        switch (AnonymousClass1.$SwitchMap$org$deeplearning4j$nn$conf$inputs$InputType$Type[inputTypeArr[0].getType().ordinal()]) {
            case KDTree.GREATER /* 1 */:
                if (!(this.preProcessor instanceof FeedForwardToCnnPreProcessor)) {
                    return this.preProcessor instanceof FeedForwardToRnnPreProcessor ? InputType.recurrent(((InputType.InputTypeFeedForward) inputTypeArr[0]).getSize()) : InputType.feedForward(((InputType.InputTypeFeedForward) inputTypeArr[0]).getSize());
                }
                FeedForwardToCnnPreProcessor feedForwardToCnnPreProcessor = (FeedForwardToCnnPreProcessor) this.preProcessor;
                return InputType.convolutional(feedForwardToCnnPreProcessor.getNumChannels(), feedForwardToCnnPreProcessor.getInputWidth(), feedForwardToCnnPreProcessor.getInputHeight());
            case 2:
                if (!(this.preProcessor instanceof RnnToCnnPreProcessor)) {
                    return this.preProcessor instanceof RnnToFeedForwardPreProcessor ? InputType.feedForward(((InputType.InputTypeRecurrent) inputTypeArr[0]).getSize()) : InputType.recurrent(((InputType.InputTypeRecurrent) inputTypeArr[0]).getSize());
                }
                RnnToCnnPreProcessor rnnToCnnPreProcessor = (RnnToCnnPreProcessor) this.preProcessor;
                return InputType.convolutional(rnnToCnnPreProcessor.getNumChannels(), rnnToCnnPreProcessor.getInputWidth(), rnnToCnnPreProcessor.getInputHeight());
            case 3:
                if (this.preProcessor instanceof CnnToFeedForwardPreProcessor) {
                    CnnToFeedForwardPreProcessor cnnToFeedForwardPreProcessor = (CnnToFeedForwardPreProcessor) this.preProcessor;
                    return InputType.feedForward(cnnToFeedForwardPreProcessor.getInputHeight() * cnnToFeedForwardPreProcessor.getInputWidth() * cnnToFeedForwardPreProcessor.getNumChannels());
                }
                if (!(this.preProcessor instanceof CnnToRnnPreProcessor)) {
                    return inputTypeArr[0];
                }
                CnnToRnnPreProcessor cnnToRnnPreProcessor = (CnnToRnnPreProcessor) this.preProcessor;
                return InputType.recurrent(cnnToRnnPreProcessor.getInputHeight() * cnnToRnnPreProcessor.getInputWidth() * cnnToRnnPreProcessor.getNumChannels());
            default:
                throw new RuntimeException("Unknown InputType: " + inputTypeArr[0]);
        }
    }

    public PreprocessorVertex() {
    }

    public InputPreProcessor getPreProcessor() {
        return this.preProcessor;
    }

    public void setPreProcessor(InputPreProcessor inputPreProcessor) {
        this.preProcessor = inputPreProcessor;
    }

    public void setOutputType(InputType inputType) {
        this.outputType = inputType;
    }

    public String toString() {
        return "PreprocessorVertex(preProcessor=" + getPreProcessor() + ", outputType=" + getOutputType(new InputType[0]) + ")";
    }
}
