DataSourceExtension.java

package net.morimekta.testing.junit5.sql;

import org.apache.tomcat.jdbc.pool.DataSource;
import org.apache.tomcat.jdbc.pool.PoolProperties;
import org.h2.tools.Server;
import org.jdbi.v3.core.Handle;
import org.jdbi.v3.core.Jdbi;
import org.jdbi.v3.core.h2.H2DatabasePlugin;
import org.jdbi.v3.core.statement.DefaultStatementBuilder;
import org.jdbi.v3.core.statement.StatementContext;
import org.jdbi.v3.sqlobject.SqlObjectPlugin;
import org.junit.jupiter.api.extension.AfterEachCallback;
import org.junit.jupiter.api.extension.BeforeEachCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.api.extension.ParameterContext;
import org.junit.jupiter.api.extension.ParameterResolutionException;
import org.junit.jupiter.api.extension.ParameterResolver;

import java.io.IOException;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.nio.file.Files;
import java.nio.file.Path;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.Optional;
import java.util.TimeZone;

import static net.morimekta.file.FileUtil.deleteRecursively;
import static net.morimekta.testing.io.ResourceUtil.resourceAsString;
import static net.morimekta.testing.junit5.AnnotationUtil.getAnnotationsBottomUp;
import static net.morimekta.testing.junit5.AnnotationUtil.getTopAnnotation;
import static org.junit.jupiter.api.extension.ExtensionContext.Namespace.GLOBAL;

/**
 * Manage a DataSource for testing.
 * <p>
 * If no external data connection using the {@link DataSourceURIMethod}
 * annotation is provided a `h2` in-memory database will be created and
 * used for the test.
 * </p>
 * <pre>{@code
 * {@literal@}ExtendWith(DataSourceExtension.class)
 * {@literal@}DataSourceMode(DataSourceMode.Mode.MYSQL)
 * public class MyTest {
 *     {@literal@}Test
 *     {@literal@}DataSourceSchema("/testing.sql")
 *     public void testThings(Jdbi jdbi) {
 *         var db = jdbi.onDemand(MyDBI.class);
 *         // do stuff.
 *         db.runMyQuery("a", "b");
 *     }
 * }
 * }</pre>
 */
