]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - crypto/openssh/kex.c
Merge llvm-project release/17.x llvmorg-17.0.3-0-g888437e1b600
[FreeBSD/FreeBSD.git] / crypto / openssh / kex.c
1 /* $OpenBSD: kex.c,v 1.181 2023/08/28 03:28:43 djm Exp $ */
2 /*
3  * Copyright (c) 2000, 2001 Markus Friedl.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
15  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
16  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
17  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
18  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
19  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
20  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
21  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
23  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24  */
25
26 #include "includes.h"
27
28 #include <sys/types.h>
29 #include <errno.h>
30 #include <signal.h>
31 #include <stdarg.h>
32 #include <stdio.h>
33 #include <stdlib.h>
34 #include <string.h>
35 #include <unistd.h>
36 #ifdef HAVE_POLL_H
37 #include <poll.h>
38 #endif
39
40 #ifdef WITH_OPENSSL
41 #include <openssl/crypto.h>
42 #include <openssl/dh.h>
43 #endif
44
45 #include "ssh.h"
46 #include "ssh2.h"
47 #include "atomicio.h"
48 #include "version.h"
49 #include "packet.h"
50 #include "compat.h"
51 #include "cipher.h"
52 #include "sshkey.h"
53 #include "kex.h"
54 #include "log.h"
55 #include "mac.h"
56 #include "match.h"
57 #include "misc.h"
58 #include "dispatch.h"
59 #include "monitor.h"
60 #include "myproposal.h"
61
62 #include "ssherr.h"
63 #include "sshbuf.h"
64 #include "digest.h"
65 #include "xmalloc.h"
66
67 /* prototype */
68 static int kex_choose_conf(struct ssh *);
69 static int kex_input_newkeys(int, u_int32_t, struct ssh *);
70
71 static const char * const proposal_names[PROPOSAL_MAX] = {
72         "KEX algorithms",
73         "host key algorithms",
74         "ciphers ctos",
75         "ciphers stoc",
76         "MACs ctos",
77         "MACs stoc",
78         "compression ctos",
79         "compression stoc",
80         "languages ctos",
81         "languages stoc",
82 };
83
84 struct kexalg {
85         char *name;
86         u_int type;
87         int ec_nid;
88         int hash_alg;
89 };
90 static const struct kexalg kexalgs[] = {
91 #ifdef WITH_OPENSSL
92         { KEX_DH1, KEX_DH_GRP1_SHA1, 0, SSH_DIGEST_SHA1 },
93         { KEX_DH14_SHA1, KEX_DH_GRP14_SHA1, 0, SSH_DIGEST_SHA1 },
94         { KEX_DH14_SHA256, KEX_DH_GRP14_SHA256, 0, SSH_DIGEST_SHA256 },
95         { KEX_DH16_SHA512, KEX_DH_GRP16_SHA512, 0, SSH_DIGEST_SHA512 },
96         { KEX_DH18_SHA512, KEX_DH_GRP18_SHA512, 0, SSH_DIGEST_SHA512 },
97         { KEX_DHGEX_SHA1, KEX_DH_GEX_SHA1, 0, SSH_DIGEST_SHA1 },
98 #ifdef HAVE_EVP_SHA256
99         { KEX_DHGEX_SHA256, KEX_DH_GEX_SHA256, 0, SSH_DIGEST_SHA256 },
100 #endif /* HAVE_EVP_SHA256 */
101 #ifdef OPENSSL_HAS_ECC
102         { KEX_ECDH_SHA2_NISTP256, KEX_ECDH_SHA2,
103             NID_X9_62_prime256v1, SSH_DIGEST_SHA256 },
104         { KEX_ECDH_SHA2_NISTP384, KEX_ECDH_SHA2, NID_secp384r1,
105             SSH_DIGEST_SHA384 },
106 # ifdef OPENSSL_HAS_NISTP521
107         { KEX_ECDH_SHA2_NISTP521, KEX_ECDH_SHA2, NID_secp521r1,
108             SSH_DIGEST_SHA512 },
109 # endif /* OPENSSL_HAS_NISTP521 */
110 #endif /* OPENSSL_HAS_ECC */
111 #endif /* WITH_OPENSSL */
112 #if defined(HAVE_EVP_SHA256) || !defined(WITH_OPENSSL)
113         { KEX_CURVE25519_SHA256, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
114         { KEX_CURVE25519_SHA256_OLD, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
115 #ifdef USE_SNTRUP761X25519
116         { KEX_SNTRUP761X25519_SHA512, KEX_KEM_SNTRUP761X25519_SHA512, 0,
117             SSH_DIGEST_SHA512 },
118 #endif
119 #endif /* HAVE_EVP_SHA256 || !WITH_OPENSSL */
120         { NULL, 0, -1, -1},
121 };
122
123 char *
124 kex_alg_list(char sep)
125 {
126         char *ret = NULL, *tmp;
127         size_t nlen, rlen = 0;
128         const struct kexalg *k;
129
130         for (k = kexalgs; k->name != NULL; k++) {
131                 if (ret != NULL)
132                         ret[rlen++] = sep;
133                 nlen = strlen(k->name);
134                 if ((tmp = realloc(ret, rlen + nlen + 2)) == NULL) {
135                         free(ret);
136                         return NULL;
137                 }
138                 ret = tmp;
139                 memcpy(ret + rlen, k->name, nlen + 1);
140                 rlen += nlen;
141         }
142         return ret;
143 }
144
145 static const struct kexalg *
146 kex_alg_by_name(const char *name)
147 {
148         const struct kexalg *k;
149
150         for (k = kexalgs; k->name != NULL; k++) {
151                 if (strcmp(k->name, name) == 0)
152                         return k;
153         }
154         return NULL;
155 }
156
157 /* Validate KEX method name list */
158 int
159 kex_names_valid(const char *names)
160 {
161         char *s, *cp, *p;
162
163         if (names == NULL || strcmp(names, "") == 0)
164                 return 0;
165         if ((s = cp = strdup(names)) == NULL)
166                 return 0;
167         for ((p = strsep(&cp, ",")); p && *p != '\0';
168             (p = strsep(&cp, ","))) {
169                 if (kex_alg_by_name(p) == NULL) {
170                         error("Unsupported KEX algorithm \"%.100s\"", p);
171                         free(s);
172                         return 0;
173                 }
174         }
175         debug3("kex names ok: [%s]", names);
176         free(s);
177         return 1;
178 }
179
180 /*
181  * Concatenate algorithm names, avoiding duplicates in the process.
182  * Caller must free returned string.
183  */
184 char *
185 kex_names_cat(const char *a, const char *b)
186 {
187         char *ret = NULL, *tmp = NULL, *cp, *p, *m;
188         size_t len;
189
190         if (a == NULL || *a == '\0')
191                 return strdup(b);
192         if (b == NULL || *b == '\0')
193                 return strdup(a);
194         if (strlen(b) > 1024*1024)
195                 return NULL;
196         len = strlen(a) + strlen(b) + 2;
197         if ((tmp = cp = strdup(b)) == NULL ||
198             (ret = calloc(1, len)) == NULL) {
199                 free(tmp);
200                 return NULL;
201         }
202         strlcpy(ret, a, len);
203         for ((p = strsep(&cp, ",")); p && *p != '\0'; (p = strsep(&cp, ","))) {
204                 if ((m = match_list(ret, p, NULL)) != NULL) {
205                         free(m);
206                         continue; /* Algorithm already present */
207                 }
208                 if (strlcat(ret, ",", len) >= len ||
209                     strlcat(ret, p, len) >= len) {
210                         free(tmp);
211                         free(ret);
212                         return NULL; /* Shouldn't happen */
213                 }
214         }
215         free(tmp);
216         return ret;
217 }
218
219 /*
220  * Assemble a list of algorithms from a default list and a string from a
221  * configuration file. The user-provided string may begin with '+' to
222  * indicate that it should be appended to the default, '-' that the
223  * specified names should be removed, or '^' that they should be placed
224  * at the head.
225  */
226 int
227 kex_assemble_names(char **listp, const char *def, const char *all)
228 {
229         char *cp, *tmp, *patterns;
230         char *list = NULL, *ret = NULL, *matching = NULL, *opatterns = NULL;
231         int r = SSH_ERR_INTERNAL_ERROR;
232
233         if (listp == NULL || def == NULL || all == NULL)
234                 return SSH_ERR_INVALID_ARGUMENT;
235
236         if (*listp == NULL || **listp == '\0') {
237                 if ((*listp = strdup(def)) == NULL)
238                         return SSH_ERR_ALLOC_FAIL;
239                 return 0;
240         }
241
242         list = *listp;
243         *listp = NULL;
244         if (*list == '+') {
245                 /* Append names to default list */
246                 if ((tmp = kex_names_cat(def, list + 1)) == NULL) {
247                         r = SSH_ERR_ALLOC_FAIL;
248                         goto fail;
249                 }
250                 free(list);
251                 list = tmp;
252         } else if (*list == '-') {
253                 /* Remove names from default list */
254                 if ((*listp = match_filter_denylist(def, list + 1)) == NULL) {
255                         r = SSH_ERR_ALLOC_FAIL;
256                         goto fail;
257                 }
258                 free(list);
259                 /* filtering has already been done */
260                 return 0;
261         } else if (*list == '^') {
262                 /* Place names at head of default list */
263                 if ((tmp = kex_names_cat(list + 1, def)) == NULL) {
264                         r = SSH_ERR_ALLOC_FAIL;
265                         goto fail;
266                 }
267                 free(list);
268                 list = tmp;
269         } else {
270                 /* Explicit list, overrides default - just use "list" as is */
271         }
272
273         /*
274          * The supplied names may be a pattern-list. For the -list case,
275          * the patterns are applied above. For the +list and explicit list
276          * cases we need to do it now.
277          */
278         ret = NULL;
279         if ((patterns = opatterns = strdup(list)) == NULL) {
280                 r = SSH_ERR_ALLOC_FAIL;
281                 goto fail;
282         }
283         /* Apply positive (i.e. non-negated) patterns from the list */
284         while ((cp = strsep(&patterns, ",")) != NULL) {
285                 if (*cp == '!') {
286                         /* negated matches are not supported here */
287                         r = SSH_ERR_INVALID_ARGUMENT;
288                         goto fail;
289                 }
290                 free(matching);
291                 if ((matching = match_filter_allowlist(all, cp)) == NULL) {
292                         r = SSH_ERR_ALLOC_FAIL;
293                         goto fail;
294                 }
295                 if ((tmp = kex_names_cat(ret, matching)) == NULL) {
296                         r = SSH_ERR_ALLOC_FAIL;
297                         goto fail;
298                 }
299                 free(ret);
300                 ret = tmp;
301         }
302         if (ret == NULL || *ret == '\0') {
303                 /* An empty name-list is an error */
304                 /* XXX better error code? */
305                 r = SSH_ERR_INVALID_ARGUMENT;
306                 goto fail;
307         }
308
309         /* success */
310         *listp = ret;
311         ret = NULL;
312         r = 0;
313
314  fail:
315         free(matching);
316         free(opatterns);
317         free(list);
318         free(ret);
319         return r;
320 }
321
322 /*
323  * Fill out a proposal array with dynamically allocated values, which may
324  * be modified as required for compatibility reasons.
325  * Any of the options may be NULL, in which case the default is used.
326  * Array contents must be freed by calling kex_proposal_free_entries.
327  */
328 void
329 kex_proposal_populate_entries(struct ssh *ssh, char *prop[PROPOSAL_MAX],
330     const char *kexalgos, const char *ciphers, const char *macs,
331     const char *comp, const char *hkalgs)
332 {
333         const char *defpropserver[PROPOSAL_MAX] = { KEX_SERVER };
334         const char *defpropclient[PROPOSAL_MAX] = { KEX_CLIENT };
335         const char **defprop = ssh->kex->server ? defpropserver : defpropclient;
336         u_int i;
337
338         if (prop == NULL)
339                 fatal_f("proposal missing");
340
341         for (i = 0; i < PROPOSAL_MAX; i++) {
342                 switch(i) {
343                 case PROPOSAL_KEX_ALGS:
344                         prop[i] = compat_kex_proposal(ssh,
345                             kexalgos ? kexalgos : defprop[i]);
346                         break;
347                 case PROPOSAL_ENC_ALGS_CTOS:
348                 case PROPOSAL_ENC_ALGS_STOC:
349                         prop[i] = xstrdup(ciphers ? ciphers : defprop[i]);
350                         break;
351                 case PROPOSAL_MAC_ALGS_CTOS:
352                 case PROPOSAL_MAC_ALGS_STOC:
353                         prop[i]  = xstrdup(macs ? macs : defprop[i]);
354                         break;
355                 case PROPOSAL_COMP_ALGS_CTOS:
356                 case PROPOSAL_COMP_ALGS_STOC:
357                         prop[i] = xstrdup(comp ? comp : defprop[i]);
358                         break;
359                 case PROPOSAL_SERVER_HOST_KEY_ALGS:
360                         prop[i] = xstrdup(hkalgs ? hkalgs : defprop[i]);
361                         break;
362                 default:
363                         prop[i] = xstrdup(defprop[i]);
364                 }
365         }
366 }
367
368 void
369 kex_proposal_free_entries(char *prop[PROPOSAL_MAX])
370 {
371         u_int i;
372
373         for (i = 0; i < PROPOSAL_MAX; i++)
374                 free(prop[i]);
375 }
376
377 /* put algorithm proposal into buffer */
378 int
379 kex_prop2buf(struct sshbuf *b, char *proposal[PROPOSAL_MAX])
380 {
381         u_int i;
382         int r;
383
384         sshbuf_reset(b);
385
386         /*
387          * add a dummy cookie, the cookie will be overwritten by
388          * kex_send_kexinit(), each time a kexinit is set
389          */
390         for (i = 0; i < KEX_COOKIE_LEN; i++) {
391                 if ((r = sshbuf_put_u8(b, 0)) != 0)
392                         return r;
393         }
394         for (i = 0; i < PROPOSAL_MAX; i++) {
395                 if ((r = sshbuf_put_cstring(b, proposal[i])) != 0)
396                         return r;
397         }
398         if ((r = sshbuf_put_u8(b, 0)) != 0 ||   /* first_kex_packet_follows */
399             (r = sshbuf_put_u32(b, 0)) != 0)    /* uint32 reserved */
400                 return r;
401         return 0;
402 }
403
404 /* parse buffer and return algorithm proposal */
405 int
406 kex_buf2prop(struct sshbuf *raw, int *first_kex_follows, char ***propp)
407 {
408         struct sshbuf *b = NULL;
409         u_char v;
410         u_int i;
411         char **proposal = NULL;
412         int r;
413
414         *propp = NULL;
415         if ((proposal = calloc(PROPOSAL_MAX, sizeof(char *))) == NULL)
416                 return SSH_ERR_ALLOC_FAIL;
417         if ((b = sshbuf_fromb(raw)) == NULL) {
418                 r = SSH_ERR_ALLOC_FAIL;
419                 goto out;
420         }
421         if ((r = sshbuf_consume(b, KEX_COOKIE_LEN)) != 0) { /* skip cookie */
422                 error_fr(r, "consume cookie");
423                 goto out;
424         }
425         /* extract kex init proposal strings */
426         for (i = 0; i < PROPOSAL_MAX; i++) {
427                 if ((r = sshbuf_get_cstring(b, &(proposal[i]), NULL)) != 0) {
428                         error_fr(r, "parse proposal %u", i);
429                         goto out;
430                 }
431                 debug2("%s: %s", proposal_names[i], proposal[i]);
432         }
433         /* first kex follows / reserved */
434         if ((r = sshbuf_get_u8(b, &v)) != 0 ||  /* first_kex_follows */
435             (r = sshbuf_get_u32(b, &i)) != 0) { /* reserved */
436                 error_fr(r, "parse");
437                 goto out;
438         }
439         if (first_kex_follows != NULL)
440                 *first_kex_follows = v;
441         debug2("first_kex_follows %d ", v);
442         debug2("reserved %u ", i);
443         r = 0;
444         *propp = proposal;
445  out:
446         if (r != 0 && proposal != NULL)
447                 kex_prop_free(proposal);
448         sshbuf_free(b);
449         return r;
450 }
451
452 void
453 kex_prop_free(char **proposal)
454 {
455         u_int i;
456
457         if (proposal == NULL)
458                 return;
459         for (i = 0; i < PROPOSAL_MAX; i++)
460                 free(proposal[i]);
461         free(proposal);
462 }
463
464 int
465 kex_protocol_error(int type, u_int32_t seq, struct ssh *ssh)
466 {
467         int r;
468
469         error("kex protocol error: type %d seq %u", type, seq);
470         if ((r = sshpkt_start(ssh, SSH2_MSG_UNIMPLEMENTED)) != 0 ||
471             (r = sshpkt_put_u32(ssh, seq)) != 0 ||
472             (r = sshpkt_send(ssh)) != 0)
473                 return r;
474         return 0;
475 }
476
477 static void
478 kex_reset_dispatch(struct ssh *ssh)
479 {
480         ssh_dispatch_range(ssh, SSH2_MSG_TRANSPORT_MIN,
481             SSH2_MSG_TRANSPORT_MAX, &kex_protocol_error);
482 }
483
484 static int
485 kex_send_ext_info(struct ssh *ssh)
486 {
487         int r;
488         char *algs;
489
490         debug("Sending SSH2_MSG_EXT_INFO");
491         if ((algs = sshkey_alg_list(0, 1, 1, ',')) == NULL)
492                 return SSH_ERR_ALLOC_FAIL;
493         /* XXX filter algs list by allowed pubkey/hostbased types */
494         if ((r = sshpkt_start(ssh, SSH2_MSG_EXT_INFO)) != 0 ||
495             (r = sshpkt_put_u32(ssh, 3)) != 0 ||
496             (r = sshpkt_put_cstring(ssh, "server-sig-algs")) != 0 ||
497             (r = sshpkt_put_cstring(ssh, algs)) != 0 ||
498             (r = sshpkt_put_cstring(ssh,
499             "publickey-hostbound@openssh.com")) != 0 ||
500             (r = sshpkt_put_cstring(ssh, "0")) != 0 ||
501             (r = sshpkt_put_cstring(ssh, "ping@openssh.com")) != 0 ||
502             (r = sshpkt_put_cstring(ssh, "0")) != 0 ||
503             (r = sshpkt_send(ssh)) != 0) {
504                 error_fr(r, "compose");
505                 goto out;
506         }
507         /* success */
508         r = 0;
509  out:
510         free(algs);
511         return r;
512 }
513
514 int
515 kex_send_newkeys(struct ssh *ssh)
516 {
517         int r;
518
519         kex_reset_dispatch(ssh);
520         if ((r = sshpkt_start(ssh, SSH2_MSG_NEWKEYS)) != 0 ||
521             (r = sshpkt_send(ssh)) != 0)
522                 return r;
523         debug("SSH2_MSG_NEWKEYS sent");
524         ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_input_newkeys);
525         if (ssh->kex->ext_info_c && (ssh->kex->flags & KEX_INITIAL) != 0)
526                 if ((r = kex_send_ext_info(ssh)) != 0)
527                         return r;
528         debug("expecting SSH2_MSG_NEWKEYS");
529         return 0;
530 }
531
532 /* Check whether an ext_info value contains the expected version string */
533 static int
534 kex_ext_info_check_ver(struct kex *kex, const char *name,
535     const u_char *val, size_t len, const char *want_ver, u_int flag)
536 {
537         if (memchr(val, '\0', len) != NULL) {
538                 error("SSH2_MSG_EXT_INFO: %s value contains nul byte", name);
539                 return SSH_ERR_INVALID_FORMAT;
540         }
541         debug_f("%s=<%s>", name, val);
542         if (strcmp(val, want_ver) == 0)
543                 kex->flags |= flag;
544         else
545                 debug_f("unsupported version of %s extension", name);
546         return 0;
547 }
548
549 int
550 kex_input_ext_info(int type, u_int32_t seq, struct ssh *ssh)
551 {
552         struct kex *kex = ssh->kex;
553         u_int32_t i, ninfo;
554         char *name;
555         u_char *val;
556         size_t vlen;
557         int r;
558
559         debug("SSH2_MSG_EXT_INFO received");
560         ssh_dispatch_set(ssh, SSH2_MSG_EXT_INFO, &kex_protocol_error);
561         if ((r = sshpkt_get_u32(ssh, &ninfo)) != 0)
562                 return r;
563         if (ninfo >= 1024) {
564                 error("SSH2_MSG_EXT_INFO with too many entries, expected "
565                     "<=1024, received %u", ninfo);
566                 return SSH_ERR_INVALID_FORMAT;
567         }
568         for (i = 0; i < ninfo; i++) {
569                 if ((r = sshpkt_get_cstring(ssh, &name, NULL)) != 0)
570                         return r;
571                 if ((r = sshpkt_get_string(ssh, &val, &vlen)) != 0) {
572                         free(name);
573                         return r;
574                 }
575                 if (strcmp(name, "server-sig-algs") == 0) {
576                         /* Ensure no \0 lurking in value */
577                         if (memchr(val, '\0', vlen) != NULL) {
578                                 error_f("nul byte in %s", name);
579                                 free(name);
580                                 free(val);
581                                 return SSH_ERR_INVALID_FORMAT;
582                         }
583                         debug_f("%s=<%s>", name, val);
584                         kex->server_sig_algs = val;
585                         val = NULL;
586                 } else if (strcmp(name,
587                     "publickey-hostbound@openssh.com") == 0) {
588                         if ((r = kex_ext_info_check_ver(kex, name, val, vlen,
589                             "0", KEX_HAS_PUBKEY_HOSTBOUND)) != 0) {
590                                 free(name);
591                                 free(val);
592                                 return r;
593                         }
594                 } else if (strcmp(name, "ping@openssh.com") == 0) {
595                         if ((r = kex_ext_info_check_ver(kex, name, val, vlen,
596                             "0", KEX_HAS_PING)) != 0) {
597                                 free(name);
598                                 free(val);
599                                 return r;
600                         }
601                 } else
602                         debug_f("%s (unrecognised)", name);
603                 free(name);
604                 free(val);
605         }
606         return sshpkt_get_end(ssh);
607 }
608
609 static int
610 kex_input_newkeys(int type, u_int32_t seq, struct ssh *ssh)
611 {
612         struct kex *kex = ssh->kex;
613         int r;
614
615         debug("SSH2_MSG_NEWKEYS received");
616         ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_protocol_error);
617         ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
618         if ((r = sshpkt_get_end(ssh)) != 0)
619                 return r;
620         if ((r = ssh_set_newkeys(ssh, MODE_IN)) != 0)
621                 return r;
622         kex->done = 1;
623         kex->flags &= ~KEX_INITIAL;
624         sshbuf_reset(kex->peer);
625         /* sshbuf_reset(kex->my); */
626         kex->flags &= ~KEX_INIT_SENT;
627         free(kex->name);
628         kex->name = NULL;
629         return 0;
630 }
631
632 int
633 kex_send_kexinit(struct ssh *ssh)
634 {
635         u_char *cookie;
636         struct kex *kex = ssh->kex;
637         int r;
638
639         if (kex == NULL) {
640                 error_f("no kex");
641                 return SSH_ERR_INTERNAL_ERROR;
642         }
643         if (kex->flags & KEX_INIT_SENT)
644                 return 0;
645         kex->done = 0;
646
647         /* generate a random cookie */
648         if (sshbuf_len(kex->my) < KEX_COOKIE_LEN) {
649                 error_f("bad kex length: %zu < %d",
650                     sshbuf_len(kex->my), KEX_COOKIE_LEN);
651                 return SSH_ERR_INVALID_FORMAT;
652         }
653         if ((cookie = sshbuf_mutable_ptr(kex->my)) == NULL) {
654                 error_f("buffer error");
655                 return SSH_ERR_INTERNAL_ERROR;
656         }
657         arc4random_buf(cookie, KEX_COOKIE_LEN);
658
659         if ((r = sshpkt_start(ssh, SSH2_MSG_KEXINIT)) != 0 ||
660             (r = sshpkt_putb(ssh, kex->my)) != 0 ||
661             (r = sshpkt_send(ssh)) != 0) {
662                 error_fr(r, "compose reply");
663                 return r;
664         }
665         debug("SSH2_MSG_KEXINIT sent");
666         kex->flags |= KEX_INIT_SENT;
667         return 0;
668 }
669
670 int
671 kex_input_kexinit(int type, u_int32_t seq, struct ssh *ssh)
672 {
673         struct kex *kex = ssh->kex;
674         const u_char *ptr;
675         u_int i;
676         size_t dlen;
677         int r;
678
679         debug("SSH2_MSG_KEXINIT received");
680         if (kex == NULL) {
681                 error_f("no kex");
682                 return SSH_ERR_INTERNAL_ERROR;
683         }
684         ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, NULL);
685         ptr = sshpkt_ptr(ssh, &dlen);
686         if ((r = sshbuf_put(kex->peer, ptr, dlen)) != 0)
687                 return r;
688
689         /* discard packet */
690         for (i = 0; i < KEX_COOKIE_LEN; i++) {
691                 if ((r = sshpkt_get_u8(ssh, NULL)) != 0) {
692                         error_fr(r, "discard cookie");
693                         return r;
694                 }
695         }
696         for (i = 0; i < PROPOSAL_MAX; i++) {
697                 if ((r = sshpkt_get_string(ssh, NULL, NULL)) != 0) {
698                         error_fr(r, "discard proposal");
699                         return r;
700                 }
701         }
702         /*
703          * XXX RFC4253 sec 7: "each side MAY guess" - currently no supported
704          * KEX method has the server move first, but a server might be using
705          * a custom method or one that we otherwise don't support. We should
706          * be prepared to remember first_kex_follows here so we can eat a
707          * packet later.
708          * XXX2 - RFC4253 is kind of ambiguous on what first_kex_follows means
709          * for cases where the server *doesn't* go first. I guess we should
710          * ignore it when it is set for these cases, which is what we do now.
711          */
712         if ((r = sshpkt_get_u8(ssh, NULL)) != 0 ||      /* first_kex_follows */
713             (r = sshpkt_get_u32(ssh, NULL)) != 0 ||     /* reserved */
714             (r = sshpkt_get_end(ssh)) != 0)
715                         return r;
716
717         if (!(kex->flags & KEX_INIT_SENT))
718                 if ((r = kex_send_kexinit(ssh)) != 0)
719                         return r;
720         if ((r = kex_choose_conf(ssh)) != 0)
721                 return r;
722
723         if (kex->kex_type < KEX_MAX && kex->kex[kex->kex_type] != NULL)
724                 return (kex->kex[kex->kex_type])(ssh);
725
726         error_f("unknown kex type %u", kex->kex_type);
727         return SSH_ERR_INTERNAL_ERROR;
728 }
729
730 struct kex *
731 kex_new(void)
732 {
733         struct kex *kex;
734
735         if ((kex = calloc(1, sizeof(*kex))) == NULL ||
736             (kex->peer = sshbuf_new()) == NULL ||
737             (kex->my = sshbuf_new()) == NULL ||
738             (kex->client_version = sshbuf_new()) == NULL ||
739             (kex->server_version = sshbuf_new()) == NULL ||
740             (kex->session_id = sshbuf_new()) == NULL) {
741                 kex_free(kex);
742                 return NULL;
743         }
744         return kex;
745 }
746
747 void
748 kex_free_newkeys(struct newkeys *newkeys)
749 {
750         if (newkeys == NULL)
751                 return;
752         if (newkeys->enc.key) {
753                 explicit_bzero(newkeys->enc.key, newkeys->enc.key_len);
754                 free(newkeys->enc.key);
755                 newkeys->enc.key = NULL;
756         }
757         if (newkeys->enc.iv) {
758                 explicit_bzero(newkeys->enc.iv, newkeys->enc.iv_len);
759                 free(newkeys->enc.iv);
760                 newkeys->enc.iv = NULL;
761         }
762         free(newkeys->enc.name);
763         explicit_bzero(&newkeys->enc, sizeof(newkeys->enc));
764         free(newkeys->comp.name);
765         explicit_bzero(&newkeys->comp, sizeof(newkeys->comp));
766         mac_clear(&newkeys->mac);
767         if (newkeys->mac.key) {
768                 explicit_bzero(newkeys->mac.key, newkeys->mac.key_len);
769                 free(newkeys->mac.key);
770                 newkeys->mac.key = NULL;
771         }
772         free(newkeys->mac.name);
773         explicit_bzero(&newkeys->mac, sizeof(newkeys->mac));
774         freezero(newkeys, sizeof(*newkeys));
775 }
776
777 void
778 kex_free(struct kex *kex)
779 {
780         u_int mode;
781
782         if (kex == NULL)
783                 return;
784
785 #ifdef WITH_OPENSSL
786         DH_free(kex->dh);
787 #ifdef OPENSSL_HAS_ECC
788         EC_KEY_free(kex->ec_client_key);
789 #endif /* OPENSSL_HAS_ECC */
790 #endif /* WITH_OPENSSL */
791         for (mode = 0; mode < MODE_MAX; mode++) {
792                 kex_free_newkeys(kex->newkeys[mode]);
793                 kex->newkeys[mode] = NULL;
794         }
795         sshbuf_free(kex->peer);
796         sshbuf_free(kex->my);
797         sshbuf_free(kex->client_version);
798         sshbuf_free(kex->server_version);
799         sshbuf_free(kex->client_pub);
800         sshbuf_free(kex->session_id);
801         sshbuf_free(kex->initial_sig);
802         sshkey_free(kex->initial_hostkey);
803         free(kex->failed_choice);
804         free(kex->hostkey_alg);
805         free(kex->name);
806         free(kex);
807 }
808
809 int
810 kex_ready(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
811 {
812         int r;
813
814         if ((r = kex_prop2buf(ssh->kex->my, proposal)) != 0)
815                 return r;
816         ssh->kex->flags = KEX_INITIAL;
817         kex_reset_dispatch(ssh);
818         ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
819         return 0;
820 }
821
822 int
823 kex_setup(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
824 {
825         int r;
826
827         if ((r = kex_ready(ssh, proposal)) != 0)
828                 return r;
829         if ((r = kex_send_kexinit(ssh)) != 0) {         /* we start */
830                 kex_free(ssh->kex);
831                 ssh->kex = NULL;
832                 return r;
833         }
834         return 0;
835 }
836
837 /*
838  * Request key re-exchange, returns 0 on success or a ssherr.h error
839  * code otherwise. Must not be called if KEX is incomplete or in-progress.
840  */
841 int
842 kex_start_rekex(struct ssh *ssh)
843 {
844         if (ssh->kex == NULL) {
845                 error_f("no kex");
846                 return SSH_ERR_INTERNAL_ERROR;
847         }
848         if (ssh->kex->done == 0) {
849                 error_f("requested twice");
850                 return SSH_ERR_INTERNAL_ERROR;
851         }
852         ssh->kex->done = 0;
853         return kex_send_kexinit(ssh);
854 }
855
856 static int
857 choose_enc(struct sshenc *enc, char *client, char *server)
858 {
859         char *name = match_list(client, server, NULL);
860
861         if (name == NULL)
862                 return SSH_ERR_NO_CIPHER_ALG_MATCH;
863         if ((enc->cipher = cipher_by_name(name)) == NULL) {
864                 error_f("unsupported cipher %s", name);
865                 free(name);
866                 return SSH_ERR_INTERNAL_ERROR;
867         }
868         enc->name = name;
869         enc->enabled = 0;
870         enc->iv = NULL;
871         enc->iv_len = cipher_ivlen(enc->cipher);
872         enc->key = NULL;
873         enc->key_len = cipher_keylen(enc->cipher);
874         enc->block_size = cipher_blocksize(enc->cipher);
875         return 0;
876 }
877
878 static int
879 choose_mac(struct ssh *ssh, struct sshmac *mac, char *client, char *server)
880 {
881         char *name = match_list(client, server, NULL);
882
883         if (name == NULL)
884                 return SSH_ERR_NO_MAC_ALG_MATCH;
885         if (mac_setup(mac, name) < 0) {
886                 error_f("unsupported MAC %s", name);
887                 free(name);
888                 return SSH_ERR_INTERNAL_ERROR;
889         }
890         mac->name = name;
891         mac->key = NULL;
892         mac->enabled = 0;
893         return 0;
894 }
895
896 static int
897 choose_comp(struct sshcomp *comp, char *client, char *server)
898 {
899         char *name = match_list(client, server, NULL);
900
901         if (name == NULL)
902                 return SSH_ERR_NO_COMPRESS_ALG_MATCH;
903 #ifdef WITH_ZLIB
904         if (strcmp(name, "zlib@openssh.com") == 0) {
905                 comp->type = COMP_DELAYED;
906         } else if (strcmp(name, "zlib") == 0) {
907                 comp->type = COMP_ZLIB;
908         } else
909 #endif  /* WITH_ZLIB */
910         if (strcmp(name, "none") == 0) {
911                 comp->type = COMP_NONE;
912         } else {
913                 error_f("unsupported compression scheme %s", name);
914                 free(name);
915                 return SSH_ERR_INTERNAL_ERROR;
916         }
917         comp->name = name;
918         return 0;
919 }
920
921 static int
922 choose_kex(struct kex *k, char *client, char *server)
923 {
924         const struct kexalg *kexalg;
925
926         k->name = match_list(client, server, NULL);
927
928         debug("kex: algorithm: %s", k->name ? k->name : "(no match)");
929         if (k->name == NULL)
930                 return SSH_ERR_NO_KEX_ALG_MATCH;
931         if ((kexalg = kex_alg_by_name(k->name)) == NULL) {
932                 error_f("unsupported KEX method %s", k->name);
933                 return SSH_ERR_INTERNAL_ERROR;
934         }
935         k->kex_type = kexalg->type;
936         k->hash_alg = kexalg->hash_alg;
937         k->ec_nid = kexalg->ec_nid;
938         return 0;
939 }
940
941 static int
942 choose_hostkeyalg(struct kex *k, char *client, char *server)
943 {
944         free(k->hostkey_alg);
945         k->hostkey_alg = match_list(client, server, NULL);
946
947         debug("kex: host key algorithm: %s",
948             k->hostkey_alg ? k->hostkey_alg : "(no match)");
949         if (k->hostkey_alg == NULL)
950                 return SSH_ERR_NO_HOSTKEY_ALG_MATCH;
951         k->hostkey_type = sshkey_type_from_name(k->hostkey_alg);
952         if (k->hostkey_type == KEY_UNSPEC) {
953                 error_f("unsupported hostkey algorithm %s", k->hostkey_alg);
954                 return SSH_ERR_INTERNAL_ERROR;
955         }
956         k->hostkey_nid = sshkey_ecdsa_nid_from_name(k->hostkey_alg);
957         return 0;
958 }
959
960 static int
961 proposals_match(char *my[PROPOSAL_MAX], char *peer[PROPOSAL_MAX])
962 {
963         static int check[] = {
964                 PROPOSAL_KEX_ALGS, PROPOSAL_SERVER_HOST_KEY_ALGS, -1
965         };
966         int *idx;
967         char *p;
968
969         for (idx = &check[0]; *idx != -1; idx++) {
970                 if ((p = strchr(my[*idx], ',')) != NULL)
971                         *p = '\0';
972                 if ((p = strchr(peer[*idx], ',')) != NULL)
973                         *p = '\0';
974                 if (strcmp(my[*idx], peer[*idx]) != 0) {
975                         debug2("proposal mismatch: my %s peer %s",
976                             my[*idx], peer[*idx]);
977                         return (0);
978                 }
979         }
980         debug2("proposals match");
981         return (1);
982 }
983
984 /* returns non-zero if proposal contains any algorithm from algs */
985 static int
986 has_any_alg(const char *proposal, const char *algs)
987 {
988         char *cp;
989
990         if ((cp = match_list(proposal, algs, NULL)) == NULL)
991                 return 0;
992         free(cp);
993         return 1;
994 }
995
996 static int
997 kex_choose_conf(struct ssh *ssh)
998 {
999         struct kex *kex = ssh->kex;
1000         struct newkeys *newkeys;
1001         char **my = NULL, **peer = NULL;
1002         char **cprop, **sprop;
1003         int nenc, nmac, ncomp;
1004         u_int mode, ctos, need, dh_need, authlen;
1005         int r, first_kex_follows;
1006
1007         debug2("local %s KEXINIT proposal", kex->server ? "server" : "client");
1008         if ((r = kex_buf2prop(kex->my, NULL, &my)) != 0)
1009                 goto out;
1010         debug2("peer %s KEXINIT proposal", kex->server ? "client" : "server");
1011         if ((r = kex_buf2prop(kex->peer, &first_kex_follows, &peer)) != 0)
1012                 goto out;
1013
1014         if (kex->server) {
1015                 cprop=peer;
1016                 sprop=my;
1017         } else {
1018                 cprop=my;
1019                 sprop=peer;
1020         }
1021
1022         /* Check whether client supports ext_info_c */
1023         if (kex->server && (kex->flags & KEX_INITIAL)) {
1024                 char *ext;
1025
1026                 ext = match_list("ext-info-c", peer[PROPOSAL_KEX_ALGS], NULL);
1027                 kex->ext_info_c = (ext != NULL);
1028                 free(ext);
1029         }
1030
1031         /* Check whether client supports rsa-sha2 algorithms */
1032         if (kex->server && (kex->flags & KEX_INITIAL)) {
1033                 if (has_any_alg(peer[PROPOSAL_SERVER_HOST_KEY_ALGS],
1034                     "rsa-sha2-256,rsa-sha2-256-cert-v01@openssh.com"))
1035                         kex->flags |= KEX_RSA_SHA2_256_SUPPORTED;
1036                 if (has_any_alg(peer[PROPOSAL_SERVER_HOST_KEY_ALGS],
1037                     "rsa-sha2-512,rsa-sha2-512-cert-v01@openssh.com"))
1038                         kex->flags |= KEX_RSA_SHA2_512_SUPPORTED;
1039         }
1040
1041         /* Algorithm Negotiation */
1042         if ((r = choose_kex(kex, cprop[PROPOSAL_KEX_ALGS],
1043             sprop[PROPOSAL_KEX_ALGS])) != 0) {
1044                 kex->failed_choice = peer[PROPOSAL_KEX_ALGS];
1045                 peer[PROPOSAL_KEX_ALGS] = NULL;
1046                 goto out;
1047         }
1048         if ((r = choose_hostkeyalg(kex, cprop[PROPOSAL_SERVER_HOST_KEY_ALGS],
1049             sprop[PROPOSAL_SERVER_HOST_KEY_ALGS])) != 0) {
1050                 kex->failed_choice = peer[PROPOSAL_SERVER_HOST_KEY_ALGS];
1051                 peer[PROPOSAL_SERVER_HOST_KEY_ALGS] = NULL;
1052                 goto out;
1053         }
1054         for (mode = 0; mode < MODE_MAX; mode++) {
1055                 if ((newkeys = calloc(1, sizeof(*newkeys))) == NULL) {
1056                         r = SSH_ERR_ALLOC_FAIL;
1057                         goto out;
1058                 }
1059                 kex->newkeys[mode] = newkeys;
1060                 ctos = (!kex->server && mode == MODE_OUT) ||
1061                     (kex->server && mode == MODE_IN);
1062                 nenc  = ctos ? PROPOSAL_ENC_ALGS_CTOS  : PROPOSAL_ENC_ALGS_STOC;
1063                 nmac  = ctos ? PROPOSAL_MAC_ALGS_CTOS  : PROPOSAL_MAC_ALGS_STOC;
1064                 ncomp = ctos ? PROPOSAL_COMP_ALGS_CTOS : PROPOSAL_COMP_ALGS_STOC;
1065                 if ((r = choose_enc(&newkeys->enc, cprop[nenc],
1066                     sprop[nenc])) != 0) {
1067                         kex->failed_choice = peer[nenc];
1068                         peer[nenc] = NULL;
1069                         goto out;
1070                 }
1071                 authlen = cipher_authlen(newkeys->enc.cipher);
1072                 /* ignore mac for authenticated encryption */
1073                 if (authlen == 0 &&
1074                     (r = choose_mac(ssh, &newkeys->mac, cprop[nmac],
1075                     sprop[nmac])) != 0) {
1076                         kex->failed_choice = peer[nmac];
1077                         peer[nmac] = NULL;
1078                         goto out;
1079                 }
1080                 if ((r = choose_comp(&newkeys->comp, cprop[ncomp],
1081                     sprop[ncomp])) != 0) {
1082                         kex->failed_choice = peer[ncomp];
1083                         peer[ncomp] = NULL;
1084                         goto out;
1085                 }
1086                 debug("kex: %s cipher: %s MAC: %s compression: %s",
1087                     ctos ? "client->server" : "server->client",
1088                     newkeys->enc.name,
1089                     authlen == 0 ? newkeys->mac.name : "<implicit>",
1090                     newkeys->comp.name);
1091         }
1092         need = dh_need = 0;
1093         for (mode = 0; mode < MODE_MAX; mode++) {
1094                 newkeys = kex->newkeys[mode];
1095                 need = MAXIMUM(need, newkeys->enc.key_len);
1096                 need = MAXIMUM(need, newkeys->enc.block_size);
1097                 need = MAXIMUM(need, newkeys->enc.iv_len);
1098                 need = MAXIMUM(need, newkeys->mac.key_len);
1099                 dh_need = MAXIMUM(dh_need, cipher_seclen(newkeys->enc.cipher));
1100                 dh_need = MAXIMUM(dh_need, newkeys->enc.block_size);
1101                 dh_need = MAXIMUM(dh_need, newkeys->enc.iv_len);
1102                 dh_need = MAXIMUM(dh_need, newkeys->mac.key_len);
1103         }
1104         /* XXX need runden? */
1105         kex->we_need = need;
1106         kex->dh_need = dh_need;
1107
1108         /* ignore the next message if the proposals do not match */
1109         if (first_kex_follows && !proposals_match(my, peer))
1110                 ssh->dispatch_skip_packets = 1;
1111         r = 0;
1112  out:
1113         kex_prop_free(my);
1114         kex_prop_free(peer);
1115         return r;
1116 }
1117
1118 static int
1119 derive_key(struct ssh *ssh, int id, u_int need, u_char *hash, u_int hashlen,
1120     const struct sshbuf *shared_secret, u_char **keyp)
1121 {
1122         struct kex *kex = ssh->kex;
1123         struct ssh_digest_ctx *hashctx = NULL;
1124         char c = id;
1125         u_int have;
1126         size_t mdsz;
1127         u_char *digest;
1128         int r;
1129
1130         if ((mdsz = ssh_digest_bytes(kex->hash_alg)) == 0)
1131                 return SSH_ERR_INVALID_ARGUMENT;
1132         if ((digest = calloc(1, ROUNDUP(need, mdsz))) == NULL) {
1133                 r = SSH_ERR_ALLOC_FAIL;
1134                 goto out;
1135         }
1136
1137         /* K1 = HASH(K || H || "A" || session_id) */
1138         if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
1139             ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
1140             ssh_digest_update(hashctx, hash, hashlen) != 0 ||
1141             ssh_digest_update(hashctx, &c, 1) != 0 ||
1142             ssh_digest_update_buffer(hashctx, kex->session_id) != 0 ||
1143             ssh_digest_final(hashctx, digest, mdsz) != 0) {
1144                 r = SSH_ERR_LIBCRYPTO_ERROR;
1145                 error_f("KEX hash failed");
1146                 goto out;
1147         }
1148         ssh_digest_free(hashctx);
1149         hashctx = NULL;
1150
1151         /*
1152          * expand key:
1153          * Kn = HASH(K || H || K1 || K2 || ... || Kn-1)
1154          * Key = K1 || K2 || ... || Kn
1155          */
1156         for (have = mdsz; need > have; have += mdsz) {
1157                 if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
1158                     ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
1159                     ssh_digest_update(hashctx, hash, hashlen) != 0 ||
1160                     ssh_digest_update(hashctx, digest, have) != 0 ||
1161                     ssh_digest_final(hashctx, digest + have, mdsz) != 0) {
1162                         error_f("KDF failed");
1163                         r = SSH_ERR_LIBCRYPTO_ERROR;
1164                         goto out;
1165                 }
1166                 ssh_digest_free(hashctx);
1167                 hashctx = NULL;
1168         }
1169 #ifdef DEBUG_KEX
1170         fprintf(stderr, "key '%c'== ", c);
1171         dump_digest("key", digest, need);
1172 #endif
1173         *keyp = digest;
1174         digest = NULL;
1175         r = 0;
1176  out:
1177         free(digest);
1178         ssh_digest_free(hashctx);
1179         return r;
1180 }
1181
1182 #define NKEYS   6
1183 int
1184 kex_derive_keys(struct ssh *ssh, u_char *hash, u_int hashlen,
1185     const struct sshbuf *shared_secret)
1186 {
1187         struct kex *kex = ssh->kex;
1188         u_char *keys[NKEYS];
1189         u_int i, j, mode, ctos;
1190         int r;
1191
1192         /* save initial hash as session id */
1193         if ((kex->flags & KEX_INITIAL) != 0) {
1194                 if (sshbuf_len(kex->session_id) != 0) {
1195                         error_f("already have session ID at kex");
1196                         return SSH_ERR_INTERNAL_ERROR;
1197                 }
1198                 if ((r = sshbuf_put(kex->session_id, hash, hashlen)) != 0)
1199                         return r;
1200         } else if (sshbuf_len(kex->session_id) == 0) {
1201                 error_f("no session ID in rekex");
1202                 return SSH_ERR_INTERNAL_ERROR;
1203         }
1204         for (i = 0; i < NKEYS; i++) {
1205                 if ((r = derive_key(ssh, 'A'+i, kex->we_need, hash, hashlen,
1206                     shared_secret, &keys[i])) != 0) {
1207                         for (j = 0; j < i; j++)
1208                                 free(keys[j]);
1209                         return r;
1210                 }
1211         }
1212         for (mode = 0; mode < MODE_MAX; mode++) {
1213                 ctos = (!kex->server && mode == MODE_OUT) ||
1214                     (kex->server && mode == MODE_IN);
1215                 kex->newkeys[mode]->enc.iv  = keys[ctos ? 0 : 1];
1216                 kex->newkeys[mode]->enc.key = keys[ctos ? 2 : 3];
1217                 kex->newkeys[mode]->mac.key = keys[ctos ? 4 : 5];
1218         }
1219         return 0;
1220 }
1221
1222 int
1223 kex_load_hostkey(struct ssh *ssh, struct sshkey **prvp, struct sshkey **pubp)
1224 {
1225         struct kex *kex = ssh->kex;
1226
1227         *pubp = NULL;
1228         *prvp = NULL;
1229         if (kex->load_host_public_key == NULL ||
1230             kex->load_host_private_key == NULL) {
1231                 error_f("missing hostkey loader");
1232                 return SSH_ERR_INVALID_ARGUMENT;
1233         }
1234         *pubp = kex->load_host_public_key(kex->hostkey_type,
1235             kex->hostkey_nid, ssh);
1236         *prvp = kex->load_host_private_key(kex->hostkey_type,
1237             kex->hostkey_nid, ssh);
1238         if (*pubp == NULL)
1239                 return SSH_ERR_NO_HOSTKEY_LOADED;
1240         return 0;
1241 }
1242
1243 int
1244 kex_verify_host_key(struct ssh *ssh, struct sshkey *server_host_key)
1245 {
1246         struct kex *kex = ssh->kex;
1247
1248         if (kex->verify_host_key == NULL) {
1249                 error_f("missing hostkey verifier");
1250                 return SSH_ERR_INVALID_ARGUMENT;
1251         }
1252         if (server_host_key->type != kex->hostkey_type ||
1253             (kex->hostkey_type == KEY_ECDSA &&
1254             server_host_key->ecdsa_nid != kex->hostkey_nid))
1255                 return SSH_ERR_KEY_TYPE_MISMATCH;
1256         if (kex->verify_host_key(server_host_key, ssh) == -1)
1257                 return  SSH_ERR_SIGNATURE_INVALID;
1258         return 0;
1259 }
1260
1261 #if defined(DEBUG_KEX) || defined(DEBUG_KEXDH) || defined(DEBUG_KEXECDH)
1262 void
1263 dump_digest(const char *msg, const u_char *digest, int len)
1264 {
1265         fprintf(stderr, "%s\n", msg);
1266         sshbuf_dump_data(digest, len, stderr);
1267 }
1268 #endif
1269
1270 /*
1271  * Send a plaintext error message to the peer, suffixed by \r\n.
1272  * Only used during banner exchange, and there only for the server.
1273  */
1274 static void
1275 send_error(struct ssh *ssh, char *msg)
1276 {
1277         char *crnl = "\r\n";
1278
1279         if (!ssh->kex->server)
1280                 return;
1281
1282         if (atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1283             msg, strlen(msg)) != strlen(msg) ||
1284             atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1285             crnl, strlen(crnl)) != strlen(crnl))
1286                 error_f("write: %.100s", strerror(errno));
1287 }
1288
1289 /*
1290  * Sends our identification string and waits for the peer's. Will block for
1291  * up to timeout_ms (or indefinitely if timeout_ms <= 0).
1292  * Returns on 0 success or a ssherr.h code on failure.
1293  */
1294 int
1295 kex_exchange_identification(struct ssh *ssh, int timeout_ms,
1296     const char *version_addendum)
1297 {
1298         int remote_major, remote_minor, mismatch, oerrno = 0;
1299         size_t len, n;
1300         int r, expect_nl;
1301         u_char c;
1302         struct sshbuf *our_version = ssh->kex->server ?
1303             ssh->kex->server_version : ssh->kex->client_version;
1304         struct sshbuf *peer_version = ssh->kex->server ?
1305             ssh->kex->client_version : ssh->kex->server_version;
1306         char *our_version_string = NULL, *peer_version_string = NULL;
1307         char *cp, *remote_version = NULL;
1308
1309         /* Prepare and send our banner */
1310         sshbuf_reset(our_version);
1311         if (version_addendum != NULL && *version_addendum == '\0')
1312                 version_addendum = NULL;
1313         if ((r = sshbuf_putf(our_version, "SSH-%d.%d-%.100s%s%s\r\n",
1314             PROTOCOL_MAJOR_2, PROTOCOL_MINOR_2, SSH_VERSION,
1315             version_addendum == NULL ? "" : " ",
1316             version_addendum == NULL ? "" : version_addendum)) != 0) {
1317                 oerrno = errno;
1318                 error_fr(r, "sshbuf_putf");
1319                 goto out;
1320         }
1321
1322         if (atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1323             sshbuf_mutable_ptr(our_version),
1324             sshbuf_len(our_version)) != sshbuf_len(our_version)) {
1325                 oerrno = errno;
1326                 debug_f("write: %.100s", strerror(errno));
1327                 r = SSH_ERR_SYSTEM_ERROR;
1328                 goto out;
1329         }
1330         if ((r = sshbuf_consume_end(our_version, 2)) != 0) { /* trim \r\n */
1331                 oerrno = errno;
1332                 error_fr(r, "sshbuf_consume_end");
1333                 goto out;
1334         }
1335         our_version_string = sshbuf_dup_string(our_version);
1336         if (our_version_string == NULL) {
1337                 error_f("sshbuf_dup_string failed");
1338                 r = SSH_ERR_ALLOC_FAIL;
1339                 goto out;
1340         }
1341         debug("Local version string %.100s", our_version_string);
1342
1343         /* Read other side's version identification. */
1344         for (n = 0; ; n++) {
1345                 if (n >= SSH_MAX_PRE_BANNER_LINES) {
1346                         send_error(ssh, "No SSH identification string "
1347                             "received.");
1348                         error_f("No SSH version received in first %u lines "
1349                             "from server", SSH_MAX_PRE_BANNER_LINES);
1350                         r = SSH_ERR_INVALID_FORMAT;
1351                         goto out;
1352                 }
1353                 sshbuf_reset(peer_version);
1354                 expect_nl = 0;
1355                 for (;;) {
1356                         if (timeout_ms > 0) {
1357                                 r = waitrfd(ssh_packet_get_connection_in(ssh),
1358                                     &timeout_ms, NULL);
1359                                 if (r == -1 && errno == ETIMEDOUT) {
1360                                         send_error(ssh, "Timed out waiting "
1361                                             "for SSH identification string.");
1362                                         error("Connection timed out during "
1363                                             "banner exchange");
1364                                         r = SSH_ERR_CONN_TIMEOUT;
1365                                         goto out;
1366                                 } else if (r == -1) {
1367                                         oerrno = errno;
1368                                         error_f("%s", strerror(errno));
1369                                         r = SSH_ERR_SYSTEM_ERROR;
1370                                         goto out;
1371                                 }
1372                         }
1373
1374                         len = atomicio(read, ssh_packet_get_connection_in(ssh),
1375                             &c, 1);
1376                         if (len != 1 && errno == EPIPE) {
1377                                 verbose_f("Connection closed by remote host");
1378                                 r = SSH_ERR_CONN_CLOSED;
1379                                 goto out;
1380                         } else if (len != 1) {
1381                                 oerrno = errno;
1382                                 error_f("read: %.100s", strerror(errno));
1383                                 r = SSH_ERR_SYSTEM_ERROR;
1384                                 goto out;
1385                         }
1386                         if (c == '\r') {
1387                                 expect_nl = 1;
1388                                 continue;
1389                         }
1390                         if (c == '\n')
1391                                 break;
1392                         if (c == '\0' || expect_nl) {
1393                                 verbose_f("banner line contains invalid "
1394                                     "characters");
1395                                 goto invalid;
1396                         }
1397                         if ((r = sshbuf_put_u8(peer_version, c)) != 0) {
1398                                 oerrno = errno;
1399                                 error_fr(r, "sshbuf_put");
1400                                 goto out;
1401                         }
1402                         if (sshbuf_len(peer_version) > SSH_MAX_BANNER_LEN) {
1403                                 verbose_f("banner line too long");
1404                                 goto invalid;
1405                         }
1406                 }
1407                 /* Is this an actual protocol banner? */
1408                 if (sshbuf_len(peer_version) > 4 &&
1409                     memcmp(sshbuf_ptr(peer_version), "SSH-", 4) == 0)
1410                         break;
1411                 /* If not, then just log the line and continue */
1412                 if ((cp = sshbuf_dup_string(peer_version)) == NULL) {
1413                         error_f("sshbuf_dup_string failed");
1414                         r = SSH_ERR_ALLOC_FAIL;
1415                         goto out;
1416                 }
1417                 /* Do not accept lines before the SSH ident from a client */
1418                 if (ssh->kex->server) {
1419                         verbose_f("client sent invalid protocol identifier "
1420                             "\"%.256s\"", cp);
1421                         free(cp);
1422                         goto invalid;
1423                 }
1424                 debug_f("banner line %zu: %s", n, cp);
1425                 free(cp);
1426         }
1427         peer_version_string = sshbuf_dup_string(peer_version);
1428         if (peer_version_string == NULL)
1429                 fatal_f("sshbuf_dup_string failed");
1430         /* XXX must be same size for sscanf */
1431         if ((remote_version = calloc(1, sshbuf_len(peer_version))) == NULL) {
1432                 error_f("calloc failed");
1433                 r = SSH_ERR_ALLOC_FAIL;
1434                 goto out;
1435         }
1436
1437         /*
1438          * Check that the versions match.  In future this might accept
1439          * several versions and set appropriate flags to handle them.
1440          */
1441         if (sscanf(peer_version_string, "SSH-%d.%d-%[^\n]\n",
1442             &remote_major, &remote_minor, remote_version) != 3) {
1443                 error("Bad remote protocol version identification: '%.100s'",
1444                     peer_version_string);
1445  invalid:
1446                 send_error(ssh, "Invalid SSH identification string.");
1447                 r = SSH_ERR_INVALID_FORMAT;
1448                 goto out;
1449         }
1450         debug("Remote protocol version %d.%d, remote software version %.100s",
1451             remote_major, remote_minor, remote_version);
1452         compat_banner(ssh, remote_version);
1453
1454         mismatch = 0;
1455         switch (remote_major) {
1456         case 2:
1457                 break;
1458         case 1:
1459                 if (remote_minor != 99)
1460                         mismatch = 1;
1461                 break;
1462         default:
1463                 mismatch = 1;
1464                 break;
1465         }
1466         if (mismatch) {
1467                 error("Protocol major versions differ: %d vs. %d",
1468                     PROTOCOL_MAJOR_2, remote_major);
1469                 send_error(ssh, "Protocol major versions differ.");
1470                 r = SSH_ERR_NO_PROTOCOL_VERSION;
1471                 goto out;
1472         }
1473
1474         if (ssh->kex->server && (ssh->compat & SSH_BUG_PROBE) != 0) {
1475                 logit("probed from %s port %d with %s.  Don't panic.",
1476                     ssh_remote_ipaddr(ssh), ssh_remote_port(ssh),
1477                     peer_version_string);
1478                 r = SSH_ERR_CONN_CLOSED; /* XXX */
1479                 goto out;
1480         }
1481         if (ssh->kex->server && (ssh->compat & SSH_BUG_SCANNER) != 0) {
1482                 logit("scanned from %s port %d with %s.  Don't panic.",
1483                     ssh_remote_ipaddr(ssh), ssh_remote_port(ssh),
1484                     peer_version_string);
1485                 r = SSH_ERR_CONN_CLOSED; /* XXX */
1486                 goto out;
1487         }
1488         /* success */
1489         r = 0;
1490  out:
1491         free(our_version_string);
1492         free(peer_version_string);
1493         free(remote_version);
1494         if (r == SSH_ERR_SYSTEM_ERROR)
1495                 errno = oerrno;
1496         return r;
1497 }
1498