mod_anti_spam/trie.lib.lua
changeset 5863 259ffdbf8906
equal deleted inserted replaced
5862:761142ee0ff2 5863:259ffdbf8906
       
     1 local bit = require "prosody.util.bitcompat";
       
     2 
       
     3 local trie_methods = {};
       
     4 local trie_mt = { __index = trie_methods };
       
     5 
       
     6 local function new_node()
       
     7 	return {};
       
     8 end
       
     9 
       
    10 function trie_methods:set(item, value)
       
    11 	local node = self.root;
       
    12 	for i = 1, #item do
       
    13 		local c = item:byte(i);
       
    14 		if not node[c] then
       
    15 			node[c] = new_node();
       
    16 		end
       
    17 		node = node[c];
       
    18 	end
       
    19 	node.terminal = true;
       
    20 	node.value = value;
       
    21 end
       
    22 
       
    23 local function _remove(node, item, i)
       
    24 	if i > #item then
       
    25 		if node.terminal then
       
    26 			node.terminal = nil;
       
    27 			node.value = nil;
       
    28 		end
       
    29 		if next(node) ~= nil then
       
    30 			return node;
       
    31 		end
       
    32 		return nil;
       
    33 	end
       
    34 	local c = item:byte(i);
       
    35 	local child = node[c];
       
    36 	local ret;
       
    37 	if child then
       
    38 		ret = _remove(child, item, i+1);
       
    39 		node[c] = ret;
       
    40 	end
       
    41 	if ret == nil and next(node) == nil then
       
    42 		return nil;
       
    43 	end
       
    44 	return node;
       
    45 end
       
    46 
       
    47 function trie_methods:remove(item)
       
    48 	return _remove(self.root, item, 1);
       
    49 end
       
    50 
       
    51 function trie_methods:get(item, partial)
       
    52 	local value;
       
    53 	local node = self.root;
       
    54 	local len = #item;
       
    55 	for i = 1, len do
       
    56 		if partial and node.terminal then
       
    57 			value = node.value;
       
    58 		end
       
    59 		local c = item:byte(i);
       
    60 		node = node[c];
       
    61 		if not node then
       
    62 			return value, i - 1;
       
    63 		end
       
    64 	end
       
    65 	return node.value, len;
       
    66 end
       
    67 
       
    68 function trie_methods:add(item)
       
    69 	return self:set(item, true);
       
    70 end
       
    71 
       
    72 function trie_methods:contains(item, partial)
       
    73 	return self:get(item, partial) ~= nil;
       
    74 end
       
    75 
       
    76 function trie_methods:longest_prefix(item)
       
    77 	return select(2, self:get(item));
       
    78 end
       
    79 
       
    80 function trie_methods:add_subnet(item, bits)
       
    81 	item = item.packed:sub(1, math.ceil(bits/8));
       
    82 	local existing = self:get(item);
       
    83 	if not existing then
       
    84 		existing = { bits };
       
    85 		return self:set(item, existing);
       
    86 	end
       
    87 
       
    88 	-- Simple insertion sort
       
    89 	for i = 1, #existing do
       
    90 		local v = existing[i];
       
    91 		if v == bits then
       
    92 			return; -- Already in there
       
    93 		elseif v > bits then
       
    94 			table.insert(existing, v, i);
       
    95 			return;
       
    96 		end
       
    97 	end
       
    98 end
       
    99 
       
   100 function trie_methods:remove_subnet(item, bits)
       
   101 	item = item.packed:sub(1, math.ceil(bits/8));
       
   102 	local existing = self:get(item);
       
   103 	if not existing then
       
   104 		return;
       
   105 	end
       
   106 
       
   107 	-- Simple insertion sort
       
   108 	for i = 1, #existing do
       
   109 		local v = existing[i];
       
   110 		if v == bits then
       
   111 			table.remove(existing, i);
       
   112 			break;
       
   113 		elseif v > bits then
       
   114 			return; -- Stop search
       
   115 		end
       
   116 	end
       
   117 
       
   118 	if #existing == 0 then
       
   119 		self:remove(item);
       
   120 	end
       
   121 end
       
   122 
       
   123 function trie_methods:has_ip(item)
       
   124 	item = item.packed;
       
   125 	local node = self.root;
       
   126 	local len = #item;
       
   127 	for i = 1, len do
       
   128 		if node.terminal then
       
   129 			return true;
       
   130 		end
       
   131 
       
   132 		local c = item:byte(i);
       
   133 		local child = node[c];
       
   134 		if not child then
       
   135 			for child_byte, child_node in pairs(node) do
       
   136 				if type(child_byte) == "number" and child_node.terminal then
       
   137 					local bits = child_node.value;
       
   138 					for j = #bits, 1, -1 do
       
   139 						local b = bits[j]-((i-1)*8);
       
   140 						if b ~= 8 then
       
   141 							local mask = bit.bnot(2^b-1);
       
   142 							if bit.band(bit.bxor(c, child_byte), mask) == 0 then
       
   143 								return true;
       
   144 							end
       
   145 						end
       
   146 					end
       
   147 				end
       
   148 			end
       
   149 			return false;
       
   150 		end
       
   151 		node = child;
       
   152 	end
       
   153 end
       
   154 
       
   155 local function new()
       
   156 	return setmetatable({
       
   157 		root = new_node();
       
   158 	}, trie_mt);
       
   159 end
       
   160 
       
   161 local function is_trie(o)
       
   162 	return getmetatable(o) == trie_mt;
       
   163 end
       
   164 
       
   165 return {
       
   166 	new = new;
       
   167 	is_trie = is_trie;
       
   168 };