plugins/mod_smacks.lua
changeset 12806 4a8740e01813
parent 12804 06ba2f8cee47
child 12807 2e12290820e8
--- a/plugins/mod_smacks.lua	Mon Dec 12 20:40:23 2022 +0100
+++ b/plugins/mod_smacks.lua	Mon Dec 12 07:10:54 2022 +0100
@@ -2,7 +2,7 @@
 --
 -- Copyright (C) 2010-2015 Matthew Wild
 -- Copyright (C) 2010 Waqas Hussain
--- Copyright (C) 2012-2021 Kim Alvefur
+-- Copyright (C) 2012-2022 Kim Alvefur
 -- Copyright (C) 2012 Thijs Alkemade
 -- Copyright (C) 2014 Florian Zeitz
 -- Copyright (C) 2016-2020 Thilo Molitor
@@ -10,6 +10,7 @@
 -- This project is MIT/X11 licensed. Please see the
 -- COPYING file in the source package for more information.
 --
+-- TODO unify sendq and smqueue
 
 local tonumber = tonumber;
 local tostring = tostring;
@@ -83,6 +84,22 @@
 local old_session_registry = module:open_store("smacks_h", "map");
 local session_registry = module:shared "/*/smacks/resumption-tokens"; -- > user@host/resumption-token --> resource
 
+local function track_session(session, id)
+	session_registry[jid.join(session.username, session.host, id or session.resumption_token)] = session;
+	session.resumption_token = id;
+end
+
+local function save_old_session(session)
+	session_registry[jid.join(session.username, session.host, session.resumption_token)] = nil;
+	return old_session_registry:set(session.username, session.resumption_token,
+		{ h = session.handled_stanza_count; t = os.time() })
+end
+
+local function clear_old_session(session, id)
+	session_registry[jid.join(session.username, session.host, id or session.resumption_token)] = nil;
+	return old_session_registry:set(session.username, id or session.resumption_token, nil)
+end
+
 local ack_errors = require"util.error".init("mod_smacks", xmlns_sm3, {
 	head = { condition = "undefined-condition"; text = "Client acknowledged more stanzas than sent by server" };
 	tail = { condition = "undefined-condition"; text = "Client acknowledged less stanzas than already acknowledged" };
@@ -90,6 +107,16 @@
 	overflow = { condition = "resource-constraint", text = "Too many unacked stanzas remaining, session can't be resumed" }
 });
 
+local enable_errors = require "util.error".init("mod_smacks", xmlns_sm3, {
+	already_enabled = { condition = "unexpected-request", text = "Stream management is already enabled" };
+	bind_required = { condition = "unexpected-request", text = "Client must bind a resource before enabling stream management" };
+	unavailable = { condition = "service-unavailable", text = "Stream management is not available for this stream" };
+	-- Resumption
+	expired = { condition = "item-not-found", text = "Session expired, and cannot be resumed" };
+	already_bound = { condition = "unexpected-request", text = "Cannot resume another session after a resource is bound" };
+	unknown_session = { condition = "item-not-found", text = "Unknown session" };
+});
+
 -- COMPAT note the use of compatibility wrapper in events (queue:table())
 
 local function ack_delayed(session, stanza)
@@ -104,18 +131,18 @@
 end
 
 local function can_do_smacks(session, advertise_only)
-	if session.smacks then return false, "unexpected-request", "Stream management is already enabled"; end
+	if session.smacks then return false, enable_errors.new("already_enabled"); end
 
 	local session_type = session.type;
 	if session.username then
 		if not(advertise_only) and not(session.resource) then -- Fail unless we're only advertising sm
-			return false, "unexpected-request", "Client must bind a resource before enabling stream management";
+			return false, enable_errors.new("bind_required");
 		end
 		return true;
 	elseif s2s_smacks and (session_type == "s2sin" or session_type == "s2sout") then
 		return true;
 	end
-	return false, "service-unavailable", "Stream management is not available for this stream";
+	return false, enable_errors.new("unavailable");
 end
 
 module:hook("stream-features",
@@ -155,13 +182,12 @@
 
 local function request_ack(session, reason)
 	local queue = session.outgoing_stanza_queue;
-	session.log("debug", "Sending <r> (inside timer, before send) from %s - #queue=%d", reason, queue:count_unacked());
+	session.log("debug", "Sending <r> from %s - #queue=%d", reason, queue:count_unacked());
 	session.awaiting_ack = true;
 	(session.sends2s or session.send)(st.stanza("r", { xmlns = session.smacks }))
 	if session.destroyed then return end -- sending something can trigger destruction
 	-- expected_h could be lower than this expression e.g. more stanzas added to the queue meanwhile)
 	session.last_requested_h = queue:count_acked() + queue:count_unacked();
