MessageRowMapper.java

/*
 * Copyright 2018-2019 Providence Authors
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package net.morimekta.proto.jdbi.v3;

import com.google.protobuf.Descriptors;
import com.google.protobuf.Message;
import com.google.protobuf.MessageOrBuilder;
import net.morimekta.collect.UnmodifiableMap;
import net.morimekta.proto.ProtoEnum;
import net.morimekta.proto.ProtoMessage;
import net.morimekta.proto.ProtoMessageBuilder;
import net.morimekta.proto.jdbi.MorimektaJdbiOptions;
import net.morimekta.strings.ReaderUtil;
import net.morimekta.strings.StringUtil;
import org.jdbi.v3.core.mapper.RowMapper;
import org.jdbi.v3.core.result.ResultSetException;
import org.jdbi.v3.core.result.UnableToProduceResultException;
import org.jdbi.v3.core.statement.StatementContext;

import java.io.IOException;
import java.io.Reader;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Timestamp;
import java.sql.Types;
import java.util.Calendar;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.TimeZone;

import static net.morimekta.proto.ProtoEnum.getEnumDescriptor;
import static net.morimekta.proto.jdbi.ProtoJdbi.getDefaultColumnName;

/**
 * Map a result set to a message based on meta information and the message
 * descriptor.
 *
 * @param <M> The message type.
 */
public class MessageRowMapper<M extends MessageOrBuilder> implements RowMapper<M> {
    /**
     * Create a message row mapper.
     *
     * @param tableName        The name of the table to filter fields for this mapper.
     * @param descriptor       Message descriptor.
     * @param fieldNameMapping The field mapping. If empty will map all fields with default names.
     */
    public MessageRowMapper(String tableName,
                            Descriptors.Descriptor descriptor,
                            Map<String, Descriptors.FieldDescriptor> fieldNameMapping) {
        this.tableName = StringUtil.emptyToNull(tableName);
        this.descriptor = descriptor;
        this.fieldNameMapping = UnmodifiableMap.asMap(fieldNameMapping);
    }

    @Override
    public String toString() {
        return "MessageRowMapper{type=" +
               descriptor.getFullName() +
               (tableName == null ? "" : ", tableName=" + tableName) +
               "}";
    }

    @Override
    public M map(ResultSet rs, StatementContext ctx) throws SQLException {
        ProtoMessageBuilder builder     = new ProtoMessageBuilder(descriptor);
        int                 columnCount = rs.getMetaData().getColumnCount();
        for (int i = 1; i <= columnCount; ++i) {
            if (tableName != null && !tableName.equalsIgnoreCase(rs.getMetaData().getTableName(i))) {
                continue;
            }

            String                      name  = rs.getMetaData().getColumnLabel(i).toLowerCase(Locale.US);
            Descriptors.FieldDescriptor field = fieldNameMapping.get(name);
            if (field == null) {
                continue;
            }
            String typeName = field.getJavaType().name();
            try {
                int columnType = rs.getMetaData().getColumnType(i);
                switch (field.getType()) {
                    case BOOL: {
                        if (columnType == Types.BOOLEAN || columnType == Types.BIT) {
                            boolean b = rs.getBoolean(i);
                            if (!rs.wasNull()) {
                                builder.set(field, b);
                            }
                        } else {
                            int b = rs.getInt(i);
                            if (!rs.wasNull()) {
                                builder.set(field, b != 0);
                            }
                        }
                        break;
                    }
                    case INT32:
                    case UINT32:
                    case SINT32:
                    case FIXED32:
                    case SFIXED32: {
                        if (columnType == Types.TIMESTAMP || columnType == Types.TIMESTAMP_WITH_TIMEZONE) {
                            Timestamp ts = rs.getTimestamp(i, UTC);
                            if (ts != null) {
                                builder.set(field, (int) (ts.getTime() / 1000L));
                            }
                        } else {
                            int b = rs.getInt(i);
                            if (!rs.wasNull()) {
                                builder.set(field, b);
                            }
                        }
                        break;
                    }
                    case INT64:
                    case UINT64:
                    case SINT64:
                    case FIXED64:
                    case SFIXED64: {
                        if (columnType == Types.TIMESTAMP || columnType == Types.TIMESTAMP_WITH_TIMEZONE) {
                            Timestamp ts = rs.getTimestamp(i, UTC);
                            if (ts != null) {
                                builder.set(field, ts.getTime());
                            }
                        } else {
                            long b = rs.getLong(i);
                            if (!rs.wasNull()) {
                                builder.set(field, b);
                            }
                        }
                        break;
                    }
                    case FLOAT: {
                        float flt = rs.getFloat(i);
                        if (!rs.wasNull()) {
                            builder.set(field, flt);
                        }
                        break;
                    }
                    case DOUBLE: {
                        double dbl = rs.getDouble(i);
                        if (!rs.wasNull()) {
                            builder.set(field, dbl);
                        }
                        break;
                    }
                    case STRING: {
                        switch (columnType) {
                            case Types.CLOB:
                            case Types.NCLOB: {
                                Reader reader = rs.getCharacterStream(i);
                                if (reader != null && !rs.wasNull()) {
                                    try {
                                        builder.set(field, ReaderUtil.readAll(reader));
                                    } catch (IOException e) {
                                        throw new UnableToProduceResultException(e, ctx);
                                    }
                                }
                                break;
                            }
                            default: {
                                builder.set(field, rs.getString(i));
                                break;
                            }
                        }
                        break;
                    }
                    case BYTES: {
                        var b = ByteStringColumnMapper.INSTANCE.map(rs, i, ctx);
                        if (b != null) {
                            builder.set(field, b);
                        }
                        break;
                    }
                    case ENUM: {
                        var ed = field.getEnumType();
                        typeName = ed.getFullName();
                        var acceptUnknown = ed
                                .getOptions().getExtension(MorimektaJdbiOptions.sqlAcceptUnknown);
                        var mapper = new EnumColumnMapper<>(
                                acceptUnknown, (ProtoEnum<?>) getEnumDescriptor(ed));
                        var value = mapper.map(rs, i, ctx);
                        if (value != null) {
                            builder.set(field, value);
                        }
                        break;
                    }
                    case MESSAGE: {
                        Descriptors.Descriptor md = field.getMessageType();
                        typeName = md.getFullName();
                        Message out = new MessageColumnMapper<>(md).map(rs, i, ctx);
                        builder.set(field, out);
                        break;
                    }
                    default: {
                        throw new ResultSetException(
                                "Unhandled column of type " + rs.getMetaData().getColumnTypeName(i) +
                                "(" + columnType + ")" +
                                " for " + field.getType() +
                                " field " + field.getName() + " in " +
                                descriptor.getFullName(),
                                null, ctx);
                    }
                }
            } catch (ResultSetException | UnableToProduceResultException e) {
                throw e;
            } catch (Exception e) {
                throw new ResultSetException(
                        "Error for field " + typeName +
                        " " + field.getName() +
                        " = " + field.getNumber() +
                        "; in " + descriptor.getFullName() +
                        ": " + e.getMessage(),
                        e, ctx);
            }
        }

        return buildMessage(builder);
    }

