--- a/mod_firewall/mod_firewall.lua Sat Feb 25 16:53:45 2017 +0000
+++ b/mod_firewall/mod_firewall.lua Sat Feb 25 16:54:52 2017 +0000
@@ -1,7 +1,9 @@
+local lfs = require "lfs";
local resolve_relative_path = require "core.configmanager".resolve_relative_path;
local logger = require "util.logger".init;
local it = require "util.iterators";
+local set = require "util.set";
local definitions = module:shared("definitions");
local active_definitions = {
@@ -549,45 +551,110 @@
return resolve_relative_path(relative_to, script_path);
end
+-- [filename] = { last_modified = ..., events_hooked = { [name] = handler } }
+local loaded_scripts = {};
+
function load_script(script)
script = resolve_script_path(script);
- local chain_functions, err = compile_firewall_rules(script)
+ local last_modified = (lfs.attributes(script) or {}).modification or os.time();
+ if loaded_scripts[script] then
+ if loaded_scripts[script].last_modified == last_modified then
+ return; -- Already loaded, and source file hasn't changed
+ end
+ module:log("debug", "Reloading %s", script);
+ -- Already loaded, but the source file has changed
+ -- unload it now, and we'll load the new version below
+ unload_script(script, true);
+ end
+ local chain_functions, err = compile_firewall_rules(script);
if not chain_functions then
module:log("error", "Error compiling %s: %s", script, err or "unknown error");
- else
- for chain, handler_code in pairs(chain_functions) do
- local new_handler, err = compile_handler(handler_code, "mod_firewall::"..chain);
- if not new_handler then
- module:log("error", "Compilation error for %s: %s", script, err);
- else
- local chain_definition = chains[chain];
- if chain_definition and chain_definition.type == "event" then
- local handler = new_handler(chain_definition.pass_return);
- for _, event_name in ipairs(chain_definition) do
- module:hook(event_name, handler, chain_definition.priority);
- end
- elseif not chain:sub(1, 5) == "user/" then
- module:log("warn", "Unknown chain %q", chain);
+ return;
+ end
+
+ -- Loop through the chains in the script, and for each chain attach the compiled code to the
+ -- relevant events, keeping track in events_hooked so we can cleanly unload later
+ local events_hooked = {};
+ for chain, handler_code in pairs(chain_functions) do
+ local new_handler, err = compile_handler(handler_code, "mod_firewall::"..chain);
+ if not new_handler then
+ module:log("error", "Compilation error for %s: %s", script, err);
+ else
+ local chain_definition = chains[chain];
+ if chain_definition and chain_definition.type == "event" then
+ local handler = new_handler(chain_definition.pass_return);
+ for _, event_name in ipairs(chain_definition) do
+ events_hooked[event_name] = handler;
+ module:hook(event_name, handler, chain_definition.priority);
end
- module:hook("firewall/chains/"..chain, new_handler(false));
+ elseif not chain:sub(1, 5) == "user/" then
+ module:log("warn", "Unknown chain %q", chain);
end
+ local event_name, handler = "firewall/chains/"..chain, new_handler(false);
+ events_hooked[event_name] = handler;
+ module:hook(event_name, handler);
end
end
+ loaded_scripts[script] = { last_modified = last_modified, events_hooked = events_hooked };
+ module:log("debug", "Loaded %s", script);
+end
+
+function unload_script(script, is_reload)
+ script = resolve_script_path(script);
+ local script_info = loaded_scripts[script];
+ if not script_info then
+ return; -- Script not loaded
+ end
+ local events_hooked = script_info.events_hooked;
+ for event_name, event_handler in pairs(events_hooked) do
+ module:unhook(event_name, event_handler);
+ events_hooked[event_name] = nil;
+ end
+ loaded_scripts[script] = nil;
+ if not is_reload then
+ module:log("debug", "Unloaded %s", script);
+ end
+end
+
+-- Given a set of scripts (e.g. from config) figure out which ones need to
+-- be loaded, which are already loaded but need unloading, and which to reload
+function load_unload_scripts(script_list)
+ local wanted_scripts = script_list / resolve_script_path;
+ local currently_loaded = set.new(it.to_array(it.keys(loaded_scripts)));
+ local scripts_to_unload = currently_loaded - wanted_scripts;
+ for script in wanted_scripts do
+ -- If the script is already loaded, this is fine - it will
+ -- reload the script for us if the file has changed
+ load_script(script);
+ end
+ for script in scripts_to_unload do
+ unload_script(script);
+ end
end
function module.load()
if not prosody.arg then return end -- Don't run in prosodyctl
- active_definitions = {};
local firewall_scripts = module:get_option_set("firewall_scripts", {});
- for script in firewall_scripts do
- load_script(script);
- end
+ load_unload_scripts(firewall_scripts);
-- Replace contents of definitions table (shared) with active definitions
for k in it.keys(definitions) do definitions[k] = nil; end
for k,v in pairs(active_definitions) do definitions[k] = v; end
end
+function module.save()
+ return { active_definitions = active_definitions, loaded_scripts = loaded_scripts };
+end
+
+function module.restore(state)
+ active_definitions = state.active_definitions;
+ loaded_scripts = state.loaded_scripts;
+end
+
+module:hook_global("config-reloaded", function ()
+ load_unload_scripts(module:get_option_set("firewall_scripts", {}));
+end);
+
function module.command(arg)
if not arg[1] or arg[1] == "--help" then
require"util.prosodyctl".show_usage([[mod_firewall <firewall.pfw>]], [[Compile files with firewall rules to Lua code]]);