MessageNamedArgumentFinder.java

/*
 * Copyright 2018-2019 Providence Authors
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package net.morimekta.proto.jdbi.v3.util;

import com.google.protobuf.Descriptors;
import com.google.protobuf.Message;
import com.google.protobuf.MessageOrBuilder;
import net.morimekta.proto.jdbi.MorimektaJdbiOptions.SqlType;
import net.morimekta.proto.jdbi.ProtoJdbi;
import org.jdbi.v3.core.argument.Argument;
import org.jdbi.v3.core.argument.NamedArgumentFinder;
import org.jdbi.v3.core.argument.NullArgument;
import org.jdbi.v3.core.statement.StatementContext;

import java.util.Locale;
import java.util.Optional;

import static net.morimekta.proto.jdbi.ProtoJdbi.getDefaultColumnType;
import static net.morimekta.strings.StringUtil.isNotEmpty;
import static net.morimekta.strings.StringUtil.isNullOrEmpty;

/**
 * A {@link NamedArgumentFinder} implementation that uses a message
 * and finds values based on the thrift declared field names. This
 * supports chained calls to any depth as long as each level is a
 * single message field.
 * <p>
 * Can be combined with {@link TypedParameterSqlParser} to parse the
 * SQL type with the named parameters.
 */
public class MessageNamedArgumentFinder implements NamedArgumentFinder {
    private final Message message;
    private final String  name;
    private final String  index;
    private final String  fieldPrefix;
    private final String  indexPrefix;
    private final SqlType messageType;

    /**
     * Create a named argument finder.
     *
     * @param name        Optional prefix name. E.g. "x" will make for lookup
     *                    tags like ":x.my_field".
     * @param index       Argument index.
     * @param messageType The SQL type to serialize the message itself.
     * @param message     The message to look up fields in.
     */
    public MessageNamedArgumentFinder(String name,
                                      int index,
                                      SqlType messageType,
                                      Message message) {
        this.message = message;
        this.name = name;
        this.fieldPrefix = isNullOrEmpty(name) ? "" : name + ".";
        this.index = String.valueOf(index);
        this.indexPrefix = index + ".";
        this.messageType = messageType;
    }

    @Override
    public Optional<Argument> find(String identifier, StatementContext ctx) {
        var specType = SqlType.UNSPECIFIED;  // aka default.
        var pipe     = identifier.indexOf("|");
        if (pipe > 0) {
            var typeName = identifier.substring(pipe + 1);
            var val      = SqlType.getDescriptor().findValueByName(typeName.toUpperCase(Locale.US));
            if (val == null) {
                throw new IllegalArgumentException("Unknown SQL type: " + typeName + " for :" + identifier);
            }
            identifier = identifier.substring(0, pipe);
            specType = SqlType.valueOf(val);
        }

        if (isNotEmpty(name) && !identifier.equals(name) && !identifier.startsWith(fieldPrefix)) {
            return Optional.empty();
        }

        if (identifier.equals(index)) {
            return Optional.of(new MessageArgument(
                    message, specType == SqlType.UNSPECIFIED ? messageType : specType));
        } else if (identifier.startsWith(indexPrefix)) {
            identifier = identifier.substring(fieldPrefix.length());
        } else if (isNotEmpty(name)) {
            if (identifier.equals(name)) {
                return Optional.of(new MessageArgument(
                        message, specType == SqlType.UNSPECIFIED ? messageType : specType));
            } else if (identifier.startsWith(fieldPrefix)) {
                identifier = identifier.substring(fieldPrefix.length());
            }
        }

        String[]               parts          = identifier.split("\\.", Byte.MAX_VALUE);
        MessageOrBuilder       leaf           = message;
        Descriptors.Descriptor leafDescriptor = message.getDescriptorForType();

        for (int i = 0; i < parts.length - 1; ++i) {
            var part  = parts[i];
            var field = leafDescriptor.findFieldByName(part);
            if (field == null) return Optional.empty();
            if (field.getType() != Descriptors.FieldDescriptor.Type.MESSAGE) {
                throw new IllegalArgumentException(
                        "Not a message field: " +
                        parts[i] +
                        " of " +
                        leafDescriptor.getFullName() +
                        ": " +
                        field.getType());
            }
            if (field.isRepeated()) {
                throw new IllegalArgumentException(
                        "Repeated message field: " + parts[i] + " of " + leafDescriptor.getFullName());
            }
            leafDescriptor = field.getMessageType();
            if (leaf != null) {
                leaf = (Message) leaf.getField(field);
            }
        }
        String leafName = parts[parts.length - 1];
        var    field    = leafDescriptor.findFieldByName(leafName);
        if (field != null) {
            var sqlType = getFieldSqlType(field, specType);
            if (leaf != null) {
                return Optional.of(new MessageFieldArgument(leaf, field, sqlType));
            }
            return Optional.of(new NullArgument(ProtoJdbi.getColumnType(sqlType)));
        }
        return Optional.empty();
    }

    private SqlType getFieldSqlType(Descriptors.FieldDescriptor field, SqlType specType) {
        if (specType != SqlType.UNSPECIFIED) return specType;
        return getDefaultColumnType(field);
    }
}