package org.deeplearning4j.graph.models.loader;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.LineIterator;
import org.deeplearning4j.graph.models.GraphVectors;
import org.deeplearning4j.graph.models.deepwalk.DeepWalk;
import org.deeplearning4j.graph.models.embeddings.GraphVectorsImpl;
import org.deeplearning4j.graph.models.embeddings.InMemoryGraphLookupTable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/graph/models/loader/GraphVectorSerializer.class */
public class GraphVectorSerializer {
    private static final Logger log = LoggerFactory.getLogger(GraphVectorSerializer.class);
    private static final String DELIM = "\t";

    private GraphVectorSerializer() {
    }

    public static void writeGraphVectors(DeepWalk deepWalk, String str) throws IOException {
        int numVertices = deepWalk.numVertices();
        int vectorSize = deepWalk.getVectorSize();
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File(str), false));
        Throwable th = null;
        for (int i = 0; i < numVertices; i++) {
            try {
                try {
                    StringBuilder sb = new StringBuilder();
                    sb.append(i);
                    INDArray vertexVector = deepWalk.getVertexVector(i);
                    for (int i2 = 0; i2 < vectorSize; i2++) {
                        sb.append(DELIM).append(vertexVector.getDouble(i2));
                    }
                    sb.append("\n");
                    bufferedWriter.write(sb.toString());
                } finally {
                }
            } catch (Throwable th2) {
                if (bufferedWriter != null) {
                    if (th != null) {
                        try {
                            bufferedWriter.close();
                        } catch (Throwable th3) {
                            th.addSuppressed(th3);
                        }
                    } else {
                        bufferedWriter.close();
                    }
                }
                throw th2;
            }
        }
        if (bufferedWriter != null) {
            if (0 != 0) {
                try {
                    bufferedWriter.close();
                } catch (Throwable th4) {
                    th.addSuppressed(th4);
                }
            } else {
                bufferedWriter.close();
            }
        }
        log.info("Wrote {} vectors of length {} to: {}", new Object[]{Integer.valueOf(numVertices), Integer.valueOf(vectorSize), str});
    }

    public static GraphVectors loadTxtVectors(File file) throws IOException {
        ArrayList arrayList = new ArrayList();
        BufferedReader bufferedReader = new BufferedReader(new FileReader(file));
        Throwable th = null;
        try {
            try {
                LineIterator lineIterator = IOUtils.lineIterator(bufferedReader);
                while (lineIterator.hasNext()) {
                    String[] split = lineIterator.next().split(DELIM);
                    double[] dArr = new double[split.length - 1];
                    for (int i = 1; i < split.length; i++) {
                        dArr[i - 1] = Double.parseDouble(split[i]);
                    }
                    arrayList.add(dArr);
                }
                if (bufferedReader != null) {
                    if (0 != 0) {
                        try {
                            bufferedReader.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        bufferedReader.close();
                    }
                }
                int length = ((double[]) arrayList.get(0)).length;
                int size = arrayList.size();
                INDArray create = Nd4j.create(size, length);
                for (int i2 = 0; i2 < arrayList.size(); i2++) {
                    double[] dArr2 = (double[]) arrayList.get(i2);
                    for (int i3 = 0; i3 < dArr2.length; i3++) {
                        create.put(i2, i3, Double.valueOf(dArr2[i3]));
                    }
                }
                InMemoryGraphLookupTable inMemoryGraphLookupTable = new InMemoryGraphLookupTable(size, length, null, 0.01d);
                inMemoryGraphLookupTable.setVertexVectors(create);
                return new GraphVectorsImpl(null, inMemoryGraphLookupTable);
            } finally {
            }
        } catch (Throwable th3) {
            if (bufferedReader != null) {
                if (th != null) {
                    try {
                        bufferedReader.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    bufferedReader.close();
                }
            }
            throw th3;
        }
    }
}
