--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/mod_anti_spam/trie.lib.lua Tue Mar 05 18:26:29 2024 +0000
@@ -0,0 +1,168 @@
+local bit = require "prosody.util.bitcompat";
+
+local trie_methods = {};
+local trie_mt = { __index = trie_methods };
+
+local function new_node()
+ return {};
+end
+
+function trie_methods:set(item, value)
+ local node = self.root;
+ for i = 1, #item do
+ local c = item:byte(i);
+ if not node[c] then
+ node[c] = new_node();
+ end
+ node = node[c];
+ end
+ node.terminal = true;
+ node.value = value;
+end
+
+local function _remove(node, item, i)
+ if i > #item then
+ if node.terminal then
+ node.terminal = nil;
+ node.value = nil;
+ end
+ if next(node) ~= nil then
+ return node;
+ end
+ return nil;
+ end
+ local c = item:byte(i);
+ local child = node[c];
+ local ret;
+ if child then
+ ret = _remove(child, item, i+1);
+ node[c] = ret;
+ end
+ if ret == nil and next(node) == nil then
+ return nil;
+ end
+ return node;
+end
+
+function trie_methods:remove(item)
+ return _remove(self.root, item, 1);
+end
+
+function trie_methods:get(item, partial)
+ local value;
+ local node = self.root;
+ local len = #item;
+ for i = 1, len do
+ if partial and node.terminal then
+ value = node.value;
+ end
+ local c = item:byte(i);
+ node = node[c];
+ if not node then
+ return value, i - 1;
+ end
+ end
+ return node.value, len;
+end
+
+function trie_methods:add(item)
+ return self:set(item, true);
+end
+
+function trie_methods:contains(item, partial)
+ return self:get(item, partial) ~= nil;
+end
+
+function trie_methods:longest_prefix(item)
+ return select(2, self:get(item));
+end
+
+function trie_methods:add_subnet(item, bits)
+ item = item.packed:sub(1, math.ceil(bits/8));
+ local existing = self:get(item);
+ if not existing then
+ existing = { bits };
+ return self:set(item, existing);
+ end
+
+ -- Simple insertion sort
+ for i = 1, #existing do
+ local v = existing[i];
+ if v == bits then
+ return; -- Already in there
+ elseif v > bits then
+ table.insert(existing, v, i);
+ return;
+ end
+ end
+end
+
+function trie_methods:remove_subnet(item, bits)
+ item = item.packed:sub(1, math.ceil(bits/8));
+ local existing = self:get(item);
+ if not existing then
+ return;
+ end
+
+ -- Simple insertion sort
+ for i = 1, #existing do
+ local v = existing[i];
+ if v == bits then
+ table.remove(existing, i);
+ break;
+ elseif v > bits then
+ return; -- Stop search
+ end
+ end
+
+ if #existing == 0 then
+ self:remove(item);
+ end
+end
+
+function trie_methods:has_ip(item)
+ item = item.packed;
+ local node = self.root;
+ local len = #item;
+ for i = 1, len do
+ if node.terminal then
+ return true;
+ end
+
+ local c = item:byte(i);
+ local child = node[c];
+ if not child then
+ for child_byte, child_node in pairs(node) do
+ if type(child_byte) == "number" and child_node.terminal then
+ local bits = child_node.value;
+ for j = #bits, 1, -1 do
+ local b = bits[j]-((i-1)*8);
+ if b ~= 8 then
+ local mask = bit.bnot(2^b-1);
+ if bit.band(bit.bxor(c, child_byte), mask) == 0 then
+ return true;
+ end
+ end
+ end
+ end
+ end
+ return false;
+ end
+ node = child;
+ end
+end
+
+local function new()
+ return setmetatable({
+ root = new_node();
+ }, trie_mt);
+end
+
+local function is_trie(o)
+ return getmetatable(o) == trie_mt;
+end
+
+return {
+ new = new;
+ is_trie = is_trie;
+};