BitPackingOutputStream.java

package net.morimekta.io;

import java.io.IOException;
import java.io.OutputStream;

/**
 * Output stream that writes individual bits consecutively to the output stream.
 * The bits will be grouped into bytes and each written when completed. The bits
 * themselves will be assigned from most to least significant per byte.
 *
 * <pre>{@code
 * |---|---|---|---|---|---|---|---|
 * | 8 | 7 | 6 | 5 | 4 | 3 | 2 | 1 |
 * |---|---|---|---|---|---|---|---|
 * | a | b | c | d | e | f | g | h |
 * | i | j | k | l | m | n | o | p |
 * |---|---|---|---|---|---|---|---|
 * }</pre>
 */
public class BitPackingOutputStream extends OutputStream {
    private OutputStream out;
    private int          pendingBitsCount;
    private int          pendingBits;
    private int          writtenBits;

    /**
     * Create a bit packing output stream.
     *
     * @param out Output stream to wrote bits to.
     */
    public BitPackingOutputStream(OutputStream out) {
        this.out = out;
        this.pendingBits = 0;
        this.pendingBitsCount = 0;
        this.writtenBits = 0;
    }

    /**
     * Write a number of bits to the stream, and make the next written bits
     * written just after the last bits with no regard to byte border
     * alignment.
     *
     * @param bitCount Number of bits to be written.
     * @param bitData  Data for bits to be written.
     * @throws IOException              If unable to write to stream.
     * @throws IllegalArgumentException On invalid input.
     */
    public void writeBits(int bitCount, int bitData) throws IOException {
        if (bitCount < 1 || bitCount > 31) {
            throw new IllegalArgumentException("Illegal writing bit count " + bitCount);
        } else if (bitData < 0) {
            throw new IllegalArgumentException("Writing negative bit data: " + bitData);
        }

        if (out == null) {
            throw new IOException("Writing to closed stream");
        }

        if (bitCount + pendingBitsCount <= 8) {
            // just store away bits.
            pendingBits = pendingBits | ((bitData & (0xff >>> (8 - bitCount))) << (8 - (bitCount + pendingBitsCount)));
            pendingBitsCount = pendingBitsCount + bitCount;
            writtenBits = writtenBits + bitCount;
            if (pendingBitsCount == 8) {
                out.write(pendingBits);

                pendingBits = 0;
                pendingBitsCount = 0;
            }
        } else {
            int writeBits = 8 - pendingBitsCount;
            writeBits(writeBits, bitData >>> (bitCount - writeBits));
            writeBits(bitCount - writeBits, bitData);
        }
    }

    /**
     * Force the next bit to be on a byte (octet) boundary.
     * Write all pending bits to the stream as if the remainder was written with 0 bits.
     *
     * @throws IOException If unable to write to stream.
     */
    public void align() throws IOException {
        if (pendingBitsCount > 0) {
            out.write(pendingBits);
            pendingBitsCount = 0;
            pendingBits = 0;
        }
    }

    /**
     * Number of bits written, including those not written to stream
     * yet, as the byte is not completed yet.
     *
     * @return Number of bits written.
     */
    public int getWrittenBits() {
        return writtenBits;
    }

    @Override
    public void write(int octet) throws IOException {
        writeBits(8, octet);
    }

    @Override
    public void flush() throws IOException {
        if (out != null) {
            out.flush();
        }
    }

    @Override
    public void close() throws IOException {
        if (out != null) {
            try {
                if (pendingBitsCount > 0) {
                    out.write(pendingBits);
                }
                out.close();
            } finally {
                pendingBitsCount = 0;
                pendingBits = 0;
                out = null;
            }
        }
    }
}