]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - crypto/openssh/sshkey-xmss.c
ena: Upgrade ena-com to freebsd v2.7.0
[FreeBSD/FreeBSD.git] / crypto / openssh / sshkey-xmss.c
1 /* $OpenBSD: sshkey-xmss.c,v 1.12 2022/10/28 00:39:29 djm Exp $ */
2 /*
3  * Copyright (c) 2017 Markus Friedl.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
15  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
16  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
17  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
18  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
19  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
20  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
21  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
23  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24  */
25
26 #include "includes.h"
27 #ifdef WITH_XMSS
28
29 #include <sys/types.h>
30 #include <sys/uio.h>
31
32 #include <stdio.h>
33 #include <string.h>
34 #include <unistd.h>
35 #include <fcntl.h>
36 #include <errno.h>
37 #ifdef HAVE_SYS_FILE_H
38 # include <sys/file.h>
39 #endif
40
41 #include "ssh2.h"
42 #include "ssherr.h"
43 #include "sshbuf.h"
44 #include "cipher.h"
45 #include "sshkey.h"
46 #include "sshkey-xmss.h"
47 #include "atomicio.h"
48 #include "log.h"
49
50 #include "xmss_fast.h"
51
52 /* opaque internal XMSS state */
53 #define XMSS_MAGIC              "xmss-state-v1"
54 #define XMSS_CIPHERNAME         "aes256-gcm@openssh.com"
55 struct ssh_xmss_state {
56         xmss_params     params;
57         u_int32_t       n, w, h, k;
58
59         bds_state       bds;
60         u_char          *stack;
61         u_int32_t       stackoffset;
62         u_char          *stacklevels;
63         u_char          *auth;
64         u_char          *keep;
65         u_char          *th_nodes;
66         u_char          *retain;
67         treehash_inst   *treehash;
68
69         u_int32_t       idx;            /* state read from file */
70         u_int32_t       maxidx;         /* restricted # of signatures */
71         int             have_state;     /* .state file exists */
72         int             lockfd;         /* locked in sshkey_xmss_get_state() */
73         u_char          allow_update;   /* allow sshkey_xmss_update_state() */
74         char            *enc_ciphername;/* encrypt state with cipher */
75         u_char          *enc_keyiv;     /* encrypt state with key */
76         u_int32_t       enc_keyiv_len;  /* length of enc_keyiv */
77 };
78
79 int      sshkey_xmss_init_bds_state(struct sshkey *);
80 int      sshkey_xmss_init_enc_key(struct sshkey *, const char *);
81 void     sshkey_xmss_free_bds(struct sshkey *);
82 int      sshkey_xmss_get_state_from_file(struct sshkey *, const char *,
83             int *, int);
84 int      sshkey_xmss_encrypt_state(const struct sshkey *, struct sshbuf *,
85             struct sshbuf **);
86 int      sshkey_xmss_decrypt_state(const struct sshkey *, struct sshbuf *,
87             struct sshbuf **);
88 int      sshkey_xmss_serialize_enc_key(const struct sshkey *, struct sshbuf *);
89 int      sshkey_xmss_deserialize_enc_key(struct sshkey *, struct sshbuf *);
90
91 #define PRINT(...) do { if (printerror) sshlog(__FILE__, __func__, __LINE__, \
92     0, SYSLOG_LEVEL_ERROR, NULL, __VA_ARGS__); } while (0)
93
94 int
95 sshkey_xmss_init(struct sshkey *key, const char *name)
96 {
97         struct ssh_xmss_state *state;
98
99         if (key->xmss_state != NULL)
100                 return SSH_ERR_INVALID_FORMAT;
101         if (name == NULL)
102                 return SSH_ERR_INVALID_FORMAT;
103         state = calloc(sizeof(struct ssh_xmss_state), 1);
104         if (state == NULL)
105                 return SSH_ERR_ALLOC_FAIL;
106         if (strcmp(name, XMSS_SHA2_256_W16_H10_NAME) == 0) {
107                 state->n = 32;
108                 state->w = 16;
109                 state->h = 10;
110         } else if (strcmp(name, XMSS_SHA2_256_W16_H16_NAME) == 0) {
111                 state->n = 32;
112                 state->w = 16;
113                 state->h = 16;
114         } else if (strcmp(name, XMSS_SHA2_256_W16_H20_NAME) == 0) {
115                 state->n = 32;
116                 state->w = 16;
117                 state->h = 20;
118         } else {
119                 free(state);
120                 return SSH_ERR_KEY_TYPE_UNKNOWN;
121         }
122         if ((key->xmss_name = strdup(name)) == NULL) {
123                 free(state);
124                 return SSH_ERR_ALLOC_FAIL;
125         }
126         state->k = 2;   /* XXX hardcoded */
127         state->lockfd = -1;
128         if (xmss_set_params(&state->params, state->n, state->h, state->w,
129             state->k) != 0) {
130                 free(state);
131                 return SSH_ERR_INVALID_FORMAT;
132         }
133         key->xmss_state = state;
134         return 0;
135 }
136
137 void
138 sshkey_xmss_free_state(struct sshkey *key)
139 {
140         struct ssh_xmss_state *state = key->xmss_state;
141
142         sshkey_xmss_free_bds(key);
143         if (state) {
144                 if (state->enc_keyiv) {
145                         explicit_bzero(state->enc_keyiv, state->enc_keyiv_len);
146                         free(state->enc_keyiv);
147                 }
148                 free(state->enc_ciphername);
149                 free(state);
150         }
151         key->xmss_state = NULL;
152 }
153
154 #define SSH_XMSS_K2_MAGIC       "k=2"
155 #define num_stack(x)            ((x->h+1)*(x->n))
156 #define num_stacklevels(x)      (x->h+1)
157 #define num_auth(x)             ((x->h)*(x->n))
158 #define num_keep(x)             ((x->h >> 1)*(x->n))
159 #define num_th_nodes(x)         ((x->h - x->k)*(x->n))
160 #define num_retain(x)           (((1ULL << x->k) - x->k - 1) * (x->n))
161 #define num_treehash(x)         ((x->h) - (x->k))
162
163 int
164 sshkey_xmss_init_bds_state(struct sshkey *key)
165 {
166         struct ssh_xmss_state *state = key->xmss_state;
167         u_int32_t i;
168
169         state->stackoffset = 0;
170         if ((state->stack = calloc(num_stack(state), 1)) == NULL ||
171             (state->stacklevels = calloc(num_stacklevels(state), 1))== NULL ||
172             (state->auth = calloc(num_auth(state), 1)) == NULL ||
173             (state->keep = calloc(num_keep(state), 1)) == NULL ||
174             (state->th_nodes = calloc(num_th_nodes(state), 1)) == NULL ||
175             (state->retain = calloc(num_retain(state), 1)) == NULL ||
176             (state->treehash = calloc(num_treehash(state),
177             sizeof(treehash_inst))) == NULL) {
178                 sshkey_xmss_free_bds(key);
179                 return SSH_ERR_ALLOC_FAIL;
180         }
181         for (i = 0; i < state->h - state->k; i++)
182                 state->treehash[i].node = &state->th_nodes[state->n*i];
183         xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
184             state->stacklevels, state->auth, state->keep, state->treehash,
185             state->retain, 0);
186         return 0;
187 }
188
189 void
190 sshkey_xmss_free_bds(struct sshkey *key)
191 {
192         struct ssh_xmss_state *state = key->xmss_state;
193
194         if (state == NULL)
195                 return;
196         free(state->stack);
197         free(state->stacklevels);
198         free(state->auth);
199         free(state->keep);
200         free(state->th_nodes);
201         free(state->retain);
202         free(state->treehash);
203         state->stack = NULL;
204         state->stacklevels = NULL;
205         state->auth = NULL;
206         state->keep = NULL;
207         state->th_nodes = NULL;
208         state->retain = NULL;
209         state->treehash = NULL;
210 }
211
212 void *
213 sshkey_xmss_params(const struct sshkey *key)
214 {
215         struct ssh_xmss_state *state = key->xmss_state;
216
217         if (state == NULL)
218                 return NULL;
219         return &state->params;
220 }
221
222 void *
223 sshkey_xmss_bds_state(const struct sshkey *key)
224 {
225         struct ssh_xmss_state *state = key->xmss_state;
226
227         if (state == NULL)
228                 return NULL;
229         return &state->bds;
230 }
231
232 int
233 sshkey_xmss_siglen(const struct sshkey *key, size_t *lenp)
234 {
235         struct ssh_xmss_state *state = key->xmss_state;
236
237         if (lenp == NULL)
238                 return SSH_ERR_INVALID_ARGUMENT;
239         if (state == NULL)
240                 return SSH_ERR_INVALID_FORMAT;
241         *lenp = 4 + state->n +
242             state->params.wots_par.keysize +
243             state->h * state->n;
244         return 0;
245 }
246
247 size_t
248 sshkey_xmss_pklen(const struct sshkey *key)
249 {
250         struct ssh_xmss_state *state = key->xmss_state;
251
252         if (state == NULL)
253                 return 0;
254         return state->n * 2;
255 }
256
257 size_t
258 sshkey_xmss_sklen(const struct sshkey *key)
259 {
260         struct ssh_xmss_state *state = key->xmss_state;
261
262         if (state == NULL)
263                 return 0;
264         return state->n * 4 + 4;
265 }
266
267 int
268 sshkey_xmss_init_enc_key(struct sshkey *k, const char *ciphername)
269 {
270         struct ssh_xmss_state *state = k->xmss_state;
271         const struct sshcipher *cipher;
272         size_t keylen = 0, ivlen = 0;
273
274         if (state == NULL)
275                 return SSH_ERR_INVALID_ARGUMENT;
276         if ((cipher = cipher_by_name(ciphername)) == NULL)
277                 return SSH_ERR_INTERNAL_ERROR;
278         if ((state->enc_ciphername = strdup(ciphername)) == NULL)
279                 return SSH_ERR_ALLOC_FAIL;
280         keylen = cipher_keylen(cipher);
281         ivlen = cipher_ivlen(cipher);
282         state->enc_keyiv_len = keylen + ivlen;
283         if ((state->enc_keyiv = calloc(state->enc_keyiv_len, 1)) == NULL) {
284                 free(state->enc_ciphername);
285                 state->enc_ciphername = NULL;
286                 return SSH_ERR_ALLOC_FAIL;
287         }
288         arc4random_buf(state->enc_keyiv, state->enc_keyiv_len);
289         return 0;
290 }
291
292 int
293 sshkey_xmss_serialize_enc_key(const struct sshkey *k, struct sshbuf *b)
294 {
295         struct ssh_xmss_state *state = k->xmss_state;
296         int r;
297
298         if (state == NULL || state->enc_keyiv == NULL ||
299             state->enc_ciphername == NULL)
300                 return SSH_ERR_INVALID_ARGUMENT;
301         if ((r = sshbuf_put_cstring(b, state->enc_ciphername)) != 0 ||
302             (r = sshbuf_put_string(b, state->enc_keyiv,
303             state->enc_keyiv_len)) != 0)
304                 return r;
305         return 0;
306 }
307
308 int
309 sshkey_xmss_deserialize_enc_key(struct sshkey *k, struct sshbuf *b)
310 {
311         struct ssh_xmss_state *state = k->xmss_state;
312         size_t len;
313         int r;
314
315         if (state == NULL)
316                 return SSH_ERR_INVALID_ARGUMENT;
317         if ((r = sshbuf_get_cstring(b, &state->enc_ciphername, NULL)) != 0 ||
318             (r = sshbuf_get_string(b, &state->enc_keyiv, &len)) != 0)
319                 return r;
320         state->enc_keyiv_len = len;
321         return 0;
322 }
323
324 int
325 sshkey_xmss_serialize_pk_info(const struct sshkey *k, struct sshbuf *b,
326     enum sshkey_serialize_rep opts)
327 {
328         struct ssh_xmss_state *state = k->xmss_state;
329         u_char have_info = 1;
330         u_int32_t idx;
331         int r;
332
333         if (state == NULL)
334                 return SSH_ERR_INVALID_ARGUMENT;
335         if (opts != SSHKEY_SERIALIZE_INFO)
336                 return 0;
337         idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
338         if ((r = sshbuf_put_u8(b, have_info)) != 0 ||
339             (r = sshbuf_put_u32(b, idx)) != 0 ||
340             (r = sshbuf_put_u32(b, state->maxidx)) != 0)
341                 return r;
342         return 0;
343 }
344
345 int
346 sshkey_xmss_deserialize_pk_info(struct sshkey *k, struct sshbuf *b)
347 {
348         struct ssh_xmss_state *state = k->xmss_state;
349         u_char have_info;
350         int r;
351
352         if (state == NULL)
353                 return SSH_ERR_INVALID_ARGUMENT;
354         /* optional */
355         if (sshbuf_len(b) == 0)
356                 return 0;
357         if ((r = sshbuf_get_u8(b, &have_info)) != 0)
358                 return r;
359         if (have_info != 1)
360                 return SSH_ERR_INVALID_ARGUMENT;
361         if ((r = sshbuf_get_u32(b, &state->idx)) != 0 ||
362             (r = sshbuf_get_u32(b, &state->maxidx)) != 0)
363                 return r;
364         return 0;
365 }
366
367 int
368 sshkey_xmss_generate_private_key(struct sshkey *k, int bits)
369 {
370         int r;
371         const char *name;
372
373         if (bits == 10) {
374                 name = XMSS_SHA2_256_W16_H10_NAME;
375         } else if (bits == 16) {
376                 name = XMSS_SHA2_256_W16_H16_NAME;
377         } else if (bits == 20) {
378                 name = XMSS_SHA2_256_W16_H20_NAME;
379         } else {
380                 name = XMSS_DEFAULT_NAME;
381         }
382         if ((r = sshkey_xmss_init(k, name)) != 0 ||
383             (r = sshkey_xmss_init_bds_state(k)) != 0 ||
384             (r = sshkey_xmss_init_enc_key(k, XMSS_CIPHERNAME)) != 0)
385                 return r;
386         if ((k->xmss_pk = malloc(sshkey_xmss_pklen(k))) == NULL ||
387             (k->xmss_sk = malloc(sshkey_xmss_sklen(k))) == NULL) {
388                 return SSH_ERR_ALLOC_FAIL;
389         }
390         xmss_keypair(k->xmss_pk, k->xmss_sk, sshkey_xmss_bds_state(k),
391             sshkey_xmss_params(k));
392         return 0;
393 }
394
395 int
396 sshkey_xmss_get_state_from_file(struct sshkey *k, const char *filename,
397     int *have_file, int printerror)
398 {
399         struct sshbuf *b = NULL, *enc = NULL;
400         int ret = SSH_ERR_SYSTEM_ERROR, r, fd = -1;
401         u_int32_t len;
402         unsigned char buf[4], *data = NULL;
403
404         *have_file = 0;
405         if ((fd = open(filename, O_RDONLY)) >= 0) {
406                 *have_file = 1;
407                 if (atomicio(read, fd, buf, sizeof(buf)) != sizeof(buf)) {
408                         PRINT("corrupt state file: %s", filename);
409                         goto done;
410                 }
411                 len = PEEK_U32(buf);
412                 if ((data = calloc(len, 1)) == NULL) {
413                         ret = SSH_ERR_ALLOC_FAIL;
414                         goto done;
415                 }
416                 if (atomicio(read, fd, data, len) != len) {
417                         PRINT("cannot read blob: %s", filename);
418                         goto done;
419                 }
420                 if ((enc = sshbuf_from(data, len)) == NULL) {
421                         ret = SSH_ERR_ALLOC_FAIL;
422                         goto done;
423                 }
424                 sshkey_xmss_free_bds(k);
425                 if ((r = sshkey_xmss_decrypt_state(k, enc, &b)) != 0) {
426                         ret = r;
427                         goto done;
428                 }
429                 if ((r = sshkey_xmss_deserialize_state(k, b)) != 0) {
430                         ret = r;
431                         goto done;
432                 }
433                 ret = 0;
434         }
435 done:
436         if (fd != -1)
437                 close(fd);
438         free(data);
439         sshbuf_free(enc);
440         sshbuf_free(b);
441         return ret;
442 }
443
444 int
445 sshkey_xmss_get_state(const struct sshkey *k, int printerror)
446 {
447         struct ssh_xmss_state *state = k->xmss_state;
448         u_int32_t idx = 0;
449         char *filename = NULL;
450         char *statefile = NULL, *ostatefile = NULL, *lockfile = NULL;
451         int lockfd = -1, have_state = 0, have_ostate, tries = 0;
452         int ret = SSH_ERR_INVALID_ARGUMENT, r;
453
454         if (state == NULL)
455                 goto done;
456         /*
457          * If maxidx is set, then we are allowed a limited number
458          * of signatures, but don't need to access the disk.
459          * Otherwise we need to deal with the on-disk state.
460          */
461         if (state->maxidx) {
462                 /* xmss_sk always contains the current state */
463                 idx = PEEK_U32(k->xmss_sk);
464                 if (idx < state->maxidx) {
465                         state->allow_update = 1;
466                         return 0;
467                 }
468                 return SSH_ERR_INVALID_ARGUMENT;
469         }
470         if ((filename = k->xmss_filename) == NULL)
471                 goto done;
472         if (asprintf(&lockfile, "%s.lock", filename) == -1 ||
473             asprintf(&statefile, "%s.state", filename) == -1 ||
474             asprintf(&ostatefile, "%s.ostate", filename) == -1) {
475                 ret = SSH_ERR_ALLOC_FAIL;
476                 goto done;
477         }
478         if ((lockfd = open(lockfile, O_CREAT|O_RDONLY, 0600)) == -1) {
479                 ret = SSH_ERR_SYSTEM_ERROR;
480                 PRINT("cannot open/create: %s", lockfile);
481                 goto done;
482         }
483         while (flock(lockfd, LOCK_EX|LOCK_NB) == -1) {
484                 if (errno != EWOULDBLOCK) {
485                         ret = SSH_ERR_SYSTEM_ERROR;
486                         PRINT("cannot lock: %s", lockfile);
487                         goto done;
488                 }
489                 if (++tries > 10) {
490                         ret = SSH_ERR_SYSTEM_ERROR;
491                         PRINT("giving up on: %s", lockfile);
492                         goto done;
493                 }
494                 usleep(1000*100*tries);
495         }
496         /* XXX no longer const */
497         if ((r = sshkey_xmss_get_state_from_file((struct sshkey *)k,
498             statefile, &have_state, printerror)) != 0) {
499                 if ((r = sshkey_xmss_get_state_from_file((struct sshkey *)k,
500                     ostatefile, &have_ostate, printerror)) == 0) {
501                         state->allow_update = 1;
502                         r = sshkey_xmss_forward_state(k, 1);
503                         state->idx = PEEK_U32(k->xmss_sk);
504                         state->allow_update = 0;
505                 }
506         }
507         if (!have_state && !have_ostate) {
508                 /* check that bds state is initialized */
509                 if (state->bds.auth == NULL)
510                         goto done;
511                 PRINT("start from scratch idx 0: %u", state->idx);
512         } else if (r != 0) {
513                 ret = r;
514                 goto done;
515         }
516         if (state->idx + 1 < state->idx) {
517                 PRINT("state wrap: %u", state->idx);
518                 goto done;
519         }
520         state->have_state = have_state;
521         state->lockfd = lockfd;
522         state->allow_update = 1;
523         lockfd = -1;
524         ret = 0;
525 done:
526         if (lockfd != -1)
527                 close(lockfd);
528         free(lockfile);
529         free(statefile);
530         free(ostatefile);
531         return ret;
532 }
533
534 int
535 sshkey_xmss_forward_state(const struct sshkey *k, u_int32_t reserve)
536 {
537         struct ssh_xmss_state *state = k->xmss_state;
538         u_char *sig = NULL;
539         size_t required_siglen;
540         unsigned long long smlen;
541         u_char data;
542         int ret, r;
543
544         if (state == NULL || !state->allow_update)
545                 return SSH_ERR_INVALID_ARGUMENT;
546         if (reserve == 0)
547                 return SSH_ERR_INVALID_ARGUMENT;
548         if (state->idx + reserve <= state->idx)
549                 return SSH_ERR_INVALID_ARGUMENT;
550         if ((r = sshkey_xmss_siglen(k, &required_siglen)) != 0)
551                 return r;
552         if ((sig = malloc(required_siglen)) == NULL)
553                 return SSH_ERR_ALLOC_FAIL;
554         while (reserve-- > 0) {
555                 state->idx = PEEK_U32(k->xmss_sk);
556                 smlen = required_siglen;
557                 if ((ret = xmss_sign(k->xmss_sk, sshkey_xmss_bds_state(k),
558                     sig, &smlen, &data, 0, sshkey_xmss_params(k))) != 0) {
559                         r = SSH_ERR_INVALID_ARGUMENT;
560                         break;
561                 }
562         }
563         free(sig);
564         return r;
565 }
566
567 int
568 sshkey_xmss_update_state(const struct sshkey *k, int printerror)
569 {
570         struct ssh_xmss_state *state = k->xmss_state;
571         struct sshbuf *b = NULL, *enc = NULL;
572         u_int32_t idx = 0;
573         unsigned char buf[4];
574         char *filename = NULL;
575         char *statefile = NULL, *ostatefile = NULL, *nstatefile = NULL;
576         int fd = -1;
577         int ret = SSH_ERR_INVALID_ARGUMENT;
578
579         if (state == NULL || !state->allow_update)
580                 return ret;
581         if (state->maxidx) {
582                 /* no update since the number of signatures is limited */
583                 ret = 0;
584                 goto done;
585         }
586         idx = PEEK_U32(k->xmss_sk);
587         if (idx == state->idx) {
588                 /* no signature happened, no need to update */
589                 ret = 0;
590                 goto done;
591         } else if (idx != state->idx + 1) {
592                 PRINT("more than one signature happened: idx %u state %u",
593                     idx, state->idx);
594                 goto done;
595         }
596         state->idx = idx;
597         if ((filename = k->xmss_filename) == NULL)
598                 goto done;
599         if (asprintf(&statefile, "%s.state", filename) == -1 ||
600             asprintf(&ostatefile, "%s.ostate", filename) == -1 ||
601             asprintf(&nstatefile, "%s.nstate", filename) == -1) {
602                 ret = SSH_ERR_ALLOC_FAIL;
603                 goto done;
604         }
605         unlink(nstatefile);
606         if ((b = sshbuf_new()) == NULL) {
607                 ret = SSH_ERR_ALLOC_FAIL;
608                 goto done;
609         }
610         if ((ret = sshkey_xmss_serialize_state(k, b)) != 0) {
611                 PRINT("SERLIALIZE FAILED: %d", ret);
612                 goto done;
613         }
614         if ((ret = sshkey_xmss_encrypt_state(k, b, &enc)) != 0) {
615                 PRINT("ENCRYPT FAILED: %d", ret);
616                 goto done;
617         }
618         if ((fd = open(nstatefile, O_CREAT|O_WRONLY|O_EXCL, 0600)) == -1) {
619                 ret = SSH_ERR_SYSTEM_ERROR;
620                 PRINT("open new state file: %s", nstatefile);
621                 goto done;
622         }
623         POKE_U32(buf, sshbuf_len(enc));
624         if (atomicio(vwrite, fd, buf, sizeof(buf)) != sizeof(buf)) {
625                 ret = SSH_ERR_SYSTEM_ERROR;
626                 PRINT("write new state file hdr: %s", nstatefile);
627                 close(fd);
628                 goto done;
629         }
630         if (atomicio(vwrite, fd, sshbuf_mutable_ptr(enc), sshbuf_len(enc)) !=
631             sshbuf_len(enc)) {
632                 ret = SSH_ERR_SYSTEM_ERROR;
633                 PRINT("write new state file data: %s", nstatefile);
634                 close(fd);
635                 goto done;
636         }
637         if (fsync(fd) == -1) {
638                 ret = SSH_ERR_SYSTEM_ERROR;
639                 PRINT("sync new state file: %s", nstatefile);
640                 close(fd);
641                 goto done;
642         }
643         if (close(fd) == -1) {
644                 ret = SSH_ERR_SYSTEM_ERROR;
645                 PRINT("close new state file: %s", nstatefile);
646                 goto done;
647         }
648         if (state->have_state) {
649                 unlink(ostatefile);
650                 if (link(statefile, ostatefile)) {
651                         ret = SSH_ERR_SYSTEM_ERROR;
652                         PRINT("backup state %s to %s", statefile, ostatefile);
653                         goto done;
654                 }
655         }
656         if (rename(nstatefile, statefile) == -1) {
657                 ret = SSH_ERR_SYSTEM_ERROR;
658                 PRINT("rename %s to %s", nstatefile, statefile);
659                 goto done;
660         }
661         ret = 0;
662 done:
663         if (state->lockfd != -1) {
664                 close(state->lockfd);
665                 state->lockfd = -1;
666         }
667         if (nstatefile)
668                 unlink(nstatefile);
669         free(statefile);
670         free(ostatefile);
671         free(nstatefile);
672         sshbuf_free(b);
673         sshbuf_free(enc);
674         return ret;
675 }
676
677 int
678 sshkey_xmss_serialize_state(const struct sshkey *k, struct sshbuf *b)
679 {
680         struct ssh_xmss_state *state = k->xmss_state;
681         treehash_inst *th;
682         u_int32_t i, node;
683         int r;
684
685         if (state == NULL)
686                 return SSH_ERR_INVALID_ARGUMENT;
687         if (state->stack == NULL)
688                 return SSH_ERR_INVALID_ARGUMENT;
689         state->stackoffset = state->bds.stackoffset;    /* copy back */
690         if ((r = sshbuf_put_cstring(b, SSH_XMSS_K2_MAGIC)) != 0 ||
691             (r = sshbuf_put_u32(b, state->idx)) != 0 ||
692             (r = sshbuf_put_string(b, state->stack, num_stack(state))) != 0 ||
693             (r = sshbuf_put_u32(b, state->stackoffset)) != 0 ||
694             (r = sshbuf_put_string(b, state->stacklevels, num_stacklevels(state))) != 0 ||
695             (r = sshbuf_put_string(b, state->auth, num_auth(state))) != 0 ||
696             (r = sshbuf_put_string(b, state->keep, num_keep(state))) != 0 ||
697             (r = sshbuf_put_string(b, state->th_nodes, num_th_nodes(state))) != 0 ||
698             (r = sshbuf_put_string(b, state->retain, num_retain(state))) != 0 ||
699             (r = sshbuf_put_u32(b, num_treehash(state))) != 0)
700                 return r;
701         for (i = 0; i < num_treehash(state); i++) {
702                 th = &state->treehash[i];
703                 node = th->node - state->th_nodes;
704                 if ((r = sshbuf_put_u32(b, th->h)) != 0 ||
705                     (r = sshbuf_put_u32(b, th->next_idx)) != 0 ||
706                     (r = sshbuf_put_u32(b, th->stackusage)) != 0 ||
707                     (r = sshbuf_put_u8(b, th->completed)) != 0 ||
708                     (r = sshbuf_put_u32(b, node)) != 0)
709                         return r;
710         }
711         return 0;
712 }
713
714 int
715 sshkey_xmss_serialize_state_opt(const struct sshkey *k, struct sshbuf *b,
716     enum sshkey_serialize_rep opts)
717 {
718         struct ssh_xmss_state *state = k->xmss_state;
719         int r = SSH_ERR_INVALID_ARGUMENT;
720         u_char have_stack, have_filename, have_enc;
721
722         if (state == NULL)
723                 return SSH_ERR_INVALID_ARGUMENT;
724         if ((r = sshbuf_put_u8(b, opts)) != 0)
725                 return r;
726         switch (opts) {
727         case SSHKEY_SERIALIZE_STATE:
728                 r = sshkey_xmss_serialize_state(k, b);
729                 break;
730         case SSHKEY_SERIALIZE_FULL:
731                 if ((r = sshkey_xmss_serialize_enc_key(k, b)) != 0)
732                         return r;
733                 r = sshkey_xmss_serialize_state(k, b);
734                 break;
735         case SSHKEY_SERIALIZE_SHIELD:
736                 /* all of stack/filename/enc are optional */
737                 have_stack = state->stack != NULL;
738                 if ((r = sshbuf_put_u8(b, have_stack)) != 0)
739                         return r;
740                 if (have_stack) {
741                         state->idx = PEEK_U32(k->xmss_sk);      /* update */
742                         if ((r = sshkey_xmss_serialize_state(k, b)) != 0)
743                                 return r;
744                 }
745                 have_filename = k->xmss_filename != NULL;
746                 if ((r = sshbuf_put_u8(b, have_filename)) != 0)
747                         return r;
748                 if (have_filename &&
749                     (r = sshbuf_put_cstring(b, k->xmss_filename)) != 0)
750                         return r;
751                 have_enc = state->enc_keyiv != NULL;
752                 if ((r = sshbuf_put_u8(b, have_enc)) != 0)
753                         return r;
754                 if (have_enc &&
755                     (r = sshkey_xmss_serialize_enc_key(k, b)) != 0)
756                         return r;
757                 if ((r = sshbuf_put_u32(b, state->maxidx)) != 0 ||
758                     (r = sshbuf_put_u8(b, state->allow_update)) != 0)
759                         return r;
760                 break;
761         case SSHKEY_SERIALIZE_DEFAULT:
762                 r = 0;
763                 break;
764         default:
765                 r = SSH_ERR_INVALID_ARGUMENT;
766                 break;
767         }
768         return r;
769 }
770
771 int
772 sshkey_xmss_deserialize_state(struct sshkey *k, struct sshbuf *b)
773 {
774         struct ssh_xmss_state *state = k->xmss_state;
775         treehash_inst *th;
776         u_int32_t i, lh, node;
777         size_t ls, lsl, la, lk, ln, lr;
778         char *magic;
779         int r = SSH_ERR_INTERNAL_ERROR;
780
781         if (state == NULL)
782                 return SSH_ERR_INVALID_ARGUMENT;
783         if (k->xmss_sk == NULL)
784                 return SSH_ERR_INVALID_ARGUMENT;
785         if ((state->treehash = calloc(num_treehash(state),
786             sizeof(treehash_inst))) == NULL)
787                 return SSH_ERR_ALLOC_FAIL;
788         if ((r = sshbuf_get_cstring(b, &magic, NULL)) != 0 ||
789             (r = sshbuf_get_u32(b, &state->idx)) != 0 ||
790             (r = sshbuf_get_string(b, &state->stack, &ls)) != 0 ||
791             (r = sshbuf_get_u32(b, &state->stackoffset)) != 0 ||
792             (r = sshbuf_get_string(b, &state->stacklevels, &lsl)) != 0 ||
793             (r = sshbuf_get_string(b, &state->auth, &la)) != 0 ||
794             (r = sshbuf_get_string(b, &state->keep, &lk)) != 0 ||
795             (r = sshbuf_get_string(b, &state->th_nodes, &ln)) != 0 ||
796             (r = sshbuf_get_string(b, &state->retain, &lr)) != 0 ||
797             (r = sshbuf_get_u32(b, &lh)) != 0)
798                 goto out;
799         if (strcmp(magic, SSH_XMSS_K2_MAGIC) != 0) {
800                 r = SSH_ERR_INVALID_ARGUMENT;
801                 goto out;
802         }
803         /* XXX check stackoffset */
804         if (ls != num_stack(state) ||
805             lsl != num_stacklevels(state) ||
806             la != num_auth(state) ||
807             lk != num_keep(state) ||
808             ln != num_th_nodes(state) ||
809             lr != num_retain(state) ||
810             lh != num_treehash(state)) {
811                 r = SSH_ERR_INVALID_ARGUMENT;
812                 goto out;
813         }
814         for (i = 0; i < num_treehash(state); i++) {
815                 th = &state->treehash[i];
816                 if ((r = sshbuf_get_u32(b, &th->h)) != 0 ||
817                     (r = sshbuf_get_u32(b, &th->next_idx)) != 0 ||
818                     (r = sshbuf_get_u32(b, &th->stackusage)) != 0 ||
819                     (r = sshbuf_get_u8(b, &th->completed)) != 0 ||
820                     (r = sshbuf_get_u32(b, &node)) != 0)
821                         goto out;
822                 if (node < num_th_nodes(state))
823                         th->node = &state->th_nodes[node];
824         }
825         POKE_U32(k->xmss_sk, state->idx);
826         xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
827             state->stacklevels, state->auth, state->keep, state->treehash,
828             state->retain, 0);
829         /* success */
830         r = 0;
831  out:
832         free(magic);
833         return r;
834 }
835
836 int
837 sshkey_xmss_deserialize_state_opt(struct sshkey *k, struct sshbuf *b)
838 {
839         struct ssh_xmss_state *state = k->xmss_state;
840         enum sshkey_serialize_rep opts;
841         u_char have_state, have_stack, have_filename, have_enc;
842         int r;
843
844         if ((r = sshbuf_get_u8(b, &have_state)) != 0)
845                 return r;
846
847         opts = have_state;
848         switch (opts) {
849         case SSHKEY_SERIALIZE_DEFAULT:
850                 r = 0;
851                 break;
852         case SSHKEY_SERIALIZE_SHIELD:
853                 if ((r = sshbuf_get_u8(b, &have_stack)) != 0)
854                         return r;
855                 if (have_stack &&
856                     (r = sshkey_xmss_deserialize_state(k, b)) != 0)
857                         return r;
858                 if ((r = sshbuf_get_u8(b, &have_filename)) != 0)
859                         return r;
860                 if (have_filename &&
861                     (r = sshbuf_get_cstring(b, &k->xmss_filename, NULL)) != 0)
862                         return r;
863                 if ((r = sshbuf_get_u8(b, &have_enc)) != 0)
864                         return r;
865                 if (have_enc &&
866                     (r = sshkey_xmss_deserialize_enc_key(k, b)) != 0)
867                         return r;
868                 if ((r = sshbuf_get_u32(b, &state->maxidx)) != 0 ||
869                     (r = sshbuf_get_u8(b, &state->allow_update)) != 0)
870                         return r;
871                 break;
872         case SSHKEY_SERIALIZE_STATE:
873                 if ((r = sshkey_xmss_deserialize_state(k, b)) != 0)
874                         return r;
875                 break;
876         case SSHKEY_SERIALIZE_FULL:
877                 if ((r = sshkey_xmss_deserialize_enc_key(k, b)) != 0 ||
878                     (r = sshkey_xmss_deserialize_state(k, b)) != 0)
879                         return r;
880                 break;
881         default:
882                 r = SSH_ERR_INVALID_FORMAT;
883                 break;
884         }
885         return r;
886 }
887
888 int
889 sshkey_xmss_encrypt_state(const struct sshkey *k, struct sshbuf *b,
890    struct sshbuf **retp)
891 {
892         struct ssh_xmss_state *state = k->xmss_state;
893         struct sshbuf *encrypted = NULL, *encoded = NULL, *padded = NULL;
894         struct sshcipher_ctx *ciphercontext = NULL;
895         const struct sshcipher *cipher;
896         u_char *cp, *key, *iv = NULL;
897         size_t i, keylen, ivlen, blocksize, authlen, encrypted_len, aadlen;
898         int r = SSH_ERR_INTERNAL_ERROR;
899
900         if (retp != NULL)
901                 *retp = NULL;
902         if (state == NULL ||
903             state->enc_keyiv == NULL ||
904             state->enc_ciphername == NULL)
905                 return SSH_ERR_INTERNAL_ERROR;
906         if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
907                 r = SSH_ERR_INTERNAL_ERROR;
908                 goto out;
909         }
910         blocksize = cipher_blocksize(cipher);
911         keylen = cipher_keylen(cipher);
912         ivlen = cipher_ivlen(cipher);
913         authlen = cipher_authlen(cipher);
914         if (state->enc_keyiv_len != keylen + ivlen) {
915                 r = SSH_ERR_INVALID_FORMAT;
916                 goto out;
917         }
918         key = state->enc_keyiv;
919         if ((encrypted = sshbuf_new()) == NULL ||
920             (encoded = sshbuf_new()) == NULL ||
921             (padded = sshbuf_new()) == NULL ||
922             (iv = malloc(ivlen)) == NULL) {
923                 r = SSH_ERR_ALLOC_FAIL;
924                 goto out;
925         }
926
927         /* replace first 4 bytes of IV with index to ensure uniqueness */
928         memcpy(iv, key + keylen, ivlen);
929         POKE_U32(iv, state->idx);
930
931         if ((r = sshbuf_put(encoded, XMSS_MAGIC, sizeof(XMSS_MAGIC))) != 0 ||
932             (r = sshbuf_put_u32(encoded, state->idx)) != 0)
933                 goto out;
934
935         /* padded state will be encrypted */
936         if ((r = sshbuf_putb(padded, b)) != 0)
937                 goto out;
938         i = 0;
939         while (sshbuf_len(padded) % blocksize) {
940                 if ((r = sshbuf_put_u8(padded, ++i & 0xff)) != 0)
941                         goto out;
942         }
943         encrypted_len = sshbuf_len(padded);
944
945         /* header including the length of state is used as AAD */
946         if ((r = sshbuf_put_u32(encoded, encrypted_len)) != 0)
947                 goto out;
948         aadlen = sshbuf_len(encoded);
949
950         /* concat header and state */
951         if ((r = sshbuf_putb(encoded, padded)) != 0)
952                 goto out;
953
954         /* reserve space for encryption of encoded data plus auth tag */
955         /* encrypt at offset addlen */
956         if ((r = sshbuf_reserve(encrypted,
957             encrypted_len + aadlen + authlen, &cp)) != 0 ||
958             (r = cipher_init(&ciphercontext, cipher, key, keylen,
959             iv, ivlen, 1)) != 0 ||
960             (r = cipher_crypt(ciphercontext, 0, cp, sshbuf_ptr(encoded),
961             encrypted_len, aadlen, authlen)) != 0)
962                 goto out;
963
964         /* success */
965         r = 0;
966  out:
967         if (retp != NULL) {
968                 *retp = encrypted;
969                 encrypted = NULL;
970         }
971         sshbuf_free(padded);
972         sshbuf_free(encoded);
973         sshbuf_free(encrypted);
974         cipher_free(ciphercontext);
975         free(iv);
976         return r;
977 }
978
979 int
980 sshkey_xmss_decrypt_state(const struct sshkey *k, struct sshbuf *encoded,
981    struct sshbuf **retp)
982 {
983         struct ssh_xmss_state *state = k->xmss_state;
984         struct sshbuf *copy = NULL, *decrypted = NULL;
985         struct sshcipher_ctx *ciphercontext = NULL;
986         const struct sshcipher *cipher = NULL;
987         u_char *key, *iv = NULL, *dp;
988         size_t keylen, ivlen, authlen, aadlen;
989         u_int blocksize, encrypted_len, index;
990         int r = SSH_ERR_INTERNAL_ERROR;
991
992         if (retp != NULL)
993                 *retp = NULL;
994         if (state == NULL ||
995             state->enc_keyiv == NULL ||
996             state->enc_ciphername == NULL)
997                 return SSH_ERR_INTERNAL_ERROR;
998         if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
999                 r = SSH_ERR_INVALID_FORMAT;
1000                 goto out;
1001         }
1002         blocksize = cipher_blocksize(cipher);
1003         keylen = cipher_keylen(cipher);
1004         ivlen = cipher_ivlen(cipher);
1005         authlen = cipher_authlen(cipher);
1006         if (state->enc_keyiv_len != keylen + ivlen) {
1007                 r = SSH_ERR_INTERNAL_ERROR;
1008                 goto out;
1009         }
1010         key = state->enc_keyiv;
1011
1012         if ((copy = sshbuf_fromb(encoded)) == NULL ||
1013             (decrypted = sshbuf_new()) == NULL ||
1014             (iv = malloc(ivlen)) == NULL) {
1015                 r = SSH_ERR_ALLOC_FAIL;
1016                 goto out;
1017         }
1018
1019         /* check magic */
1020         if (sshbuf_len(encoded) < sizeof(XMSS_MAGIC) ||
1021             memcmp(sshbuf_ptr(encoded), XMSS_MAGIC, sizeof(XMSS_MAGIC))) {
1022                 r = SSH_ERR_INVALID_FORMAT;
1023                 goto out;
1024         }
1025         /* parse public portion */
1026         if ((r = sshbuf_consume(encoded, sizeof(XMSS_MAGIC))) != 0 ||
1027             (r = sshbuf_get_u32(encoded, &index)) != 0 ||
1028             (r = sshbuf_get_u32(encoded, &encrypted_len)) != 0)
1029                 goto out;
1030
1031         /* check size of encrypted key blob */
1032         if (encrypted_len < blocksize || (encrypted_len % blocksize) != 0) {
1033                 r = SSH_ERR_INVALID_FORMAT;
1034                 goto out;
1035         }
1036         /* check that an appropriate amount of auth data is present */
1037         if (sshbuf_len(encoded) < authlen ||
1038             sshbuf_len(encoded) - authlen < encrypted_len) {
1039                 r = SSH_ERR_INVALID_FORMAT;
1040                 goto out;
1041         }
1042
1043         aadlen = sshbuf_len(copy) - sshbuf_len(encoded);
1044
1045         /* replace first 4 bytes of IV with index to ensure uniqueness */
1046         memcpy(iv, key + keylen, ivlen);
1047         POKE_U32(iv, index);
1048
1049         /* decrypt private state of key */
1050         if ((r = sshbuf_reserve(decrypted, aadlen + encrypted_len, &dp)) != 0 ||
1051             (r = cipher_init(&ciphercontext, cipher, key, keylen,
1052             iv, ivlen, 0)) != 0 ||
1053             (r = cipher_crypt(ciphercontext, 0, dp, sshbuf_ptr(copy),
1054             encrypted_len, aadlen, authlen)) != 0)
1055                 goto out;
1056
1057         /* there should be no trailing data */
1058         if ((r = sshbuf_consume(encoded, encrypted_len + authlen)) != 0)
1059                 goto out;
1060         if (sshbuf_len(encoded) != 0) {
1061                 r = SSH_ERR_INVALID_FORMAT;
1062                 goto out;
1063         }
1064
1065         /* remove AAD */
1066         if ((r = sshbuf_consume(decrypted, aadlen)) != 0)
1067                 goto out;
1068         /* XXX encrypted includes unchecked padding */
1069
1070         /* success */
1071         r = 0;
1072         if (retp != NULL) {
1073                 *retp = decrypted;
1074                 decrypted = NULL;
1075         }
1076  out:
1077         cipher_free(ciphercontext);
1078         sshbuf_free(copy);
1079         sshbuf_free(decrypted);
1080         free(iv);
1081         return r;
1082 }
1083
1084 u_int32_t
1085 sshkey_xmss_signatures_left(const struct sshkey *k)
1086 {
1087         struct ssh_xmss_state *state = k->xmss_state;
1088         u_int32_t idx;
1089
1090         if (sshkey_type_plain(k->type) == KEY_XMSS && state &&
1091             state->maxidx) {
1092                 idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
1093                 if (idx < state->maxidx)
1094                         return state->maxidx - idx;
1095         }
1096         return 0;
1097 }
1098
1099 int
1100 sshkey_xmss_enable_maxsign(struct sshkey *k, u_int32_t maxsign)
1101 {
1102         struct ssh_xmss_state *state = k->xmss_state;
1103
1104         if (sshkey_type_plain(k->type) != KEY_XMSS)
1105                 return SSH_ERR_INVALID_ARGUMENT;
1106         if (maxsign == 0)
1107                 return 0;
1108         if (state->idx + maxsign < state->idx)
1109                 return SSH_ERR_INVALID_ARGUMENT;
1110         state->maxidx = state->idx + maxsign;
1111         return 0;
1112 }
1113 #endif /* WITH_XMSS */