1 1.1 christos /* ****************************************************************** 2 1.1 christos * Common functions of New Generation Entropy library 3 1.1 christos * Copyright (c) Meta Platforms, Inc. and affiliates. 4 1.1 christos * 5 1.1 christos * You can contact the author at : 6 1.1 christos * - FSE+HUF source repository : https://github.com/Cyan4973/FiniteStateEntropy 7 1.1 christos * - Public forum : https://groups.google.com/forum/#!forum/lz4c 8 1.1 christos * 9 1.1 christos * This source code is licensed under both the BSD-style license (found in the 10 1.1 christos * LICENSE file in the root directory of this source tree) and the GPLv2 (found 11 1.1 christos * in the COPYING file in the root directory of this source tree). 12 1.1 christos * You may select, at your option, one of the above-listed licenses. 13 1.1 christos ****************************************************************** */ 14 1.1 christos 15 1.1 christos /* ************************************* 16 1.1 christos * Dependencies 17 1.1 christos ***************************************/ 18 1.1 christos #include "mem.h" 19 1.1 christos #include "error_private.h" /* ERR_*, ERROR */ 20 1.1 christos #define FSE_STATIC_LINKING_ONLY /* FSE_MIN_TABLELOG */ 21 1.1 christos #include "fse.h" 22 1.1 christos #include "huf.h" 23 1.1 christos #include "bits.h" /* ZSDT_highbit32, ZSTD_countTrailingZeros32 */ 24 1.1 christos 25 1.1 christos 26 1.1 christos /*=== Version ===*/ 27 1.1 christos unsigned FSE_versionNumber(void) { return FSE_VERSION_NUMBER; } 28 1.1 christos 29 1.1 christos 30 1.1 christos /*=== Error Management ===*/ 31 1.1 christos unsigned FSE_isError(size_t code) { return ERR_isError(code); } 32 1.1 christos const char* FSE_getErrorName(size_t code) { return ERR_getErrorName(code); } 33 1.1 christos 34 1.1 christos unsigned HUF_isError(size_t code) { return ERR_isError(code); } 35 1.1 christos const char* HUF_getErrorName(size_t code) { return ERR_getErrorName(code); } 36 1.1 christos 37 1.1 christos 38 1.1 christos /*-************************************************************** 39 1.1 christos * FSE NCount encoding-decoding 40 1.1 christos ****************************************************************/ 41 1.1 christos FORCE_INLINE_TEMPLATE 42 1.1 christos size_t FSE_readNCount_body(short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr, 43 1.1 christos const void* headerBuffer, size_t hbSize) 44 1.1 christos { 45 1.1 christos const BYTE* const istart = (const BYTE*) headerBuffer; 46 1.1 christos const BYTE* const iend = istart + hbSize; 47 1.1 christos const BYTE* ip = istart; 48 1.1 christos int nbBits; 49 1.1 christos int remaining; 50 1.1 christos int threshold; 51 1.1 christos U32 bitStream; 52 1.1 christos int bitCount; 53 1.1 christos unsigned charnum = 0; 54 1.1 christos unsigned const maxSV1 = *maxSVPtr + 1; 55 1.1 christos int previous0 = 0; 56 1.1 christos 57 1.1 christos if (hbSize < 8) { 58 1.1 christos /* This function only works when hbSize >= 8 */ 59 1.1 christos char buffer[8] = {0}; 60 1.1 christos ZSTD_memcpy(buffer, headerBuffer, hbSize); 61 1.1 christos { size_t const countSize = FSE_readNCount(normalizedCounter, maxSVPtr, tableLogPtr, 62 1.1 christos buffer, sizeof(buffer)); 63 1.1 christos if (FSE_isError(countSize)) return countSize; 64 1.1 christos if (countSize > hbSize) return ERROR(corruption_detected); 65 1.1 christos return countSize; 66 1.1 christos } } 67 1.1 christos assert(hbSize >= 8); 68 1.1 christos 69 1.1 christos /* init */ 70 1.1 christos ZSTD_memset(normalizedCounter, 0, (*maxSVPtr+1) * sizeof(normalizedCounter[0])); /* all symbols not present in NCount have a frequency of 0 */ 71 1.1 christos bitStream = MEM_readLE32(ip); 72 1.1 christos nbBits = (bitStream & 0xF) + FSE_MIN_TABLELOG; /* extract tableLog */ 73 1.1 christos if (nbBits > FSE_TABLELOG_ABSOLUTE_MAX) return ERROR(tableLog_tooLarge); 74 1.1 christos bitStream >>= 4; 75 1.1 christos bitCount = 4; 76 1.1 christos *tableLogPtr = nbBits; 77 1.1 christos remaining = (1<<nbBits)+1; 78 1.1 christos threshold = 1<<nbBits; 79 1.1 christos nbBits++; 80 1.1 christos 81 1.1 christos for (;;) { 82 1.1 christos if (previous0) { 83 1.1 christos /* Count the number of repeats. Each time the 84 1.1 christos * 2-bit repeat code is 0b11 there is another 85 1.1 christos * repeat. 86 1.1 christos * Avoid UB by setting the high bit to 1. 87 1.1 christos */ 88 1.1 christos int repeats = ZSTD_countTrailingZeros32(~bitStream | 0x80000000) >> 1; 89 1.1 christos while (repeats >= 12) { 90 1.1 christos charnum += 3 * 12; 91 1.1 christos if (LIKELY(ip <= iend-7)) { 92 1.1 christos ip += 3; 93 1.1 christos } else { 94 1.1 christos bitCount -= (int)(8 * (iend - 7 - ip)); 95 1.1 christos bitCount &= 31; 96 1.1 christos ip = iend - 4; 97 1.1 christos } 98 1.1 christos bitStream = MEM_readLE32(ip) >> bitCount; 99 1.1 christos repeats = ZSTD_countTrailingZeros32(~bitStream | 0x80000000) >> 1; 100 1.1 christos } 101 1.1 christos charnum += 3 * repeats; 102 1.1 christos bitStream >>= 2 * repeats; 103 1.1 christos bitCount += 2 * repeats; 104 1.1 christos 105 1.1 christos /* Add the final repeat which isn't 0b11. */ 106 1.1 christos assert((bitStream & 3) < 3); 107 1.1 christos charnum += bitStream & 3; 108 1.1 christos bitCount += 2; 109 1.1 christos 110 1.1 christos /* This is an error, but break and return an error 111 1.1 christos * at the end, because returning out of a loop makes 112 1.1 christos * it harder for the compiler to optimize. 113 1.1 christos */ 114 1.1 christos if (charnum >= maxSV1) break; 115 1.1 christos 116 1.1 christos /* We don't need to set the normalized count to 0 117 1.1 christos * because we already memset the whole buffer to 0. 118 1.1 christos */ 119 1.1 christos 120 1.1 christos if (LIKELY(ip <= iend-7) || (ip + (bitCount>>3) <= iend-4)) { 121 1.1 christos assert((bitCount >> 3) <= 3); /* For first condition to work */ 122 1.1 christos ip += bitCount>>3; 123 1.1 christos bitCount &= 7; 124 1.1 christos } else { 125 1.1 christos bitCount -= (int)(8 * (iend - 4 - ip)); 126 1.1 christos bitCount &= 31; 127 1.1 christos ip = iend - 4; 128 1.1 christos } 129 1.1 christos bitStream = MEM_readLE32(ip) >> bitCount; 130 1.1 christos } 131 1.1 christos { 132 1.1 christos int const max = (2*threshold-1) - remaining; 133 1.1 christos int count; 134 1.1 christos 135 1.1 christos if ((bitStream & (threshold-1)) < (U32)max) { 136 1.1 christos count = bitStream & (threshold-1); 137 1.1 christos bitCount += nbBits-1; 138 1.1 christos } else { 139 1.1 christos count = bitStream & (2*threshold-1); 140 1.1 christos if (count >= threshold) count -= max; 141 1.1 christos bitCount += nbBits; 142 1.1 christos } 143 1.1 christos 144 1.1 christos count--; /* extra accuracy */ 145 1.1 christos /* When it matters (small blocks), this is a 146 1.1 christos * predictable branch, because we don't use -1. 147 1.1 christos */ 148 1.1 christos if (count >= 0) { 149 1.1 christos remaining -= count; 150 1.1 christos } else { 151 1.1 christos assert(count == -1); 152 1.1 christos remaining += count; 153 1.1 christos } 154 1.1 christos normalizedCounter[charnum++] = (short)count; 155 1.1 christos previous0 = !count; 156 1.1 christos 157 1.1 christos assert(threshold > 1); 158 1.1 christos if (remaining < threshold) { 159 1.1 christos /* This branch can be folded into the 160 1.1 christos * threshold update condition because we 161 1.1 christos * know that threshold > 1. 162 1.1 christos */ 163 1.1 christos if (remaining <= 1) break; 164 1.1 christos nbBits = ZSTD_highbit32(remaining) + 1; 165 1.1 christos threshold = 1 << (nbBits - 1); 166 1.1 christos } 167 1.1 christos if (charnum >= maxSV1) break; 168 1.1 christos 169 1.1 christos if (LIKELY(ip <= iend-7) || (ip + (bitCount>>3) <= iend-4)) { 170 1.1 christos ip += bitCount>>3; 171 1.1 christos bitCount &= 7; 172 1.1 christos } else { 173 1.1 christos bitCount -= (int)(8 * (iend - 4 - ip)); 174 1.1 christos bitCount &= 31; 175 1.1 christos ip = iend - 4; 176 1.1 christos } 177 1.1 christos bitStream = MEM_readLE32(ip) >> bitCount; 178 1.1 christos } } 179 1.1 christos if (remaining != 1) return ERROR(corruption_detected); 180 1.1 christos /* Only possible when there are too many zeros. */ 181 1.1 christos if (charnum > maxSV1) return ERROR(maxSymbolValue_tooSmall); 182 1.1 christos if (bitCount > 32) return ERROR(corruption_detected); 183 1.1 christos *maxSVPtr = charnum-1; 184 1.1 christos 185 1.1 christos ip += (bitCount+7)>>3; 186 1.1 christos return ip-istart; 187 1.1 christos } 188 1.1 christos 189 1.1 christos /* Avoids the FORCE_INLINE of the _body() function. */ 190 1.1 christos static size_t FSE_readNCount_body_default( 191 1.1 christos short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr, 192 1.1 christos const void* headerBuffer, size_t hbSize) 193 1.1 christos { 194 1.1 christos return FSE_readNCount_body(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize); 195 1.1 christos } 196 1.1 christos 197 1.1 christos #if DYNAMIC_BMI2 198 1.1 christos BMI2_TARGET_ATTRIBUTE static size_t FSE_readNCount_body_bmi2( 199 1.1 christos short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr, 200 1.1 christos const void* headerBuffer, size_t hbSize) 201 1.1 christos { 202 1.1 christos return FSE_readNCount_body(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize); 203 1.1 christos } 204 1.1 christos #endif 205 1.1 christos 206 1.1 christos size_t FSE_readNCount_bmi2( 207 1.1 christos short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr, 208 1.1 christos const void* headerBuffer, size_t hbSize, int bmi2) 209 1.1 christos { 210 1.1 christos #if DYNAMIC_BMI2 211 1.1 christos if (bmi2) { 212 1.1 christos return FSE_readNCount_body_bmi2(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize); 213 1.1 christos } 214 1.1 christos #endif 215 1.1 christos (void)bmi2; 216 1.1 christos return FSE_readNCount_body_default(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize); 217 1.1 christos } 218 1.1 christos 219 1.1 christos size_t FSE_readNCount( 220 1.1 christos short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr, 221 1.1 christos const void* headerBuffer, size_t hbSize) 222 1.1 christos { 223 1.1 christos return FSE_readNCount_bmi2(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize, /* bmi2 */ 0); 224 1.1 christos } 225 1.1 christos 226 1.1 christos 227 1.1 christos /*! HUF_readStats() : 228 1.1 christos Read compact Huffman tree, saved by HUF_writeCTable(). 229 1.1 christos `huffWeight` is destination buffer. 230 1.1 christos `rankStats` is assumed to be a table of at least HUF_TABLELOG_MAX U32. 231 1.1 christos @return : size read from `src` , or an error Code . 232 1.1 christos Note : Needed by HUF_readCTable() and HUF_readDTableX?() . 233 1.1 christos */ 234 1.1 christos size_t HUF_readStats(BYTE* huffWeight, size_t hwSize, U32* rankStats, 235 1.1 christos U32* nbSymbolsPtr, U32* tableLogPtr, 236 1.1 christos const void* src, size_t srcSize) 237 1.1 christos { 238 1.1 christos U32 wksp[HUF_READ_STATS_WORKSPACE_SIZE_U32]; 239 1.1 christos return HUF_readStats_wksp(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, wksp, sizeof(wksp), /* flags */ 0); 240 1.1 christos } 241 1.1 christos 242 1.1 christos FORCE_INLINE_TEMPLATE size_t 243 1.1 christos HUF_readStats_body(BYTE* huffWeight, size_t hwSize, U32* rankStats, 244 1.1 christos U32* nbSymbolsPtr, U32* tableLogPtr, 245 1.1 christos const void* src, size_t srcSize, 246 1.1 christos void* workSpace, size_t wkspSize, 247 1.1 christos int bmi2) 248 1.1 christos { 249 1.1 christos U32 weightTotal; 250 1.1 christos const BYTE* ip = (const BYTE*) src; 251 1.1 christos size_t iSize; 252 1.1 christos size_t oSize; 253 1.1 christos 254 1.1 christos if (!srcSize) return ERROR(srcSize_wrong); 255 1.1 christos iSize = ip[0]; 256 1.1 christos /* ZSTD_memset(huffWeight, 0, hwSize); *//* is not necessary, even though some analyzer complain ... */ 257 1.1 christos 258 1.1 christos if (iSize >= 128) { /* special header */ 259 1.1 christos oSize = iSize - 127; 260 1.1 christos iSize = ((oSize+1)/2); 261 1.1 christos if (iSize+1 > srcSize) return ERROR(srcSize_wrong); 262 1.1 christos if (oSize >= hwSize) return ERROR(corruption_detected); 263 1.1 christos ip += 1; 264 1.1 christos { U32 n; 265 1.1 christos for (n=0; n<oSize; n+=2) { 266 1.1 christos huffWeight[n] = ip[n/2] >> 4; 267 1.1 christos huffWeight[n+1] = ip[n/2] & 15; 268 1.1 christos } } } 269 1.1 christos else { /* header compressed with FSE (normal case) */ 270 1.1 christos if (iSize+1 > srcSize) return ERROR(srcSize_wrong); 271 1.1 christos /* max (hwSize-1) values decoded, as last one is implied */ 272 1.1 christos oSize = FSE_decompress_wksp_bmi2(huffWeight, hwSize-1, ip+1, iSize, 6, workSpace, wkspSize, bmi2); 273 1.1 christos if (FSE_isError(oSize)) return oSize; 274 1.1 christos } 275 1.1 christos 276 1.1 christos /* collect weight stats */ 277 1.1 christos ZSTD_memset(rankStats, 0, (HUF_TABLELOG_MAX + 1) * sizeof(U32)); 278 1.1 christos weightTotal = 0; 279 1.1 christos { U32 n; for (n=0; n<oSize; n++) { 280 1.1 christos if (huffWeight[n] > HUF_TABLELOG_MAX) return ERROR(corruption_detected); 281 1.1 christos rankStats[huffWeight[n]]++; 282 1.1 christos weightTotal += (1 << huffWeight[n]) >> 1; 283 1.1 christos } } 284 1.1 christos if (weightTotal == 0) return ERROR(corruption_detected); 285 1.1 christos 286 1.1 christos /* get last non-null symbol weight (implied, total must be 2^n) */ 287 1.1 christos { U32 const tableLog = ZSTD_highbit32(weightTotal) + 1; 288 1.1 christos if (tableLog > HUF_TABLELOG_MAX) return ERROR(corruption_detected); 289 1.1 christos *tableLogPtr = tableLog; 290 1.1 christos /* determine last weight */ 291 1.1 christos { U32 const total = 1 << tableLog; 292 1.1 christos U32 const rest = total - weightTotal; 293 1.1 christos U32 const verif = 1 << ZSTD_highbit32(rest); 294 1.1 christos U32 const lastWeight = ZSTD_highbit32(rest) + 1; 295 1.1 christos if (verif != rest) return ERROR(corruption_detected); /* last value must be a clean power of 2 */ 296 1.1 christos huffWeight[oSize] = (BYTE)lastWeight; 297 1.1 christos rankStats[lastWeight]++; 298 1.1 christos } } 299 1.1 christos 300 1.1 christos /* check tree construction validity */ 301 1.1 christos if ((rankStats[1] < 2) || (rankStats[1] & 1)) return ERROR(corruption_detected); /* by construction : at least 2 elts of rank 1, must be even */ 302 1.1 christos 303 1.1 christos /* results */ 304 1.1 christos *nbSymbolsPtr = (U32)(oSize+1); 305 1.1 christos return iSize+1; 306 1.1 christos } 307 1.1 christos 308 1.1 christos /* Avoids the FORCE_INLINE of the _body() function. */ 309 1.1 christos static size_t HUF_readStats_body_default(BYTE* huffWeight, size_t hwSize, U32* rankStats, 310 1.1 christos U32* nbSymbolsPtr, U32* tableLogPtr, 311 1.1 christos const void* src, size_t srcSize, 312 1.1 christos void* workSpace, size_t wkspSize) 313 1.1 christos { 314 1.1 christos return HUF_readStats_body(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize, 0); 315 1.1 christos } 316 1.1 christos 317 1.1 christos #if DYNAMIC_BMI2 318 1.1 christos static BMI2_TARGET_ATTRIBUTE size_t HUF_readStats_body_bmi2(BYTE* huffWeight, size_t hwSize, U32* rankStats, 319 1.1 christos U32* nbSymbolsPtr, U32* tableLogPtr, 320 1.1 christos const void* src, size_t srcSize, 321 1.1 christos void* workSpace, size_t wkspSize) 322 1.1 christos { 323 1.1 christos return HUF_readStats_body(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize, 1); 324 1.1 christos } 325 1.1 christos #endif 326 1.1 christos 327 1.1 christos size_t HUF_readStats_wksp(BYTE* huffWeight, size_t hwSize, U32* rankStats, 328 1.1 christos U32* nbSymbolsPtr, U32* tableLogPtr, 329 1.1 christos const void* src, size_t srcSize, 330 1.1 christos void* workSpace, size_t wkspSize, 331 1.1 christos int flags) 332 1.1 christos { 333 1.1 christos #if DYNAMIC_BMI2 334 1.1 christos if (flags & HUF_flags_bmi2) { 335 1.1 christos return HUF_readStats_body_bmi2(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize); 336 1.1 christos } 337 1.1 christos #endif 338 1.1 christos (void)flags; 339 1.1 christos return HUF_readStats_body_default(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize); 340 1.1 christos } 341