Home | History | Annotate | Line # | Download | only in regression
method.c revision 1.1
      1 /*
      2  * Copyright (c) Meta Platforms, Inc. and affiliates.
      3  * All rights reserved.
      4  *
      5  * This source code is licensed under both the BSD-style license (found in the
      6  * LICENSE file in the root directory of this source tree) and the GPLv2 (found
      7  * in the COPYING file in the root directory of this source tree).
      8  * You may select, at your option, one of the above-listed licenses.
      9  */
     10 
     11 #include "method.h"
     12 
     13 #include <stdio.h>
     14 #include <stdlib.h>
     15 
     16 #define ZSTD_STATIC_LINKING_ONLY
     17 #include <zstd.h>
     18 
     19 #define MIN(x, y) ((x) < (y) ? (x) : (y))
     20 
     21 static char const* g_zstdcli = NULL;
     22 
     23 void method_set_zstdcli(char const* zstdcli) {
     24     g_zstdcli = zstdcli;
     25 }
     26 
     27 /**
     28  * Macro to get a pointer of type, given ptr, which is a member variable with
     29  * the given name, member.
     30  *
     31  *     method_state_t* base = ...;
     32  *     buffer_state_t* state = container_of(base, buffer_state_t, base);
     33  */
     34 #define container_of(ptr, type, member) \
     35     ((type*)(ptr == NULL ? NULL : (char*)(ptr)-offsetof(type, member)))
     36 
     37 /** State to reuse the same buffers between compression calls. */
     38 typedef struct {
     39     method_state_t base;
     40     data_buffers_t inputs; /**< The input buffer for each file. */
     41     data_buffer_t dictionary; /**< The dictionary. */
     42     data_buffer_t compressed; /**< The compressed data buffer. */
     43     data_buffer_t decompressed; /**< The decompressed data buffer. */
     44 } buffer_state_t;
     45 
     46 static size_t buffers_max_size(data_buffers_t buffers) {
     47     size_t max = 0;
     48     for (size_t i = 0; i < buffers.size; ++i) {
     49         if (buffers.buffers[i].size > max)
     50             max = buffers.buffers[i].size;
     51     }
     52     return max;
     53 }
     54 
     55 static method_state_t* buffer_state_create(data_t const* data) {
     56     buffer_state_t* state = (buffer_state_t*)calloc(1, sizeof(buffer_state_t));
     57     if (state == NULL)
     58         return NULL;
     59     state->base.data = data;
     60     state->inputs = data_buffers_get(data);
     61     state->dictionary = data_buffer_get_dict(data);
     62     size_t const max_size = buffers_max_size(state->inputs);
     63     state->compressed = data_buffer_create(ZSTD_compressBound(max_size));
     64     state->decompressed = data_buffer_create(max_size);
     65     return &state->base;
     66 }
     67 
     68 static void buffer_state_destroy(method_state_t* base) {
     69     if (base == NULL)
     70         return;
     71     buffer_state_t* state = container_of(base, buffer_state_t, base);
     72     free(state);
     73 }
     74 
     75 static int buffer_state_bad(
     76     buffer_state_t const* state,
     77     config_t const* config) {
     78     if (state == NULL) {
     79         fprintf(stderr, "buffer_state_t is NULL\n");
     80         return 1;
     81     }
     82     if (state->inputs.size == 0 || state->compressed.data == NULL ||
     83         state->decompressed.data == NULL) {
     84         fprintf(stderr, "buffer state allocation failure\n");
     85         return 1;
     86     }
     87     if (config->use_dictionary && state->dictionary.data == NULL) {
     88         fprintf(stderr, "dictionary loading failed\n");
     89         return 1;
     90     }
     91     return 0;
     92 }
     93 
     94 static result_t simple_compress(method_state_t* base, config_t const* config) {
     95     buffer_state_t* state = container_of(base, buffer_state_t, base);
     96 
     97     if (buffer_state_bad(state, config))
     98         return result_error(result_error_system_error);
     99 
    100     /* Keep the tests short by skipping directories, since behavior shouldn't
    101      * change.
    102      */
    103     if (base->data->type != data_type_file)
    104         return result_error(result_error_skip);
    105 
    106     if (config->advanced_api_only)
    107         return result_error(result_error_skip);
    108 
    109     if (config->use_dictionary || config->no_pledged_src_size)
    110         return result_error(result_error_skip);
    111 
    112     /* If the config doesn't specify a level, skip. */
    113     int const level = config_get_level(config);
    114     if (level == CONFIG_NO_LEVEL)
    115         return result_error(result_error_skip);
    116 
    117     data_buffer_t const input = state->inputs.buffers[0];
    118 
    119     /* Compress, decompress, and check the result. */
    120     state->compressed.size = ZSTD_compress(
    121         state->compressed.data,
    122         state->compressed.capacity,
    123         input.data,
    124         input.size,
    125         level);
    126     if (ZSTD_isError(state->compressed.size))
    127         return result_error(result_error_compression_error);
    128 
    129     state->decompressed.size = ZSTD_decompress(
    130         state->decompressed.data,
    131         state->decompressed.capacity,
    132         state->compressed.data,
    133         state->compressed.size);
    134     if (ZSTD_isError(state->decompressed.size))
    135         return result_error(result_error_decompression_error);
    136     if (data_buffer_compare(input, state->decompressed))
    137         return result_error(result_error_round_trip_error);
    138 
    139     result_data_t data;
    140     data.total_size = state->compressed.size;
    141     return result_data(data);
    142 }
    143 
    144 static result_t compress_cctx_compress(
    145     method_state_t* base,
    146     config_t const* config) {
    147     buffer_state_t* state = container_of(base, buffer_state_t, base);
    148 
    149     if (buffer_state_bad(state, config))
    150         return result_error(result_error_system_error);
    151 
    152     if (config->no_pledged_src_size)
    153         return result_error(result_error_skip);
    154 
    155     if (base->data->type != data_type_dir)
    156         return result_error(result_error_skip);
    157 
    158     if (config->advanced_api_only)
    159         return result_error(result_error_skip);
    160 
    161     int const level = config_get_level(config);
    162 
    163     ZSTD_CCtx* cctx = ZSTD_createCCtx();
    164     ZSTD_DCtx* dctx = ZSTD_createDCtx();
    165     if (cctx == NULL || dctx == NULL) {
    166         fprintf(stderr, "context creation failed\n");
    167         return result_error(result_error_system_error);
    168     }
    169 
    170     result_t result;
    171     result_data_t data = {.total_size = 0};
    172     for (size_t i = 0; i < state->inputs.size; ++i) {
    173         data_buffer_t const input = state->inputs.buffers[i];
    174         ZSTD_parameters const params =
    175             config_get_zstd_params(config, input.size, state->dictionary.size);
    176 
    177         if (level == CONFIG_NO_LEVEL)
    178             state->compressed.size = ZSTD_compress_advanced(
    179                 cctx,
    180                 state->compressed.data,
    181                 state->compressed.capacity,
    182                 input.data,
    183                 input.size,
    184                 config->use_dictionary ? state->dictionary.data : NULL,
    185                 config->use_dictionary ? state->dictionary.size : 0,
    186                 params);
    187         else if (config->use_dictionary)
    188             state->compressed.size = ZSTD_compress_usingDict(
    189                 cctx,
    190                 state->compressed.data,
    191                 state->compressed.capacity,
    192                 input.data,
    193                 input.size,
    194                 state->dictionary.data,
    195                 state->dictionary.size,
    196                 level);
    197         else
    198             state->compressed.size = ZSTD_compressCCtx(
    199                 cctx,
    200                 state->compressed.data,
    201                 state->compressed.capacity,
    202                 input.data,
    203                 input.size,
    204                 level);
    205 
    206         if (ZSTD_isError(state->compressed.size)) {
    207             result = result_error(result_error_compression_error);
    208             goto out;
    209         }
    210 
    211         if (config->use_dictionary)
    212             state->decompressed.size = ZSTD_decompress_usingDict(
    213                 dctx,
    214                 state->decompressed.data,
    215                 state->decompressed.capacity,
    216                 state->compressed.data,
    217                 state->compressed.size,
    218                 state->dictionary.data,
    219                 state->dictionary.size);
    220         else
    221             state->decompressed.size = ZSTD_decompressDCtx(
    222                 dctx,
    223                 state->decompressed.data,
    224                 state->decompressed.capacity,
    225                 state->compressed.data,
    226                 state->compressed.size);
    227         if (ZSTD_isError(state->decompressed.size)) {
    228             result = result_error(result_error_decompression_error);
    229             goto out;
    230         }
    231         if (data_buffer_compare(input, state->decompressed)) {
    232             result = result_error(result_error_round_trip_error);
    233             goto out;
    234         }
    235 
    236         data.total_size += state->compressed.size;
    237     }
    238 
    239     result = result_data(data);
    240 out:
    241     ZSTD_freeCCtx(cctx);
    242     ZSTD_freeDCtx(dctx);
    243     return result;
    244 }
    245 
    246 /** Generic state creation function. */
    247 static method_state_t* method_state_create(data_t const* data) {
    248     method_state_t* state = (method_state_t*)malloc(sizeof(method_state_t));
    249     if (state == NULL)
    250         return NULL;
    251     state->data = data;
    252     return state;
    253 }
    254 
    255 static void method_state_destroy(method_state_t* state) {
    256     free(state);
    257 }
    258 
    259 static result_t cli_compress(method_state_t* state, config_t const* config) {
    260     if (config->cli_args == NULL)
    261         return result_error(result_error_skip);
    262 
    263     if (config->advanced_api_only)
    264         return result_error(result_error_skip);
    265 
    266     /* We don't support no pledged source size with directories. Too slow. */
    267     if (state->data->type == data_type_dir && config->no_pledged_src_size)
    268         return result_error(result_error_skip);
    269 
    270     if (g_zstdcli == NULL)
    271         return result_error(result_error_system_error);
    272 
    273     /* '<zstd>' -cqr <args> [-D '<dict>'] '<file/dir>' */
    274     char cmd[1024];
    275     size_t const cmd_size = snprintf(
    276         cmd,
    277         sizeof(cmd),
    278         "'%s' -cqr %s %s%s%s %s '%s'",
    279         g_zstdcli,
    280         config->cli_args,
    281         config->use_dictionary ? "-D '" : "",
    282         config->use_dictionary ? state->data->dict.path : "",
    283         config->use_dictionary ? "'" : "",
    284         config->no_pledged_src_size ? "<" : "",
    285         state->data->data.path);
    286     if (cmd_size >= sizeof(cmd)) {
    287         fprintf(stderr, "command too large: %s\n", cmd);
    288         return result_error(result_error_system_error);
    289     }
    290     FILE* zstd = popen(cmd, "r");
    291     if (zstd == NULL) {
    292         fprintf(stderr, "failed to popen command: %s\n", cmd);
    293         return result_error(result_error_system_error);
    294     }
    295 
    296     char out[4096];
    297     size_t total_size = 0;
    298     while (1) {
    299         size_t const size = fread(out, 1, sizeof(out), zstd);
    300         total_size += size;
    301         if (size != sizeof(out))
    302             break;
    303     }
    304     if (ferror(zstd) || pclose(zstd) != 0) {
    305         fprintf(stderr, "zstd failed with command: %s\n", cmd);
    306         return result_error(result_error_compression_error);
    307     }
    308 
    309     result_data_t const data = {.total_size = total_size};
    310     return result_data(data);
    311 }
    312 
    313 static int advanced_config(
    314     ZSTD_CCtx* cctx,
    315     buffer_state_t* state,
    316     config_t const* config) {
    317     ZSTD_CCtx_reset(cctx, ZSTD_reset_session_and_parameters);
    318     for (size_t p = 0; p < config->param_values.size; ++p) {
    319         param_value_t const pv = config->param_values.data[p];
    320         if (ZSTD_isError(ZSTD_CCtx_setParameter(cctx, pv.param, pv.value))) {
    321             return 1;
    322         }
    323     }
    324     if (config->use_dictionary) {
    325         if (ZSTD_isError(ZSTD_CCtx_loadDictionary(
    326                 cctx, state->dictionary.data, state->dictionary.size))) {
    327             return 1;
    328         }
    329     }
    330     return 0;
    331 }
    332 
    333 static result_t advanced_one_pass_compress_output_adjustment(
    334     method_state_t* base,
    335     config_t const* config,
    336     size_t const subtract) {
    337     buffer_state_t* state = container_of(base, buffer_state_t, base);
    338 
    339     if (buffer_state_bad(state, config))
    340         return result_error(result_error_system_error);
    341 
    342     ZSTD_CCtx* cctx = ZSTD_createCCtx();
    343     result_t result;
    344 
    345     if (!cctx || advanced_config(cctx, state, config)) {
    346         result = result_error(result_error_compression_error);
    347         goto out;
    348     }
    349 
    350     result_data_t data = {.total_size = 0};
    351     for (size_t i = 0; i < state->inputs.size; ++i) {
    352         data_buffer_t const input = state->inputs.buffers[i];
    353 
    354         if (!config->no_pledged_src_size) {
    355             if (ZSTD_isError(ZSTD_CCtx_setPledgedSrcSize(cctx, input.size))) {
    356                 result = result_error(result_error_compression_error);
    357                 goto out;
    358             }
    359         }
    360         size_t const size = ZSTD_compress2(
    361             cctx,
    362             state->compressed.data,
    363             ZSTD_compressBound(input.size) - subtract,
    364             input.data,
    365             input.size);
    366         if (ZSTD_isError(size)) {
    367             result = result_error(result_error_compression_error);
    368             goto out;
    369         }
    370         data.total_size += size;
    371     }
    372 
    373     result = result_data(data);
    374 out:
    375     ZSTD_freeCCtx(cctx);
    376     return result;
    377 }
    378 
    379 static result_t advanced_one_pass_compress(
    380     method_state_t* base,
    381     config_t const* config) {
    382   return advanced_one_pass_compress_output_adjustment(base, config, 0);
    383 }
    384 
    385 static result_t advanced_one_pass_compress_small_output(
    386     method_state_t* base,
    387     config_t const* config) {
    388   return advanced_one_pass_compress_output_adjustment(base, config, 1);
    389 }
    390 
    391 static result_t advanced_streaming_compress(
    392     method_state_t* base,
    393     config_t const* config) {
    394     buffer_state_t* state = container_of(base, buffer_state_t, base);
    395 
    396     if (buffer_state_bad(state, config))
    397         return result_error(result_error_system_error);
    398 
    399     ZSTD_CCtx* cctx = ZSTD_createCCtx();
    400     result_t result;
    401 
    402     if (!cctx || advanced_config(cctx, state, config)) {
    403         result = result_error(result_error_compression_error);
    404         goto out;
    405     }
    406 
    407     result_data_t data = {.total_size = 0};
    408     for (size_t i = 0; i < state->inputs.size; ++i) {
    409         data_buffer_t input = state->inputs.buffers[i];
    410 
    411         if (!config->no_pledged_src_size) {
    412             if (ZSTD_isError(ZSTD_CCtx_setPledgedSrcSize(cctx, input.size))) {
    413                 result = result_error(result_error_compression_error);
    414                 goto out;
    415             }
    416         }
    417 
    418         while (input.size > 0) {
    419             ZSTD_inBuffer in = {input.data, MIN(input.size, 4096)};
    420             input.data += in.size;
    421             input.size -= in.size;
    422             ZSTD_EndDirective const op =
    423                 input.size > 0 ? ZSTD_e_continue : ZSTD_e_end;
    424             size_t ret = 0;
    425             while (in.pos < in.size || (op == ZSTD_e_end && ret != 0)) {
    426                 ZSTD_outBuffer out = {state->compressed.data,
    427                                       MIN(state->compressed.capacity, 1024)};
    428                 ret = ZSTD_compressStream2(cctx, &out, &in, op);
    429                 if (ZSTD_isError(ret)) {
    430                     result = result_error(result_error_compression_error);
    431                     goto out;
    432                 }
    433                 data.total_size += out.pos;
    434             }
    435         }
    436     }
    437 
    438     result = result_data(data);
    439 out:
    440     ZSTD_freeCCtx(cctx);
    441     return result;
    442 }
    443 
    444 static int init_cstream(
    445     buffer_state_t* state,
    446     ZSTD_CStream* zcs,
    447     config_t const* config,
    448     int const advanced,
    449     ZSTD_CDict** cdict)
    450 {
    451     size_t zret;
    452     if (advanced) {
    453         ZSTD_parameters const params = config_get_zstd_params(config, 0, 0);
    454         ZSTD_CDict* dict = NULL;
    455         if (cdict) {
    456             if (!config->use_dictionary)
    457               return 1;
    458             *cdict = ZSTD_createCDict_advanced(
    459                 state->dictionary.data,
    460                 state->dictionary.size,
    461                 ZSTD_dlm_byRef,
    462                 ZSTD_dct_auto,
    463                 params.cParams,
    464                 ZSTD_defaultCMem);
    465             if (!*cdict) {
    466                 return 1;
    467             }
    468             zret = ZSTD_initCStream_usingCDict_advanced(
    469                 zcs, *cdict, params.fParams, ZSTD_CONTENTSIZE_UNKNOWN);
    470         } else {
    471             zret = ZSTD_initCStream_advanced(
    472                 zcs,
    473                 config->use_dictionary ? state->dictionary.data : NULL,
    474                 config->use_dictionary ? state->dictionary.size : 0,
    475                 params,
    476                 ZSTD_CONTENTSIZE_UNKNOWN);
    477         }
    478     } else {
    479         int const level = config_get_level(config);
    480         if (level == CONFIG_NO_LEVEL)
    481             return 1;
    482         if (cdict) {
    483             if (!config->use_dictionary)
    484               return 1;
    485             *cdict = ZSTD_createCDict(
    486                 state->dictionary.data,
    487                 state->dictionary.size,
    488                 level);
    489             if (!*cdict) {
    490                 return 1;
    491             }
    492             zret = ZSTD_initCStream_usingCDict(zcs, *cdict);
    493         } else if (config->use_dictionary) {
    494             zret = ZSTD_initCStream_usingDict(
    495                 zcs,
    496                 state->dictionary.data,
    497                 state->dictionary.size,
    498                 level);
    499         } else {
    500             zret = ZSTD_initCStream(zcs, level);
    501         }
    502     }
    503     if (ZSTD_isError(zret)) {
    504         return 1;
    505     }
    506     return 0;
    507 }
    508 
    509 static result_t old_streaming_compress_internal(
    510     method_state_t* base,
    511     config_t const* config,
    512     int const advanced,
    513     int const cdict) {
    514   buffer_state_t* state = container_of(base, buffer_state_t, base);
    515 
    516   if (buffer_state_bad(state, config))
    517     return result_error(result_error_system_error);
    518 
    519 
    520   ZSTD_CStream* zcs = ZSTD_createCStream();
    521   ZSTD_CDict* cd = NULL;
    522   result_t result;
    523   if (zcs == NULL) {
    524     result = result_error(result_error_compression_error);
    525     goto out;
    526   }
    527   if (!advanced && config_get_level(config) == CONFIG_NO_LEVEL) {
    528     result = result_error(result_error_skip);
    529     goto out;
    530   }
    531   if (cdict && !config->use_dictionary) {
    532     result = result_error(result_error_skip);
    533     goto out;
    534   }
    535   if (config->advanced_api_only) {
    536     result = result_error(result_error_skip);
    537     goto out;
    538   }
    539   if (init_cstream(state, zcs, config, advanced, cdict ? &cd : NULL)) {
    540     result = result_error(result_error_compression_error);
    541     goto out;
    542   }
    543 
    544   result_data_t data = {.total_size = 0};
    545   for (size_t i = 0; i < state->inputs.size; ++i) {
    546     data_buffer_t input = state->inputs.buffers[i];
    547     size_t zret = ZSTD_resetCStream(
    548         zcs,
    549         config->no_pledged_src_size ? ZSTD_CONTENTSIZE_UNKNOWN : input.size);
    550     if (ZSTD_isError(zret)) {
    551       result = result_error(result_error_compression_error);
    552       goto out;
    553     }
    554 
    555     while (input.size > 0) {
    556       ZSTD_inBuffer in = {input.data, MIN(input.size, 4096)};
    557       input.data += in.size;
    558       input.size -= in.size;
    559       ZSTD_EndDirective const op =
    560           input.size > 0 ? ZSTD_e_continue : ZSTD_e_end;
    561       zret = 0;
    562       while (in.pos < in.size || (op == ZSTD_e_end && zret != 0)) {
    563         ZSTD_outBuffer out = {state->compressed.data,
    564                               MIN(state->compressed.capacity, 1024)};
    565         if (op == ZSTD_e_continue || in.pos < in.size)
    566           zret = ZSTD_compressStream(zcs, &out, &in);
    567         else
    568           zret = ZSTD_endStream(zcs, &out);
    569         if (ZSTD_isError(zret)) {
    570           result = result_error(result_error_compression_error);
    571           goto out;
    572         }
    573         data.total_size += out.pos;
    574       }
    575     }
    576   }
    577 
    578   result = result_data(data);
    579 out:
    580     ZSTD_freeCStream(zcs);
    581     ZSTD_freeCDict(cd);
    582     return result;
    583 }
    584 
    585 static result_t old_streaming_compress(
    586     method_state_t* base,
    587     config_t const* config)
    588 {
    589     return old_streaming_compress_internal(
    590         base, config, /* advanced */ 0, /* cdict */ 0);
    591 }
    592 
    593 static result_t old_streaming_compress_advanced(
    594     method_state_t* base,
    595     config_t const* config)
    596 {
    597     return old_streaming_compress_internal(
    598         base, config, /* advanced */ 1, /* cdict */ 0);
    599 }
    600 
    601 static result_t old_streaming_compress_cdict(
    602     method_state_t* base,
    603     config_t const* config)
    604 {
    605     return old_streaming_compress_internal(
    606         base, config, /* advanced */ 0, /* cdict */ 1);
    607 }
    608 
    609 static result_t old_streaming_compress_cdict_advanced(
    610     method_state_t* base,
    611     config_t const* config)
    612 {
    613     return old_streaming_compress_internal(
    614         base, config, /* advanced */ 1, /* cdict */ 1);
    615 }
    616 
    617 method_t const simple = {
    618     .name = "compress simple",
    619     .create = buffer_state_create,
    620     .compress = simple_compress,
    621     .destroy = buffer_state_destroy,
    622 };
    623 
    624 method_t const compress_cctx = {
    625     .name = "compress cctx",
    626     .create = buffer_state_create,
    627     .compress = compress_cctx_compress,
    628     .destroy = buffer_state_destroy,
    629 };
    630 
    631 method_t const advanced_one_pass = {
    632     .name = "advanced one pass",
    633     .create = buffer_state_create,
    634     .compress = advanced_one_pass_compress,
    635     .destroy = buffer_state_destroy,
    636 };
    637 
    638 method_t const advanced_one_pass_small_out = {
    639     .name = "advanced one pass small out",
    640     .create = buffer_state_create,
    641     .compress = advanced_one_pass_compress,
    642     .destroy = buffer_state_destroy,
    643 };
    644 
    645 method_t const advanced_streaming = {
    646     .name = "advanced streaming",
    647     .create = buffer_state_create,
    648     .compress = advanced_streaming_compress,
    649     .destroy = buffer_state_destroy,
    650 };
    651 
    652 method_t const old_streaming = {
    653     .name = "old streaming",
    654     .create = buffer_state_create,
    655     .compress = old_streaming_compress,
    656     .destroy = buffer_state_destroy,
    657 };
    658 
    659 method_t const old_streaming_advanced = {
    660     .name = "old streaming advanced",
    661     .create = buffer_state_create,
    662     .compress = old_streaming_compress_advanced,
    663     .destroy = buffer_state_destroy,
    664 };
    665 
    666 method_t const old_streaming_cdict = {
    667     .name = "old streaming cdict",
    668     .create = buffer_state_create,
    669     .compress = old_streaming_compress_cdict,
    670     .destroy = buffer_state_destroy,
    671 };
    672 
    673 method_t const old_streaming_advanced_cdict = {
    674     .name = "old streaming advanced cdict",
    675     .create = buffer_state_create,
    676     .compress = old_streaming_compress_cdict_advanced,
    677     .destroy = buffer_state_destroy,
    678 };
    679 
    680 method_t const cli = {
    681     .name = "zstdcli",
    682     .create = method_state_create,
    683     .compress = cli_compress,
    684     .destroy = method_state_destroy,
    685 };
    686 
    687 static method_t const* g_methods[] = {
    688     &simple,
    689     &compress_cctx,
    690     &cli,
    691     &advanced_one_pass,
    692     &advanced_one_pass_small_out,
    693     &advanced_streaming,
    694     &old_streaming,
    695     &old_streaming_advanced,
    696     &old_streaming_cdict,
    697     &old_streaming_advanced_cdict,
    698     NULL,
    699 };
    700 
    701 method_t const* const* methods = g_methods;
    702