package com.github.database.rider.junit5;

import com.github.database.rider.core.api.connection.ConnectionHolder;
import com.github.database.rider.core.api.dataset.DataSet;
import com.github.database.rider.core.api.dataset.DataSetExecutor;
import com.github.database.rider.core.api.dataset.ExpectedDataSet;
import com.github.database.rider.core.api.exporter.DataSetExportConfig;
import com.github.database.rider.core.api.exporter.ExportDataSet;
import com.github.database.rider.core.api.leak.LeakHunter;
import com.github.database.rider.core.configuration.ConnectionConfig;
import com.github.database.rider.core.configuration.DBUnitConfig;
import com.github.database.rider.core.configuration.DataSetConfig;
import com.github.database.rider.core.connection.ConnectionHolderImpl;
import com.github.database.rider.core.dataset.DataSetExecutorImpl;
import com.github.database.rider.core.exporter.DataSetExporter;
import com.github.database.rider.core.leak.LeakHunterException;
import com.github.database.rider.core.leak.LeakHunterFactory;
import com.github.database.rider.core.util.EntityManagerProvider;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.sql.Connection;
import java.sql.DriverManager;
import java.util.Arrays;
import java.util.Optional;
import org.dbunit.DatabaseUnitException;
import org.junit.jupiter.api.extension.AfterTestExecutionCallback;
import org.junit.jupiter.api.extension.BeforeTestExecutionCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/github/database/rider/junit5/DBUnitExtension.class */
public class DBUnitExtension implements BeforeTestExecutionCallback, AfterTestExecutionCallback {
    private static final Logger log = LoggerFactory.getLogger(DBUnitExtension.class);
    private static final ExtensionContext.Namespace namespace = ExtensionContext.Namespace.create(new Object[]{DBUnitExtension.class});

    public void beforeTestExecution(ExtensionContext extensionContext) throws Exception {
        if (shouldCreateDataSet(extensionContext)) {
            ConnectionHolder findTestConnection = findTestConnection(extensionContext);
            if (EntityManagerProvider.isEntityManagerActive()) {
                EntityManagerProvider.em().clear();
            }
            DataSet annotation = ((Method) extensionContext.getTestMethod().get()).getAnnotation(DataSet.class);
            if (annotation == null) {
                annotation = (DataSet) ((Class) extensionContext.getTestClass().get()).getAnnotation(DataSet.class);
            }
            DBUnitConfig from = DBUnitConfig.from((Method) extensionContext.getTestMethod().get());
            DataSetConfig from2 = new DataSetConfig().from(annotation);
            if (findTestConnection == null || findTestConnection.getConnection() == null) {
                findTestConnection = createConnection(from, ((Method) extensionContext.getTestMethod().get()).getName());
            }
            DataSetExecutorImpl instance = DataSetExecutorImpl.instance(from2.getExecutorId(), findTestConnection);
            instance.setDBUnitConfig(from);
            DBUnitTestContext testContext = getTestContext(extensionContext);
            testContext.setExecutor(instance).setDataSetConfig(from2);
            if (from2 != null && from2.getExecuteStatementsBefore() != null && from2.getExecuteStatementsBefore().length > 0) {
                try {
                    instance.executeStatements(from2.getExecuteStatementsBefore());
                } catch (Exception e) {
                    log.error(((Method) extensionContext.getTestMethod().get()).getName() + "() - Could not execute statements Before:" + e.getMessage(), e);
                }
            }
            if (from2.getExecuteScriptsBefore() != null && from2.getExecuteScriptsBefore().length > 0) {
                for (int i = 0; i < from2.getExecuteScriptsBefore().length; i++) {
                    try {
                        instance.executeScript(from2.getExecuteScriptsBefore()[i]);
                    } catch (Exception e2) {
                        if (e2 instanceof DatabaseUnitException) {
                            throw e2;
                        }
                        log.error(((Method) extensionContext.getTestMethod().get()).getName() + "() - Could not execute scriptsBefore:" + e2.getMessage(), e2);
                    }
                }
            }
            if (from.isLeakHunter().booleanValue()) {
                LeakHunter from3 = LeakHunterFactory.from(findTestConnection.getConnection());
                testContext.setLeakHunter(from3).setOpenConnections(from3.openConnections());
            }
            try {
                instance.createDataSet(from2);
                if (from2.isTransactional()) {
                    if (!EntityManagerProvider.isEntityManagerActive()) {
                        instance.getRiderDataSource().getConnection().setAutoCommit(false);
                    } else {
                        if (EntityManagerProvider.tx().isActive()) {
                            return;
                        }
                        EntityManagerProvider.em().getTransaction().begin();
                    }
                }
            } catch (Exception e3) {
                throw new RuntimeException(String.format("Could not create dataset for test method %s due to following error " + e3.getMessage(), ((Method) extensionContext.getTestMethod().get()).getName()), e3);
            }
        }
    }

    private boolean shouldCreateDataSet(ExtensionContext extensionContext) {
        return ((Method) extensionContext.getTestMethod().get()).isAnnotationPresent(DataSet.class) || ((Class) extensionContext.getTestClass().get()).isAnnotationPresent(DataSet.class);
    }

