package org.jetbrains.kotlinx.dl.api.core;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileOutputStream;
import java.io.OutputStreamWriter;
import java.io.Writer;
import java.nio.FloatBuffer;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.nio.file.attribute.FileAttribute;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.TuplesKt;
import kotlin.Unit;
import kotlin.collections.ArraysKt;
import kotlin.collections.CollectionsKt;
import kotlin.collections.MapsKt;
import kotlin.io.CloseableKt;
import kotlin.io.FilesKt;
import kotlin.jdk7.AutoCloseableKt;
import kotlin.jvm.functions.Function0;
import kotlin.jvm.functions.Function1;
import kotlin.jvm.functions.Function2;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import kotlin.jvm.internal.LongSpreadBuilder;
import kotlin.jvm.internal.Ref;
import kotlin.jvm.internal.Reflection;
import kotlin.ranges.RangesKt;
import kotlin.text.Charsets;
import kotlin.text.StringsKt;
import mu.KLogger;
import mu.KotlinLogging;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.jetbrains.kotlinx.dl.api.core.callback.Callback;
import org.jetbrains.kotlinx.dl.api.core.exception.RepeatableLayerNameException;
import org.jetbrains.kotlinx.dl.api.core.history.BatchEvent;
import org.jetbrains.kotlinx.dl.api.core.history.History;
import org.jetbrains.kotlinx.dl.api.core.history.TrainingHistory;
import org.jetbrains.kotlinx.dl.api.core.layer.KVariable;
import org.jetbrains.kotlinx.dl.api.core.layer.Layer;
import org.jetbrains.kotlinx.dl.api.core.layer.ParametrizedLayerKt;
import org.jetbrains.kotlinx.dl.api.core.layer.TrainableLayerKt;
import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.Conv1D;
import org.jetbrains.kotlinx.dl.api.core.layer.core.ActivationLayer;
import org.jetbrains.kotlinx.dl.api.core.layer.core.Dense;
import org.jetbrains.kotlinx.dl.api.core.layer.core.Input;
import org.jetbrains.kotlinx.dl.api.core.loss.LossFunction;
import org.jetbrains.kotlinx.dl.api.core.loss.Losses;
import org.jetbrains.kotlinx.dl.api.core.loss.SoftmaxCrossEntropyWithLogits;
import org.jetbrains.kotlinx.dl.api.core.metric.EvaluationResult;
import org.jetbrains.kotlinx.dl.api.core.metric.Metric;
import org.jetbrains.kotlinx.dl.api.core.metric.Metrics;
import org.jetbrains.kotlinx.dl.api.core.optimizer.Optimizer;
import org.jetbrains.kotlinx.dl.api.core.regularizer.Regularizer;
import org.jetbrains.kotlinx.dl.api.core.shape.ShapeFunctionsKt;
import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape;
import org.jetbrains.kotlinx.dl.api.core.summary.LayerSummary;
import org.jetbrains.kotlinx.dl.api.core.summary.TfModelSummary;
import org.jetbrains.kotlinx.dl.api.core.util.ConvertersKt;
import org.jetbrains.kotlinx.dl.api.core.util.DtypeConversionUtilKt;
import org.jetbrains.kotlinx.dl.api.core.util.NameConventionsKt;
import org.jetbrains.kotlinx.dl.api.core.util.TensorNamesKt;
import org.jetbrains.kotlinx.dl.api.extension.TensorExtensionFunctionsKt;
import org.jetbrains.kotlinx.dl.api.inference.InferenceModel;
import org.jetbrains.kotlinx.dl.api.inference.keras.ModelSaverKt;
import org.jetbrains.kotlinx.dl.dataset.DataBatch;
import org.jetbrains.kotlinx.dl.dataset.Dataset;
import org.jetbrains.kotlinx.dl.impl.util.AutoClosableExtensionsKt;
import org.jetbrains.kotlinx.dl.impl.util.FloatArrayExtensionFunctionsKt;
import org.tensorflow.Graph;
import org.tensorflow.GraphOperation;
import org.tensorflow.Operand;
import org.tensorflow.Session;
import org.tensorflow.Shape;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlowException;
import org.tensorflow.op.NnOps;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.core.Variable;
import org.tensorflow.op.math.Add;
import org.tensorflow.op.nn.Softmax;

