]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - crypto/openssh/kex.c
ssh: Update to OpenSSH 9.3p1
[FreeBSD/FreeBSD.git] / crypto / openssh / kex.c
1 /* $OpenBSD: kex.c,v 1.178 2023/03/12 10:40:39 dtucker Exp $ */
2 /*
3  * Copyright (c) 2000, 2001 Markus Friedl.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
15  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
16  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
17  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
18  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
19  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
20  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
21  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
23  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24  */
25
26 #include "includes.h"
27
28 #include <sys/types.h>
29 #include <errno.h>
30 #include <signal.h>
31 #include <stdarg.h>
32 #include <stdio.h>
33 #include <stdlib.h>
34 #include <string.h>
35 #include <unistd.h>
36 #ifdef HAVE_POLL_H
37 #include <poll.h>
38 #endif
39
40 #ifdef WITH_OPENSSL
41 #include <openssl/crypto.h>
42 #include <openssl/dh.h>
43 #endif
44
45 #include "ssh.h"
46 #include "ssh2.h"
47 #include "atomicio.h"
48 #include "version.h"
49 #include "packet.h"
50 #include "compat.h"
51 #include "cipher.h"
52 #include "sshkey.h"
53 #include "kex.h"
54 #include "log.h"
55 #include "mac.h"
56 #include "match.h"
57 #include "misc.h"
58 #include "dispatch.h"
59 #include "monitor.h"
60 #include "myproposal.h"
61
62 #include "ssherr.h"
63 #include "sshbuf.h"
64 #include "digest.h"
65 #include "xmalloc.h"
66
67 /* prototype */
68 static int kex_choose_conf(struct ssh *);
69 static int kex_input_newkeys(int, u_int32_t, struct ssh *);
70
71 static const char * const proposal_names[PROPOSAL_MAX] = {
72         "KEX algorithms",
73         "host key algorithms",
74         "ciphers ctos",
75         "ciphers stoc",
76         "MACs ctos",
77         "MACs stoc",
78         "compression ctos",
79         "compression stoc",
80         "languages ctos",
81         "languages stoc",
82 };
83
84 struct kexalg {
85         char *name;
86         u_int type;
87         int ec_nid;
88         int hash_alg;
89 };
90 static const struct kexalg kexalgs[] = {
91 #ifdef WITH_OPENSSL
92         { KEX_DH1, KEX_DH_GRP1_SHA1, 0, SSH_DIGEST_SHA1 },
93         { KEX_DH14_SHA1, KEX_DH_GRP14_SHA1, 0, SSH_DIGEST_SHA1 },
94         { KEX_DH14_SHA256, KEX_DH_GRP14_SHA256, 0, SSH_DIGEST_SHA256 },
95         { KEX_DH16_SHA512, KEX_DH_GRP16_SHA512, 0, SSH_DIGEST_SHA512 },
96         { KEX_DH18_SHA512, KEX_DH_GRP18_SHA512, 0, SSH_DIGEST_SHA512 },
97         { KEX_DHGEX_SHA1, KEX_DH_GEX_SHA1, 0, SSH_DIGEST_SHA1 },
98 #ifdef HAVE_EVP_SHA256
99         { KEX_DHGEX_SHA256, KEX_DH_GEX_SHA256, 0, SSH_DIGEST_SHA256 },
100 #endif /* HAVE_EVP_SHA256 */
101 #ifdef OPENSSL_HAS_ECC
102         { KEX_ECDH_SHA2_NISTP256, KEX_ECDH_SHA2,
103             NID_X9_62_prime256v1, SSH_DIGEST_SHA256 },
104         { KEX_ECDH_SHA2_NISTP384, KEX_ECDH_SHA2, NID_secp384r1,
105             SSH_DIGEST_SHA384 },
106 # ifdef OPENSSL_HAS_NISTP521
107         { KEX_ECDH_SHA2_NISTP521, KEX_ECDH_SHA2, NID_secp521r1,
108             SSH_DIGEST_SHA512 },
109 # endif /* OPENSSL_HAS_NISTP521 */
110 #endif /* OPENSSL_HAS_ECC */
111 #endif /* WITH_OPENSSL */
112 #if defined(HAVE_EVP_SHA256) || !defined(WITH_OPENSSL)
113         { KEX_CURVE25519_SHA256, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
114         { KEX_CURVE25519_SHA256_OLD, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
115 #ifdef USE_SNTRUP761X25519
116         { KEX_SNTRUP761X25519_SHA512, KEX_KEM_SNTRUP761X25519_SHA512, 0,
117             SSH_DIGEST_SHA512 },
118 #endif
119 #endif /* HAVE_EVP_SHA256 || !WITH_OPENSSL */
120         { NULL, 0, -1, -1},
121 };
122
123 char *
124 kex_alg_list(char sep)
125 {
126         char *ret = NULL, *tmp;
127         size_t nlen, rlen = 0;
128         const struct kexalg *k;
129
130         for (k = kexalgs; k->name != NULL; k++) {
131                 if (ret != NULL)
132                         ret[rlen++] = sep;
133                 nlen = strlen(k->name);
134                 if ((tmp = realloc(ret, rlen + nlen + 2)) == NULL) {
135                         free(ret);
136                         return NULL;
137                 }
138                 ret = tmp;
139                 memcpy(ret + rlen, k->name, nlen + 1);
140                 rlen += nlen;
141         }
142         return ret;
143 }
144
145 static const struct kexalg *
146 kex_alg_by_name(const char *name)
147 {
148         const struct kexalg *k;
149
150         for (k = kexalgs; k->name != NULL; k++) {
151                 if (strcmp(k->name, name) == 0)
152                         return k;
153         }
154         return NULL;
155 }
156
157 /* Validate KEX method name list */
158 int
159 kex_names_valid(const char *names)
160 {
161         char *s, *cp, *p;
162
163         if (names == NULL || strcmp(names, "") == 0)
164                 return 0;
165         if ((s = cp = strdup(names)) == NULL)
166                 return 0;
167         for ((p = strsep(&cp, ",")); p && *p != '\0';
168             (p = strsep(&cp, ","))) {
169                 if (kex_alg_by_name(p) == NULL) {
170                         error("Unsupported KEX algorithm \"%.100s\"", p);
171                         free(s);
172                         return 0;
173                 }
174         }
175         debug3("kex names ok: [%s]", names);
176         free(s);
177         return 1;
178 }
179
180 /*
181  * Concatenate algorithm names, avoiding duplicates in the process.
182  * Caller must free returned string.
183  */
184 char *
185 kex_names_cat(const char *a, const char *b)
186 {
187         char *ret = NULL, *tmp = NULL, *cp, *p, *m;
188         size_t len;
189
190         if (a == NULL || *a == '\0')
191                 return strdup(b);
192         if (b == NULL || *b == '\0')
193                 return strdup(a);
194         if (strlen(b) > 1024*1024)
195                 return NULL;
196         len = strlen(a) + strlen(b) + 2;
197         if ((tmp = cp = strdup(b)) == NULL ||
198             (ret = calloc(1, len)) == NULL) {
199                 free(tmp);
200                 return NULL;
201         }
202         strlcpy(ret, a, len);
203         for ((p = strsep(&cp, ",")); p && *p != '\0'; (p = strsep(&cp, ","))) {
204                 if ((m = match_list(ret, p, NULL)) != NULL) {
205                         free(m);
206                         continue; /* Algorithm already present */
207                 }
208                 if (strlcat(ret, ",", len) >= len ||
209                     strlcat(ret, p, len) >= len) {
210                         free(tmp);
211                         free(ret);
212                         return NULL; /* Shouldn't happen */
213                 }
214         }
215         free(tmp);
216         return ret;
217 }
218
219 /*
220  * Assemble a list of algorithms from a default list and a string from a
221  * configuration file. The user-provided string may begin with '+' to
222  * indicate that it should be appended to the default, '-' that the
223  * specified names should be removed, or '^' that they should be placed
224  * at the head.
225  */
226 int
227 kex_assemble_names(char **listp, const char *def, const char *all)
228 {
229         char *cp, *tmp, *patterns;
230         char *list = NULL, *ret = NULL, *matching = NULL, *opatterns = NULL;
231         int r = SSH_ERR_INTERNAL_ERROR;
232
233         if (listp == NULL || def == NULL || all == NULL)
234                 return SSH_ERR_INVALID_ARGUMENT;
235
236         if (*listp == NULL || **listp == '\0') {
237                 if ((*listp = strdup(def)) == NULL)
238                         return SSH_ERR_ALLOC_FAIL;
239                 return 0;
240         }
241
242         list = *listp;
243         *listp = NULL;
244         if (*list == '+') {
245                 /* Append names to default list */
246                 if ((tmp = kex_names_cat(def, list + 1)) == NULL) {
247                         r = SSH_ERR_ALLOC_FAIL;
248                         goto fail;
249                 }
250                 free(list);
251                 list = tmp;
252         } else if (*list == '-') {
253                 /* Remove names from default list */
254                 if ((*listp = match_filter_denylist(def, list + 1)) == NULL) {
255                         r = SSH_ERR_ALLOC_FAIL;
256                         goto fail;
257                 }
258                 free(list);
259                 /* filtering has already been done */
260                 return 0;
261         } else if (*list == '^') {
262                 /* Place names at head of default list */
263                 if ((tmp = kex_names_cat(list + 1, def)) == NULL) {
264                         r = SSH_ERR_ALLOC_FAIL;
265                         goto fail;
266                 }
267                 free(list);
268                 list = tmp;
269         } else {
270                 /* Explicit list, overrides default - just use "list" as is */
271         }
272
273         /*
274          * The supplied names may be a pattern-list. For the -list case,
275          * the patterns are applied above. For the +list and explicit list
276          * cases we need to do it now.
277          */
278         ret = NULL;
279         if ((patterns = opatterns = strdup(list)) == NULL) {
280                 r = SSH_ERR_ALLOC_FAIL;
281                 goto fail;
282         }
283         /* Apply positive (i.e. non-negated) patterns from the list */
284         while ((cp = strsep(&patterns, ",")) != NULL) {
285                 if (*cp == '!') {
286                         /* negated matches are not supported here */
287                         r = SSH_ERR_INVALID_ARGUMENT;
288                         goto fail;
289                 }
290                 free(matching);
291                 if ((matching = match_filter_allowlist(all, cp)) == NULL) {
292                         r = SSH_ERR_ALLOC_FAIL;
293                         goto fail;
294                 }
295                 if ((tmp = kex_names_cat(ret, matching)) == NULL) {
296                         r = SSH_ERR_ALLOC_FAIL;
297                         goto fail;
298                 }
299                 free(ret);
300                 ret = tmp;
301         }
302         if (ret == NULL || *ret == '\0') {
303                 /* An empty name-list is an error */
304                 /* XXX better error code? */
305                 r = SSH_ERR_INVALID_ARGUMENT;
306                 goto fail;
307         }
308
309         /* success */
310         *listp = ret;
311         ret = NULL;
312         r = 0;
313
314  fail:
315         free(matching);
316         free(opatterns);
317         free(list);
318         free(ret);
319         return r;
320 }
321
322 /*
323  * Fill out a proposal array with dynamically allocated values, which may
324  * be modified as required for compatibility reasons.
325  * Any of the options may be NULL, in which case the default is used.
326  * Array contents must be freed by calling kex_proposal_free_entries.
327  */
328 void
329 kex_proposal_populate_entries(struct ssh *ssh, char *prop[PROPOSAL_MAX],
330     const char *kexalgos, const char *ciphers, const char *macs,
331     const char *comp, const char *hkalgs)
332 {
333         const char *defpropserver[PROPOSAL_MAX] = { KEX_SERVER };
334         const char *defpropclient[PROPOSAL_MAX] = { KEX_CLIENT };
335         const char **defprop = ssh->kex->server ? defpropserver : defpropclient;
336         u_int i;
337
338         if (prop == NULL)
339                 fatal_f("proposal missing");
340
341         for (i = 0; i < PROPOSAL_MAX; i++) {
342                 switch(i) {
343                 case PROPOSAL_KEX_ALGS:
344                         prop[i] = compat_kex_proposal(ssh,
345                             kexalgos ? kexalgos : defprop[i]);
346                         break;
347                 case PROPOSAL_ENC_ALGS_CTOS:
348                 case PROPOSAL_ENC_ALGS_STOC:
349                         prop[i] = xstrdup(ciphers ? ciphers : defprop[i]);
350                         break;
351                 case PROPOSAL_MAC_ALGS_CTOS:
352                 case PROPOSAL_MAC_ALGS_STOC:
353                         prop[i]  = xstrdup(macs ? macs : defprop[i]);
354                         break;
355                 case PROPOSAL_COMP_ALGS_CTOS:
356                 case PROPOSAL_COMP_ALGS_STOC:
357                         prop[i] = xstrdup(comp ? comp : defprop[i]);
358                         break;
359                 case PROPOSAL_SERVER_HOST_KEY_ALGS:
360                         prop[i] = xstrdup(hkalgs ? hkalgs : defprop[i]);
361                         break;
362                 default:
363                         prop[i] = xstrdup(defprop[i]);
364                 }
365         }
366 }
367
368 void
369 kex_proposal_free_entries(char *prop[PROPOSAL_MAX])
370 {
371         u_int i;
372
373         for (i = 0; i < PROPOSAL_MAX; i++)
374                 free(prop[i]);
375 }
376
377 /* put algorithm proposal into buffer */
378 int
379 kex_prop2buf(struct sshbuf *b, char *proposal[PROPOSAL_MAX])
380 {
381         u_int i;
382         int r;
383
384         sshbuf_reset(b);
385
386         /*
387          * add a dummy cookie, the cookie will be overwritten by
388          * kex_send_kexinit(), each time a kexinit is set
389          */
390         for (i = 0; i < KEX_COOKIE_LEN; i++) {
391                 if ((r = sshbuf_put_u8(b, 0)) != 0)
392                         return r;
393         }
394         for (i = 0; i < PROPOSAL_MAX; i++) {
395                 if ((r = sshbuf_put_cstring(b, proposal[i])) != 0)
396                         return r;
397         }
398         if ((r = sshbuf_put_u8(b, 0)) != 0 ||   /* first_kex_packet_follows */
399             (r = sshbuf_put_u32(b, 0)) != 0)    /* uint32 reserved */
400                 return r;
401         return 0;
402 }
403
404 /* parse buffer and return algorithm proposal */
405 int
406 kex_buf2prop(struct sshbuf *raw, int *first_kex_follows, char ***propp)
407 {
408         struct sshbuf *b = NULL;
409         u_char v;
410         u_int i;
411         char **proposal = NULL;
412         int r;
413
414         *propp = NULL;
415         if ((proposal = calloc(PROPOSAL_MAX, sizeof(char *))) == NULL)
416                 return SSH_ERR_ALLOC_FAIL;
417         if ((b = sshbuf_fromb(raw)) == NULL) {
418                 r = SSH_ERR_ALLOC_FAIL;
419                 goto out;
420         }
421         if ((r = sshbuf_consume(b, KEX_COOKIE_LEN)) != 0) { /* skip cookie */
422                 error_fr(r, "consume cookie");
423                 goto out;
424         }
425         /* extract kex init proposal strings */
426         for (i = 0; i < PROPOSAL_MAX; i++) {
427                 if ((r = sshbuf_get_cstring(b, &(proposal[i]), NULL)) != 0) {
428                         error_fr(r, "parse proposal %u", i);
429                         goto out;
430                 }
431                 debug2("%s: %s", proposal_names[i], proposal[i]);
432         }
433         /* first kex follows / reserved */
434         if ((r = sshbuf_get_u8(b, &v)) != 0 ||  /* first_kex_follows */
435             (r = sshbuf_get_u32(b, &i)) != 0) { /* reserved */
436                 error_fr(r, "parse");
437                 goto out;
438         }
439         if (first_kex_follows != NULL)
440                 *first_kex_follows = v;
441         debug2("first_kex_follows %d ", v);
442         debug2("reserved %u ", i);
443         r = 0;
444         *propp = proposal;
445  out:
446         if (r != 0 && proposal != NULL)
447                 kex_prop_free(proposal);
448         sshbuf_free(b);
449         return r;
450 }
451
452 void
453 kex_prop_free(char **proposal)
454 {
455         u_int i;
456
457         if (proposal == NULL)
458                 return;
459         for (i = 0; i < PROPOSAL_MAX; i++)
460                 free(proposal[i]);
461         free(proposal);
462 }
463
464 int
465 kex_protocol_error(int type, u_int32_t seq, struct ssh *ssh)
466 {
467         int r;
468
469         error("kex protocol error: type %d seq %u", type, seq);
470         if ((r = sshpkt_start(ssh, SSH2_MSG_UNIMPLEMENTED)) != 0 ||
471             (r = sshpkt_put_u32(ssh, seq)) != 0 ||
472             (r = sshpkt_send(ssh)) != 0)
473                 return r;
474         return 0;
475 }
476
477 static void
478 kex_reset_dispatch(struct ssh *ssh)
479 {
480         ssh_dispatch_range(ssh, SSH2_MSG_TRANSPORT_MIN,
481             SSH2_MSG_TRANSPORT_MAX, &kex_protocol_error);
482 }
483
484 static int
485 kex_send_ext_info(struct ssh *ssh)
486 {
487         int r;
488         char *algs;
489
490         debug("Sending SSH2_MSG_EXT_INFO");
491         if ((algs = sshkey_alg_list(0, 1, 1, ',')) == NULL)
492                 return SSH_ERR_ALLOC_FAIL;
493         /* XXX filter algs list by allowed pubkey/hostbased types */
494         if ((r = sshpkt_start(ssh, SSH2_MSG_EXT_INFO)) != 0 ||
495             (r = sshpkt_put_u32(ssh, 2)) != 0 ||
496             (r = sshpkt_put_cstring(ssh, "server-sig-algs")) != 0 ||
497             (r = sshpkt_put_cstring(ssh, algs)) != 0 ||
498             (r = sshpkt_put_cstring(ssh,
499             "publickey-hostbound@openssh.com")) != 0 ||
500             (r = sshpkt_put_cstring(ssh, "0")) != 0 ||
501             (r = sshpkt_send(ssh)) != 0) {
502                 error_fr(r, "compose");
503                 goto out;
504         }
505         /* success */
506         r = 0;
507  out:
508         free(algs);
509         return r;
510 }
511
512 int
513 kex_send_newkeys(struct ssh *ssh)
514 {
515         int r;
516
517         kex_reset_dispatch(ssh);
518         if ((r = sshpkt_start(ssh, SSH2_MSG_NEWKEYS)) != 0 ||
519             (r = sshpkt_send(ssh)) != 0)
520                 return r;
521         debug("SSH2_MSG_NEWKEYS sent");
522         ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_input_newkeys);
523         if (ssh->kex->ext_info_c && (ssh->kex->flags & KEX_INITIAL) != 0)
524                 if ((r = kex_send_ext_info(ssh)) != 0)
525                         return r;
526         debug("expecting SSH2_MSG_NEWKEYS");
527         return 0;
528 }
529
530 int
531 kex_input_ext_info(int type, u_int32_t seq, struct ssh *ssh)
532 {
533         struct kex *kex = ssh->kex;
534         u_int32_t i, ninfo;
535         char *name;
536         u_char *val;
537         size_t vlen;
538         int r;
539
540         debug("SSH2_MSG_EXT_INFO received");
541         ssh_dispatch_set(ssh, SSH2_MSG_EXT_INFO, &kex_protocol_error);
542         if ((r = sshpkt_get_u32(ssh, &ninfo)) != 0)
543                 return r;
544         if (ninfo >= 1024) {
545                 error("SSH2_MSG_EXT_INFO with too many entries, expected "
546                     "<=1024, received %u", ninfo);
547                 return SSH_ERR_INVALID_FORMAT;
548         }
549         for (i = 0; i < ninfo; i++) {
550                 if ((r = sshpkt_get_cstring(ssh, &name, NULL)) != 0)
551                         return r;
552                 if ((r = sshpkt_get_string(ssh, &val, &vlen)) != 0) {
553                         free(name);
554                         return r;
555                 }
556                 if (strcmp(name, "server-sig-algs") == 0) {
557                         /* Ensure no \0 lurking in value */
558                         if (memchr(val, '\0', vlen) != NULL) {
559                                 error_f("nul byte in %s", name);
560                                 return SSH_ERR_INVALID_FORMAT;
561                         }
562                         debug_f("%s=<%s>", name, val);
563                         kex->server_sig_algs = val;
564                         val = NULL;
565                 } else if (strcmp(name,
566                     "publickey-hostbound@openssh.com") == 0) {
567                         /* XXX refactor */
568                         /* Ensure no \0 lurking in value */
569                         if (memchr(val, '\0', vlen) != NULL) {
570                                 error_f("nul byte in %s", name);
571                                 return SSH_ERR_INVALID_FORMAT;
572                         }
573                         debug_f("%s=<%s>", name, val);
574                         if (strcmp(val, "0") == 0)
575                                 kex->flags |= KEX_HAS_PUBKEY_HOSTBOUND;
576                         else {
577                                 debug_f("unsupported version of %s extension",
578                                     name);
579                         }
580                 } else
581                         debug_f("%s (unrecognised)", name);
582                 free(name);
583                 free(val);
584         }
585         return sshpkt_get_end(ssh);
586 }
587
588 static int
589 kex_input_newkeys(int type, u_int32_t seq, struct ssh *ssh)
590 {
591         struct kex *kex = ssh->kex;
592         int r;
593
594         debug("SSH2_MSG_NEWKEYS received");
595         ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_protocol_error);
596         ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
597         if ((r = sshpkt_get_end(ssh)) != 0)
598                 return r;
599         if ((r = ssh_set_newkeys(ssh, MODE_IN)) != 0)
600                 return r;
601         kex->done = 1;
602         kex->flags &= ~KEX_INITIAL;
603         sshbuf_reset(kex->peer);
604         /* sshbuf_reset(kex->my); */
605         kex->flags &= ~KEX_INIT_SENT;
606         free(kex->name);
607         kex->name = NULL;
608         return 0;
609 }
610
611 int
612 kex_send_kexinit(struct ssh *ssh)
613 {
614         u_char *cookie;
615         struct kex *kex = ssh->kex;
616         int r;
617
618         if (kex == NULL) {
619                 error_f("no kex");
620                 return SSH_ERR_INTERNAL_ERROR;
621         }
622         if (kex->flags & KEX_INIT_SENT)
623                 return 0;
624         kex->done = 0;
625
626         /* generate a random cookie */
627         if (sshbuf_len(kex->my) < KEX_COOKIE_LEN) {
628                 error_f("bad kex length: %zu < %d",
629                     sshbuf_len(kex->my), KEX_COOKIE_LEN);
630                 return SSH_ERR_INVALID_FORMAT;
631         }
632         if ((cookie = sshbuf_mutable_ptr(kex->my)) == NULL) {
633                 error_f("buffer error");
634                 return SSH_ERR_INTERNAL_ERROR;
635         }
636         arc4random_buf(cookie, KEX_COOKIE_LEN);
637
638         if ((r = sshpkt_start(ssh, SSH2_MSG_KEXINIT)) != 0 ||
639             (r = sshpkt_putb(ssh, kex->my)) != 0 ||
640             (r = sshpkt_send(ssh)) != 0) {
641                 error_fr(r, "compose reply");
642                 return r;
643         }
644         debug("SSH2_MSG_KEXINIT sent");
645         kex->flags |= KEX_INIT_SENT;
646         return 0;
647 }
648
649 int
650 kex_input_kexinit(int type, u_int32_t seq, struct ssh *ssh)
651 {
652         struct kex *kex = ssh->kex;
653         const u_char *ptr;
654         u_int i;
655         size_t dlen;
656         int r;
657
658         debug("SSH2_MSG_KEXINIT received");
659         if (kex == NULL) {
660                 error_f("no kex");
661                 return SSH_ERR_INTERNAL_ERROR;
662         }
663         ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, NULL);
664         ptr = sshpkt_ptr(ssh, &dlen);
665         if ((r = sshbuf_put(kex->peer, ptr, dlen)) != 0)
666                 return r;
667
668         /* discard packet */
669         for (i = 0; i < KEX_COOKIE_LEN; i++) {
670                 if ((r = sshpkt_get_u8(ssh, NULL)) != 0) {
671                         error_fr(r, "discard cookie");
672                         return r;
673                 }
674         }
675         for (i = 0; i < PROPOSAL_MAX; i++) {
676                 if ((r = sshpkt_get_string(ssh, NULL, NULL)) != 0) {
677                         error_fr(r, "discard proposal");
678                         return r;
679                 }
680         }
681         /*
682          * XXX RFC4253 sec 7: "each side MAY guess" - currently no supported
683          * KEX method has the server move first, but a server might be using
684          * a custom method or one that we otherwise don't support. We should
685          * be prepared to remember first_kex_follows here so we can eat a
686          * packet later.
687          * XXX2 - RFC4253 is kind of ambiguous on what first_kex_follows means
688          * for cases where the server *doesn't* go first. I guess we should
689          * ignore it when it is set for these cases, which is what we do now.
690          */
691         if ((r = sshpkt_get_u8(ssh, NULL)) != 0 ||      /* first_kex_follows */
692             (r = sshpkt_get_u32(ssh, NULL)) != 0 ||     /* reserved */
693             (r = sshpkt_get_end(ssh)) != 0)
694                         return r;
695
696         if (!(kex->flags & KEX_INIT_SENT))
697                 if ((r = kex_send_kexinit(ssh)) != 0)
698                         return r;
699         if ((r = kex_choose_conf(ssh)) != 0)
700                 return r;
701
702         if (kex->kex_type < KEX_MAX && kex->kex[kex->kex_type] != NULL)
703                 return (kex->kex[kex->kex_type])(ssh);
704
705         error_f("unknown kex type %u", kex->kex_type);
706         return SSH_ERR_INTERNAL_ERROR;
707 }
708
709 struct kex *
710 kex_new(void)
711 {
712         struct kex *kex;
713
714         if ((kex = calloc(1, sizeof(*kex))) == NULL ||
715             (kex->peer = sshbuf_new()) == NULL ||
716             (kex->my = sshbuf_new()) == NULL ||
717             (kex->client_version = sshbuf_new()) == NULL ||
718             (kex->server_version = sshbuf_new()) == NULL ||
719             (kex->session_id = sshbuf_new()) == NULL) {
720                 kex_free(kex);
721                 return NULL;
722         }
723         return kex;
724 }
725
726 void
727 kex_free_newkeys(struct newkeys *newkeys)
728 {
729         if (newkeys == NULL)
730                 return;
731         if (newkeys->enc.key) {
732                 explicit_bzero(newkeys->enc.key, newkeys->enc.key_len);
733                 free(newkeys->enc.key);
734                 newkeys->enc.key = NULL;
735         }
736         if (newkeys->enc.iv) {
737                 explicit_bzero(newkeys->enc.iv, newkeys->enc.iv_len);
738                 free(newkeys->enc.iv);
739                 newkeys->enc.iv = NULL;
740         }
741         free(newkeys->enc.name);
742         explicit_bzero(&newkeys->enc, sizeof(newkeys->enc));
743         free(newkeys->comp.name);
744         explicit_bzero(&newkeys->comp, sizeof(newkeys->comp));
745         mac_clear(&newkeys->mac);
746         if (newkeys->mac.key) {
747                 explicit_bzero(newkeys->mac.key, newkeys->mac.key_len);
748                 free(newkeys->mac.key);
749                 newkeys->mac.key = NULL;
750         }
751         free(newkeys->mac.name);
752         explicit_bzero(&newkeys->mac, sizeof(newkeys->mac));
753         freezero(newkeys, sizeof(*newkeys));
754 }
755
756 void
757 kex_free(struct kex *kex)
758 {
759         u_int mode;
760
761         if (kex == NULL)
762                 return;
763
764 #ifdef WITH_OPENSSL
765         DH_free(kex->dh);
766 #ifdef OPENSSL_HAS_ECC
767         EC_KEY_free(kex->ec_client_key);
768 #endif /* OPENSSL_HAS_ECC */
769 #endif /* WITH_OPENSSL */
770         for (mode = 0; mode < MODE_MAX; mode++) {
771                 kex_free_newkeys(kex->newkeys[mode]);
772                 kex->newkeys[mode] = NULL;
773         }
774         sshbuf_free(kex->peer);
775         sshbuf_free(kex->my);
776         sshbuf_free(kex->client_version);
777         sshbuf_free(kex->server_version);
778         sshbuf_free(kex->client_pub);
779         sshbuf_free(kex->session_id);
780         sshbuf_free(kex->initial_sig);
781         sshkey_free(kex->initial_hostkey);
782         free(kex->failed_choice);
783         free(kex->hostkey_alg);
784         free(kex->name);
785         free(kex);
786 }
787
788 int
789 kex_ready(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
790 {
791         int r;
792
793         if ((r = kex_prop2buf(ssh->kex->my, proposal)) != 0)
794                 return r;
795         ssh->kex->flags = KEX_INITIAL;
796         kex_reset_dispatch(ssh);
797         ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
798         return 0;
799 }
800
801 int
802 kex_setup(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
803 {
804         int r;
805
806         if ((r = kex_ready(ssh, proposal)) != 0)
807                 return r;
808         if ((r = kex_send_kexinit(ssh)) != 0) {         /* we start */
809                 kex_free(ssh->kex);
810                 ssh->kex = NULL;
811                 return r;
812         }
813         return 0;
814 }
815
816 /*
817  * Request key re-exchange, returns 0 on success or a ssherr.h error
818  * code otherwise. Must not be called if KEX is incomplete or in-progress.
819  */
820 int
821 kex_start_rekex(struct ssh *ssh)
822 {
823         if (ssh->kex == NULL) {
824                 error_f("no kex");
825                 return SSH_ERR_INTERNAL_ERROR;
826         }
827         if (ssh->kex->done == 0) {
828                 error_f("requested twice");
829                 return SSH_ERR_INTERNAL_ERROR;
830         }
831         ssh->kex->done = 0;
832         return kex_send_kexinit(ssh);
833 }
834
835 static int
836 choose_enc(struct sshenc *enc, char *client, char *server)
837 {
838         char *name = match_list(client, server, NULL);
839
840         if (name == NULL)
841                 return SSH_ERR_NO_CIPHER_ALG_MATCH;
842         if ((enc->cipher = cipher_by_name(name)) == NULL) {
843                 error_f("unsupported cipher %s", name);
844                 free(name);
845                 return SSH_ERR_INTERNAL_ERROR;
846         }
847         enc->name = name;
848         enc->enabled = 0;
849         enc->iv = NULL;
850         enc->iv_len = cipher_ivlen(enc->cipher);
851         enc->key = NULL;
852         enc->key_len = cipher_keylen(enc->cipher);
853         enc->block_size = cipher_blocksize(enc->cipher);
854         return 0;
855 }
856
857 static int
858 choose_mac(struct ssh *ssh, struct sshmac *mac, char *client, char *server)
859 {
860         char *name = match_list(client, server, NULL);
861
862         if (name == NULL)
863                 return SSH_ERR_NO_MAC_ALG_MATCH;
864         if (mac_setup(mac, name) < 0) {
865                 error_f("unsupported MAC %s", name);
866                 free(name);
867                 return SSH_ERR_INTERNAL_ERROR;
868         }
869         mac->name = name;
870         mac->key = NULL;
871         mac->enabled = 0;
872         return 0;
873 }
874
875 static int
876 choose_comp(struct sshcomp *comp, char *client, char *server)
877 {
878         char *name = match_list(client, server, NULL);
879
880         if (name == NULL)
881                 return SSH_ERR_NO_COMPRESS_ALG_MATCH;
882 #ifdef WITH_ZLIB
883         if (strcmp(name, "zlib@openssh.com") == 0) {
884                 comp->type = COMP_DELAYED;
885         } else if (strcmp(name, "zlib") == 0) {
886                 comp->type = COMP_ZLIB;
887         } else
888 #endif  /* WITH_ZLIB */
889         if (strcmp(name, "none") == 0) {
890                 comp->type = COMP_NONE;
891         } else {
892                 error_f("unsupported compression scheme %s", name);
893                 free(name);
894                 return SSH_ERR_INTERNAL_ERROR;
895         }
896         comp->name = name;
897         return 0;
898 }
899
900 static int
901 choose_kex(struct kex *k, char *client, char *server)
902 {
903         const struct kexalg *kexalg;
904
905         k->name = match_list(client, server, NULL);
906
907         debug("kex: algorithm: %s", k->name ? k->name : "(no match)");
908         if (k->name == NULL)
909                 return SSH_ERR_NO_KEX_ALG_MATCH;
910         if ((kexalg = kex_alg_by_name(k->name)) == NULL) {
911                 error_f("unsupported KEX method %s", k->name);
912                 return SSH_ERR_INTERNAL_ERROR;
913         }
914         k->kex_type = kexalg->type;
915         k->hash_alg = kexalg->hash_alg;
916         k->ec_nid = kexalg->ec_nid;
917         return 0;
918 }
919
920 static int
921 choose_hostkeyalg(struct kex *k, char *client, char *server)
922 {
923         free(k->hostkey_alg);
924         k->hostkey_alg = match_list(client, server, NULL);
925
926         debug("kex: host key algorithm: %s",
927             k->hostkey_alg ? k->hostkey_alg : "(no match)");
928         if (k->hostkey_alg == NULL)
929                 return SSH_ERR_NO_HOSTKEY_ALG_MATCH;
930         k->hostkey_type = sshkey_type_from_name(k->hostkey_alg);
931         if (k->hostkey_type == KEY_UNSPEC) {
932                 error_f("unsupported hostkey algorithm %s", k->hostkey_alg);
933                 return SSH_ERR_INTERNAL_ERROR;
934         }
935         k->hostkey_nid = sshkey_ecdsa_nid_from_name(k->hostkey_alg);
936         return 0;
937 }
938
939 static int
940 proposals_match(char *my[PROPOSAL_MAX], char *peer[PROPOSAL_MAX])
941 {
942         static int check[] = {
943                 PROPOSAL_KEX_ALGS, PROPOSAL_SERVER_HOST_KEY_ALGS, -1
944         };
945         int *idx;
946         char *p;
947
948         for (idx = &check[0]; *idx != -1; idx++) {
949                 if ((p = strchr(my[*idx], ',')) != NULL)
950                         *p = '\0';
951                 if ((p = strchr(peer[*idx], ',')) != NULL)
952                         *p = '\0';
953                 if (strcmp(my[*idx], peer[*idx]) != 0) {
954                         debug2("proposal mismatch: my %s peer %s",
955                             my[*idx], peer[*idx]);
956                         return (0);
957                 }
958         }
959         debug2("proposals match");
960         return (1);
961 }
962
963 /* returns non-zero if proposal contains any algorithm from algs */
964 static int
965 has_any_alg(const char *proposal, const char *algs)
966 {
967         char *cp;
968
969         if ((cp = match_list(proposal, algs, NULL)) == NULL)
970                 return 0;
971         free(cp);
972         return 1;
973 }
974
975 static int
976 kex_choose_conf(struct ssh *ssh)
977 {
978         struct kex *kex = ssh->kex;
979         struct newkeys *newkeys;
980         char **my = NULL, **peer = NULL;
981         char **cprop, **sprop;
982         int nenc, nmac, ncomp;
983         u_int mode, ctos, need, dh_need, authlen;
984         int r, first_kex_follows;
985
986         debug2("local %s KEXINIT proposal", kex->server ? "server" : "client");
987         if ((r = kex_buf2prop(kex->my, NULL, &my)) != 0)
988                 goto out;
989         debug2("peer %s KEXINIT proposal", kex->server ? "client" : "server");
990         if ((r = kex_buf2prop(kex->peer, &first_kex_follows, &peer)) != 0)
991                 goto out;
992
993         if (kex->server) {
994                 cprop=peer;
995                 sprop=my;
996         } else {
997                 cprop=my;
998                 sprop=peer;
999         }
1000
1001         /* Check whether client supports ext_info_c */
1002         if (kex->server && (kex->flags & KEX_INITIAL)) {
1003                 char *ext;
1004
1005                 ext = match_list("ext-info-c", peer[PROPOSAL_KEX_ALGS], NULL);
1006                 kex->ext_info_c = (ext != NULL);
1007                 free(ext);
1008         }
1009
1010         /* Check whether client supports rsa-sha2 algorithms */
1011         if (kex->server && (kex->flags & KEX_INITIAL)) {
1012                 if (has_any_alg(peer[PROPOSAL_SERVER_HOST_KEY_ALGS],
1013                     "rsa-sha2-256,rsa-sha2-256-cert-v01@openssh.com"))
1014                         kex->flags |= KEX_RSA_SHA2_256_SUPPORTED;
1015                 if (has_any_alg(peer[PROPOSAL_SERVER_HOST_KEY_ALGS],
1016                     "rsa-sha2-512,rsa-sha2-512-cert-v01@openssh.com"))
1017                         kex->flags |= KEX_RSA_SHA2_512_SUPPORTED;
1018         }
1019
1020         /* Algorithm Negotiation */
1021         if ((r = choose_kex(kex, cprop[PROPOSAL_KEX_ALGS],
1022             sprop[PROPOSAL_KEX_ALGS])) != 0) {
1023                 kex->failed_choice = peer[PROPOSAL_KEX_ALGS];
1024                 peer[PROPOSAL_KEX_ALGS] = NULL;
1025                 goto out;
1026         }
1027         if ((r = choose_hostkeyalg(kex, cprop[PROPOSAL_SERVER_HOST_KEY_ALGS],
1028             sprop[PROPOSAL_SERVER_HOST_KEY_ALGS])) != 0) {
1029                 kex->failed_choice = peer[PROPOSAL_SERVER_HOST_KEY_ALGS];
1030                 peer[PROPOSAL_SERVER_HOST_KEY_ALGS] = NULL;
1031                 goto out;
1032         }
1033         for (mode = 0; mode < MODE_MAX; mode++) {
1034                 if ((newkeys = calloc(1, sizeof(*newkeys))) == NULL) {
1035                         r = SSH_ERR_ALLOC_FAIL;
1036                         goto out;
1037                 }
1038                 kex->newkeys[mode] = newkeys;
1039                 ctos = (!kex->server && mode == MODE_OUT) ||
1040                     (kex->server && mode == MODE_IN);
1041                 nenc  = ctos ? PROPOSAL_ENC_ALGS_CTOS  : PROPOSAL_ENC_ALGS_STOC;
1042                 nmac  = ctos ? PROPOSAL_MAC_ALGS_CTOS  : PROPOSAL_MAC_ALGS_STOC;
1043                 ncomp = ctos ? PROPOSAL_COMP_ALGS_CTOS : PROPOSAL_COMP_ALGS_STOC;
1044                 if ((r = choose_enc(&newkeys->enc, cprop[nenc],
1045                     sprop[nenc])) != 0) {
1046                         kex->failed_choice = peer[nenc];
1047                         peer[nenc] = NULL;
1048                         goto out;
1049                 }
1050                 authlen = cipher_authlen(newkeys->enc.cipher);
1051                 /* ignore mac for authenticated encryption */
1052                 if (authlen == 0 &&
1053                     (r = choose_mac(ssh, &newkeys->mac, cprop[nmac],
1054                     sprop[nmac])) != 0) {
1055                         kex->failed_choice = peer[nmac];
1056                         peer[nmac] = NULL;
1057                         goto out;
1058                 }
1059                 if ((r = choose_comp(&newkeys->comp, cprop[ncomp],
1060                     sprop[ncomp])) != 0) {
1061                         kex->failed_choice = peer[ncomp];
1062                         peer[ncomp] = NULL;
1063                         goto out;
1064                 }
1065                 debug("kex: %s cipher: %s MAC: %s compression: %s",
1066                     ctos ? "client->server" : "server->client",
1067                     newkeys->enc.name,
1068                     authlen == 0 ? newkeys->mac.name : "<implicit>",
1069                     newkeys->comp.name);
1070         }
1071         need = dh_need = 0;
1072         for (mode = 0; mode < MODE_MAX; mode++) {
1073                 newkeys = kex->newkeys[mode];
1074                 need = MAXIMUM(need, newkeys->enc.key_len);
1075                 need = MAXIMUM(need, newkeys->enc.block_size);
1076                 need = MAXIMUM(need, newkeys->enc.iv_len);
1077                 need = MAXIMUM(need, newkeys->mac.key_len);
1078                 dh_need = MAXIMUM(dh_need, cipher_seclen(newkeys->enc.cipher));
1079                 dh_need = MAXIMUM(dh_need, newkeys->enc.block_size);
1080                 dh_need = MAXIMUM(dh_need, newkeys->enc.iv_len);
1081                 dh_need = MAXIMUM(dh_need, newkeys->mac.key_len);
1082         }
1083         /* XXX need runden? */
1084         kex->we_need = need;
1085         kex->dh_need = dh_need;
1086
1087         /* ignore the next message if the proposals do not match */
1088         if (first_kex_follows && !proposals_match(my, peer))
1089                 ssh->dispatch_skip_packets = 1;
1090         r = 0;
1091  out:
1092         kex_prop_free(my);
1093         kex_prop_free(peer);
1094         return r;
1095 }
1096
1097 static int
1098 derive_key(struct ssh *ssh, int id, u_int need, u_char *hash, u_int hashlen,
1099     const struct sshbuf *shared_secret, u_char **keyp)
1100 {
1101         struct kex *kex = ssh->kex;
1102         struct ssh_digest_ctx *hashctx = NULL;
1103         char c = id;
1104         u_int have;
1105         size_t mdsz;
1106         u_char *digest;
1107         int r;
1108
1109         if ((mdsz = ssh_digest_bytes(kex->hash_alg)) == 0)
1110                 return SSH_ERR_INVALID_ARGUMENT;
1111         if ((digest = calloc(1, ROUNDUP(need, mdsz))) == NULL) {
1112                 r = SSH_ERR_ALLOC_FAIL;
1113                 goto out;
1114         }
1115
1116         /* K1 = HASH(K || H || "A" || session_id) */
1117         if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
1118             ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
1119             ssh_digest_update(hashctx, hash, hashlen) != 0 ||
1120             ssh_digest_update(hashctx, &c, 1) != 0 ||
1121             ssh_digest_update_buffer(hashctx, kex->session_id) != 0 ||
1122             ssh_digest_final(hashctx, digest, mdsz) != 0) {
1123                 r = SSH_ERR_LIBCRYPTO_ERROR;
1124                 error_f("KEX hash failed");
1125                 goto out;
1126         }
1127         ssh_digest_free(hashctx);
1128         hashctx = NULL;
1129
1130         /*
1131          * expand key:
1132          * Kn = HASH(K || H || K1 || K2 || ... || Kn-1)
1133          * Key = K1 || K2 || ... || Kn
1134          */
1135         for (have = mdsz; need > have; have += mdsz) {
1136                 if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
1137                     ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
1138                     ssh_digest_update(hashctx, hash, hashlen) != 0 ||
1139                     ssh_digest_update(hashctx, digest, have) != 0 ||
1140                     ssh_digest_final(hashctx, digest + have, mdsz) != 0) {
1141                         error_f("KDF failed");
1142                         r = SSH_ERR_LIBCRYPTO_ERROR;
1143                         goto out;
1144                 }
1145                 ssh_digest_free(hashctx);
1146                 hashctx = NULL;
1147         }
1148 #ifdef DEBUG_KEX
1149         fprintf(stderr, "key '%c'== ", c);
1150         dump_digest("key", digest, need);
1151 #endif
1152         *keyp = digest;
1153         digest = NULL;
1154         r = 0;
1155  out:
1156         free(digest);
1157         ssh_digest_free(hashctx);
1158         return r;
1159 }
1160
1161 #define NKEYS   6
1162 int
1163 kex_derive_keys(struct ssh *ssh, u_char *hash, u_int hashlen,
1164     const struct sshbuf *shared_secret)
1165 {
1166         struct kex *kex = ssh->kex;
1167         u_char *keys[NKEYS];
1168         u_int i, j, mode, ctos;
1169         int r;
1170
1171         /* save initial hash as session id */
1172         if ((kex->flags & KEX_INITIAL) != 0) {
1173                 if (sshbuf_len(kex->session_id) != 0) {
1174                         error_f("already have session ID at kex");
1175                         return SSH_ERR_INTERNAL_ERROR;
1176                 }
1177                 if ((r = sshbuf_put(kex->session_id, hash, hashlen)) != 0)
1178                         return r;
1179         } else if (sshbuf_len(kex->session_id) == 0) {
1180                 error_f("no session ID in rekex");
1181                 return SSH_ERR_INTERNAL_ERROR;
1182         }
1183         for (i = 0; i < NKEYS; i++) {
1184                 if ((r = derive_key(ssh, 'A'+i, kex->we_need, hash, hashlen,
1185                     shared_secret, &keys[i])) != 0) {
1186                         for (j = 0; j < i; j++)
1187                                 free(keys[j]);
1188                         return r;
1189                 }
1190         }
1191         for (mode = 0; mode < MODE_MAX; mode++) {
1192                 ctos = (!kex->server && mode == MODE_OUT) ||
1193                     (kex->server && mode == MODE_IN);
1194                 kex->newkeys[mode]->enc.iv  = keys[ctos ? 0 : 1];
1195                 kex->newkeys[mode]->enc.key = keys[ctos ? 2 : 3];
1196                 kex->newkeys[mode]->mac.key = keys[ctos ? 4 : 5];
1197         }
1198         return 0;
1199 }
1200
1201 int
1202 kex_load_hostkey(struct ssh *ssh, struct sshkey **prvp, struct sshkey **pubp)
1203 {
1204         struct kex *kex = ssh->kex;
1205
1206         *pubp = NULL;
1207         *prvp = NULL;
1208         if (kex->load_host_public_key == NULL ||
1209             kex->load_host_private_key == NULL) {
1210                 error_f("missing hostkey loader");
1211                 return SSH_ERR_INVALID_ARGUMENT;
1212         }
1213         *pubp = kex->load_host_public_key(kex->hostkey_type,
1214             kex->hostkey_nid, ssh);
1215         *prvp = kex->load_host_private_key(kex->hostkey_type,
1216             kex->hostkey_nid, ssh);
1217         if (*pubp == NULL)
1218                 return SSH_ERR_NO_HOSTKEY_LOADED;
1219         return 0;
1220 }
1221
1222 int
1223 kex_verify_host_key(struct ssh *ssh, struct sshkey *server_host_key)
1224 {
1225         struct kex *kex = ssh->kex;
1226
1227         if (kex->verify_host_key == NULL) {
1228                 error_f("missing hostkey verifier");
1229                 return SSH_ERR_INVALID_ARGUMENT;
1230         }
1231         if (server_host_key->type != kex->hostkey_type ||
1232             (kex->hostkey_type == KEY_ECDSA &&
1233             server_host_key->ecdsa_nid != kex->hostkey_nid))
1234                 return SSH_ERR_KEY_TYPE_MISMATCH;
1235         if (kex->verify_host_key(server_host_key, ssh) == -1)
1236                 return  SSH_ERR_SIGNATURE_INVALID;
1237         return 0;
1238 }
1239
1240 #if defined(DEBUG_KEX) || defined(DEBUG_KEXDH) || defined(DEBUG_KEXECDH)
1241 void
1242 dump_digest(const char *msg, const u_char *digest, int len)
1243 {
1244         fprintf(stderr, "%s\n", msg);
1245         sshbuf_dump_data(digest, len, stderr);
1246 }
1247 #endif
1248
1249 /*
1250  * Send a plaintext error message to the peer, suffixed by \r\n.
1251  * Only used during banner exchange, and there only for the server.
1252  */
1253 static void
1254 send_error(struct ssh *ssh, char *msg)
1255 {
1256         char *crnl = "\r\n";
1257
1258         if (!ssh->kex->server)
1259                 return;
1260
1261         if (atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1262             msg, strlen(msg)) != strlen(msg) ||
1263             atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1264             crnl, strlen(crnl)) != strlen(crnl))
1265                 error_f("write: %.100s", strerror(errno));
1266 }
1267
1268 /*
1269  * Sends our identification string and waits for the peer's. Will block for
1270  * up to timeout_ms (or indefinitely if timeout_ms <= 0).
1271  * Returns on 0 success or a ssherr.h code on failure.
1272  */
1273 int
1274 kex_exchange_identification(struct ssh *ssh, int timeout_ms,
1275     const char *version_addendum)
1276 {
1277         int remote_major, remote_minor, mismatch, oerrno = 0;
1278         size_t len, n;
1279         int r, expect_nl;
1280         u_char c;
1281         struct sshbuf *our_version = ssh->kex->server ?
1282             ssh->kex->server_version : ssh->kex->client_version;
1283         struct sshbuf *peer_version = ssh->kex->server ?
1284             ssh->kex->client_version : ssh->kex->server_version;
1285         char *our_version_string = NULL, *peer_version_string = NULL;
1286         char *cp, *remote_version = NULL;
1287
1288         /* Prepare and send our banner */
1289         sshbuf_reset(our_version);
1290         if (version_addendum != NULL && *version_addendum == '\0')
1291                 version_addendum = NULL;
1292         if ((r = sshbuf_putf(our_version, "SSH-%d.%d-%.100s%s%s\r\n",
1293             PROTOCOL_MAJOR_2, PROTOCOL_MINOR_2, SSH_VERSION,
1294             version_addendum == NULL ? "" : " ",
1295             version_addendum == NULL ? "" : version_addendum)) != 0) {
1296                 oerrno = errno;
1297                 error_fr(r, "sshbuf_putf");
1298                 goto out;
1299         }
1300
1301         if (atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1302             sshbuf_mutable_ptr(our_version),
1303             sshbuf_len(our_version)) != sshbuf_len(our_version)) {
1304                 oerrno = errno;
1305                 debug_f("write: %.100s", strerror(errno));
1306                 r = SSH_ERR_SYSTEM_ERROR;
1307                 goto out;
1308         }
1309         if ((r = sshbuf_consume_end(our_version, 2)) != 0) { /* trim \r\n */
1310                 oerrno = errno;
1311                 error_fr(r, "sshbuf_consume_end");
1312                 goto out;
1313         }
1314         our_version_string = sshbuf_dup_string(our_version);
1315         if (our_version_string == NULL) {
1316                 error_f("sshbuf_dup_string failed");
1317                 r = SSH_ERR_ALLOC_FAIL;
1318                 goto out;
1319         }
1320         debug("Local version string %.100s", our_version_string);
1321
1322         /* Read other side's version identification. */
1323         for (n = 0; ; n++) {
1324                 if (n >= SSH_MAX_PRE_BANNER_LINES) {
1325                         send_error(ssh, "No SSH identification string "
1326                             "received.");
1327                         error_f("No SSH version received in first %u lines "
1328                             "from server", SSH_MAX_PRE_BANNER_LINES);
1329                         r = SSH_ERR_INVALID_FORMAT;
1330                         goto out;
1331                 }
1332                 sshbuf_reset(peer_version);
1333                 expect_nl = 0;
1334                 for (;;) {
1335                         if (timeout_ms > 0) {
1336                                 r = waitrfd(ssh_packet_get_connection_in(ssh),
1337                                     &timeout_ms);
1338                                 if (r == -1 && errno == ETIMEDOUT) {
1339                                         send_error(ssh, "Timed out waiting "
1340                                             "for SSH identification string.");
1341                                         error("Connection timed out during "
1342                                             "banner exchange");
1343                                         r = SSH_ERR_CONN_TIMEOUT;
1344                                         goto out;
1345                                 } else if (r == -1) {
1346                                         oerrno = errno;
1347                                         error_f("%s", strerror(errno));
1348                                         r = SSH_ERR_SYSTEM_ERROR;
1349                                         goto out;
1350                                 }
1351                         }
1352
1353                         len = atomicio(read, ssh_packet_get_connection_in(ssh),
1354                             &c, 1);
1355                         if (len != 1 && errno == EPIPE) {
1356                                 error_f("Connection closed by remote host");
1357                                 r = SSH_ERR_CONN_CLOSED;
1358                                 goto out;
1359                         } else if (len != 1) {
1360                                 oerrno = errno;
1361                                 error_f("read: %.100s", strerror(errno));
1362                                 r = SSH_ERR_SYSTEM_ERROR;
1363                                 goto out;
1364                         }
1365                         if (c == '\r') {
1366                                 expect_nl = 1;
1367                                 continue;
1368                         }
1369                         if (c == '\n')
1370                                 break;
1371                         if (c == '\0' || expect_nl) {
1372                                 error_f("banner line contains invalid "
1373                                     "characters");
1374                                 goto invalid;
1375                         }
1376                         if ((r = sshbuf_put_u8(peer_version, c)) != 0) {
1377                                 oerrno = errno;
1378                                 error_fr(r, "sshbuf_put");
1379                                 goto out;
1380                         }
1381                         if (sshbuf_len(peer_version) > SSH_MAX_BANNER_LEN) {
1382                                 error_f("banner line too long");
1383                                 goto invalid;
1384                         }
1385                 }
1386                 /* Is this an actual protocol banner? */
1387                 if (sshbuf_len(peer_version) > 4 &&
1388                     memcmp(sshbuf_ptr(peer_version), "SSH-", 4) == 0)
1389                         break;
1390                 /* If not, then just log the line and continue */
1391                 if ((cp = sshbuf_dup_string(peer_version)) == NULL) {
1392                         error_f("sshbuf_dup_string failed");
1393                         r = SSH_ERR_ALLOC_FAIL;
1394                         goto out;
1395                 }
1396                 /* Do not accept lines before the SSH ident from a client */
1397                 if (ssh->kex->server) {
1398                         error_f("client sent invalid protocol identifier "
1399                             "\"%.256s\"", cp);
1400                         free(cp);
1401                         goto invalid;
1402                 }
1403                 debug_f("banner line %zu: %s", n, cp);
1404                 free(cp);
1405         }
1406         peer_version_string = sshbuf_dup_string(peer_version);
1407         if (peer_version_string == NULL)
1408                 fatal_f("sshbuf_dup_string failed");
1409         /* XXX must be same size for sscanf */
1410         if ((remote_version = calloc(1, sshbuf_len(peer_version))) == NULL) {
1411                 error_f("calloc failed");
1412                 r = SSH_ERR_ALLOC_FAIL;
1413                 goto out;
1414         }
1415
1416         /*
1417          * Check that the versions match.  In future this might accept
1418          * several versions and set appropriate flags to handle them.
1419          */
1420         if (sscanf(peer_version_string, "SSH-%d.%d-%[^\n]\n",
1421             &remote_major, &remote_minor, remote_version) != 3) {
1422                 error("Bad remote protocol version identification: '%.100s'",
1423                     peer_version_string);
1424  invalid:
1425                 send_error(ssh, "Invalid SSH identification string.");
1426                 r = SSH_ERR_INVALID_FORMAT;
1427                 goto out;
1428         }
1429         debug("Remote protocol version %d.%d, remote software version %.100s",
1430             remote_major, remote_minor, remote_version);
1431         compat_banner(ssh, remote_version);
1432
1433         mismatch = 0;
1434         switch (remote_major) {
1435         case 2:
1436                 break;
1437         case 1:
1438                 if (remote_minor != 99)
1439                         mismatch = 1;
1440                 break;
1441         default:
1442                 mismatch = 1;
1443                 break;
1444         }
1445         if (mismatch) {
1446                 error("Protocol major versions differ: %d vs. %d",
1447                     PROTOCOL_MAJOR_2, remote_major);
1448                 send_error(ssh, "Protocol major versions differ.");
1449                 r = SSH_ERR_NO_PROTOCOL_VERSION;
1450                 goto out;
1451         }
1452
1453         if (ssh->kex->server && (ssh->compat & SSH_BUG_PROBE) != 0) {
1454                 logit("probed from %s port %d with %s.  Don't panic.",
1455                     ssh_remote_ipaddr(ssh), ssh_remote_port(ssh),
1456                     peer_version_string);
1457                 r = SSH_ERR_CONN_CLOSED; /* XXX */
1458                 goto out;
1459         }
1460         if (ssh->kex->server && (ssh->compat & SSH_BUG_SCANNER) != 0) {
1461                 logit("scanned from %s port %d with %s.  Don't panic.",
1462                     ssh_remote_ipaddr(ssh), ssh_remote_port(ssh),
1463                     peer_version_string);
1464                 r = SSH_ERR_CONN_CLOSED; /* XXX */
1465                 goto out;
1466         }
1467         /* success */
1468         r = 0;
1469  out:
1470         free(our_version_string);
1471         free(peer_version_string);
1472         free(remote_version);
1473         if (r == SSH_ERR_SYSTEM_ERROR)
1474                 errno = oerrno;
1475         return r;
1476 }
1477