package org.nd4j.jdbc.loader.impl;

import com.mchange.v2.c3p0.ComboPooledDataSource;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.sql.Blob;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import javax.sql.DataSource;
import org.nd4j.jdbc.driverfinder.DriverFinder;
import org.nd4j.jdbc.loader.api.JDBCNDArrayIO;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.serde.binary.BinarySerde;

/* loaded from: input_file:org/nd4j/jdbc/loader/impl/BaseLoader.class */
public abstract class BaseLoader implements JDBCNDArrayIO {
    protected String tableName;
    protected String columnName;
    protected String idColumnName;
    protected String jdbcUrl;
    protected DataSource dataSource;

    protected BaseLoader(DataSource dataSource, String str, String str2, String str3, String str4) throws Exception {
        this.dataSource = dataSource;
        this.jdbcUrl = str;
        this.tableName = str2;
        this.columnName = str4;
        this.idColumnName = str3;
        if (dataSource == null) {
            ComboPooledDataSource comboPooledDataSource = new ComboPooledDataSource();
            comboPooledDataSource.setJdbcUrl(str);
            comboPooledDataSource.setDriverClass(DriverFinder.getDriver().getClass().getName());
        }
    }

    protected BaseLoader(String str, String str2, String str3, String str4) throws Exception {
        this.jdbcUrl = str;
        this.tableName = str2;
        this.columnName = str4;
        this.dataSource = new ComboPooledDataSource();
        ComboPooledDataSource comboPooledDataSource = (ComboPooledDataSource) this.dataSource;
        comboPooledDataSource.setJdbcUrl(str);
        comboPooledDataSource.setDriverClass(DriverFinder.getDriver().getClass().getName());
        this.idColumnName = str3;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public BaseLoader(DataSource dataSource, String str, String str2, String str3) throws Exception {
        this(dataSource, str, str2, "id", str3);
    }

    @Override // org.nd4j.jdbc.loader.api.JDBCNDArrayIO
    public Blob convert(IComplexNDArray iComplexNDArray) throws IOException, SQLException {
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        Nd4j.writeComplex(iComplexNDArray, new DataOutputStream(byteArrayOutputStream));
        byte[] byteArray = byteArrayOutputStream.toByteArray();
        Blob createBlob = this.dataSource.getConnection().createBlob();
        createBlob.setBytes(1L, byteArray);
        return createBlob;
    }

    @Override // org.nd4j.jdbc.loader.api.JDBCNDArrayIO
    public Blob convert(INDArray iNDArray) throws SQLException {
        ByteBuffer byteBuffer = BinarySerde.toByteBuffer(iNDArray);
        byteBuffer.rewind();
        byte[] bArr = new byte[byteBuffer.capacity()];
        byteBuffer.get(bArr);
        Blob createBlob = this.dataSource.getConnection().createBlob();
        createBlob.setBytes(1L, bArr);
        return createBlob;
    }

    @Override // org.nd4j.jdbc.loader.api.JDBCNDArrayIO
    public INDArray load(Blob blob) throws SQLException {
        if (blob == null) {
            return null;
        }
        try {
            InputStream binaryStream = blob.getBinaryStream();
            Throwable th = null;
            try {
                try {
                    ByteBuffer allocateDirect = ByteBuffer.allocateDirect((int) blob.length());
                    Channels.newChannel(binaryStream).read(allocateDirect);
                    allocateDirect.rewind();
                    INDArray array = BinarySerde.toArray(allocateDirect);
                    if (binaryStream != null) {
                        if (0 != 0) {
                            try {
                                binaryStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            binaryStream.close();
                        }
                    }
                    return array;
                } finally {
                }
            } finally {
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.nd4j.jdbc.loader.api.JDBCNDArrayIO
    public IComplexNDArray loadComplex(Blob blob) throws SQLException, IOException {
        return Nd4j.readComplex(new DataInputStream(blob.getBinaryStream()));
    }

    @Override // org.nd4j.jdbc.loader.api.JDBCNDArrayIO
    public void save(INDArray iNDArray, String str) throws SQLException, IOException {
        doSave(iNDArray, str);
    }

    @Override // org.nd4j.jdbc.loader.api.JDBCNDArrayIO
    public void save(IComplexNDArray iComplexNDArray, String str) throws IOException, SQLException {
        doSave(iComplexNDArray, str);
    }

    private void doSave(INDArray iNDArray, String str) throws SQLException, IOException {
        Connection connection = this.dataSource.getConnection();
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        DataOutputStream dataOutputStream = new DataOutputStream(byteArrayOutputStream);
        if (iNDArray instanceof IComplexNDArray) {
            Nd4j.writeComplex((IComplexNDArray) iNDArray, dataOutputStream);
        } else {
            BinarySerde.writeArrayToOutputStream(iNDArray, byteArrayOutputStream);
        }
        byte[] byteArray = byteArrayOutputStream.toByteArray();
        PreparedStatement prepareStatement = connection.prepareStatement(insertStatement());
        prepareStatement.setString(1, str);
        prepareStatement.setBytes(2, byteArray);
        prepareStatement.executeUpdate();
    }

    @Override // org.nd4j.jdbc.loader.api.JDBCNDArrayIO
    public Blob loadForID(String str) throws SQLException {
        PreparedStatement prepareStatement = this.dataSource.getConnection().prepareStatement(loadStatement());
        prepareStatement.setString(1, str);
        ResultSet executeQuery = prepareStatement.executeQuery();
        if (executeQuery.wasNull() || !executeQuery.next()) {
            return null;
        }
        return executeQuery.getBlob(2);
    }

    @Override // org.nd4j.jdbc.loader.api.JDBCNDArrayIO
    public INDArray loadArrayForId(String str) throws SQLException {
        return load(loadForID(str));
    }

    @Override // org.nd4j.jdbc.loader.api.JDBCNDArrayIO
    public void delete(String str) throws SQLException {
        PreparedStatement prepareStatement = this.dataSource.getConnection().prepareStatement(deleteStatement());
        prepareStatement.setString(1, str);
        prepareStatement.execute();
    }
}
