]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - crypto/openssh/packet.c
Update llvm, clang, lld and lldb to release_39 branch r287912.
[FreeBSD/FreeBSD.git] / crypto / openssh / packet.c
1 /* $OpenBSD: packet.c,v 1.229 2016/02/17 22:20:14 djm Exp $ */
2 /*
3  * Author: Tatu Ylonen <ylo@cs.hut.fi>
4  * Copyright (c) 1995 Tatu Ylonen <ylo@cs.hut.fi>, Espoo, Finland
5  *                    All rights reserved
6  * This file contains code implementing the packet protocol and communication
7  * with the other side.  This same code is used both on client and server side.
8  *
9  * As far as I am concerned, the code I have written for this software
10  * can be used freely for any purpose.  Any derived versions of this
11  * software must be clearly marked as such, and if the derived work is
12  * incompatible with the protocol description in the RFC file, it must be
13  * called by a name other than "ssh" or "Secure Shell".
14  *
15  *
16  * SSH2 packet format added by Markus Friedl.
17  * Copyright (c) 2000, 2001 Markus Friedl.  All rights reserved.
18  *
19  * Redistribution and use in source and binary forms, with or without
20  * modification, are permitted provided that the following conditions
21  * are met:
22  * 1. Redistributions of source code must retain the above copyright
23  *    notice, this list of conditions and the following disclaimer.
24  * 2. Redistributions in binary form must reproduce the above copyright
25  *    notice, this list of conditions and the following disclaimer in the
26  *    documentation and/or other materials provided with the distribution.
27  *
28  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
29  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
30  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
31  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
32  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
33  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
34  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
35  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
36  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
37  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
38  */
39
40 #include "includes.h"
41 __RCSID("$FreeBSD$");
42  
43 #include <sys/param.h>  /* MIN roundup */
44 #include <sys/types.h>
45 #include "openbsd-compat/sys-queue.h"
46 #include <sys/socket.h>
47 #ifdef HAVE_SYS_TIME_H
48 # include <sys/time.h>
49 #endif
50
51 #include <netinet/in.h>
52 #include <netinet/ip.h>
53 #include <arpa/inet.h>
54
55 #include <errno.h>
56 #include <stdarg.h>
57 #include <stdio.h>
58 #include <stdlib.h>
59 #include <string.h>
60 #include <unistd.h>
61 #include <limits.h>
62 #include <signal.h>
63 #include <time.h>
64
65 #include <zlib.h>
66
67 #include "buffer.h"     /* typedefs XXX */
68 #include "key.h"        /* typedefs XXX */
69
70 #include "xmalloc.h"
71 #include "crc32.h"
72 #include "deattack.h"
73 #include "compat.h"
74 #include "ssh1.h"
75 #include "ssh2.h"
76 #include "cipher.h"
77 #include "sshkey.h"
78 #include "kex.h"
79 #include "digest.h"
80 #include "mac.h"
81 #include "log.h"
82 #include "canohost.h"
83 #include "misc.h"
84 #include "channels.h"
85 #include "ssh.h"
86 #include "packet.h"
87 #include "ssherr.h"
88 #include "sshbuf.h"
89 #include "blacklist_client.h"
90
91 #ifdef PACKET_DEBUG
92 #define DBG(x) x
93 #else
94 #define DBG(x)
95 #endif
96
97 #define PACKET_MAX_SIZE (256 * 1024)
98
99 struct packet_state {
100         u_int32_t seqnr;
101         u_int32_t packets;
102         u_int64_t blocks;
103         u_int64_t bytes;
104 };
105
106 struct packet {
107         TAILQ_ENTRY(packet) next;
108         u_char type;
109         struct sshbuf *payload;
110 };
111
112 struct session_state {
113         /*
114          * This variable contains the file descriptors used for
115          * communicating with the other side.  connection_in is used for
116          * reading; connection_out for writing.  These can be the same
117          * descriptor, in which case it is assumed to be a socket.
118          */
119         int connection_in;
120         int connection_out;
121
122         /* Protocol flags for the remote side. */
123         u_int remote_protocol_flags;
124
125         /* Encryption context for receiving data.  Only used for decryption. */
126         struct sshcipher_ctx receive_context;
127
128         /* Encryption context for sending data.  Only used for encryption. */
129         struct sshcipher_ctx send_context;
130
131         /* Buffer for raw input data from the socket. */
132         struct sshbuf *input;
133
134         /* Buffer for raw output data going to the socket. */
135         struct sshbuf *output;
136
137         /* Buffer for the partial outgoing packet being constructed. */
138         struct sshbuf *outgoing_packet;
139
140         /* Buffer for the incoming packet currently being processed. */
141         struct sshbuf *incoming_packet;
142
143         /* Scratch buffer for packet compression/decompression. */
144         struct sshbuf *compression_buffer;
145
146         /* Incoming/outgoing compression dictionaries */
147         z_stream compression_in_stream;
148         z_stream compression_out_stream;
149         int compression_in_started;
150         int compression_out_started;
151         int compression_in_failures;
152         int compression_out_failures;
153
154         /*
155          * Flag indicating whether packet compression/decompression is
156          * enabled.
157          */
158         int packet_compression;
159
160         /* default maximum packet size */
161         u_int max_packet_size;
162
163         /* Flag indicating whether this module has been initialized. */
164         int initialized;
165
166         /* Set to true if the connection is interactive. */
167         int interactive_mode;
168
169         /* Set to true if we are the server side. */
170         int server_side;
171
172         /* Set to true if we are authenticated. */
173         int after_authentication;
174
175         int keep_alive_timeouts;
176
177         /* The maximum time that we will wait to send or receive a packet */
178         int packet_timeout_ms;
179
180         /* Session key information for Encryption and MAC */
181         struct newkeys *newkeys[MODE_MAX];
182         struct packet_state p_read, p_send;
183
184         /* Volume-based rekeying */
185         u_int64_t max_blocks_in, max_blocks_out, rekey_limit;
186
187         /* Time-based rekeying */
188         u_int32_t rekey_interval;       /* how often in seconds */
189         time_t rekey_time;      /* time of last rekeying */
190
191         /* Session key for protocol v1 */
192         u_char ssh1_key[SSH_SESSION_KEY_LENGTH];
193         u_int ssh1_keylen;
194
195         /* roundup current message to extra_pad bytes */
196         u_char extra_pad;
197
198         /* XXX discard incoming data after MAC error */
199         u_int packet_discard;
200         struct sshmac *packet_discard_mac;
201
202         /* Used in packet_read_poll2() */
203         u_int packlen;
204
205         /* Used in packet_send2 */
206         int rekeying;
207
208         /* Used in packet_set_interactive */
209         int set_interactive_called;
210
211         /* Used in packet_set_maxsize */
212         int set_maxsize_called;
213
214         /* One-off warning about weak ciphers */
215         int cipher_warning_done;
216
217         /* SSH1 CRC compensation attack detector */
218         struct deattack_ctx deattack;
219
220         TAILQ_HEAD(, packet) outgoing;
221 };
222
223 struct ssh *
224 ssh_alloc_session_state(void)
225 {
226         struct ssh *ssh = NULL;
227         struct session_state *state = NULL;
228
229         if ((ssh = calloc(1, sizeof(*ssh))) == NULL ||
230             (state = calloc(1, sizeof(*state))) == NULL ||
231             (state->input = sshbuf_new()) == NULL ||
232             (state->output = sshbuf_new()) == NULL ||
233             (state->outgoing_packet = sshbuf_new()) == NULL ||
234             (state->incoming_packet = sshbuf_new()) == NULL)
235                 goto fail;
236         TAILQ_INIT(&state->outgoing);
237         TAILQ_INIT(&ssh->private_keys);
238         TAILQ_INIT(&ssh->public_keys);
239         state->connection_in = -1;
240         state->connection_out = -1;
241         state->max_packet_size = 32768;
242         state->packet_timeout_ms = -1;
243         state->p_send.packets = state->p_read.packets = 0;
244         state->initialized = 1;
245         /*
246          * ssh_packet_send2() needs to queue packets until
247          * we've done the initial key exchange.
248          */
249         state->rekeying = 1;
250         ssh->state = state;
251         return ssh;
252  fail:
253         if (state) {
254                 sshbuf_free(state->input);
255                 sshbuf_free(state->output);
256                 sshbuf_free(state->incoming_packet);
257                 sshbuf_free(state->outgoing_packet);
258                 free(state);
259         }
260         free(ssh);
261         return NULL;
262 }
263
264 /* Returns nonzero if rekeying is in progress */
265 int
266 ssh_packet_is_rekeying(struct ssh *ssh)
267 {
268         return compat20 &&
269             (ssh->state->rekeying || (ssh->kex != NULL && ssh->kex->done == 0));
270 }
271
272 /*
273  * Sets the descriptors used for communication.  Disables encryption until
274  * packet_set_encryption_key is called.
275  */
276 struct ssh *
277 ssh_packet_set_connection(struct ssh *ssh, int fd_in, int fd_out)
278 {
279         struct session_state *state;
280         const struct sshcipher *none = cipher_by_name("none");
281         int r;
282
283         if (none == NULL) {
284                 error("%s: cannot load cipher 'none'", __func__);
285                 return NULL;
286         }
287         if (ssh == NULL)
288                 ssh = ssh_alloc_session_state();
289         if (ssh == NULL) {
290                 error("%s: cound not allocate state", __func__);
291                 return NULL;
292         }
293         state = ssh->state;
294         state->connection_in = fd_in;
295         state->connection_out = fd_out;
296         if ((r = cipher_init(&state->send_context, none,
297             (const u_char *)"", 0, NULL, 0, CIPHER_ENCRYPT)) != 0 ||
298             (r = cipher_init(&state->receive_context, none,
299             (const u_char *)"", 0, NULL, 0, CIPHER_DECRYPT)) != 0) {
300                 error("%s: cipher_init failed: %s", __func__, ssh_err(r));
301                 free(ssh);
302                 return NULL;
303         }
304         state->newkeys[MODE_IN] = state->newkeys[MODE_OUT] = NULL;
305         deattack_init(&state->deattack);
306         /*
307          * Cache the IP address of the remote connection for use in error
308          * messages that might be generated after the connection has closed.
309          */
310         (void)ssh_remote_ipaddr(ssh);
311         return ssh;
312 }
313
314 void
315 ssh_packet_set_timeout(struct ssh *ssh, int timeout, int count)
316 {
317         struct session_state *state = ssh->state;
318
319         if (timeout <= 0 || count <= 0) {
320                 state->packet_timeout_ms = -1;
321                 return;
322         }
323         if ((INT_MAX / 1000) / count < timeout)
324                 state->packet_timeout_ms = INT_MAX;
325         else
326                 state->packet_timeout_ms = timeout * count * 1000;
327 }
328
329 int
330 ssh_packet_stop_discard(struct ssh *ssh)
331 {
332         struct session_state *state = ssh->state;
333         int r;
334
335         if (state->packet_discard_mac) {
336                 char buf[1024];
337
338                 memset(buf, 'a', sizeof(buf));
339                 while (sshbuf_len(state->incoming_packet) <
340                     PACKET_MAX_SIZE)
341                         if ((r = sshbuf_put(state->incoming_packet, buf,
342                             sizeof(buf))) != 0)
343                                 return r;
344                 (void) mac_compute(state->packet_discard_mac,
345                     state->p_read.seqnr,
346                     sshbuf_ptr(state->incoming_packet), PACKET_MAX_SIZE,
347                     NULL, 0);
348         }
349         logit("Finished discarding for %.200s port %d",
350             ssh_remote_ipaddr(ssh), ssh_remote_port(ssh));
351         return SSH_ERR_MAC_INVALID;
352 }
353
354 static int
355 ssh_packet_start_discard(struct ssh *ssh, struct sshenc *enc,
356     struct sshmac *mac, u_int packet_length, u_int discard)
357 {
358         struct session_state *state = ssh->state;
359         int r;
360
361         if (enc == NULL || !cipher_is_cbc(enc->cipher) || (mac && mac->etm)) {
362                 if ((r = sshpkt_disconnect(ssh, "Packet corrupt")) != 0)
363                         return r;
364                 return SSH_ERR_MAC_INVALID;
365         }
366         if (packet_length != PACKET_MAX_SIZE && mac && mac->enabled)
367                 state->packet_discard_mac = mac;
368         if (sshbuf_len(state->input) >= discard &&
369            (r = ssh_packet_stop_discard(ssh)) != 0)
370                 return r;
371         state->packet_discard = discard - sshbuf_len(state->input);
372         return 0;
373 }
374
375 /* Returns 1 if remote host is connected via socket, 0 if not. */
376
377 int
378 ssh_packet_connection_is_on_socket(struct ssh *ssh)
379 {
380         struct session_state *state = ssh->state;
381         struct sockaddr_storage from, to;
382         socklen_t fromlen, tolen;
383
384         /* filedescriptors in and out are the same, so it's a socket */
385         if (state->connection_in == state->connection_out)
386                 return 1;
387         fromlen = sizeof(from);
388         memset(&from, 0, sizeof(from));
389         if (getpeername(state->connection_in, (struct sockaddr *)&from,
390             &fromlen) < 0)
391                 return 0;
392         tolen = sizeof(to);
393         memset(&to, 0, sizeof(to));
394         if (getpeername(state->connection_out, (struct sockaddr *)&to,
395             &tolen) < 0)
396                 return 0;
397         if (fromlen != tolen || memcmp(&from, &to, fromlen) != 0)
398                 return 0;
399         if (from.ss_family != AF_INET && from.ss_family != AF_INET6)
400                 return 0;
401         return 1;
402 }
403
404 void
405 ssh_packet_get_bytes(struct ssh *ssh, u_int64_t *ibytes, u_int64_t *obytes)
406 {
407         if (ibytes)
408                 *ibytes = ssh->state->p_read.bytes;
409         if (obytes)
410                 *obytes = ssh->state->p_send.bytes;
411 }
412
413 int
414 ssh_packet_connection_af(struct ssh *ssh)
415 {
416         struct sockaddr_storage to;
417         socklen_t tolen = sizeof(to);
418
419         memset(&to, 0, sizeof(to));
420         if (getsockname(ssh->state->connection_out, (struct sockaddr *)&to,
421             &tolen) < 0)
422                 return 0;
423 #ifdef IPV4_IN_IPV6
424         if (to.ss_family == AF_INET6 &&
425             IN6_IS_ADDR_V4MAPPED(&((struct sockaddr_in6 *)&to)->sin6_addr))
426                 return AF_INET;
427 #endif
428         return to.ss_family;
429 }
430
431 /* Sets the connection into non-blocking mode. */
432
433 void
434 ssh_packet_set_nonblocking(struct ssh *ssh)
435 {
436         /* Set the socket into non-blocking mode. */
437         set_nonblock(ssh->state->connection_in);
438
439         if (ssh->state->connection_out != ssh->state->connection_in)
440                 set_nonblock(ssh->state->connection_out);
441 }
442
443 /* Returns the socket used for reading. */
444
445 int
446 ssh_packet_get_connection_in(struct ssh *ssh)
447 {
448         return ssh->state->connection_in;
449 }
450
451 /* Returns the descriptor used for writing. */
452
453 int
454 ssh_packet_get_connection_out(struct ssh *ssh)
455 {
456         return ssh->state->connection_out;
457 }
458
459 /*
460  * Returns the IP-address of the remote host as a string.  The returned
461  * string must not be freed.
462  */
463
464 const char *
465 ssh_remote_ipaddr(struct ssh *ssh)
466 {
467         const int sock = ssh->state->connection_in;
468
469         /* Check whether we have cached the ipaddr. */
470         if (ssh->remote_ipaddr == NULL) {
471                 if (ssh_packet_connection_is_on_socket(ssh)) {
472                         ssh->remote_ipaddr = get_peer_ipaddr(sock);
473                         ssh->remote_port = get_sock_port(sock, 0);
474                 } else {
475                         ssh->remote_ipaddr = strdup("UNKNOWN");
476                         ssh->remote_port = 0;
477                 }
478         }
479         return ssh->remote_ipaddr;
480 }
481
482 /* Returns the port number of the remote host. */
483
484 int
485 ssh_remote_port(struct ssh *ssh)
486 {
487         (void)ssh_remote_ipaddr(ssh); /* Will lookup and cache. */
488         return ssh->remote_port;
489 }
490
491 /* Closes the connection and clears and frees internal data structures. */
492
493 void
494 ssh_packet_close(struct ssh *ssh)
495 {
496         struct session_state *state = ssh->state;
497         int r;
498         u_int mode;
499
500         if (!state->initialized)
501                 return;
502         state->initialized = 0;
503         if (state->connection_in == state->connection_out) {
504                 shutdown(state->connection_out, SHUT_RDWR);
505                 close(state->connection_out);
506         } else {
507                 close(state->connection_in);
508                 close(state->connection_out);
509         }
510         sshbuf_free(state->input);
511         sshbuf_free(state->output);
512         sshbuf_free(state->outgoing_packet);
513         sshbuf_free(state->incoming_packet);
514         for (mode = 0; mode < MODE_MAX; mode++)
515                 kex_free_newkeys(state->newkeys[mode]);
516         if (state->compression_buffer) {
517                 sshbuf_free(state->compression_buffer);
518                 if (state->compression_out_started) {
519                         z_streamp stream = &state->compression_out_stream;
520                         debug("compress outgoing: "
521                             "raw data %llu, compressed %llu, factor %.2f",
522                                 (unsigned long long)stream->total_in,
523                                 (unsigned long long)stream->total_out,
524                                 stream->total_in == 0 ? 0.0 :
525                                 (double) stream->total_out / stream->total_in);
526                         if (state->compression_out_failures == 0)
527                                 deflateEnd(stream);
528                 }
529                 if (state->compression_in_started) {
530                         z_streamp stream = &state->compression_out_stream;
531                         debug("compress incoming: "
532                             "raw data %llu, compressed %llu, factor %.2f",
533                             (unsigned long long)stream->total_out,
534                             (unsigned long long)stream->total_in,
535                             stream->total_out == 0 ? 0.0 :
536                             (double) stream->total_in / stream->total_out);
537                         if (state->compression_in_failures == 0)
538                                 inflateEnd(stream);
539                 }
540         }
541         if ((r = cipher_cleanup(&state->send_context)) != 0)
542                 error("%s: cipher_cleanup failed: %s", __func__, ssh_err(r));
543         if ((r = cipher_cleanup(&state->receive_context)) != 0)
544                 error("%s: cipher_cleanup failed: %s", __func__, ssh_err(r));
545         free(ssh->remote_ipaddr);
546         ssh->remote_ipaddr = NULL;
547         free(ssh->state);
548         ssh->state = NULL;
549 }
550
551 /* Sets remote side protocol flags. */
552
553 void
554 ssh_packet_set_protocol_flags(struct ssh *ssh, u_int protocol_flags)
555 {
556         ssh->state->remote_protocol_flags = protocol_flags;
557 }
558
559 /* Returns the remote protocol flags set earlier by the above function. */
560
561 u_int
562 ssh_packet_get_protocol_flags(struct ssh *ssh)
563 {
564         return ssh->state->remote_protocol_flags;
565 }
566
567 /*
568  * Starts packet compression from the next packet on in both directions.
569  * Level is compression level 1 (fastest) - 9 (slow, best) as in gzip.
570  */
571
572 static int
573 ssh_packet_init_compression(struct ssh *ssh)
574 {
575         if (!ssh->state->compression_buffer &&
576            ((ssh->state->compression_buffer = sshbuf_new()) == NULL))
577                 return SSH_ERR_ALLOC_FAIL;
578         return 0;
579 }
580
581 static int
582 start_compression_out(struct ssh *ssh, int level)
583 {
584         if (level < 1 || level > 9)
585                 return SSH_ERR_INVALID_ARGUMENT;
586         debug("Enabling compression at level %d.", level);
587         if (ssh->state->compression_out_started == 1)
588                 deflateEnd(&ssh->state->compression_out_stream);
589         switch (deflateInit(&ssh->state->compression_out_stream, level)) {
590         case Z_OK:
591                 ssh->state->compression_out_started = 1;
592                 break;
593         case Z_MEM_ERROR:
594                 return SSH_ERR_ALLOC_FAIL;
595         default:
596                 return SSH_ERR_INTERNAL_ERROR;
597         }
598         return 0;
599 }
600
601 static int
602 start_compression_in(struct ssh *ssh)
603 {
604         if (ssh->state->compression_in_started == 1)
605                 inflateEnd(&ssh->state->compression_in_stream);
606         switch (inflateInit(&ssh->state->compression_in_stream)) {
607         case Z_OK:
608                 ssh->state->compression_in_started = 1;
609                 break;
610         case Z_MEM_ERROR:
611                 return SSH_ERR_ALLOC_FAIL;
612         default:
613                 return SSH_ERR_INTERNAL_ERROR;
614         }
615         return 0;
616 }
617
618 int
619 ssh_packet_start_compression(struct ssh *ssh, int level)
620 {
621         int r;
622
623         if (ssh->state->packet_compression && !compat20)
624                 return SSH_ERR_INTERNAL_ERROR;
625         ssh->state->packet_compression = 1;
626         if ((r = ssh_packet_init_compression(ssh)) != 0 ||
627             (r = start_compression_in(ssh)) != 0 ||
628             (r = start_compression_out(ssh, level)) != 0)
629                 return r;
630         return 0;
631 }
632
633 /* XXX remove need for separate compression buffer */
634 static int
635 compress_buffer(struct ssh *ssh, struct sshbuf *in, struct sshbuf *out)
636 {
637         u_char buf[4096];
638         int r, status;
639
640         if (ssh->state->compression_out_started != 1)
641                 return SSH_ERR_INTERNAL_ERROR;
642
643         /* This case is not handled below. */
644         if (sshbuf_len(in) == 0)
645                 return 0;
646
647         /* Input is the contents of the input buffer. */
648         if ((ssh->state->compression_out_stream.next_in =
649             sshbuf_mutable_ptr(in)) == NULL)
650                 return SSH_ERR_INTERNAL_ERROR;
651         ssh->state->compression_out_stream.avail_in = sshbuf_len(in);
652
653         /* Loop compressing until deflate() returns with avail_out != 0. */
654         do {
655                 /* Set up fixed-size output buffer. */
656                 ssh->state->compression_out_stream.next_out = buf;
657                 ssh->state->compression_out_stream.avail_out = sizeof(buf);
658
659                 /* Compress as much data into the buffer as possible. */
660                 status = deflate(&ssh->state->compression_out_stream,
661                     Z_PARTIAL_FLUSH);
662                 switch (status) {
663                 case Z_MEM_ERROR:
664                         return SSH_ERR_ALLOC_FAIL;
665                 case Z_OK:
666                         /* Append compressed data to output_buffer. */
667                         if ((r = sshbuf_put(out, buf, sizeof(buf) -
668                             ssh->state->compression_out_stream.avail_out)) != 0)
669                                 return r;
670                         break;
671                 case Z_STREAM_ERROR:
672                 default:
673                         ssh->state->compression_out_failures++;
674                         return SSH_ERR_INVALID_FORMAT;
675                 }
676         } while (ssh->state->compression_out_stream.avail_out == 0);
677         return 0;
678 }
679
680 static int
681 uncompress_buffer(struct ssh *ssh, struct sshbuf *in, struct sshbuf *out)
682 {
683         u_char buf[4096];
684         int r, status;
685
686         if (ssh->state->compression_in_started != 1)
687                 return SSH_ERR_INTERNAL_ERROR;
688
689         if ((ssh->state->compression_in_stream.next_in =
690             sshbuf_mutable_ptr(in)) == NULL)
691                 return SSH_ERR_INTERNAL_ERROR;
692         ssh->state->compression_in_stream.avail_in = sshbuf_len(in);
693
694         for (;;) {
695                 /* Set up fixed-size output buffer. */
696                 ssh->state->compression_in_stream.next_out = buf;
697                 ssh->state->compression_in_stream.avail_out = sizeof(buf);
698
699                 status = inflate(&ssh->state->compression_in_stream,
700                     Z_PARTIAL_FLUSH);
701                 switch (status) {
702                 case Z_OK:
703                         if ((r = sshbuf_put(out, buf, sizeof(buf) -
704                             ssh->state->compression_in_stream.avail_out)) != 0)
705                                 return r;
706                         break;
707                 case Z_BUF_ERROR:
708                         /*
709                          * Comments in zlib.h say that we should keep calling
710                          * inflate() until we get an error.  This appears to
711                          * be the error that we get.
712                          */
713                         return 0;
714                 case Z_DATA_ERROR:
715                         return SSH_ERR_INVALID_FORMAT;
716                 case Z_MEM_ERROR:
717                         return SSH_ERR_ALLOC_FAIL;
718                 case Z_STREAM_ERROR:
719                 default:
720                         ssh->state->compression_in_failures++;
721                         return SSH_ERR_INTERNAL_ERROR;
722                 }
723         }
724         /* NOTREACHED */
725 }
726
727 /* Serialise compression state into a blob for privsep */
728 static int
729 ssh_packet_get_compress_state(struct sshbuf *m, struct ssh *ssh)
730 {
731         struct session_state *state = ssh->state;
732         struct sshbuf *b;
733         int r;
734
735         if ((b = sshbuf_new()) == NULL)
736                 return SSH_ERR_ALLOC_FAIL;
737         if (state->compression_in_started) {
738                 if ((r = sshbuf_put_string(b, &state->compression_in_stream,
739                     sizeof(state->compression_in_stream))) != 0)
740                         goto out;
741         } else if ((r = sshbuf_put_string(b, NULL, 0)) != 0)
742                 goto out;
743         if (state->compression_out_started) {
744                 if ((r = sshbuf_put_string(b, &state->compression_out_stream,
745                     sizeof(state->compression_out_stream))) != 0)
746                         goto out;
747         } else if ((r = sshbuf_put_string(b, NULL, 0)) != 0)
748                 goto out;
749         r = sshbuf_put_stringb(m, b);
750  out:
751         sshbuf_free(b);
752         return r;
753 }
754
755 /* Deserialise compression state from a blob for privsep */
756 static int
757 ssh_packet_set_compress_state(struct ssh *ssh, struct sshbuf *m)
758 {
759         struct session_state *state = ssh->state;
760         struct sshbuf *b = NULL;
761         int r;
762         const u_char *inblob, *outblob;
763         size_t inl, outl;
764
765         if ((r = sshbuf_froms(m, &b)) != 0)
766                 goto out;
767         if ((r = sshbuf_get_string_direct(b, &inblob, &inl)) != 0 ||
768             (r = sshbuf_get_string_direct(b, &outblob, &outl)) != 0)
769                 goto out;
770         if (inl == 0)
771                 state->compression_in_started = 0;
772         else if (inl != sizeof(state->compression_in_stream)) {
773                 r = SSH_ERR_INTERNAL_ERROR;
774                 goto out;
775         } else {
776                 state->compression_in_started = 1;
777                 memcpy(&state->compression_in_stream, inblob, inl);
778         }
779         if (outl == 0)
780                 state->compression_out_started = 0;
781         else if (outl != sizeof(state->compression_out_stream)) {
782                 r = SSH_ERR_INTERNAL_ERROR;
783                 goto out;
784         } else {
785                 state->compression_out_started = 1;
786                 memcpy(&state->compression_out_stream, outblob, outl);
787         }
788         r = 0;
789  out:
790         sshbuf_free(b);
791         return r;
792 }
793
794 void
795 ssh_packet_set_compress_hooks(struct ssh *ssh, void *ctx,
796     void *(*allocfunc)(void *, u_int, u_int),
797     void (*freefunc)(void *, void *))
798 {
799         ssh->state->compression_out_stream.zalloc = (alloc_func)allocfunc;
800         ssh->state->compression_out_stream.zfree = (free_func)freefunc;
801         ssh->state->compression_out_stream.opaque = ctx;
802         ssh->state->compression_in_stream.zalloc = (alloc_func)allocfunc;
803         ssh->state->compression_in_stream.zfree = (free_func)freefunc;
804         ssh->state->compression_in_stream.opaque = ctx;
805 }
806
807 /*
808  * Causes any further packets to be encrypted using the given key.  The same
809  * key is used for both sending and reception.  However, both directions are
810  * encrypted independently of each other.
811  */
812
813 void
814 ssh_packet_set_encryption_key(struct ssh *ssh, const u_char *key, u_int keylen, int number)
815 {
816 #ifndef WITH_SSH1
817         fatal("no SSH protocol 1 support");
818 #else /* WITH_SSH1 */
819         struct session_state *state = ssh->state;
820         const struct sshcipher *cipher = cipher_by_number(number);
821         int r;
822         const char *wmsg;
823
824         if (cipher == NULL)
825                 fatal("%s: unknown cipher number %d", __func__, number);
826         if (keylen < 20)
827                 fatal("%s: keylen too small: %d", __func__, keylen);
828         if (keylen > SSH_SESSION_KEY_LENGTH)
829                 fatal("%s: keylen too big: %d", __func__, keylen);
830         memcpy(state->ssh1_key, key, keylen);
831         state->ssh1_keylen = keylen;
832         if ((r = cipher_init(&state->send_context, cipher, key, keylen,
833             NULL, 0, CIPHER_ENCRYPT)) != 0 ||
834             (r = cipher_init(&state->receive_context, cipher, key, keylen,
835             NULL, 0, CIPHER_DECRYPT) != 0))
836                 fatal("%s: cipher_init failed: %s", __func__, ssh_err(r));
837         if (!state->cipher_warning_done &&
838             ((wmsg = cipher_warning_message(&state->send_context)) != NULL ||
839             (wmsg = cipher_warning_message(&state->send_context)) != NULL)) {
840                 error("Warning: %s", wmsg);
841                 state->cipher_warning_done = 1;
842         }
843 #endif /* WITH_SSH1 */
844 }
845
846 /*
847  * Finalizes and sends the packet.  If the encryption key has been set,
848  * encrypts the packet before sending.
849  */
850
851 int
852 ssh_packet_send1(struct ssh *ssh)
853 {
854         struct session_state *state = ssh->state;
855         u_char buf[8], *cp;
856         int r, padding, len;
857         u_int checksum;
858
859         /*
860          * If using packet compression, compress the payload of the outgoing
861          * packet.
862          */
863         if (state->packet_compression) {
864                 sshbuf_reset(state->compression_buffer);
865                 /* Skip padding. */
866                 if ((r = sshbuf_consume(state->outgoing_packet, 8)) != 0)
867                         goto out;
868                 /* padding */
869                 if ((r = sshbuf_put(state->compression_buffer,
870                     "\0\0\0\0\0\0\0\0", 8)) != 0)
871                         goto out;
872                 if ((r = compress_buffer(ssh, state->outgoing_packet,
873                     state->compression_buffer)) != 0)
874                         goto out;
875                 sshbuf_reset(state->outgoing_packet);
876                 if ((r = sshbuf_putb(state->outgoing_packet,
877                     state->compression_buffer)) != 0)
878                         goto out;
879         }
880         /* Compute packet length without padding (add checksum, remove padding). */
881         len = sshbuf_len(state->outgoing_packet) + 4 - 8;
882
883         /* Insert padding. Initialized to zero in packet_start1() */
884         padding = 8 - len % 8;
885         if (!state->send_context.plaintext) {
886                 cp = sshbuf_mutable_ptr(state->outgoing_packet);
887                 if (cp == NULL) {
888                         r = SSH_ERR_INTERNAL_ERROR;
889                         goto out;
890                 }
891                 arc4random_buf(cp + 8 - padding, padding);
892         }
893         if ((r = sshbuf_consume(state->outgoing_packet, 8 - padding)) != 0)
894                 goto out;
895
896         /* Add check bytes. */
897         checksum = ssh_crc32(sshbuf_ptr(state->outgoing_packet),
898             sshbuf_len(state->outgoing_packet));
899         POKE_U32(buf, checksum);
900         if ((r = sshbuf_put(state->outgoing_packet, buf, 4)) != 0)
901                 goto out;
902
903 #ifdef PACKET_DEBUG
904         fprintf(stderr, "packet_send plain: ");
905         sshbuf_dump(state->outgoing_packet, stderr);
906 #endif
907
908         /* Append to output. */
909         POKE_U32(buf, len);
910         if ((r = sshbuf_put(state->output, buf, 4)) != 0)
911                 goto out;
912         if ((r = sshbuf_reserve(state->output,
913             sshbuf_len(state->outgoing_packet), &cp)) != 0)
914                 goto out;
915         if ((r = cipher_crypt(&state->send_context, 0, cp,
916             sshbuf_ptr(state->outgoing_packet),
917             sshbuf_len(state->outgoing_packet), 0, 0)) != 0)
918                 goto out;
919
920 #ifdef PACKET_DEBUG
921         fprintf(stderr, "encrypted: ");
922         sshbuf_dump(state->output, stderr);
923 #endif
924         state->p_send.packets++;
925         state->p_send.bytes += len +
926             sshbuf_len(state->outgoing_packet);
927         sshbuf_reset(state->outgoing_packet);
928
929         /*
930          * Note that the packet is now only buffered in output.  It won't be
931          * actually sent until ssh_packet_write_wait or ssh_packet_write_poll
932          * is called.
933          */
934         r = 0;
935  out:
936         return r;
937 }
938
939 int
940 ssh_set_newkeys(struct ssh *ssh, int mode)
941 {
942         struct session_state *state = ssh->state;
943         struct sshenc *enc;
944         struct sshmac *mac;
945         struct sshcomp *comp;
946         struct sshcipher_ctx *cc;
947         u_int64_t *max_blocks;
948         const char *wmsg;
949         int r, crypt_type;
950
951         debug2("set_newkeys: mode %d", mode);
952
953         if (mode == MODE_OUT) {
954                 cc = &state->send_context;
955                 crypt_type = CIPHER_ENCRYPT;
956                 state->p_send.packets = state->p_send.blocks = 0;
957                 max_blocks = &state->max_blocks_out;
958         } else {
959                 cc = &state->receive_context;
960                 crypt_type = CIPHER_DECRYPT;
961                 state->p_read.packets = state->p_read.blocks = 0;
962                 max_blocks = &state->max_blocks_in;
963         }
964         if (state->newkeys[mode] != NULL) {
965                 debug("set_newkeys: rekeying, input %llu bytes %llu blocks, "
966                    "output %llu bytes %llu blocks",
967                    (unsigned long long)state->p_read.bytes,
968                    (unsigned long long)state->p_read.blocks,
969                    (unsigned long long)state->p_send.bytes,
970                    (unsigned long long)state->p_send.blocks);
971                 if ((r = cipher_cleanup(cc)) != 0)
972                         return r;
973                 enc  = &state->newkeys[mode]->enc;
974                 mac  = &state->newkeys[mode]->mac;
975                 comp = &state->newkeys[mode]->comp;
976                 mac_clear(mac);
977                 explicit_bzero(enc->iv,  enc->iv_len);
978                 explicit_bzero(enc->key, enc->key_len);
979                 explicit_bzero(mac->key, mac->key_len);
980                 free(enc->name);
981                 free(enc->iv);
982                 free(enc->key);
983                 free(mac->name);
984                 free(mac->key);
985                 free(comp->name);
986                 free(state->newkeys[mode]);
987         }
988         /* move newkeys from kex to state */
989         if ((state->newkeys[mode] = ssh->kex->newkeys[mode]) == NULL)
990                 return SSH_ERR_INTERNAL_ERROR;
991         ssh->kex->newkeys[mode] = NULL;
992         enc  = &state->newkeys[mode]->enc;
993         mac  = &state->newkeys[mode]->mac;
994         comp = &state->newkeys[mode]->comp;
995         if (cipher_authlen(enc->cipher) == 0) {
996                 if ((r = mac_init(mac)) != 0)
997                         return r;
998         }
999         mac->enabled = 1;
1000         DBG(debug("cipher_init_context: %d", mode));
1001         if ((r = cipher_init(cc, enc->cipher, enc->key, enc->key_len,
1002             enc->iv, enc->iv_len, crypt_type)) != 0)
1003                 return r;
1004         if (!state->cipher_warning_done &&
1005             (wmsg = cipher_warning_message(cc)) != NULL) {
1006                 error("Warning: %s", wmsg);
1007                 state->cipher_warning_done = 1;
1008         }
1009         /* Deleting the keys does not gain extra security */
1010         /* explicit_bzero(enc->iv,  enc->block_size);
1011            explicit_bzero(enc->key, enc->key_len);
1012            explicit_bzero(mac->key, mac->key_len); */
1013         if ((comp->type == COMP_ZLIB ||
1014             (comp->type == COMP_DELAYED &&
1015              state->after_authentication)) && comp->enabled == 0) {
1016                 if ((r = ssh_packet_init_compression(ssh)) < 0)
1017                         return r;
1018                 if (mode == MODE_OUT) {
1019                         if ((r = start_compression_out(ssh, 6)) != 0)
1020                                 return r;
1021                 } else {
1022                         if ((r = start_compression_in(ssh)) != 0)
1023                                 return r;
1024                 }
1025                 comp->enabled = 1;
1026         }
1027         /*
1028          * The 2^(blocksize*2) limit is too expensive for 3DES,
1029          * blowfish, etc, so enforce a 1GB limit for small blocksizes.
1030          */
1031         if (enc->block_size >= 16)
1032                 *max_blocks = (u_int64_t)1 << (enc->block_size*2);
1033         else
1034                 *max_blocks = ((u_int64_t)1 << 30) / enc->block_size;
1035         if (state->rekey_limit)
1036                 *max_blocks = MIN(*max_blocks,
1037                     state->rekey_limit / enc->block_size);
1038         debug("rekey after %llu blocks", (unsigned long long)*max_blocks);
1039         return 0;
1040 }
1041
1042 #define MAX_PACKETS     (1U<<31)
1043 static int
1044 ssh_packet_need_rekeying(struct ssh *ssh, u_int outbound_packet_len)
1045 {
1046         struct session_state *state = ssh->state;
1047         u_int32_t out_blocks;
1048
1049         /* XXX client can't cope with rekeying pre-auth */
1050         if (!state->after_authentication)
1051                 return 0;
1052
1053         /* Haven't keyed yet or KEX in progress. */
1054         if (ssh->kex == NULL || ssh_packet_is_rekeying(ssh))
1055                 return 0;
1056
1057         /* Peer can't rekey */
1058         if (ssh->compat & SSH_BUG_NOREKEY)
1059                 return 0;
1060
1061         /*
1062          * Permit one packet in or out per rekey - this allows us to
1063          * make progress when rekey limits are very small.
1064          */
1065         if (state->p_send.packets == 0 && state->p_read.packets == 0)
1066                 return 0;
1067
1068         /* Time-based rekeying */
1069         if (state->rekey_interval != 0 &&
1070             state->rekey_time + state->rekey_interval <= monotime())
1071                 return 1;
1072
1073         /* Always rekey when MAX_PACKETS sent in either direction */
1074         if (state->p_send.packets > MAX_PACKETS ||
1075             state->p_read.packets > MAX_PACKETS)
1076                 return 1;
1077
1078         /* Rekey after (cipher-specific) maxiumum blocks */
1079         out_blocks = roundup(outbound_packet_len,
1080             state->newkeys[MODE_OUT]->enc.block_size);
1081         return (state->max_blocks_out &&
1082             (state->p_send.blocks + out_blocks > state->max_blocks_out)) ||
1083             (state->max_blocks_in &&
1084             (state->p_read.blocks > state->max_blocks_in));
1085 }
1086
1087 /*
1088  * Delayed compression for SSH2 is enabled after authentication:
1089  * This happens on the server side after a SSH2_MSG_USERAUTH_SUCCESS is sent,
1090  * and on the client side after a SSH2_MSG_USERAUTH_SUCCESS is received.
1091  */
1092 static int
1093 ssh_packet_enable_delayed_compress(struct ssh *ssh)
1094 {
1095         struct session_state *state = ssh->state;
1096         struct sshcomp *comp = NULL;
1097         int r, mode;
1098
1099         /*
1100          * Remember that we are past the authentication step, so rekeying
1101          * with COMP_DELAYED will turn on compression immediately.
1102          */
1103         state->after_authentication = 1;
1104         for (mode = 0; mode < MODE_MAX; mode++) {
1105                 /* protocol error: USERAUTH_SUCCESS received before NEWKEYS */
1106                 if (state->newkeys[mode] == NULL)
1107                         continue;
1108                 comp = &state->newkeys[mode]->comp;
1109                 if (comp && !comp->enabled && comp->type == COMP_DELAYED) {
1110                         if ((r = ssh_packet_init_compression(ssh)) != 0)
1111                                 return r;
1112                         if (mode == MODE_OUT) {
1113                                 if ((r = start_compression_out(ssh, 6)) != 0)
1114                                         return r;
1115                         } else {
1116                                 if ((r = start_compression_in(ssh)) != 0)
1117                                         return r;
1118                         }
1119                         comp->enabled = 1;
1120                 }
1121         }
1122         return 0;
1123 }
1124
1125 /* Used to mute debug logging for noisy packet types */
1126 static int
1127 ssh_packet_log_type(u_char type)
1128 {
1129         switch (type) {
1130         case SSH2_MSG_CHANNEL_DATA:
1131         case SSH2_MSG_CHANNEL_EXTENDED_DATA:
1132         case SSH2_MSG_CHANNEL_WINDOW_ADJUST:
1133                 return 0;
1134         default:
1135                 return 1;
1136         }
1137 }
1138
1139 /*
1140  * Finalize packet in SSH2 format (compress, mac, encrypt, enqueue)
1141  */
1142 int
1143 ssh_packet_send2_wrapped(struct ssh *ssh)
1144 {
1145         struct session_state *state = ssh->state;
1146         u_char type, *cp, macbuf[SSH_DIGEST_MAX_LENGTH];
1147         u_char padlen, pad = 0;
1148         u_int authlen = 0, aadlen = 0;
1149         u_int len;
1150         struct sshenc *enc   = NULL;
1151         struct sshmac *mac   = NULL;
1152         struct sshcomp *comp = NULL;
1153         int r, block_size;
1154
1155         if (state->newkeys[MODE_OUT] != NULL) {
1156                 enc  = &state->newkeys[MODE_OUT]->enc;
1157                 mac  = &state->newkeys[MODE_OUT]->mac;
1158                 comp = &state->newkeys[MODE_OUT]->comp;
1159                 /* disable mac for authenticated encryption */
1160                 if ((authlen = cipher_authlen(enc->cipher)) != 0)
1161                         mac = NULL;
1162         }
1163         block_size = enc ? enc->block_size : 8;
1164         aadlen = (mac && mac->enabled && mac->etm) || authlen ? 4 : 0;
1165
1166         type = (sshbuf_ptr(state->outgoing_packet))[5];
1167         if (ssh_packet_log_type(type))
1168                 debug3("send packet: type %u", type);
1169 #ifdef PACKET_DEBUG
1170         fprintf(stderr, "plain:     ");
1171         sshbuf_dump(state->outgoing_packet, stderr);
1172 #endif
1173
1174         if (comp && comp->enabled) {
1175                 len = sshbuf_len(state->outgoing_packet);
1176                 /* skip header, compress only payload */
1177                 if ((r = sshbuf_consume(state->outgoing_packet, 5)) != 0)
1178                         goto out;
1179                 sshbuf_reset(state->compression_buffer);
1180                 if ((r = compress_buffer(ssh, state->outgoing_packet,
1181                     state->compression_buffer)) != 0)
1182                         goto out;
1183                 sshbuf_reset(state->outgoing_packet);
1184                 if ((r = sshbuf_put(state->outgoing_packet,
1185                     "\0\0\0\0\0", 5)) != 0 ||
1186                     (r = sshbuf_putb(state->outgoing_packet,
1187                     state->compression_buffer)) != 0)
1188                         goto out;
1189                 DBG(debug("compression: raw %d compressed %zd", len,
1190                     sshbuf_len(state->outgoing_packet)));
1191         }
1192
1193         /* sizeof (packet_len + pad_len + payload) */
1194         len = sshbuf_len(state->outgoing_packet);
1195
1196         /*
1197          * calc size of padding, alloc space, get random data,
1198          * minimum padding is 4 bytes
1199          */
1200         len -= aadlen; /* packet length is not encrypted for EtM modes */
1201         padlen = block_size - (len % block_size);
1202         if (padlen < 4)
1203                 padlen += block_size;
1204         if (state->extra_pad) {
1205                 /* will wrap if extra_pad+padlen > 255 */
1206                 state->extra_pad =
1207                     roundup(state->extra_pad, block_size);
1208                 pad = state->extra_pad -
1209                     ((len + padlen) % state->extra_pad);
1210                 DBG(debug3("%s: adding %d (len %d padlen %d extra_pad %d)",
1211                     __func__, pad, len, padlen, state->extra_pad));
1212                 padlen += pad;
1213                 state->extra_pad = 0;
1214         }
1215         if ((r = sshbuf_reserve(state->outgoing_packet, padlen, &cp)) != 0)
1216                 goto out;
1217         if (enc && !state->send_context.plaintext) {
1218                 /* random padding */
1219                 arc4random_buf(cp, padlen);
1220         } else {
1221                 /* clear padding */
1222                 explicit_bzero(cp, padlen);
1223         }
1224         /* sizeof (packet_len + pad_len + payload + padding) */
1225         len = sshbuf_len(state->outgoing_packet);
1226         cp = sshbuf_mutable_ptr(state->outgoing_packet);
1227         if (cp == NULL) {
1228                 r = SSH_ERR_INTERNAL_ERROR;
1229                 goto out;
1230         }
1231         /* packet_length includes payload, padding and padding length field */
1232         POKE_U32(cp, len - 4);
1233         cp[4] = padlen;
1234         DBG(debug("send: len %d (includes padlen %d, aadlen %d)",
1235             len, padlen, aadlen));
1236
1237         /* compute MAC over seqnr and packet(length fields, payload, padding) */
1238         if (mac && mac->enabled && !mac->etm) {
1239                 if ((r = mac_compute(mac, state->p_send.seqnr,
1240                     sshbuf_ptr(state->outgoing_packet), len,
1241                     macbuf, sizeof(macbuf))) != 0)
1242                         goto out;
1243                 DBG(debug("done calc MAC out #%d", state->p_send.seqnr));
1244         }
1245         /* encrypt packet and append to output buffer. */
1246         if ((r = sshbuf_reserve(state->output,
1247             sshbuf_len(state->outgoing_packet) + authlen, &cp)) != 0)
1248                 goto out;
1249         if ((r = cipher_crypt(&state->send_context, state->p_send.seqnr, cp,
1250             sshbuf_ptr(state->outgoing_packet),
1251             len - aadlen, aadlen, authlen)) != 0)
1252                 goto out;
1253         /* append unencrypted MAC */
1254         if (mac && mac->enabled) {
1255                 if (mac->etm) {
1256                         /* EtM: compute mac over aadlen + cipher text */
1257                         if ((r = mac_compute(mac, state->p_send.seqnr,
1258                             cp, len, macbuf, sizeof(macbuf))) != 0)
1259                                 goto out;
1260                         DBG(debug("done calc MAC(EtM) out #%d",
1261                             state->p_send.seqnr));
1262                 }
1263                 if ((r = sshbuf_put(state->output, macbuf, mac->mac_len)) != 0)
1264                         goto out;
1265         }
1266 #ifdef PACKET_DEBUG
1267         fprintf(stderr, "encrypted: ");
1268         sshbuf_dump(state->output, stderr);
1269 #endif
1270         /* increment sequence number for outgoing packets */
1271         if (++state->p_send.seqnr == 0)
1272                 logit("outgoing seqnr wraps around");
1273         if (++state->p_send.packets == 0)
1274                 if (!(ssh->compat & SSH_BUG_NOREKEY))
1275                         return SSH_ERR_NEED_REKEY;
1276         state->p_send.blocks += len / block_size;
1277         state->p_send.bytes += len;
1278         sshbuf_reset(state->outgoing_packet);
1279
1280         if (type == SSH2_MSG_NEWKEYS)
1281                 r = ssh_set_newkeys(ssh, MODE_OUT);
1282         else if (type == SSH2_MSG_USERAUTH_SUCCESS && state->server_side)
1283                 r = ssh_packet_enable_delayed_compress(ssh);
1284         else
1285                 r = 0;
1286  out:
1287         return r;
1288 }
1289
1290 /* returns non-zero if the specified packet type is usec by KEX */
1291 static int
1292 ssh_packet_type_is_kex(u_char type)
1293 {
1294         return
1295             type >= SSH2_MSG_TRANSPORT_MIN &&
1296             type <= SSH2_MSG_TRANSPORT_MAX &&
1297             type != SSH2_MSG_SERVICE_REQUEST &&
1298             type != SSH2_MSG_SERVICE_ACCEPT &&
1299             type != SSH2_MSG_EXT_INFO;
1300 }
1301
1302 int
1303 ssh_packet_send2(struct ssh *ssh)
1304 {
1305         struct session_state *state = ssh->state;
1306         struct packet *p;
1307         u_char type;
1308         int r, need_rekey;
1309
1310         if (sshbuf_len(state->outgoing_packet) < 6)
1311                 return SSH_ERR_INTERNAL_ERROR;
1312         type = sshbuf_ptr(state->outgoing_packet)[5];
1313         need_rekey = !ssh_packet_type_is_kex(type) &&
1314             ssh_packet_need_rekeying(ssh, sshbuf_len(state->outgoing_packet));
1315
1316         /*
1317          * During rekeying we can only send key exchange messages.
1318          * Queue everything else.
1319          */
1320         if ((need_rekey || state->rekeying) && !ssh_packet_type_is_kex(type)) {
1321                 if (need_rekey)
1322                         debug3("%s: rekex triggered", __func__);
1323                 debug("enqueue packet: %u", type);
1324                 p = calloc(1, sizeof(*p));
1325                 if (p == NULL)
1326                         return SSH_ERR_ALLOC_FAIL;
1327                 p->type = type;
1328                 p->payload = state->outgoing_packet;
1329                 TAILQ_INSERT_TAIL(&state->outgoing, p, next);
1330                 state->outgoing_packet = sshbuf_new();
1331                 if (state->outgoing_packet == NULL)
1332                         return SSH_ERR_ALLOC_FAIL;
1333                 if (need_rekey) {
1334                         /*
1335                          * This packet triggered a rekey, so send the
1336                          * KEXINIT now.
1337                          * NB. reenters this function via kex_start_rekex().
1338                          */
1339                         return kex_start_rekex(ssh);
1340                 }
1341                 return 0;
1342         }
1343
1344         /* rekeying starts with sending KEXINIT */
1345         if (type == SSH2_MSG_KEXINIT)
1346                 state->rekeying = 1;
1347
1348         if ((r = ssh_packet_send2_wrapped(ssh)) != 0)
1349                 return r;
1350
1351         /* after a NEWKEYS message we can send the complete queue */
1352         if (type == SSH2_MSG_NEWKEYS) {
1353                 state->rekeying = 0;
1354                 state->rekey_time = monotime();
1355                 while ((p = TAILQ_FIRST(&state->outgoing))) {
1356                         type = p->type;
1357                         /*
1358                          * If this packet triggers a rekex, then skip the
1359                          * remaining packets in the queue for now.
1360                          * NB. re-enters this function via kex_start_rekex.
1361                          */
1362                         if (ssh_packet_need_rekeying(ssh,
1363                             sshbuf_len(p->payload))) {
1364                                 debug3("%s: queued packet triggered rekex",
1365                                     __func__);
1366                                 return kex_start_rekex(ssh);
1367                         }
1368                         debug("dequeue packet: %u", type);
1369                         sshbuf_free(state->outgoing_packet);
1370                         state->outgoing_packet = p->payload;
1371                         TAILQ_REMOVE(&state->outgoing, p, next);
1372                         memset(p, 0, sizeof(*p));
1373                         free(p);
1374                         if ((r = ssh_packet_send2_wrapped(ssh)) != 0)
1375                                 return r;
1376                 }
1377         }
1378         return 0;
1379 }
1380
1381 /*
1382  * Waits until a packet has been received, and returns its type.  Note that
1383  * no other data is processed until this returns, so this function should not
1384  * be used during the interactive session.
1385  */
1386
1387 int
1388 ssh_packet_read_seqnr(struct ssh *ssh, u_char *typep, u_int32_t *seqnr_p)
1389 {
1390         struct session_state *state = ssh->state;
1391         int len, r, ms_remain;
1392         fd_set *setp;
1393         char buf[8192];
1394         struct timeval timeout, start, *timeoutp = NULL;
1395
1396         DBG(debug("packet_read()"));
1397
1398         setp = calloc(howmany(state->connection_in + 1,
1399             NFDBITS), sizeof(fd_mask));
1400         if (setp == NULL)
1401                 return SSH_ERR_ALLOC_FAIL;
1402
1403         /*
1404          * Since we are blocking, ensure that all written packets have
1405          * been sent.
1406          */
1407         if ((r = ssh_packet_write_wait(ssh)) != 0)
1408                 goto out;
1409
1410         /* Stay in the loop until we have received a complete packet. */
1411         for (;;) {
1412                 /* Try to read a packet from the buffer. */
1413                 r = ssh_packet_read_poll_seqnr(ssh, typep, seqnr_p);
1414                 if (r != 0)
1415                         break;
1416                 if (!compat20 && (
1417                     *typep == SSH_SMSG_SUCCESS
1418                     || *typep == SSH_SMSG_FAILURE
1419                     || *typep == SSH_CMSG_EOF
1420                     || *typep == SSH_CMSG_EXIT_CONFIRMATION))
1421                         if ((r = sshpkt_get_end(ssh)) != 0)
1422                                 break;
1423                 /* If we got a packet, return it. */
1424                 if (*typep != SSH_MSG_NONE)
1425                         break;
1426                 /*
1427                  * Otherwise, wait for some data to arrive, add it to the
1428                  * buffer, and try again.
1429                  */
1430                 memset(setp, 0, howmany(state->connection_in + 1,
1431                     NFDBITS) * sizeof(fd_mask));
1432                 FD_SET(state->connection_in, setp);
1433
1434                 if (state->packet_timeout_ms > 0) {
1435                         ms_remain = state->packet_timeout_ms;
1436                         timeoutp = &timeout;
1437                 }
1438                 /* Wait for some data to arrive. */
1439                 for (;;) {
1440                         if (state->packet_timeout_ms != -1) {
1441                                 ms_to_timeval(&timeout, ms_remain);
1442                                 gettimeofday(&start, NULL);
1443                         }
1444                         if ((r = select(state->connection_in + 1, setp,
1445                             NULL, NULL, timeoutp)) >= 0)
1446                                 break;
1447                         if (errno != EAGAIN && errno != EINTR &&
1448                             errno != EWOULDBLOCK)
1449                                 break;
1450                         if (state->packet_timeout_ms == -1)
1451                                 continue;
1452                         ms_subtract_diff(&start, &ms_remain);
1453                         if (ms_remain <= 0) {
1454                                 r = 0;
1455                                 break;
1456                         }
1457                 }
1458                 if (r == 0)
1459                         return SSH_ERR_CONN_TIMEOUT;
1460                 /* Read data from the socket. */
1461                 len = read(state->connection_in, buf, sizeof(buf));
1462                 if (len == 0) {
1463                         r = SSH_ERR_CONN_CLOSED;
1464                         goto out;
1465                 }
1466                 if (len < 0) {
1467                         r = SSH_ERR_SYSTEM_ERROR;
1468                         goto out;
1469                 }
1470
1471                 /* Append it to the buffer. */
1472                 if ((r = ssh_packet_process_incoming(ssh, buf, len)) != 0)
1473                         goto out;
1474         }
1475  out:
1476         free(setp);
1477         return r;
1478 }
1479
1480 int
1481 ssh_packet_read(struct ssh *ssh)
1482 {
1483         u_char type;
1484         int r;
1485
1486         if ((r = ssh_packet_read_seqnr(ssh, &type, NULL)) != 0)
1487                 fatal("%s: %s", __func__, ssh_err(r));
1488         return type;
1489 }
1490
1491 /*
1492  * Waits until a packet has been received, verifies that its type matches
1493  * that given, and gives a fatal error and exits if there is a mismatch.
1494  */
1495
1496 int
1497 ssh_packet_read_expect(struct ssh *ssh, u_int expected_type)
1498 {
1499         int r;
1500         u_char type;
1501
1502         if ((r = ssh_packet_read_seqnr(ssh, &type, NULL)) != 0)
1503                 return r;
1504         if (type != expected_type) {
1505                 if ((r = sshpkt_disconnect(ssh,
1506                     "Protocol error: expected packet type %d, got %d",
1507                     expected_type, type)) != 0)
1508                         return r;
1509                 return SSH_ERR_PROTOCOL_ERROR;
1510         }
1511         return 0;
1512 }
1513
1514 /* Checks if a full packet is available in the data received so far via
1515  * packet_process_incoming.  If so, reads the packet; otherwise returns
1516  * SSH_MSG_NONE.  This does not wait for data from the connection.
1517  *
1518  * SSH_MSG_DISCONNECT is handled specially here.  Also,
1519  * SSH_MSG_IGNORE messages are skipped by this function and are never returned
1520  * to higher levels.
1521  */
1522
1523 int
1524 ssh_packet_read_poll1(struct ssh *ssh, u_char *typep)
1525 {
1526         struct session_state *state = ssh->state;
1527         u_int len, padded_len;
1528         const char *emsg;
1529         const u_char *cp;
1530         u_char *p;
1531         u_int checksum, stored_checksum;
1532         int r;
1533
1534         *typep = SSH_MSG_NONE;
1535
1536         /* Check if input size is less than minimum packet size. */
1537         if (sshbuf_len(state->input) < 4 + 8)
1538                 return 0;
1539         /* Get length of incoming packet. */
1540         len = PEEK_U32(sshbuf_ptr(state->input));
1541         if (len < 1 + 2 + 2 || len > 256 * 1024) {
1542                 if ((r = sshpkt_disconnect(ssh, "Bad packet length %u",
1543                     len)) != 0)
1544                         return r;
1545                 return SSH_ERR_CONN_CORRUPT;
1546         }
1547         padded_len = (len + 8) & ~7;
1548
1549         /* Check if the packet has been entirely received. */
1550         if (sshbuf_len(state->input) < 4 + padded_len)
1551                 return 0;
1552
1553         /* The entire packet is in buffer. */
1554
1555         /* Consume packet length. */
1556         if ((r = sshbuf_consume(state->input, 4)) != 0)
1557                 goto out;
1558
1559         /*
1560          * Cryptographic attack detector for ssh
1561          * (C)1998 CORE-SDI, Buenos Aires Argentina
1562          * Ariel Futoransky(futo@core-sdi.com)
1563          */
1564         if (!state->receive_context.plaintext) {
1565                 emsg = NULL;
1566                 switch (detect_attack(&state->deattack,
1567                     sshbuf_ptr(state->input), padded_len)) {
1568                 case DEATTACK_OK:
1569                         break;
1570                 case DEATTACK_DETECTED:
1571                         emsg = "crc32 compensation attack detected";
1572                         break;
1573                 case DEATTACK_DOS_DETECTED:
1574                         emsg = "deattack denial of service detected";
1575                         break;
1576                 default:
1577                         emsg = "deattack error";
1578                         break;
1579                 }
1580                 if (emsg != NULL) {
1581                         error("%s", emsg);
1582                         if ((r = sshpkt_disconnect(ssh, "%s", emsg)) != 0 ||
1583                             (r = ssh_packet_write_wait(ssh)) != 0)
1584                                         return r;
1585                         return SSH_ERR_CONN_CORRUPT;
1586                 }
1587         }
1588
1589         /* Decrypt data to incoming_packet. */
1590         sshbuf_reset(state->incoming_packet);
1591         if ((r = sshbuf_reserve(state->incoming_packet, padded_len, &p)) != 0)
1592                 goto out;
1593         if ((r = cipher_crypt(&state->receive_context, 0, p,
1594             sshbuf_ptr(state->input), padded_len, 0, 0)) != 0)
1595                 goto out;
1596
1597         if ((r = sshbuf_consume(state->input, padded_len)) != 0)
1598                 goto out;
1599
1600 #ifdef PACKET_DEBUG
1601         fprintf(stderr, "read_poll plain: ");
1602         sshbuf_dump(state->incoming_packet, stderr);
1603 #endif
1604
1605         /* Compute packet checksum. */
1606         checksum = ssh_crc32(sshbuf_ptr(state->incoming_packet),
1607             sshbuf_len(state->incoming_packet) - 4);
1608
1609         /* Skip padding. */
1610         if ((r = sshbuf_consume(state->incoming_packet, 8 - len % 8)) != 0)
1611                 goto out;
1612
1613         /* Test check bytes. */
1614         if (len != sshbuf_len(state->incoming_packet)) {
1615                 error("%s: len %d != sshbuf_len %zd", __func__,
1616                     len, sshbuf_len(state->incoming_packet));
1617                 if ((r = sshpkt_disconnect(ssh, "invalid packet length")) != 0 ||
1618                     (r = ssh_packet_write_wait(ssh)) != 0)
1619                         return r;
1620                 return SSH_ERR_CONN_CORRUPT;
1621         }
1622
1623         cp = sshbuf_ptr(state->incoming_packet) + len - 4;
1624         stored_checksum = PEEK_U32(cp);
1625         if (checksum != stored_checksum) {
1626                 error("Corrupted check bytes on input");
1627                 if ((r = sshpkt_disconnect(ssh, "connection corrupted")) != 0 ||
1628                     (r = ssh_packet_write_wait(ssh)) != 0)
1629                         return r;
1630                 return SSH_ERR_CONN_CORRUPT;
1631         }
1632         if ((r = sshbuf_consume_end(state->incoming_packet, 4)) < 0)
1633                 goto out;
1634
1635         if (state->packet_compression) {
1636                 sshbuf_reset(state->compression_buffer);
1637                 if ((r = uncompress_buffer(ssh, state->incoming_packet,
1638                     state->compression_buffer)) != 0)
1639                         goto out;
1640                 sshbuf_reset(state->incoming_packet);
1641                 if ((r = sshbuf_putb(state->incoming_packet,
1642                     state->compression_buffer)) != 0)
1643                         goto out;
1644         }
1645         state->p_read.packets++;
1646         state->p_read.bytes += padded_len + 4;
1647         if ((r = sshbuf_get_u8(state->incoming_packet, typep)) != 0)
1648                 goto out;
1649         if (*typep < SSH_MSG_MIN || *typep > SSH_MSG_MAX) {
1650                 error("Invalid ssh1 packet type: %d", *typep);
1651                 if ((r = sshpkt_disconnect(ssh, "invalid packet type")) != 0 ||
1652                     (r = ssh_packet_write_wait(ssh)) != 0)
1653                         return r;
1654                 return SSH_ERR_PROTOCOL_ERROR;
1655         }
1656         r = 0;
1657  out:
1658         return r;
1659 }
1660
1661 int
1662 ssh_packet_read_poll2(struct ssh *ssh, u_char *typep, u_int32_t *seqnr_p)
1663 {
1664         struct session_state *state = ssh->state;
1665         u_int padlen, need;
1666         u_char *cp, macbuf[SSH_DIGEST_MAX_LENGTH];
1667         u_int maclen, aadlen = 0, authlen = 0, block_size;
1668         struct sshenc *enc   = NULL;
1669         struct sshmac *mac   = NULL;
1670         struct sshcomp *comp = NULL;
1671         int r;
1672
1673         *typep = SSH_MSG_NONE;
1674
1675         if (state->packet_discard)
1676                 return 0;
1677
1678         if (state->newkeys[MODE_IN] != NULL) {
1679                 enc  = &state->newkeys[MODE_IN]->enc;
1680                 mac  = &state->newkeys[MODE_IN]->mac;
1681                 comp = &state->newkeys[MODE_IN]->comp;
1682                 /* disable mac for authenticated encryption */
1683                 if ((authlen = cipher_authlen(enc->cipher)) != 0)
1684                         mac = NULL;
1685         }
1686         maclen = mac && mac->enabled ? mac->mac_len : 0;
1687         block_size = enc ? enc->block_size : 8;
1688         aadlen = (mac && mac->enabled && mac->etm) || authlen ? 4 : 0;
1689
1690         if (aadlen && state->packlen == 0) {
1691                 if (cipher_get_length(&state->receive_context,
1692                     &state->packlen, state->p_read.seqnr,
1693                     sshbuf_ptr(state->input), sshbuf_len(state->input)) != 0)
1694                         return 0;
1695                 if (state->packlen < 1 + 4 ||
1696                     state->packlen > PACKET_MAX_SIZE) {
1697 #ifdef PACKET_DEBUG
1698                         sshbuf_dump(state->input, stderr);
1699 #endif
1700                         logit("Bad packet length %u.", state->packlen);
1701                         if ((r = sshpkt_disconnect(ssh, "Packet corrupt")) != 0)
1702                                 return r;
1703                         return SSH_ERR_CONN_CORRUPT;
1704                 }
1705                 sshbuf_reset(state->incoming_packet);
1706         } else if (state->packlen == 0) {
1707                 /*
1708                  * check if input size is less than the cipher block size,
1709                  * decrypt first block and extract length of incoming packet
1710                  */
1711                 if (sshbuf_len(state->input) < block_size)
1712                         return 0;
1713                 sshbuf_reset(state->incoming_packet);
1714                 if ((r = sshbuf_reserve(state->incoming_packet, block_size,
1715                     &cp)) != 0)
1716                         goto out;
1717                 if ((r = cipher_crypt(&state->receive_context,
1718                     state->p_send.seqnr, cp, sshbuf_ptr(state->input),
1719                     block_size, 0, 0)) != 0)
1720                         goto out;
1721                 state->packlen = PEEK_U32(sshbuf_ptr(state->incoming_packet));
1722                 if (state->packlen < 1 + 4 ||
1723                     state->packlen > PACKET_MAX_SIZE) {
1724 #ifdef PACKET_DEBUG
1725                         fprintf(stderr, "input: \n");
1726                         sshbuf_dump(state->input, stderr);
1727                         fprintf(stderr, "incoming_packet: \n");
1728                         sshbuf_dump(state->incoming_packet, stderr);
1729 #endif
1730                         logit("Bad packet length %u.", state->packlen);
1731                         return ssh_packet_start_discard(ssh, enc, mac,
1732                             state->packlen, PACKET_MAX_SIZE);
1733                 }
1734                 if ((r = sshbuf_consume(state->input, block_size)) != 0)
1735                         goto out;
1736         }
1737         DBG(debug("input: packet len %u", state->packlen+4));
1738
1739         if (aadlen) {
1740                 /* only the payload is encrypted */
1741                 need = state->packlen;
1742         } else {
1743                 /*
1744                  * the payload size and the payload are encrypted, but we
1745                  * have a partial packet of block_size bytes
1746                  */
1747                 need = 4 + state->packlen - block_size;
1748         }
1749         DBG(debug("partial packet: block %d, need %d, maclen %d, authlen %d,"
1750             " aadlen %d", block_size, need, maclen, authlen, aadlen));
1751         if (need % block_size != 0) {
1752                 logit("padding error: need %d block %d mod %d",
1753                     need, block_size, need % block_size);
1754                 return ssh_packet_start_discard(ssh, enc, mac,
1755                     state->packlen, PACKET_MAX_SIZE - block_size);
1756         }
1757         /*
1758          * check if the entire packet has been received and
1759          * decrypt into incoming_packet:
1760          * 'aadlen' bytes are unencrypted, but authenticated.
1761          * 'need' bytes are encrypted, followed by either
1762          * 'authlen' bytes of authentication tag or
1763          * 'maclen' bytes of message authentication code.
1764          */
1765         if (sshbuf_len(state->input) < aadlen + need + authlen + maclen)
1766                 return 0;
1767 #ifdef PACKET_DEBUG
1768         fprintf(stderr, "read_poll enc/full: ");
1769         sshbuf_dump(state->input, stderr);
1770 #endif
1771         /* EtM: compute mac over encrypted input */
1772         if (mac && mac->enabled && mac->etm) {
1773                 if ((r = mac_compute(mac, state->p_read.seqnr,
1774                     sshbuf_ptr(state->input), aadlen + need,
1775                     macbuf, sizeof(macbuf))) != 0)
1776                         goto out;
1777         }
1778         if ((r = sshbuf_reserve(state->incoming_packet, aadlen + need,
1779             &cp)) != 0)
1780                 goto out;
1781         if ((r = cipher_crypt(&state->receive_context, state->p_read.seqnr, cp,
1782             sshbuf_ptr(state->input), need, aadlen, authlen)) != 0)
1783                 goto out;
1784         if ((r = sshbuf_consume(state->input, aadlen + need + authlen)) != 0)
1785                 goto out;
1786         /*
1787          * compute MAC over seqnr and packet,
1788          * increment sequence number for incoming packet
1789          */
1790         if (mac && mac->enabled) {
1791                 if (!mac->etm)
1792                         if ((r = mac_compute(mac, state->p_read.seqnr,
1793                             sshbuf_ptr(state->incoming_packet),
1794                             sshbuf_len(state->incoming_packet),
1795                             macbuf, sizeof(macbuf))) != 0)
1796                                 goto out;
1797                 if (timingsafe_bcmp(macbuf, sshbuf_ptr(state->input),
1798                     mac->mac_len) != 0) {
1799                         logit("Corrupted MAC on input.");
1800                         if (need > PACKET_MAX_SIZE)
1801                                 return SSH_ERR_INTERNAL_ERROR;
1802                         return ssh_packet_start_discard(ssh, enc, mac,
1803                             state->packlen, PACKET_MAX_SIZE - need);
1804                 }
1805
1806                 DBG(debug("MAC #%d ok", state->p_read.seqnr));
1807                 if ((r = sshbuf_consume(state->input, mac->mac_len)) != 0)
1808                         goto out;
1809         }
1810         if (seqnr_p != NULL)
1811                 *seqnr_p = state->p_read.seqnr;
1812         if (++state->p_read.seqnr == 0)
1813                 logit("incoming seqnr wraps around");
1814         if (++state->p_read.packets == 0)
1815                 if (!(ssh->compat & SSH_BUG_NOREKEY))
1816                         return SSH_ERR_NEED_REKEY;
1817         state->p_read.blocks += (state->packlen + 4) / block_size;
1818         state->p_read.bytes += state->packlen + 4;
1819
1820         /* get padlen */
1821         padlen = sshbuf_ptr(state->incoming_packet)[4];
1822         DBG(debug("input: padlen %d", padlen));
1823         if (padlen < 4) {
1824                 if ((r = sshpkt_disconnect(ssh,
1825                     "Corrupted padlen %d on input.", padlen)) != 0 ||
1826                     (r = ssh_packet_write_wait(ssh)) != 0)
1827                         return r;
1828                 return SSH_ERR_CONN_CORRUPT;
1829         }
1830
1831         /* skip packet size + padlen, discard padding */
1832         if ((r = sshbuf_consume(state->incoming_packet, 4 + 1)) != 0 ||
1833             ((r = sshbuf_consume_end(state->incoming_packet, padlen)) != 0))
1834                 goto out;
1835
1836         DBG(debug("input: len before de-compress %zd",
1837             sshbuf_len(state->incoming_packet)));
1838         if (comp && comp->enabled) {
1839                 sshbuf_reset(state->compression_buffer);
1840                 if ((r = uncompress_buffer(ssh, state->incoming_packet,
1841                     state->compression_buffer)) != 0)
1842                         goto out;
1843                 sshbuf_reset(state->incoming_packet);
1844                 if ((r = sshbuf_putb(state->incoming_packet,
1845                     state->compression_buffer)) != 0)
1846                         goto out;
1847                 DBG(debug("input: len after de-compress %zd",
1848                     sshbuf_len(state->incoming_packet)));
1849         }
1850         /*
1851          * get packet type, implies consume.
1852          * return length of payload (without type field)
1853          */
1854         if ((r = sshbuf_get_u8(state->incoming_packet, typep)) != 0)
1855                 goto out;
1856         if (ssh_packet_log_type(*typep))
1857                 debug3("receive packet: type %u", *typep);
1858         if (*typep < SSH2_MSG_MIN || *typep >= SSH2_MSG_LOCAL_MIN) {
1859                 if ((r = sshpkt_disconnect(ssh,
1860                     "Invalid ssh2 packet type: %d", *typep)) != 0 ||
1861                     (r = ssh_packet_write_wait(ssh)) != 0)
1862                         return r;
1863                 return SSH_ERR_PROTOCOL_ERROR;
1864         }
1865         if (*typep == SSH2_MSG_NEWKEYS)
1866                 r = ssh_set_newkeys(ssh, MODE_IN);
1867         else if (*typep == SSH2_MSG_USERAUTH_SUCCESS && !state->server_side)
1868                 r = ssh_packet_enable_delayed_compress(ssh);
1869         else
1870                 r = 0;
1871 #ifdef PACKET_DEBUG
1872         fprintf(stderr, "read/plain[%d]:\r\n", *typep);
1873         sshbuf_dump(state->incoming_packet, stderr);
1874 #endif
1875         /* reset for next packet */
1876         state->packlen = 0;
1877
1878         /* do we need to rekey? */
1879         if (ssh_packet_need_rekeying(ssh, 0)) {
1880                 debug3("%s: rekex triggered", __func__);
1881                 if ((r = kex_start_rekex(ssh)) != 0)
1882                         return r;
1883         }
1884  out:
1885         return r;
1886 }
1887
1888 int
1889 ssh_packet_read_poll_seqnr(struct ssh *ssh, u_char *typep, u_int32_t *seqnr_p)
1890 {
1891         struct session_state *state = ssh->state;
1892         u_int reason, seqnr;
1893         int r;
1894         u_char *msg;
1895
1896         for (;;) {
1897                 msg = NULL;
1898                 if (compat20) {
1899                         r = ssh_packet_read_poll2(ssh, typep, seqnr_p);
1900                         if (r != 0)
1901                                 return r;
1902                         if (*typep) {
1903                                 state->keep_alive_timeouts = 0;
1904                                 DBG(debug("received packet type %d", *typep));
1905                         }
1906                         switch (*typep) {
1907                         case SSH2_MSG_IGNORE:
1908                                 debug3("Received SSH2_MSG_IGNORE");
1909                                 break;
1910                         case SSH2_MSG_DEBUG:
1911                                 if ((r = sshpkt_get_u8(ssh, NULL)) != 0 ||
1912                                     (r = sshpkt_get_string(ssh, &msg, NULL)) != 0 ||
1913                                     (r = sshpkt_get_string(ssh, NULL, NULL)) != 0) {
1914                                         free(msg);
1915                                         return r;
1916                                 }
1917                                 debug("Remote: %.900s", msg);
1918                                 free(msg);
1919                                 break;
1920                         case SSH2_MSG_DISCONNECT:
1921                                 if ((r = sshpkt_get_u32(ssh, &reason)) != 0 ||
1922                                     (r = sshpkt_get_string(ssh, &msg, NULL)) != 0)
1923                                         return r;
1924                                 /* Ignore normal client exit notifications */
1925                                 do_log2(ssh->state->server_side &&
1926                                     reason == SSH2_DISCONNECT_BY_APPLICATION ?
1927                                     SYSLOG_LEVEL_INFO : SYSLOG_LEVEL_ERROR,
1928                                     "Received disconnect from %s port %d:"
1929                                     "%u: %.400s", ssh_remote_ipaddr(ssh),
1930                                     ssh_remote_port(ssh), reason, msg);
1931                                 free(msg);
1932                                 return SSH_ERR_DISCONNECTED;
1933                         case SSH2_MSG_UNIMPLEMENTED:
1934                                 if ((r = sshpkt_get_u32(ssh, &seqnr)) != 0)
1935                                         return r;
1936                                 debug("Received SSH2_MSG_UNIMPLEMENTED for %u",
1937                                     seqnr);
1938                                 break;
1939                         default:
1940                                 return 0;
1941                         }
1942                 } else {
1943                         r = ssh_packet_read_poll1(ssh, typep);
1944                         switch (*typep) {
1945                         case SSH_MSG_NONE:
1946                                 return SSH_MSG_NONE;
1947                         case SSH_MSG_IGNORE:
1948                                 break;
1949                         case SSH_MSG_DEBUG:
1950                                 if ((r = sshpkt_get_string(ssh, &msg, NULL)) != 0)
1951                                         return r;
1952                                 debug("Remote: %.900s", msg);
1953                                 free(msg);
1954                                 break;
1955                         case SSH_MSG_DISCONNECT:
1956                                 if ((r = sshpkt_get_string(ssh, &msg, NULL)) != 0)
1957                                         return r;
1958                                 logit("Received disconnect from %s port %d: "
1959                                     "%.400s", ssh_remote_ipaddr(ssh),
1960                                     ssh_remote_port(ssh), msg);
1961                                 free(msg);
1962                                 return SSH_ERR_DISCONNECTED;
1963                         default:
1964                                 DBG(debug("received packet type %d", *typep));
1965                                 return 0;
1966                         }
1967                 }
1968         }
1969 }
1970
1971 /*
1972  * Buffers the given amount of input characters.  This is intended to be used
1973  * together with packet_read_poll.
1974  */
1975
1976 int
1977 ssh_packet_process_incoming(struct ssh *ssh, const char *buf, u_int len)
1978 {
1979         struct session_state *state = ssh->state;
1980         int r;
1981
1982         if (state->packet_discard) {
1983                 state->keep_alive_timeouts = 0; /* ?? */
1984                 if (len >= state->packet_discard) {
1985                         if ((r = ssh_packet_stop_discard(ssh)) != 0)
1986                                 return r;
1987                 }
1988                 state->packet_discard -= len;
1989                 return 0;
1990         }
1991         if ((r = sshbuf_put(ssh->state->input, buf, len)) != 0)
1992                 return r;
1993
1994         return 0;
1995 }
1996
1997 int
1998 ssh_packet_remaining(struct ssh *ssh)
1999 {
2000         return sshbuf_len(ssh->state->incoming_packet);
2001 }
2002
2003 /*
2004  * Sends a diagnostic message from the server to the client.  This message
2005  * can be sent at any time (but not while constructing another message). The
2006  * message is printed immediately, but only if the client is being executed
2007  * in verbose mode.  These messages are primarily intended to ease debugging
2008  * authentication problems.   The length of the formatted message must not
2009  * exceed 1024 bytes.  This will automatically call ssh_packet_write_wait.
2010  */
2011 void
2012 ssh_packet_send_debug(struct ssh *ssh, const char *fmt,...)
2013 {
2014         char buf[1024];
2015         va_list args;
2016         int r;
2017
2018         if (compat20 && (ssh->compat & SSH_BUG_DEBUG))
2019                 return;
2020
2021         va_start(args, fmt);
2022         vsnprintf(buf, sizeof(buf), fmt, args);
2023         va_end(args);
2024
2025         if (compat20) {
2026                 if ((r = sshpkt_start(ssh, SSH2_MSG_DEBUG)) != 0 ||
2027                     (r = sshpkt_put_u8(ssh, 0)) != 0 || /* always display */
2028                     (r = sshpkt_put_cstring(ssh, buf)) != 0 ||
2029                     (r = sshpkt_put_cstring(ssh, "")) != 0 ||
2030                     (r = sshpkt_send(ssh)) != 0)
2031                         fatal("%s: %s", __func__, ssh_err(r));
2032         } else {
2033                 if ((r = sshpkt_start(ssh, SSH_MSG_DEBUG)) != 0 ||
2034                     (r = sshpkt_put_cstring(ssh, buf)) != 0 ||
2035                     (r = sshpkt_send(ssh)) != 0)
2036                         fatal("%s: %s", __func__, ssh_err(r));
2037         }
2038         if ((r = ssh_packet_write_wait(ssh)) != 0)
2039                 fatal("%s: %s", __func__, ssh_err(r));
2040 }
2041
2042 /*
2043  * Pretty-print connection-terminating errors and exit.
2044  */
2045 void
2046 sshpkt_fatal(struct ssh *ssh, const char *tag, int r)
2047 {
2048         switch (r) {
2049         case SSH_ERR_CONN_CLOSED:
2050                 logit("Connection closed by %.200s port %d",
2051                     ssh_remote_ipaddr(ssh), ssh_remote_port(ssh));
2052                 cleanup_exit(255);
2053         case SSH_ERR_CONN_TIMEOUT:
2054                 logit("Connection %s %.200s port %d timed out",
2055                     ssh->state->server_side ? "from" : "to",
2056                     ssh_remote_ipaddr(ssh), ssh_remote_port(ssh));
2057                 cleanup_exit(255);
2058         case SSH_ERR_DISCONNECTED:
2059                 logit("Disconnected from %.200s port %d",
2060                     ssh_remote_ipaddr(ssh), ssh_remote_port(ssh));
2061                 cleanup_exit(255);
2062         case SSH_ERR_SYSTEM_ERROR:
2063                 if (errno == ECONNRESET) {
2064                         logit("Connection reset by %.200s port %d",
2065                             ssh_remote_ipaddr(ssh), ssh_remote_port(ssh));
2066                         cleanup_exit(255);
2067                 }
2068                 /* FALLTHROUGH */
2069         case SSH_ERR_NO_CIPHER_ALG_MATCH:
2070         case SSH_ERR_NO_MAC_ALG_MATCH:
2071         case SSH_ERR_NO_COMPRESS_ALG_MATCH:
2072         case SSH_ERR_NO_KEX_ALG_MATCH:
2073         case SSH_ERR_NO_HOSTKEY_ALG_MATCH:
2074                 if (ssh && ssh->kex && ssh->kex->failed_choice) {
2075                         BLACKLIST_NOTIFY(BLACKLIST_AUTH_FAIL);
2076                         fatal("Unable to negotiate with %.200s port %d: %s. "
2077                             "Their offer: %s", ssh_remote_ipaddr(ssh),
2078                             ssh_remote_port(ssh), ssh_err(r),
2079                             ssh->kex->failed_choice);
2080                 }
2081                 /* FALLTHROUGH */
2082         default:
2083                 fatal("%s%sConnection %s %.200s port %d: %s",
2084                     tag != NULL ? tag : "", tag != NULL ? ": " : "",
2085                     ssh->state->server_side ? "from" : "to",
2086                     ssh_remote_ipaddr(ssh), ssh_remote_port(ssh), ssh_err(r));
2087         }
2088 }
2089
2090 /*
2091  * Logs the error plus constructs and sends a disconnect packet, closes the
2092  * connection, and exits.  This function never returns. The error message
2093  * should not contain a newline.  The length of the formatted message must
2094  * not exceed 1024 bytes.
2095  */
2096 void
2097 ssh_packet_disconnect(struct ssh *ssh, const char *fmt,...)
2098 {
2099         char buf[1024];
2100         va_list args;
2101         static int disconnecting = 0;
2102         int r;
2103
2104         if (disconnecting)      /* Guard against recursive invocations. */
2105                 fatal("packet_disconnect called recursively.");
2106         disconnecting = 1;
2107
2108         /*
2109          * Format the message.  Note that the caller must make sure the
2110          * message is of limited size.
2111          */
2112         va_start(args, fmt);
2113         vsnprintf(buf, sizeof(buf), fmt, args);
2114         va_end(args);
2115
2116         /* Display the error locally */
2117         logit("Disconnecting: %.100s", buf);
2118
2119         /*
2120          * Send the disconnect message to the other side, and wait
2121          * for it to get sent.
2122          */
2123         if ((r = sshpkt_disconnect(ssh, "%s", buf)) != 0)
2124                 sshpkt_fatal(ssh, __func__, r);
2125
2126         if ((r = ssh_packet_write_wait(ssh)) != 0)
2127                 sshpkt_fatal(ssh, __func__, r);
2128
2129         /* Close the connection. */
2130         ssh_packet_close(ssh);
2131         cleanup_exit(255);
2132 }
2133
2134 /*
2135  * Checks if there is any buffered output, and tries to write some of
2136  * the output.
2137  */
2138 int
2139 ssh_packet_write_poll(struct ssh *ssh)
2140 {
2141         struct session_state *state = ssh->state;
2142         int len = sshbuf_len(state->output);
2143         int r;
2144
2145         if (len > 0) {
2146                 len = write(state->connection_out,
2147                     sshbuf_ptr(state->output), len);
2148                 if (len == -1) {
2149                         if (errno == EINTR || errno == EAGAIN ||
2150                             errno == EWOULDBLOCK)
2151                                 return 0;
2152                         return SSH_ERR_SYSTEM_ERROR;
2153                 }
2154                 if (len == 0)
2155                         return SSH_ERR_CONN_CLOSED;
2156                 if ((r = sshbuf_consume(state->output, len)) != 0)
2157                         return r;
2158         }
2159         return 0;
2160 }
2161
2162 /*
2163  * Calls packet_write_poll repeatedly until all pending output data has been
2164  * written.
2165  */
2166 int
2167 ssh_packet_write_wait(struct ssh *ssh)
2168 {
2169         fd_set *setp;
2170         int ret, r, ms_remain = 0;
2171         struct timeval start, timeout, *timeoutp = NULL;
2172         struct session_state *state = ssh->state;
2173
2174         setp = calloc(howmany(state->connection_out + 1,
2175             NFDBITS), sizeof(fd_mask));
2176         if (setp == NULL)
2177                 return SSH_ERR_ALLOC_FAIL;
2178         if ((r = ssh_packet_write_poll(ssh)) != 0) {
2179                 free(setp);
2180                 return r;
2181         }
2182         while (ssh_packet_have_data_to_write(ssh)) {
2183                 memset(setp, 0, howmany(state->connection_out + 1,
2184                     NFDBITS) * sizeof(fd_mask));
2185                 FD_SET(state->connection_out, setp);
2186
2187                 if (state->packet_timeout_ms > 0) {
2188                         ms_remain = state->packet_timeout_ms;
2189                         timeoutp = &timeout;
2190                 }
2191                 for (;;) {
2192                         if (state->packet_timeout_ms != -1) {
2193                                 ms_to_timeval(&timeout, ms_remain);
2194                                 gettimeofday(&start, NULL);
2195                         }
2196                         if ((ret = select(state->connection_out + 1,
2197                             NULL, setp, NULL, timeoutp)) >= 0)
2198                                 break;
2199                         if (errno != EAGAIN && errno != EINTR &&
2200                             errno != EWOULDBLOCK)
2201                                 break;
2202                         if (state->packet_timeout_ms == -1)
2203                                 continue;
2204                         ms_subtract_diff(&start, &ms_remain);
2205                         if (ms_remain <= 0) {
2206                                 ret = 0;
2207                                 break;
2208                         }
2209                 }
2210                 if (ret == 0) {
2211                         free(setp);
2212                         return SSH_ERR_CONN_TIMEOUT;
2213                 }
2214                 if ((r = ssh_packet_write_poll(ssh)) != 0) {
2215                         free(setp);
2216                         return r;
2217                 }
2218         }
2219         free(setp);
2220         return 0;
2221 }
2222
2223 /* Returns true if there is buffered data to write to the connection. */
2224
2225 int
2226 ssh_packet_have_data_to_write(struct ssh *ssh)
2227 {
2228         return sshbuf_len(ssh->state->output) != 0;
2229 }
2230
2231 /* Returns true if there is not too much data to write to the connection. */
2232
2233 int
2234 ssh_packet_not_very_much_data_to_write(struct ssh *ssh)
2235 {
2236         if (ssh->state->interactive_mode)
2237                 return sshbuf_len(ssh->state->output) < 16384;
2238         else
2239                 return sshbuf_len(ssh->state->output) < 128 * 1024;
2240 }
2241
2242 void
2243 ssh_packet_set_tos(struct ssh *ssh, int tos)
2244 {
2245 #ifndef IP_TOS_IS_BROKEN
2246         if (!ssh_packet_connection_is_on_socket(ssh))
2247                 return;
2248         switch (ssh_packet_connection_af(ssh)) {
2249 # ifdef IP_TOS
2250         case AF_INET:
2251                 debug3("%s: set IP_TOS 0x%02x", __func__, tos);
2252                 if (setsockopt(ssh->state->connection_in,
2253                     IPPROTO_IP, IP_TOS, &tos, sizeof(tos)) < 0)
2254                         error("setsockopt IP_TOS %d: %.100s:",
2255                             tos, strerror(errno));
2256                 break;
2257 # endif /* IP_TOS */
2258 # ifdef IPV6_TCLASS
2259         case AF_INET6:
2260                 debug3("%s: set IPV6_TCLASS 0x%02x", __func__, tos);
2261                 if (setsockopt(ssh->state->connection_in,
2262                     IPPROTO_IPV6, IPV6_TCLASS, &tos, sizeof(tos)) < 0)
2263                         error("setsockopt IPV6_TCLASS %d: %.100s:",
2264                             tos, strerror(errno));
2265                 break;
2266 # endif /* IPV6_TCLASS */
2267         }
2268 #endif /* IP_TOS_IS_BROKEN */
2269 }
2270
2271 /* Informs that the current session is interactive.  Sets IP flags for that. */
2272
2273 void
2274 ssh_packet_set_interactive(struct ssh *ssh, int interactive, int qos_interactive, int qos_bulk)
2275 {
2276         struct session_state *state = ssh->state;
2277
2278         if (state->set_interactive_called)
2279                 return;
2280         state->set_interactive_called = 1;
2281
2282         /* Record that we are in interactive mode. */
2283         state->interactive_mode = interactive;
2284
2285         /* Only set socket options if using a socket.  */
2286         if (!ssh_packet_connection_is_on_socket(ssh))
2287                 return;
2288         set_nodelay(state->connection_in);
2289         ssh_packet_set_tos(ssh, interactive ? qos_interactive :
2290             qos_bulk);
2291 }
2292
2293 /* Returns true if the current connection is interactive. */
2294
2295 int
2296 ssh_packet_is_interactive(struct ssh *ssh)
2297 {
2298         return ssh->state->interactive_mode;
2299 }
2300
2301 int
2302 ssh_packet_set_maxsize(struct ssh *ssh, u_int s)
2303 {
2304         struct session_state *state = ssh->state;
2305
2306         if (state->set_maxsize_called) {
2307                 logit("packet_set_maxsize: called twice: old %d new %d",
2308                     state->max_packet_size, s);
2309                 return -1;
2310         }
2311         if (s < 4 * 1024 || s > 1024 * 1024) {
2312                 logit("packet_set_maxsize: bad size %d", s);
2313                 return -1;
2314         }
2315         state->set_maxsize_called = 1;
2316         debug("packet_set_maxsize: setting to %d", s);
2317         state->max_packet_size = s;
2318         return s;
2319 }
2320
2321 int
2322 ssh_packet_inc_alive_timeouts(struct ssh *ssh)
2323 {
2324         return ++ssh->state->keep_alive_timeouts;
2325 }
2326
2327 void
2328 ssh_packet_set_alive_timeouts(struct ssh *ssh, int ka)
2329 {
2330         ssh->state->keep_alive_timeouts = ka;
2331 }
2332
2333 u_int
2334 ssh_packet_get_maxsize(struct ssh *ssh)
2335 {
2336         return ssh->state->max_packet_size;
2337 }
2338
2339 /*
2340  * 9.2.  Ignored Data Message
2341  *
2342  *   byte      SSH_MSG_IGNORE
2343  *   string    data
2344  *
2345  * All implementations MUST understand (and ignore) this message at any
2346  * time (after receiving the protocol version). No implementation is
2347  * required to send them. This message can be used as an additional
2348  * protection measure against advanced traffic analysis techniques.
2349  */
2350 void
2351 ssh_packet_send_ignore(struct ssh *ssh, int nbytes)
2352 {
2353         u_int32_t rnd = 0;
2354         int r, i;
2355
2356         if ((r = sshpkt_start(ssh, compat20 ?
2357             SSH2_MSG_IGNORE : SSH_MSG_IGNORE)) != 0 ||
2358             (r = sshpkt_put_u32(ssh, nbytes)) != 0)
2359                 fatal("%s: %s", __func__, ssh_err(r));
2360         for (i = 0; i < nbytes; i++) {
2361                 if (i % 4 == 0)
2362                         rnd = arc4random();
2363                 if ((r = sshpkt_put_u8(ssh, (u_char)rnd & 0xff)) != 0)
2364                         fatal("%s: %s", __func__, ssh_err(r));
2365                 rnd >>= 8;
2366         }
2367 }
2368
2369 void
2370 ssh_packet_set_rekey_limits(struct ssh *ssh, u_int64_t bytes, time_t seconds)
2371 {
2372         debug3("rekey after %llu bytes, %d seconds", (unsigned long long)bytes,
2373             (int)seconds);
2374         ssh->state->rekey_limit = bytes;
2375         ssh->state->rekey_interval = seconds;
2376 }
2377
2378 time_t
2379 ssh_packet_get_rekey_timeout(struct ssh *ssh)
2380 {
2381         time_t seconds;
2382
2383         seconds = ssh->state->rekey_time + ssh->state->rekey_interval -
2384             monotime();
2385         return (seconds <= 0 ? 1 : seconds);
2386 }
2387
2388 void
2389 ssh_packet_set_server(struct ssh *ssh)
2390 {
2391         ssh->state->server_side = 1;
2392 }
2393
2394 void
2395 ssh_packet_set_authenticated(struct ssh *ssh)
2396 {
2397         ssh->state->after_authentication = 1;
2398 }
2399
2400 void *
2401 ssh_packet_get_input(struct ssh *ssh)
2402 {
2403         return (void *)ssh->state->input;
2404 }
2405
2406 void *
2407 ssh_packet_get_output(struct ssh *ssh)
2408 {
2409         return (void *)ssh->state->output;
2410 }
2411
2412 /* Reset after_authentication and reset compression in post-auth privsep */
2413 static int
2414 ssh_packet_set_postauth(struct ssh *ssh)
2415 {
2416         struct sshcomp *comp;
2417         int r, mode;
2418
2419         debug("%s: called", __func__);
2420         /* This was set in net child, but is not visible in user child */
2421         ssh->state->after_authentication = 1;
2422         ssh->state->rekeying = 0;
2423         for (mode = 0; mode < MODE_MAX; mode++) {
2424                 if (ssh->state->newkeys[mode] == NULL)
2425                         continue;
2426                 comp = &ssh->state->newkeys[mode]->comp;
2427                 if (comp && comp->enabled &&
2428                     (r = ssh_packet_init_compression(ssh)) != 0)
2429                         return r;
2430         }
2431         return 0;
2432 }
2433
2434 /* Packet state (de-)serialization for privsep */
2435
2436 /* turn kex into a blob for packet state serialization */
2437 static int
2438 kex_to_blob(struct sshbuf *m, struct kex *kex)
2439 {
2440         int r;
2441
2442         if ((r = sshbuf_put_string(m, kex->session_id,
2443             kex->session_id_len)) != 0 ||
2444             (r = sshbuf_put_u32(m, kex->we_need)) != 0 ||
2445             (r = sshbuf_put_u32(m, kex->hostkey_type)) != 0 ||
2446             (r = sshbuf_put_u32(m, kex->kex_type)) != 0 ||
2447             (r = sshbuf_put_stringb(m, kex->my)) != 0 ||
2448             (r = sshbuf_put_stringb(m, kex->peer)) != 0 ||
2449             (r = sshbuf_put_u32(m, kex->flags)) != 0 ||
2450             (r = sshbuf_put_cstring(m, kex->client_version_string)) != 0 ||
2451             (r = sshbuf_put_cstring(m, kex->server_version_string)) != 0)
2452                 return r;
2453         return 0;
2454 }
2455
2456 /* turn key exchange results into a blob for packet state serialization */
2457 static int
2458 newkeys_to_blob(struct sshbuf *m, struct ssh *ssh, int mode)
2459 {
2460         struct sshbuf *b;
2461         struct sshcipher_ctx *cc;
2462         struct sshcomp *comp;
2463         struct sshenc *enc;
2464         struct sshmac *mac;
2465         struct newkeys *newkey;
2466         int r;
2467
2468         if ((newkey = ssh->state->newkeys[mode]) == NULL)
2469                 return SSH_ERR_INTERNAL_ERROR;
2470         enc = &newkey->enc;
2471         mac = &newkey->mac;
2472         comp = &newkey->comp;
2473         cc = (mode == MODE_OUT) ? &ssh->state->send_context :
2474             &ssh->state->receive_context;
2475         if ((r = cipher_get_keyiv(cc, enc->iv, enc->iv_len)) != 0)
2476                 return r;
2477         if ((b = sshbuf_new()) == NULL)
2478                 return SSH_ERR_ALLOC_FAIL;
2479         /* The cipher struct is constant and shared, you export pointer */
2480         if ((r = sshbuf_put_cstring(b, enc->name)) != 0 ||
2481             (r = sshbuf_put(b, &enc->cipher, sizeof(enc->cipher))) != 0 ||
2482             (r = sshbuf_put_u32(b, enc->enabled)) != 0 ||
2483             (r = sshbuf_put_u32(b, enc->block_size)) != 0 ||
2484             (r = sshbuf_put_string(b, enc->key, enc->key_len)) != 0 ||
2485             (r = sshbuf_put_string(b, enc->iv, enc->iv_len)) != 0)
2486                 goto out;
2487         if (cipher_authlen(enc->cipher) == 0) {
2488                 if ((r = sshbuf_put_cstring(b, mac->name)) != 0 ||
2489                     (r = sshbuf_put_u32(b, mac->enabled)) != 0 ||
2490                     (r = sshbuf_put_string(b, mac->key, mac->key_len)) != 0)
2491                         goto out;
2492         }
2493         if ((r = sshbuf_put_u32(b, comp->type)) != 0 ||
2494             (r = sshbuf_put_u32(b, comp->enabled)) != 0 ||
2495             (r = sshbuf_put_cstring(b, comp->name)) != 0)
2496                 goto out;
2497         r = sshbuf_put_stringb(m, b);
2498  out:
2499         sshbuf_free(b);
2500         return r;
2501 }
2502
2503 /* serialize packet state into a blob */
2504 int
2505 ssh_packet_get_state(struct ssh *ssh, struct sshbuf *m)
2506 {
2507         struct session_state *state = ssh->state;
2508         u_char *p;
2509         size_t slen, rlen;
2510         int r, ssh1cipher;
2511
2512         if (!compat20) {
2513                 ssh1cipher = cipher_get_number(state->receive_context.cipher);
2514                 slen = cipher_get_keyiv_len(&state->send_context);
2515                 rlen = cipher_get_keyiv_len(&state->receive_context);
2516                 if ((r = sshbuf_put_u32(m, state->remote_protocol_flags)) != 0 ||
2517                     (r = sshbuf_put_u32(m, ssh1cipher)) != 0 ||
2518                     (r = sshbuf_put_string(m, state->ssh1_key, state->ssh1_keylen)) != 0 ||
2519                     (r = sshbuf_put_u32(m, slen)) != 0 ||
2520                     (r = sshbuf_reserve(m, slen, &p)) != 0 ||
2521                     (r = cipher_get_keyiv(&state->send_context, p, slen)) != 0 ||
2522                     (r = sshbuf_put_u32(m, rlen)) != 0 ||
2523                     (r = sshbuf_reserve(m, rlen, &p)) != 0 ||
2524                     (r = cipher_get_keyiv(&state->receive_context, p, rlen)) != 0)
2525                         return r;
2526         } else {
2527                 if ((r = kex_to_blob(m, ssh->kex)) != 0 ||
2528                     (r = newkeys_to_blob(m, ssh, MODE_OUT)) != 0 ||
2529                     (r = newkeys_to_blob(m, ssh, MODE_IN)) != 0 ||
2530                     (r = sshbuf_put_u64(m, state->rekey_limit)) != 0 ||
2531                     (r = sshbuf_put_u32(m, state->rekey_interval)) != 0 ||
2532                     (r = sshbuf_put_u32(m, state->p_send.seqnr)) != 0 ||
2533                     (r = sshbuf_put_u64(m, state->p_send.blocks)) != 0 ||
2534                     (r = sshbuf_put_u32(m, state->p_send.packets)) != 0 ||
2535                     (r = sshbuf_put_u64(m, state->p_send.bytes)) != 0 ||
2536                     (r = sshbuf_put_u32(m, state->p_read.seqnr)) != 0 ||
2537                     (r = sshbuf_put_u64(m, state->p_read.blocks)) != 0 ||
2538                     (r = sshbuf_put_u32(m, state->p_read.packets)) != 0 ||
2539                     (r = sshbuf_put_u64(m, state->p_read.bytes)) != 0)
2540                         return r;
2541         }
2542
2543         slen = cipher_get_keycontext(&state->send_context, NULL);
2544         rlen = cipher_get_keycontext(&state->receive_context, NULL);
2545         if ((r = sshbuf_put_u32(m, slen)) != 0 ||
2546             (r = sshbuf_reserve(m, slen, &p)) != 0)
2547                 return r;
2548         if (cipher_get_keycontext(&state->send_context, p) != (int)slen)
2549                 return SSH_ERR_INTERNAL_ERROR;
2550         if ((r = sshbuf_put_u32(m, rlen)) != 0 ||
2551             (r = sshbuf_reserve(m, rlen, &p)) != 0)
2552                 return r;
2553         if (cipher_get_keycontext(&state->receive_context, p) != (int)rlen)
2554                 return SSH_ERR_INTERNAL_ERROR;
2555
2556         if ((r = ssh_packet_get_compress_state(m, ssh)) != 0 ||
2557             (r = sshbuf_put_stringb(m, state->input)) != 0 ||
2558             (r = sshbuf_put_stringb(m, state->output)) != 0)
2559                 return r;
2560
2561         return 0;
2562 }
2563
2564 /* restore key exchange results from blob for packet state de-serialization */
2565 static int
2566 newkeys_from_blob(struct sshbuf *m, struct ssh *ssh, int mode)
2567 {
2568         struct sshbuf *b = NULL;
2569         struct sshcomp *comp;
2570         struct sshenc *enc;
2571         struct sshmac *mac;
2572         struct newkeys *newkey = NULL;
2573         size_t keylen, ivlen, maclen;
2574         int r;
2575
2576         if ((newkey = calloc(1, sizeof(*newkey))) == NULL) {
2577                 r = SSH_ERR_ALLOC_FAIL;
2578                 goto out;
2579         }
2580         if ((r = sshbuf_froms(m, &b)) != 0)
2581                 goto out;
2582 #ifdef DEBUG_PK
2583         sshbuf_dump(b, stderr);
2584 #endif
2585         enc = &newkey->enc;
2586         mac = &newkey->mac;
2587         comp = &newkey->comp;
2588
2589         if ((r = sshbuf_get_cstring(b, &enc->name, NULL)) != 0 ||
2590             (r = sshbuf_get(b, &enc->cipher, sizeof(enc->cipher))) != 0 ||
2591             (r = sshbuf_get_u32(b, (u_int *)&enc->enabled)) != 0 ||
2592             (r = sshbuf_get_u32(b, &enc->block_size)) != 0 ||
2593             (r = sshbuf_get_string(b, &enc->key, &keylen)) != 0 ||
2594             (r = sshbuf_get_string(b, &enc->iv, &ivlen)) != 0)
2595                 goto out;
2596         if (cipher_authlen(enc->cipher) == 0) {
2597                 if ((r = sshbuf_get_cstring(b, &mac->name, NULL)) != 0)
2598                         goto out;
2599                 if ((r = mac_setup(mac, mac->name)) != 0)
2600                         goto out;
2601                 if ((r = sshbuf_get_u32(b, (u_int *)&mac->enabled)) != 0 ||
2602                     (r = sshbuf_get_string(b, &mac->key, &maclen)) != 0)
2603                         goto out;
2604                 if (maclen > mac->key_len) {
2605                         r = SSH_ERR_INVALID_FORMAT;
2606                         goto out;
2607                 }
2608                 mac->key_len = maclen;
2609         }
2610         if ((r = sshbuf_get_u32(b, &comp->type)) != 0 ||
2611             (r = sshbuf_get_u32(b, (u_int *)&comp->enabled)) != 0 ||
2612             (r = sshbuf_get_cstring(b, &comp->name, NULL)) != 0)
2613                 goto out;
2614         if (enc->name == NULL ||
2615             cipher_by_name(enc->name) != enc->cipher) {
2616                 r = SSH_ERR_INVALID_FORMAT;
2617                 goto out;
2618         }
2619         if (sshbuf_len(b) != 0) {
2620                 r = SSH_ERR_INVALID_FORMAT;
2621                 goto out;
2622         }
2623         enc->key_len = keylen;
2624         enc->iv_len = ivlen;
2625         ssh->kex->newkeys[mode] = newkey;
2626         newkey = NULL;
2627         r = 0;
2628  out:
2629         free(newkey);
2630         sshbuf_free(b);
2631         return r;
2632 }
2633
2634 /* restore kex from blob for packet state de-serialization */
2635 static int
2636 kex_from_blob(struct sshbuf *m, struct kex **kexp)
2637 {
2638         struct kex *kex;
2639         int r;
2640
2641         if ((kex = calloc(1, sizeof(struct kex))) == NULL ||
2642             (kex->my = sshbuf_new()) == NULL ||
2643             (kex->peer = sshbuf_new()) == NULL) {
2644                 r = SSH_ERR_ALLOC_FAIL;
2645                 goto out;
2646         }
2647         if ((r = sshbuf_get_string(m, &kex->session_id, &kex->session_id_len)) != 0 ||
2648             (r = sshbuf_get_u32(m, &kex->we_need)) != 0 ||
2649             (r = sshbuf_get_u32(m, (u_int *)&kex->hostkey_type)) != 0 ||
2650             (r = sshbuf_get_u32(m, &kex->kex_type)) != 0 ||
2651             (r = sshbuf_get_stringb(m, kex->my)) != 0 ||
2652             (r = sshbuf_get_stringb(m, kex->peer)) != 0 ||
2653             (r = sshbuf_get_u32(m, &kex->flags)) != 0 ||
2654             (r = sshbuf_get_cstring(m, &kex->client_version_string, NULL)) != 0 ||
2655             (r = sshbuf_get_cstring(m, &kex->server_version_string, NULL)) != 0)
2656                 goto out;
2657         kex->server = 1;
2658         kex->done = 1;
2659         r = 0;
2660  out:
2661         if (r != 0 || kexp == NULL) {
2662                 if (kex != NULL) {
2663                         sshbuf_free(kex->my);
2664                         sshbuf_free(kex->peer);
2665                         free(kex);
2666                 }
2667                 if (kexp != NULL)
2668                         *kexp = NULL;
2669         } else {
2670                 *kexp = kex;
2671         }
2672         return r;
2673 }
2674
2675 /*
2676  * Restore packet state from content of blob 'm' (de-serialization).
2677  * Note that 'm' will be partially consumed on parsing or any other errors.
2678  */
2679 int
2680 ssh_packet_set_state(struct ssh *ssh, struct sshbuf *m)
2681 {
2682         struct session_state *state = ssh->state;
2683         const u_char *ssh1key, *ivin, *ivout, *keyin, *keyout, *input, *output;
2684         size_t ssh1keylen, rlen, slen, ilen, olen;
2685         int r;
2686         u_int ssh1cipher = 0;
2687
2688         if (!compat20) {
2689                 if ((r = sshbuf_get_u32(m, &state->remote_protocol_flags)) != 0 ||
2690                     (r = sshbuf_get_u32(m, &ssh1cipher)) != 0 ||
2691                     (r = sshbuf_get_string_direct(m, &ssh1key, &ssh1keylen)) != 0 ||
2692                     (r = sshbuf_get_string_direct(m, &ivout, &slen)) != 0 ||
2693                     (r = sshbuf_get_string_direct(m, &ivin, &rlen)) != 0)
2694                         return r;
2695                 if (ssh1cipher > INT_MAX)
2696                         return SSH_ERR_KEY_UNKNOWN_CIPHER;
2697                 ssh_packet_set_encryption_key(ssh, ssh1key, ssh1keylen,
2698                     (int)ssh1cipher);
2699                 if (cipher_get_keyiv_len(&state->send_context) != (int)slen ||
2700                     cipher_get_keyiv_len(&state->receive_context) != (int)rlen)
2701                         return SSH_ERR_INVALID_FORMAT;
2702                 if ((r = cipher_set_keyiv(&state->send_context, ivout)) != 0 ||
2703                     (r = cipher_set_keyiv(&state->receive_context, ivin)) != 0)
2704                         return r;
2705         } else {
2706                 if ((r = kex_from_blob(m, &ssh->kex)) != 0 ||
2707                     (r = newkeys_from_blob(m, ssh, MODE_OUT)) != 0 ||
2708                     (r = newkeys_from_blob(m, ssh, MODE_IN)) != 0 ||
2709                     (r = sshbuf_get_u64(m, &state->rekey_limit)) != 0 ||
2710                     (r = sshbuf_get_u32(m, &state->rekey_interval)) != 0 ||
2711                     (r = sshbuf_get_u32(m, &state->p_send.seqnr)) != 0 ||
2712                     (r = sshbuf_get_u64(m, &state->p_send.blocks)) != 0 ||
2713                     (r = sshbuf_get_u32(m, &state->p_send.packets)) != 0 ||
2714                     (r = sshbuf_get_u64(m, &state->p_send.bytes)) != 0 ||
2715                     (r = sshbuf_get_u32(m, &state->p_read.seqnr)) != 0 ||
2716                     (r = sshbuf_get_u64(m, &state->p_read.blocks)) != 0 ||
2717                     (r = sshbuf_get_u32(m, &state->p_read.packets)) != 0 ||
2718                     (r = sshbuf_get_u64(m, &state->p_read.bytes)) != 0)
2719                         return r;
2720                 /*
2721                  * We set the time here so that in post-auth privsep slave we
2722                  * count from the completion of the authentication.
2723                  */
2724                 state->rekey_time = monotime();
2725                 /* XXX ssh_set_newkeys overrides p_read.packets? XXX */
2726                 if ((r = ssh_set_newkeys(ssh, MODE_IN)) != 0 ||
2727                     (r = ssh_set_newkeys(ssh, MODE_OUT)) != 0)
2728                         return r;
2729         }
2730         if ((r = sshbuf_get_string_direct(m, &keyout, &slen)) != 0 ||
2731             (r = sshbuf_get_string_direct(m, &keyin, &rlen)) != 0)
2732                 return r;
2733         if (cipher_get_keycontext(&state->send_context, NULL) != (int)slen ||
2734             cipher_get_keycontext(&state->receive_context, NULL) != (int)rlen)
2735                 return SSH_ERR_INVALID_FORMAT;
2736         cipher_set_keycontext(&state->send_context, keyout);
2737         cipher_set_keycontext(&state->receive_context, keyin);
2738
2739         if ((r = ssh_packet_set_compress_state(ssh, m)) != 0 ||
2740             (r = ssh_packet_set_postauth(ssh)) != 0)
2741                 return r;
2742
2743         sshbuf_reset(state->input);
2744         sshbuf_reset(state->output);
2745         if ((r = sshbuf_get_string_direct(m, &input, &ilen)) != 0 ||
2746             (r = sshbuf_get_string_direct(m, &output, &olen)) != 0 ||
2747             (r = sshbuf_put(state->input, input, ilen)) != 0 ||
2748             (r = sshbuf_put(state->output, output, olen)) != 0)
2749                 return r;
2750
2751         if (sshbuf_len(m))
2752                 return SSH_ERR_INVALID_FORMAT;
2753         debug3("%s: done", __func__);
2754         return 0;
2755 }
2756
2757 /* NEW API */
2758
2759 /* put data to the outgoing packet */
2760
2761 int
2762 sshpkt_put(struct ssh *ssh, const void *v, size_t len)
2763 {
2764         return sshbuf_put(ssh->state->outgoing_packet, v, len);
2765 }
2766
2767 int
2768 sshpkt_putb(struct ssh *ssh, const struct sshbuf *b)
2769 {
2770         return sshbuf_putb(ssh->state->outgoing_packet, b);
2771 }
2772
2773 int
2774 sshpkt_put_u8(struct ssh *ssh, u_char val)
2775 {
2776         return sshbuf_put_u8(ssh->state->outgoing_packet, val);
2777 }
2778
2779 int
2780 sshpkt_put_u32(struct ssh *ssh, u_int32_t val)
2781 {
2782         return sshbuf_put_u32(ssh->state->outgoing_packet, val);
2783 }
2784
2785 int
2786 sshpkt_put_u64(struct ssh *ssh, u_int64_t val)
2787 {
2788         return sshbuf_put_u64(ssh->state->outgoing_packet, val);
2789 }
2790
2791 int
2792 sshpkt_put_string(struct ssh *ssh, const void *v, size_t len)
2793 {
2794         return sshbuf_put_string(ssh->state->outgoing_packet, v, len);
2795 }
2796
2797 int
2798 sshpkt_put_cstring(struct ssh *ssh, const void *v)
2799 {
2800         return sshbuf_put_cstring(ssh->state->outgoing_packet, v);
2801 }
2802
2803 int
2804 sshpkt_put_stringb(struct ssh *ssh, const struct sshbuf *v)
2805 {
2806         return sshbuf_put_stringb(ssh->state->outgoing_packet, v);
2807 }
2808
2809 #ifdef WITH_OPENSSL
2810 #ifdef OPENSSL_HAS_ECC
2811 int
2812 sshpkt_put_ec(struct ssh *ssh, const EC_POINT *v, const EC_GROUP *g)
2813 {
2814         return sshbuf_put_ec(ssh->state->outgoing_packet, v, g);
2815 }
2816 #endif /* OPENSSL_HAS_ECC */
2817
2818 #ifdef WITH_SSH1
2819 int
2820 sshpkt_put_bignum1(struct ssh *ssh, const BIGNUM *v)
2821 {
2822         return sshbuf_put_bignum1(ssh->state->outgoing_packet, v);
2823 }
2824 #endif /* WITH_SSH1 */
2825
2826 int
2827 sshpkt_put_bignum2(struct ssh *ssh, const BIGNUM *v)
2828 {
2829         return sshbuf_put_bignum2(ssh->state->outgoing_packet, v);
2830 }
2831 #endif /* WITH_OPENSSL */
2832
2833 /* fetch data from the incoming packet */
2834
2835 int
2836 sshpkt_get(struct ssh *ssh, void *valp, size_t len)
2837 {
2838         return sshbuf_get(ssh->state->incoming_packet, valp, len);
2839 }
2840
2841 int
2842 sshpkt_get_u8(struct ssh *ssh, u_char *valp)
2843 {
2844         return sshbuf_get_u8(ssh->state->incoming_packet, valp);
2845 }
2846
2847 int
2848 sshpkt_get_u32(struct ssh *ssh, u_int32_t *valp)
2849 {
2850         return sshbuf_get_u32(ssh->state->incoming_packet, valp);
2851 }
2852
2853 int
2854 sshpkt_get_u64(struct ssh *ssh, u_int64_t *valp)
2855 {
2856         return sshbuf_get_u64(ssh->state->incoming_packet, valp);
2857 }
2858
2859 int
2860 sshpkt_get_string(struct ssh *ssh, u_char **valp, size_t *lenp)
2861 {
2862         return sshbuf_get_string(ssh->state->incoming_packet, valp, lenp);
2863 }
2864
2865 int
2866 sshpkt_get_string_direct(struct ssh *ssh, const u_char **valp, size_t *lenp)
2867 {
2868         return sshbuf_get_string_direct(ssh->state->incoming_packet, valp, lenp);
2869 }
2870
2871 int
2872 sshpkt_get_cstring(struct ssh *ssh, char **valp, size_t *lenp)
2873 {
2874         return sshbuf_get_cstring(ssh->state->incoming_packet, valp, lenp);
2875 }
2876
2877 #ifdef WITH_OPENSSL
2878 #ifdef OPENSSL_HAS_ECC
2879 int
2880 sshpkt_get_ec(struct ssh *ssh, EC_POINT *v, const EC_GROUP *g)
2881 {
2882         return sshbuf_get_ec(ssh->state->incoming_packet, v, g);
2883 }
2884 #endif /* OPENSSL_HAS_ECC */
2885
2886 #ifdef WITH_SSH1
2887 int
2888 sshpkt_get_bignum1(struct ssh *ssh, BIGNUM *v)
2889 {
2890         return sshbuf_get_bignum1(ssh->state->incoming_packet, v);
2891 }
2892 #endif /* WITH_SSH1 */
2893
2894 int
2895 sshpkt_get_bignum2(struct ssh *ssh, BIGNUM *v)
2896 {
2897         return sshbuf_get_bignum2(ssh->state->incoming_packet, v);
2898 }
2899 #endif /* WITH_OPENSSL */
2900
2901 int
2902 sshpkt_get_end(struct ssh *ssh)
2903 {
2904         if (sshbuf_len(ssh->state->incoming_packet) > 0)
2905                 return SSH_ERR_UNEXPECTED_TRAILING_DATA;
2906         return 0;
2907 }
2908
2909 const u_char *
2910 sshpkt_ptr(struct ssh *ssh, size_t *lenp)
2911 {
2912         if (lenp != NULL)
2913                 *lenp = sshbuf_len(ssh->state->incoming_packet);
2914         return sshbuf_ptr(ssh->state->incoming_packet);
2915 }
2916
2917 /* start a new packet */
2918
2919 int
2920 sshpkt_start(struct ssh *ssh, u_char type)
2921 {
2922         u_char buf[9];
2923         int len;
2924
2925         DBG(debug("packet_start[%d]", type));
2926         len = compat20 ? 6 : 9;
2927         memset(buf, 0, len - 1);
2928         buf[len - 1] = type;
2929         sshbuf_reset(ssh->state->outgoing_packet);
2930         return sshbuf_put(ssh->state->outgoing_packet, buf, len);
2931 }
2932
2933 /* send it */
2934
2935 int
2936 sshpkt_send(struct ssh *ssh)
2937 {
2938         if (compat20)
2939                 return ssh_packet_send2(ssh);
2940         else
2941                 return ssh_packet_send1(ssh);
2942 }
2943
2944 int
2945 sshpkt_disconnect(struct ssh *ssh, const char *fmt,...)
2946 {
2947         char buf[1024];
2948         va_list args;
2949         int r;
2950
2951         va_start(args, fmt);
2952         vsnprintf(buf, sizeof(buf), fmt, args);
2953         va_end(args);
2954
2955         if (compat20) {
2956                 if ((r = sshpkt_start(ssh, SSH2_MSG_DISCONNECT)) != 0 ||
2957                     (r = sshpkt_put_u32(ssh, SSH2_DISCONNECT_PROTOCOL_ERROR)) != 0 ||
2958                     (r = sshpkt_put_cstring(ssh, buf)) != 0 ||
2959                     (r = sshpkt_put_cstring(ssh, "")) != 0 ||
2960                     (r = sshpkt_send(ssh)) != 0)
2961                         return r;
2962         } else {
2963                 if ((r = sshpkt_start(ssh, SSH_MSG_DISCONNECT)) != 0 ||
2964                     (r = sshpkt_put_cstring(ssh, buf)) != 0 ||
2965                     (r = sshpkt_send(ssh)) != 0)
2966                         return r;
2967         }
2968         return 0;
2969 }
2970
2971 /* roundup current message to pad bytes */
2972 int
2973 sshpkt_add_padding(struct ssh *ssh, u_char pad)
2974 {
2975         ssh->state->extra_pad = pad;
2976         return 0;
2977 }