MessageArgument.java

/*
 * Copyright 2018-2019 Providence Authors
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package net.morimekta.proto.jdbi.v3.util;

import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.gson.JsonIOException;
import com.google.protobuf.Message;
import net.morimekta.proto.gson.ProtoTypeAdapterFactory;
import net.morimekta.proto.gson.ProtoTypeOptions;
import net.morimekta.proto.jdbi.MorimektaJdbiOptions.SqlType;
import net.morimekta.proto.jdbi.ProtoJdbi;
import org.jdbi.v3.core.argument.Argument;
import org.jdbi.v3.core.result.ResultSetException;
import org.jdbi.v3.core.statement.StatementContext;

import java.io.ByteArrayInputStream;
import java.io.StringReader;
import java.sql.Date;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Time;
import java.sql.Timestamp;
import java.sql.Types;
import java.time.temporal.ChronoUnit;
import java.util.Objects;

import static net.morimekta.proto.gson.ProtoTypeOptions.Option.LENIENT_READER;
import static net.morimekta.proto.gson.ProtoTypeOptions.Option.WRITE_COMPACT_MESSAGE;
import static net.morimekta.proto.gson.ProtoTypeOptions.Option.WRITE_UNPACKED_ANY;
import static net.morimekta.proto.jdbi.v3.util.MessageFieldArgument.UTC;
import static net.morimekta.proto.utils.GoogleTypesUtil.toInstant;
import static net.morimekta.proto.utils.GoogleTypesUtil.toJavaDuration;

/**
 * Smart mapping of message fields to SQL bound argument. It will
 * map the type to whichever type is default or selected (if supported)
 * for most field types.
 */
public class MessageArgument implements Argument {
    private final Message message;
    private final SqlType type;

    /**
     * Create a message field argument.
     *
     * @param message The message to get the field from.
     * @param type    The SQL type. See {@link Types}.
     */
    public MessageArgument(Message message, SqlType type) {
        this.message = message;
        this.type = type;
    }

    @Override
    public String toString() {
        return getClass().getSimpleName() +
               "{@type=" + type + "(" + ProtoJdbi.getColumnType(type) + "); " +
               message.getDescriptorForType().getFullName() +
               "}";
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof MessageArgument)) {
            return false;
        }
        MessageArgument other = (MessageArgument) o;
        return Objects.equals(message, other.message) &&
               type == other.type;
    }

    @Override
    public int hashCode() {
        return Objects.hash(getClass(), message, type);
    }

    @Override
    public void apply(int position, PreparedStatement statement, StatementContext ctx) throws SQLException {
        try {
            if (message instanceof com.google.protobuf.Timestamp) {
                var instant = toInstant((com.google.protobuf.Timestamp) message);
                switch (type) {
                    case TIMESTAMP:
                    case TIMESTAMP_WITH_TIMEZONE: {
                        var ts = new Timestamp(instant.toEpochMilli());
                        statement.setTimestamp(position, ts, UTC);
                        return;
                    }
                    case DATE: {
                        var date = new Date(instant.truncatedTo(ChronoUnit.DAYS).toEpochMilli());
                        statement.setDate(position, date, UTC);
                        return;
                    }
                    case INTEGER:
                    case NUMERIC: {
                        statement.setInt(position, (int) instant.getEpochSecond());
                        return;
                    }
                    case BIGINT: {
                        statement.setLong(position, instant.toEpochMilli());
                        return;
                    }
                    case DOUBLE: {
                        var dbl = (double) instant.getEpochSecond();
                        if (instant.getNano() != 0) {
                            // We only care about millisecond resolution.
                            var millis = instant.getNano() / 1_000_000;
                            dbl += (millis / 1_000.0);
                        }
                        statement.setDouble(position, dbl);
                        return;
                    }
                    case CHAR:
                    case NCHAR:
                    case VARCHAR:
                    case NVARCHAR:
                    case LONG_VARCHAR:
                    case LONG_NVARCHAR: {
                        statement.setString(position, instant.toString());
                        return;
                    }
                }
            } else if (message instanceof com.google.protobuf.Duration) {
                var dur = toJavaDuration((com.google.protobuf.Duration) message);
                switch (type) {
                    case TIME:
                    case TIME_WITH_TIMEZONE: {
                        // TIME type
                        var time = new Time(dur.truncatedTo(ChronoUnit.SECONDS).toMillis());
                        statement.setTime(position, time, UTC);
                        return;
                    }
                    case REAL:
                    case FLOAT:
                    case DOUBLE:
                    case DECIMAL: {
                        // duration in seconds, with decimals, millisecond resolution.
                        var dbl = (double) dur.getSeconds();
                        if (dur.getNano() != 0) {
                            var millis = dur.getNano() / 1_000_000;
                            dbl += (millis / 1_000.0);
                        }
                        statement.setDouble(position, dbl);
                        return;
                    }
                    case TINYINT:
                    case SMALLINT:
                    case NUMERIC: {
                        // duration in seconds.
                        statement.setInt(position, (int) dur.getSeconds());
                        return;
                    }
                    case INTEGER:
                    case BIGINT: {
                        // duration in milliseconds.
                        statement.setLong(position, dur.toMillis());
                        return;
                    }
                    case CHAR:
                    case NCHAR:
                    case VARCHAR:
                    case NVARCHAR:
                    case LONG_VARCHAR:
                    case LONG_NVARCHAR: {
                        statement.setString(position, dur.toString());
                        return;
                    }
                }
            }

            switch (type) {
                case BINARY:
                case VARBINARY:
                case LONG_VARBINARY: {
                    statement.setBytes(position, message.toByteArray());
                    return;
                }
                case BLOB: {
                    statement.setBlob(position, new ByteArrayInputStream(message.toByteArray()));
                    return;
                }
                case CHAR:
                case VARCHAR:
                case LONG_VARCHAR:
                case NCHAR:
                case NVARCHAR:
                case LONG_NVARCHAR: {
                    statement.setString(position, JSON.toJson(message));
                    return;
                }
                case CLOB:
                case NCLOB: {
                    statement.setClob(position, new StringReader(JSON.toJson(message)));
                    return;
                }
                default:
                    throw new ResultSetException(
                            "Unsupported type " + type + " (" + ProtoJdbi.getColumnType(type) + ") for " +
                            message.getDescriptorForType().getFullName(),
                            null, ctx);
            }
        } catch (JsonIOException e) {
            throw new ResultSetException(e.getMessage(), e, ctx);
        }
    }

    public static final Gson JSON = new GsonBuilder()
            .registerTypeAdapterFactory(
                    new ProtoTypeAdapterFactory(
                            new ProtoTypeOptions()
                                    .withEnabled(WRITE_COMPACT_MESSAGE)
                                    .withEnabled(WRITE_UNPACKED_ANY)
                                    .withEnabled(LENIENT_READER)
                    ))
            .create();
}