diff --git a/contrib/diagnose_corruption/check_flipped_bits.c b/contrib/diagnose_corruption/check_flipped_bits.c index 23b53287..58f7a45d 100644 --- a/contrib/diagnose_corruption/check_flipped_bits.c +++ b/contrib/diagnose_corruption/check_flipped_bits.c @@ -9,6 +9,7 @@ */ #include "zstd.h" +#include "zstd_errors.h" #include #include @@ -27,6 +28,9 @@ typedef struct { size_t output_size; ZSTD_DCtx* dctx; + + int success_count; + int error_counts[ZSTD_error_maxCode]; } stuff_t; static void free_stuff(stuff_t* stuff) { @@ -40,6 +44,19 @@ static void usage(void) { exit(1); } +static void print_summary(stuff_t* stuff) { + int error_code; + fprintf(stderr, "%9d successful decompressions\n", stuff->success_count); + for (error_code = 0; error_code < ZSTD_error_maxCode; error_code++) { + int count = stuff->error_counts[error_code]; + if (count) { + fprintf( + stderr, "%9d failed decompressions with message: %s\n", + count, ZSTD_getErrorString(error_code)); + } + } +} + static char* readFile(const char* filename, size_t* size) { struct stat statbuf; int ret; @@ -121,6 +138,9 @@ static int init_stuff(stuff_t* stuff, int argc, char *argv[]) { return 0; } + stuff->success_count = 0; + memset(stuff->error_counts, 0, sizeof(stuff->error_counts)); + return 1; } @@ -138,6 +158,12 @@ static int test_decompress(stuff_t* stuff) { ret = ZSTD_decompressStream(dctx, &out, &in); if (ZSTD_isError(ret)) { + unsigned int code = ZSTD_getErrorCode(ret); + if (code >= ZSTD_error_maxCode) { + fprintf(stderr, "Received unexpected error code!\n"); + exit(1); + } + stuff->error_counts[code]++; /* fprintf( stderr, "Decompression failed: %s\n", ZSTD_getErrorName(ret)); @@ -146,6 +172,7 @@ static int test_decompress(stuff_t* stuff) { } } + stuff->success_count++; return 1; } @@ -155,7 +182,7 @@ static int perturb_bits(stuff_t* stuff) { for (pos = 0; pos < stuff->input_size; pos++) { unsigned char old_val = stuff->input[pos]; if (pos % 1000 == 0) { - fprintf(stderr, "Perturbing byte %zu\n", pos); + fprintf(stderr, "Perturbing byte %zu / %zu\n", pos, stuff->input_size); } for (bit = 0; bit < 8; bit++) { unsigned char new_val = old_val ^ (1 << bit); @@ -179,7 +206,7 @@ static int perturb_bytes(stuff_t* stuff) { for (pos = 0; pos < stuff->input_size; pos++) { unsigned char old_val = stuff->input[pos]; if (pos % 1000 == 0) { - fprintf(stderr, "Perturbing byte %zu\n", pos); + fprintf(stderr, "Perturbing byte %zu / %zu\n", pos, stuff->input_size); } for (new_val = 0; new_val < 256; new_val++) { stuff->perturbed[pos] = new_val; @@ -213,6 +240,8 @@ int main(int argc, char* argv[]) { perturb_bytes(&stuff); + print_summary(&stuff); + free_stuff(&stuff); return 0;