package com.googlecode.flyway.test.dbunit;

import com.googlecode.flyway.test.ExecutionListenerHelper;
import java.io.File;
import java.io.FileWriter;
import java.io.InputStream;
import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.SQLException;
import javax.sql.DataSource;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.dbunit.database.AmbiguousTableNameException;
import org.dbunit.database.DatabaseSequenceFilter;
import org.dbunit.database.IDatabaseConnection;
import org.dbunit.database.QueryDataSet;
import org.dbunit.dataset.DataSetException;
import org.dbunit.dataset.FilteredDataSet;
import org.dbunit.dataset.IDataSet;
import org.dbunit.dataset.xml.FlatXmlDataSet;
import org.dbunit.dataset.xml.FlatXmlDataSetBuilder;
import org.dbunit.operation.DatabaseOperation;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.core.io.ClassPathResource;
import org.springframework.test.context.TestContext;
import org.springframework.test.context.TestExecutionListener;

/* loaded from: input_file:com/googlecode/flyway/test/dbunit/DBUnitTestExecutionListener.class */
public class DBUnitTestExecutionListener implements TestExecutionListener {
    protected final Log logger = LogFactory.getLog(getClass());

    @Autowired(required = false)
    protected DatabaseConnectionFactory dbConnectionFactory = new DefaultDatabaseConnectionFactory();

    public void beforeTestClass(TestContext testContext) throws Exception {
        Annotation annotation = testContext.getTestClass().getAnnotation(DBUnitSupport.class);
        if (annotation != null) {
            loadFiles(testContext, (DBUnitSupport) annotation);
        }
    }

    public void prepareTestInstance(TestContext testContext) throws Exception {
    }

    public void beforeTestMethod(TestContext testContext) throws Exception {
        Annotation annotation = testContext.getTestMethod().getAnnotation(DBUnitSupport.class);
        if (annotation != null) {
            loadFiles(testContext, (DBUnitSupport) annotation);
        }
    }

    public void afterTestMethod(TestContext testContext) throws Exception {
        Annotation annotation = testContext.getTestMethod().getAnnotation(DBUnitSupport.class);
        String executionInformation = ExecutionListenerHelper.getExecutionInformation(testContext);
        if (annotation != null) {
            DBUnitSupport dBUnitSupport = (DBUnitSupport) annotation;
            String saveFileAfterRun = dBUnitSupport.saveFileAfterRun();
            String[] saveTableAfterRun = dBUnitSupport.saveTableAfterRun();
            if (saveFileAfterRun == null || saveFileAfterRun.trim().length() <= 0) {
                return;
            }
            if (this.logger.isDebugEnabled()) {
                this.logger.debug("******** Start save information '" + executionInformation + "' info file '" + saveFileAfterRun + "'.");
            }
            FlatXmlDataSet.write(getDataSetToExport(saveTableAfterRun, getConnection(getSaveDataSource(testContext), testContext)), new FileWriter(getFileToExport(saveFileAfterRun)));
            if (this.logger.isDebugEnabled()) {
                this.logger.debug("******** Finished save information '" + executionInformation + "' info file '" + saveFileAfterRun + "'.");
            }
        }
    }

    public void afterTestClass(TestContext testContext) throws Exception {
    }

    private void loadFiles(TestContext testContext, DBUnitSupport dBUnitSupport) throws Exception {
        String[] loadFilesForRun = dBUnitSupport.loadFilesForRun();
        if (loadFilesForRun == null || loadFilesForRun.length <= 0) {
            return;
        }
        String executionInformation = ExecutionListenerHelper.getExecutionInformation(testContext);
        if (this.logger.isDebugEnabled()) {
            this.logger.debug("******** Load files  '" + executionInformation + "'.");
        }
        for (int i = 0; i < loadFilesForRun.length; i += 2) {
            String str = loadFilesForRun[i];
            String str2 = loadFilesForRun[i + 1];
            if (this.logger.isDebugEnabled()) {
                this.logger.debug("******** load file '" + executionInformation + "' op='" + str + "' - '" + str2 + "'.");
            }
            getOperation(str).execute(getConnection(getSaveDataSource(testContext), testContext), getFileDataSet(new ClassPathResource(str2).getInputStream()));
        }
        if (this.logger.isDebugEnabled()) {
            this.logger.debug("******** Finished load files '" + executionInformation + "'.");
        }
    }

