plugins/mod_smacks.lua
changeset 12060 e62025f949f9
parent 12058 0116fa57f05c
child 12063 70a55fbe447c
--- a/plugins/mod_smacks.lua	Tue Dec 14 19:58:53 2021 +0100
+++ b/plugins/mod_smacks.lua	Tue Dec 14 20:00:45 2021 +0100
@@ -13,13 +13,12 @@
 
 local tonumber = tonumber;
 local tostring = tostring;
-local math_min = math.min;
 local os_time = os.time;
-local t_remove = table.remove;
 
 local datetime = require "util.datetime";
 local add_filter = require "util.filters".add_filter;
 local jid = require "util.jid";
+local smqueue = require "util.smqueue";
 local st = require "util.stanza";
 local timer = require "util.timer";
 local uuid_generate = require "util.uuid".generate;
@@ -37,6 +36,7 @@
 local sm2_attr = { xmlns = xmlns_sm2 };
 local sm3_attr = { xmlns = xmlns_sm3 };
 
+local queue_size = module:get_option_number("smacks_max_queue_size", 500);
 local resume_timeout = module:get_option_number("smacks_hibernation_time", 600);
 local s2s_smacks = module:get_option_boolean("smacks_enabled_s2s", true);
 local s2s_resend = module:get_option_boolean("smacks_s2s_resend", false);
@@ -51,13 +51,22 @@
 local old_session_registry = module:open_store("smacks_h", "map");
 local session_registry = module:shared "/*/smacks/resumption-tokens"; -- > user@host/resumption-token --> resource
 
+local ack_errors = require"util.error".init("mod_smacks", xmlns_sm3, {
+	head = { condition = "undefined-condition"; text = "Client acknowledged more stanzas than sent by server" };
+	tail = { condition = "undefined-condition"; text = "Client acknowledged less stanzas than already acknowledged" };
+	pop = { condition = "internal-server-error"; text = "Something went wrong with Stream Management" };
+	overflow = { condition = "resource-constraint", text = "Too many unacked stanzas remaining, session can't be resumed" }
+});
+
+-- COMPAT note the use of compatibilty wrapper in events (queue:table())
+
 local function ack_delayed(session, stanza)
 	-- fire event only if configured to do so and our session is not already hibernated or destroyed
 	if delayed_ack_timeout > 0 and session.awaiting_ack
 	and not session.hibernating and not session.destroyed then
 		session.log("debug", "Firing event 'smacks-ack-delayed', queue = %d",
-			session.outgoing_stanza_queue and #session.outgoing_stanza_queue or 0);
-		module:fire_event("smacks-ack-delayed", {origin = session, queue = session.outgoing_stanza_queue, stanza = stanza});
+			session.outgoing_stanza_queue and session.outgoing_stanza_queue:count_unacked() or 0);
+		module:fire_event("smacks-ack-delayed", {origin = session, queue = session.outgoing_stanza_queue:table(), stanza = stanza});
 	end
 	session.delayed_ack_timer = nil;
 end
@@ -101,7 +110,7 @@
 	if session.awaiting_ack then return end -- already waiting
 	if force then return force end
 	local queue = session.outgoing_stanza_queue;
-	local expected_h = session.last_acknowledged_stanza + #queue;
+	local expected_h = session.last_acknowledged_stanza + queue:count_unacked();
 	local max_unacked = max_unacked_stanzas;
 	if session.state == "inactive" then
 		max_unacked = max_inactive_unacked_stanzas;
@@ -109,18 +118,18 @@
 	-- this check of last_requested_h prevents ack-loops if missbehaving clients report wrong
 	-- stanza counts. it is set when an <r> is really sent (e.g. inside timer), preventing any
 	-- further requests until a higher h-value would be expected.
-	return #queue > max_unacked and expected_h ~= session.last_requested_h;
+	return queue:count_unacked() > max_unacked and expected_h ~= session.last_requested_h;
 end
 
 local function request_ack(session, reason)
 	local queue = session.outgoing_stanza_queue;
-	session.log("debug", "Sending <r> (inside timer, before send) from %s - #queue=%d", reason, #queue);
+	session.log("debug", "Sending <r> (inside timer, before send) from %s - #queue=%d", reason, queue:count_unacked());
 	(session.sends2s or session.send)(st.stanza("r", { xmlns = session.smacks }))
 	if session.destroyed then return end -- sending something can trigger destruction
 	session.awaiting_ack = true;
 	-- expected_h could be lower than this expression e.g. more stanzas added to the queue meanwhile)
