MessageFieldArgument.java

/*
 * Copyright 2018-2019 Providence Authors
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package net.morimekta.proto.jdbi.v3.util;

import com.google.protobuf.ByteString;
import com.google.protobuf.Descriptors;
import com.google.protobuf.Message;
import com.google.protobuf.MessageOrBuilder;
import net.morimekta.proto.ProtoField;
import net.morimekta.proto.jdbi.MorimektaJdbiOptions;
import net.morimekta.proto.jdbi.ProtoJdbi;
import org.jdbi.v3.core.argument.Argument;
import org.jdbi.v3.core.result.ResultSetException;
import org.jdbi.v3.core.statement.StatementContext;

import java.io.ByteArrayInputStream;
import java.io.StringReader;
import java.sql.Date;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Time;
import java.sql.Timestamp;
import java.sql.Types;
import java.util.Calendar;
import java.util.Objects;
import java.util.TimeZone;

import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Objects.requireNonNull;
import static net.morimekta.proto.jdbi.ProtoJdbi.getColumnType;
import static net.morimekta.proto.jdbi.ProtoJdbi.getDefaultEnumName;

/**
 * Smart mapping of message fields to SQL bound argument. It will
 * map the type to whichever type is default or selected (if supported)
 * for most field types.
 */
public class MessageFieldArgument implements Argument {
    static final Calendar UTC = Calendar.getInstance(TimeZone.getTimeZone("UTC"));

    private final MessageOrBuilder             message;
    private final Descriptors.FieldDescriptor  field;
    private final MorimektaJdbiOptions.SqlType type;

    /**
     * Create a message field argument.
     *
     * @param message The message to get the field from.
     * @param field   The field to select.
     * @param type    The SQL type. See {@link Types}.
     */
    public MessageFieldArgument(MessageOrBuilder message,
                                Descriptors.FieldDescriptor field,
                                MorimektaJdbiOptions.SqlType type) {
        requireNonNull(message, "message == null");
        requireNonNull(field, "field == null");
        requireNonNull(type, "type == null");
        if (field.isRepeated()) {
            throw new IllegalArgumentException("Repeated fields are not supported");
        }
        this.message = message;
        this.field = field;
        this.type = type;
    }

