MessageColumnMapper.java

/*
 * Copyright 2018-2019 Providence Authors
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package net.morimekta.proto.jdbi.v3;

import com.google.protobuf.Descriptors;
import com.google.protobuf.Message;
import com.google.protobuf.MessageOrBuilder;
import net.morimekta.proto.ProtoMessage;
import org.jdbi.v3.core.mapper.ColumnMapper;
import org.jdbi.v3.core.result.ResultSetException;
import org.jdbi.v3.core.result.UnableToProduceResultException;
import org.jdbi.v3.core.statement.StatementContext;

import java.io.IOException;
import java.io.InputStream;
import java.sql.Blob;
import java.sql.Clob;
import java.sql.JDBCType;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Timestamp;
import java.sql.Types;
import java.time.Duration;
import java.time.Instant;
import java.util.Calendar;
import java.util.Date;
import java.util.TimeZone;
import java.util.regex.Pattern;

import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.concurrent.TimeUnit.SECONDS;
import static net.morimekta.proto.ProtoMessage.getMessageClass;
import static net.morimekta.proto.jdbi.v3.util.MessageArgument.JSON;
import static net.morimekta.proto.utils.GoogleTypesUtil.makeProtoDuration;
import static net.morimekta.proto.utils.GoogleTypesUtil.makeProtoTimestamp;
import static net.morimekta.proto.utils.GoogleTypesUtil.toProtoDuration;
import static net.morimekta.proto.utils.GoogleTypesUtil.toProtoTimestamp;

/**
 * Map a result set to a message based on meta information and the message
 * descriptor.
 *
 * @param <M> The message type.
 */
public class MessageColumnMapper<M extends Message> implements ColumnMapper<M> {
    /**
     * Create a message row mapper.
     *
     * @param descriptor Message descriptor.
     */
    public MessageColumnMapper(Descriptors.Descriptor descriptor) {
        this.descriptor = descriptor;
    }

    @Override
    public String toString() {
        return "MessageColumnMapper{type=" + descriptor.getFullName() + "}";
    }

