net.resolvers.service: Honour record 'weight' when picking SRV targets
authorMatthew Wild <mwild1@gmail.com>
Thu, 17 Mar 2022 18:20:26 +0000
changeset 12405 c029ddcad258
parent 12404 728d1c1dc7db
child 12406 8deb401a21f6
net.resolvers.service: Honour record 'weight' when picking SRV targets #NotHappyEyeballs
net/resolvers/service.lua
spec/net_resolvers_service_spec.lua
--- a/net/resolvers/service.lua	Thu Mar 17 13:15:50 2022 +0100
+++ b/net/resolvers/service.lua	Thu Mar 17 18:20:26 2022 +0000
@@ -2,23 +2,78 @@
 local basic = require "net.resolvers.basic";
 local inet_pton = require "util.net".pton;
 local idna_to_ascii = require "util.encodings".idna.to_ascii;
-local unpack = table.unpack or unpack; -- luacheck: ignore 113
 
 local methods = {};
 local resolver_mt = { __index = methods };
 
+local function new_target_selector(rrset)
+	local rr_count = rrset and #rrset;
+	if not rr_count or rr_count == 0 then
+		rrset = nil;
+	else
+		table.sort(rrset, function (a, b) return a.srv.priority < b.srv.priority end);
+	end
+	local rrset_pos = 1;
+	local priority_bucket, bucket_total_weight, bucket_len, bucket_used;
+	return function ()
+		if not rrset then return; end
+
+		if not priority_bucket or bucket_used >= bucket_len then
+			if rrset_pos > rr_count then return; end -- Used up all records
+
+			-- Going to start on a new priority now. Gather up all the next
+			-- records with the same priority and add them to priority_bucket
+			priority_bucket, bucket_total_weight, bucket_len, bucket_used = {}, 0, 0, 0;
+			local current_priority;
+			repeat
+				local curr_record = rrset[rrset_pos].srv;
+				if not current_priority then
+					current_priority = curr_record.priority;
+				elseif current_priority ~= curr_record.priority then
+					break;
+				end
+				table.insert(priority_bucket, curr_record);
+				bucket_total_weight = bucket_total_weight + curr_record.weight;
+				bucket_len = bucket_len + 1;
+				rrset_pos = rrset_pos + 1;
+			until rrset_pos > rr_count;
+		end
+
+		bucket_used = bucket_used + 1;
+		local n, running_total = math.random(0, bucket_total_weight), 0;
+		local target_record;
+		for i = 1, bucket_len do
+			local candidate = priority_bucket[i];
+			if candidate then
+				running_total = running_total + candidate.weight;
+				if running_total >= n then
+					target_record = candidate;
+					bucket_total_weight = bucket_total_weight - candidate.weight;
+					priority_bucket[i] = nil;
+					break;
+				end
+			end
+		end
+		return target_record;
+	end;
+end
+
 -- Find the next target to connect to, and
 -- pass it to cb()
 function methods:next(cb)
-	if self.targets then
-		if not self.resolver then
-			if #self.targets == 0 then
+	if self.resolver or self._get_next_target then
+		if not self.resolver then -- Do we have a basic resolver currently?
+			-- We don't, so fetch a new SRV target, create a new basic resolver for it
+			local next_srv_target = self._get_next_target and self._get_next_target();
+			if not next_srv_target then
+				-- No more SRV targets left
 				cb(nil);
 				return;
 			end
-			local next_target = table.remove(self.targets, 1);
-			self.resolver = basic.new(unpack(next_target, 1, 4));
+			-- Create a new basic resolver for this SRV target
+			self.resolver = basic.new(next_srv_target.target, next_srv_target.port, self.conn_type, self.extra);
 		end
+		-- Look up the next (basic) target from the current target's resolver
 		self.resolver:next(function (...)
 			if self.resolver then
 				self.last_error = self.resolver.last_error;
@@ -31,6 +86,9 @@
 			end
 		end);
 		return;
+	elseif self.in_progress then
+		cb(nil);
+		return;
 	end
 
 	if not self.hostname then
@@ -39,9 +97,9 @@
 		return;
 	end
 
-	local targets = {};
+	self.in_progress = true;
+
 	local function ready()
-		self.targets = targets;
 		self:next(cb);
 	end
 
@@ -63,7 +121,7 @@
 
 			if #answer == 0 then
 				if self.extra and self.extra.default_port then