-	session.last_requested_h = session.last_acknowledged_stanza + #queue;
-	session.log("debug", "Sending <r> (inside timer, after send) from %s - #queue=%d", reason, #queue);
+	session.last_requested_h = session.last_acknowledged_stanza + queue:count_unacked();
+	session.log("debug", "Sending <r> (inside timer, after send) from %s - #queue=%d", reason, queue:count_unacked());
 	if not session.delayed_ack_timer then
 		session.delayed_ack_timer = timer.add_task(delayed_ack_timeout, function()
 			ack_delayed(session, nil); -- we don't know if this is the only new stanza in the queue
@@ -150,7 +159,7 @@
 	if session.state == "inactive" then
 		max_unacked = max_inactive_unacked_stanzas;
 	end
-	if #queue > max_unacked and session.awaiting_ack and session.delayed_ack_timer == nil then
+	if queue:count_unacked() > max_unacked and session.awaiting_ack and session.delayed_ack_timer == nil then
 		session.log("debug", "Calling ack_delayed directly (still waiting for ack)");
 		ack_delayed(session, stanza); -- this is the only new stanza in the queue --> provide it to other modules
 	end
@@ -178,10 +187,12 @@
 			});
 		end
 
-		queue[#queue+1] = cached_stanza;
+		queue:push(cached_stanza);
+
 		if session.hibernating then
 			session.log("debug", "hibernating since %s, stanza queued", datetime.datetime(session.hibernating));
-			module:fire_event("smacks-hibernation-stanza-queued", {origin = session, queue = queue, stanza = cached_stanza});
+			-- FIXME queue implementation changed, anything depending on it being an array will break
+			module:fire_event("smacks-hibernation-stanza-queued", {origin = session, queue = queue:table(), stanza = cached_stanza});
 			return nil;
 		end
 	end
@@ -198,7 +209,7 @@
 
 local function wrap_session_out(session, resume)
 	if not resume then
-		session.outgoing_stanza_queue = {};
+		session.outgoing_stanza_queue = smqueue.new(queue_size);
 		session.last_acknowledged_stanza = 0;
 	end
 
@@ -324,31 +335,26 @@
 		origin.delayed_ack_timer = nil;
 	end
 	-- Remove handled stanzas from outgoing_stanza_queue
-	-- origin.log("debug", "ACK: h=%s, last=%s", stanza.attr.h or "", origin.last_acknowledged_stanza or "");
 	local h = tonumber(stanza.attr.h);
 	if not h then
 		origin:close{ condition = "invalid-xml"; text = "Missing or invalid 'h' attribute"; };
 		return;
 	end
-	local handled_stanza_count = h-origin.last_acknowledged_stanza;
 	local queue = origin.outgoing_stanza_queue;
-	if handled_stanza_count > #queue then
-		origin.log("warn", "The client says it handled %d new stanzas, but we only sent %d :)",
-			handled_stanza_count, #queue);
-		origin.log("debug", "Client h: %d, our h: %d", tonumber(stanza.attr.h), origin.last_acknowledged_stanza);
-		for i=1,#queue do
-			origin.log("debug", "Q item %d: %s", i, tostring(queue[i]));
+	local handled_stanza_count = h-queue:count_acked();
+	local acked, err = ack_errors.coerce(queue:ack(h)); -- luacheck: ignore 211/acked
+	if err then
+		origin.log("warn", "The client says it handled %d new stanzas, but we sent %d :)",
+			handled_stanza_count, queue:count_unacked());
+		origin.log("debug", "Client h: %d, our h: %d", tonumber(stanza.attr.h), queue:count_acked());
+		for i, item in queue._queue:items() do
+			origin.log("debug", "Q item %d: %s", i, item);
 		end
-		origin:close{ condition = "undefined-condition"; text = "Client acknowledged more stanzas than sent by server"; };
+		origin:close(err);
 		return;
 	end
 
-	for _=1,math_min(handled_stanza_count,#queue) do
-		t_remove(origin.outgoing_stanza_queue, 1);
-	end
-
-	origin.log("debug", "#queue = %d", #queue);
-	origin.last_acknowledged_stanza = origin.last_acknowledged_stanza + handled_stanza_count;
+	origin.log("debug", "#queue = %d", queue:count_unacked());
 	request_ack_now_if_needed(origin, false, "handle_a", nil)
 	return true;
 end
@@ -357,12 +363,13 @@
 
 local function handle_unacked_stanzas(session)
 	local queue = session.outgoing_stanza_queue;
-	if #queue > 0 then
-		session.outgoing_stanza_queue = {};
-		for i=1,#queue do
-			if not module:fire_event("delivery/failure", { session = session, stanza = queue[i] }) then
-				if queue[i].attr.type ~= "error" and queue[i].attr.from ~= session.full_jid then
-					local reply = st.error_reply(queue[i], "cancel", "recipient-unavailable");
+	if queue:count_unacked() > 0 then
+		session.smacks = false; -- Disable queueing
+		session.outgoing_stanza_queue = nil;
+		for stanza in queue._queue:consume() do
+			if not module:fire_event("delivery/failure", { session = session, stanza = stanza }) then
+				if stanza.attr.type ~= "error" and stanza.attr.to ~= session.full_jid then
+					local reply = st.error_reply(stanza, "cancel", "recipient-unavailable");
 					core_process_stanza(session, reply);
 				end
 			end
@@ -416,8 +423,8 @@
 	if not session.smacks then return end
 		if not session.resumption_token then
 			local queue = session.outgoing_stanza_queue;
-			if #queue > 0 then
-				session.log("debug", "Destroying session with %d unacked stanzas", #queue);
+		if queue:count_unacked() > 0 then
+			session.log("debug", "Destroying session with %d unacked stanzas", queue:count_unacked());
 				handle_unacked_stanzas(session);
 			end
 		return
@@ -440,18 +447,18 @@
 	if session.conn then
 		session.conn:close();
 	end
-	module:fire_event("smacks-hibernation-start", { origin = session; queue = session.outgoing_stanza_queue });
+	module:fire_event("smacks-hibernation-start", { origin = session; queue = session.outgoing_stanza_queue:table() });
 	return true; -- Postpone destruction for now
 end);
 
 local function handle_s2s_destroyed(event)
 	local session = event.session;
 	local queue = session.outgoing_stanza_queue;
-	if queue and #queue > 0 then
-		session.log("warn", "Destroying session with %d unacked stanzas", #queue);
+	if queue and queue:count_unacked() > 0 then
+		session.log("warn", "Destroying session with %d unacked stanzas", queue:count_unacked());
 		if s2s_resend then
-			for i = 1, #queue do
-				module:send(queue[i]);
+			for stanza in queue:consume() do
+				module:send(stanza);
 			end
 			session.outgoing_stanza_queue = nil;
 		else
@@ -505,6 +512,7 @@
 			c2s_sessions[conn] = nil;
 			conn:close();
 		end
+
 		local migrated_session_log = session.log;
 		original_session.ip = session.ip;
 		original_session.conn = session.conn;
@@ -530,33 +538,46 @@
 		-- Similar for connlisteners
 		c2s_sessions[session.conn] = original_session;
 
+		local queue = original_session.outgoing_stanza_queue;
+		local h = tonumber(stanza.attr.h);
+
+		original_session.log("debug", "Pre-resumption #queue = %d", queue:count_unacked())
+		local acked, err = ack_errors.coerce(queue:ack(h)); -- luacheck: ignore 211/acked
+
+		if not err and not queue:resumable() then
+			err = ack_errors.new("overflow");
+		end
+
+		if err or not queue:resumable() then
+			original_session.send(st.stanza("failed",
+				{ xmlns = xmlns_sm; h = format_h(original_session.handled_stanza_count); previd = id }));
+			original_session:close(err);
+			return false;
+		end
+
 		original_session.send(st.stanza("resumed", { xmlns = xmlns_sm,
 			h = format_h(original_session.handled_stanza_count), previd = id }));
 
-		-- Fake an <a> with the h of the <resume/> from the client
-		original_session:dispatch_stanza(st.stanza("a", { xmlns = xmlns_sm,
-			h = stanza.attr.h }));
-
 		-- Ok, we need to re-send any stanzas that the client didn't see
 		-- ...they are what is now left in the outgoing stanza queue
 		-- We have to use the send of "session" because we don't want to add our resent stanzas
 		-- to the outgoing queue again
-		local queue = original_session.outgoing_stanza_queue;
-		session.log("debug", "resending all unacked stanzas that are still queued after resume, #queue = %d", #queue);
+
+		session.log("debug", "resending all unacked stanzas that are still queued after resume, #queue = %d", queue:count_unacked());
 		-- FIXME Which session is it that the queue filter sees?
 		session.resending_unacked = true;
 		original_session.resending_unacked = true;
-		for i=1,#queue do
-			session.send(queue[i]);
+		for _, queued_stanza in queue:resume() do
+			session.send(queued_stanza);
 		end
 		session.resending_unacked = nil;
 		original_session.resending_unacked = nil;
-		session.log("debug", "all stanzas resent, now disabling send() in this migrated session, #queue = %d", #queue);
+		session.log("debug", "all stanzas resent, now disabling send() in this migrated session, #queue = %d", queue:count_unacked());
 		function session.send(stanza) -- luacheck: ignore 432
 			migrated_session_log("error", "Tried to send stanza on old session migrated by smacks resume (maybe there is a bug?): %s", tostring(stanza));
 			return false;
 		end
-		module:fire_event("smacks-hibernation-end", {origin = session, resumed = original_session, queue = queue});
+		module:fire_event("smacks-hibernation-end", {origin = session, resumed = original_session, queue = queue:table()});
 		request_ack_if_needed(original_session, true, "handle_resume", nil);
 	end
 	return true;