    public static <M extends MessageOrBuilder> Builder<M> builder(String tableName, Class<M> messageType) {
        return new Builder<>(tableName, messageType);
    }

    public static <M extends MessageOrBuilder> Builder<M> builder(Class<M> messageType) {
        return new Builder<>("", messageType);
    }

    public static class Builder<BM extends MessageOrBuilder> {
        private final String                                   tableName;
        private final Descriptors.Descriptor                   descriptor;
        private final Map<String, Descriptors.FieldDescriptor> fieldMapping = new HashMap<>();
        private       boolean                                  allColumns   = false;

        private Builder(String tableName, Class<BM> messageType) {
            this.tableName = tableName;
            this.descriptor = ProtoMessage.getMessageDescriptor(messageType);
        }

        public Builder<BM> allColumns() {
            if (allColumns) {
                throw new IllegalStateException("All columns not must be set before calling allColumns()");
            }
            allColumns = true;
            for (Descriptors.FieldDescriptor field : descriptor.getFields()) {
                if (!fieldMapping.containsValue(field) || field.getOptions().getExtension(MorimektaJdbiOptions.sqlIgnore)) {
                    // Field itself is already covered, or should be ignored by decree of
                    // the proto extension.
                    continue;
                }
                var name = getDefaultColumnName(field);
                if (fieldMapping.containsKey(name)) {
                    throw new IllegalStateException("Duplicate column name " + name);
                }
                fieldMapping.put(name, field);
            }
            return this;
        }

        public Builder<BM> map(String columnName, int fieldNumber) {
            Objects.requireNonNull(columnName);
            if (allColumns) {
                throw new IllegalStateException("All (remaining) columns already default mapped.");
            }
            var field = descriptor.findFieldByNumber(fieldNumber);
            if (field == null) {
                throw new IllegalArgumentException(
                        "Field " + fieldNumber + " not found in " + descriptor.getFullName());
            }
            fieldMapping.put(columnName.toLowerCase(Locale.US), field);
            return this;
        }

        public MessageRowMapper<BM> build() {
            var mappingBuilder = new HashMap<String, Descriptors.FieldDescriptor>();
            if (fieldMapping.isEmpty() && !allColumns) {
                allColumns();
            }

            fieldMapping.forEach((columnName, mappedField) -> {
                mappingBuilder.put(columnName.toLowerCase(Locale.US), mappedField);
            });
            return new MessageRowMapper<>(tableName, descriptor, mappingBuilder);
        }
    }

    @SuppressWarnings("unchecked")
    private M buildMessage(ProtoMessageBuilder builder) {
        return (M) builder.getMessage().build();
    }

    private static final Calendar UTC = Calendar.getInstance(TimeZone.getTimeZone("UTC"));

    private final Descriptors.Descriptor                   descriptor;
    private final Map<String, Descriptors.FieldDescriptor> fieldNameMapping;
    private final String                                   tableName;
}