/* compiled from: GraphTrainableModel.kt */
@Metadata(mv = {Conv1D.EXTRA_DIM, 7, Conv1D.EXTRA_DIM}, k = Conv1D.EXTRA_DIM, xi = 48, d1 = {"��\u0098\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0011\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u0016\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010 \n\u0002\b\u0005\n\u0002\u0010$\n\u0002\u0010\u000e\n\u0002\b\u0005\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\u0010\u0007\n\u0002\b\r\n\u0002\u0010\u000b\n\u0002\b\u0006\n\u0002\u0010\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\b\n\u0002\b\u0004\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010��\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0006\n\u0002\u0018\u0002\n\u0002\b\u0007\n\u0002\u0018\u0002\n\u0002\b\b\n\u0002\u0010\u0014\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0010\u001e\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u0015\n\u0002\b\u0007\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\b\n\u0002\u0018\u0002\n\u0002\b\u0007\b&\u0018�� \u009e\u00012\u00020\u0001:\u0002\u009e\u0001B\u0019\u0012\u0012\u0010\u0002\u001a\n\u0012\u0006\b\u0001\u0012\u00020\u00040\u0003\"\u00020\u0004¢\u0006\u0002\u0010\u0005J \u00104\u001a\u0002052\u0006\u00106\u001a\u0002072\u0006\u00108\u001a\u00020\u00072\u0006\u00109\u001a\u00020\u0007H\u0002J<\u0010:\u001a\u001a\u0012\n\u0012\b\u0012\u0004\u0012\u00020 0<\u0012\n\u0012\b\u0012\u0004\u0012\u00020 0\u001f0;2\f\u0010-\u001a\b\u0012\u0004\u0012\u00020.0\u001f2\f\u0010=\u001a\b\u0012\u0004\u0012\u00020 0\u001fH$J\u0016\u0010>\u001a\b\u0012\u0004\u0012\u00020 0\u001f2\u0006\u0010?\u001a\u00020@H\u0002J\u0010\u0010A\u001a\u00020\u00072\u0006\u0010B\u001a\u00020CH\u0002J\u001c\u0010D\u001a\u000e\u0012\u0004\u0012\u00020\u0007\u0012\u0004\u0012\u00020\u00070;2\u0006\u00106\u001a\u000207H\u0002J\u0010\u0010E\u001a\u00020\u00072\u0006\u0010B\u001a\u00020CH\u0002J&\u0010F\u001a\u0002052\u0006\u0010G\u001a\u00020H2\u0006\u0010?\u001a\u00020@2\f\u0010I\u001a\b\u0012\u0004\u0012\u00020J0\u000eH\u0016J \u0010F\u001a\u0002052\u0006\u0010G\u001a\u00020H2\u0006\u0010?\u001a\u00020@2\u0006\u0010K\u001a\u00020JH\u0016J \u0010F\u001a\u0002052\u0006\u0010G\u001a\u00020H2\u0006\u0010?\u001a\u00020@2\u0006\u0010K\u001a\u00020LH\u0016J \u0010F\u001a\u0002052\u0006\u0010G\u001a\u00020H2\u0006\u0010?\u001a\u00020M2\u0006\u0010K\u001a\u00020JH\u0016J \u0010F\u001a\u0002052\u0006\u0010G\u001a\u00020H2\u0006\u0010?\u001a\u00020M2\u0006\u0010K\u001a\u00020LH\u0016J&\u0010N\u001a\u00020O2\u0006\u0010P\u001a\u00020Q2\u0006\u0010B\u001a\u00020C2\f\u0010R\u001a\b\u0012\u0004\u0012\u00020S0\u000eH\u0016J\u001d\u0010T\u001a\u0002052\u0006\u0010U\u001a\u00020V2\u0006\u0010W\u001a\u00020XH��¢\u0006\u0002\bYJ.\u0010Z\u001a\u00020[2\u0006\u0010P\u001a\u00020Q2\u0006\u0010\\\u001a\u00020C2\u0006\u0010B\u001a\u00020C2\f\u0010R\u001a\b\u0012\u0004\u0012\u00020S0\u000eH\u0016J>\u0010Z\u001a\u00020[2\u0006\u0010]\u001a\u00020Q2\u0006\u0010^\u001a\u00020Q2\u0006\u0010\\\u001a\u00020C2\u0006\u0010_\u001a\u00020C2\u0006\u0010`\u001a\u00020C2\f\u0010R\u001a\b\u0012\u0004\u0012\u00020S0\u000eH\u0016J0\u0010a\u001a\f\u0012\b\u0012\u0006\u0012\u0002\b\u00030b0\u000e2\u0006\u0010c\u001a\u00020\u00152\f\u0010d\u001a\b\u0012\u0004\u0012\u00020 0b2\u0006\u0010e\u001a\u00020.H\u0002J\u000e\u0010f\u001a\b\u0012\u0004\u0012\u00020V0\u000eH\u0002J\u0011\u0010g\u001a\u00020\u00042\u0006\u0010h\u001a\u00020\u0015H\u0086\u0004J,\u0010i\u001a\u001e\u0012\u001a\u0012\u0018\u0012\n\u0012\b\u0012\u0004\u0012\u00020 0j\u0012\b\u0012\u0006\u0012\u0002\b\u00030b0;0\u000e2\u0006\u0010k\u001a\u00020.H\u0002J\u0006\u0010l\u001a\u000205J\u0015\u0010l\u001a\u0002052\u0006\u0010U\u001a\u00020VH��¢\u0006\u0002\bmJO\u0010n\u001a\u00020[2\u0006\u0010_\u001a\u00020C2\u0006\u0010\\\u001a\u00020C2\u0006\u0010]\u001a\u00020Q2\u0006\u0010o\u001a\u00020.2\b\u0010^\u001a\u0004\u0018\u00010Q2\b\u0010`\u001a\u0004\u0018\u00010C2\f\u0010p\u001a\b\u0012\u0004\u0012\u00020S0\u000eH\u0002¢\u0006\u0002\u0010qJ0\u0010r\u001a\u0012\u0012\u0004\u0012\u00020s\u0012\b\u0012\u0006\u0012\u0002\b\u00030\u000e0;2\u0006\u0010t\u001a\u00020s2\u0006\u0010e\u001a\u00020.2\u0006\u0010c\u001a\u00020\u0015H\u0002J\u0010\u0010u\u001a\u00020.2\u0006\u0010v\u001a\u00020\u0015H\u0002J\u0006\u0010w\u001a\u00020xJ\u000e\u0010y\u001a\b\u0012\u0004\u0012\u00020V0\u000eH\u0002J0\u0010z\u001a\u0002052\f\u0010{\u001a\b\u0012\u0004\u0012\u00020\u00150|2\u0018\u0010}\u001a\u0014\u0012\u0004\u0012\u00020\u0015\u0012\u0004\u0012\u00020\u007f\u0012\u0004\u0012\u00020X0~H\u0014J\u001c\u0010\u0080\u0001\u001a\u0002052\b\u0010\u0081\u0001\u001a\u00030\u0082\u00012\u0007\u0010\u0083\u0001\u001a\u00020.H\u0016J\u0011\u0010\u0084\u0001\u001a\u00020C2\u0006\u0010t\u001a\u00020sH\u0016J\u0019\u0010\u0084\u0001\u001a\u00020C2\u0006\u0010t\u001a\u00020s2\u0006\u0010c\u001a\u00020\u0015H\u0016J(\u0010\u0084\u0001\u001a\u00030\u0085\u00012\u0006\u0010P\u001a\u00020Q2\u0006\u0010B\u001a\u00020C2\f\u0010R\u001a\b\u0012\u0004\u0012\u00020S0\u000eH\u0016J)\u0010\u0086\u0001\u001a\u0012\u0012\u0004\u0012\u00020C\u0012\b\u0012\u0006\u0012\u0002\b\u00030\u000e0;2\u0006\u0010t\u001a\u00020s2\u0006\u0010c\u001a\u00020\u0015H\u0016J\u0019\u0010\u0087\u0001\u001a\u00020s2\u0006\u0010t\u001a\u00020s2\u0006\u0010c\u001a\u00020\u0015H\u0016J3\u0010\u0087\u0001\u001a\b\u0012\u0004\u0012\u00020s0\u00032\u0006\u0010P\u001a\u00020Q2\u0006\u0010B\u001a\u00020C2\f\u0010R\u001a\b\u0012\u0004\u0012\u00020S0\u000eH\u0016¢\u0006\u0003\u0010\u0088\u0001J)\u0010\u0089\u0001\u001a\u0012\u0012\u0004\u0012\u00020s\u0012\b\u0012\u0006\u0012\u0002\b\u00030\u000e0;2\u0006\u0010t\u001a\u00020s2\u0006\u0010c\u001a\u00020\u0015H\u0014J\u0007\u0010\u008a\u0001\u001a\u000205J/\u0010\u008b\u0001\u001a\u0002052\b\u0010\u0081\u0001\u001a\u00030\u0082\u00012\b\u0010\u008c\u0001\u001a\u00030\u008d\u00012\u0006\u0010k\u001a\u00020.2\b\u0010\u008e\u0001\u001a\u00030\u008f\u0001H\u0016J\u0012\u0010\u0090\u0001\u001a\u0002052\u0007\u0010\u0091\u0001\u001a\u00020\u0015H\u0002J\u001a\u0010\u0092\u0001\u001a\u0002052\u0007\u0010\u0091\u0001\u001a\u00020\u00152\u0006\u0010k\u001a\u00020.H\u0002J\u0012\u0010\u0093\u0001\u001a\u0002052\u0007\u0010\u0091\u0001\u001a\u00020\u0015H\u0002J\u001a\u0010\u0094\u0001\u001a\u0002052\u0007\u0010\u0091\u0001\u001a\u00020\u00152\u0006\u0010k\u001a\u00020.H\u0002J\u0012\u0010\u0095\u0001\u001a\u0002052\u0007\u0010\u0091\u0001\u001a\u00020\u0015H\u0002J\u001a\u0010\u0096\u0001\u001a\u0002052\u0007\u0010\u0091\u0001\u001a\u00020\u00152\u0006\u0010k\u001a\u00020.H\u0004J\n\u0010\u0097\u0001\u001a\u00030\u0098\u0001H\u0016J\t\u0010\u0099\u0001\u001a\u00020\u0015H\u0016J~\u0010\u009a\u0001\u001a\u0014\u0012\u0004\u0012\u00020 \u0012\n\u0012\b\u0012\u0004\u0012\u00020 0\u000e0;2\u0012\u0010*\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020 0\u001f0\u000e2\r\u0010\u009b\u0001\u001a\b\u0012\u0004\u0012\u00020 0b2\r\u0010\u009c\u0001\u001a\b\u0012\u0004\u0012\u00020 0b2\f\u0010=\u001a\b\u0012\u0004\u0012\u00020 0b2\r\u0010\u009d\u0001\u001a\b\u0012\u0004\u0012\u00020 0b2\u0012\u0010%\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020 0\u001f0\u000eH\u0002R\u0014\u0010\u0006\u001a\u00020\u00078VX\u0096\u0004¢\u0006\u0006\u001a\u0004\b\b\u0010\tR\u0011\u0010\n\u001a\u00020\u000b8F¢\u0006\u0006\u001a\u0004\b\f\u0010\rR \u0010\u0002\u001a\b\u0012\u0004\u0012\u00020\u00040\u000eX\u0086\u000e¢\u0006\u000e\n��\u001a\u0004\b\u000f\u0010\u0010\"\u0004\b\u0011\u0010\u0012R&\u0010\u0013\u001a\u000e\u0012\u0004\u0012\u00020\u0015\u0012\u0004\u0012\u00020\u00040\u0014X\u0084\u000e¢\u0006\u000e\n��\u001a\u0004\b\u0016\u0010\u0017\"\u0004\b\u0018\u0010\u0019R\u0011\u0010\u001a\u001a\u00020\u001b¢\u0006\b\n��\u001a\u0004\b\u001c\u0010\u001dR \u0010\u001e\u001a\b\u0012\u0004\u0012\u00020 0\u001fX\u0084.¢\u0006\u000e\n��\u001a\u0004\b!\u0010\"\"\u0004\b#\u0010$R\u001a\u0010%\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020 0\u001f0\u000eX\u0082.¢\u0006\u0002\n��R \u0010&\u001a\b\u0012\u0004\u0012\u00020 0\u001fX\u0084.¢\u0006\u000e\n��\u001a\u0004\b'\u0010\"\"\u0004\b(\u0010$R\u0014\u0010)\u001a\b\u0012\u0004\u0012\u00020 0\u001fX\u0082.¢\u0006\u0002\n��R&\u0010*\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020 0\u001f0\u000eX\u0084.¢\u0006\u000e\n��\u001a\u0004\b+\u0010\u0010\"\u0004\b,\u0010\u0012R \u0010-\u001a\b\u0012\u0004\u0012\u00020.0\u001fX\u0084.¢\u0006\u000e\n��\u001a\u0004\b/\u0010\"\"\u0004\b0\u0010$R\u0014\u00101\u001a\b\u0012\u0004\u0012\u00020 0\u001fX\u0082.¢\u0006\u0002\n��R\u0014\u00102\u001a\b\u0012\u0004\u0012\u00020 0\u001fX\u0082.¢\u0006\u0002\n��R\u0014\u00103\u001a\b\u0012\u0004\u0012\u00020 0\u001fX\u0082.¢\u0006\u0002\n��¨\u0006\u009f\u0001"}, d2 = {"Lorg/jetbrains/kotlinx/dl/api/core/GraphTrainableModel;", "Lorg/jetbrains/kotlinx/dl/api/core/TrainableModel;", "layers", "", "Lorg/jetbrains/kotlinx/dl/api/core/layer/Layer;", "([Lorg/jetbrains/kotlinx/dl/api/core/layer/Layer;)V", "inputDimensions", "", "getInputDimensions", "()[J", "inputLayer", "Lorg/jetbrains/kotlinx/dl/api/core/layer/core/Input;", "getInputLayer", "()Lorg/jetbrains/kotlinx/dl/api/core/layer/core/Input;", "", "getLayers", "()Ljava/util/List;", "setLayers", "(Ljava/util/List;)V", "layersByName", "", "", "getLayersByName", "()Ljava/util/Map;", "setLayersByName", "(Ljava/util/Map;)V", "logger", "Lmu/KLogger;", "getLogger", "()Lmu/KLogger;", "lossOp", "Lorg/tensorflow/Operand;", "", "getLossOp", "()Lorg/tensorflow/Operand;", "setLossOp", "(Lorg/tensorflow/Operand;)V", "metricOps", "numberOfLossesOp", "getNumberOfLossesOp", "setNumberOfLossesOp", "predictionOp", "targets", "getTargets", "setTargets", "training", "", "getTraining", "setTraining", "xOp", "yPredOp", "yTrueOp", "batchValidation", "", "batch", "Lorg/jetbrains/kotlinx/dl/dataset/DataBatch;", "xBatchShape", "yBatchShape", "buildLayers", "Lkotlin/Pair;", "Lorg/tensorflow/op/core/Placeholder;", "numberOfLosses", "buildLossFunction", "loss", "Lorg/jetbrains/kotlinx/dl/api/core/loss/LossFunction;", "calculateXShape", "batchSize", "", "calculateXYShapes", "calculateYShape", "compile", "optimizer", "Lorg/jetbrains/kotlinx/dl/api/core/optimizer/Optimizer;", "metrics", "Lorg/jetbrains/kotlinx/dl/api/core/metric/Metric;", "metric", "Lorg/jetbrains/kotlinx/dl/api/core/metric/Metrics;", "Lorg/jetbrains/kotlinx/dl/api/core/loss/Losses;", "evaluate", "Lorg/jetbrains/kotlinx/dl/api/core/metric/EvaluationResult;", "dataset", "Lorg/jetbrains/kotlinx/dl/dataset/Dataset;", "callbacks", "Lorg/jetbrains/kotlinx/dl/api/core/callback/Callback;", "fill", "variable", "Lorg/jetbrains/kotlinx/dl/api/core/layer/KVariable;", "data", "", "fill$tensorflow", "fit", "Lorg/jetbrains/kotlinx/dl/api/core/history/TrainingHistory;", "epochs", "trainingDataset", "validationDataset", "trainBatchSize", "validationBatchSize", "formPredictionAndActivationsTensors", "Lorg/tensorflow/Tensor;", "predictionTensorName", "testImages", "visualizationIsEnabled", "frozenLayerVariables", "getLayer", "layerName", "getVariablesAndTensors", "Lorg/tensorflow/op/core/Variable;", "saveOptimizerState", "init", "init$tensorflow", "internalFit", "validationIsEnabled", "fitCallbacks", "(IILorg/jetbrains/kotlinx/dl/dataset/Dataset;ZLorg/jetbrains/kotlinx/dl/dataset/Dataset;Ljava/lang/Integer;Ljava/util/List;)Lorg/jetbrains/kotlinx/dl/api/core/history/TrainingHistory;", "internalPredict", "", "inputData", "isVariableRelatedToFrozenLayer", "variableName", "kGraph", "Lorg/jetbrains/kotlinx/dl/api/core/KGraph;", "layerVariables", "loadVariables", "variableNames", "", "getData", "Lkotlin/Function2;", "Lorg/tensorflow/Shape;", "loadWeights", "modelDirectory", "Ljava/io/File;", "loadOptimizerState", "predict", "", "predictAndGetActivations", "predictSoftly", "(Lorg/jetbrains/kotlinx/dl/dataset/Dataset;ILjava/util/List;)[[F", "predictSoftlyAndGetActivations", "reset", "save", "savingFormat", "Lorg/jetbrains/kotlinx/dl/api/core/SavingFormat;", "writingMode", "Lorg/jetbrains/kotlinx/dl/api/core/WritingMode;", "saveGraphDef", "pathToModelDirectory", "saveInKerasFormat", "saveInSavedModelFormat", "saveInSimpleFormat", "saveModel", "saveVariables", "summary", "Lorg/jetbrains/kotlinx/dl/api/core/summary/TfModelSummary;", "toString", "trainOnBatch", "batchImages", "batchLabels", "isTraining", "Companion", "tensorflow"})
/* loaded from: input_file:org/jetbrains/kotlinx/dl/api/core/GraphTrainableModel.class */
public abstract class GraphTrainableModel extends TrainableModel {

