package org.deeplearning4j.util;

import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.util.zip.ZipEntry;
import java.util.zip.ZipFile;
import java.util.zip.ZipOutputStream;
import lombok.NonNull;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/util/ModelSerializer.class */
public class ModelSerializer {
    public static void writeModel(@NonNull Model model, @NonNull File file, boolean z) throws IOException {
        if (model == null) {
            throw new NullPointerException("model");
        }
        if (file == null) {
            throw new NullPointerException("file");
        }
        BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(new FileOutputStream(file));
        Throwable th = null;
        try {
            writeModel(model, bufferedOutputStream, z);
            if (bufferedOutputStream != null) {
                if (0 == 0) {
                    bufferedOutputStream.close();
                    return;
                }
                try {
                    bufferedOutputStream.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
        } catch (Throwable th3) {
            if (bufferedOutputStream != null) {
                if (0 != 0) {
                    try {
                        bufferedOutputStream.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    bufferedOutputStream.close();
                }
            }
            throw th3;
        }
    }

    public static void writeModel(@NonNull Model model, @NonNull String str, boolean z) throws IOException {
        if (model == null) {
            throw new NullPointerException("model");
        }
        if (str == null) {
            throw new NullPointerException("path");
        }
        BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(new FileOutputStream(str));
        Throwable th = null;
        try {
            writeModel(model, bufferedOutputStream, z);
            if (bufferedOutputStream != null) {
                if (0 == 0) {
                    bufferedOutputStream.close();
                    return;
                }
                try {
                    bufferedOutputStream.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
        } catch (Throwable th3) {
            if (bufferedOutputStream != null) {
                if (0 != 0) {
                    try {
                        bufferedOutputStream.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    bufferedOutputStream.close();
                }
            }
            throw th3;
        }
    }

    public static void writeModel(@NonNull Model model, @NonNull OutputStream outputStream, boolean z) throws IOException {
        if (model == null) {
            throw new NullPointerException("model");
        }
        if (outputStream == null) {
            throw new NullPointerException("stream");
        }
        ZipOutputStream zipOutputStream = new ZipOutputStream(outputStream);
        String str = "";
        if (model instanceof MultiLayerNetwork) {
            str = ((MultiLayerNetwork) model).getLayerWiseConfigurations().toJson();
        } else if (model instanceof ComputationGraph) {
            str = ((ComputationGraph) model).getConfiguration().toJson();
        }
        zipOutputStream.putNextEntry(new ZipEntry("configuration.json"));
        writeEntry(new ByteArrayInputStream(str.getBytes()), zipOutputStream);
        zipOutputStream.putNextEntry(new ZipEntry("coefficients.bin"));
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        DataOutputStream dataOutputStream = new DataOutputStream(byteArrayOutputStream);
        Nd4j.write(model.params(), dataOutputStream);
        dataOutputStream.flush();
        dataOutputStream.close();
        writeEntry(new ByteArrayInputStream(byteArrayOutputStream.toByteArray()), zipOutputStream);
        if (z) {
            zipOutputStream.putNextEntry(new ZipEntry("updater.bin"));
            ByteArrayOutputStream byteArrayOutputStream2 = new ByteArrayOutputStream();
            ObjectOutputStream objectOutputStream = new ObjectOutputStream(byteArrayOutputStream2);
            if (model instanceof MultiLayerNetwork) {
                objectOutputStream.writeObject(((MultiLayerNetwork) model).getUpdater());
            } else if (model instanceof ComputationGraph) {
                objectOutputStream.writeObject(((ComputationGraph) model).getUpdater());
            }
            objectOutputStream.flush();
            objectOutputStream.close();
            writeEntry(new ByteArrayInputStream(byteArrayOutputStream2.toByteArray()), zipOutputStream);
        }
        zipOutputStream.flush();
        zipOutputStream.close();
    }

    private static void writeEntry(InputStream inputStream, ZipOutputStream zipOutputStream) throws IOException {
        byte[] bArr = new byte[1024];
        while (true) {
            int read = inputStream.read(bArr);
            if (read == -1) {
                return;
            } else {
                zipOutputStream.write(bArr, 0, read);
            }
        }
    }

    public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull File file) throws IOException {
        if (file == null) {
            throw new NullPointerException("file");
        }
        ZipFile zipFile = new ZipFile(file);
        boolean z = false;
        boolean z2 = false;
        boolean z3 = false;
        String str = "";
        INDArray iNDArray = null;
        Updater updater = null;
        ZipEntry entry = zipFile.getEntry("configuration.json");
        if (entry != null) {
            InputStream inputStream = zipFile.getInputStream(entry);
            BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(inputStream));
            StringBuilder sb = new StringBuilder();
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    break;
                }
                sb.append(readLine).append("\n");
            }
            str = sb.toString();
            bufferedReader.close();
            inputStream.close();
            z = true;
        }
        ZipEntry entry2 = zipFile.getEntry("coefficients.bin");
        if (entry2 != null) {
            DataInputStream dataInputStream = new DataInputStream(zipFile.getInputStream(entry2));
            iNDArray = Nd4j.read(dataInputStream);
            dataInputStream.close();
            z2 = true;
        }
        ZipEntry entry3 = zipFile.getEntry("updater.bin");
        if (entry3 != null) {
            try {
                updater = (Updater) new ObjectInputStream(zipFile.getInputStream(entry3)).readObject();
                z3 = true;
            } catch (ClassNotFoundException e) {
                throw new RuntimeException(e);
            }
        }
        zipFile.close();
        if (!z || !z2) {
            throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + z + "], gotCoefficients: [" + z2 + "], gotUpdater: [" + z3 + "]");
        }
        MultiLayerNetwork multiLayerNetwork = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(str));
        multiLayerNetwork.init();
        multiLayerNetwork.setParameters(iNDArray);
        if (z3 && updater != null) {
            multiLayerNetwork.setUpdater(updater);
        }
        return multiLayerNetwork;
    }

