util.stanza: Add add_error() to simplify adding error tags to existing stanzas
authorMatthew Wild <mwild1@gmail.com>
Mon, 29 Aug 2022 14:59:46 +0100
changeset 12691 5b69ecaf3427
parent 12690 5f182bccf33f
child 12692 36ba170c4fd0
util.stanza: Add add_error() to simplify adding error tags to existing stanzas Some fiddling is required now in error_reply() to ensure the cursor is in the same place as before this change (a lot of code apparently uses that feature).
spec/util_stanza_spec.lua
util/stanza.lua
--- a/spec/util_stanza_spec.lua	Sun Aug 28 07:51:50 2022 +0100
+++ b/spec/util_stanza_spec.lua	Mon Aug 29 14:59:46 2022 +0100
@@ -314,6 +314,20 @@
 		end)
 	end)
 
+	describe("#add_error()", function ()
+		describe("basics", function ()
+			local s = st.stanza("custom", { xmlns = "urn:example:foo" });
+			local e = s:add_error("cancel", "not-acceptable", "UNACCEPTABLE!!!! ONE MILLION YEARS DUNGEON!")
+				:tag("dungeon", { xmlns = "urn:uuid:c9026187-5b05-4e70-b265-c3b6338a7d0f", period="1000000years"});
+			assert.equal(s, e);
+			local typ, cond, text, extra = e:get_error();
+			assert.equal("cancel", typ);
+			assert.equal("not-acceptable", cond);
+			assert.equal("UNACCEPTABLE!!!! ONE MILLION YEARS DUNGEON!", text);
+			assert.is_nil(extra);
+		end)
+	end)
+
 	describe("should reject #invalid", function ()
 		local invalid_names = {
 			["empty string"] = "", ["characters"] = "<>";
--- a/util/stanza.lua	Sun Aug 28 07:51:50 2022 +0100
+++ b/util/stanza.lua	Mon Aug 29 14:59:46 2022 +0100
@@ -29,6 +29,7 @@
 local do_pretty_printing, termcolours = pcall(require, "util.termcolours");
 
 local xmlns_stanzas = "urn:ietf:params:xml:ns:xmpp-stanzas";
+local xmpp_stanzas_attr = { xmlns = xmlns_stanzas };
 
 local _ENV = nil;
 -- luacheck: std none
@@ -396,6 +397,33 @@
 	return error_type, condition or "undefined-condition", text, extra_tag;
 end
 
+function stanza_mt.add_error(stanza, error_type, condition, error_message, error_by)
+	local extra;
+	if type(error_type) == "table" then -- an util.error or similar object
+		if type(error_type.extra) == "table" then
+			extra = error_type.extra;
+		end
+		if type(error_type.context) == "table" and type(error_type.context.by) == "string" then error_by = error_type.context.by; end
+		error_type, condition, error_message = error_type.type, error_type.condition, error_type.text;
+	end
+	if stanza.attr.from == error_by then
+		error_by = nil;
+	end
+	stanza:tag("error", {type = error_type, by = error_by}) --COMPAT: Some day xmlns:stanzas goes here
+	:tag(condition, xmpp_stanzas_attr);
+	if extra and condition == "gone" and type(extra.uri) == "string" then
+		stanza:text(extra.uri);
+	end
+	stanza:up();
+	if error_message then stanza:text_tag("text", error_message, xmpp_stanzas_attr); end
+	if extra and is_stanza(extra.tag) then
+		stanza:add_child(extra.tag);
+	elseif extra and extra.namespace and extra.condition then
+		stanza:tag(extra.condition, { xmlns = extra.namespace }):up();
+	end
+	return stanza:up();
+end
+
 local function preserialize(stanza)
 	local s = { name = stanza.name, attr = stanza.attr };
 	for _, child in ipairs(stanza) do
@@ -470,7 +498,6 @@
 		});
 end
 
-local xmpp_stanzas_attr = { xmlns = xmlns_stanzas };
 local function error_reply(orig, error_type, condition, error_message, error_by)
 	if not is_stanza(orig) then
 		error("bad argument to error_reply: expected stanza, got "..type(orig));
@@ -479,30 +506,9 @@
 	end
 	local t = reply(orig);
 	t.attr.type = "error";
-	local extra;
-	if type(error_type) == "table" then -- an util.error or similar object
-		if type(error_type.extra) == "table" then
-			extra = error_type.extra;
-		end
-		if type(error_type.context) == "table" and type(error_type.context.by) == "string" then error_by = error_type.context.by; end
-		error_type, condition, error_message = error_type.type, error_type.condition, error_type.text;
-	end
-	if t.attr.from == error_by then
-		error_by = nil;
-	end
-	t:tag("error", {type = error_type, by = error_by}) --COMPAT: Some day xmlns:stanzas goes here
-	:tag(condition, xmpp_stanzas_attr);
-	if extra and condition == "gone" and type(extra.uri) == "string" then
-		t:text(extra.uri);
-	end
-	t:up();
-	if error_message then t:text_tag("text", error_message, xmpp_stanzas_attr); end
-	if extra and is_stanza(extra.tag) then
-		t:add_child(extra.tag);
-	elseif extra and extra.namespace and extra.condition then
-		t:tag(extra.condition, { xmlns = extra.namespace }):up();
-	end
-	return t; -- stanza ready for adding app-specific errors
+	t:add_error(error_type, condition, error_message, error_by);
+	t.last_add = { t[1] }; -- ready to add application-specific errors
+	return t;
 end
 
 local function presence(attr)