    @NotNull
    public static final Companion Companion = new Companion(null);

    @NotNull
    private final KLogger logger;

    @NotNull
    private List<? extends Layer> layers;

    @NotNull
    private Map<String, ? extends Layer> layersByName;
    private Operand<Float> yPredOp;
    protected Operand<Float> lossOp;
    private Operand<Float> predictionOp;
    private List<? extends Operand<Float>> metricOps;
    protected List<? extends Operand<Float>> targets;
    private Operand<Float> xOp;
    private Operand<Float> yTrueOp;
    protected Operand<Float> numberOfLossesOp;
    protected Operand<Boolean> training;

    /* compiled from: GraphTrainableModel.kt */
    @Metadata(mv = {Conv1D.EXTRA_DIM, 7, Conv1D.EXTRA_DIM}, k = Conv1D.EXTRA_DIM, xi = 48, d1 = {"��&\n\u0002\u0018\u0002\n\u0002\u0010��\n\u0002\b\u0002\n\u0002\u0010\u0002\n��\n\u0002\u0010 \n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u0011\n\u0002\b\u0003\b\u0080\u0003\u0018��2\u00020\u0001B\u0007\b\u0002¢\u0006\u0002\u0010\u0002J\u001b\u0010\u0003\u001a\u00020\u00042\f\u0010\u0005\u001a\b\u0012\u0004\u0012\u00020\u00070\u0006H��¢\u0006\u0002\b\bJ\u001f\u0010\t\u001a\u00020\u00042\u000e\u0010\u0005\u001a\n\u0012\u0006\b\u0001\u0012\u00020\u00070\nH��¢\u0006\u0004\b\u000b\u0010\f¨\u0006\r"}, d2 = {"Lorg/jetbrains/kotlinx/dl/api/core/GraphTrainableModel$Companion;", "", "()V", "layerValidation", "", "layers", "", "Lorg/jetbrains/kotlinx/dl/api/core/layer/Layer;", "layerValidation$tensorflow", "preProcessLayerNames", "", "preProcessLayerNames$tensorflow", "([Lorg/jetbrains/kotlinx/dl/api/core/layer/Layer;)V", "tensorflow"})
    /* loaded from: input_file:org/jetbrains/kotlinx/dl/api/core/GraphTrainableModel$Companion.class */
    public static final class Companion {
        private Companion() {
        }

        public final void preProcessLayerNames$tensorflow(@NotNull Layer[] layerArr) {
            Intrinsics.checkNotNullParameter(layerArr, "layers");
            int length = layerArr.length;
            for (int i = 0; i < length; i++) {
                int i2 = i;
                Layer layer = layerArr[i];
                if (layer.getName().length() == 0) {
                    String simpleName = Reflection.getOrCreateKotlinClass(layer.getClass()).getSimpleName();
                    if (simpleName == null) {
                        simpleName = "layer";
                    }
                    StringBuilder sb = new StringBuilder();
                    Locale locale = Locale.getDefault();
                    Intrinsics.checkNotNullExpressionValue(locale, "getDefault()");
                    String lowerCase = simpleName.toLowerCase(locale);
                    Intrinsics.checkNotNullExpressionValue(lowerCase, "this as java.lang.String).toLowerCase(locale)");
                    layer.setName(sb.append(lowerCase).append('_').append(i2 + 1).toString());
                }
            }
        }

        public final void layerValidation$tensorflow(@NotNull List<? extends Layer> list) {
            Intrinsics.checkNotNullParameter(list, "layers");
            if (!(!list.isEmpty())) {
                throw new IllegalArgumentException("Model should contain layers!".toString());
            }
            if (!(list.get(0) instanceof Input)) {
                throw new IllegalArgumentException("Model should start from the Input layer".toString());
            }
        }

        public /* synthetic */ Companion(DefaultConstructorMarker defaultConstructorMarker) {
            this();
        }
    }

    /* compiled from: GraphTrainableModel.kt */
    @Metadata(mv = {Conv1D.EXTRA_DIM, 7, Conv1D.EXTRA_DIM}, k = 3, xi = 48)
    /* loaded from: input_file:org/jetbrains/kotlinx/dl/api/core/GraphTrainableModel$WhenMappings.class */
    public /* synthetic */ class WhenMappings {
        public static final /* synthetic */ int[] $EnumSwitchMapping$0;
        public static final /* synthetic */ int[] $EnumSwitchMapping$1;

        static {
            int[] iArr = new int[WritingMode.values().length];
            iArr[WritingMode.FAIL_IF_EXISTS.ordinal()] = 1;
            iArr[WritingMode.OVERRIDE.ordinal()] = 2;
            iArr[WritingMode.APPEND.ordinal()] = 3;
            $EnumSwitchMapping$0 = iArr;
            int[] iArr2 = new int[SavingFormat.values().length];
            iArr2[SavingFormat.TF_GRAPH_CUSTOM_VARIABLES.ordinal()] = 1;
            iArr2[SavingFormat.TF_GRAPH.ordinal()] = 2;
            iArr2[SavingFormat.JSON_CONFIG_CUSTOM_VARIABLES.ordinal()] = 3;
            $EnumSwitchMapping$1 = iArr2;
        }
    }

    public GraphTrainableModel(@NotNull Layer... layerArr) {
        Intrinsics.checkNotNullParameter(layerArr, "layers");
        this.logger = KotlinLogging.INSTANCE.logger(new Function0<Unit>() { // from class: org.jetbrains.kotlinx.dl.api.core.GraphTrainableModel$logger$1
            public final void invoke() {
            }

            /* renamed from: invoke, reason: collision with other method in class */
            public /* bridge */ /* synthetic */ Object m11invoke() {
                invoke();
                return Unit.INSTANCE;
            }
        });
        this.layers = CollectionsKt.listOf(Arrays.copyOf(layerArr, layerArr.length));
        this.layersByName = MapsKt.emptyMap();
        for (final Layer layer : layerArr) {
            if (this.layersByName.containsKey(layer.getName())) {
                throw new RepeatableLayerNameException(layer.getName());
            }
            this.layersByName = MapsKt.plus(this.layersByName, TuplesKt.to(layer.getName(), layer));
            if (layer.getParentModel() != null) {
                this.logger.warn(new Function0<Object>() { // from class: org.jetbrains.kotlinx.dl.api.core.GraphTrainableModel.1
                    {
                        super(0);
                    }

                    @Nullable
                    public final Object invoke() {
                        return "Layer " + Layer.this.getName() + " is a part of model " + Layer.this.getParentModel();
                    }
                });
            }
            layer.setParentModel(this);
        }
        byte[] graphDef = new Graph().toGraphDef();
        Intrinsics.checkNotNullExpressionValue(graphDef, "Graph().toGraphDef()");
        setKGraph(new KGraph(graphDef));
        Ops create = Ops.create(getKGraph().getTfGraph$tensorflow());
        Intrinsics.checkNotNullExpressionValue(create, "create(kGraph.tfGraph)");
        setTf(create);
        setSession$tensorflow(new Session(getKGraph().getTfGraph$tensorflow()));
    }

    @NotNull
    public final KLogger getLogger() {
        return this.logger;
    }

    @NotNull
    public final List<Layer> getLayers() {
        return this.layers;
    }

    public final void setLayers(@NotNull List<? extends Layer> list) {
        Intrinsics.checkNotNullParameter(list, "<set-?>");
        this.layers = list;
    }

    @NotNull
    public final Input getInputLayer() {
        Layer layer = this.layers.get(0);
        Intrinsics.checkNotNull(layer, "null cannot be cast to non-null type org.jetbrains.kotlinx.dl.api.core.layer.core.Input");
        return (Input) layer;
    }

    @Override // org.jetbrains.kotlinx.dl.api.inference.TensorFlowInferenceModel
    @NotNull
    public long[] getInputDimensions() {
        Layer layer = this.layers.get(0);
        Intrinsics.checkNotNull(layer, "null cannot be cast to non-null type org.jetbrains.kotlinx.dl.api.core.layer.core.Input");
        return ((Input) layer).getPackedDims();
    }

    @NotNull
    protected final Map<String, Layer> getLayersByName() {
        return this.layersByName;
    }

    protected final void setLayersByName(@NotNull Map<String, ? extends Layer> map) {
        Intrinsics.checkNotNullParameter(map, "<set-?>");
        this.layersByName = map;
    }

    @NotNull
    protected final Operand<Float> getLossOp() {
        Operand<Float> operand = this.lossOp;
        if (operand != null) {
            return operand;
        }
        Intrinsics.throwUninitializedPropertyAccessException("lossOp");
        return null;
    }

    protected final void setLossOp(@NotNull Operand<Float> operand) {
        Intrinsics.checkNotNullParameter(operand, "<set-?>");
        this.lossOp = operand;
    }

    @NotNull
    protected final List<Operand<Float>> getTargets() {
        List<? extends Operand<Float>> list = this.targets;
        if (list != null) {
            return list;
        }
        Intrinsics.throwUninitializedPropertyAccessException("targets");
        return null;
    }

