From d237efd325cd7bdd73f5eb111c769470238dce6e Mon Sep 17 00:00:00 2001
From: Etienne Dechamps <etienne@edechamps.fr>
Date: Sun, 10 May 2015 19:28:11 +0100
Subject: [PATCH] Only read one record at a time in sptps_receive_data().

sptps_receive_data() always consumes the entire buffer passed to it,
which is somewhat inflexible. This commit improves the interface so that
sptps_receive_data() consumes at most one record. The goal is to allow
non-SPTPS stuff to be interleaved with SPTPS records in a single TCP
stream.
---
 src/meta.c  |  10 ++++-
 src/sptps.c | 124 ++++++++++++++++++++++++++--------------------------
 src/sptps.h |   2 +-
 3 files changed, 72 insertions(+), 64 deletions(-)

diff --git a/src/meta.c b/src/meta.c
index 1c29fe9c..0849d3cd 100644
--- a/src/meta.c
+++ b/src/meta.c
@@ -159,8 +159,14 @@ bool receive_meta(connection_t *c) {
 	}
 
 	do {
-		if(c->protocol_minor >= 2)
-			return sptps_receive_data(&c->sptps, bufp, inlen);
+		if(c->protocol_minor >= 2) {
+			int len = sptps_receive_data(&c->sptps, bufp, inlen);
+			if(!len)
+				return false;
+			bufp += len;
+			inlen -= len;
+			continue;
+		}
 
 		if(!c->status.decryptin) {
 			endp = memchr(bufp, '\n', inlen);
diff --git a/src/sptps.c b/src/sptps.c
index 4a9683f2..e5946ed6 100644
--- a/src/sptps.c
+++ b/src/sptps.c
@@ -495,90 +495,92 @@ static bool sptps_receive_data_datagram(sptps_t *s, const char *data, size_t len
 }
 
 // Receive incoming data. Check if it contains a complete record, if so, handle it.
-bool sptps_receive_data(sptps_t *s, const void *data, size_t len) {
+size_t sptps_receive_data(sptps_t *s, const void *data, size_t len) {
+	size_t total_read = 0;
+
 	if(!s->state)
 		return error(s, EIO, "Invalid session state zero");
 
 	if(s->datagram)
-		return sptps_receive_data_datagram(s, data, len);
-
-	while(len) {
-		// First read the 2 length bytes.
-		if(s->buflen < 2) {
-			size_t toread = 2 - s->buflen;
-			if(toread > len)
-				toread = len;
+		return sptps_receive_data_datagram(s, data, len) ? len : false;
 
-			memcpy(s->inbuf + s->buflen, data, toread);
+	// First read the 2 length bytes.
+	if(s->buflen < 2) {
+		size_t toread = 2 - s->buflen;
+		if(toread > len)
+			toread = len;
 
-			s->buflen += toread;
-			len -= toread;
-			data += toread;
+		memcpy(s->inbuf + s->buflen, data, toread);
 
-			// Exit early if we don't have the full length.
-			if(s->buflen < 2)
-				return true;
+		total_read += toread;
+		s->buflen += toread;
+		len -= toread;
+		data += toread;
 
-			// Get the length bytes
+		// Exit early if we don't have the full length.
+		if(s->buflen < 2)
+			return total_read;
 
-			memcpy(&s->reclen, s->inbuf, 2);
-			s->reclen = ntohs(s->reclen);
+		// Get the length bytes
 
-			// If we have the length bytes, ensure our buffer can hold the whole request.
-			s->inbuf = realloc(s->inbuf, s->reclen + 19UL);
-			if(!s->inbuf)
-				return error(s, errno, strerror(errno));
+		memcpy(&s->reclen, s->inbuf, 2);
+		s->reclen = ntohs(s->reclen);
 
-			// Exit early if we have no more data to process.
-			if(!len)
-				return true;
-		}
+		// If we have the length bytes, ensure our buffer can hold the whole request.
+		s->inbuf = realloc(s->inbuf, s->reclen + 19UL);
+		if(!s->inbuf)
+			return error(s, errno, strerror(errno));
 
-		// Read up to the end of the record.
-		size_t toread = s->reclen + (s->instate ? 19UL : 3UL) - s->buflen;
-		if(toread > len)
-			toread = len;
+		// Exit early if we have no more data to process.
+		if(!len)
+			return total_read;
+	}
 
-		memcpy(s->inbuf + s->buflen, data, toread);
-		s->buflen += toread;
-		len -= toread;
-		data += toread;
+	// Read up to the end of the record.
+	size_t toread = s->reclen + (s->instate ? 19UL : 3UL) - s->buflen;
+	if(toread > len)
+		toread = len;
 
-		// If we don't have a whole record, exit.
-		if(s->buflen < s->reclen + (s->instate ? 19UL : 3UL))
-			return true;
+	memcpy(s->inbuf + s->buflen, data, toread);
+	total_read += toread;
+	s->buflen += toread;
+	len -= toread;
+	data += toread;
 
-		// Update sequence number.
+	// If we don't have a whole record, exit.
+	if(s->buflen < s->reclen + (s->instate ? 19UL : 3UL))
+		return total_read;
 
-		uint32_t seqno = s->inseqno++;
+	// Update sequence number.
 
-		// Check HMAC and decrypt.
-		if(s->instate) {
-			if(!chacha_poly1305_decrypt(s->incipher, seqno, s->inbuf + 2UL, s->reclen + 17UL, s->inbuf + 2UL, NULL))
-				return error(s, EINVAL, "Failed to decrypt and verify record");
-		}
+	uint32_t seqno = s->inseqno++;
 
-		// Append a NULL byte for safety.
-		s->inbuf[s->reclen + 3UL] = 0;
+	// Check HMAC and decrypt.
+	if(s->instate) {
+		if(!chacha_poly1305_decrypt(s->incipher, seqno, s->inbuf + 2UL, s->reclen + 17UL, s->inbuf + 2UL, NULL))
+			return error(s, EINVAL, "Failed to decrypt and verify record");
+	}
 
-		uint8_t type = s->inbuf[2];
+	// Append a NULL byte for safety.
+	s->inbuf[s->reclen + 3UL] = 0;
 
-		if(type < SPTPS_HANDSHAKE) {
-			if(!s->instate)
-				return error(s, EIO, "Application record received before handshake finished");
-			if(!s->receive_record(s->handle, type, s->inbuf + 3, s->reclen))
-				return false;
-		} else if(type == SPTPS_HANDSHAKE) {
-			if(!receive_handshake(s, s->inbuf + 3, s->reclen))
-				return false;
-		} else {
-			return error(s, EIO, "Invalid record type %d", type);
-		}
+	uint8_t type = s->inbuf[2];
 
-		s->buflen = 0;
+	if(type < SPTPS_HANDSHAKE) {
+		if(!s->instate)
+			return error(s, EIO, "Application record received before handshake finished");
+		if(!s->receive_record(s->handle, type, s->inbuf + 3, s->reclen))
+			return false;
+	} else if(type == SPTPS_HANDSHAKE) {
+		if(!receive_handshake(s, s->inbuf + 3, s->reclen))
+			return false;
+	} else {
+		return error(s, EIO, "Invalid record type %d", type);
 	}
 
-	return true;
+	s->buflen = 0;
+
+	return total_read;
 }
 
 // Start a SPTPS session.
diff --git a/src/sptps.h b/src/sptps.h
index a2633bd1..75a95651 100644
--- a/src/sptps.h
+++ b/src/sptps.h
@@ -88,7 +88,7 @@ extern void (*sptps_log)(sptps_t *s, int s_errno, const char *format, va_list ap
 extern bool sptps_start(sptps_t *s, void *handle, bool initiator, bool datagram, ecdsa_t *mykey, ecdsa_t *hiskey, const void *label, size_t labellen, send_data_t send_data, receive_record_t receive_record);
 extern bool sptps_stop(sptps_t *s);
 extern bool sptps_send_record(sptps_t *s, uint8_t type, const void *data, uint16_t len);
-extern bool sptps_receive_data(sptps_t *s, const void *data, size_t len);
+extern size_t sptps_receive_data(sptps_t *s, const void *data, size_t len);
 extern bool sptps_force_kex(sptps_t *s);
 extern bool sptps_verify_datagram(sptps_t *s, const void *data, size_t len);
 
-- 
2.39.5