]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - packet.c
Vendor import of OpenSSH 6.9p1.
[FreeBSD/FreeBSD.git] / packet.c
1 /* $OpenBSD: packet.c,v 1.212 2015/05/01 07:10:01 djm Exp $ */
2 /*
3  * Author: Tatu Ylonen <ylo@cs.hut.fi>
4  * Copyright (c) 1995 Tatu Ylonen <ylo@cs.hut.fi>, Espoo, Finland
5  *                    All rights reserved
6  * This file contains code implementing the packet protocol and communication
7  * with the other side.  This same code is used both on client and server side.
8  *
9  * As far as I am concerned, the code I have written for this software
10  * can be used freely for any purpose.  Any derived versions of this
11  * software must be clearly marked as such, and if the derived work is
12  * incompatible with the protocol description in the RFC file, it must be
13  * called by a name other than "ssh" or "Secure Shell".
14  *
15  *
16  * SSH2 packet format added by Markus Friedl.
17  * Copyright (c) 2000, 2001 Markus Friedl.  All rights reserved.
18  *
19  * Redistribution and use in source and binary forms, with or without
20  * modification, are permitted provided that the following conditions
21  * are met:
22  * 1. Redistributions of source code must retain the above copyright
23  *    notice, this list of conditions and the following disclaimer.
24  * 2. Redistributions in binary form must reproduce the above copyright
25  *    notice, this list of conditions and the following disclaimer in the
26  *    documentation and/or other materials provided with the distribution.
27  *
28  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
29  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
30  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
31  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
32  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
33  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
34  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
35  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
36  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
37  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
38  */
39
40 #include "includes.h"
41  
42 #include <sys/param.h>  /* MIN roundup */
43 #include <sys/types.h>
44 #include "openbsd-compat/sys-queue.h"
45 #include <sys/socket.h>
46 #ifdef HAVE_SYS_TIME_H
47 # include <sys/time.h>
48 #endif
49
50 #include <netinet/in.h>
51 #include <netinet/ip.h>
52 #include <arpa/inet.h>
53
54 #include <errno.h>
55 #include <stdarg.h>
56 #include <stdio.h>
57 #include <stdlib.h>
58 #include <string.h>
59 #include <unistd.h>
60 #include <limits.h>
61 #include <signal.h>
62 #include <time.h>
63
64 #include <zlib.h>
65
66 #include "buffer.h"     /* typedefs XXX */
67 #include "key.h"        /* typedefs XXX */
68
69 #include "xmalloc.h"
70 #include "crc32.h"
71 #include "deattack.h"
72 #include "compat.h"
73 #include "ssh1.h"
74 #include "ssh2.h"
75 #include "cipher.h"
76 #include "sshkey.h"
77 #include "kex.h"
78 #include "digest.h"
79 #include "mac.h"
80 #include "log.h"
81 #include "canohost.h"
82 #include "misc.h"
83 #include "channels.h"
84 #include "ssh.h"
85 #include "packet.h"
86 #include "roaming.h"
87 #include "ssherr.h"
88 #include "sshbuf.h"
89
90 #ifdef PACKET_DEBUG
91 #define DBG(x) x
92 #else
93 #define DBG(x)
94 #endif
95
96 #define PACKET_MAX_SIZE (256 * 1024)
97
98 struct packet_state {
99         u_int32_t seqnr;
100         u_int32_t packets;
101         u_int64_t blocks;
102         u_int64_t bytes;
103 };
104
105 struct packet {
106         TAILQ_ENTRY(packet) next;
107         u_char type;
108         struct sshbuf *payload;
109 };
110
111 struct session_state {
112         /*
113          * This variable contains the file descriptors used for
114          * communicating with the other side.  connection_in is used for
115          * reading; connection_out for writing.  These can be the same
116          * descriptor, in which case it is assumed to be a socket.
117          */
118         int connection_in;
119         int connection_out;
120
121         /* Protocol flags for the remote side. */
122         u_int remote_protocol_flags;
123
124         /* Encryption context for receiving data.  Only used for decryption. */
125         struct sshcipher_ctx receive_context;
126
127         /* Encryption context for sending data.  Only used for encryption. */
128         struct sshcipher_ctx send_context;
129
130         /* Buffer for raw input data from the socket. */
131         struct sshbuf *input;
132
133         /* Buffer for raw output data going to the socket. */
134         struct sshbuf *output;
135
136         /* Buffer for the partial outgoing packet being constructed. */
137         struct sshbuf *outgoing_packet;
138
139         /* Buffer for the incoming packet currently being processed. */
140         struct sshbuf *incoming_packet;
141
142         /* Scratch buffer for packet compression/decompression. */
143         struct sshbuf *compression_buffer;
144
145         /* Incoming/outgoing compression dictionaries */
146         z_stream compression_in_stream;
147         z_stream compression_out_stream;
148         int compression_in_started;
149         int compression_out_started;
150         int compression_in_failures;
151         int compression_out_failures;
152
153         /*
154          * Flag indicating whether packet compression/decompression is
155          * enabled.
156          */
157         int packet_compression;
158
159         /* default maximum packet size */
160         u_int max_packet_size;
161
162         /* Flag indicating whether this module has been initialized. */
163         int initialized;
164
165         /* Set to true if the connection is interactive. */
166         int interactive_mode;
167
168         /* Set to true if we are the server side. */
169         int server_side;
170
171         /* Set to true if we are authenticated. */
172         int after_authentication;
173
174         int keep_alive_timeouts;
175
176         /* The maximum time that we will wait to send or receive a packet */
177         int packet_timeout_ms;
178
179         /* Session key information for Encryption and MAC */
180         struct newkeys *newkeys[MODE_MAX];
181         struct packet_state p_read, p_send;
182
183         /* Volume-based rekeying */
184         u_int64_t max_blocks_in, max_blocks_out;
185         u_int32_t rekey_limit;
186
187         /* Time-based rekeying */
188         u_int32_t rekey_interval;       /* how often in seconds */
189         time_t rekey_time;      /* time of last rekeying */
190
191         /* Session key for protocol v1 */
192         u_char ssh1_key[SSH_SESSION_KEY_LENGTH];
193         u_int ssh1_keylen;
194
195         /* roundup current message to extra_pad bytes */
196         u_char extra_pad;
197
198         /* XXX discard incoming data after MAC error */
199         u_int packet_discard;
200         struct sshmac *packet_discard_mac;
201
202         /* Used in packet_read_poll2() */
203         u_int packlen;
204
205         /* Used in packet_send2 */
206         int rekeying;
207
208         /* Used in packet_set_interactive */
209         int set_interactive_called;
210
211         /* Used in packet_set_maxsize */
212         int set_maxsize_called;
213
214         /* One-off warning about weak ciphers */
215         int cipher_warning_done;
216
217         /* SSH1 CRC compensation attack detector */
218         struct deattack_ctx deattack;
219
220         TAILQ_HEAD(, packet) outgoing;
221 };
222
223 struct ssh *
224 ssh_alloc_session_state(void)
225 {
226         struct ssh *ssh = NULL;
227         struct session_state *state = NULL;
228
229         if ((ssh = calloc(1, sizeof(*ssh))) == NULL ||
230             (state = calloc(1, sizeof(*state))) == NULL ||
231             (state->input = sshbuf_new()) == NULL ||
232             (state->output = sshbuf_new()) == NULL ||
233             (state->outgoing_packet = sshbuf_new()) == NULL ||
234             (state->incoming_packet = sshbuf_new()) == NULL)
235                 goto fail;
236         TAILQ_INIT(&state->outgoing);
237         TAILQ_INIT(&ssh->private_keys);
238         TAILQ_INIT(&ssh->public_keys);
239         state->connection_in = -1;
240         state->connection_out = -1;
241         state->max_packet_size = 32768;
242         state->packet_timeout_ms = -1;
243         state->p_send.packets = state->p_read.packets = 0;
244         state->initialized = 1;
245         /*
246          * ssh_packet_send2() needs to queue packets until
247          * we've done the initial key exchange.
248          */
249         state->rekeying = 1;
250         ssh->state = state;
251         return ssh;
252  fail:
253         if (state) {
254                 sshbuf_free(state->input);
255                 sshbuf_free(state->output);
256                 sshbuf_free(state->incoming_packet);
257                 sshbuf_free(state->outgoing_packet);
258                 free(state);
259         }
260         free(ssh);
261         return NULL;
262 }
263
264 /*
265  * Sets the descriptors used for communication.  Disables encryption until
266  * packet_set_encryption_key is called.
267  */
268 struct ssh *
269 ssh_packet_set_connection(struct ssh *ssh, int fd_in, int fd_out)
270 {
271         struct session_state *state;
272         const struct sshcipher *none = cipher_by_name("none");
273         int r;
274
275         if (none == NULL) {
276                 error("%s: cannot load cipher 'none'", __func__);
277                 return NULL;
278         }
279         if (ssh == NULL)
280                 ssh = ssh_alloc_session_state();
281         if (ssh == NULL) {
282                 error("%s: cound not allocate state", __func__);
283                 return NULL;
284         }
285         state = ssh->state;
286         state->connection_in = fd_in;
287         state->connection_out = fd_out;
288         if ((r = cipher_init(&state->send_context, none,
289             (const u_char *)"", 0, NULL, 0, CIPHER_ENCRYPT)) != 0 ||
290             (r = cipher_init(&state->receive_context, none,
291             (const u_char *)"", 0, NULL, 0, CIPHER_DECRYPT)) != 0) {
292                 error("%s: cipher_init failed: %s", __func__, ssh_err(r));
293                 free(ssh);
294                 return NULL;
295         }
296         state->newkeys[MODE_IN] = state->newkeys[MODE_OUT] = NULL;
297         deattack_init(&state->deattack);
298         /*
299          * Cache the IP address of the remote connection for use in error
300          * messages that might be generated after the connection has closed.
301          */
302         (void)ssh_remote_ipaddr(ssh);
303         return ssh;
304 }
305
306 void
307 ssh_packet_set_timeout(struct ssh *ssh, int timeout, int count)
308 {
309         struct session_state *state = ssh->state;
310
311         if (timeout <= 0 || count <= 0) {
312                 state->packet_timeout_ms = -1;
313                 return;
314         }
315         if ((INT_MAX / 1000) / count < timeout)
316                 state->packet_timeout_ms = INT_MAX;
317         else
318                 state->packet_timeout_ms = timeout * count * 1000;
319 }
320
321 int
322 ssh_packet_stop_discard(struct ssh *ssh)
323 {
324         struct session_state *state = ssh->state;
325         int r;
326
327         if (state->packet_discard_mac) {
328                 char buf[1024];
329
330                 memset(buf, 'a', sizeof(buf));
331                 while (sshbuf_len(state->incoming_packet) <
332                     PACKET_MAX_SIZE)
333                         if ((r = sshbuf_put(state->incoming_packet, buf,
334                             sizeof(buf))) != 0)
335                                 return r;
336                 (void) mac_compute(state->packet_discard_mac,
337                     state->p_read.seqnr,
338                     sshbuf_ptr(state->incoming_packet), PACKET_MAX_SIZE,
339                     NULL, 0);
340         }
341         logit("Finished discarding for %.200s", ssh_remote_ipaddr(ssh));
342         return SSH_ERR_MAC_INVALID;
343 }
344
345 static int
346 ssh_packet_start_discard(struct ssh *ssh, struct sshenc *enc,
347     struct sshmac *mac, u_int packet_length, u_int discard)
348 {
349         struct session_state *state = ssh->state;
350         int r;
351
352         if (enc == NULL || !cipher_is_cbc(enc->cipher) || (mac && mac->etm)) {
353                 if ((r = sshpkt_disconnect(ssh, "Packet corrupt")) != 0)
354                         return r;
355                 return SSH_ERR_MAC_INVALID;
356         }
357         if (packet_length != PACKET_MAX_SIZE && mac && mac->enabled)
358                 state->packet_discard_mac = mac;
359         if (sshbuf_len(state->input) >= discard &&
360            (r = ssh_packet_stop_discard(ssh)) != 0)
361                 return r;
362         state->packet_discard = discard - sshbuf_len(state->input);
363         return 0;
364 }
365
366 /* Returns 1 if remote host is connected via socket, 0 if not. */
367
368 int
369 ssh_packet_connection_is_on_socket(struct ssh *ssh)
370 {
371         struct session_state *state = ssh->state;
372         struct sockaddr_storage from, to;
373         socklen_t fromlen, tolen;
374
375         /* filedescriptors in and out are the same, so it's a socket */
376         if (state->connection_in == state->connection_out)
377                 return 1;
378         fromlen = sizeof(from);
379         memset(&from, 0, sizeof(from));
380         if (getpeername(state->connection_in, (struct sockaddr *)&from,
381             &fromlen) < 0)
382                 return 0;
383         tolen = sizeof(to);
384         memset(&to, 0, sizeof(to));
385         if (getpeername(state->connection_out, (struct sockaddr *)&to,
386             &tolen) < 0)
387                 return 0;
388         if (fromlen != tolen || memcmp(&from, &to, fromlen) != 0)
389                 return 0;
390         if (from.ss_family != AF_INET && from.ss_family != AF_INET6)
391                 return 0;
392         return 1;
393 }
394
395 void
396 ssh_packet_get_bytes(struct ssh *ssh, u_int64_t *ibytes, u_int64_t *obytes)
397 {
398         if (ibytes)
399                 *ibytes = ssh->state->p_read.bytes;
400         if (obytes)
401                 *obytes = ssh->state->p_send.bytes;
402 }
403
404 int
405 ssh_packet_connection_af(struct ssh *ssh)
406 {
407         struct sockaddr_storage to;
408         socklen_t tolen = sizeof(to);
409
410         memset(&to, 0, sizeof(to));
411         if (getsockname(ssh->state->connection_out, (struct sockaddr *)&to,
412             &tolen) < 0)
413                 return 0;
414 #ifdef IPV4_IN_IPV6
415         if (to.ss_family == AF_INET6 &&
416             IN6_IS_ADDR_V4MAPPED(&((struct sockaddr_in6 *)&to)->sin6_addr))
417                 return AF_INET;
418 #endif
419         return to.ss_family;
420 }
421
422 /* Sets the connection into non-blocking mode. */
423
424 void
425 ssh_packet_set_nonblocking(struct ssh *ssh)
426 {
427         /* Set the socket into non-blocking mode. */
428         set_nonblock(ssh->state->connection_in);
429
430         if (ssh->state->connection_out != ssh->state->connection_in)
431                 set_nonblock(ssh->state->connection_out);
432 }
433
434 /* Returns the socket used for reading. */
435
436 int
437 ssh_packet_get_connection_in(struct ssh *ssh)
438 {
439         return ssh->state->connection_in;
440 }
441
442 /* Returns the descriptor used for writing. */
443
444 int
445 ssh_packet_get_connection_out(struct ssh *ssh)
446 {
447         return ssh->state->connection_out;
448 }
449
450 /*
451  * Returns the IP-address of the remote host as a string.  The returned
452  * string must not be freed.
453  */
454
455 const char *
456 ssh_remote_ipaddr(struct ssh *ssh)
457 {
458         /* Check whether we have cached the ipaddr. */
459         if (ssh->remote_ipaddr == NULL)
460                 ssh->remote_ipaddr = ssh_packet_connection_is_on_socket(ssh) ?
461                     get_peer_ipaddr(ssh->state->connection_in) :
462                     strdup("UNKNOWN");
463         if (ssh->remote_ipaddr == NULL)
464                 return "UNKNOWN";
465         return ssh->remote_ipaddr;
466 }
467
468 /* Closes the connection and clears and frees internal data structures. */
469
470 void
471 ssh_packet_close(struct ssh *ssh)
472 {
473         struct session_state *state = ssh->state;
474         int r;
475         u_int mode;
476
477         if (!state->initialized)
478                 return;
479         state->initialized = 0;
480         if (state->connection_in == state->connection_out) {
481                 shutdown(state->connection_out, SHUT_RDWR);
482                 close(state->connection_out);
483         } else {
484                 close(state->connection_in);
485                 close(state->connection_out);
486         }
487         sshbuf_free(state->input);
488         sshbuf_free(state->output);
489         sshbuf_free(state->outgoing_packet);
490         sshbuf_free(state->incoming_packet);
491         for (mode = 0; mode < MODE_MAX; mode++)
492                 kex_free_newkeys(state->newkeys[mode]);
493         if (state->compression_buffer) {
494                 sshbuf_free(state->compression_buffer);
495                 if (state->compression_out_started) {
496                         z_streamp stream = &state->compression_out_stream;
497                         debug("compress outgoing: "
498                             "raw data %llu, compressed %llu, factor %.2f",
499                                 (unsigned long long)stream->total_in,
500                                 (unsigned long long)stream->total_out,
501                                 stream->total_in == 0 ? 0.0 :
502                                 (double) stream->total_out / stream->total_in);
503                         if (state->compression_out_failures == 0)
504                                 deflateEnd(stream);
505                 }
506                 if (state->compression_in_started) {
507                         z_streamp stream = &state->compression_out_stream;
508                         debug("compress incoming: "
509                             "raw data %llu, compressed %llu, factor %.2f",
510                             (unsigned long long)stream->total_out,
511                             (unsigned long long)stream->total_in,
512                             stream->total_out == 0 ? 0.0 :
513                             (double) stream->total_in / stream->total_out);
514                         if (state->compression_in_failures == 0)
515                                 inflateEnd(stream);
516                 }
517         }
518         if ((r = cipher_cleanup(&state->send_context)) != 0)
519                 error("%s: cipher_cleanup failed: %s", __func__, ssh_err(r));
520         if ((r = cipher_cleanup(&state->receive_context)) != 0)
521                 error("%s: cipher_cleanup failed: %s", __func__, ssh_err(r));
522         if (ssh->remote_ipaddr) {
523                 free(ssh->remote_ipaddr);
524                 ssh->remote_ipaddr = NULL;
525         }
526         free(ssh->state);
527         ssh->state = NULL;
528 }
529
530 /* Sets remote side protocol flags. */
531
532 void
533 ssh_packet_set_protocol_flags(struct ssh *ssh, u_int protocol_flags)
534 {
535         ssh->state->remote_protocol_flags = protocol_flags;
536 }
537
538 /* Returns the remote protocol flags set earlier by the above function. */
539
540 u_int
541 ssh_packet_get_protocol_flags(struct ssh *ssh)
542 {
543         return ssh->state->remote_protocol_flags;
544 }
545
546 /*
547  * Starts packet compression from the next packet on in both directions.
548  * Level is compression level 1 (fastest) - 9 (slow, best) as in gzip.
549  */
550
551 static int
552 ssh_packet_init_compression(struct ssh *ssh)
553 {
554         if (!ssh->state->compression_buffer &&
555            ((ssh->state->compression_buffer = sshbuf_new()) == NULL))
556                 return SSH_ERR_ALLOC_FAIL;
557         return 0;
558 }
559
560 static int
561 start_compression_out(struct ssh *ssh, int level)
562 {
563         if (level < 1 || level > 9)
564                 return SSH_ERR_INVALID_ARGUMENT;
565         debug("Enabling compression at level %d.", level);
566         if (ssh->state->compression_out_started == 1)
567                 deflateEnd(&ssh->state->compression_out_stream);
568         switch (deflateInit(&ssh->state->compression_out_stream, level)) {
569         case Z_OK:
570                 ssh->state->compression_out_started = 1;
571                 break;
572         case Z_MEM_ERROR:
573                 return SSH_ERR_ALLOC_FAIL;
574         default:
575                 return SSH_ERR_INTERNAL_ERROR;
576         }
577         return 0;
578 }
579
580 static int
581 start_compression_in(struct ssh *ssh)
582 {
583         if (ssh->state->compression_in_started == 1)
584                 inflateEnd(&ssh->state->compression_in_stream);
585         switch (inflateInit(&ssh->state->compression_in_stream)) {
586         case Z_OK:
587                 ssh->state->compression_in_started = 1;
588                 break;
589         case Z_MEM_ERROR:
590                 return SSH_ERR_ALLOC_FAIL;
591         default:
592                 return SSH_ERR_INTERNAL_ERROR;
593         }
594         return 0;
595 }
596
597 int
598 ssh_packet_start_compression(struct ssh *ssh, int level)
599 {
600         int r;
601
602         if (ssh->state->packet_compression && !compat20)
603                 return SSH_ERR_INTERNAL_ERROR;
604         ssh->state->packet_compression = 1;
605         if ((r = ssh_packet_init_compression(ssh)) != 0 ||
606             (r = start_compression_in(ssh)) != 0 ||
607             (r = start_compression_out(ssh, level)) != 0)
608                 return r;
609         return 0;
610 }
611
612 /* XXX remove need for separate compression buffer */
613 static int
614 compress_buffer(struct ssh *ssh, struct sshbuf *in, struct sshbuf *out)
615 {
616         u_char buf[4096];
617         int r, status;
618
619         if (ssh->state->compression_out_started != 1)
620                 return SSH_ERR_INTERNAL_ERROR;
621
622         /* This case is not handled below. */
623         if (sshbuf_len(in) == 0)
624                 return 0;
625
626         /* Input is the contents of the input buffer. */
627         if ((ssh->state->compression_out_stream.next_in =
628             sshbuf_mutable_ptr(in)) == NULL)
629                 return SSH_ERR_INTERNAL_ERROR;
630         ssh->state->compression_out_stream.avail_in = sshbuf_len(in);
631
632         /* Loop compressing until deflate() returns with avail_out != 0. */
633         do {
634                 /* Set up fixed-size output buffer. */
635                 ssh->state->compression_out_stream.next_out = buf;
636                 ssh->state->compression_out_stream.avail_out = sizeof(buf);
637
638                 /* Compress as much data into the buffer as possible. */
639                 status = deflate(&ssh->state->compression_out_stream,
640                     Z_PARTIAL_FLUSH);
641                 switch (status) {
642                 case Z_MEM_ERROR:
643                         return SSH_ERR_ALLOC_FAIL;
644                 case Z_OK:
645                         /* Append compressed data to output_buffer. */
646                         if ((r = sshbuf_put(out, buf, sizeof(buf) -
647                             ssh->state->compression_out_stream.avail_out)) != 0)
648                                 return r;
649                         break;
650                 case Z_STREAM_ERROR:
651                 default:
652                         ssh->state->compression_out_failures++;
653                         return SSH_ERR_INVALID_FORMAT;
654                 }
655         } while (ssh->state->compression_out_stream.avail_out == 0);
656         return 0;
657 }
658
659 static int
660 uncompress_buffer(struct ssh *ssh, struct sshbuf *in, struct sshbuf *out)
661 {
662         u_char buf[4096];
663         int r, status;
664
665         if (ssh->state->compression_in_started != 1)
666                 return SSH_ERR_INTERNAL_ERROR;
667
668         if ((ssh->state->compression_in_stream.next_in =
669             sshbuf_mutable_ptr(in)) == NULL)
670                 return SSH_ERR_INTERNAL_ERROR;
671         ssh->state->compression_in_stream.avail_in = sshbuf_len(in);
672
673         for (;;) {
674                 /* Set up fixed-size output buffer. */
675                 ssh->state->compression_in_stream.next_out = buf;
676                 ssh->state->compression_in_stream.avail_out = sizeof(buf);
677
678                 status = inflate(&ssh->state->compression_in_stream,
679                     Z_PARTIAL_FLUSH);
680                 switch (status) {
681                 case Z_OK:
682                         if ((r = sshbuf_put(out, buf, sizeof(buf) -
683                             ssh->state->compression_in_stream.avail_out)) != 0)
684                                 return r;
685                         break;
686                 case Z_BUF_ERROR:
687                         /*
688                          * Comments in zlib.h say that we should keep calling
689                          * inflate() until we get an error.  This appears to
690                          * be the error that we get.
691                          */
692                         return 0;
693                 case Z_DATA_ERROR:
694                         return SSH_ERR_INVALID_FORMAT;
695                 case Z_MEM_ERROR:
696                         return SSH_ERR_ALLOC_FAIL;
697                 case Z_STREAM_ERROR:
698                 default:
699                         ssh->state->compression_in_failures++;
700                         return SSH_ERR_INTERNAL_ERROR;
701                 }
702         }
703         /* NOTREACHED */
704 }
705
706 /* Serialise compression state into a blob for privsep */
707 static int
708 ssh_packet_get_compress_state(struct sshbuf *m, struct ssh *ssh)
709 {
710         struct session_state *state = ssh->state;
711         struct sshbuf *b;
712         int r;
713
714         if ((b = sshbuf_new()) == NULL)
715                 return SSH_ERR_ALLOC_FAIL;
716         if (state->compression_in_started) {
717                 if ((r = sshbuf_put_string(b, &state->compression_in_stream,
718                     sizeof(state->compression_in_stream))) != 0)
719                         goto out;
720         } else if ((r = sshbuf_put_string(b, NULL, 0)) != 0)
721                 goto out;
722         if (state->compression_out_started) {
723                 if ((r = sshbuf_put_string(b, &state->compression_out_stream,
724                     sizeof(state->compression_out_stream))) != 0)
725                         goto out;
726         } else if ((r = sshbuf_put_string(b, NULL, 0)) != 0)
727                 goto out;
728         r = sshbuf_put_stringb(m, b);
729  out:
730         sshbuf_free(b);
731         return r;
732 }
733
734 /* Deserialise compression state from a blob for privsep */
735 static int
736 ssh_packet_set_compress_state(struct ssh *ssh, struct sshbuf *m)
737 {
738         struct session_state *state = ssh->state;
739         struct sshbuf *b = NULL;
740         int r;
741         const u_char *inblob, *outblob;
742         size_t inl, outl;
743
744         if ((r = sshbuf_froms(m, &b)) != 0)
745                 goto out;
746         if ((r = sshbuf_get_string_direct(b, &inblob, &inl)) != 0 ||
747             (r = sshbuf_get_string_direct(b, &outblob, &outl)) != 0)
748                 goto out;
749         if (inl == 0)
750                 state->compression_in_started = 0;
751         else if (inl != sizeof(state->compression_in_stream)) {
752                 r = SSH_ERR_INTERNAL_ERROR;
753                 goto out;
754         } else {
755                 state->compression_in_started = 1;
756                 memcpy(&state->compression_in_stream, inblob, inl);
757         }
758         if (outl == 0)
759                 state->compression_out_started = 0;
760         else if (outl != sizeof(state->compression_out_stream)) {
761                 r = SSH_ERR_INTERNAL_ERROR;
762                 goto out;
763         } else {
764                 state->compression_out_started = 1;
765                 memcpy(&state->compression_out_stream, outblob, outl);
766         }
767         r = 0;
768  out:
769         sshbuf_free(b);
770         return r;
771 }
772
773 void
774 ssh_packet_set_compress_hooks(struct ssh *ssh, void *ctx,
775     void *(*allocfunc)(void *, u_int, u_int),
776     void (*freefunc)(void *, void *))
777 {
778         ssh->state->compression_out_stream.zalloc = (alloc_func)allocfunc;
779         ssh->state->compression_out_stream.zfree = (free_func)freefunc;
780         ssh->state->compression_out_stream.opaque = ctx;
781         ssh->state->compression_in_stream.zalloc = (alloc_func)allocfunc;
782         ssh->state->compression_in_stream.zfree = (free_func)freefunc;
783         ssh->state->compression_in_stream.opaque = ctx;
784 }
785
786 /*
787  * Causes any further packets to be encrypted using the given key.  The same
788  * key is used for both sending and reception.  However, both directions are
789  * encrypted independently of each other.
790  */
791
792 void
793 ssh_packet_set_encryption_key(struct ssh *ssh, const u_char *key, u_int keylen, int number)
794 {
795 #ifndef WITH_SSH1
796         fatal("no SSH protocol 1 support");
797 #else /* WITH_SSH1 */
798         struct session_state *state = ssh->state;
799         const struct sshcipher *cipher = cipher_by_number(number);
800         int r;
801         const char *wmsg;
802
803         if (cipher == NULL)
804                 fatal("%s: unknown cipher number %d", __func__, number);
805         if (keylen < 20)
806                 fatal("%s: keylen too small: %d", __func__, keylen);
807         if (keylen > SSH_SESSION_KEY_LENGTH)
808                 fatal("%s: keylen too big: %d", __func__, keylen);
809         memcpy(state->ssh1_key, key, keylen);
810         state->ssh1_keylen = keylen;
811         if ((r = cipher_init(&state->send_context, cipher, key, keylen,
812             NULL, 0, CIPHER_ENCRYPT)) != 0 ||
813             (r = cipher_init(&state->receive_context, cipher, key, keylen,
814             NULL, 0, CIPHER_DECRYPT) != 0))
815                 fatal("%s: cipher_init failed: %s", __func__, ssh_err(r));
816         if (!state->cipher_warning_done &&
817             ((wmsg = cipher_warning_message(&state->send_context)) != NULL ||
818             (wmsg = cipher_warning_message(&state->send_context)) != NULL)) {
819                 error("Warning: %s", wmsg);
820                 state->cipher_warning_done = 1;
821         }
822 #endif /* WITH_SSH1 */
823 }
824
825 /*
826  * Finalizes and sends the packet.  If the encryption key has been set,
827  * encrypts the packet before sending.
828  */
829
830 int
831 ssh_packet_send1(struct ssh *ssh)
832 {
833         struct session_state *state = ssh->state;
834         u_char buf[8], *cp;
835         int r, padding, len;
836         u_int checksum;
837
838         /*
839          * If using packet compression, compress the payload of the outgoing
840          * packet.
841          */
842         if (state->packet_compression) {
843                 sshbuf_reset(state->compression_buffer);
844                 /* Skip padding. */
845                 if ((r = sshbuf_consume(state->outgoing_packet, 8)) != 0)
846                         goto out;
847                 /* padding */
848                 if ((r = sshbuf_put(state->compression_buffer,
849                     "\0\0\0\0\0\0\0\0", 8)) != 0)
850                         goto out;
851                 if ((r = compress_buffer(ssh, state->outgoing_packet,
852                     state->compression_buffer)) != 0)
853                         goto out;
854                 sshbuf_reset(state->outgoing_packet);
855                 if ((r = sshbuf_putb(state->outgoing_packet,
856                     state->compression_buffer)) != 0)
857                         goto out;
858         }
859         /* Compute packet length without padding (add checksum, remove padding). */
860         len = sshbuf_len(state->outgoing_packet) + 4 - 8;
861
862         /* Insert padding. Initialized to zero in packet_start1() */
863         padding = 8 - len % 8;
864         if (!state->send_context.plaintext) {
865                 cp = sshbuf_mutable_ptr(state->outgoing_packet);
866                 if (cp == NULL) {
867                         r = SSH_ERR_INTERNAL_ERROR;
868                         goto out;
869                 }
870                 arc4random_buf(cp + 8 - padding, padding);
871         }
872         if ((r = sshbuf_consume(state->outgoing_packet, 8 - padding)) != 0)
873                 goto out;
874
875         /* Add check bytes. */
876         checksum = ssh_crc32(sshbuf_ptr(state->outgoing_packet),
877             sshbuf_len(state->outgoing_packet));
878         POKE_U32(buf, checksum);
879         if ((r = sshbuf_put(state->outgoing_packet, buf, 4)) != 0)
880                 goto out;
881
882 #ifdef PACKET_DEBUG
883         fprintf(stderr, "packet_send plain: ");
884         sshbuf_dump(state->outgoing_packet, stderr);
885 #endif
886
887         /* Append to output. */
888         POKE_U32(buf, len);
889         if ((r = sshbuf_put(state->output, buf, 4)) != 0)
890                 goto out;
891         if ((r = sshbuf_reserve(state->output,
892             sshbuf_len(state->outgoing_packet), &cp)) != 0)
893                 goto out;
894         if ((r = cipher_crypt(&state->send_context, 0, cp,
895             sshbuf_ptr(state->outgoing_packet),
896             sshbuf_len(state->outgoing_packet), 0, 0)) != 0)
897                 goto out;
898
899 #ifdef PACKET_DEBUG
900         fprintf(stderr, "encrypted: ");
901         sshbuf_dump(state->output, stderr);
902 #endif
903         state->p_send.packets++;
904         state->p_send.bytes += len +
905             sshbuf_len(state->outgoing_packet);
906         sshbuf_reset(state->outgoing_packet);
907
908         /*
909          * Note that the packet is now only buffered in output.  It won't be
910          * actually sent until ssh_packet_write_wait or ssh_packet_write_poll
911          * is called.
912          */
913         r = 0;
914  out:
915         return r;
916 }
917
918 int
919 ssh_set_newkeys(struct ssh *ssh, int mode)
920 {
921         struct session_state *state = ssh->state;
922         struct sshenc *enc;
923         struct sshmac *mac;
924         struct sshcomp *comp;
925         struct sshcipher_ctx *cc;
926         u_int64_t *max_blocks;
927         const char *wmsg;
928         int r, crypt_type;
929
930         debug2("set_newkeys: mode %d", mode);
931
932         if (mode == MODE_OUT) {
933                 cc = &state->send_context;
934                 crypt_type = CIPHER_ENCRYPT;
935                 state->p_send.packets = state->p_send.blocks = 0;
936                 max_blocks = &state->max_blocks_out;
937         } else {
938                 cc = &state->receive_context;
939                 crypt_type = CIPHER_DECRYPT;
940                 state->p_read.packets = state->p_read.blocks = 0;
941                 max_blocks = &state->max_blocks_in;
942         }
943         if (state->newkeys[mode] != NULL) {
944                 debug("set_newkeys: rekeying");
945                 if ((r = cipher_cleanup(cc)) != 0)
946                         return r;
947                 enc  = &state->newkeys[mode]->enc;
948                 mac  = &state->newkeys[mode]->mac;
949                 comp = &state->newkeys[mode]->comp;
950                 mac_clear(mac);
951                 explicit_bzero(enc->iv,  enc->iv_len);
952                 explicit_bzero(enc->key, enc->key_len);
953                 explicit_bzero(mac->key, mac->key_len);
954                 free(enc->name);
955                 free(enc->iv);
956                 free(enc->key);
957                 free(mac->name);
958                 free(mac->key);
959                 free(comp->name);
960                 free(state->newkeys[mode]);
961         }
962         /* move newkeys from kex to state */
963         if ((state->newkeys[mode] = ssh->kex->newkeys[mode]) == NULL)
964                 return SSH_ERR_INTERNAL_ERROR;
965         ssh->kex->newkeys[mode] = NULL;
966         enc  = &state->newkeys[mode]->enc;
967         mac  = &state->newkeys[mode]->mac;
968         comp = &state->newkeys[mode]->comp;
969         if (cipher_authlen(enc->cipher) == 0) {
970                 if ((r = mac_init(mac)) != 0)
971                         return r;
972         }
973         mac->enabled = 1;
974         DBG(debug("cipher_init_context: %d", mode));
975         if ((r = cipher_init(cc, enc->cipher, enc->key, enc->key_len,
976             enc->iv, enc->iv_len, crypt_type)) != 0)
977                 return r;
978         if (!state->cipher_warning_done &&
979             (wmsg = cipher_warning_message(cc)) != NULL) {
980                 error("Warning: %s", wmsg);
981                 state->cipher_warning_done = 1;
982         }
983         /* Deleting the keys does not gain extra security */
984         /* explicit_bzero(enc->iv,  enc->block_size);
985            explicit_bzero(enc->key, enc->key_len);
986            explicit_bzero(mac->key, mac->key_len); */
987         if ((comp->type == COMP_ZLIB ||
988             (comp->type == COMP_DELAYED &&
989              state->after_authentication)) && comp->enabled == 0) {
990                 if ((r = ssh_packet_init_compression(ssh)) < 0)
991                         return r;
992                 if (mode == MODE_OUT) {
993                         if ((r = start_compression_out(ssh, 6)) != 0)
994                                 return r;
995                 } else {
996                         if ((r = start_compression_in(ssh)) != 0)
997                                 return r;
998                 }
999                 comp->enabled = 1;
1000         }
1001         /*
1002          * The 2^(blocksize*2) limit is too expensive for 3DES,
1003          * blowfish, etc, so enforce a 1GB limit for small blocksizes.
1004          */
1005         if (enc->block_size >= 16)
1006                 *max_blocks = (u_int64_t)1 << (enc->block_size*2);
1007         else
1008                 *max_blocks = ((u_int64_t)1 << 30) / enc->block_size;
1009         if (state->rekey_limit)
1010                 *max_blocks = MIN(*max_blocks,
1011                     state->rekey_limit / enc->block_size);
1012         return 0;
1013 }
1014
1015 /*
1016  * Delayed compression for SSH2 is enabled after authentication:
1017  * This happens on the server side after a SSH2_MSG_USERAUTH_SUCCESS is sent,
1018  * and on the client side after a SSH2_MSG_USERAUTH_SUCCESS is received.
1019  */
1020 static int
1021 ssh_packet_enable_delayed_compress(struct ssh *ssh)
1022 {
1023         struct session_state *state = ssh->state;
1024         struct sshcomp *comp = NULL;
1025         int r, mode;
1026
1027         /*
1028          * Remember that we are past the authentication step, so rekeying
1029          * with COMP_DELAYED will turn on compression immediately.
1030          */
1031         state->after_authentication = 1;
1032         for (mode = 0; mode < MODE_MAX; mode++) {
1033                 /* protocol error: USERAUTH_SUCCESS received before NEWKEYS */
1034                 if (state->newkeys[mode] == NULL)
1035                         continue;
1036                 comp = &state->newkeys[mode]->comp;
1037                 if (comp && !comp->enabled && comp->type == COMP_DELAYED) {
1038                         if ((r = ssh_packet_init_compression(ssh)) != 0)
1039                                 return r;
1040                         if (mode == MODE_OUT) {
1041                                 if ((r = start_compression_out(ssh, 6)) != 0)
1042                                         return r;
1043                         } else {
1044                                 if ((r = start_compression_in(ssh)) != 0)
1045                                         return r;
1046                         }
1047                         comp->enabled = 1;
1048                 }
1049         }
1050         return 0;
1051 }
1052
1053 /*
1054  * Finalize packet in SSH2 format (compress, mac, encrypt, enqueue)
1055  */
1056 int
1057 ssh_packet_send2_wrapped(struct ssh *ssh)
1058 {
1059         struct session_state *state = ssh->state;
1060         u_char type, *cp, macbuf[SSH_DIGEST_MAX_LENGTH];
1061         u_char padlen, pad = 0;
1062         u_int authlen = 0, aadlen = 0;
1063         u_int len;
1064         struct sshenc *enc   = NULL;
1065         struct sshmac *mac   = NULL;
1066         struct sshcomp *comp = NULL;
1067         int r, block_size;
1068
1069         if (state->newkeys[MODE_OUT] != NULL) {
1070                 enc  = &state->newkeys[MODE_OUT]->enc;
1071                 mac  = &state->newkeys[MODE_OUT]->mac;
1072                 comp = &state->newkeys[MODE_OUT]->comp;
1073                 /* disable mac for authenticated encryption */
1074                 if ((authlen = cipher_authlen(enc->cipher)) != 0)
1075                         mac = NULL;
1076         }
1077         block_size = enc ? enc->block_size : 8;
1078         aadlen = (mac && mac->enabled && mac->etm) || authlen ? 4 : 0;
1079
1080         type = (sshbuf_ptr(state->outgoing_packet))[5];
1081
1082 #ifdef PACKET_DEBUG
1083         fprintf(stderr, "plain:     ");
1084         sshbuf_dump(state->outgoing_packet, stderr);
1085 #endif
1086
1087         if (comp && comp->enabled) {
1088                 len = sshbuf_len(state->outgoing_packet);
1089                 /* skip header, compress only payload */
1090                 if ((r = sshbuf_consume(state->outgoing_packet, 5)) != 0)
1091                         goto out;
1092                 sshbuf_reset(state->compression_buffer);
1093                 if ((r = compress_buffer(ssh, state->outgoing_packet,
1094                     state->compression_buffer)) != 0)
1095                         goto out;
1096                 sshbuf_reset(state->outgoing_packet);
1097                 if ((r = sshbuf_put(state->outgoing_packet,
1098                     "\0\0\0\0\0", 5)) != 0 ||
1099                     (r = sshbuf_putb(state->outgoing_packet,
1100                     state->compression_buffer)) != 0)
1101                         goto out;
1102                 DBG(debug("compression: raw %d compressed %zd", len,
1103                     sshbuf_len(state->outgoing_packet)));
1104         }
1105
1106         /* sizeof (packet_len + pad_len + payload) */
1107         len = sshbuf_len(state->outgoing_packet);
1108
1109         /*
1110          * calc size of padding, alloc space, get random data,
1111          * minimum padding is 4 bytes
1112          */
1113         len -= aadlen; /* packet length is not encrypted for EtM modes */
1114         padlen = block_size - (len % block_size);
1115         if (padlen < 4)
1116                 padlen += block_size;
1117         if (state->extra_pad) {
1118                 /* will wrap if extra_pad+padlen > 255 */
1119                 state->extra_pad =
1120                     roundup(state->extra_pad, block_size);
1121                 pad = state->extra_pad -
1122                     ((len + padlen) % state->extra_pad);
1123                 DBG(debug3("%s: adding %d (len %d padlen %d extra_pad %d)",
1124                     __func__, pad, len, padlen, state->extra_pad));
1125                 padlen += pad;
1126                 state->extra_pad = 0;
1127         }
1128         if ((r = sshbuf_reserve(state->outgoing_packet, padlen, &cp)) != 0)
1129                 goto out;
1130         if (enc && !state->send_context.plaintext) {
1131                 /* random padding */
1132                 arc4random_buf(cp, padlen);
1133         } else {
1134                 /* clear padding */
1135                 explicit_bzero(cp, padlen);
1136         }
1137         /* sizeof (packet_len + pad_len + payload + padding) */
1138         len = sshbuf_len(state->outgoing_packet);
1139         cp = sshbuf_mutable_ptr(state->outgoing_packet);
1140         if (cp == NULL) {
1141                 r = SSH_ERR_INTERNAL_ERROR;
1142                 goto out;
1143         }
1144         /* packet_length includes payload, padding and padding length field */
1145         POKE_U32(cp, len - 4);
1146         cp[4] = padlen;
1147         DBG(debug("send: len %d (includes padlen %d, aadlen %d)",
1148             len, padlen, aadlen));
1149
1150         /* compute MAC over seqnr and packet(length fields, payload, padding) */
1151         if (mac && mac->enabled && !mac->etm) {
1152                 if ((r = mac_compute(mac, state->p_send.seqnr,
1153                     sshbuf_ptr(state->outgoing_packet), len,
1154                     macbuf, sizeof(macbuf))) != 0)
1155                         goto out;
1156                 DBG(debug("done calc MAC out #%d", state->p_send.seqnr));
1157         }
1158         /* encrypt packet and append to output buffer. */
1159         if ((r = sshbuf_reserve(state->output,
1160             sshbuf_len(state->outgoing_packet) + authlen, &cp)) != 0)
1161                 goto out;
1162         if ((r = cipher_crypt(&state->send_context, state->p_send.seqnr, cp,
1163             sshbuf_ptr(state->outgoing_packet),
1164             len - aadlen, aadlen, authlen)) != 0)
1165                 goto out;
1166         /* append unencrypted MAC */
1167         if (mac && mac->enabled) {
1168                 if (mac->etm) {
1169                         /* EtM: compute mac over aadlen + cipher text */
1170                         if ((r = mac_compute(mac, state->p_send.seqnr,
1171                             cp, len, macbuf, sizeof(macbuf))) != 0)
1172                                 goto out;
1173                         DBG(debug("done calc MAC(EtM) out #%d",
1174                             state->p_send.seqnr));
1175                 }
1176                 if ((r = sshbuf_put(state->output, macbuf, mac->mac_len)) != 0)
1177                         goto out;
1178         }
1179 #ifdef PACKET_DEBUG
1180         fprintf(stderr, "encrypted: ");
1181         sshbuf_dump(state->output, stderr);
1182 #endif
1183         /* increment sequence number for outgoing packets */
1184         if (++state->p_send.seqnr == 0)
1185                 logit("outgoing seqnr wraps around");
1186         if (++state->p_send.packets == 0)
1187                 if (!(ssh->compat & SSH_BUG_NOREKEY))
1188                         return SSH_ERR_NEED_REKEY;
1189         state->p_send.blocks += len / block_size;
1190         state->p_send.bytes += len;
1191         sshbuf_reset(state->outgoing_packet);
1192
1193         if (type == SSH2_MSG_NEWKEYS)
1194                 r = ssh_set_newkeys(ssh, MODE_OUT);
1195         else if (type == SSH2_MSG_USERAUTH_SUCCESS && state->server_side)
1196                 r = ssh_packet_enable_delayed_compress(ssh);
1197         else
1198                 r = 0;
1199  out:
1200         return r;
1201 }
1202
1203 int
1204 ssh_packet_send2(struct ssh *ssh)
1205 {
1206         struct session_state *state = ssh->state;
1207         struct packet *p;
1208         u_char type;
1209         int r;
1210
1211         type = sshbuf_ptr(state->outgoing_packet)[5];
1212
1213         /* during rekeying we can only send key exchange messages */
1214         if (state->rekeying) {
1215                 if ((type < SSH2_MSG_TRANSPORT_MIN) ||
1216                     (type > SSH2_MSG_TRANSPORT_MAX) ||
1217                     (type == SSH2_MSG_SERVICE_REQUEST) ||
1218                     (type == SSH2_MSG_SERVICE_ACCEPT)) {
1219                         debug("enqueue packet: %u", type);
1220                         p = calloc(1, sizeof(*p));
1221                         if (p == NULL)
1222                                 return SSH_ERR_ALLOC_FAIL;
1223                         p->type = type;
1224                         p->payload = state->outgoing_packet;
1225                         TAILQ_INSERT_TAIL(&state->outgoing, p, next);
1226                         state->outgoing_packet = sshbuf_new();
1227                         if (state->outgoing_packet == NULL)
1228                                 return SSH_ERR_ALLOC_FAIL;
1229                         return 0;
1230                 }
1231         }
1232
1233         /* rekeying starts with sending KEXINIT */
1234         if (type == SSH2_MSG_KEXINIT)
1235                 state->rekeying = 1;
1236
1237         if ((r = ssh_packet_send2_wrapped(ssh)) != 0)
1238                 return r;
1239
1240         /* after a NEWKEYS message we can send the complete queue */
1241         if (type == SSH2_MSG_NEWKEYS) {
1242                 state->rekeying = 0;
1243                 state->rekey_time = monotime();
1244                 while ((p = TAILQ_FIRST(&state->outgoing))) {
1245                         type = p->type;
1246                         debug("dequeue packet: %u", type);
1247                         sshbuf_free(state->outgoing_packet);
1248                         state->outgoing_packet = p->payload;
1249                         TAILQ_REMOVE(&state->outgoing, p, next);
1250                         free(p);
1251                         if ((r = ssh_packet_send2_wrapped(ssh)) != 0)
1252                                 return r;
1253                 }
1254         }
1255         return 0;
1256 }
1257
1258 /*
1259  * Waits until a packet has been received, and returns its type.  Note that
1260  * no other data is processed until this returns, so this function should not
1261  * be used during the interactive session.
1262  */
1263
1264 int
1265 ssh_packet_read_seqnr(struct ssh *ssh, u_char *typep, u_int32_t *seqnr_p)
1266 {
1267         struct session_state *state = ssh->state;
1268         int len, r, ms_remain, cont;
1269         fd_set *setp;
1270         char buf[8192];
1271         struct timeval timeout, start, *timeoutp = NULL;
1272
1273         DBG(debug("packet_read()"));
1274
1275         setp = (fd_set *)calloc(howmany(state->connection_in + 1,
1276             NFDBITS), sizeof(fd_mask));
1277         if (setp == NULL)
1278                 return SSH_ERR_ALLOC_FAIL;
1279
1280         /*
1281          * Since we are blocking, ensure that all written packets have
1282          * been sent.
1283          */
1284         if ((r = ssh_packet_write_wait(ssh)) != 0)
1285                 goto out;
1286
1287         /* Stay in the loop until we have received a complete packet. */
1288         for (;;) {
1289                 /* Try to read a packet from the buffer. */
1290                 r = ssh_packet_read_poll_seqnr(ssh, typep, seqnr_p);
1291                 if (r != 0)
1292                         break;
1293                 if (!compat20 && (
1294                     *typep == SSH_SMSG_SUCCESS
1295                     || *typep == SSH_SMSG_FAILURE
1296                     || *typep == SSH_CMSG_EOF
1297                     || *typep == SSH_CMSG_EXIT_CONFIRMATION))
1298                         if ((r = sshpkt_get_end(ssh)) != 0)
1299                                 break;
1300                 /* If we got a packet, return it. */
1301                 if (*typep != SSH_MSG_NONE)
1302                         break;
1303                 /*
1304                  * Otherwise, wait for some data to arrive, add it to the
1305                  * buffer, and try again.
1306                  */
1307                 memset(setp, 0, howmany(state->connection_in + 1,
1308                     NFDBITS) * sizeof(fd_mask));
1309                 FD_SET(state->connection_in, setp);
1310
1311                 if (state->packet_timeout_ms > 0) {
1312                         ms_remain = state->packet_timeout_ms;
1313                         timeoutp = &timeout;
1314                 }
1315                 /* Wait for some data to arrive. */
1316                 for (;;) {
1317                         if (state->packet_timeout_ms != -1) {
1318                                 ms_to_timeval(&timeout, ms_remain);
1319                                 gettimeofday(&start, NULL);
1320                         }
1321                         if ((r = select(state->connection_in + 1, setp,
1322                             NULL, NULL, timeoutp)) >= 0)
1323                                 break;
1324                         if (errno != EAGAIN && errno != EINTR &&
1325                             errno != EWOULDBLOCK)
1326                                 break;
1327                         if (state->packet_timeout_ms == -1)
1328                                 continue;
1329                         ms_subtract_diff(&start, &ms_remain);
1330                         if (ms_remain <= 0) {
1331                                 r = 0;
1332                                 break;
1333                         }
1334                 }
1335                 if (r == 0)
1336                         return SSH_ERR_CONN_TIMEOUT;
1337                 /* Read data from the socket. */
1338                 do {
1339                         cont = 0;
1340                         len = roaming_read(state->connection_in, buf,
1341                             sizeof(buf), &cont);
1342                 } while (len == 0 && cont);
1343                 if (len == 0) {
1344                         r = SSH_ERR_CONN_CLOSED;
1345                         goto out;
1346                 }
1347                 if (len < 0) {
1348                         r = SSH_ERR_SYSTEM_ERROR;
1349                         goto out;
1350                 }
1351
1352                 /* Append it to the buffer. */
1353                 if ((r = ssh_packet_process_incoming(ssh, buf, len)) != 0)
1354                         goto out;
1355         }
1356  out:
1357         free(setp);
1358         return r;
1359 }
1360
1361 int
1362 ssh_packet_read(struct ssh *ssh)
1363 {
1364         u_char type;
1365         int r;
1366
1367         if ((r = ssh_packet_read_seqnr(ssh, &type, NULL)) != 0)
1368                 fatal("%s: %s", __func__, ssh_err(r));
1369         return type;
1370 }
1371
1372 /*
1373  * Waits until a packet has been received, verifies that its type matches
1374  * that given, and gives a fatal error and exits if there is a mismatch.
1375  */
1376
1377 int
1378 ssh_packet_read_expect(struct ssh *ssh, u_int expected_type)
1379 {
1380         int r;
1381         u_char type;
1382
1383         if ((r = ssh_packet_read_seqnr(ssh, &type, NULL)) != 0)
1384                 return r;
1385         if (type != expected_type) {
1386                 if ((r = sshpkt_disconnect(ssh,
1387                     "Protocol error: expected packet type %d, got %d",
1388                     expected_type, type)) != 0)
1389                         return r;
1390                 return SSH_ERR_PROTOCOL_ERROR;
1391         }
1392         return 0;
1393 }
1394
1395 /* Checks if a full packet is available in the data received so far via
1396  * packet_process_incoming.  If so, reads the packet; otherwise returns
1397  * SSH_MSG_NONE.  This does not wait for data from the connection.
1398  *
1399  * SSH_MSG_DISCONNECT is handled specially here.  Also,
1400  * SSH_MSG_IGNORE messages are skipped by this function and are never returned
1401  * to higher levels.
1402  */
1403
1404 int
1405 ssh_packet_read_poll1(struct ssh *ssh, u_char *typep)
1406 {
1407         struct session_state *state = ssh->state;
1408         u_int len, padded_len;
1409         const char *emsg;
1410         const u_char *cp;
1411         u_char *p;
1412         u_int checksum, stored_checksum;
1413         int r;
1414
1415         *typep = SSH_MSG_NONE;
1416
1417         /* Check if input size is less than minimum packet size. */
1418         if (sshbuf_len(state->input) < 4 + 8)
1419                 return 0;
1420         /* Get length of incoming packet. */
1421         len = PEEK_U32(sshbuf_ptr(state->input));
1422         if (len < 1 + 2 + 2 || len > 256 * 1024) {
1423                 if ((r = sshpkt_disconnect(ssh, "Bad packet length %u",
1424                     len)) != 0)
1425                         return r;
1426                 return SSH_ERR_CONN_CORRUPT;
1427         }
1428         padded_len = (len + 8) & ~7;
1429
1430         /* Check if the packet has been entirely received. */
1431         if (sshbuf_len(state->input) < 4 + padded_len)
1432                 return 0;
1433
1434         /* The entire packet is in buffer. */
1435
1436         /* Consume packet length. */
1437         if ((r = sshbuf_consume(state->input, 4)) != 0)
1438                 goto out;
1439
1440         /*
1441          * Cryptographic attack detector for ssh
1442          * (C)1998 CORE-SDI, Buenos Aires Argentina
1443          * Ariel Futoransky(futo@core-sdi.com)
1444          */
1445         if (!state->receive_context.plaintext) {
1446                 emsg = NULL;
1447                 switch (detect_attack(&state->deattack,
1448                     sshbuf_ptr(state->input), padded_len)) {
1449                 case DEATTACK_OK:
1450                         break;
1451                 case DEATTACK_DETECTED:
1452                         emsg = "crc32 compensation attack detected";
1453                         break;
1454                 case DEATTACK_DOS_DETECTED:
1455                         emsg = "deattack denial of service detected";
1456                         break;
1457                 default:
1458                         emsg = "deattack error";
1459                         break;
1460                 }
1461                 if (emsg != NULL) {
1462                         error("%s", emsg);
1463                         if ((r = sshpkt_disconnect(ssh, "%s", emsg)) != 0 ||
1464                             (r = ssh_packet_write_wait(ssh)) != 0)
1465                                         return r;
1466                         return SSH_ERR_CONN_CORRUPT;
1467                 }
1468         }
1469
1470         /* Decrypt data to incoming_packet. */
1471         sshbuf_reset(state->incoming_packet);
1472         if ((r = sshbuf_reserve(state->incoming_packet, padded_len, &p)) != 0)
1473                 goto out;
1474         if ((r = cipher_crypt(&state->receive_context, 0, p,
1475             sshbuf_ptr(state->input), padded_len, 0, 0)) != 0)
1476                 goto out;
1477
1478         if ((r = sshbuf_consume(state->input, padded_len)) != 0)
1479                 goto out;
1480
1481 #ifdef PACKET_DEBUG
1482         fprintf(stderr, "read_poll plain: ");
1483         sshbuf_dump(state->incoming_packet, stderr);
1484 #endif
1485
1486         /* Compute packet checksum. */
1487         checksum = ssh_crc32(sshbuf_ptr(state->incoming_packet),
1488             sshbuf_len(state->incoming_packet) - 4);
1489
1490         /* Skip padding. */
1491         if ((r = sshbuf_consume(state->incoming_packet, 8 - len % 8)) != 0)
1492                 goto out;
1493
1494         /* Test check bytes. */
1495         if (len != sshbuf_len(state->incoming_packet)) {
1496                 error("%s: len %d != sshbuf_len %zd", __func__,
1497                     len, sshbuf_len(state->incoming_packet));
1498                 if ((r = sshpkt_disconnect(ssh, "invalid packet length")) != 0 ||
1499                     (r = ssh_packet_write_wait(ssh)) != 0)
1500                         return r;
1501                 return SSH_ERR_CONN_CORRUPT;
1502         }
1503
1504         cp = sshbuf_ptr(state->incoming_packet) + len - 4;
1505         stored_checksum = PEEK_U32(cp);
1506         if (checksum != stored_checksum) {
1507                 error("Corrupted check bytes on input");
1508                 if ((r = sshpkt_disconnect(ssh, "connection corrupted")) != 0 ||
1509                     (r = ssh_packet_write_wait(ssh)) != 0)
1510                         return r;
1511                 return SSH_ERR_CONN_CORRUPT;
1512         }
1513         if ((r = sshbuf_consume_end(state->incoming_packet, 4)) < 0)
1514                 goto out;
1515
1516         if (state->packet_compression) {
1517                 sshbuf_reset(state->compression_buffer);
1518                 if ((r = uncompress_buffer(ssh, state->incoming_packet,
1519                     state->compression_buffer)) != 0)
1520                         goto out;
1521                 sshbuf_reset(state->incoming_packet);
1522                 if ((r = sshbuf_putb(state->incoming_packet,
1523                     state->compression_buffer)) != 0)
1524                         goto out;
1525         }
1526         state->p_read.packets++;
1527         state->p_read.bytes += padded_len + 4;
1528         if ((r = sshbuf_get_u8(state->incoming_packet, typep)) != 0)
1529                 goto out;
1530         if (*typep < SSH_MSG_MIN || *typep > SSH_MSG_MAX) {
1531                 error("Invalid ssh1 packet type: %d", *typep);
1532                 if ((r = sshpkt_disconnect(ssh, "invalid packet type")) != 0 ||
1533                     (r = ssh_packet_write_wait(ssh)) != 0)
1534                         return r;
1535                 return SSH_ERR_PROTOCOL_ERROR;
1536         }
1537         r = 0;
1538  out:
1539         return r;
1540 }
1541
1542 int
1543 ssh_packet_read_poll2(struct ssh *ssh, u_char *typep, u_int32_t *seqnr_p)
1544 {
1545         struct session_state *state = ssh->state;
1546         u_int padlen, need;
1547         u_char *cp, macbuf[SSH_DIGEST_MAX_LENGTH];
1548         u_int maclen, aadlen = 0, authlen = 0, block_size;
1549         struct sshenc *enc   = NULL;
1550         struct sshmac *mac   = NULL;
1551         struct sshcomp *comp = NULL;
1552         int r;
1553
1554         *typep = SSH_MSG_NONE;
1555
1556         if (state->packet_discard)
1557                 return 0;
1558
1559         if (state->newkeys[MODE_IN] != NULL) {
1560                 enc  = &state->newkeys[MODE_IN]->enc;
1561                 mac  = &state->newkeys[MODE_IN]->mac;
1562                 comp = &state->newkeys[MODE_IN]->comp;
1563                 /* disable mac for authenticated encryption */
1564                 if ((authlen = cipher_authlen(enc->cipher)) != 0)
1565                         mac = NULL;
1566         }
1567         maclen = mac && mac->enabled ? mac->mac_len : 0;
1568         block_size = enc ? enc->block_size : 8;
1569         aadlen = (mac && mac->enabled && mac->etm) || authlen ? 4 : 0;
1570
1571         if (aadlen && state->packlen == 0) {
1572                 if (cipher_get_length(&state->receive_context,
1573                     &state->packlen, state->p_read.seqnr,
1574                     sshbuf_ptr(state->input), sshbuf_len(state->input)) != 0)
1575                         return 0;
1576                 if (state->packlen < 1 + 4 ||
1577                     state->packlen > PACKET_MAX_SIZE) {
1578 #ifdef PACKET_DEBUG
1579                         sshbuf_dump(state->input, stderr);
1580 #endif
1581                         logit("Bad packet length %u.", state->packlen);
1582                         if ((r = sshpkt_disconnect(ssh, "Packet corrupt")) != 0)
1583                                 return r;
1584                 }
1585                 sshbuf_reset(state->incoming_packet);
1586         } else if (state->packlen == 0) {
1587                 /*
1588                  * check if input size is less than the cipher block size,
1589                  * decrypt first block and extract length of incoming packet
1590                  */
1591                 if (sshbuf_len(state->input) < block_size)
1592                         return 0;
1593                 sshbuf_reset(state->incoming_packet);
1594                 if ((r = sshbuf_reserve(state->incoming_packet, block_size,
1595                     &cp)) != 0)
1596                         goto out;
1597                 if ((r = cipher_crypt(&state->receive_context,
1598                     state->p_send.seqnr, cp, sshbuf_ptr(state->input),
1599                     block_size, 0, 0)) != 0)
1600                         goto out;
1601                 state->packlen = PEEK_U32(sshbuf_ptr(state->incoming_packet));
1602                 if (state->packlen < 1 + 4 ||
1603                     state->packlen > PACKET_MAX_SIZE) {
1604 #ifdef PACKET_DEBUG
1605                         fprintf(stderr, "input: \n");
1606                         sshbuf_dump(state->input, stderr);
1607                         fprintf(stderr, "incoming_packet: \n");
1608                         sshbuf_dump(state->incoming_packet, stderr);
1609 #endif
1610                         logit("Bad packet length %u.", state->packlen);
1611                         return ssh_packet_start_discard(ssh, enc, mac,
1612                             state->packlen, PACKET_MAX_SIZE);
1613                 }
1614                 if ((r = sshbuf_consume(state->input, block_size)) != 0)
1615                         goto out;
1616         }
1617         DBG(debug("input: packet len %u", state->packlen+4));
1618
1619         if (aadlen) {
1620                 /* only the payload is encrypted */
1621                 need = state->packlen;
1622         } else {
1623                 /*
1624                  * the payload size and the payload are encrypted, but we
1625                  * have a partial packet of block_size bytes
1626                  */
1627                 need = 4 + state->packlen - block_size;
1628         }
1629         DBG(debug("partial packet: block %d, need %d, maclen %d, authlen %d,"
1630             " aadlen %d", block_size, need, maclen, authlen, aadlen));
1631         if (need % block_size != 0) {
1632                 logit("padding error: need %d block %d mod %d",
1633                     need, block_size, need % block_size);
1634                 return ssh_packet_start_discard(ssh, enc, mac,
1635                     state->packlen, PACKET_MAX_SIZE - block_size);
1636         }
1637         /*
1638          * check if the entire packet has been received and
1639          * decrypt into incoming_packet:
1640          * 'aadlen' bytes are unencrypted, but authenticated.
1641          * 'need' bytes are encrypted, followed by either
1642          * 'authlen' bytes of authentication tag or
1643          * 'maclen' bytes of message authentication code.
1644          */
1645         if (sshbuf_len(state->input) < aadlen + need + authlen + maclen)
1646                 return 0;
1647 #ifdef PACKET_DEBUG
1648         fprintf(stderr, "read_poll enc/full: ");
1649         sshbuf_dump(state->input, stderr);
1650 #endif
1651         /* EtM: compute mac over encrypted input */
1652         if (mac && mac->enabled && mac->etm) {
1653                 if ((r = mac_compute(mac, state->p_read.seqnr,
1654                     sshbuf_ptr(state->input), aadlen + need,
1655                     macbuf, sizeof(macbuf))) != 0)
1656                         goto out;
1657         }
1658         if ((r = sshbuf_reserve(state->incoming_packet, aadlen + need,
1659             &cp)) != 0)
1660                 goto out;
1661         if ((r = cipher_crypt(&state->receive_context, state->p_read.seqnr, cp,
1662             sshbuf_ptr(state->input), need, aadlen, authlen)) != 0)
1663                 goto out;
1664         if ((r = sshbuf_consume(state->input, aadlen + need + authlen)) != 0)
1665                 goto out;
1666         /*
1667          * compute MAC over seqnr and packet,
1668          * increment sequence number for incoming packet
1669          */
1670         if (mac && mac->enabled) {
1671                 if (!mac->etm)
1672                         if ((r = mac_compute(mac, state->p_read.seqnr,
1673                             sshbuf_ptr(state->incoming_packet),
1674                             sshbuf_len(state->incoming_packet),
1675                             macbuf, sizeof(macbuf))) != 0)
1676                                 goto out;
1677                 if (timingsafe_bcmp(macbuf, sshbuf_ptr(state->input),
1678                     mac->mac_len) != 0) {
1679                         logit("Corrupted MAC on input.");
1680                         if (need > PACKET_MAX_SIZE)
1681                                 return SSH_ERR_INTERNAL_ERROR;
1682                         return ssh_packet_start_discard(ssh, enc, mac,
1683                             state->packlen, PACKET_MAX_SIZE - need);
1684                 }
1685
1686                 DBG(debug("MAC #%d ok", state->p_read.seqnr));
1687                 if ((r = sshbuf_consume(state->input, mac->mac_len)) != 0)
1688                         goto out;
1689         }
1690         if (seqnr_p != NULL)
1691                 *seqnr_p = state->p_read.seqnr;
1692         if (++state->p_read.seqnr == 0)
1693                 logit("incoming seqnr wraps around");
1694         if (++state->p_read.packets == 0)
1695                 if (!(ssh->compat & SSH_BUG_NOREKEY))
1696                         return SSH_ERR_NEED_REKEY;
1697         state->p_read.blocks += (state->packlen + 4) / block_size;
1698         state->p_read.bytes += state->packlen + 4;
1699
1700         /* get padlen */
1701         padlen = sshbuf_ptr(state->incoming_packet)[4];
1702         DBG(debug("input: padlen %d", padlen));
1703         if (padlen < 4) {
1704                 if ((r = sshpkt_disconnect(ssh,
1705                     "Corrupted padlen %d on input.", padlen)) != 0 ||
1706                     (r = ssh_packet_write_wait(ssh)) != 0)
1707                         return r;
1708                 return SSH_ERR_CONN_CORRUPT;
1709         }
1710
1711         /* skip packet size + padlen, discard padding */
1712         if ((r = sshbuf_consume(state->incoming_packet, 4 + 1)) != 0 ||
1713             ((r = sshbuf_consume_end(state->incoming_packet, padlen)) != 0))
1714                 goto out;
1715
1716         DBG(debug("input: len before de-compress %zd",
1717             sshbuf_len(state->incoming_packet)));
1718         if (comp && comp->enabled) {
1719                 sshbuf_reset(state->compression_buffer);
1720                 if ((r = uncompress_buffer(ssh, state->incoming_packet,
1721                     state->compression_buffer)) != 0)
1722                         goto out;
1723                 sshbuf_reset(state->incoming_packet);
1724                 if ((r = sshbuf_putb(state->incoming_packet,
1725                     state->compression_buffer)) != 0)
1726                         goto out;
1727                 DBG(debug("input: len after de-compress %zd",
1728                     sshbuf_len(state->incoming_packet)));
1729         }
1730         /*
1731          * get packet type, implies consume.
1732          * return length of payload (without type field)
1733          */
1734         if ((r = sshbuf_get_u8(state->incoming_packet, typep)) != 0)
1735                 goto out;
1736         if (*typep < SSH2_MSG_MIN || *typep >= SSH2_MSG_LOCAL_MIN) {
1737                 if ((r = sshpkt_disconnect(ssh,
1738                     "Invalid ssh2 packet type: %d", *typep)) != 0 ||
1739                     (r = ssh_packet_write_wait(ssh)) != 0)
1740                         return r;
1741                 return SSH_ERR_PROTOCOL_ERROR;
1742         }
1743         if (*typep == SSH2_MSG_NEWKEYS)
1744                 r = ssh_set_newkeys(ssh, MODE_IN);
1745         else if (*typep == SSH2_MSG_USERAUTH_SUCCESS && !state->server_side)
1746                 r = ssh_packet_enable_delayed_compress(ssh);
1747         else
1748                 r = 0;
1749 #ifdef PACKET_DEBUG
1750         fprintf(stderr, "read/plain[%d]:\r\n", *typep);
1751         sshbuf_dump(state->incoming_packet, stderr);
1752 #endif
1753         /* reset for next packet */
1754         state->packlen = 0;
1755  out:
1756         return r;
1757 }
1758
1759 int
1760 ssh_packet_read_poll_seqnr(struct ssh *ssh, u_char *typep, u_int32_t *seqnr_p)
1761 {
1762         struct session_state *state = ssh->state;
1763         u_int reason, seqnr;
1764         int r;
1765         u_char *msg;
1766
1767         for (;;) {
1768                 msg = NULL;
1769                 if (compat20) {
1770                         r = ssh_packet_read_poll2(ssh, typep, seqnr_p);
1771                         if (r != 0)
1772                                 return r;
1773                         if (*typep) {
1774                                 state->keep_alive_timeouts = 0;
1775                                 DBG(debug("received packet type %d", *typep));
1776                         }
1777                         switch (*typep) {
1778                         case SSH2_MSG_IGNORE:
1779                                 debug3("Received SSH2_MSG_IGNORE");
1780                                 break;
1781                         case SSH2_MSG_DEBUG:
1782                                 if ((r = sshpkt_get_u8(ssh, NULL)) != 0 ||
1783                                     (r = sshpkt_get_string(ssh, &msg, NULL)) != 0 ||
1784                                     (r = sshpkt_get_string(ssh, NULL, NULL)) != 0) {
1785                                         if (msg)
1786                                                 free(msg);
1787                                         return r;
1788                                 }
1789                                 debug("Remote: %.900s", msg);
1790                                 free(msg);
1791                                 break;
1792                         case SSH2_MSG_DISCONNECT:
1793                                 if ((r = sshpkt_get_u32(ssh, &reason)) != 0 ||
1794                                     (r = sshpkt_get_string(ssh, &msg, NULL)) != 0)
1795                                         return r;
1796                                 /* Ignore normal client exit notifications */
1797                                 do_log2(ssh->state->server_side &&
1798                                     reason == SSH2_DISCONNECT_BY_APPLICATION ?
1799                                     SYSLOG_LEVEL_INFO : SYSLOG_LEVEL_ERROR,
1800                                     "Received disconnect from %s: %u: %.400s",
1801                                     ssh_remote_ipaddr(ssh), reason, msg);
1802                                 free(msg);
1803                                 return SSH_ERR_DISCONNECTED;
1804                         case SSH2_MSG_UNIMPLEMENTED:
1805                                 if ((r = sshpkt_get_u32(ssh, &seqnr)) != 0)
1806                                         return r;
1807                                 debug("Received SSH2_MSG_UNIMPLEMENTED for %u",
1808                                     seqnr);
1809                                 break;
1810                         default:
1811                                 return 0;
1812                         }
1813                 } else {
1814                         r = ssh_packet_read_poll1(ssh, typep);
1815                         switch (*typep) {
1816                         case SSH_MSG_NONE:
1817                                 return SSH_MSG_NONE;
1818                         case SSH_MSG_IGNORE:
1819                                 break;
1820                         case SSH_MSG_DEBUG:
1821                                 if ((r = sshpkt_get_string(ssh, &msg, NULL)) != 0)
1822                                         return r;
1823                                 debug("Remote: %.900s", msg);
1824                                 free(msg);
1825                                 break;
1826                         case SSH_MSG_DISCONNECT:
1827                                 if ((r = sshpkt_get_string(ssh, &msg, NULL)) != 0)
1828                                         return r;
1829                                 error("Received disconnect from %s: %.400s",
1830                                     ssh_remote_ipaddr(ssh), msg);
1831                                 free(msg);
1832                                 return SSH_ERR_DISCONNECTED;
1833                         default:
1834                                 DBG(debug("received packet type %d", *typep));
1835                                 return 0;
1836                         }
1837                 }
1838         }
1839 }
1840
1841 /*
1842  * Buffers the given amount of input characters.  This is intended to be used
1843  * together with packet_read_poll.
1844  */
1845
1846 int
1847 ssh_packet_process_incoming(struct ssh *ssh, const char *buf, u_int len)
1848 {
1849         struct session_state *state = ssh->state;
1850         int r;
1851
1852         if (state->packet_discard) {
1853                 state->keep_alive_timeouts = 0; /* ?? */
1854                 if (len >= state->packet_discard) {
1855                         if ((r = ssh_packet_stop_discard(ssh)) != 0)
1856                                 return r;
1857                 }
1858                 state->packet_discard -= len;
1859                 return 0;
1860         }
1861         if ((r = sshbuf_put(ssh->state->input, buf, len)) != 0)
1862                 return r;
1863
1864         return 0;
1865 }
1866
1867 int
1868 ssh_packet_remaining(struct ssh *ssh)
1869 {
1870         return sshbuf_len(ssh->state->incoming_packet);
1871 }
1872
1873 /*
1874  * Sends a diagnostic message from the server to the client.  This message
1875  * can be sent at any time (but not while constructing another message). The
1876  * message is printed immediately, but only if the client is being executed
1877  * in verbose mode.  These messages are primarily intended to ease debugging
1878  * authentication problems.   The length of the formatted message must not
1879  * exceed 1024 bytes.  This will automatically call ssh_packet_write_wait.
1880  */
1881 void
1882 ssh_packet_send_debug(struct ssh *ssh, const char *fmt,...)
1883 {
1884         char buf[1024];
1885         va_list args;
1886         int r;
1887
1888         if (compat20 && (ssh->compat & SSH_BUG_DEBUG))
1889                 return;
1890
1891         va_start(args, fmt);
1892         vsnprintf(buf, sizeof(buf), fmt, args);
1893         va_end(args);
1894
1895         if (compat20) {
1896                 if ((r = sshpkt_start(ssh, SSH2_MSG_DEBUG)) != 0 ||
1897                     (r = sshpkt_put_u8(ssh, 0)) != 0 || /* always display */
1898                     (r = sshpkt_put_cstring(ssh, buf)) != 0 ||
1899                     (r = sshpkt_put_cstring(ssh, "")) != 0 ||
1900                     (r = sshpkt_send(ssh)) != 0)
1901                         fatal("%s: %s", __func__, ssh_err(r));
1902         } else {
1903                 if ((r = sshpkt_start(ssh, SSH_MSG_DEBUG)) != 0 ||
1904                     (r = sshpkt_put_cstring(ssh, buf)) != 0 ||
1905                     (r = sshpkt_send(ssh)) != 0)
1906                         fatal("%s: %s", __func__, ssh_err(r));
1907         }
1908         if ((r = ssh_packet_write_wait(ssh)) != 0)
1909                 fatal("%s: %s", __func__, ssh_err(r));
1910 }
1911
1912 /*
1913  * Pretty-print connection-terminating errors and exit.
1914  */
1915 void
1916 sshpkt_fatal(struct ssh *ssh, const char *tag, int r)
1917 {
1918         switch (r) {
1919         case SSH_ERR_CONN_CLOSED:
1920                 logit("Connection closed by %.200s", ssh_remote_ipaddr(ssh));
1921                 cleanup_exit(255);
1922         case SSH_ERR_CONN_TIMEOUT:
1923                 logit("Connection to %.200s timed out", ssh_remote_ipaddr(ssh));
1924                 cleanup_exit(255);
1925         case SSH_ERR_DISCONNECTED:
1926                 logit("Disconnected from %.200s",
1927                     ssh_remote_ipaddr(ssh));
1928                 cleanup_exit(255);
1929         case SSH_ERR_SYSTEM_ERROR:
1930                 if (errno == ECONNRESET) {
1931                         logit("Connection reset by %.200s",
1932                             ssh_remote_ipaddr(ssh));
1933                         cleanup_exit(255);
1934                 }
1935                 /* FALLTHROUGH */
1936         default:
1937                 fatal("%s%sConnection to %.200s: %s",
1938                     tag != NULL ? tag : "", tag != NULL ? ": " : "",
1939                     ssh_remote_ipaddr(ssh), ssh_err(r));
1940         }
1941 }
1942
1943 /*
1944  * Logs the error plus constructs and sends a disconnect packet, closes the
1945  * connection, and exits.  This function never returns. The error message
1946  * should not contain a newline.  The length of the formatted message must
1947  * not exceed 1024 bytes.
1948  */
1949 void
1950 ssh_packet_disconnect(struct ssh *ssh, const char *fmt,...)
1951 {
1952         char buf[1024];
1953         va_list args;
1954         static int disconnecting = 0;
1955         int r;
1956
1957         if (disconnecting)      /* Guard against recursive invocations. */
1958                 fatal("packet_disconnect called recursively.");
1959         disconnecting = 1;
1960
1961         /*
1962          * Format the message.  Note that the caller must make sure the
1963          * message is of limited size.
1964          */
1965         va_start(args, fmt);
1966         vsnprintf(buf, sizeof(buf), fmt, args);
1967         va_end(args);
1968
1969         /* Display the error locally */
1970         logit("Disconnecting: %.100s", buf);
1971
1972         /*
1973          * Send the disconnect message to the other side, and wait
1974          * for it to get sent.
1975          */
1976         if ((r = sshpkt_disconnect(ssh, "%s", buf)) != 0)
1977                 sshpkt_fatal(ssh, __func__, r);
1978
1979         if ((r = ssh_packet_write_wait(ssh)) != 0)
1980                 sshpkt_fatal(ssh, __func__, r);
1981
1982         /* Close the connection. */
1983         ssh_packet_close(ssh);
1984         cleanup_exit(255);
1985 }
1986
1987 /*
1988  * Checks if there is any buffered output, and tries to write some of
1989  * the output.
1990  */
1991 int
1992 ssh_packet_write_poll(struct ssh *ssh)
1993 {
1994         struct session_state *state = ssh->state;
1995         int len = sshbuf_len(state->output);
1996         int cont, r;
1997
1998         if (len > 0) {
1999                 cont = 0;
2000                 len = roaming_write(state->connection_out,
2001                     sshbuf_ptr(state->output), len, &cont);
2002                 if (len == -1) {
2003                         if (errno == EINTR || errno == EAGAIN ||
2004                             errno == EWOULDBLOCK)
2005                                 return 0;
2006                         return SSH_ERR_SYSTEM_ERROR;
2007                 }
2008                 if (len == 0 && !cont)
2009                         return SSH_ERR_CONN_CLOSED;
2010                 if ((r = sshbuf_consume(state->output, len)) != 0)
2011                         return r;
2012         }
2013         return 0;
2014 }
2015
2016 /*
2017  * Calls packet_write_poll repeatedly until all pending output data has been
2018  * written.
2019  */
2020 int
2021 ssh_packet_write_wait(struct ssh *ssh)
2022 {
2023         fd_set *setp;
2024         int ret, r, ms_remain = 0;
2025         struct timeval start, timeout, *timeoutp = NULL;
2026         struct session_state *state = ssh->state;
2027
2028         setp = (fd_set *)calloc(howmany(state->connection_out + 1,
2029             NFDBITS), sizeof(fd_mask));
2030         if (setp == NULL)
2031                 return SSH_ERR_ALLOC_FAIL;
2032         ssh_packet_write_poll(ssh);
2033         while (ssh_packet_have_data_to_write(ssh)) {
2034                 memset(setp, 0, howmany(state->connection_out + 1,
2035                     NFDBITS) * sizeof(fd_mask));
2036                 FD_SET(state->connection_out, setp);
2037
2038                 if (state->packet_timeout_ms > 0) {
2039                         ms_remain = state->packet_timeout_ms;
2040                         timeoutp = &timeout;
2041                 }
2042                 for (;;) {
2043                         if (state->packet_timeout_ms != -1) {
2044                                 ms_to_timeval(&timeout, ms_remain);
2045                                 gettimeofday(&start, NULL);
2046                         }
2047                         if ((ret = select(state->connection_out + 1,
2048                             NULL, setp, NULL, timeoutp)) >= 0)
2049                                 break;
2050                         if (errno != EAGAIN && errno != EINTR &&
2051                             errno != EWOULDBLOCK)
2052                                 break;
2053                         if (state->packet_timeout_ms == -1)
2054                                 continue;
2055                         ms_subtract_diff(&start, &ms_remain);
2056                         if (ms_remain <= 0) {
2057                                 ret = 0;
2058                                 break;
2059                         }
2060                 }
2061                 if (ret == 0) {
2062                         free(setp);
2063                         return SSH_ERR_CONN_TIMEOUT;
2064                 }
2065                 if ((r = ssh_packet_write_poll(ssh)) != 0) {
2066                         free(setp);
2067                         return r;
2068                 }
2069         }
2070         free(setp);
2071         return 0;
2072 }
2073
2074 /* Returns true if there is buffered data to write to the connection. */
2075
2076 int
2077 ssh_packet_have_data_to_write(struct ssh *ssh)
2078 {
2079         return sshbuf_len(ssh->state->output) != 0;
2080 }
2081
2082 /* Returns true if there is not too much data to write to the connection. */
2083
2084 int
2085 ssh_packet_not_very_much_data_to_write(struct ssh *ssh)
2086 {
2087         if (ssh->state->interactive_mode)
2088                 return sshbuf_len(ssh->state->output) < 16384;
2089         else
2090                 return sshbuf_len(ssh->state->output) < 128 * 1024;
2091 }
2092
2093 void
2094 ssh_packet_set_tos(struct ssh *ssh, int tos)
2095 {
2096 #ifndef IP_TOS_IS_BROKEN
2097         if (!ssh_packet_connection_is_on_socket(ssh))
2098                 return;
2099         switch (ssh_packet_connection_af(ssh)) {
2100 # ifdef IP_TOS
2101         case AF_INET:
2102                 debug3("%s: set IP_TOS 0x%02x", __func__, tos);
2103                 if (setsockopt(ssh->state->connection_in,
2104                     IPPROTO_IP, IP_TOS, &tos, sizeof(tos)) < 0)
2105                         error("setsockopt IP_TOS %d: %.100s:",
2106                             tos, strerror(errno));
2107                 break;
2108 # endif /* IP_TOS */
2109 # ifdef IPV6_TCLASS
2110         case AF_INET6:
2111                 debug3("%s: set IPV6_TCLASS 0x%02x", __func__, tos);
2112                 if (setsockopt(ssh->state->connection_in,
2113                     IPPROTO_IPV6, IPV6_TCLASS, &tos, sizeof(tos)) < 0)
2114                         error("setsockopt IPV6_TCLASS %d: %.100s:",
2115                             tos, strerror(errno));
2116                 break;
2117 # endif /* IPV6_TCLASS */
2118         }
2119 #endif /* IP_TOS_IS_BROKEN */
2120 }
2121
2122 /* Informs that the current session is interactive.  Sets IP flags for that. */
2123
2124 void
2125 ssh_packet_set_interactive(struct ssh *ssh, int interactive, int qos_interactive, int qos_bulk)
2126 {
2127         struct session_state *state = ssh->state;
2128
2129         if (state->set_interactive_called)
2130                 return;
2131         state->set_interactive_called = 1;
2132
2133         /* Record that we are in interactive mode. */
2134         state->interactive_mode = interactive;
2135
2136         /* Only set socket options if using a socket.  */
2137         if (!ssh_packet_connection_is_on_socket(ssh))
2138                 return;
2139         set_nodelay(state->connection_in);
2140         ssh_packet_set_tos(ssh, interactive ? qos_interactive :
2141             qos_bulk);
2142 }
2143
2144 /* Returns true if the current connection is interactive. */
2145
2146 int
2147 ssh_packet_is_interactive(struct ssh *ssh)
2148 {
2149         return ssh->state->interactive_mode;
2150 }
2151
2152 int
2153 ssh_packet_set_maxsize(struct ssh *ssh, u_int s)
2154 {
2155         struct session_state *state = ssh->state;
2156
2157         if (state->set_maxsize_called) {
2158                 logit("packet_set_maxsize: called twice: old %d new %d",
2159                     state->max_packet_size, s);
2160                 return -1;
2161         }
2162         if (s < 4 * 1024 || s > 1024 * 1024) {
2163                 logit("packet_set_maxsize: bad size %d", s);
2164                 return -1;
2165         }
2166         state->set_maxsize_called = 1;
2167         debug("packet_set_maxsize: setting to %d", s);
2168         state->max_packet_size = s;
2169         return s;
2170 }
2171
2172 int
2173 ssh_packet_inc_alive_timeouts(struct ssh *ssh)
2174 {
2175         return ++ssh->state->keep_alive_timeouts;
2176 }
2177
2178 void
2179 ssh_packet_set_alive_timeouts(struct ssh *ssh, int ka)
2180 {
2181         ssh->state->keep_alive_timeouts = ka;
2182 }
2183
2184 u_int
2185 ssh_packet_get_maxsize(struct ssh *ssh)
2186 {
2187         return ssh->state->max_packet_size;
2188 }
2189
2190 /*
2191  * 9.2.  Ignored Data Message
2192  *
2193  *   byte      SSH_MSG_IGNORE
2194  *   string    data
2195  *
2196  * All implementations MUST understand (and ignore) this message at any
2197  * time (after receiving the protocol version). No implementation is
2198  * required to send them. This message can be used as an additional
2199  * protection measure against advanced traffic analysis techniques.
2200  */
2201 void
2202 ssh_packet_send_ignore(struct ssh *ssh, int nbytes)
2203 {
2204         u_int32_t rnd = 0;
2205         int r, i;
2206
2207         if ((r = sshpkt_start(ssh, compat20 ?
2208             SSH2_MSG_IGNORE : SSH_MSG_IGNORE)) != 0 ||
2209             (r = sshpkt_put_u32(ssh, nbytes)) != 0)
2210                 fatal("%s: %s", __func__, ssh_err(r));
2211         for (i = 0; i < nbytes; i++) {
2212                 if (i % 4 == 0)
2213                         rnd = arc4random();
2214                 if ((r = sshpkt_put_u8(ssh, (u_char)rnd & 0xff)) != 0)
2215                         fatal("%s: %s", __func__, ssh_err(r));
2216                 rnd >>= 8;
2217         }
2218 }
2219
2220 #define MAX_PACKETS     (1U<<31)
2221 int
2222 ssh_packet_need_rekeying(struct ssh *ssh)
2223 {
2224         struct session_state *state = ssh->state;
2225
2226         if (ssh->compat & SSH_BUG_NOREKEY)
2227                 return 0;
2228         return
2229             (state->p_send.packets > MAX_PACKETS) ||
2230             (state->p_read.packets > MAX_PACKETS) ||
2231             (state->max_blocks_out &&
2232                 (state->p_send.blocks > state->max_blocks_out)) ||
2233             (state->max_blocks_in &&
2234                 (state->p_read.blocks > state->max_blocks_in)) ||
2235             (state->rekey_interval != 0 && state->rekey_time +
2236                  state->rekey_interval <= monotime());
2237 }
2238
2239 void
2240 ssh_packet_set_rekey_limits(struct ssh *ssh, u_int32_t bytes, time_t seconds)
2241 {
2242         debug3("rekey after %lld bytes, %d seconds", (long long)bytes,
2243             (int)seconds);
2244         ssh->state->rekey_limit = bytes;
2245         ssh->state->rekey_interval = seconds;
2246 }
2247
2248 time_t
2249 ssh_packet_get_rekey_timeout(struct ssh *ssh)
2250 {
2251         time_t seconds;
2252
2253         seconds = ssh->state->rekey_time + ssh->state->rekey_interval -
2254             monotime();
2255         return (seconds <= 0 ? 1 : seconds);
2256 }
2257
2258 void
2259 ssh_packet_set_server(struct ssh *ssh)
2260 {
2261         ssh->state->server_side = 1;
2262 }
2263
2264 void
2265 ssh_packet_set_authenticated(struct ssh *ssh)
2266 {
2267         ssh->state->after_authentication = 1;
2268 }
2269
2270 void *
2271 ssh_packet_get_input(struct ssh *ssh)
2272 {
2273         return (void *)ssh->state->input;
2274 }
2275
2276 void *
2277 ssh_packet_get_output(struct ssh *ssh)
2278 {
2279         return (void *)ssh->state->output;
2280 }
2281
2282 /* XXX TODO update roaming to new API (does not work anyway) */
2283 /*
2284  * Save the state for the real connection, and use a separate state when
2285  * resuming a suspended connection.
2286  */
2287 void
2288 ssh_packet_backup_state(struct ssh *ssh,
2289     struct ssh *backup_state)
2290 {
2291         struct ssh *tmp;
2292
2293         close(ssh->state->connection_in);
2294         ssh->state->connection_in = -1;
2295         close(ssh->state->connection_out);
2296         ssh->state->connection_out = -1;
2297         if (backup_state)
2298                 tmp = backup_state;
2299         else
2300                 tmp = ssh_alloc_session_state();
2301         backup_state = ssh;
2302         ssh = tmp;
2303 }
2304
2305 /* XXX FIXME FIXME FIXME */
2306 /*
2307  * Swap in the old state when resuming a connecion.
2308  */
2309 void
2310 ssh_packet_restore_state(struct ssh *ssh,
2311     struct ssh *backup_state)
2312 {
2313         struct ssh *tmp;
2314         u_int len;
2315         int r;
2316
2317         tmp = backup_state;
2318         backup_state = ssh;
2319         ssh = tmp;
2320         ssh->state->connection_in = backup_state->state->connection_in;
2321         backup_state->state->connection_in = -1;
2322         ssh->state->connection_out = backup_state->state->connection_out;
2323         backup_state->state->connection_out = -1;
2324         len = sshbuf_len(backup_state->state->input);
2325         if (len > 0) {
2326                 if ((r = sshbuf_putb(ssh->state->input,
2327                     backup_state->state->input)) != 0)
2328                         fatal("%s: %s", __func__, ssh_err(r));
2329                 sshbuf_reset(backup_state->state->input);
2330                 add_recv_bytes(len);
2331         }
2332 }
2333
2334 /* Reset after_authentication and reset compression in post-auth privsep */
2335 static int
2336 ssh_packet_set_postauth(struct ssh *ssh)
2337 {
2338         struct sshcomp *comp;
2339         int r, mode;
2340
2341         debug("%s: called", __func__);
2342         /* This was set in net child, but is not visible in user child */
2343         ssh->state->after_authentication = 1;
2344         ssh->state->rekeying = 0;
2345         for (mode = 0; mode < MODE_MAX; mode++) {
2346                 if (ssh->state->newkeys[mode] == NULL)
2347                         continue;
2348                 comp = &ssh->state->newkeys[mode]->comp;
2349                 if (comp && comp->enabled &&
2350                     (r = ssh_packet_init_compression(ssh)) != 0)
2351                         return r;
2352         }
2353         return 0;
2354 }
2355
2356 /* Packet state (de-)serialization for privsep */
2357
2358 /* turn kex into a blob for packet state serialization */
2359 static int
2360 kex_to_blob(struct sshbuf *m, struct kex *kex)
2361 {
2362         int r;
2363
2364         if ((r = sshbuf_put_string(m, kex->session_id,
2365             kex->session_id_len)) != 0 ||
2366             (r = sshbuf_put_u32(m, kex->we_need)) != 0 ||
2367             (r = sshbuf_put_u32(m, kex->hostkey_type)) != 0 ||
2368             (r = sshbuf_put_u32(m, kex->kex_type)) != 0 ||
2369             (r = sshbuf_put_stringb(m, kex->my)) != 0 ||
2370             (r = sshbuf_put_stringb(m, kex->peer)) != 0 ||
2371             (r = sshbuf_put_u32(m, kex->flags)) != 0 ||
2372             (r = sshbuf_put_cstring(m, kex->client_version_string)) != 0 ||
2373             (r = sshbuf_put_cstring(m, kex->server_version_string)) != 0)
2374                 return r;
2375         return 0;
2376 }
2377
2378 /* turn key exchange results into a blob for packet state serialization */
2379 static int
2380 newkeys_to_blob(struct sshbuf *m, struct ssh *ssh, int mode)
2381 {
2382         struct sshbuf *b;
2383         struct sshcipher_ctx *cc;
2384         struct sshcomp *comp;
2385         struct sshenc *enc;
2386         struct sshmac *mac;
2387         struct newkeys *newkey;
2388         int r;
2389
2390         if ((newkey = ssh->state->newkeys[mode]) == NULL)
2391                 return SSH_ERR_INTERNAL_ERROR;
2392         enc = &newkey->enc;
2393         mac = &newkey->mac;
2394         comp = &newkey->comp;
2395         cc = (mode == MODE_OUT) ? &ssh->state->send_context :
2396             &ssh->state->receive_context;
2397         if ((r = cipher_get_keyiv(cc, enc->iv, enc->iv_len)) != 0)
2398                 return r;
2399         if ((b = sshbuf_new()) == NULL)
2400                 return SSH_ERR_ALLOC_FAIL;
2401         /* The cipher struct is constant and shared, you export pointer */
2402         if ((r = sshbuf_put_cstring(b, enc->name)) != 0 ||
2403             (r = sshbuf_put(b, &enc->cipher, sizeof(enc->cipher))) != 0 ||
2404             (r = sshbuf_put_u32(b, enc->enabled)) != 0 ||
2405             (r = sshbuf_put_u32(b, enc->block_size)) != 0 ||
2406             (r = sshbuf_put_string(b, enc->key, enc->key_len)) != 0 ||
2407             (r = sshbuf_put_string(b, enc->iv, enc->iv_len)) != 0)
2408                 goto out;
2409         if (cipher_authlen(enc->cipher) == 0) {
2410                 if ((r = sshbuf_put_cstring(b, mac->name)) != 0 ||
2411                     (r = sshbuf_put_u32(b, mac->enabled)) != 0 ||
2412                     (r = sshbuf_put_string(b, mac->key, mac->key_len)) != 0)
2413                         goto out;
2414         }
2415         if ((r = sshbuf_put_u32(b, comp->type)) != 0 ||
2416             (r = sshbuf_put_u32(b, comp->enabled)) != 0 ||
2417             (r = sshbuf_put_cstring(b, comp->name)) != 0)
2418                 goto out;
2419         r = sshbuf_put_stringb(m, b);
2420  out:
2421         if (b != NULL)
2422                 sshbuf_free(b);
2423         return r;
2424 }
2425
2426 /* serialize packet state into a blob */
2427 int
2428 ssh_packet_get_state(struct ssh *ssh, struct sshbuf *m)
2429 {
2430         struct session_state *state = ssh->state;
2431         u_char *p;
2432         size_t slen, rlen;
2433         int r, ssh1cipher;
2434
2435         if (!compat20) {
2436                 ssh1cipher = cipher_get_number(state->receive_context.cipher);
2437                 slen = cipher_get_keyiv_len(&state->send_context);
2438                 rlen = cipher_get_keyiv_len(&state->receive_context);
2439                 if ((r = sshbuf_put_u32(m, state->remote_protocol_flags)) != 0 ||
2440                     (r = sshbuf_put_u32(m, ssh1cipher)) != 0 ||
2441                     (r = sshbuf_put_string(m, state->ssh1_key, state->ssh1_keylen)) != 0 ||
2442                     (r = sshbuf_put_u32(m, slen)) != 0 ||
2443                     (r = sshbuf_reserve(m, slen, &p)) != 0 ||
2444                     (r = cipher_get_keyiv(&state->send_context, p, slen)) != 0 ||
2445                     (r = sshbuf_put_u32(m, rlen)) != 0 ||
2446                     (r = sshbuf_reserve(m, rlen, &p)) != 0 ||
2447                     (r = cipher_get_keyiv(&state->receive_context, p, rlen)) != 0)
2448                         return r;
2449         } else {
2450                 if ((r = kex_to_blob(m, ssh->kex)) != 0 ||
2451                     (r = newkeys_to_blob(m, ssh, MODE_OUT)) != 0 ||
2452                     (r = newkeys_to_blob(m, ssh, MODE_IN)) != 0 ||
2453                     (r = sshbuf_put_u32(m, state->rekey_limit)) != 0 ||
2454                     (r = sshbuf_put_u32(m, state->rekey_interval)) != 0 ||
2455                     (r = sshbuf_put_u32(m, state->p_send.seqnr)) != 0 ||
2456                     (r = sshbuf_put_u64(m, state->p_send.blocks)) != 0 ||
2457                     (r = sshbuf_put_u32(m, state->p_send.packets)) != 0 ||
2458                     (r = sshbuf_put_u64(m, state->p_send.bytes)) != 0 ||
2459                     (r = sshbuf_put_u32(m, state->p_read.seqnr)) != 0 ||
2460                     (r = sshbuf_put_u64(m, state->p_read.blocks)) != 0 ||
2461                     (r = sshbuf_put_u32(m, state->p_read.packets)) != 0 ||
2462                     (r = sshbuf_put_u64(m, state->p_read.bytes)) != 0)
2463                         return r;
2464         }
2465
2466         slen = cipher_get_keycontext(&state->send_context, NULL);
2467         rlen = cipher_get_keycontext(&state->receive_context, NULL);
2468         if ((r = sshbuf_put_u32(m, slen)) != 0 ||
2469             (r = sshbuf_reserve(m, slen, &p)) != 0)
2470                 return r;
2471         if (cipher_get_keycontext(&state->send_context, p) != (int)slen)
2472                 return SSH_ERR_INTERNAL_ERROR;
2473         if ((r = sshbuf_put_u32(m, rlen)) != 0 ||
2474             (r = sshbuf_reserve(m, rlen, &p)) != 0)
2475                 return r;
2476         if (cipher_get_keycontext(&state->receive_context, p) != (int)rlen)
2477                 return SSH_ERR_INTERNAL_ERROR;
2478
2479         if ((r = ssh_packet_get_compress_state(m, ssh)) != 0 ||
2480             (r = sshbuf_put_stringb(m, state->input)) != 0 ||
2481             (r = sshbuf_put_stringb(m, state->output)) != 0)
2482                 return r;
2483
2484         if (compat20) {
2485                 if ((r = sshbuf_put_u64(m, get_sent_bytes())) != 0 ||
2486                     (r = sshbuf_put_u64(m, get_recv_bytes())) != 0)
2487                         return r;
2488         }
2489         return 0;
2490 }
2491
2492 /* restore key exchange results from blob for packet state de-serialization */
2493 static int
2494 newkeys_from_blob(struct sshbuf *m, struct ssh *ssh, int mode)
2495 {
2496         struct sshbuf *b = NULL;
2497         struct sshcomp *comp;
2498         struct sshenc *enc;
2499         struct sshmac *mac;
2500         struct newkeys *newkey = NULL;
2501         size_t keylen, ivlen, maclen;
2502         int r;
2503
2504         if ((newkey = calloc(1, sizeof(*newkey))) == NULL) {
2505                 r = SSH_ERR_ALLOC_FAIL;
2506                 goto out;
2507         }
2508         if ((r = sshbuf_froms(m, &b)) != 0)
2509                 goto out;
2510 #ifdef DEBUG_PK
2511         sshbuf_dump(b, stderr);
2512 #endif
2513         enc = &newkey->enc;
2514         mac = &newkey->mac;
2515         comp = &newkey->comp;
2516
2517         if ((r = sshbuf_get_cstring(b, &enc->name, NULL)) != 0 ||
2518             (r = sshbuf_get(b, &enc->cipher, sizeof(enc->cipher))) != 0 ||
2519             (r = sshbuf_get_u32(b, (u_int *)&enc->enabled)) != 0 ||
2520             (r = sshbuf_get_u32(b, &enc->block_size)) != 0 ||
2521             (r = sshbuf_get_string(b, &enc->key, &keylen)) != 0 ||
2522             (r = sshbuf_get_string(b, &enc->iv, &ivlen)) != 0)
2523                 goto out;
2524         if (cipher_authlen(enc->cipher) == 0) {
2525                 if ((r = sshbuf_get_cstring(b, &mac->name, NULL)) != 0)
2526                         goto out;
2527                 if ((r = mac_setup(mac, mac->name)) != 0)
2528                         goto out;
2529                 if ((r = sshbuf_get_u32(b, (u_int *)&mac->enabled)) != 0 ||
2530                     (r = sshbuf_get_string(b, &mac->key, &maclen)) != 0)
2531                         goto out;
2532                 if (maclen > mac->key_len) {
2533                         r = SSH_ERR_INVALID_FORMAT;
2534                         goto out;
2535                 }
2536                 mac->key_len = maclen;
2537         }
2538         if ((r = sshbuf_get_u32(b, &comp->type)) != 0 ||
2539             (r = sshbuf_get_u32(b, (u_int *)&comp->enabled)) != 0 ||
2540             (r = sshbuf_get_cstring(b, &comp->name, NULL)) != 0)
2541                 goto out;
2542         if (enc->name == NULL ||
2543             cipher_by_name(enc->name) != enc->cipher) {
2544                 r = SSH_ERR_INVALID_FORMAT;
2545                 goto out;
2546         }
2547         if (sshbuf_len(b) != 0) {
2548                 r = SSH_ERR_INVALID_FORMAT;
2549                 goto out;
2550         }
2551         enc->key_len = keylen;
2552         enc->iv_len = ivlen;
2553         ssh->kex->newkeys[mode] = newkey;
2554         newkey = NULL;
2555         r = 0;
2556  out:
2557         if (newkey != NULL)
2558                 free(newkey);
2559         if (b != NULL)
2560                 sshbuf_free(b);
2561         return r;
2562 }
2563
2564 /* restore kex from blob for packet state de-serialization */
2565 static int
2566 kex_from_blob(struct sshbuf *m, struct kex **kexp)
2567 {
2568         struct kex *kex;
2569         int r;
2570
2571         if ((kex = calloc(1, sizeof(struct kex))) == NULL ||
2572             (kex->my = sshbuf_new()) == NULL ||
2573             (kex->peer = sshbuf_new()) == NULL) {
2574                 r = SSH_ERR_ALLOC_FAIL;
2575                 goto out;
2576         }
2577         if ((r = sshbuf_get_string(m, &kex->session_id, &kex->session_id_len)) != 0 ||
2578             (r = sshbuf_get_u32(m, &kex->we_need)) != 0 ||
2579             (r = sshbuf_get_u32(m, (u_int *)&kex->hostkey_type)) != 0 ||
2580             (r = sshbuf_get_u32(m, &kex->kex_type)) != 0 ||
2581             (r = sshbuf_get_stringb(m, kex->my)) != 0 ||
2582             (r = sshbuf_get_stringb(m, kex->peer)) != 0 ||
2583             (r = sshbuf_get_u32(m, &kex->flags)) != 0 ||
2584             (r = sshbuf_get_cstring(m, &kex->client_version_string, NULL)) != 0 ||
2585             (r = sshbuf_get_cstring(m, &kex->server_version_string, NULL)) != 0)
2586                 goto out;
2587         kex->server = 1;
2588         kex->done = 1;
2589         r = 0;
2590  out:
2591         if (r != 0 || kexp == NULL) {
2592                 if (kex != NULL) {
2593                         if (kex->my != NULL)
2594                                 sshbuf_free(kex->my);
2595                         if (kex->peer != NULL)
2596                                 sshbuf_free(kex->peer);
2597                         free(kex);
2598                 }
2599                 if (kexp != NULL)
2600                         *kexp = NULL;
2601         } else {
2602                 *kexp = kex;
2603         }
2604         return r;
2605 }
2606
2607 /*
2608  * Restore packet state from content of blob 'm' (de-serialization).
2609  * Note that 'm' will be partially consumed on parsing or any other errors.
2610  */
2611 int
2612 ssh_packet_set_state(struct ssh *ssh, struct sshbuf *m)
2613 {
2614         struct session_state *state = ssh->state;
2615         const u_char *ssh1key, *ivin, *ivout, *keyin, *keyout, *input, *output;
2616         size_t ssh1keylen, rlen, slen, ilen, olen;
2617         int r;
2618         u_int ssh1cipher = 0;
2619         u_int64_t sent_bytes = 0, recv_bytes = 0;
2620
2621         if (!compat20) {
2622                 if ((r = sshbuf_get_u32(m, &state->remote_protocol_flags)) != 0 ||
2623                     (r = sshbuf_get_u32(m, &ssh1cipher)) != 0 ||
2624                     (r = sshbuf_get_string_direct(m, &ssh1key, &ssh1keylen)) != 0 ||
2625                     (r = sshbuf_get_string_direct(m, &ivout, &slen)) != 0 ||
2626                     (r = sshbuf_get_string_direct(m, &ivin, &rlen)) != 0)
2627                         return r;
2628                 if (ssh1cipher > INT_MAX)
2629                         return SSH_ERR_KEY_UNKNOWN_CIPHER;
2630                 ssh_packet_set_encryption_key(ssh, ssh1key, ssh1keylen,
2631                     (int)ssh1cipher);
2632                 if (cipher_get_keyiv_len(&state->send_context) != (int)slen ||
2633                     cipher_get_keyiv_len(&state->receive_context) != (int)rlen)
2634                         return SSH_ERR_INVALID_FORMAT;
2635                 if ((r = cipher_set_keyiv(&state->send_context, ivout)) != 0 ||
2636                     (r = cipher_set_keyiv(&state->receive_context, ivin)) != 0)
2637                         return r;
2638         } else {
2639                 if ((r = kex_from_blob(m, &ssh->kex)) != 0 ||
2640                     (r = newkeys_from_blob(m, ssh, MODE_OUT)) != 0 ||
2641                     (r = newkeys_from_blob(m, ssh, MODE_IN)) != 0 ||
2642                     (r = sshbuf_get_u32(m, &state->rekey_limit)) != 0 ||
2643                     (r = sshbuf_get_u32(m, &state->rekey_interval)) != 0 ||
2644                     (r = sshbuf_get_u32(m, &state->p_send.seqnr)) != 0 ||
2645                     (r = sshbuf_get_u64(m, &state->p_send.blocks)) != 0 ||
2646                     (r = sshbuf_get_u32(m, &state->p_send.packets)) != 0 ||
2647                     (r = sshbuf_get_u64(m, &state->p_send.bytes)) != 0 ||
2648                     (r = sshbuf_get_u32(m, &state->p_read.seqnr)) != 0 ||
2649                     (r = sshbuf_get_u64(m, &state->p_read.blocks)) != 0 ||
2650                     (r = sshbuf_get_u32(m, &state->p_read.packets)) != 0 ||
2651                     (r = sshbuf_get_u64(m, &state->p_read.bytes)) != 0)
2652                         return r;
2653                 /*
2654                  * We set the time here so that in post-auth privsep slave we
2655                  * count from the completion of the authentication.
2656                  */
2657                 state->rekey_time = monotime();
2658                 /* XXX ssh_set_newkeys overrides p_read.packets? XXX */
2659                 if ((r = ssh_set_newkeys(ssh, MODE_IN)) != 0 ||
2660                     (r = ssh_set_newkeys(ssh, MODE_OUT)) != 0)
2661                         return r;
2662         }
2663         if ((r = sshbuf_get_string_direct(m, &keyout, &slen)) != 0 ||
2664             (r = sshbuf_get_string_direct(m, &keyin, &rlen)) != 0)
2665                 return r;
2666         if (cipher_get_keycontext(&state->send_context, NULL) != (int)slen ||
2667             cipher_get_keycontext(&state->receive_context, NULL) != (int)rlen)
2668                 return SSH_ERR_INVALID_FORMAT;
2669         cipher_set_keycontext(&state->send_context, keyout);
2670         cipher_set_keycontext(&state->receive_context, keyin);
2671
2672         if ((r = ssh_packet_set_compress_state(ssh, m)) != 0 ||
2673             (r = ssh_packet_set_postauth(ssh)) != 0)
2674                 return r;
2675
2676         sshbuf_reset(state->input);
2677         sshbuf_reset(state->output);
2678         if ((r = sshbuf_get_string_direct(m, &input, &ilen)) != 0 ||
2679             (r = sshbuf_get_string_direct(m, &output, &olen)) != 0 ||
2680             (r = sshbuf_put(state->input, input, ilen)) != 0 ||
2681             (r = sshbuf_put(state->output, output, olen)) != 0)
2682                 return r;
2683
2684         if (compat20) {
2685                 if ((r = sshbuf_get_u64(m, &sent_bytes)) != 0 ||
2686                     (r = sshbuf_get_u64(m, &recv_bytes)) != 0)
2687                         return r;
2688                 roam_set_bytes(sent_bytes, recv_bytes);
2689         }
2690         if (sshbuf_len(m))
2691                 return SSH_ERR_INVALID_FORMAT;
2692         debug3("%s: done", __func__);
2693         return 0;
2694 }
2695
2696 /* NEW API */
2697
2698 /* put data to the outgoing packet */
2699
2700 int
2701 sshpkt_put(struct ssh *ssh, const void *v, size_t len)
2702 {
2703         return sshbuf_put(ssh->state->outgoing_packet, v, len);
2704 }
2705
2706 int
2707 sshpkt_putb(struct ssh *ssh, const struct sshbuf *b)
2708 {
2709         return sshbuf_putb(ssh->state->outgoing_packet, b);
2710 }
2711
2712 int
2713 sshpkt_put_u8(struct ssh *ssh, u_char val)
2714 {
2715         return sshbuf_put_u8(ssh->state->outgoing_packet, val);
2716 }
2717
2718 int
2719 sshpkt_put_u32(struct ssh *ssh, u_int32_t val)
2720 {
2721         return sshbuf_put_u32(ssh->state->outgoing_packet, val);
2722 }
2723
2724 int
2725 sshpkt_put_u64(struct ssh *ssh, u_int64_t val)
2726 {
2727         return sshbuf_put_u64(ssh->state->outgoing_packet, val);
2728 }
2729
2730 int
2731 sshpkt_put_string(struct ssh *ssh, const void *v, size_t len)
2732 {
2733         return sshbuf_put_string(ssh->state->outgoing_packet, v, len);
2734 }
2735
2736 int
2737 sshpkt_put_cstring(struct ssh *ssh, const void *v)
2738 {
2739         return sshbuf_put_cstring(ssh->state->outgoing_packet, v);
2740 }
2741
2742 int
2743 sshpkt_put_stringb(struct ssh *ssh, const struct sshbuf *v)
2744 {
2745         return sshbuf_put_stringb(ssh->state->outgoing_packet, v);
2746 }
2747
2748 #ifdef WITH_OPENSSL
2749 #ifdef OPENSSL_HAS_ECC
2750 int
2751 sshpkt_put_ec(struct ssh *ssh, const EC_POINT *v, const EC_GROUP *g)
2752 {
2753         return sshbuf_put_ec(ssh->state->outgoing_packet, v, g);
2754 }
2755 #endif /* OPENSSL_HAS_ECC */
2756
2757 #ifdef WITH_SSH1
2758 int
2759 sshpkt_put_bignum1(struct ssh *ssh, const BIGNUM *v)
2760 {
2761         return sshbuf_put_bignum1(ssh->state->outgoing_packet, v);
2762 }
2763 #endif /* WITH_SSH1 */
2764
2765 int
2766 sshpkt_put_bignum2(struct ssh *ssh, const BIGNUM *v)
2767 {
2768         return sshbuf_put_bignum2(ssh->state->outgoing_packet, v);
2769 }
2770 #endif /* WITH_OPENSSL */
2771
2772 /* fetch data from the incoming packet */
2773
2774 int
2775 sshpkt_get(struct ssh *ssh, void *valp, size_t len)
2776 {
2777         return sshbuf_get(ssh->state->incoming_packet, valp, len);
2778 }
2779
2780 int
2781 sshpkt_get_u8(struct ssh *ssh, u_char *valp)
2782 {
2783         return sshbuf_get_u8(ssh->state->incoming_packet, valp);
2784 }
2785
2786 int
2787 sshpkt_get_u32(struct ssh *ssh, u_int32_t *valp)
2788 {
2789         return sshbuf_get_u32(ssh->state->incoming_packet, valp);
2790 }
2791
2792 int
2793 sshpkt_get_u64(struct ssh *ssh, u_int64_t *valp)
2794 {
2795         return sshbuf_get_u64(ssh->state->incoming_packet, valp);
2796 }
2797
2798 int
2799 sshpkt_get_string(struct ssh *ssh, u_char **valp, size_t *lenp)
2800 {
2801         return sshbuf_get_string(ssh->state->incoming_packet, valp, lenp);
2802 }
2803
2804 int
2805 sshpkt_get_string_direct(struct ssh *ssh, const u_char **valp, size_t *lenp)
2806 {
2807         return sshbuf_get_string_direct(ssh->state->incoming_packet, valp, lenp);
2808 }
2809
2810 int
2811 sshpkt_get_cstring(struct ssh *ssh, char **valp, size_t *lenp)
2812 {
2813         return sshbuf_get_cstring(ssh->state->incoming_packet, valp, lenp);
2814 }
2815
2816 #ifdef WITH_OPENSSL
2817 #ifdef OPENSSL_HAS_ECC
2818 int
2819 sshpkt_get_ec(struct ssh *ssh, EC_POINT *v, const EC_GROUP *g)
2820 {
2821         return sshbuf_get_ec(ssh->state->incoming_packet, v, g);
2822 }
2823 #endif /* OPENSSL_HAS_ECC */
2824
2825 #ifdef WITH_SSH1
2826 int
2827 sshpkt_get_bignum1(struct ssh *ssh, BIGNUM *v)
2828 {
2829         return sshbuf_get_bignum1(ssh->state->incoming_packet, v);
2830 }
2831 #endif /* WITH_SSH1 */
2832
2833 int
2834 sshpkt_get_bignum2(struct ssh *ssh, BIGNUM *v)
2835 {
2836         return sshbuf_get_bignum2(ssh->state->incoming_packet, v);
2837 }
2838 #endif /* WITH_OPENSSL */
2839
2840 int
2841 sshpkt_get_end(struct ssh *ssh)
2842 {
2843         if (sshbuf_len(ssh->state->incoming_packet) > 0)
2844                 return SSH_ERR_UNEXPECTED_TRAILING_DATA;
2845         return 0;
2846 }
2847
2848 const u_char *
2849 sshpkt_ptr(struct ssh *ssh, size_t *lenp)
2850 {
2851         if (lenp != NULL)
2852                 *lenp = sshbuf_len(ssh->state->incoming_packet);
2853         return sshbuf_ptr(ssh->state->incoming_packet);
2854 }
2855
2856 /* start a new packet */
2857
2858 int
2859 sshpkt_start(struct ssh *ssh, u_char type)
2860 {
2861         u_char buf[9];
2862         int len;
2863
2864         DBG(debug("packet_start[%d]", type));
2865         len = compat20 ? 6 : 9;
2866         memset(buf, 0, len - 1);
2867         buf[len - 1] = type;
2868         sshbuf_reset(ssh->state->outgoing_packet);
2869         return sshbuf_put(ssh->state->outgoing_packet, buf, len);
2870 }
2871
2872 /* send it */
2873
2874 int
2875 sshpkt_send(struct ssh *ssh)
2876 {
2877         if (compat20)
2878                 return ssh_packet_send2(ssh);
2879         else
2880                 return ssh_packet_send1(ssh);
2881 }
2882
2883 int
2884 sshpkt_disconnect(struct ssh *ssh, const char *fmt,...)
2885 {
2886         char buf[1024];
2887         va_list args;
2888         int r;
2889
2890         va_start(args, fmt);
2891         vsnprintf(buf, sizeof(buf), fmt, args);
2892         va_end(args);
2893
2894         if (compat20) {
2895                 if ((r = sshpkt_start(ssh, SSH2_MSG_DISCONNECT)) != 0 ||
2896                     (r = sshpkt_put_u32(ssh, SSH2_DISCONNECT_PROTOCOL_ERROR)) != 0 ||
2897                     (r = sshpkt_put_cstring(ssh, buf)) != 0 ||
2898                     (r = sshpkt_put_cstring(ssh, "")) != 0 ||
2899                     (r = sshpkt_send(ssh)) != 0)
2900                         return r;
2901         } else {
2902                 if ((r = sshpkt_start(ssh, SSH_MSG_DISCONNECT)) != 0 ||
2903                     (r = sshpkt_put_cstring(ssh, buf)) != 0 ||
2904                     (r = sshpkt_send(ssh)) != 0)
2905                         return r;
2906         }
2907         return 0;
2908 }
2909
2910 /* roundup current message to pad bytes */
2911 int
2912 sshpkt_add_padding(struct ssh *ssh, u_char pad)
2913 {
2914         ssh->state->extra_pad = pad;
2915         return 0;
2916 }