BinaryInputStream.java

/*
 * Copyright (c) 2016, Stein Eldar Johnsen
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package net.morimekta.io;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;

/**
 * IO-Optimized binary reader. This is somewhat similar to the native java {@link java.io.ObjectInputStream}, but that
 * has some specific properties.
 *
 * <ul>
 *     <li>
 *         The stream never reads more from the enveloped input
 *         stream than is consumed. This differs from {@link java.io.ObjectInputStream}
 *         which is also buffering bytes.
 *     </li>
 *     <li>
 *         Has methods designed to control the endian-ness of the
 *         written numbers. See {@link BigEndianBinaryInputStream}
 *         and {@link LittleEndianBinaryInputStream} for the two main
 *         variants.
 *     </li>
 *     <li>
 *         Each basic value type has two reader methods, one that
 *         will fail if the entire value is not read, and one that
 *         passes and returns some default value on EOF.
 *     </li>
 * </ul>
 */
public abstract class BinaryInputStream
        extends InputStream {
    private final InputStream in;

    public BinaryInputStream(InputStream in) {
        this.in = in;
    }

    /**
     * Read a single byte.
     *
     * @return the byte value, or -1 if end of stream.
     * @throws IOException if unable to read from stream.
     */
    @Override
    public int read() throws IOException {
        return in.read();
    }

    /**
     * Read binary data from stream.
     *
     * @param out The output buffer to read into.
     * @throws IOException if unable to read from stream.
     */
    @Override
    public int read(byte[] out) throws IOException {
        int i, off = 0;
        while (off < out.length && (i = in.read(out, off, out.length - off)) > 0) {
            off += i;
        }
        return off;
    }

    /**
     * Read binary data from stream.
     *
     * @param out The output buffer to read into.
     * @param off Offset in out array to writeBinary to.
     * @param len Number of bytes to read.
     * @throws IOException if unable to read from stream.
     */
    @Override
    public int read(byte[] out, final int off, final int len) throws IOException {
        if (off < 0 || len < 0 || (off + len) > out.length) {
            throw new IllegalArgumentException(String.format(
                    "Illegal arguments for read: byte[%d], off:%d, len:%d",
                    out.length, off, len));
        }

        final int end = off + len;
        int pos = off;
        int i;
        while (pos < end && (i = in.read(out, off, end - off)) > 0) {
            pos += i;
        }
        return pos - off;
    }

    @Override
    public void close() throws IOException {
        in.close();
    }

    @Override
    public int available() throws IOException {
        return in.available();
    }

    @Override
    public boolean markSupported() {
        return in.markSupported();
    }

    @Override
    public synchronized void mark(int readLimit) {
        in.mark(readLimit);
    }

    @Override
    public synchronized void reset() throws IOException {
        in.reset();
    }

    /**
     * Read binary data from stream.
     *
     * @param out The output buffer to read into.
     * @throws IOException if unable to read from stream.
     */
    public void expect(byte[] out) throws IOException {
        int i, off = 0;
        while (off < out.length && (i = in.read(out, off, out.length - off)) > 0) {
            off += i;
        }
        if (off < out.length) {
            throw new IOException("Not enough data available on stream: " + off + " < " + out.length);
        }
    }

    /**
     * Read a byte from the input stream.
     *
     * @return The number read.
     * @throws IOException If no byte to read.
     */
    public byte expectByte() throws IOException {
        int read = in.read();
        if (read < 0) {
            throw new IOException("Missing expected byte");
        }
        return (byte) read;
    }

    /**
     * Read a short from the input stream.
     *
     * @return The number read.
     * @throws IOException if unable to read from stream.
     */
    public short expectShort() throws IOException {
        int b1 = in.read();
        if (b1 < 0) {
            throw new IOException("Missing byte 1 to expected short");
        }
        int b2 = in.read();
        if (b2 < 0) {
            throw new IOException("Missing byte 2 to expected short");
        }
        return (short) unshift2bytes(b1, b2);
    }

    /**
     * Read an int from the input stream.
     *
     * @return The number read.
     * @throws IOException if unable to read from stream.
     */
    public int expectInt() throws IOException {
        int b1 = in.read();
        if (b1 < 0) {
            throw new IOException("Missing byte 1 to expected int");
        }
        int b2 = in.read();
        if (b2 < 0) {
            throw new IOException("Missing byte 2 to expected int");
        }
        int b3 = in.read();
        if (b3 < 0) {
            throw new IOException("Missing byte 3 to expected int");
        }
        int b4 = in.read();
        if (b4 < 0) {
            throw new IOException("Missing byte 4 to expected int");
        }
        return unshift4bytes(b1, b2, b3, b4);
    }

    /**
     * Read a long int from the input stream.
     *
     * @return The number read.
     * @throws IOException if unable to read from stream.
     */
    public long expectLong() throws IOException {
        int b1 = in.read();
        if (b1 < 0) {
            throw new IOException("Missing byte 1 to expected long");
        }
        int b2 = in.read();
        if (b2 < 0) {
            throw new IOException("Missing byte 2 to expected long");
        }
        int b3 = in.read();
        if (b3 < 0) {
            throw new IOException("Missing byte 3 to expected long");
        }
        long b4 = in.read();
        if (b4 < 0) {
            throw new IOException("Missing byte 4 to expected long");
        }
        long b5 = in.read();
        if (b5 < 0) {
            throw new IOException("Missing byte 5 to expected long");
        }
        long b6 = in.read();
        if (b6 < 0) {
            throw new IOException("Missing byte 6 to expected long");
        }
        long b7 = in.read();
        if (b7 < 0) {
            throw new IOException("Missing byte 7 to expected long");
        }
        long b8 = in.read();
        if (b8 < 0) {
            throw new IOException("Missing byte 8 to expected long");
        }

        return unshift8bytes(b1, b2, b3, b4, b5, b6, b7, b8);
    }

    /**
     * Read a float from the input stream.
     *
     * @return The number read.
     * @throws IOException if unable to read from stream.
     */
    public float expectFloat() throws IOException {
        return Float.intBitsToFloat(expectInt());
    }

    /**
     * Read a double from the input stream.
     *
     * @return The number read.
     * @throws IOException if unable to read from stream.
     */
    public double expectDouble() throws IOException {
        return Double.longBitsToDouble(expectLong());
    }

    /**
     * Read binary data from stream.
     *
     * @param bytes Number of bytes to read.
     * @return The binary wrapper.
     * @throws IOException if unable to read from stream.
     */
    public byte[] expectBytes(final int bytes) throws IOException {
        if (bytes > (128 * 1024)) {  // more than 128 kB, be cautious.
            ByteArrayOutputStream tmp = new ByteArrayOutputStream(128 * 1024);
            int remaining = bytes;
            byte[] buffer = new byte[4 * 1024];  // 1 page in memory.
            int r;
            while (remaining > 0 &&
                   (r = in.read(buffer, 0, Math.min(remaining, buffer.length))) > 0) {
                tmp.write(buffer, 0, r);
                remaining -= r;
            }
            if (tmp.size() < bytes) {
                throw new IOException("Not enough data available on stream: " + tmp.size() + " < " + bytes);
            }
            return tmp.toByteArray();
        }
        byte[] out = new byte[bytes];
        expect(out);
        return out;
    }

    /**
     * Read an unsigned byte from the input stream.
     *
     * @return Unsigned byte.
     * @throws IOException If no number to read.
     */
    public int expectUInt8() throws IOException {
        int read = in.read();
        if (read < 0) {
            throw new IOException("Missing unsigned byte");
        }
        return read;
    }

    /**
     * Read an unsigned short from the input stream.
     *
     * @return The number read.
     * @throws IOException If no number to read.
     */
    public int expectUInt16() throws IOException {
        int b1 = in.read();
        if (b1 < 0) {
            throw new IOException("Missing byte 1 to expected uint16");
        }
        int b2 = in.read();
        if (b2 < 0) {
            throw new IOException("Missing byte 2 to expected uint16");
        }
        return unshift2bytes(b1, b2);
    }

    /**
     * Read an unsigned short from the input stream.
     *
     * @return The number read.
     * @throws IOException If no number to read.
     */
    public int readUInt16() throws IOException {
        int b1 = in.read();
        if (b1 < 0) {
            return 0;
        }
        int b2 = in.read();
        if (b2 < 0) {
            throw new IOException("Missing byte 2 to read uint16");
        }
        return unshift2bytes(b1, b2);
    }

    /**
     * Read an unsigned short from the input stream.
     *
     * @return The number read.
     * @throws IOException If no number to read.
     */
    public int expectUInt24() throws IOException {
        int b1 = in.read();
        if (b1 < 0) {
            throw new IOException("Missing byte 1 to expected uint24");
        }
        int b2 = in.read();
        if (b2 < 0) {
            throw new IOException("Missing byte 2 to expected uint24");
        }
        int b3 = in.read();
        if (b3 < 0) {
            throw new IOException("Missing byte 3 to expected uint24");
        }
        return unshift3bytes(b1, b2, b3);
    }

    /**
     * Read an unsigned int from the input stream.
     *
     * @return The number read.
     * @throws IOException If no number to read.
     */
    public int expectUInt32() throws IOException {
        return expectInt();
    }

    public long expectULong32() throws IOException {
        return 0xFFFFFFFFL & expectInt();
    }

    public long expectULong40() throws IOException {
        return unshiftNBytes(expectBytes(5));
    }

    public long expectULong48() throws IOException {
        return unshiftNBytes(expectBytes(6));
    }

    public long expectULong56() throws IOException {
        return unshiftNBytes(expectBytes(7));
    }

    public long expectULong64() throws IOException {
        return expectLong();
    }

    /**
     * Read an unsigned number from the input stream.
     *
     * @param bytes Number of bytes to read.
     * @return The number read.
     * @throws IOException if unable to read from stream.
     */
    public int expectUnsigned(int bytes) throws IOException {
        switch (bytes) {
            case 4:
                return expectUInt32();
            case 3:
                return expectUInt24();
            case 2:
                return expectUInt16();
            case 1:
                return expectUInt8();
        }
        throw new IllegalArgumentException("Unsupported byte count for unsigned: " + bytes);
    }

    /**
     * Read an unsigned number from the input stream.
     *
     * @param bytes Number of bytes to read.
     * @return The number read.
     * @throws IOException if unable to read from stream.
     */
    public long expectUnsignedLong(int bytes) throws IOException {
        switch (bytes) {
            case 8:
                return expectLong();
            case 7:
            case 6:
            case 5:
                return unshiftNBytes(expectBytes(bytes));
            case 4:
                return expectULong32();
            case 3:
                return expectUInt24();
            case 2:
                return expectUInt16();
            case 1:
                return expectUInt8();
        }
        throw new IllegalArgumentException("Unsupported byte count for unsigned long: " + bytes);
    }

    /**
     * Read an signed number from the input stream.
     *
     * @param bytes Number of bytes to read.
     * @return The number read.
     * @throws IOException if unable to read from stream.
     */
    public long expectSigned(int bytes) throws IOException {
        switch (bytes) {
            case 8:
                return expectLong();
            case 4:
                return expectInt();
            case 2:
                return expectShort();
            case 1:
                return expectByte();
        }
        throw new IllegalArgumentException("Unsupported byte count for signed: " + bytes);
    }

    /**
     * Read a long number as zigzag encoded from the stream. The least significant bit becomes the sign, and the actual
     * value is absolute and shifted one bit. This makes it maximum compressed both when positive and negative.
     *
     * @return The zigzag decoded value.
     * @throws IOException if unable to read from stream.
     */
    public int readIntZigzag() throws IOException {
        int value = readIntBase128();
        return (value & 1) != 0 ? ~(value >>> 1) : value >>> 1;
    }

    /**
     * Read a long number as zigzag encoded from the stream. The least significant bit becomes the sign, and the actual
     * value is absolute and shifted one bit. This makes it maximum compressed both when positive and negative.
     *
     * @return The zigzag decoded value.
     * @throws IOException if unable to read from stream.
     */
    public int expectIntZigzag() throws IOException {
        int value = expectIntBase128();
        return (value & 1) != 0 ? ~(value >>> 1) : value >>> 1;
    }

    /**
     * Read a long number as zigzag encoded from the stream. The least significant bit becomes the sign, and the actual
     * value is absolute and shifted one bit. This makes it maximum compressed both when positive and negative.
     *
     * @return The zigzag decoded value.
     * @throws IOException if unable to read from stream.
     */
    public long readLongZigzag() throws IOException {
        long value = readLongBase128();
        return (value & 1) != 0 ? ~(value >>> 1) : value >>> 1;
    }

    /**
     * Read a long number as zigzag encoded from the stream. The least significant bit becomes the sign, and the actual
     * value is absolute and shifted one bit. This makes it maximum compressed both when positive and negative.
     *
     * @return The zigzag decoded value.
     * @throws IOException if unable to read from stream.
     */
    public long expectLongZigzag() throws IOException {
        long value = expectLongBase128();
        return (value & 1) != 0 ? ~(value >>> 1) : value >>> 1;
    }

    /**
     * Read a signed number as base128 (integer with variable number of bytes, determined as part of the bytes
     * themselves), using the endianness to determine order of 7-byte block assembly.
     * <p>
     * NOTE: Reading base128 accepts end of stream as '0'.
     *
     * @return The base128 number read from stream.
     * @throws IOException if unable to read from stream.
     */
    public int readIntBase128() throws IOException {
        int i = read();
        if (i < 0) {
            return 0;
        }
        return internalReadIntBase128(i);
    }

    /**
     * Read a signed number as base128 (integer with variable number of bytes, determined as part of the bytes
     * themselves), using the endianness to determine order of 7-byte block assembly.
     *
     * @return The base128 number read from stream.
     * @throws IOException if unable to read from stream.
     */
    public int expectIntBase128() throws IOException {
        return internalReadIntBase128(expectUInt8());
    }

    @Deprecated
    public int readIntVarint() throws IOException {
        return readIntBase128();
    }

    /**
     * Read a signed number as base128 (integer with variable number of bytes, determined as part of the bytes
     * themselves), using the endianness to determine order of 7-byte block assembly.
     * <p>
     * NOTE: Reading base128 accepts end of stream as '0'.
     *
     * @return The base128 number read from stream.
     * @throws IOException if unable to read from stream.
     */
    public long readLongBase128() throws IOException {
        int i = read();
        if (i < 0) {
            return 0L;
        }
        return internalReadLongBase128(i);
    }

    /**
     * Read a signed number as base128 (integer with variable number of bytes, determined as part of the bytes
     * themselves), using the endianness to determine order of 7-byte block assembly.
     *
     * @return The base128 number read from stream.
     * @throws IOException if unable to read from stream.
     */
    public long expectLongBase128() throws IOException {
        return internalReadLongBase128(expectUInt8());
    }

    @Deprecated
    public long readLongVarint() throws IOException {
        return readLongBase128();
    }

    // ---- Protected methods to handle endianness ----

    protected abstract int internalReadIntBase128(int i) throws IOException;

    protected abstract long internalReadLongBase128(int i) throws IOException;

    protected abstract int unshift2bytes(int b1, int b2);

    protected abstract int unshift3bytes(int b1, int b2, int b3);

    protected abstract int unshift4bytes(int b1, int b2, int b3, int b4);

    protected abstract long unshift8bytes(long b1, long b2, long b3, long b4, long b5, long b6, long b7, long b8);

    protected abstract long unshiftNBytes(byte[] bytes);
}