DataSourceExtension.java
package net.morimekta.testing.junit5.sql;
import org.apache.tomcat.jdbc.pool.DataSource;
import org.apache.tomcat.jdbc.pool.PoolProperties;
import org.h2.tools.Server;
import org.jdbi.v3.core.Handle;
import org.jdbi.v3.core.Jdbi;
import org.jdbi.v3.core.h2.H2DatabasePlugin;
import org.jdbi.v3.core.statement.DefaultStatementBuilder;
import org.jdbi.v3.core.statement.StatementContext;
import org.jdbi.v3.sqlobject.SqlObjectPlugin;
import org.junit.jupiter.api.extension.AfterEachCallback;
import org.junit.jupiter.api.extension.BeforeEachCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.api.extension.ParameterContext;
import org.junit.jupiter.api.extension.ParameterResolutionException;
import org.junit.jupiter.api.extension.ParameterResolver;
import java.io.IOException;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.nio.file.Files;
import java.nio.file.Path;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.Optional;
import java.util.TimeZone;
import static net.morimekta.file.FileUtil.deleteRecursively;
import static net.morimekta.testing.io.ResourceUtil.resourceAsString;
import static net.morimekta.testing.junit5.AnnotationUtil.getAnnotationsBottomUp;
import static net.morimekta.testing.junit5.AnnotationUtil.getTopAnnotation;
import static org.junit.jupiter.api.extension.ExtensionContext.Namespace.GLOBAL;
/**
* Manage a DataSource for testing.
* <p>
* If no external data connection using the {@link DataSourceURIMethod}
* annotation is provided a `h2` in-memory database will be created and
* used for the test.
* </p>
* <pre>{@code
* {@literal@}ExtendWith(DataSourceExtension.class)
* {@literal@}DataSourceMode(DataSourceMode.Mode.MYSQL)
* public class MyTest {
* {@literal@}Test
* {@literal@}DataSourceSchema("/testing.sql")
* public void testThings(Jdbi jdbi) {
* var db = jdbi.onDemand(MyDBI.class);
* // do stuff.
* db.runMyQuery("a", "b");
* }
* }
* }</pre>
*/
public class DataSourceExtension
implements BeforeEachCallback, AfterEachCallback, ParameterResolver {
static {
// This forces the H2 DB to use UTC as timezone. This is same as setting
// the timezone on connection for mysql. Sadly this is the only way to
// enforce H2 to use UTC as timezone, as it also affects the test-local
// timezone.
TimeZone.setDefault(TimeZone.getTimeZone("UTC"));
}
/**
* Create a H2 database resource, which will clear the DB and load the
* given schemas for each test.
*/
public DataSourceExtension() {
try {
DriverManager.registerDriver(new org.h2.Driver());
// We be like, what is this craziness:
// http://www.h2database.com/html/advanced.html#java_objects_serialization
System.setProperty("h2.serializeJavaObject", "false");
} catch (SQLException e) {
throw new AssertionError("Failed to set up testing H2 database", e);
}
}
// --- BeforeEachCallback ---
@Override
public void beforeEach(ExtensionContext context) {
if (context.getTestClass().isPresent()) {
getAnnotationsBottomUp(context, DataSourceSchema.class)
.forEach(schema -> loadSchema(context, schema.value()));
}
for (Field field : context.getRequiredTestClass().getFields()) {
if (field.isAnnotationPresent(DataSourceURI.class)) {
if ((field.getModifiers() & Modifier.PUBLIC) != Modifier.PUBLIC) {
throw new AssertionError("");
}
if ((field.getModifiers() & Modifier.FINAL) == Modifier.FINAL) {
throw new AssertionError("");
}
if ((field.getModifiers() & Modifier.STATIC) == Modifier.STATIC) {
throw new AssertionError("");
}
if (field.getType() != String.class) {
throw new AssertionError("");
}
try {
field.set(context.getRequiredTestInstance(), createJdbcUri(context));
} catch (IllegalAccessException e) {
throw new AssertionError("", e);
}
}
}
}
// --- AfterEachCallback ---
@Override
public void afterEach(ExtensionContext context) {
var store = context.getStore(GLOBAL);
if (store == null) {
return;
}
var keepSchemaHandle = (Handle) store.get(Handle.class);
if (keepSchemaHandle != null) {
keepSchemaHandle.close();
}
var ds = getDataSource(context);
if (ds != null) {
ds.close();
}
var h2 = getH2Context(context);
if (h2 != null) {
h2.server.shutdown();
h2.server.stop();
if (Files.exists(h2.dir)) {
try {
deleteRecursively(h2.dir);
} catch (IOException ignore) {
}
}
}
store.remove(Handle.class);
store.remove(Jdbi.class);
store.remove(DataSource.class);
store.remove(H2Context.class);
store.remove(JDBC_URI);
}
// --- ParameterResolver ---
@Override
public boolean supportsParameter(ParameterContext parameterContext, ExtensionContext extensionContext)
throws ParameterResolutionException {
var type = parameterContext.getParameter().getType();
if (parameterContext.isAnnotated(DataSourceURI.class)) {
if (type == String.class) {
return true;
}
throw new AssertionError("Data Source URI parameter must be String");
}
return type == Jdbi.class ||
type == javax.sql.DataSource.class;
}
@Override
public Object resolveParameter(ParameterContext parameterContext, ExtensionContext extensionContext)
throws ParameterResolutionException {
var type = parameterContext.getParameter().getType();
if (type == String.class && parameterContext.isAnnotated(DataSourceURI.class)) {
return createJdbcUri(extensionContext);
}
if (type == Jdbi.class) {
return createJdbi(extensionContext);
}
return createDataSource(extensionContext);
}
// --- Private ---
private static final String JDBC_URI = "jdbcURI";
private static class H2Context {
Path dir;
Server server;
}
private H2Context createH2Context(ExtensionContext context) {
return (H2Context) context
.getStore(GLOBAL)
.getOrComputeIfAbsent(
H2Context.class,
type -> {
var h2 = new H2Context();
try {
h2.dir = Files.createTempDirectory("junit-h2");
h2.server = Server.createTcpServer(
"-tcpPort", "0",
"-tcp",
"-tcpAllowOthers",
"-tcpDaemon",
"-ifNotExists");
h2.server.setOut(System.err);
h2.server.start();
} catch (IOException | SQLException e) {
throw new AssertionError("Unable to create H2 in-memory DB", e);
}
return h2;
});
}
private H2Context getH2Context(ExtensionContext context) {
return (H2Context) context
.getStore(GLOBAL)
.get(H2Context.class);
}
private String createJdbcUri(ExtensionContext context) {
return (String) context
.getStore(GLOBAL)
.getOrComputeIfAbsent(JDBC_URI, (key) -> {
var dsUriByAnnotation = getJdbcUriInternal(context);
if (dsUriByAnnotation.isPresent()) {
return dsUriByAnnotation.get();
}
var mode = getTopAnnotation(context, DataSourceMode.class)
.map(DataSourceMode::value)
.orElse(DataSourceMode.Mode.POSTGRESQL);
var h2 = createH2Context(context);
return "jdbc:h2:" + h2.server.getURL() + "/" + h2.dir.toString() + ";MODE=" + mode.mode;
});
}
private DataSource createDataSource(ExtensionContext context) {
return (DataSource) context
.getStore(GLOBAL)
.getOrComputeIfAbsent(
DataSource.class,
(type) -> {
var dbUri = createJdbcUri(context);
var dbConfig = new PoolProperties();
dbConfig.setUrl(dbUri);
var driver = getTopAnnotation(context, DataSourceDriverClass.class);
if (driver.isPresent()) {
dbConfig.setDriverClassName(driver.get().value().getName());
} else {
try {
dbConfig.setDriverClassName(DriverManager.getDriver(dbUri).getClass().getName());
} catch (SQLException e) {
throw new AssertionError("Unable to determine DB driver.", e);
}
}
dbConfig.setInitialSize(2);
dbConfig.setMinIdle(2);
dbConfig.setMaxIdle(10);
dbConfig.setMaxActive(10);
dbConfig.setValidationQuery("SELECT 1");
dbConfig.setTestOnBorrow(true);
dbConfig.setTestOnReturn(false);
dbConfig.setValidationInterval(60);
dbConfig.setSuspectTimeout(300);
dbConfig.setLogAbandoned(true);
var ds = new DataSource(dbConfig);
if (getH2Context(context) != null) {
var dbi = Jdbi.create(ds);
context.getStore(GLOBAL).put(Handle.class, dbi.open());
}
return ds;
});
}
private DataSource getDataSource(ExtensionContext context) {
return (DataSource) context
.getStore(GLOBAL)
.get(DataSource.class);
}
private Jdbi createJdbi(ExtensionContext context) {
return (Jdbi) context
.getStore(GLOBAL)
.getOrComputeIfAbsent(
Jdbi.class,
(type) -> {
var jdbi = Jdbi.create(createDataSource(context));
if (getH2Context(context) != null) {
jdbi.setStatementBuilderFactory(conn -> new ScrollingH2StatementBuilderV3());
jdbi.installPlugin(new H2DatabasePlugin());
}
jdbi.installPlugin(new SqlObjectPlugin());
return jdbi;
});
}
private void loadSchema(ExtensionContext context, String... schemas) {
// Keep at least one handle around until closed, as H2 removes all
// content when the last connection closes.
if (schemas.length == 0) {
return;
}
try (Handle handle = createJdbi(context).open()) {
for (String schema : schemas) {
handle.createScript(resourceAsString(context.getRequiredTestClass(), schema))
.execute();
}
}
}
private static Optional<String> getJdbcUriInternal(ExtensionContext context) {
return context.getTestMethod()
.map(m -> getJdbcUriInternal(context, m))
.or(() -> context.getTestClass()
.map(t -> getJdbcUriInternal(context, t)));
}
private static String getJdbcUriInternal(ExtensionContext context, Method method) {
var dataSourceURIMethod = method.getAnnotation(DataSourceURIMethod.class);
if (dataSourceURIMethod != null) {
var getJdbcUriName = dataSourceURIMethod.value();
try {
var getJdbcUri = method.getDeclaringClass().getMethod(getJdbcUriName);
if ((getJdbcUri.getModifiers() & Modifier.STATIC) == Modifier.STATIC) {
return (String) getJdbcUri.invoke(null);
} else {
return (String) getJdbcUri.invoke(context.getRequiredTestInstance());
}
} catch (InvocationTargetException e) {
throw new AssertionError("Exception calling " + getJdbcUriName + "().", e.getCause());
} catch (NoSuchMethodException e) {
throw new AssertionError("No URI method with no args for " + getJdbcUriName + "().", e);
} catch (IllegalAccessException e) {
throw new AssertionError("Unable to access " + getJdbcUriName + "().", e);
} catch (ClassCastException e) {
throw new AssertionError("Response from " + getJdbcUriName + "() not a string.", e);
}
}
return null;
}
private static String getJdbcUriInternal(ExtensionContext context, Class<?> type) {
var dataSourceURIMethod = type.getAnnotation(DataSourceURIMethod.class);
if (dataSourceURIMethod != null) {
var method = dataSourceURIMethod.value();
try {
var getUri = type.getMethod(method);
if ((getUri.getModifiers() & Modifier.STATIC) == Modifier.STATIC) {
return (String) getUri.invoke(null);
} else {
return (String) getUri.invoke(context.getRequiredTestInstance());
}
} catch (InvocationTargetException e) {
throw new AssertionError("Exception calling " + method + "().", e.getCause());
} catch (NoSuchMethodException e) {
throw new AssertionError("No URI method with no args for " + method + "().", e);
} catch (IllegalAccessException e) {
throw new AssertionError("Unable to access " + method + "().", e);
} catch (ClassCastException e) {
throw new AssertionError("Response from " + method + "() not a string.", e);
}
}
if (type.getSuperclass() != null) {
return getJdbcUriInternal(context, type.getSuperclass());
}
return null;
}
private static class ScrollingH2StatementBuilderV3 extends DefaultStatementBuilder {
@Override
public PreparedStatement create(Connection conn, String sql, StatementContext ctx) throws SQLException {
if (ctx.isReturningGeneratedKeys()) {
String[] columnNames = ctx.getGeneratedKeysColumnNames();
if (columnNames != null && columnNames.length > 0) {
return conn.prepareStatement(sql, columnNames);
}
return conn.prepareStatement(sql, Statement.RETURN_GENERATED_KEYS);
} else if (ctx.isConcurrentUpdatable()) {
return conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_UPDATABLE);
} else {
// let statement be scrollable
return conn.prepareStatement(sql, ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_READ_ONLY);
}
}
}
}