--- a/mod_rest/mod_rest.lua Wed Feb 26 18:04:17 2020 +0000
+++ b/mod_rest/mod_rest.lua Wed Feb 26 18:36:40 2020 +0000
@@ -4,34 +4,41 @@
--
-- This file is MIT/X11 licensed.
+local encodings = require "util.encodings";
+local base64 = encodings.base64;
local errors = require "util.error";
local http = require "net.http";
local id = require "util.id";
local jid = require "util.jid";
local json = require "util.json";
local st = require "util.stanza";
+local um = require "core.usermanager";
local xml = require "util.xml";
-local allow_any_source = module:get_host_type() == "component";
-local validate_from_addresses = module:get_option_boolean("validate_from_addresses", true);
-local secret = assert(module:get_option_string("rest_credentials"), "rest_credentials is a required setting");
-local auth_type = assert(secret:match("^%S+"), "Format of rest_credentials MUST be like 'Bearer secret'");
-assert(auth_type == "Bearer" or auth_type == "Basic", "Only 'Bearer' and 'Basic' are supported in rest_credentials");
+local jsonmap = module:require"jsonmap";
+
+local tokens = module:depends("authtokens");
+
+local auth_mechanisms = module:get_option_set("rest_auth_mechanisms", { "Basic", "Bearer" });
-local jsonmap = module:require"jsonmap";
+local www_authenticate_header;
+do
+ local header, realm = {}, module.host.."/"..module.name;
+ for mech in auth_mechanisms do
+ header[#header+1] = ("%s realm=%q"):format(mech, realm);
+ end
+ www_authenticate_header = table.concat(header, ", ");
+end
+
-- Bearer token
local function check_credentials(request)
- return request.headers.authorization == secret;
-end
-if secret == "Basic" and module:get_host_type() == "local" then
- local um = require "core.usermanager";
- local encodings = require "util.encodings";
- local base64 = encodings.base64;
+ local auth_type, auth_data = string.match(request.headers.authorization, "^(%S+)%s(.+)$");
+ if not (auth_type and auth_data) or not auth_mechanisms:contains(auth_type) then
+ return false;
+ end
- function check_credentials(request)
- local creds = string.match(request.headers.authorization, "^Basic%s+([A-Za-z0-9+/]+=?=?)%s*$");
- if not creds then return false; end
- creds = base64.decode(creds);
+ if auth_type == "Basic" then
+ local creds = base64.decode(auth_data);
if not creds then return false; end
local username, password = string.match(creds, "^([^:]+):(.*)$");
if not username then return false; end
@@ -40,8 +47,15 @@
if not um.test_password(username, module.host, password) then
return false;
end
- return jid.join(username, module.host);
+ return { username = username, host = module.host };
+ elseif auth_type == "Bearer" then
+ local token_info = tokens.get_token_info(auth_data);
+ if not token_info or not token_info.session then
+ return false;
+ end
+ return token_info.session;
end
+ return nil;
end
local function parse(mimetype, data)
@@ -84,18 +98,18 @@
local function handle_post(event)
local request, response = event.request, event.response;
- local from = module.host;
+ local from;
+ local origin;
+
if not request.headers.authorization then
- response.headers.www_authenticate = ("%s realm=%q"):format(auth_type, module.host.."/"..module.name);
+ response.headers.www_authenticate = www_authenticate_header;
return 401;
else
- local authz = check_credentials(request);
- if not authz then
+ origin = check_credentials(request);
+ if not origin then
return 401;
end
- if type(authz) == "string" then
- from = authz;
- end
+ from = jid.join(origin.username, origin.host, origin.resource);
end
local payload, err = parse(request.headers.content_type, request.body);
if not payload then
@@ -111,13 +125,15 @@
if not to then
return errors.new({ code = 422, text = "Invalid destination JID" });
end
- if allow_any_source and payload.attr.from then
- from = jid.prep(payload.attr.from);
- if not from then
+ if payload.attr.from then
+ local requested_from = jid.prep(payload.attr.from);
+ if not requested_from then
return errors.new({ code = 422, text = "Invalid source JID" });
end
- if validate_from_addresses and not jid.compare(from, module.host) then
- return errors.new({ code = 403, text = "Source JID must belong to current host" });
+ if jid.compare(requested_from, from) then
+ from = requested_from;
+ else
+ return errors.new({ code = 403, text = "Not authorized to send from "..requested_from });
end
end
payload.attr = {
@@ -130,12 +146,15 @@
module:log("debug", "Received[rest]: %s", payload:top_tag());
local send_type = decide_type((request.headers.accept or "") ..",".. request.headers.content_type)
if payload.name == "iq" then
+ function origin.send(stanza)
+ prosody.core_route_stanza(nil, stanza);
+ end
if payload.attr.type ~= "get" and payload.attr.type ~= "set" then
return errors.new({ code = 422, text = "'iq' stanza must be of type 'get' or 'set'" });
elseif #payload.tags ~= 1 then
return errors.new({ code = 422, text = "'iq' stanza must have exactly one child tag" });
end
- return module:send_iq(payload):next(
+ return module:send_iq(payload, origin):next(
function (result)
module:log("debug", "Sending[rest]: %s", result.stanza:top_tag());
response.headers.content_type = send_type;
@@ -154,7 +173,6 @@
end
end);
else
- local origin = {};
function origin.send(stanza)
module:log("debug", "Sending[rest]: %s", stanza:top_tag());
response.headers.content_type = send_type;