-	session.log("debug", "Sending <r> (inside timer, after send) from %s - #queue=%d", reason, queue:count_unacked());
 	if not session.delayed_ack_timer then
 		session.delayed_ack_timer = timer.add_task(delayed_ack_timeout, function()
 			ack_delayed(session, nil); -- we don't know if this is the only new stanza in the queue
@@ -180,7 +206,6 @@
 	-- supposed to be nil.
 	-- However, when using mod_smacks with mod_websocket, then mod_websocket's
 	-- stanzas/out filter can get called before this one and adds the xmlns.
-	if session.resending_unacked then return stanza end
 	if not session.smacks then return stanza end
 	local is_stanza = st.is_stanza(stanza) and
 		(not stanza.attr.xmlns or stanza.attr.xmlns == 'jabber:client')
@@ -234,8 +259,7 @@
 	if session.smacks == nil then return end
 	if session.resumption_token then
 		session.log("debug", "Revoking resumption token");
-		session_registry[jid.join(session.username, session.host, session.resumption_token)] = nil;
-		old_session_registry:set(session.username, session.resumption_token, nil);
+		clear_old_session(session);
 		session.resumption_token = nil;
 	else
 		session.log("debug", "Session not resumable");
@@ -274,17 +298,16 @@
 	return session;
 end
 
-function handle_enable(session, stanza, xmlns_sm)
-	local ok, err, err_text = can_do_smacks(session);
+function do_enable(session, stanza)
+	local ok, err = can_do_smacks(session);
 	if not ok then
-		session.log("warn", "Failed to enable smacks: %s", err_text); -- TODO: XEP doesn't say we can send error text, should it?
-		(session.sends2s or session.send)(st.stanza("failed", { xmlns = xmlns_sm }):tag(err, { xmlns = xmlns_errors}));
-		return true;
+		session.log("warn", "Failed to enable smacks: %s", err.text); -- TODO: XEP doesn't say we can send error text, should it?
+		return nil, err;
 	end
 
 	if session.username then
 		local old_sessions, err = all_old_sessions:get(session.username);
-		module:log("debug", "Old sessions: %q", old_sessions)
+		session.log("debug", "Old sessions: %q", old_sessions)
 		if old_sessions then
 			local keep, count = {}, 0;
 			for token, info in it.sorted_pairs(old_sessions, function(a, b)
@@ -296,54 +319,73 @@
 			end
 			all_old_sessions:set(session.username, keep);
 		elseif err then
-			module:log("error", "Unable to retrieve old resumption counters: %s", err);
+			session.log("error", "Unable to retrieve old resumption counters: %s", err);
 		end
 	end
 
-	module:log("debug", "Enabling stream management");
-	session.smacks = xmlns_sm;
-
-	wrap_session(session, false);
-
-	local resume_max;
 	local resume_token;
 	local resume = stanza.attr.resume;
 	if (resume == "true" or resume == "1") and session.username then
 		-- resumption on s2s is not currently supported
 		resume_token = new_id();
-		session_registry[jid.join(session.username, session.host, resume_token)] = session;
-		session.resumption_token = resume_token;
-		resume_max = tostring(resume_timeout);
 	end
-	(session.sends2s or session.send)(st.stanza("enabled", { xmlns = xmlns_sm, id = resume_token, resume = resume, max = resume_max }));
+
+	return {
+		type = "enabled";
+		id = resume_token;
+		resume_max = resume_token and tostring(resume_timeout) or nil;
+		session = session;
+		finish = function ()
+			session.log("debug", "Enabling stream management");
+
+			session.smacks = stanza.attr.xmlns;
+			if resume_token then
+				track_session(session, resume_token);
+			end
+			wrap_session(session, false);
+		end;
+	};
+end
+
+function handle_enable(session, stanza, xmlns_sm)
+	local enabled, err = do_enable(session, stanza);
+	if not enabled then
+		(session.sends2s or session.send)(st.stanza("failed", { xmlns = xmlns_sm }):add_error(err));
+		return true;
+	end
+
+	(session.sends2s or session.send)(st.stanza("enabled", {
+		xmlns = xmlns_sm;
+		id = enabled.id;
+		resume = enabled.id and "true" or nil; -- COMPAT w/ Conversations 2.10.10 requires 'true' not '1'
+		max = enabled.resume_max;
+	}));
+
+	session.smacks = xmlns_sm;
+	enabled.finish();
+
 	return true;
 end
 module:hook_tag(xmlns_sm2, "enable", function (session, stanza) return handle_enable(session, stanza, xmlns_sm2); end, 100);
 module:hook_tag(xmlns_sm3, "enable", function (session, stanza) return handle_enable(session, stanza, xmlns_sm3); end, 100);
 
-module:hook_tag("http://etherx.jabber.org/streams", "features",
-		function (session, stanza)
-			-- Needs to be done after flushing sendq since those aren't stored as
-			-- stanzas and counting them is weird.
-			-- TODO unify sendq and smqueue
-			timer.add_task(1e-6, function ()
-				if can_do_smacks(session) then
-					if stanza:get_child("sm", xmlns_sm3) then
-						session.sends2s(st.stanza("enable", sm3_attr));
-						session.smacks = xmlns_sm3;
-					elseif stanza:get_child("sm", xmlns_sm2) then
-						session.sends2s(st.stanza("enable", sm2_attr));
-						session.smacks = xmlns_sm2;
-					else
-						return;
-					end
-					wrap_session_out(session, false);
-				end
-			end);
-		end);
+module:hook_tag("http://etherx.jabber.org/streams", "features", function(session, stanza)
+	if can_do_smacks(session) then
+		session.smacks_feature = stanza:get_child("sm", xmlns_sm3) or stanza:get_child("sm", xmlns_sm2);
+	end
+end);
+
+module:hook("s2sout-established", function (event)
+	local session = event.session;
+	if not session.smacks_feature then return end
+
+	session.smacks = session.smacks_feature.attr.xmlns;
+	wrap_session_out(session, false);
+	session.sends2s(st.stanza("enable", { xmlns = session.smacks }));
+end);
 
 function handle_enabled(session, stanza, xmlns_sm) -- luacheck: ignore 212/stanza
-	module:log("debug", "Enabling stream management");
+	session.log("debug", "Enabling stream management");
 	session.smacks = xmlns_sm;
 
 	wrap_session_in(session, false);
@@ -357,10 +399,10 @@
 
 function handle_r(origin, stanza, xmlns_sm) -- luacheck: ignore 212/stanza
 	if not origin.smacks then
-		module:log("debug", "Received ack request from non-smack-enabled session");
+		origin.log("debug", "Received ack request from non-smack-enabled session");
 		return;
 	end
-	module:log("debug", "Received ack request, acking for %d", origin.handled_stanza_count);
+	origin.log("debug", "Received ack request, acking for %d", origin.handled_stanza_count);
 	-- Reply with <a>
 	(origin.sends2s or origin.send)(st.stanza("a", { xmlns = xmlns_sm, h = format_h(origin.handled_stanza_count) }));
 	-- piggyback our own ack request if needed (see request_ack_if_needed() for explanation of last_requested_h)
@@ -413,13 +455,14 @@
 	local queue = session.outgoing_stanza_queue;
 	local unacked = queue:count_unacked()
 	if unacked > 0 then
+		local error_from = jid.join(session.username, session.host or module.host);
 		tx_dropped_stanzas:sample(unacked);
 		session.smacks = false; -- Disable queueing
 		session.outgoing_stanza_queue = nil;
 		for stanza in queue._queue:consume() do
 			if not module:fire_event("delivery/failure", { session = session, stanza = stanza }) then
 				if stanza.attr.type ~= "error" and stanza.attr.from ~= session.full_jid then
-					local reply = st.error_reply(stanza, "cancel", "recipient-unavailable");
+					local reply = st.error_reply(stanza, "cancel", "recipient-unavailable", nil, error_from);
 					module:send(reply);
 				end
 			end
@@ -486,11 +529,8 @@
 		end
 
 		session.log("debug", "Destroying session for hibernating too long");
-		session_registry[jid.join(session.username, session.host, session.resumption_token)] = nil;
-		old_session_registry:set(session.username, session.resumption_token,
-			{ h = session.handled_stanza_count; t = os.time() });
+		save_old_session(session);
 		session.resumption_token = nil;
-		session.resending_unacked = true; -- stop outgoing_stanza_filter from re-queueing anything anymore
 		sessionmanager.destroy_session(session, "Hibernating too long");
 		sessions_expired(1);
 	end);
@@ -523,17 +563,10 @@
 module:hook("s2sout-destroyed", handle_s2s_destroyed);
 module:hook("s2sin-destroyed", handle_s2s_destroyed);
 
-local function get_session_id(session)
-	return session.id or (tostring(session):match("[a-f0-9]+$"));
-end
-
-function handle_resume(session, stanza, xmlns_sm)
+function do_resume(session, stanza)
 	if session.full_jid then
 		session.log("warn", "Tried to resume after resource binding");
-		session.send(st.stanza("failed", { xmlns = xmlns_sm })
-			:tag("unexpected-request", { xmlns = xmlns_errors })
-		);
-		return true;
+		return nil, enable_errors.new("already_bound");
 	end
 
 	local id = stanza.attr.previd;
@@ -542,112 +575,98 @@
 		local old_session = old_session_registry:get(session.username, id);
 		if old_session then
 			session.log("debug", "Tried to resume old expired session with id %s", id);
-			session.send(st.stanza("failed", { xmlns = xmlns_sm, h = format_h(old_session.h) })
-				:tag("item-not-found", { xmlns = xmlns_errors })
-			);
-			old_session_registry:set(session.username, id, nil);
+			clear_old_session(session, id);
 			resumption_expired(1);
-		else
-			session.log("debug", "Tried to resume non-existent session with id %s", id);
-			session.send(st.stanza("failed", { xmlns = xmlns_sm })
-				:tag("item-not-found", { xmlns = xmlns_errors })
-			);
-		end;
-	else
-		if original_session.hibernating_watchdog then
-			original_session.log("debug", "Letting the watchdog go");
-			original_session.hibernating_watchdog:cancel();
-			original_session.hibernating_watchdog = nil;
-		elseif session.hibernating then
-			original_session.log("error", "Hibernating session has no watchdog!")
+			return nil, enable_errors.new("expired", { h = old_session.h });
 		end
-		-- zero age = was not hibernating yet
-		local age = 0;
-		if original_session.hibernating then
-			local now = os_time();
-			age = now - original_session.hibernating;
-		end
-		session.log("debug", "mod_smacks resuming existing session %s...", get_session_id(original_session));
-		original_session.log("debug", "mod_smacks session resumed from %s...", get_session_id(session));
-		-- TODO: All this should move to sessionmanager (e.g. session:replace(new_session))
-		if original_session.conn then
-			original_session.log("debug", "mod_smacks closing an old connection for this session");
-			local conn = original_session.conn;
-			c2s_sessions[conn] = nil;
-			conn:close();
-		end
+		session.log("debug", "Tried to resume non-existent session with id %s", id);
+		return nil, enable_errors.new("unknown_session");
+	end
+
+	if original_session.hibernating_watchdog then
+		original_session.log("debug", "Letting the watchdog go");
+		original_session.hibernating_watchdog:cancel();
+		original_session.hibernating_watchdog = nil;
+	elseif session.hibernating then
+		original_session.log("error", "Hibernating session has no watchdog!")
+	end
+	-- zero age = was not hibernating yet
+	local age = 0;
+	if original_session.hibernating then
+		local now = os_time();
+		age = now - original_session.hibernating;
+	end
+
+	session.log("debug", "mod_smacks resuming existing session %s...", original_session.id);
+
+	local queue = original_session.outgoing_stanza_queue;
+	local h = tonumber(stanza.attr.h);
+
+	original_session.log("debug", "Pre-resumption #queue = %d", queue:count_unacked())
+	local acked, err = ack_errors.coerce(queue:ack(h)); -- luacheck: ignore 211/acked
+
+	if not err and not queue:resumable() then
+		err = ack_errors.new("overflow");
+	end
+
+	if err then
+		session.log("debug", "Resumption failed: %s", err);
+		return nil, err;
+	end
+
+	-- Update original_session with the parameters (connection, etc.) from the new session
+	sessionmanager.update_session(original_session, session);
 
-		local migrated_session_log = session.log;
-		original_session.ip = session.ip;
-		original_session.conn = session.conn;
-		original_session.rawsend = session.rawsend;
-		original_session.rawsend.session = original_session;
-		original_session.rawsend.conn = original_session.conn;
-		original_session.send = session.send;
-		original_session.send.session = original_session;
-		original_session.close = session.close;
-		original_session.filter = session.filter;
-		original_session.filter.session = original_session;
-		original_session.filters = session.filters;
-		original_session.send.filter = original_session.filter;
-		original_session.stream = session.stream;
-		original_session.secure = session.secure;
-		original_session.hibernating = nil;
-		original_session.resumption_counter = (original_session.resumption_counter or 0) + 1;
-		session.log = original_session.log;
-		session.type = original_session.type;
-		wrap_session(original_session, true);
-		-- Inform xmppstream of the new session (passed to its callbacks)
-		original_session.stream:set_session(original_session);
-		-- Similar for connlisteners
-		c2s_sessions[session.conn] = original_session;
+	return {
+		type = "resumed";
+		session = original_session;
+		id = id;
+		-- Return function to complete the resumption and resync unacked stanzas
+		-- This is two steps so we can support SASL2/ISR
+		finish = function ()
+			-- Ok, we need to re-send any stanzas that the client didn't see
+			-- ...they are what is now left in the outgoing stanza queue
+			-- We have to use the send of "session" because we don't want to add our resent stanzas
+			-- to the outgoing queue again
 
-		local queue = original_session.outgoing_stanza_queue;
-		local h = tonumber(stanza.attr.h);
+			original_session.log("debug", "resending all unacked stanzas that are still queued after resume, #queue = %d", queue:count_unacked());
+			for _, queued_stanza in queue:resume() do
+				original_session.send(queued_stanza);
+			end
+			original_session.log("debug", "all stanzas resent, enabling stream management on resumed stream, #queue = %d", queue:count_unacked());
 
-		original_session.log("debug", "Pre-resumption #queue = %d", queue:count_unacked())
-		local acked, err = ack_errors.coerce(queue:ack(h)); -- luacheck: ignore 211/acked
-
-		if not err and not queue:resumable() then
-			err = ack_errors.new("overflow");
-		end
+			-- Add our own handlers to the resumed session (filters have been reset in the update)
+			wrap_session(original_session, true);
 
-		if err or not queue:resumable() then
-			original_session.send(st.stanza("failed",
-				{ xmlns = xmlns_sm; h = format_h(original_session.handled_stanza_count); previd = id }));
-			original_session:close(err);
-			return false;
-		end
-
-		original_session.send(st.stanza("resumed", { xmlns = xmlns_sm,
-			h = format_h(original_session.handled_stanza_count), previd = id }));
-
-		-- Ok, we need to re-send any stanzas that the client didn't see
-		-- ...they are what is now left in the outgoing stanza queue
-		-- We have to use the send of "session" because we don't want to add our resent stanzas
-		-- to the outgoing queue again
+			-- Let everyone know that we are no longer hibernating
+			module:fire_event("smacks-hibernation-end", {origin = session, resumed = original_session, queue = queue:table()});
+			original_session.awaiting_ack = nil; -- Don't wait for acks from before the resumption
+			request_ack_now_if_needed(original_session, true, "handle_resume", nil);
+			resumption_age:sample(age);
+		end;
+	};
+end
 
-		session.log("debug", "resending all unacked stanzas that are still queued after resume, #queue = %d", queue:count_unacked());
-		-- FIXME Which session is it that the queue filter sees?
-		session.resending_unacked = true;
-		original_session.resending_unacked = true;
-		for _, queued_stanza in queue:resume() do
-			session.send(queued_stanza);
-		end
-		session.resending_unacked = nil;
-		original_session.resending_unacked = nil;
-		session.log("debug", "all stanzas resent, now disabling send() in this migrated session, #queue = %d", queue:count_unacked());
-		function session.send(stanza) -- luacheck: ignore 432
-			migrated_session_log("error", "Tried to send stanza on old session migrated by smacks resume (maybe there is a bug?): %s", tostring(stanza));
-			return false;
-		end
-		module:fire_event("smacks-hibernation-end", {origin = session, resumed = original_session, queue = queue:table()});
-		original_session.awaiting_ack = nil; -- Don't wait for acks from before the resumption
-		request_ack_now_if_needed(original_session, true, "handle_resume", nil);
-		resumption_age:sample(age);
+function handle_resume(session, stanza, xmlns_sm)
+	local resumed, err = do_resume(session, stanza);
+	if not resumed then
+		session.send(st.stanza("failed", { xmlns = xmlns_sm, h = format_h(err.context.h) })
+			:tag(err.condition, { xmlns = xmlns_errors }));
+		return true;
 	end
+
+	session = resumed.session;
+
+	-- Inform client of successful resumption
+	session.send(st.stanza("resumed", { xmlns = xmlns_sm,
+		h = format_h(session.handled_stanza_count), previd = resumed.id }));
+
+	-- Complete resume (sync stanzas, etc.)
+	resumed.finish();
+
 	return true;
 end
+
 module:hook_tag(xmlns_sm2, "resume", function (session, stanza) return handle_resume(session, stanza, xmlns_sm2); end);
 module:hook_tag(xmlns_sm3, "resume", function (session, stanza) return handle_resume(session, stanza, xmlns_sm3); end);
 
@@ -702,8 +721,7 @@
 	for _, user in pairs(local_sessions) do
 		for _, session in pairs(user.sessions) do
 			if session.resumption_token then
-				if old_session_registry:set(session.username, session.resumption_token,
-					{ h = session.handled_stanza_count; t = os.time() }) then
+				if save_old_session(session) then
 					session.resumption_token = nil;
 
 					-- Deal with unacked stanzas