mod_firewall: General stanza filtering plugin with a declarative rule-based syntax
--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/mod_firewall/actions.lib.lua Wed Apr 03 16:11:20 2013 +0100
@@ -0,0 +1,158 @@
+local action_handlers = {};
+
+-- Takes an XML string and returns a code string that builds that stanza
+-- using st.stanza()
+local function compile_xml(data)
+ local code = {};
+ local first, short_close = true, nil;
+ for tagline, text in data:gmatch("<([^>]+)>([^<]*)") do
+ if tagline:sub(-1,-1) == "/" then
+ tagline = tagline:sub(1, -2);
+ short_close = true;
+ end
+ if tagline:sub(1,1) == "/" then
+ code[#code+1] = (":up()");
+ else
+ local name, attr = tagline:match("^(%S*)%s*(.*)$");
+ local attr_str = {};
+ for k, _, v in attr:gmatch("(%S+)=([\"'])([^%2]-)%2") do
+ if #attr_str == 0 then
+ table.insert(attr_str, ", { ");
+ else
+ table.insert(attr_str, ", ");
+ end
+ if k:match("^%a%w*$") then
+ table.insert(attr_str, string.format("%s = %q", k, v));
+ else
+ table.insert(attr_str, string.format("[%q] = %q", k, v));
+ end
+ end
+ if #attr_str > 0 then
+ table.insert(attr_str, " }");
+ end
+ if first then
+ code[#code+1] = (string.format("st.stanza(%q %s)", name, #attr_str>0 and table.concat(attr_str) or ", nil"));
+ first = nil;
+ else
+ code[#code+1] = (string.format(":tag(%q%s)", name, table.concat(attr_str)));
+ end
+ end
+ if text and text:match("%S") then
+ code[#code+1] = (string.format(":text(%q)", text));
+ elseif short_close then
+ short_close = nil;
+ code[#code+1] = (":up()");
+ end
+ end
+ return table.concat(code, "");
+end
+
+
+function action_handlers.DROP()
+ return "log('debug', 'Firewall dropping stanza: %s', tostring(stanza)); return true;";
+end
+
+function action_handlers.STRIP(tag_desc)
+ local code = {};
+ local name, xmlns = tag_desc:match("^(%S+) (.+)$");
+ if not name then
+ name, xmlns = tag_desc, nil;
+ end
+ if name == "*" then
+ name = nil;
+ end
+ code[#code+1] = ("local stanza_xmlns = stanza.attr.xmlns; ");
+ code[#code+1] = "stanza:maptags(function (tag) if ";
+ if name then
+ code[#code+1] = ("tag.name == %q and "):format(name);
+ end
+ if xmlns then
+ code[#code+1] = ("(tag.attr.xmlns or stanza_xmlns) == %q "):format(xmlns);
+ else
+ code[#code+1] = ("tag.attr.xmlns == stanza_xmlns ");
+ end
+ code[#code+1] = "then return nil; end return tag; end );";
+ return table.concat(code);
+end
+
+function action_handlers.INJECT(tag)
+ return "stanza:add_child("..compile_xml(tag)..")", { "st" };
+end
+
+local error_types = {
+ ["bad-request"] = "modify";
+ ["conflict"] = "cancel";
+ ["feature-not-implemented"] = "cancel";
+ ["forbidden"] = "auth";
+ ["gone"] = "cancel";
+ ["internal-server-error"] = "cancel";
+ ["item-not-found"] = "cancel";
+ ["jid-malformed"] = "modify";
+ ["not-acceptable"] = "modify";
+ ["not-allowed"] = "cancel";
+ ["not-authorized"] = "auth";
+ ["payment-required"] = "auth";
+ ["policy-violation"] = "modify";
+ ["recipient-unavailable"] = "wait";
+ ["redirect"] = "modify";
+ ["registration-required"] = "auth";
+ ["remote-server-not-found"] = "cancel";
+ ["remote-server-timeout"] = "wait";
+ ["resource-constraint"] = "wait";
+ ["service-unavailable"] = "cancel";
+ ["subscription-required"] = "auth";
+ ["undefined-condition"] = "cancel";
+ ["unexpected-request"] = "wait";
+};
+
+
+local function route_modify(make_new, to, drop)
+ local reroute, deps = "session.send(newstanza)", { "st" };
+ if to then
+ reroute = ("newstanza.attr.to = %q; core_post_stanza(session, newstanza)"):format(to);
+ deps[#deps+1] = "core_post_stanza";
+ end
+ return ([[local newstanza = st.%s; %s; %s; ]])
+ :format(make_new, reroute, drop and "return true" or ""), deps;
+end
+
+function action_handlers.BOUNCE(with)
+ local error = with and with:match("^%S+") or "service-unavailable";
+ local error_type = error:match(":(%S+)");
+ if not error_type then
+ error_type = error_types[error] or "cancel";
+ else
+ error = error:match("^[^:]+");
+ end
+ error, error_type = string.format("%q", error), string.format("%q", error_type);
+ local text = with and with:match(" %((.+)%)$");
+ if text then
+ text = string.format("%q", text);
+ else
+ text = "nil";
+ end
+ return route_modify(("error_reply(stanza, %s, %s, %s)"):format(error_type, error, text), nil, true);
+end
+
+function action_handlers.REDIRECT(where)
+ return route_modify("clone(stanza)", where, true, true);
+end
+
+function action_handlers.COPY(where)
+ return route_modify("clone(stanza)", where, true, false);
+end
+
+function action_handlers.LOG(string)
+ local level = string:match("^%[(%a+)%]") or "info";
+ string = string:gsub("^%[%a+%] ?", "");
+ return (("log(%q, %q)"):format(level, string)
+ :gsub("$top", [["..stanza:top_tag().."]])
+ :gsub("$stanza", [["..stanza.."]])
+ :gsub("$(%b())", [["..%1.."]]));
+end
+
+function action_handlers.RULEDEP(dep)
+ return "", { dep };
+end
+
+return action_handlers;
--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/mod_firewall/conditions.lib.lua Wed Apr 03 16:11:20 2013 +0100
@@ -0,0 +1,94 @@
+local condition_handlers = {};
+
+local jid = require "util.jid";
+
+-- Return a code string for a condition that checks whether the contents
+-- of variable with the name 'name' matches any of the values in the
+-- comma/space/pipe delimited list 'values'.
+local function compile_comparison_list(name, values)
+ local conditions = {};
+ for value in values:gmatch("[^%s,|]+") do
+ table.insert(conditions, ("%s == %q"):format(name, value));
+ end
+ return table.concat(conditions, " or ");
+end
+
+function condition_handlers.KIND(kind)
+ return compile_comparison_list("name", kind), { "name" };
+end
+
+local wildcard_equivs = { ["*"] = ".*", ["?"] = "." };
+
+local function compile_jid_match_part(part, match)
+ if not match then
+ return part.." == nil"
+ end
+ local pattern = match:match("<(.*)>");
+ -- TODO: Support Lua pattern matching (main issue syntax... << >>?)
+ if pattern then
+ if pattern ~= "*" then
+ return ("%s:match(%q)"):format(part, pattern:gsub(".", wildcard_equivs));
+ end
+ else
+ return ("%s == %q"):format(part, match);
+ end
+end
+
+local function compile_jid_match(which, match_jid)
+ local match_node, match_host, match_resource = jid.split(match_jid);
+ local conditions = {
+ compile_jid_match_part(which.."_node", match_node);
+ compile_jid_match_part(which.."_host", match_host);
+ match_resource and compile_jid_match_part(which.."_resource", match_resource) or nil;
+ };
+ return table.concat(conditions, " and ");
+end
+
+function condition_handlers.TO(to)
+ return compile_jid_match("to", to), { "split_to" };
+end
+
+function condition_handlers.FROM(from)
+ return compile_jid_match("from", from), { "split_from" };
+end
+
+function condition_handlers.TYPE(type)
+ return compile_comparison_list("type", type), { "type" };
+end
+
+function condition_handlers.ENTERING(zone)
+ return ("(zones[%q] and (zones[%q][to_host] or "
+ .."zones[%q][to] or "
+ .."zones[%q][bare_to]))"
+ )
+ :format(zone, zone, zone, zone), { "split_to", "bare_to" };
+end
+
+function condition_handlers.LEAVING(zone)
+ return ("zones[%q] and (zones[%q][from_host] or "
+ .."(zones[%q][from] or "
+ .."zones[%q][bare_from]))")
+ :format(zone, zone, zone, zone), { "split_from", "bare_from" };
+end
+
+function condition_handlers.PAYLOAD(payload_ns)
+ return ("stanza:get_child(nil, %q)"):format(payload_ns);
+end
+
+function condition_handlers.FROM_GROUP(group_name)
+ return ("group_contains(%q, bare_from)"):format(group_name), { "group_contains", "bare_from" };
+end
+
+function condition_handlers.TO_GROUP(group_name)
+ return ("group_contains(%q, bare_to)"):format(group_name), { "group_contains", "bare_to" };
+end
+
+function condition_handlers.FROM_ADMIN_OF(host)
+ return ("is_admin(bare_from, %s)"):format(host ~= "*" and host or nil), { "is_admin", "bare_from" };
+end
+
+function condition_handlers.TO_ADMIN_OF(host)
+ return ("is_admin(bare_to, %s)"):format(host ~= "*" and host or nil), { "is_admin", "bare_to" };
+end
+
+return condition_handlers;
--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/mod_firewall/mod_firewall.lua Wed Apr 03 16:11:20 2013 +0100
@@ -0,0 +1,271 @@
+
+local resolve_relative_path = require "core.configmanager".resolve_relative_path;
+local logger = require "util.logger".init;
+local set = require "util.set";
+local add_filter = require "util.filters".add_filter;
+
+
+zones = {};
+local zones = zones;
+setmetatable(zones, {
+ __index = function (zones, zone)
+ local t = { [zone] = true };
+ rawset(zones, zone, t);
+ return t;
+ end;
+});
+
+local chains = {
+ preroute = {
+ type = "event";
+ priority = 0.1;
+ "pre-message/bare", "pre-message/full", "pre-message/host";
+ "pre-presence/bare", "pre-presence/full", "pre-presence/host";
+ "pre-iq/bare", "pre-iq/full", "pre-iq/host";
+ };
+ deliver = {
+ type = "event";
+ priority = 0.1;
+ "message/bare", "message/full", "message/host";
+ "presence/bare", "presence/full", "presence/host";
+ "iq/bare", "iq/full", "iq/host";
+ };
+ deliver_remote = {
+ type = "event"; "route/remote";
+ priority = 0.1;
+ };
+};
+
+-- Dependency locations:
+-- <type lib>
+-- <type global>
+-- function handler()
+-- <local deps>
+-- if <conditions> then
+-- <actions>
+-- end
+-- end
+
+local available_deps = {
+ st = { global_code = [[local st = require "util.stanza"]]};
+ jid_split = {
+ global_code = [[local jid_split = require "util.jid".split;]];
+ };
+ jid_bare = {
+ global_code = [[local jid_bare = require "util.jid".bare;]];
+ };
+ to = { local_code = [[local to = stanza.attr.to;]] };
+ from = { local_code = [[local from = stanza.attr.from;]] };
+ type = { local_code = [[local type = stanza.attr.type;]] };
+ name = { local_code = [[local name = stanza.name]] };
+ split_to = { -- The stanza's split to address
+ depends = { "jid_split", "to" };
+ local_code = [[local to_node, to_host, to_resource = jid_split(to);]];
+ };
+ split_from = { -- The stanza's split from address
+ depends = { "jid_split", "from" };
+ local_code = [[local from_node, from_host, from_resource = jid_split(from);]];
+ };
+ bare_to = { depends = { "jid_bare", "to" }, local_code = "local bare_to = jid_bare(to)"};
+ bare_from = { depends = { "jid_bare", "from" }, local_code = "local bare_from = jid_bare(from)"};
+ group_contains = {
+ global_code = [[local group_contains = module:depends("groups").group_contains]];
+ };
+ is_admin = { global_code = [[local is_admin = require "core.usermanager".is_admin]]};
+ core_post_stanza = { global_code = [[local core_post_stanza = prosody.core_post_stanza]] };
+};
+
+local function include_dep(dep, code)
+ local dep_info = available_deps[dep];
+ if not dep_info then
+ module:log("error", "Dependency not found: %s", dep);
+ return;
+ end
+ if code.included_deps[dep] then
+ if code.included_deps[dep] ~= true then
+ module:log("error", "Circular dependency on %s", dep);
+ end
+ return;
+ end
+ code.included_deps[dep] = false; -- Pending flag (used to detect circular references)
+ for _, dep_dep in ipairs(dep_info.depends or {}) do
+ include_dep(dep_dep, code);
+ end
+ if dep_info.global_code then
+ table.insert(code.global_header, dep_info.global_code);
+ end
+ if dep_info.local_code then
+ table.insert(code, "\n\t-- "..dep.."\n\t"..dep_info.local_code.."\n\n\t");
+ end
+ code.included_deps[dep] = true;
+end
+
+local condition_handlers = module:require("conditions");
+local action_handlers = module:require("actions");
+
+local function new_rule(ruleset, chain)
+ assert(chain, "no chain specified");
+ local rule = { conditions = {}, actions = {}, deps = {} };
+ table.insert(ruleset[chain], rule);
+ return rule;
+end
+
+local function compile_firewall_rules(filename)
+ local line_no = 0;
+
+ local ruleset = {
+ deliver = {};
+ };
+
+ local chain = "deliver"; -- Default chain
+ local rule;
+
+ local file, err = io.open(filename);
+ if not file then return nil, err; end
+
+ local state; -- nil -> "rules" -> "actions" -> nil -> ...
+
+ local line_hold;
+ for line in file:lines() do
+ line = line:match("^%s*(.-)%s*$");
+ if line_hold and line:sub(-1,-1) ~= "\\" then
+ line = line_hold..line;
+ line_hold = nil;
+ elseif line:sub(-1,-1) == "\\" then
+ line_hold = (line_hold or "")..line:sub(1,-2);
+ end
+ line_no = line_no + 1;
+
+ if line_hold or line:match("^[#;]") then
+ -- No action; comment or partial line
+ elseif line == "" then
+ if state == "rules" then
+ return nil, ("Expected an action on line %d for preceding criteria")
+ :format(line_no);
+ end
+ state = nil;
+ elseif not(state) and line:match("^::") then
+ chain = line:gsub("^::%s*", "");
+ ruleset[chain] = ruleset[chain] or {};
+ elseif not(state) and line:match("^ZONE ") then
+ local zone_name = line:match("^ZONE ([^:]+)");
+ local zone_members = line:match("^ZONE .-: ?(.*)");
+ local zone_member_list = {};
+ for member in zone_members:gmatch("[^, ]+") do
+ zone_member_list[#zone_member_list+1] = member;
+ end
+ zones[zone_name] = set.new(zone_member_list)._items;
+ elseif line:match("^[^%s:]+[%.=]") then
+ -- Action
+ if state == nil then
+ -- This is a standalone action with no conditions
+ rule = new_rule(ruleset, chain);
+ end
+ state = "actions";
+ -- Action handlers?
+ local action = line:match("^%P+");
+ if not action_handlers[action] then
+ return nil, ("Unknown action on line %d: %s"):format(line_no, action or "<unknown>");
+ end
+ table.insert(rule.actions, "-- "..line)
+ local action_string, action_deps = action_handlers[action](line:match("=(.+)$"));
+ table.insert(rule.actions, action_string);
+ for _, dep in ipairs(action_deps or {}) do
+ table.insert(rule.deps, dep);
+ end
+ elseif state == "actions" then -- state is actions but action pattern did not match
+ state = nil; -- Awaiting next rule, etc.
+ table.insert(ruleset[chain], rule);
+ rule = nil;
+ else
+ if not state then
+ state = "rules";
+ rule = new_rule(ruleset, chain);
+ end
+ -- Check standard modifiers for the condition (e.g. NOT)
+ local negated;
+ local condition = line:match("^[^:=%.]*");
+ if condition:match("%f[%w]NOT%f[^%w]") then
+ local s, e = condition:match("%f[%w]()NOT()%f[^%w]");
+ condition = (condition:sub(1,s-1)..condition:sub(e+1, -1)):match("^%s*(.-)%s*$");
+ negated = true;
+ end
+ condition = condition:gsub(" ", "");
+ if not condition_handlers[condition] then
+ return nil, ("Unknown condition on line %d: %s"):format(line_no, condition);
+ end
+ -- Get the code for this condition
+ local condition_code, condition_deps = condition_handlers[condition](line:match(":%s?(.+)$"));
+ if negated then condition_code = "not("..condition_code..")"; end
+ table.insert(rule.conditions, condition_code);
+ for _, dep in ipairs(condition_deps or {}) do
+ table.insert(rule.deps, dep);
+ end
+ end
+ end
+
+ -- Compile ruleset and return complete code
+
+ local chain_handlers = {};
+
+ -- Loop through the chains in the parsed ruleset (e.g. incoming, outgoing)
+ for chain_name, rules in pairs(ruleset) do
+ local code = { included_deps = {}, global_header = {} };
+ -- This inner loop assumes chain is an event-based, not a filter-based
+ -- chain (filter-based will be added later)
+ for _, rule in ipairs(rules) do
+ for _, dep in ipairs(rule.deps) do
+ include_dep(dep, code);
+ end
+ local rule_code = "if ("..table.concat(rule.conditions, ") and (")..") then\n\t"
+ ..table.concat(rule.actions, "\n\t")
+ .."\n end\n";
+ table.insert(code, rule_code);
+ end
+
+ assert(chains[chain_name].type == "event", "Only event chains supported at the moment")
+
+ local code_string = [[return function (zones, log)
+ ]]..table.concat(code.global_header, "\n")..[[
+ local db = require 'util.debug'
+ return function (event)
+ local stanza, session = event.stanza, event.origin;
+
+ ]]..table.concat(code, " ")..[[
+ end;
+ end]];
+
+ print(code_string)
+
+ -- Prepare event handler function
+ local chunk, err = loadstring(code_string, "="..filename);
+ if not chunk then
+ return nil, "Error compiling (probably a compiler bug, please report): "..err;
+ end
+ chunk = chunk()(zones, logger(filename)); -- Returns event handler with 'zones' upvalue.
+ chain_handlers[chain_name] = chunk;
+ end
+
+ return chain_handlers;
+end
+
+function module.load()
+ local firewall_scripts = module:get_option_set("firewall_scripts", {});
+ for script in firewall_scripts do
+ script = resolve_relative_path(script) or script;
+ local chain_functions, err = compile_firewall_rules(script)
+
+ if not chain_functions then
+ module:log("error", "Error compiling %s: %s", script, err or "unknown error");
+ else
+ for chain, handler in pairs(chain_functions) do
+ local chain_definition = chains[chain];
+ if chain_definition.type == "event" then
+ for _, event_name in ipairs(chain_definition) do
+ module:hook(event_name, handler, chain_definition.priority);
+ end
+ end
+ end
+ end
+ end
+end