MessageNamedArgumentFinder.java
/*
* Copyright 2018-2019 Providence Authors
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package net.morimekta.proto.jdbi.v3.util;
import com.google.protobuf.Descriptors;
import com.google.protobuf.Message;
import com.google.protobuf.MessageOrBuilder;
import net.morimekta.proto.jdbi.MorimektaJdbiOptions.SqlType;
import net.morimekta.proto.jdbi.ProtoJdbi;
import org.jdbi.v3.core.argument.Argument;
import org.jdbi.v3.core.argument.NamedArgumentFinder;
import org.jdbi.v3.core.argument.NullArgument;
import org.jdbi.v3.core.statement.StatementContext;
import java.util.Locale;
import java.util.Optional;
import static net.morimekta.proto.jdbi.ProtoJdbi.getDefaultColumnType;
import static net.morimekta.strings.StringUtil.isNotEmpty;
import static net.morimekta.strings.StringUtil.isNullOrEmpty;
/**
* A {@link NamedArgumentFinder} implementation that uses a message
* and finds values based on the thrift declared field names. This
* supports chained calls to any depth as long as each level is a
* single message field.
* <p>
* Can be combined with {@link TypedParameterSqlParser} to parse the
* SQL type with the named parameters.
*/
public class MessageNamedArgumentFinder implements NamedArgumentFinder {
private final Message message;
private final String name;
private final String index;
private final String fieldPrefix;
private final String indexPrefix;
private final SqlType messageType;
/**
* Create a named argument finder.
*
* @param name Optional prefix name. E.g. "x" will make for lookup
* tags like ":x.my_field".
* @param index Argument index.
* @param messageType The SQL type to serialize the message itself.
* @param message The message to look up fields in.
*/
public MessageNamedArgumentFinder(String name,
int index,
SqlType messageType,
Message message) {
this.message = message;
this.name = name;
this.fieldPrefix = isNullOrEmpty(name) ? "" : name + ".";
this.index = String.valueOf(index);
this.indexPrefix = index + ".";
this.messageType = messageType;
}
@Override
public Optional<Argument> find(String identifier, StatementContext ctx) {
var specType = SqlType.UNSPECIFIED; // aka default.
var pipe = identifier.indexOf("|");
if (pipe > 0) {
var typeName = identifier.substring(pipe + 1);
var val = SqlType.getDescriptor().findValueByName(typeName.toUpperCase(Locale.US));
if (val == null) {
throw new IllegalArgumentException("Unknown SQL type: " + typeName + " for :" + identifier);
}
identifier = identifier.substring(0, pipe);
specType = SqlType.valueOf(val);
}
if (isNotEmpty(name) && !identifier.equals(name) && !identifier.startsWith(fieldPrefix)) {
return Optional.empty();
}
if (identifier.equals(index)) {
return Optional.of(new MessageArgument(
message, specType == SqlType.UNSPECIFIED ? messageType : specType));
} else if (identifier.startsWith(indexPrefix)) {
identifier = identifier.substring(fieldPrefix.length());
} else if (isNotEmpty(name)) {
if (identifier.equals(name)) {
return Optional.of(new MessageArgument(
message, specType == SqlType.UNSPECIFIED ? messageType : specType));
} else if (identifier.startsWith(fieldPrefix)) {
identifier = identifier.substring(fieldPrefix.length());
}
}
String[] parts = identifier.split("\\.", Byte.MAX_VALUE);
MessageOrBuilder leaf = message;
Descriptors.Descriptor leafDescriptor = message.getDescriptorForType();
for (int i = 0; i < parts.length - 1; ++i) {
var part = parts[i];
var field = leafDescriptor.findFieldByName(part);
if (field == null) return Optional.empty();
if (field.getType() != Descriptors.FieldDescriptor.Type.MESSAGE) {
throw new IllegalArgumentException(
"Not a message field: " +
parts[i] +
" of " +
leafDescriptor.getFullName() +
": " +
field.getType());
}
if (field.isRepeated()) {
throw new IllegalArgumentException(
"Repeated message field: " + parts[i] + " of " + leafDescriptor.getFullName());
}
leafDescriptor = field.getMessageType();
if (leaf != null) {
leaf = (Message) leaf.getField(field);
}
}
String leafName = parts[parts.length - 1];
var field = leafDescriptor.findFieldByName(leafName);
if (field != null) {
var sqlType = getFieldSqlType(field, specType);
if (leaf != null) {
return Optional.of(new MessageFieldArgument(leaf, field, sqlType));
}
return Optional.of(new NullArgument(ProtoJdbi.getColumnType(sqlType)));
}
return Optional.empty();
}
private SqlType getFieldSqlType(Descriptors.FieldDescriptor field, SqlType specType) {
if (specType != SqlType.UNSPECIFIED) return specType;
return getDefaultColumnType(field);
}
}