    public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull String str) throws IOException {
        if (str == null) {
            throw new NullPointerException("path");
        }
        return restoreMultiLayerNetwork(new File(str));
    }

    public static ComputationGraph restoreComputationGraph(@NonNull String str) throws IOException {
        if (str == null) {
            throw new NullPointerException("path");
        }
        return restoreComputationGraph(new File(str));
    }

    public static ComputationGraph restoreComputationGraph(@NonNull File file) throws IOException {
        if (file == null) {
            throw new NullPointerException("file");
        }
        ZipFile zipFile = new ZipFile(file);
        boolean z = false;
        boolean z2 = false;
        boolean z3 = false;
        String str = "";
        INDArray iNDArray = null;
        ComputationGraphUpdater computationGraphUpdater = null;
        ZipEntry entry = zipFile.getEntry("configuration.json");
        if (entry != null) {
            InputStream inputStream = zipFile.getInputStream(entry);
            BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(inputStream));
            StringBuilder sb = new StringBuilder();
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    break;
                }
                sb.append(readLine).append("\n");
            }
            str = sb.toString();
            bufferedReader.close();
            inputStream.close();
            z = true;
        }
        ZipEntry entry2 = zipFile.getEntry("coefficients.bin");
        if (entry2 != null) {
            DataInputStream dataInputStream = new DataInputStream(zipFile.getInputStream(entry2));
            iNDArray = Nd4j.read(dataInputStream);
            dataInputStream.close();
            z2 = true;
        }
        ZipEntry entry3 = zipFile.getEntry("updater.bin");
        if (entry3 != null) {
            try {
                computationGraphUpdater = (ComputationGraphUpdater) new ObjectInputStream(zipFile.getInputStream(entry3)).readObject();
                z3 = true;
            } catch (ClassNotFoundException e) {
                throw new RuntimeException(e);
            }
        }
        zipFile.close();
        if (!z || !z2) {
            throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + z + "], gotCoefficients: [" + z2 + "], gotUpdater: [" + z3 + "]");
        }
        ComputationGraph computationGraph = new ComputationGraph(ComputationGraphConfiguration.fromJson(str));
        computationGraph.init();
        computationGraph.setParams(iNDArray);
        if (z3 && computationGraphUpdater != null) {
            computationGraph.setUpdater(computationGraphUpdater);
        }
        return computationGraph;
    }

    /* JADX WARN: Code restructure failed: missing block: B:37:0x005a, code lost:
    
        r0.setArchitectureType(org.nd4j.linalg.heartbeat.reports.Task.ArchitectureType.RBM);
     */
    /* JADX WARN: Code restructure failed: missing block: B:65:0x014e, code lost:
    
        r0.setArchitectureType(org.nd4j.linalg.heartbeat.reports.Task.ArchitectureType.RECURRENT);
     */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public static org.nd4j.linalg.heartbeat.reports.Task taskByModel(org.deeplearning4j.nn.api.Model r3) {
        /*
            Method dump skipped, instructions count: 383
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: org.deeplearning4j.util.ModelSerializer.taskByModel(org.deeplearning4j.nn.api.Model):org.nd4j.linalg.heartbeat.reports.Task");
    }
}
