mod_anti_spam/trie.lib.lua
changeset 5863 259ffdbf8906
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/mod_anti_spam/trie.lib.lua	Tue Mar 05 18:26:29 2024 +0000
@@ -0,0 +1,168 @@
+local bit = require "prosody.util.bitcompat";
+
+local trie_methods = {};
+local trie_mt = { __index = trie_methods };
+
+local function new_node()
+	return {};
+end
+
+function trie_methods:set(item, value)
+	local node = self.root;
+	for i = 1, #item do
+		local c = item:byte(i);
+		if not node[c] then
+			node[c] = new_node();
+		end
+		node = node[c];
+	end
+	node.terminal = true;
+	node.value = value;
+end
+
+local function _remove(node, item, i)
+	if i > #item then
+		if node.terminal then
+			node.terminal = nil;
+			node.value = nil;
+		end
+		if next(node) ~= nil then
+			return node;
+		end
+		return nil;
+	end
+	local c = item:byte(i);
+	local child = node[c];
+	local ret;
+	if child then
+		ret = _remove(child, item, i+1);
+		node[c] = ret;
+	end
+	if ret == nil and next(node) == nil then
+		return nil;
+	end
+	return node;
+end
+
+function trie_methods:remove(item)
+	return _remove(self.root, item, 1);
+end
+
+function trie_methods:get(item, partial)
+	local value;
+	local node = self.root;
+	local len = #item;
+	for i = 1, len do
+		if partial and node.terminal then
+			value = node.value;
+		end
+		local c = item:byte(i);
+		node = node[c];
+		if not node then
+			return value, i - 1;
+		end
+	end
+	return node.value, len;
+end
+
+function trie_methods:add(item)
+	return self:set(item, true);
+end
+
+function trie_methods:contains(item, partial)
+	return self:get(item, partial) ~= nil;
+end
+
+function trie_methods:longest_prefix(item)
+	return select(2, self:get(item));
+end
+
+function trie_methods:add_subnet(item, bits)
+	item = item.packed:sub(1, math.ceil(bits/8));
+	local existing = self:get(item);
+	if not existing then
+		existing = { bits };
+		return self:set(item, existing);
+	end
+
+	-- Simple insertion sort
+	for i = 1, #existing do
+		local v = existing[i];
+		if v == bits then
+			return; -- Already in there
+		elseif v > bits then
+			table.insert(existing, v, i);
+			return;
+		end
+	end
+end
+
+function trie_methods:remove_subnet(item, bits)
+	item = item.packed:sub(1, math.ceil(bits/8));
+	local existing = self:get(item);
+	if not existing then
+		return;
+	end
+
+	-- Simple insertion sort
+	for i = 1, #existing do
+		local v = existing[i];
+		if v == bits then
+			table.remove(existing, i);
+			break;
+		elseif v > bits then
+			return; -- Stop search
+		end
+	end
+
+	if #existing == 0 then
+		self:remove(item);
+	end
+end
+
+function trie_methods:has_ip(item)
+	item = item.packed;
+	local node = self.root;
+	local len = #item;
+	for i = 1, len do
+		if node.terminal then
+			return true;
+		end
+
+		local c = item:byte(i);
+		local child = node[c];
+		if not child then
+			for child_byte, child_node in pairs(node) do
+				if type(child_byte) == "number" and child_node.terminal then
+					local bits = child_node.value;
+					for j = #bits, 1, -1 do
+						local b = bits[j]-((i-1)*8);
+						if b ~= 8 then
+							local mask = bit.bnot(2^b-1);
+							if bit.band(bit.bxor(c, child_byte), mask) == 0 then
+								return true;
+							end
+						end
+					end
+				end
+			end
+			return false;
+		end
+		node = child;
+	end
+end
+
+local function new()
+	return setmetatable({
+		root = new_node();
+	}, trie_mt);
+end
+
+local function is_trie(o)
+	return getmetatable(o) == trie_mt;
+end
+
+return {
+	new = new;
+	is_trie = is_trie;
+};