]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - crypto/openssh/kex.c
zfs: merge openzfs/zfs@a4bf6baae
[FreeBSD/FreeBSD.git] / crypto / openssh / kex.c
1 /* $OpenBSD: kex.c,v 1.184 2023/12/18 14:45:49 djm Exp $ */
2 /*
3  * Copyright (c) 2000, 2001 Markus Friedl.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
15  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
16  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
17  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
18  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
19  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
20  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
21  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
23  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24  */
25
26 #include "includes.h"
27
28 #include <sys/types.h>
29 #include <errno.h>
30 #include <signal.h>
31 #include <stdarg.h>
32 #include <stdio.h>
33 #include <stdlib.h>
34 #include <string.h>
35 #include <unistd.h>
36 #ifdef HAVE_POLL_H
37 #include <poll.h>
38 #endif
39
40 #ifdef WITH_OPENSSL
41 #include <openssl/crypto.h>
42 #include <openssl/dh.h>
43 #endif
44
45 #include "ssh.h"
46 #include "ssh2.h"
47 #include "atomicio.h"
48 #include "version.h"
49 #include "packet.h"
50 #include "compat.h"
51 #include "cipher.h"
52 #include "sshkey.h"
53 #include "kex.h"
54 #include "log.h"
55 #include "mac.h"
56 #include "match.h"
57 #include "misc.h"
58 #include "dispatch.h"
59 #include "monitor.h"
60 #include "myproposal.h"
61
62 #include "ssherr.h"
63 #include "sshbuf.h"
64 #include "digest.h"
65 #include "xmalloc.h"
66
67 /* prototype */
68 static int kex_choose_conf(struct ssh *, uint32_t seq);
69 static int kex_input_newkeys(int, u_int32_t, struct ssh *);
70
71 static const char * const proposal_names[PROPOSAL_MAX] = {
72         "KEX algorithms",
73         "host key algorithms",
74         "ciphers ctos",
75         "ciphers stoc",
76         "MACs ctos",
77         "MACs stoc",
78         "compression ctos",
79         "compression stoc",
80         "languages ctos",
81         "languages stoc",
82 };
83
84 struct kexalg {
85         char *name;
86         u_int type;
87         int ec_nid;
88         int hash_alg;
89 };
90 static const struct kexalg kexalgs[] = {
91 #ifdef WITH_OPENSSL
92         { KEX_DH1, KEX_DH_GRP1_SHA1, 0, SSH_DIGEST_SHA1 },
93         { KEX_DH14_SHA1, KEX_DH_GRP14_SHA1, 0, SSH_DIGEST_SHA1 },
94         { KEX_DH14_SHA256, KEX_DH_GRP14_SHA256, 0, SSH_DIGEST_SHA256 },
95         { KEX_DH16_SHA512, KEX_DH_GRP16_SHA512, 0, SSH_DIGEST_SHA512 },
96         { KEX_DH18_SHA512, KEX_DH_GRP18_SHA512, 0, SSH_DIGEST_SHA512 },
97         { KEX_DHGEX_SHA1, KEX_DH_GEX_SHA1, 0, SSH_DIGEST_SHA1 },
98 #ifdef HAVE_EVP_SHA256
99         { KEX_DHGEX_SHA256, KEX_DH_GEX_SHA256, 0, SSH_DIGEST_SHA256 },
100 #endif /* HAVE_EVP_SHA256 */
101 #ifdef OPENSSL_HAS_ECC
102         { KEX_ECDH_SHA2_NISTP256, KEX_ECDH_SHA2,
103             NID_X9_62_prime256v1, SSH_DIGEST_SHA256 },
104         { KEX_ECDH_SHA2_NISTP384, KEX_ECDH_SHA2, NID_secp384r1,
105             SSH_DIGEST_SHA384 },
106 # ifdef OPENSSL_HAS_NISTP521
107         { KEX_ECDH_SHA2_NISTP521, KEX_ECDH_SHA2, NID_secp521r1,
108             SSH_DIGEST_SHA512 },
109 # endif /* OPENSSL_HAS_NISTP521 */
110 #endif /* OPENSSL_HAS_ECC */
111 #endif /* WITH_OPENSSL */
112 #if defined(HAVE_EVP_SHA256) || !defined(WITH_OPENSSL)
113         { KEX_CURVE25519_SHA256, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
114         { KEX_CURVE25519_SHA256_OLD, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
115 #ifdef USE_SNTRUP761X25519
116         { KEX_SNTRUP761X25519_SHA512, KEX_KEM_SNTRUP761X25519_SHA512, 0,
117             SSH_DIGEST_SHA512 },
118 #endif
119 #endif /* HAVE_EVP_SHA256 || !WITH_OPENSSL */
120         { NULL, 0, -1, -1},
121 };
122
123 char *
124 kex_alg_list(char sep)
125 {
126         char *ret = NULL, *tmp;
127         size_t nlen, rlen = 0;
128         const struct kexalg *k;
129
130         for (k = kexalgs; k->name != NULL; k++) {
131                 if (ret != NULL)
132                         ret[rlen++] = sep;
133                 nlen = strlen(k->name);
134                 if ((tmp = realloc(ret, rlen + nlen + 2)) == NULL) {
135                         free(ret);
136                         return NULL;
137                 }
138                 ret = tmp;
139                 memcpy(ret + rlen, k->name, nlen + 1);
140                 rlen += nlen;
141         }
142         return ret;
143 }
144
145 static const struct kexalg *
146 kex_alg_by_name(const char *name)
147 {
148         const struct kexalg *k;
149
150         for (k = kexalgs; k->name != NULL; k++) {
151                 if (strcmp(k->name, name) == 0)
152                         return k;
153         }
154         return NULL;
155 }
156
157 /* Validate KEX method name list */
158 int
159 kex_names_valid(const char *names)
160 {
161         char *s, *cp, *p;
162
163         if (names == NULL || strcmp(names, "") == 0)
164                 return 0;
165         if ((s = cp = strdup(names)) == NULL)
166                 return 0;
167         for ((p = strsep(&cp, ",")); p && *p != '\0';
168             (p = strsep(&cp, ","))) {
169                 if (kex_alg_by_name(p) == NULL) {
170                         error("Unsupported KEX algorithm \"%.100s\"", p);
171                         free(s);
172                         return 0;
173                 }
174         }
175         debug3("kex names ok: [%s]", names);
176         free(s);
177         return 1;
178 }
179
180 /* returns non-zero if proposal contains any algorithm from algs */
181 static int
182 has_any_alg(const char *proposal, const char *algs)
183 {
184         char *cp;
185
186         if ((cp = match_list(proposal, algs, NULL)) == NULL)
187                 return 0;
188         free(cp);
189         return 1;
190 }
191
192 /*
193  * Concatenate algorithm names, avoiding duplicates in the process.
194  * Caller must free returned string.
195  */
196 char *
197 kex_names_cat(const char *a, const char *b)
198 {
199         char *ret = NULL, *tmp = NULL, *cp, *p;
200         size_t len;
201
202         if (a == NULL || *a == '\0')
203                 return strdup(b);
204         if (b == NULL || *b == '\0')
205                 return strdup(a);
206         if (strlen(b) > 1024*1024)
207                 return NULL;
208         len = strlen(a) + strlen(b) + 2;
209         if ((tmp = cp = strdup(b)) == NULL ||
210             (ret = calloc(1, len)) == NULL) {
211                 free(tmp);
212                 return NULL;
213         }
214         strlcpy(ret, a, len);
215         for ((p = strsep(&cp, ",")); p && *p != '\0'; (p = strsep(&cp, ","))) {
216                 if (has_any_alg(ret, p))
217                         continue; /* Algorithm already present */
218                 if (strlcat(ret, ",", len) >= len ||
219                     strlcat(ret, p, len) >= len) {
220                         free(tmp);
221                         free(ret);
222                         return NULL; /* Shouldn't happen */
223                 }
224         }
225         free(tmp);
226         return ret;
227 }
228
229 /*
230  * Assemble a list of algorithms from a default list and a string from a
231  * configuration file. The user-provided string may begin with '+' to
232  * indicate that it should be appended to the default, '-' that the
233  * specified names should be removed, or '^' that they should be placed
234  * at the head.
235  */
236 int
237 kex_assemble_names(char **listp, const char *def, const char *all)
238 {
239         char *cp, *tmp, *patterns;
240         char *list = NULL, *ret = NULL, *matching = NULL, *opatterns = NULL;
241         int r = SSH_ERR_INTERNAL_ERROR;
242
243         if (listp == NULL || def == NULL || all == NULL)
244                 return SSH_ERR_INVALID_ARGUMENT;
245
246         if (*listp == NULL || **listp == '\0') {
247                 if ((*listp = strdup(def)) == NULL)
248                         return SSH_ERR_ALLOC_FAIL;
249                 return 0;
250         }
251
252         list = *listp;
253         *listp = NULL;
254         if (*list == '+') {
255                 /* Append names to default list */
256                 if ((tmp = kex_names_cat(def, list + 1)) == NULL) {
257                         r = SSH_ERR_ALLOC_FAIL;
258                         goto fail;
259                 }
260                 free(list);
261                 list = tmp;
262         } else if (*list == '-') {
263                 /* Remove names from default list */
264                 if ((*listp = match_filter_denylist(def, list + 1)) == NULL) {
265                         r = SSH_ERR_ALLOC_FAIL;
266                         goto fail;
267                 }
268                 free(list);
269                 /* filtering has already been done */
270                 return 0;
271         } else if (*list == '^') {
272                 /* Place names at head of default list */
273                 if ((tmp = kex_names_cat(list + 1, def)) == NULL) {
274                         r = SSH_ERR_ALLOC_FAIL;
275                         goto fail;
276                 }
277                 free(list);
278                 list = tmp;
279         } else {
280                 /* Explicit list, overrides default - just use "list" as is */
281         }
282
283         /*
284          * The supplied names may be a pattern-list. For the -list case,
285          * the patterns are applied above. For the +list and explicit list
286          * cases we need to do it now.
287          */
288         ret = NULL;
289         if ((patterns = opatterns = strdup(list)) == NULL) {
290                 r = SSH_ERR_ALLOC_FAIL;
291                 goto fail;
292         }
293         /* Apply positive (i.e. non-negated) patterns from the list */
294         while ((cp = strsep(&patterns, ",")) != NULL) {
295                 if (*cp == '!') {
296                         /* negated matches are not supported here */
297                         r = SSH_ERR_INVALID_ARGUMENT;
298                         goto fail;
299                 }
300                 free(matching);
301                 if ((matching = match_filter_allowlist(all, cp)) == NULL) {
302                         r = SSH_ERR_ALLOC_FAIL;
303                         goto fail;
304                 }
305                 if ((tmp = kex_names_cat(ret, matching)) == NULL) {
306                         r = SSH_ERR_ALLOC_FAIL;
307                         goto fail;
308                 }
309                 free(ret);
310                 ret = tmp;
311         }
312         if (ret == NULL || *ret == '\0') {
313                 /* An empty name-list is an error */
314                 /* XXX better error code? */
315                 r = SSH_ERR_INVALID_ARGUMENT;
316                 goto fail;
317         }
318
319         /* success */
320         *listp = ret;
321         ret = NULL;
322         r = 0;
323
324  fail:
325         free(matching);
326         free(opatterns);
327         free(list);
328         free(ret);
329         return r;
330 }
331
332 /*
333  * Fill out a proposal array with dynamically allocated values, which may
334  * be modified as required for compatibility reasons.
335  * Any of the options may be NULL, in which case the default is used.
336  * Array contents must be freed by calling kex_proposal_free_entries.
337  */
338 void
339 kex_proposal_populate_entries(struct ssh *ssh, char *prop[PROPOSAL_MAX],
340     const char *kexalgos, const char *ciphers, const char *macs,
341     const char *comp, const char *hkalgs)
342 {
343         const char *defpropserver[PROPOSAL_MAX] = { KEX_SERVER };
344         const char *defpropclient[PROPOSAL_MAX] = { KEX_CLIENT };
345         const char **defprop = ssh->kex->server ? defpropserver : defpropclient;
346         u_int i;
347         char *cp;
348
349         if (prop == NULL)
350                 fatal_f("proposal missing");
351
352         /* Append EXT_INFO signalling to KexAlgorithms */
353         if (kexalgos == NULL)
354                 kexalgos = defprop[PROPOSAL_KEX_ALGS];
355         if ((cp = kex_names_cat(kexalgos, ssh->kex->server ?
356             "ext-info-s,kex-strict-s-v00@openssh.com" :
357             "ext-info-c,kex-strict-c-v00@openssh.com")) == NULL)
358                 fatal_f("kex_names_cat");
359
360         for (i = 0; i < PROPOSAL_MAX; i++) {
361                 switch(i) {
362                 case PROPOSAL_KEX_ALGS:
363                         prop[i] = compat_kex_proposal(ssh, cp);
364                         break;
365                 case PROPOSAL_ENC_ALGS_CTOS:
366                 case PROPOSAL_ENC_ALGS_STOC:
367                         prop[i] = xstrdup(ciphers ? ciphers : defprop[i]);
368                         break;
369                 case PROPOSAL_MAC_ALGS_CTOS:
370                 case PROPOSAL_MAC_ALGS_STOC:
371                         prop[i]  = xstrdup(macs ? macs : defprop[i]);
372                         break;
373                 case PROPOSAL_COMP_ALGS_CTOS:
374                 case PROPOSAL_COMP_ALGS_STOC:
375                         prop[i] = xstrdup(comp ? comp : defprop[i]);
376                         break;
377                 case PROPOSAL_SERVER_HOST_KEY_ALGS:
378                         prop[i] = xstrdup(hkalgs ? hkalgs : defprop[i]);
379                         break;
380                 default:
381                         prop[i] = xstrdup(defprop[i]);
382                 }
383         }
384         free(cp);
385 }
386
387 void
388 kex_proposal_free_entries(char *prop[PROPOSAL_MAX])
389 {
390         u_int i;
391
392         for (i = 0; i < PROPOSAL_MAX; i++)
393                 free(prop[i]);
394 }
395
396 /* put algorithm proposal into buffer */
397 int
398 kex_prop2buf(struct sshbuf *b, char *proposal[PROPOSAL_MAX])
399 {
400         u_int i;
401         int r;
402
403         sshbuf_reset(b);
404
405         /*
406          * add a dummy cookie, the cookie will be overwritten by
407          * kex_send_kexinit(), each time a kexinit is set
408          */
409         for (i = 0; i < KEX_COOKIE_LEN; i++) {
410                 if ((r = sshbuf_put_u8(b, 0)) != 0)
411                         return r;
412         }
413         for (i = 0; i < PROPOSAL_MAX; i++) {
414                 if ((r = sshbuf_put_cstring(b, proposal[i])) != 0)
415                         return r;
416         }
417         if ((r = sshbuf_put_u8(b, 0)) != 0 ||   /* first_kex_packet_follows */
418             (r = sshbuf_put_u32(b, 0)) != 0)    /* uint32 reserved */
419                 return r;
420         return 0;
421 }
422
423 /* parse buffer and return algorithm proposal */
424 int
425 kex_buf2prop(struct sshbuf *raw, int *first_kex_follows, char ***propp)
426 {
427         struct sshbuf *b = NULL;
428         u_char v;
429         u_int i;
430         char **proposal = NULL;
431         int r;
432
433         *propp = NULL;
434         if ((proposal = calloc(PROPOSAL_MAX, sizeof(char *))) == NULL)
435                 return SSH_ERR_ALLOC_FAIL;
436         if ((b = sshbuf_fromb(raw)) == NULL) {
437                 r = SSH_ERR_ALLOC_FAIL;
438                 goto out;
439         }
440         if ((r = sshbuf_consume(b, KEX_COOKIE_LEN)) != 0) { /* skip cookie */
441                 error_fr(r, "consume cookie");
442                 goto out;
443         }
444         /* extract kex init proposal strings */
445         for (i = 0; i < PROPOSAL_MAX; i++) {
446                 if ((r = sshbuf_get_cstring(b, &(proposal[i]), NULL)) != 0) {
447                         error_fr(r, "parse proposal %u", i);
448                         goto out;
449                 }
450                 debug2("%s: %s", proposal_names[i], proposal[i]);
451         }
452         /* first kex follows / reserved */
453         if ((r = sshbuf_get_u8(b, &v)) != 0 ||  /* first_kex_follows */
454             (r = sshbuf_get_u32(b, &i)) != 0) { /* reserved */
455                 error_fr(r, "parse");
456                 goto out;
457         }
458         if (first_kex_follows != NULL)
459                 *first_kex_follows = v;
460         debug2("first_kex_follows %d ", v);
461         debug2("reserved %u ", i);
462         r = 0;
463         *propp = proposal;
464  out:
465         if (r != 0 && proposal != NULL)
466                 kex_prop_free(proposal);
467         sshbuf_free(b);
468         return r;
469 }
470
471 void
472 kex_prop_free(char **proposal)
473 {
474         u_int i;
475
476         if (proposal == NULL)
477                 return;
478         for (i = 0; i < PROPOSAL_MAX; i++)
479                 free(proposal[i]);
480         free(proposal);
481 }
482
483 int
484 kex_protocol_error(int type, u_int32_t seq, struct ssh *ssh)
485 {
486         int r;
487
488         /* If in strict mode, any unexpected message is an error */
489         if ((ssh->kex->flags & KEX_INITIAL) && ssh->kex->kex_strict) {
490                 ssh_packet_disconnect(ssh, "strict KEX violation: "
491                     "unexpected packet type %u (seqnr %u)", type, seq);
492         }
493         error_f("type %u seq %u", type, seq);
494         if ((r = sshpkt_start(ssh, SSH2_MSG_UNIMPLEMENTED)) != 0 ||
495             (r = sshpkt_put_u32(ssh, seq)) != 0 ||
496             (r = sshpkt_send(ssh)) != 0)
497                 return r;
498         return 0;
499 }
500
501 static void
502 kex_reset_dispatch(struct ssh *ssh)
503 {
504         ssh_dispatch_range(ssh, SSH2_MSG_TRANSPORT_MIN,
505             SSH2_MSG_TRANSPORT_MAX, &kex_protocol_error);
506 }
507
508 void
509 kex_set_server_sig_algs(struct ssh *ssh, const char *allowed_algs)
510 {
511         char *alg, *oalgs, *algs, *sigalgs;
512         const char *sigalg;
513
514         /*
515          * NB. allowed algorithms may contain certificate algorithms that
516          * map to a specific plain signature type, e.g.
517          * rsa-sha2-512-cert-v01@openssh.com => rsa-sha2-512
518          * We need to be careful here to match these, retain the mapping
519          * and only add each signature algorithm once.
520          */
521         if ((sigalgs = sshkey_alg_list(0, 1, 1, ',')) == NULL)
522                 fatal_f("sshkey_alg_list failed");
523         oalgs = algs = xstrdup(allowed_algs);
524         free(ssh->kex->server_sig_algs);
525         ssh->kex->server_sig_algs = NULL;
526         for ((alg = strsep(&algs, ",")); alg != NULL && *alg != '\0';
527             (alg = strsep(&algs, ","))) {
528                 if ((sigalg = sshkey_sigalg_by_name(alg)) == NULL)
529                         continue;
530                 if (!has_any_alg(sigalg, sigalgs))
531                         continue;
532                 /* Don't add an algorithm twice. */
533                 if (ssh->kex->server_sig_algs != NULL &&
534                     has_any_alg(sigalg, ssh->kex->server_sig_algs))
535                         continue;
536                 xextendf(&ssh->kex->server_sig_algs, ",", "%s", sigalg);
537         }
538         free(oalgs);
539         free(sigalgs);
540         if (ssh->kex->server_sig_algs == NULL)
541                 ssh->kex->server_sig_algs = xstrdup("");
542 }
543
544 static int
545 kex_compose_ext_info_server(struct ssh *ssh, struct sshbuf *m)
546 {
547         int r;
548
549         if (ssh->kex->server_sig_algs == NULL &&
550             (ssh->kex->server_sig_algs = sshkey_alg_list(0, 1, 1, ',')) == NULL)
551                 return SSH_ERR_ALLOC_FAIL;
552         if ((r = sshbuf_put_u32(m, 3)) != 0 ||
553             (r = sshbuf_put_cstring(m, "server-sig-algs")) != 0 ||
554             (r = sshbuf_put_cstring(m, ssh->kex->server_sig_algs)) != 0 ||
555             (r = sshbuf_put_cstring(m,
556             "publickey-hostbound@openssh.com")) != 0 ||
557             (r = sshbuf_put_cstring(m, "0")) != 0 ||
558             (r = sshbuf_put_cstring(m, "ping@openssh.com")) != 0 ||
559             (r = sshbuf_put_cstring(m, "0")) != 0) {
560                 error_fr(r, "compose");
561                 return r;
562         }
563         return 0;
564 }
565
566 static int
567 kex_compose_ext_info_client(struct ssh *ssh, struct sshbuf *m)
568 {
569         int r;
570
571         if ((r = sshbuf_put_u32(m, 1)) != 0 ||
572             (r = sshbuf_put_cstring(m, "ext-info-in-auth@openssh.com")) != 0 ||
573             (r = sshbuf_put_cstring(m, "0")) != 0) {
574                 error_fr(r, "compose");
575                 goto out;
576         }
577         /* success */
578         r = 0;
579  out:
580         return r;
581 }
582
583 static int
584 kex_maybe_send_ext_info(struct ssh *ssh)
585 {
586         int r;
587         struct sshbuf *m = NULL;
588
589         if ((ssh->kex->flags & KEX_INITIAL) == 0)
590                 return 0;
591         if (!ssh->kex->ext_info_c && !ssh->kex->ext_info_s)
592                 return 0;
593
594         /* Compose EXT_INFO packet. */
595         if ((m = sshbuf_new()) == NULL)
596                 fatal_f("sshbuf_new failed");
597         if (ssh->kex->ext_info_c &&
598             (r = kex_compose_ext_info_server(ssh, m)) != 0)
599                 goto fail;
600         if (ssh->kex->ext_info_s &&
601             (r = kex_compose_ext_info_client(ssh, m)) != 0)
602                 goto fail;
603
604         /* Send the actual KEX_INFO packet */
605         debug("Sending SSH2_MSG_EXT_INFO");
606         if ((r = sshpkt_start(ssh, SSH2_MSG_EXT_INFO)) != 0 ||
607             (r = sshpkt_putb(ssh, m)) != 0 ||
608             (r = sshpkt_send(ssh)) != 0) {
609                 error_f("send EXT_INFO");
610                 goto fail;
611         }
612
613         r = 0;
614
615  fail:
616         sshbuf_free(m);
617         return r;
618 }
619
620 int
621 kex_server_update_ext_info(struct ssh *ssh)
622 {
623         int r;
624
625         if ((ssh->kex->flags & KEX_HAS_EXT_INFO_IN_AUTH) == 0)
626                 return 0;
627
628         debug_f("Sending SSH2_MSG_EXT_INFO");
629         if ((r = sshpkt_start(ssh, SSH2_MSG_EXT_INFO)) != 0 ||
630             (r = sshpkt_put_u32(ssh, 1)) != 0 ||
631             (r = sshpkt_put_cstring(ssh, "server-sig-algs")) != 0 ||
632             (r = sshpkt_put_cstring(ssh, ssh->kex->server_sig_algs)) != 0 ||
633             (r = sshpkt_send(ssh)) != 0) {
634                 error_f("send EXT_INFO");
635                 return r;
636         }
637         return 0;
638 }
639
640 int
641 kex_send_newkeys(struct ssh *ssh)
642 {
643         int r;
644
645         kex_reset_dispatch(ssh);
646         if ((r = sshpkt_start(ssh, SSH2_MSG_NEWKEYS)) != 0 ||
647             (r = sshpkt_send(ssh)) != 0)
648                 return r;
649         debug("SSH2_MSG_NEWKEYS sent");
650         ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_input_newkeys);
651         if ((r = kex_maybe_send_ext_info(ssh)) != 0)
652                 return r;
653         debug("expecting SSH2_MSG_NEWKEYS");
654         return 0;
655 }
656
657 /* Check whether an ext_info value contains the expected version string */
658 static int
659 kex_ext_info_check_ver(struct kex *kex, const char *name,
660     const u_char *val, size_t len, const char *want_ver, u_int flag)
661 {
662         if (memchr(val, '\0', len) != NULL) {
663                 error("SSH2_MSG_EXT_INFO: %s value contains nul byte", name);
664                 return SSH_ERR_INVALID_FORMAT;
665         }
666         debug_f("%s=<%s>", name, val);
667         if (strcmp(val, want_ver) == 0)
668                 kex->flags |= flag;
669         else
670                 debug_f("unsupported version of %s extension", name);
671         return 0;
672 }
673
674 static int
675 kex_ext_info_client_parse(struct ssh *ssh, const char *name,
676     const u_char *value, size_t vlen)
677 {
678         int r;
679
680         /* NB. some messages are only accepted in the initial EXT_INFO */
681         if (strcmp(name, "server-sig-algs") == 0) {
682                 /* Ensure no \0 lurking in value */
683                 if (memchr(value, '\0', vlen) != NULL) {
684                         error_f("nul byte in %s", name);
685                         return SSH_ERR_INVALID_FORMAT;
686                 }
687                 debug_f("%s=<%s>", name, value);
688                 free(ssh->kex->server_sig_algs);
689                 ssh->kex->server_sig_algs = xstrdup((const char *)value);
690         } else if (ssh->kex->ext_info_received == 1 &&
691             strcmp(name, "publickey-hostbound@openssh.com") == 0) {
692                 if ((r = kex_ext_info_check_ver(ssh->kex, name, value, vlen,
693                     "0", KEX_HAS_PUBKEY_HOSTBOUND)) != 0) {
694                         return r;
695                 }
696         } else if (ssh->kex->ext_info_received == 1 &&
697             strcmp(name, "ping@openssh.com") == 0) {
698                 if ((r = kex_ext_info_check_ver(ssh->kex, name, value, vlen,
699                     "0", KEX_HAS_PING)) != 0) {
700                         return r;
701                 }
702         } else
703                 debug_f("%s (unrecognised)", name);
704
705         return 0;
706 }
707
708 static int
709 kex_ext_info_server_parse(struct ssh *ssh, const char *name,
710     const u_char *value, size_t vlen)
711 {
712         int r;
713
714         if (strcmp(name, "ext-info-in-auth@openssh.com") == 0) {
715                 if ((r = kex_ext_info_check_ver(ssh->kex, name, value, vlen,
716                     "0", KEX_HAS_EXT_INFO_IN_AUTH)) != 0) {
717                         return r;
718                 }
719         } else
720                 debug_f("%s (unrecognised)", name);
721         return 0;
722 }
723
724 int
725 kex_input_ext_info(int type, u_int32_t seq, struct ssh *ssh)
726 {
727         struct kex *kex = ssh->kex;
728         const int max_ext_info = kex->server ? 1 : 2;
729         u_int32_t i, ninfo;
730         char *name;
731         u_char *val;
732         size_t vlen;
733         int r;
734
735         debug("SSH2_MSG_EXT_INFO received");
736         if (++kex->ext_info_received > max_ext_info) {
737                 error("too many SSH2_MSG_EXT_INFO messages sent by peer");
738                 return dispatch_protocol_error(type, seq, ssh);
739         }
740         ssh_dispatch_set(ssh, SSH2_MSG_EXT_INFO, &kex_protocol_error);
741         if ((r = sshpkt_get_u32(ssh, &ninfo)) != 0)
742                 return r;
743         if (ninfo >= 1024) {
744                 error("SSH2_MSG_EXT_INFO with too many entries, expected "
745                     "<=1024, received %u", ninfo);
746                 return dispatch_protocol_error(type, seq, ssh);
747         }
748         for (i = 0; i < ninfo; i++) {
749                 if ((r = sshpkt_get_cstring(ssh, &name, NULL)) != 0)
750                         return r;
751                 if ((r = sshpkt_get_string(ssh, &val, &vlen)) != 0) {
752                         free(name);
753                         return r;
754                 }
755                 debug3_f("extension %s", name);
756                 if (kex->server) {
757                         if ((r = kex_ext_info_server_parse(ssh, name,
758                             val, vlen)) != 0)
759                                 return r;
760                 } else {
761                         if ((r = kex_ext_info_client_parse(ssh, name,
762                             val, vlen)) != 0)
763                                 return r;
764                 }
765                 free(name);
766                 free(val);
767         }
768         return sshpkt_get_end(ssh);
769 }
770
771 static int
772 kex_input_newkeys(int type, u_int32_t seq, struct ssh *ssh)
773 {
774         struct kex *kex = ssh->kex;
775         int r;
776
777         debug("SSH2_MSG_NEWKEYS received");
778         if (kex->ext_info_c && (kex->flags & KEX_INITIAL) != 0)
779                 ssh_dispatch_set(ssh, SSH2_MSG_EXT_INFO, &kex_input_ext_info);
780         ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_protocol_error);
781         ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
782         if ((r = sshpkt_get_end(ssh)) != 0)
783                 return r;
784         if ((r = ssh_set_newkeys(ssh, MODE_IN)) != 0)
785                 return r;
786         kex->done = 1;
787         kex->flags &= ~KEX_INITIAL;
788         sshbuf_reset(kex->peer);
789         /* sshbuf_reset(kex->my); */
790         kex->flags &= ~KEX_INIT_SENT;
791         free(kex->name);
792         kex->name = NULL;
793         return 0;
794 }
795
796 int
797 kex_send_kexinit(struct ssh *ssh)
798 {
799         u_char *cookie;
800         struct kex *kex = ssh->kex;
801         int r;
802
803         if (kex == NULL) {
804                 error_f("no kex");
805                 return SSH_ERR_INTERNAL_ERROR;
806         }
807         if (kex->flags & KEX_INIT_SENT)
808                 return 0;
809         kex->done = 0;
810
811         /* generate a random cookie */
812         if (sshbuf_len(kex->my) < KEX_COOKIE_LEN) {
813                 error_f("bad kex length: %zu < %d",
814                     sshbuf_len(kex->my), KEX_COOKIE_LEN);
815                 return SSH_ERR_INVALID_FORMAT;
816         }
817         if ((cookie = sshbuf_mutable_ptr(kex->my)) == NULL) {
818                 error_f("buffer error");
819                 return SSH_ERR_INTERNAL_ERROR;
820         }
821         arc4random_buf(cookie, KEX_COOKIE_LEN);
822
823         if ((r = sshpkt_start(ssh, SSH2_MSG_KEXINIT)) != 0 ||
824             (r = sshpkt_putb(ssh, kex->my)) != 0 ||
825             (r = sshpkt_send(ssh)) != 0) {
826                 error_fr(r, "compose reply");
827                 return r;
828         }
829         debug("SSH2_MSG_KEXINIT sent");
830         kex->flags |= KEX_INIT_SENT;
831         return 0;
832 }
833
834 int
835 kex_input_kexinit(int type, u_int32_t seq, struct ssh *ssh)
836 {
837         struct kex *kex = ssh->kex;
838         const u_char *ptr;
839         u_int i;
840         size_t dlen;
841         int r;
842
843         debug("SSH2_MSG_KEXINIT received");
844         if (kex == NULL) {
845                 error_f("no kex");
846                 return SSH_ERR_INTERNAL_ERROR;
847         }
848         ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_protocol_error);
849         ptr = sshpkt_ptr(ssh, &dlen);
850         if ((r = sshbuf_put(kex->peer, ptr, dlen)) != 0)
851                 return r;
852
853         /* discard packet */
854         for (i = 0; i < KEX_COOKIE_LEN; i++) {
855                 if ((r = sshpkt_get_u8(ssh, NULL)) != 0) {
856                         error_fr(r, "discard cookie");
857                         return r;
858                 }
859         }
860         for (i = 0; i < PROPOSAL_MAX; i++) {
861                 if ((r = sshpkt_get_string(ssh, NULL, NULL)) != 0) {
862                         error_fr(r, "discard proposal");
863                         return r;
864                 }
865         }
866         /*
867          * XXX RFC4253 sec 7: "each side MAY guess" - currently no supported
868          * KEX method has the server move first, but a server might be using
869          * a custom method or one that we otherwise don't support. We should
870          * be prepared to remember first_kex_follows here so we can eat a
871          * packet later.
872          * XXX2 - RFC4253 is kind of ambiguous on what first_kex_follows means
873          * for cases where the server *doesn't* go first. I guess we should
874          * ignore it when it is set for these cases, which is what we do now.
875          */
876         if ((r = sshpkt_get_u8(ssh, NULL)) != 0 ||      /* first_kex_follows */
877             (r = sshpkt_get_u32(ssh, NULL)) != 0 ||     /* reserved */
878             (r = sshpkt_get_end(ssh)) != 0)
879                         return r;
880
881         if (!(kex->flags & KEX_INIT_SENT))
882                 if ((r = kex_send_kexinit(ssh)) != 0)
883                         return r;
884         if ((r = kex_choose_conf(ssh, seq)) != 0)
885                 return r;
886
887         if (kex->kex_type < KEX_MAX && kex->kex[kex->kex_type] != NULL)
888                 return (kex->kex[kex->kex_type])(ssh);
889
890         error_f("unknown kex type %u", kex->kex_type);
891         return SSH_ERR_INTERNAL_ERROR;
892 }
893
894 struct kex *
895 kex_new(void)
896 {
897         struct kex *kex;
898
899         if ((kex = calloc(1, sizeof(*kex))) == NULL ||
900             (kex->peer = sshbuf_new()) == NULL ||
901             (kex->my = sshbuf_new()) == NULL ||
902             (kex->client_version = sshbuf_new()) == NULL ||
903             (kex->server_version = sshbuf_new()) == NULL ||
904             (kex->session_id = sshbuf_new()) == NULL) {
905                 kex_free(kex);
906                 return NULL;
907         }
908         return kex;
909 }
910
911 void
912 kex_free_newkeys(struct newkeys *newkeys)
913 {
914         if (newkeys == NULL)
915                 return;
916         if (newkeys->enc.key) {
917                 explicit_bzero(newkeys->enc.key, newkeys->enc.key_len);
918                 free(newkeys->enc.key);
919                 newkeys->enc.key = NULL;
920         }
921         if (newkeys->enc.iv) {
922                 explicit_bzero(newkeys->enc.iv, newkeys->enc.iv_len);
923                 free(newkeys->enc.iv);
924                 newkeys->enc.iv = NULL;
925         }
926         free(newkeys->enc.name);
927         explicit_bzero(&newkeys->enc, sizeof(newkeys->enc));
928         free(newkeys->comp.name);
929         explicit_bzero(&newkeys->comp, sizeof(newkeys->comp));
930         mac_clear(&newkeys->mac);
931         if (newkeys->mac.key) {
932                 explicit_bzero(newkeys->mac.key, newkeys->mac.key_len);
933                 free(newkeys->mac.key);
934                 newkeys->mac.key = NULL;
935         }
936         free(newkeys->mac.name);
937         explicit_bzero(&newkeys->mac, sizeof(newkeys->mac));
938         freezero(newkeys, sizeof(*newkeys));
939 }
940
941 void
942 kex_free(struct kex *kex)
943 {
944         u_int mode;
945
946         if (kex == NULL)
947                 return;
948
949 #ifdef WITH_OPENSSL
950         DH_free(kex->dh);
951 #ifdef OPENSSL_HAS_ECC
952         EC_KEY_free(kex->ec_client_key);
953 #endif /* OPENSSL_HAS_ECC */
954 #endif /* WITH_OPENSSL */
955         for (mode = 0; mode < MODE_MAX; mode++) {
956                 kex_free_newkeys(kex->newkeys[mode]);
957                 kex->newkeys[mode] = NULL;
958         }
959         sshbuf_free(kex->peer);
960         sshbuf_free(kex->my);
961         sshbuf_free(kex->client_version);
962         sshbuf_free(kex->server_version);
963         sshbuf_free(kex->client_pub);
964         sshbuf_free(kex->session_id);
965         sshbuf_free(kex->initial_sig);
966         sshkey_free(kex->initial_hostkey);
967         free(kex->failed_choice);
968         free(kex->hostkey_alg);
969         free(kex->name);
970         free(kex);
971 }
972
973 int
974 kex_ready(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
975 {
976         int r;
977
978         if ((r = kex_prop2buf(ssh->kex->my, proposal)) != 0)
979                 return r;
980         ssh->kex->flags = KEX_INITIAL;
981         kex_reset_dispatch(ssh);
982         ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
983         return 0;
984 }
985
986 int
987 kex_setup(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
988 {
989         int r;
990
991         if ((r = kex_ready(ssh, proposal)) != 0)
992                 return r;
993         if ((r = kex_send_kexinit(ssh)) != 0) {         /* we start */
994                 kex_free(ssh->kex);
995                 ssh->kex = NULL;
996                 return r;
997         }
998         return 0;
999 }
1000
1001 /*
1002  * Request key re-exchange, returns 0 on success or a ssherr.h error
1003  * code otherwise. Must not be called if KEX is incomplete or in-progress.
1004  */
1005 int
1006 kex_start_rekex(struct ssh *ssh)
1007 {
1008         if (ssh->kex == NULL) {
1009                 error_f("no kex");
1010                 return SSH_ERR_INTERNAL_ERROR;
1011         }
1012         if (ssh->kex->done == 0) {
1013                 error_f("requested twice");
1014                 return SSH_ERR_INTERNAL_ERROR;
1015         }
1016         ssh->kex->done = 0;
1017         return kex_send_kexinit(ssh);
1018 }
1019
1020 static int
1021 choose_enc(struct sshenc *enc, char *client, char *server)
1022 {
1023         char *name = match_list(client, server, NULL);
1024
1025         if (name == NULL)
1026                 return SSH_ERR_NO_CIPHER_ALG_MATCH;
1027         if ((enc->cipher = cipher_by_name(name)) == NULL) {
1028                 error_f("unsupported cipher %s", name);
1029                 free(name);
1030                 return SSH_ERR_INTERNAL_ERROR;
1031         }
1032         enc->name = name;
1033         enc->enabled = 0;
1034         enc->iv = NULL;
1035         enc->iv_len = cipher_ivlen(enc->cipher);
1036         enc->key = NULL;
1037         enc->key_len = cipher_keylen(enc->cipher);
1038         enc->block_size = cipher_blocksize(enc->cipher);
1039         return 0;
1040 }
1041
1042 static int
1043 choose_mac(struct ssh *ssh, struct sshmac *mac, char *client, char *server)
1044 {
1045         char *name = match_list(client, server, NULL);
1046
1047         if (name == NULL)
1048                 return SSH_ERR_NO_MAC_ALG_MATCH;
1049         if (mac_setup(mac, name) < 0) {
1050                 error_f("unsupported MAC %s", name);
1051                 free(name);
1052                 return SSH_ERR_INTERNAL_ERROR;
1053         }
1054         mac->name = name;
1055         mac->key = NULL;
1056         mac->enabled = 0;
1057         return 0;
1058 }
1059
1060 static int
1061 choose_comp(struct sshcomp *comp, char *client, char *server)
1062 {
1063         char *name = match_list(client, server, NULL);
1064
1065         if (name == NULL)
1066                 return SSH_ERR_NO_COMPRESS_ALG_MATCH;
1067 #ifdef WITH_ZLIB
1068         if (strcmp(name, "zlib@openssh.com") == 0) {
1069                 comp->type = COMP_DELAYED;
1070         } else if (strcmp(name, "zlib") == 0) {
1071                 comp->type = COMP_ZLIB;
1072         } else
1073 #endif  /* WITH_ZLIB */
1074         if (strcmp(name, "none") == 0) {
1075                 comp->type = COMP_NONE;
1076         } else {
1077                 error_f("unsupported compression scheme %s", name);
1078                 free(name);
1079                 return SSH_ERR_INTERNAL_ERROR;
1080         }
1081         comp->name = name;
1082         return 0;
1083 }
1084
1085 static int
1086 choose_kex(struct kex *k, char *client, char *server)
1087 {
1088         const struct kexalg *kexalg;
1089
1090         k->name = match_list(client, server, NULL);
1091
1092         debug("kex: algorithm: %s", k->name ? k->name : "(no match)");
1093         if (k->name == NULL)
1094                 return SSH_ERR_NO_KEX_ALG_MATCH;
1095         if ((kexalg = kex_alg_by_name(k->name)) == NULL) {
1096                 error_f("unsupported KEX method %s", k->name);
1097                 return SSH_ERR_INTERNAL_ERROR;
1098         }
1099         k->kex_type = kexalg->type;
1100         k->hash_alg = kexalg->hash_alg;
1101         k->ec_nid = kexalg->ec_nid;
1102         return 0;
1103 }
1104
1105 static int
1106 choose_hostkeyalg(struct kex *k, char *client, char *server)
1107 {
1108         free(k->hostkey_alg);
1109         k->hostkey_alg = match_list(client, server, NULL);
1110
1111         debug("kex: host key algorithm: %s",
1112             k->hostkey_alg ? k->hostkey_alg : "(no match)");
1113         if (k->hostkey_alg == NULL)
1114                 return SSH_ERR_NO_HOSTKEY_ALG_MATCH;
1115         k->hostkey_type = sshkey_type_from_name(k->hostkey_alg);
1116         if (k->hostkey_type == KEY_UNSPEC) {
1117                 error_f("unsupported hostkey algorithm %s", k->hostkey_alg);
1118                 return SSH_ERR_INTERNAL_ERROR;
1119         }
1120         k->hostkey_nid = sshkey_ecdsa_nid_from_name(k->hostkey_alg);
1121         return 0;
1122 }
1123
1124 static int
1125 proposals_match(char *my[PROPOSAL_MAX], char *peer[PROPOSAL_MAX])
1126 {
1127         static int check[] = {
1128                 PROPOSAL_KEX_ALGS, PROPOSAL_SERVER_HOST_KEY_ALGS, -1
1129         };
1130         int *idx;
1131         char *p;
1132
1133         for (idx = &check[0]; *idx != -1; idx++) {
1134                 if ((p = strchr(my[*idx], ',')) != NULL)
1135                         *p = '\0';
1136                 if ((p = strchr(peer[*idx], ',')) != NULL)
1137                         *p = '\0';
1138                 if (strcmp(my[*idx], peer[*idx]) != 0) {
1139                         debug2("proposal mismatch: my %s peer %s",
1140                             my[*idx], peer[*idx]);
1141                         return (0);
1142                 }
1143         }
1144         debug2("proposals match");
1145         return (1);
1146 }
1147
1148 static int
1149 kexalgs_contains(char **peer, const char *ext)
1150 {
1151         return has_any_alg(peer[PROPOSAL_KEX_ALGS], ext);
1152 }
1153
1154 static int
1155 kex_choose_conf(struct ssh *ssh, uint32_t seq)
1156 {
1157         struct kex *kex = ssh->kex;
1158         struct newkeys *newkeys;
1159         char **my = NULL, **peer = NULL;
1160         char **cprop, **sprop;
1161         int nenc, nmac, ncomp;
1162         u_int mode, ctos, need, dh_need, authlen;
1163         int r, first_kex_follows;
1164
1165         debug2("local %s KEXINIT proposal", kex->server ? "server" : "client");
1166         if ((r = kex_buf2prop(kex->my, NULL, &my)) != 0)
1167                 goto out;
1168         debug2("peer %s KEXINIT proposal", kex->server ? "client" : "server");
1169         if ((r = kex_buf2prop(kex->peer, &first_kex_follows, &peer)) != 0)
1170                 goto out;
1171
1172         if (kex->server) {
1173                 cprop=peer;
1174                 sprop=my;
1175         } else {
1176                 cprop=my;
1177                 sprop=peer;
1178         }
1179
1180         /* Check whether peer supports ext_info/kex_strict */
1181         if ((kex->flags & KEX_INITIAL) != 0) {
1182                 if (kex->server) {
1183                         kex->ext_info_c = kexalgs_contains(peer, "ext-info-c");
1184                         kex->kex_strict = kexalgs_contains(peer,
1185                             "kex-strict-c-v00@openssh.com");
1186                 } else {
1187                         kex->ext_info_s = kexalgs_contains(peer, "ext-info-s");
1188                         kex->kex_strict = kexalgs_contains(peer,
1189                             "kex-strict-s-v00@openssh.com");
1190                 }
1191                 if (kex->kex_strict) {
1192                         debug3_f("will use strict KEX ordering");
1193                         if (seq != 0)
1194                                 ssh_packet_disconnect(ssh,
1195                                     "strict KEX violation: "
1196                                     "KEXINIT was not the first packet");
1197                 }
1198         }
1199
1200         /* Check whether client supports rsa-sha2 algorithms */
1201         if (kex->server && (kex->flags & KEX_INITIAL)) {
1202                 if (has_any_alg(peer[PROPOSAL_SERVER_HOST_KEY_ALGS],
1203                     "rsa-sha2-256,rsa-sha2-256-cert-v01@openssh.com"))
1204                         kex->flags |= KEX_RSA_SHA2_256_SUPPORTED;
1205                 if (has_any_alg(peer[PROPOSAL_SERVER_HOST_KEY_ALGS],
1206                     "rsa-sha2-512,rsa-sha2-512-cert-v01@openssh.com"))
1207                         kex->flags |= KEX_RSA_SHA2_512_SUPPORTED;
1208         }
1209
1210         /* Algorithm Negotiation */
1211         if ((r = choose_kex(kex, cprop[PROPOSAL_KEX_ALGS],
1212             sprop[PROPOSAL_KEX_ALGS])) != 0) {
1213                 kex->failed_choice = peer[PROPOSAL_KEX_ALGS];
1214                 peer[PROPOSAL_KEX_ALGS] = NULL;
1215                 goto out;
1216         }
1217         if ((r = choose_hostkeyalg(kex, cprop[PROPOSAL_SERVER_HOST_KEY_ALGS],
1218             sprop[PROPOSAL_SERVER_HOST_KEY_ALGS])) != 0) {
1219                 kex->failed_choice = peer[PROPOSAL_SERVER_HOST_KEY_ALGS];
1220                 peer[PROPOSAL_SERVER_HOST_KEY_ALGS] = NULL;
1221                 goto out;
1222         }
1223         for (mode = 0; mode < MODE_MAX; mode++) {
1224                 if ((newkeys = calloc(1, sizeof(*newkeys))) == NULL) {
1225                         r = SSH_ERR_ALLOC_FAIL;
1226                         goto out;
1227                 }
1228                 kex->newkeys[mode] = newkeys;
1229                 ctos = (!kex->server && mode == MODE_OUT) ||
1230                     (kex->server && mode == MODE_IN);
1231                 nenc  = ctos ? PROPOSAL_ENC_ALGS_CTOS  : PROPOSAL_ENC_ALGS_STOC;
1232                 nmac  = ctos ? PROPOSAL_MAC_ALGS_CTOS  : PROPOSAL_MAC_ALGS_STOC;
1233                 ncomp = ctos ? PROPOSAL_COMP_ALGS_CTOS : PROPOSAL_COMP_ALGS_STOC;
1234                 if ((r = choose_enc(&newkeys->enc, cprop[nenc],
1235                     sprop[nenc])) != 0) {
1236                         kex->failed_choice = peer[nenc];
1237                         peer[nenc] = NULL;
1238                         goto out;
1239                 }
1240                 authlen = cipher_authlen(newkeys->enc.cipher);
1241                 /* ignore mac for authenticated encryption */
1242                 if (authlen == 0 &&
1243                     (r = choose_mac(ssh, &newkeys->mac, cprop[nmac],
1244                     sprop[nmac])) != 0) {
1245                         kex->failed_choice = peer[nmac];
1246                         peer[nmac] = NULL;
1247                         goto out;
1248                 }
1249                 if ((r = choose_comp(&newkeys->comp, cprop[ncomp],
1250                     sprop[ncomp])) != 0) {
1251                         kex->failed_choice = peer[ncomp];
1252                         peer[ncomp] = NULL;
1253                         goto out;
1254                 }
1255                 debug("kex: %s cipher: %s MAC: %s compression: %s",
1256                     ctos ? "client->server" : "server->client",
1257                     newkeys->enc.name,
1258                     authlen == 0 ? newkeys->mac.name : "<implicit>",
1259                     newkeys->comp.name);
1260         }
1261         need = dh_need = 0;
1262         for (mode = 0; mode < MODE_MAX; mode++) {
1263                 newkeys = kex->newkeys[mode];
1264                 need = MAXIMUM(need, newkeys->enc.key_len);
1265                 need = MAXIMUM(need, newkeys->enc.block_size);
1266                 need = MAXIMUM(need, newkeys->enc.iv_len);
1267                 need = MAXIMUM(need, newkeys->mac.key_len);
1268                 dh_need = MAXIMUM(dh_need, cipher_seclen(newkeys->enc.cipher));
1269                 dh_need = MAXIMUM(dh_need, newkeys->enc.block_size);
1270                 dh_need = MAXIMUM(dh_need, newkeys->enc.iv_len);
1271                 dh_need = MAXIMUM(dh_need, newkeys->mac.key_len);
1272         }
1273         /* XXX need runden? */
1274         kex->we_need = need;
1275         kex->dh_need = dh_need;
1276
1277         /* ignore the next message if the proposals do not match */
1278         if (first_kex_follows && !proposals_match(my, peer))
1279                 ssh->dispatch_skip_packets = 1;
1280         r = 0;
1281  out:
1282         kex_prop_free(my);
1283         kex_prop_free(peer);
1284         return r;
1285 }
1286
1287 static int
1288 derive_key(struct ssh *ssh, int id, u_int need, u_char *hash, u_int hashlen,
1289     const struct sshbuf *shared_secret, u_char **keyp)
1290 {
1291         struct kex *kex = ssh->kex;
1292         struct ssh_digest_ctx *hashctx = NULL;
1293         char c = id;
1294         u_int have;
1295         size_t mdsz;
1296         u_char *digest;
1297         int r;
1298
1299         if ((mdsz = ssh_digest_bytes(kex->hash_alg)) == 0)
1300                 return SSH_ERR_INVALID_ARGUMENT;
1301         if ((digest = calloc(1, ROUNDUP(need, mdsz))) == NULL) {
1302                 r = SSH_ERR_ALLOC_FAIL;
1303                 goto out;
1304         }
1305
1306         /* K1 = HASH(K || H || "A" || session_id) */
1307         if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
1308             ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
1309             ssh_digest_update(hashctx, hash, hashlen) != 0 ||
1310             ssh_digest_update(hashctx, &c, 1) != 0 ||
1311             ssh_digest_update_buffer(hashctx, kex->session_id) != 0 ||
1312             ssh_digest_final(hashctx, digest, mdsz) != 0) {
1313                 r = SSH_ERR_LIBCRYPTO_ERROR;
1314                 error_f("KEX hash failed");
1315                 goto out;
1316         }
1317         ssh_digest_free(hashctx);
1318         hashctx = NULL;
1319
1320         /*
1321          * expand key:
1322          * Kn = HASH(K || H || K1 || K2 || ... || Kn-1)
1323          * Key = K1 || K2 || ... || Kn
1324          */
1325         for (have = mdsz; need > have; have += mdsz) {
1326                 if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
1327                     ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
1328                     ssh_digest_update(hashctx, hash, hashlen) != 0 ||
1329                     ssh_digest_update(hashctx, digest, have) != 0 ||
1330                     ssh_digest_final(hashctx, digest + have, mdsz) != 0) {
1331                         error_f("KDF failed");
1332                         r = SSH_ERR_LIBCRYPTO_ERROR;
1333                         goto out;
1334                 }
1335                 ssh_digest_free(hashctx);
1336                 hashctx = NULL;
1337         }
1338 #ifdef DEBUG_KEX
1339         fprintf(stderr, "key '%c'== ", c);
1340         dump_digest("key", digest, need);
1341 #endif
1342         *keyp = digest;
1343         digest = NULL;
1344         r = 0;
1345  out:
1346         free(digest);
1347         ssh_digest_free(hashctx);
1348         return r;
1349 }
1350
1351 #define NKEYS   6
1352 int
1353 kex_derive_keys(struct ssh *ssh, u_char *hash, u_int hashlen,
1354     const struct sshbuf *shared_secret)
1355 {
1356         struct kex *kex = ssh->kex;
1357         u_char *keys[NKEYS];
1358         u_int i, j, mode, ctos;
1359         int r;
1360
1361         /* save initial hash as session id */
1362         if ((kex->flags & KEX_INITIAL) != 0) {
1363                 if (sshbuf_len(kex->session_id) != 0) {
1364                         error_f("already have session ID at kex");
1365                         return SSH_ERR_INTERNAL_ERROR;
1366                 }
1367                 if ((r = sshbuf_put(kex->session_id, hash, hashlen)) != 0)
1368                         return r;
1369         } else if (sshbuf_len(kex->session_id) == 0) {
1370                 error_f("no session ID in rekex");
1371                 return SSH_ERR_INTERNAL_ERROR;
1372         }
1373         for (i = 0; i < NKEYS; i++) {
1374                 if ((r = derive_key(ssh, 'A'+i, kex->we_need, hash, hashlen,
1375                     shared_secret, &keys[i])) != 0) {
1376                         for (j = 0; j < i; j++)
1377                                 free(keys[j]);
1378                         return r;
1379                 }
1380         }
1381         for (mode = 0; mode < MODE_MAX; mode++) {
1382                 ctos = (!kex->server && mode == MODE_OUT) ||
1383                     (kex->server && mode == MODE_IN);
1384                 kex->newkeys[mode]->enc.iv  = keys[ctos ? 0 : 1];
1385                 kex->newkeys[mode]->enc.key = keys[ctos ? 2 : 3];
1386                 kex->newkeys[mode]->mac.key = keys[ctos ? 4 : 5];
1387         }
1388         return 0;
1389 }
1390
1391 int
1392 kex_load_hostkey(struct ssh *ssh, struct sshkey **prvp, struct sshkey **pubp)
1393 {
1394         struct kex *kex = ssh->kex;
1395
1396         *pubp = NULL;
1397         *prvp = NULL;
1398         if (kex->load_host_public_key == NULL ||
1399             kex->load_host_private_key == NULL) {
1400                 error_f("missing hostkey loader");
1401                 return SSH_ERR_INVALID_ARGUMENT;
1402         }
1403         *pubp = kex->load_host_public_key(kex->hostkey_type,
1404             kex->hostkey_nid, ssh);
1405         *prvp = kex->load_host_private_key(kex->hostkey_type,
1406             kex->hostkey_nid, ssh);
1407         if (*pubp == NULL)
1408                 return SSH_ERR_NO_HOSTKEY_LOADED;
1409         return 0;
1410 }
1411
1412 int
1413 kex_verify_host_key(struct ssh *ssh, struct sshkey *server_host_key)
1414 {
1415         struct kex *kex = ssh->kex;
1416
1417         if (kex->verify_host_key == NULL) {
1418                 error_f("missing hostkey verifier");
1419                 return SSH_ERR_INVALID_ARGUMENT;
1420         }
1421         if (server_host_key->type != kex->hostkey_type ||
1422             (kex->hostkey_type == KEY_ECDSA &&
1423             server_host_key->ecdsa_nid != kex->hostkey_nid))
1424                 return SSH_ERR_KEY_TYPE_MISMATCH;
1425         if (kex->verify_host_key(server_host_key, ssh) == -1)
1426                 return  SSH_ERR_SIGNATURE_INVALID;
1427         return 0;
1428 }
1429
1430 #if defined(DEBUG_KEX) || defined(DEBUG_KEXDH) || defined(DEBUG_KEXECDH)
1431 void
1432 dump_digest(const char *msg, const u_char *digest, int len)
1433 {
1434         fprintf(stderr, "%s\n", msg);
1435         sshbuf_dump_data(digest, len, stderr);
1436 }
1437 #endif
1438
1439 /*
1440  * Send a plaintext error message to the peer, suffixed by \r\n.
1441  * Only used during banner exchange, and there only for the server.
1442  */
1443 static void
1444 send_error(struct ssh *ssh, char *msg)
1445 {
1446         char *crnl = "\r\n";
1447
1448         if (!ssh->kex->server)
1449                 return;
1450
1451         if (atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1452             msg, strlen(msg)) != strlen(msg) ||
1453             atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1454             crnl, strlen(crnl)) != strlen(crnl))
1455                 error_f("write: %.100s", strerror(errno));
1456 }
1457
1458 /*
1459  * Sends our identification string and waits for the peer's. Will block for
1460  * up to timeout_ms (or indefinitely if timeout_ms <= 0).
1461  * Returns on 0 success or a ssherr.h code on failure.
1462  */
1463 int
1464 kex_exchange_identification(struct ssh *ssh, int timeout_ms,
1465     const char *version_addendum)
1466 {
1467         int remote_major, remote_minor, mismatch, oerrno = 0;
1468         size_t len, n;
1469         int r, expect_nl;
1470         u_char c;
1471         struct sshbuf *our_version = ssh->kex->server ?
1472             ssh->kex->server_version : ssh->kex->client_version;
1473         struct sshbuf *peer_version = ssh->kex->server ?
1474             ssh->kex->client_version : ssh->kex->server_version;
1475         char *our_version_string = NULL, *peer_version_string = NULL;
1476         char *cp, *remote_version = NULL;
1477
1478         /* Prepare and send our banner */
1479         sshbuf_reset(our_version);
1480         if (version_addendum != NULL && *version_addendum == '\0')
1481                 version_addendum = NULL;
1482         if ((r = sshbuf_putf(our_version, "SSH-%d.%d-%s%s%s\r\n",
1483             PROTOCOL_MAJOR_2, PROTOCOL_MINOR_2, SSH_VERSION,
1484             version_addendum == NULL ? "" : " ",
1485             version_addendum == NULL ? "" : version_addendum)) != 0) {
1486                 oerrno = errno;
1487                 error_fr(r, "sshbuf_putf");
1488                 goto out;
1489         }
1490
1491         if (atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1492             sshbuf_mutable_ptr(our_version),
1493             sshbuf_len(our_version)) != sshbuf_len(our_version)) {
1494                 oerrno = errno;
1495                 debug_f("write: %.100s", strerror(errno));
1496                 r = SSH_ERR_SYSTEM_ERROR;
1497                 goto out;
1498         }
1499         if ((r = sshbuf_consume_end(our_version, 2)) != 0) { /* trim \r\n */
1500                 oerrno = errno;
1501                 error_fr(r, "sshbuf_consume_end");
1502                 goto out;
1503         }
1504         our_version_string = sshbuf_dup_string(our_version);
1505         if (our_version_string == NULL) {
1506                 error_f("sshbuf_dup_string failed");
1507                 r = SSH_ERR_ALLOC_FAIL;
1508                 goto out;
1509         }
1510         debug("Local version string %.100s", our_version_string);
1511
1512         /* Read other side's version identification. */
1513         for (n = 0; ; n++) {
1514                 if (n >= SSH_MAX_PRE_BANNER_LINES) {
1515                         send_error(ssh, "No SSH identification string "
1516                             "received.");
1517                         error_f("No SSH version received in first %u lines "
1518                             "from server", SSH_MAX_PRE_BANNER_LINES);
1519                         r = SSH_ERR_INVALID_FORMAT;
1520                         goto out;
1521                 }
1522                 sshbuf_reset(peer_version);
1523                 expect_nl = 0;
1524                 for (;;) {
1525                         if (timeout_ms > 0) {
1526                                 r = waitrfd(ssh_packet_get_connection_in(ssh),
1527                                     &timeout_ms, NULL);
1528                                 if (r == -1 && errno == ETIMEDOUT) {
1529                                         send_error(ssh, "Timed out waiting "
1530                                             "for SSH identification string.");
1531                                         error("Connection timed out during "
1532                                             "banner exchange");
1533                                         r = SSH_ERR_CONN_TIMEOUT;
1534                                         goto out;
1535                                 } else if (r == -1) {
1536                                         oerrno = errno;
1537                                         error_f("%s", strerror(errno));
1538                                         r = SSH_ERR_SYSTEM_ERROR;
1539                                         goto out;
1540                                 }
1541                         }
1542
1543                         len = atomicio(read, ssh_packet_get_connection_in(ssh),
1544                             &c, 1);
1545                         if (len != 1 && errno == EPIPE) {
1546                                 verbose_f("Connection closed by remote host");
1547                                 r = SSH_ERR_CONN_CLOSED;
1548                                 goto out;
1549                         } else if (len != 1) {
1550                                 oerrno = errno;
1551                                 error_f("read: %.100s", strerror(errno));
1552                                 r = SSH_ERR_SYSTEM_ERROR;
1553                                 goto out;
1554                         }
1555                         if (c == '\r') {
1556                                 expect_nl = 1;
1557                                 continue;
1558                         }
1559                         if (c == '\n')
1560                                 break;
1561                         if (c == '\0' || expect_nl) {
1562                                 verbose_f("banner line contains invalid "
1563                                     "characters");
1564                                 goto invalid;
1565                         }
1566                         if ((r = sshbuf_put_u8(peer_version, c)) != 0) {
1567                                 oerrno = errno;
1568                                 error_fr(r, "sshbuf_put");
1569                                 goto out;
1570                         }
1571                         if (sshbuf_len(peer_version) > SSH_MAX_BANNER_LEN) {
1572                                 verbose_f("banner line too long");
1573                                 goto invalid;
1574                         }
1575                 }
1576                 /* Is this an actual protocol banner? */
1577                 if (sshbuf_len(peer_version) > 4 &&
1578                     memcmp(sshbuf_ptr(peer_version), "SSH-", 4) == 0)
1579                         break;
1580                 /* If not, then just log the line and continue */
1581                 if ((cp = sshbuf_dup_string(peer_version)) == NULL) {
1582                         error_f("sshbuf_dup_string failed");
1583                         r = SSH_ERR_ALLOC_FAIL;
1584                         goto out;
1585                 }
1586                 /* Do not accept lines before the SSH ident from a client */
1587                 if (ssh->kex->server) {
1588                         verbose_f("client sent invalid protocol identifier "
1589                             "\"%.256s\"", cp);
1590                         free(cp);
1591                         goto invalid;
1592                 }
1593                 debug_f("banner line %zu: %s", n, cp);
1594                 free(cp);
1595         }
1596         peer_version_string = sshbuf_dup_string(peer_version);
1597         if (peer_version_string == NULL)
1598                 fatal_f("sshbuf_dup_string failed");
1599         /* XXX must be same size for sscanf */
1600         if ((remote_version = calloc(1, sshbuf_len(peer_version))) == NULL) {
1601                 error_f("calloc failed");
1602                 r = SSH_ERR_ALLOC_FAIL;
1603                 goto out;
1604         }
1605
1606         /*
1607          * Check that the versions match.  In future this might accept
1608          * several versions and set appropriate flags to handle them.
1609          */
1610         if (sscanf(peer_version_string, "SSH-%d.%d-%[^\n]\n",
1611             &remote_major, &remote_minor, remote_version) != 3) {
1612                 error("Bad remote protocol version identification: '%.100s'",
1613                     peer_version_string);
1614  invalid:
1615                 send_error(ssh, "Invalid SSH identification string.");
1616                 r = SSH_ERR_INVALID_FORMAT;
1617                 goto out;
1618         }
1619         debug("Remote protocol version %d.%d, remote software version %.100s",
1620             remote_major, remote_minor, remote_version);
1621         compat_banner(ssh, remote_version);
1622
1623         mismatch = 0;
1624         switch (remote_major) {
1625         case 2:
1626                 break;
1627         case 1:
1628                 if (remote_minor != 99)
1629                         mismatch = 1;
1630                 break;
1631         default:
1632                 mismatch = 1;
1633                 break;
1634         }
1635         if (mismatch) {
1636                 error("Protocol major versions differ: %d vs. %d",
1637                     PROTOCOL_MAJOR_2, remote_major);
1638                 send_error(ssh, "Protocol major versions differ.");
1639                 r = SSH_ERR_NO_PROTOCOL_VERSION;
1640                 goto out;
1641         }
1642
1643         if (ssh->kex->server && (ssh->compat & SSH_BUG_PROBE) != 0) {
1644                 logit("probed from %s port %d with %s.  Don't panic.",
1645                     ssh_remote_ipaddr(ssh), ssh_remote_port(ssh),
1646                     peer_version_string);
1647                 r = SSH_ERR_CONN_CLOSED; /* XXX */
1648                 goto out;
1649         }
1650         if (ssh->kex->server && (ssh->compat & SSH_BUG_SCANNER) != 0) {
1651                 logit("scanned from %s port %d with %s.  Don't panic.",
1652                     ssh_remote_ipaddr(ssh), ssh_remote_port(ssh),
1653                     peer_version_string);
1654                 r = SSH_ERR_CONN_CLOSED; /* XXX */
1655                 goto out;
1656         }
1657         /* success */
1658         r = 0;
1659  out:
1660         free(our_version_string);
1661         free(peer_version_string);
1662         free(remote_version);
1663         if (r == SSH_ERR_SYSTEM_ERROR)
1664                 errno = oerrno;
1665         return r;
1666 }
1667