]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - crypto/openssh/kex.c
Import zstandard 1.1.4 in base
[FreeBSD/FreeBSD.git] / crypto / openssh / kex.c
1 /* $OpenBSD: kex.c,v 1.127 2016/10/10 19:28:48 markus Exp $ */
2 /*
3  * Copyright (c) 2000, 2001 Markus Friedl.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
15  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
16  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
17  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
18  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
19  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
20  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
21  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
23  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24  */
25
26 #include "includes.h"
27
28
29 #include <signal.h>
30 #include <stdarg.h>
31 #include <stdio.h>
32 #include <stdlib.h>
33 #include <string.h>
34
35 #ifdef WITH_OPENSSL
36 #include <openssl/crypto.h>
37 #include <openssl/dh.h>
38 #endif
39
40 #include "ssh2.h"
41 #include "packet.h"
42 #include "compat.h"
43 #include "cipher.h"
44 #include "sshkey.h"
45 #include "kex.h"
46 #include "log.h"
47 #include "mac.h"
48 #include "match.h"
49 #include "misc.h"
50 #include "dispatch.h"
51 #include "monitor.h"
52
53 #include "ssherr.h"
54 #include "sshbuf.h"
55 #include "digest.h"
56
57 #if OPENSSL_VERSION_NUMBER >= 0x00907000L
58 # if defined(HAVE_EVP_SHA256)
59 # define evp_ssh_sha256 EVP_sha256
60 # else
61 extern const EVP_MD *evp_ssh_sha256(void);
62 # endif
63 #endif
64
65 /* prototype */
66 static int kex_choose_conf(struct ssh *);
67 static int kex_input_newkeys(int, u_int32_t, void *);
68
69 static const char *proposal_names[PROPOSAL_MAX] = {
70         "KEX algorithms",
71         "host key algorithms",
72         "ciphers ctos",
73         "ciphers stoc",
74         "MACs ctos",
75         "MACs stoc",
76         "compression ctos",
77         "compression stoc",
78         "languages ctos",
79         "languages stoc",
80 };
81
82 struct kexalg {
83         char *name;
84         u_int type;
85         int ec_nid;
86         int hash_alg;
87 };
88 static const struct kexalg kexalgs[] = {
89 #ifdef WITH_OPENSSL
90         { KEX_DH1, KEX_DH_GRP1_SHA1, 0, SSH_DIGEST_SHA1 },
91         { KEX_DH14_SHA1, KEX_DH_GRP14_SHA1, 0, SSH_DIGEST_SHA1 },
92         { KEX_DH14_SHA256, KEX_DH_GRP14_SHA256, 0, SSH_DIGEST_SHA256 },
93         { KEX_DH16_SHA512, KEX_DH_GRP16_SHA512, 0, SSH_DIGEST_SHA512 },
94         { KEX_DH18_SHA512, KEX_DH_GRP18_SHA512, 0, SSH_DIGEST_SHA512 },
95         { KEX_DHGEX_SHA1, KEX_DH_GEX_SHA1, 0, SSH_DIGEST_SHA1 },
96 #ifdef HAVE_EVP_SHA256
97         { KEX_DHGEX_SHA256, KEX_DH_GEX_SHA256, 0, SSH_DIGEST_SHA256 },
98 #endif /* HAVE_EVP_SHA256 */
99 #ifdef OPENSSL_HAS_ECC
100         { KEX_ECDH_SHA2_NISTP256, KEX_ECDH_SHA2,
101             NID_X9_62_prime256v1, SSH_DIGEST_SHA256 },
102         { KEX_ECDH_SHA2_NISTP384, KEX_ECDH_SHA2, NID_secp384r1,
103             SSH_DIGEST_SHA384 },
104 # ifdef OPENSSL_HAS_NISTP521
105         { KEX_ECDH_SHA2_NISTP521, KEX_ECDH_SHA2, NID_secp521r1,
106             SSH_DIGEST_SHA512 },
107 # endif /* OPENSSL_HAS_NISTP521 */
108 #endif /* OPENSSL_HAS_ECC */
109 #endif /* WITH_OPENSSL */
110 #if defined(HAVE_EVP_SHA256) || !defined(WITH_OPENSSL)
111         { KEX_CURVE25519_SHA256, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
112         { KEX_CURVE25519_SHA256_OLD, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
113 #endif /* HAVE_EVP_SHA256 || !WITH_OPENSSL */
114         { NULL, -1, -1, -1},
115 };
116
117 char *
118 kex_alg_list(char sep)
119 {
120         char *ret = NULL, *tmp;
121         size_t nlen, rlen = 0;
122         const struct kexalg *k;
123
124         for (k = kexalgs; k->name != NULL; k++) {
125                 if (ret != NULL)
126                         ret[rlen++] = sep;
127                 nlen = strlen(k->name);
128                 if ((tmp = realloc(ret, rlen + nlen + 2)) == NULL) {
129                         free(ret);
130                         return NULL;
131                 }
132                 ret = tmp;
133                 memcpy(ret + rlen, k->name, nlen + 1);
134                 rlen += nlen;
135         }
136         return ret;
137 }
138
139 static const struct kexalg *
140 kex_alg_by_name(const char *name)
141 {
142         const struct kexalg *k;
143
144         for (k = kexalgs; k->name != NULL; k++) {
145                 if (strcmp(k->name, name) == 0)
146                         return k;
147         }
148         return NULL;
149 }
150
151 /* Validate KEX method name list */
152 int
153 kex_names_valid(const char *names)
154 {
155         char *s, *cp, *p;
156
157         if (names == NULL || strcmp(names, "") == 0)
158                 return 0;
159         if ((s = cp = strdup(names)) == NULL)
160                 return 0;
161         for ((p = strsep(&cp, ",")); p && *p != '\0';
162             (p = strsep(&cp, ","))) {
163                 if (kex_alg_by_name(p) == NULL) {
164                         error("Unsupported KEX algorithm \"%.100s\"", p);
165                         free(s);
166                         return 0;
167                 }
168         }
169         debug3("kex names ok: [%s]", names);
170         free(s);
171         return 1;
172 }
173
174 /*
175  * Concatenate algorithm names, avoiding duplicates in the process.
176  * Caller must free returned string.
177  */
178 char *
179 kex_names_cat(const char *a, const char *b)
180 {
181         char *ret = NULL, *tmp = NULL, *cp, *p;
182         size_t len;
183
184         if (a == NULL || *a == '\0')
185                 return NULL;
186         if (b == NULL || *b == '\0')
187                 return strdup(a);
188         if (strlen(b) > 1024*1024)
189                 return NULL;
190         len = strlen(a) + strlen(b) + 2;
191         if ((tmp = cp = strdup(b)) == NULL ||
192             (ret = calloc(1, len)) == NULL) {
193                 free(tmp);
194                 return NULL;
195         }
196         strlcpy(ret, a, len);
197         for ((p = strsep(&cp, ",")); p && *p != '\0'; (p = strsep(&cp, ","))) {
198                 if (match_list(ret, p, NULL) != NULL)
199                         continue; /* Algorithm already present */
200                 if (strlcat(ret, ",", len) >= len ||
201                     strlcat(ret, p, len) >= len) {
202                         free(tmp);
203                         free(ret);
204                         return NULL; /* Shouldn't happen */
205                 }
206         }
207         free(tmp);
208         return ret;
209 }
210
211 /*
212  * Assemble a list of algorithms from a default list and a string from a
213  * configuration file. The user-provided string may begin with '+' to
214  * indicate that it should be appended to the default.
215  */
216 int
217 kex_assemble_names(const char *def, char **list)
218 {
219         char *ret;
220
221         if (list == NULL || *list == NULL || **list == '\0') {
222                 *list = strdup(def);
223                 return 0;
224         }
225         if (**list != '+') {
226                 return 0;
227         }
228
229         if ((ret = kex_names_cat(def, *list + 1)) == NULL)
230                 return SSH_ERR_ALLOC_FAIL;
231         free(*list);
232         *list = ret;
233         return 0;
234 }
235
236 /* put algorithm proposal into buffer */
237 int
238 kex_prop2buf(struct sshbuf *b, char *proposal[PROPOSAL_MAX])
239 {
240         u_int i;
241         int r;
242
243         sshbuf_reset(b);
244
245         /*
246          * add a dummy cookie, the cookie will be overwritten by
247          * kex_send_kexinit(), each time a kexinit is set
248          */
249         for (i = 0; i < KEX_COOKIE_LEN; i++) {
250                 if ((r = sshbuf_put_u8(b, 0)) != 0)
251                         return r;
252         }
253         for (i = 0; i < PROPOSAL_MAX; i++) {
254                 if ((r = sshbuf_put_cstring(b, proposal[i])) != 0)
255                         return r;
256         }
257         if ((r = sshbuf_put_u8(b, 0)) != 0 ||   /* first_kex_packet_follows */
258             (r = sshbuf_put_u32(b, 0)) != 0)    /* uint32 reserved */
259                 return r;
260         return 0;
261 }
262
263 /* parse buffer and return algorithm proposal */
264 int
265 kex_buf2prop(struct sshbuf *raw, int *first_kex_follows, char ***propp)
266 {
267         struct sshbuf *b = NULL;
268         u_char v;
269         u_int i;
270         char **proposal = NULL;
271         int r;
272
273         *propp = NULL;
274         if ((proposal = calloc(PROPOSAL_MAX, sizeof(char *))) == NULL)
275                 return SSH_ERR_ALLOC_FAIL;
276         if ((b = sshbuf_fromb(raw)) == NULL) {
277                 r = SSH_ERR_ALLOC_FAIL;
278                 goto out;
279         }
280         if ((r = sshbuf_consume(b, KEX_COOKIE_LEN)) != 0) /* skip cookie */
281                 goto out;
282         /* extract kex init proposal strings */
283         for (i = 0; i < PROPOSAL_MAX; i++) {
284                 if ((r = sshbuf_get_cstring(b, &(proposal[i]), NULL)) != 0)
285                         goto out;
286                 debug2("%s: %s", proposal_names[i], proposal[i]);
287         }
288         /* first kex follows / reserved */
289         if ((r = sshbuf_get_u8(b, &v)) != 0 ||  /* first_kex_follows */
290             (r = sshbuf_get_u32(b, &i)) != 0)   /* reserved */
291                 goto out;
292         if (first_kex_follows != NULL)
293                 *first_kex_follows = v;
294         debug2("first_kex_follows %d ", v);
295         debug2("reserved %u ", i);
296         r = 0;
297         *propp = proposal;
298  out:
299         if (r != 0 && proposal != NULL)
300                 kex_prop_free(proposal);
301         sshbuf_free(b);
302         return r;
303 }
304
305 void
306 kex_prop_free(char **proposal)
307 {
308         u_int i;
309
310         if (proposal == NULL)
311                 return;
312         for (i = 0; i < PROPOSAL_MAX; i++)
313                 free(proposal[i]);
314         free(proposal);
315 }
316
317 /* ARGSUSED */
318 static int
319 kex_protocol_error(int type, u_int32_t seq, void *ctxt)
320 {
321         struct ssh *ssh = active_state; /* XXX */
322         int r;
323
324         error("kex protocol error: type %d seq %u", type, seq);
325         if ((r = sshpkt_start(ssh, SSH2_MSG_UNIMPLEMENTED)) != 0 ||
326             (r = sshpkt_put_u32(ssh, seq)) != 0 ||
327             (r = sshpkt_send(ssh)) != 0)
328                 return r;
329         return 0;
330 }
331
332 static void
333 kex_reset_dispatch(struct ssh *ssh)
334 {
335         ssh_dispatch_range(ssh, SSH2_MSG_TRANSPORT_MIN,
336             SSH2_MSG_TRANSPORT_MAX, &kex_protocol_error);
337         ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
338 }
339
340 static int
341 kex_send_ext_info(struct ssh *ssh)
342 {
343         int r;
344         char *algs;
345
346         if ((algs = sshkey_alg_list(0, 1, ',')) == NULL)
347                 return SSH_ERR_ALLOC_FAIL;
348         if ((r = sshpkt_start(ssh, SSH2_MSG_EXT_INFO)) != 0 ||
349             (r = sshpkt_put_u32(ssh, 1)) != 0 ||
350             (r = sshpkt_put_cstring(ssh, "server-sig-algs")) != 0 ||
351             (r = sshpkt_put_cstring(ssh, algs)) != 0 ||
352             (r = sshpkt_send(ssh)) != 0)
353                 goto out;
354         /* success */
355         r = 0;
356  out:
357         free(algs);
358         return r;
359 }
360
361 int
362 kex_send_newkeys(struct ssh *ssh)
363 {
364         int r;
365
366         kex_reset_dispatch(ssh);
367         if ((r = sshpkt_start(ssh, SSH2_MSG_NEWKEYS)) != 0 ||
368             (r = sshpkt_send(ssh)) != 0)
369                 return r;
370         debug("SSH2_MSG_NEWKEYS sent");
371         debug("expecting SSH2_MSG_NEWKEYS");
372         ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_input_newkeys);
373         if (ssh->kex->ext_info_c)
374                 if ((r = kex_send_ext_info(ssh)) != 0)
375                         return r;
376         return 0;
377 }
378
379 int
380 kex_input_ext_info(int type, u_int32_t seq, void *ctxt)
381 {
382         struct ssh *ssh = ctxt;
383         struct kex *kex = ssh->kex;
384         u_int32_t i, ninfo;
385         char *name, *val, *found;
386         int r;
387
388         debug("SSH2_MSG_EXT_INFO received");
389         ssh_dispatch_set(ssh, SSH2_MSG_EXT_INFO, &kex_protocol_error);
390         if ((r = sshpkt_get_u32(ssh, &ninfo)) != 0)
391                 return r;
392         for (i = 0; i < ninfo; i++) {
393                 if ((r = sshpkt_get_cstring(ssh, &name, NULL)) != 0)
394                         return r;
395                 if ((r = sshpkt_get_cstring(ssh, &val, NULL)) != 0) {
396                         free(name);
397                         return r;
398                 }
399                 debug("%s: %s=<%s>", __func__, name, val);
400                 if (strcmp(name, "server-sig-algs") == 0) {
401                         found = match_list("rsa-sha2-256", val, NULL);
402                         if (found) {
403                                 kex->rsa_sha2 = 256;
404                                 free(found);
405                         }
406                         found = match_list("rsa-sha2-512", val, NULL);
407                         if (found) {
408                                 kex->rsa_sha2 = 512;
409                                 free(found);
410                         }
411                 }
412                 free(name);
413                 free(val);
414         }
415         return sshpkt_get_end(ssh);
416 }
417
418 static int
419 kex_input_newkeys(int type, u_int32_t seq, void *ctxt)
420 {
421         struct ssh *ssh = ctxt;
422         struct kex *kex = ssh->kex;
423         int r;
424
425         debug("SSH2_MSG_NEWKEYS received");
426         ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_protocol_error);
427         if ((r = sshpkt_get_end(ssh)) != 0)
428                 return r;
429         if ((r = ssh_set_newkeys(ssh, MODE_IN)) != 0)
430                 return r;
431         kex->done = 1;
432         sshbuf_reset(kex->peer);
433         /* sshbuf_reset(kex->my); */
434         kex->flags &= ~KEX_INIT_SENT;
435         free(kex->name);
436         kex->name = NULL;
437         return 0;
438 }
439
440 int
441 kex_send_kexinit(struct ssh *ssh)
442 {
443         u_char *cookie;
444         struct kex *kex = ssh->kex;
445         int r;
446
447         if (kex == NULL)
448                 return SSH_ERR_INTERNAL_ERROR;
449         if (kex->flags & KEX_INIT_SENT)
450                 return 0;
451         kex->done = 0;
452
453         /* generate a random cookie */
454         if (sshbuf_len(kex->my) < KEX_COOKIE_LEN)
455                 return SSH_ERR_INVALID_FORMAT;
456         if ((cookie = sshbuf_mutable_ptr(kex->my)) == NULL)
457                 return SSH_ERR_INTERNAL_ERROR;
458         arc4random_buf(cookie, KEX_COOKIE_LEN);
459
460         if ((r = sshpkt_start(ssh, SSH2_MSG_KEXINIT)) != 0 ||
461             (r = sshpkt_putb(ssh, kex->my)) != 0 ||
462             (r = sshpkt_send(ssh)) != 0)
463                 return r;
464         debug("SSH2_MSG_KEXINIT sent");
465         kex->flags |= KEX_INIT_SENT;
466         return 0;
467 }
468
469 /* ARGSUSED */
470 int
471 kex_input_kexinit(int type, u_int32_t seq, void *ctxt)
472 {
473         struct ssh *ssh = ctxt;
474         struct kex *kex = ssh->kex;
475         const u_char *ptr;
476         u_int i;
477         size_t dlen;
478         int r;
479
480         debug("SSH2_MSG_KEXINIT received");
481         if (kex == NULL)
482                 return SSH_ERR_INVALID_ARGUMENT;
483
484         ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, NULL);
485         ptr = sshpkt_ptr(ssh, &dlen);
486         if ((r = sshbuf_put(kex->peer, ptr, dlen)) != 0)
487                 return r;
488
489         /* discard packet */
490         for (i = 0; i < KEX_COOKIE_LEN; i++)
491                 if ((r = sshpkt_get_u8(ssh, NULL)) != 0)
492                         return r;
493         for (i = 0; i < PROPOSAL_MAX; i++)
494                 if ((r = sshpkt_get_string(ssh, NULL, NULL)) != 0)
495                         return r;
496         /*
497          * XXX RFC4253 sec 7: "each side MAY guess" - currently no supported
498          * KEX method has the server move first, but a server might be using
499          * a custom method or one that we otherwise don't support. We should
500          * be prepared to remember first_kex_follows here so we can eat a
501          * packet later.
502          * XXX2 - RFC4253 is kind of ambiguous on what first_kex_follows means
503          * for cases where the server *doesn't* go first. I guess we should
504          * ignore it when it is set for these cases, which is what we do now.
505          */
506         if ((r = sshpkt_get_u8(ssh, NULL)) != 0 ||      /* first_kex_follows */
507             (r = sshpkt_get_u32(ssh, NULL)) != 0 ||     /* reserved */
508             (r = sshpkt_get_end(ssh)) != 0)
509                         return r;
510
511         if (!(kex->flags & KEX_INIT_SENT))
512                 if ((r = kex_send_kexinit(ssh)) != 0)
513                         return r;
514         if ((r = kex_choose_conf(ssh)) != 0)
515                 return r;
516
517         if (kex->kex_type < KEX_MAX && kex->kex[kex->kex_type] != NULL)
518                 return (kex->kex[kex->kex_type])(ssh);
519
520         return SSH_ERR_INTERNAL_ERROR;
521 }
522
523 int
524 kex_new(struct ssh *ssh, char *proposal[PROPOSAL_MAX], struct kex **kexp)
525 {
526         struct kex *kex;
527         int r;
528
529         *kexp = NULL;
530         if ((kex = calloc(1, sizeof(*kex))) == NULL)
531                 return SSH_ERR_ALLOC_FAIL;
532         if ((kex->peer = sshbuf_new()) == NULL ||
533             (kex->my = sshbuf_new()) == NULL) {
534                 r = SSH_ERR_ALLOC_FAIL;
535                 goto out;
536         }
537         if ((r = kex_prop2buf(kex->my, proposal)) != 0)
538                 goto out;
539         kex->done = 0;
540         kex_reset_dispatch(ssh);
541         r = 0;
542         *kexp = kex;
543  out:
544         if (r != 0)
545                 kex_free(kex);
546         return r;
547 }
548
549 void
550 kex_free_newkeys(struct newkeys *newkeys)
551 {
552         if (newkeys == NULL)
553                 return;
554         if (newkeys->enc.key) {
555                 explicit_bzero(newkeys->enc.key, newkeys->enc.key_len);
556                 free(newkeys->enc.key);
557                 newkeys->enc.key = NULL;
558         }
559         if (newkeys->enc.iv) {
560                 explicit_bzero(newkeys->enc.iv, newkeys->enc.iv_len);
561                 free(newkeys->enc.iv);
562                 newkeys->enc.iv = NULL;
563         }
564         free(newkeys->enc.name);
565         explicit_bzero(&newkeys->enc, sizeof(newkeys->enc));
566         free(newkeys->comp.name);
567         explicit_bzero(&newkeys->comp, sizeof(newkeys->comp));
568         mac_clear(&newkeys->mac);
569         if (newkeys->mac.key) {
570                 explicit_bzero(newkeys->mac.key, newkeys->mac.key_len);
571                 free(newkeys->mac.key);
572                 newkeys->mac.key = NULL;
573         }
574         free(newkeys->mac.name);
575         explicit_bzero(&newkeys->mac, sizeof(newkeys->mac));
576         explicit_bzero(newkeys, sizeof(*newkeys));
577         free(newkeys);
578 }
579
580 void
581 kex_free(struct kex *kex)
582 {
583         u_int mode;
584
585 #ifdef WITH_OPENSSL
586         if (kex->dh)
587                 DH_free(kex->dh);
588 #ifdef OPENSSL_HAS_ECC
589         if (kex->ec_client_key)
590                 EC_KEY_free(kex->ec_client_key);
591 #endif /* OPENSSL_HAS_ECC */
592 #endif /* WITH_OPENSSL */
593         for (mode = 0; mode < MODE_MAX; mode++) {
594                 kex_free_newkeys(kex->newkeys[mode]);
595                 kex->newkeys[mode] = NULL;
596         }
597         sshbuf_free(kex->peer);
598         sshbuf_free(kex->my);
599         free(kex->session_id);
600         free(kex->client_version_string);
601         free(kex->server_version_string);
602         free(kex->failed_choice);
603         free(kex->hostkey_alg);
604         free(kex->name);
605         free(kex);
606 }
607
608 int
609 kex_setup(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
610 {
611         int r;
612
613         if ((r = kex_new(ssh, proposal, &ssh->kex)) != 0)
614                 return r;
615         if ((r = kex_send_kexinit(ssh)) != 0) {         /* we start */
616                 kex_free(ssh->kex);
617                 ssh->kex = NULL;
618                 return r;
619         }
620         return 0;
621 }
622
623 /*
624  * Request key re-exchange, returns 0 on success or a ssherr.h error
625  * code otherwise. Must not be called if KEX is incomplete or in-progress.
626  */
627 int
628 kex_start_rekex(struct ssh *ssh)
629 {
630         if (ssh->kex == NULL) {
631                 error("%s: no kex", __func__);
632                 return SSH_ERR_INTERNAL_ERROR;
633         }
634         if (ssh->kex->done == 0) {
635                 error("%s: requested twice", __func__);
636                 return SSH_ERR_INTERNAL_ERROR;
637         }
638         ssh->kex->done = 0;
639         return kex_send_kexinit(ssh);
640 }
641
642 static int
643 choose_enc(struct sshenc *enc, char *client, char *server)
644 {
645         char *name = match_list(client, server, NULL);
646
647         if (name == NULL)
648                 return SSH_ERR_NO_CIPHER_ALG_MATCH;
649         if ((enc->cipher = cipher_by_name(name)) == NULL)
650                 return SSH_ERR_INTERNAL_ERROR;
651         enc->name = name;
652         enc->enabled = 0;
653         enc->iv = NULL;
654         enc->iv_len = cipher_ivlen(enc->cipher);
655         enc->key = NULL;
656         enc->key_len = cipher_keylen(enc->cipher);
657         enc->block_size = cipher_blocksize(enc->cipher);
658         return 0;
659 }
660
661 static int
662 choose_mac(struct ssh *ssh, struct sshmac *mac, char *client, char *server)
663 {
664         char *name = match_list(client, server, NULL);
665
666         if (name == NULL)
667                 return SSH_ERR_NO_MAC_ALG_MATCH;
668         if (mac_setup(mac, name) < 0)
669                 return SSH_ERR_INTERNAL_ERROR;
670         /* truncate the key */
671         if (ssh->compat & SSH_BUG_HMAC)
672                 mac->key_len = 16;
673         mac->name = name;
674         mac->key = NULL;
675         mac->enabled = 0;
676         return 0;
677 }
678
679 static int
680 choose_comp(struct sshcomp *comp, char *client, char *server)
681 {
682         char *name = match_list(client, server, NULL);
683
684         if (name == NULL)
685                 return SSH_ERR_NO_COMPRESS_ALG_MATCH;
686         if (strcmp(name, "zlib@openssh.com") == 0) {
687                 comp->type = COMP_DELAYED;
688         } else if (strcmp(name, "zlib") == 0) {
689                 comp->type = COMP_ZLIB;
690         } else if (strcmp(name, "none") == 0) {
691                 comp->type = COMP_NONE;
692         } else {
693                 return SSH_ERR_INTERNAL_ERROR;
694         }
695         comp->name = name;
696         return 0;
697 }
698
699 static int
700 choose_kex(struct kex *k, char *client, char *server)
701 {
702         const struct kexalg *kexalg;
703
704         k->name = match_list(client, server, NULL);
705
706         debug("kex: algorithm: %s", k->name ? k->name : "(no match)");
707         if (k->name == NULL)
708                 return SSH_ERR_NO_KEX_ALG_MATCH;
709         if ((kexalg = kex_alg_by_name(k->name)) == NULL)
710                 return SSH_ERR_INTERNAL_ERROR;
711         k->kex_type = kexalg->type;
712         k->hash_alg = kexalg->hash_alg;
713         k->ec_nid = kexalg->ec_nid;
714         return 0;
715 }
716
717 static int
718 choose_hostkeyalg(struct kex *k, char *client, char *server)
719 {
720         k->hostkey_alg = match_list(client, server, NULL);
721
722         debug("kex: host key algorithm: %s",
723             k->hostkey_alg ? k->hostkey_alg : "(no match)");
724         if (k->hostkey_alg == NULL)
725                 return SSH_ERR_NO_HOSTKEY_ALG_MATCH;
726         k->hostkey_type = sshkey_type_from_name(k->hostkey_alg);
727         if (k->hostkey_type == KEY_UNSPEC)
728                 return SSH_ERR_INTERNAL_ERROR;
729         k->hostkey_nid = sshkey_ecdsa_nid_from_name(k->hostkey_alg);
730         return 0;
731 }
732
733 static int
734 proposals_match(char *my[PROPOSAL_MAX], char *peer[PROPOSAL_MAX])
735 {
736         static int check[] = {
737                 PROPOSAL_KEX_ALGS, PROPOSAL_SERVER_HOST_KEY_ALGS, -1
738         };
739         int *idx;
740         char *p;
741
742         for (idx = &check[0]; *idx != -1; idx++) {
743                 if ((p = strchr(my[*idx], ',')) != NULL)
744                         *p = '\0';
745                 if ((p = strchr(peer[*idx], ',')) != NULL)
746                         *p = '\0';
747                 if (strcmp(my[*idx], peer[*idx]) != 0) {
748                         debug2("proposal mismatch: my %s peer %s",
749                             my[*idx], peer[*idx]);
750                         return (0);
751                 }
752         }
753         debug2("proposals match");
754         return (1);
755 }
756
757 static int
758 kex_choose_conf(struct ssh *ssh)
759 {
760         struct kex *kex = ssh->kex;
761         struct newkeys *newkeys;
762         char **my = NULL, **peer = NULL;
763         char **cprop, **sprop;
764         int nenc, nmac, ncomp;
765         u_int mode, ctos, need, dh_need, authlen;
766         int r, first_kex_follows;
767
768         debug2("local %s KEXINIT proposal", kex->server ? "server" : "client");
769         if ((r = kex_buf2prop(kex->my, NULL, &my)) != 0)
770                 goto out;
771         debug2("peer %s KEXINIT proposal", kex->server ? "client" : "server");
772         if ((r = kex_buf2prop(kex->peer, &first_kex_follows, &peer)) != 0)
773                 goto out;
774
775         if (kex->server) {
776                 cprop=peer;
777                 sprop=my;
778         } else {
779                 cprop=my;
780                 sprop=peer;
781         }
782
783         /* Check whether client supports ext_info_c */
784         if (kex->server) {
785                 char *ext;
786
787                 ext = match_list("ext-info-c", peer[PROPOSAL_KEX_ALGS], NULL);
788                 kex->ext_info_c = (ext != NULL);
789                 free(ext);
790         }
791
792         /* Algorithm Negotiation */
793         if ((r = choose_kex(kex, cprop[PROPOSAL_KEX_ALGS],
794             sprop[PROPOSAL_KEX_ALGS])) != 0) {
795                 kex->failed_choice = peer[PROPOSAL_KEX_ALGS];
796                 peer[PROPOSAL_KEX_ALGS] = NULL;
797                 goto out;
798         }
799         if ((r = choose_hostkeyalg(kex, cprop[PROPOSAL_SERVER_HOST_KEY_ALGS],
800             sprop[PROPOSAL_SERVER_HOST_KEY_ALGS])) != 0) {
801                 kex->failed_choice = peer[PROPOSAL_SERVER_HOST_KEY_ALGS];
802                 peer[PROPOSAL_SERVER_HOST_KEY_ALGS] = NULL;
803                 goto out;
804         }
805         for (mode = 0; mode < MODE_MAX; mode++) {
806                 if ((newkeys = calloc(1, sizeof(*newkeys))) == NULL) {
807                         r = SSH_ERR_ALLOC_FAIL;
808                         goto out;
809                 }
810                 kex->newkeys[mode] = newkeys;
811                 ctos = (!kex->server && mode == MODE_OUT) ||
812                     (kex->server && mode == MODE_IN);
813                 nenc  = ctos ? PROPOSAL_ENC_ALGS_CTOS  : PROPOSAL_ENC_ALGS_STOC;
814                 nmac  = ctos ? PROPOSAL_MAC_ALGS_CTOS  : PROPOSAL_MAC_ALGS_STOC;
815                 ncomp = ctos ? PROPOSAL_COMP_ALGS_CTOS : PROPOSAL_COMP_ALGS_STOC;
816                 if ((r = choose_enc(&newkeys->enc, cprop[nenc],
817                     sprop[nenc])) != 0) {
818                         kex->failed_choice = peer[nenc];
819                         peer[nenc] = NULL;
820                         goto out;
821                 }
822                 authlen = cipher_authlen(newkeys->enc.cipher);
823                 /* ignore mac for authenticated encryption */
824                 if (authlen == 0 &&
825                     (r = choose_mac(ssh, &newkeys->mac, cprop[nmac],
826                     sprop[nmac])) != 0) {
827                         kex->failed_choice = peer[nmac];
828                         peer[nmac] = NULL;
829                         goto out;
830                 }
831                 if ((r = choose_comp(&newkeys->comp, cprop[ncomp],
832                     sprop[ncomp])) != 0) {
833                         kex->failed_choice = peer[ncomp];
834                         peer[ncomp] = NULL;
835                         goto out;
836                 }
837                 debug("kex: %s cipher: %s MAC: %s compression: %s",
838                     ctos ? "client->server" : "server->client",
839                     newkeys->enc.name,
840                     authlen == 0 ? newkeys->mac.name : "<implicit>",
841                     newkeys->comp.name);
842         }
843         need = dh_need = 0;
844         for (mode = 0; mode < MODE_MAX; mode++) {
845                 newkeys = kex->newkeys[mode];
846                 need = MAXIMUM(need, newkeys->enc.key_len);
847                 need = MAXIMUM(need, newkeys->enc.block_size);
848                 need = MAXIMUM(need, newkeys->enc.iv_len);
849                 need = MAXIMUM(need, newkeys->mac.key_len);
850                 dh_need = MAXIMUM(dh_need, cipher_seclen(newkeys->enc.cipher));
851                 dh_need = MAXIMUM(dh_need, newkeys->enc.block_size);
852                 dh_need = MAXIMUM(dh_need, newkeys->enc.iv_len);
853                 dh_need = MAXIMUM(dh_need, newkeys->mac.key_len);
854         }
855         /* XXX need runden? */
856         kex->we_need = need;
857         kex->dh_need = dh_need;
858
859         /* ignore the next message if the proposals do not match */
860         if (first_kex_follows && !proposals_match(my, peer) &&
861             !(ssh->compat & SSH_BUG_FIRSTKEX))
862                 ssh->dispatch_skip_packets = 1;
863         r = 0;
864  out:
865         kex_prop_free(my);
866         kex_prop_free(peer);
867         return r;
868 }
869
870 static int
871 derive_key(struct ssh *ssh, int id, u_int need, u_char *hash, u_int hashlen,
872     const struct sshbuf *shared_secret, u_char **keyp)
873 {
874         struct kex *kex = ssh->kex;
875         struct ssh_digest_ctx *hashctx = NULL;
876         char c = id;
877         u_int have;
878         size_t mdsz;
879         u_char *digest;
880         int r;
881
882         if ((mdsz = ssh_digest_bytes(kex->hash_alg)) == 0)
883                 return SSH_ERR_INVALID_ARGUMENT;
884         if ((digest = calloc(1, ROUNDUP(need, mdsz))) == NULL) {
885                 r = SSH_ERR_ALLOC_FAIL;
886                 goto out;
887         }
888
889         /* K1 = HASH(K || H || "A" || session_id) */
890         if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
891             ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
892             ssh_digest_update(hashctx, hash, hashlen) != 0 ||
893             ssh_digest_update(hashctx, &c, 1) != 0 ||
894             ssh_digest_update(hashctx, kex->session_id,
895             kex->session_id_len) != 0 ||
896             ssh_digest_final(hashctx, digest, mdsz) != 0) {
897                 r = SSH_ERR_LIBCRYPTO_ERROR;
898                 goto out;
899         }
900         ssh_digest_free(hashctx);
901         hashctx = NULL;
902
903         /*
904          * expand key:
905          * Kn = HASH(K || H || K1 || K2 || ... || Kn-1)
906          * Key = K1 || K2 || ... || Kn
907          */
908         for (have = mdsz; need > have; have += mdsz) {
909                 if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
910                     ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
911                     ssh_digest_update(hashctx, hash, hashlen) != 0 ||
912                     ssh_digest_update(hashctx, digest, have) != 0 ||
913                     ssh_digest_final(hashctx, digest + have, mdsz) != 0) {
914                         r = SSH_ERR_LIBCRYPTO_ERROR;
915                         goto out;
916                 }
917                 ssh_digest_free(hashctx);
918                 hashctx = NULL;
919         }
920 #ifdef DEBUG_KEX
921         fprintf(stderr, "key '%c'== ", c);
922         dump_digest("key", digest, need);
923 #endif
924         *keyp = digest;
925         digest = NULL;
926         r = 0;
927  out:
928         free(digest);
929         ssh_digest_free(hashctx);
930         return r;
931 }
932
933 #define NKEYS   6
934 int
935 kex_derive_keys(struct ssh *ssh, u_char *hash, u_int hashlen,
936     const struct sshbuf *shared_secret)
937 {
938         struct kex *kex = ssh->kex;
939         u_char *keys[NKEYS];
940         u_int i, j, mode, ctos;
941         int r;
942
943         for (i = 0; i < NKEYS; i++) {
944                 if ((r = derive_key(ssh, 'A'+i, kex->we_need, hash, hashlen,
945                     shared_secret, &keys[i])) != 0) {
946                         for (j = 0; j < i; j++)
947                                 free(keys[j]);
948                         return r;
949                 }
950         }
951         for (mode = 0; mode < MODE_MAX; mode++) {
952                 ctos = (!kex->server && mode == MODE_OUT) ||
953                     (kex->server && mode == MODE_IN);
954                 kex->newkeys[mode]->enc.iv  = keys[ctos ? 0 : 1];
955                 kex->newkeys[mode]->enc.key = keys[ctos ? 2 : 3];
956                 kex->newkeys[mode]->mac.key = keys[ctos ? 4 : 5];
957         }
958         return 0;
959 }
960
961 #ifdef WITH_OPENSSL
962 int
963 kex_derive_keys_bn(struct ssh *ssh, u_char *hash, u_int hashlen,
964     const BIGNUM *secret)
965 {
966         struct sshbuf *shared_secret;
967         int r;
968
969         if ((shared_secret = sshbuf_new()) == NULL)
970                 return SSH_ERR_ALLOC_FAIL;
971         if ((r = sshbuf_put_bignum2(shared_secret, secret)) == 0)
972                 r = kex_derive_keys(ssh, hash, hashlen, shared_secret);
973         sshbuf_free(shared_secret);
974         return r;
975 }
976 #endif
977
978 #ifdef WITH_SSH1
979 int
980 derive_ssh1_session_id(BIGNUM *host_modulus, BIGNUM *server_modulus,
981     u_int8_t cookie[8], u_int8_t id[16])
982 {
983         u_int8_t hbuf[2048], sbuf[2048], obuf[SSH_DIGEST_MAX_LENGTH];
984         struct ssh_digest_ctx *hashctx = NULL;
985         size_t hlen, slen;
986         int r;
987
988         hlen = BN_num_bytes(host_modulus);
989         slen = BN_num_bytes(server_modulus);
990         if (hlen < (512 / 8) || (u_int)hlen > sizeof(hbuf) ||
991             slen < (512 / 8) || (u_int)slen > sizeof(sbuf))
992                 return SSH_ERR_KEY_BITS_MISMATCH;
993         if (BN_bn2bin(host_modulus, hbuf) <= 0 ||
994             BN_bn2bin(server_modulus, sbuf) <= 0) {
995                 r = SSH_ERR_LIBCRYPTO_ERROR;
996                 goto out;
997         }
998         if ((hashctx = ssh_digest_start(SSH_DIGEST_MD5)) == NULL) {
999                 r = SSH_ERR_ALLOC_FAIL;
1000                 goto out;
1001         }
1002         if (ssh_digest_update(hashctx, hbuf, hlen) != 0 ||
1003             ssh_digest_update(hashctx, sbuf, slen) != 0 ||
1004             ssh_digest_update(hashctx, cookie, 8) != 0 ||
1005             ssh_digest_final(hashctx, obuf, sizeof(obuf)) != 0) {
1006                 r = SSH_ERR_LIBCRYPTO_ERROR;
1007                 goto out;
1008         }
1009         memcpy(id, obuf, ssh_digest_bytes(SSH_DIGEST_MD5));
1010         r = 0;
1011  out:
1012         ssh_digest_free(hashctx);
1013         explicit_bzero(hbuf, sizeof(hbuf));
1014         explicit_bzero(sbuf, sizeof(sbuf));
1015         explicit_bzero(obuf, sizeof(obuf));
1016         return r;
1017 }
1018 #endif
1019
1020 #if defined(DEBUG_KEX) || defined(DEBUG_KEXDH) || defined(DEBUG_KEXECDH)
1021 void
1022 dump_digest(char *msg, u_char *digest, int len)
1023 {
1024         fprintf(stderr, "%s\n", msg);
1025         sshbuf_dump_data(digest, len, stderr);
1026 }
1027 #endif