package de.datexis.cdv.tagger;

import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.graph.L2NormalizeVertex;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.listeners.PerformanceListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.schedule.ExponentialSchedule;
import org.nd4j.linalg.schedule.ScheduleType;

/* loaded from: input_file:de/datexis/cdv/tagger/CDVModelBuilder.class */
public class CDVModelBuilder {
    public static ComputationGraph buildSingleTaskCDV(long j, long j2, long j3, long j4, long j5, double d, double d2, double d3, ILossFunction iLossFunction, Activation activation) {
        ComputationGraph computationGraph = new ComputationGraph(new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Adam(new ExponentialSchedule(ScheduleType.EPOCH, d, 0.95d))).weightInit(WeightInit.XAVIER).weightDecay(d3).dropOut(0.0d).trainingWorkspaceMode(WorkspaceMode.ENABLED).inferenceWorkspaceMode(WorkspaceMode.ENABLED).cacheMode(CacheMode.HOST).graphBuilder().addInputs(new String[]{"input"}).addInputs(new String[]{"position"}).addVertex("sentence", new MergeVertex(), new String[]{"input", "position"}).addLayer("BLSTM", new Bidirectional(Bidirectional.Mode.CONCAT, new LSTM.Builder().nIn(j + j2).nOut(j3).activation(Activation.TANH).gateActivationFunction(Activation.SIGMOID).dropOut(d2).build()), new String[]{"sentence"}).addLayer("bottleneck", new DenseLayer.Builder().nIn(2 * j3).nOut(j4).activation(Activation.TANH).build(), new String[]{"BLSTM"}).addVertex("embedding", new L2NormalizeVertex(new int[]{1}, 1.0E-8d), new String[]{"bottleneck"}).addLayer("target", new RnnOutputLayer.Builder(iLossFunction).nIn(j4).nOut(j5).activation(activation).build(), new String[]{"embedding"}).setOutputs(new String[]{"target"}).setInputTypes(new InputType[]{InputType.recurrent(j), InputType.recurrent(j2)}).backpropType(BackpropType.Standard).build());
        computationGraph.init();
        computationGraph.setListeners(new TrainingListener[]{new PerformanceListener(16, true)});
        return computationGraph;
    }

    public static ComputationGraph buildMultiTaskCDV(long j, long j2, long j3, long j4, long j5, long j6, double d, double d2, double d3, ILossFunction iLossFunction, Activation activation) {
        ComputationGraph computationGraph = new ComputationGraph(new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Adam(new ExponentialSchedule(ScheduleType.EPOCH, d, 0.975d))).weightInit(WeightInit.XAVIER).weightDecay(d3).dropOut(0.0d).trainingWorkspaceMode(WorkspaceMode.ENABLED).inferenceWorkspaceMode(WorkspaceMode.ENABLED).cacheMode(CacheMode.HOST).graphBuilder().addInputs(new String[]{"input"}).addInputs(new String[]{"position"}).addVertex("sentence", new MergeVertex(), new String[]{"input", "position"}).addLayer("BLSTM", new Bidirectional(Bidirectional.Mode.CONCAT, new LSTM.Builder().nIn(j + j2).nOut(j3).activation(Activation.TANH).gateActivationFunction(Activation.SIGMOID).dropOut(d2).build()), new String[]{"sentence"}).addLayer("embedding", new DenseLayer.Builder().nIn(2 * j3).nOut(j4).activation(Activation.TANH).build(), new String[]{"BLSTM"}).addLayer("dense_entity", new DenseLayer.Builder().nIn(j4).nOut(j5).activation(Activation.TANH).build(), new String[]{"embedding"}).addLayer("dense_aspect", new DenseLayer.Builder().nIn(j4).nOut(j6).activation(Activation.TANH).build(), new String[]{"embedding"}).addVertex("emb_entity", new L2NormalizeVertex(new int[]{1}, 1.0E-8d), new String[]{"dense_entity"}).addVertex("emb_aspect", new L2NormalizeVertex(new int[]{1}, 1.0E-8d), new String[]{"dense_aspect"}).addLayer("entity", new RnnOutputLayer.Builder(iLossFunction).nIn(j5).nOut(j5).activation(activation).build(), new String[]{"emb_entity"}).addLayer("aspect", new RnnOutputLayer.Builder(iLossFunction).nIn(j6).nOut(j6).activation(activation).build(), new String[]{"emb_aspect"}).setOutputs(new String[]{"entity", "aspect"}).setInputTypes(new InputType[]{InputType.recurrent(j), InputType.recurrent(j2)}).backpropType(BackpropType.Standard).build());
        computationGraph.init();
        computationGraph.setListeners(new TrainingListener[]{new PerformanceListener(32, true), new ScoreIterationListener(4)});
        return computationGraph;
    }
}