    @Override
    public M map(ResultSet rs, int position, StatementContext ctx) throws SQLException {
        var columnType = rs.getMetaData().getColumnType(position);
        try {
            if (descriptor.equals(com.google.protobuf.Timestamp.getDescriptor())) {
                switch (columnType) {
                    case Types.TIMESTAMP:
                    case Types.TIMESTAMP_WITH_TIMEZONE: {
                        Timestamp ts = rs.getTimestamp(position, UTC);
                        if (ts != null) {
                            var tsb = com.google.protobuf.Timestamp.newBuilder();
                            tsb.setSeconds(ts.getTime() / 1000L);
                            tsb.setNanos(ts.getNanos());
                            return cast(tsb.build());
                        }
                        return null;
                    }
                    case Types.DATE: {
                        Date date = rs.getDate(position, UTC);
                        if (date != null) {
                            return cast(makeProtoTimestamp(date.getTime(), MILLISECONDS));
                        }
                        return null;
                    }
                    case Types.INTEGER:
                    case Types.NUMERIC: {
                        var seconds = rs.getInt(position);
                        if (!rs.wasNull()) {
                            return cast(makeProtoTimestamp(seconds, SECONDS));
                        }
                        return null;
                    }
                    case Types.BIGINT: {
                        var milliseconds = rs.getLong(position);
                        if (!rs.wasNull()) {
                            return cast(makeProtoTimestamp(milliseconds, MILLISECONDS));
                        }
                        return null;
                    }
                    case Types.DOUBLE: {
                        var seconds = rs.getDouble(position);
                        if (!rs.wasNull()) {
                            var tsb = com.google.protobuf.Timestamp.newBuilder();
                            tsb.setSeconds((long) seconds);
                            var secondsDecimals = seconds - tsb.getSeconds();
                            // We only care about millisecond resolution.
                            var millis = (int) Math.round(secondsDecimals * 1000);
                            if (millis == 1000) {
                                tsb.setSeconds(tsb.getSeconds() + 1);
                            } else if (millis < 0) {
                                tsb.setSeconds(tsb.getSeconds() - 1);
                                tsb.setNanos((1000 - millis) * 1_000_000);
                            } else {
                                tsb.setNanos(millis * 1_000_000);
                            }
                            return cast(tsb.build());
                        }
                        return null;
                    }
                    case Types.CHAR:
                    case Types.NCHAR:
                    case Types.VARCHAR:
                    case Types.NVARCHAR:
                    case Types.LONGVARCHAR:
                    case Types.LONGNVARCHAR: {
                        String tmp = rs.getString(position);
                        if (tmp != null && !JSON_PATTERN.matcher(tmp).matches()) {
                            try {
                                var ts = Instant.parse(tmp);
                                return cast(toProtoTimestamp(ts));
                            } catch (Exception ignored) {
                                // ignore
                            }
                        }
                        // continue to basic matchers.
                        break;
                    }
                }
            }

            if (descriptor.equals(com.google.protobuf.Duration.getDescriptor())) {
                switch (columnType) {
                    case Types.TIME:
                    case Types.TIME_WITH_TIMEZONE: {
                        var time = rs.getTime(position, UTC);
                        if (time != null) {
                            return cast(makeProtoDuration(time.getTime(), MILLISECONDS));
                        }
                        return null;
                    }
                    case Types.REAL:
                    case Types.FLOAT:
                    case Types.DOUBLE:
                    case Types.DECIMAL: {
                        var seconds = rs.getDouble(position);
                        if (!rs.wasNull()) {
                            var db = com.google.protobuf.Duration.newBuilder();
                            db.setSeconds((long) seconds);
                            var secondsDecimals = seconds - db.getSeconds();
                            // We only care about millisecond resolution.
                            var millis = (int) Math.round(secondsDecimals * 1000);
                            if (millis >= 1000) {
                                db.setSeconds(db.getSeconds() + 1);
                            } else if (millis <= -1000) {
                                db.setSeconds(db.getSeconds() - 1);
                            } else {
                                db.setNanos(millis * 1_000_000);
                            }
                            return cast(db.build());
                        }
                        return null;
                    }
                    case Types.TINYINT:
                    case Types.SMALLINT:
                    case Types.NUMERIC: {
                        var seconds = rs.getLong(position);
                        if (!rs.wasNull()) {
                            var db = com.google.protobuf.Duration.newBuilder();
                            db.setSeconds(seconds);
                            return cast(db.build());
                        }
                        return null;
                    }
                    case Types.INTEGER:
                    case Types.BIGINT: {
                        var milliseconds = rs.getLong(position);
                        if (!rs.wasNull()) {
                            return cast(makeProtoDuration(milliseconds, MILLISECONDS));
                        }
                        return null;
                    }
                    case Types.CHAR:
                    case Types.VARCHAR:
                    case Types.LONGVARCHAR:
                    case Types.NCHAR:
                    case Types.NVARCHAR:
                    case Types.LONGNVARCHAR: {
                        String tmp = rs.getString(position);
                        if (tmp != null) {
                            if (DURATION_PATTERN.matcher(tmp).matches()) {
                                var dur = Duration.parse(tmp);
                                return cast(toProtoDuration(dur));
                            }
                        }
                        // Continue to basic matchers.
                        break;
                    }
                }
            }

            switch (columnType) {
                case Types.BINARY:
                case Types.VARBINARY:
                case Types.LONGVARBINARY: {
                    InputStream is = rs.getBinaryStream(position);
                    if (is != null) {
                        Message.Builder bld = ProtoMessage.newBuilder(descriptor);
                        bld.mergeFrom(is);
                        return cast(bld.build());
                    }
                    return null;
                }
                case Types.BLOB: {
                    Blob blob = rs.getBlob(position);
                    if (blob != null) {
                        Message.Builder bld = ProtoMessage.newBuilder(descriptor);
                        bld.mergeFrom(blob.getBinaryStream());
                        return cast(bld.build());
                    }
                    return null;
                }
                case Types.CHAR:
                case Types.VARCHAR:
                case Types.LONGVARCHAR:
                case Types.NCHAR:
                case Types.NVARCHAR:
                case Types.LONGNVARCHAR: {
                    String tmp = rs.getString(position);
                    if (tmp != null) {
                        return cast(JSON.fromJson(tmp, getMessageClass(descriptor)));
                    }
                    return null;
                }
                case Types.CLOB:
                case Types.NCLOB: {
                    Clob clob = rs.getClob(position);
                    if (clob != null) {
                        return cast(JSON.fromJson(
                                clob.getCharacterStream(),
                                getMessageClass(descriptor)));
                    }
                    return null;
                }
                default:
                    throw new ResultSetException(
                            "Unknown column type " + JDBCType.valueOf(columnType) +
                            "(" + columnType + ")" +
                            " for " +
                            descriptor.getFullName(),
                            null, ctx);
            }
        } catch (IOException e) {
            throw new UnableToProduceResultException(e.getMessage(), e, ctx);
        }
    }

    @SuppressWarnings("unchecked")
    private M cast(MessageOrBuilder m) {
        return (M) m;
    }

    private final Descriptors.Descriptor descriptor;

    private static final Calendar UTC              = Calendar.getInstance(TimeZone.getTimeZone("UTC"));
    // Pattern to match ISO duration (period).
    private static final Pattern  DURATION_PATTERN =
            Pattern.compile("([-+]?)P(?:([-+]?[0-9]+)D)?" +
                            "(T(?:([-+]?[0-9]+)H)?(?:([-+]?[0-9]+)M)?" +
                            "(?:([-+]?[0-9]+)(?:[.,]([0-9]{0,9}))?S)?)?",
                            Pattern.CASE_INSENSITIVE);
    // Pattern to match JSON object (not value).
    private static final Pattern  JSON_PATTERN     =
            Pattern.compile("^(\\{.*}|\\[.*])\\n?$");
}