|
1 local bit = require "prosody.util.bitcompat"; |
|
2 |
|
3 local trie_methods = {}; |
|
4 local trie_mt = { __index = trie_methods }; |
|
5 |
|
6 local function new_node() |
|
7 return {}; |
|
8 end |
|
9 |
|
10 function trie_methods:set(item, value) |
|
11 local node = self.root; |
|
12 for i = 1, #item do |
|
13 local c = item:byte(i); |
|
14 if not node[c] then |
|
15 node[c] = new_node(); |
|
16 end |
|
17 node = node[c]; |
|
18 end |
|
19 node.terminal = true; |
|
20 node.value = value; |
|
21 end |
|
22 |
|
23 local function _remove(node, item, i) |
|
24 if i > #item then |
|
25 if node.terminal then |
|
26 node.terminal = nil; |
|
27 node.value = nil; |
|
28 end |
|
29 if next(node) ~= nil then |
|
30 return node; |
|
31 end |
|
32 return nil; |
|
33 end |
|
34 local c = item:byte(i); |
|
35 local child = node[c]; |
|
36 local ret; |
|
37 if child then |
|
38 ret = _remove(child, item, i+1); |
|
39 node[c] = ret; |
|
40 end |
|
41 if ret == nil and next(node) == nil then |
|
42 return nil; |
|
43 end |
|
44 return node; |
|
45 end |
|
46 |
|
47 function trie_methods:remove(item) |
|
48 return _remove(self.root, item, 1); |
|
49 end |
|
50 |
|
51 function trie_methods:get(item, partial) |
|
52 local value; |
|
53 local node = self.root; |
|
54 local len = #item; |
|
55 for i = 1, len do |
|
56 if partial and node.terminal then |
|
57 value = node.value; |
|
58 end |
|
59 local c = item:byte(i); |
|
60 node = node[c]; |
|
61 if not node then |
|
62 return value, i - 1; |
|
63 end |
|
64 end |
|
65 return node.value, len; |
|
66 end |
|
67 |
|
68 function trie_methods:add(item) |
|
69 return self:set(item, true); |
|
70 end |
|
71 |
|
72 function trie_methods:contains(item, partial) |
|
73 return self:get(item, partial) ~= nil; |
|
74 end |
|
75 |
|
76 function trie_methods:longest_prefix(item) |
|
77 return select(2, self:get(item)); |
|
78 end |
|
79 |
|
80 function trie_methods:add_subnet(item, bits) |
|
81 item = item.packed:sub(1, math.ceil(bits/8)); |
|
82 local existing = self:get(item); |
|
83 if not existing then |
|
84 existing = { bits }; |
|
85 return self:set(item, existing); |
|
86 end |
|
87 |
|
88 -- Simple insertion sort |
|
89 for i = 1, #existing do |
|
90 local v = existing[i]; |
|
91 if v == bits then |
|
92 return; -- Already in there |
|
93 elseif v > bits then |
|
94 table.insert(existing, v, i); |
|
95 return; |
|
96 end |
|
97 end |
|
98 end |
|
99 |
|
100 function trie_methods:remove_subnet(item, bits) |
|
101 item = item.packed:sub(1, math.ceil(bits/8)); |
|
102 local existing = self:get(item); |
|
103 if not existing then |
|
104 return; |
|
105 end |
|
106 |
|
107 -- Simple insertion sort |
|
108 for i = 1, #existing do |
|
109 local v = existing[i]; |
|
110 if v == bits then |
|
111 table.remove(existing, i); |
|
112 break; |
|
113 elseif v > bits then |
|
114 return; -- Stop search |
|
115 end |
|
116 end |
|
117 |
|
118 if #existing == 0 then |
|
119 self:remove(item); |
|
120 end |
|
121 end |
|
122 |
|
123 function trie_methods:has_ip(item) |
|
124 item = item.packed; |
|
125 local node = self.root; |
|
126 local len = #item; |
|
127 for i = 1, len do |
|
128 if node.terminal then |
|
129 return true; |
|
130 end |
|
131 |
|
132 local c = item:byte(i); |
|
133 local child = node[c]; |
|
134 if not child then |
|
135 for child_byte, child_node in pairs(node) do |
|
136 if type(child_byte) == "number" and child_node.terminal then |
|
137 local bits = child_node.value; |
|
138 for j = #bits, 1, -1 do |
|
139 local b = bits[j]-((i-1)*8); |
|
140 if b ~= 8 then |
|
141 local mask = bit.bnot(2^b-1); |
|
142 if bit.band(bit.bxor(c, child_byte), mask) == 0 then |
|
143 return true; |
|
144 end |
|
145 end |
|
146 end |
|
147 end |
|
148 end |
|
149 return false; |
|
150 end |
|
151 node = child; |
|
152 end |
|
153 end |
|
154 |
|
155 local function new() |
|
156 return setmetatable({ |
|
157 root = new_node(); |
|
158 }, trie_mt); |
|
159 end |
|
160 |
|
161 local function is_trie(o) |
|
162 return getmetatable(o) == trie_mt; |
|
163 end |
|
164 |
|
165 return { |
|
166 new = new; |
|
167 is_trie = is_trie; |
|
168 }; |