MessageUpserter.java

/*
 * Copyright 2018-2019 Providence Authors
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package net.morimekta.proto.jdbi.v3;

import com.google.protobuf.Descriptors;
import com.google.protobuf.Descriptors.FieldDescriptor;
import com.google.protobuf.MessageOrBuilder;
import net.morimekta.collect.UnmodifiableList;
import net.morimekta.collect.UnmodifiableMap;
import net.morimekta.collect.UnmodifiableSet;
import net.morimekta.collect.util.SetOperations;
import net.morimekta.proto.ProtoField;
import net.morimekta.proto.ProtoMessage;
import net.morimekta.proto.jdbi.MorimektaJdbiOptions.SqlType;
import net.morimekta.proto.jdbi.v3.util.MessageFieldArgument;
import org.jdbi.v3.core.Handle;
import org.jdbi.v3.core.statement.Update;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static java.util.Objects.requireNonNull;
import static net.morimekta.collect.UnmodifiableList.asList;
import static net.morimekta.collect.UnmodifiableList.toList;
import static net.morimekta.collect.util.Pair.pairOf;
import static net.morimekta.proto.jdbi.ProtoJdbi.getDefaultColumnName;
import static net.morimekta.proto.jdbi.ProtoJdbi.getDefaultColumnType;

/**
 * Helper class to handle inserting content from messages into a table.
 * The helper will only select values form the message itself, not using
 * nested structure or anything like that.
 * <p>
 * The inserter is built in  such a way that you can create the inserter
 * (even as a static field), and use it any number of times with a handle
 * to do the pre-programmed insert. The execute method is thread safe, as
 * long as none of the modification methods are called.
 *
 * <pre>{@code
 * class MyInserter {
 *     private static final MessageUpserter&lt;MyMessage,MyMessage._Field&gt; INSERTER =
 *             new MessageUpserter.Builder&lt;&gt;("my_message")
 *                     .set(MyMessage.UUID, MyMessage.NAME)
 *                     .set("amount", MyMessage.VALUE, Types.INTEGER)  // DOUBLE -&gt; INTEGER
 *                     .onDuplicateKeyUpdate(MyMessage.VALUE)
 *                     .build();
 *
 *     private final Jdbi dbi;
 *
 *     public MyInserter(Jdbi dbi) {
 *         this.dbi = dbi;
 *     }
 *
 *     int insert(HandleMyMessage... messages) {
 *         try (Handle handle = dbi.open()) {
 *             return INSERTER.execute(handle, messages);
 *         }
 *     }
 * }
 * }</pre>
 * <p>
 * Or it can be handled in line where needed. The building process is pretty cheap,
 * so this should not be a problem unless it is called <i>a lot</i> for very small
 * message.
 *
 * <pre>{@code
 * class MyInserter {
 *     int insert(HandleMyMessage... messages) {
 *         try (Handle handle = dbi.open()) {
 *             return new MessageUpserter.Builder&lt;MyMessage,MyMessage._Field&gt;("my_message")
 *                     .set(MyMessage.UUID, MyMessage.NAME)
 *                     .set("amount", MyMessage.VALUE, Types.INTEGER)  // DOUBLE -&gt; INTEGER
 *                     .onDuplicateKeyUpdateAllExcept(MyMessage.UUID)
 *                     .build()
 *                     .execute(handle, messages);
 *         }
 *     }
 * }
 * }</pre>
 * <p>
 * The rules for using this is pretty simple:
 *
 * <ul>
 *     <li>
 *         All fields set must be specified before onDuplicateKey* behavior.
 *     </li>
 *     <li>
 *         Only one of <code>onDuplicateKeyIgnore</code> and <code>onDuplicateKeyUpdate</code>
 *         can be set.
 *     </li>
 *     <li>
 *         <code>execute(...)</code> can be called any number of times, and is thread safe.
 *     </li>
 * </ul>
 */
public class MessageUpserter<M extends MessageOrBuilder> {
    private final String                  queryPrefix;
    private final String                  querySuffix;
    private final Map<String, ColumnSpec> columnMap;
    private final List<String>            columnOrder;
    private final String                  valueMarkers;

    private MessageUpserter(String queryPrefix,
                            String querySuffix,
                            List<String> columnOrder,
                            Map<String, ColumnSpec> columnMap) {
        this.queryPrefix = queryPrefix;
        this.querySuffix = querySuffix;
        this.columnOrder = asList(columnOrder);
        this.columnMap = UnmodifiableMap.asMap(columnMap);
        this.valueMarkers = columnOrder.stream()
                                       .map(k -> "?")
                                       .collect(Collectors.joining(",", "(", ")"));
    }

    @Override
    public String toString() {
        return queryPrefix + "(...)" + querySuffix;
    }

    @SafeVarargs
    public final int execute(Handle handle, M... items) {
        return execute(requireNonNull(handle), asList(items));
    }

