MessageFieldArgument.java

/*
 * Copyright 2018-2019 Providence Authors
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package net.morimekta.providence.jdbi.v2;

import net.morimekta.providence.PEnumValue;
import net.morimekta.providence.PMessage;
import net.morimekta.providence.PMessageOrBuilder;
import net.morimekta.providence.PType;
import net.morimekta.providence.descriptor.PField;
import net.morimekta.providence.serializer.BinarySerializer;
import net.morimekta.providence.serializer.JsonSerializer;
import net.morimekta.util.Binary;
import org.skife.jdbi.v2.StatementContext;
import org.skife.jdbi.v2.tweak.Argument;

import javax.annotation.Nonnull;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.StringReader;
import java.io.StringWriter;
import java.sql.Date;
import java.sql.PreparedStatement;
import java.sql.SQLDataException;
import java.sql.SQLException;
import java.sql.Timestamp;
import java.sql.Types;
import java.util.Calendar;
import java.util.Objects;
import java.util.TimeZone;

/**
 * Smart mapping of message fields to SQL bound argument. It will
 * map the type to whichever type is default or selected (if supported)
 * for most field types.
 *
 * @param <M> The message type.
 */
public class MessageFieldArgument<M extends PMessage<M>> implements Argument {
    private static final BinarySerializer BINARY = new BinarySerializer();
    private static final JsonSerializer   JSON   = new JsonSerializer().named();
    private static final Calendar         UTC    = Calendar.getInstance(TimeZone.getTimeZone("UTC"));

    private final PMessageOrBuilder<M> message;
    private final PField<M>            field;
    private final int                  type;

    /**
     * Create a message field argument.
     *
     * @param message The message to get the field from.
     * @param field The field to select.
     */
    public MessageFieldArgument(@Nonnull PMessageOrBuilder<M> message, @Nonnull PField<M> field) {
        this(message, field, getDefaultColumnType(field));
    }

    /**
     * Create a message field argument.
     *
     * @param message The message to get the field from.
     * @param field The field to select.
     * @param type The SQL type. See {@link Types}.
     */
    public MessageFieldArgument(@Nonnull PMessageOrBuilder<M> message, @Nonnull PField<M> field, int type) {
        this.message = message;
        this.field = field;
        this.type = type;
    }

