net.server_epoll: Factor out TLS initialization into a method
authorKim Alvefur <zash@zash.se>
Tue, 13 Jul 2021 14:20:24 +0200
changeset 11676 79f8e29e88a0
parent 11675 4e4e26e3df8d
child 11677 3ab8496579f1
net.server_epoll: Factor out TLS initialization into a method So there's :startls(), :inittls() and :tlshandshake() :starttls() prepares for plain -> TLS upgrade and ensures that the (unencrypted) write buffer is drained before proceeding. :inittls() wraps the connection and does things like SNI, DANE etc. :tlshandshake() steps the TLS negotiation forward until it completes
net/server_epoll.lua
--- a/net/server_epoll.lua	Tue Jul 13 02:05:35 2021 +0200
+++ b/net/server_epoll.lua	Tue Jul 13 14:20:24 2021 +0200
@@ -561,8 +561,8 @@
 		if self.ondrain == interface.starttls then
 			self.ondrain = nil;
 		end
-		self.onwritable = interface.tlshandshake;
-		self.onreadable = interface.tlshandshake;
+		self.onwritable = interface.inittls;
+		self.onreadable = interface.inittls;
 		self:set(true, true);
 		self:setreadtimeout(cfg.ssl_handshake_timeout);
 		self:setwritetimeout(cfg.ssl_handshake_timeout);
@@ -570,52 +570,55 @@
 	end
 end
 
+function interface:inittls(tls_ctx)
+	if self._tls then return end
+	if tls_ctx then self.tls_ctx = tls_ctx; end
+	self._tls = true;
+	self:debug("Starting TLS now");
+	self:del();
+	self:updatenames(); -- Can't getpeer/sockname after wrap()
+	local ok, conn, err = pcall(luasec.wrap, self.conn, self.tls_ctx);
+	if not ok then
+		conn, err = ok, conn;
+		self:debug("Failed to initialize TLS: %s", err);
+	end
+	if not conn then
+		self:on("disconnect", err);
+		self:destroy();
+		return conn, err;
+	end
+	conn:settimeout(0);
+	self.conn = conn;
+	if conn.sni then
+		if self.servername then
+			conn:sni(self.servername);
+		elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then
+			conn:sni(self._server.hosts, true);
+		end
+	end
+	if self.extra and self.extra.tlsa and conn.settlsa then
+		-- TODO Error handling
+		if not conn:setdane(self.servername or self.extra.dane_hostname) then
+			self:debug("Could not enable DANE on connection");
+		else
+			self:debug("Enabling DANE with %d TLSA records", #self.extra.tlsa);
+			self:noise("DANE hostname is %q", self.servername or self.extra.dane_hostname);
+			for _, tlsa in ipairs(self.extra.tlsa) do
+				self:noise("TLSA: %q", tlsa);
+				conn:settlsa(tlsa.use, tlsa.select, tlsa.match, tlsa.data);
+			end
+		end
+	end
+	self:on("starttls");
+	self.ondrain = nil;
+	self.onwritable = interface.tlshandshake;
+	self.onreadable = interface.tlshandshake;
+	return self:init();
+end
+
 function interface:tlshandshake()
 	self:setwritetimeout(false);
 	self:setreadtimeout(false);
-	if not self._tls then
-		self._tls = true;
-		self:debug("Starting TLS now");
-		self:del();
-		self:updatenames(); -- Can't getpeer/sockname after wrap()
-		local ok, conn, err = pcall(luasec.wrap, self.conn, self.tls_ctx);
-		if not ok then
-			conn, err = ok, conn;
-			self:debug("Failed to initialize TLS: %s", err);
-		end
-		if not conn then
-			self:on("disconnect", err);
-			self:destroy();
-			return conn, err;
-		end
-		conn:settimeout(0);
-		self.conn = conn;
-		if conn.sni then
-			if self.servername then
-				conn:sni(self.servername);
-			elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then
-				conn:sni(self._server.hosts, true);
-			end
-		end
-		if self.extra and self.extra.tlsa and conn.settlsa then
-			-- TODO Error handling
-			if not conn:setdane(self.servername or self.extra.dane_hostname) then
-				self:debug("Could not enable DANE on connection");
-			else
-				self:debug("Enabling DANE with %d TLSA records", #self.extra.tlsa);
-				self:noise("DANE hostname is %q", self.servername or self.extra.dane_hostname);
-				for _, tlsa in ipairs(self.extra.tlsa) do
-					self:noise("TLSA: %q", tlsa);
-					conn:settlsa(tlsa.use, tlsa.select, tlsa.match, tlsa.data);
-				end
-			end
-		end
-		self:on("starttls");
-		self.ondrain = nil;
-		self.onwritable = interface.tlshandshake;
-		self.onreadable = interface.tlshandshake;
-		return self:init();
-	end
 	self:noise("Continuing TLS handshake");
 	local ok, err = self.conn:dohandshake();
 	if ok then
@@ -697,7 +700,10 @@
 	client:debug("New connection %s on server %s", client, self);
 	if self.tls_direct then
 		client:add(true, true);
-		client:starttls(self.tls_ctx);
+		if client:inittls(self.tls_ctx) then
+			client:setreadtimeout(cfg.ssl_handshake_timeout);
+			client:setwritetimeout(cfg.ssl_handshake_timeout);
+		end
 	else
 		client:add(true, false);
 		client:onconnect();