MessageRowMapper.java

/*
 * Copyright 2018-2019 Providence Authors
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package net.morimekta.providence.jdbi.v2;

import net.morimekta.providence.PMessage;
import net.morimekta.providence.PMessageBuilder;
import net.morimekta.providence.descriptor.PEnumDescriptor;
import net.morimekta.providence.descriptor.PField;
import net.morimekta.providence.descriptor.PMessageDescriptor;
import net.morimekta.providence.serializer.BinarySerializer;
import net.morimekta.providence.serializer.JsonSerializer;
import net.morimekta.util.Binary;
import net.morimekta.util.collect.UnmodifiableMap;
import net.morimekta.util.io.IOUtils;
import org.skife.jdbi.v2.StatementContext;
import org.skife.jdbi.v2.tweak.ResultSetMapper;

import javax.annotation.Nonnull;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.Reader;
import java.io.StringReader;
import java.io.UncheckedIOException;
import java.sql.Blob;
import java.sql.Clob;
import java.sql.ResultSet;
import java.sql.SQLDataException;
import java.sql.SQLException;
import java.sql.Timestamp;
import java.sql.Types;
import java.util.Calendar;
import java.util.Date;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import java.util.TimeZone;

/**
 * Map a result set to a message based on meta information and the message
 * descriptor.
 *
 * @param <M> The message type.
 */
public class MessageRowMapper<M extends PMessage<M>> implements ResultSetMapper<M> {
    private static final Calendar UTC = Calendar.getInstance(TimeZone.getTimeZone("UTC"));

    public static final String ALL_FIELDS = "*";

    /**
     * Create a message row mapper.
     *
     * @param descriptor Message descriptor.
     */
    public MessageRowMapper(@Nonnull PMessageDescriptor<M> descriptor) {
        this(descriptor, UnmodifiableMap.mapOf());
    }

    /**
     * Create a message row mapper.
     *
     * @param tableName  The name of the table to filter fields for this mapper.
     * @param descriptor Message descriptor.
     */
    public MessageRowMapper(@Nonnull String tableName, @Nonnull PMessageDescriptor<M> descriptor) {
        this(tableName, descriptor, UnmodifiableMap.mapOf());
    }

    /**
     * Create a message row mapper.
     *
     * @param descriptor   Message descriptor.
     * @param fieldMapping The field mapping. If empty will map all fields with default names.
     */
    public MessageRowMapper(@Nonnull PMessageDescriptor<M> descriptor,
                            @Nonnull Map<String, PField<M>> fieldMapping) {
        this("", descriptor, fieldMapping);
    }

    /**
     * Create a message row mapper.
     *
     * @param tableName    The name of the table to filter fields for this mapper.
     * @param descriptor   Message descriptor.
     * @param fieldMapping The field mapping. If empty will map all fields with default names.
     */
    public MessageRowMapper(@Nonnull String tableName,
                            @Nonnull PMessageDescriptor<M> descriptor,
                            @Nonnull Map<String, PField<M>> fieldMapping) {
        Map<String, PField<M>> mappingBuilder = new HashMap<>();
        if (fieldMapping.isEmpty()) {
            for (PField<M> field : descriptor.getFields()) {
                mappingBuilder.put(field.getName().toUpperCase(Locale.US), field);
            }
        } else {
            fieldMapping.forEach((name, addField) -> {
                if (ALL_FIELDS.equals(name)) {
                    for (PField<M> field : descriptor.getFields()) {
                        String fieldName = field.getName().toUpperCase(Locale.US);
                        // To avoid overwriting already specified fields.
                        if (!mappingBuilder.containsKey(fieldName)) {
                            mappingBuilder.put(fieldName, field);
                        }
                    }
                } else {
                    mappingBuilder.put(name.toUpperCase(Locale.US), addField);
                }
            });
        }

        this.tableName = tableName;
        this.descriptor = descriptor;
        this.fieldNameMapping = UnmodifiableMap.copyOf(mappingBuilder);
    }

    @Override
    public String toString() {
        return getClass().getSimpleName() + "{type=" +
               descriptor.getQualifiedName() +
               (tableName == null ? "" : ", tableName=" + tableName) + "}";
    }

