--- a/mod_storage_s3/mod_storage_s3.lua Fri Nov 10 00:26:17 2023 +0100
+++ b/mod_storage_s3/mod_storage_s3.lua Sat Nov 11 17:01:29 2023 +0100
@@ -25,22 +25,13 @@
local access_key = module:get_option_string("s3_access_key");
local secret_key = module:get_option_string("s3_secret_key");
-function driver:open(store, typ)
- local mt = self[typ or "keyval"]
- if not mt then
- return nil, "unsupported-store";
- end
- return setmetatable({ store = store; bucket = bucket; type = typ }, mt);
-end
-
-local keyval = { };
-driver.keyval = { __index = keyval; __name = module.name .. " keyval store" };
-
local aws4_format = "AWS4-HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s";
-local function new_request(method, path, query, payload)
- local request = url.parse(base_uri);
- request.path = path;
+local function aws_auth(event)
+ local request, options = event.request, event.options;
+ local method = options.method or "GET";
+ local query = options.query;
+ local payload = options.body;
local payload_type = nil;
if st.is_stanza(payload) then
@@ -50,6 +41,7 @@
payload_type = "application/json";
payload = json.encode(payload);
end
+ options.body = payload;
local payload_hash = sha256(payload or "", true);
@@ -112,7 +104,27 @@
headers["Authorization"] = string.format(aws4_format, access_key, scope, signed_headers, signature);
- return http.request(url.build(request), { method = method; headers = headers; body = payload });
+ options.headers = headers;
+end
+
+function driver:open(store, typ)
+ local mt = self[typ or "keyval"]
+ if not mt then
+ return nil, "unsupported-store";
+ end
+ local httpclient = http.new({});
+ httpclient.events.add_handler("pre-request", aws_auth);
+ return setmetatable({ store = store; bucket = bucket; type = typ; http = httpclient }, mt);
+end
+
+local keyval = { };
+driver.keyval = { __index = keyval; __name = module.name .. " keyval store" };
+
+local function new_request(self, method, path, query, payload)
+ local request = url.parse(base_uri);
+ request.path = path;
+
+ return self.http:request(url.build(request), { method = method; body = payload; query = query });
end
-- coerce result back into Prosody data type
@@ -147,22 +159,22 @@
end
function keyval:get(user)
- return async.wait_for(new_request("GET", self:_path(user)):next(on_result));
+ return async.wait_for(new_request(self, "GET", self:_path(user)):next(on_result));
end
function keyval:set(user, data)
if data == nil or (type(data) == "table" and next(data) == nil) then
- return async.wait_for(new_request("DELETE", self:_path(user)));
+ return async.wait_for(new_request(self, "DELETE", self:_path(user)));
end
- return async.wait_for(new_request("PUT", self:_path(user), nil, data));
+ return async.wait_for(new_request(self, "PUT", self:_path(user), nil, data));
end
function keyval:users()
local bucket_path = url.build_path({ is_absolute = true; bucket; is_directory = true });
local prefix = url.build_path({ jid.escape(module.host); jid.escape(self.store); is_directory = true });
- local list_result, err = async.wait_for(new_request("GET", bucket_path, { prefix = prefix }))
+ local list_result, err = async.wait_for(new_request(self, "GET", bucket_path, { prefix = prefix }))
if err or list_result.code ~= 200 then
return nil, err;
end
@@ -208,7 +220,7 @@
wrapper:tag("delay", { xmlns = "urn:xmpp:delay"; stamp = dt.datetime(when) }):up();
wrapper:add_direct_child(value);
key = key or new_uuid();
- return async.wait_for(new_request("PUT", self:_path(username, nil, when, with, key), nil, wrapper):next(function(r)
+ return async.wait_for(new_request(self, "PUT", self:_path(username, nil, when, with, key), nil, wrapper):next(function(r)
if r.code == 200 then
return key;
else
@@ -232,7 +244,7 @@
end
prefix = url.build_path(prefix);
- local list_result, err = async.wait_for(new_request("GET", bucket_path, {
+ local list_result, err = async.wait_for(new_request(self, "GET", bucket_path, {
prefix = prefix;
["max-keys"] = query["max"] and tostring(query["max"]);
}));
@@ -276,7 +288,7 @@
return nil;
end
-- luacheck: ignore 431/err
- local value, err = async.wait_for(new_request("GET", self:_path(username or "@", item.date, nil, item.with, item.key)):next(on_result));
+ local value, err = async.wait_for(new_request(self, "GET", self:_path(username or "@", item.date, nil, item.with, item.key)):next(on_result));
if not value then
module:log("error", "%s", err);
return nil;