net/unbound.lua
changeset 10966 92f30e8ecdfc
child 10971 67aabf83230b
equal deleted inserted replaced
10965:f93dce30089a 10966:92f30e8ecdfc
       
     1 -- libunbound based net.adns replacement for Prosody IM
       
     2 -- Copyright (C) 2013-2015 Kim Alvefur
       
     3 --
       
     4 -- This file is MIT licensed.
       
     5 --
       
     6 -- luacheck: ignore prosody
       
     7 
       
     8 local setmetatable = setmetatable;
       
     9 local tostring = tostring;
       
    10 local t_concat = table.concat;
       
    11 local s_format = string.format;
       
    12 local s_lower = string.lower;
       
    13 local s_upper = string.upper;
       
    14 local noop = function() end;
       
    15 local zero = function() return 0 end;
       
    16 local truop = function() return true; end;
       
    17 
       
    18 local log = require "util.logger".init("unbound");
       
    19 local net_server = require "net.server";
       
    20 local libunbound = require"lunbound";
       
    21 local have_promise, promise = pcall(require, "util.promise");
       
    22 
       
    23 local gettime = require"socket".gettime;
       
    24 local dns_utils = require"util.dns";
       
    25 local classes, types, errors = dns_utils.classes, dns_utils.types, dns_utils.errors;
       
    26 local parsers = dns_utils.parsers;
       
    27 
       
    28 local function add_defaults(conf)
       
    29 	if conf then
       
    30 		for option, default in pairs(libunbound.config) do
       
    31 			if conf[option] == nil then
       
    32 				conf[option] = default;
       
    33 			end
       
    34 		end
       
    35 	end
       
    36 	return conf;
       
    37 end
       
    38 
       
    39 local unbound_config;
       
    40 if prosody then
       
    41 	local config = require"core.configmanager";
       
    42 	unbound_config = add_defaults(config.get("*", "unbound"));
       
    43 	prosody.events.add_handler("config-reloaded", function()
       
    44 		unbound_config = add_defaults(config.get("*", "unbound"));
       
    45 	end);
       
    46 end
       
    47 -- Note: libunbound will default to using root hints if resolvconf is unset
       
    48 
       
    49 local function connect_server(unbound, server)
       
    50 	if server.watchfd then
       
    51 		return server.watchfd(unbound, function ()
       
    52 			unbound:process()
       
    53 		end);
       
    54 	elseif server.event and server.addevent then
       
    55 		local EV_READ = server.event.EV_READ;
       
    56 		local function event_callback()
       
    57 			unbound:process();
       
    58 			return EV_READ;
       
    59 		end
       
    60 		return server.addevent(unbound:getfd(), EV_READ, event_callback)
       
    61 	elseif server.wrapclient then
       
    62 		local conn = {
       
    63 			getfd = function()
       
    64 				return unbound:getfd();
       
    65 			end,
       
    66 
       
    67 			send = zero,
       
    68 			receive = noop,
       
    69 			settimeout = noop,
       
    70 			close = truop,
       
    71 		}
       
    72 
       
    73 		local function process()
       
    74 			unbound:process();
       
    75 		end
       
    76 		local listener = {
       
    77 			onincoming = process,
       
    78 
       
    79 			onconnect = noop,
       
    80 			ondisconnect = noop,
       
    81 			onreadtimeout = truop,
       
    82 		};
       
    83 		return server.wrapclient(conn, "dns", 0, listener, "*a" );
       
    84 	end
       
    85 end
       
    86 
       
    87 local unbound = libunbound.new(unbound_config);
       
    88 
       
    89 local server_conn = connect_server(unbound, net_server);
       
    90 
       
    91 local answer_mt = {
       
    92 	__tostring = function(self)
       
    93 		if self._string then return self._string end
       
    94 		local h = s_format("Status: %s", errors[self.status]);
       
    95 		if self.secure then
       
    96 			h = h .. ", Secure";
       
    97 		elseif self.bogus then
       
    98 			h = h .. s_format(", Bogus: %s", self.bogus);
       
    99 		end
       
   100 		local t = { h };
       
   101 		for i = 1, #self do
       
   102 			t[i+1]=self.qname.."\t"..classes[self.qclass].."\t"..types[self.qtype].."\t"..tostring(self[i]);
       
   103 		end
       
   104 		local _string = t_concat(t, "\n");
       
   105 		self._string = _string;
       
   106 		return _string;
       
   107 	end;
       
   108 };
       
   109 
       
   110 local waiting_queries = {};
       
   111 
       
   112 local function prep_answer(a)
       
   113 	if not a then return end
       
   114 	local status = errors[a.rcode];
       
   115 	local qclass = classes[a.qclass];
       
   116 	local qtype = types[a.qtype];
       
   117 	a.status, a.class, a.type = status, qclass, qtype;
       
   118 
       
   119 	local t = s_lower(qtype);
       
   120 	local rr_mt = { __index = a, __tostring = function(self) return tostring(self[t]) end };
       
   121 	local parser = parsers[qtype];
       
   122 	for i = 1, #a do
       
   123 		if a.bogus then
       
   124 			-- Discard bogus data
       
   125 			a[i] = nil;
       
   126 		else
       
   127 			a[i] = setmetatable({[t] = parser(a[i])}, rr_mt);
       
   128 		end
       
   129 	end
       
   130 	return setmetatable(a, answer_mt);
       
   131 end
       
   132 
       
   133 local function lookup(callback, qname, qtype, qclass)
       
   134 	qtype = qtype and s_upper(qtype) or "A";
       
   135 	qclass = qclass and s_upper(qclass) or "IN";
       
   136 	local ntype, nclass = types[qtype], classes[qclass];
       
   137 	local startedat = gettime();
       
   138 	local ret;
       
   139 	local function callback_wrapper(a, err)
       
   140 		local gotdataat = gettime();
       
   141 		waiting_queries[ret] = nil;
       
   142 		if a then
       
   143 			prep_answer(a);
       
   144 			log("debug", "Results for %s %s %s: %s (%s, %f sec)", qname, qclass, qtype, a.rcode == 0 and (#a .. " items") or a.status,
       
   145 				a.secure and "Secure" or a.bogus or "Insecure", gotdataat - startedat); -- Insecure as in unsigned
       
   146 		else
       
   147 			log("error", "Results for %s %s %s: %s", qname, qclass, qtype, tostring(err));
       
   148 		end
       
   149 		local ok, cerr = pcall(callback, a, err);
       
   150 		if not ok then log("error", "Error in callback: %s", cerr); end
       
   151 	end
       
   152 	log("debug", "Resolve %s %s %s", qname, qclass, qtype);
       
   153 	local err;
       
   154 	ret, err = unbound:resolve_async(callback_wrapper, qname, ntype, nclass);
       
   155 	if ret then
       
   156 		waiting_queries[ret] = callback;
       
   157 	else
       
   158 		log("warn", err);
       
   159 	end
       
   160 	return ret, err;
       
   161 end
       
   162 
       
   163 local function lookup_sync(qname, qtype, qclass)
       
   164 	qtype = qtype and s_upper(qtype) or "A";
       
   165 	qclass = qclass and s_upper(qclass) or "IN";
       
   166 	local ntype, nclass = types[qtype], classes[qclass];
       
   167 	local a, err = unbound:resolve(qname, ntype, nclass);
       
   168 	if not a then return a, err; end
       
   169 	return prep_answer(a);
       
   170 end
       
   171 
       
   172 local function cancel(id)
       
   173 	local cb = waiting_queries[id];
       
   174 	unbound:cancel(id);
       
   175 	if cb then
       
   176 		cb(nil, "canceled");
       
   177 		waiting_queries[id] = nil;
       
   178 	end
       
   179 	return true;
       
   180 end
       
   181 
       
   182 -- Reinitiate libunbound context, drops cache
       
   183 local function purge()
       
   184 	for id in pairs(waiting_queries) do cancel(id); end
       
   185 	if server_conn then server_conn:close(); end
       
   186 	unbound = libunbound.new(unbound_config);
       
   187 	server_conn = connect_server(unbound, net_server);
       
   188 	return true;
       
   189 end
       
   190 
       
   191 local function not_implemented()
       
   192 	error "not implemented";
       
   193 end
       
   194 -- Public API
       
   195 local _M = {
       
   196 	lookup = lookup;
       
   197 	cancel = cancel;
       
   198 	new_async_socket = not_implemented;
       
   199 	dns = {
       
   200 		lookup = lookup_sync;
       
   201 		cancel = cancel;
       
   202 		cache = noop;
       
   203 		socket_wrapper_set = noop;
       
   204 		settimeout = noop;
       
   205 		query = noop;
       
   206 		purge = purge;
       
   207 		random = noop;
       
   208 		peek = noop;
       
   209 
       
   210 		types = types;
       
   211 		classes = classes;
       
   212 	};
       
   213 };
       
   214 
       
   215 local lookup_promise;
       
   216 if have_promise then
       
   217 	function lookup_promise(_, qname, qtype, qclass)
       
   218 		return promise.new(function (resolve, reject)
       
   219 			local function callback(answer, err)
       
   220 				if err then
       
   221 					return reject(err);
       
   222 				else
       
   223 					return resolve(answer);
       
   224 				end
       
   225 			end
       
   226 			local ret, err = lookup(callback, qname, qtype, qclass)
       
   227 			if not ret then reject(err); end
       
   228 		end);
       
   229 	end
       
   230 end
       
   231 
       
   232 local wrapper = {
       
   233 	lookup = function (_, callback, qname, qtype, qclass)
       
   234 		return lookup(callback, qname, qtype, qclass)
       
   235 	end;
       
   236 	lookup_promise = lookup_promise;
       
   237 	_resolver = {
       
   238 		settimeout = function () end;
       
   239 		closeall = function () end;
       
   240 	};
       
   241 }
       
   242 
       
   243 function _M.resolver() return wrapper; end
       
   244 
       
   245 return _M;