local bit = require "prosody.util.bitcompat";
local trie_methods = {};
local trie_mt = { __index = trie_methods };
local function new_node()
return {};
end
function trie_methods:set(item, value)
local node = self.root;
for i = 1, #item do
local c = item:byte(i);
if not node[c] then
node[c] = new_node();
end
node = node[c];
end
node.terminal = true;
node.value = value;
end
local function _remove(node, item, i)
if i > #item then
if node.terminal then
node.terminal = nil;
node.value = nil;
end
if next(node) ~= nil then
return node;
end
return nil;
end
local c = item:byte(i);
local child = node[c];
local ret;
if child then
ret = _remove(child, item, i+1);
node[c] = ret;
end
if ret == nil and next(node) == nil then
return nil;
end
return node;
end
function trie_methods:remove(item)
return _remove(self.root, item, 1);
end
function trie_methods:get(item, partial)
local value;
local node = self.root;
local len = #item;
for i = 1, len do
if partial and node.terminal then
value = node.value;
end
local c = item:byte(i);
node = node[c];
if not node then
return value, i - 1;
end
end
return node.value, len;
end
function trie_methods:add(item)
return self:set(item, true);
end
function trie_methods:contains(item, partial)
return self:get(item, partial) ~= nil;
end
function trie_methods:longest_prefix(item)
return select(2, self:get(item));
end
function trie_methods:add_subnet(item, bits)
item = item.packed:sub(1, math.ceil(bits/8));
local existing = self:get(item);
if not existing then
existing = { bits };
return self:set(item, existing);
end
-- Simple insertion sort
for i = 1, #existing do
local v = existing[i];
if v == bits then
return; -- Already in there
elseif v > bits then
table.insert(existing, v, i);
return;
end
end
end
function trie_methods:remove_subnet(item, bits)
item = item.packed:sub(1, math.ceil(bits/8));
local existing = self:get(item);
if not existing then
return;
end
-- Simple insertion sort
for i = 1, #existing do
local v = existing[i];
if v == bits then
table.remove(existing, i);
break;
elseif v > bits then
return; -- Stop search
end
end
if #existing == 0 then
self:remove(item);
end
end
function trie_methods:has_ip(item)
item = item.packed;
local node = self.root;
local len = #item;
for i = 1, len do
if node.terminal then
return true;
end
local c = item:byte(i);
local child = node[c];
if not child then
for child_byte, child_node in pairs(node) do
if type(child_byte) == "number" and child_node.terminal then
local bits = child_node.value;
for j = #bits, 1, -1 do
local b = bits[j]-((i-1)*8);
if b ~= 8 then
local mask = bit.bnot(2^b-1);
if bit.band(bit.bxor(c, child_byte), mask) == 0 then
return true;
end
end
end
end
end
return false;
end
node = child;
end
end
local function new()
return setmetatable({
root = new_node();
}, trie_mt);
end
local function is_trie(o)
return getmetatable(o) == trie_mt;
end
return {
new = new;
is_trie = is_trie;
};