    public int execute(Handle handle, Collection<M> items) {
        if (items.isEmpty()) {
            return 0;
        }

        String query = items.stream()
                            .map(item -> valueMarkers)
                            .collect(Collectors.joining(", ", queryPrefix, querySuffix));
        Update update = handle.createUpdate(query);
        int    offset = 0;
        for (MessageOrBuilder item : items) {
            for (String column : columnOrder) {
                ColumnSpec spec = columnMap.get(column);
                update.bind(offset++, new MessageFieldArgument(item, spec.field, spec.sqlType));
            }
        }
        return update.execute();
    }

    private static class ColumnSpec {
        final FieldDescriptor field;
        final SqlType         sqlType;

        ColumnSpec(FieldDescriptor field, SqlType sqlType) {
            this.field = field;
            this.sqlType = sqlType;
        }
    }

    public static class Builder<M extends MessageOrBuilder> {
        private final String                  intoTable;
        private final Map<String, ColumnSpec> columnMap;
        private final AtomicBoolean           onDuplicateIgnore;
        private final Set<String>             onDuplicateUpdate;

        private final Descriptors.Descriptor descriptor;

        /**
         * Create a message inserter builder.
         *
         * @param type      The message type class.
         * @param intoTable The table name to insert info.
         */
        public Builder(Class<M> type,
                       String intoTable) {
            this.descriptor = ProtoMessage.getMessageDescriptor(type);
            this.intoTable = intoTable;
            this.columnMap = new LinkedHashMap<>();
            this.onDuplicateUpdate = new TreeSet<>();
            this.onDuplicateIgnore = new AtomicBoolean();
        }

        /**
         * Set all fields not already handled with default name and type.
         *
         * @return The builder.
         */
        public final Builder<M> setAll() {
            if (!columnMap.isEmpty()) {
                throw new IllegalStateException("columnMap is already set");
            }
            for (FieldDescriptor field : descriptor.getFields()) {
                var name = getDefaultColumnName(field);
                columnMap.put(name, new ColumnSpec(field, getDefaultColumnType(field)));
            }
            return this;
        }

        /**
         * Set all fields with defaults.
         *
         * @param except Fields to exclude.
         * @return The builder.
         */
        public final Builder<M> setAllFieldsExcept(int... except) {
            return setAllFieldsExcept(
                    IntStream.of(except)
                             .boxed()
                             .collect(UnmodifiableSet.toSet()));
        }

        /**
         * Set all fields with defaults.
         *
         * @param except Fields to exclude.
         * @return The builder.
         */
        public final Builder<M> setAllFieldsExcept(Collection<Integer> except) {
            var setAllExcept = UnmodifiableSet.asSet(except);
            for (FieldDescriptor field : descriptor.getFields()) {
                if (!setAllExcept.contains(field.getNumber())) {
                    var name = getDefaultColumnName(field);
                    if (columnMap.containsKey(name)) {
                        throw new IllegalStateException(
                                "Duplicate column '" + name + "' for field: " +
                                descriptor.getFullName() + "{ " + ProtoField.toString(field) + "; }");
                    }

                    columnMap.put(name, new ColumnSpec(field, getDefaultColumnType(field)));
                }
            }
            return this;
        }

        /**
         * Set the specific fields with default name and type.
         *
         * @param fields The fields to be set.
         * @return The builder.
         */
        public final Builder<M> setFields(int... fields) {
            return setFields(IntStream.of(fields).boxed().collect(toList()));
        }

        /**
         * Set the specific fields with default name and type.
         *
         * @param fieldNumbers The fields to be set.
         * @return The builder.
         */
        public final Builder<M> setFields(Collection<Integer> fieldNumbers) {
            var missing = fieldNumbers
                    .stream()
                    .map(num -> pairOf(num, descriptor.findFieldByNumber(num)))
                    .peek(pair -> {
                        if (pair.second != null) {
                            setInternal(getDefaultColumnName(pair.second),
                                        pair.second,
                                        getDefaultColumnType(pair.second));
                        }
                    })
                    .filter(p -> p.second == null)
                    .map(pair -> pair.first)
                    .collect(Collectors.toList());
            if (!missing.isEmpty()) {
                throw new IllegalArgumentException(
                        "Unrecognized field numbers " + missing + " in " + descriptor.getFullName());
            }
            return this;
        }

        /**
         * Set the specific field with specific type and default name.
         *
         * @param fieldNumber The field number to be set.
         * @param type        The field type to set as.
         * @return The builder.
         */
        public final Builder<M> setField(int fieldNumber, SqlType type) {
            var field = descriptor.findFieldByNumber(fieldNumber);
            if (field == null) {
                throw new IllegalArgumentException(
                        "Unrecognized field number " + fieldNumber + " in " + descriptor.getFullName());
            }
            return setInternal(getDefaultColumnName(field), field, type);
        }

        /**
         * Set the specific field with name and default type.
         *
         * @param columnName  The column name to set.
         * @param fieldNumber The field to be set.
         * @return The builder.
         */
        public final Builder<M> setColumn(String columnName, int fieldNumber) {
            var field = descriptor.findFieldByNumber(fieldNumber);
            if (field == null) {
                throw new IllegalArgumentException(
                        "Unrecognized field number " + fieldNumber + " in " + descriptor.getFullName());
            }
            return setInternal(columnName, field, getDefaultColumnType(field));
        }