-					table.insert(targets, { self.hostname, self.extra.default_port, self.conn_type, self.extra });
+					self.resolver = basic.new(self.hostname, self.extra.default_port, self.conn_type, self.extra);
 				else
 					self.last_error = "zero SRV records found";
 				end
@@ -77,10 +135,7 @@
 				return;
 			end
 
-			table.sort(answer, function (a, b) return a.srv.priority < b.srv.priority end);
-			for _, record in ipairs(answer) do
-				table.insert(targets, { record.srv.target, record.srv.port, self.conn_type, self.extra });
-			end
+			self._get_next_target = new_target_selector(answer);
 		else
 			self.last_error = err;
 		end
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/spec/net_resolvers_service_spec.lua	Thu Mar 17 18:20:26 2022 +0000
@@ -0,0 +1,241 @@
+local set = require "util.set";
+
+insulate("net.resolvers.service", function ()
+	local adns = {
+		resolver = function ()
+			return {
+				lookup = function (_, cb, qname, qtype, qclass)
+					if qname == "_xmpp-server._tcp.example.com"
+					   and (qtype or "SRV") == "SRV"
+					   and (qclass or "IN") == "IN" then
+						cb({
+							{ -- 60+35+60
+								srv = { target = "xmpp0-a.example.com", port = 5228, priority = 0, weight = 60 };
+							};
+							{
+								srv = { target = "xmpp0-b.example.com", port = 5216, priority = 0, weight = 35 };
+							};
+							{
+								srv = { target = "xmpp0-c.example.com", port = 5200, priority = 0, weight = 0 };
+							};
+							{
+								srv = { target = "xmpp0-d.example.com", port = 5256, priority = 0, weight = 120 };
+							};
+
+							{
+								srv = { target = "xmpp1-a.example.com", port = 5273, priority = 1, weight = 30 };
+							};
+							{
+								srv = { target = "xmpp1-b.example.com", port = 5274, priority = 1, weight = 30 };
+							};
+
+							{
+								srv = { target = "xmpp2.example.com", port = 5275, priority = 2, weight = 0 };
+							};
+						});
+					elseif qname == "_xmpp-server._tcp.single.example.com"
+					   and (qtype or "SRV") == "SRV"
+					   and (qclass or "IN") == "IN" then
+						cb({
+							{
+								srv = { target = "xmpp0-a.example.com", port = 5269, priority = 0, weight = 0 };
+							};
+						});
+					elseif qname == "_xmpp-server._tcp.half.example.com"
+					   and (qtype or "SRV") == "SRV"
+					   and (qclass or "IN") == "IN" then
+						cb({
+							{
+								srv = { target = "xmpp0-a.example.com", port = 5269, priority = 0, weight = 0 };
+							};
+							{
+								srv = { target = "xmpp0-b.example.com", port = 5270, priority = 0, weight = 1 };
+							};
+						});
+					elseif qtype == "A" then
+						local l = qname:match("%-(%a)%.example.com$") or "1";
+						local d = ("%d"):format(l:byte())
+						cb({
+							{
+								a = "127.0.0."..d;
+							};
+						});
+					elseif qtype == "AAAA" then
+						local l = qname:match("%-(%a)%.example.com$") or "1";
+						local d = ("%04d"):format(l:byte())
+						cb({
+							{
+								aaaa = "fdeb:9619:649e:c7d9::"..d;
+							};
+						});
+					else
+						cb(nil);
+					end
+				end;
+			};
+		end;
+	};
+	package.loaded["net.adns"] = mock(adns);
+	local resolver = require "net.resolvers.service";
+	math.randomseed(os.time());
+	it("works for 99% of deployments", function ()
+		-- Most deployments only have a single SRV record, let's make
+		-- sure that works okay
+
+		local expected_targets = set.new({
+			-- xmpp0-a
+			"tcp4  127.0.0.97  5269";
+			"tcp6  fdeb:9619:649e:c7d9::0097  5269";
+		});
+		local received_targets = set.new({});
+
+		local r = resolver.new("single.example.com", "xmpp-server");
+		local done = false;
+		local function handle_target(...)
+			if ... == nil then
+				done = true;
+				-- No more targets
+				return;
+			end
+			received_targets:add(table.concat({ ... }, "  ", 1, 3));
+		end
+		r:next(handle_target);
+		while not done do
+			r:next(handle_target);
+		end
+
+		-- We should have received all expected targets, and no unexpected
+		-- ones:
+		assert.truthy(set.xor(received_targets, expected_targets):empty());
+	end);
+
+	it("supports A/AAAA fallback", function ()
+		-- Many deployments don't have any SRV records, so we should
+		-- fall back to A/AAAA records instead when that is the case
+
+		local expected_targets = set.new({
+			-- xmpp0-a
+			"tcp4  127.0.0.97  5269";
+			"tcp6  fdeb:9619:649e:c7d9::0097  5269";
+		});
+		local received_targets = set.new({});
+
+		local r = resolver.new("xmpp0-a.example.com", "xmpp-server", "tcp", { default_port = 5269 });
+		local done = false;
+		local function handle_target(...)
+			if ... == nil then
+				done = true;
+				-- No more targets
+				return;
+			end
+			received_targets:add(table.concat({ ... }, "  ", 1, 3));
+		end
+		r:next(handle_target);
+		while not done do
+			r:next(handle_target);
+		end
+
+		-- We should have received all expected targets, and no unexpected
+		-- ones:
+		assert.truthy(set.xor(received_targets, expected_targets):empty());
+	end);
+
+
+	it("works", function ()
+		local expected_targets = set.new({
+			-- xmpp0-a
+			"tcp4  127.0.0.97  5228";
+			"tcp6  fdeb:9619:649e:c7d9::0097  5228";
+			"tcp4  127.0.0.97  5273";
+			"tcp6  fdeb:9619:649e:c7d9::0097  5273";
+
+			-- xmpp0-b
+			"tcp4  127.0.0.98  5274";
+			"tcp6  fdeb:9619:649e:c7d9::0098  5274";
+			"tcp4  127.0.0.98  5216";
+			"tcp6  fdeb:9619:649e:c7d9::0098  5216";
+
+			-- xmpp0-c
+			"tcp4  127.0.0.99  5200";
+			"tcp6  fdeb:9619:649e:c7d9::0099  5200";
+
+			-- xmpp0-d
+			"tcp4  127.0.0.100  5256";
+			"tcp6  fdeb:9619:649e:c7d9::0100  5256";
+
+			-- xmpp2
+			"tcp4  127.0.0.49  5275";
+			"tcp6  fdeb:9619:649e:c7d9::0049  5275";
+
+		});
+		local received_targets = set.new({});
+
+		local r = resolver.new("example.com", "xmpp-server");
+		local done = false;
+		local function handle_target(...)
+			if ... == nil then
+				done = true;
+				-- No more targets
+				return;
+			end
+			received_targets:add(table.concat({ ... }, "  ", 1, 3));
+		end
+		r:next(handle_target);
+		while not done do
+			r:next(handle_target);
+		end
+
+		-- We should have received all expected targets, and no unexpected
+		-- ones:
+		assert.truthy(set.xor(received_targets, expected_targets):empty());
+	end);
+
+	it("balances across weights correctly #slow", function ()
+		-- This mimics many repeated connections to 'example.com' (mock
+		-- records defined above), and records the port number of the
+		-- first target. Therefore it (should) only return priority
+		-- 0 records, and the input data is constructed such that the
+		-- last two digits of the port number represent the percentage
+		-- of times that record should (on average) be picked first.
+
+		-- To prevent random test failures, we test across a handful
+		-- of fixed (randomly selected) seeds.
+		for _, seed in ipairs({ 8401877, 3943829, 7830992 }) do
+			math.randomseed(seed);
+
+			local results = {};
+			local function run()
+				local run_results = {};
+				local r = resolver.new("example.com", "xmpp-server");
+				local function record_target(...)
+					if ... == nil then
+						-- No more targets
+						return;
+					end
+					run_results = { ... };
+				end
+				r:next(record_target);
+				return run_results[3];
+			end
+
+			for _ = 1, 1000 do
+				local port = run();
+				results[port] = (results[port] or 0) + 1;
+			end
+
+			local ports = {};
+			for port in pairs(results) do
+				table.insert(ports, port);
+			end
+			table.sort(ports);
+			for _, port in ipairs(ports) do
+				--print("PORT", port, tostring((results[port]/1000) * 100).."% hits (expected "..tostring(port-5200).."%)");
+				local hit_pct = (results[port]/1000) * 100;
+				local expected_pct = port - 5200;
+				--print(hit_pct, expected_pct, math.abs(hit_pct - expected_pct));
+				assert.is_true(math.abs(hit_pct - expected_pct) < 5);
+			end
+			--print("---");
+		end
+	end);
+end);