MessageUpserter.java
/*
* Copyright 2018-2019 Providence Authors
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package net.morimekta.proto.jdbi.v3;
import com.google.protobuf.Descriptors;
import com.google.protobuf.Descriptors.FieldDescriptor;
import com.google.protobuf.MessageOrBuilder;
import net.morimekta.collect.UnmodifiableList;
import net.morimekta.collect.UnmodifiableMap;
import net.morimekta.collect.UnmodifiableSet;
import net.morimekta.collect.util.SetOperations;
import net.morimekta.proto.ProtoField;
import net.morimekta.proto.ProtoMessage;
import net.morimekta.proto.jdbi.MorimektaJdbiOptions.SqlType;
import net.morimekta.proto.jdbi.v3.util.MessageFieldArgument;
import org.jdbi.v3.core.Handle;
import org.jdbi.v3.core.statement.Update;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static java.util.Objects.requireNonNull;
import static net.morimekta.collect.UnmodifiableList.asList;
import static net.morimekta.collect.UnmodifiableList.toList;
import static net.morimekta.collect.util.Pair.pairOf;
import static net.morimekta.proto.jdbi.ProtoJdbi.getDefaultColumnName;
import static net.morimekta.proto.jdbi.ProtoJdbi.getDefaultColumnType;
/**
* Helper class to handle inserting content from messages into a table.
* The helper will only select values form the message itself, not using
* nested structure or anything like that.
* <p>
* The inserter is built in such a way that you can create the inserter
* (even as a static field), and use it any number of times with a handle
* to do the pre-programmed insert. The execute method is thread safe, as
* long as none of the modification methods are called.
*
* <pre>{@code
* class MyInserter {
* private static final MessageUpserter<MyMessage,MyMessage._Field> INSERTER =
* new MessageUpserter.Builder<>("my_message")
* .set(MyMessage.UUID, MyMessage.NAME)
* .set("amount", MyMessage.VALUE, Types.INTEGER) // DOUBLE -> INTEGER
* .onDuplicateKeyUpdate(MyMessage.VALUE)
* .build();
*
* private final Jdbi dbi;
*
* public MyInserter(Jdbi dbi) {
* this.dbi = dbi;
* }
*
* int insert(HandleMyMessage... messages) {
* try (Handle handle = dbi.open()) {
* return INSERTER.execute(handle, messages);
* }
* }
* }
* }</pre>
* <p>
* Or it can be handled in line where needed. The building process is pretty cheap,
* so this should not be a problem unless it is called <i>a lot</i> for very small
* message.
*
* <pre>{@code
* class MyInserter {
* int insert(HandleMyMessage... messages) {
* try (Handle handle = dbi.open()) {
* return new MessageUpserter.Builder<MyMessage,MyMessage._Field>("my_message")
* .set(MyMessage.UUID, MyMessage.NAME)
* .set("amount", MyMessage.VALUE, Types.INTEGER) // DOUBLE -> INTEGER
* .onDuplicateKeyUpdateAllExcept(MyMessage.UUID)
* .build()
* .execute(handle, messages);
* }
* }
* }
* }</pre>
* <p>
* The rules for using this is pretty simple:
*
* <ul>
* <li>
* All fields set must be specified before onDuplicateKey* behavior.
* </li>
* <li>
* Only one of <code>onDuplicateKeyIgnore</code> and <code>onDuplicateKeyUpdate</code>
* can be set.
* </li>
* <li>
* <code>execute(...)</code> can be called any number of times, and is thread safe.
* </li>
* </ul>
*/
public class MessageUpserter<M extends MessageOrBuilder> {
private final String queryPrefix;
private final String querySuffix;
private final Map<String, ColumnSpec> columnMap;
private final List<String> columnOrder;
private final String valueMarkers;
private MessageUpserter(String queryPrefix,
String querySuffix,
List<String> columnOrder,
Map<String, ColumnSpec> columnMap) {
this.queryPrefix = queryPrefix;
this.querySuffix = querySuffix;
this.columnOrder = asList(columnOrder);
this.columnMap = UnmodifiableMap.asMap(columnMap);
this.valueMarkers = columnOrder.stream()
.map(k -> "?")
.collect(Collectors.joining(",", "(", ")"));
}
@Override
public String toString() {
return queryPrefix + "(...)" + querySuffix;
}
@SafeVarargs
public final int execute(Handle handle, M... items) {
return execute(requireNonNull(handle), asList(items));
}
public int execute(Handle handle, Collection<M> items) {
if (items.isEmpty()) {
return 0;
}
String query = items.stream()
.map(item -> valueMarkers)
.collect(Collectors.joining(", ", queryPrefix, querySuffix));
Update update = handle.createUpdate(query);
int offset = 0;
for (MessageOrBuilder item : items) {
for (String column : columnOrder) {
ColumnSpec spec = columnMap.get(column);
update.bind(offset++, new MessageFieldArgument(item, spec.field, spec.sqlType));
}
}
return update.execute();
}
private static class ColumnSpec {
final FieldDescriptor field;
final SqlType sqlType;
ColumnSpec(FieldDescriptor field, SqlType sqlType) {
this.field = field;
this.sqlType = sqlType;
}
}
public static class Builder<M extends MessageOrBuilder> {
private final String intoTable;
private final Map<String, ColumnSpec> columnMap;
private final AtomicBoolean onDuplicateIgnore;
private final Set<String> onDuplicateUpdate;
private final Descriptors.Descriptor descriptor;
/**
* Create a message inserter builder.
*
* @param type The message type class.
* @param intoTable The table name to insert info.
*/
public Builder(Class<M> type,
String intoTable) {
this.descriptor = ProtoMessage.getMessageDescriptor(type);
this.intoTable = intoTable;
this.columnMap = new LinkedHashMap<>();
this.onDuplicateUpdate = new TreeSet<>();
this.onDuplicateIgnore = new AtomicBoolean();
}
/**
* Set all fields not already handled with default name and type.
*
* @return The builder.
*/
public final Builder<M> setAll() {
if (!columnMap.isEmpty()) {
throw new IllegalStateException("columnMap is already set");
}
for (FieldDescriptor field : descriptor.getFields()) {
var name = getDefaultColumnName(field);
columnMap.put(name, new ColumnSpec(field, getDefaultColumnType(field)));
}
return this;
}
/**
* Set all fields with defaults.
*
* @param except Fields to exclude.
* @return The builder.
*/
public final Builder<M> setAllFieldsExcept(int... except) {
return setAllFieldsExcept(
IntStream.of(except)
.boxed()
.collect(UnmodifiableSet.toSet()));
}
/**
* Set all fields with defaults.
*
* @param except Fields to exclude.
* @return The builder.
*/
public final Builder<M> setAllFieldsExcept(Collection<Integer> except) {
var setAllExcept = UnmodifiableSet.asSet(except);
for (FieldDescriptor field : descriptor.getFields()) {
if (!setAllExcept.contains(field.getNumber())) {
var name = getDefaultColumnName(field);
if (columnMap.containsKey(name)) {
throw new IllegalStateException(
"Duplicate column '" + name + "' for field: " +
descriptor.getFullName() + "{ " + ProtoField.toString(field) + "; }");
}
columnMap.put(name, new ColumnSpec(field, getDefaultColumnType(field)));
}
}
return this;
}
/**
* Set the specific fields with default name and type.
*
* @param fields The fields to be set.
* @return The builder.
*/
public final Builder<M> setFields(int... fields) {
return setFields(IntStream.of(fields).boxed().collect(toList()));
}
/**
* Set the specific fields with default name and type.
*
* @param fieldNumbers The fields to be set.
* @return The builder.
*/
public final Builder<M> setFields(Collection<Integer> fieldNumbers) {
var missing = fieldNumbers
.stream()
.map(num -> pairOf(num, descriptor.findFieldByNumber(num)))
.peek(pair -> {
if (pair.second != null) {
setInternal(getDefaultColumnName(pair.second),
pair.second,
getDefaultColumnType(pair.second));
}
})
.filter(p -> p.second == null)
.map(pair -> pair.first)
.collect(Collectors.toList());
if (!missing.isEmpty()) {
throw new IllegalArgumentException(
"Unrecognized field numbers " + missing + " in " + descriptor.getFullName());
}
return this;
}
/**
* Set the specific field with specific type and default name.
*
* @param fieldNumber The field number to be set.
* @param type The field type to set as.
* @return The builder.
*/
public final Builder<M> setField(int fieldNumber, SqlType type) {
var field = descriptor.findFieldByNumber(fieldNumber);
if (field == null) {
throw new IllegalArgumentException(
"Unrecognized field number " + fieldNumber + " in " + descriptor.getFullName());
}
return setInternal(getDefaultColumnName(field), field, type);
}
/**
* Set the specific field with name and default type.
*
* @param columnName The column name to set.
* @param fieldNumber The field to be set.
* @return The builder.
*/
public final Builder<M> setColumn(String columnName, int fieldNumber) {
var field = descriptor.findFieldByNumber(fieldNumber);
if (field == null) {
throw new IllegalArgumentException(
"Unrecognized field number " + fieldNumber + " in " + descriptor.getFullName());
}
return setInternal(columnName, field, getDefaultColumnType(field));
}
/**
* Set the specific field with name and default type.
*
* @param columnName The column name to set.
* @param fieldNumber The field to be set.
* @param type The SQL field type to set.
* @return The builder.
*/
public final Builder<M> setColumn(String columnName, int fieldNumber, SqlType type) {
var field = descriptor.findFieldByNumber(fieldNumber);
if (field == null) {
throw new IllegalArgumentException(
"Unrecognized field number " + fieldNumber + " in " + descriptor.getFullName());
}
return setInternal(columnName, field, type);
}
/**
* Set the specific field with specific name and type.
*
* @param column The column name to set.
* @param field The field to be set.
* @param type The field type to set as.
* @return The builder.
*/
public final Builder<M> setInternal(String column, FieldDescriptor field, SqlType type) {
if (columnMap.containsKey(column)) {
var old = columnMap.get(column);
throw new IllegalArgumentException(
"Column " + column + " already inserted, replacing '" +
ProtoField.toString(old.field) + "' with '" +
ProtoField.toString(field) + "'");
}
if (onDuplicateIgnore.get() || !onDuplicateUpdate.isEmpty()) {
throw new IllegalStateException(
"Duplicate key behavior already determined when specifying column '" +
column + "': " + ProtoField.toString(field));
}
this.columnMap.put(column, new ColumnSpec(field, type));
return this;
}
/**
* On duplicate keys update all except the given fields.
*
* @param exceptColumns The column names NOT to update.
* @return The builder.
*/
public final Builder<M> onDuplicateKeyUpdateAllColumnsExcept(String... exceptColumns) {
return onDuplicateKeyUpdateAllColumnsExcept(UnmodifiableList.asList(exceptColumns));
}
/**
* On duplicate keys update all except the given fields.
*
* @param exceptColumns The column names NOT to update.
* @return The builder.
*/
public final Builder<M> onDuplicateKeyUpdateAllColumnsExcept(Collection<String> exceptColumns) {
return onDuplicateKeyUpdateColumns(
SetOperations.subtract(
columnMap.keySet(),
UnmodifiableSet.asSet(exceptColumns)));
}
/**
* On duplicate keys update the given columns.
*
* @param columns The column names to update.
* @return The builder.
*/
public final Builder<M> onDuplicateKeyUpdateColumns(String... columns) {
if (onDuplicateIgnore.get()) {
throw new IllegalStateException("Duplicate key behavior already set to ignore");
}
Collections.addAll(onDuplicateUpdate, columns);
return this;
}
/**
* On duplicate keys update the given columns.
*
* @param columns The column names to update.
* @return The builder.
*/
public final Builder<M> onDuplicateKeyUpdateColumns(Collection<String> columns) {
if (onDuplicateIgnore.get()) {
throw new IllegalStateException("Duplicate key behavior already set to ignore");
}
onDuplicateUpdate.addAll(columns);
return this;
}
/**
* On duplicate keys ignore updates.
*
* @return The builder.
*/
public final Builder<M> onDuplicateKeyIgnore() {
if (!onDuplicateUpdate.isEmpty()) {
throw new IllegalStateException("Duplicate key behavior already set to update");
}
onDuplicateIgnore.set(true);
return this;
}
/**
* @return The final built inserter.
*/
public MessageUpserter<M> build() {
if (columnMap.isEmpty()) {
throw new IllegalStateException("No columns inserted");
}
List<String> columnOrder = new ArrayList<>(columnMap.keySet());
StringBuilder prefixBuilder = new StringBuilder("INSERT ");
if (onDuplicateIgnore.get()) {
prefixBuilder.append("IGNORE ");
}
prefixBuilder.append("INTO ")
.append(intoTable)
.append(" (")
.append(columnOrder.stream()
.map(col -> "`" + col + "`")
.collect(Collectors.joining(", ")))
.append(") VALUES ");
StringBuilder suffixBuilder = new StringBuilder();
if (!onDuplicateUpdate.isEmpty()) {
suffixBuilder.append(" ON DUPLICATE KEY UPDATE");
boolean first = true;
for (String column : onDuplicateUpdate) {
if (first) {
first = false;
} else {
suffixBuilder.append(",");
}
suffixBuilder.append(" `")
.append(column)
.append("` = VALUES(`")
.append(column)
.append("`)");
}
}
return new MessageUpserter<>(prefixBuilder.toString(),
suffixBuilder.toString(),
columnOrder,
columnMap);
}
}
}