--- a/mod_firewall/conditions.lib.lua Sat Apr 06 21:47:46 2013 +0200
+++ b/mod_firewall/conditions.lib.lua Sat Apr 06 22:20:59 2013 +0100
@@ -170,4 +170,8 @@
return table.concat(conditions, " or "), { "time:hour,min" };
end
+function condition_handlers.LIMIT(name)
+ return ("not throttle_%s:poll(1)"):format(name), { "throttle:"..name };
+end
+
return condition_handlers;
--- a/mod_firewall/mod_firewall.lua Sat Apr 06 21:47:46 2013 +0200
+++ b/mod_firewall/mod_firewall.lua Sat Apr 06 22:20:59 2013 +0100
@@ -2,17 +2,12 @@
local resolve_relative_path = require "core.configmanager".resolve_relative_path;
local logger = require "util.logger".init;
local set = require "util.set";
+local it = require "util.iterators";
local add_filter = require "util.filters".add_filter;
+local new_throttle = require "util.throttle".create;
-zones = {};
-local zones = zones;
-setmetatable(zones, {
- __index = function (zones, zone)
- local t = { [zone] = true };
- rawset(zones, zone, t);
- return t;
- end;
-});
+local zones, throttles = module:shared("zones", "throttles");
+local active_zones, active_throttles = {}, {};
local chains = {
preroute = {
@@ -35,6 +30,10 @@
};
};
+local function idsafe(name)
+ return not not name:match("^%a[%w_]*$")
+end
+
-- Dependency locations:
-- <type lib>
-- <type global>
@@ -73,7 +72,7 @@
is_admin = { global_code = [[local is_admin = require "core.usermanager".is_admin]]};
core_post_stanza = { global_code = [[local core_post_stanza = prosody.core_post_stanza]] };
zone = { global_code = function (zone)
- assert(zone:match("^%a[%w_]*$"), "Invalid zone name: "..zone);
+ assert(idsafe(zone), "Invalid zone name: "..zone);
return ("local zone_%s = zones[%q] or {};"):format(zone, zone);
end };
date_time = { global_code = [[local os_date = os.date]]; local_code = [[local current_date_time = os_date("*t");]] };
@@ -84,6 +83,13 @@
end
return table.concat(defs, " ");
end, depends = { "date_time" }; };
+ throttle = {
+ global_code = function (throttle)
+ assert(idsafe(throttle), "Invalid rate limit name: "..throttle);
+ assert(throttles[throttle], "Unknown rate limit: "..throttle);
+ return ("local throttle_%s = throttles.%s;"):format(throttle, throttle);
+ end;
+ };
};
local function include_dep(dep, code)
@@ -188,6 +194,14 @@
zone_member_list[#zone_member_list+1] = member;
end
zones[zone_name] = set.new(zone_member_list)._items;
+ table.insert(active_zones, zone_name);
+ elseif not(state) and line:match("^RATE ") then
+ local name = line:match("^RATE ([^:]+)");
+ assert(idsafe(name), "Invalid rate limit name: "..name);
+ local rate = assert(tonumber(line:match(":%s*([%d.]+)")), "Unable to parse rate");
+ local burst = tonumber(line:match("%(%s*burst%s+([%d.]+)%s*%)")) or 1;
+ throttles[name] = new_throttle(rate*burst, burst);
+ table.insert(active_throttles, name);
elseif line:match("^[^%s:]+[%.=]") then
-- Action
if state == nil then
@@ -265,7 +279,7 @@
table.insert(code, rule_code);
end
- local code_string = [[return function (zones, fire_event, log)
+ local code_string = [[return function (zones, throttles, fire_event, log)
]]..table.concat(code.global_header, "\n")..[[
local db = require 'util.debug'
return function (event)
@@ -291,11 +305,17 @@
local function fire_event(name, data)
return module:fire_event(name, data);
end
- chunk = chunk()(zones, fire_event, logger(filename)); -- Returns event handler with 'zones' upvalue.
+ chunk = chunk()(zones, throttles, fire_event, logger(filename)); -- Returns event handler with 'zones' upvalue.
return chunk;
end
+local function cleanup(t, active_list)
+ local unused = set.new(it.to_array(it.keys(t))) - set.new(active_list);
+ for k in unused do t[k] = nil; end
+end
+
function module.load()
+ active_zones, active_throttles = {}, {};
local firewall_scripts = module:get_option_set("firewall_scripts", {});
for script in firewall_scripts do
script = resolve_relative_path(prosody.paths.config, script);
@@ -322,4 +342,7 @@
end
end
end
+ -- Remove entries from tables that are no longer in use
+ cleanup(zones, active_zones);
+ cleanup(throttles, active_throttles);
end