Home | History | Annotate | Line # | Download | only in educational_decoder
      1 /*
      2  * Copyright (c) Meta Platforms, Inc. and affiliates.
      3  * All rights reserved.
      4  *
      5  * This source code is licensed under both the BSD-style license (found in the
      6  * LICENSE file in the root directory of this source tree) and the GPLv2 (found
      7  * in the COPYING file in the root directory of this source tree).
      8  * You may select, at your option, one of the above-listed licenses.
      9  */
     10 
     11 /// Zstandard educational decoder implementation
     12 /// See https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md
     13 
     14 #include <stdint.h>   // uint8_t, etc.
     15 #include <stdlib.h>   // malloc, free, exit
     16 #include <stdio.h>    // fprintf
     17 #include <string.h>   // memset, memcpy
     18 #include "zstd_decompress.h"
     19 
     20 
     21 /******* IMPORTANT CONSTANTS *********************************************/
     22 
     23 // Zstandard frame
     24 // "Magic_Number
     25 // 4 Bytes, little-endian format. Value : 0xFD2FB528"
     26 #define ZSTD_MAGIC_NUMBER 0xFD2FB528U
     27 
     28 // The size of `Block_Content` is limited by `Block_Maximum_Size`,
     29 #define ZSTD_BLOCK_SIZE_MAX ((size_t)128 * 1024)
     30 
     31 // literal blocks can't be larger than their block
     32 #define MAX_LITERALS_SIZE ZSTD_BLOCK_SIZE_MAX
     33 
     34 
     35 /******* UTILITY MACROS AND TYPES *********************************************/
     36 #define MAX(a, b) ((a) > (b) ? (a) : (b))
     37 #define MIN(a, b) ((a) < (b) ? (a) : (b))
     38 
     39 #if defined(ZDEC_NO_MESSAGE)
     40 #define MESSAGE(...)
     41 #else
     42 #define MESSAGE(...)  fprintf(stderr, "" __VA_ARGS__)
     43 #endif
     44 
     45 /// This decoder calls exit(1) when it encounters an error, however a production
     46 /// library should propagate error codes
     47 #define ERROR(s)                                                               \
     48     do {                                                                       \
     49         MESSAGE("Error: %s\n", s);                                     \
     50         exit(1);                                                               \
     51     } while (0)
     52 #define INP_SIZE()                                                             \
     53     ERROR("Input buffer smaller than it should be or input is "                \
     54           "corrupted")
     55 #define OUT_SIZE() ERROR("Output buffer too small for output")
     56 #define CORRUPTION() ERROR("Corruption detected while decompressing")
     57 #define BAD_ALLOC() ERROR("Memory allocation error")
     58 #define IMPOSSIBLE() ERROR("An impossibility has occurred")
     59 
     60 typedef uint8_t  u8;
     61 typedef uint16_t u16;
     62 typedef uint32_t u32;
     63 typedef uint64_t u64;
     64 
     65 typedef int8_t  i8;
     66 typedef int16_t i16;
     67 typedef int32_t i32;
     68 typedef int64_t i64;
     69 /******* END UTILITY MACROS AND TYPES *****************************************/
     70 
     71 /******* IMPLEMENTATION PRIMITIVE PROTOTYPES **********************************/
     72 /// The implementations for these functions can be found at the bottom of this
     73 /// file.  They implement low-level functionality needed for the higher level
     74 /// decompression functions.
     75 
     76 /*** IO STREAM OPERATIONS *************/
     77 
     78 /// ostream_t/istream_t are used to wrap the pointers/length data passed into
     79 /// ZSTD_decompress, so that all IO operations are safely bounds checked
     80 /// They are written/read forward, and reads are treated as little-endian
     81 /// They should be used opaquely to ensure safety
     82 typedef struct {
     83     u8 *ptr;
     84     size_t len;
     85 } ostream_t;
     86 
     87 typedef struct {
     88     const u8 *ptr;
     89     size_t len;
     90 
     91     // Input often reads a few bits at a time, so maintain an internal offset
     92     int bit_offset;
     93 } istream_t;
     94 
     95 /// The following two functions are the only ones that allow the istream to be
     96 /// non-byte aligned
     97 
     98 /// Reads `num` bits from a bitstream, and updates the internal offset
     99 static inline u64 IO_read_bits(istream_t *const in, const int num_bits);
    100 /// Backs-up the stream by `num` bits so they can be read again
    101 static inline void IO_rewind_bits(istream_t *const in, const int num_bits);
    102 /// If the remaining bits in a byte will be unused, advance to the end of the
    103 /// byte
    104 static inline void IO_align_stream(istream_t *const in);
    105 
    106 /// Write the given byte into the output stream
    107 static inline void IO_write_byte(ostream_t *const out, u8 symb);
    108 
    109 /// Returns the number of bytes left to be read in this stream.  The stream must
    110 /// be byte aligned.
    111 static inline size_t IO_istream_len(const istream_t *const in);
    112 
    113 /// Advances the stream by `len` bytes, and returns a pointer to the chunk that
    114 /// was skipped.  The stream must be byte aligned.
    115 static inline const u8 *IO_get_read_ptr(istream_t *const in, size_t len);
    116 /// Advances the stream by `len` bytes, and returns a pointer to the chunk that
    117 /// was skipped so it can be written to.
    118 static inline u8 *IO_get_write_ptr(ostream_t *const out, size_t len);
    119 
    120 /// Advance the inner state by `len` bytes.  The stream must be byte aligned.
    121 static inline void IO_advance_input(istream_t *const in, size_t len);
    122 
    123 /// Returns an `ostream_t` constructed from the given pointer and length.
    124 static inline ostream_t IO_make_ostream(u8 *out, size_t len);
    125 /// Returns an `istream_t` constructed from the given pointer and length.
    126 static inline istream_t IO_make_istream(const u8 *in, size_t len);
    127 
    128 /// Returns an `istream_t` with the same base as `in`, and length `len`.
    129 /// Then, advance `in` to account for the consumed bytes.
    130 /// `in` must be byte aligned.
    131 static inline istream_t IO_make_sub_istream(istream_t *const in, size_t len);
    132 /*** END IO STREAM OPERATIONS *********/
    133 
    134 /*** BITSTREAM OPERATIONS *************/
    135 /// Read `num` bits (up to 64) from `src + offset`, where `offset` is in bits,
    136 /// and return them interpreted as a little-endian unsigned integer.
    137 static inline u64 read_bits_LE(const u8 *src, const int num_bits,
    138                                const size_t offset);
    139 
    140 /// Read bits from the end of a HUF or FSE bitstream.  `offset` is in bits, so
    141 /// it updates `offset` to `offset - bits`, and then reads `bits` bits from
    142 /// `src + offset`.  If the offset becomes negative, the extra bits at the
    143 /// bottom are filled in with `0` bits instead of reading from before `src`.
    144 static inline u64 STREAM_read_bits(const u8 *src, const int bits,
    145                                    i64 *const offset);
    146 /*** END BITSTREAM OPERATIONS *********/
    147 
    148 /*** BIT COUNTING OPERATIONS **********/
    149 /// Returns the index of the highest set bit in `num`, or `-1` if `num == 0`
    150 static inline int highest_set_bit(const u64 num);
    151 /*** END BIT COUNTING OPERATIONS ******/
    152 
    153 /*** HUFFMAN PRIMITIVES ***************/
    154 // Table decode method uses exponential memory, so we need to limit depth
    155 #define HUF_MAX_BITS (16)
    156 
    157 // Limit the maximum number of symbols to 256 so we can store a symbol in a byte
    158 #define HUF_MAX_SYMBS (256)
    159 
    160 /// Structure containing all tables necessary for efficient Huffman decoding
    161 typedef struct {
    162     u8 *symbols;
    163     u8 *num_bits;
    164     int max_bits;
    165 } HUF_dtable;
    166 
    167 /// Decode a single symbol and read in enough bits to refresh the state
    168 static inline u8 HUF_decode_symbol(const HUF_dtable *const dtable,
    169                                    u16 *const state, const u8 *const src,
    170                                    i64 *const offset);
    171 /// Read in a full state's worth of bits to initialize it
    172 static inline void HUF_init_state(const HUF_dtable *const dtable,
    173                                   u16 *const state, const u8 *const src,
    174                                   i64 *const offset);
    175 
    176 /// Decompresses a single Huffman stream, returns the number of bytes decoded.
    177 /// `src_len` must be the exact length of the Huffman-coded block.
    178 static size_t HUF_decompress_1stream(const HUF_dtable *const dtable,
    179                                      ostream_t *const out, istream_t *const in);
    180 /// Same as previous but decodes 4 streams, formatted as in the Zstandard
    181 /// specification.
    182 /// `src_len` must be the exact length of the Huffman-coded block.
    183 static size_t HUF_decompress_4stream(const HUF_dtable *const dtable,
    184                                      ostream_t *const out, istream_t *const in);
    185 
    186 /// Initialize a Huffman decoding table using the table of bit counts provided
    187 static void HUF_init_dtable(HUF_dtable *const table, const u8 *const bits,
    188                             const int num_symbs);
    189 /// Initialize a Huffman decoding table using the table of weights provided
    190 /// Weights follow the definition provided in the Zstandard specification
    191 static void HUF_init_dtable_usingweights(HUF_dtable *const table,
    192                                          const u8 *const weights,
    193                                          const int num_symbs);
    194 
    195 /// Free the malloc'ed parts of a decoding table
    196 static void HUF_free_dtable(HUF_dtable *const dtable);
    197 /*** END HUFFMAN PRIMITIVES ***********/
    198 
    199 /*** FSE PRIMITIVES *******************/
    200 /// For more description of FSE see
    201 /// https://github.com/Cyan4973/FiniteStateEntropy/
    202 
    203 // FSE table decoding uses exponential memory, so limit the maximum accuracy
    204 #define FSE_MAX_ACCURACY_LOG (15)
    205 // Limit the maximum number of symbols so they can be stored in a single byte
    206 #define FSE_MAX_SYMBS (256)
    207 
    208 /// The tables needed to decode FSE encoded streams
    209 typedef struct {
    210     u8 *symbols;
    211     u8 *num_bits;
    212     u16 *new_state_base;
    213     int accuracy_log;
    214 } FSE_dtable;
    215 
    216 /// Return the symbol for the current state
    217 static inline u8 FSE_peek_symbol(const FSE_dtable *const dtable,
    218                                  const u16 state);
    219 /// Read the number of bits necessary to update state, update, and shift offset
    220 /// back to reflect the bits read
    221 static inline void FSE_update_state(const FSE_dtable *const dtable,
    222                                     u16 *const state, const u8 *const src,
    223                                     i64 *const offset);
    224 
    225 /// Combine peek and update: decode a symbol and update the state
    226 static inline u8 FSE_decode_symbol(const FSE_dtable *const dtable,
    227                                    u16 *const state, const u8 *const src,
    228                                    i64 *const offset);
    229 
    230 /// Read bits from the stream to initialize the state and shift offset back
    231 static inline void FSE_init_state(const FSE_dtable *const dtable,
    232                                   u16 *const state, const u8 *const src,
    233                                   i64 *const offset);
    234 
    235 /// Decompress two interleaved bitstreams (e.g. compressed Huffman weights)
    236 /// using an FSE decoding table.  `src_len` must be the exact length of the
    237 /// block.
    238 static size_t FSE_decompress_interleaved2(const FSE_dtable *const dtable,
    239                                           ostream_t *const out,
    240                                           istream_t *const in);
    241 
    242 /// Initialize a decoding table using normalized frequencies.
    243 static void FSE_init_dtable(FSE_dtable *const dtable,
    244                             const i16 *const norm_freqs, const int num_symbs,
    245                             const int accuracy_log);
    246 
    247 /// Decode an FSE header as defined in the Zstandard format specification and
    248 /// use the decoded frequencies to initialize a decoding table.
    249 static void FSE_decode_header(FSE_dtable *const dtable, istream_t *const in,
    250                                 const int max_accuracy_log);
    251 
    252 /// Initialize an FSE table that will always return the same symbol and consume
    253 /// 0 bits per symbol, to be used for RLE mode in sequence commands
    254 static void FSE_init_dtable_rle(FSE_dtable *const dtable, const u8 symb);
    255 
    256 /// Free the malloc'ed parts of a decoding table
    257 static void FSE_free_dtable(FSE_dtable *const dtable);
    258 /*** END FSE PRIMITIVES ***************/
    259 
    260 /******* END IMPLEMENTATION PRIMITIVE PROTOTYPES ******************************/
    261 
    262 /******* ZSTD HELPER STRUCTS AND PROTOTYPES ***********************************/
    263 
    264 /// A small structure that can be reused in various places that need to access
    265 /// frame header information
    266 typedef struct {
    267     // The size of window that we need to be able to contiguously store for
    268     // references
    269     size_t window_size;
    270     // The total output size of this compressed frame
    271     size_t frame_content_size;
    272 
    273     // The dictionary id if this frame uses one
    274     u32 dictionary_id;
    275 
    276     // Whether or not the content of this frame has a checksum
    277     int content_checksum_flag;
    278     // Whether or not the output for this frame is in a single segment
    279     int single_segment_flag;
    280 } frame_header_t;
    281 
    282 /// The context needed to decode blocks in a frame
    283 typedef struct {
    284     frame_header_t header;
    285 
    286     // The total amount of data available for backreferences, to determine if an
    287     // offset too large to be correct
    288     size_t current_total_output;
    289 
    290     const u8 *dict_content;
    291     size_t dict_content_len;
    292 
    293     // Entropy encoding tables so they can be repeated by future blocks instead
    294     // of retransmitting
    295     HUF_dtable literals_dtable;
    296     FSE_dtable ll_dtable;
    297     FSE_dtable ml_dtable;
    298     FSE_dtable of_dtable;
    299 
    300     // The last 3 offsets for the special "repeat offsets".
    301     u64 previous_offsets[3];
    302 } frame_context_t;
    303 
    304 /// The decoded contents of a dictionary so that it doesn't have to be repeated
    305 /// for each frame that uses it
    306 struct dictionary_s {
    307     // Entropy tables
    308     HUF_dtable literals_dtable;
    309     FSE_dtable ll_dtable;
    310     FSE_dtable ml_dtable;
    311     FSE_dtable of_dtable;
    312 
    313     // Raw content for backreferences
    314     u8 *content;
    315     size_t content_size;
    316 
    317     // Offset history to prepopulate the frame's history
    318     u64 previous_offsets[3];
    319 
    320     u32 dictionary_id;
    321 };
    322 
    323 /// A tuple containing the parts necessary to decode and execute a ZSTD sequence
    324 /// command
    325 typedef struct {
    326     u32 literal_length;
    327     u32 match_length;
    328     u32 offset;
    329 } sequence_command_t;
    330 
    331 /// The decoder works top-down, starting at the high level like Zstd frames, and
    332 /// working down to lower more technical levels such as blocks, literals, and
    333 /// sequences.  The high-level functions roughly follow the outline of the
    334 /// format specification:
    335 /// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md
    336 
    337 /// Before the implementation of each high-level function declared here, the
    338 /// prototypes for their helper functions are defined and explained
    339 
    340 /// Decode a single Zstd frame, or error if the input is not a valid frame.
    341 /// Accepts a dict argument, which may be NULL indicating no dictionary.
    342 /// See
    343 /// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#frame-concatenation
    344 static void decode_frame(ostream_t *const out, istream_t *const in,
    345                          const dictionary_t *const dict);
    346 
    347 // Decode data in a compressed block
    348 static void decompress_block(frame_context_t *const ctx, ostream_t *const out,
    349                              istream_t *const in);
    350 
    351 // Decode the literals section of a block
    352 static size_t decode_literals(frame_context_t *const ctx, istream_t *const in,
    353                               u8 **const literals);
    354 
    355 // Decode the sequences part of a block
    356 static size_t decode_sequences(frame_context_t *const ctx, istream_t *const in,
    357                                sequence_command_t **const sequences);
    358 
    359 // Execute the decoded sequences on the literals block
    360 static void execute_sequences(frame_context_t *const ctx, ostream_t *const out,
    361                               const u8 *const literals,
    362                               const size_t literals_len,
    363                               const sequence_command_t *const sequences,
    364                               const size_t num_sequences);
    365 
    366 // Copies literals and returns the total literal length that was copied
    367 static u32 copy_literals(const size_t seq, istream_t *litstream,
    368                          ostream_t *const out);
    369 
    370 // Given an offset code from a sequence command (either an actual offset value
    371 // or an index for previous offset), computes the correct offset and updates
    372 // the offset history
    373 static size_t compute_offset(sequence_command_t seq, u64 *const offset_hist);
    374 
    375 // Given an offset, match length, and total output, as well as the frame
    376 // context for the dictionary, determines if the dictionary is used and
    377 // executes the copy operation
    378 static void execute_match_copy(frame_context_t *const ctx, size_t offset,
    379                               size_t match_length, size_t total_output,
    380                               ostream_t *const out);
    381 
    382 /******* END ZSTD HELPER STRUCTS AND PROTOTYPES *******************************/
    383 
    384 size_t ZSTD_decompress(void *const dst, const size_t dst_len,
    385                        const void *const src, const size_t src_len) {
    386     dictionary_t* const uninit_dict = create_dictionary();
    387     size_t const decomp_size = ZSTD_decompress_with_dict(dst, dst_len, src,
    388                                                          src_len, uninit_dict);
    389     free_dictionary(uninit_dict);
    390     return decomp_size;
    391 }
    392 
    393 size_t ZSTD_decompress_with_dict(void *const dst, const size_t dst_len,
    394                                  const void *const src, const size_t src_len,
    395                                  dictionary_t* parsed_dict) {
    396 
    397     istream_t in = IO_make_istream(src, src_len);
    398     ostream_t out = IO_make_ostream(dst, dst_len);
    399 
    400     // "A content compressed by Zstandard is transformed into a Zstandard frame.
    401     // Multiple frames can be appended into a single file or stream. A frame is
    402     // totally independent, has a defined beginning and end, and a set of
    403     // parameters which tells the decoder how to decompress it."
    404 
    405     /* this decoder assumes decompression of a single frame */
    406     decode_frame(&out, &in, parsed_dict);
    407 
    408     return (size_t)(out.ptr - (u8 *)dst);
    409 }
    410 
    411 /******* FRAME DECODING ******************************************************/
    412 
    413 static void decode_data_frame(ostream_t *const out, istream_t *const in,
    414                               const dictionary_t *const dict);
    415 static void init_frame_context(frame_context_t *const context,
    416                                istream_t *const in,
    417                                const dictionary_t *const dict);
    418 static void free_frame_context(frame_context_t *const context);
    419 static void parse_frame_header(frame_header_t *const header,
    420                                istream_t *const in);
    421 static void frame_context_apply_dict(frame_context_t *const ctx,
    422                                      const dictionary_t *const dict);
    423 
    424 static void decompress_data(frame_context_t *const ctx, ostream_t *const out,
    425                             istream_t *const in);
    426 
    427 static void decode_frame(ostream_t *const out, istream_t *const in,
    428                          const dictionary_t *const dict) {
    429     const u32 magic_number = (u32)IO_read_bits(in, 32);
    430     if (magic_number == ZSTD_MAGIC_NUMBER) {
    431         // ZSTD frame
    432         decode_data_frame(out, in, dict);
    433 
    434         return;
    435     }
    436 
    437     // not a real frame or a skippable frame
    438     ERROR("Tried to decode non-ZSTD frame");
    439 }
    440 
    441 /// Decode a frame that contains compressed data.  Not all frames do as there
    442 /// are skippable frames.
    443 /// See
    444 /// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#general-structure-of-zstandard-frame-format
    445 static void decode_data_frame(ostream_t *const out, istream_t *const in,
    446                               const dictionary_t *const dict) {
    447     frame_context_t ctx;
    448 
    449     // Initialize the context that needs to be carried from block to block
    450     init_frame_context(&ctx, in, dict);
    451 
    452     if (ctx.header.frame_content_size != 0 &&
    453         ctx.header.frame_content_size > out->len) {
    454         OUT_SIZE();
    455     }
    456 
    457     decompress_data(&ctx, out, in);
    458 
    459     free_frame_context(&ctx);
    460 }
    461 
    462 /// Takes the information provided in the header and dictionary, and initializes
    463 /// the context for this frame
    464 static void init_frame_context(frame_context_t *const context,
    465                                istream_t *const in,
    466                                const dictionary_t *const dict) {
    467     // Most fields in context are correct when initialized to 0
    468     memset(context, 0, sizeof(frame_context_t));
    469 
    470     // Parse data from the frame header
    471     parse_frame_header(&context->header, in);
    472 
    473     // Set up the offset history for the repeat offset commands
    474     context->previous_offsets[0] = 1;
    475     context->previous_offsets[1] = 4;
    476     context->previous_offsets[2] = 8;
    477 
    478     // Apply details from the dict if it exists
    479     frame_context_apply_dict(context, dict);
    480 }
    481 
    482 static void free_frame_context(frame_context_t *const context) {
    483     HUF_free_dtable(&context->literals_dtable);
    484 
    485     FSE_free_dtable(&context->ll_dtable);
    486     FSE_free_dtable(&context->ml_dtable);
    487     FSE_free_dtable(&context->of_dtable);
    488 
    489     memset(context, 0, sizeof(frame_context_t));
    490 }
    491 
    492 static void parse_frame_header(frame_header_t *const header,
    493                                istream_t *const in) {
    494     // "The first header's byte is called the Frame_Header_Descriptor. It tells
    495     // which other fields are present. Decoding this byte is enough to tell the
    496     // size of Frame_Header.
    497     //
    498     // Bit number   Field name
    499     // 7-6  Frame_Content_Size_flag
    500     // 5    Single_Segment_flag
    501     // 4    Unused_bit
    502     // 3    Reserved_bit
    503     // 2    Content_Checksum_flag
    504     // 1-0  Dictionary_ID_flag"
    505     const u8 descriptor = (u8)IO_read_bits(in, 8);
    506 
    507     // decode frame header descriptor into flags
    508     const u8 frame_content_size_flag = descriptor >> 6;
    509     const u8 single_segment_flag = (descriptor >> 5) & 1;
    510     const u8 reserved_bit = (descriptor >> 3) & 1;
    511     const u8 content_checksum_flag = (descriptor >> 2) & 1;
    512     const u8 dictionary_id_flag = descriptor & 3;
    513 
    514     if (reserved_bit != 0) {
    515         CORRUPTION();
    516     }
    517 
    518     header->single_segment_flag = single_segment_flag;
    519     header->content_checksum_flag = content_checksum_flag;
    520 
    521     // decode window size
    522     if (!single_segment_flag) {
    523         // "Provides guarantees on maximum back-reference distance that will be
    524         // used within compressed data. This information is important for
    525         // decoders to allocate enough memory.
    526         //
    527         // Bit numbers  7-3         2-0
    528         // Field name   Exponent    Mantissa"
    529         u8 window_descriptor = (u8)IO_read_bits(in, 8);
    530         u8 exponent = window_descriptor >> 3;
    531         u8 mantissa = window_descriptor & 7;
    532 
    533         // Use the algorithm from the specification to compute window size
    534         // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#window_descriptor
    535         size_t window_base = (size_t)1 << (10 + exponent);
    536         size_t window_add = (window_base / 8) * mantissa;
    537         header->window_size = window_base + window_add;
    538     }
    539 
    540     // decode dictionary id if it exists
    541     if (dictionary_id_flag) {
    542         // "This is a variable size field, which contains the ID of the
    543         // dictionary required to properly decode the frame. Note that this
    544         // field is optional. When it's not present, it's up to the caller to
    545         // make sure it uses the correct dictionary. Format is little-endian."
    546         const int bytes_array[] = {0, 1, 2, 4};
    547         const int bytes = bytes_array[dictionary_id_flag];
    548 
    549         header->dictionary_id = (u32)IO_read_bits(in, bytes * 8);
    550     } else {
    551         header->dictionary_id = 0;
    552     }
    553 
    554     // decode frame content size if it exists
    555     if (single_segment_flag || frame_content_size_flag) {
    556         // "This is the original (uncompressed) size. This information is
    557         // optional. The Field_Size is provided according to value of
    558         // Frame_Content_Size_flag. The Field_Size can be equal to 0 (not
    559         // present), 1, 2, 4 or 8 bytes. Format is little-endian."
    560         //
    561         // if frame_content_size_flag == 0 but single_segment_flag is set, we
    562         // still have a 1 byte field
    563         const int bytes_array[] = {1, 2, 4, 8};
    564         const int bytes = bytes_array[frame_content_size_flag];
    565 
    566         header->frame_content_size = IO_read_bits(in, bytes * 8);
    567         if (bytes == 2) {
    568             // "When Field_Size is 2, the offset of 256 is added."
    569             header->frame_content_size += 256;
    570         }
    571     } else {
    572         header->frame_content_size = 0;
    573     }
    574 
    575     if (single_segment_flag) {
    576         // "The Window_Descriptor byte is optional. It is absent when
    577         // Single_Segment_flag is set. In this case, the maximum back-reference
    578         // distance is the content size itself, which can be any value from 1 to
    579         // 2^64-1 bytes (16 EB)."
    580         header->window_size = header->frame_content_size;
    581     }
    582 }
    583 
    584 /// Decompress the data from a frame block by block
    585 static void decompress_data(frame_context_t *const ctx, ostream_t *const out,
    586                             istream_t *const in) {
    587     // "A frame encapsulates one or multiple blocks. Each block can be
    588     // compressed or not, and has a guaranteed maximum content size, which
    589     // depends on frame parameters. Unlike frames, each block depends on
    590     // previous blocks for proper decoding. However, each block can be
    591     // decompressed without waiting for its successor, allowing streaming
    592     // operations."
    593     int last_block = 0;
    594     do {
    595         // "Last_Block
    596         //
    597         // The lowest bit signals if this block is the last one. Frame ends
    598         // right after this block.
    599         //
    600         // Block_Type and Block_Size
    601         //
    602         // The next 2 bits represent the Block_Type, while the remaining 21 bits
    603         // represent the Block_Size. Format is little-endian."
    604         last_block = (int)IO_read_bits(in, 1);
    605         const int block_type = (int)IO_read_bits(in, 2);
    606         const size_t block_len = IO_read_bits(in, 21);
    607 
    608         switch (block_type) {
    609         case 0: {
    610             // "Raw_Block - this is an uncompressed block. Block_Size is the
    611             // number of bytes to read and copy."
    612             const u8 *const read_ptr = IO_get_read_ptr(in, block_len);
    613             u8 *const write_ptr = IO_get_write_ptr(out, block_len);
    614 
    615             // Copy the raw data into the output
    616             memcpy(write_ptr, read_ptr, block_len);
    617 
    618             ctx->current_total_output += block_len;
    619             break;
    620         }
    621         case 1: {
    622             // "RLE_Block - this is a single byte, repeated N times. In which
    623             // case, Block_Size is the size to regenerate, while the
    624             // "compressed" block is just 1 byte (the byte to repeat)."
    625             const u8 *const read_ptr = IO_get_read_ptr(in, 1);
    626             u8 *const write_ptr = IO_get_write_ptr(out, block_len);
    627 
    628             // Copy `block_len` copies of `read_ptr[0]` to the output
    629             memset(write_ptr, read_ptr[0], block_len);
    630 
    631             ctx->current_total_output += block_len;
    632             break;
    633         }
    634         case 2: {
    635             // "Compressed_Block - this is a Zstandard compressed block,
    636             // detailed in another section of this specification. Block_Size is
    637             // the compressed size.
    638 
    639             // Create a sub-stream for the block
    640             istream_t block_stream = IO_make_sub_istream(in, block_len);
    641             decompress_block(ctx, out, &block_stream);
    642             break;
    643         }
    644         case 3:
    645             // "Reserved - this is not a block. This value cannot be used with
    646             // current version of this specification."
    647             CORRUPTION();
    648             break;
    649         default:
    650             IMPOSSIBLE();
    651         }
    652     } while (!last_block);
    653 
    654     if (ctx->header.content_checksum_flag) {
    655         // This program does not support checking the checksum, so skip over it
    656         // if it's present
    657         IO_advance_input(in, 4);
    658     }
    659 }
    660 /******* END FRAME DECODING ***************************************************/
    661 
    662 /******* BLOCK DECOMPRESSION **************************************************/
    663 static void decompress_block(frame_context_t *const ctx, ostream_t *const out,
    664                              istream_t *const in) {
    665     // "A compressed block consists of 2 sections :
    666     //
    667     // Literals_Section
    668     // Sequences_Section"
    669 
    670 
    671     // Part 1: decode the literals block
    672     u8 *literals = NULL;
    673     const size_t literals_size = decode_literals(ctx, in, &literals);
    674 
    675     // Part 2: decode the sequences block
    676     sequence_command_t *sequences = NULL;
    677     const size_t num_sequences =
    678         decode_sequences(ctx, in, &sequences);
    679 
    680     // Part 3: combine literals and sequence commands to generate output
    681     execute_sequences(ctx, out, literals, literals_size, sequences,
    682                       num_sequences);
    683     free(literals);
    684     free(sequences);
    685 }
    686 /******* END BLOCK DECOMPRESSION **********************************************/
    687 
    688 /******* LITERALS DECODING ****************************************************/
    689 static size_t decode_literals_simple(istream_t *const in, u8 **const literals,
    690                                      const int block_type,
    691                                      const int size_format);
    692 static size_t decode_literals_compressed(frame_context_t *const ctx,
    693                                          istream_t *const in,
    694                                          u8 **const literals,
    695                                          const int block_type,
    696                                          const int size_format);
    697 static void decode_huf_table(HUF_dtable *const dtable, istream_t *const in);
    698 static void fse_decode_hufweights(ostream_t *weights, istream_t *const in,
    699                                     int *const num_symbs);
    700 
    701 static size_t decode_literals(frame_context_t *const ctx, istream_t *const in,
    702                               u8 **const literals) {
    703     // "Literals can be stored uncompressed or compressed using Huffman prefix
    704     // codes. When compressed, an optional tree description can be present,
    705     // followed by 1 or 4 streams."
    706     //
    707     // "Literals_Section_Header
    708     //
    709     // Header is in charge of describing how literals are packed. It's a
    710     // byte-aligned variable-size bitfield, ranging from 1 to 5 bytes, using
    711     // little-endian convention."
    712     //
    713     // "Literals_Block_Type
    714     //
    715     // This field uses 2 lowest bits of first byte, describing 4 different block
    716     // types"
    717     //
    718     // size_format takes between 1 and 2 bits
    719     int block_type = (int)IO_read_bits(in, 2);
    720     int size_format = (int)IO_read_bits(in, 2);
    721 
    722     if (block_type <= 1) {
    723         // Raw or RLE literals block
    724         return decode_literals_simple(in, literals, block_type,
    725                                       size_format);
    726     } else {
    727         // Huffman compressed literals
    728         return decode_literals_compressed(ctx, in, literals, block_type,
    729                                           size_format);
    730     }
    731 }
    732 
    733 /// Decodes literals blocks in raw or RLE form
    734 static size_t decode_literals_simple(istream_t *const in, u8 **const literals,
    735                                      const int block_type,
    736                                      const int size_format) {
    737     size_t size;
    738     switch (size_format) {
    739     // These cases are in the form ?0
    740     // In this case, the ? bit is actually part of the size field
    741     case 0:
    742     case 2:
    743         // "Size_Format uses 1 bit. Regenerated_Size uses 5 bits (0-31)."
    744         IO_rewind_bits(in, 1);
    745         size = IO_read_bits(in, 5);
    746         break;
    747     case 1:
    748         // "Size_Format uses 2 bits. Regenerated_Size uses 12 bits (0-4095)."
    749         size = IO_read_bits(in, 12);
    750         break;
    751     case 3:
    752         // "Size_Format uses 2 bits. Regenerated_Size uses 20 bits (0-1048575)."
    753         size = IO_read_bits(in, 20);
    754         break;
    755     default:
    756         // Size format is in range 0-3
    757         IMPOSSIBLE();
    758     }
    759 
    760     if (size > MAX_LITERALS_SIZE) {
    761         CORRUPTION();
    762     }
    763 
    764     *literals = malloc(size);
    765     if (!*literals) {
    766         BAD_ALLOC();
    767     }
    768 
    769     switch (block_type) {
    770     case 0: {
    771         // "Raw_Literals_Block - Literals are stored uncompressed."
    772         const u8 *const read_ptr = IO_get_read_ptr(in, size);
    773         memcpy(*literals, read_ptr, size);
    774         break;
    775     }
    776     case 1: {
    777         // "RLE_Literals_Block - Literals consist of a single byte value repeated N times."
    778         const u8 *const read_ptr = IO_get_read_ptr(in, 1);
    779         memset(*literals, read_ptr[0], size);
    780         break;
    781     }
    782     default:
    783         IMPOSSIBLE();
    784     }
    785 
    786     return size;
    787 }
    788 
    789 /// Decodes Huffman compressed literals
    790 static size_t decode_literals_compressed(frame_context_t *const ctx,
    791                                          istream_t *const in,
    792                                          u8 **const literals,
    793                                          const int block_type,
    794                                          const int size_format) {
    795     size_t regenerated_size, compressed_size;
    796     // Only size_format=0 has 1 stream, so default to 4
    797     int num_streams = 4;
    798     switch (size_format) {
    799     case 0:
    800         // "A single stream. Both Compressed_Size and Regenerated_Size use 10
    801         // bits (0-1023)."
    802         num_streams = 1;
    803     // Fall through as it has the same size format
    804         /* fallthrough */
    805     case 1:
    806         // "4 streams. Both Compressed_Size and Regenerated_Size use 10 bits
    807         // (0-1023)."
    808         regenerated_size = IO_read_bits(in, 10);
    809         compressed_size = IO_read_bits(in, 10);
    810         break;
    811     case 2:
    812         // "4 streams. Both Compressed_Size and Regenerated_Size use 14 bits
    813         // (0-16383)."
    814         regenerated_size = IO_read_bits(in, 14);
    815         compressed_size = IO_read_bits(in, 14);
    816         break;
    817     case 3:
    818         // "4 streams. Both Compressed_Size and Regenerated_Size use 18 bits
    819         // (0-262143)."
    820         regenerated_size = IO_read_bits(in, 18);
    821         compressed_size = IO_read_bits(in, 18);
    822         break;
    823     default:
    824         // Impossible
    825         IMPOSSIBLE();
    826     }
    827     if (regenerated_size > MAX_LITERALS_SIZE) {
    828         CORRUPTION();
    829     }
    830 
    831     *literals = malloc(regenerated_size);
    832     if (!*literals) {
    833         BAD_ALLOC();
    834     }
    835 
    836     ostream_t lit_stream = IO_make_ostream(*literals, regenerated_size);
    837     istream_t huf_stream = IO_make_sub_istream(in, compressed_size);
    838 
    839     if (block_type == 2) {
    840         // Decode the provided Huffman table
    841         // "This section is only present when Literals_Block_Type type is
    842         // Compressed_Literals_Block (2)."
    843 
    844         HUF_free_dtable(&ctx->literals_dtable);
    845         decode_huf_table(&ctx->literals_dtable, &huf_stream);
    846     } else {
    847         // If the previous Huffman table is being repeated, ensure it exists
    848         if (!ctx->literals_dtable.symbols) {
    849             CORRUPTION();
    850         }
    851     }
    852 
    853     size_t symbols_decoded;
    854     if (num_streams == 1) {
    855         symbols_decoded = HUF_decompress_1stream(&ctx->literals_dtable, &lit_stream, &huf_stream);
    856     } else {
    857         symbols_decoded = HUF_decompress_4stream(&ctx->literals_dtable, &lit_stream, &huf_stream);
    858     }
    859 
    860     if (symbols_decoded != regenerated_size) {
    861         CORRUPTION();
    862     }
    863 
    864     return regenerated_size;
    865 }
    866 
    867 // Decode the Huffman table description
    868 static void decode_huf_table(HUF_dtable *const dtable, istream_t *const in) {
    869     // "All literal values from zero (included) to last present one (excluded)
    870     // are represented by Weight with values from 0 to Max_Number_of_Bits."
    871 
    872     // "This is a single byte value (0-255), which describes how to decode the list of weights."
    873     const u8 header = IO_read_bits(in, 8);
    874 
    875     u8 weights[HUF_MAX_SYMBS];
    876     memset(weights, 0, sizeof(weights));
    877 
    878     int num_symbs;
    879 
    880     if (header >= 128) {
    881         // "This is a direct representation, where each Weight is written
    882         // directly as a 4 bits field (0-15). The full representation occupies
    883         // ((Number_of_Symbols+1)/2) bytes, meaning it uses a last full byte
    884         // even if Number_of_Symbols is odd. Number_of_Symbols = headerByte -
    885         // 127"
    886         num_symbs = header - 127;
    887         const size_t bytes = (num_symbs + 1) / 2;
    888 
    889         const u8 *const weight_src = IO_get_read_ptr(in, bytes);
    890 
    891         for (int i = 0; i < num_symbs; i++) {
    892             // "They are encoded forward, 2
    893             // weights to a byte with the first weight taking the top four bits
    894             // and the second taking the bottom four (e.g. the following
    895             // operations could be used to read the weights: Weight[0] =
    896             // (Byte[0] >> 4), Weight[1] = (Byte[0] & 0xf), etc.)."
    897             if (i % 2 == 0) {
    898                 weights[i] = weight_src[i / 2] >> 4;
    899             } else {
    900                 weights[i] = weight_src[i / 2] & 0xf;
    901             }
    902         }
    903     } else {
    904         // The weights are FSE encoded, decode them before we can construct the
    905         // table
    906         istream_t fse_stream = IO_make_sub_istream(in, header);
    907         ostream_t weight_stream = IO_make_ostream(weights, HUF_MAX_SYMBS);
    908         fse_decode_hufweights(&weight_stream, &fse_stream, &num_symbs);
    909     }
    910 
    911     // Construct the table using the decoded weights
    912     HUF_init_dtable_usingweights(dtable, weights, num_symbs);
    913 }
    914 
    915 static void fse_decode_hufweights(ostream_t *weights, istream_t *const in,
    916                                     int *const num_symbs) {
    917     const int MAX_ACCURACY_LOG = 7;
    918 
    919     FSE_dtable dtable;
    920 
    921     // "An FSE bitstream starts by a header, describing probabilities
    922     // distribution. It will create a Decoding Table. For a list of Huffman
    923     // weights, maximum accuracy is 7 bits."
    924     FSE_decode_header(&dtable, in, MAX_ACCURACY_LOG);
    925 
    926     // Decode the weights
    927     *num_symbs = FSE_decompress_interleaved2(&dtable, weights, in);
    928 
    929     FSE_free_dtable(&dtable);
    930 }
    931 /******* END LITERALS DECODING ************************************************/
    932 
    933 /******* SEQUENCE DECODING ****************************************************/
    934 /// The combination of FSE states needed to decode sequences
    935 typedef struct {
    936     FSE_dtable ll_table;
    937     FSE_dtable of_table;
    938     FSE_dtable ml_table;
    939 
    940     u16 ll_state;
    941     u16 of_state;
    942     u16 ml_state;
    943 } sequence_states_t;
    944 
    945 /// Different modes to signal to decode_seq_tables what to do
    946 typedef enum {
    947     seq_literal_length = 0,
    948     seq_offset = 1,
    949     seq_match_length = 2,
    950 } seq_part_t;
    951 
    952 typedef enum {
    953     seq_predefined = 0,
    954     seq_rle = 1,
    955     seq_fse = 2,
    956     seq_repeat = 3,
    957 } seq_mode_t;
    958 
    959 /// The predefined FSE distribution tables for `seq_predefined` mode
    960 static const i16 SEQ_LITERAL_LENGTH_DEFAULT_DIST[36] = {
    961     4, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1,  1,  2,  2,
    962     2, 2, 2, 2, 2, 2, 2, 3, 2, 1, 1, 1, 1, 1, -1, -1, -1, -1};
    963 static const i16 SEQ_OFFSET_DEFAULT_DIST[29] = {
    964     1, 1, 1, 1, 1, 1, 2, 2, 2, 1,  1,  1,  1,  1, 1,
    965     1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1};
    966 static const i16 SEQ_MATCH_LENGTH_DEFAULT_DIST[53] = {
    967     1, 4, 3, 2, 2, 2, 2, 2, 2, 1, 1,  1,  1,  1,  1,  1,  1, 1,
    968     1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,  1,  1,  1,  1,  1,  1, 1,
    969     1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1};
    970 
    971 /// The sequence decoding baseline and number of additional bits to read/add
    972 /// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#the-codes-for-literals-lengths-match-lengths-and-offsets
    973 static const u32 SEQ_LITERAL_LENGTH_BASELINES[36] = {
    974     0,  1,  2,   3,   4,   5,    6,    7,    8,    9,     10,    11,
    975     12, 13, 14,  15,  16,  18,   20,   22,   24,   28,    32,    40,
    976     48, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536};
    977 static const u8 SEQ_LITERAL_LENGTH_EXTRA_BITS[36] = {
    978     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  0,  0,  0,  0,  1,  1,
    979     1, 1, 2, 2, 3, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
    980 
    981 static const u32 SEQ_MATCH_LENGTH_BASELINES[53] = {
    982     3,  4,   5,   6,   7,    8,    9,    10,   11,    12,    13,   14, 15, 16,
    983     17, 18,  19,  20,  21,   22,   23,   24,   25,    26,    27,   28, 29, 30,
    984     31, 32,  33,  34,  35,   37,   39,   41,   43,    47,    51,   59, 67, 83,
    985     99, 131, 259, 515, 1027, 2051, 4099, 8195, 16387, 32771, 65539};
    986 static const u8 SEQ_MATCH_LENGTH_EXTRA_BITS[53] = {
    987     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  0,  0,  0,  0,  0,  0, 0,
    988     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  0,  0,  0,  1,  1,  1, 1,
    989     2, 2, 3, 3, 4, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
    990 
    991 /// Offset decoding is simpler so we just need a maximum code value
    992 static const u8 SEQ_MAX_CODES[3] = {35, (u8)-1, 52};
    993 
    994 static void decompress_sequences(frame_context_t *const ctx,
    995                                  istream_t *const in,
    996                                  sequence_command_t *const sequences,
    997                                  const size_t num_sequences);
    998 static sequence_command_t decode_sequence(sequence_states_t *const state,
    999                                           const u8 *const src,
   1000                                           i64 *const offset,
   1001                                           int lastSequence);
   1002 static void decode_seq_table(FSE_dtable *const table, istream_t *const in,
   1003                                const seq_part_t type, const seq_mode_t mode);
   1004 
   1005 static size_t decode_sequences(frame_context_t *const ctx, istream_t *in,
   1006                                sequence_command_t **const sequences) {
   1007     // "A compressed block is a succession of sequences . A sequence is a
   1008     // literal copy command, followed by a match copy command. A literal copy
   1009     // command specifies a length. It is the number of bytes to be copied (or
   1010     // extracted) from the literal section. A match copy command specifies an
   1011     // offset and a length. The offset gives the position to copy from, which
   1012     // can be within a previous block."
   1013 
   1014     size_t num_sequences;
   1015 
   1016     // "Number_of_Sequences
   1017     //
   1018     // This is a variable size field using between 1 and 3 bytes. Let's call its
   1019     // first byte byte0."
   1020     u8 header = IO_read_bits(in, 8);
   1021     if (header < 128) {
   1022         // "Number_of_Sequences = byte0 . Uses 1 byte."
   1023         num_sequences = header;
   1024     } else if (header < 255) {
   1025         // "Number_of_Sequences = ((byte0-128) << 8) + byte1 . Uses 2 bytes."
   1026         num_sequences = ((header - 128) << 8) + IO_read_bits(in, 8);
   1027     } else {
   1028         // "Number_of_Sequences = byte1 + (byte2<<8) + 0x7F00 . Uses 3 bytes."
   1029         num_sequences = IO_read_bits(in, 16) + 0x7F00;
   1030     }
   1031 
   1032     if (num_sequences == 0) {
   1033         // "There are no sequences. The sequence section stops there."
   1034         *sequences = NULL;
   1035         return 0;
   1036     }
   1037 
   1038     *sequences = malloc(num_sequences * sizeof(sequence_command_t));
   1039     if (!*sequences) {
   1040         BAD_ALLOC();
   1041     }
   1042 
   1043     decompress_sequences(ctx, in, *sequences, num_sequences);
   1044     return num_sequences;
   1045 }
   1046 
   1047 /// Decompress the FSE encoded sequence commands
   1048 static void decompress_sequences(frame_context_t *const ctx, istream_t *in,
   1049                                  sequence_command_t *const sequences,
   1050                                  const size_t num_sequences) {
   1051     // "The Sequences_Section regroup all symbols required to decode commands.
   1052     // There are 3 symbol types : literals lengths, offsets and match lengths.
   1053     // They are encoded together, interleaved, in a single bitstream."
   1054 
   1055     // "Symbol compression modes
   1056     //
   1057     // This is a single byte, defining the compression mode of each symbol
   1058     // type."
   1059     //
   1060     // Bit number : Field name
   1061     // 7-6        : Literals_Lengths_Mode
   1062     // 5-4        : Offsets_Mode
   1063     // 3-2        : Match_Lengths_Mode
   1064     // 1-0        : Reserved
   1065     u8 compression_modes = IO_read_bits(in, 8);
   1066 
   1067     if ((compression_modes & 3) != 0) {
   1068         // Reserved bits set
   1069         CORRUPTION();
   1070     }
   1071 
   1072     // "Following the header, up to 3 distribution tables can be described. When
   1073     // present, they are in this order :
   1074     //
   1075     // Literals lengths
   1076     // Offsets
   1077     // Match Lengths"
   1078     // Update the tables we have stored in the context
   1079     decode_seq_table(&ctx->ll_dtable, in, seq_literal_length,
   1080                      (compression_modes >> 6) & 3);
   1081 
   1082     decode_seq_table(&ctx->of_dtable, in, seq_offset,
   1083                      (compression_modes >> 4) & 3);
   1084 
   1085     decode_seq_table(&ctx->ml_dtable, in, seq_match_length,
   1086                      (compression_modes >> 2) & 3);
   1087 
   1088 
   1089     sequence_states_t states;
   1090 
   1091     // Initialize the decoding tables
   1092     {
   1093         states.ll_table = ctx->ll_dtable;
   1094         states.of_table = ctx->of_dtable;
   1095         states.ml_table = ctx->ml_dtable;
   1096     }
   1097 
   1098     const size_t len = IO_istream_len(in);
   1099     const u8 *const src = IO_get_read_ptr(in, len);
   1100 
   1101     // "After writing the last bit containing information, the compressor writes
   1102     // a single 1-bit and then fills the byte with 0-7 0 bits of padding."
   1103     const int padding = 8 - highest_set_bit(src[len - 1]);
   1104     // The offset starts at the end because FSE streams are read backwards
   1105     i64 bit_offset = (i64)(len * 8 - (size_t)padding);
   1106 
   1107     // "The bitstream starts with initial state values, each using the required
   1108     // number of bits in their respective accuracy, decoded previously from
   1109     // their normalized distribution.
   1110     //
   1111     // It starts by Literals_Length_State, followed by Offset_State, and finally
   1112     // Match_Length_State."
   1113     FSE_init_state(&states.ll_table, &states.ll_state, src, &bit_offset);
   1114     FSE_init_state(&states.of_table, &states.of_state, src, &bit_offset);
   1115     FSE_init_state(&states.ml_table, &states.ml_state, src, &bit_offset);
   1116 
   1117     for (size_t i = 0; i < num_sequences; i++) {
   1118         // Decode sequences one by one
   1119         sequences[i] = decode_sequence(&states, src, &bit_offset, i==num_sequences-1);
   1120     }
   1121 
   1122     if (bit_offset != 0) {
   1123         CORRUPTION();
   1124     }
   1125 }
   1126 
   1127 // Decode a single sequence and update the state
   1128 static sequence_command_t decode_sequence(sequence_states_t *const states,
   1129                                           const u8 *const src,
   1130                                           i64 *const offset,
   1131                                           int lastSequence) {
   1132     // "Each symbol is a code in its own context, which specifies Baseline and
   1133     // Number_of_Bits to add. Codes are FSE compressed, and interleaved with raw
   1134     // additional bits in the same bitstream."
   1135 
   1136     // Decode symbols, but don't update states
   1137     const u8 of_code = FSE_peek_symbol(&states->of_table, states->of_state);
   1138     const u8 ll_code = FSE_peek_symbol(&states->ll_table, states->ll_state);
   1139     const u8 ml_code = FSE_peek_symbol(&states->ml_table, states->ml_state);
   1140 
   1141     // Offset doesn't need a max value as it's not decoded using a table
   1142     if (ll_code > SEQ_MAX_CODES[seq_literal_length] ||
   1143         ml_code > SEQ_MAX_CODES[seq_match_length]) {
   1144         CORRUPTION();
   1145     }
   1146 
   1147     // Read the interleaved bits
   1148     sequence_command_t seq;
   1149     // "Decoding starts by reading the Number_of_Bits required to decode Offset.
   1150     // It then does the same for Match_Length, and then for Literals_Length."
   1151     seq.offset = ((u32)1 << of_code) + STREAM_read_bits(src, of_code, offset);
   1152 
   1153     seq.match_length =
   1154         SEQ_MATCH_LENGTH_BASELINES[ml_code] +
   1155         STREAM_read_bits(src, SEQ_MATCH_LENGTH_EXTRA_BITS[ml_code], offset);
   1156 
   1157     seq.literal_length =
   1158         SEQ_LITERAL_LENGTH_BASELINES[ll_code] +
   1159         STREAM_read_bits(src, SEQ_LITERAL_LENGTH_EXTRA_BITS[ll_code], offset);
   1160 
   1161     // "If it is not the last sequence in the block, the next operation is to
   1162     // update states. Using the rules pre-calculated in the decoding tables,
   1163     // Literals_Length_State is updated, followed by Match_Length_State, and
   1164     // then Offset_State."
   1165     // If the stream is complete don't read bits to update state
   1166     if (!lastSequence) {
   1167         FSE_update_state(&states->ll_table, &states->ll_state, src, offset);
   1168         FSE_update_state(&states->ml_table, &states->ml_state, src, offset);
   1169         FSE_update_state(&states->of_table, &states->of_state, src, offset);
   1170     }
   1171 
   1172     return seq;
   1173 }
   1174 
   1175 /// Given a sequence part and table mode, decode the FSE distribution
   1176 /// Errors if the mode is `seq_repeat` without a pre-existing table in `table`
   1177 static void decode_seq_table(FSE_dtable *const table, istream_t *const in,
   1178                              const seq_part_t type, const seq_mode_t mode) {
   1179     // Constant arrays indexed by seq_part_t
   1180     const i16 *const default_distributions[] = {SEQ_LITERAL_LENGTH_DEFAULT_DIST,
   1181                                                 SEQ_OFFSET_DEFAULT_DIST,
   1182                                                 SEQ_MATCH_LENGTH_DEFAULT_DIST};
   1183     const size_t default_distribution_lengths[] = {36, 29, 53};
   1184     const size_t default_distribution_accuracies[] = {6, 5, 6};
   1185 
   1186     const size_t max_accuracies[] = {9, 8, 9};
   1187 
   1188     if (mode != seq_repeat) {
   1189         // Free old one before overwriting
   1190         FSE_free_dtable(table);
   1191     }
   1192 
   1193     switch (mode) {
   1194     case seq_predefined: {
   1195         // "Predefined_Mode : uses a predefined distribution table."
   1196         const i16 *distribution = default_distributions[type];
   1197         const size_t symbs = default_distribution_lengths[type];
   1198         const size_t accuracy_log = default_distribution_accuracies[type];
   1199 
   1200         FSE_init_dtable(table, distribution, symbs, accuracy_log);
   1201         break;
   1202     }
   1203     case seq_rle: {
   1204         // "RLE_Mode : it's a single code, repeated Number_of_Sequences times."
   1205         const u8 symb = IO_get_read_ptr(in, 1)[0];
   1206         FSE_init_dtable_rle(table, symb);
   1207         break;
   1208     }
   1209     case seq_fse: {
   1210         // "FSE_Compressed_Mode : standard FSE compression. A distribution table
   1211         // will be present "
   1212         FSE_decode_header(table, in, max_accuracies[type]);
   1213         break;
   1214     }
   1215     case seq_repeat:
   1216         // "Repeat_Mode : reuse distribution table from previous compressed
   1217         // block."
   1218         // Nothing to do here, table will be unchanged
   1219         if (!table->symbols) {
   1220             // This mode is invalid if we don't already have a table
   1221             CORRUPTION();
   1222         }
   1223         break;
   1224     default:
   1225         // Impossible, as mode is from 0-3
   1226         IMPOSSIBLE();
   1227         break;
   1228     }
   1229 
   1230 }
   1231 /******* END SEQUENCE DECODING ************************************************/
   1232 
   1233 /******* SEQUENCE EXECUTION ***************************************************/
   1234 static void execute_sequences(frame_context_t *const ctx, ostream_t *const out,
   1235                               const u8 *const literals,
   1236                               const size_t literals_len,
   1237                               const sequence_command_t *const sequences,
   1238                               const size_t num_sequences) {
   1239     istream_t litstream = IO_make_istream(literals, literals_len);
   1240 
   1241     u64 *const offset_hist = ctx->previous_offsets;
   1242     size_t total_output = ctx->current_total_output;
   1243 
   1244     for (size_t i = 0; i < num_sequences; i++) {
   1245         const sequence_command_t seq = sequences[i];
   1246         {
   1247             const u32 literals_size = copy_literals(seq.literal_length, &litstream, out);
   1248             total_output += literals_size;
   1249         }
   1250 
   1251         size_t const offset = compute_offset(seq, offset_hist);
   1252 
   1253         size_t const match_length = seq.match_length;
   1254 
   1255         execute_match_copy(ctx, offset, match_length, total_output, out);
   1256 
   1257         total_output += match_length;
   1258     }
   1259 
   1260     // Copy any leftover literals
   1261     {
   1262         size_t len = IO_istream_len(&litstream);
   1263         copy_literals(len, &litstream, out);
   1264         total_output += len;
   1265     }
   1266 
   1267     ctx->current_total_output = total_output;
   1268 }
   1269 
   1270 static u32 copy_literals(const size_t literal_length, istream_t *litstream,
   1271                          ostream_t *const out) {
   1272     // If the sequence asks for more literals than are left, the
   1273     // sequence must be corrupted
   1274     if (literal_length > IO_istream_len(litstream)) {
   1275         CORRUPTION();
   1276     }
   1277 
   1278     u8 *const write_ptr = IO_get_write_ptr(out, literal_length);
   1279     const u8 *const read_ptr =
   1280          IO_get_read_ptr(litstream, literal_length);
   1281     // Copy literals to output
   1282     memcpy(write_ptr, read_ptr, literal_length);
   1283 
   1284     return literal_length;
   1285 }
   1286 
   1287 static size_t compute_offset(sequence_command_t seq, u64 *const offset_hist) {
   1288     size_t offset;
   1289     // Offsets are special, we need to handle the repeat offsets
   1290     if (seq.offset <= 3) {
   1291         // "The first 3 values define a repeated offset and we will call
   1292         // them Repeated_Offset1, Repeated_Offset2, and Repeated_Offset3.
   1293         // They are sorted in recency order, with Repeated_Offset1 meaning
   1294         // 'most recent one'".
   1295 
   1296         // Use 0 indexing for the array
   1297         u32 idx = seq.offset - 1;
   1298         if (seq.literal_length == 0) {
   1299             // "There is an exception though, when current sequence's
   1300             // literals length is 0. In this case, repeated offsets are
   1301             // shifted by one, so Repeated_Offset1 becomes Repeated_Offset2,
   1302             // Repeated_Offset2 becomes Repeated_Offset3, and
   1303             // Repeated_Offset3 becomes Repeated_Offset1 - 1_byte."
   1304             idx++;
   1305         }
   1306 
   1307         if (idx == 0) {
   1308             offset = offset_hist[0];
   1309         } else {
   1310             // If idx == 3 then literal length was 0 and the offset was 3,
   1311             // as per the exception listed above
   1312             offset = idx < 3 ? offset_hist[idx] : offset_hist[0] - 1;
   1313 
   1314             // If idx == 1 we don't need to modify offset_hist[2], since
   1315             // we're using the second-most recent code
   1316             if (idx > 1) {
   1317                 offset_hist[2] = offset_hist[1];
   1318             }
   1319             offset_hist[1] = offset_hist[0];
   1320             offset_hist[0] = offset;
   1321         }
   1322     } else {
   1323         // When it's not a repeat offset:
   1324         // "if (Offset_Value > 3) offset = Offset_Value - 3;"
   1325         offset = seq.offset - 3;
   1326 
   1327         // Shift back history
   1328         offset_hist[2] = offset_hist[1];
   1329         offset_hist[1] = offset_hist[0];
   1330         offset_hist[0] = offset;
   1331     }
   1332     return offset;
   1333 }
   1334 
   1335 static void execute_match_copy(frame_context_t *const ctx, size_t offset,
   1336                               size_t match_length, size_t total_output,
   1337                               ostream_t *const out) {
   1338     u8 *write_ptr = IO_get_write_ptr(out, match_length);
   1339     if (total_output <= ctx->header.window_size) {
   1340         // In this case offset might go back into the dictionary
   1341         if (offset > total_output + ctx->dict_content_len) {
   1342             // The offset goes beyond even the dictionary
   1343             CORRUPTION();
   1344         }
   1345 
   1346         if (offset > total_output) {
   1347             // "The rest of the dictionary is its content. The content act
   1348             // as a "past" in front of data to compress or decompress, so it
   1349             // can be referenced in sequence commands."
   1350             const size_t dict_copy =
   1351                 MIN(offset - total_output, match_length);
   1352             const size_t dict_offset =
   1353                 ctx->dict_content_len - (offset - total_output);
   1354 
   1355             memcpy(write_ptr, ctx->dict_content + dict_offset, dict_copy);
   1356             write_ptr += dict_copy;
   1357             match_length -= dict_copy;
   1358         }
   1359     } else if (offset > ctx->header.window_size) {
   1360         CORRUPTION();
   1361     }
   1362 
   1363     // We must copy byte by byte because the match length might be larger
   1364     // than the offset
   1365     // ex: if the output so far was "abc", a command with offset=3 and
   1366     // match_length=6 would produce "abcabcabc" as the new output
   1367     for (size_t j = 0; j < match_length; j++) {
   1368         *write_ptr = *(write_ptr - offset);
   1369         write_ptr++;
   1370     }
   1371 }
   1372 /******* END SEQUENCE EXECUTION ***********************************************/
   1373 
   1374 /******* OUTPUT SIZE COUNTING *************************************************/
   1375 /// Get the decompressed size of an input stream so memory can be allocated in
   1376 /// advance.
   1377 /// This implementation assumes `src` points to a single ZSTD-compressed frame
   1378 size_t ZSTD_get_decompressed_size(const void *src, const size_t src_len) {
   1379     istream_t in = IO_make_istream(src, src_len);
   1380 
   1381     // get decompressed size from ZSTD frame header
   1382     {
   1383         const u32 magic_number = (u32)IO_read_bits(&in, 32);
   1384 
   1385         if (magic_number == ZSTD_MAGIC_NUMBER) {
   1386             // ZSTD frame
   1387             frame_header_t header;
   1388             parse_frame_header(&header, &in);
   1389 
   1390             if (header.frame_content_size == 0 && !header.single_segment_flag) {
   1391                 // Content size not provided, we can't tell
   1392                 return (size_t)-1;
   1393             }
   1394 
   1395             return header.frame_content_size;
   1396         } else {
   1397             // not a real frame or skippable frame
   1398             ERROR("ZSTD frame magic number did not match");
   1399         }
   1400     }
   1401 }
   1402 /******* END OUTPUT SIZE COUNTING *********************************************/
   1403 
   1404 /******* DICTIONARY PARSING ***************************************************/
   1405 dictionary_t* create_dictionary(void) {
   1406     dictionary_t* const dict = calloc(1, sizeof(dictionary_t));
   1407     if (!dict) {
   1408         BAD_ALLOC();
   1409     }
   1410     return dict;
   1411 }
   1412 
   1413 /// Free an allocated dictionary
   1414 void free_dictionary(dictionary_t *const dict) {
   1415     HUF_free_dtable(&dict->literals_dtable);
   1416     FSE_free_dtable(&dict->ll_dtable);
   1417     FSE_free_dtable(&dict->of_dtable);
   1418     FSE_free_dtable(&dict->ml_dtable);
   1419 
   1420     free(dict->content);
   1421 
   1422     memset(dict, 0, sizeof(dictionary_t));
   1423 
   1424     free(dict);
   1425 }
   1426 
   1427 
   1428 #if !defined(ZDEC_NO_DICTIONARY)
   1429 #define DICT_SIZE_ERROR() ERROR("Dictionary size cannot be less than 8 bytes")
   1430 #define NULL_SRC() ERROR("Tried to create dictionary with pointer to null src");
   1431 
   1432 static void init_dictionary_content(dictionary_t *const dict,
   1433                                     istream_t *const in);
   1434 
   1435 void parse_dictionary(dictionary_t *const dict, const void *src,
   1436                              size_t src_len) {
   1437     const u8 *byte_src = (const u8 *)src;
   1438     memset(dict, 0, sizeof(dictionary_t));
   1439     if (src == NULL) { /* cannot initialize dictionary with null src */
   1440         NULL_SRC();
   1441     }
   1442     if (src_len < 8) {
   1443         DICT_SIZE_ERROR();
   1444     }
   1445 
   1446     istream_t in = IO_make_istream(byte_src, src_len);
   1447 
   1448     const u32 magic_number = IO_read_bits(&in, 32);
   1449     if (magic_number != 0xEC30A437) {
   1450         // raw content dict
   1451         IO_rewind_bits(&in, 32);
   1452         init_dictionary_content(dict, &in);
   1453         return;
   1454     }
   1455 
   1456     dict->dictionary_id = IO_read_bits(&in, 32);
   1457 
   1458     // "Entropy_Tables : following the same format as the tables in compressed
   1459     // blocks. They are stored in following order : Huffman tables for literals,
   1460     // FSE table for offsets, FSE table for match lengths, and FSE table for
   1461     // literals lengths. It's finally followed by 3 offset values, populating
   1462     // recent offsets (instead of using {1,4,8}), stored in order, 4-bytes
   1463     // little-endian each, for a total of 12 bytes. Each recent offset must have
   1464     // a value < dictionary size."
   1465     decode_huf_table(&dict->literals_dtable, &in);
   1466     decode_seq_table(&dict->of_dtable, &in, seq_offset, seq_fse);
   1467     decode_seq_table(&dict->ml_dtable, &in, seq_match_length, seq_fse);
   1468     decode_seq_table(&dict->ll_dtable, &in, seq_literal_length, seq_fse);
   1469 
   1470     // Read in the previous offset history
   1471     dict->previous_offsets[0] = IO_read_bits(&in, 32);
   1472     dict->previous_offsets[1] = IO_read_bits(&in, 32);
   1473     dict->previous_offsets[2] = IO_read_bits(&in, 32);
   1474 
   1475     // Ensure the provided offsets aren't too large
   1476     // "Each recent offset must have a value < dictionary size."
   1477     for (int i = 0; i < 3; i++) {
   1478         if (dict->previous_offsets[i] > src_len) {
   1479             ERROR("Dictionary corrupted");
   1480         }
   1481     }
   1482 
   1483     // "Content : The rest of the dictionary is its content. The content act as
   1484     // a "past" in front of data to compress or decompress, so it can be
   1485     // referenced in sequence commands."
   1486     init_dictionary_content(dict, &in);
   1487 }
   1488 
   1489 static void init_dictionary_content(dictionary_t *const dict,
   1490                                     istream_t *const in) {
   1491     // Copy in the content
   1492     dict->content_size = IO_istream_len(in);
   1493     dict->content = malloc(dict->content_size);
   1494     if (!dict->content) {
   1495         BAD_ALLOC();
   1496     }
   1497 
   1498     const u8 *const content = IO_get_read_ptr(in, dict->content_size);
   1499 
   1500     memcpy(dict->content, content, dict->content_size);
   1501 }
   1502 
   1503 static void HUF_copy_dtable(HUF_dtable *const dst,
   1504                             const HUF_dtable *const src) {
   1505     if (src->max_bits == 0) {
   1506         memset(dst, 0, sizeof(HUF_dtable));
   1507         return;
   1508     }
   1509 
   1510     const size_t size = (size_t)1 << src->max_bits;
   1511     dst->max_bits = src->max_bits;
   1512 
   1513     dst->symbols = malloc(size);
   1514     dst->num_bits = malloc(size);
   1515     if (!dst->symbols || !dst->num_bits) {
   1516         BAD_ALLOC();
   1517     }
   1518 
   1519     memcpy(dst->symbols, src->symbols, size);
   1520     memcpy(dst->num_bits, src->num_bits, size);
   1521 }
   1522 
   1523 static void FSE_copy_dtable(FSE_dtable *const dst, const FSE_dtable *const src) {
   1524     if (src->accuracy_log == 0) {
   1525         memset(dst, 0, sizeof(FSE_dtable));
   1526         return;
   1527     }
   1528 
   1529     size_t size = (size_t)1 << src->accuracy_log;
   1530     dst->accuracy_log = src->accuracy_log;
   1531 
   1532     dst->symbols = malloc(size);
   1533     dst->num_bits = malloc(size);
   1534     dst->new_state_base = malloc(size * sizeof(u16));
   1535     if (!dst->symbols || !dst->num_bits || !dst->new_state_base) {
   1536         BAD_ALLOC();
   1537     }
   1538 
   1539     memcpy(dst->symbols, src->symbols, size);
   1540     memcpy(dst->num_bits, src->num_bits, size);
   1541     memcpy(dst->new_state_base, src->new_state_base, size * sizeof(u16));
   1542 }
   1543 
   1544 /// A dictionary acts as initializing values for the frame context before
   1545 /// decompression, so we implement it by applying it's predetermined
   1546 /// tables and content to the context before beginning decompression
   1547 static void frame_context_apply_dict(frame_context_t *const ctx,
   1548                                      const dictionary_t *const dict) {
   1549     // If the content pointer is NULL then it must be an empty dict
   1550     if (!dict || !dict->content)
   1551         return;
   1552 
   1553     // If the requested dictionary_id is non-zero, the correct dictionary must
   1554     // be present
   1555     if (ctx->header.dictionary_id != 0 &&
   1556         ctx->header.dictionary_id != dict->dictionary_id) {
   1557         ERROR("Wrong dictionary provided");
   1558     }
   1559 
   1560     // Copy the dict content to the context for references during sequence
   1561     // execution
   1562     ctx->dict_content = dict->content;
   1563     ctx->dict_content_len = dict->content_size;
   1564 
   1565     // If it's a formatted dict copy the precomputed tables in so they can
   1566     // be used in the table repeat modes
   1567     if (dict->dictionary_id != 0) {
   1568         // Deep copy the entropy tables so they can be freed independently of
   1569         // the dictionary struct
   1570         HUF_copy_dtable(&ctx->literals_dtable, &dict->literals_dtable);
   1571         FSE_copy_dtable(&ctx->ll_dtable, &dict->ll_dtable);
   1572         FSE_copy_dtable(&ctx->of_dtable, &dict->of_dtable);
   1573         FSE_copy_dtable(&ctx->ml_dtable, &dict->ml_dtable);
   1574 
   1575         // Copy the repeated offsets
   1576         memcpy(ctx->previous_offsets, dict->previous_offsets,
   1577                sizeof(ctx->previous_offsets));
   1578     }
   1579 }
   1580 
   1581 #else  // ZDEC_NO_DICTIONARY is defined
   1582 
   1583 static void frame_context_apply_dict(frame_context_t *const ctx,
   1584                                      const dictionary_t *const dict) {
   1585     (void)ctx;
   1586     if (dict && dict->content) ERROR("dictionary not supported");
   1587 }
   1588 
   1589 #endif
   1590 /******* END DICTIONARY PARSING ***********************************************/
   1591 
   1592 /******* IO STREAM OPERATIONS *************************************************/
   1593 
   1594 /// Reads `num` bits from a bitstream, and updates the internal offset
   1595 static inline u64 IO_read_bits(istream_t *const in, const int num_bits) {
   1596     if (num_bits > 64 || num_bits <= 0) {
   1597         ERROR("Attempt to read an invalid number of bits");
   1598     }
   1599 
   1600     const size_t bytes = (num_bits + in->bit_offset + 7) / 8;
   1601     const size_t full_bytes = (num_bits + in->bit_offset) / 8;
   1602     if (bytes > in->len) {
   1603         INP_SIZE();
   1604     }
   1605 
   1606     const u64 result = read_bits_LE(in->ptr, num_bits, in->bit_offset);
   1607 
   1608     in->bit_offset = (num_bits + in->bit_offset) % 8;
   1609     in->ptr += full_bytes;
   1610     in->len -= full_bytes;
   1611 
   1612     return result;
   1613 }
   1614 
   1615 /// If a non-zero number of bits have been read from the current byte, advance
   1616 /// the offset to the next byte
   1617 static inline void IO_rewind_bits(istream_t *const in, int num_bits) {
   1618     if (num_bits < 0) {
   1619         ERROR("Attempting to rewind stream by a negative number of bits");
   1620     }
   1621 
   1622     // move the offset back by `num_bits` bits
   1623     const int new_offset = in->bit_offset - num_bits;
   1624     // determine the number of whole bytes we have to rewind, rounding up to an
   1625     // integer number (e.g. if `new_offset == -5`, `bytes == 1`)
   1626     const i64 bytes = -(new_offset - 7) / 8;
   1627 
   1628     in->ptr -= bytes;
   1629     in->len += bytes;
   1630     // make sure the resulting `bit_offset` is positive, as mod in C does not
   1631     // convert numbers from negative to positive (e.g. -22 % 8 == -6)
   1632     in->bit_offset = ((new_offset % 8) + 8) % 8;
   1633 }
   1634 
   1635 /// If the remaining bits in a byte will be unused, advance to the end of the
   1636 /// byte
   1637 static inline void IO_align_stream(istream_t *const in) {
   1638     if (in->bit_offset != 0) {
   1639         if (in->len == 0) {
   1640             INP_SIZE();
   1641         }
   1642         in->ptr++;
   1643         in->len--;
   1644         in->bit_offset = 0;
   1645     }
   1646 }
   1647 
   1648 /// Write the given byte into the output stream
   1649 static inline void IO_write_byte(ostream_t *const out, u8 symb) {
   1650     if (out->len == 0) {
   1651         OUT_SIZE();
   1652     }
   1653 
   1654     out->ptr[0] = symb;
   1655     out->ptr++;
   1656     out->len--;
   1657 }
   1658 
   1659 /// Returns the number of bytes left to be read in this stream.  The stream must
   1660 /// be byte aligned.
   1661 static inline size_t IO_istream_len(const istream_t *const in) {
   1662     return in->len;
   1663 }
   1664 
   1665 /// Returns a pointer where `len` bytes can be read, and advances the internal
   1666 /// state.  The stream must be byte aligned.
   1667 static inline const u8 *IO_get_read_ptr(istream_t *const in, size_t len) {
   1668     if (len > in->len) {
   1669         INP_SIZE();
   1670     }
   1671     if (in->bit_offset != 0) {
   1672         ERROR("Attempting to operate on a non-byte aligned stream");
   1673     }
   1674     const u8 *const ptr = in->ptr;
   1675     in->ptr += len;
   1676     in->len -= len;
   1677 
   1678     return ptr;
   1679 }
   1680 /// Returns a pointer to write `len` bytes to, and advances the internal state
   1681 static inline u8 *IO_get_write_ptr(ostream_t *const out, size_t len) {
   1682     if (len > out->len) {
   1683         OUT_SIZE();
   1684     }
   1685     u8 *const ptr = out->ptr;
   1686     out->ptr += len;
   1687     out->len -= len;
   1688 
   1689     return ptr;
   1690 }
   1691 
   1692 /// Advance the inner state by `len` bytes
   1693 static inline void IO_advance_input(istream_t *const in, size_t len) {
   1694     if (len > in->len) {
   1695          INP_SIZE();
   1696     }
   1697     if (in->bit_offset != 0) {
   1698         ERROR("Attempting to operate on a non-byte aligned stream");
   1699     }
   1700 
   1701     in->ptr += len;
   1702     in->len -= len;
   1703 }
   1704 
   1705 /// Returns an `ostream_t` constructed from the given pointer and length
   1706 static inline ostream_t IO_make_ostream(u8 *out, size_t len) {
   1707     return (ostream_t) { out, len };
   1708 }
   1709 
   1710 /// Returns an `istream_t` constructed from the given pointer and length
   1711 static inline istream_t IO_make_istream(const u8 *in, size_t len) {
   1712     return (istream_t) { in, len, 0 };
   1713 }
   1714 
   1715 /// Returns an `istream_t` with the same base as `in`, and length `len`
   1716 /// Then, advance `in` to account for the consumed bytes
   1717 /// `in` must be byte aligned
   1718 static inline istream_t IO_make_sub_istream(istream_t *const in, size_t len) {
   1719     // Consume `len` bytes of the parent stream
   1720     const u8 *const ptr = IO_get_read_ptr(in, len);
   1721 
   1722     // Make a substream using the pointer to those `len` bytes
   1723     return IO_make_istream(ptr, len);
   1724 }
   1725 /******* END IO STREAM OPERATIONS *********************************************/
   1726 
   1727 /******* BITSTREAM OPERATIONS *************************************************/
   1728 /// Read `num` bits (up to 64) from `src + offset`, where `offset` is in bits
   1729 static inline u64 read_bits_LE(const u8 *src, const int num_bits,
   1730                                const size_t offset) {
   1731     if (num_bits > 64) {
   1732         ERROR("Attempt to read an invalid number of bits");
   1733     }
   1734 
   1735     // Skip over bytes that aren't in range
   1736     src += offset / 8;
   1737     size_t bit_offset = offset % 8;
   1738     u64 res = 0;
   1739 
   1740     int shift = 0;
   1741     int left = num_bits;
   1742     while (left > 0) {
   1743         u64 mask = left >= 8 ? 0xff : (((u64)1 << left) - 1);
   1744         // Read the next byte, shift it to account for the offset, and then mask
   1745         // out the top part if we don't need all the bits
   1746         res += (((u64)*src++ >> bit_offset) & mask) << shift;
   1747         shift += 8 - bit_offset;
   1748         left -= 8 - bit_offset;
   1749         bit_offset = 0;
   1750     }
   1751 
   1752     return res;
   1753 }
   1754 
   1755 /// Read bits from the end of a HUF or FSE bitstream.  `offset` is in bits, so
   1756 /// it updates `offset` to `offset - bits`, and then reads `bits` bits from
   1757 /// `src + offset`.  If the offset becomes negative, the extra bits at the
   1758 /// bottom are filled in with `0` bits instead of reading from before `src`.
   1759 static inline u64 STREAM_read_bits(const u8 *const src, const int bits,
   1760                                    i64 *const offset) {
   1761     *offset = *offset - bits;
   1762     size_t actual_off = *offset;
   1763     size_t actual_bits = bits;
   1764     // Don't actually read bits from before the start of src, so if `*offset <
   1765     // 0` fix actual_off and actual_bits to reflect the quantity to read
   1766     if (*offset < 0) {
   1767         actual_bits += *offset;
   1768         actual_off = 0;
   1769     }
   1770     u64 res = read_bits_LE(src, actual_bits, actual_off);
   1771 
   1772     if (*offset < 0) {
   1773         // Fill in the bottom "overflowed" bits with 0's
   1774         res = -*offset >= 64 ? 0 : (res << -*offset);
   1775     }
   1776     return res;
   1777 }
   1778 /******* END BITSTREAM OPERATIONS *********************************************/
   1779 
   1780 /******* BIT COUNTING OPERATIONS **********************************************/
   1781 /// Returns `x`, where `2^x` is the largest power of 2 less than or equal to
   1782 /// `num`, or `-1` if `num == 0`.
   1783 static inline int highest_set_bit(const u64 num) {
   1784     for (int i = 63; i >= 0; i--) {
   1785         if (((u64)1 << i) <= num) {
   1786             return i;
   1787         }
   1788     }
   1789     return -1;
   1790 }
   1791 /******* END BIT COUNTING OPERATIONS ******************************************/
   1792 
   1793 /******* HUFFMAN PRIMITIVES ***************************************************/
   1794 static inline u8 HUF_decode_symbol(const HUF_dtable *const dtable,
   1795                                    u16 *const state, const u8 *const src,
   1796                                    i64 *const offset) {
   1797     // Look up the symbol and number of bits to read
   1798     const u8 symb = dtable->symbols[*state];
   1799     const u8 bits = dtable->num_bits[*state];
   1800     const u16 rest = STREAM_read_bits(src, bits, offset);
   1801     // Shift `bits` bits out of the state, keeping the low order bits that
   1802     // weren't necessary to determine this symbol.  Then add in the new bits
   1803     // read from the stream.
   1804     *state = ((*state << bits) + rest) & (((u16)1 << dtable->max_bits) - 1);
   1805 
   1806     return symb;
   1807 }
   1808 
   1809 static inline void HUF_init_state(const HUF_dtable *const dtable,
   1810                                   u16 *const state, const u8 *const src,
   1811                                   i64 *const offset) {
   1812     // Read in a full `dtable->max_bits` bits to initialize the state
   1813     const u8 bits = dtable->max_bits;
   1814     *state = STREAM_read_bits(src, bits, offset);
   1815 }
   1816 
   1817 static size_t HUF_decompress_1stream(const HUF_dtable *const dtable,
   1818                                      ostream_t *const out,
   1819                                      istream_t *const in) {
   1820     const size_t len = IO_istream_len(in);
   1821     if (len == 0) {
   1822         INP_SIZE();
   1823     }
   1824     const u8 *const src = IO_get_read_ptr(in, len);
   1825 
   1826     // "Each bitstream must be read backward, that is starting from the end down
   1827     // to the beginning. Therefore it's necessary to know the size of each
   1828     // bitstream.
   1829     //
   1830     // It's also necessary to know exactly which bit is the latest. This is
   1831     // detected by a final bit flag : the highest bit of latest byte is a
   1832     // final-bit-flag. Consequently, a last byte of 0 is not possible. And the
   1833     // final-bit-flag itself is not part of the useful bitstream. Hence, the
   1834     // last byte contains between 0 and 7 useful bits."
   1835     const int padding = 8 - highest_set_bit(src[len - 1]);
   1836 
   1837     // Offset starts at the end because HUF streams are read backwards
   1838     i64 bit_offset = len * 8 - padding;
   1839     u16 state;
   1840 
   1841     HUF_init_state(dtable, &state, src, &bit_offset);
   1842 
   1843     size_t symbols_written = 0;
   1844     while (bit_offset > -dtable->max_bits) {
   1845         // Iterate over the stream, decoding one symbol at a time
   1846         IO_write_byte(out, HUF_decode_symbol(dtable, &state, src, &bit_offset));
   1847         symbols_written++;
   1848     }
   1849     // "The process continues up to reading the required number of symbols per
   1850     // stream. If a bitstream is not entirely and exactly consumed, hence
   1851     // reaching exactly its beginning position with all bits consumed, the
   1852     // decoding process is considered faulty."
   1853 
   1854     // When all symbols have been decoded, the final state value shouldn't have
   1855     // any data from the stream, so it should have "read" dtable->max_bits from
   1856     // before the start of `src`
   1857     // Therefore `offset`, the edge to start reading new bits at, should be
   1858     // dtable->max_bits before the start of the stream
   1859     if (bit_offset != -dtable->max_bits) {
   1860         CORRUPTION();
   1861     }
   1862 
   1863     return symbols_written;
   1864 }
   1865 
   1866 static size_t HUF_decompress_4stream(const HUF_dtable *const dtable,
   1867                                      ostream_t *const out, istream_t *const in) {
   1868     // "Compressed size is provided explicitly : in the 4-streams variant,
   1869     // bitstreams are preceded by 3 unsigned little-endian 16-bits values. Each
   1870     // value represents the compressed size of one stream, in order. The last
   1871     // stream size is deducted from total compressed size and from previously
   1872     // decoded stream sizes"
   1873     const size_t csize1 = IO_read_bits(in, 16);
   1874     const size_t csize2 = IO_read_bits(in, 16);
   1875     const size_t csize3 = IO_read_bits(in, 16);
   1876 
   1877     istream_t in1 = IO_make_sub_istream(in, csize1);
   1878     istream_t in2 = IO_make_sub_istream(in, csize2);
   1879     istream_t in3 = IO_make_sub_istream(in, csize3);
   1880     istream_t in4 = IO_make_sub_istream(in, IO_istream_len(in));
   1881 
   1882     size_t total_output = 0;
   1883     // Decode each stream independently for simplicity
   1884     // If we wanted to we could decode all 4 at the same time for speed,
   1885     // utilizing more execution units
   1886     total_output += HUF_decompress_1stream(dtable, out, &in1);
   1887     total_output += HUF_decompress_1stream(dtable, out, &in2);
   1888     total_output += HUF_decompress_1stream(dtable, out, &in3);
   1889     total_output += HUF_decompress_1stream(dtable, out, &in4);
   1890 
   1891     return total_output;
   1892 }
   1893 
   1894 /// Initializes a Huffman table using canonical Huffman codes
   1895 /// For more explanation on canonical Huffman codes see
   1896 /// https://www.cs.scranton.edu/~mccloske/courses/cmps340/huff_canonical_dec2015.html
   1897 /// Codes within a level are allocated in symbol order (i.e. smaller symbols get
   1898 /// earlier codes)
   1899 static void HUF_init_dtable(HUF_dtable *const table, const u8 *const bits,
   1900                             const int num_symbs) {
   1901     memset(table, 0, sizeof(HUF_dtable));
   1902     if (num_symbs > HUF_MAX_SYMBS) {
   1903         ERROR("Too many symbols for Huffman");
   1904     }
   1905 
   1906     u8 max_bits = 0;
   1907     u16 rank_count[HUF_MAX_BITS + 1];
   1908     memset(rank_count, 0, sizeof(rank_count));
   1909 
   1910     // Count the number of symbols for each number of bits, and determine the
   1911     // depth of the tree
   1912     for (int i = 0; i < num_symbs; i++) {
   1913         if (bits[i] > HUF_MAX_BITS) {
   1914             ERROR("Huffman table depth too large");
   1915         }
   1916         max_bits = MAX(max_bits, bits[i]);
   1917         rank_count[bits[i]]++;
   1918     }
   1919 
   1920     const size_t table_size = 1 << max_bits;
   1921     table->max_bits = max_bits;
   1922     table->symbols = malloc(table_size);
   1923     table->num_bits = malloc(table_size);
   1924 
   1925     if (!table->symbols || !table->num_bits) {
   1926         free(table->symbols);
   1927         free(table->num_bits);
   1928         BAD_ALLOC();
   1929     }
   1930 
   1931     // "Symbols are sorted by Weight. Within same Weight, symbols keep natural
   1932     // order. Symbols with a Weight of zero are removed. Then, starting from
   1933     // lowest weight, prefix codes are distributed in order."
   1934 
   1935     u32 rank_idx[HUF_MAX_BITS + 1];
   1936     // Initialize the starting codes for each rank (number of bits)
   1937     rank_idx[max_bits] = 0;
   1938     for (int i = max_bits; i >= 1; i--) {
   1939         rank_idx[i - 1] = rank_idx[i] + rank_count[i] * (1 << (max_bits - i));
   1940         // The entire range takes the same number of bits so we can memset it
   1941         memset(&table->num_bits[rank_idx[i]], i, rank_idx[i - 1] - rank_idx[i]);
   1942     }
   1943 
   1944     if (rank_idx[0] != table_size) {
   1945         CORRUPTION();
   1946     }
   1947 
   1948     // Allocate codes and fill in the table
   1949     for (int i = 0; i < num_symbs; i++) {
   1950         if (bits[i] != 0) {
   1951             // Allocate a code for this symbol and set its range in the table
   1952             const u16 code = rank_idx[bits[i]];
   1953             // Since the code doesn't care about the bottom `max_bits - bits[i]`
   1954             // bits of state, it gets a range that spans all possible values of
   1955             // the lower bits
   1956             const u16 len = 1 << (max_bits - bits[i]);
   1957             memset(&table->symbols[code], i, len);
   1958             rank_idx[bits[i]] += len;
   1959         }
   1960     }
   1961 }
   1962 
   1963 static void HUF_init_dtable_usingweights(HUF_dtable *const table,
   1964                                          const u8 *const weights,
   1965                                          const int num_symbs) {
   1966     // +1 because the last weight is not transmitted in the header
   1967     if (num_symbs + 1 > HUF_MAX_SYMBS) {
   1968         ERROR("Too many symbols for Huffman");
   1969     }
   1970 
   1971     u8 bits[HUF_MAX_SYMBS];
   1972 
   1973     u64 weight_sum = 0;
   1974     for (int i = 0; i < num_symbs; i++) {
   1975         // Weights are in the same range as bit count
   1976         if (weights[i] > HUF_MAX_BITS) {
   1977             CORRUPTION();
   1978         }
   1979         weight_sum += weights[i] > 0 ? (u64)1 << (weights[i] - 1) : 0;
   1980     }
   1981 
   1982     // Find the first power of 2 larger than the sum
   1983     const int max_bits = highest_set_bit(weight_sum) + 1;
   1984     const u64 left_over = ((u64)1 << max_bits) - weight_sum;
   1985     // If the left over isn't a power of 2, the weights are invalid
   1986     if (left_over & (left_over - 1)) {
   1987         CORRUPTION();
   1988     }
   1989 
   1990     // left_over is used to find the last weight as it's not transmitted
   1991     // by inverting 2^(weight - 1) we can determine the value of last_weight
   1992     const int last_weight = highest_set_bit(left_over) + 1;
   1993 
   1994     for (int i = 0; i < num_symbs; i++) {
   1995         // "Number_of_Bits = Number_of_Bits ? Max_Number_of_Bits + 1 - Weight : 0"
   1996         bits[i] = weights[i] > 0 ? (max_bits + 1 - weights[i]) : 0;
   1997     }
   1998     bits[num_symbs] =
   1999         max_bits + 1 - last_weight; // Last weight is always non-zero
   2000 
   2001     HUF_init_dtable(table, bits, num_symbs + 1);
   2002 }
   2003 
   2004 static void HUF_free_dtable(HUF_dtable *const dtable) {
   2005     free(dtable->symbols);
   2006     free(dtable->num_bits);
   2007     memset(dtable, 0, sizeof(HUF_dtable));
   2008 }
   2009 /******* END HUFFMAN PRIMITIVES ***********************************************/
   2010 
   2011 /******* FSE PRIMITIVES *******************************************************/
   2012 /// For more description of FSE see
   2013 /// https://github.com/Cyan4973/FiniteStateEntropy/
   2014 
   2015 /// Allow a symbol to be decoded without updating state
   2016 static inline u8 FSE_peek_symbol(const FSE_dtable *const dtable,
   2017                                  const u16 state) {
   2018     return dtable->symbols[state];
   2019 }
   2020 
   2021 /// Consumes bits from the input and uses the current state to determine the
   2022 /// next state
   2023 static inline void FSE_update_state(const FSE_dtable *const dtable,
   2024                                     u16 *const state, const u8 *const src,
   2025                                     i64 *const offset) {
   2026     const u8 bits = dtable->num_bits[*state];
   2027     const u16 rest = STREAM_read_bits(src, bits, offset);
   2028     *state = dtable->new_state_base[*state] + rest;
   2029 }
   2030 
   2031 /// Decodes a single FSE symbol and updates the offset
   2032 static inline u8 FSE_decode_symbol(const FSE_dtable *const dtable,
   2033                                    u16 *const state, const u8 *const src,
   2034                                    i64 *const offset) {
   2035     const u8 symb = FSE_peek_symbol(dtable, *state);
   2036     FSE_update_state(dtable, state, src, offset);
   2037     return symb;
   2038 }
   2039 
   2040 static inline void FSE_init_state(const FSE_dtable *const dtable,
   2041                                   u16 *const state, const u8 *const src,
   2042                                   i64 *const offset) {
   2043     // Read in a full `accuracy_log` bits to initialize the state
   2044     const u8 bits = dtable->accuracy_log;
   2045     *state = STREAM_read_bits(src, bits, offset);
   2046 }
   2047 
   2048 static size_t FSE_decompress_interleaved2(const FSE_dtable *const dtable,
   2049                                           ostream_t *const out,
   2050                                           istream_t *const in) {
   2051     const size_t len = IO_istream_len(in);
   2052     if (len == 0) {
   2053         INP_SIZE();
   2054     }
   2055     const u8 *const src = IO_get_read_ptr(in, len);
   2056 
   2057     // "Each bitstream must be read backward, that is starting from the end down
   2058     // to the beginning. Therefore it's necessary to know the size of each
   2059     // bitstream.
   2060     //
   2061     // It's also necessary to know exactly which bit is the latest. This is
   2062     // detected by a final bit flag : the highest bit of latest byte is a
   2063     // final-bit-flag. Consequently, a last byte of 0 is not possible. And the
   2064     // final-bit-flag itself is not part of the useful bitstream. Hence, the
   2065     // last byte contains between 0 and 7 useful bits."
   2066     const int padding = 8 - highest_set_bit(src[len - 1]);
   2067     i64 offset = len * 8 - padding;
   2068 
   2069     u16 state1, state2;
   2070     // "The first state (State1) encodes the even indexed symbols, and the
   2071     // second (State2) encodes the odd indexes. State1 is initialized first, and
   2072     // then State2, and they take turns decoding a single symbol and updating
   2073     // their state."
   2074     FSE_init_state(dtable, &state1, src, &offset);
   2075     FSE_init_state(dtable, &state2, src, &offset);
   2076 
   2077     // Decode until we overflow the stream
   2078     // Since we decode in reverse order, overflowing the stream is offset going
   2079     // negative
   2080     size_t symbols_written = 0;
   2081     while (1) {
   2082         // "The number of symbols to decode is determined by tracking bitStream
   2083         // overflow condition: If updating state after decoding a symbol would
   2084         // require more bits than remain in the stream, it is assumed the extra
   2085         // bits are 0. Then, the symbols for each of the final states are
   2086         // decoded and the process is complete."
   2087         IO_write_byte(out, FSE_decode_symbol(dtable, &state1, src, &offset));
   2088         symbols_written++;
   2089         if (offset < 0) {
   2090             // There's still a symbol to decode in state2
   2091             IO_write_byte(out, FSE_peek_symbol(dtable, state2));
   2092             symbols_written++;
   2093             break;
   2094         }
   2095 
   2096         IO_write_byte(out, FSE_decode_symbol(dtable, &state2, src, &offset));
   2097         symbols_written++;
   2098         if (offset < 0) {
   2099             // There's still a symbol to decode in state1
   2100             IO_write_byte(out, FSE_peek_symbol(dtable, state1));
   2101             symbols_written++;
   2102             break;
   2103         }
   2104     }
   2105 
   2106     return symbols_written;
   2107 }
   2108 
   2109 static void FSE_init_dtable(FSE_dtable *const dtable,
   2110                             const i16 *const norm_freqs, const int num_symbs,
   2111                             const int accuracy_log) {
   2112     if (accuracy_log > FSE_MAX_ACCURACY_LOG) {
   2113         ERROR("FSE accuracy too large");
   2114     }
   2115     if (num_symbs > FSE_MAX_SYMBS) {
   2116         ERROR("Too many symbols for FSE");
   2117     }
   2118 
   2119     dtable->accuracy_log = accuracy_log;
   2120 
   2121     const size_t size = (size_t)1 << accuracy_log;
   2122     dtable->symbols = malloc(size * sizeof(u8));
   2123     dtable->num_bits = malloc(size * sizeof(u8));
   2124     dtable->new_state_base = malloc(size * sizeof(u16));
   2125 
   2126     if (!dtable->symbols || !dtable->num_bits || !dtable->new_state_base) {
   2127         BAD_ALLOC();
   2128     }
   2129 
   2130     // Used to determine how many bits need to be read for each state,
   2131     // and where the destination range should start
   2132     // Needs to be u16 because max value is 2 * max number of symbols,
   2133     // which can be larger than a byte can store
   2134     u16 state_desc[FSE_MAX_SYMBS];
   2135 
   2136     // "Symbols are scanned in their natural order for "less than 1"
   2137     // probabilities. Symbols with this probability are being attributed a
   2138     // single cell, starting from the end of the table. These symbols define a
   2139     // full state reset, reading Accuracy_Log bits."
   2140     int high_threshold = size;
   2141     for (int s = 0; s < num_symbs; s++) {
   2142         // Scan for low probability symbols to put at the top
   2143         if (norm_freqs[s] == -1) {
   2144             dtable->symbols[--high_threshold] = s;
   2145             state_desc[s] = 1;
   2146         }
   2147     }
   2148 
   2149     // "All remaining symbols are sorted in their natural order. Starting from
   2150     // symbol 0 and table position 0, each symbol gets attributed as many cells
   2151     // as its probability. Cell allocation is spread, not linear."
   2152     // Place the rest in the table
   2153     const u16 step = (size >> 1) + (size >> 3) + 3;
   2154     const u16 mask = size - 1;
   2155     u16 pos = 0;
   2156     for (int s = 0; s < num_symbs; s++) {
   2157         if (norm_freqs[s] <= 0) {
   2158             continue;
   2159         }
   2160 
   2161         state_desc[s] = norm_freqs[s];
   2162 
   2163         for (int i = 0; i < norm_freqs[s]; i++) {
   2164             // Give `norm_freqs[s]` states to symbol s
   2165             dtable->symbols[pos] = s;
   2166             // "A position is skipped if already occupied, typically by a "less
   2167             // than 1" probability symbol."
   2168             do {
   2169                 pos = (pos + step) & mask;
   2170             } while (pos >=
   2171                      high_threshold);
   2172             // Note: no other collision checking is necessary as `step` is
   2173             // coprime to `size`, so the cycle will visit each position exactly
   2174             // once
   2175         }
   2176     }
   2177     if (pos != 0) {
   2178         CORRUPTION();
   2179     }
   2180 
   2181     // Now we can fill baseline and num bits
   2182     for (size_t i = 0; i < size; i++) {
   2183         u8 symbol = dtable->symbols[i];
   2184         u16 next_state_desc = state_desc[symbol]++;
   2185         // Fills in the table appropriately, next_state_desc increases by symbol
   2186         // over time, decreasing number of bits
   2187         dtable->num_bits[i] = (u8)(accuracy_log - highest_set_bit(next_state_desc));
   2188         // Baseline increases until the bit threshold is passed, at which point
   2189         // it resets to 0
   2190         dtable->new_state_base[i] =
   2191             ((u16)next_state_desc << dtable->num_bits[i]) - size;
   2192     }
   2193 }
   2194 
   2195 /// Decode an FSE header as defined in the Zstandard format specification and
   2196 /// use the decoded frequencies to initialize a decoding table.
   2197 static void FSE_decode_header(FSE_dtable *const dtable, istream_t *const in,
   2198                                 const int max_accuracy_log) {
   2199     // "An FSE distribution table describes the probabilities of all symbols
   2200     // from 0 to the last present one (included) on a normalized scale of 1 <<
   2201     // Accuracy_Log .
   2202     //
   2203     // It's a bitstream which is read forward, in little-endian fashion. It's
   2204     // not necessary to know its exact size, since it will be discovered and
   2205     // reported by the decoding process.
   2206     if (max_accuracy_log > FSE_MAX_ACCURACY_LOG) {
   2207         ERROR("FSE accuracy too large");
   2208     }
   2209 
   2210     // The bitstream starts by reporting on which scale it operates.
   2211     // Accuracy_Log = low4bits + 5. Note that maximum Accuracy_Log for literal
   2212     // and match lengths is 9, and for offsets is 8. Higher values are
   2213     // considered errors."
   2214     const int accuracy_log = 5 + IO_read_bits(in, 4);
   2215     if (accuracy_log > max_accuracy_log) {
   2216         ERROR("FSE accuracy too large");
   2217     }
   2218 
   2219     // "Then follows each symbol value, from 0 to last present one. The number
   2220     // of bits used by each field is variable. It depends on :
   2221     //
   2222     // Remaining probabilities + 1 : example : Presuming an Accuracy_Log of 8,
   2223     // and presuming 100 probabilities points have already been distributed, the
   2224     // decoder may read any value from 0 to 255 - 100 + 1 == 156 (inclusive).
   2225     // Therefore, it must read log2sup(156) == 8 bits.
   2226     //
   2227     // Value decoded : small values use 1 less bit : example : Presuming values
   2228     // from 0 to 156 (inclusive) are possible, 255-156 = 99 values are remaining
   2229     // in an 8-bits field. They are used this way : first 99 values (hence from
   2230     // 0 to 98) use only 7 bits, values from 99 to 156 use 8 bits. "
   2231 
   2232     i32 remaining = 1 << accuracy_log;
   2233     i16 frequencies[FSE_MAX_SYMBS];
   2234 
   2235     int symb = 0;
   2236     while (remaining > 0 && symb < FSE_MAX_SYMBS) {
   2237         // Log of the number of possible values we could read
   2238         int bits = highest_set_bit(remaining + 1) + 1;
   2239 
   2240         u16 val = IO_read_bits(in, bits);
   2241 
   2242         // Try to mask out the lower bits to see if it qualifies for the "small
   2243         // value" threshold
   2244         const u16 lower_mask = ((u16)1 << (bits - 1)) - 1;
   2245         const u16 threshold = ((u16)1 << bits) - 1 - (remaining + 1);
   2246 
   2247         if ((val & lower_mask) < threshold) {
   2248             IO_rewind_bits(in, 1);
   2249             val = val & lower_mask;
   2250         } else if (val > lower_mask) {
   2251             val = val - threshold;
   2252         }
   2253 
   2254         // "Probability is obtained from Value decoded by following formula :
   2255         // Proba = value - 1"
   2256         const i16 proba = (i16)val - 1;
   2257 
   2258         // "It means value 0 becomes negative probability -1. -1 is a special
   2259         // probability, which means "less than 1". Its effect on distribution
   2260         // table is described in next paragraph. For the purpose of calculating
   2261         // cumulated distribution, it counts as one."
   2262         remaining -= proba < 0 ? -proba : proba;
   2263 
   2264         frequencies[symb] = proba;
   2265         symb++;
   2266 
   2267         // "When a symbol has a probability of zero, it is followed by a 2-bits
   2268         // repeat flag. This repeat flag tells how many probabilities of zeroes
   2269         // follow the current one. It provides a number ranging from 0 to 3. If
   2270         // it is a 3, another 2-bits repeat flag follows, and so on."
   2271         if (proba == 0) {
   2272             // Read the next two bits to see how many more 0s
   2273             int repeat = IO_read_bits(in, 2);
   2274 
   2275             while (1) {
   2276                 for (int i = 0; i < repeat && symb < FSE_MAX_SYMBS; i++) {
   2277                     frequencies[symb++] = 0;
   2278                 }
   2279                 if (repeat == 3) {
   2280                     repeat = IO_read_bits(in, 2);
   2281                 } else {
   2282                     break;
   2283                 }
   2284             }
   2285         }
   2286     }
   2287     IO_align_stream(in);
   2288 
   2289     // "When last symbol reaches cumulated total of 1 << Accuracy_Log, decoding
   2290     // is complete. If the last symbol makes cumulated total go above 1 <<
   2291     // Accuracy_Log, distribution is considered corrupted."
   2292     if (remaining != 0 || symb >= FSE_MAX_SYMBS) {
   2293         CORRUPTION();
   2294     }
   2295 
   2296     // Initialize the decoding table using the determined weights
   2297     FSE_init_dtable(dtable, frequencies, symb, accuracy_log);
   2298 }
   2299 
   2300 static void FSE_init_dtable_rle(FSE_dtable *const dtable, const u8 symb) {
   2301     dtable->symbols = malloc(sizeof(u8));
   2302     dtable->num_bits = malloc(sizeof(u8));
   2303     dtable->new_state_base = malloc(sizeof(u16));
   2304 
   2305     if (!dtable->symbols || !dtable->num_bits || !dtable->new_state_base) {
   2306         BAD_ALLOC();
   2307     }
   2308 
   2309     // This setup will always have a state of 0, always return symbol `symb`,
   2310     // and never consume any bits
   2311     dtable->symbols[0] = symb;
   2312     dtable->num_bits[0] = 0;
   2313     dtable->new_state_base[0] = 0;
   2314     dtable->accuracy_log = 0;
   2315 }
   2316 
   2317 static void FSE_free_dtable(FSE_dtable *const dtable) {
   2318     free(dtable->symbols);
   2319     free(dtable->num_bits);
   2320     free(dtable->new_state_base);
   2321     memset(dtable, 0, sizeof(FSE_dtable));
   2322 }
   2323 /******* END FSE PRIMITIVES ***************************************************/
   2324