    private boolean shouldCompareDataSet(ExtensionContext extensionContext) {
        return ((Method) extensionContext.getTestMethod().get()).isAnnotationPresent(ExpectedDataSet.class) || ((Class) extensionContext.getTestClass().get()).isAnnotationPresent(ExpectedDataSet.class);
    }

    private boolean shouldExportDataSet(ExtensionContext extensionContext) {
        return ((Method) extensionContext.getTestMethod().get()).isAnnotationPresent(ExportDataSet.class) || ((Class) extensionContext.getTestClass().get()).isAnnotationPresent(ExportDataSet.class);
    }

    public void exportDataSet(DataSetExecutor dataSetExecutor, Method method) {
        ExportDataSet resolveExportDataSet = resolveExportDataSet(method);
        if (resolveExportDataSet != null) {
            DataSetExportConfig from = DataSetExportConfig.from(resolveExportDataSet);
            String outputFileName = from.getOutputFileName();
            if (outputFileName == null || "".equals(outputFileName.trim())) {
                outputFileName = method.getName().toLowerCase() + "." + from.getDataSetFormat().name().toLowerCase();
            }
            from.outputFileName(outputFileName);
            try {
                DataSetExporter.getInstance().export(dataSetExecutor.getRiderDataSource().getDBUnitConnection(), from);
            } catch (Exception e) {
                log.warn("Could not export dataset after method " + method.getName(), e);
            }
        }
    }

    private ExportDataSet resolveExportDataSet(Method method) {
        ExportDataSet annotation = method.getAnnotation(ExportDataSet.class);
        if (annotation == null) {
            annotation = (ExportDataSet) method.getDeclaringClass().getAnnotation(ExportDataSet.class);
        }
        return annotation;
    }

    public void afterTestExecution(ExtensionContext extensionContext) throws Exception {
        DBUnitTestContext testContext = getTestContext(extensionContext);
        DataSetConfig dataSetConfig = testContext.getDataSetConfig();
        DataSetExecutor executor = testContext.getExecutor();
        DBUnitConfig dBUnitConfig = executor != null ? executor.getDBUnitConfig() : DBUnitConfig.from((Method) extensionContext.getTestMethod().get());
        if (dataSetConfig != null && executor != null) {
            try {
                if (shouldCompareDataSet(extensionContext)) {
                    ExpectedDataSet annotation = ((Method) extensionContext.getTestMethod().get()).getAnnotation(ExpectedDataSet.class);
                    if (annotation == null) {
                        annotation = (ExpectedDataSet) ((Class) extensionContext.getTestClass().get()).getAnnotation(ExpectedDataSet.class);
                    }
                    if (annotation != null) {
                        if (dataSetConfig.isTransactional()) {
                            try {
                                if (!EntityManagerProvider.isEntityManagerActive()) {
                                    Connection connection = executor.getRiderDataSource().getConnection();
                                    connection.commit();
                                    connection.setAutoCommit(false);
                                } else if (EntityManagerProvider.tx().isActive()) {
                                    EntityManagerProvider.tx().commit();
                                }
                            } catch (Exception e) {
                                if (EntityManagerProvider.isEntityManagerActive()) {
                                    EntityManagerProvider.tx().rollback();
                                } else {
                                    Connection connection2 = executor.getRiderDataSource().getConnection();
                                    connection2.setAutoCommit(false);
                                    connection2.setReadOnly(true);
                                }
                            }
                        }
                        executor.compareCurrentDataSetWith(new DataSetConfig(annotation.value()).disableConstraints(true), annotation.ignoreCols());
                    }
                }
            } catch (Throwable th) {
                if (dataSetConfig == null || executor == null) {
                    return;
                }
                if (shouldExportDataSet(extensionContext)) {
                    exportDataSet(executor, (Method) extensionContext.getTestMethod().get());
                }
                if (dataSetConfig.getExecuteStatementsAfter() != null && dataSetConfig.getExecuteStatementsAfter().length > 0) {
                    try {
                        executor.executeStatements(dataSetConfig.getExecuteStatementsAfter());
                    } catch (Exception e2) {
                        log.error(((Method) extensionContext.getTestMethod().get()).getName() + "() - Could not execute statements after:" + e2.getMessage(), e2);
                    }
                }
                if (dataSetConfig.getExecuteScriptsAfter() != null && dataSetConfig.getExecuteScriptsAfter().length > 0) {
                    for (int i = 0; i < dataSetConfig.getExecuteScriptsAfter().length; i++) {
                        try {
                            executor.executeScript(dataSetConfig.getExecuteScriptsAfter()[i]);
                        } catch (Exception e3) {
                            if (e3 instanceof DatabaseUnitException) {
                                throw e3;
                            }
                            log.error(((Method) extensionContext.getTestMethod().get()).getName() + "() - Could not execute scriptsAfter:" + e3.getMessage(), e3);
                        }
                    }
                }
                if (dataSetConfig.isCleanAfter()) {
                    executor.clearDatabase(dataSetConfig);
                }
                executor.enableConstraints();
                throw th;
            }
        }
        if (dBUnitConfig != null && dBUnitConfig.isLeakHunter().booleanValue()) {
            LeakHunter leakHunter = testContext.getLeakHunter();
            int openConnections = testContext.getOpenConnections();
            int openConnections2 = leakHunter.openConnections();
            if (openConnections2 > openConnections) {
                throw new LeakHunterException(((Method) extensionContext.getTestMethod().get()).getName(), openConnections2 - openConnections);
            }
        }
        if (dataSetConfig == null || executor == null) {
            return;
        }
        if (shouldExportDataSet(extensionContext)) {
            exportDataSet(executor, (Method) extensionContext.getTestMethod().get());
        }
        if (dataSetConfig.getExecuteStatementsAfter() != null && dataSetConfig.getExecuteStatementsAfter().length > 0) {
            try {
                executor.executeStatements(dataSetConfig.getExecuteStatementsAfter());
            } catch (Exception e4) {
                log.error(((Method) extensionContext.getTestMethod().get()).getName() + "() - Could not execute statements after:" + e4.getMessage(), e4);
            }
        }
        if (dataSetConfig.getExecuteScriptsAfter() != null && dataSetConfig.getExecuteScriptsAfter().length > 0) {
            for (int i2 = 0; i2 < dataSetConfig.getExecuteScriptsAfter().length; i2++) {
                try {
                    executor.executeScript(dataSetConfig.getExecuteScriptsAfter()[i2]);
                } catch (Exception e5) {
                    if (e5 instanceof DatabaseUnitException) {
                        throw e5;
                    }
                    log.error(((Method) extensionContext.getTestMethod().get()).getName() + "() - Could not execute scriptsAfter:" + e5.getMessage(), e5);
                }
            }
        }
        if (dataSetConfig.isCleanAfter()) {
            executor.clearDatabase(dataSetConfig);
        }
        executor.enableConstraints();
    }