    @Override
    public M map(int idx, ResultSet rs, StatementContext ctx) throws SQLException {
        PMessageBuilder<M> builder = descriptor.builder();
        for (int i = 1; i <= rs.getMetaData().getColumnCount(); ++i) {
            if (!tableName.isEmpty() &&
                !tableName.equalsIgnoreCase(rs.getMetaData().getTableName(i))) {
                continue;
            }

            String name  = rs.getMetaData().getColumnLabel(i).toUpperCase(Locale.US);
            PField<M> field = fieldNameMapping.get(name);
            if (field != null) {
                int columnType = rs.getMetaData().getColumnType(i);
                switch (field.getType()) {
                    case BOOL: {
                        if (columnType == Types.BOOLEAN || columnType == Types.BIT) {
                            boolean b = rs.getBoolean(i);
                            if (!rs.wasNull()) {
                                builder.set(field, b);
                            }
                        } else {
                            int b = rs.getInt(i);
                            if (!rs.wasNull()) {
                                builder.set(field, b != 0);
                            }
                        }
                        break;
                    }
                    case BYTE: {
                        byte b = rs.getByte(i);
                        if (!rs.wasNull()) {
                            builder.set(field, b);
                        }
                        break;
                    }
                    case I16: {
                        short b = rs.getShort(i);
                        if (!rs.wasNull()) {
                            builder.set(field, b);
                        }
                        break;
                    }
                    case I32: {
                        if (columnType == Types.TIMESTAMP || columnType == Types.TIMESTAMP_WITH_TIMEZONE) {
                            Timestamp ts = rs.getTimestamp(i, UTC);
                            if (ts != null) {
                                builder.set(field, (int) (ts.getTime() / 1000L));
                            }
                        } else if (columnType == Types.DATE) {
                            Date date = rs.getDate(i, UTC);
                            if (date != null) {
                                builder.set(field, (int) (date.getTime() / 1000L));
                            }
                        } else {
                            int b = rs.getInt(i);
                            if (!rs.wasNull()) {
                                builder.set(field, b);
                            }
                        }
                        break;
                    }
                    case I64: {
                        if (columnType == Types.TIMESTAMP || columnType == Types.TIMESTAMP_WITH_TIMEZONE) {
                            Timestamp ts = rs.getTimestamp(i, UTC);
                            if (ts != null) {
                                builder.set(field, ts.getTime());
                            }
                        } else if (columnType == Types.DATE) {
                            Date date = rs.getDate(i, UTC);
                            if (date != null) {
                                builder.set(field, date.getTime());
                            }
                        } else {
                            long b = rs.getLong(i);
                            if (!rs.wasNull()) {
                                builder.set(field, b);
                            }
                        }
                        break;
                    }
                    case DOUBLE: {
                        double b = rs.getDouble(i);
                        if (!rs.wasNull()) {
                            builder.set(field, b);
                        }
                        break;
                    }
                    case STRING: {
                        switch (columnType) {
                            case Types.CLOB:
                            case Types.NCLOB: {
                                Reader reader = rs.getCharacterStream(i);
                                if (reader != null) {
                                    try {
                                        builder.set(field, IOUtils.readString(reader));
                                    } catch (IOException e) {
                                        throw new UncheckedIOException(e);
                                    }
                                } else {
                                    builder.clear(field);
                                }
                                break;
                            }
                            default: {
                                builder.set(field, rs.getString(i));
                                break;
                            }
                        }
                        break;
                    }
                    case BINARY: {
                        switch (columnType) {
                            case Types.BINARY:
                            case Types.VARBINARY:
                                byte[] ts = rs.getBytes(i);
                                if (ts != null) {
                                    builder.set(field, Binary.copy(ts));
                                }
                                break;
                            case Types.LONGVARBINARY:
                                InputStream is = rs.getBinaryStream(i);
                                if (is != null) {
                                    try {
                                        builder.set(field, Binary.read(is));
                                    } catch (IOException e) {
                                        throw new UncheckedIOException(e.getMessage(), e);
                                    }
                                }
                                break;
                            case Types.BLOB:
                                Blob blob = rs.getBlob(i);
                                if (blob != null) {
                                    try {
                                        builder.set(field, Binary.read(blob.getBinaryStream(), (int) blob.length()));
                                    } catch (IOException e) {
                                        throw new UncheckedIOException(e.getMessage(), e);
                                    }
                                }
                                break;
                            case Types.CHAR:
                            case Types.VARCHAR:
                            case Types.LONGVARCHAR:
                            case Types.NCHAR:
                            case Types.NVARCHAR: {
                                String tmp = rs.getString(i);
                                if (tmp != null) {
                                    builder.set(field, Binary.fromBase64(tmp));
                                }
                                break;
                            }
                            case Types.NULL:
                                break;
                            default:
                                throw new SQLDataException(
                                        "Unknown column type " + rs.getMetaData().getColumnTypeName(i) +
                                        "(" + columnType + ")" +
                                        " for " + descriptor.getType().toString() +
                                        " field " + field.getName() + " in " +
                                        descriptor.getQualifiedName());
                        }
                        break;
                    }
                    case ENUM: {
                        int val = rs.getInt(i);
                        if (!rs.wasNull()) {
                            PEnumDescriptor ed = (PEnumDescriptor) field.getDescriptor();
                            builder.set(field, ed.findById(val));
                        }
                        break;
                    }
                    case MESSAGE: {
                        try {
                            PMessageDescriptor<?> md = (PMessageDescriptor) field.getDescriptor();
                            switch (columnType) {
                                case Types.BINARY:
                                case Types.VARBINARY: {
                                    byte[] data = rs.getBytes(i);
                                    if (data != null) {
                                        ByteArrayInputStream in = new ByteArrayInputStream(data);
                                        builder.set(field, BINARY.deserialize(in, md));
                                    }
                                    break;
                                }
                                case Types.LONGVARBINARY: {
                                    InputStream is = rs.getBinaryStream(i);
                                    if (is != null) {
                                        builder.set(field, BINARY.deserialize(is, md));
                                    }
                                    break;
                                }
                                case Types.BLOB: {
                                    Blob blob = rs.getBlob(i);
                                    if (blob != null) {
                                        builder.set(field, BINARY.deserialize(blob.getBinaryStream(), md));
                                    }
                                    break;
                                }
                                case Types.CHAR:
                                case Types.VARCHAR:
                                case Types.NCHAR:
                                case Types.NVARCHAR: {
                                    String tmp = rs.getString(i);
                                    if (tmp != null) {
                                        StringReader reader = new StringReader(tmp);
                                        builder.set(field, JSON.deserialize(reader, md));
                                    }
                                    break;
                                }
                                case Types.LONGVARCHAR:
                                case Types.NCLOB:
                                case Types.CLOB: {
                                    Clob clob = rs.getClob(i);
                                    if (clob != null) {
                                        builder.set(field, JSON.deserialize(clob.getCharacterStream(), md));
                                    }
                                    break;
                                }
                                case Types.NULL:
                                    break;
                                default:
                                    throw new SQLDataException(
                                            "Unknown column type " + rs.getMetaData().getColumnTypeName(i) +
                                            "(" + columnType + ")" +
                                            " for " + descriptor.getType().toString() +
                                            " field " + field.getName() + " in " +
                                            descriptor.getQualifiedName());
                            }
                        } catch (IOException e) {
                            throw new UncheckedIOException(e.getMessage(), e);
                        }
                        break;
                    }
                    case LIST:
                    case SET:
                    case MAP: {
                        // ... woot?
                    }
                    case VOID:
                    default: {
                        throw new SQLDataException("Unhandled column of type " + rs.getMetaData().getColumnTypeName(i) +
                                                   "(" + columnType + ")" +
                                                   " for " + descriptor.getType().toString() +
                                                   " field " + field.getName() + " in " +
                                                   descriptor.getQualifiedName());
                    }
                }
            }
        }
        return builder.build();
    }

    private static final BinarySerializer BINARY = new BinarySerializer();
    private static final JsonSerializer   JSON   = new JsonSerializer();

    private final PMessageDescriptor<M>  descriptor;
    private final Map<String, PField<M>> fieldNameMapping;
    private final String                 tableName;
}