MessageFieldArgument.java
/*
* Copyright 2018-2019 Providence Authors
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package net.morimekta.providence.jdbi.v3;
import net.morimekta.providence.PEnumValue;
import net.morimekta.providence.PMessage;
import net.morimekta.providence.PMessageOrBuilder;
import net.morimekta.providence.PType;
import net.morimekta.providence.descriptor.PField;
import net.morimekta.providence.serializer.BinarySerializer;
import net.morimekta.providence.serializer.JsonSerializer;
import net.morimekta.util.Binary;
import net.morimekta.util.Strings;
import org.jdbi.v3.core.argument.Argument;
import org.jdbi.v3.core.result.ResultSetException;
import org.jdbi.v3.core.statement.StatementContext;
import javax.annotation.Nonnull;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.StringReader;
import java.io.StringWriter;
import java.sql.Date;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Timestamp;
import java.sql.Types;
import java.util.Calendar;
import java.util.Objects;
import java.util.TimeZone;
/**
* Smart mapping of message fields to SQL bound argument. It will
* map the type to whichever type is default or selected (if supported)
* for most field types.
*
* @param <M> The message type.
*/
public class MessageFieldArgument<M extends PMessage<M>> implements Argument {
private static final BinarySerializer BINARY = new BinarySerializer();
private static final JsonSerializer JSON = new JsonSerializer().named();
private static final Calendar UTC = Calendar.getInstance(TimeZone.getTimeZone("UTC"));
private final PMessageOrBuilder<M> message;
private final PField<M> field;
private final int type;
/**
* Create a message field argument.
*
* @param message The message to get the field from.
* @param field The field to select.
*/
public MessageFieldArgument(@Nonnull PMessageOrBuilder<M> message, @Nonnull PField<M> field) {
this(message, field, getDefaultColumnType(field));
}
/**
* Create a message field argument.
*
* @param message The message to get the field from.
* @param field The field to select.
* @param type The SQL type. See {@link Types}.
*/
public MessageFieldArgument(@Nonnull PMessageOrBuilder<M> message, @Nonnull PField<M> field, int type) {
this.message = message;
this.field = field;
this.type = type;
}
@Override
public String toString() {
if (message.has(field)) {
if (field.getType() == PType.STRING) {
return "'" + message.get(field) + "'";
} else if (field.getType() == PType.BINARY) {
Binary binary = message.get(field);
return "b64(" + binary.toBase64() + ")";
} else if (field.getType() == PType.MESSAGE) {
return ((PMessage) message.get(field)).asString();
}
return Strings.asString((Object) message.get(field));
}
return "null";
}
@Override
public boolean equals(Object o) {
if (o == this) {
return true;
}
if (!(o instanceof MessageFieldArgument)) {
return false;
}
MessageFieldArgument other = (MessageFieldArgument) o;
return Objects.equals(message, other.message) &&
Objects.equals(field, other.field) &&
type == other.type;
}
@Override
public int hashCode() {
return Objects.hash(getClass(), message, field, type);
}
@Override
@SuppressWarnings("unchecked")
public void apply(int position, PreparedStatement statement, StatementContext ctx) throws SQLException {
if (message.has(field)) {
switch (field.getType()) {
case BOOL: {
boolean value = message.get(field);
if (type == Types.BOOLEAN || type == Types.BIT) {
statement.setBoolean(position, value);
} else {
statement.setInt(position, value ? 1 : 0);
}
break;
}
case BYTE: {
statement.setByte(position, message.get(field));
break;
}
case I16: {
statement.setShort(position, message.get(field));
break;
}
case I32: {
if (type == Types.TIMESTAMP || type == Types.TIMESTAMP_WITH_TIMEZONE) {
Timestamp timestamp = new Timestamp(1000L * (int) message.get(field));
statement.setTimestamp(position, timestamp, UTC);
} else if (type == Types.DATE) {
Date date = new Date(1000L * (int) message.get(field));
statement.setDate(position, date);
} else {
statement.setInt(position, message.get(field));
}
break;
}
case I64: {
if (type == Types.TIMESTAMP || type == Types.TIMESTAMP_WITH_TIMEZONE) {
Timestamp timestamp = new Timestamp(message.get(field));
statement.setTimestamp(position, timestamp, UTC);
} else if (type == Types.DATE) {
Date date = new Date(message.get(field));
statement.setDate(position, date);
} else {
statement.setLong(position, message.get(field));
}
break;
}
case DOUBLE: {
statement.setDouble(position, message.get(field));
break;
}
case STRING: {
switch (type) {
case Types.CLOB:
case Types.NCLOB: {
StringReader reader = new StringReader(message.get(field));
statement.setClob(position, reader);
break;
}
default: {
statement.setString(position, message.get(field));
break;
}
}
break;
}
case BINARY: {
Binary binary = message.get(field);
switch (type) {
case Types.BINARY:
case Types.VARBINARY: {
statement.setBytes(position, binary.get());
break;
}
case Types.LONGVARBINARY: {
statement.setBinaryStream(position, binary.getInputStream());
break;
}
case Types.BLOB: {
statement.setBlob(position, binary.getInputStream());
break;
}
case Types.CHAR:
case Types.VARCHAR:
case Types.LONGVARCHAR:
case Types.NCHAR:
case Types.NVARCHAR: {
statement.setString(position, binary.toBase64());
break;
}
default:
throw new ResultSetException("Unknown binary field type: " + type + " for " + field, null, ctx);
}
break;
}
case ENUM: {
PEnumValue value = message.get(field);
statement.setInt(position, value.asInteger());
break;
}
case MESSAGE: {
PMessage value = message.get(field);
switch (type) {
case Types.BINARY:
case Types.LONGVARBINARY:
case Types.VARBINARY: {
ByteArrayOutputStream out = new ByteArrayOutputStream();
try {
BINARY.serialize(out, value);
statement.setBytes(position, out.toByteArray());
} catch (IOException e) {
throw new ResultSetException(e.getMessage(), e, ctx);
}
break;
}
case Types.BLOB: {
ByteArrayOutputStream out = new ByteArrayOutputStream();
try {
BINARY.serialize(out, value);
statement.setBlob(position, new ByteArrayInputStream(out.toByteArray()));
} catch (IOException e) {
throw new ResultSetException(e.getMessage(), e, ctx);
}
break;
}
case Types.CHAR:
case Types.VARCHAR:
case Types.LONGVARCHAR:
case Types.NCHAR:
case Types.NVARCHAR: {
StringWriter writer = new StringWriter();
try {
JSON.serialize(new PrintWriter(writer), value);
statement.setString(position, writer.getBuffer().toString());
} catch (IOException e) {
throw new ResultSetException(e.getMessage(), e, ctx);
}
break;
}
case Types.NCLOB:
case Types.CLOB: {
StringWriter writer = new StringWriter();
try {
JSON.serialize(new PrintWriter(writer), value);
statement.setClob(position, new StringReader(writer.getBuffer().toString()));
} catch (IOException e) {
throw new ResultSetException(e.getMessage(), e, ctx);
}
break;
}
default:
throw new ResultSetException("Unknown message field type: " + type + " for " + field, null, ctx);
}
break;
}
default:
throw new ResultSetException("Unhandled field type in SQL: " + field, null, ctx);
}
} else {
statement.setNull(position, type);
}
}
static int getDefaultColumnType(PField field) {
switch (field.getType()) {
case BOOL: return Types.BIT;
case BYTE: return Types.TINYINT;
case I16: return Types.SMALLINT;
case I32:
case ENUM:
return Types.INTEGER;
case I64: return Types.BIGINT;
case DOUBLE: return Types.DOUBLE;
case STRING:
case MESSAGE: // JSON string.
return Types.VARCHAR;
case BINARY: return Types.VARBINARY;
case SET:
case LIST:
return Types.ARRAY;
default: {
throw new IllegalArgumentException("No default column type for " + field.toString());
}
}
}
}