]> CyberLeo.Net >> Repos - FreeBSD/releng/10.3.git/blob - crypto/openssh/kex.c
Fix OpenSSH remote Denial of Service vulnerability. [SA-16:33]
[FreeBSD/releng/10.3.git] / crypto / openssh / kex.c
1 /* $OpenBSD: kex.c,v 1.117 2016/02/08 10:57:07 djm Exp $ */
2 /*
3  * Copyright (c) 2000, 2001 Markus Friedl.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
15  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
16  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
17  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
18  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
19  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
20  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
21  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
23  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24  */
25
26 #include "includes.h"
27
28 #include <sys/param.h>  /* MAX roundup */
29
30 #include <signal.h>
31 #include <stdarg.h>
32 #include <stdio.h>
33 #include <stdlib.h>
34 #include <string.h>
35
36 #ifdef WITH_OPENSSL
37 #include <openssl/crypto.h>
38 #endif
39
40 #include "ssh2.h"
41 #include "packet.h"
42 #include "compat.h"
43 #include "cipher.h"
44 #include "sshkey.h"
45 #include "kex.h"
46 #include "log.h"
47 #include "mac.h"
48 #include "match.h"
49 #include "misc.h"
50 #include "dispatch.h"
51 #include "monitor.h"
52
53 #include "ssherr.h"
54 #include "sshbuf.h"
55 #include "digest.h"
56
57 #if OPENSSL_VERSION_NUMBER >= 0x00907000L
58 # if defined(HAVE_EVP_SHA256)
59 # define evp_ssh_sha256 EVP_sha256
60 # else
61 extern const EVP_MD *evp_ssh_sha256(void);
62 # endif
63 #endif
64
65 /* prototype */
66 static int kex_choose_conf(struct ssh *);
67 static int kex_input_newkeys(int, u_int32_t, void *);
68
69 static const char *proposal_names[PROPOSAL_MAX] = {
70         "KEX algorithms",
71         "host key algorithms",
72         "ciphers ctos",
73         "ciphers stoc",
74         "MACs ctos",
75         "MACs stoc",
76         "compression ctos",
77         "compression stoc",
78         "languages ctos",
79         "languages stoc",
80 };
81
82 struct kexalg {
83         char *name;
84         u_int type;
85         int ec_nid;
86         int hash_alg;
87 };
88 static const struct kexalg kexalgs[] = {
89 #ifdef WITH_OPENSSL
90         { KEX_DH1, KEX_DH_GRP1_SHA1, 0, SSH_DIGEST_SHA1 },
91         { KEX_DH14, KEX_DH_GRP14_SHA1, 0, SSH_DIGEST_SHA1 },
92         { KEX_DHGEX_SHA1, KEX_DH_GEX_SHA1, 0, SSH_DIGEST_SHA1 },
93 #ifdef HAVE_EVP_SHA256
94         { KEX_DHGEX_SHA256, KEX_DH_GEX_SHA256, 0, SSH_DIGEST_SHA256 },
95 #endif /* HAVE_EVP_SHA256 */
96 #ifdef OPENSSL_HAS_ECC
97         { KEX_ECDH_SHA2_NISTP256, KEX_ECDH_SHA2,
98             NID_X9_62_prime256v1, SSH_DIGEST_SHA256 },
99         { KEX_ECDH_SHA2_NISTP384, KEX_ECDH_SHA2, NID_secp384r1,
100             SSH_DIGEST_SHA384 },
101 # ifdef OPENSSL_HAS_NISTP521
102         { KEX_ECDH_SHA2_NISTP521, KEX_ECDH_SHA2, NID_secp521r1,
103             SSH_DIGEST_SHA512 },
104 # endif /* OPENSSL_HAS_NISTP521 */
105 #endif /* OPENSSL_HAS_ECC */
106 #endif /* WITH_OPENSSL */
107 #if defined(HAVE_EVP_SHA256) || !defined(WITH_OPENSSL)
108         { KEX_CURVE25519_SHA256, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
109 #endif /* HAVE_EVP_SHA256 || !WITH_OPENSSL */
110         { NULL, -1, -1, -1},
111 };
112
113 char *
114 kex_alg_list(char sep)
115 {
116         char *ret = NULL, *tmp;
117         size_t nlen, rlen = 0;
118         const struct kexalg *k;
119
120         for (k = kexalgs; k->name != NULL; k++) {
121                 if (ret != NULL)
122                         ret[rlen++] = sep;
123                 nlen = strlen(k->name);
124                 if ((tmp = realloc(ret, rlen + nlen + 2)) == NULL) {
125                         free(ret);
126                         return NULL;
127                 }
128                 ret = tmp;
129                 memcpy(ret + rlen, k->name, nlen + 1);
130                 rlen += nlen;
131         }
132         return ret;
133 }
134
135 static const struct kexalg *
136 kex_alg_by_name(const char *name)
137 {
138         const struct kexalg *k;
139
140         for (k = kexalgs; k->name != NULL; k++) {
141                 if (strcmp(k->name, name) == 0)
142                         return k;
143         }
144         return NULL;
145 }
146
147 /* Validate KEX method name list */
148 int
149 kex_names_valid(const char *names)
150 {
151         char *s, *cp, *p;
152
153         if (names == NULL || strcmp(names, "") == 0)
154                 return 0;
155         if ((s = cp = strdup(names)) == NULL)
156                 return 0;
157         for ((p = strsep(&cp, ",")); p && *p != '\0';
158             (p = strsep(&cp, ","))) {
159                 if (kex_alg_by_name(p) == NULL) {
160                         error("Unsupported KEX algorithm \"%.100s\"", p);
161                         free(s);
162                         return 0;
163                 }
164         }
165         debug3("kex names ok: [%s]", names);
166         free(s);
167         return 1;
168 }
169
170 /*
171  * Concatenate algorithm names, avoiding duplicates in the process.
172  * Caller must free returned string.
173  */
174 char *
175 kex_names_cat(const char *a, const char *b)
176 {
177         char *ret = NULL, *tmp = NULL, *cp, *p;
178         size_t len;
179
180         if (a == NULL || *a == '\0')
181                 return NULL;
182         if (b == NULL || *b == '\0')
183                 return strdup(a);
184         if (strlen(b) > 1024*1024)
185                 return NULL;
186         len = strlen(a) + strlen(b) + 2;
187         if ((tmp = cp = strdup(b)) == NULL ||
188             (ret = calloc(1, len)) == NULL) {
189                 free(tmp);
190                 return NULL;
191         }
192         strlcpy(ret, a, len);
193         for ((p = strsep(&cp, ",")); p && *p != '\0'; (p = strsep(&cp, ","))) {
194                 if (match_list(ret, p, NULL) != NULL)
195                         continue; /* Algorithm already present */
196                 if (strlcat(ret, ",", len) >= len ||
197                     strlcat(ret, p, len) >= len) {
198                         free(tmp);
199                         free(ret);
200                         return NULL; /* Shouldn't happen */
201                 }
202         }
203         free(tmp);
204         return ret;
205 }
206
207 /*
208  * Assemble a list of algorithms from a default list and a string from a
209  * configuration file. The user-provided string may begin with '+' to
210  * indicate that it should be appended to the default.
211  */
212 int
213 kex_assemble_names(const char *def, char **list)
214 {
215         char *ret;
216
217         if (list == NULL || *list == NULL || **list == '\0') {
218                 *list = strdup(def);
219                 return 0;
220         }
221         if (**list != '+') {
222                 return 0;
223         }
224
225         if ((ret = kex_names_cat(def, *list + 1)) == NULL)
226                 return SSH_ERR_ALLOC_FAIL;
227         free(*list);
228         *list = ret;
229         return 0;
230 }
231
232 /* put algorithm proposal into buffer */
233 int
234 kex_prop2buf(struct sshbuf *b, char *proposal[PROPOSAL_MAX])
235 {
236         u_int i;
237         int r;
238
239         sshbuf_reset(b);
240
241         /*
242          * add a dummy cookie, the cookie will be overwritten by
243          * kex_send_kexinit(), each time a kexinit is set
244          */
245         for (i = 0; i < KEX_COOKIE_LEN; i++) {
246                 if ((r = sshbuf_put_u8(b, 0)) != 0)
247                         return r;
248         }
249         for (i = 0; i < PROPOSAL_MAX; i++) {
250                 if ((r = sshbuf_put_cstring(b, proposal[i])) != 0)
251                         return r;
252         }
253         if ((r = sshbuf_put_u8(b, 0)) != 0 ||   /* first_kex_packet_follows */
254             (r = sshbuf_put_u32(b, 0)) != 0)    /* uint32 reserved */
255                 return r;
256         return 0;
257 }
258
259 /* parse buffer and return algorithm proposal */
260 int
261 kex_buf2prop(struct sshbuf *raw, int *first_kex_follows, char ***propp)
262 {
263         struct sshbuf *b = NULL;
264         u_char v;
265         u_int i;
266         char **proposal = NULL;
267         int r;
268
269         *propp = NULL;
270         if ((proposal = calloc(PROPOSAL_MAX, sizeof(char *))) == NULL)
271                 return SSH_ERR_ALLOC_FAIL;
272         if ((b = sshbuf_fromb(raw)) == NULL) {
273                 r = SSH_ERR_ALLOC_FAIL;
274                 goto out;
275         }
276         if ((r = sshbuf_consume(b, KEX_COOKIE_LEN)) != 0) /* skip cookie */
277                 goto out;
278         /* extract kex init proposal strings */
279         for (i = 0; i < PROPOSAL_MAX; i++) {
280                 if ((r = sshbuf_get_cstring(b, &(proposal[i]), NULL)) != 0)
281                         goto out;
282                 debug2("%s: %s", proposal_names[i], proposal[i]);
283         }
284         /* first kex follows / reserved */
285         if ((r = sshbuf_get_u8(b, &v)) != 0 ||  /* first_kex_follows */
286             (r = sshbuf_get_u32(b, &i)) != 0)   /* reserved */
287                 goto out;
288         if (first_kex_follows != NULL)
289                 *first_kex_follows = v;
290         debug2("first_kex_follows %d ", v);
291         debug2("reserved %u ", i);
292         r = 0;
293         *propp = proposal;
294  out:
295         if (r != 0 && proposal != NULL)
296                 kex_prop_free(proposal);
297         sshbuf_free(b);
298         return r;
299 }
300
301 void
302 kex_prop_free(char **proposal)
303 {
304         u_int i;
305
306         if (proposal == NULL)
307                 return;
308         for (i = 0; i < PROPOSAL_MAX; i++)
309                 free(proposal[i]);
310         free(proposal);
311 }
312
313 /* ARGSUSED */
314 static int
315 kex_protocol_error(int type, u_int32_t seq, void *ctxt)
316 {
317         struct ssh *ssh = active_state; /* XXX */
318         int r;
319
320         error("kex protocol error: type %d seq %u", type, seq);
321         if ((r = sshpkt_start(ssh, SSH2_MSG_UNIMPLEMENTED)) != 0 ||
322             (r = sshpkt_put_u32(ssh, seq)) != 0 ||
323             (r = sshpkt_send(ssh)) != 0)
324                 return r;
325         return 0;
326 }
327
328 static void
329 kex_reset_dispatch(struct ssh *ssh)
330 {
331         ssh_dispatch_range(ssh, SSH2_MSG_TRANSPORT_MIN,
332             SSH2_MSG_TRANSPORT_MAX, &kex_protocol_error);
333         ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
334 }
335
336 static int
337 kex_send_ext_info(struct ssh *ssh)
338 {
339         int r;
340
341         if ((r = sshpkt_start(ssh, SSH2_MSG_EXT_INFO)) != 0 ||
342             (r = sshpkt_put_u32(ssh, 1)) != 0 ||
343             (r = sshpkt_put_cstring(ssh, "server-sig-algs")) != 0 ||
344             (r = sshpkt_put_cstring(ssh, "rsa-sha2-256,rsa-sha2-512")) != 0 ||
345             (r = sshpkt_send(ssh)) != 0)
346                 return r;
347         return 0;
348 }
349
350 int
351 kex_send_newkeys(struct ssh *ssh)
352 {
353         int r;
354
355         kex_reset_dispatch(ssh);
356         if ((r = sshpkt_start(ssh, SSH2_MSG_NEWKEYS)) != 0 ||
357             (r = sshpkt_send(ssh)) != 0)
358                 return r;
359         debug("SSH2_MSG_NEWKEYS sent");
360         debug("expecting SSH2_MSG_NEWKEYS");
361         ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_input_newkeys);
362         if (ssh->kex->ext_info_c)
363                 if ((r = kex_send_ext_info(ssh)) != 0)
364                         return r;
365         return 0;
366 }
367
368 int
369 kex_input_ext_info(int type, u_int32_t seq, void *ctxt)
370 {
371         struct ssh *ssh = ctxt;
372         struct kex *kex = ssh->kex;
373         u_int32_t i, ninfo;
374         char *name, *val, *found;
375         int r;
376
377         debug("SSH2_MSG_EXT_INFO received");
378         ssh_dispatch_set(ssh, SSH2_MSG_EXT_INFO, &kex_protocol_error);
379         if ((r = sshpkt_get_u32(ssh, &ninfo)) != 0)
380                 return r;
381         for (i = 0; i < ninfo; i++) {
382                 if ((r = sshpkt_get_cstring(ssh, &name, NULL)) != 0)
383                         return r;
384                 if ((r = sshpkt_get_cstring(ssh, &val, NULL)) != 0) {
385                         free(name);
386                         return r;
387                 }
388                 debug("%s: %s=<%s>", __func__, name, val);
389                 if (strcmp(name, "server-sig-algs") == 0) {
390                         found = match_list("rsa-sha2-256", val, NULL);
391                         if (found) {
392                                 kex->rsa_sha2 = 256;
393                                 free(found);
394                         }
395                         found = match_list("rsa-sha2-512", val, NULL);
396                         if (found) {
397                                 kex->rsa_sha2 = 512;
398                                 free(found);
399                         }
400                 }
401                 free(name);
402                 free(val);
403         }
404         return sshpkt_get_end(ssh);
405 }
406
407 static int
408 kex_input_newkeys(int type, u_int32_t seq, void *ctxt)
409 {
410         struct ssh *ssh = ctxt;
411         struct kex *kex = ssh->kex;
412         int r;
413
414         debug("SSH2_MSG_NEWKEYS received");
415         ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_protocol_error);
416         if ((r = sshpkt_get_end(ssh)) != 0)
417                 return r;
418         kex->done = 1;
419         sshbuf_reset(kex->peer);
420         /* sshbuf_reset(kex->my); */
421         kex->flags &= ~KEX_INIT_SENT;
422         free(kex->name);
423         kex->name = NULL;
424         return 0;
425 }
426
427 int
428 kex_send_kexinit(struct ssh *ssh)
429 {
430         u_char *cookie;
431         struct kex *kex = ssh->kex;
432         int r;
433
434         if (kex == NULL)
435                 return SSH_ERR_INTERNAL_ERROR;
436         if (kex->flags & KEX_INIT_SENT)
437                 return 0;
438         kex->done = 0;
439
440         /* generate a random cookie */
441         if (sshbuf_len(kex->my) < KEX_COOKIE_LEN)
442                 return SSH_ERR_INVALID_FORMAT;
443         if ((cookie = sshbuf_mutable_ptr(kex->my)) == NULL)
444                 return SSH_ERR_INTERNAL_ERROR;
445         arc4random_buf(cookie, KEX_COOKIE_LEN);
446
447         if ((r = sshpkt_start(ssh, SSH2_MSG_KEXINIT)) != 0 ||
448             (r = sshpkt_putb(ssh, kex->my)) != 0 ||
449             (r = sshpkt_send(ssh)) != 0)
450                 return r;
451         debug("SSH2_MSG_KEXINIT sent");
452         kex->flags |= KEX_INIT_SENT;
453         return 0;
454 }
455
456 /* ARGSUSED */
457 int
458 kex_input_kexinit(int type, u_int32_t seq, void *ctxt)
459 {
460         struct ssh *ssh = ctxt;
461         struct kex *kex = ssh->kex;
462         const u_char *ptr;
463         u_int i;
464         size_t dlen;
465         int r;
466
467         debug("SSH2_MSG_KEXINIT received");
468         if (kex == NULL)
469                 return SSH_ERR_INVALID_ARGUMENT;
470
471         ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, NULL);
472         ptr = sshpkt_ptr(ssh, &dlen);
473         if ((r = sshbuf_put(kex->peer, ptr, dlen)) != 0)
474                 return r;
475
476         /* discard packet */
477         for (i = 0; i < KEX_COOKIE_LEN; i++)
478                 if ((r = sshpkt_get_u8(ssh, NULL)) != 0)
479                         return r;
480         for (i = 0; i < PROPOSAL_MAX; i++)
481                 if ((r = sshpkt_get_string(ssh, NULL, NULL)) != 0)
482                         return r;
483         /*
484          * XXX RFC4253 sec 7: "each side MAY guess" - currently no supported
485          * KEX method has the server move first, but a server might be using
486          * a custom method or one that we otherwise don't support. We should
487          * be prepared to remember first_kex_follows here so we can eat a
488          * packet later.
489          * XXX2 - RFC4253 is kind of ambiguous on what first_kex_follows means
490          * for cases where the server *doesn't* go first. I guess we should
491          * ignore it when it is set for these cases, which is what we do now.
492          */
493         if ((r = sshpkt_get_u8(ssh, NULL)) != 0 ||      /* first_kex_follows */
494             (r = sshpkt_get_u32(ssh, NULL)) != 0 ||     /* reserved */
495             (r = sshpkt_get_end(ssh)) != 0)
496                         return r;
497
498         if (!(kex->flags & KEX_INIT_SENT))
499                 if ((r = kex_send_kexinit(ssh)) != 0)
500                         return r;
501         if ((r = kex_choose_conf(ssh)) != 0)
502                 return r;
503
504         if (kex->kex_type < KEX_MAX && kex->kex[kex->kex_type] != NULL)
505                 return (kex->kex[kex->kex_type])(ssh);
506
507         return SSH_ERR_INTERNAL_ERROR;
508 }
509
510 int
511 kex_new(struct ssh *ssh, char *proposal[PROPOSAL_MAX], struct kex **kexp)
512 {
513         struct kex *kex;
514         int r;
515
516         *kexp = NULL;
517         if ((kex = calloc(1, sizeof(*kex))) == NULL)
518                 return SSH_ERR_ALLOC_FAIL;
519         if ((kex->peer = sshbuf_new()) == NULL ||
520             (kex->my = sshbuf_new()) == NULL) {
521                 r = SSH_ERR_ALLOC_FAIL;
522                 goto out;
523         }
524         if ((r = kex_prop2buf(kex->my, proposal)) != 0)
525                 goto out;
526         kex->done = 0;
527         kex_reset_dispatch(ssh);
528         r = 0;
529         *kexp = kex;
530  out:
531         if (r != 0)
532                 kex_free(kex);
533         return r;
534 }
535
536 void
537 kex_free_newkeys(struct newkeys *newkeys)
538 {
539         if (newkeys == NULL)
540                 return;
541         if (newkeys->enc.key) {
542                 explicit_bzero(newkeys->enc.key, newkeys->enc.key_len);
543                 free(newkeys->enc.key);
544                 newkeys->enc.key = NULL;
545         }
546         if (newkeys->enc.iv) {
547                 explicit_bzero(newkeys->enc.iv, newkeys->enc.iv_len);
548                 free(newkeys->enc.iv);
549                 newkeys->enc.iv = NULL;
550         }
551         free(newkeys->enc.name);
552         explicit_bzero(&newkeys->enc, sizeof(newkeys->enc));
553         free(newkeys->comp.name);
554         explicit_bzero(&newkeys->comp, sizeof(newkeys->comp));
555         mac_clear(&newkeys->mac);
556         if (newkeys->mac.key) {
557                 explicit_bzero(newkeys->mac.key, newkeys->mac.key_len);
558                 free(newkeys->mac.key);
559                 newkeys->mac.key = NULL;
560         }
561         free(newkeys->mac.name);
562         explicit_bzero(&newkeys->mac, sizeof(newkeys->mac));
563         explicit_bzero(newkeys, sizeof(*newkeys));
564         free(newkeys);
565 }
566
567 void
568 kex_free(struct kex *kex)
569 {
570         u_int mode;
571
572 #ifdef WITH_OPENSSL
573         if (kex->dh)
574                 DH_free(kex->dh);
575 #ifdef OPENSSL_HAS_ECC
576         if (kex->ec_client_key)
577                 EC_KEY_free(kex->ec_client_key);
578 #endif /* OPENSSL_HAS_ECC */
579 #endif /* WITH_OPENSSL */
580         for (mode = 0; mode < MODE_MAX; mode++) {
581                 kex_free_newkeys(kex->newkeys[mode]);
582                 kex->newkeys[mode] = NULL;
583         }
584         sshbuf_free(kex->peer);
585         sshbuf_free(kex->my);
586         free(kex->session_id);
587         free(kex->client_version_string);
588         free(kex->server_version_string);
589         free(kex->failed_choice);
590         free(kex->hostkey_alg);
591         free(kex->name);
592         free(kex);
593 }
594
595 int
596 kex_setup(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
597 {
598         int r;
599
600         if ((r = kex_new(ssh, proposal, &ssh->kex)) != 0)
601                 return r;
602         if ((r = kex_send_kexinit(ssh)) != 0) {         /* we start */
603                 kex_free(ssh->kex);
604                 ssh->kex = NULL;
605                 return r;
606         }
607         return 0;
608 }
609
610 /*
611  * Request key re-exchange, returns 0 on success or a ssherr.h error
612  * code otherwise. Must not be called if KEX is incomplete or in-progress.
613  */
614 int
615 kex_start_rekex(struct ssh *ssh)
616 {
617         if (ssh->kex == NULL) {
618                 error("%s: no kex", __func__);
619                 return SSH_ERR_INTERNAL_ERROR;
620         }
621         if (ssh->kex->done == 0) {
622                 error("%s: requested twice", __func__);
623                 return SSH_ERR_INTERNAL_ERROR;
624         }
625         ssh->kex->done = 0;
626         return kex_send_kexinit(ssh);
627 }
628
629 static int
630 choose_enc(struct sshenc *enc, char *client, char *server)
631 {
632         char *name = match_list(client, server, NULL);
633
634         if (name == NULL)
635                 return SSH_ERR_NO_CIPHER_ALG_MATCH;
636         if ((enc->cipher = cipher_by_name(name)) == NULL)
637                 return SSH_ERR_INTERNAL_ERROR;
638         enc->name = name;
639         enc->enabled = 0;
640         enc->iv = NULL;
641         enc->iv_len = cipher_ivlen(enc->cipher);
642         enc->key = NULL;
643         enc->key_len = cipher_keylen(enc->cipher);
644         enc->block_size = cipher_blocksize(enc->cipher);
645         return 0;
646 }
647
648 static int
649 choose_mac(struct ssh *ssh, struct sshmac *mac, char *client, char *server)
650 {
651         char *name = match_list(client, server, NULL);
652
653         if (name == NULL)
654                 return SSH_ERR_NO_MAC_ALG_MATCH;
655         if (mac_setup(mac, name) < 0)
656                 return SSH_ERR_INTERNAL_ERROR;
657         /* truncate the key */
658         if (ssh->compat & SSH_BUG_HMAC)
659                 mac->key_len = 16;
660         mac->name = name;
661         mac->key = NULL;
662         mac->enabled = 0;
663         return 0;
664 }
665
666 static int
667 choose_comp(struct sshcomp *comp, char *client, char *server)
668 {
669         char *name = match_list(client, server, NULL);
670
671         if (name == NULL)
672                 return SSH_ERR_NO_COMPRESS_ALG_MATCH;
673         if (strcmp(name, "zlib@openssh.com") == 0) {
674                 comp->type = COMP_DELAYED;
675         } else if (strcmp(name, "zlib") == 0) {
676                 comp->type = COMP_ZLIB;
677         } else if (strcmp(name, "none") == 0) {
678                 comp->type = COMP_NONE;
679         } else {
680                 return SSH_ERR_INTERNAL_ERROR;
681         }
682         comp->name = name;
683         return 0;
684 }
685
686 static int
687 choose_kex(struct kex *k, char *client, char *server)
688 {
689         const struct kexalg *kexalg;
690
691         k->name = match_list(client, server, NULL);
692
693         debug("kex: algorithm: %s", k->name ? k->name : "(no match)");
694         if (k->name == NULL)
695                 return SSH_ERR_NO_KEX_ALG_MATCH;
696         if ((kexalg = kex_alg_by_name(k->name)) == NULL)
697                 return SSH_ERR_INTERNAL_ERROR;
698         k->kex_type = kexalg->type;
699         k->hash_alg = kexalg->hash_alg;
700         k->ec_nid = kexalg->ec_nid;
701         return 0;
702 }
703
704 static int
705 choose_hostkeyalg(struct kex *k, char *client, char *server)
706 {
707         k->hostkey_alg = match_list(client, server, NULL);
708
709         debug("kex: host key algorithm: %s",
710             k->hostkey_alg ? k->hostkey_alg : "(no match)");
711         if (k->hostkey_alg == NULL)
712                 return SSH_ERR_NO_HOSTKEY_ALG_MATCH;
713         k->hostkey_type = sshkey_type_from_name(k->hostkey_alg);
714         if (k->hostkey_type == KEY_UNSPEC)
715                 return SSH_ERR_INTERNAL_ERROR;
716         k->hostkey_nid = sshkey_ecdsa_nid_from_name(k->hostkey_alg);
717         return 0;
718 }
719
720 static int
721 proposals_match(char *my[PROPOSAL_MAX], char *peer[PROPOSAL_MAX])
722 {
723         static int check[] = {
724                 PROPOSAL_KEX_ALGS, PROPOSAL_SERVER_HOST_KEY_ALGS, -1
725         };
726         int *idx;
727         char *p;
728
729         for (idx = &check[0]; *idx != -1; idx++) {
730                 if ((p = strchr(my[*idx], ',')) != NULL)
731                         *p = '\0';
732                 if ((p = strchr(peer[*idx], ',')) != NULL)
733                         *p = '\0';
734                 if (strcmp(my[*idx], peer[*idx]) != 0) {
735                         debug2("proposal mismatch: my %s peer %s",
736                             my[*idx], peer[*idx]);
737                         return (0);
738                 }
739         }
740         debug2("proposals match");
741         return (1);
742 }
743
744 static int
745 kex_choose_conf(struct ssh *ssh)
746 {
747         struct kex *kex = ssh->kex;
748         struct newkeys *newkeys;
749         char **my = NULL, **peer = NULL;
750         char **cprop, **sprop;
751         int nenc, nmac, ncomp;
752         u_int mode, ctos, need, dh_need, authlen;
753         int r, first_kex_follows;
754
755         debug2("local %s KEXINIT proposal", kex->server ? "server" : "client");
756         if ((r = kex_buf2prop(kex->my, NULL, &my)) != 0)
757                 goto out;
758         debug2("peer %s KEXINIT proposal", kex->server ? "client" : "server");
759         if ((r = kex_buf2prop(kex->peer, &first_kex_follows, &peer)) != 0)
760                 goto out;
761
762         if (kex->server) {
763                 cprop=peer;
764                 sprop=my;
765         } else {
766                 cprop=my;
767                 sprop=peer;
768         }
769
770         /* Check whether client supports ext_info_c */
771         if (kex->server) {
772                 char *ext;
773
774                 ext = match_list("ext-info-c", peer[PROPOSAL_KEX_ALGS], NULL);
775                 if (ext) {
776                         kex->ext_info_c = 1;
777                         free(ext);
778                 }
779         }
780
781         /* Algorithm Negotiation */
782         if ((r = choose_kex(kex, cprop[PROPOSAL_KEX_ALGS],
783             sprop[PROPOSAL_KEX_ALGS])) != 0) {
784                 kex->failed_choice = peer[PROPOSAL_KEX_ALGS];
785                 peer[PROPOSAL_KEX_ALGS] = NULL;
786                 goto out;
787         }
788         if ((r = choose_hostkeyalg(kex, cprop[PROPOSAL_SERVER_HOST_KEY_ALGS],
789             sprop[PROPOSAL_SERVER_HOST_KEY_ALGS])) != 0) {
790                 kex->failed_choice = peer[PROPOSAL_SERVER_HOST_KEY_ALGS];
791                 peer[PROPOSAL_SERVER_HOST_KEY_ALGS] = NULL;
792                 goto out;
793         }
794         for (mode = 0; mode < MODE_MAX; mode++) {
795                 if ((newkeys = calloc(1, sizeof(*newkeys))) == NULL) {
796                         r = SSH_ERR_ALLOC_FAIL;
797                         goto out;
798                 }
799                 kex->newkeys[mode] = newkeys;
800                 ctos = (!kex->server && mode == MODE_OUT) ||
801                     (kex->server && mode == MODE_IN);
802                 nenc  = ctos ? PROPOSAL_ENC_ALGS_CTOS  : PROPOSAL_ENC_ALGS_STOC;
803                 nmac  = ctos ? PROPOSAL_MAC_ALGS_CTOS  : PROPOSAL_MAC_ALGS_STOC;
804                 ncomp = ctos ? PROPOSAL_COMP_ALGS_CTOS : PROPOSAL_COMP_ALGS_STOC;
805                 if ((r = choose_enc(&newkeys->enc, cprop[nenc],
806                     sprop[nenc])) != 0) {
807                         kex->failed_choice = peer[nenc];
808                         peer[nenc] = NULL;
809                         goto out;
810                 }
811                 authlen = cipher_authlen(newkeys->enc.cipher);
812                 /* ignore mac for authenticated encryption */
813                 if (authlen == 0 &&
814                     (r = choose_mac(ssh, &newkeys->mac, cprop[nmac],
815                     sprop[nmac])) != 0) {
816                         kex->failed_choice = peer[nmac];
817                         peer[nmac] = NULL;
818                         goto out;
819                 }
820                 if ((r = choose_comp(&newkeys->comp, cprop[ncomp],
821                     sprop[ncomp])) != 0) {
822                         kex->failed_choice = peer[ncomp];
823                         peer[ncomp] = NULL;
824                         goto out;
825                 }
826                 debug("kex: %s cipher: %s MAC: %s compression: %s",
827                     ctos ? "client->server" : "server->client",
828                     newkeys->enc.name,
829                     authlen == 0 ? newkeys->mac.name : "<implicit>",
830                     newkeys->comp.name);
831         }
832         need = dh_need = 0;
833         for (mode = 0; mode < MODE_MAX; mode++) {
834                 newkeys = kex->newkeys[mode];
835                 need = MAX(need, newkeys->enc.key_len);
836                 need = MAX(need, newkeys->enc.block_size);
837                 need = MAX(need, newkeys->enc.iv_len);
838                 need = MAX(need, newkeys->mac.key_len);
839                 dh_need = MAX(dh_need, cipher_seclen(newkeys->enc.cipher));
840                 dh_need = MAX(dh_need, newkeys->enc.block_size);
841                 dh_need = MAX(dh_need, newkeys->enc.iv_len);
842                 dh_need = MAX(dh_need, newkeys->mac.key_len);
843         }
844         /* XXX need runden? */
845         kex->we_need = need;
846         kex->dh_need = dh_need;
847
848         /* ignore the next message if the proposals do not match */
849         if (first_kex_follows && !proposals_match(my, peer) &&
850             !(ssh->compat & SSH_BUG_FIRSTKEX))
851                 ssh->dispatch_skip_packets = 1;
852         r = 0;
853  out:
854         kex_prop_free(my);
855         kex_prop_free(peer);
856         return r;
857 }
858
859 static int
860 derive_key(struct ssh *ssh, int id, u_int need, u_char *hash, u_int hashlen,
861     const struct sshbuf *shared_secret, u_char **keyp)
862 {
863         struct kex *kex = ssh->kex;
864         struct ssh_digest_ctx *hashctx = NULL;
865         char c = id;
866         u_int have;
867         size_t mdsz;
868         u_char *digest;
869         int r;
870
871         if ((mdsz = ssh_digest_bytes(kex->hash_alg)) == 0)
872                 return SSH_ERR_INVALID_ARGUMENT;
873         if ((digest = calloc(1, roundup(need, mdsz))) == NULL) {
874                 r = SSH_ERR_ALLOC_FAIL;
875                 goto out;
876         }
877
878         /* K1 = HASH(K || H || "A" || session_id) */
879         if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
880             ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
881             ssh_digest_update(hashctx, hash, hashlen) != 0 ||
882             ssh_digest_update(hashctx, &c, 1) != 0 ||
883             ssh_digest_update(hashctx, kex->session_id,
884             kex->session_id_len) != 0 ||
885             ssh_digest_final(hashctx, digest, mdsz) != 0) {
886                 r = SSH_ERR_LIBCRYPTO_ERROR;
887                 goto out;
888         }
889         ssh_digest_free(hashctx);
890         hashctx = NULL;
891
892         /*
893          * expand key:
894          * Kn = HASH(K || H || K1 || K2 || ... || Kn-1)
895          * Key = K1 || K2 || ... || Kn
896          */
897         for (have = mdsz; need > have; have += mdsz) {
898                 if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
899                     ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
900                     ssh_digest_update(hashctx, hash, hashlen) != 0 ||
901                     ssh_digest_update(hashctx, digest, have) != 0 ||
902                     ssh_digest_final(hashctx, digest + have, mdsz) != 0) {
903                         r = SSH_ERR_LIBCRYPTO_ERROR;
904                         goto out;
905                 }
906                 ssh_digest_free(hashctx);
907                 hashctx = NULL;
908         }
909 #ifdef DEBUG_KEX
910         fprintf(stderr, "key '%c'== ", c);
911         dump_digest("key", digest, need);
912 #endif
913         *keyp = digest;
914         digest = NULL;
915         r = 0;
916  out:
917         free(digest);
918         ssh_digest_free(hashctx);
919         return r;
920 }
921
922 #define NKEYS   6
923 int
924 kex_derive_keys(struct ssh *ssh, u_char *hash, u_int hashlen,
925     const struct sshbuf *shared_secret)
926 {
927         struct kex *kex = ssh->kex;
928         u_char *keys[NKEYS];
929         u_int i, j, mode, ctos;
930         int r;
931
932         for (i = 0; i < NKEYS; i++) {
933                 if ((r = derive_key(ssh, 'A'+i, kex->we_need, hash, hashlen,
934                     shared_secret, &keys[i])) != 0) {
935                         for (j = 0; j < i; j++)
936                                 free(keys[j]);
937                         return r;
938                 }
939         }
940         for (mode = 0; mode < MODE_MAX; mode++) {
941                 ctos = (!kex->server && mode == MODE_OUT) ||
942                     (kex->server && mode == MODE_IN);
943                 kex->newkeys[mode]->enc.iv  = keys[ctos ? 0 : 1];
944                 kex->newkeys[mode]->enc.key = keys[ctos ? 2 : 3];
945                 kex->newkeys[mode]->mac.key = keys[ctos ? 4 : 5];
946         }
947         return 0;
948 }
949
950 #ifdef WITH_OPENSSL
951 int
952 kex_derive_keys_bn(struct ssh *ssh, u_char *hash, u_int hashlen,
953     const BIGNUM *secret)
954 {
955         struct sshbuf *shared_secret;
956         int r;
957
958         if ((shared_secret = sshbuf_new()) == NULL)
959                 return SSH_ERR_ALLOC_FAIL;
960         if ((r = sshbuf_put_bignum2(shared_secret, secret)) == 0)
961                 r = kex_derive_keys(ssh, hash, hashlen, shared_secret);
962         sshbuf_free(shared_secret);
963         return r;
964 }
965 #endif
966
967 #ifdef WITH_SSH1
968 int
969 derive_ssh1_session_id(BIGNUM *host_modulus, BIGNUM *server_modulus,
970     u_int8_t cookie[8], u_int8_t id[16])
971 {
972         u_int8_t hbuf[2048], sbuf[2048], obuf[SSH_DIGEST_MAX_LENGTH];
973         struct ssh_digest_ctx *hashctx = NULL;
974         size_t hlen, slen;
975         int r;
976
977         hlen = BN_num_bytes(host_modulus);
978         slen = BN_num_bytes(server_modulus);
979         if (hlen < (512 / 8) || (u_int)hlen > sizeof(hbuf) ||
980             slen < (512 / 8) || (u_int)slen > sizeof(sbuf))
981                 return SSH_ERR_KEY_BITS_MISMATCH;
982         if (BN_bn2bin(host_modulus, hbuf) <= 0 ||
983             BN_bn2bin(server_modulus, sbuf) <= 0) {
984                 r = SSH_ERR_LIBCRYPTO_ERROR;
985                 goto out;
986         }
987         if ((hashctx = ssh_digest_start(SSH_DIGEST_MD5)) == NULL) {
988                 r = SSH_ERR_ALLOC_FAIL;
989                 goto out;
990         }
991         if (ssh_digest_update(hashctx, hbuf, hlen) != 0 ||
992             ssh_digest_update(hashctx, sbuf, slen) != 0 ||
993             ssh_digest_update(hashctx, cookie, 8) != 0 ||
994             ssh_digest_final(hashctx, obuf, sizeof(obuf)) != 0) {
995                 r = SSH_ERR_LIBCRYPTO_ERROR;
996                 goto out;
997         }
998         memcpy(id, obuf, ssh_digest_bytes(SSH_DIGEST_MD5));
999         r = 0;
1000  out:
1001         ssh_digest_free(hashctx);
1002         explicit_bzero(hbuf, sizeof(hbuf));
1003         explicit_bzero(sbuf, sizeof(sbuf));
1004         explicit_bzero(obuf, sizeof(obuf));
1005         return r;
1006 }
1007 #endif
1008
1009 #if defined(DEBUG_KEX) || defined(DEBUG_KEXDH) || defined(DEBUG_KEXECDH)
1010 void
1011 dump_digest(char *msg, u_char *digest, int len)
1012 {
1013         fprintf(stderr, "%s\n", msg);
1014         sshbuf_dump_data(digest, len, stderr);
1015 }
1016 #endif