    @Override
    public String toString() {
        if (message.has(field)) {
            if (field.getType() == PType.STRING) {
                return "'" + message.get(field) + "'";
            } else if (field.getType() == PType.BINARY) {
                Binary binary = message.get(field);
                return binary.toHexString();
            } else if (field.getType() == PType.MESSAGE) {
                return ((PMessage) message.get(field)).asString();
            }
            return String.valueOf((Object) message.get(field));
        }

        return "null";
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) return true;
        if (!(o instanceof MessageFieldArgument)) return false;
        MessageFieldArgument other = (MessageFieldArgument) o;
        return Objects.equals(message, other.message) &&
                Objects.equals(field, other.field) &&
                type == other.type;
    }

    @Override
    public int hashCode() {
        return Objects.hash(getClass(), message, field, type);
    }

    @Override
    @SuppressWarnings("unchecked")
    public void apply(int position, PreparedStatement statement, StatementContext ctx) throws SQLException {
        if (message.has(field)) {
            switch (field.getType()) {
                case BOOL: {
                    boolean value = message.get(field);
                    if (type == Types.BOOLEAN || type == Types.BIT) {
                        statement.setBoolean(position, value);
                    } else {
                        statement.setInt(position, value ? 1 : 0);
                    }
                    break;
                }
                case BYTE: {
                    statement.setByte(position, message.get(field));
                    break;
                }
                case I16: {
                    statement.setShort(position, message.get(field));
                    break;
                }
                case I32: {
                    if (type == Types.TIMESTAMP || type == Types.TIMESTAMP_WITH_TIMEZONE) {
                        Timestamp timestamp = new Timestamp(1000L * (int) message.get(field));
                        statement.setTimestamp(position, timestamp, UTC);
                    } else if (type == Types.DATE) {
                        Date date = new Date(1000L * (int) message.get(field));
                        statement.setDate(position, date, UTC);
                    } else {
                        statement.setInt(position, message.get(field));
                    }
                    break;
                }
                case I64: {
                    if (type == Types.TIMESTAMP || type == Types.TIMESTAMP_WITH_TIMEZONE) {
                        Timestamp timestamp = new Timestamp(message.get(field));
                        statement.setTimestamp(position, timestamp, UTC);
                    } else if (type == Types.DATE) {
                        Date date = new Date(message.get(field));
                        statement.setDate(position, date, UTC);
                    } else {
                        statement.setLong(position, message.get(field));
                    }
                    break;
                }
                case DOUBLE: {
                    statement.setDouble(position, message.get(field));
                    break;
                }
                case STRING: {
                    switch (type) {
                        case Types.CLOB:
                        case Types.NCLOB: {
                            StringReader reader = new StringReader(message.get(field));
                            statement.setClob(position, reader);
                            break;
                        }
                        default: {
                            statement.setString(position, message.get(field));
                            break;
                        }
                    }
                    break;
                }
                case BINARY: {
                    Binary binary = message.get(field);
                    switch (type) {
                        case Types.BINARY:
                        case Types.VARBINARY: {
                            statement.setBytes(position, binary.get());
                            break;
                        }
                        case Types.LONGVARBINARY: {
                            statement.setBinaryStream(position, binary.getInputStream());
                            break;
                        }
                        case Types.BLOB: {
                            statement.setBlob(position, binary.getInputStream());
                            break;
                        }
                        case Types.CHAR:
                        case Types.VARCHAR:
                        case Types.LONGVARCHAR:
                        case Types.NCHAR:
                        case Types.NVARCHAR: {
                            statement.setString(position, binary.toBase64());
                            break;
                        }
                        default:
                            throw new SQLDataException("Unknown binary field type: " + type + " for " + field);
                    }
                    break;
                }
                case ENUM: {
                    PEnumValue value = message.get(field);
                    statement.setInt(position, value.asInteger());
                    break;
                }
                case MESSAGE: {
                    PMessage value = message.get(field);
                    switch (type) {
                        case Types.BINARY:
                        case Types.LONGVARBINARY:
                        case Types.VARBINARY: {
                            ByteArrayOutputStream out = new ByteArrayOutputStream();
                            try {
                                BINARY.serialize(out, value);
                                statement.setBytes(position, out.toByteArray());
                            } catch (IOException e) {
                                throw new SQLDataException(e.getMessage(), e);
                            }
                            break;
                        }
                        case Types.BLOB: {
                            ByteArrayOutputStream out = new ByteArrayOutputStream();
                            try {
                                BINARY.serialize(out, value);
                                statement.setBlob(position, new ByteArrayInputStream(out.toByteArray()));
                            } catch (IOException e) {
                                throw new SQLDataException(e.getMessage(), e);
                            }
                            break;
                        }
                        case Types.CHAR:
                        case Types.VARCHAR:
                        case Types.LONGVARCHAR:
                        case Types.NCHAR:
                        case Types.NVARCHAR: {
                            StringWriter writer = new StringWriter();
                            try {
                                JSON.serialize(new PrintWriter(writer), value);
                                statement.setString(position, writer.getBuffer().toString());
                            } catch (IOException e) {
                                throw new SQLDataException(e.getMessage(), e);
                            }
                            break;
                        }
                        case Types.NCLOB:
                        case Types.CLOB: {
                            StringWriter writer = new StringWriter();
                            try {
                                JSON.serialize(new PrintWriter(writer), value);
                                statement.setClob(position, new StringReader(writer.getBuffer().toString()));
                            } catch (IOException e) {
                                throw new SQLDataException(e.getMessage(), e);
                            }
                            break;
                        }
                        default:
                            throw new SQLDataException("Unknown message field type: " + type + " for " + field);
                    }
                    break;
                }
                default:
                    throw new SQLDataException("Unhandled field type in SQL: " + field);
            }
        } else {
            statement.setNull(position, type);
        }
    }

    static int getDefaultColumnType(PField field) {
        switch (field.getType()) {
            case BOOL: return Types.BIT;
            case BYTE: return Types.TINYINT;
            case I16: return Types.SMALLINT;
            case I32: return Types.INTEGER;
            case I64: return Types.BIGINT;
            case DOUBLE: return Types.DOUBLE;
            case STRING: return Types.VARCHAR;
            case BINARY: return Types.VARBINARY;
            case ENUM: return Types.INTEGER;
            case MESSAGE: return Types.VARCHAR;  // JSON string.
            default: {
                throw new IllegalArgumentException("No default column type for " + field.toString());
            }
        }
    }
}