]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - crypto/openssh/kex.c
THIS BRANCH IS OBSOLETE, PLEASE READ:
[FreeBSD/FreeBSD.git] / crypto / openssh / kex.c
1 /* $OpenBSD: kex.c,v 1.141 2018/07/09 13:37:10 sf Exp $ */
2 /*
3  * Copyright (c) 2000, 2001 Markus Friedl.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
15  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
16  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
17  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
18  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
19  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
20  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
21  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
23  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24  */
25
26 #include "includes.h"
27
28
29 #include <signal.h>
30 #include <stdarg.h>
31 #include <stdio.h>
32 #include <stdlib.h>
33 #include <string.h>
34
35 #ifdef WITH_OPENSSL
36 #include <openssl/crypto.h>
37 #include <openssl/dh.h>
38 #endif
39
40 #include "ssh2.h"
41 #include "packet.h"
42 #include "compat.h"
43 #include "cipher.h"
44 #include "sshkey.h"
45 #include "kex.h"
46 #include "log.h"
47 #include "mac.h"
48 #include "match.h"
49 #include "misc.h"
50 #include "dispatch.h"
51 #include "monitor.h"
52
53 #include "ssherr.h"
54 #include "sshbuf.h"
55 #include "digest.h"
56
57 /* prototype */
58 static int kex_choose_conf(struct ssh *);
59 static int kex_input_newkeys(int, u_int32_t, struct ssh *);
60
61 static const char *proposal_names[PROPOSAL_MAX] = {
62         "KEX algorithms",
63         "host key algorithms",
64         "ciphers ctos",
65         "ciphers stoc",
66         "MACs ctos",
67         "MACs stoc",
68         "compression ctos",
69         "compression stoc",
70         "languages ctos",
71         "languages stoc",
72 };
73
74 struct kexalg {
75         char *name;
76         u_int type;
77         int ec_nid;
78         int hash_alg;
79 };
80 static const struct kexalg kexalgs[] = {
81 #ifdef WITH_OPENSSL
82         { KEX_DH1, KEX_DH_GRP1_SHA1, 0, SSH_DIGEST_SHA1 },
83         { KEX_DH14_SHA1, KEX_DH_GRP14_SHA1, 0, SSH_DIGEST_SHA1 },
84         { KEX_DH14_SHA256, KEX_DH_GRP14_SHA256, 0, SSH_DIGEST_SHA256 },
85         { KEX_DH16_SHA512, KEX_DH_GRP16_SHA512, 0, SSH_DIGEST_SHA512 },
86         { KEX_DH18_SHA512, KEX_DH_GRP18_SHA512, 0, SSH_DIGEST_SHA512 },
87         { KEX_DHGEX_SHA1, KEX_DH_GEX_SHA1, 0, SSH_DIGEST_SHA1 },
88 #ifdef HAVE_EVP_SHA256
89         { KEX_DHGEX_SHA256, KEX_DH_GEX_SHA256, 0, SSH_DIGEST_SHA256 },
90 #endif /* HAVE_EVP_SHA256 */
91 #ifdef OPENSSL_HAS_ECC
92         { KEX_ECDH_SHA2_NISTP256, KEX_ECDH_SHA2,
93             NID_X9_62_prime256v1, SSH_DIGEST_SHA256 },
94         { KEX_ECDH_SHA2_NISTP384, KEX_ECDH_SHA2, NID_secp384r1,
95             SSH_DIGEST_SHA384 },
96 # ifdef OPENSSL_HAS_NISTP521
97         { KEX_ECDH_SHA2_NISTP521, KEX_ECDH_SHA2, NID_secp521r1,
98             SSH_DIGEST_SHA512 },
99 # endif /* OPENSSL_HAS_NISTP521 */
100 #endif /* OPENSSL_HAS_ECC */
101 #endif /* WITH_OPENSSL */
102 #if defined(HAVE_EVP_SHA256) || !defined(WITH_OPENSSL)
103         { KEX_CURVE25519_SHA256, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
104         { KEX_CURVE25519_SHA256_OLD, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
105 #endif /* HAVE_EVP_SHA256 || !WITH_OPENSSL */
106         { NULL, -1, -1, -1},
107 };
108
109 char *
110 kex_alg_list(char sep)
111 {
112         char *ret = NULL, *tmp;
113         size_t nlen, rlen = 0;
114         const struct kexalg *k;
115
116         for (k = kexalgs; k->name != NULL; k++) {
117                 if (ret != NULL)
118                         ret[rlen++] = sep;
119                 nlen = strlen(k->name);
120                 if ((tmp = realloc(ret, rlen + nlen + 2)) == NULL) {
121                         free(ret);
122                         return NULL;
123                 }
124                 ret = tmp;
125                 memcpy(ret + rlen, k->name, nlen + 1);
126                 rlen += nlen;
127         }
128         return ret;
129 }
130
131 static const struct kexalg *
132 kex_alg_by_name(const char *name)
133 {
134         const struct kexalg *k;
135
136         for (k = kexalgs; k->name != NULL; k++) {
137                 if (strcmp(k->name, name) == 0)
138                         return k;
139         }
140         return NULL;
141 }
142
143 /* Validate KEX method name list */
144 int
145 kex_names_valid(const char *names)
146 {
147         char *s, *cp, *p;
148
149         if (names == NULL || strcmp(names, "") == 0)
150                 return 0;
151         if ((s = cp = strdup(names)) == NULL)
152                 return 0;
153         for ((p = strsep(&cp, ",")); p && *p != '\0';
154             (p = strsep(&cp, ","))) {
155                 if (kex_alg_by_name(p) == NULL) {
156                         error("Unsupported KEX algorithm \"%.100s\"", p);
157                         free(s);
158                         return 0;
159                 }
160         }
161         debug3("kex names ok: [%s]", names);
162         free(s);
163         return 1;
164 }
165
166 /*
167  * Concatenate algorithm names, avoiding duplicates in the process.
168  * Caller must free returned string.
169  */
170 char *
171 kex_names_cat(const char *a, const char *b)
172 {
173         char *ret = NULL, *tmp = NULL, *cp, *p, *m;
174         size_t len;
175
176         if (a == NULL || *a == '\0')
177                 return strdup(b);
178         if (b == NULL || *b == '\0')
179                 return strdup(a);
180         if (strlen(b) > 1024*1024)
181                 return NULL;
182         len = strlen(a) + strlen(b) + 2;
183         if ((tmp = cp = strdup(b)) == NULL ||
184             (ret = calloc(1, len)) == NULL) {
185                 free(tmp);
186                 return NULL;
187         }
188         strlcpy(ret, a, len);
189         for ((p = strsep(&cp, ",")); p && *p != '\0'; (p = strsep(&cp, ","))) {
190                 if ((m = match_list(ret, p, NULL)) != NULL) {
191                         free(m);
192                         continue; /* Algorithm already present */
193                 }
194                 if (strlcat(ret, ",", len) >= len ||
195                     strlcat(ret, p, len) >= len) {
196                         free(tmp);
197                         free(ret);
198                         return NULL; /* Shouldn't happen */
199                 }
200         }
201         free(tmp);
202         return ret;
203 }
204
205 /*
206  * Assemble a list of algorithms from a default list and a string from a
207  * configuration file. The user-provided string may begin with '+' to
208  * indicate that it should be appended to the default or '-' that the
209  * specified names should be removed.
210  */
211 int
212 kex_assemble_names(char **listp, const char *def, const char *all)
213 {
214         char *cp, *tmp, *patterns;
215         char *list = NULL, *ret = NULL, *matching = NULL, *opatterns = NULL;
216         int r = SSH_ERR_INTERNAL_ERROR;
217
218         if (listp == NULL || *listp == NULL || **listp == '\0') {
219                 if ((*listp = strdup(def)) == NULL)
220                         return SSH_ERR_ALLOC_FAIL;
221                 return 0;
222         }
223
224         list = *listp;
225         *listp = NULL;
226         if (*list == '+') {
227                 /* Append names to default list */
228                 if ((tmp = kex_names_cat(def, list + 1)) == NULL) {
229                         r = SSH_ERR_ALLOC_FAIL;
230                         goto fail;
231                 }
232                 free(list);
233                 list = tmp;
234         } else if (*list == '-') {
235                 /* Remove names from default list */
236                 if ((*listp = match_filter_blacklist(def, list + 1)) == NULL) {
237                         r = SSH_ERR_ALLOC_FAIL;
238                         goto fail;
239                 }
240                 free(list);
241                 /* filtering has already been done */
242                 return 0;
243         } else {
244                 /* Explicit list, overrides default - just use "list" as is */
245         }
246
247         /*
248          * The supplied names may be a pattern-list. For the -list case,
249          * the patterns are applied above. For the +list and explicit list
250          * cases we need to do it now.
251          */
252         ret = NULL;
253         if ((patterns = opatterns = strdup(list)) == NULL) {
254                 r = SSH_ERR_ALLOC_FAIL;
255                 goto fail;
256         }
257         /* Apply positive (i.e. non-negated) patterns from the list */
258         while ((cp = strsep(&patterns, ",")) != NULL) {
259                 if (*cp == '!') {
260                         /* negated matches are not supported here */
261                         r = SSH_ERR_INVALID_ARGUMENT;
262                         goto fail;
263                 }
264                 free(matching);
265                 if ((matching = match_filter_whitelist(all, cp)) == NULL) {
266                         r = SSH_ERR_ALLOC_FAIL;
267                         goto fail;
268                 }
269                 if ((tmp = kex_names_cat(ret, matching)) == NULL) {
270                         r = SSH_ERR_ALLOC_FAIL;
271                         goto fail;
272                 }
273                 free(ret);
274                 ret = tmp;
275         }
276         if (ret == NULL || *ret == '\0') {
277                 /* An empty name-list is an error */
278                 /* XXX better error code? */
279                 r = SSH_ERR_INVALID_ARGUMENT;
280                 goto fail;
281         }
282
283         /* success */
284         *listp = ret;
285         ret = NULL;
286         r = 0;
287
288  fail:
289         free(matching);
290         free(opatterns);
291         free(list);
292         free(ret);
293         return r;
294 }
295
296 /* put algorithm proposal into buffer */
297 int
298 kex_prop2buf(struct sshbuf *b, char *proposal[PROPOSAL_MAX])
299 {
300         u_int i;
301         int r;
302
303         sshbuf_reset(b);
304
305         /*
306          * add a dummy cookie, the cookie will be overwritten by
307          * kex_send_kexinit(), each time a kexinit is set
308          */
309         for (i = 0; i < KEX_COOKIE_LEN; i++) {
310                 if ((r = sshbuf_put_u8(b, 0)) != 0)
311                         return r;
312         }
313         for (i = 0; i < PROPOSAL_MAX; i++) {
314                 if ((r = sshbuf_put_cstring(b, proposal[i])) != 0)
315                         return r;
316         }
317         if ((r = sshbuf_put_u8(b, 0)) != 0 ||   /* first_kex_packet_follows */
318             (r = sshbuf_put_u32(b, 0)) != 0)    /* uint32 reserved */
319                 return r;
320         return 0;
321 }
322
323 /* parse buffer and return algorithm proposal */
324 int
325 kex_buf2prop(struct sshbuf *raw, int *first_kex_follows, char ***propp)
326 {
327         struct sshbuf *b = NULL;
328         u_char v;
329         u_int i;
330         char **proposal = NULL;
331         int r;
332
333         *propp = NULL;
334         if ((proposal = calloc(PROPOSAL_MAX, sizeof(char *))) == NULL)
335                 return SSH_ERR_ALLOC_FAIL;
336         if ((b = sshbuf_fromb(raw)) == NULL) {
337                 r = SSH_ERR_ALLOC_FAIL;
338                 goto out;
339         }
340         if ((r = sshbuf_consume(b, KEX_COOKIE_LEN)) != 0) /* skip cookie */
341                 goto out;
342         /* extract kex init proposal strings */
343         for (i = 0; i < PROPOSAL_MAX; i++) {
344                 if ((r = sshbuf_get_cstring(b, &(proposal[i]), NULL)) != 0)
345                         goto out;
346                 debug2("%s: %s", proposal_names[i], proposal[i]);
347         }
348         /* first kex follows / reserved */
349         if ((r = sshbuf_get_u8(b, &v)) != 0 ||  /* first_kex_follows */
350             (r = sshbuf_get_u32(b, &i)) != 0)   /* reserved */
351                 goto out;
352         if (first_kex_follows != NULL)
353                 *first_kex_follows = v;
354         debug2("first_kex_follows %d ", v);
355         debug2("reserved %u ", i);
356         r = 0;
357         *propp = proposal;
358  out:
359         if (r != 0 && proposal != NULL)
360                 kex_prop_free(proposal);
361         sshbuf_free(b);
362         return r;
363 }
364
365 void
366 kex_prop_free(char **proposal)
367 {
368         u_int i;
369
370         if (proposal == NULL)
371                 return;
372         for (i = 0; i < PROPOSAL_MAX; i++)
373                 free(proposal[i]);
374         free(proposal);
375 }
376
377 /* ARGSUSED */
378 static int
379 kex_protocol_error(int type, u_int32_t seq, struct ssh *ssh)
380 {
381         int r;
382
383         error("kex protocol error: type %d seq %u", type, seq);
384         if ((r = sshpkt_start(ssh, SSH2_MSG_UNIMPLEMENTED)) != 0 ||
385             (r = sshpkt_put_u32(ssh, seq)) != 0 ||
386             (r = sshpkt_send(ssh)) != 0)
387                 return r;
388         return 0;
389 }
390
391 static void
392 kex_reset_dispatch(struct ssh *ssh)
393 {
394         ssh_dispatch_range(ssh, SSH2_MSG_TRANSPORT_MIN,
395             SSH2_MSG_TRANSPORT_MAX, &kex_protocol_error);
396 }
397
398 static int
399 kex_send_ext_info(struct ssh *ssh)
400 {
401         int r;
402         char *algs;
403
404         if ((algs = sshkey_alg_list(0, 1, 1, ',')) == NULL)
405                 return SSH_ERR_ALLOC_FAIL;
406         /* XXX filter algs list by allowed pubkey/hostbased types */
407         if ((r = sshpkt_start(ssh, SSH2_MSG_EXT_INFO)) != 0 ||
408             (r = sshpkt_put_u32(ssh, 1)) != 0 ||
409             (r = sshpkt_put_cstring(ssh, "server-sig-algs")) != 0 ||
410             (r = sshpkt_put_cstring(ssh, algs)) != 0 ||
411             (r = sshpkt_send(ssh)) != 0)
412                 goto out;
413         /* success */
414         r = 0;
415  out:
416         free(algs);
417         return r;
418 }
419
420 int
421 kex_send_newkeys(struct ssh *ssh)
422 {
423         int r;
424
425         kex_reset_dispatch(ssh);
426         if ((r = sshpkt_start(ssh, SSH2_MSG_NEWKEYS)) != 0 ||
427             (r = sshpkt_send(ssh)) != 0)
428                 return r;
429         debug("SSH2_MSG_NEWKEYS sent");
430         debug("expecting SSH2_MSG_NEWKEYS");
431         ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_input_newkeys);
432         if (ssh->kex->ext_info_c)
433                 if ((r = kex_send_ext_info(ssh)) != 0)
434                         return r;
435         return 0;
436 }
437
438 int
439 kex_input_ext_info(int type, u_int32_t seq, struct ssh *ssh)
440 {
441         struct kex *kex = ssh->kex;
442         u_int32_t i, ninfo;
443         char *name;
444         u_char *val;
445         size_t vlen;
446         int r;
447
448         debug("SSH2_MSG_EXT_INFO received");
449         ssh_dispatch_set(ssh, SSH2_MSG_EXT_INFO, &kex_protocol_error);
450         if ((r = sshpkt_get_u32(ssh, &ninfo)) != 0)
451                 return r;
452         for (i = 0; i < ninfo; i++) {
453                 if ((r = sshpkt_get_cstring(ssh, &name, NULL)) != 0)
454                         return r;
455                 if ((r = sshpkt_get_string(ssh, &val, &vlen)) != 0) {
456                         free(name);
457                         return r;
458                 }
459                 if (strcmp(name, "server-sig-algs") == 0) {
460                         /* Ensure no \0 lurking in value */
461                         if (memchr(val, '\0', vlen) != NULL) {
462                                 error("%s: nul byte in %s", __func__, name);
463                                 return SSH_ERR_INVALID_FORMAT;
464                         }
465                         debug("%s: %s=<%s>", __func__, name, val);
466                         kex->server_sig_algs = val;
467                         val = NULL;
468                 } else
469                         debug("%s: %s (unrecognised)", __func__, name);
470                 free(name);
471                 free(val);
472         }
473         return sshpkt_get_end(ssh);
474 }
475
476 static int
477 kex_input_newkeys(int type, u_int32_t seq, struct ssh *ssh)
478 {
479         struct kex *kex = ssh->kex;
480         int r;
481
482         debug("SSH2_MSG_NEWKEYS received");
483         ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_protocol_error);
484         ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
485         if ((r = sshpkt_get_end(ssh)) != 0)
486                 return r;
487         if ((r = ssh_set_newkeys(ssh, MODE_IN)) != 0)
488                 return r;
489         kex->done = 1;
490         sshbuf_reset(kex->peer);
491         /* sshbuf_reset(kex->my); */
492         kex->flags &= ~KEX_INIT_SENT;
493         free(kex->name);
494         kex->name = NULL;
495         return 0;
496 }
497
498 int
499 kex_send_kexinit(struct ssh *ssh)
500 {
501         u_char *cookie;
502         struct kex *kex = ssh->kex;
503         int r;
504
505         if (kex == NULL)
506                 return SSH_ERR_INTERNAL_ERROR;
507         if (kex->flags & KEX_INIT_SENT)
508                 return 0;
509         kex->done = 0;
510
511         /* generate a random cookie */
512         if (sshbuf_len(kex->my) < KEX_COOKIE_LEN)
513                 return SSH_ERR_INVALID_FORMAT;
514         if ((cookie = sshbuf_mutable_ptr(kex->my)) == NULL)
515                 return SSH_ERR_INTERNAL_ERROR;
516         arc4random_buf(cookie, KEX_COOKIE_LEN);
517
518         if ((r = sshpkt_start(ssh, SSH2_MSG_KEXINIT)) != 0 ||
519             (r = sshpkt_putb(ssh, kex->my)) != 0 ||
520             (r = sshpkt_send(ssh)) != 0)
521                 return r;
522         debug("SSH2_MSG_KEXINIT sent");
523         kex->flags |= KEX_INIT_SENT;
524         return 0;
525 }
526
527 /* ARGSUSED */
528 int
529 kex_input_kexinit(int type, u_int32_t seq, struct ssh *ssh)
530 {
531         struct kex *kex = ssh->kex;
532         const u_char *ptr;
533         u_int i;
534         size_t dlen;
535         int r;
536
537         debug("SSH2_MSG_KEXINIT received");
538         if (kex == NULL)
539                 return SSH_ERR_INVALID_ARGUMENT;
540
541         ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, NULL);
542         ptr = sshpkt_ptr(ssh, &dlen);
543         if ((r = sshbuf_put(kex->peer, ptr, dlen)) != 0)
544                 return r;
545
546         /* discard packet */
547         for (i = 0; i < KEX_COOKIE_LEN; i++)
548                 if ((r = sshpkt_get_u8(ssh, NULL)) != 0)
549                         return r;
550         for (i = 0; i < PROPOSAL_MAX; i++)
551                 if ((r = sshpkt_get_string(ssh, NULL, NULL)) != 0)
552                         return r;
553         /*
554          * XXX RFC4253 sec 7: "each side MAY guess" - currently no supported
555          * KEX method has the server move first, but a server might be using
556          * a custom method or one that we otherwise don't support. We should
557          * be prepared to remember first_kex_follows here so we can eat a
558          * packet later.
559          * XXX2 - RFC4253 is kind of ambiguous on what first_kex_follows means
560          * for cases where the server *doesn't* go first. I guess we should
561          * ignore it when it is set for these cases, which is what we do now.
562          */
563         if ((r = sshpkt_get_u8(ssh, NULL)) != 0 ||      /* first_kex_follows */
564             (r = sshpkt_get_u32(ssh, NULL)) != 0 ||     /* reserved */
565             (r = sshpkt_get_end(ssh)) != 0)
566                         return r;
567
568         if (!(kex->flags & KEX_INIT_SENT))
569                 if ((r = kex_send_kexinit(ssh)) != 0)
570                         return r;
571         if ((r = kex_choose_conf(ssh)) != 0)
572                 return r;
573
574         if (kex->kex_type < KEX_MAX && kex->kex[kex->kex_type] != NULL)
575                 return (kex->kex[kex->kex_type])(ssh);
576
577         return SSH_ERR_INTERNAL_ERROR;
578 }
579
580 int
581 kex_new(struct ssh *ssh, char *proposal[PROPOSAL_MAX], struct kex **kexp)
582 {
583         struct kex *kex;
584         int r;
585
586         *kexp = NULL;
587         if ((kex = calloc(1, sizeof(*kex))) == NULL)
588                 return SSH_ERR_ALLOC_FAIL;
589         if ((kex->peer = sshbuf_new()) == NULL ||
590             (kex->my = sshbuf_new()) == NULL) {
591                 r = SSH_ERR_ALLOC_FAIL;
592                 goto out;
593         }
594         if ((r = kex_prop2buf(kex->my, proposal)) != 0)
595                 goto out;
596         kex->done = 0;
597         kex_reset_dispatch(ssh);
598         ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
599         r = 0;
600         *kexp = kex;
601  out:
602         if (r != 0)
603                 kex_free(kex);
604         return r;
605 }
606
607 void
608 kex_free_newkeys(struct newkeys *newkeys)
609 {
610         if (newkeys == NULL)
611                 return;
612         if (newkeys->enc.key) {
613                 explicit_bzero(newkeys->enc.key, newkeys->enc.key_len);
614                 free(newkeys->enc.key);
615                 newkeys->enc.key = NULL;
616         }
617         if (newkeys->enc.iv) {
618                 explicit_bzero(newkeys->enc.iv, newkeys->enc.iv_len);
619                 free(newkeys->enc.iv);
620                 newkeys->enc.iv = NULL;
621         }
622         free(newkeys->enc.name);
623         explicit_bzero(&newkeys->enc, sizeof(newkeys->enc));
624         free(newkeys->comp.name);
625         explicit_bzero(&newkeys->comp, sizeof(newkeys->comp));
626         mac_clear(&newkeys->mac);
627         if (newkeys->mac.key) {
628                 explicit_bzero(newkeys->mac.key, newkeys->mac.key_len);
629                 free(newkeys->mac.key);
630                 newkeys->mac.key = NULL;
631         }
632         free(newkeys->mac.name);
633         explicit_bzero(&newkeys->mac, sizeof(newkeys->mac));
634         explicit_bzero(newkeys, sizeof(*newkeys));
635         free(newkeys);
636 }
637
638 void
639 kex_free(struct kex *kex)
640 {
641         u_int mode;
642
643 #ifdef WITH_OPENSSL
644         DH_free(kex->dh);
645 #ifdef OPENSSL_HAS_ECC
646         EC_KEY_free(kex->ec_client_key);
647 #endif /* OPENSSL_HAS_ECC */
648 #endif /* WITH_OPENSSL */
649         for (mode = 0; mode < MODE_MAX; mode++) {
650                 kex_free_newkeys(kex->newkeys[mode]);
651                 kex->newkeys[mode] = NULL;
652         }
653         sshbuf_free(kex->peer);
654         sshbuf_free(kex->my);
655         free(kex->session_id);
656         free(kex->client_version_string);
657         free(kex->server_version_string);
658         free(kex->failed_choice);
659         free(kex->hostkey_alg);
660         free(kex->name);
661         free(kex);
662 }
663
664 int
665 kex_setup(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
666 {
667         int r;
668
669         if ((r = kex_new(ssh, proposal, &ssh->kex)) != 0)
670                 return r;
671         if ((r = kex_send_kexinit(ssh)) != 0) {         /* we start */
672                 kex_free(ssh->kex);
673                 ssh->kex = NULL;
674                 return r;
675         }
676         return 0;
677 }
678
679 /*
680  * Request key re-exchange, returns 0 on success or a ssherr.h error
681  * code otherwise. Must not be called if KEX is incomplete or in-progress.
682  */
683 int
684 kex_start_rekex(struct ssh *ssh)
685 {
686         if (ssh->kex == NULL) {
687                 error("%s: no kex", __func__);
688                 return SSH_ERR_INTERNAL_ERROR;
689         }
690         if (ssh->kex->done == 0) {
691                 error("%s: requested twice", __func__);
692                 return SSH_ERR_INTERNAL_ERROR;
693         }
694         ssh->kex->done = 0;
695         return kex_send_kexinit(ssh);
696 }
697
698 static int
699 choose_enc(struct sshenc *enc, char *client, char *server)
700 {
701         char *name = match_list(client, server, NULL);
702
703         if (name == NULL)
704                 return SSH_ERR_NO_CIPHER_ALG_MATCH;
705         if ((enc->cipher = cipher_by_name(name)) == NULL) {
706                 free(name);
707                 return SSH_ERR_INTERNAL_ERROR;
708         }
709         enc->name = name;
710         enc->enabled = 0;
711         enc->iv = NULL;
712         enc->iv_len = cipher_ivlen(enc->cipher);
713         enc->key = NULL;
714         enc->key_len = cipher_keylen(enc->cipher);
715         enc->block_size = cipher_blocksize(enc->cipher);
716         return 0;
717 }
718
719 static int
720 choose_mac(struct ssh *ssh, struct sshmac *mac, char *client, char *server)
721 {
722         char *name = match_list(client, server, NULL);
723
724         if (name == NULL)
725                 return SSH_ERR_NO_MAC_ALG_MATCH;
726         if (mac_setup(mac, name) < 0) {
727                 free(name);
728                 return SSH_ERR_INTERNAL_ERROR;
729         }
730         mac->name = name;
731         mac->key = NULL;
732         mac->enabled = 0;
733         return 0;
734 }
735
736 static int
737 choose_comp(struct sshcomp *comp, char *client, char *server)
738 {
739         char *name = match_list(client, server, NULL);
740
741         if (name == NULL)
742                 return SSH_ERR_NO_COMPRESS_ALG_MATCH;
743         if (strcmp(name, "zlib@openssh.com") == 0) {
744                 comp->type = COMP_DELAYED;
745         } else if (strcmp(name, "zlib") == 0) {
746                 comp->type = COMP_ZLIB;
747         } else if (strcmp(name, "none") == 0) {
748                 comp->type = COMP_NONE;
749         } else {
750                 free(name);
751                 return SSH_ERR_INTERNAL_ERROR;
752         }
753         comp->name = name;
754         return 0;
755 }
756
757 static int
758 choose_kex(struct kex *k, char *client, char *server)
759 {
760         const struct kexalg *kexalg;
761
762         k->name = match_list(client, server, NULL);
763
764         debug("kex: algorithm: %s", k->name ? k->name : "(no match)");
765         if (k->name == NULL)
766                 return SSH_ERR_NO_KEX_ALG_MATCH;
767         if ((kexalg = kex_alg_by_name(k->name)) == NULL)
768                 return SSH_ERR_INTERNAL_ERROR;
769         k->kex_type = kexalg->type;
770         k->hash_alg = kexalg->hash_alg;
771         k->ec_nid = kexalg->ec_nid;
772         return 0;
773 }
774
775 static int
776 choose_hostkeyalg(struct kex *k, char *client, char *server)
777 {
778         k->hostkey_alg = match_list(client, server, NULL);
779
780         debug("kex: host key algorithm: %s",
781             k->hostkey_alg ? k->hostkey_alg : "(no match)");
782         if (k->hostkey_alg == NULL)
783                 return SSH_ERR_NO_HOSTKEY_ALG_MATCH;
784         k->hostkey_type = sshkey_type_from_name(k->hostkey_alg);
785         if (k->hostkey_type == KEY_UNSPEC)
786                 return SSH_ERR_INTERNAL_ERROR;
787         k->hostkey_nid = sshkey_ecdsa_nid_from_name(k->hostkey_alg);
788         return 0;
789 }
790
791 static int
792 proposals_match(char *my[PROPOSAL_MAX], char *peer[PROPOSAL_MAX])
793 {
794         static int check[] = {
795                 PROPOSAL_KEX_ALGS, PROPOSAL_SERVER_HOST_KEY_ALGS, -1
796         };
797         int *idx;
798         char *p;
799
800         for (idx = &check[0]; *idx != -1; idx++) {
801                 if ((p = strchr(my[*idx], ',')) != NULL)
802                         *p = '\0';
803                 if ((p = strchr(peer[*idx], ',')) != NULL)
804                         *p = '\0';
805                 if (strcmp(my[*idx], peer[*idx]) != 0) {
806                         debug2("proposal mismatch: my %s peer %s",
807                             my[*idx], peer[*idx]);
808                         return (0);
809                 }
810         }
811         debug2("proposals match");
812         return (1);
813 }
814
815 static int
816 kex_choose_conf(struct ssh *ssh)
817 {
818         struct kex *kex = ssh->kex;
819         struct newkeys *newkeys;
820         char **my = NULL, **peer = NULL;
821         char **cprop, **sprop;
822         int nenc, nmac, ncomp;
823         u_int mode, ctos, need, dh_need, authlen;
824         int r, first_kex_follows;
825
826         debug2("local %s KEXINIT proposal", kex->server ? "server" : "client");
827         if ((r = kex_buf2prop(kex->my, NULL, &my)) != 0)
828                 goto out;
829         debug2("peer %s KEXINIT proposal", kex->server ? "client" : "server");
830         if ((r = kex_buf2prop(kex->peer, &first_kex_follows, &peer)) != 0)
831                 goto out;
832
833         if (kex->server) {
834                 cprop=peer;
835                 sprop=my;
836         } else {
837                 cprop=my;
838                 sprop=peer;
839         }
840
841         /* Check whether client supports ext_info_c */
842         if (kex->server) {
843                 char *ext;
844
845                 ext = match_list("ext-info-c", peer[PROPOSAL_KEX_ALGS], NULL);
846                 kex->ext_info_c = (ext != NULL);
847                 free(ext);
848         }
849
850         /* Algorithm Negotiation */
851         if ((r = choose_kex(kex, cprop[PROPOSAL_KEX_ALGS],
852             sprop[PROPOSAL_KEX_ALGS])) != 0) {
853                 kex->failed_choice = peer[PROPOSAL_KEX_ALGS];
854                 peer[PROPOSAL_KEX_ALGS] = NULL;
855                 goto out;
856         }
857         if ((r = choose_hostkeyalg(kex, cprop[PROPOSAL_SERVER_HOST_KEY_ALGS],
858             sprop[PROPOSAL_SERVER_HOST_KEY_ALGS])) != 0) {
859                 kex->failed_choice = peer[PROPOSAL_SERVER_HOST_KEY_ALGS];
860                 peer[PROPOSAL_SERVER_HOST_KEY_ALGS] = NULL;
861                 goto out;
862         }
863         for (mode = 0; mode < MODE_MAX; mode++) {
864                 if ((newkeys = calloc(1, sizeof(*newkeys))) == NULL) {
865                         r = SSH_ERR_ALLOC_FAIL;
866                         goto out;
867                 }
868                 kex->newkeys[mode] = newkeys;
869                 ctos = (!kex->server && mode == MODE_OUT) ||
870                     (kex->server && mode == MODE_IN);
871                 nenc  = ctos ? PROPOSAL_ENC_ALGS_CTOS  : PROPOSAL_ENC_ALGS_STOC;
872                 nmac  = ctos ? PROPOSAL_MAC_ALGS_CTOS  : PROPOSAL_MAC_ALGS_STOC;
873                 ncomp = ctos ? PROPOSAL_COMP_ALGS_CTOS : PROPOSAL_COMP_ALGS_STOC;
874                 if ((r = choose_enc(&newkeys->enc, cprop[nenc],
875                     sprop[nenc])) != 0) {
876                         kex->failed_choice = peer[nenc];
877                         peer[nenc] = NULL;
878                         goto out;
879                 }
880                 authlen = cipher_authlen(newkeys->enc.cipher);
881                 /* ignore mac for authenticated encryption */
882                 if (authlen == 0 &&
883                     (r = choose_mac(ssh, &newkeys->mac, cprop[nmac],
884                     sprop[nmac])) != 0) {
885                         kex->failed_choice = peer[nmac];
886                         peer[nmac] = NULL;
887                         goto out;
888                 }
889                 if ((r = choose_comp(&newkeys->comp, cprop[ncomp],
890                     sprop[ncomp])) != 0) {
891                         kex->failed_choice = peer[ncomp];
892                         peer[ncomp] = NULL;
893                         goto out;
894                 }
895                 debug("kex: %s cipher: %s MAC: %s compression: %s",
896                     ctos ? "client->server" : "server->client",
897                     newkeys->enc.name,
898                     authlen == 0 ? newkeys->mac.name : "<implicit>",
899                     newkeys->comp.name);
900         }
901         need = dh_need = 0;
902         for (mode = 0; mode < MODE_MAX; mode++) {
903                 newkeys = kex->newkeys[mode];
904                 need = MAXIMUM(need, newkeys->enc.key_len);
905                 need = MAXIMUM(need, newkeys->enc.block_size);
906                 need = MAXIMUM(need, newkeys->enc.iv_len);
907                 need = MAXIMUM(need, newkeys->mac.key_len);
908                 dh_need = MAXIMUM(dh_need, cipher_seclen(newkeys->enc.cipher));
909                 dh_need = MAXIMUM(dh_need, newkeys->enc.block_size);
910                 dh_need = MAXIMUM(dh_need, newkeys->enc.iv_len);
911                 dh_need = MAXIMUM(dh_need, newkeys->mac.key_len);
912         }
913         /* XXX need runden? */
914         kex->we_need = need;
915         kex->dh_need = dh_need;
916
917         /* ignore the next message if the proposals do not match */
918         if (first_kex_follows && !proposals_match(my, peer))
919                 ssh->dispatch_skip_packets = 1;
920         r = 0;
921  out:
922         kex_prop_free(my);
923         kex_prop_free(peer);
924         return r;
925 }
926
927 static int
928 derive_key(struct ssh *ssh, int id, u_int need, u_char *hash, u_int hashlen,
929     const struct sshbuf *shared_secret, u_char **keyp)
930 {
931         struct kex *kex = ssh->kex;
932         struct ssh_digest_ctx *hashctx = NULL;
933         char c = id;
934         u_int have;
935         size_t mdsz;
936         u_char *digest;
937         int r;
938
939         if ((mdsz = ssh_digest_bytes(kex->hash_alg)) == 0)
940                 return SSH_ERR_INVALID_ARGUMENT;
941         if ((digest = calloc(1, ROUNDUP(need, mdsz))) == NULL) {
942                 r = SSH_ERR_ALLOC_FAIL;
943                 goto out;
944         }
945
946         /* K1 = HASH(K || H || "A" || session_id) */
947         if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
948             ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
949             ssh_digest_update(hashctx, hash, hashlen) != 0 ||
950             ssh_digest_update(hashctx, &c, 1) != 0 ||
951             ssh_digest_update(hashctx, kex->session_id,
952             kex->session_id_len) != 0 ||
953             ssh_digest_final(hashctx, digest, mdsz) != 0) {
954                 r = SSH_ERR_LIBCRYPTO_ERROR;
955                 goto out;
956         }
957         ssh_digest_free(hashctx);
958         hashctx = NULL;
959
960         /*
961          * expand key:
962          * Kn = HASH(K || H || K1 || K2 || ... || Kn-1)
963          * Key = K1 || K2 || ... || Kn
964          */
965         for (have = mdsz; need > have; have += mdsz) {
966                 if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
967                     ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
968                     ssh_digest_update(hashctx, hash, hashlen) != 0 ||
969                     ssh_digest_update(hashctx, digest, have) != 0 ||
970                     ssh_digest_final(hashctx, digest + have, mdsz) != 0) {
971                         r = SSH_ERR_LIBCRYPTO_ERROR;
972                         goto out;
973                 }
974                 ssh_digest_free(hashctx);
975                 hashctx = NULL;
976         }
977 #ifdef DEBUG_KEX
978         fprintf(stderr, "key '%c'== ", c);
979         dump_digest("key", digest, need);
980 #endif
981         *keyp = digest;
982         digest = NULL;
983         r = 0;
984  out:
985         free(digest);
986         ssh_digest_free(hashctx);
987         return r;
988 }
989
990 #define NKEYS   6
991 int
992 kex_derive_keys(struct ssh *ssh, u_char *hash, u_int hashlen,
993     const struct sshbuf *shared_secret)
994 {
995         struct kex *kex = ssh->kex;
996         u_char *keys[NKEYS];
997         u_int i, j, mode, ctos;
998         int r;
999
1000         for (i = 0; i < NKEYS; i++) {
1001                 if ((r = derive_key(ssh, 'A'+i, kex->we_need, hash, hashlen,
1002                     shared_secret, &keys[i])) != 0) {
1003                         for (j = 0; j < i; j++)
1004                                 free(keys[j]);
1005                         return r;
1006                 }
1007         }
1008         for (mode = 0; mode < MODE_MAX; mode++) {
1009                 ctos = (!kex->server && mode == MODE_OUT) ||
1010                     (kex->server && mode == MODE_IN);
1011                 kex->newkeys[mode]->enc.iv  = keys[ctos ? 0 : 1];
1012                 kex->newkeys[mode]->enc.key = keys[ctos ? 2 : 3];
1013                 kex->newkeys[mode]->mac.key = keys[ctos ? 4 : 5];
1014         }
1015         return 0;
1016 }
1017
1018 #ifdef WITH_OPENSSL
1019 int
1020 kex_derive_keys_bn(struct ssh *ssh, u_char *hash, u_int hashlen,
1021     const BIGNUM *secret)
1022 {
1023         struct sshbuf *shared_secret;
1024         int r;
1025
1026         if ((shared_secret = sshbuf_new()) == NULL)
1027                 return SSH_ERR_ALLOC_FAIL;
1028         if ((r = sshbuf_put_bignum2(shared_secret, secret)) == 0)
1029                 r = kex_derive_keys(ssh, hash, hashlen, shared_secret);
1030         sshbuf_free(shared_secret);
1031         return r;
1032 }
1033 #endif
1034
1035
1036 #if defined(DEBUG_KEX) || defined(DEBUG_KEXDH) || defined(DEBUG_KEXECDH)
1037 void
1038 dump_digest(char *msg, u_char *digest, int len)
1039 {
1040         fprintf(stderr, "%s\n", msg);
1041         sshbuf_dump_data(digest, len, stderr);
1042 }
1043 #endif