        /**
         * Set the specific field with name and default type.
         *
         * @param columnName  The column name to set.
         * @param fieldNumber The field to be set.
         * @param type        The SQL field type to set.
         * @return The builder.
         */
        public final Builder<M> setColumn(String columnName, int fieldNumber, SqlType type) {
            var field = descriptor.findFieldByNumber(fieldNumber);
            if (field == null) {
                throw new IllegalArgumentException(
                        "Unrecognized field number " + fieldNumber + " in " + descriptor.getFullName());
            }
            return setInternal(columnName, field, type);
        }

        /**
         * Set the specific field with specific name and type.
         *
         * @param column The column name to set.
         * @param field  The field to be set.
         * @param type   The field type to set as.
         * @return The builder.
         */
        public final Builder<M> setInternal(String column, FieldDescriptor field, SqlType type) {
            if (columnMap.containsKey(column)) {
                var old = columnMap.get(column);
                throw new IllegalArgumentException(
                        "Column " + column + " already inserted, replacing '" +
                        ProtoField.toString(old.field) + "' with '" +
                        ProtoField.toString(field) + "'");
            }
            if (onDuplicateIgnore.get() || !onDuplicateUpdate.isEmpty()) {
                throw new IllegalStateException(
                        "Duplicate key behavior already determined when specifying column '" +
                        column + "': " + ProtoField.toString(field));
            }
            this.columnMap.put(column, new ColumnSpec(field, type));
            return this;
        }

        /**
         * On duplicate keys update all except the given fields.
         *
         * @param exceptColumns The column names NOT to update.
         * @return The builder.
         */
        public final Builder<M> onDuplicateKeyUpdateAllColumnsExcept(String... exceptColumns) {
            return onDuplicateKeyUpdateAllColumnsExcept(UnmodifiableList.asList(exceptColumns));
        }

        /**
         * On duplicate keys update all except the given fields.
         *
         * @param exceptColumns The column names NOT to update.
         * @return The builder.
         */
        public final Builder<M> onDuplicateKeyUpdateAllColumnsExcept(Collection<String> exceptColumns) {
            return onDuplicateKeyUpdateColumns(
                    SetOperations.subtract(
                            columnMap.keySet(),
                            UnmodifiableSet.asSet(exceptColumns)));
        }

        /**
         * On duplicate keys update the given columns.
         *
         * @param columns The column names to update.
         * @return The builder.
         */
        public final Builder<M> onDuplicateKeyUpdateColumns(String... columns) {
            if (onDuplicateIgnore.get()) {
                throw new IllegalStateException("Duplicate key behavior already set to ignore");
            }
            Collections.addAll(onDuplicateUpdate, columns);
            return this;
        }

        /**
         * On duplicate keys update the given columns.
         *
         * @param columns The column names to update.
         * @return The builder.
         */
        public final Builder<M> onDuplicateKeyUpdateColumns(Collection<String> columns) {
            if (onDuplicateIgnore.get()) {
                throw new IllegalStateException("Duplicate key behavior already set to ignore");
            }
            onDuplicateUpdate.addAll(columns);
            return this;
        }

        /**
         * On duplicate keys ignore updates.
         *
         * @return The builder.
         */
        public final Builder<M> onDuplicateKeyIgnore() {
            if (!onDuplicateUpdate.isEmpty()) {
                throw new IllegalStateException("Duplicate key behavior already set to update");
            }
            onDuplicateIgnore.set(true);
            return this;
        }

        /**
         * @return The final built inserter.
         */
        public MessageUpserter<M> build() {
            if (columnMap.isEmpty()) {
                throw new IllegalStateException("No columns inserted");
            }
            List<String> columnOrder = new ArrayList<>(columnMap.keySet());

            StringBuilder prefixBuilder = new StringBuilder("INSERT ");
            if (onDuplicateIgnore.get()) {
                prefixBuilder.append("IGNORE ");
            }
            prefixBuilder.append("INTO ")
                         .append(intoTable)
                         .append(" (")
                         .append(columnOrder.stream()
                                            .map(col -> "`" + col + "`")
                                            .collect(Collectors.joining(", ")))
                         .append(") VALUES ");

            StringBuilder suffixBuilder = new StringBuilder();

            if (!onDuplicateUpdate.isEmpty()) {
                suffixBuilder.append(" ON DUPLICATE KEY UPDATE");
                boolean first = true;
                for (String column : onDuplicateUpdate) {
                    if (first) {
                        first = false;
                    } else {
                        suffixBuilder.append(",");
                    }

                    suffixBuilder.append(" `")
                                 .append(column)
                                 .append("` = VALUES(`")
                                 .append(column)
                                 .append("`)");
                }
            }

            return new MessageUpserter<>(prefixBuilder.toString(),
                                         suffixBuilder.toString(),
                                         columnOrder,
                                         columnMap);
        }
    }
}