public class DataSourceExtension
        implements BeforeEachCallback, AfterEachCallback, ParameterResolver {
    static {
        // This forces the H2 DB to use UTC as timezone. This is same as setting
        // the timezone on connection for mysql. Sadly this is the only way to
        // enforce H2 to use UTC as timezone, as it also affects the test-local
        // timezone.
        TimeZone.setDefault(TimeZone.getTimeZone("UTC"));
    }

    /**
     * Create a H2 database resource, which will clear the DB and load the
     * given schemas for each test.
     */
    public DataSourceExtension() {
        try {
            DriverManager.registerDriver(new org.h2.Driver());

            // We be like, what is this craziness:
            // http://www.h2database.com/html/advanced.html#java_objects_serialization
            System.setProperty("h2.serializeJavaObject", "false");
        } catch (SQLException e) {
            throw new AssertionError("Failed to set up testing H2 database", e);
        }
    }

    // --- BeforeEachCallback ---

    @Override
    public void beforeEach(ExtensionContext context) {
        if (context.getTestClass().isPresent()) {
            getAnnotationsBottomUp(context, DataSourceSchema.class)
                    .forEach(schema -> loadSchema(context, schema.value()));
        }

        for (Field field : context.getRequiredTestClass().getFields()) {
            if (field.isAnnotationPresent(DataSourceURI.class)) {
                if ((field.getModifiers() & Modifier.PUBLIC) != Modifier.PUBLIC) {
                    throw new AssertionError("");
                }
                if ((field.getModifiers() & Modifier.FINAL) == Modifier.FINAL) {
                    throw new AssertionError("");
                }
                if ((field.getModifiers() & Modifier.STATIC) == Modifier.STATIC) {
                    throw new AssertionError("");
                }
                if (field.getType() != String.class) {
                    throw new AssertionError("");
                }
                try {
                    field.set(context.getRequiredTestInstance(), createJdbcUri(context));
                } catch (IllegalAccessException e) {
                    throw new AssertionError("", e);
                }
            }
        }
    }

    // --- AfterEachCallback ---

    @Override
    public void afterEach(ExtensionContext context) {
        var store = context.getStore(GLOBAL);
        var keepSchemaHandle = (Handle) store.get(Handle.class);
        if (keepSchemaHandle != null) {
            keepSchemaHandle.close();
        }
        var ds = getDataSource(context);
        if (ds != null) {
            ds.close();
        }
        var h2 = getH2Context(context);
        if (h2 != null) {
            h2.server.shutdown();
            h2.server.stop();
            if (Files.exists(h2.dir)) {
                try {
                    deleteRecursively(h2.dir);
                } catch (IOException ignore) {
                }
            }
        }

        store.remove(Handle.class);
        store.remove(Jdbi.class);
        store.remove(DataSource.class);
        store.remove(H2Context.class);
        store.remove(JDBC_URI);
    }

    // --- ParameterResolver ---

    @Override
    public boolean supportsParameter(ParameterContext parameterContext, ExtensionContext extensionContext)
            throws ParameterResolutionException {
        var type = parameterContext.getParameter().getType();
        if (parameterContext.isAnnotated(DataSourceURI.class)) {
            if (type == String.class) {
                return true;
            }
            throw new AssertionError("Data Source URI parameter must be String");
        }
        return type == Jdbi.class ||
               type == javax.sql.DataSource.class;
    }

    @Override
    public Object resolveParameter(ParameterContext parameterContext, ExtensionContext extensionContext)
            throws ParameterResolutionException {
        var type = parameterContext.getParameter().getType();
        if (parameterContext.isAnnotated(DataSourceURI.class)) {
            return createJdbcUri(extensionContext);
        }
        if (type == Jdbi.class) {
            return createJdbi(extensionContext);
        }
        return createDataSource(extensionContext);
    }

    // --- Private ---

    private static final String JDBC_URI = "jdbcURI";

    private static class H2Context {
        Path   dir;
        Server server;
    }

    private H2Context createH2Context(ExtensionContext context) {
        return (H2Context) context
                .getStore(GLOBAL)
                .getOrComputeIfAbsent(
                        H2Context.class,
                        type -> {
                            var h2 = new H2Context();
                            try {
                                h2.dir = Files.createTempDirectory("junit-h2");
                                h2.server = Server.createTcpServer(
                                        "-tcpPort", "0",
                                        "-tcp",
                                        "-tcpAllowOthers",
                                        "-tcpDaemon",
                                        "-ifNotExists");
                                h2.server.setOut(System.err);
                                h2.server.start();
                            } catch (IOException | SQLException e) {
                                throw new AssertionError("Unable to create H2 in-memory DB", e);
                            }
                            return h2;
                        });
    }

    private H2Context getH2Context(ExtensionContext context) {
        return (H2Context) context
                .getStore(GLOBAL)
                .get(H2Context.class);
    }

    private String createJdbcUri(ExtensionContext context) {
        return (String) context
                .getStore(GLOBAL)
                .getOrComputeIfAbsent(JDBC_URI, (key) -> {
                    var dsUriByAnnotation = getJdbcUriInternal(context);
                    if (dsUriByAnnotation.isPresent()) {
                        return dsUriByAnnotation.get();
                    }
                    var mode = getTopAnnotation(context, DataSourceMode.class)
                            .map(DataSourceMode::value)
                            .orElse(DataSourceMode.Mode.POSTGRESQL);
                    var h2 = createH2Context(context);
                    return "jdbc:h2:" + h2.server.getURL() + "/" + h2.dir.toString() + ";MODE=" + mode;
                });
    }

    private DataSource createDataSource(ExtensionContext context) {
        return (DataSource) context
                .getStore(GLOBAL)
                .getOrComputeIfAbsent(
                        DataSource.class,
                        (type) -> {
                            var dbUri = createJdbcUri(context);
                            var dbConfig = new PoolProperties();
                            dbConfig.setUrl(dbUri);

                            var driver = getTopAnnotation(context, DataSourceDriverClass.class);
                            if (driver.isPresent()) {
                                dbConfig.setDriverClassName(driver.get().value().getName());
                            } else {
                                try {
                                    dbConfig.setDriverClassName(DriverManager.getDriver(dbUri).getClass().getName());
                                } catch (SQLException e) {
                                    throw new AssertionError("Unable to determine DB driver.", e);
                                }
                            }
                            dbConfig.setInitialSize(2);
                            dbConfig.setMinIdle(2);
                            dbConfig.setMaxIdle(10);
                            dbConfig.setMaxActive(10);
                            dbConfig.setValidationQuery("SELECT 1");
                            dbConfig.setTestOnBorrow(true);
                            dbConfig.setTestOnReturn(false);
                            dbConfig.setValidationInterval(60);
                            dbConfig.setSuspectTimeout(300);
                            dbConfig.setLogAbandoned(true);

                            var ds = new DataSource(dbConfig);
                            if (getH2Context(context) != null) {
                                var dbi = Jdbi.create(ds);
                                context.getStore(GLOBAL).put(Handle.class, dbi.open());
                            }
                            return ds;
                        });
    }

    private DataSource getDataSource(ExtensionContext context) {
        return (DataSource) context
                .getStore(GLOBAL)
                .get(DataSource.class);
    }

    private Jdbi createJdbi(ExtensionContext context) {
        return (Jdbi) context
                .getStore(GLOBAL)
                .getOrComputeIfAbsent(
                        Jdbi.class,
                        (type) -> {
                            var jdbi = Jdbi.create(createDataSource(context));
                            if (getH2Context(context) != null) {
                                jdbi.setStatementBuilderFactory(conn -> new ScrollingH2StatementBuilderV3());
                                jdbi.installPlugin(new H2DatabasePlugin());
                            }
                            jdbi.installPlugin(new SqlObjectPlugin());
                            return jdbi;
                        });
    }

    private void loadSchema(ExtensionContext context, String... schemas) {
        // Keep at least one handle around until closed, as H2 removes all
        // content when the last connection closes.
        if (schemas.length == 0) {
            return;
        }
        try (Handle handle = createJdbi(context).open()) {
            for (String schema : schemas) {
                handle.createScript(resourceAsString(context.getRequiredTestClass(), schema))
                      .execute();
            }
        }
    }

    private static Optional<String> getJdbcUriInternal(ExtensionContext context) {
        return context.getTestMethod()
                      .map(m -> getJdbcUriInternal(context, m))
                      .or(() -> context.getTestClass()
                                       .map(t -> getJdbcUriInternal(context, t)));
    }

    private static String getJdbcUriInternal(ExtensionContext context, Method method) {
        var dataSourceURIMethod = method.getAnnotation(DataSourceURIMethod.class);
        if (dataSourceURIMethod != null) {
            var getJdbcUriName = dataSourceURIMethod.value();
            try {
                var getJdbcUri = method.getDeclaringClass().getMethod(getJdbcUriName);
                if ((getJdbcUri.getModifiers() & Modifier.STATIC) == Modifier.STATIC) {
                    return (String) getJdbcUri.invoke(null);
                } else {
                    return (String) getJdbcUri.invoke(context.getRequiredTestInstance());
                }
            } catch (InvocationTargetException e) {
                throw new AssertionError("Exception calling " + getJdbcUriName + "().", e.getCause());
            } catch (NoSuchMethodException e) {
                throw new AssertionError("No URI method with no args for " + getJdbcUriName + "().", e);
            } catch (IllegalAccessException e) {
                throw new AssertionError("Unable to access " + getJdbcUriName + "().", e);
            } catch (ClassCastException e) {
                throw new AssertionError("Response from " + getJdbcUriName + "() not a string.", e);
            }
        }
        return null;
    }

    private static String getJdbcUriInternal(ExtensionContext context, Class<?> type) {
        var dataSourceURIMethod = type.getAnnotation(DataSourceURIMethod.class);
        if (dataSourceURIMethod != null) {
            var method = dataSourceURIMethod.value();
            try {
                var getUri = type.getMethod(method);
                if ((getUri.getModifiers() & Modifier.STATIC) == Modifier.STATIC) {
                    return (String) getUri.invoke(null);
                } else {
                    return (String) getUri.invoke(context.getRequiredTestInstance());
                }
            } catch (InvocationTargetException e) {
                throw new AssertionError("Exception calling " + method + "().", e.getCause());
            } catch (NoSuchMethodException e) {
                throw new AssertionError("No URI method with no args for " + method + "().", e);
            } catch (IllegalAccessException e) {
                throw new AssertionError("Unable to access " + method + "().", e);
            } catch (ClassCastException e) {
                throw new AssertionError("Response from " + method + "() not a string.", e);
            }
        }
        if (type.getSuperclass() != null) {
            return getJdbcUriInternal(context, type.getSuperclass());
        }
        return null;
    }

    private static class ScrollingH2StatementBuilderV3 extends DefaultStatementBuilder {
        @Override
        public PreparedStatement create(Connection conn, String sql, StatementContext ctx) throws SQLException {
            if (ctx.isReturningGeneratedKeys()) {
                String[] columnNames = ctx.getGeneratedKeysColumnNames();
                if (columnNames != null && columnNames.length > 0) {
                    return conn.prepareStatement(sql, columnNames);
                }
                return conn.prepareStatement(sql, Statement.RETURN_GENERATED_KEYS);
            } else if (ctx.isConcurrentUpdatable()) {
                return conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_UPDATABLE);
            } else {
                // let statement be scrollable
                return conn.prepareStatement(sql, ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_READ_ONLY);
            }
        }
    }
}