BitPackingInputStream.java

package net.morimekta.io;

import java.io.IOException;
import java.io.InputStream;

import static java.lang.Math.min;
import static java.util.Objects.requireNonNull;

/**
 * An input stream that reads individual bits in groups of up to 31 bits at a
 * time. The 31 limit is to allow for negative numbers to still represent
 * <code>EOF</code>. The bits themselves will be read from most to least
 * significant per byte.
 *
 * <pre>{@code
 * |---|---|---|---|---|---|---|---|
 * | 8 | 7 | 6 | 5 | 4 | 3 | 2 | 1 |
 * |---|---|---|---|---|---|---|---|
 * | a | b | c | d | e | f | g | h |
 * | i | j | k | l | m | n | o | p |
 * |---|---|---|---|---|---|---|---|
 * }</pre>
 */
public class BitPackingInputStream extends InputStream {
    private InputStream in;
    private int         unreadBits;
    private int         unreadBitsCount;
    private int         readBits;

    /**
     * Create a bit packing input stream.
     *
     * @param in Input stream to read bits from.
     */
    public BitPackingInputStream(InputStream in) {
        this.in = requireNonNull(in, "in == null");
        this.unreadBits = 0;
        this.unreadBitsCount = 0;
        this.readBits = 0;
    }

    /**
     * Read a number of bits return as a bit endian integer.
     *
     * @param bits Number of bits to read. 1 is same as boolean, 8 same as a byte.
     *             Maximum value of 31.
     * @return The read bits. -1 if not enough bit data to read. Will consume bits
     * regardless.
     * @throws IOException              If unable to read from stream.
     * @throws IllegalArgumentException If bad number of bits requested.
     */
    public int readBits(final int bits) throws IOException {
        if (bits > 31) {
            throw new IllegalArgumentException("Trying to read " + bits + " bits, more than max 31 allowed");
        } else if (bits < 1) {
            throw new IllegalArgumentException("Trying to read " + bits + " bits");
        }

        if (in == null) {
            throw new IOException("Reading from closed stream");
        }

        if (unreadBits < 0) {
            // EOF
            return -1;
        }

        if (unreadBitsCount == 0) {
            unreadBits = in.read();
            if (unreadBits < 0) {
                return -1;
            }
            unreadBitsCount = 8;
        }

        if (bits <= unreadBitsCount) {
            int remainingBits = unreadBitsCount - bits;
            int mask = 0xff >>> (8 - bits);
            unreadBitsCount = remainingBits;
            readBits += bits;
            return (unreadBits >> remainingBits) & mask;
        } else {
            int result = 0;
            int left = bits;
            while (left > 0) {
                int more = min(left, unreadBitsCount == 0 ? 8 : unreadBitsCount);
                int read = readBits(more);
                result = result << more | read;
                left -= more;
            }
            return result;
        }
    }

    /**
     * @return The number of bits that have been read from the stream.
     */
    public int getReadBits() {
        return readBits;
    }

    /**
     * Return the number of available chunks of given number of bits.
     *
     * @param bits The number of bits per chunk.
     * @return The number of readable chunks.
     * @throws IOException If unable to get available bytes.
     */
    public int available(int bits) throws IOException {
        if (bits > 31) {
            throw new IllegalArgumentException("Trying to read " + bits + " bits, more than max 31 allowed");
        } else if (bits < 1) {
            throw new IllegalArgumentException("Trying to read " + bits + " bits");
        }
        if (in == null) return 0;
        int a = in.available();
        // available bits.
        int ba = a * 8 + unreadBitsCount;
        // available chunks.
        return ba / bits;
    }

    /**
     * Ignore the remainder of any partially read byte (octet), and continue
     * to read from the beginning of next byte.
     */
    public void align() {
        if (unreadBitsCount > 0 && unreadBitsCount < 8) {
            unreadBitsCount = 0;
            unreadBits = 0;
        }
    }

    @Override
    public int read() throws IOException {
        return readBits(8);
    }

    @Override
    public int available() throws IOException {
        if (in == null) return 0;
        return available(8);
    }

    @Override
    public void close() throws IOException {
        if (in != null) {
            try {
                in.close();
            } finally {
                in = null;
            }
        }
    }
}