    protected final void setTargets(@NotNull List<? extends Operand<Float>> list) {
        Intrinsics.checkNotNullParameter(list, "<set-?>");
        this.targets = list;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @NotNull
    public final Operand<Float> getNumberOfLossesOp() {
        Operand<Float> operand = this.numberOfLossesOp;
        if (operand != null) {
            return operand;
        }
        Intrinsics.throwUninitializedPropertyAccessException("numberOfLossesOp");
        return null;
    }

    protected final void setNumberOfLossesOp(@NotNull Operand<Float> operand) {
        Intrinsics.checkNotNullParameter(operand, "<set-?>");
        this.numberOfLossesOp = operand;
    }

    @NotNull
    protected final Operand<Boolean> getTraining() {
        Operand<Boolean> operand = this.training;
        if (operand != null) {
            return operand;
        }
        Intrinsics.throwUninitializedPropertyAccessException("training");
        return null;
    }

    protected final void setTraining(@NotNull Operand<Boolean> operand) {
        Intrinsics.checkNotNullParameter(operand, "<set-?>");
        this.training = operand;
    }

    private final List<KVariable> layerVariables() {
        return ParametrizedLayerKt.variables(this.layers);
    }

    private final List<KVariable> frozenLayerVariables() {
        return ParametrizedLayerKt.frozenVariables(this.layers);
    }

    @Override // org.jetbrains.kotlinx.dl.api.core.TrainableModel
    public void compile(@NotNull Optimizer optimizer, @NotNull Losses losses, @NotNull Metrics metrics) {
        Intrinsics.checkNotNullParameter(optimizer, "optimizer");
        Intrinsics.checkNotNullParameter(losses, "loss");
        Intrinsics.checkNotNullParameter(metrics, "metric");
        compile(optimizer, Losses.Companion.convert(losses), Metric.Companion.convert(metrics));
    }

    @Override // org.jetbrains.kotlinx.dl.api.core.TrainableModel
    public void compile(@NotNull Optimizer optimizer, @NotNull LossFunction lossFunction, @NotNull Metric metric) {
        Intrinsics.checkNotNullParameter(optimizer, "optimizer");
        Intrinsics.checkNotNullParameter(lossFunction, "loss");
        Intrinsics.checkNotNullParameter(metric, "metric");
        compile(optimizer, lossFunction, CollectionsKt.listOf(metric));
    }

    @Override // org.jetbrains.kotlinx.dl.api.core.TrainableModel
    public void compile(@NotNull Optimizer optimizer, @NotNull LossFunction lossFunction, @NotNull List<? extends Metric> list) {
        Operand<Float> operand;
        Intrinsics.checkNotNullParameter(optimizer, "optimizer");
        Intrinsics.checkNotNullParameter(lossFunction, "loss");
        Intrinsics.checkNotNullParameter(list, "metrics");
        if (!(!isModelCompiled())) {
            throw new IllegalStateException("The model is compiled already. Graph is created. Create new model and compile it.".toString());
        }
        setLoss(lossFunction);
        setMetrics(list);
        setOptimizer(optimizer);
        Placeholder placeholder = getTf().withName("training").placeholder(Boolean.class, new Placeholder.Options[]{Placeholder.shape(Shape.scalar())});
        Intrinsics.checkNotNullExpressionValue(placeholder, "tf.withName(\"training\").…Shape.scalar())\n        )");
        setTraining((Operand) placeholder);
        Placeholder placeholder2 = getTf().withName("numberOfLosses").placeholder(DtypeConversionUtilKt.getDType(), new Placeholder.Options[]{Placeholder.shape(Shape.scalar())});
        Intrinsics.checkNotNullExpressionValue(placeholder2, "tf.withName(\"numberOfLos…Shape.scalar())\n        )");
        setNumberOfLossesOp((Operand) placeholder2);
        Pair<Placeholder<Float>, Operand<Float>> buildLayers = buildLayers(getTraining(), getNumberOfLossesOp());
        Operand<Float> operand2 = (Placeholder) buildLayers.component1();
        Operand<Float> operand3 = (Operand) buildLayers.component2();
        this.xOp = operand2;
        this.yPredOp = operand3;
        Layer layer = (Layer) CollectionsKt.last(this.layers);
        setNumberOfClasses(layer instanceof Dense ? ((Dense) layer).getOutputSize() : layer instanceof ActivationLayer ? ArraysKt.last(layer.getOutputShape().tail()) : 1L);
        Operand<Float> placeholder3 = getTf().placeholder(DtypeConversionUtilKt.getDType(), new Placeholder.Options[0]);
        Intrinsics.checkNotNull(placeholder3, "null cannot be cast to non-null type org.tensorflow.Operand<kotlin.Float>");
        this.yTrueOp = placeholder3;
        setLossOp(buildLossFunction(lossFunction));
        KGraph kGraph = getKGraph();
        List<KVariable> trainableVariables = ParametrizedLayerKt.trainableVariables(this.layers);
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(trainableVariables, 10));
        Iterator<T> it = trainableVariables.iterator();
        while (it.hasNext()) {
            arrayList.add(((KVariable) it.next()).getVariable());
        }
        setTargets(optimizer.prepareTargets$tensorflow(kGraph, arrayList, getTf(), getLossOp()));
        if (lossFunction instanceof SoftmaxCrossEntropyWithLogits) {
            NnOps nnOps = getTf().withName(TensorNamesKt.OUTPUT_NAME).nn;
            Operand<Float> operand4 = this.yPredOp;
            if (operand4 == null) {
                Intrinsics.throwUninitializedPropertyAccessException("yPredOp");
                operand4 = null;
            }
            Softmax softmax = nnOps.softmax(operand4);
            Intrinsics.checkNotNullExpressionValue(softmax, "tf.withName(OUTPUT_NAME).nn.softmax(yPredOp)");
            operand = (Operand) softmax;
        } else {
            Ops withName = getTf().withName(TensorNamesKt.OUTPUT_NAME);
            Operand<Float> operand5 = this.yPredOp;
            if (operand5 == null) {
                Intrinsics.throwUninitializedPropertyAccessException("yPredOp");
                operand5 = null;
            }
            Operand<Float> identity = withName.identity(operand5);
            Intrinsics.checkNotNullExpressionValue(identity, "tf.withName(OUTPUT_NAME).identity(yPredOp)");
            operand = identity;
        }
        this.predictionOp = operand;
        List<? extends Metric> list2 = list;
        ArrayList arrayList2 = new ArrayList(CollectionsKt.collectionSizeOrDefault(list2, 10));
        for (Metric metric : list2) {
            Ops tf = getTf();
            Operand<Float> operand6 = this.predictionOp;
            if (operand6 == null) {
                Intrinsics.throwUninitializedPropertyAccessException("predictionOp");
                operand6 = null;
            }
            Operand<Float> operand7 = this.yTrueOp;
            if (operand7 == null) {
                Intrinsics.throwUninitializedPropertyAccessException("yTrueOp");
                operand7 = null;
            }
            arrayList2.add(metric.apply(tf, operand6, operand7, getNumberOfLossesOp()));
        }
        this.metricOps = arrayList2;
        setModelCompiled$tensorflow(true);
    }

    private final Operand<Float> buildLossFunction(LossFunction lossFunction) {
        Ops tf = getTf();
        Operand<Float> operand = this.yPredOp;
        if (operand == null) {
            Intrinsics.throwUninitializedPropertyAccessException("yPredOp");
            operand = null;
        }
        Operand<Float> operand2 = this.yTrueOp;
        if (operand2 == null) {
            Intrinsics.throwUninitializedPropertyAccessException("yTrueOp");
            operand2 = null;
        }
        Add apply = lossFunction.apply(tf, operand, operand2, getNumberOfLossesOp());
        for (KVariable kVariable : ParametrizedLayerKt.trainableVariables(this.layers)) {
            Regularizer regularizer = kVariable.getRegularizer();
            if (regularizer != null) {
                Add add = getTf().math.add((Operand) apply, regularizer.apply(getTf(), (Operand) kVariable.getVariable()));
                Intrinsics.checkNotNullExpressionValue(add, "tf.math.add(totalLoss, r…y(tf, variable.variable))");
                apply = add;
            }
        }
        Operand<Float> identity = getTf().withName(TensorNamesKt.TRAINING_LOSS).identity((Operand) apply);
        Intrinsics.checkNotNullExpressionValue(identity, "tf.withName(TRAINING_LOSS).identity(totalLoss)");
        return identity;
    }

    @Override // org.jetbrains.kotlinx.dl.api.core.TrainableModel
    public void compile(@NotNull Optimizer optimizer, @NotNull Losses losses, @NotNull Metric metric) {
        Intrinsics.checkNotNullParameter(optimizer, "optimizer");
        Intrinsics.checkNotNullParameter(losses, "loss");
        Intrinsics.checkNotNullParameter(metric, "metric");
        compile(optimizer, Losses.Companion.convert(losses), metric);
    }

    @Override // org.jetbrains.kotlinx.dl.api.core.TrainableModel
    public void compile(@NotNull Optimizer optimizer, @NotNull LossFunction lossFunction, @NotNull Metrics metrics) {
        Intrinsics.checkNotNullParameter(optimizer, "optimizer");
        Intrinsics.checkNotNullParameter(lossFunction, "loss");
        Intrinsics.checkNotNullParameter(metrics, "metric");
        compile(optimizer, lossFunction, Metric.Companion.convert(metrics));
    }

    @NotNull
    protected abstract Pair<Placeholder<Float>, Operand<Float>> buildLayers(@NotNull Operand<Boolean> operand, @NotNull Operand<Float> operand2);

    @Override // org.jetbrains.kotlinx.dl.api.core.TrainableModel
    @NotNull
    public TrainingHistory fit(@NotNull Dataset dataset, @NotNull Dataset dataset2, int i, int i2, int i3, @NotNull List<? extends Callback> list) {
        Intrinsics.checkNotNullParameter(dataset, "trainingDataset");
        Intrinsics.checkNotNullParameter(dataset2, "validationDataset");
        Intrinsics.checkNotNullParameter(list, "callbacks");
        return internalFit(i2, i, dataset, true, dataset2, Integer.valueOf(i3), list);
    }

    @Override // org.jetbrains.kotlinx.dl.api.core.TrainableModel
    @NotNull
    public TrainingHistory fit(@NotNull Dataset dataset, int i, int i2, @NotNull List<? extends Callback> list) {
        Intrinsics.checkNotNullParameter(dataset, "dataset");
        Intrinsics.checkNotNullParameter(list, "callbacks");
        return internalFit(i2, i, dataset, false, null, null, list);
    }

    public final void init() {
        if (!isModelCompiled()) {
            throw new IllegalStateException("The model is not compiled yet. Compile the model to use this method.".toString());
        }
        if (!(!isModelInitialized())) {
            throw new IllegalStateException("Model is initialized already!".toString());
        }
        if (!(!isOptimizerVariableInitialized())) {
            throw new IllegalStateException("Optimizer variables are initialized already!".toString());
        }
        this.logger.debug(new Function0<Object>() { // from class: org.jetbrains.kotlinx.dl.api.core.GraphTrainableModel$init$4
            @Nullable
            public final Object invoke() {
                return "Initialization of TensorFlow Graph variables.";
            }
        });
        ParametrizedLayerKt.initializeVariables(this.layers, getSession$tensorflow());
        setModelInitialized$tensorflow(true);
    }

