]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - crypto/openssh/sshkey-xmss.c
Optionally bind ktls threads to NUMA domains
[FreeBSD/FreeBSD.git] / crypto / openssh / sshkey-xmss.c
1 /* $OpenBSD: sshkey-xmss.c,v 1.3 2018/07/09 21:59:10 markus Exp $ */
2 /*
3  * Copyright (c) 2017 Markus Friedl.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
15  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
16  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
17  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
18  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
19  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
20  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
21  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
23  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24  */
25
26 #include "includes.h"
27 #ifdef WITH_XMSS
28
29 #include <sys/types.h>
30 #include <sys/uio.h>
31
32 #include <stdio.h>
33 #include <string.h>
34 #include <unistd.h>
35 #include <fcntl.h>
36 #include <errno.h>
37 #ifdef HAVE_SYS_FILE_H
38 # include <sys/file.h>
39 #endif
40
41 #include "ssh2.h"
42 #include "ssherr.h"
43 #include "sshbuf.h"
44 #include "cipher.h"
45 #include "sshkey.h"
46 #include "sshkey-xmss.h"
47 #include "atomicio.h"
48
49 #include "xmss_fast.h"
50
51 /* opaque internal XMSS state */
52 #define XMSS_MAGIC              "xmss-state-v1"
53 #define XMSS_CIPHERNAME         "aes256-gcm@openssh.com"
54 struct ssh_xmss_state {
55         xmss_params     params;
56         u_int32_t       n, w, h, k;
57
58         bds_state       bds;
59         u_char          *stack;
60         u_int32_t       stackoffset;
61         u_char          *stacklevels;
62         u_char          *auth;
63         u_char          *keep;
64         u_char          *th_nodes;
65         u_char          *retain;
66         treehash_inst   *treehash;
67
68         u_int32_t       idx;            /* state read from file */
69         u_int32_t       maxidx;         /* restricted # of signatures */
70         int             have_state;     /* .state file exists */
71         int             lockfd;         /* locked in sshkey_xmss_get_state() */
72         int             allow_update;   /* allow sshkey_xmss_update_state() */
73         char            *enc_ciphername;/* encrypt state with cipher */
74         u_char          *enc_keyiv;     /* encrypt state with key */
75         u_int32_t       enc_keyiv_len;  /* length of enc_keyiv */
76 };
77
78 int      sshkey_xmss_init_bds_state(struct sshkey *);
79 int      sshkey_xmss_init_enc_key(struct sshkey *, const char *);
80 void     sshkey_xmss_free_bds(struct sshkey *);
81 int      sshkey_xmss_get_state_from_file(struct sshkey *, const char *,
82             int *, sshkey_printfn *);
83 int      sshkey_xmss_encrypt_state(const struct sshkey *, struct sshbuf *,
84             struct sshbuf **);
85 int      sshkey_xmss_decrypt_state(const struct sshkey *, struct sshbuf *,
86             struct sshbuf **);
87 int      sshkey_xmss_serialize_enc_key(const struct sshkey *, struct sshbuf *);
88 int      sshkey_xmss_deserialize_enc_key(struct sshkey *, struct sshbuf *);
89
90 #define PRINT(s...) do { if (pr) pr(s); } while (0)
91
92 int
93 sshkey_xmss_init(struct sshkey *key, const char *name)
94 {
95         struct ssh_xmss_state *state;
96
97         if (key->xmss_state != NULL)
98                 return SSH_ERR_INVALID_FORMAT;
99         if (name == NULL)
100                 return SSH_ERR_INVALID_FORMAT;
101         state = calloc(sizeof(struct ssh_xmss_state), 1);
102         if (state == NULL)
103                 return SSH_ERR_ALLOC_FAIL;
104         if (strcmp(name, XMSS_SHA2_256_W16_H10_NAME) == 0) {
105                 state->n = 32;
106                 state->w = 16;
107                 state->h = 10;
108         } else if (strcmp(name, XMSS_SHA2_256_W16_H16_NAME) == 0) {
109                 state->n = 32;
110                 state->w = 16;
111                 state->h = 16;
112         } else if (strcmp(name, XMSS_SHA2_256_W16_H20_NAME) == 0) {
113                 state->n = 32;
114                 state->w = 16;
115                 state->h = 20;
116         } else {
117                 free(state);
118                 return SSH_ERR_KEY_TYPE_UNKNOWN;
119         }
120         if ((key->xmss_name = strdup(name)) == NULL) {
121                 free(state);
122                 return SSH_ERR_ALLOC_FAIL;
123         }
124         state->k = 2;   /* XXX hardcoded */
125         state->lockfd = -1;
126         if (xmss_set_params(&state->params, state->n, state->h, state->w,
127             state->k) != 0) {
128                 free(state);
129                 return SSH_ERR_INVALID_FORMAT;
130         }
131         key->xmss_state = state;
132         return 0;
133 }
134
135 void
136 sshkey_xmss_free_state(struct sshkey *key)
137 {
138         struct ssh_xmss_state *state = key->xmss_state;
139
140         sshkey_xmss_free_bds(key);
141         if (state) {
142                 if (state->enc_keyiv) {
143                         explicit_bzero(state->enc_keyiv, state->enc_keyiv_len);
144                         free(state->enc_keyiv);
145                 }
146                 free(state->enc_ciphername);
147                 free(state);
148         }
149         key->xmss_state = NULL;
150 }
151
152 #define SSH_XMSS_K2_MAGIC       "k=2"
153 #define num_stack(x)            ((x->h+1)*(x->n))
154 #define num_stacklevels(x)      (x->h+1)
155 #define num_auth(x)             ((x->h)*(x->n))
156 #define num_keep(x)             ((x->h >> 1)*(x->n))
157 #define num_th_nodes(x)         ((x->h - x->k)*(x->n))
158 #define num_retain(x)           (((1ULL << x->k) - x->k - 1) * (x->n))
159 #define num_treehash(x)         ((x->h) - (x->k))
160
161 int
162 sshkey_xmss_init_bds_state(struct sshkey *key)
163 {
164         struct ssh_xmss_state *state = key->xmss_state;
165         u_int32_t i;
166
167         state->stackoffset = 0;
168         if ((state->stack = calloc(num_stack(state), 1)) == NULL ||
169             (state->stacklevels = calloc(num_stacklevels(state), 1))== NULL ||
170             (state->auth = calloc(num_auth(state), 1)) == NULL ||
171             (state->keep = calloc(num_keep(state), 1)) == NULL ||
172             (state->th_nodes = calloc(num_th_nodes(state), 1)) == NULL ||
173             (state->retain = calloc(num_retain(state), 1)) == NULL ||
174             (state->treehash = calloc(num_treehash(state),
175             sizeof(treehash_inst))) == NULL) {
176                 sshkey_xmss_free_bds(key);
177                 return SSH_ERR_ALLOC_FAIL;
178         }
179         for (i = 0; i < state->h - state->k; i++)
180                 state->treehash[i].node = &state->th_nodes[state->n*i];
181         xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
182             state->stacklevels, state->auth, state->keep, state->treehash,
183             state->retain, 0);
184         return 0;
185 }
186
187 void
188 sshkey_xmss_free_bds(struct sshkey *key)
189 {
190         struct ssh_xmss_state *state = key->xmss_state;
191
192         if (state == NULL)
193                 return;
194         free(state->stack);
195         free(state->stacklevels);
196         free(state->auth);
197         free(state->keep);
198         free(state->th_nodes);
199         free(state->retain);
200         free(state->treehash);
201         state->stack = NULL;
202         state->stacklevels = NULL;
203         state->auth = NULL;
204         state->keep = NULL;
205         state->th_nodes = NULL;
206         state->retain = NULL;
207         state->treehash = NULL;
208 }
209
210 void *
211 sshkey_xmss_params(const struct sshkey *key)
212 {
213         struct ssh_xmss_state *state = key->xmss_state;
214
215         if (state == NULL)
216                 return NULL;
217         return &state->params;
218 }
219
220 void *
221 sshkey_xmss_bds_state(const struct sshkey *key)
222 {
223         struct ssh_xmss_state *state = key->xmss_state;
224
225         if (state == NULL)
226                 return NULL;
227         return &state->bds;
228 }
229
230 int
231 sshkey_xmss_siglen(const struct sshkey *key, size_t *lenp)
232 {
233         struct ssh_xmss_state *state = key->xmss_state;
234
235         if (lenp == NULL)
236                 return SSH_ERR_INVALID_ARGUMENT;
237         if (state == NULL)
238                 return SSH_ERR_INVALID_FORMAT;
239         *lenp = 4 + state->n +
240             state->params.wots_par.keysize +
241             state->h * state->n;
242         return 0;
243 }
244
245 size_t
246 sshkey_xmss_pklen(const struct sshkey *key)
247 {
248         struct ssh_xmss_state *state = key->xmss_state;
249
250         if (state == NULL)
251                 return 0;
252         return state->n * 2;
253 }
254
255 size_t
256 sshkey_xmss_sklen(const struct sshkey *key)
257 {
258         struct ssh_xmss_state *state = key->xmss_state;
259
260         if (state == NULL)
261                 return 0;
262         return state->n * 4 + 4;
263 }
264
265 int
266 sshkey_xmss_init_enc_key(struct sshkey *k, const char *ciphername)
267 {
268         struct ssh_xmss_state *state = k->xmss_state;
269         const struct sshcipher *cipher;
270         size_t keylen = 0, ivlen = 0;
271
272         if (state == NULL)
273                 return SSH_ERR_INVALID_ARGUMENT;
274         if ((cipher = cipher_by_name(ciphername)) == NULL)
275                 return SSH_ERR_INTERNAL_ERROR;
276         if ((state->enc_ciphername = strdup(ciphername)) == NULL)
277                 return SSH_ERR_ALLOC_FAIL;
278         keylen = cipher_keylen(cipher);
279         ivlen = cipher_ivlen(cipher);
280         state->enc_keyiv_len = keylen + ivlen;
281         if ((state->enc_keyiv = calloc(state->enc_keyiv_len, 1)) == NULL) {
282                 free(state->enc_ciphername);
283                 state->enc_ciphername = NULL;
284                 return SSH_ERR_ALLOC_FAIL;
285         }
286         arc4random_buf(state->enc_keyiv, state->enc_keyiv_len);
287         return 0;
288 }
289
290 int
291 sshkey_xmss_serialize_enc_key(const struct sshkey *k, struct sshbuf *b)
292 {
293         struct ssh_xmss_state *state = k->xmss_state;
294         int r;
295
296         if (state == NULL || state->enc_keyiv == NULL ||
297             state->enc_ciphername == NULL)
298                 return SSH_ERR_INVALID_ARGUMENT;
299         if ((r = sshbuf_put_cstring(b, state->enc_ciphername)) != 0 ||
300             (r = sshbuf_put_string(b, state->enc_keyiv,
301             state->enc_keyiv_len)) != 0)
302                 return r;
303         return 0;
304 }
305
306 int
307 sshkey_xmss_deserialize_enc_key(struct sshkey *k, struct sshbuf *b)
308 {
309         struct ssh_xmss_state *state = k->xmss_state;
310         size_t len;
311         int r;
312
313         if (state == NULL)
314                 return SSH_ERR_INVALID_ARGUMENT;
315         if ((r = sshbuf_get_cstring(b, &state->enc_ciphername, NULL)) != 0 ||
316             (r = sshbuf_get_string(b, &state->enc_keyiv, &len)) != 0)
317                 return r;
318         state->enc_keyiv_len = len;
319         return 0;
320 }
321
322 int
323 sshkey_xmss_serialize_pk_info(const struct sshkey *k, struct sshbuf *b,
324     enum sshkey_serialize_rep opts)
325 {
326         struct ssh_xmss_state *state = k->xmss_state;
327         u_char have_info = 1;
328         u_int32_t idx;
329         int r;
330
331         if (state == NULL)
332                 return SSH_ERR_INVALID_ARGUMENT;
333         if (opts != SSHKEY_SERIALIZE_INFO)
334                 return 0;
335         idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
336         if ((r = sshbuf_put_u8(b, have_info)) != 0 ||
337             (r = sshbuf_put_u32(b, idx)) != 0 ||
338             (r = sshbuf_put_u32(b, state->maxidx)) != 0)
339                 return r;
340         return 0;
341 }
342
343 int
344 sshkey_xmss_deserialize_pk_info(struct sshkey *k, struct sshbuf *b)
345 {
346         struct ssh_xmss_state *state = k->xmss_state;
347         u_char have_info;
348         int r;
349
350         if (state == NULL)
351                 return SSH_ERR_INVALID_ARGUMENT;
352         /* optional */
353         if (sshbuf_len(b) == 0)
354                 return 0;
355         if ((r = sshbuf_get_u8(b, &have_info)) != 0)
356                 return r;
357         if (have_info != 1)
358                 return SSH_ERR_INVALID_ARGUMENT;
359         if ((r = sshbuf_get_u32(b, &state->idx)) != 0 ||
360             (r = sshbuf_get_u32(b, &state->maxidx)) != 0)
361                 return r;
362         return 0;
363 }
364
365 int
366 sshkey_xmss_generate_private_key(struct sshkey *k, u_int bits)
367 {
368         int r;
369         const char *name;
370
371         if (bits == 10) {
372                 name = XMSS_SHA2_256_W16_H10_NAME;
373         } else if (bits == 16) {
374                 name = XMSS_SHA2_256_W16_H16_NAME;
375         } else if (bits == 20) {
376                 name = XMSS_SHA2_256_W16_H20_NAME;
377         } else {
378                 name = XMSS_DEFAULT_NAME;
379         }
380         if ((r = sshkey_xmss_init(k, name)) != 0 ||
381             (r = sshkey_xmss_init_bds_state(k)) != 0 ||
382             (r = sshkey_xmss_init_enc_key(k, XMSS_CIPHERNAME)) != 0)
383                 return r;
384         if ((k->xmss_pk = malloc(sshkey_xmss_pklen(k))) == NULL ||
385             (k->xmss_sk = malloc(sshkey_xmss_sklen(k))) == NULL) {
386                 return SSH_ERR_ALLOC_FAIL;
387         }
388         xmss_keypair(k->xmss_pk, k->xmss_sk, sshkey_xmss_bds_state(k),
389             sshkey_xmss_params(k));
390         return 0;
391 }
392
393 int
394 sshkey_xmss_get_state_from_file(struct sshkey *k, const char *filename,
395     int *have_file, sshkey_printfn *pr)
396 {
397         struct sshbuf *b = NULL, *enc = NULL;
398         int ret = SSH_ERR_SYSTEM_ERROR, r, fd = -1;
399         u_int32_t len;
400         unsigned char buf[4], *data = NULL;
401
402         *have_file = 0;
403         if ((fd = open(filename, O_RDONLY)) >= 0) {
404                 *have_file = 1;
405                 if (atomicio(read, fd, buf, sizeof(buf)) != sizeof(buf)) {
406                         PRINT("%s: corrupt state file: %s", __func__, filename);
407                         goto done;
408                 }
409                 len = PEEK_U32(buf);
410                 if ((data = calloc(len, 1)) == NULL) {
411                         ret = SSH_ERR_ALLOC_FAIL;
412                         goto done;
413                 }
414                 if (atomicio(read, fd, data, len) != len) {
415                         PRINT("%s: cannot read blob: %s", __func__, filename);
416                         goto done;
417                 }
418                 if ((enc = sshbuf_from(data, len)) == NULL) {
419                         ret = SSH_ERR_ALLOC_FAIL;
420                         goto done;
421                 }
422                 sshkey_xmss_free_bds(k);
423                 if ((r = sshkey_xmss_decrypt_state(k, enc, &b)) != 0) {
424                         ret = r;
425                         goto done;
426                 }
427                 if ((r = sshkey_xmss_deserialize_state(k, b)) != 0) {
428                         ret = r;
429                         goto done;
430                 }
431                 ret = 0;
432         }
433 done:
434         if (fd != -1)
435                 close(fd);
436         free(data);
437         sshbuf_free(enc);
438         sshbuf_free(b);
439         return ret;
440 }
441
442 int
443 sshkey_xmss_get_state(const struct sshkey *k, sshkey_printfn *pr)
444 {
445         struct ssh_xmss_state *state = k->xmss_state;
446         u_int32_t idx = 0;
447         char *filename = NULL;
448         char *statefile = NULL, *ostatefile = NULL, *lockfile = NULL;
449         int lockfd = -1, have_state = 0, have_ostate, tries = 0;
450         int ret = SSH_ERR_INVALID_ARGUMENT, r;
451
452         if (state == NULL)
453                 goto done;
454         /*
455          * If maxidx is set, then we are allowed a limited number
456          * of signatures, but don't need to access the disk.
457          * Otherwise we need to deal with the on-disk state.
458          */
459         if (state->maxidx) {
460                 /* xmss_sk always contains the current state */
461                 idx = PEEK_U32(k->xmss_sk);
462                 if (idx < state->maxidx) {
463                         state->allow_update = 1;
464                         return 0;
465                 }
466                 return SSH_ERR_INVALID_ARGUMENT;
467         }
468         if ((filename = k->xmss_filename) == NULL)
469                 goto done;
470         if (asprintf(&lockfile, "%s.lock", filename) < 0 ||
471             asprintf(&statefile, "%s.state", filename) < 0 ||
472             asprintf(&ostatefile, "%s.ostate", filename) < 0) {
473                 ret = SSH_ERR_ALLOC_FAIL;
474                 goto done;
475         }
476         if ((lockfd = open(lockfile, O_CREAT|O_RDONLY, 0600)) < 0) {
477                 ret = SSH_ERR_SYSTEM_ERROR;
478                 PRINT("%s: cannot open/create: %s", __func__, lockfile);
479                 goto done;
480         }
481         while (flock(lockfd, LOCK_EX|LOCK_NB) < 0) {
482                 if (errno != EWOULDBLOCK) {
483                         ret = SSH_ERR_SYSTEM_ERROR;
484                         PRINT("%s: cannot lock: %s", __func__, lockfile);
485                         goto done;
486                 }
487                 if (++tries > 10) {
488                         ret = SSH_ERR_SYSTEM_ERROR;
489                         PRINT("%s: giving up on: %s", __func__, lockfile);
490                         goto done;
491                 }
492                 usleep(1000*100*tries);
493         }
494         /* XXX no longer const */
495         if ((r = sshkey_xmss_get_state_from_file((struct sshkey *)k,
496             statefile, &have_state, pr)) != 0) {
497                 if ((r = sshkey_xmss_get_state_from_file((struct sshkey *)k,
498                     ostatefile, &have_ostate, pr)) == 0) {
499                         state->allow_update = 1;
500                         r = sshkey_xmss_forward_state(k, 1);
501                         state->idx = PEEK_U32(k->xmss_sk);
502                         state->allow_update = 0;
503                 }
504         }
505         if (!have_state && !have_ostate) {
506                 /* check that bds state is initialized */
507                 if (state->bds.auth == NULL)
508                         goto done;
509                 PRINT("%s: start from scratch idx 0: %u", __func__, state->idx);
510         } else if (r != 0) {
511                 ret = r;
512                 goto done;
513         }
514         if (state->idx + 1 < state->idx) {
515                 PRINT("%s: state wrap: %u", __func__, state->idx);
516                 goto done;
517         }
518         state->have_state = have_state;
519         state->lockfd = lockfd;
520         state->allow_update = 1;
521         lockfd = -1;
522         ret = 0;
523 done:
524         if (lockfd != -1)
525                 close(lockfd);
526         free(lockfile);
527         free(statefile);
528         free(ostatefile);
529         return ret;
530 }
531
532 int
533 sshkey_xmss_forward_state(const struct sshkey *k, u_int32_t reserve)
534 {
535         struct ssh_xmss_state *state = k->xmss_state;
536         u_char *sig = NULL;
537         size_t required_siglen;
538         unsigned long long smlen;
539         u_char data;
540         int ret, r;
541
542         if (state == NULL || !state->allow_update)
543                 return SSH_ERR_INVALID_ARGUMENT;
544         if (reserve == 0)
545                 return SSH_ERR_INVALID_ARGUMENT;
546         if (state->idx + reserve <= state->idx)
547                 return SSH_ERR_INVALID_ARGUMENT;
548         if ((r = sshkey_xmss_siglen(k, &required_siglen)) != 0)
549                 return r;
550         if ((sig = malloc(required_siglen)) == NULL)
551                 return SSH_ERR_ALLOC_FAIL;
552         while (reserve-- > 0) {
553                 state->idx = PEEK_U32(k->xmss_sk);
554                 smlen = required_siglen;
555                 if ((ret = xmss_sign(k->xmss_sk, sshkey_xmss_bds_state(k),
556                     sig, &smlen, &data, 0, sshkey_xmss_params(k))) != 0) {
557                         r = SSH_ERR_INVALID_ARGUMENT;
558                         break;
559                 }
560         }
561         free(sig);
562         return r;
563 }
564
565 int
566 sshkey_xmss_update_state(const struct sshkey *k, sshkey_printfn *pr)
567 {
568         struct ssh_xmss_state *state = k->xmss_state;
569         struct sshbuf *b = NULL, *enc = NULL;
570         u_int32_t idx = 0;
571         unsigned char buf[4];
572         char *filename = NULL;
573         char *statefile = NULL, *ostatefile = NULL, *nstatefile = NULL;
574         int fd = -1;
575         int ret = SSH_ERR_INVALID_ARGUMENT;
576
577         if (state == NULL || !state->allow_update)
578                 return ret;
579         if (state->maxidx) {
580                 /* no update since the number of signatures is limited */
581                 ret = 0;
582                 goto done;
583         }
584         idx = PEEK_U32(k->xmss_sk);
585         if (idx == state->idx) {
586                 /* no signature happened, no need to update */
587                 ret = 0;
588                 goto done;
589         } else if (idx != state->idx + 1) {
590                 PRINT("%s: more than one signature happened: idx %u state %u",
591                      __func__, idx, state->idx);
592                 goto done;
593         }
594         state->idx = idx;
595         if ((filename = k->xmss_filename) == NULL)
596                 goto done;
597         if (asprintf(&statefile, "%s.state", filename) < 0 ||
598             asprintf(&ostatefile, "%s.ostate", filename) < 0 ||
599             asprintf(&nstatefile, "%s.nstate", filename) < 0) {
600                 ret = SSH_ERR_ALLOC_FAIL;
601                 goto done;
602         }
603         unlink(nstatefile);
604         if ((b = sshbuf_new()) == NULL) {
605                 ret = SSH_ERR_ALLOC_FAIL;
606                 goto done;
607         }
608         if ((ret = sshkey_xmss_serialize_state(k, b)) != 0) {
609                 PRINT("%s: SERLIALIZE FAILED: %d", __func__, ret);
610                 goto done;
611         }
612         if ((ret = sshkey_xmss_encrypt_state(k, b, &enc)) != 0) {
613                 PRINT("%s: ENCRYPT FAILED: %d", __func__, ret);
614                 goto done;
615         }
616         if ((fd = open(nstatefile, O_CREAT|O_WRONLY|O_EXCL, 0600)) < 0) {
617                 ret = SSH_ERR_SYSTEM_ERROR;
618                 PRINT("%s: open new state file: %s", __func__, nstatefile);
619                 goto done;
620         }
621         POKE_U32(buf, sshbuf_len(enc));
622         if (atomicio(vwrite, fd, buf, sizeof(buf)) != sizeof(buf)) {
623                 ret = SSH_ERR_SYSTEM_ERROR;
624                 PRINT("%s: write new state file hdr: %s", __func__, nstatefile);
625                 close(fd);
626                 goto done;
627         }
628         if (atomicio(vwrite, fd, sshbuf_mutable_ptr(enc), sshbuf_len(enc)) !=
629             sshbuf_len(enc)) {
630                 ret = SSH_ERR_SYSTEM_ERROR;
631                 PRINT("%s: write new state file data: %s", __func__, nstatefile);
632                 close(fd);
633                 goto done;
634         }
635         if (fsync(fd) < 0) {
636                 ret = SSH_ERR_SYSTEM_ERROR;
637                 PRINT("%s: sync new state file: %s", __func__, nstatefile);
638                 close(fd);
639                 goto done;
640         }
641         if (close(fd) < 0) {
642                 ret = SSH_ERR_SYSTEM_ERROR;
643                 PRINT("%s: close new state file: %s", __func__, nstatefile);
644                 goto done;
645         }
646         if (state->have_state) {
647                 unlink(ostatefile);
648                 if (link(statefile, ostatefile)) {
649                         ret = SSH_ERR_SYSTEM_ERROR;
650                         PRINT("%s: backup state %s to %s", __func__, statefile,
651                             ostatefile);
652                         goto done;
653                 }
654         }
655         if (rename(nstatefile, statefile) < 0) {
656                 ret = SSH_ERR_SYSTEM_ERROR;
657                 PRINT("%s: rename %s to %s", __func__, nstatefile, statefile);
658                 goto done;
659         }
660         ret = 0;
661 done:
662         if (state->lockfd != -1) {
663                 close(state->lockfd);
664                 state->lockfd = -1;
665         }
666         if (nstatefile)
667                 unlink(nstatefile);
668         free(statefile);
669         free(ostatefile);
670         free(nstatefile);
671         sshbuf_free(b);
672         sshbuf_free(enc);
673         return ret;
674 }
675
676 int
677 sshkey_xmss_serialize_state(const struct sshkey *k, struct sshbuf *b)
678 {
679         struct ssh_xmss_state *state = k->xmss_state;
680         treehash_inst *th;
681         u_int32_t i, node;
682         int r;
683
684         if (state == NULL)
685                 return SSH_ERR_INVALID_ARGUMENT;
686         if (state->stack == NULL)
687                 return SSH_ERR_INVALID_ARGUMENT;
688         state->stackoffset = state->bds.stackoffset;    /* copy back */
689         if ((r = sshbuf_put_cstring(b, SSH_XMSS_K2_MAGIC)) != 0 ||
690             (r = sshbuf_put_u32(b, state->idx)) != 0 ||
691             (r = sshbuf_put_string(b, state->stack, num_stack(state))) != 0 ||
692             (r = sshbuf_put_u32(b, state->stackoffset)) != 0 ||
693             (r = sshbuf_put_string(b, state->stacklevels, num_stacklevels(state))) != 0 ||
694             (r = sshbuf_put_string(b, state->auth, num_auth(state))) != 0 ||
695             (r = sshbuf_put_string(b, state->keep, num_keep(state))) != 0 ||
696             (r = sshbuf_put_string(b, state->th_nodes, num_th_nodes(state))) != 0 ||
697             (r = sshbuf_put_string(b, state->retain, num_retain(state))) != 0 ||
698             (r = sshbuf_put_u32(b, num_treehash(state))) != 0)
699                 return r;
700         for (i = 0; i < num_treehash(state); i++) {
701                 th = &state->treehash[i];
702                 node = th->node - state->th_nodes;
703                 if ((r = sshbuf_put_u32(b, th->h)) != 0 ||
704                     (r = sshbuf_put_u32(b, th->next_idx)) != 0 ||
705                     (r = sshbuf_put_u32(b, th->stackusage)) != 0 ||
706                     (r = sshbuf_put_u8(b, th->completed)) != 0 ||
707                     (r = sshbuf_put_u32(b, node)) != 0)
708                         return r;
709         }
710         return 0;
711 }
712
713 int
714 sshkey_xmss_serialize_state_opt(const struct sshkey *k, struct sshbuf *b,
715     enum sshkey_serialize_rep opts)
716 {
717         struct ssh_xmss_state *state = k->xmss_state;
718         int r = SSH_ERR_INVALID_ARGUMENT;
719
720         if (state == NULL)
721                 return SSH_ERR_INVALID_ARGUMENT;
722         if ((r = sshbuf_put_u8(b, opts)) != 0)
723                 return r;
724         switch (opts) {
725         case SSHKEY_SERIALIZE_STATE:
726                 r = sshkey_xmss_serialize_state(k, b);
727                 break;
728         case SSHKEY_SERIALIZE_FULL:
729                 if ((r = sshkey_xmss_serialize_enc_key(k, b)) != 0)
730                         break;
731                 r = sshkey_xmss_serialize_state(k, b);
732                 break;
733         case SSHKEY_SERIALIZE_DEFAULT:
734                 r = 0;
735                 break;
736         default:
737                 r = SSH_ERR_INVALID_ARGUMENT;
738                 break;
739         }
740         return r;
741 }
742
743 int
744 sshkey_xmss_deserialize_state(struct sshkey *k, struct sshbuf *b)
745 {
746         struct ssh_xmss_state *state = k->xmss_state;
747         treehash_inst *th;
748         u_int32_t i, lh, node;
749         size_t ls, lsl, la, lk, ln, lr;
750         char *magic;
751         int r;
752
753         if (state == NULL)
754                 return SSH_ERR_INVALID_ARGUMENT;
755         if (k->xmss_sk == NULL)
756                 return SSH_ERR_INVALID_ARGUMENT;
757         if ((state->treehash = calloc(num_treehash(state),
758             sizeof(treehash_inst))) == NULL)
759                 return SSH_ERR_ALLOC_FAIL;
760         if ((r = sshbuf_get_cstring(b, &magic, NULL)) != 0 ||
761             (r = sshbuf_get_u32(b, &state->idx)) != 0 ||
762             (r = sshbuf_get_string(b, &state->stack, &ls)) != 0 ||
763             (r = sshbuf_get_u32(b, &state->stackoffset)) != 0 ||
764             (r = sshbuf_get_string(b, &state->stacklevels, &lsl)) != 0 ||
765             (r = sshbuf_get_string(b, &state->auth, &la)) != 0 ||
766             (r = sshbuf_get_string(b, &state->keep, &lk)) != 0 ||
767             (r = sshbuf_get_string(b, &state->th_nodes, &ln)) != 0 ||
768             (r = sshbuf_get_string(b, &state->retain, &lr)) != 0 ||
769             (r = sshbuf_get_u32(b, &lh)) != 0)
770                 return r;
771         if (strcmp(magic, SSH_XMSS_K2_MAGIC) != 0)
772                 return SSH_ERR_INVALID_ARGUMENT;
773         /* XXX check stackoffset */
774         if (ls != num_stack(state) ||
775             lsl != num_stacklevels(state) ||
776             la != num_auth(state) ||
777             lk != num_keep(state) ||
778             ln != num_th_nodes(state) ||
779             lr != num_retain(state) ||
780             lh != num_treehash(state))
781                 return SSH_ERR_INVALID_ARGUMENT;
782         for (i = 0; i < num_treehash(state); i++) {
783                 th = &state->treehash[i];
784                 if ((r = sshbuf_get_u32(b, &th->h)) != 0 ||
785                     (r = sshbuf_get_u32(b, &th->next_idx)) != 0 ||
786                     (r = sshbuf_get_u32(b, &th->stackusage)) != 0 ||
787                     (r = sshbuf_get_u8(b, &th->completed)) != 0 ||
788                     (r = sshbuf_get_u32(b, &node)) != 0)
789                         return r;
790                 if (node < num_th_nodes(state))
791                         th->node = &state->th_nodes[node];
792         }
793         POKE_U32(k->xmss_sk, state->idx);
794         xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
795             state->stacklevels, state->auth, state->keep, state->treehash,
796             state->retain, 0);
797         return 0;
798 }
799
800 int
801 sshkey_xmss_deserialize_state_opt(struct sshkey *k, struct sshbuf *b)
802 {
803         enum sshkey_serialize_rep opts;
804         u_char have_state;
805         int r;
806
807         if ((r = sshbuf_get_u8(b, &have_state)) != 0)
808                 return r;
809
810         opts = have_state;
811         switch (opts) {
812         case SSHKEY_SERIALIZE_DEFAULT:
813                 r = 0;
814                 break;
815         case SSHKEY_SERIALIZE_STATE:
816                 if ((r = sshkey_xmss_deserialize_state(k, b)) != 0)
817                         return r;
818                 break;
819         case SSHKEY_SERIALIZE_FULL:
820                 if ((r = sshkey_xmss_deserialize_enc_key(k, b)) != 0 ||
821                     (r = sshkey_xmss_deserialize_state(k, b)) != 0)
822                         return r;
823                 break;
824         default:
825                 r = SSH_ERR_INVALID_FORMAT;
826                 break;
827         }
828         return r;
829 }
830
831 int
832 sshkey_xmss_encrypt_state(const struct sshkey *k, struct sshbuf *b,
833    struct sshbuf **retp)
834 {
835         struct ssh_xmss_state *state = k->xmss_state;
836         struct sshbuf *encrypted = NULL, *encoded = NULL, *padded = NULL;
837         struct sshcipher_ctx *ciphercontext = NULL;
838         const struct sshcipher *cipher;
839         u_char *cp, *key, *iv = NULL;
840         size_t i, keylen, ivlen, blocksize, authlen, encrypted_len, aadlen;
841         int r = SSH_ERR_INTERNAL_ERROR;
842
843         if (retp != NULL)
844                 *retp = NULL;
845         if (state == NULL ||
846             state->enc_keyiv == NULL ||
847             state->enc_ciphername == NULL)
848                 return SSH_ERR_INTERNAL_ERROR;
849         if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
850                 r = SSH_ERR_INTERNAL_ERROR;
851                 goto out;
852         }
853         blocksize = cipher_blocksize(cipher);
854         keylen = cipher_keylen(cipher);
855         ivlen = cipher_ivlen(cipher);
856         authlen = cipher_authlen(cipher);
857         if (state->enc_keyiv_len != keylen + ivlen) {
858                 r = SSH_ERR_INVALID_FORMAT;
859                 goto out;
860         }
861         key = state->enc_keyiv;
862         if ((encrypted = sshbuf_new()) == NULL ||
863             (encoded = sshbuf_new()) == NULL ||
864             (padded = sshbuf_new()) == NULL ||
865             (iv = malloc(ivlen)) == NULL) {
866                 r = SSH_ERR_ALLOC_FAIL;
867                 goto out;
868         }
869
870         /* replace first 4 bytes of IV with index to ensure uniqueness */
871         memcpy(iv, key + keylen, ivlen);
872         POKE_U32(iv, state->idx);
873
874         if ((r = sshbuf_put(encoded, XMSS_MAGIC, sizeof(XMSS_MAGIC))) != 0 ||
875             (r = sshbuf_put_u32(encoded, state->idx)) != 0)
876                 goto out;
877
878         /* padded state will be encrypted */
879         if ((r = sshbuf_putb(padded, b)) != 0)
880                 goto out;
881         i = 0;
882         while (sshbuf_len(padded) % blocksize) {
883                 if ((r = sshbuf_put_u8(padded, ++i & 0xff)) != 0)
884                         goto out;
885         }
886         encrypted_len = sshbuf_len(padded);
887
888         /* header including the length of state is used as AAD */
889         if ((r = sshbuf_put_u32(encoded, encrypted_len)) != 0)
890                 goto out;
891         aadlen = sshbuf_len(encoded);
892
893         /* concat header and state */
894         if ((r = sshbuf_putb(encoded, padded)) != 0)
895                 goto out;
896
897         /* reserve space for encryption of encoded data plus auth tag */
898         /* encrypt at offset addlen */
899         if ((r = sshbuf_reserve(encrypted,
900             encrypted_len + aadlen + authlen, &cp)) != 0 ||
901             (r = cipher_init(&ciphercontext, cipher, key, keylen,
902             iv, ivlen, 1)) != 0 ||
903             (r = cipher_crypt(ciphercontext, 0, cp, sshbuf_ptr(encoded),
904             encrypted_len, aadlen, authlen)) != 0)
905                 goto out;
906
907         /* success */
908         r = 0;
909  out:
910         if (retp != NULL) {
911                 *retp = encrypted;
912                 encrypted = NULL;
913         }
914         sshbuf_free(padded);
915         sshbuf_free(encoded);
916         sshbuf_free(encrypted);
917         cipher_free(ciphercontext);
918         free(iv);
919         return r;
920 }
921
922 int
923 sshkey_xmss_decrypt_state(const struct sshkey *k, struct sshbuf *encoded,
924    struct sshbuf **retp)
925 {
926         struct ssh_xmss_state *state = k->xmss_state;
927         struct sshbuf *copy = NULL, *decrypted = NULL;
928         struct sshcipher_ctx *ciphercontext = NULL;
929         const struct sshcipher *cipher = NULL;
930         u_char *key, *iv = NULL, *dp;
931         size_t keylen, ivlen, authlen, aadlen;
932         u_int blocksize, encrypted_len, index;
933         int r = SSH_ERR_INTERNAL_ERROR;
934
935         if (retp != NULL)
936                 *retp = NULL;
937         if (state == NULL ||
938             state->enc_keyiv == NULL ||
939             state->enc_ciphername == NULL)
940                 return SSH_ERR_INTERNAL_ERROR;
941         if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
942                 r = SSH_ERR_INVALID_FORMAT;
943                 goto out;
944         }
945         blocksize = cipher_blocksize(cipher);
946         keylen = cipher_keylen(cipher);
947         ivlen = cipher_ivlen(cipher);
948         authlen = cipher_authlen(cipher);
949         if (state->enc_keyiv_len != keylen + ivlen) {
950                 r = SSH_ERR_INTERNAL_ERROR;
951                 goto out;
952         }
953         key = state->enc_keyiv;
954
955         if ((copy = sshbuf_fromb(encoded)) == NULL ||
956             (decrypted = sshbuf_new()) == NULL ||
957             (iv = malloc(ivlen)) == NULL) {
958                 r = SSH_ERR_ALLOC_FAIL;
959                 goto out;
960         }
961
962         /* check magic */
963         if (sshbuf_len(encoded) < sizeof(XMSS_MAGIC) ||
964             memcmp(sshbuf_ptr(encoded), XMSS_MAGIC, sizeof(XMSS_MAGIC))) {
965                 r = SSH_ERR_INVALID_FORMAT;
966                 goto out;
967         }
968         /* parse public portion */
969         if ((r = sshbuf_consume(encoded, sizeof(XMSS_MAGIC))) != 0 ||
970             (r = sshbuf_get_u32(encoded, &index)) != 0 ||
971             (r = sshbuf_get_u32(encoded, &encrypted_len)) != 0)
972                 goto out;
973
974         /* check size of encrypted key blob */
975         if (encrypted_len < blocksize || (encrypted_len % blocksize) != 0) {
976                 r = SSH_ERR_INVALID_FORMAT;
977                 goto out;
978         }
979         /* check that an appropriate amount of auth data is present */
980         if (sshbuf_len(encoded) < encrypted_len + authlen) {
981                 r = SSH_ERR_INVALID_FORMAT;
982                 goto out;
983         }
984
985         aadlen = sshbuf_len(copy) - sshbuf_len(encoded);
986
987         /* replace first 4 bytes of IV with index to ensure uniqueness */
988         memcpy(iv, key + keylen, ivlen);
989         POKE_U32(iv, index);
990
991         /* decrypt private state of key */
992         if ((r = sshbuf_reserve(decrypted, aadlen + encrypted_len, &dp)) != 0 ||
993             (r = cipher_init(&ciphercontext, cipher, key, keylen,
994             iv, ivlen, 0)) != 0 ||
995             (r = cipher_crypt(ciphercontext, 0, dp, sshbuf_ptr(copy),
996             encrypted_len, aadlen, authlen)) != 0)
997                 goto out;
998
999         /* there should be no trailing data */
1000         if ((r = sshbuf_consume(encoded, encrypted_len + authlen)) != 0)
1001                 goto out;
1002         if (sshbuf_len(encoded) != 0) {
1003                 r = SSH_ERR_INVALID_FORMAT;
1004                 goto out;
1005         }
1006
1007         /* remove AAD */
1008         if ((r = sshbuf_consume(decrypted, aadlen)) != 0)
1009                 goto out;
1010         /* XXX encrypted includes unchecked padding */
1011
1012         /* success */
1013         r = 0;
1014         if (retp != NULL) {
1015                 *retp = decrypted;
1016                 decrypted = NULL;
1017         }
1018  out:
1019         cipher_free(ciphercontext);
1020         sshbuf_free(copy);
1021         sshbuf_free(decrypted);
1022         free(iv);
1023         return r;
1024 }
1025
1026 u_int32_t
1027 sshkey_xmss_signatures_left(const struct sshkey *k)
1028 {
1029         struct ssh_xmss_state *state = k->xmss_state;
1030         u_int32_t idx;
1031
1032         if (sshkey_type_plain(k->type) == KEY_XMSS && state &&
1033             state->maxidx) {
1034                 idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
1035                 if (idx < state->maxidx)
1036                         return state->maxidx - idx;
1037         }
1038         return 0;
1039 }
1040
1041 int
1042 sshkey_xmss_enable_maxsign(struct sshkey *k, u_int32_t maxsign)
1043 {
1044         struct ssh_xmss_state *state = k->xmss_state;
1045
1046         if (sshkey_type_plain(k->type) != KEY_XMSS)
1047                 return SSH_ERR_INVALID_ARGUMENT;
1048         if (maxsign == 0)
1049                 return 0;
1050         if (state->idx + maxsign < state->idx)
1051                 return SSH_ERR_INVALID_ARGUMENT;
1052         state->maxidx = state->idx + maxsign;
1053         return 0;
1054 }
1055 #endif /* WITH_XMSS */