    private ConnectionHolder findTestConnection(ExtensionContext extensionContext) {
        Class cls = (Class) extensionContext.getTestClass().get();
        try {
            Optional findFirst = Arrays.stream(cls.getDeclaredFields()).filter(field -> {
                return field.getType() == ConnectionHolder.class;
            }).findFirst();
            if (findFirst.isPresent()) {
                Field field2 = (Field) findFirst.get();
                if (!field2.isAccessible()) {
                    field2.setAccessible(true);
                }
                ConnectionHolder connectionHolder = (ConnectionHolder) ConnectionHolder.class.cast(field2.get(extensionContext.getTestInstance().get()));
                if (connectionHolder == null || connectionHolder.getConnection() == null) {
                    throw new RuntimeException("ConnectionHolder not initialized correctly");
                }
                return connectionHolder;
            }
            Optional findFirst2 = Arrays.stream(cls.getDeclaredMethods()).filter(method -> {
                return method.getReturnType() == ConnectionHolder.class;
            }).findFirst();
            if (!findFirst2.isPresent()) {
                return null;
            }
            Method method2 = (Method) findFirst2.get();
            if (!method2.isAccessible()) {
                method2.setAccessible(true);
            }
            ConnectionHolder connectionHolder2 = (ConnectionHolder) ConnectionHolder.class.cast(method2.invoke(extensionContext.getTestInstance().get(), new Object[0]));
            if (connectionHolder2 == null || connectionHolder2 == null) {
                throw new RuntimeException("ConnectionHolder not initialized correctly");
            }
            return connectionHolder2;
        } catch (Exception e) {
            throw new RuntimeException("Could not get database connection for test " + cls, e);
        }
    }

    private ConnectionHolder createConnection(DBUnitConfig dBUnitConfig, String str) {
        ConnectionConfig connectionConfig = dBUnitConfig.getConnectionConfig();
        if ("".equals(connectionConfig.getUrl()) || "".equals(connectionConfig.getUser())) {
            throw new RuntimeException(String.format("Could not create JDBC connection for method %s, provide a connection at test level or via configuration, see documentation here: https://github.com/rmpestano/dbunit-rules#jdbc-connection", str));
        }
        try {
            if (!"".equals(connectionConfig.getDriver())) {
                Class.forName(connectionConfig.getDriver());
            }
            return new ConnectionHolderImpl(DriverManager.getConnection(connectionConfig.getUrl(), connectionConfig.getUser(), connectionConfig.getPassword()));
        } catch (Exception e) {
            log.error("Could not create JDBC connection for method " + str, e);
            return null;
        }
    }

    private DBUnitTestContext getTestContext(ExtensionContext extensionContext) {
        Class cls = (Class) extensionContext.getTestClass().get();
        ExtensionContext.Store store = extensionContext.getStore(namespace);
        DBUnitTestContext dBUnitTestContext = (DBUnitTestContext) store.get(cls, DBUnitTestContext.class);
        if (dBUnitTestContext == null) {
            dBUnitTestContext = new DBUnitTestContext();
            store.put(cls, dBUnitTestContext);
        }
        return dBUnitTestContext;
    }
}
