Only read one record at a time in sptps_receive_data().
authorEtienne Dechamps <etienne@edechamps.fr>
Sun, 10 May 2015 18:28:11 +0000 (19:28 +0100)
committerEtienne Dechamps <etienne@edechamps.fr>
Sun, 10 May 2015 20:08:57 +0000 (21:08 +0100)
sptps_receive_data() always consumes the entire buffer passed to it,
which is somewhat inflexible. This commit improves the interface so that
sptps_receive_data() consumes at most one record. The goal is to allow
non-SPTPS stuff to be interleaved with SPTPS records in a single TCP
stream.

src/meta.c
src/sptps.c
src/sptps.h

index 1c29fe9..0849d3c 100644 (file)
@@ -159,8 +159,14 @@ bool receive_meta(connection_t *c) {
        }
 
        do {
-               if(c->protocol_minor >= 2)
-                       return sptps_receive_data(&c->sptps, bufp, inlen);
+               if(c->protocol_minor >= 2) {
+                       int len = sptps_receive_data(&c->sptps, bufp, inlen);
+                       if(!len)
+                               return false;
+                       bufp += len;
+                       inlen -= len;
+                       continue;
+               }
 
                if(!c->status.decryptin) {
                        endp = memchr(bufp, '\n', inlen);
index 4a9683f..e5946ed 100644 (file)
@@ -495,90 +495,92 @@ static bool sptps_receive_data_datagram(sptps_t *s, const char *data, size_t len
 }
 
 // Receive incoming data. Check if it contains a complete record, if so, handle it.
-bool sptps_receive_data(sptps_t *s, const void *data, size_t len) {
+size_t sptps_receive_data(sptps_t *s, const void *data, size_t len) {
+       size_t total_read = 0;
+
        if(!s->state)
                return error(s, EIO, "Invalid session state zero");
 
        if(s->datagram)
-               return sptps_receive_data_datagram(s, data, len);
-
-       while(len) {
-               // First read the 2 length bytes.
-               if(s->buflen < 2) {
-                       size_t toread = 2 - s->buflen;
-                       if(toread > len)
-                               toread = len;
+               return sptps_receive_data_datagram(s, data, len) ? len : false;
 
-                       memcpy(s->inbuf + s->buflen, data, toread);
+       // First read the 2 length bytes.
+       if(s->buflen < 2) {
+               size_t toread = 2 - s->buflen;
+               if(toread > len)
+                       toread = len;
 
-                       s->buflen += toread;
-                       len -= toread;
-                       data += toread;
+               memcpy(s->inbuf + s->buflen, data, toread);
 
-                       // Exit early if we don't have the full length.
-                       if(s->buflen < 2)
-                               return true;
+               total_read += toread;
+               s->buflen += toread;
+               len -= toread;
+               data += toread;
 
-                       // Get the length bytes
+               // Exit early if we don't have the full length.
+               if(s->buflen < 2)
+                       return total_read;
 
-                       memcpy(&s->reclen, s->inbuf, 2);
-                       s->reclen = ntohs(s->reclen);
+               // Get the length bytes
 
-                       // If we have the length bytes, ensure our buffer can hold the whole request.
-                       s->inbuf = realloc(s->inbuf, s->reclen + 19UL);
-                       if(!s->inbuf)
-                               return error(s, errno, strerror(errno));
+               memcpy(&s->reclen, s->inbuf, 2);
+               s->reclen = ntohs(s->reclen);
 
-                       // Exit early if we have no more data to process.
-                       if(!len)
-                               return true;
-               }
+               // If we have the length bytes, ensure our buffer can hold the whole request.
+               s->inbuf = realloc(s->inbuf, s->reclen + 19UL);
+               if(!s->inbuf)
+                       return error(s, errno, strerror(errno));
 
-               // Read up to the end of the record.
-               size_t toread = s->reclen + (s->instate ? 19UL : 3UL) - s->buflen;
-               if(toread > len)
-                       toread = len;
+               // Exit early if we have no more data to process.
+               if(!len)
+                       return total_read;
+       }
 
-               memcpy(s->inbuf + s->buflen, data, toread);
-               s->buflen += toread;
-               len -= toread;
-               data += toread;
+       // Read up to the end of the record.
+       size_t toread = s->reclen + (s->instate ? 19UL : 3UL) - s->buflen;
+       if(toread > len)
+               toread = len;
 
-               // If we don't have a whole record, exit.
-               if(s->buflen < s->reclen + (s->instate ? 19UL : 3UL))
-                       return true;
+       memcpy(s->inbuf + s->buflen, data, toread);
+       total_read += toread;
+       s->buflen += toread;
+       len -= toread;
+       data += toread;
 
-               // Update sequence number.
+       // If we don't have a whole record, exit.
+       if(s->buflen < s->reclen + (s->instate ? 19UL : 3UL))
+               return total_read;
 
-               uint32_t seqno = s->inseqno++;
+       // Update sequence number.
 
-               // Check HMAC and decrypt.
-               if(s->instate) {
-                       if(!chacha_poly1305_decrypt(s->incipher, seqno, s->inbuf + 2UL, s->reclen + 17UL, s->inbuf + 2UL, NULL))
-                               return error(s, EINVAL, "Failed to decrypt and verify record");
-               }
+       uint32_t seqno = s->inseqno++;
 
-               // Append a NULL byte for safety.
-               s->inbuf[s->reclen + 3UL] = 0;
+       // Check HMAC and decrypt.
+       if(s->instate) {
+               if(!chacha_poly1305_decrypt(s->incipher, seqno, s->inbuf + 2UL, s->reclen + 17UL, s->inbuf + 2UL, NULL))
+                       return error(s, EINVAL, "Failed to decrypt and verify record");
+       }
 
-               uint8_t type = s->inbuf[2];
+       // Append a NULL byte for safety.
+       s->inbuf[s->reclen + 3UL] = 0;
 
-               if(type < SPTPS_HANDSHAKE) {
-                       if(!s->instate)
-                               return error(s, EIO, "Application record received before handshake finished");
-                       if(!s->receive_record(s->handle, type, s->inbuf + 3, s->reclen))
-                               return false;
-               } else if(type == SPTPS_HANDSHAKE) {
-                       if(!receive_handshake(s, s->inbuf + 3, s->reclen))
-                               return false;
-               } else {
-                       return error(s, EIO, "Invalid record type %d", type);
-               }
+       uint8_t type = s->inbuf[2];
 
-               s->buflen = 0;
+       if(type < SPTPS_HANDSHAKE) {
+               if(!s->instate)
+                       return error(s, EIO, "Application record received before handshake finished");
+               if(!s->receive_record(s->handle, type, s->inbuf + 3, s->reclen))
+                       return false;
+       } else if(type == SPTPS_HANDSHAKE) {
+               if(!receive_handshake(s, s->inbuf + 3, s->reclen))
+                       return false;
+       } else {
+               return error(s, EIO, "Invalid record type %d", type);
        }
 
-       return true;
+       s->buflen = 0;
+
+       return total_read;
 }
 
 // Start a SPTPS session.
index a2633bd..75a9565 100644 (file)
@@ -88,7 +88,7 @@ extern void (*sptps_log)(sptps_t *s, int s_errno, const char *format, va_list ap
 extern bool sptps_start(sptps_t *s, void *handle, bool initiator, bool datagram, ecdsa_t *mykey, ecdsa_t *hiskey, const void *label, size_t labellen, send_data_t send_data, receive_record_t receive_record);
 extern bool sptps_stop(sptps_t *s);
 extern bool sptps_send_record(sptps_t *s, uint8_t type, const void *data, uint16_t len);
-extern bool sptps_receive_data(sptps_t *s, const void *data, size_t len);
+extern size_t sptps_receive_data(sptps_t *s, const void *data, size_t len);
 extern bool sptps_force_kex(sptps_t *s);
 extern bool sptps_verify_datagram(sptps_t *s, const void *data, size_t len);