    public final void reset() {
        if (!isModelCompiled()) {
            throw new IllegalStateException("The model is not compiled yet. Compile the model to use this method.".toString());
        }
        this.logger.debug(new Function0<Object>() { // from class: org.jetbrains.kotlinx.dl.api.core.GraphTrainableModel$reset$2
            @Nullable
            public final Object invoke() {
                return "Initialization of TensorFlow Graph variables.";
            }
        });
        ParametrizedLayerKt.initializeVariables(this.layers, getSession$tensorflow());
        setModelInitialized$tensorflow(true);
        setOptimizerVariableInitialized$tensorflow(false);
    }

    /* JADX WARN: Code restructure failed: missing block: B:66:0x02e1, code lost:
    
        if ((r0 == Float.NEGATIVE_INFINITY) != false) goto L60;
     */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    private final org.jetbrains.kotlinx.dl.api.core.history.TrainingHistory internalFit(int r11, int r12, org.jetbrains.kotlinx.dl.dataset.Dataset r13, boolean r14, org.jetbrains.kotlinx.dl.dataset.Dataset r15, java.lang.Integer r16, java.util.List<? extends org.jetbrains.kotlinx.dl.api.core.callback.Callback> r17) {
        /*
            Method dump skipped, instructions count: 1979
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: org.jetbrains.kotlinx.dl.api.core.GraphTrainableModel.internalFit(int, int, org.jetbrains.kotlinx.dl.dataset.Dataset, boolean, org.jetbrains.kotlinx.dl.dataset.Dataset, java.lang.Integer, java.util.List):org.jetbrains.kotlinx.dl.api.core.history.TrainingHistory");
    }

    private final Pair<Float, List<Float>> trainOnBatch(List<? extends Operand<Float>> list, Tensor<Float> tensor, Tensor<Float> tensor2, Tensor<Float> tensor3, Tensor<Float> tensor4, final List<? extends Operand<Float>> list2) {
        Session.Runner runner = getSession$tensorflow().runner();
        Iterator<T> it = list.iterator();
        while (it.hasNext()) {
            runner.addTarget((Operand) it.next());
        }
        Operand<Float> operand = this.xOp;
        if (operand == null) {
            Intrinsics.throwUninitializedPropertyAccessException("xOp");
            operand = null;
        }
        Session.Runner feed = runner.feed(operand.asOutput(), tensor);
        Operand<Float> operand2 = this.yTrueOp;
        if (operand2 == null) {
            Intrinsics.throwUninitializedPropertyAccessException("yTrueOp");
            operand2 = null;
        }
        feed.feed(operand2.asOutput(), tensor2).feed(getNumberOfLossesOp().asOutput(), tensor3).feed(getTraining().asOutput(), tensor4);
        runner.fetch(TensorNamesKt.TRAINING_LOSS);
        Iterator<T> it2 = list2.iterator();
        while (it2.hasNext()) {
            runner.fetch((Operand) it2.next());
        }
        try {
            List run = runner.run();
            Intrinsics.checkNotNullExpressionValue(run, "runner.run()");
            return (Pair) AutoClosableExtensionsKt.use(run, new Function1<List<? extends Tensor<?>>, Pair<? extends Float, ? extends List<Float>>>() { // from class: org.jetbrains.kotlinx.dl.api.core.GraphTrainableModel$trainOnBatch$3
                /* JADX INFO: Access modifiers changed from: package-private */
                /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
                {
                    super(1);
                }

                @NotNull
                public final Pair<Float, List<Float>> invoke(@NotNull List<? extends Tensor<?>> list3) {
                    Intrinsics.checkNotNullParameter(list3, "tensorList");
                    float floatValue = list3.get(0).floatValue();
                    ArrayList arrayList = new ArrayList();
                    boolean z = list3.size() == list2.size() + 1;
                    List<Operand<Float>> list4 = list2;
                    if (!z) {
                        throw new IllegalStateException((list4.size() + " metrics are monitored, but " + (list3.size() - 1) + " metrics are returned!").toString());
                    }
                    int i = 1;
                    int size = list2.size();
                    if (1 <= size) {
                        while (true) {
                            arrayList.add(Float.valueOf(list3.get(i).floatValue()));
                            if (i == size) {
                                break;
                            }
                            i++;
                        }
                    }
                    return TuplesKt.to(Float.valueOf(floatValue), arrayList);
                }
            });
        } catch (TensorFlowException e) {
            e.printStackTrace();
            throw new RuntimeException(e.getMessage());
        }
    }

    @Override // org.jetbrains.kotlinx.dl.api.core.TrainableModel
    @NotNull
    public EvaluationResult evaluate(@NotNull Dataset dataset, int i, @NotNull List<? extends Callback> list) {
        int i2;
        Intrinsics.checkNotNullParameter(dataset, "dataset");
        Intrinsics.checkNotNullParameter(list, "callbacks");
        if (!isModelCompiled()) {
            throw new IllegalStateException("The model is not compiled yet. Compile the model to use this method.".toString());
        }
        if (!isModelInitialized()) {
            throw new IllegalStateException("The model is not initialized yet. Initialize the model weights with init() method or load weights to use this method.".toString());
        }
        History history = new History();
        Iterator<T> it = list.iterator();
        while (it.hasNext()) {
            ((Callback) it.next()).setModel$tensorflow(this);
        }
        Iterator<T> it2 = list.iterator();
        while (it2.hasNext()) {
            ((Callback) it2.next()).onTestBegin();
        }
        Dataset.BatchIterator batchIterator = dataset.batchIterator(i);
        int size = getMetrics().size();
        float[] fArr = new float[size];
        for (int i3 = 0; i3 < size; i3++) {
            fArr[i3] = 0.0f;
        }
        float f = 0.0f;
        int i4 = 0;
        while (true) {
            i2 = i4;
            if (!batchIterator.hasNext()) {
                break;
            }
            Iterator<T> it3 = list.iterator();
            while (it3.hasNext()) {
                ((Callback) it3.next()).onTestBatchBegin(i2, i, history);
            }
            DataBatch next = batchIterator.next();
            Pair<long[], long[]> calculateXYShapes = calculateXYShapes(next);
            long[] jArr = (long[]) calculateXYShapes.component1();
            long[] jArr2 = (long[]) calculateXYShapes.component2();
            Tensor tensor = (AutoCloseable) Tensor.create(jArr, ConvertersKt.serializeToBuffer(next.getX()));
            try {
                Tensor tensor2 = tensor;
                Tensor tensor3 = (AutoCloseable) Tensor.create(jArr2, ConvertersKt.serializeLabelsToBuffer(next.getY(), getNumberOfClasses()));
                try {
                    Tensor tensor4 = tensor3;
                    Tensor tensor5 = (AutoCloseable) Tensor.create(Float.valueOf((float) new TensorShape(jArr2).numElements()));
                    Throwable th = null;
                    try {
                        try {
                            Tensor tensor6 = tensor5;
                            tensor5 = (AutoCloseable) Tensor.create(false);
                            Throwable th2 = null;
                            try {
                                try {
                                    Tensor tensor7 = tensor5;
                                    Session.Runner fetch = getSession$tensorflow().runner().fetch(TensorNamesKt.TRAINING_LOSS);
                                    List<? extends Operand<Float>> list2 = this.metricOps;
                                    if (list2 == null) {
                                        Intrinsics.throwUninitializedPropertyAccessException("metricOps");
                                        list2 = null;
                                    }
                                    Iterator<T> it4 = list2.iterator();
                                    while (it4.hasNext()) {
                                        fetch.fetch((Operand) it4.next());
                                    }
                                    Operand<Float> operand = this.xOp;
                                    if (operand == null) {
                                        Intrinsics.throwUninitializedPropertyAccessException("xOp");
                                        operand = null;
                                    }
                                    Session.Runner feed = fetch.feed(operand.asOutput(), tensor2);
                                    Operand<Float> operand2 = this.yTrueOp;
                                    if (operand2 == null) {
                                        Intrinsics.throwUninitializedPropertyAccessException("yTrueOp");
                                        operand2 = null;
                                    }
                                    List run = feed.feed(operand2.asOutput(), tensor4).feed(getTraining().asOutput(), tensor7).feed(getNumberOfLossesOp().asOutput(), tensor6).run();
                                    Intrinsics.checkNotNullExpressionValue(run, "runner\n                 …                   .run()");
                                    Pair pair = (Pair) AutoClosableExtensionsKt.use(run, new Function1<List<? extends Tensor<?>>, Pair<? extends Float, ? extends List<Float>>>() { // from class: org.jetbrains.kotlinx.dl.api.core.GraphTrainableModel$evaluate$6$1$1$1$2
                                        /* JADX INFO: Access modifiers changed from: package-private */
                                        {
                                            super(1);
                                        }

                                        @NotNull
                                        public final Pair<Float, List<Float>> invoke(@NotNull List<? extends Tensor<?>> list3) {
                                            List list4;
                                            List list5;
                                            List list6;
                                            Intrinsics.checkNotNullParameter(list3, "lossAndMetricsTensors");
                                            float floatValue = list3.get(0).floatValue();
                                            ArrayList arrayList = new ArrayList();
                                            int size2 = list3.size();
                                            list4 = GraphTrainableModel.this.metricOps;
                                            if (list4 == null) {
                                                Intrinsics.throwUninitializedPropertyAccessException("metricOps");
                                                list4 = null;
                                            }
                                            boolean z = size2 == list4.size() + 1;
                                            GraphTrainableModel graphTrainableModel = GraphTrainableModel.this;
                                            if (!z) {
                                                StringBuilder sb = new StringBuilder();
                                                list6 = graphTrainableModel.metricOps;
                                                if (list6 == null) {
                                                    Intrinsics.throwUninitializedPropertyAccessException("metricOps");
                                                    list6 = null;
                                                }
                                                throw new IllegalStateException(sb.append(list6.size()).append(" metrics are monitored, but ").append(list3.size() - 1).append(" metrics are returned!").toString().toString());
                                            }
                                            int i5 = 1;
                                            list5 = GraphTrainableModel.this.metricOps;
                                            if (list5 == null) {
                                                Intrinsics.throwUninitializedPropertyAccessException("metricOps");
                                                list5 = null;
                                            }
                                            int size3 = list5.size();
                                            if (1 <= size3) {
                                                while (true) {
                                                    arrayList.add(Float.valueOf(list3.get(i5).floatValue()));
                                                    if (i5 == size3) {
                                                        break;
                                                    }
                                                    i5++;
                                                }
                                            }
                                            return TuplesKt.to(Float.valueOf(floatValue), arrayList);
                                        }
                                    });
                                    float floatValue = ((Number) pair.component1()).floatValue();
                                    List list3 = (List) pair.component2();
                                    f += floatValue;
                                    int i5 = 0;
                                    for (Object obj : getMetrics()) {
                                        int i6 = i5;
                                        i5++;
                                        if (i6 < 0) {
                                            CollectionsKt.throwIndexOverflow();
                                        }
                                        fArr[i6] = fArr[i6] + ((Number) list3.get(i6)).floatValue();
                                    }
                                    double d = floatValue;
                                    ArrayList arrayList = new ArrayList(fArr.length);
                                    for (float f2 : fArr) {
                                        arrayList.add(Double.valueOf(f2));
                                    }
                                    BatchEvent batchEvent = new BatchEvent(i2, d, arrayList);
                                    history.appendBatch(batchEvent);
                                    Iterator<T> it5 = list.iterator();
                                    while (it5.hasNext()) {
                                        ((Callback) it5.next()).onTestBatchEnd(i2, i, batchEvent, history);
                                    }
                                    Unit unit = Unit.INSTANCE;
                                    AutoCloseableKt.closeFinally(tensor5, (Throwable) null);
                                    Unit unit2 = Unit.INSTANCE;
                                    AutoCloseableKt.closeFinally(tensor5, (Throwable) null);
                                    Unit unit3 = Unit.INSTANCE;
                                    AutoCloseableKt.closeFinally(tensor3, (Throwable) null);
                                    Unit unit4 = Unit.INSTANCE;
                                    AutoCloseableKt.closeFinally(tensor, (Throwable) null);
                                    i4 = i2 + 1;
                                } finally {
                                }
                            } finally {
                            }
                        } finally {
                        }
                    } finally {
                        AutoCloseableKt.closeFinally(tensor5, th);
                    }
                } catch (Throwable th3) {
                    AutoCloseableKt.closeFinally(tensor3, (Throwable) null);
                    throw th3;
                }
            } catch (Throwable th4) {
                AutoCloseableKt.closeFinally(tensor, (Throwable) null);
                throw th4;
            }
        }
        int size2 = getMetrics().size();
        float[] fArr2 = new float[size2];
        for (int i7 = 0; i7 < size2; i7++) {
            fArr2[i7] = 0.0f;
        }
        int i8 = 0;
        for (float f3 : fArr) {
            int i9 = i8;
            i8++;
            fArr2[i9] = f3 / i2;
        }
        double d2 = f / i2;
        Iterator<T> it6 = list.iterator();
        while (it6.hasNext()) {
            ((Callback) it6.next()).onTestEnd(history);
        }
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        int i10 = 0;
        for (Object obj2 : getMetrics()) {
            int i11 = i10;
            i10++;
            if (i11 < 0) {
                CollectionsKt.throwIndexOverflow();
            }
            linkedHashMap.put(Metric.Companion.convertBack((Metric) obj2), Double.valueOf(fArr2[i11]));
        }
        return new EvaluationResult(d2, linkedHashMap);
    }

    @Override // org.jetbrains.kotlinx.dl.api.core.TrainableModel
    @NotNull
    public int[] predict(@NotNull Dataset dataset, final int i, @NotNull final List<? extends Callback> list) {
        Intrinsics.checkNotNullParameter(dataset, "dataset");
        Intrinsics.checkNotNullParameter(list, "callbacks");
        if (!(dataset.xSize() % i == 0)) {
            throw new IllegalArgumentException("The amount of images must be a multiple of batch size.".toString());
        }
        if (!isModelCompiled()) {
            throw new IllegalStateException("The model is not compiled yet. Compile the model to use this method.".toString());
        }
        if (!isModelInitialized()) {
            throw new IllegalStateException("The model is not initialized yet. Initialize the model weights with init() method or load weights to use this method.".toString());
        }
        Iterator<T> it = list.iterator();
        while (it.hasNext()) {
            ((Callback) it.next()).setModel$tensorflow(this);
        }
        Iterator<T> it2 = list.iterator();
        while (it2.hasNext()) {
            ((Callback) it2.next()).onPredictBegin();
        }
        final long[] calculateXShape = calculateXShape(i);
        int xSize = dataset.xSize();
        final int[] iArr = new int[xSize];
        for (int i2 = 0; i2 < xSize; i2++) {
            iArr[i2] = Integer.MIN_VALUE;
        }
        Dataset.BatchIterator batchIterator = dataset.batchIterator(i);
        final Ref.IntRef intRef = new Ref.IntRef();
        while (batchIterator.hasNext()) {
            Iterator<T> it3 = list.iterator();
            while (it3.hasNext()) {
                ((Callback) it3.next()).onPredictBatchBegin(intRef.element, i);
            }
            Tensor tensor = (AutoCloseable) Tensor.create(calculateXShape, ConvertersKt.serializeToBuffer(batchIterator.next().getX()));
            Throwable th = null;
            try {
                try {
                    Tensor tensor2 = tensor;
                    Tensor tensor3 = (AutoCloseable) Tensor.create(false);
                    Throwable th2 = null;
                    try {
                        try {
                            Tensor tensor4 = tensor3;
                            Session.Runner runner = getSession$tensorflow().runner();
                            Operand<Float> operand = this.predictionOp;
                            if (operand == null) {
                                Intrinsics.throwUninitializedPropertyAccessException("predictionOp");
                                operand = null;
                            }
                            Session.Runner fetch = runner.fetch(operand);
                            Operand<Float> operand2 = this.xOp;
                            if (operand2 == null) {
                                Intrinsics.throwUninitializedPropertyAccessException("xOp");
                                operand2 = null;
                            }
                            List run = fetch.feed(operand2.asOutput(), tensor2).feed(getTraining().asOutput(), tensor4).run();
                            Intrinsics.checkNotNullExpressionValue(run, "session.runner()\n       …                   .run()");
                            AutoCloseableKt.closeFinally(tensor3, (Throwable) null);
                            AutoCloseableKt.closeFinally(tensor, (Throwable) null);
                        } finally {
                        }
                    } finally {
                    }
                } finally {
                }
            } catch (Throwable th3) {
                AutoCloseableKt.closeFinally(tensor, th);
                throw th3;
            }
        }
        Iterator<T> it4 = list.iterator();
        while (it4.hasNext()) {
            ((Callback) it4.next()).onPredictEnd();
        }
        return iArr;
    }

    @Override // org.jetbrains.kotlinx.dl.api.inference.TensorFlowInferenceModel
    public int predict(@NotNull float[] fArr) {
        Intrinsics.checkNotNullParameter(fArr, "inputData");
        return FloatArrayExtensionFunctionsKt.argmax(InferenceModel.DefaultImpls.predictSoftly$default(this, fArr, (String) null, 2, (Object) null));
    }

    @Override // org.jetbrains.kotlinx.dl.api.core.TrainableModel
    public int predict(@NotNull float[] fArr, @NotNull String str) {
        Intrinsics.checkNotNullParameter(fArr, "inputData");
        Intrinsics.checkNotNullParameter(str, "predictionTensorName");
        return FloatArrayExtensionFunctionsKt.argmax(predictSoftly(fArr, str));
    }

    @Override // org.jetbrains.kotlinx.dl.api.core.TrainableModel
    @NotNull
    public Pair<Integer, List<?>> predictAndGetActivations(@NotNull float[] fArr, @NotNull String str) {
        Intrinsics.checkNotNullParameter(fArr, "inputData");
        Intrinsics.checkNotNullParameter(str, "predictionTensorName");
        Pair<float[], List<?>> internalPredict = internalPredict(fArr, true, str);
        float[] fArr2 = (float[]) internalPredict.component1();
        return new Pair<>(Integer.valueOf(FloatArrayExtensionFunctionsKt.argmax(fArr2)), (List) internalPredict.component2());
    }

    /* JADX WARN: Type inference failed for: r0v31, types: [float[], float[][]] */
    @Override // org.jetbrains.kotlinx.dl.api.core.TrainableModel
    @NotNull
    public float[][] predictSoftly(@NotNull Dataset dataset, final int i, @NotNull final List<? extends Callback> list) {
        Intrinsics.checkNotNullParameter(dataset, "dataset");
        Intrinsics.checkNotNullParameter(list, "callbacks");
        if (!(dataset.xSize() % i == 0)) {
            throw new IllegalArgumentException("The amount of images must be a multiple of batch size.".toString());
        }
        if (!isModelCompiled()) {
            throw new IllegalStateException("The model is not compiled yet. Compile the model to use this method.".toString());
        }
        if (!isModelInitialized()) {
            throw new IllegalStateException("The model is not initialized yet. Initialize the model weights with init() method or load weights to use this method.".toString());
        }
        Iterator<T> it = list.iterator();
        while (it.hasNext()) {
            ((Callback) it.next()).setModel$tensorflow(this);
        }
        Iterator<T> it2 = list.iterator();
        while (it2.hasNext()) {
            ((Callback) it2.next()).onPredictBegin();
        }
        final long[] calculateXShape = calculateXShape(i);
        int xSize = dataset.xSize();
        final ?? r0 = new float[xSize];
        for (int i2 = 0; i2 < xSize; i2++) {
            int i3 = i2;
            int numberOfClasses = (int) getNumberOfClasses();
            float[] fArr = new float[numberOfClasses];
            for (int i4 = 0; i4 < numberOfClasses; i4++) {
                fArr[i4] = 0.0f;
            }
            r0[i3] = fArr;
        }
        Dataset.BatchIterator batchIterator = dataset.batchIterator(i);
        final Ref.IntRef intRef = new Ref.IntRef();
        while (batchIterator.hasNext()) {
            Iterator<T> it3 = list.iterator();
            while (it3.hasNext()) {
                ((Callback) it3.next()).onPredictBatchBegin(intRef.element, i);
            }
            Tensor tensor = (AutoCloseable) Tensor.create(calculateXShape, ConvertersKt.serializeToBuffer(batchIterator.next().getX()));
            Throwable th = null;
            try {
                try {
                    Tensor tensor2 = tensor;
                    Session.Runner runner = getSession$tensorflow().runner();
                    Operand<Float> operand = this.predictionOp;
                    if (operand == null) {
                        Intrinsics.throwUninitializedPropertyAccessException("predictionOp");
                        operand = null;
                    }
                    Session.Runner fetch = runner.fetch(operand);
                    Operand<Float> operand2 = this.xOp;
                    if (operand2 == null) {
                        Intrinsics.throwUninitializedPropertyAccessException("xOp");
                        operand2 = null;
                    }
                    List run = fetch.feed(operand2.asOutput(), tensor2).run();
                    Intrinsics.checkNotNullExpressionValue(run, "session.runner()\n       …                   .run()");
                    AutoCloseableKt.closeFinally(tensor, (Throwable) null);
                } finally {
                }
            } catch (Throwable th2) {
                AutoCloseableKt.closeFinally(tensor, th);
                throw th2;
            }
        }
        Iterator<T> it4 = list.iterator();
        while (it4.hasNext()) {
            ((Callback) it4.next()).onPredictEnd();
        }
        return r0;
    }

    @Override // org.jetbrains.kotlinx.dl.api.inference.TensorFlowInferenceModel
    @NotNull
    public float[] predictSoftly(@NotNull float[] fArr, @NotNull String str) {
        Intrinsics.checkNotNullParameter(fArr, "inputData");
        Intrinsics.checkNotNullParameter(str, "predictionTensorName");
        return (float[]) internalPredict(fArr, false, str).component1();
    }

    @Override // org.jetbrains.kotlinx.dl.api.core.TrainableModel
    @NotNull
    protected Pair<float[], List<?>> predictSoftlyAndGetActivations(@NotNull float[] fArr, @NotNull String str) {
        Intrinsics.checkNotNullParameter(fArr, "inputData");
        Intrinsics.checkNotNullParameter(str, "predictionTensorName");
        return internalPredict(fArr, true, str);
    }

    private final Pair<float[], List<?>> internalPredict(float[] fArr, final boolean z, String str) {
        if (!isModelCompiled()) {
            throw new IllegalStateException("The model is not compiled yet. Compile the model to use this method.".toString());
        }
        if (!isModelInitialized()) {
            throw new IllegalStateException("The model is not initialized yet. Initialize the model weights with init() method or load weights to use this method.".toString());
        }
        AutoCloseable create = Tensor.create(calculateXShape(1), FloatBuffer.wrap(fArr));
        Throwable th = null;
        try {
            try {
                Tensor<Float> tensor = (Tensor) create;
                Intrinsics.checkNotNullExpressionValue(tensor, "testImages");
                Pair<float[], List<?>> pair = (Pair) AutoClosableExtensionsKt.use(formPredictionAndActivationsTensors(str, tensor, z), new Function1<List<? extends Tensor<?>>, Pair<? extends float[], ? extends List<Object>>>() { // from class: org.jetbrains.kotlinx.dl.api.core.GraphTrainableModel$internalPredict$3$1
                    /* JADX INFO: Access modifiers changed from: package-private */
                    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
                    {
                        super(1);
                    }

                    @NotNull
                    public final Pair<float[], List<Object>> invoke(@NotNull List<? extends Tensor<?>> list) {
                        Intrinsics.checkNotNullParameter(list, "tensors");
                        float[] convertTensorToFlattenFloatArray = TensorExtensionFunctionsKt.convertTensorToFlattenFloatArray(list.get(0));
                        ArrayList arrayList = new ArrayList();
                        if (z && list.size() > 1) {
                            int size = list.size();
                            for (int i = 1; i < size; i++) {
                                arrayList.add(TensorExtensionFunctionsKt.convertTensorToMultiDimArray(list.get(i)));
                            }
                        }
                        return TuplesKt.to(convertTensorToFlattenFloatArray, arrayList);
                    }
                });
                AutoCloseableKt.closeFinally(create, (Throwable) null);
                return pair;
            } finally {
            }
        } catch (Throwable th2) {
            AutoCloseableKt.closeFinally(create, th);
            throw th2;
        }
    }

    private final List<Tensor<?>> formPredictionAndActivationsTensors(String str, Tensor<Float> tensor, boolean z) {
        Session.Runner runner = getSession$tensorflow().runner();
        if (str.length() == 0) {
            Operand<Float> operand = this.predictionOp;
            if (operand == null) {
                Intrinsics.throwUninitializedPropertyAccessException("predictionOp");
                operand = null;
            }
            Session.Runner fetch = runner.fetch(operand);
            Operand<Float> operand2 = this.xOp;
            if (operand2 == null) {
                Intrinsics.throwUninitializedPropertyAccessException("xOp");
                operand2 = null;
            }
            fetch.feed(operand2.asOutput(), tensor);
        } else {
            if (!(kGraph().getTfGraph$tensorflow().operation(str) != null)) {
                throw new IllegalArgumentException(("No such tensor output named [" + str + "] in the TensorFlow graph!").toString());
            }
            Session.Runner fetch2 = runner.fetch(str);
            Operand<Float> operand3 = this.xOp;
            if (operand3 == null) {
                Intrinsics.throwUninitializedPropertyAccessException("xOp");
                operand3 = null;
            }
            fetch2.feed(operand3.asOutput(), tensor);
        }
        if (z) {
            for (Layer layer : this.layers) {
                if (layer.getHasActivation() && !Intrinsics.areEqual(layer, CollectionsKt.last(this.layers))) {
                    runner.fetch(NameConventionsKt.defaultActivationName(layer));
                }
            }
        }
        List<Tensor<?>> run = runner.run();
        Intrinsics.checkNotNullExpressionValue(run, "runner.run()");
        return run;
    }

    private final Pair<long[], long[]> calculateXYShapes(DataBatch dataBatch) {
        int size = dataBatch.getSize();
        long[] calculateXShape = calculateXShape(size);
        long[] calculateYShape = calculateYShape(size);
        if (size > 0) {
            batchValidation(dataBatch, calculateXShape, calculateYShape);
        }
        return new Pair<>(calculateXShape, calculateYShape);
    }

    private final long[] calculateYShape(int i) {
        return new long[]{i, getNumberOfClasses()};
    }

    private final void batchValidation(DataBatch dataBatch, long[] jArr, long[] jArr2) {
        if (!(((int) new TensorShape(jArr).numElements()) == dataBatch.getX().length * dataBatch.getX()[0].length)) {
            StringBuilder append = new StringBuilder().append("The calculated [from the Model] data batch shape ");
            String arrays = Arrays.toString(jArr);
            Intrinsics.checkNotNullExpressionValue(arrays, "toString(this)");
            throw new IllegalStateException(append.append(arrays).append(" doesn't match actual data buffer size ").append(dataBatch.getX().length * dataBatch.getX()[0].length).append(". Please, check input data.").toString().toString());
        }
        if (((int) new TensorShape(jArr2).numElements()) == dataBatch.getY().length * ((int) getNumberOfClasses())) {
            return;
        }
        StringBuilder append2 = new StringBuilder().append("The calculated [from the model] label batch shape ");
        String arrays2 = Arrays.toString(jArr2);
        Intrinsics.checkNotNullExpressionValue(arrays2, "toString(this)");
        throw new IllegalStateException(append2.append(arrays2).append(" doesn't match actual data buffer size ").append(dataBatch.getY().length * ((int) getNumberOfClasses())).append(". \nPlease, check the input label data or correct number of classes [number of neurons] in last Dense layer, if you have a classification problem.\nHighly likely, you have different number of classes presented in data and described in model as desired output.").toString().toString());
    }

    private final long[] calculateXShape(int i) {
        Object first = CollectionsKt.first(this.layers);
        Intrinsics.checkNotNull(first, "null cannot be cast to non-null type org.jetbrains.kotlinx.dl.api.core.layer.core.Input");
        Shape shape = ((Input) first).getInput().asOutput().shape();
        LongSpreadBuilder longSpreadBuilder = new LongSpreadBuilder(2);
        longSpreadBuilder.add(i);
        Intrinsics.checkNotNullExpressionValue(shape, "xTensorShape");
        longSpreadBuilder.addSpread(ShapeFunctionsKt.tail(shape));
        return longSpreadBuilder.toArray();
    }

    @NotNull
    public final KGraph kGraph() {
        return getKGraph();
    }

    @Override // org.jetbrains.kotlinx.dl.api.core.TrainableModel
    public void save(@NotNull File file, @NotNull SavingFormat savingFormat, boolean z, @NotNull WritingMode writingMode) {
        Intrinsics.checkNotNullParameter(file, "modelDirectory");
        Intrinsics.checkNotNullParameter(savingFormat, "savingFormat");
        Intrinsics.checkNotNullParameter(writingMode, "writingMode");
        if (!isModelCompiled()) {
            throw new IllegalStateException("The model is not compiled yet. Compile the model to use this method.".toString());
        }
        if (!isModelInitialized()) {
            throw new IllegalStateException("The model is not initialized yet. Initialize the model weights with init() method or load weights to use this method.".toString());
        }
        if (z && !isOptimizerVariableInitialized()) {
            throw new IllegalStateException("The optimizer variables are not initialized yet. Initialize the optimizer variables with init() method or load optimizer weights to use this method.".toString());
        }
        String absolutePath = file.getAbsolutePath();
        switch (WhenMappings.$EnumSwitchMapping$0[writingMode.ordinal()]) {
            case Conv1D.EXTRA_DIM /* 1 */:
                if (!(!file.exists())) {
                    throw new IllegalStateException(("The directory exists on path " + absolutePath + ", please be careful it could contain valuable model! Change this mode to OVERRIDE if you want to override this directory.").toString());
                }
                Files.createDirectories(file.toPath(), new FileAttribute[0]);
                file.mkdir();
                break;
            case 2:
                if (file.exists()) {
                    FilesKt.deleteRecursively(file);
                }
                Files.createDirectories(file.toPath(), new FileAttribute[0]);
                file.mkdir();
                break;
            case 3:
                if (!file.exists()) {
                    Files.createDirectories(file.toPath(), new FileAttribute[0]);
                    file.mkdir();
                    break;
                }
                break;
        }
        switch (WhenMappings.$EnumSwitchMapping$1[savingFormat.ordinal()]) {
            case Conv1D.EXTRA_DIM /* 1 */:
                Intrinsics.checkNotNullExpressionValue(absolutePath, "pathToModelDirectory");
                saveInSimpleFormat(absolutePath, z);
                return;
            case 2:
                Intrinsics.checkNotNullExpressionValue(absolutePath, "pathToModelDirectory");
                saveInSavedModelFormat(absolutePath);
                return;
            case 3:
                Intrinsics.checkNotNullExpressionValue(absolutePath, "pathToModelDirectory");
                saveInKerasFormat(absolutePath, z);
                return;
            default:
                return;
        }
    }

    private final void saveInKerasFormat(String str, boolean z) {
        saveModel(str);
        saveVariables(str, z);
    }

    private final void saveModel(String str) {
        ModelSaverKt.saveModelConfiguration$default(this, new File(str + "/modelConfig.json"), false, 2, null);
    }

    private final void saveInSavedModelFormat(String str) {
        saveGraphDef(str);
    }

    private final void saveInSimpleFormat(String str, boolean z) {
        saveGraphDef(str);
        saveVariables(str, z);
    }

    private final void saveGraphDef(String str) {
        File file = new File(str + "/graph.pb");
        Files.createDirectories(Paths.get(str, new String[0]), new FileAttribute[0]);
        byte[] graphDef = getKGraph().getTfGraph$tensorflow().toGraphDef();
        Intrinsics.checkNotNullExpressionValue(graphDef, "kGraph.tfGraph.toGraphDef()");
        FilesKt.writeBytes(file, graphDef);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public final void saveVariables(@NotNull String str, boolean z) {
        Intrinsics.checkNotNullParameter(str, "pathToModelDirectory");
        List<Pair<Variable<Float>, Tensor<?>>> variablesAndTensors = getVariablesAndTensors(z);
        Files.createDirectories(Paths.get(str, new String[0]), new FileAttribute[0]);
        Writer outputStreamWriter = new OutputStreamWriter(new FileOutputStream(new File(str + "/variableNames.txt")), Charsets.UTF_8);
        BufferedWriter bufferedWriter = outputStreamWriter instanceof BufferedWriter ? (BufferedWriter) outputStreamWriter : new BufferedWriter(outputStreamWriter, 8192);
        try {
            BufferedWriter bufferedWriter2 = bufferedWriter;
            for (Pair<Variable<Float>, Tensor<?>> pair : variablesAndTensors) {
                Variable variable = (Variable) pair.component1();
                Tensor tensor = (Tensor) pair.component2();
                String name = variable.asOutput().op().name();
                bufferedWriter2.write(name);
                bufferedWriter2.newLine();
                Writer outputStreamWriter2 = new OutputStreamWriter(new FileOutputStream(new File(str + '/' + name + ".txt")), Charsets.UTF_8);
                BufferedWriter bufferedWriter3 = outputStreamWriter2 instanceof BufferedWriter ? (BufferedWriter) outputStreamWriter2 : new BufferedWriter(outputStreamWriter2, 8192);
                Throwable th = null;
                try {
                    try {
                        BufferedWriter bufferedWriter4 = bufferedWriter3;
                        Tensor tensor2 = (AutoCloseable) tensor;
                        Throwable th2 = null;
                        try {
                            try {
                                Tensor tensor3 = tensor2;
                                float[] convertTensorToFlattenFloatArray = TensorExtensionFunctionsKt.convertTensorToFlattenFloatArray(tensor);
                                int i = 0;
                                int length = convertTensorToFlattenFloatArray.length - 2;
                                if (0 <= length) {
                                    while (true) {
                                        bufferedWriter4.write(new StringBuilder().append(convertTensorToFlattenFloatArray[i]).append(' ').toString());
                                        if (i == length) {
                                            break;
                                        } else {
                                            i++;
                                        }
                                    }
                                }
                                bufferedWriter4.write(String.valueOf(convertTensorToFlattenFloatArray[convertTensorToFlattenFloatArray.length - 1]));
                                bufferedWriter4.flush();
                                Unit unit = Unit.INSTANCE;
                                AutoCloseableKt.closeFinally(tensor2, (Throwable) null);
                                Unit unit2 = Unit.INSTANCE;
                                CloseableKt.closeFinally(bufferedWriter3, (Throwable) null);
                                bufferedWriter2.flush();
                            } finally {
                            }
                        } finally {
                        }
                    } finally {
                    }
                } catch (Throwable th3) {
                    CloseableKt.closeFinally(bufferedWriter3, th);
                    throw th3;
                }
            }
            Unit unit3 = Unit.INSTANCE;
            CloseableKt.closeFinally(bufferedWriter, (Throwable) null);
        } catch (Throwable th4) {
            CloseableKt.closeFinally(bufferedWriter, (Throwable) null);
            throw th4;
        }
    }

    private final List<Pair<Variable<Float>, Tensor<?>>> getVariablesAndTensors(boolean z) {
        List<KVariable> layerVariables = layerVariables();
        Collection arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(layerVariables, 10));
        Iterator<T> it = layerVariables.iterator();
        while (it.hasNext()) {
            arrayList.add(((KVariable) it.next()).getVariable());
        }
        Collection collection = (List) arrayList;
        if (z) {
            collection = CollectionsKt.plus(collection, getKGraph().optimizerVariables());
        }
        Session.Runner runner = getSession$tensorflow().runner();
        Iterator it2 = collection.iterator();
        while (it2.hasNext()) {
            runner.fetch((Operand) it2.next());
        }
        List run = runner.run();
        Intrinsics.checkNotNullExpressionValue(run, "modelWeightsExtractorRunner.run()");
        return CollectionsKt.zip(collection, run);
    }

    @Override // org.jetbrains.kotlinx.dl.api.core.TrainableModel
    public void loadWeights(@NotNull File file, final boolean z) {
        Intrinsics.checkNotNullParameter(file, "modelDirectory");
        if (!isModelCompiled()) {
            throw new IllegalStateException("The model is not compiled yet. Compile the model to use this method.".toString());
        }
        if (!(!isModelInitialized())) {
            throw new IllegalStateException("The model is initialized already.".toString());
        }
        Files.createDirectories(file.toPath(), new FileAttribute[0]);
        String path = file.getPath();
        Intrinsics.checkNotNullExpressionValue(path, "modelDirectory.path");
        loadVariablesFromTxt(path, new Function1<String, Boolean>() { // from class: org.jetbrains.kotlinx.dl.api.core.GraphTrainableModel$loadWeights$3
            /* JADX INFO: Access modifiers changed from: package-private */
            /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
            {
                super(1);
            }

            @NotNull
            public final Boolean invoke(@NotNull String str) {
                boolean isOptimizerVariable;
                boolean z2;
                boolean isVariableRelatedToFrozenLayer;
                Intrinsics.checkNotNullParameter(str, "variableName");
                isOptimizerVariable = GraphTrainableModel.this.isOptimizerVariable(str);
                if (!isOptimizerVariable) {
                    z2 = true;
                } else if (z) {
                    isVariableRelatedToFrozenLayer = GraphTrainableModel.this.isVariableRelatedToFrozenLayer(str);
                    z2 = !isVariableRelatedToFrozenLayer;
                } else {
                    z2 = false;
                }
                return Boolean.valueOf(z2);
            }
        });
        setModelInitialized$tensorflow(true);
        if (z) {
            setOptimizerVariableInitialized$tensorflow(true);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public final boolean isVariableRelatedToFrozenLayer(String str) {
        List<KVariable> frozenLayerVariables = frozenLayerVariables();
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(frozenLayerVariables, 10));
        Iterator<T> it = frozenLayerVariables.iterator();
        while (it.hasNext()) {
            arrayList.add(((KVariable) it.next()).getName());
        }
        ArrayList arrayList2 = arrayList;
        if ((arrayList2 instanceof Collection) && arrayList2.isEmpty()) {
            return false;
        }
        Iterator it2 = arrayList2.iterator();
        while (it2.hasNext()) {
            if (StringsKt.contains$default(str, (String) it2.next(), false, 2, (Object) null)) {
                return true;
            }
        }
        return false;
    }

    @Override // org.jetbrains.kotlinx.dl.api.inference.TensorFlowInferenceModel
    protected void loadVariables(@NotNull Collection<String> collection, @NotNull Function2<? super String, ? super Shape, ? extends Object> function2) {
        Intrinsics.checkNotNullParameter(collection, "variableNames");
        Intrinsics.checkNotNullParameter(function2, "getData");
        List<KVariable> layerVariables = layerVariables();
        LinkedHashMap linkedHashMap = new LinkedHashMap(RangesKt.coerceAtLeast(MapsKt.mapCapacity(CollectionsKt.collectionSizeOrDefault(layerVariables, 10)), 16));
        for (Object obj : layerVariables) {
            linkedHashMap.put(((KVariable) obj).getName(), obj);
        }
        for (String str : collection) {
            GraphOperation operation = getKGraph().getTfGraph$tensorflow().operation(str);
            if (!(operation != null)) {
                throw new IllegalStateException(("Operation " + str + " is not found in static graph.").toString());
            }
            Shape shape = operation.output(0).shape();
            Intrinsics.checkNotNullExpressionValue(shape, "variableShape");
            Object invoke = function2.invoke(str, shape);
            KVariable kVariable = (KVariable) linkedHashMap.get(str);
            if (kVariable != null) {
                fill$tensorflow(kVariable, invoke);
            } else {
                assignVariable(str, shape, invoke);
            }
        }
    }

    public final void fill$tensorflow(@NotNull KVariable kVariable, @NotNull Object obj) {
        Intrinsics.checkNotNullParameter(kVariable, "variable");
        Intrinsics.checkNotNullParameter(obj, "data");
        kVariable.getInitializerOperation().fill(obj, getSession$tensorflow());
    }

    public final void init$tensorflow(@NotNull KVariable kVariable) {
        Intrinsics.checkNotNullParameter(kVariable, "variable");
        kVariable.getInitializerOperation().run(getSession$tensorflow());
    }

    @NotNull
    public final Layer getLayer(@NotNull String str) {
        Intrinsics.checkNotNullParameter(str, "layerName");
        Layer layer = this.layersByName.get(str);
        if (layer == null) {
            throw new IllegalStateException(("No such layer " + str + " in the model.").toString());
        }
        return layer;
    }

    @Override // org.jetbrains.kotlinx.dl.api.core.TrainableModel, org.jetbrains.kotlinx.dl.api.inference.TensorFlowInferenceModel
    @NotNull
    public String toString() {
        return "GraphTrainableModel(numberOfLayers=" + this.layers.size() + ") " + super.toString();
    }

    @NotNull
    /* renamed from: summary, reason: merged with bridge method [inline-methods] */
    public TfModelSummary m4summary() {
        if (!isModelCompiled()) {
            throw new IllegalStateException("The model is not compiled yet. Compile the model to use this method.".toString());
        }
        List<? extends Layer> list = this.layers;
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (Object obj : list) {
            if (TrainableLayerKt.isTrainable((Layer) obj)) {
                arrayList.add(obj);
            } else {
                arrayList2.add(obj);
            }
        }
        Pair pair = new Pair(arrayList, arrayList2);
        List list2 = (List) pair.component1();
        List list3 = (List) pair.component2();
        String valueOf = String.valueOf(Reflection.getOrCreateKotlinClass(getClass()).getSimpleName());
        String name = getName();
        List<? extends Layer> list4 = this.layers;
        ArrayList arrayList3 = new ArrayList(CollectionsKt.collectionSizeOrDefault(list4, 10));
        for (Layer layer : list4) {
            String name2 = layer.getName();
            String valueOf2 = String.valueOf(Reflection.getOrCreateKotlinClass(layer.getClass()).getSimpleName());
            TensorShape outputShape = layer.getOutputShape();
            long paramCount = ParametrizedLayerKt.getParamCount(layer);
            List<Layer> inboundLayers = layer.getInboundLayers();
            ArrayList arrayList4 = new ArrayList(CollectionsKt.collectionSizeOrDefault(inboundLayers, 10));
            Iterator<T> it = inboundLayers.iterator();
            while (it.hasNext()) {
                arrayList4.add(((Layer) it.next()).getName());
            }
            arrayList3.add(new LayerSummary(name2, valueOf2, outputShape, paramCount, arrayList4));
        }
        ArrayList arrayList5 = arrayList3;
        long j = 0;
        while (list2.iterator().hasNext()) {
            j += ParametrizedLayerKt.getParamCount((Layer) r0.next());
        }
        long j2 = j;
        long j3 = 0;
        while (list3.iterator().hasNext()) {
            j3 += ParametrizedLayerKt.getParamCount((Layer) r0.next());
        }
        return new TfModelSummary(valueOf, name, arrayList5, j2, j3);
    }
}
