]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - crypto/openssh/kex.c
ena: Upgrade ena-com to freebsd v2.7.0
[FreeBSD/FreeBSD.git] / crypto / openssh / kex.c
1 /* $OpenBSD: kex.c,v 1.181 2023/08/28 03:28:43 djm Exp $ */
2 /*
3  * Copyright (c) 2000, 2001 Markus Friedl.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
15  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
16  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
17  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
18  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
19  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
20  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
21  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
23  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24  */
25
26 #include "includes.h"
27
28 #include <sys/types.h>
29 #include <errno.h>
30 #include <signal.h>
31 #include <stdarg.h>
32 #include <stdio.h>
33 #include <stdlib.h>
34 #include <string.h>
35 #include <unistd.h>
36 #ifdef HAVE_POLL_H
37 #include <poll.h>
38 #endif
39
40 #ifdef WITH_OPENSSL
41 #include <openssl/crypto.h>
42 #include <openssl/dh.h>
43 #endif
44
45 #include "ssh.h"
46 #include "ssh2.h"
47 #include "atomicio.h"
48 #include "version.h"
49 #include "packet.h"
50 #include "compat.h"
51 #include "cipher.h"
52 #include "sshkey.h"
53 #include "kex.h"
54 #include "log.h"
55 #include "mac.h"
56 #include "match.h"
57 #include "misc.h"
58 #include "dispatch.h"
59 #include "monitor.h"
60 #include "myproposal.h"
61
62 #include "ssherr.h"
63 #include "sshbuf.h"
64 #include "digest.h"
65 #include "xmalloc.h"
66
67 /* prototype */
68 static int kex_choose_conf(struct ssh *, uint32_t seq);
69 static int kex_input_newkeys(int, u_int32_t, struct ssh *);
70
71 static const char * const proposal_names[PROPOSAL_MAX] = {
72         "KEX algorithms",
73         "host key algorithms",
74         "ciphers ctos",
75         "ciphers stoc",
76         "MACs ctos",
77         "MACs stoc",
78         "compression ctos",
79         "compression stoc",
80         "languages ctos",
81         "languages stoc",
82 };
83
84 struct kexalg {
85         char *name;
86         u_int type;
87         int ec_nid;
88         int hash_alg;
89 };
90 static const struct kexalg kexalgs[] = {
91 #ifdef WITH_OPENSSL
92         { KEX_DH1, KEX_DH_GRP1_SHA1, 0, SSH_DIGEST_SHA1 },
93         { KEX_DH14_SHA1, KEX_DH_GRP14_SHA1, 0, SSH_DIGEST_SHA1 },
94         { KEX_DH14_SHA256, KEX_DH_GRP14_SHA256, 0, SSH_DIGEST_SHA256 },
95         { KEX_DH16_SHA512, KEX_DH_GRP16_SHA512, 0, SSH_DIGEST_SHA512 },
96         { KEX_DH18_SHA512, KEX_DH_GRP18_SHA512, 0, SSH_DIGEST_SHA512 },
97         { KEX_DHGEX_SHA1, KEX_DH_GEX_SHA1, 0, SSH_DIGEST_SHA1 },
98 #ifdef HAVE_EVP_SHA256
99         { KEX_DHGEX_SHA256, KEX_DH_GEX_SHA256, 0, SSH_DIGEST_SHA256 },
100 #endif /* HAVE_EVP_SHA256 */
101 #ifdef OPENSSL_HAS_ECC
102         { KEX_ECDH_SHA2_NISTP256, KEX_ECDH_SHA2,
103             NID_X9_62_prime256v1, SSH_DIGEST_SHA256 },
104         { KEX_ECDH_SHA2_NISTP384, KEX_ECDH_SHA2, NID_secp384r1,
105             SSH_DIGEST_SHA384 },
106 # ifdef OPENSSL_HAS_NISTP521
107         { KEX_ECDH_SHA2_NISTP521, KEX_ECDH_SHA2, NID_secp521r1,
108             SSH_DIGEST_SHA512 },
109 # endif /* OPENSSL_HAS_NISTP521 */
110 #endif /* OPENSSL_HAS_ECC */
111 #endif /* WITH_OPENSSL */
112 #if defined(HAVE_EVP_SHA256) || !defined(WITH_OPENSSL)
113         { KEX_CURVE25519_SHA256, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
114         { KEX_CURVE25519_SHA256_OLD, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
115 #ifdef USE_SNTRUP761X25519
116         { KEX_SNTRUP761X25519_SHA512, KEX_KEM_SNTRUP761X25519_SHA512, 0,
117             SSH_DIGEST_SHA512 },
118 #endif
119 #endif /* HAVE_EVP_SHA256 || !WITH_OPENSSL */
120         { NULL, 0, -1, -1},
121 };
122
123 char *
124 kex_alg_list(char sep)
125 {
126         char *ret = NULL, *tmp;
127         size_t nlen, rlen = 0;
128         const struct kexalg *k;
129
130         for (k = kexalgs; k->name != NULL; k++) {
131                 if (ret != NULL)
132                         ret[rlen++] = sep;
133                 nlen = strlen(k->name);
134                 if ((tmp = realloc(ret, rlen + nlen + 2)) == NULL) {
135                         free(ret);
136                         return NULL;
137                 }
138                 ret = tmp;
139                 memcpy(ret + rlen, k->name, nlen + 1);
140                 rlen += nlen;
141         }
142         return ret;
143 }
144
145 static const struct kexalg *
146 kex_alg_by_name(const char *name)
147 {
148         const struct kexalg *k;
149
150         for (k = kexalgs; k->name != NULL; k++) {
151                 if (strcmp(k->name, name) == 0)
152                         return k;
153         }
154         return NULL;
155 }
156
157 /* Validate KEX method name list */
158 int
159 kex_names_valid(const char *names)
160 {
161         char *s, *cp, *p;
162
163         if (names == NULL || strcmp(names, "") == 0)
164                 return 0;
165         if ((s = cp = strdup(names)) == NULL)
166                 return 0;
167         for ((p = strsep(&cp, ",")); p && *p != '\0';
168             (p = strsep(&cp, ","))) {
169                 if (kex_alg_by_name(p) == NULL) {
170                         error("Unsupported KEX algorithm \"%.100s\"", p);
171                         free(s);
172                         return 0;
173                 }
174         }
175         debug3("kex names ok: [%s]", names);
176         free(s);
177         return 1;
178 }
179
180 /* returns non-zero if proposal contains any algorithm from algs */
181 static int
182 has_any_alg(const char *proposal, const char *algs)
183 {
184         char *cp;
185
186         if ((cp = match_list(proposal, algs, NULL)) == NULL)
187                 return 0;
188         free(cp);
189         return 1;
190 }
191
192 /*
193  * Concatenate algorithm names, avoiding duplicates in the process.
194  * Caller must free returned string.
195  */
196 char *
197 kex_names_cat(const char *a, const char *b)
198 {
199         char *ret = NULL, *tmp = NULL, *cp, *p;
200         size_t len;
201
202         if (a == NULL || *a == '\0')
203                 return strdup(b);
204         if (b == NULL || *b == '\0')
205                 return strdup(a);
206         if (strlen(b) > 1024*1024)
207                 return NULL;
208         len = strlen(a) + strlen(b) + 2;
209         if ((tmp = cp = strdup(b)) == NULL ||
210             (ret = calloc(1, len)) == NULL) {
211                 free(tmp);
212                 return NULL;
213         }
214         strlcpy(ret, a, len);
215         for ((p = strsep(&cp, ",")); p && *p != '\0'; (p = strsep(&cp, ","))) {
216                 if (has_any_alg(ret, p))
217                         continue; /* Algorithm already present */
218                 if (strlcat(ret, ",", len) >= len ||
219                     strlcat(ret, p, len) >= len) {
220                         free(tmp);
221                         free(ret);
222                         return NULL; /* Shouldn't happen */
223                 }
224         }
225         free(tmp);
226         return ret;
227 }
228
229 /*
230  * Assemble a list of algorithms from a default list and a string from a
231  * configuration file. The user-provided string may begin with '+' to
232  * indicate that it should be appended to the default, '-' that the
233  * specified names should be removed, or '^' that they should be placed
234  * at the head.
235  */
236 int
237 kex_assemble_names(char **listp, const char *def, const char *all)
238 {
239         char *cp, *tmp, *patterns;
240         char *list = NULL, *ret = NULL, *matching = NULL, *opatterns = NULL;
241         int r = SSH_ERR_INTERNAL_ERROR;
242
243         if (listp == NULL || def == NULL || all == NULL)
244                 return SSH_ERR_INVALID_ARGUMENT;
245
246         if (*listp == NULL || **listp == '\0') {
247                 if ((*listp = strdup(def)) == NULL)
248                         return SSH_ERR_ALLOC_FAIL;
249                 return 0;
250         }
251
252         list = *listp;
253         *listp = NULL;
254         if (*list == '+') {
255                 /* Append names to default list */
256                 if ((tmp = kex_names_cat(def, list + 1)) == NULL) {
257                         r = SSH_ERR_ALLOC_FAIL;
258                         goto fail;
259                 }
260                 free(list);
261                 list = tmp;
262         } else if (*list == '-') {
263                 /* Remove names from default list */
264                 if ((*listp = match_filter_denylist(def, list + 1)) == NULL) {
265                         r = SSH_ERR_ALLOC_FAIL;
266                         goto fail;
267                 }
268                 free(list);
269                 /* filtering has already been done */
270                 return 0;
271         } else if (*list == '^') {
272                 /* Place names at head of default list */
273                 if ((tmp = kex_names_cat(list + 1, def)) == NULL) {
274                         r = SSH_ERR_ALLOC_FAIL;
275                         goto fail;
276                 }
277                 free(list);
278                 list = tmp;
279         } else {
280                 /* Explicit list, overrides default - just use "list" as is */
281         }
282
283         /*
284          * The supplied names may be a pattern-list. For the -list case,
285          * the patterns are applied above. For the +list and explicit list
286          * cases we need to do it now.
287          */
288         ret = NULL;
289         if ((patterns = opatterns = strdup(list)) == NULL) {
290                 r = SSH_ERR_ALLOC_FAIL;
291                 goto fail;
292         }
293         /* Apply positive (i.e. non-negated) patterns from the list */
294         while ((cp = strsep(&patterns, ",")) != NULL) {
295                 if (*cp == '!') {
296                         /* negated matches are not supported here */
297                         r = SSH_ERR_INVALID_ARGUMENT;
298                         goto fail;
299                 }
300                 free(matching);
301                 if ((matching = match_filter_allowlist(all, cp)) == NULL) {
302                         r = SSH_ERR_ALLOC_FAIL;
303                         goto fail;
304                 }
305                 if ((tmp = kex_names_cat(ret, matching)) == NULL) {
306                         r = SSH_ERR_ALLOC_FAIL;
307                         goto fail;
308                 }
309                 free(ret);
310                 ret = tmp;
311         }
312         if (ret == NULL || *ret == '\0') {
313                 /* An empty name-list is an error */
314                 /* XXX better error code? */
315                 r = SSH_ERR_INVALID_ARGUMENT;
316                 goto fail;
317         }
318
319         /* success */
320         *listp = ret;
321         ret = NULL;
322         r = 0;
323
324  fail:
325         free(matching);
326         free(opatterns);
327         free(list);
328         free(ret);
329         return r;
330 }
331
332 /*
333  * Fill out a proposal array with dynamically allocated values, which may
334  * be modified as required for compatibility reasons.
335  * Any of the options may be NULL, in which case the default is used.
336  * Array contents must be freed by calling kex_proposal_free_entries.
337  */
338 void
339 kex_proposal_populate_entries(struct ssh *ssh, char *prop[PROPOSAL_MAX],
340     const char *kexalgos, const char *ciphers, const char *macs,
341     const char *comp, const char *hkalgs)
342 {
343         const char *defpropserver[PROPOSAL_MAX] = { KEX_SERVER };
344         const char *defpropclient[PROPOSAL_MAX] = { KEX_CLIENT };
345         const char **defprop = ssh->kex->server ? defpropserver : defpropclient;
346         u_int i;
347         char *cp;
348
349         if (prop == NULL)
350                 fatal_f("proposal missing");
351
352         /* Append EXT_INFO signalling to KexAlgorithms */
353         if (kexalgos == NULL)
354                 kexalgos = defprop[PROPOSAL_KEX_ALGS];
355         if ((cp = kex_names_cat(kexalgos, ssh->kex->server ?
356             "kex-strict-s-v00@openssh.com" :
357             "ext-info-c,kex-strict-c-v00@openssh.com")) == NULL)
358                 fatal_f("kex_names_cat");
359
360         for (i = 0; i < PROPOSAL_MAX; i++) {
361                 switch(i) {
362                 case PROPOSAL_KEX_ALGS:
363                         prop[i] = compat_kex_proposal(ssh, cp);
364                         break;
365                 case PROPOSAL_ENC_ALGS_CTOS:
366                 case PROPOSAL_ENC_ALGS_STOC:
367                         prop[i] = xstrdup(ciphers ? ciphers : defprop[i]);
368                         break;
369                 case PROPOSAL_MAC_ALGS_CTOS:
370                 case PROPOSAL_MAC_ALGS_STOC:
371                         prop[i]  = xstrdup(macs ? macs : defprop[i]);
372                         break;
373                 case PROPOSAL_COMP_ALGS_CTOS:
374                 case PROPOSAL_COMP_ALGS_STOC:
375                         prop[i] = xstrdup(comp ? comp : defprop[i]);
376                         break;
377                 case PROPOSAL_SERVER_HOST_KEY_ALGS:
378                         prop[i] = xstrdup(hkalgs ? hkalgs : defprop[i]);
379                         break;
380                 default:
381                         prop[i] = xstrdup(defprop[i]);
382                 }
383         }
384         free(cp);
385 }
386
387 void
388 kex_proposal_free_entries(char *prop[PROPOSAL_MAX])
389 {
390         u_int i;
391
392         for (i = 0; i < PROPOSAL_MAX; i++)
393                 free(prop[i]);
394 }
395
396 /* put algorithm proposal into buffer */
397 int
398 kex_prop2buf(struct sshbuf *b, char *proposal[PROPOSAL_MAX])
399 {
400         u_int i;
401         int r;
402
403         sshbuf_reset(b);
404
405         /*
406          * add a dummy cookie, the cookie will be overwritten by
407          * kex_send_kexinit(), each time a kexinit is set
408          */
409         for (i = 0; i < KEX_COOKIE_LEN; i++) {
410                 if ((r = sshbuf_put_u8(b, 0)) != 0)
411                         return r;
412         }
413         for (i = 0; i < PROPOSAL_MAX; i++) {
414                 if ((r = sshbuf_put_cstring(b, proposal[i])) != 0)
415                         return r;
416         }
417         if ((r = sshbuf_put_u8(b, 0)) != 0 ||   /* first_kex_packet_follows */
418             (r = sshbuf_put_u32(b, 0)) != 0)    /* uint32 reserved */
419                 return r;
420         return 0;
421 }
422
423 /* parse buffer and return algorithm proposal */
424 int
425 kex_buf2prop(struct sshbuf *raw, int *first_kex_follows, char ***propp)
426 {
427         struct sshbuf *b = NULL;
428         u_char v;
429         u_int i;
430         char **proposal = NULL;
431         int r;
432
433         *propp = NULL;
434         if ((proposal = calloc(PROPOSAL_MAX, sizeof(char *))) == NULL)
435                 return SSH_ERR_ALLOC_FAIL;
436         if ((b = sshbuf_fromb(raw)) == NULL) {
437                 r = SSH_ERR_ALLOC_FAIL;
438                 goto out;
439         }
440         if ((r = sshbuf_consume(b, KEX_COOKIE_LEN)) != 0) { /* skip cookie */
441                 error_fr(r, "consume cookie");
442                 goto out;
443         }
444         /* extract kex init proposal strings */
445         for (i = 0; i < PROPOSAL_MAX; i++) {
446                 if ((r = sshbuf_get_cstring(b, &(proposal[i]), NULL)) != 0) {
447                         error_fr(r, "parse proposal %u", i);
448                         goto out;
449                 }
450                 debug2("%s: %s", proposal_names[i], proposal[i]);
451         }
452         /* first kex follows / reserved */
453         if ((r = sshbuf_get_u8(b, &v)) != 0 ||  /* first_kex_follows */
454             (r = sshbuf_get_u32(b, &i)) != 0) { /* reserved */
455                 error_fr(r, "parse");
456                 goto out;
457         }
458         if (first_kex_follows != NULL)
459                 *first_kex_follows = v;
460         debug2("first_kex_follows %d ", v);
461         debug2("reserved %u ", i);
462         r = 0;
463         *propp = proposal;
464  out:
465         if (r != 0 && proposal != NULL)
466                 kex_prop_free(proposal);
467         sshbuf_free(b);
468         return r;
469 }
470
471 void
472 kex_prop_free(char **proposal)
473 {
474         u_int i;
475
476         if (proposal == NULL)
477                 return;
478         for (i = 0; i < PROPOSAL_MAX; i++)
479                 free(proposal[i]);
480         free(proposal);
481 }
482
483 int
484 kex_protocol_error(int type, u_int32_t seq, struct ssh *ssh)
485 {
486         int r;
487
488         /* If in strict mode, any unexpected message is an error */
489         if ((ssh->kex->flags & KEX_INITIAL) && ssh->kex->kex_strict) {
490                 ssh_packet_disconnect(ssh, "strict KEX violation: "
491                     "unexpected packet type %u (seqnr %u)", type, seq);
492         }
493         error_f("type %u seq %u", type, seq);
494         if ((r = sshpkt_start(ssh, SSH2_MSG_UNIMPLEMENTED)) != 0 ||
495             (r = sshpkt_put_u32(ssh, seq)) != 0 ||
496             (r = sshpkt_send(ssh)) != 0)
497                 return r;
498         return 0;
499 }
500
501 static void
502 kex_reset_dispatch(struct ssh *ssh)
503 {
504         ssh_dispatch_range(ssh, SSH2_MSG_TRANSPORT_MIN,
505             SSH2_MSG_TRANSPORT_MAX, &kex_protocol_error);
506 }
507
508 static int
509 kex_send_ext_info(struct ssh *ssh)
510 {
511         int r;
512         char *algs;
513
514         debug("Sending SSH2_MSG_EXT_INFO");
515         if ((algs = sshkey_alg_list(0, 1, 1, ',')) == NULL)
516                 return SSH_ERR_ALLOC_FAIL;
517         /* XXX filter algs list by allowed pubkey/hostbased types */
518         if ((r = sshpkt_start(ssh, SSH2_MSG_EXT_INFO)) != 0 ||
519             (r = sshpkt_put_u32(ssh, 3)) != 0 ||
520             (r = sshpkt_put_cstring(ssh, "server-sig-algs")) != 0 ||
521             (r = sshpkt_put_cstring(ssh, algs)) != 0 ||
522             (r = sshpkt_put_cstring(ssh,
523             "publickey-hostbound@openssh.com")) != 0 ||
524             (r = sshpkt_put_cstring(ssh, "0")) != 0 ||
525             (r = sshpkt_put_cstring(ssh, "ping@openssh.com")) != 0 ||
526             (r = sshpkt_put_cstring(ssh, "0")) != 0 ||
527             (r = sshpkt_send(ssh)) != 0) {
528                 error_fr(r, "compose");
529                 goto out;
530         }
531         /* success */
532         r = 0;
533  out:
534         free(algs);
535         return r;
536 }
537
538 int
539 kex_send_newkeys(struct ssh *ssh)
540 {
541         int r;
542
543         kex_reset_dispatch(ssh);
544         if ((r = sshpkt_start(ssh, SSH2_MSG_NEWKEYS)) != 0 ||
545             (r = sshpkt_send(ssh)) != 0)
546                 return r;
547         debug("SSH2_MSG_NEWKEYS sent");
548         ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_input_newkeys);
549         if (ssh->kex->ext_info_c && (ssh->kex->flags & KEX_INITIAL) != 0)
550                 if ((r = kex_send_ext_info(ssh)) != 0)
551                         return r;
552         debug("expecting SSH2_MSG_NEWKEYS");
553         return 0;
554 }
555
556 /* Check whether an ext_info value contains the expected version string */
557 static int
558 kex_ext_info_check_ver(struct kex *kex, const char *name,
559     const u_char *val, size_t len, const char *want_ver, u_int flag)
560 {
561         if (memchr(val, '\0', len) != NULL) {
562                 error("SSH2_MSG_EXT_INFO: %s value contains nul byte", name);
563                 return SSH_ERR_INVALID_FORMAT;
564         }
565         debug_f("%s=<%s>", name, val);
566         if (strcmp(val, want_ver) == 0)
567                 kex->flags |= flag;
568         else
569                 debug_f("unsupported version of %s extension", name);
570         return 0;
571 }
572
573 int
574 kex_input_ext_info(int type, u_int32_t seq, struct ssh *ssh)
575 {
576         struct kex *kex = ssh->kex;
577         u_int32_t i, ninfo;
578         char *name;
579         u_char *val;
580         size_t vlen;
581         int r;
582
583         debug("SSH2_MSG_EXT_INFO received");
584         ssh_dispatch_set(ssh, SSH2_MSG_EXT_INFO, &kex_protocol_error);
585         if ((r = sshpkt_get_u32(ssh, &ninfo)) != 0)
586                 return r;
587         if (ninfo >= 1024) {
588                 error("SSH2_MSG_EXT_INFO with too many entries, expected "
589                     "<=1024, received %u", ninfo);
590                 return dispatch_protocol_error(type, seq, ssh);
591         }
592         for (i = 0; i < ninfo; i++) {
593                 if ((r = sshpkt_get_cstring(ssh, &name, NULL)) != 0)
594                         return r;
595                 if ((r = sshpkt_get_string(ssh, &val, &vlen)) != 0) {
596                         free(name);
597                         return r;
598                 }
599                 if (strcmp(name, "server-sig-algs") == 0) {
600                         /* Ensure no \0 lurking in value */
601                         if (memchr(val, '\0', vlen) != NULL) {
602                                 error_f("nul byte in %s", name);
603                                 free(name);
604                                 free(val);
605                                 return SSH_ERR_INVALID_FORMAT;
606                         }
607                         debug_f("%s=<%s>", name, val);
608                         kex->server_sig_algs = val;
609                         val = NULL;
610                 } else if (strcmp(name,
611                     "publickey-hostbound@openssh.com") == 0) {
612                         if ((r = kex_ext_info_check_ver(kex, name, val, vlen,
613                             "0", KEX_HAS_PUBKEY_HOSTBOUND)) != 0) {
614                                 free(name);
615                                 free(val);
616                                 return r;
617                         }
618                 } else if (strcmp(name, "ping@openssh.com") == 0) {
619                         if ((r = kex_ext_info_check_ver(kex, name, val, vlen,
620                             "0", KEX_HAS_PING)) != 0) {
621                                 free(name);
622                                 free(val);
623                                 return r;
624                         }
625                 } else
626                         debug_f("%s (unrecognised)", name);
627                 free(name);
628                 free(val);
629         }
630         return sshpkt_get_end(ssh);
631 }
632
633 static int
634 kex_input_newkeys(int type, u_int32_t seq, struct ssh *ssh)
635 {
636         struct kex *kex = ssh->kex;
637         int r;
638
639         debug("SSH2_MSG_NEWKEYS received");
640         ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_protocol_error);
641         ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
642         if ((r = sshpkt_get_end(ssh)) != 0)
643                 return r;
644         if ((r = ssh_set_newkeys(ssh, MODE_IN)) != 0)
645                 return r;
646         kex->done = 1;
647         kex->flags &= ~KEX_INITIAL;
648         sshbuf_reset(kex->peer);
649         /* sshbuf_reset(kex->my); */
650         kex->flags &= ~KEX_INIT_SENT;
651         free(kex->name);
652         kex->name = NULL;
653         return 0;
654 }
655
656 int
657 kex_send_kexinit(struct ssh *ssh)
658 {
659         u_char *cookie;
660         struct kex *kex = ssh->kex;
661         int r;
662
663         if (kex == NULL) {
664                 error_f("no kex");
665                 return SSH_ERR_INTERNAL_ERROR;
666         }
667         if (kex->flags & KEX_INIT_SENT)
668                 return 0;
669         kex->done = 0;
670
671         /* generate a random cookie */
672         if (sshbuf_len(kex->my) < KEX_COOKIE_LEN) {
673                 error_f("bad kex length: %zu < %d",
674                     sshbuf_len(kex->my), KEX_COOKIE_LEN);
675                 return SSH_ERR_INVALID_FORMAT;
676         }
677         if ((cookie = sshbuf_mutable_ptr(kex->my)) == NULL) {
678                 error_f("buffer error");
679                 return SSH_ERR_INTERNAL_ERROR;
680         }
681         arc4random_buf(cookie, KEX_COOKIE_LEN);
682
683         if ((r = sshpkt_start(ssh, SSH2_MSG_KEXINIT)) != 0 ||
684             (r = sshpkt_putb(ssh, kex->my)) != 0 ||
685             (r = sshpkt_send(ssh)) != 0) {
686                 error_fr(r, "compose reply");
687                 return r;
688         }
689         debug("SSH2_MSG_KEXINIT sent");
690         kex->flags |= KEX_INIT_SENT;
691         return 0;
692 }
693
694 int
695 kex_input_kexinit(int type, u_int32_t seq, struct ssh *ssh)
696 {
697         struct kex *kex = ssh->kex;
698         const u_char *ptr;
699         u_int i;
700         size_t dlen;
701         int r;
702
703         debug("SSH2_MSG_KEXINIT received");
704         if (kex == NULL) {
705                 error_f("no kex");
706                 return SSH_ERR_INTERNAL_ERROR;
707         }
708         ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_protocol_error);
709         ptr = sshpkt_ptr(ssh, &dlen);
710         if ((r = sshbuf_put(kex->peer, ptr, dlen)) != 0)
711                 return r;
712
713         /* discard packet */
714         for (i = 0; i < KEX_COOKIE_LEN; i++) {
715                 if ((r = sshpkt_get_u8(ssh, NULL)) != 0) {
716                         error_fr(r, "discard cookie");
717                         return r;
718                 }
719         }
720         for (i = 0; i < PROPOSAL_MAX; i++) {
721                 if ((r = sshpkt_get_string(ssh, NULL, NULL)) != 0) {
722                         error_fr(r, "discard proposal");
723                         return r;
724                 }
725         }
726         /*
727          * XXX RFC4253 sec 7: "each side MAY guess" - currently no supported
728          * KEX method has the server move first, but a server might be using
729          * a custom method or one that we otherwise don't support. We should
730          * be prepared to remember first_kex_follows here so we can eat a
731          * packet later.
732          * XXX2 - RFC4253 is kind of ambiguous on what first_kex_follows means
733          * for cases where the server *doesn't* go first. I guess we should
734          * ignore it when it is set for these cases, which is what we do now.
735          */
736         if ((r = sshpkt_get_u8(ssh, NULL)) != 0 ||      /* first_kex_follows */
737             (r = sshpkt_get_u32(ssh, NULL)) != 0 ||     /* reserved */
738             (r = sshpkt_get_end(ssh)) != 0)
739                         return r;
740
741         if (!(kex->flags & KEX_INIT_SENT))
742                 if ((r = kex_send_kexinit(ssh)) != 0)
743                         return r;
744         if ((r = kex_choose_conf(ssh, seq)) != 0)
745                 return r;
746
747         if (kex->kex_type < KEX_MAX && kex->kex[kex->kex_type] != NULL)
748                 return (kex->kex[kex->kex_type])(ssh);
749
750         error_f("unknown kex type %u", kex->kex_type);
751         return SSH_ERR_INTERNAL_ERROR;
752 }
753
754 struct kex *
755 kex_new(void)
756 {
757         struct kex *kex;
758
759         if ((kex = calloc(1, sizeof(*kex))) == NULL ||
760             (kex->peer = sshbuf_new()) == NULL ||
761             (kex->my = sshbuf_new()) == NULL ||
762             (kex->client_version = sshbuf_new()) == NULL ||
763             (kex->server_version = sshbuf_new()) == NULL ||
764             (kex->session_id = sshbuf_new()) == NULL) {
765                 kex_free(kex);
766                 return NULL;
767         }
768         return kex;
769 }
770
771 void
772 kex_free_newkeys(struct newkeys *newkeys)
773 {
774         if (newkeys == NULL)
775                 return;
776         if (newkeys->enc.key) {
777                 explicit_bzero(newkeys->enc.key, newkeys->enc.key_len);
778                 free(newkeys->enc.key);
779                 newkeys->enc.key = NULL;
780         }
781         if (newkeys->enc.iv) {
782                 explicit_bzero(newkeys->enc.iv, newkeys->enc.iv_len);
783                 free(newkeys->enc.iv);
784                 newkeys->enc.iv = NULL;
785         }
786         free(newkeys->enc.name);
787         explicit_bzero(&newkeys->enc, sizeof(newkeys->enc));
788         free(newkeys->comp.name);
789         explicit_bzero(&newkeys->comp, sizeof(newkeys->comp));
790         mac_clear(&newkeys->mac);
791         if (newkeys->mac.key) {
792                 explicit_bzero(newkeys->mac.key, newkeys->mac.key_len);
793                 free(newkeys->mac.key);
794                 newkeys->mac.key = NULL;
795         }
796         free(newkeys->mac.name);
797         explicit_bzero(&newkeys->mac, sizeof(newkeys->mac));
798         freezero(newkeys, sizeof(*newkeys));
799 }
800
801 void
802 kex_free(struct kex *kex)
803 {
804         u_int mode;
805
806         if (kex == NULL)
807                 return;
808
809 #ifdef WITH_OPENSSL
810         DH_free(kex->dh);
811 #ifdef OPENSSL_HAS_ECC
812         EC_KEY_free(kex->ec_client_key);
813 #endif /* OPENSSL_HAS_ECC */
814 #endif /* WITH_OPENSSL */
815         for (mode = 0; mode < MODE_MAX; mode++) {
816                 kex_free_newkeys(kex->newkeys[mode]);
817                 kex->newkeys[mode] = NULL;
818         }
819         sshbuf_free(kex->peer);
820         sshbuf_free(kex->my);
821         sshbuf_free(kex->client_version);
822         sshbuf_free(kex->server_version);
823         sshbuf_free(kex->client_pub);
824         sshbuf_free(kex->session_id);
825         sshbuf_free(kex->initial_sig);
826         sshkey_free(kex->initial_hostkey);
827         free(kex->failed_choice);
828         free(kex->hostkey_alg);
829         free(kex->name);
830         free(kex);
831 }
832
833 int
834 kex_ready(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
835 {
836         int r;
837
838         if ((r = kex_prop2buf(ssh->kex->my, proposal)) != 0)
839                 return r;
840         ssh->kex->flags = KEX_INITIAL;
841         kex_reset_dispatch(ssh);
842         ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
843         return 0;
844 }
845
846 int
847 kex_setup(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
848 {
849         int r;
850
851         if ((r = kex_ready(ssh, proposal)) != 0)
852                 return r;
853         if ((r = kex_send_kexinit(ssh)) != 0) {         /* we start */
854                 kex_free(ssh->kex);
855                 ssh->kex = NULL;
856                 return r;
857         }
858         return 0;
859 }
860
861 /*
862  * Request key re-exchange, returns 0 on success or a ssherr.h error
863  * code otherwise. Must not be called if KEX is incomplete or in-progress.
864  */
865 int
866 kex_start_rekex(struct ssh *ssh)
867 {
868         if (ssh->kex == NULL) {
869                 error_f("no kex");
870                 return SSH_ERR_INTERNAL_ERROR;
871         }
872         if (ssh->kex->done == 0) {
873                 error_f("requested twice");
874                 return SSH_ERR_INTERNAL_ERROR;
875         }
876         ssh->kex->done = 0;
877         return kex_send_kexinit(ssh);
878 }
879
880 static int
881 choose_enc(struct sshenc *enc, char *client, char *server)
882 {
883         char *name = match_list(client, server, NULL);
884
885         if (name == NULL)
886                 return SSH_ERR_NO_CIPHER_ALG_MATCH;
887         if ((enc->cipher = cipher_by_name(name)) == NULL) {
888                 error_f("unsupported cipher %s", name);
889                 free(name);
890                 return SSH_ERR_INTERNAL_ERROR;
891         }
892         enc->name = name;
893         enc->enabled = 0;
894         enc->iv = NULL;
895         enc->iv_len = cipher_ivlen(enc->cipher);
896         enc->key = NULL;
897         enc->key_len = cipher_keylen(enc->cipher);
898         enc->block_size = cipher_blocksize(enc->cipher);
899         return 0;
900 }
901
902 static int
903 choose_mac(struct ssh *ssh, struct sshmac *mac, char *client, char *server)
904 {
905         char *name = match_list(client, server, NULL);
906
907         if (name == NULL)
908                 return SSH_ERR_NO_MAC_ALG_MATCH;
909         if (mac_setup(mac, name) < 0) {
910                 error_f("unsupported MAC %s", name);
911                 free(name);
912                 return SSH_ERR_INTERNAL_ERROR;
913         }
914         mac->name = name;
915         mac->key = NULL;
916         mac->enabled = 0;
917         return 0;
918 }
919
920 static int
921 choose_comp(struct sshcomp *comp, char *client, char *server)
922 {
923         char *name = match_list(client, server, NULL);
924
925         if (name == NULL)
926                 return SSH_ERR_NO_COMPRESS_ALG_MATCH;
927 #ifdef WITH_ZLIB
928         if (strcmp(name, "zlib@openssh.com") == 0) {
929                 comp->type = COMP_DELAYED;
930         } else if (strcmp(name, "zlib") == 0) {
931                 comp->type = COMP_ZLIB;
932         } else
933 #endif  /* WITH_ZLIB */
934         if (strcmp(name, "none") == 0) {
935                 comp->type = COMP_NONE;
936         } else {
937                 error_f("unsupported compression scheme %s", name);
938                 free(name);
939                 return SSH_ERR_INTERNAL_ERROR;
940         }
941         comp->name = name;
942         return 0;
943 }
944
945 static int
946 choose_kex(struct kex *k, char *client, char *server)
947 {
948         const struct kexalg *kexalg;
949
950         k->name = match_list(client, server, NULL);
951
952         debug("kex: algorithm: %s", k->name ? k->name : "(no match)");
953         if (k->name == NULL)
954                 return SSH_ERR_NO_KEX_ALG_MATCH;
955         if ((kexalg = kex_alg_by_name(k->name)) == NULL) {
956                 error_f("unsupported KEX method %s", k->name);
957                 return SSH_ERR_INTERNAL_ERROR;
958         }
959         k->kex_type = kexalg->type;
960         k->hash_alg = kexalg->hash_alg;
961         k->ec_nid = kexalg->ec_nid;
962         return 0;
963 }
964
965 static int
966 choose_hostkeyalg(struct kex *k, char *client, char *server)
967 {
968         free(k->hostkey_alg);
969         k->hostkey_alg = match_list(client, server, NULL);
970
971         debug("kex: host key algorithm: %s",
972             k->hostkey_alg ? k->hostkey_alg : "(no match)");
973         if (k->hostkey_alg == NULL)
974                 return SSH_ERR_NO_HOSTKEY_ALG_MATCH;
975         k->hostkey_type = sshkey_type_from_name(k->hostkey_alg);
976         if (k->hostkey_type == KEY_UNSPEC) {
977                 error_f("unsupported hostkey algorithm %s", k->hostkey_alg);
978                 return SSH_ERR_INTERNAL_ERROR;
979         }
980         k->hostkey_nid = sshkey_ecdsa_nid_from_name(k->hostkey_alg);
981         return 0;
982 }
983
984 static int
985 proposals_match(char *my[PROPOSAL_MAX], char *peer[PROPOSAL_MAX])
986 {
987         static int check[] = {
988                 PROPOSAL_KEX_ALGS, PROPOSAL_SERVER_HOST_KEY_ALGS, -1
989         };
990         int *idx;
991         char *p;
992
993         for (idx = &check[0]; *idx != -1; idx++) {
994                 if ((p = strchr(my[*idx], ',')) != NULL)
995                         *p = '\0';
996                 if ((p = strchr(peer[*idx], ',')) != NULL)
997                         *p = '\0';
998                 if (strcmp(my[*idx], peer[*idx]) != 0) {
999                         debug2("proposal mismatch: my %s peer %s",
1000                             my[*idx], peer[*idx]);
1001                         return (0);
1002                 }
1003         }
1004         debug2("proposals match");
1005         return (1);
1006 }
1007
1008 static int
1009 kexalgs_contains(char **peer, const char *ext)
1010 {
1011         return has_any_alg(peer[PROPOSAL_KEX_ALGS], ext);
1012 }
1013
1014 static int
1015 kex_choose_conf(struct ssh *ssh, uint32_t seq)
1016 {
1017         struct kex *kex = ssh->kex;
1018         struct newkeys *newkeys;
1019         char **my = NULL, **peer = NULL;
1020         char **cprop, **sprop;
1021         int nenc, nmac, ncomp;
1022         u_int mode, ctos, need, dh_need, authlen;
1023         int r, first_kex_follows;
1024
1025         debug2("local %s KEXINIT proposal", kex->server ? "server" : "client");
1026         if ((r = kex_buf2prop(kex->my, NULL, &my)) != 0)
1027                 goto out;
1028         debug2("peer %s KEXINIT proposal", kex->server ? "client" : "server");
1029         if ((r = kex_buf2prop(kex->peer, &first_kex_follows, &peer)) != 0)
1030                 goto out;
1031
1032         if (kex->server) {
1033                 cprop=peer;
1034                 sprop=my;
1035         } else {
1036                 cprop=my;
1037                 sprop=peer;
1038         }
1039
1040         /* Check whether peer supports ext_info/kex_strict */
1041         if ((kex->flags & KEX_INITIAL) != 0) {
1042                 if (kex->server) {
1043                         kex->ext_info_c = kexalgs_contains(peer, "ext-info-c");
1044                         kex->kex_strict = kexalgs_contains(peer,
1045                             "kex-strict-c-v00@openssh.com");
1046                 } else {
1047                         kex->kex_strict = kexalgs_contains(peer,
1048                             "kex-strict-s-v00@openssh.com");
1049                 }
1050                 if (kex->kex_strict) {
1051                         debug3_f("will use strict KEX ordering");
1052                         if (seq != 0)
1053                                 ssh_packet_disconnect(ssh,
1054                                     "strict KEX violation: "
1055                                     "KEXINIT was not the first packet");
1056                 }
1057         }
1058
1059         /* Check whether client supports rsa-sha2 algorithms */
1060         if (kex->server && (kex->flags & KEX_INITIAL)) {
1061                 if (has_any_alg(peer[PROPOSAL_SERVER_HOST_KEY_ALGS],
1062                     "rsa-sha2-256,rsa-sha2-256-cert-v01@openssh.com"))
1063                         kex->flags |= KEX_RSA_SHA2_256_SUPPORTED;
1064                 if (has_any_alg(peer[PROPOSAL_SERVER_HOST_KEY_ALGS],
1065                     "rsa-sha2-512,rsa-sha2-512-cert-v01@openssh.com"))
1066                         kex->flags |= KEX_RSA_SHA2_512_SUPPORTED;
1067         }
1068
1069         /* Algorithm Negotiation */
1070         if ((r = choose_kex(kex, cprop[PROPOSAL_KEX_ALGS],
1071             sprop[PROPOSAL_KEX_ALGS])) != 0) {
1072                 kex->failed_choice = peer[PROPOSAL_KEX_ALGS];
1073                 peer[PROPOSAL_KEX_ALGS] = NULL;
1074                 goto out;
1075         }
1076         if ((r = choose_hostkeyalg(kex, cprop[PROPOSAL_SERVER_HOST_KEY_ALGS],
1077             sprop[PROPOSAL_SERVER_HOST_KEY_ALGS])) != 0) {
1078                 kex->failed_choice = peer[PROPOSAL_SERVER_HOST_KEY_ALGS];
1079                 peer[PROPOSAL_SERVER_HOST_KEY_ALGS] = NULL;
1080                 goto out;
1081         }
1082         for (mode = 0; mode < MODE_MAX; mode++) {
1083                 if ((newkeys = calloc(1, sizeof(*newkeys))) == NULL) {
1084                         r = SSH_ERR_ALLOC_FAIL;
1085                         goto out;
1086                 }
1087                 kex->newkeys[mode] = newkeys;
1088                 ctos = (!kex->server && mode == MODE_OUT) ||
1089                     (kex->server && mode == MODE_IN);
1090                 nenc  = ctos ? PROPOSAL_ENC_ALGS_CTOS  : PROPOSAL_ENC_ALGS_STOC;
1091                 nmac  = ctos ? PROPOSAL_MAC_ALGS_CTOS  : PROPOSAL_MAC_ALGS_STOC;
1092                 ncomp = ctos ? PROPOSAL_COMP_ALGS_CTOS : PROPOSAL_COMP_ALGS_STOC;
1093                 if ((r = choose_enc(&newkeys->enc, cprop[nenc],
1094                     sprop[nenc])) != 0) {
1095                         kex->failed_choice = peer[nenc];
1096                         peer[nenc] = NULL;
1097                         goto out;
1098                 }
1099                 authlen = cipher_authlen(newkeys->enc.cipher);
1100                 /* ignore mac for authenticated encryption */
1101                 if (authlen == 0 &&
1102                     (r = choose_mac(ssh, &newkeys->mac, cprop[nmac],
1103                     sprop[nmac])) != 0) {
1104                         kex->failed_choice = peer[nmac];
1105                         peer[nmac] = NULL;
1106                         goto out;
1107                 }
1108                 if ((r = choose_comp(&newkeys->comp, cprop[ncomp],
1109                     sprop[ncomp])) != 0) {
1110                         kex->failed_choice = peer[ncomp];
1111                         peer[ncomp] = NULL;
1112                         goto out;
1113                 }
1114                 debug("kex: %s cipher: %s MAC: %s compression: %s",
1115                     ctos ? "client->server" : "server->client",
1116                     newkeys->enc.name,
1117                     authlen == 0 ? newkeys->mac.name : "<implicit>",
1118                     newkeys->comp.name);
1119         }
1120         need = dh_need = 0;
1121         for (mode = 0; mode < MODE_MAX; mode++) {
1122                 newkeys = kex->newkeys[mode];
1123                 need = MAXIMUM(need, newkeys->enc.key_len);
1124                 need = MAXIMUM(need, newkeys->enc.block_size);
1125                 need = MAXIMUM(need, newkeys->enc.iv_len);
1126                 need = MAXIMUM(need, newkeys->mac.key_len);
1127                 dh_need = MAXIMUM(dh_need, cipher_seclen(newkeys->enc.cipher));
1128                 dh_need = MAXIMUM(dh_need, newkeys->enc.block_size);
1129                 dh_need = MAXIMUM(dh_need, newkeys->enc.iv_len);
1130                 dh_need = MAXIMUM(dh_need, newkeys->mac.key_len);
1131         }
1132         /* XXX need runden? */
1133         kex->we_need = need;
1134         kex->dh_need = dh_need;
1135
1136         /* ignore the next message if the proposals do not match */
1137         if (first_kex_follows && !proposals_match(my, peer))
1138                 ssh->dispatch_skip_packets = 1;
1139         r = 0;
1140  out:
1141         kex_prop_free(my);
1142         kex_prop_free(peer);
1143         return r;
1144 }
1145
1146 static int
1147 derive_key(struct ssh *ssh, int id, u_int need, u_char *hash, u_int hashlen,
1148     const struct sshbuf *shared_secret, u_char **keyp)
1149 {
1150         struct kex *kex = ssh->kex;
1151         struct ssh_digest_ctx *hashctx = NULL;
1152         char c = id;
1153         u_int have;
1154         size_t mdsz;
1155         u_char *digest;
1156         int r;
1157
1158         if ((mdsz = ssh_digest_bytes(kex->hash_alg)) == 0)
1159                 return SSH_ERR_INVALID_ARGUMENT;
1160         if ((digest = calloc(1, ROUNDUP(need, mdsz))) == NULL) {
1161                 r = SSH_ERR_ALLOC_FAIL;
1162                 goto out;
1163         }
1164
1165         /* K1 = HASH(K || H || "A" || session_id) */
1166         if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
1167             ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
1168             ssh_digest_update(hashctx, hash, hashlen) != 0 ||
1169             ssh_digest_update(hashctx, &c, 1) != 0 ||
1170             ssh_digest_update_buffer(hashctx, kex->session_id) != 0 ||
1171             ssh_digest_final(hashctx, digest, mdsz) != 0) {
1172                 r = SSH_ERR_LIBCRYPTO_ERROR;
1173                 error_f("KEX hash failed");
1174                 goto out;
1175         }
1176         ssh_digest_free(hashctx);
1177         hashctx = NULL;
1178
1179         /*
1180          * expand key:
1181          * Kn = HASH(K || H || K1 || K2 || ... || Kn-1)
1182          * Key = K1 || K2 || ... || Kn
1183          */
1184         for (have = mdsz; need > have; have += mdsz) {
1185                 if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
1186                     ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
1187                     ssh_digest_update(hashctx, hash, hashlen) != 0 ||
1188                     ssh_digest_update(hashctx, digest, have) != 0 ||
1189                     ssh_digest_final(hashctx, digest + have, mdsz) != 0) {
1190                         error_f("KDF failed");
1191                         r = SSH_ERR_LIBCRYPTO_ERROR;
1192                         goto out;
1193                 }
1194                 ssh_digest_free(hashctx);
1195                 hashctx = NULL;
1196         }
1197 #ifdef DEBUG_KEX
1198         fprintf(stderr, "key '%c'== ", c);
1199         dump_digest("key", digest, need);
1200 #endif
1201         *keyp = digest;
1202         digest = NULL;
1203         r = 0;
1204  out:
1205         free(digest);
1206         ssh_digest_free(hashctx);
1207         return r;
1208 }
1209
1210 #define NKEYS   6
1211 int
1212 kex_derive_keys(struct ssh *ssh, u_char *hash, u_int hashlen,
1213     const struct sshbuf *shared_secret)
1214 {
1215         struct kex *kex = ssh->kex;
1216         u_char *keys[NKEYS];
1217         u_int i, j, mode, ctos;
1218         int r;
1219
1220         /* save initial hash as session id */
1221         if ((kex->flags & KEX_INITIAL) != 0) {
1222                 if (sshbuf_len(kex->session_id) != 0) {
1223                         error_f("already have session ID at kex");
1224                         return SSH_ERR_INTERNAL_ERROR;
1225                 }
1226                 if ((r = sshbuf_put(kex->session_id, hash, hashlen)) != 0)
1227                         return r;
1228         } else if (sshbuf_len(kex->session_id) == 0) {
1229                 error_f("no session ID in rekex");
1230                 return SSH_ERR_INTERNAL_ERROR;
1231         }
1232         for (i = 0; i < NKEYS; i++) {
1233                 if ((r = derive_key(ssh, 'A'+i, kex->we_need, hash, hashlen,
1234                     shared_secret, &keys[i])) != 0) {
1235                         for (j = 0; j < i; j++)
1236                                 free(keys[j]);
1237                         return r;
1238                 }
1239         }
1240         for (mode = 0; mode < MODE_MAX; mode++) {
1241                 ctos = (!kex->server && mode == MODE_OUT) ||
1242                     (kex->server && mode == MODE_IN);
1243                 kex->newkeys[mode]->enc.iv  = keys[ctos ? 0 : 1];
1244                 kex->newkeys[mode]->enc.key = keys[ctos ? 2 : 3];
1245                 kex->newkeys[mode]->mac.key = keys[ctos ? 4 : 5];
1246         }
1247         return 0;
1248 }
1249
1250 int
1251 kex_load_hostkey(struct ssh *ssh, struct sshkey **prvp, struct sshkey **pubp)
1252 {
1253         struct kex *kex = ssh->kex;
1254
1255         *pubp = NULL;
1256         *prvp = NULL;
1257         if (kex->load_host_public_key == NULL ||
1258             kex->load_host_private_key == NULL) {
1259                 error_f("missing hostkey loader");
1260                 return SSH_ERR_INVALID_ARGUMENT;
1261         }
1262         *pubp = kex->load_host_public_key(kex->hostkey_type,
1263             kex->hostkey_nid, ssh);
1264         *prvp = kex->load_host_private_key(kex->hostkey_type,
1265             kex->hostkey_nid, ssh);
1266         if (*pubp == NULL)
1267                 return SSH_ERR_NO_HOSTKEY_LOADED;
1268         return 0;
1269 }
1270
1271 int
1272 kex_verify_host_key(struct ssh *ssh, struct sshkey *server_host_key)
1273 {
1274         struct kex *kex = ssh->kex;
1275
1276         if (kex->verify_host_key == NULL) {
1277                 error_f("missing hostkey verifier");
1278                 return SSH_ERR_INVALID_ARGUMENT;
1279         }
1280         if (server_host_key->type != kex->hostkey_type ||
1281             (kex->hostkey_type == KEY_ECDSA &&
1282             server_host_key->ecdsa_nid != kex->hostkey_nid))
1283                 return SSH_ERR_KEY_TYPE_MISMATCH;
1284         if (kex->verify_host_key(server_host_key, ssh) == -1)
1285                 return  SSH_ERR_SIGNATURE_INVALID;
1286         return 0;
1287 }
1288
1289 #if defined(DEBUG_KEX) || defined(DEBUG_KEXDH) || defined(DEBUG_KEXECDH)
1290 void
1291 dump_digest(const char *msg, const u_char *digest, int len)
1292 {
1293         fprintf(stderr, "%s\n", msg);
1294         sshbuf_dump_data(digest, len, stderr);
1295 }
1296 #endif
1297
1298 /*
1299  * Send a plaintext error message to the peer, suffixed by \r\n.
1300  * Only used during banner exchange, and there only for the server.
1301  */
1302 static void
1303 send_error(struct ssh *ssh, char *msg)
1304 {
1305         char *crnl = "\r\n";
1306
1307         if (!ssh->kex->server)
1308                 return;
1309
1310         if (atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1311             msg, strlen(msg)) != strlen(msg) ||
1312             atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1313             crnl, strlen(crnl)) != strlen(crnl))
1314                 error_f("write: %.100s", strerror(errno));
1315 }
1316
1317 /*
1318  * Sends our identification string and waits for the peer's. Will block for
1319  * up to timeout_ms (or indefinitely if timeout_ms <= 0).
1320  * Returns on 0 success or a ssherr.h code on failure.
1321  */
1322 int
1323 kex_exchange_identification(struct ssh *ssh, int timeout_ms,
1324     const char *version_addendum)
1325 {
1326         int remote_major, remote_minor, mismatch, oerrno = 0;
1327         size_t len, n;
1328         int r, expect_nl;
1329         u_char c;
1330         struct sshbuf *our_version = ssh->kex->server ?
1331             ssh->kex->server_version : ssh->kex->client_version;
1332         struct sshbuf *peer_version = ssh->kex->server ?
1333             ssh->kex->client_version : ssh->kex->server_version;
1334         char *our_version_string = NULL, *peer_version_string = NULL;
1335         char *cp, *remote_version = NULL;
1336
1337         /* Prepare and send our banner */
1338         sshbuf_reset(our_version);
1339         if (version_addendum != NULL && *version_addendum == '\0')
1340                 version_addendum = NULL;
1341         if ((r = sshbuf_putf(our_version, "SSH-%d.%d-%.100s%s%s\r\n",
1342             PROTOCOL_MAJOR_2, PROTOCOL_MINOR_2, SSH_VERSION,
1343             version_addendum == NULL ? "" : " ",
1344             version_addendum == NULL ? "" : version_addendum)) != 0) {
1345                 oerrno = errno;
1346                 error_fr(r, "sshbuf_putf");
1347                 goto out;
1348         }
1349
1350         if (atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1351             sshbuf_mutable_ptr(our_version),
1352             sshbuf_len(our_version)) != sshbuf_len(our_version)) {
1353                 oerrno = errno;
1354                 debug_f("write: %.100s", strerror(errno));
1355                 r = SSH_ERR_SYSTEM_ERROR;
1356                 goto out;
1357         }
1358         if ((r = sshbuf_consume_end(our_version, 2)) != 0) { /* trim \r\n */
1359                 oerrno = errno;
1360                 error_fr(r, "sshbuf_consume_end");
1361                 goto out;
1362         }
1363         our_version_string = sshbuf_dup_string(our_version);
1364         if (our_version_string == NULL) {
1365                 error_f("sshbuf_dup_string failed");
1366                 r = SSH_ERR_ALLOC_FAIL;
1367                 goto out;
1368         }
1369         debug("Local version string %.100s", our_version_string);
1370
1371         /* Read other side's version identification. */
1372         for (n = 0; ; n++) {
1373                 if (n >= SSH_MAX_PRE_BANNER_LINES) {
1374                         send_error(ssh, "No SSH identification string "
1375                             "received.");
1376                         error_f("No SSH version received in first %u lines "
1377                             "from server", SSH_MAX_PRE_BANNER_LINES);
1378                         r = SSH_ERR_INVALID_FORMAT;
1379                         goto out;
1380                 }
1381                 sshbuf_reset(peer_version);
1382                 expect_nl = 0;
1383                 for (;;) {
1384                         if (timeout_ms > 0) {
1385                                 r = waitrfd(ssh_packet_get_connection_in(ssh),
1386                                     &timeout_ms, NULL);
1387                                 if (r == -1 && errno == ETIMEDOUT) {
1388                                         send_error(ssh, "Timed out waiting "
1389                                             "for SSH identification string.");
1390                                         error("Connection timed out during "
1391                                             "banner exchange");
1392                                         r = SSH_ERR_CONN_TIMEOUT;
1393                                         goto out;
1394                                 } else if (r == -1) {
1395                                         oerrno = errno;
1396                                         error_f("%s", strerror(errno));
1397                                         r = SSH_ERR_SYSTEM_ERROR;
1398                                         goto out;
1399                                 }
1400                         }
1401
1402                         len = atomicio(read, ssh_packet_get_connection_in(ssh),
1403                             &c, 1);
1404                         if (len != 1 && errno == EPIPE) {
1405                                 verbose_f("Connection closed by remote host");
1406                                 r = SSH_ERR_CONN_CLOSED;
1407                                 goto out;
1408                         } else if (len != 1) {
1409                                 oerrno = errno;
1410                                 error_f("read: %.100s", strerror(errno));
1411                                 r = SSH_ERR_SYSTEM_ERROR;
1412                                 goto out;
1413                         }
1414                         if (c == '\r') {
1415                                 expect_nl = 1;
1416                                 continue;
1417                         }
1418                         if (c == '\n')
1419                                 break;
1420                         if (c == '\0' || expect_nl) {
1421                                 verbose_f("banner line contains invalid "
1422                                     "characters");
1423                                 goto invalid;
1424                         }
1425                         if ((r = sshbuf_put_u8(peer_version, c)) != 0) {
1426                                 oerrno = errno;
1427                                 error_fr(r, "sshbuf_put");
1428                                 goto out;
1429                         }
1430                         if (sshbuf_len(peer_version) > SSH_MAX_BANNER_LEN) {
1431                                 verbose_f("banner line too long");
1432                                 goto invalid;
1433                         }
1434                 }
1435                 /* Is this an actual protocol banner? */
1436                 if (sshbuf_len(peer_version) > 4 &&
1437                     memcmp(sshbuf_ptr(peer_version), "SSH-", 4) == 0)
1438                         break;
1439                 /* If not, then just log the line and continue */
1440                 if ((cp = sshbuf_dup_string(peer_version)) == NULL) {
1441                         error_f("sshbuf_dup_string failed");
1442                         r = SSH_ERR_ALLOC_FAIL;
1443                         goto out;
1444                 }
1445                 /* Do not accept lines before the SSH ident from a client */
1446                 if (ssh->kex->server) {
1447                         verbose_f("client sent invalid protocol identifier "
1448                             "\"%.256s\"", cp);
1449                         free(cp);
1450                         goto invalid;
1451                 }
1452                 debug_f("banner line %zu: %s", n, cp);
1453                 free(cp);
1454         }
1455         peer_version_string = sshbuf_dup_string(peer_version);
1456         if (peer_version_string == NULL)
1457                 fatal_f("sshbuf_dup_string failed");
1458         /* XXX must be same size for sscanf */
1459         if ((remote_version = calloc(1, sshbuf_len(peer_version))) == NULL) {
1460                 error_f("calloc failed");
1461                 r = SSH_ERR_ALLOC_FAIL;
1462                 goto out;
1463         }
1464
1465         /*
1466          * Check that the versions match.  In future this might accept
1467          * several versions and set appropriate flags to handle them.
1468          */
1469         if (sscanf(peer_version_string, "SSH-%d.%d-%[^\n]\n",
1470             &remote_major, &remote_minor, remote_version) != 3) {
1471                 error("Bad remote protocol version identification: '%.100s'",
1472                     peer_version_string);
1473  invalid:
1474                 send_error(ssh, "Invalid SSH identification string.");
1475                 r = SSH_ERR_INVALID_FORMAT;
1476                 goto out;
1477         }
1478         debug("Remote protocol version %d.%d, remote software version %.100s",
1479             remote_major, remote_minor, remote_version);
1480         compat_banner(ssh, remote_version);
1481
1482         mismatch = 0;
1483         switch (remote_major) {
1484         case 2:
1485                 break;
1486         case 1:
1487                 if (remote_minor != 99)
1488                         mismatch = 1;
1489                 break;
1490         default:
1491                 mismatch = 1;
1492                 break;
1493         }
1494         if (mismatch) {
1495                 error("Protocol major versions differ: %d vs. %d",
1496                     PROTOCOL_MAJOR_2, remote_major);
1497                 send_error(ssh, "Protocol major versions differ.");
1498                 r = SSH_ERR_NO_PROTOCOL_VERSION;
1499                 goto out;
1500         }
1501
1502         if (ssh->kex->server && (ssh->compat & SSH_BUG_PROBE) != 0) {
1503                 logit("probed from %s port %d with %s.  Don't panic.",
1504                     ssh_remote_ipaddr(ssh), ssh_remote_port(ssh),
1505                     peer_version_string);
1506                 r = SSH_ERR_CONN_CLOSED; /* XXX */
1507                 goto out;
1508         }
1509         if (ssh->kex->server && (ssh->compat & SSH_BUG_SCANNER) != 0) {
1510                 logit("scanned from %s port %d with %s.  Don't panic.",
1511                     ssh_remote_ipaddr(ssh), ssh_remote_port(ssh),
1512                     peer_version_string);
1513                 r = SSH_ERR_CONN_CLOSED; /* XXX */
1514                 goto out;
1515         }
1516         /* success */
1517         r = 0;
1518  out:
1519         free(our_version_string);
1520         free(peer_version_string);
1521         free(remote_version);
1522         if (r == SSH_ERR_SYSTEM_ERROR)
1523                 errno = oerrno;
1524         return r;
1525 }
1526