]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - lib/libmp/mpasbn.c
Upgrade our copies of clang, llvm, lld, lldb, compiler-rt and libc++ to
[FreeBSD/FreeBSD.git] / lib / libmp / mpasbn.c
1 /*-
2  * SPDX-License-Identifier: BSD-2-Clause-FreeBSD
3  *
4  * Copyright (c) 2001 Dima Dorfman.
5  * All rights reserved.
6  *
7  * Redistribution and use in source and binary forms, with or without
8  * modification, are permitted provided that the following conditions
9  * are met:
10  * 1. Redistributions of source code must retain the above copyright
11  *    notice, this list of conditions and the following disclaimer.
12  * 2. Redistributions in binary form must reproduce the above copyright
13  *    notice, this list of conditions and the following disclaimer in the
14  *    documentation and/or other materials provided with the distribution.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
17  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
20  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
22  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
23  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
24  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
25  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
26  * SUCH DAMAGE.
27  */
28
29 /*
30  * This is the traditional Berkeley MP library implemented in terms of
31  * the OpenSSL BIGNUM library.  It was written to replace libgmp, and
32  * is meant to be as compatible with the latter as feasible.
33  *
34  * There seems to be a lack of documentation for the Berkeley MP
35  * interface.  All I could find was libgmp documentation (which didn't
36  * talk about the semantics of the functions) and an old SunOS 4.1
37  * manual page from 1989.  The latter wasn't very detailed, either,
38  * but at least described what the function's arguments were.  In
39  * general the interface seems to be archaic, somewhat poorly
40  * designed, and poorly, if at all, documented.  It is considered
41  * harmful.
42  *
43  * Miscellaneous notes on this implementation:
44  *
45  *  - The SunOS manual page mentioned above indicates that if an error
46  *  occurs, the library should "produce messages and core images."
47  *  Given that most of the functions don't have return values (and
48  *  thus no sane way of alerting the caller to an error), this seems
49  *  reasonable.  The MPERR and MPERRX macros call warn and warnx,
50  *  respectively, then abort().
51  *
52  *  - All the functions which take an argument to be "filled in"
53  *  assume that the argument has been initialized by one of the *tom()
54  *  routines before being passed to it.  I never saw this documented
55  *  anywhere, but this seems to be consistent with the way this
56  *  library is used.
57  *
58  *  - msqrt() is the only routine which had to be implemented which
59  *  doesn't have a close counterpart in the OpenSSL BIGNUM library.
60  *  It was implemented by hand using Newton's recursive formula.
61  *  Doing it this way, although more error-prone, has the positive
62  *  sideaffect of testing a lot of other functions; if msqrt()
63  *  produces the correct results, most of the other routines will as
64  *  well.
65  *
66  *  - Internal-use-only routines (i.e., those defined here statically
67  *  and not in mp.h) have an underscore prepended to their name (this
68  *  is more for aesthetical reasons than technical).  All such
69  *  routines take an extra argument, 'msg', that denotes what they
70  *  should call themselves in an error message.  This is so a user
71  *  doesn't get an error message from a function they didn't call.
72  */
73
74 #include <sys/cdefs.h>
75 __FBSDID("$FreeBSD$");
76
77 #include <ctype.h>
78 #include <err.h>
79 #include <errno.h>
80 #include <stdio.h>
81 #include <stdlib.h>
82 #include <string.h>
83
84 #include <openssl/crypto.h>
85 #include <openssl/err.h>
86
87 #include "mp.h"
88
89 #define MPERR(s)        do { warn s; abort(); } while (0)
90 #define MPERRX(s)       do { warnx s; abort(); } while (0)
91 #define BN_ERRCHECK(msg, expr) do {             \
92         if (!(expr)) _bnerr(msg);               \
93 } while (0)
94
95 static void _bnerr(const char *);
96 static MINT *_dtom(const char *, const char *);
97 static MINT *_itom(const char *, short);
98 static void _madd(const char *, const MINT *, const MINT *, MINT *);
99 static int _mcmpa(const char *, const MINT *, const MINT *);
100 static void _mdiv(const char *, const MINT *, const MINT *, MINT *, MINT *,
101                 BN_CTX *);
102 static void _mfree(const char *, MINT *);
103 static void _moveb(const char *, const BIGNUM *, MINT *);
104 static void _movem(const char *, const MINT *, MINT *);
105 static void _msub(const char *, const MINT *, const MINT *, MINT *);
106 static char *_mtod(const char *, const MINT *);
107 static char *_mtox(const char *, const MINT *);
108 static void _mult(const char *, const MINT *, const MINT *, MINT *, BN_CTX *);
109 static void _sdiv(const char *, const MINT *, short, MINT *, short *, BN_CTX *);
110 static MINT *_xtom(const char *, const char *);
111
112 /*
113  * Report an error from one of the BN_* functions using MPERRX.
114  */
115 static void
116 _bnerr(const char *msg)
117 {
118
119         ERR_load_crypto_strings();
120         MPERRX(("%s: %s", msg, ERR_reason_error_string(ERR_get_error())));
121 }
122
123 /*
124  * Convert a decimal string to an MINT.
125  */
126 static MINT *
127 _dtom(const char *msg, const char *s)
128 {
129         MINT *mp;
130
131         mp = malloc(sizeof(*mp));
132         if (mp == NULL)
133                 MPERR(("%s", msg));
134         mp->bn = BN_new();
135         if (mp->bn == NULL)
136                 _bnerr(msg);
137         BN_ERRCHECK(msg, BN_dec2bn(&mp->bn, s));
138         return (mp);
139 }
140
141 /*
142  * Compute the greatest common divisor of mp1 and mp2; result goes in rmp.
143  */
144 void
145 mp_gcd(const MINT *mp1, const MINT *mp2, MINT *rmp)
146 {
147         BIGNUM b;
148         BN_CTX *c;
149
150         c = BN_CTX_new();
151         if (c == NULL)
152                 _bnerr("gcd");
153         BN_init(&b);
154         BN_ERRCHECK("gcd", BN_gcd(&b, mp1->bn, mp2->bn, c));
155         _moveb("gcd", &b, rmp);
156         BN_free(&b);
157         BN_CTX_free(c);
158 }
159
160 /*
161  * Make an MINT out of a short integer.  Return value must be mfree()'d.
162  */
163 static MINT *
164 _itom(const char *msg, short n)
165 {
166         MINT *mp;
167         char *s;
168
169         asprintf(&s, "%x", n);
170         if (s == NULL)
171                 MPERR(("%s", msg));
172         mp = _xtom(msg, s);
173         free(s);
174         return (mp);
175 }
176
177 MINT *
178 mp_itom(short n)
179 {
180
181         return (_itom("itom", n));
182 }
183
184 /*
185  * Compute rmp=mp1+mp2.
186  */
187 static void
188 _madd(const char *msg, const MINT *mp1, const MINT *mp2, MINT *rmp)
189 {
190         BIGNUM b;
191
192         BN_init(&b);
193         BN_ERRCHECK(msg, BN_add(&b, mp1->bn, mp2->bn));
194         _moveb(msg, &b, rmp);
195         BN_free(&b);
196 }
197
198 void
199 mp_madd(const MINT *mp1, const MINT *mp2, MINT *rmp)
200 {
201
202         _madd("madd", mp1, mp2, rmp);
203 }
204
205 /*
206  * Return -1, 0, or 1 if mp1<mp2, mp1==mp2, or mp1>mp2, respectivley.
207  */
208 int
209 mp_mcmp(const MINT *mp1, const MINT *mp2)
210 {
211
212         return (BN_cmp(mp1->bn, mp2->bn));
213 }
214
215 /*
216  * Same as mcmp but compares absolute values.
217  */
218 static int
219 _mcmpa(const char *msg __unused, const MINT *mp1, const MINT *mp2)
220 {
221
222         return (BN_ucmp(mp1->bn, mp2->bn));
223 }
224
225 /*
226  * Compute qmp=nmp/dmp and rmp=nmp%dmp.
227  */
228 static void
229 _mdiv(const char *msg, const MINT *nmp, const MINT *dmp, MINT *qmp, MINT *rmp,
230     BN_CTX *c)
231 {
232         BIGNUM q, r;
233
234         BN_init(&r);
235         BN_init(&q);
236         BN_ERRCHECK(msg, BN_div(&q, &r, nmp->bn, dmp->bn, c));
237         _moveb(msg, &q, qmp);
238         _moveb(msg, &r, rmp);
239         BN_free(&q);
240         BN_free(&r);
241 }
242
243 void
244 mp_mdiv(const MINT *nmp, const MINT *dmp, MINT *qmp, MINT *rmp)
245 {
246         BN_CTX *c;
247
248         c = BN_CTX_new();
249         if (c == NULL)
250                 _bnerr("mdiv");
251         _mdiv("mdiv", nmp, dmp, qmp, rmp, c);
252         BN_CTX_free(c);
253 }
254
255 /*
256  * Free memory associated with an MINT.
257  */
258 static void
259 _mfree(const char *msg __unused, MINT *mp)
260 {
261
262         BN_clear(mp->bn);
263         BN_free(mp->bn);
264         free(mp);
265 }
266
267 void
268 mp_mfree(MINT *mp)
269 {
270
271         _mfree("mfree", mp);
272 }
273
274 /*
275  * Read an integer from standard input and stick the result in mp.
276  * The input is treated to be in base 10.  This must be the silliest
277  * API in existence; why can't the program read in a string and call
278  * xtom()?  (Or if base 10 is desires, perhaps dtom() could be
279  * exported.)
280  */
281 void
282 mp_min(MINT *mp)
283 {
284         MINT *rmp;
285         char *line, *nline;
286         size_t linelen;
287
288         line = fgetln(stdin, &linelen);
289         if (line == NULL)
290                 MPERR(("min"));
291         nline = malloc(linelen + 1);
292         if (nline == NULL)
293                 MPERR(("min"));
294         memcpy(nline, line, linelen);
295         nline[linelen] = '\0';
296         rmp = _dtom("min", nline);
297         _movem("min", rmp, mp);
298         _mfree("min", rmp);
299         free(nline);
300 }
301
302 /*
303  * Print the value of mp to standard output in base 10.  See blurb
304  * above min() for why this is so useless.
305  */
306 void
307 mp_mout(const MINT *mp)
308 {
309         char *s;
310
311         s = _mtod("mout", mp);
312         printf("%s", s);
313         free(s);
314 }
315
316 /*
317  * Set the value of tmp to the value of smp (i.e., tmp=smp).
318  */
319 void
320 mp_move(const MINT *smp, MINT *tmp)
321 {
322
323         _movem("move", smp, tmp);
324 }
325
326
327 /*
328  * Internal routine to set the value of tmp to that of sbp.
329  */
330 static void
331 _moveb(const char *msg, const BIGNUM *sbp, MINT *tmp)
332 {
333
334         BN_ERRCHECK(msg, BN_copy(tmp->bn, sbp));
335 }
336
337 /*
338  * Internal routine to set the value of tmp to that of smp.
339  */
340 static void
341 _movem(const char *msg, const MINT *smp, MINT *tmp)
342 {
343
344         BN_ERRCHECK(msg, BN_copy(tmp->bn, smp->bn));
345 }
346
347 /*
348  * Compute the square root of nmp and put the result in xmp.  The
349  * remainder goes in rmp.  Should satisfy: rmp=nmp-(xmp*xmp).
350  *
351  * Note that the OpenSSL BIGNUM library does not have a square root
352  * function, so this had to be implemented by hand using Newton's
353  * recursive formula:
354  *
355  *              x = (x + (n / x)) / 2
356  *
357  * where x is the square root of the positive number n.  In the
358  * beginning, x should be a reasonable guess, but the value 1,
359  * although suboptimal, works, too; this is that is used below.
360  */
361 void
362 mp_msqrt(const MINT *nmp, MINT *xmp, MINT *rmp)
363 {
364         BN_CTX *c;
365         MINT *tolerance;
366         MINT *ox, *x;
367         MINT *z1, *z2, *z3;
368         short i;
369
370         c = BN_CTX_new();
371         if (c == NULL)
372                 _bnerr("msqrt");
373         tolerance = _itom("msqrt", 1);
374         x = _itom("msqrt", 1);
375         ox = _itom("msqrt", 0);
376         z1 = _itom("msqrt", 0);
377         z2 = _itom("msqrt", 0);
378         z3 = _itom("msqrt", 0);
379         do {
380                 _movem("msqrt", x, ox);
381                 _mdiv("msqrt", nmp, x, z1, z2, c);
382                 _madd("msqrt", x, z1, z2);
383                 _sdiv("msqrt", z2, 2, x, &i, c);
384                 _msub("msqrt", ox, x, z3);
385         } while (_mcmpa("msqrt", z3, tolerance) == 1);
386         _movem("msqrt", x, xmp);
387         _mult("msqrt", x, x, z1, c);
388         _msub("msqrt", nmp, z1, z2);
389         _movem("msqrt", z2, rmp);
390         _mfree("msqrt", tolerance);
391         _mfree("msqrt", ox);
392         _mfree("msqrt", x);
393         _mfree("msqrt", z1);
394         _mfree("msqrt", z2);
395         _mfree("msqrt", z3);
396         BN_CTX_free(c);
397 }
398
399 /*
400  * Compute rmp=mp1-mp2.
401  */
402 static void
403 _msub(const char *msg, const MINT *mp1, const MINT *mp2, MINT *rmp)
404 {
405         BIGNUM b;
406
407         BN_init(&b);
408         BN_ERRCHECK(msg, BN_sub(&b, mp1->bn, mp2->bn));
409         _moveb(msg, &b, rmp);
410         BN_free(&b);
411 }
412
413 void
414 mp_msub(const MINT *mp1, const MINT *mp2, MINT *rmp)
415 {
416
417         _msub("msub", mp1, mp2, rmp);
418 }
419
420 /*
421  * Return a decimal representation of mp.  Return value must be
422  * free()'d.
423  */
424 static char *
425 _mtod(const char *msg, const MINT *mp)
426 {
427         char *s, *s2;
428
429         s = BN_bn2dec(mp->bn);
430         if (s == NULL)
431                 _bnerr(msg);
432         asprintf(&s2, "%s", s);
433         if (s2 == NULL)
434                 MPERR(("%s", msg));
435         OPENSSL_free(s);
436         return (s2);
437 }
438
439 /*
440  * Return a hexadecimal representation of mp.  Return value must be
441  * free()'d.
442  */
443 static char *
444 _mtox(const char *msg, const MINT *mp)
445 {
446         char *p, *s, *s2;
447         int len;
448
449         s = BN_bn2hex(mp->bn);
450         if (s == NULL)
451                 _bnerr(msg);
452         asprintf(&s2, "%s", s);
453         if (s2 == NULL)
454                 MPERR(("%s", msg));
455         OPENSSL_free(s);
456
457         /*
458          * This is a kludge for libgmp compatibility.  The latter's
459          * implementation of this function returns lower-case letters,
460          * but BN_bn2hex returns upper-case.  Some programs (e.g.,
461          * newkey(1)) are sensitive to this.  Although it's probably
462          * their fault, it's nice to be compatible.
463          */
464         len = strlen(s2);
465         for (p = s2; p < s2 + len; p++)
466                 *p = tolower(*p);
467
468         return (s2);
469 }
470
471 char *
472 mp_mtox(const MINT *mp)
473 {
474
475         return (_mtox("mtox", mp));
476 }
477
478 /*
479  * Compute rmp=mp1*mp2.
480  */
481 static void
482 _mult(const char *msg, const MINT *mp1, const MINT *mp2, MINT *rmp, BN_CTX *c)
483 {
484         BIGNUM b;
485
486         BN_init(&b);
487         BN_ERRCHECK(msg, BN_mul(&b, mp1->bn, mp2->bn, c));
488         _moveb(msg, &b, rmp);
489         BN_free(&b);
490 }
491
492 void
493 mp_mult(const MINT *mp1, const MINT *mp2, MINT *rmp)
494 {
495         BN_CTX *c;
496
497         c = BN_CTX_new();
498         if (c == NULL)
499                 _bnerr("mult");
500         _mult("mult", mp1, mp2, rmp, c);
501         BN_CTX_free(c);
502 }
503
504 /*
505  * Compute rmp=(bmp^emp)mod mmp.  (Note that here and above rpow() '^'
506  * means 'raise to power', not 'bitwise XOR'.)
507  */
508 void
509 mp_pow(const MINT *bmp, const MINT *emp, const MINT *mmp, MINT *rmp)
510 {
511         BIGNUM b;
512         BN_CTX *c;
513
514         c = BN_CTX_new();
515         if (c == NULL)
516                 _bnerr("pow");
517         BN_init(&b);
518         BN_ERRCHECK("pow", BN_mod_exp(&b, bmp->bn, emp->bn, mmp->bn, c));
519         _moveb("pow", &b, rmp);
520         BN_free(&b);
521         BN_CTX_free(c);
522 }
523
524 /*
525  * Compute rmp=bmp^e.  (See note above pow().)
526  */
527 void
528 mp_rpow(const MINT *bmp, short e, MINT *rmp)
529 {
530         MINT *emp;
531         BIGNUM b;
532         BN_CTX *c;
533
534         c = BN_CTX_new();
535         if (c == NULL)
536                 _bnerr("rpow");
537         BN_init(&b);
538         emp = _itom("rpow", e);
539         BN_ERRCHECK("rpow", BN_exp(&b, bmp->bn, emp->bn, c));
540         _moveb("rpow", &b, rmp);
541         _mfree("rpow", emp);
542         BN_free(&b);
543         BN_CTX_free(c);
544 }
545
546 /*
547  * Compute qmp=nmp/d and ro=nmp%d.
548  */
549 static void
550 _sdiv(const char *msg, const MINT *nmp, short d, MINT *qmp, short *ro,
551     BN_CTX *c)
552 {
553         MINT *dmp, *rmp;
554         BIGNUM q, r;
555         char *s;
556
557         BN_init(&q);
558         BN_init(&r);
559         dmp = _itom(msg, d);
560         rmp = _itom(msg, 0);
561         BN_ERRCHECK(msg, BN_div(&q, &r, nmp->bn, dmp->bn, c));
562         _moveb(msg, &q, qmp);
563         _moveb(msg, &r, rmp);
564         s = _mtox(msg, rmp);
565         errno = 0;
566         *ro = strtol(s, NULL, 16);
567         if (errno != 0)
568                 MPERR(("%s underflow or overflow", msg));
569         free(s);
570         _mfree(msg, dmp);
571         _mfree(msg, rmp);
572         BN_free(&r);
573         BN_free(&q);
574 }
575
576 void
577 mp_sdiv(const MINT *nmp, short d, MINT *qmp, short *ro)
578 {
579         BN_CTX *c;
580
581         c = BN_CTX_new();
582         if (c == NULL)
583                 _bnerr("sdiv");
584         _sdiv("sdiv", nmp, d, qmp, ro, c);
585         BN_CTX_free(c);
586 }
587
588 /*
589  * Convert a hexadecimal string to an MINT.
590  */
591 static MINT *
592 _xtom(const char *msg, const char *s)
593 {
594         MINT *mp;
595
596         mp = malloc(sizeof(*mp));
597         if (mp == NULL)
598                 MPERR(("%s", msg));
599         mp->bn = BN_new();
600         if (mp->bn == NULL)
601                 _bnerr(msg);
602         BN_ERRCHECK(msg, BN_hex2bn(&mp->bn, s));
603         return (mp);
604 }
605
606 MINT *
607 mp_xtom(const char *s)
608 {
609
610         return (_xtom("xtom", s));
611 }