]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - sys/contrib/zstd/tests/decodecorpus.c
Merge llvm trunk r321414 to contrib/llvm.
[FreeBSD/FreeBSD.git] / sys / contrib / zstd / tests / decodecorpus.c
1 /*
2  * Copyright (c) 2017-present, Yann Collet, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under both the BSD-style license (found in the
6  * LICENSE file in the root directory of this source tree) and the GPLv2 (found
7  * in the COPYING file in the root directory of this source tree).
8  * You may select, at your option, one of the above-listed licenses.
9  */
10
11 #include <limits.h>
12 #include <math.h>
13 #include <stddef.h>
14 #include <stdio.h>
15 #include <stdlib.h>
16 #include <string.h>
17 #include <time.h>
18
19 #include "zstd.h"
20 #include "zstd_internal.h"
21 #include "mem.h"
22 #define ZDICT_STATIC_LINKING_ONLY
23 #include "zdict.h"
24
25 // Direct access to internal compression functions is required
26 #include "zstd_compress.c"
27
28 #define XXH_STATIC_LINKING_ONLY
29 #include "xxhash.h"     /* XXH64 */
30
31 #ifndef MIN
32     #define MIN(a, b) ((a) < (b) ? (a) : (b))
33 #endif
34
35 #ifndef MAX_PATH
36     #ifdef PATH_MAX
37         #define MAX_PATH PATH_MAX
38     #else
39         #define MAX_PATH 256
40     #endif
41 #endif
42
43 /*-************************************
44 *  DISPLAY Macros
45 **************************************/
46 #define DISPLAY(...)          fprintf(stderr, __VA_ARGS__)
47 #define DISPLAYLEVEL(l, ...)  if (g_displayLevel>=l) { DISPLAY(__VA_ARGS__); }
48 static U32 g_displayLevel = 2;
49
50 #define DISPLAYUPDATE(...)                                                     \
51     do {                                                                       \
52         if ((clockSpan(g_displayClock) > g_refreshRate) ||                     \
53             (g_displayLevel >= 4)) {                                           \
54             g_displayClock = clock();                                          \
55             DISPLAY(__VA_ARGS__);                                              \
56             if (g_displayLevel >= 4) fflush(stderr);                           \
57         }                                                                      \
58     } while (0)
59 static const clock_t g_refreshRate = CLOCKS_PER_SEC / 6;
60 static clock_t g_displayClock = 0;
61
62 static clock_t clockSpan(clock_t cStart)
63 {
64     return clock() - cStart;   /* works even when overflow; max span ~ 30mn */
65 }
66
67 #define CHECKERR(code)                                                         \
68     do {                                                                       \
69         if (ZSTD_isError(code)) {                                              \
70             DISPLAY("Error occurred while generating data: %s\n",              \
71                     ZSTD_getErrorName(code));                                  \
72             exit(1);                                                           \
73         }                                                                      \
74     } while (0)
75
76 /*-*******************************************************
77 *  Random function
78 *********************************************************/
79 static unsigned RAND(unsigned* src)
80 {
81 #define RAND_rotl32(x,r) ((x << r) | (x >> (32 - r)))
82     static const U32 prime1 = 2654435761U;
83     static const U32 prime2 = 2246822519U;
84     U32 rand32 = *src;
85     rand32 *= prime1;
86     rand32 += prime2;
87     rand32  = RAND_rotl32(rand32, 13);
88     *src = rand32;
89     return RAND_rotl32(rand32, 27);
90 #undef RAND_rotl32
91 }
92
93 #define DISTSIZE (8192)
94
95 /* Write `size` bytes into `ptr`, all of which are less than or equal to `maxSymb` */
96 static void RAND_bufferMaxSymb(U32* seed, void* ptr, size_t size, int maxSymb)
97 {
98     size_t i;
99     BYTE* op = ptr;
100
101     for (i = 0; i < size; i++) {
102         op[i] = (BYTE) (RAND(seed) % (maxSymb + 1));
103     }
104 }
105
106 /* Write `size` random bytes into `ptr` */
107 static void RAND_buffer(U32* seed, void* ptr, size_t size)
108 {
109     size_t i;
110     BYTE* op = ptr;
111
112     for (i = 0; i + 4 <= size; i += 4) {
113         MEM_writeLE32(op + i, RAND(seed));
114     }
115     for (; i < size; i++) {
116         op[i] = RAND(seed) & 0xff;
117     }
118 }
119
120 /* Write `size` bytes into `ptr` following the distribution `dist` */
121 static void RAND_bufferDist(U32* seed, BYTE* dist, void* ptr, size_t size)
122 {
123     size_t i;
124     BYTE* op = ptr;
125
126     for (i = 0; i < size; i++) {
127         op[i] = dist[RAND(seed) % DISTSIZE];
128     }
129 }
130
131 /* Generate a random distribution where the frequency of each symbol follows a
132  * geometric distribution defined by `weight`
133  * `dist` should have size at least `DISTSIZE` */
134 static void RAND_genDist(U32* seed, BYTE* dist, double weight)
135 {
136     size_t i = 0;
137     size_t statesLeft = DISTSIZE;
138     BYTE symb = (BYTE) (RAND(seed) % 256);
139     BYTE step = (BYTE) ((RAND(seed) % 256) | 1); /* force it to be odd so it's relatively prime to 256 */
140
141     while (i < DISTSIZE) {
142         size_t states = ((size_t)(weight * statesLeft)) + 1;
143         size_t j;
144         for (j = 0; j < states && i < DISTSIZE; j++, i++) {
145             dist[i] = symb;
146         }
147
148         symb += step;
149         statesLeft -= states;
150     }
151 }
152
153 /* Generates a random number in the range [min, max) */
154 static inline U32 RAND_range(U32* seed, U32 min, U32 max)
155 {
156     return (RAND(seed) % (max-min)) + min;
157 }
158
159 #define ROUND(x) ((U32)(x + 0.5))
160
161 /* Generates a random number in an exponential distribution with mean `mean` */
162 static double RAND_exp(U32* seed, double mean)
163 {
164     double const u = RAND(seed) / (double) UINT_MAX;
165     return log(1-u) * (-mean);
166 }
167
168 /*-*******************************************************
169 *  Constants and Structs
170 *********************************************************/
171 const char *BLOCK_TYPES[] = {"raw", "rle", "compressed"};
172
173 #define MAX_DECOMPRESSED_SIZE_LOG 20
174 #define MAX_DECOMPRESSED_SIZE (1ULL << MAX_DECOMPRESSED_SIZE_LOG)
175
176 #define MAX_WINDOW_LOG 22 /* Recommended support is 8MB, so limit to 4MB + mantissa */
177
178 #define MIN_SEQ_LEN (3)
179 #define MAX_NB_SEQ ((ZSTD_BLOCKSIZE_MAX + MIN_SEQ_LEN - 1) / MIN_SEQ_LEN)
180
181 BYTE CONTENT_BUFFER[MAX_DECOMPRESSED_SIZE];
182 BYTE FRAME_BUFFER[MAX_DECOMPRESSED_SIZE * 2];
183 BYTE LITERAL_BUFFER[ZSTD_BLOCKSIZE_MAX];
184
185 seqDef SEQUENCE_BUFFER[MAX_NB_SEQ];
186 BYTE SEQUENCE_LITERAL_BUFFER[ZSTD_BLOCKSIZE_MAX]; /* storeSeq expects a place to copy literals to */
187 BYTE SEQUENCE_LLCODE[ZSTD_BLOCKSIZE_MAX];
188 BYTE SEQUENCE_MLCODE[ZSTD_BLOCKSIZE_MAX];
189 BYTE SEQUENCE_OFCODE[ZSTD_BLOCKSIZE_MAX];
190
191 unsigned WKSP[1024];
192
193 typedef struct {
194     size_t contentSize; /* 0 means unknown (unless contentSize == windowSize == 0) */
195     unsigned windowSize; /* contentSize >= windowSize means single segment */
196 } frameHeader_t;
197
198 /* For repeat modes */
199 typedef struct {
200     U32 rep[ZSTD_REP_NUM];
201
202     int hufInit;
203     /* the distribution used in the previous block for repeat mode */
204     BYTE hufDist[DISTSIZE];
205     U32 hufTable [256]; /* HUF_CElt is an incomplete type */
206
207     int fseInit;
208     FSE_CTable offcodeCTable  [FSE_CTABLE_SIZE_U32(OffFSELog, MaxOff)];
209     FSE_CTable matchlengthCTable[FSE_CTABLE_SIZE_U32(MLFSELog, MaxML)];
210     FSE_CTable litlengthCTable  [FSE_CTABLE_SIZE_U32(LLFSELog, MaxLL)];
211
212     /* Symbols that were present in the previous distribution, for use with
213      * set_repeat */
214     BYTE litlengthSymbolSet[36];
215     BYTE offsetSymbolSet[29];
216     BYTE matchlengthSymbolSet[53];
217 } cblockStats_t;
218
219 typedef struct {
220     void* data;
221     void* dataStart;
222     void* dataEnd;
223
224     void* src;
225     void* srcStart;
226     void* srcEnd;
227
228     frameHeader_t header;
229
230     cblockStats_t stats;
231     cblockStats_t oldStats; /* so they can be rolled back if uncompressible */
232 } frame_t;
233
234 typedef struct {
235     int useDict;
236     U32 dictID;
237     size_t dictContentSize;
238     BYTE* dictContent;
239 } dictInfo;
240
241 typedef enum {
242   gt_frame = 0,  /* generate frames */
243   gt_block,      /* generate compressed blocks without block/frame headers */
244 } genType_e;
245
246 /*-*******************************************************
247 *  Global variables (set from command line)
248 *********************************************************/
249 U32 g_maxDecompressedSizeLog = MAX_DECOMPRESSED_SIZE_LOG;  /* <= 20 */
250 U32 g_maxBlockSize = ZSTD_BLOCKSIZE_MAX;                       /* <= 128 KB */
251
252 /*-*******************************************************
253 *  Generator Functions
254 *********************************************************/
255
256 struct {
257     int contentSize; /* force the content size to be present */
258 } opts; /* advanced options on generation */
259
260 /* Generate and write a random frame header */
261 static void writeFrameHeader(U32* seed, frame_t* frame, dictInfo info)
262 {
263     BYTE* const op = frame->data;
264     size_t pos = 0;
265     frameHeader_t fh;
266
267     BYTE windowByte = 0;
268
269     int singleSegment = 0;
270     int contentSizeFlag = 0;
271     int fcsCode = 0;
272
273     memset(&fh, 0, sizeof(fh));
274
275     /* generate window size */
276     {
277         /* Follow window algorithm from specification */
278         int const exponent = RAND(seed) % (MAX_WINDOW_LOG - 10);
279         int const mantissa = RAND(seed) % 8;
280         windowByte = (BYTE) ((exponent << 3) | mantissa);
281         fh.windowSize = (1U << (exponent + 10));
282         fh.windowSize += fh.windowSize / 8 * mantissa;
283     }
284
285     {
286         /* Generate random content size */
287         size_t highBit;
288         if (RAND(seed) & 7 && g_maxDecompressedSizeLog > 7) {
289             /* do content of at least 128 bytes */
290             highBit = 1ULL << RAND_range(seed, 7, g_maxDecompressedSizeLog);
291         } else if (RAND(seed) & 3) {
292             /* do small content */
293             highBit = 1ULL << RAND_range(seed, 0, MIN(7, 1U << g_maxDecompressedSizeLog));
294         } else {
295             /* 0 size frame */
296             highBit = 0;
297         }
298         fh.contentSize = highBit ? highBit + (RAND(seed) % highBit) : 0;
299
300         /* provide size sometimes */
301         contentSizeFlag = opts.contentSize | (RAND(seed) & 1);
302
303         if (contentSizeFlag && (fh.contentSize == 0 || !(RAND(seed) & 7))) {
304             /* do single segment sometimes */
305             fh.windowSize = (U32) fh.contentSize;
306             singleSegment = 1;
307         }
308     }
309
310     if (contentSizeFlag) {
311         /* Determine how large fcs field has to be */
312         int minFcsCode = (fh.contentSize >= 256) +
313                                (fh.contentSize >= 65536 + 256) +
314                                (fh.contentSize > 0xFFFFFFFFU);
315         if (!singleSegment && !minFcsCode) {
316             minFcsCode = 1;
317         }
318         fcsCode = minFcsCode + (RAND(seed) % (4 - minFcsCode));
319         if (fcsCode == 1 && fh.contentSize < 256) fcsCode++;
320     }
321
322     /* write out the header */
323     MEM_writeLE32(op + pos, ZSTD_MAGICNUMBER);
324     pos += 4;
325
326     {
327         /*
328          * fcsCode: 2-bit flag specifying how many bytes used to represent Frame_Content_Size (bits 7-6)
329          * singleSegment: 1-bit flag describing if data must be regenerated within a single continuous memory segment. (bit 5)
330          * contentChecksumFlag: 1-bit flag that is set if frame includes checksum at the end -- set to 1 below (bit 2)
331          * dictBits: 2-bit flag describing how many bytes Dictionary_ID uses -- set to 3 (bits 1-0)
332          * For more information: https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#frame_header
333          */
334         int const dictBits = info.useDict ? 3 : 0;
335         BYTE const frameHeaderDescriptor =
336                 (BYTE) ((fcsCode << 6) | (singleSegment << 5) | (1 << 2) | dictBits);
337         op[pos++] = frameHeaderDescriptor;
338     }
339
340     if (!singleSegment) {
341         op[pos++] = windowByte;
342     }
343     if (info.useDict) {
344         MEM_writeLE32(op + pos, (U32) info.dictID);
345         pos += 4;
346     }
347     if (contentSizeFlag) {
348         switch (fcsCode) {
349         default: /* Impossible */
350         case 0: op[pos++] = (BYTE) fh.contentSize; break;
351         case 1: MEM_writeLE16(op + pos, (U16) (fh.contentSize - 256)); pos += 2; break;
352         case 2: MEM_writeLE32(op + pos, (U32) fh.contentSize); pos += 4; break;
353         case 3: MEM_writeLE64(op + pos, (U64) fh.contentSize); pos += 8; break;
354         }
355     }
356
357     DISPLAYLEVEL(3, " frame content size:\t%u\n", (U32)fh.contentSize);
358     DISPLAYLEVEL(3, " frame window size:\t%u\n", fh.windowSize);
359     DISPLAYLEVEL(3, " content size flag:\t%d\n", contentSizeFlag);
360     DISPLAYLEVEL(3, " single segment flag:\t%d\n", singleSegment);
361
362     frame->data = op + pos;
363     frame->header = fh;
364 }
365
366 /* Write a literal block in either raw or RLE form, return the literals size */
367 static size_t writeLiteralsBlockSimple(U32* seed, frame_t* frame, size_t contentSize)
368 {
369     BYTE* op = (BYTE*)frame->data;
370     int const type = RAND(seed) % 2;
371     int const sizeFormatDesc = RAND(seed) % 8;
372     size_t litSize;
373     size_t maxLitSize = MIN(contentSize, g_maxBlockSize);
374
375     if (sizeFormatDesc == 0) {
376         /* Size_FormatDesc = ?0 */
377         maxLitSize = MIN(maxLitSize, 31);
378     } else if (sizeFormatDesc <= 4) {
379         /* Size_FormatDesc = 01 */
380         maxLitSize = MIN(maxLitSize, 4095);
381     } else {
382         /* Size_Format = 11 */
383         maxLitSize = MIN(maxLitSize, 1048575);
384     }
385
386     litSize = RAND(seed) % (maxLitSize + 1);
387     if (frame->src == frame->srcStart && litSize == 0) {
388         litSize = 1; /* no empty literals if there's nothing preceding this block */
389     }
390     if (litSize + 3 > contentSize) {
391         litSize = contentSize; /* no matches shorter than 3 are allowed */
392     }
393     /* use smallest size format that fits */
394     if (litSize < 32) {
395         op[0] = (type | (0 << 2) | (litSize << 3)) & 0xff;
396         op += 1;
397     } else if (litSize < 4096) {
398         op[0] = (type | (1 << 2) | (litSize << 4)) & 0xff;
399         op[1] = (litSize >> 4) & 0xff;
400         op += 2;
401     } else {
402         op[0] = (type | (3 << 2) | (litSize << 4)) & 0xff;
403         op[1] = (litSize >> 4) & 0xff;
404         op[2] = (litSize >> 12) & 0xff;
405         op += 3;
406     }
407
408     if (type == 0) {
409         /* Raw literals */
410         DISPLAYLEVEL(4, "   raw literals\n");
411
412         RAND_buffer(seed, LITERAL_BUFFER, litSize);
413         memcpy(op, LITERAL_BUFFER, litSize);
414         op += litSize;
415     } else {
416         /* RLE literals */
417         BYTE const symb = (BYTE) (RAND(seed) % 256);
418
419         DISPLAYLEVEL(4, "   rle literals: 0x%02x\n", (U32)symb);
420
421         memset(LITERAL_BUFFER, symb, litSize);
422         op[0] = symb;
423         op++;
424     }
425
426     frame->data = op;
427
428     return litSize;
429 }
430
431 /* Generate a Huffman header for the given source */
432 static size_t writeHufHeader(U32* seed, HUF_CElt* hufTable, void* dst, size_t dstSize,
433                                  const void* src, size_t srcSize)
434 {
435     BYTE* const ostart = (BYTE*)dst;
436     BYTE* op = ostart;
437
438     unsigned huffLog = 11;
439     U32 maxSymbolValue = 255;
440
441     U32 count[HUF_SYMBOLVALUE_MAX+1];
442
443     /* Scan input and build symbol stats */
444     {   size_t const largest = FSE_count_wksp (count, &maxSymbolValue, (const BYTE*)src, srcSize, WKSP);
445         if (largest == srcSize) { *ostart = ((const BYTE*)src)[0]; return 0; }   /* single symbol, rle */
446         if (largest <= (srcSize >> 7)+1) return 0;   /* Fast heuristic : not compressible enough */
447     }
448
449     /* Build Huffman Tree */
450     /* Max Huffman log is 11, min is highbit(maxSymbolValue)+1 */
451     huffLog = RAND_range(seed, ZSTD_highbit32(maxSymbolValue)+1, huffLog+1);
452     DISPLAYLEVEL(6, "     huffman log: %u\n", huffLog);
453     {   size_t const maxBits = HUF_buildCTable_wksp (hufTable, count, maxSymbolValue, huffLog, WKSP, sizeof(WKSP));
454         CHECKERR(maxBits);
455         huffLog = (U32)maxBits;
456     }
457
458     /* Write table description header */
459     {   size_t const hSize = HUF_writeCTable (op, dstSize, hufTable, maxSymbolValue, huffLog);
460         if (hSize + 12 >= srcSize) return 0;   /* not useful to try compression */
461         op += hSize;
462     }
463
464     return op - ostart;
465 }
466
467 /* Write a Huffman coded literals block and return the literals size */
468 static size_t writeLiteralsBlockCompressed(U32* seed, frame_t* frame, size_t contentSize)
469 {
470     BYTE* origop = (BYTE*)frame->data;
471     BYTE* opend = (BYTE*)frame->dataEnd;
472     BYTE* op;
473     BYTE* const ostart = origop;
474     int const sizeFormat = RAND(seed) % 4;
475     size_t litSize;
476     size_t hufHeaderSize = 0;
477     size_t compressedSize = 0;
478     size_t maxLitSize = MIN(contentSize-3, g_maxBlockSize);
479
480     symbolEncodingType_e hType;
481
482     if (contentSize < 64) {
483         /* make sure we get reasonably-sized literals for compression */
484         return ERROR(GENERIC);
485     }
486
487     DISPLAYLEVEL(4, "   compressed literals\n");
488
489     switch (sizeFormat) {
490     case 0: /* fall through, size is the same as case 1 */
491     case 1:
492         maxLitSize = MIN(maxLitSize, 1023);
493         origop += 3;
494         break;
495     case 2:
496         maxLitSize = MIN(maxLitSize, 16383);
497         origop += 4;
498         break;
499     case 3:
500         maxLitSize = MIN(maxLitSize, 262143);
501         origop += 5;
502         break;
503     default:; /* impossible */
504     }
505
506     do {
507         op = origop;
508         do {
509             litSize = RAND(seed) % (maxLitSize + 1);
510         } while (litSize < 32); /* avoid small literal sizes */
511         if (litSize + 3 > contentSize) {
512             litSize = contentSize; /* no matches shorter than 3 are allowed */
513         }
514
515         /* most of the time generate a new distribution */
516         if ((RAND(seed) & 3) || !frame->stats.hufInit) {
517             do {
518                 if (RAND(seed) & 3) {
519                     /* add 10 to ensure some compressability */
520                     double const weight = ((RAND(seed) % 90) + 10) / 100.0;
521
522                     DISPLAYLEVEL(5, "    distribution weight: %d%%\n",
523                                  (int)(weight * 100));
524
525                     RAND_genDist(seed, frame->stats.hufDist, weight);
526                 } else {
527                     /* sometimes do restricted range literals to force
528                      * non-huffman headers */
529                     DISPLAYLEVEL(5, "    small range literals\n");
530                     RAND_bufferMaxSymb(seed, frame->stats.hufDist, DISTSIZE,
531                                        15);
532                 }
533                 RAND_bufferDist(seed, frame->stats.hufDist, LITERAL_BUFFER,
534                                 litSize);
535
536                 /* generate the header from the distribution instead of the
537                  * actual data to avoid bugs with symbols that were in the
538                  * distribution but never showed up in the output */
539                 hufHeaderSize = writeHufHeader(
540                         seed, (HUF_CElt*)frame->stats.hufTable, op, opend - op,
541                         frame->stats.hufDist, DISTSIZE);
542                 CHECKERR(hufHeaderSize);
543                 /* repeat until a valid header is written */
544             } while (hufHeaderSize == 0);
545             op += hufHeaderSize;
546             hType = set_compressed;
547
548             frame->stats.hufInit = 1;
549         } else {
550             /* repeat the distribution/table from last time */
551             DISPLAYLEVEL(5, "    huffman repeat stats\n");
552             RAND_bufferDist(seed, frame->stats.hufDist, LITERAL_BUFFER,
553                             litSize);
554             hufHeaderSize = 0;
555             hType = set_repeat;
556         }
557
558         do {
559             compressedSize =
560                     sizeFormat == 0
561                             ? HUF_compress1X_usingCTable(
562                                       op, opend - op, LITERAL_BUFFER, litSize,
563                                       (HUF_CElt*)frame->stats.hufTable)
564                             : HUF_compress4X_usingCTable(
565                                       op, opend - op, LITERAL_BUFFER, litSize,
566                                       (HUF_CElt*)frame->stats.hufTable);
567             CHECKERR(compressedSize);
568             /* this only occurs when it could not compress or similar */
569         } while (compressedSize <= 0);
570
571         op += compressedSize;
572
573         compressedSize += hufHeaderSize;
574         DISPLAYLEVEL(5, "    regenerated size: %u\n", (U32)litSize);
575         DISPLAYLEVEL(5, "    compressed size: %u\n", (U32)compressedSize);
576         if (compressedSize >= litSize) {
577             DISPLAYLEVEL(5, "     trying again\n");
578             /* if we have to try again, reset the stats so we don't accidentally
579              * try to repeat a distribution we just made */
580             frame->stats = frame->oldStats;
581         } else {
582             break;
583         }
584     } while (1);
585
586     /* write header */
587     switch (sizeFormat) {
588     case 0: /* fall through, size is the same as case 1 */
589     case 1: {
590         U32 const header = hType | (sizeFormat << 2) | ((U32)litSize << 4) |
591                            ((U32)compressedSize << 14);
592         MEM_writeLE24(ostart, header);
593         break;
594     }
595     case 2: {
596         U32 const header = hType | (sizeFormat << 2) | ((U32)litSize << 4) |
597                            ((U32)compressedSize << 18);
598         MEM_writeLE32(ostart, header);
599         break;
600     }
601     case 3: {
602         U32 const header = hType | (sizeFormat << 2) | ((U32)litSize << 4) |
603                            ((U32)compressedSize << 22);
604         MEM_writeLE32(ostart, header);
605         ostart[4] = (BYTE)(compressedSize >> 10);
606         break;
607     }
608     default:; /* impossible */
609     }
610
611     frame->data = op;
612     return litSize;
613 }
614
615 static size_t writeLiteralsBlock(U32* seed, frame_t* frame, size_t contentSize)
616 {
617     /* only do compressed for larger segments to avoid compressibility issues */
618     if (RAND(seed) & 7 && contentSize >= 64) {
619         return writeLiteralsBlockCompressed(seed, frame, contentSize);
620     } else {
621         return writeLiteralsBlockSimple(seed, frame, contentSize);
622     }
623 }
624
625 static inline void initSeqStore(seqStore_t *seqStore) {
626     seqStore->sequencesStart = SEQUENCE_BUFFER;
627     seqStore->litStart = SEQUENCE_LITERAL_BUFFER;
628     seqStore->llCode = SEQUENCE_LLCODE;
629     seqStore->mlCode = SEQUENCE_MLCODE;
630     seqStore->ofCode = SEQUENCE_OFCODE;
631
632     ZSTD_resetSeqStore(seqStore);
633 }
634
635 /* Randomly generate sequence commands */
636 static U32 generateSequences(U32* seed, frame_t* frame, seqStore_t* seqStore,
637                                 size_t contentSize, size_t literalsSize, dictInfo info)
638 {
639     /* The total length of all the matches */
640     size_t const remainingMatch = contentSize - literalsSize;
641     size_t excessMatch = 0;
642     U32 numSequences = 0;
643
644     U32 i;
645
646
647     const BYTE* literals = LITERAL_BUFFER;
648     BYTE* srcPtr = frame->src;
649
650     if (literalsSize != contentSize) {
651         /* each match must be at least MIN_SEQ_LEN, so this is the maximum
652          * number of sequences we can have */
653         U32 const maxSequences = (U32)remainingMatch / MIN_SEQ_LEN;
654         numSequences = (RAND(seed) % maxSequences) + 1;
655
656         /* the extra match lengths we have to allocate to each sequence */
657         excessMatch = remainingMatch - numSequences * MIN_SEQ_LEN;
658     }
659
660     DISPLAYLEVEL(5, "    total match lengths: %u\n", (U32)remainingMatch);
661     for (i = 0; i < numSequences; i++) {
662         /* Generate match and literal lengths by exponential distribution to
663          * ensure nice numbers */
664         U32 matchLen =
665                 MIN_SEQ_LEN +
666                 ROUND(RAND_exp(seed, excessMatch / (double)(numSequences - i)));
667         U32 literalLen =
668                 (RAND(seed) & 7)
669                         ? ROUND(RAND_exp(seed,
670                                          literalsSize /
671                                                  (double)(numSequences - i)))
672                         : 0;
673         /* actual offset, code to send, and point to copy up to when shifting
674          * codes in the repeat offsets history */
675         U32 offset, offsetCode, repIndex;
676
677         /* bounds checks */
678         matchLen = (U32) MIN(matchLen, excessMatch + MIN_SEQ_LEN);
679         literalLen = MIN(literalLen, (U32) literalsSize);
680         if (i == 0 && srcPtr == frame->srcStart && literalLen == 0) literalLen = 1;
681         if (i + 1 == numSequences) matchLen = MIN_SEQ_LEN + (U32) excessMatch;
682
683         memcpy(srcPtr, literals, literalLen);
684         srcPtr += literalLen;
685         do {
686             if (RAND(seed) & 7) {
687                 /* do a normal offset */
688                 U32 const dataDecompressed = (U32)((BYTE*)srcPtr-(BYTE*)frame->srcStart);
689                 offset = (RAND(seed) %
690                           MIN(frame->header.windowSize,
691                               (size_t)((BYTE*)srcPtr - (BYTE*)frame->srcStart))) +
692                          1;
693                 if (info.useDict && (RAND(seed) & 1) && i + 1 != numSequences && dataDecompressed < frame->header.windowSize) {
694                     /* need to occasionally generate offsets that go past the start */
695                     /* including i+1 != numSequences because the last sequences has to adhere to predetermined contentSize */
696                     U32 lenPastStart = (RAND(seed) % info.dictContentSize) + 1;
697                     offset = (U32)((BYTE*)srcPtr - (BYTE*)frame->srcStart)+lenPastStart;
698                     if (offset > frame->header.windowSize) {
699                         if (lenPastStart < MIN_SEQ_LEN) {
700                             /* when offset > windowSize, matchLen bound by end of dictionary (lenPastStart) */
701                             /* this also means that lenPastStart must be greater than MIN_SEQ_LEN */
702                             /* make sure lenPastStart does not go past dictionary start though */
703                             lenPastStart = MIN(lenPastStart+MIN_SEQ_LEN, (U32)info.dictContentSize);
704                             offset = (U32)((BYTE*)srcPtr - (BYTE*)frame->srcStart) + lenPastStart;
705                         }
706                         {
707                             U32 const matchLenBound = MIN(frame->header.windowSize, lenPastStart);
708                             matchLen = MIN(matchLen, matchLenBound);
709                         }
710                     }
711                 }
712                 offsetCode = offset + ZSTD_REP_MOVE;
713                 repIndex = 2;
714             } else {
715                 /* do a repeat offset */
716                 offsetCode = RAND(seed) % 3;
717                 if (literalLen > 0) {
718                     offset = frame->stats.rep[offsetCode];
719                     repIndex = offsetCode;
720                 } else {
721                     /* special case */
722                     offset = offsetCode == 2 ? frame->stats.rep[0] - 1
723                                            : frame->stats.rep[offsetCode + 1];
724                     repIndex = MIN(2, offsetCode + 1);
725                 }
726             }
727         } while (((!info.useDict) && (offset > (size_t)((BYTE*)srcPtr - (BYTE*)frame->srcStart))) || offset == 0);
728
729         {
730             size_t j;
731             BYTE* const dictEnd = info.dictContent + info.dictContentSize;
732             for (j = 0; j < matchLen; j++) {
733                 if ((U32)((BYTE*)srcPtr - (BYTE*)frame->srcStart) < offset) {
734                     /* copy from dictionary instead of literals */
735                     size_t const dictOffset = offset - (srcPtr - (BYTE*)frame->srcStart);
736                     *srcPtr = *(dictEnd - dictOffset);
737                 }
738                 else {
739                     *srcPtr = *(srcPtr-offset);
740                 }
741                 srcPtr++;
742             }
743         }
744
745         {   int r;
746             for (r = repIndex; r > 0; r--) {
747                 frame->stats.rep[r] = frame->stats.rep[r - 1];
748             }
749             frame->stats.rep[0] = offset;
750         }
751
752         DISPLAYLEVEL(6, "      LL: %5u OF: %5u ML: %5u", literalLen, offset, matchLen);
753         DISPLAYLEVEL(7, " srcPos: %8u seqNb: %3u",
754                      (U32)((BYTE*)srcPtr - (BYTE*)frame->srcStart), i);
755         DISPLAYLEVEL(6, "\n");
756         if (offsetCode < 3) {
757             DISPLAYLEVEL(7, "        repeat offset: %d\n", repIndex);
758         }
759         /* use libzstd sequence handling */
760         ZSTD_storeSeq(seqStore, literalLen, literals, offsetCode,
761                       matchLen - MINMATCH);
762
763         literalsSize -= literalLen;
764         excessMatch -= (matchLen - MIN_SEQ_LEN);
765         literals += literalLen;
766     }
767
768     memcpy(srcPtr, literals, literalsSize);
769     srcPtr += literalsSize;
770     DISPLAYLEVEL(6, "      excess literals: %5u", (U32)literalsSize);
771     DISPLAYLEVEL(7, " srcPos: %8u", (U32)((BYTE*)srcPtr - (BYTE*)frame->srcStart));
772     DISPLAYLEVEL(6, "\n");
773
774     return numSequences;
775 }
776
777 static void initSymbolSet(const BYTE* symbols, size_t len, BYTE* set, BYTE maxSymbolValue)
778 {
779     size_t i;
780
781     memset(set, 0, (size_t)maxSymbolValue+1);
782
783     for (i = 0; i < len; i++) {
784         set[symbols[i]] = 1;
785     }
786 }
787
788 static int isSymbolSubset(const BYTE* symbols, size_t len, const BYTE* set, BYTE maxSymbolValue)
789 {
790     size_t i;
791
792     for (i = 0; i < len; i++) {
793         if (symbols[i] > maxSymbolValue || !set[symbols[i]]) {
794             return 0;
795         }
796     }
797     return 1;
798 }
799
800 static size_t writeSequences(U32* seed, frame_t* frame, seqStore_t* seqStorePtr,
801                              size_t nbSeq)
802 {
803     /* This code is mostly copied from ZSTD_compressSequences in zstd_compress.c */
804     U32 count[MaxSeq+1];
805     S16 norm[MaxSeq+1];
806     FSE_CTable* CTable_LitLength = frame->stats.litlengthCTable;
807     FSE_CTable* CTable_OffsetBits = frame->stats.offcodeCTable;
808     FSE_CTable* CTable_MatchLength = frame->stats.matchlengthCTable;
809     U32 LLtype, Offtype, MLtype;   /* compressed, raw or rle */
810     const seqDef* const sequences = seqStorePtr->sequencesStart;
811     const BYTE* const ofCodeTable = seqStorePtr->ofCode;
812     const BYTE* const llCodeTable = seqStorePtr->llCode;
813     const BYTE* const mlCodeTable = seqStorePtr->mlCode;
814     BYTE* const oend = (BYTE*)frame->dataEnd;
815     BYTE* op = (BYTE*)frame->data;
816     BYTE* seqHead;
817     BYTE scratchBuffer[1<<MAX(MLFSELog,LLFSELog)];
818
819     /* literals compressing block removed so that can be done separately */
820
821     /* Sequences Header */
822     if ((oend-op) < 3 /*max nbSeq Size*/ + 1 /*seqHead */) return ERROR(dstSize_tooSmall);
823     if (nbSeq < 0x7F) *op++ = (BYTE)nbSeq;
824     else if (nbSeq < LONGNBSEQ) op[0] = (BYTE)((nbSeq>>8) + 0x80), op[1] = (BYTE)nbSeq, op+=2;
825     else op[0]=0xFF, MEM_writeLE16(op+1, (U16)(nbSeq - LONGNBSEQ)), op+=3;
826
827     /* seqHead : flags for FSE encoding type */
828     seqHead = op++;
829
830     if (nbSeq==0) {
831         frame->data = op;
832
833         return 0;
834     }
835
836     /* convert length/distances into codes */
837     ZSTD_seqToCodes(seqStorePtr);
838
839     /* CTable for Literal Lengths */
840     {   U32 max = MaxLL;
841         size_t const mostFrequent = FSE_countFast_wksp(count, &max, llCodeTable, nbSeq, WKSP);
842         if (mostFrequent == nbSeq) {
843             /* do RLE if we have the chance */
844             *op++ = llCodeTable[0];
845             FSE_buildCTable_rle(CTable_LitLength, (BYTE)max);
846             LLtype = set_rle;
847         } else if (frame->stats.fseInit && !(RAND(seed) & 3) &&
848                    isSymbolSubset(llCodeTable, nbSeq,
849                                   frame->stats.litlengthSymbolSet, 35)) {
850             /* maybe do repeat mode if we're allowed to */
851             LLtype = set_repeat;
852         } else if (!(RAND(seed) & 3)) {
853             /* maybe use the default distribution */
854             FSE_buildCTable_wksp(CTable_LitLength, LL_defaultNorm, MaxLL, LL_defaultNormLog, scratchBuffer, sizeof(scratchBuffer));
855             LLtype = set_basic;
856         } else {
857             /* fall back on a full table */
858             size_t nbSeq_1 = nbSeq;
859             const U32 tableLog = FSE_optimalTableLog(LLFSELog, nbSeq, max);
860             if (count[llCodeTable[nbSeq-1]]>1) { count[llCodeTable[nbSeq-1]]--; nbSeq_1--; }
861             FSE_normalizeCount(norm, tableLog, count, nbSeq_1, max);
862             { size_t const NCountSize = FSE_writeNCount(op, oend-op, norm, max, tableLog);   /* overflow protected */
863               if (FSE_isError(NCountSize)) return ERROR(GENERIC);
864               op += NCountSize; }
865             FSE_buildCTable_wksp(CTable_LitLength, norm, max, tableLog, scratchBuffer, sizeof(scratchBuffer));
866             LLtype = set_compressed;
867     }   }
868
869     /* CTable for Offsets */
870     /* see Literal Lengths for descriptions of mode choices */
871     {   U32 max = MaxOff;
872         size_t const mostFrequent = FSE_countFast_wksp(count, &max, ofCodeTable, nbSeq, WKSP);
873         if (mostFrequent == nbSeq) {
874             *op++ = ofCodeTable[0];
875             FSE_buildCTable_rle(CTable_OffsetBits, (BYTE)max);
876             Offtype = set_rle;
877         } else if (frame->stats.fseInit && !(RAND(seed) & 3) &&
878                    isSymbolSubset(ofCodeTable, nbSeq,
879                                   frame->stats.offsetSymbolSet, 28)) {
880             Offtype = set_repeat;
881         } else if (!(RAND(seed) & 3)) {
882             FSE_buildCTable_wksp(CTable_OffsetBits, OF_defaultNorm, DefaultMaxOff, OF_defaultNormLog, scratchBuffer, sizeof(scratchBuffer));
883             Offtype = set_basic;
884         } else {
885             size_t nbSeq_1 = nbSeq;
886             const U32 tableLog = FSE_optimalTableLog(OffFSELog, nbSeq, max);
887             if (count[ofCodeTable[nbSeq-1]]>1) { count[ofCodeTable[nbSeq-1]]--; nbSeq_1--; }
888             FSE_normalizeCount(norm, tableLog, count, nbSeq_1, max);
889             { size_t const NCountSize = FSE_writeNCount(op, oend-op, norm, max, tableLog);   /* overflow protected */
890               if (FSE_isError(NCountSize)) return ERROR(GENERIC);
891               op += NCountSize; }
892             FSE_buildCTable_wksp(CTable_OffsetBits, norm, max, tableLog, scratchBuffer, sizeof(scratchBuffer));
893             Offtype = set_compressed;
894     }   }
895
896     /* CTable for MatchLengths */
897     /* see Literal Lengths for descriptions of mode choices */
898     {   U32 max = MaxML;
899         size_t const mostFrequent = FSE_countFast_wksp(count, &max, mlCodeTable, nbSeq, WKSP);
900         if (mostFrequent == nbSeq) {
901             *op++ = *mlCodeTable;
902             FSE_buildCTable_rle(CTable_MatchLength, (BYTE)max);
903             MLtype = set_rle;
904         } else if (frame->stats.fseInit && !(RAND(seed) & 3) &&
905                    isSymbolSubset(mlCodeTable, nbSeq,
906                                   frame->stats.matchlengthSymbolSet, 52)) {
907             MLtype = set_repeat;
908         } else if (!(RAND(seed) & 3)) {
909             /* sometimes do default distribution */
910             FSE_buildCTable_wksp(CTable_MatchLength, ML_defaultNorm, MaxML, ML_defaultNormLog, scratchBuffer, sizeof(scratchBuffer));
911             MLtype = set_basic;
912         } else {
913             /* fall back on table */
914             size_t nbSeq_1 = nbSeq;
915             const U32 tableLog = FSE_optimalTableLog(MLFSELog, nbSeq, max);
916             if (count[mlCodeTable[nbSeq-1]]>1) { count[mlCodeTable[nbSeq-1]]--; nbSeq_1--; }
917             FSE_normalizeCount(norm, tableLog, count, nbSeq_1, max);
918             { size_t const NCountSize = FSE_writeNCount(op, oend-op, norm, max, tableLog);   /* overflow protected */
919               if (FSE_isError(NCountSize)) return ERROR(GENERIC);
920               op += NCountSize; }
921             FSE_buildCTable_wksp(CTable_MatchLength, norm, max, tableLog, scratchBuffer, sizeof(scratchBuffer));
922             MLtype = set_compressed;
923     }   }
924     frame->stats.fseInit = 1;
925     initSymbolSet(llCodeTable, nbSeq, frame->stats.litlengthSymbolSet, 35);
926     initSymbolSet(ofCodeTable, nbSeq, frame->stats.offsetSymbolSet, 28);
927     initSymbolSet(mlCodeTable, nbSeq, frame->stats.matchlengthSymbolSet, 52);
928
929     DISPLAYLEVEL(5, "    LL type: %d OF type: %d ML type: %d\n", LLtype, Offtype, MLtype);
930
931     *seqHead = (BYTE)((LLtype<<6) + (Offtype<<4) + (MLtype<<2));
932
933     /* Encoding Sequences */
934     {   BIT_CStream_t blockStream;
935         FSE_CState_t  stateMatchLength;
936         FSE_CState_t  stateOffsetBits;
937         FSE_CState_t  stateLitLength;
938
939         CHECK_E(BIT_initCStream(&blockStream, op, oend-op), dstSize_tooSmall); /* not enough space remaining */
940
941         /* first symbols */
942         FSE_initCState2(&stateMatchLength, CTable_MatchLength, mlCodeTable[nbSeq-1]);
943         FSE_initCState2(&stateOffsetBits,  CTable_OffsetBits,  ofCodeTable[nbSeq-1]);
944         FSE_initCState2(&stateLitLength,   CTable_LitLength,   llCodeTable[nbSeq-1]);
945         BIT_addBits(&blockStream, sequences[nbSeq-1].litLength, LL_bits[llCodeTable[nbSeq-1]]);
946         if (MEM_32bits()) BIT_flushBits(&blockStream);
947         BIT_addBits(&blockStream, sequences[nbSeq-1].matchLength, ML_bits[mlCodeTable[nbSeq-1]]);
948         if (MEM_32bits()) BIT_flushBits(&blockStream);
949         BIT_addBits(&blockStream, sequences[nbSeq-1].offset, ofCodeTable[nbSeq-1]);
950         BIT_flushBits(&blockStream);
951
952         {   size_t n;
953             for (n=nbSeq-2 ; n<nbSeq ; n--) {      /* intentional underflow */
954                 BYTE const llCode = llCodeTable[n];
955                 BYTE const ofCode = ofCodeTable[n];
956                 BYTE const mlCode = mlCodeTable[n];
957                 U32  const llBits = LL_bits[llCode];
958                 U32  const ofBits = ofCode;                                     /* 32b*/  /* 64b*/
959                 U32  const mlBits = ML_bits[mlCode];
960                                                                                 /* (7)*/  /* (7)*/
961                 FSE_encodeSymbol(&blockStream, &stateOffsetBits, ofCode);       /* 15 */  /* 15 */
962                 FSE_encodeSymbol(&blockStream, &stateMatchLength, mlCode);      /* 24 */  /* 24 */
963                 if (MEM_32bits()) BIT_flushBits(&blockStream);                  /* (7)*/
964                 FSE_encodeSymbol(&blockStream, &stateLitLength, llCode);        /* 16 */  /* 33 */
965                 if (MEM_32bits() || (ofBits+mlBits+llBits >= 64-7-(LLFSELog+MLFSELog+OffFSELog)))
966                     BIT_flushBits(&blockStream);                                /* (7)*/
967                 BIT_addBits(&blockStream, sequences[n].litLength, llBits);
968                 if (MEM_32bits() && ((llBits+mlBits)>24)) BIT_flushBits(&blockStream);
969                 BIT_addBits(&blockStream, sequences[n].matchLength, mlBits);
970                 if (MEM_32bits()) BIT_flushBits(&blockStream);                  /* (7)*/
971                 BIT_addBits(&blockStream, sequences[n].offset, ofBits);         /* 31 */
972                 BIT_flushBits(&blockStream);                                    /* (7)*/
973         }   }
974
975         FSE_flushCState(&blockStream, &stateMatchLength);
976         FSE_flushCState(&blockStream, &stateOffsetBits);
977         FSE_flushCState(&blockStream, &stateLitLength);
978
979         {   size_t const streamSize = BIT_closeCStream(&blockStream);
980             if (streamSize==0) return ERROR(dstSize_tooSmall);   /* not enough space */
981             op += streamSize;
982     }   }
983
984     frame->data = op;
985
986     return 0;
987 }
988
989 static size_t writeSequencesBlock(U32* seed, frame_t* frame, size_t contentSize,
990                                   size_t literalsSize, dictInfo info)
991 {
992     seqStore_t seqStore;
993     size_t numSequences;
994
995
996     initSeqStore(&seqStore);
997
998     /* randomly generate sequences */
999     numSequences = generateSequences(seed, frame, &seqStore, contentSize, literalsSize, info);
1000     /* write them out to the frame data */
1001     CHECKERR(writeSequences(seed, frame, &seqStore, numSequences));
1002
1003     return numSequences;
1004 }
1005
1006 static size_t writeCompressedBlock(U32* seed, frame_t* frame, size_t contentSize, dictInfo info)
1007 {
1008     BYTE* const blockStart = (BYTE*)frame->data;
1009     size_t literalsSize;
1010     size_t nbSeq;
1011
1012     DISPLAYLEVEL(4, "  compressed block:\n");
1013
1014     literalsSize = writeLiteralsBlock(seed, frame, contentSize);
1015
1016     DISPLAYLEVEL(4, "   literals size: %u\n", (U32)literalsSize);
1017
1018     nbSeq = writeSequencesBlock(seed, frame, contentSize, literalsSize, info);
1019
1020     DISPLAYLEVEL(4, "   number of sequences: %u\n", (U32)nbSeq);
1021
1022     return (BYTE*)frame->data - blockStart;
1023 }
1024
1025 static void writeBlock(U32* seed, frame_t* frame, size_t contentSize,
1026                        int lastBlock, dictInfo info)
1027 {
1028     int const blockTypeDesc = RAND(seed) % 8;
1029     size_t blockSize;
1030     int blockType;
1031
1032     BYTE *const header = (BYTE*)frame->data;
1033     BYTE *op = header + 3;
1034
1035     DISPLAYLEVEL(4, " block:\n");
1036     DISPLAYLEVEL(4, "  block content size: %u\n", (U32)contentSize);
1037     DISPLAYLEVEL(4, "  last block: %s\n", lastBlock ? "yes" : "no");
1038
1039     if (blockTypeDesc == 0) {
1040         /* Raw data frame */
1041
1042         RAND_buffer(seed, frame->src, contentSize);
1043         memcpy(op, frame->src, contentSize);
1044
1045         op += contentSize;
1046         blockType = 0;
1047         blockSize = contentSize;
1048     } else if (blockTypeDesc == 1) {
1049         /* RLE */
1050         BYTE const symbol = RAND(seed) & 0xff;
1051
1052         op[0] = symbol;
1053         memset(frame->src, symbol, contentSize);
1054
1055         op++;
1056         blockType = 1;
1057         blockSize = contentSize;
1058     } else {
1059         /* compressed, most common */
1060         size_t compressedSize;
1061         blockType = 2;
1062
1063         frame->oldStats = frame->stats;
1064
1065         frame->data = op;
1066         compressedSize = writeCompressedBlock(seed, frame, contentSize, info);
1067         if (compressedSize >= contentSize) {   /* compressed block must be strictly smaller than uncompressed one */
1068             blockType = 0;
1069             memcpy(op, frame->src, contentSize);
1070
1071             op += contentSize;
1072             blockSize = contentSize; /* fall back on raw block if data doesn't
1073                                         compress */
1074
1075             frame->stats = frame->oldStats; /* don't update the stats */
1076         } else {
1077             op += compressedSize;
1078             blockSize = compressedSize;
1079         }
1080     }
1081     frame->src = (BYTE*)frame->src + contentSize;
1082
1083     DISPLAYLEVEL(4, "  block type: %s\n", BLOCK_TYPES[blockType]);
1084     DISPLAYLEVEL(4, "  block size field: %u\n", (U32)blockSize);
1085
1086     header[0] = (BYTE) ((lastBlock | (blockType << 1) | (blockSize << 3)) & 0xff);
1087     MEM_writeLE16(header + 1, (U16) (blockSize >> 5));
1088
1089     frame->data = op;
1090 }
1091
1092 static void writeBlocks(U32* seed, frame_t* frame, dictInfo info)
1093 {
1094     size_t contentLeft = frame->header.contentSize;
1095     size_t const maxBlockSize = MIN(g_maxBlockSize, frame->header.windowSize);
1096     while (1) {
1097         /* 1 in 4 chance of ending frame */
1098         int const lastBlock = contentLeft > maxBlockSize ? 0 : !(RAND(seed) & 3);
1099         size_t blockContentSize;
1100         if (lastBlock) {
1101             blockContentSize = contentLeft;
1102         } else {
1103             if (contentLeft > 0 && (RAND(seed) & 7)) {
1104                 /* some variable size block */
1105                 blockContentSize = RAND(seed) % (MIN(maxBlockSize, contentLeft)+1);
1106             } else if (contentLeft > maxBlockSize && (RAND(seed) & 1)) {
1107                 /* some full size block */
1108                 blockContentSize = maxBlockSize;
1109             } else {
1110                 /* some empty block */
1111                 blockContentSize = 0;
1112             }
1113         }
1114
1115         writeBlock(seed, frame, blockContentSize, lastBlock, info);
1116
1117         contentLeft -= blockContentSize;
1118         if (lastBlock) break;
1119     }
1120 }
1121
1122 static void writeChecksum(frame_t* frame)
1123 {
1124     /* write checksum so implementations can verify their output */
1125     U64 digest = XXH64(frame->srcStart, (BYTE*)frame->src-(BYTE*)frame->srcStart, 0);
1126     DISPLAYLEVEL(3, "  checksum: %08x\n", (U32)digest);
1127     MEM_writeLE32(frame->data, (U32)digest);
1128     frame->data = (BYTE*)frame->data + 4;
1129 }
1130
1131 static void outputBuffer(const void* buf, size_t size, const char* const path)
1132 {
1133     /* write data out to file */
1134     const BYTE* ip = (const BYTE*)buf;
1135     FILE* out;
1136     if (path) {
1137         out = fopen(path, "wb");
1138     } else {
1139         out = stdout;
1140     }
1141     if (!out) {
1142         fprintf(stderr, "Failed to open file at %s: ", path);
1143         perror(NULL);
1144         exit(1);
1145     }
1146
1147     {   size_t fsize = size;
1148         size_t written = 0;
1149         while (written < fsize) {
1150             written += fwrite(ip + written, 1, fsize - written, out);
1151             if (ferror(out)) {
1152                 fprintf(stderr, "Failed to write to file at %s: ", path);
1153                 perror(NULL);
1154                 exit(1);
1155             }
1156         }
1157     }
1158
1159     if (path) {
1160         fclose(out);
1161     }
1162 }
1163
1164 static void initFrame(frame_t* fr)
1165 {
1166     memset(fr, 0, sizeof(*fr));
1167     fr->data = fr->dataStart = FRAME_BUFFER;
1168     fr->dataEnd = FRAME_BUFFER + sizeof(FRAME_BUFFER);
1169     fr->src = fr->srcStart = CONTENT_BUFFER;
1170     fr->srcEnd = CONTENT_BUFFER + sizeof(CONTENT_BUFFER);
1171
1172     /* init repeat codes */
1173     fr->stats.rep[0] = 1;
1174     fr->stats.rep[1] = 4;
1175     fr->stats.rep[2] = 8;
1176 }
1177
1178 /**
1179  * Generated a single zstd compressed block with no block/frame header.
1180  * Returns the final seed.
1181  */
1182 static U32 generateCompressedBlock(U32 seed, frame_t* frame, dictInfo info)
1183 {
1184     size_t blockContentSize;
1185     int blockWritten = 0;
1186     BYTE* op;
1187     DISPLAYLEVEL(4, "block seed: %u\n", seed);
1188     initFrame(frame);
1189     op = (BYTE*)frame->data;
1190
1191     while (!blockWritten) {
1192         size_t cSize;
1193         /* generate window size */
1194         {   int const exponent = RAND(&seed) % (MAX_WINDOW_LOG - 10);
1195             int const mantissa = RAND(&seed) % 8;
1196             frame->header.windowSize = (1U << (exponent + 10));
1197             frame->header.windowSize += (frame->header.windowSize / 8) * mantissa;
1198         }
1199
1200         /* generate content size */
1201         {   size_t const maxBlockSize = MIN(g_maxBlockSize, frame->header.windowSize);
1202             if (RAND(&seed) & 15) {
1203                 /* some full size blocks */
1204                 blockContentSize = maxBlockSize;
1205             } else if (RAND(&seed) & 7 && g_maxBlockSize >= (1U << 7)) {
1206                 /* some small blocks <= 128 bytes*/
1207                 blockContentSize = RAND(&seed) % (1U << 7);
1208             } else {
1209                 /* some variable size blocks */
1210                 blockContentSize = RAND(&seed) % maxBlockSize;
1211             }
1212         }
1213
1214         /* try generating a compressed block */
1215         frame->oldStats = frame->stats;
1216         frame->data = op;
1217         cSize = writeCompressedBlock(&seed, frame, blockContentSize, info);
1218         if (cSize >= blockContentSize) {  /* compressed size must be strictly smaller than decompressed size : https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#blocks */
1219             /* data doesn't compress -- try again */
1220             frame->stats = frame->oldStats; /* don't update the stats */
1221             DISPLAYLEVEL(5, "   can't compress block : try again \n");
1222         } else {
1223             blockWritten = 1;
1224             DISPLAYLEVEL(4, "   block size: %u \n", (U32)cSize);
1225             frame->src = (BYTE*)frame->src + blockContentSize;
1226         }
1227     }
1228     return seed;
1229 }
1230
1231 /* Return the final seed */
1232 static U32 generateFrame(U32 seed, frame_t* fr, dictInfo info)
1233 {
1234     /* generate a complete frame */
1235     DISPLAYLEVEL(3, "frame seed: %u\n", seed);
1236     initFrame(fr);
1237
1238     writeFrameHeader(&seed, fr, info);
1239     writeBlocks(&seed, fr, info);
1240     writeChecksum(fr);
1241
1242     return seed;
1243 }
1244
1245 /*_*******************************************************
1246 *  Dictionary Helper Functions
1247 *********************************************************/
1248 /* returns 0 if successful, otherwise returns 1 upon error */
1249 static int genRandomDict(U32 dictID, U32 seed, size_t dictSize, BYTE* fullDict)
1250 {
1251     /* allocate space for samples */
1252     int ret = 0;
1253     unsigned const numSamples = 4;
1254     size_t sampleSizes[4];
1255     BYTE* const samples = malloc(5000*sizeof(BYTE));
1256     if (samples == NULL) {
1257         DISPLAY("Error: could not allocate space for samples\n");
1258         return 1;
1259     }
1260
1261     /* generate samples */
1262     {   unsigned literalValue = 1;
1263         unsigned samplesPos = 0;
1264         size_t currSize = 1;
1265         while (literalValue <= 4) {
1266             sampleSizes[literalValue - 1] = currSize;
1267             {   size_t k;
1268                 for (k = 0; k < currSize; k++) {
1269                     *(samples + (samplesPos++)) = (BYTE)literalValue;
1270             }   }
1271             literalValue++;
1272             currSize *= 16;
1273     }   }
1274
1275     {   size_t dictWriteSize = 0;
1276         ZDICT_params_t zdictParams;
1277         size_t const headerSize = MAX(dictSize/4, 256);
1278         size_t const dictContentSize = dictSize - headerSize;
1279         BYTE* const dictContent = fullDict + headerSize;
1280         if (dictContentSize < ZDICT_CONTENTSIZE_MIN || dictSize < ZDICT_DICTSIZE_MIN) {
1281             DISPLAY("Error: dictionary size is too small\n");
1282             ret = 1;
1283             goto exitGenRandomDict;
1284         }
1285
1286         /* init dictionary params */
1287         memset(&zdictParams, 0, sizeof(zdictParams));
1288         zdictParams.dictID = dictID;
1289         zdictParams.notificationLevel = 1;
1290
1291         /* fill in dictionary content */
1292         RAND_buffer(&seed, (void*)dictContent, dictContentSize);
1293
1294         /* finalize dictionary with random samples */
1295         dictWriteSize = ZDICT_finalizeDictionary(fullDict, dictSize,
1296                                     dictContent, dictContentSize,
1297                                     samples, sampleSizes, numSamples,
1298                                     zdictParams);
1299
1300         if (ZDICT_isError(dictWriteSize)) {
1301             DISPLAY("Could not finalize dictionary: %s\n", ZDICT_getErrorName(dictWriteSize));
1302             ret = 1;
1303         }
1304     }
1305
1306 exitGenRandomDict:
1307     free(samples);
1308     return ret;
1309 }
1310
1311 static dictInfo initDictInfo(int useDict, size_t dictContentSize, BYTE* dictContent, U32 dictID){
1312     /* allocate space statically */
1313     dictInfo dictOp;
1314     memset(&dictOp, 0, sizeof(dictOp));
1315     dictOp.useDict = useDict;
1316     dictOp.dictContentSize = dictContentSize;
1317     dictOp.dictContent = dictContent;
1318     dictOp.dictID = dictID;
1319     return dictOp;
1320 }
1321
1322 /*-*******************************************************
1323 *  Test Mode
1324 *********************************************************/
1325
1326 BYTE DECOMPRESSED_BUFFER[MAX_DECOMPRESSED_SIZE];
1327
1328 static size_t testDecodeSimple(frame_t* fr)
1329 {
1330     /* test decoding the generated data with the simple API */
1331     size_t const ret = ZSTD_decompress(DECOMPRESSED_BUFFER, MAX_DECOMPRESSED_SIZE,
1332                            fr->dataStart, (BYTE*)fr->data - (BYTE*)fr->dataStart);
1333
1334     if (ZSTD_isError(ret)) return ret;
1335
1336     if (memcmp(DECOMPRESSED_BUFFER, fr->srcStart,
1337                (BYTE*)fr->src - (BYTE*)fr->srcStart) != 0) {
1338         return ERROR(corruption_detected);
1339     }
1340
1341     return ret;
1342 }
1343
1344 static size_t testDecodeStreaming(frame_t* fr)
1345 {
1346     /* test decoding the generated data with the streaming API */
1347     ZSTD_DStream* zd = ZSTD_createDStream();
1348     ZSTD_inBuffer in;
1349     ZSTD_outBuffer out;
1350     size_t ret;
1351
1352     if (!zd) return ERROR(memory_allocation);
1353
1354     in.src = fr->dataStart;
1355     in.pos = 0;
1356     in.size = (BYTE*)fr->data - (BYTE*)fr->dataStart;
1357
1358     out.dst = DECOMPRESSED_BUFFER;
1359     out.pos = 0;
1360     out.size = ZSTD_DStreamOutSize();
1361
1362     ZSTD_initDStream(zd);
1363     while (1) {
1364         ret = ZSTD_decompressStream(zd, &out, &in);
1365         if (ZSTD_isError(ret)) goto cleanup; /* error */
1366         if (ret == 0) break; /* frame is done */
1367
1368         /* force decoding to be done in chunks */
1369         out.size += MIN(ZSTD_DStreamOutSize(), MAX_DECOMPRESSED_SIZE - out.size);
1370     }
1371
1372     ret = out.pos;
1373
1374     if (memcmp(out.dst, fr->srcStart, out.pos) != 0) {
1375         return ERROR(corruption_detected);
1376     }
1377
1378 cleanup:
1379     ZSTD_freeDStream(zd);
1380     return ret;
1381 }
1382
1383 static size_t testDecodeWithDict(U32 seed, genType_e genType)
1384 {
1385     /* create variables */
1386     size_t const dictSize = RAND(&seed) % (10 << 20) + ZDICT_DICTSIZE_MIN + ZDICT_CONTENTSIZE_MIN;
1387     U32 const dictID = RAND(&seed);
1388     size_t errorDetected = 0;
1389     BYTE* const fullDict = malloc(dictSize);
1390     if (fullDict == NULL) {
1391         return ERROR(GENERIC);
1392     }
1393
1394     /* generate random dictionary */
1395     if (genRandomDict(dictID, seed, dictSize, fullDict)) {  /* return 0 on success */
1396         errorDetected = ERROR(GENERIC);
1397         goto dictTestCleanup;
1398     }
1399
1400
1401     {   frame_t fr;
1402         dictInfo info;
1403         ZSTD_DCtx* const dctx = ZSTD_createDCtx();
1404         size_t ret;
1405
1406         /* get dict info */
1407         {   size_t const headerSize = MAX(dictSize/4, 256);
1408             size_t const dictContentSize = dictSize-headerSize;
1409             BYTE* const dictContent = fullDict+headerSize;
1410             info = initDictInfo(1, dictContentSize, dictContent, dictID);
1411         }
1412
1413         /* manually decompress and check difference */
1414         if (genType == gt_frame) {
1415             /* Test frame */
1416             generateFrame(seed, &fr, info);
1417             ret = ZSTD_decompress_usingDict(dctx, DECOMPRESSED_BUFFER, MAX_DECOMPRESSED_SIZE,
1418                                             fr.dataStart, (BYTE*)fr.data - (BYTE*)fr.dataStart,
1419                                             fullDict, dictSize);
1420         } else {
1421             /* Test block */
1422             generateCompressedBlock(seed, &fr, info);
1423             ret = ZSTD_decompressBegin_usingDict(dctx, fullDict, dictSize);
1424             if (ZSTD_isError(ret)) {
1425                 errorDetected = ret;
1426                 ZSTD_freeDCtx(dctx);
1427                 goto dictTestCleanup;
1428             }
1429             ret = ZSTD_decompressBlock(dctx, DECOMPRESSED_BUFFER, MAX_DECOMPRESSED_SIZE,
1430                                        fr.dataStart, (BYTE*)fr.data - (BYTE*)fr.dataStart);
1431         }
1432         ZSTD_freeDCtx(dctx);
1433
1434         if (ZSTD_isError(ret)) {
1435             errorDetected = ret;
1436             goto dictTestCleanup;
1437         }
1438
1439         if (memcmp(DECOMPRESSED_BUFFER, fr.srcStart, (BYTE*)fr.src - (BYTE*)fr.srcStart) != 0) {
1440             errorDetected = ERROR(corruption_detected);
1441             goto dictTestCleanup;
1442         }
1443     }
1444
1445 dictTestCleanup:
1446     free(fullDict);
1447     return errorDetected;
1448 }
1449
1450 static size_t testDecodeRawBlock(frame_t* fr)
1451 {
1452     ZSTD_DCtx* dctx = ZSTD_createDCtx();
1453     size_t ret = ZSTD_decompressBegin(dctx);
1454     if (ZSTD_isError(ret)) return ret;
1455
1456     ret = ZSTD_decompressBlock(
1457             dctx,
1458             DECOMPRESSED_BUFFER, MAX_DECOMPRESSED_SIZE,
1459             fr->dataStart, (BYTE*)fr->data - (BYTE*)fr->dataStart);
1460     ZSTD_freeDCtx(dctx);
1461     if (ZSTD_isError(ret)) return ret;
1462
1463     if (memcmp(DECOMPRESSED_BUFFER, fr->srcStart,
1464                (BYTE*)fr->src - (BYTE*)fr->srcStart) != 0) {
1465         return ERROR(corruption_detected);
1466     }
1467
1468     return ret;
1469 }
1470
1471 static int runBlockTest(U32* seed)
1472 {
1473     frame_t fr;
1474     U32 const seedCopy = *seed;
1475     {   dictInfo const info = initDictInfo(0, 0, NULL, 0);
1476         *seed = generateCompressedBlock(*seed, &fr, info);
1477     }
1478
1479     {   size_t const r = testDecodeRawBlock(&fr);
1480         if (ZSTD_isError(r)) {
1481             DISPLAY("Error in block mode on test seed %u: %s\n", seedCopy,
1482                     ZSTD_getErrorName(r));
1483             return 1;
1484         }
1485     }
1486
1487     {   size_t const r = testDecodeWithDict(*seed, gt_block);
1488         if (ZSTD_isError(r)) {
1489             DISPLAY("Error in block mode with dictionary on test seed %u: %s\n",
1490                     seedCopy, ZSTD_getErrorName(r));
1491             return 1;
1492         }
1493     }
1494     return 0;
1495 }
1496
1497 static int runFrameTest(U32* seed)
1498 {
1499     frame_t fr;
1500     U32 const seedCopy = *seed;
1501     {   dictInfo const info = initDictInfo(0, 0, NULL, 0);
1502         *seed = generateFrame(*seed, &fr, info);
1503     }
1504
1505     {   size_t const r = testDecodeSimple(&fr);
1506         if (ZSTD_isError(r)) {
1507             DISPLAY("Error in simple mode on test seed %u: %s\n",
1508                     seedCopy, ZSTD_getErrorName(r));
1509             return 1;
1510         }
1511     }
1512     {   size_t const r = testDecodeStreaming(&fr);
1513         if (ZSTD_isError(r)) {
1514             DISPLAY("Error in streaming mode on test seed %u: %s\n",
1515                     seedCopy, ZSTD_getErrorName(r));
1516             return 1;
1517         }
1518     }
1519     {   size_t const r = testDecodeWithDict(*seed, gt_frame);  /* avoid big dictionaries */
1520         if (ZSTD_isError(r)) {
1521             DISPLAY("Error in dictionary mode on test seed %u: %s\n",
1522                     seedCopy, ZSTD_getErrorName(r));
1523             return 1;
1524         }
1525     }
1526     return 0;
1527 }
1528
1529 static int runTestMode(U32 seed, unsigned numFiles, unsigned const testDurationS,
1530                        genType_e genType)
1531 {
1532     unsigned fnum;
1533
1534     clock_t const startClock = clock();
1535     clock_t const maxClockSpan = testDurationS * CLOCKS_PER_SEC;
1536
1537     if (numFiles == 0 && !testDurationS) numFiles = 1;
1538
1539     DISPLAY("seed: %u\n", seed);
1540
1541     for (fnum = 0; fnum < numFiles || clockSpan(startClock) < maxClockSpan; fnum++) {
1542         if (fnum < numFiles)
1543             DISPLAYUPDATE("\r%u/%u        ", fnum, numFiles);
1544         else
1545             DISPLAYUPDATE("\r%u           ", fnum);
1546
1547         {   int const ret = (genType == gt_frame) ?
1548                             runFrameTest(&seed) :
1549                             runBlockTest(&seed);
1550             if (ret) return ret;
1551         }
1552     }
1553
1554     DISPLAY("\r%u tests completed: ", fnum);
1555     DISPLAY("OK\n");
1556
1557     return 0;
1558 }
1559
1560 /*-*******************************************************
1561 *  File I/O
1562 *********************************************************/
1563
1564 static int generateFile(U32 seed, const char* const path,
1565                         const char* const origPath, genType_e genType)
1566 {
1567     frame_t fr;
1568
1569     DISPLAY("seed: %u\n", seed);
1570
1571     {   dictInfo const info = initDictInfo(0, 0, NULL, 0);
1572         if (genType == gt_frame) {
1573             generateFrame(seed, &fr, info);
1574         } else {
1575             generateCompressedBlock(seed, &fr, info);
1576         }
1577     }
1578     outputBuffer(fr.dataStart, (BYTE*)fr.data - (BYTE*)fr.dataStart, path);
1579     if (origPath) {
1580         outputBuffer(fr.srcStart, (BYTE*)fr.src - (BYTE*)fr.srcStart, origPath);
1581     }
1582     return 0;
1583 }
1584
1585 static int generateCorpus(U32 seed, unsigned numFiles, const char* const path,
1586                           const char* const origPath, genType_e genType)
1587 {
1588     char outPath[MAX_PATH];
1589     unsigned fnum;
1590
1591     DISPLAY("seed: %u\n", seed);
1592
1593     for (fnum = 0; fnum < numFiles; fnum++) {
1594         frame_t fr;
1595
1596         DISPLAYUPDATE("\r%u/%u        ", fnum, numFiles);
1597
1598         {   dictInfo const info = initDictInfo(0, 0, NULL, 0);
1599             if (genType == gt_frame) {
1600                 seed = generateFrame(seed, &fr, info);
1601             } else {
1602                 seed = generateCompressedBlock(seed, &fr, info);
1603             }
1604         }
1605
1606         if (snprintf(outPath, MAX_PATH, "%s/z%06u.zst", path, fnum) + 1 > MAX_PATH) {
1607             DISPLAY("Error: path too long\n");
1608             return 1;
1609         }
1610         outputBuffer(fr.dataStart, (BYTE*)fr.data - (BYTE*)fr.dataStart, outPath);
1611
1612         if (origPath) {
1613             if (snprintf(outPath, MAX_PATH, "%s/z%06u", origPath, fnum) + 1 > MAX_PATH) {
1614                 DISPLAY("Error: path too long\n");
1615                 return 1;
1616             }
1617             outputBuffer(fr.srcStart, (BYTE*)fr.src - (BYTE*)fr.srcStart, outPath);
1618         }
1619     }
1620
1621     DISPLAY("\r%u/%u      \n", fnum, numFiles);
1622
1623     return 0;
1624 }
1625
1626 static int generateCorpusWithDict(U32 seed, unsigned numFiles, const char* const path,
1627                                   const char* const origPath, const size_t dictSize,
1628                                   genType_e genType)
1629 {
1630     char outPath[MAX_PATH];
1631     BYTE* fullDict;
1632     U32 const dictID = RAND(&seed);
1633     int errorDetected = 0;
1634
1635     if (snprintf(outPath, MAX_PATH, "%s/dictionary", path) + 1 > MAX_PATH) {
1636         DISPLAY("Error: path too long\n");
1637         return 1;
1638     }
1639
1640     /* allocate space for the dictionary */
1641     fullDict = malloc(dictSize);
1642     if (fullDict == NULL) {
1643         DISPLAY("Error: could not allocate space for full dictionary.\n");
1644         return 1;
1645     }
1646
1647     /* randomly generate the dictionary */
1648     {   int const ret = genRandomDict(dictID, seed, dictSize, fullDict);
1649         if (ret != 0) {
1650             errorDetected = ret;
1651             goto dictCleanup;
1652         }
1653     }
1654
1655     /* write out dictionary */
1656     if (numFiles != 0) {
1657         if (snprintf(outPath, MAX_PATH, "%s/dictionary", path) + 1 > MAX_PATH) {
1658             DISPLAY("Error: dictionary path too long\n");
1659             errorDetected = 1;
1660             goto dictCleanup;
1661         }
1662         outputBuffer(fullDict, dictSize, outPath);
1663     }
1664     else {
1665         outputBuffer(fullDict, dictSize, "dictionary");
1666     }
1667
1668     /* generate random compressed/decompressed files */
1669     {   unsigned fnum;
1670         for (fnum = 0; fnum < MAX(numFiles, 1); fnum++) {
1671             frame_t fr;
1672             DISPLAYUPDATE("\r%u/%u        ", fnum, numFiles);
1673             {
1674                 size_t const headerSize = MAX(dictSize/4, 256);
1675                 size_t const dictContentSize = dictSize-headerSize;
1676                 BYTE* const dictContent = fullDict+headerSize;
1677                 dictInfo const info = initDictInfo(1, dictContentSize, dictContent, dictID);
1678                 if (genType == gt_frame) {
1679                     seed = generateFrame(seed, &fr, info);
1680                 } else {
1681                     seed = generateCompressedBlock(seed, &fr, info);
1682                 }
1683             }
1684
1685             if (numFiles != 0) {
1686                 if (snprintf(outPath, MAX_PATH, "%s/z%06u.zst", path, fnum) + 1 > MAX_PATH) {
1687                     DISPLAY("Error: path too long\n");
1688                     errorDetected = 1;
1689                     goto dictCleanup;
1690                 }
1691                 outputBuffer(fr.dataStart, (BYTE*)fr.data - (BYTE*)fr.dataStart, outPath);
1692
1693                 if (origPath) {
1694                     if (snprintf(outPath, MAX_PATH, "%s/z%06u", origPath, fnum) + 1 > MAX_PATH) {
1695                         DISPLAY("Error: path too long\n");
1696                         errorDetected = 1;
1697                         goto dictCleanup;
1698                     }
1699                     outputBuffer(fr.srcStart, (BYTE*)fr.src - (BYTE*)fr.srcStart, outPath);
1700                 }
1701             }
1702             else {
1703                 outputBuffer(fr.dataStart, (BYTE*)fr.data - (BYTE*)fr.dataStart, path);
1704                 if (origPath) {
1705                     outputBuffer(fr.srcStart, (BYTE*)fr.src - (BYTE*)fr.srcStart, origPath);
1706                 }
1707             }
1708         }
1709     }
1710
1711 dictCleanup:
1712     free(fullDict);
1713     return errorDetected;
1714 }
1715
1716
1717 /*_*******************************************************
1718 *  Command line
1719 *********************************************************/
1720 static U32 makeSeed(void)
1721 {
1722     U32 t = (U32) time(NULL);
1723     return XXH32(&t, sizeof(t), 0) % 65536;
1724 }
1725
1726 static unsigned readInt(const char** argument)
1727 {
1728     unsigned val = 0;
1729     while ((**argument>='0') && (**argument<='9')) {
1730         val *= 10;
1731         val += **argument - '0';
1732         (*argument)++;
1733     }
1734     return val;
1735 }
1736
1737 static void usage(const char* programName)
1738 {
1739     DISPLAY( "Usage :\n");
1740     DISPLAY( "      %s [args]\n", programName);
1741     DISPLAY( "\n");
1742     DISPLAY( "Arguments :\n");
1743     DISPLAY( " -p<path> : select output path (default:stdout)\n");
1744     DISPLAY( "                in multiple files mode this should be a directory\n");
1745     DISPLAY( " -o<path> : select path to output original file (default:no output)\n");
1746     DISPLAY( "                in multiple files mode this should be a directory\n");
1747     DISPLAY( " -s#      : select seed (default:random based on time)\n");
1748     DISPLAY( " -n#      : number of files to generate (default:1)\n");
1749     DISPLAY( " -t       : activate test mode (test files against libzstd instead of outputting them)\n");
1750     DISPLAY( " -T#      : length of time to run tests for\n");
1751     DISPLAY( " -v       : increase verbosity level (default:0, max:7)\n");
1752     DISPLAY( " -h/H     : display help/long help and exit\n");
1753 }
1754
1755 static void advancedUsage(const char* programName)
1756 {
1757     usage(programName);
1758     DISPLAY( "\n");
1759     DISPLAY( "Advanced arguments        :\n");
1760     DISPLAY( " --content-size           : always include the content size in the frame header\n");
1761     DISPLAY( " --use-dict=#             : include a dictionary used to decompress the corpus\n");
1762     DISPLAY( " --gen-blocks             : generate raw compressed blocks without block/frame headers\n");
1763     DISPLAY( " --max-block-size-log=#   : max block size log, must be in range [2, 17]\n");
1764     DISPLAY( " --max-content-size-log=# : max content size log, must be <= 20\n");
1765     DISPLAY( "                            (this is ignored with gen-blocks)\n");
1766 }
1767
1768 /*! readU32FromChar() :
1769     @return : unsigned integer value read from input in `char` format
1770     allows and interprets K, KB, KiB, M, MB and MiB suffix.
1771     Will also modify `*stringPtr`, advancing it to position where it stopped reading.
1772     Note : function result can overflow if digit string > MAX_UINT */
1773 static unsigned readU32FromChar(const char** stringPtr)
1774 {
1775     unsigned result = 0;
1776     while ((**stringPtr >='0') && (**stringPtr <='9'))
1777         result *= 10, result += **stringPtr - '0', (*stringPtr)++ ;
1778     if ((**stringPtr=='K') || (**stringPtr=='M')) {
1779         result <<= 10;
1780         if (**stringPtr=='M') result <<= 10;
1781         (*stringPtr)++ ;
1782         if (**stringPtr=='i') (*stringPtr)++;
1783         if (**stringPtr=='B') (*stringPtr)++;
1784     }
1785     return result;
1786 }
1787
1788 /** longCommandWArg() :
1789  *  check if *stringPtr is the same as longCommand.
1790  *  If yes, @return 1 and advances *stringPtr to the position which immediately follows longCommand.
1791  *  @return 0 and doesn't modify *stringPtr otherwise.
1792  */
1793 static unsigned longCommandWArg(const char** stringPtr, const char* longCommand)
1794 {
1795     size_t const comSize = strlen(longCommand);
1796     int const result = !strncmp(*stringPtr, longCommand, comSize);
1797     if (result) *stringPtr += comSize;
1798     return result;
1799 }
1800
1801 int main(int argc, char** argv)
1802 {
1803     U32 seed = 0;
1804     int seedset = 0;
1805     unsigned numFiles = 0;
1806     unsigned testDuration = 0;
1807     int testMode = 0;
1808     const char* path = NULL;
1809     const char* origPath = NULL;
1810     int useDict = 0;
1811     unsigned dictSize = (10 << 10); /* 10 kB default */
1812     genType_e genType = gt_frame;
1813
1814     int argNb;
1815
1816     /* Check command line */
1817     for (argNb=1; argNb<argc; argNb++) {
1818         const char* argument = argv[argNb];
1819         if(!argument) continue;   /* Protection if argument empty */
1820
1821         /* Handle commands. Aggregated commands are allowed */
1822         if (argument[0]=='-') {
1823             argument++;
1824             while (*argument!=0) {
1825                 switch(*argument)
1826                 {
1827                 case 'h':
1828                     usage(argv[0]);
1829                     return 0;
1830                 case 'H':
1831                     advancedUsage(argv[0]);
1832                     return 0;
1833                 case 'v':
1834                     argument++;
1835                     g_displayLevel++;
1836                     break;
1837                 case 's':
1838                     argument++;
1839                     seedset=1;
1840                     seed = readInt(&argument);
1841                     break;
1842                 case 'n':
1843                     argument++;
1844                     numFiles = readInt(&argument);
1845                     break;
1846                 case 'T':
1847                     argument++;
1848                     testDuration = readInt(&argument);
1849                     if (*argument == 'm') {
1850                         testDuration *= 60;
1851                         argument++;
1852                         if (*argument == 'n') argument++;
1853                     }
1854                     break;
1855                 case 'o':
1856                     argument++;
1857                     origPath = argument;
1858                     argument += strlen(argument);
1859                     break;
1860                 case 'p':
1861                     argument++;
1862                     path = argument;
1863                     argument += strlen(argument);
1864                     break;
1865                 case 't':
1866                     argument++;
1867                     testMode = 1;
1868                     break;
1869                 case '-':
1870                     argument++;
1871                     if (strcmp(argument, "content-size") == 0) {
1872                         opts.contentSize = 1;
1873                     } else if (longCommandWArg(&argument, "use-dict=")) {
1874                         dictSize = readU32FromChar(&argument);
1875                         useDict = 1;
1876                     } else if (strcmp(argument, "gen-blocks") == 0) {
1877                         genType = gt_block;
1878                     } else if (longCommandWArg(&argument, "max-block-size-log=")) {
1879                         U32 value = readU32FromChar(&argument);
1880                         if (value >= 2 && value <= ZSTD_BLOCKSIZE_MAX) {
1881                             g_maxBlockSize = 1U << value;
1882                         }
1883                     } else if (longCommandWArg(&argument, "max-content-size-log=")) {
1884                         U32 value = readU32FromChar(&argument);
1885                         g_maxDecompressedSizeLog =
1886                                 MIN(MAX_DECOMPRESSED_SIZE_LOG, value);
1887                     } else {
1888                         advancedUsage(argv[0]);
1889                         return 1;
1890                     }
1891                     argument += strlen(argument);
1892                     break;
1893                 default:
1894                     usage(argv[0]);
1895                     return 1;
1896     }   }   }   }   /* for (argNb=1; argNb<argc; argNb++) */
1897
1898     if (!seedset) {
1899         seed = makeSeed();
1900     }
1901
1902     if (testMode) {
1903         return runTestMode(seed, numFiles, testDuration, genType);
1904     } else {
1905         if (testDuration) {
1906             DISPLAY("Error: -T requires test mode (-t)\n\n");
1907             usage(argv[0]);
1908             return 1;
1909         }
1910     }
1911
1912     if (!path) {
1913         DISPLAY("Error: path is required in file generation mode\n");
1914         usage(argv[0]);
1915         return 1;
1916     }
1917
1918     if (numFiles == 0 && useDict == 0) {
1919         return generateFile(seed, path, origPath, genType);
1920     } else if (useDict == 0){
1921         return generateCorpus(seed, numFiles, path, origPath, genType);
1922     } else {
1923         /* should generate files with a dictionary */
1924         return generateCorpusWithDict(seed, numFiles, path, origPath, dictSize, genType);
1925     }
1926
1927 }