    private DatabaseOperation getOperation(String str) throws SecurityException, NoSuchFieldException, IllegalArgumentException, IllegalAccessException {
        Field field = DatabaseOperation.class.getField(str.toUpperCase());
        if (field == null || !field.getType().equals(DatabaseOperation.class)) {
            throw new IllegalArgumentException("Operation " + str + " is unknown");
        }
        return (DatabaseOperation) field.get(DatabaseOperation.class);
    }

    private File getFileToExport(String str) {
        String str2 = str;
        if (str.startsWith(".")) {
            str2 = new File(".").getAbsolutePath() + File.separator + str;
        }
        File file = new File(str2);
        file.getParentFile().mkdirs();
        return file;
    }

    private IDataSet getDataSetToExport(String[] strArr, IDatabaseConnection iDatabaseConnection) throws DataSetException, SQLException, AmbiguousTableNameException {
        QueryDataSet filteredDataSet;
        DatabaseSequenceFilter databaseSequenceFilter = new DatabaseSequenceFilter(iDatabaseConnection);
        if (strArr == null || strArr.length == 0) {
            filteredDataSet = new FilteredDataSet(databaseSequenceFilter, iDatabaseConnection.createDataSet());
        } else {
            QueryDataSet queryDataSet = new QueryDataSet(iDatabaseConnection);
            if (strArr.length % 2 != 0) {
                throw new IllegalArgumentException("Contract {<Table Name>,<SELECT_QUERY>} is brocken.");
            }
            for (int i = 0; i < strArr.length; i += 2) {
                String upperCase = strArr[i].toUpperCase();
                String str = strArr[i + 1];
                if (str == null || str.trim().length() == 0) {
                    str = "SELECT * FROM " + upperCase;
                }
                queryDataSet.addTable(upperCase, str);
            }
            filteredDataSet = queryDataSet;
        }
        return filteredDataSet;
    }

    private DataSource getSaveDataSource(TestContext testContext) {
        ApplicationContext applicationContext = testContext.getApplicationContext();
        if (applicationContext == null) {
            throw new IllegalArgumentException("The test configuration contains no application context.");
        }
        DataSource bean = getBean(applicationContext, DataSource.class);
        if (bean != null) {
            return bean;
        }
        throw new IllegalArgumentException("The test application context has no configured data source!");
    }

    protected IDatabaseConnection getConnection(DataSource dataSource, TestContext testContext) throws Exception {
        Connection connection = dataSource.getConnection();
        DatabaseMetaData metaData = connection.getMetaData();
        try {
            DatabaseConnectionFactory databaseConnectionFactory = (DatabaseConnectionFactory) testContext.getApplicationContext().getBean(DatabaseConnectionFactory.class);
            if (databaseConnectionFactory != null) {
                this.dbConnectionFactory = databaseConnectionFactory;
            }
        } catch (Exception e) {
            this.logger.debug(String.format("We ignore if we could not find a instance of '%s'", DatabaseConnectionFactory.class.getName()));
        }
        if (this.dbConnectionFactory != null) {
            return this.dbConnectionFactory.createConnection(connection, metaData);
        }
        return null;
    }

    private FlatXmlDataSet getFileDataSet(InputStream inputStream) throws Exception {
        return new FlatXmlDataSetBuilder().build(inputStream);
    }

    private DataSource getBean(ApplicationContext applicationContext, Class<?> cls) {
        DataSource dataSource = null;
        String[] beanNamesForType = applicationContext.getBeanNamesForType(cls);
        if (beanNamesForType != null && beanNamesForType.length > 0) {
            dataSource = (DataSource) applicationContext.getBean(beanNamesForType[0]);
        }
        return dataSource;
    }
}
