]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - contrib/zstd/tests/roundTripCrash.c
Update zstd to 1.3.0
[FreeBSD/FreeBSD.git] / contrib / zstd / tests / roundTripCrash.c
1 /**
2  * Copyright (c) 2016-present, Yann Collet, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree. An additional grant
7  * of patent rights can be found in the PATENTS file in the same directory.
8  */
9
10 /*
11   This program takes a file in input,
12   performs a zstd round-trip test (compression - decompress)
13   compares the result with original
14   and generates a crash (double free) on corruption detection.
15 */
16
17 /*===========================================
18 *   Dependencies
19 *==========================================*/
20 #include <stddef.h>     /* size_t */
21 #include <stdlib.h>     /* malloc, free, exit */
22 #include <stdio.h>      /* fprintf */
23 #include <sys/types.h>  /* stat */
24 #include <sys/stat.h>   /* stat */
25 #include "xxhash.h"
26 #include "zstd.h"
27
28 /*===========================================
29 *   Macros
30 *==========================================*/
31 #define MIN(a,b)  ( (a) < (b) ? (a) : (b) )
32
33 /** roundTripTest() :
34 *   Compresses `srcBuff` into `compressedBuff`,
35 *   then decompresses `compressedBuff` into `resultBuff`.
36 *   Compression level used is derived from first content byte.
37 *   @return : result of decompression, which should be == `srcSize`
38 *          or an error code if either compression or decompression fails.
39 *   Note : `compressedBuffCapacity` should be `>= ZSTD_compressBound(srcSize)`
40 *          for compression to be guaranteed to work */
41 static size_t roundTripTest(void* resultBuff, size_t resultBuffCapacity,
42                             void* compressedBuff, size_t compressedBuffCapacity,
43                       const void* srcBuff, size_t srcBuffSize)
44 {
45     static const int maxClevel = 19;
46     size_t const hashLength = MIN(128, srcBuffSize);
47     unsigned const h32 = XXH32(srcBuff, hashLength, 0);
48     int const cLevel = h32 % maxClevel;
49     size_t const cSize = ZSTD_compress(compressedBuff, compressedBuffCapacity, srcBuff, srcBuffSize, cLevel);
50     if (ZSTD_isError(cSize)) {
51         fprintf(stderr, "Compression error : %s \n", ZSTD_getErrorName(cSize));
52         return cSize;
53     }
54     return ZSTD_decompress(resultBuff, resultBuffCapacity, compressedBuff, cSize);
55 }
56
57
58 static size_t checkBuffers(const void* buff1, const void* buff2, size_t buffSize)
59 {
60     const char* ip1 = (const char*)buff1;
61     const char* ip2 = (const char*)buff2;
62     size_t pos;
63
64     for (pos=0; pos<buffSize; pos++)
65         if (ip1[pos]!=ip2[pos])
66             break;
67
68     return pos;
69 }
70
71 static void crash(int errorCode){
72     /* abort if AFL/libfuzzer, exit otherwise */
73     #ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION /* could also use __AFL_COMPILER */
74         abort();
75     #else
76         exit(errorCode);
77     #endif
78 }
79
80 static void roundTripCheck(const void* srcBuff, size_t srcBuffSize)
81 {
82     size_t const cBuffSize = ZSTD_compressBound(srcBuffSize);
83     void* cBuff = malloc(cBuffSize);
84     void* rBuff = malloc(cBuffSize);
85
86     if (!cBuff || !rBuff) {
87         fprintf(stderr, "not enough memory ! \n");
88         exit (1);
89     }
90
91     {   size_t const result = roundTripTest(rBuff, cBuffSize, cBuff, cBuffSize, srcBuff, srcBuffSize);
92         if (ZSTD_isError(result)) {
93             fprintf(stderr, "roundTripTest error : %s \n", ZSTD_getErrorName(result));
94             crash(1);
95         }
96         if (result != srcBuffSize) {
97             fprintf(stderr, "Incorrect regenerated size : %u != %u\n", (unsigned)result, (unsigned)srcBuffSize);
98             crash(1);
99         }
100         if (checkBuffers(srcBuff, rBuff, srcBuffSize) != srcBuffSize) {
101             fprintf(stderr, "Silent decoding corruption !!!");
102             crash(1);
103         }
104     }
105
106     free(cBuff);
107     free(rBuff);
108 }
109
110
111 static size_t getFileSize(const char* infilename)
112 {
113     int r;
114 #if defined(_MSC_VER)
115     struct _stat64 statbuf;
116     r = _stat64(infilename, &statbuf);
117     if (r || !(statbuf.st_mode & S_IFREG)) return 0;   /* No good... */
118 #else
119     struct stat statbuf;
120     r = stat(infilename, &statbuf);
121     if (r || !S_ISREG(statbuf.st_mode)) return 0;   /* No good... */
122 #endif
123     return (size_t)statbuf.st_size;
124 }
125
126
127 static int isDirectory(const char* infilename)
128 {
129     int r;
130 #if defined(_MSC_VER)
131     struct _stat64 statbuf;
132     r = _stat64(infilename, &statbuf);
133     if (!r && (statbuf.st_mode & _S_IFDIR)) return 1;
134 #else
135     struct stat statbuf;
136     r = stat(infilename, &statbuf);
137     if (!r && S_ISDIR(statbuf.st_mode)) return 1;
138 #endif
139     return 0;
140 }
141
142
143 /** loadFile() :
144 *   requirement : `buffer` size >= `fileSize` */
145 static void loadFile(void* buffer, const char* fileName, size_t fileSize)
146 {
147     FILE* const f = fopen(fileName, "rb");
148     if (isDirectory(fileName)) {
149         fprintf(stderr, "Ignoring %s directory \n", fileName);
150         exit(2);
151     }
152     if (f==NULL) {
153         fprintf(stderr, "Impossible to open %s \n", fileName);
154         exit(3);
155     }
156     {   size_t const readSize = fread(buffer, 1, fileSize, f);
157         if (readSize != fileSize) {
158             fprintf(stderr, "Error reading %s \n", fileName);
159             exit(5);
160     }   }
161     fclose(f);
162 }
163
164
165 static void fileCheck(const char* fileName)
166 {
167     size_t const fileSize = getFileSize(fileName);
168     void* buffer = malloc(fileSize);
169     if (!buffer) {
170         fprintf(stderr, "not enough memory \n");
171         exit(4);
172     }
173     loadFile(buffer, fileName, fileSize);
174     roundTripCheck(buffer, fileSize);
175     free (buffer);
176 }
177
178 int main(int argCount, const char** argv) {
179     if (argCount < 2) {
180         fprintf(stderr, "Error : no argument : need input file \n");
181         exit(9);
182     }
183     fileCheck(argv[1]);
184     fprintf(stderr, "no pb detected\n");
185     return 0;
186 }