    @Override
    public String toString() {
        return getClass().getSimpleName() +
               "{@type=" +
               type + "(" + getColumnType(type) + "); " +
               message.getDescriptorForType().getFullName() + "; " +
               ProtoField.toString(field) + "}";
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof MessageFieldArgument)) {
            return false;
        }
        MessageFieldArgument other = (MessageFieldArgument) o;
        return Objects.equals(message, other.message) &&
               Objects.equals(field, other.field) &&
               type == other.type;
    }

    @Override
    public int hashCode() {
        return Objects.hash(getClass(), message, field, type);
    }

    @Override
    public void apply(int position, PreparedStatement statement, StatementContext ctx) throws SQLException {
        if (!message.hasField(field)) {
            statement.setNull(position, ProtoJdbi.getColumnType(type));
            return;
        }
        switch (field.getType()) {
            case BOOL: {
                var value = (boolean) message.getField(field);
                switch (type) {
                    case BIT:
                    case BOOLEAN:
                        statement.setBoolean(position, value);
                        break;
                    case TINYINT:
                    case SMALLINT:
                    case INTEGER:
                        statement.setInt(position, value ? 1 : 0);
                        break;
                    default:
                        throw new ResultSetException(
                                "Unsupported type " + type + " for " + ProtoField.toString(field),
                                null, ctx);
                }
                break;
            }
            case INT32:
            case UINT32:
            case SINT32:
            case FIXED32:
            case SFIXED32: {
                var value = (int) message.getField(field);
                switch (type) {
                    case TIMESTAMP:
                    case TIMESTAMP_WITH_TIMEZONE:
                        Timestamp timestamp = new Timestamp(1000L * value);
                        statement.setTimestamp(position, timestamp, UTC);
                        break;
                    case TIME:
                    case TIME_WITH_TIMEZONE:
                        Time time = new Time(1000L * value);
                        statement.setTime(position, time, UTC);
                        break;
                    case DATE:
                        Date date = new Date(1000L * value);
                        statement.setDate(position, date, UTC);
                        break;
                    case TINYINT:
                    case SMALLINT:
                    case INTEGER:
                    case BIGINT:
                    case NUMERIC:
                        statement.setInt(position, value);
                        break;
                    default:
                        throw new ResultSetException(
                                "Unsupported type " + type + " for " + ProtoField.toString(field),
                                null, ctx);
                }
                break;
            }
            case INT64:
            case UINT64:
            case SINT64:
            case FIXED64:
            case SFIXED64: {
                var value = (long) message.getField(field);
                switch (type) {
                    case TIMESTAMP:
                    case TIMESTAMP_WITH_TIMEZONE:
                        Timestamp timestamp = new Timestamp(value);
                        statement.setTimestamp(position, timestamp, UTC);
                        break;
                    case TIME:
                    case TIME_WITH_TIMEZONE:
                        Time time = new Time(value);
                        statement.setTime(position, time, UTC);
                        break;
                    case DATE:
                        Date date = new Date(value);
                        statement.setDate(position, date, UTC);
                        break;
                    case TINYINT:
                    case SMALLINT:
                    case INTEGER:
                    case BIGINT:
                    case NUMERIC:
                        statement.setLong(position, value);
                        break;
                    default:
                        throw new ResultSetException(
                                "Unsupported type " + type + " for " + ProtoField.toString(field),
                                null, ctx);
                }
                break;
            }
            case FLOAT: {
                switch (type) {
                    case FLOAT:
                    case DECIMAL:
                        statement.setFloat(position, (float) message.getField(field));
                        break;
                    case DOUBLE:
                    case REAL:
                        statement.setDouble(position, (float) message.getField(field));
                        break;
                    default:
                        throw new ResultSetException(
                                "Unsupported type " + type + " for " + ProtoField.toString(field),
                                null, ctx);
                }
                break;
            }
            case DOUBLE: {
                switch (type) {
                    case FLOAT:
                    case DECIMAL:
                        statement.setFloat(position, (float) (double) message.getField(field));
                        break;
                    case DOUBLE:
                    case REAL:
                        statement.setDouble(position, (double) message.getField(field));
                        break;
                    default:
                        throw new ResultSetException(
                                "Unsupported type " + type + " for " + ProtoField.toString(field),
                                null, ctx);
                }
                break;
            }
            case STRING: {
                var value = (String) message.getField(field);
                switch (type) {
                    case VARBINARY:
                    case BINARY:
                    case LONG_VARBINARY: {
                        statement.setBytes(position, value.getBytes(UTF_8));
                        break;
                    }
                    case BLOB: {
                        statement.setBlob(position, new ByteArrayInputStream(value.getBytes(UTF_8)));
                        break;
                    }
                    case CLOB:
                    case NCLOB: {
                        statement.setClob(position, new StringReader(value));
                        break;
                    }
                    case CHAR:
                    case NCHAR:
                    case VARCHAR:
                    case NVARCHAR:
                    case LONG_VARCHAR:
                    case LONG_NVARCHAR: {
                        statement.setString(position, value);
                        break;
                    }
                    default: {
                        throw new ResultSetException(
                                "Unsupported type " + type + " for " + ProtoField.toString(field),
                                null, ctx);
                    }
                }
                break;
            }
            case BYTES: {
                ByteString binary = (ByteString) message.getField(field);
                try {
                    new ByteStringArgument(binary, type).apply(position, statement, ctx);
                } catch (Exception e) {
                    if (e instanceof ResultSetException &&
                        e.getMessage().startsWith("Unsupported type")) {
                        throw new ResultSetException(
                                "Unsupported type " + type + " for " + ProtoField.toString(field),
                                e, ctx);
                    }
                    throw new ResultSetException(
                            "Exception for " + field + " in " + message.getDescriptorForType().getFullName() +
                            ": " + e.getMessage(),
                            e, ctx);
                }
                break;
            }
            case ENUM: {
                var value = (Descriptors.EnumValueDescriptor) message.getField(field);
                switch (type) {
                    case CHAR:
                    case VARCHAR:
                    case LONG_VARCHAR:
                    case NCHAR:
                    case NVARCHAR: {
                        statement.setString(position, getDefaultEnumName(value));
                        break;
                    }
                    case TINYINT:
                    case SMALLINT:
                    case INTEGER:
                    case BIGINT: {
                        statement.setInt(position, value.getNumber());
                        break;
                    }
                    default: {
                        throw new ResultSetException(
                                "Unsupported type " + type + " for " + ProtoField.toString(field),
                                null, ctx);
                    }
                }
                break;
            }
            case MESSAGE: {
                Message value = (Message) message.getField(field);
                try {
                    new MessageArgument(value, type).apply(position, statement, ctx);
                } catch (Exception e) {
                    if (e instanceof ResultSetException &&
                        e.getMessage().startsWith("Unsupported type ")) {
                        throw new ResultSetException(
                                "Unsupported type " + type + " for " + ProtoField.toString(field),
                                e, ctx);
                    }
                    throw new ResultSetException(
                            "Exception for " + field + " in " + message.getDescriptorForType().getFullName() +
                            ": " + e.getMessage(),
                            e, ctx);
                }
                break;
            }
            default:
                throw new ResultSetException("Unhandled field type in SQL: " + field, null, ctx);
        }
    }
}