Source code for keras_mml.utils.array.encoding

"""
Utilities for encoding special types of arrays.
"""

from collections import Counter
from typing import List, Tuple

import numpy as np

from keras_mml.utils.misc.number import int_to_bin

# fmt: off
BITS_MAPPING = {
    0: "0",   # Most common
    1: "10",  # Second most common
    2: "11"   # Least common
}
# fmt: on


[docs] def encode_ternary_array(x: np.ndarray) -> Tuple[Tuple[int, ...], bytes]: """ Encodes a ternary array into a more space efficient format. For a given ternary matrix, which consists of only the elements ``0``, ``1``, and ``-1``, we convert the individual elements into bit sequences. Specifically, - the most common element becomes ``0``; - the second most common element becomes ``10``; and - the least common element becomes ``11``. We then convert these bit sequences into bytes. The first three to four bits describes which elements were converted to ``0`` and ``10`` respectively. - ``0`` will be represented by ``0``; - ``1`` will be represented by ``10``; and - ``-1`` will be represented by ``11``. This function will prepend this information in front of the bit sequences, and then return them as bytes. Args: x: Ternary array to encode. Returns: A tuple. The first element is the shape of the array. The second element is the encoded representation of the array. Examples: >>> x = np.array([1, -1, 1, 0, -1, 1]) >>> shape, encoded = encode_ternary_array(x) >>> shape (6,) >>> encoded b'\\xb4\\xe0' >>> x = np.array([[0, 1, -1], [-1, 0, 0]]) >>> shape, encoded = encode_ternary_array(x) >>> shape (2, 3) >>> encoded b'n\\x80' >>> x = np.array([[[0, 1, -1], [-1, 1, 0]], [[1, 0, 0], [0, -1, 1]]]) >>> shape, encoded = encode_ternary_array(x) >>> shape (2, 2, 3) >>> encoded b'K\\xe48' >>> x = np.array([[[0, 1, -1], [-1, 1, 0]], [[1, -1, 0], [0, -1, 1]]]) >>> shape, encoded = encode_ternary_array(x) >>> shape (2, 2, 3) >>> encoded b'K\\xe5\\x9c' """ shape = x.shape flattened = x.flatten() # Get the frequency of each of the elements counts = Counter(flattened) remaining = {0, 1, -1} elem_order = [x[0] for x in counts.most_common()] elem_order += remaining.difference(elem_order) # Account for when there may not be 3 elements in `elem_order` # Record down what the top 2 most frequent elements are temp_bits = "" temp_bits += BITS_MAPPING[[0, 1, -1].index(elem_order[0])] temp_bits += BITS_MAPPING[[0, 1, -1].index(elem_order[1])] # Then perform the encoding output = bytearray() remainder = "" for elem in flattened: # Get the bits that represent the current element if remainder != "" and temp_bits == "": temp_bits = remainder remainder = "" temp_bits += BITS_MAPPING[elem_order.index(elem)] # Encode every 8 bits as a byte if len(temp_bits) >= 8: main = temp_bits[:8] remainder = temp_bits[8:] byte = int(main, 2) output.append(byte) temp_bits = "" if len(temp_bits) != 0 or len(remainder) != 0: temp_bits = temp_bits + remainder temp_bits = temp_bits + "0" * (8 - len(temp_bits)) byte = int(temp_bits, 2) output.append(byte) return shape, bytes(output)
def _retrieve_bit_encoding(encoded: bytes) -> Tuple[List[int], List[str]]: """ Retrieves the bit encoding map from the encoded representation. Args: encoded: Encoded ternary array. Returns: A tuple. The first element is the ordering of the distinct elements, from most common to least common. The second element is the remaining unconsumed bit buffer. """ start_index = 0 retrieved_count = 0 elem_order = [] remaining = {0, 1, -1} buffer = list(int_to_bin(encoded[0], pad_len=8)) # Encoding information is 3 to 4 bits, which is within first byte while retrieved_count < 2: if buffer.pop(0) == "0": # Full sequence is "0" -> 0 elem_order.append(0) remaining.remove(0) retrieved_count += 1 elif buffer.pop(0) == "0": # Full sequence is "10" -> 1 elem_order.append(1) remaining.remove(1) retrieved_count += 1 else: # Full sequence is "11" -> -1 elem_order.append(-1) remaining.remove(-1) retrieved_count += 1 start_index += 1 elem_order += remaining return elem_order, buffer
[docs] def decode_ternary_array(shape: Tuple[int, ...], encoded: bytes) -> np.ndarray: """ Decodes the ternary array generated by :py:func:`~encode_ternary_array`. Args: shape: Shape of the original array. encoded: Encoded ternary array. Raises: ValueError: If the encoded byte string is empty. Returns: The decoded ternary array. Examples: >>> decode_ternary_array((6,), b"\\xb4\\xe0") array([ 1, -1, 1, 0, -1, 1]) >>> decode_ternary_array((2, 3), b"n\\x80") array([[ 0, 1, -1], [-1, 0, 0]]) >>> decode_ternary_array((2, 2, 3), b"K\\xe48") array([[[ 0, 1, -1], [-1, 1, 0]], <BLANKLINE> [[ 1, 0, 0], [ 0, -1, 1]]]) >>> decode_ternary_array((2, 2, 3), b"K\\xe5\\x9c") array([[[ 0, 1, -1], [-1, 1, 0]], <BLANKLINE> [[ 1, -1, 0], [ 0, -1, 1]]]) """ if len(encoded) == 0: raise ValueError("Cannot decode empty encoded array") # Allocate memory for output total = shape[0] for i in range(1, len(shape)): total *= shape[i] flattened = np.zeros((total,), dtype=int) # Retrieve bit encoding elem_order, buffer = _retrieve_bit_encoding(encoded) # Decode array i = -1 # We'll increment this right at the start byte_num = 0 while True: i += 1 if i == total: break if len(buffer) == 0 or (len(buffer) == 1 and buffer[0] == "1"): # Can't decode lone "1" byte_num += 1 buffer += list(int_to_bin(encoded[byte_num], pad_len=8)) if buffer.pop(0) == "0": # The full sequence is "0" -> get most common element flattened[i] = elem_order[0] elif buffer.pop(0) == "0": # The full sequence is "10" -> get second most common element flattened[i] = elem_order[1] else: # The full sequence is "11" -> get least common element flattened[i] = elem_order[2] return flattened.reshape(shape)