util/pubsub.lua
author Kim Alvefur <zash@zash.se>
Sun, 05 Aug 2018 15:17:00 +0200
changeset 9120 a19fdc6e4f09
parent 9110 6e42ef9c805c
child 9132 7721794e9e93
permissions -rw-r--r--
util.pubsub: Apply defaults metatable before config check (thanks pep.) Makes it so that the callback sees the default if it’s not in the form, which makes it easier to validate.

local events = require "util.events";
local cache = require "util.cache";

local service_mt = {};

local default_config = {
	itemstore = function (config, _) return cache.new(config["max_items"]) end;
	broadcaster = function () end;
	itemcheck = function () return true; end;
	get_affiliation = function () end;
	normalize_jid = function (jid) return jid; end;
	capabilities = {};
};
local default_config_mt = { __index = default_config };

local default_node_config = {
	["persist_items"] = false;
	["max_items"] = 20;
	["access_model"] = "open";
};
local default_node_config_mt = { __index = default_node_config };

-- Storage helper functions

local function load_node_from_store(service, node_name)
	local node = service.config.nodestore:get(node_name);
	node.config = setmetatable(node.config or {}, {__index=service.node_defaults});
	return node;
end

local function save_node_to_store(service, node)
	return service.config.nodestore:set(node.name, {
		name = node.name;
		config = node.config;
		subscribers = node.subscribers;
		affiliations = node.affiliations;
	});
end

local function delete_node_in_store(service, node_name)
	return service.config.nodestore:set(node_name, nil);
end

-- Create and return a new service object
local function new(config)
	config = config or {};

	local service = setmetatable({
		config = setmetatable(config, default_config_mt);
		node_defaults = setmetatable(config.node_defaults or {}, default_node_config_mt);
		affiliations = {};
		subscriptions = {};
		nodes = {};
		data = {};
		events = events.new();
	}, service_mt);

	-- Load nodes from storage, if we have a store and it supports iterating over stored items
	if config.nodestore and config.nodestore.users then
		for node_name in config.nodestore:users() do
			service.nodes[node_name] = load_node_from_store(service, node_name);
			service.data[node_name] = config.itemstore(service.nodes[node_name].config, node_name);
		end
	end

	return service;
end

--- Service methods

local service = {};
service_mt.__index = service;

function service:jids_equal(jid1, jid2)
	local normalize = self.config.normalize_jid;
	return normalize(jid1) == normalize(jid2);
end

function service:may(node, actor, action)
	if actor == true then return true; end

	local node_obj = self.nodes[node];
	local node_aff = node_obj and (node_obj.affiliations[actor]
	              or node_obj.affiliations[self.config.normalize_jid(actor)]);
	local service_aff = self.affiliations[actor]
	                 or self.config.get_affiliation(actor, node, action);
	local default_aff = self:get_default_affiliation(node, actor) or "none";

	-- Check if node allows/forbids it
	local node_capabilities = node_obj and node_obj.capabilities;
	if node_capabilities then
		local caps = node_capabilities[node_aff or service_aff or default_aff];
		if caps then
			local can = caps[action];
			if can ~= nil then
				return can;
			end
		end
	end

	-- Check service-wide capabilities instead
	local service_capabilities = self.config.capabilities;
	local caps = service_capabilities[node_aff or service_aff or default_aff];
	if caps then
		local can = caps[action];
		if can ~= nil then
			return can;
		end
	end

	return false;
end

function service:get_default_affiliation(node, actor, action) -- luacheck: ignore 212
	local node_obj = self.nodes[node];
	local access_model = node_obj and node_obj.config.access_model
		or self.node_defaults.access_model;

	if access_model == "open" then
		return "none";
	elseif access_model == "whitelist" then
		return "restricted";
	end

	if self.config.access_models then
		local check = self.config.access_models[access_model];
		if check then
			local aff = check(actor);
			if aff then
				return aff;
			end
		end
	end
end

function service:set_affiliation(node, actor, jid, affiliation)
	-- Access checking
	if not self:may(node, actor, "set_affiliation") then
		return false, "forbidden";
	end
	--
	local node_obj = self.nodes[node];
	if not node_obj then
		return false, "item-not-found";
	end
	jid = self.config.normalize_jid(jid);
	local old_affiliation = node_obj.affiliations[jid];
	node_obj.affiliations[jid] = affiliation;

	if self.config.nodestore then
		local ok, err = save_node_to_store(self, node_obj);
		if not ok then
			node_obj.affiliations[jid] = old_affiliation;
			return ok, "internal-server-error";
		end
	end

	local _, jid_sub = self:get_subscription(node, true, jid);
	if not jid_sub and not self:may(node, jid, "be_unsubscribed") then
		local ok, err = self:add_subscription(node, true, jid);
		if not ok then
			return ok, err;
		end
	elseif jid_sub and not self:may(node, jid, "be_subscribed") then
		local ok, err = self:add_subscription(node, true, jid);
		if not ok then
			return ok, err;
		end
	end
	return true;
end

function service:add_subscription(node, actor, jid, options)
	-- Access checking
	local cap;
	if actor == true or jid == actor or self:jids_equal(actor, jid) then
		cap = "subscribe";
	else
		cap = "subscribe_other";
	end
	if not self:may(node, actor, cap) then
		return false, "forbidden";
	end
	if not self:may(node, jid, "be_subscribed") then
		return false, "forbidden";
	end
	--
	local node_obj = self.nodes[node];
	if not node_obj then
		if not self.config.autocreate_on_subscribe then
			return false, "item-not-found";
		else
			local ok, err = self:create(node, true);
			if not ok then
				return ok, err;
			end
			node_obj = self.nodes[node];
		end
	end
	local old_subscription = node_obj.subscribers[jid];
	node_obj.subscribers[jid] = options or true;
	local normal_jid = self.config.normalize_jid(jid);
	local subs = self.subscriptions[normal_jid];
	if subs then
		if not subs[jid] then
			subs[jid] = { [node] = true };
		else
			subs[jid][node] = true;
		end
	else
		self.subscriptions[normal_jid] = { [jid] = { [node] = true } };
	end

	if self.config.nodestore then
		local ok, err = save_node_to_store(self, node_obj);
		if not ok then
			node_obj.subscribers[jid] = old_subscription;
			self.subscriptions[normal_jid][jid][node] = old_subscription and true or nil;
			return ok, "internal-server-error";
		end
	end

	self.events.fire_event("subscription-added", { node = node, jid = jid, normalized_jid = normal_jid, options = options });
	return true;
end

function service:remove_subscription(node, actor, jid)
	-- Access checking
	local cap;
	if actor == true or jid == actor or self:jids_equal(actor, jid) then
		cap = "unsubscribe";
	else
		cap = "unsubscribe_other";
	end
	if not self:may(node, actor, cap) then
		return false, "forbidden";
	end
	if not self:may(node, jid, "be_unsubscribed") then
		return false, "forbidden";
	end
	--
	local node_obj = self.nodes[node];
	if not node_obj then
		return false, "item-not-found";
	end
	if not node_obj.subscribers[jid] then
		return false, "not-subscribed";
	end
	local old_subscription = node_obj.subscribers[jid];
	node_obj.subscribers[jid] = nil;
	local normal_jid = self.config.normalize_jid(jid);
	local subs = self.subscriptions[normal_jid];
	if subs then
		local jid_subs = subs[jid];
		if jid_subs then
			jid_subs[node] = nil;
			if next(jid_subs) == nil then
				subs[jid] = nil;
			end
		end
		if next(subs) == nil then
			self.subscriptions[normal_jid] = nil;
		end
	end

	if self.config.nodestore then
		local ok, err = save_node_to_store(self, node_obj);
		if not ok then
			node_obj.subscribers[jid] = old_subscription;
			self.subscriptions[normal_jid][jid][node] = old_subscription and true or nil;
			return ok, "internal-server-error";
		end
	end

	self.events.fire_event("subscription-removed", { node = node, jid = jid, normalized_jid = normal_jid });
	return true;
end

function service:get_subscription(node, actor, jid)
	-- Access checking
	local cap;
	if actor == true or jid == actor or self:jids_equal(actor, jid) then
		cap = "get_subscription";
	else
		cap = "get_subscription_other";
	end
	if not self:may(node, actor, cap) then
		return false, "forbidden";
	end
	--
	local node_obj = self.nodes[node];
	if not node_obj then
		return false, "item-not-found";
	end
	return true, node_obj.subscribers[jid];
end

function service:create(node, actor, options)
	-- Access checking
	if not self:may(node, actor, "create") then
		return false, "forbidden";
	end
	--
	if self.nodes[node] then
		return false, "conflict";
	end

	self.nodes[node] = {
		name = node;
		subscribers = {};
		config = setmetatable(options or {}, {__index=self.node_defaults});
		affiliations = {};
	};

	if self.config.nodestore then
		local ok, err = save_node_to_store(self, self.nodes[node]);
		if not ok then
			self.nodes[node] = nil;
			return ok, "internal-server-error";
		end
	end

	self.data[node] = self.config.itemstore(self.nodes[node].config, node);
	self.events.fire_event("node-created", { node = node, actor = actor });
	if actor ~= true then
		local ok, err = self:set_affiliation(node, true, actor, "owner");
		if not ok then
			self.nodes[node] = nil;
			self.data[node] = nil;
			return ok, err;
		end
	end

	return true;
end

function service:delete(node, actor)
	-- Access checking
	if not self:may(node, actor, "delete") then
		return false, "forbidden";
	end
	--
	local node_obj = self.nodes[node];
	if not node_obj then
		return false, "item-not-found";
	end
	self.nodes[node] = nil;
	if self.data[node] and self.data[node].clear then
		self.data[node]:clear();
	end
	self.data[node] = nil;

	if self.config.nodestore then
		local ok, err = delete_node_in_store(self, node);
		if not ok then
			self.nodes[node] = nil;
			return ok, err;
		end
	end

	self.events.fire_event("node-deleted", { node = node, actor = actor });
	self.config.broadcaster("delete", node, node_obj.subscribers, nil, actor, node_obj, self);
	return true;
end

function service:publish(node, actor, id, item)
	-- Access checking
	if not self:may(node, actor, "publish") then
		return false, "forbidden";
	end
	--
	local node_obj = self.nodes[node];
	if not node_obj then
		if not self.config.autocreate_on_publish then
			return false, "item-not-found";
		end
		local ok, err = self:create(node, true);
		if not ok then
			return ok, err;
		end
		node_obj = self.nodes[node];
	end
	if not self.config.itemcheck(item) then
		return nil, "internal-server-error";
	end
	local node_data = self.data[node];
	local ok = node_data:set(id, item);
	if not ok then
		return nil, "internal-server-error";
	end
	if type(ok) == "string" then id = ok; end
	self.events.fire_event("item-published", { node = node, actor = actor, id = id, item = item });
	self.config.broadcaster("items", node, node_obj.subscribers, item, actor, node_obj, self);
	return true;
end

function service:retract(node, actor, id, retract)
	-- Access checking
	if not self:may(node, actor, "retract") then
		return false, "forbidden";
	end
	--
	local node_obj = self.nodes[node];
	if (not node_obj) or (not self.data[node]:get(id)) then
		return false, "item-not-found";
	end
	local ok = self.data[node]:set(id, nil);
	if not ok then
		return nil, "internal-server-error";
	end
	self.events.fire_event("item-retracted", { node = node, actor = actor, id = id });
	if retract then
		self.config.broadcaster("items", node, node_obj.subscribers, retract, actor, node_obj, self);
	end
	return true
end

function service:purge(node, actor, notify)
	-- Access checking
	if not self:may(node, actor, "retract") then
		return false, "forbidden";
	end
	--
	local node_obj = self.nodes[node];
	if not node_obj then
		return false, "item-not-found";
	end
	if self.data[node] and self.data[node].clear then
		self.data[node]:clear()
	else
		self.data[node] = self.config.itemstore(self.nodes[node].config, node);
	end
	self.events.fire_event("node-purged", { node = node, actor = actor });
	if notify then
		self.config.broadcaster("purge", node, node_obj.subscribers, nil, actor, node_obj, self);
	end
	return true
end

function service:get_items(node, actor, id)
	-- Access checking
	if not self:may(node, actor, "get_items") then
		return false, "forbidden";
	end
	--
	local node_obj = self.nodes[node];
	if not node_obj then
		return false, "item-not-found";
	end
	if id then -- Restrict results to a single specific item
		local with_id = self.data[node]:get(id);
		if not with_id then
			return true, { };
		end
		return true, { id, [id] = with_id };
	else
		local data = {}
		for key, value in self.data[node]:items() do
			data[#data+1] = key;
			data[key] = value;
		end
		return true, data;
	end
end

function service:get_last_item(node, actor)
	-- Access checking
	if not self:may(node, actor, "get_items") then
		return false, "forbidden";
	end
	--
	return true, self.data[node]:tail();
end

function service:get_nodes(actor)
	-- Access checking
	if not self:may(nil, actor, "get_nodes") then
		return false, "forbidden";
	end
	--
	return true, self.nodes;
end

local function flatten_subscriptions(ret, serv, subs, node, node_obj)
	for subscribed_jid, subscribed_nodes in pairs(subs) do
		if node then -- Return only subscriptions to this node
			if subscribed_nodes[node] then
				ret[#ret+1] = {
					node = node;
					jid = subscribed_jid;
					subscription = node_obj.subscribers[subscribed_jid];
				};
			end
		else -- Return subscriptions to all nodes
			local nodes = serv.nodes;
			for subscribed_node in pairs(subscribed_nodes) do
				ret[#ret+1] = {
					node = subscribed_node;
					jid = subscribed_jid;
					subscription = nodes[subscribed_node].subscribers[subscribed_jid];
				};
			end
		end
	end
end

function service:get_subscriptions(node, actor, jid)
	-- Access checking
	local cap;
	if actor == true or jid == actor or self:jids_equal(actor, jid) then
		cap = "get_subscriptions";
	else
		cap = "get_subscriptions_other";
	end
	if not self:may(node, actor, cap) then
		return false, "forbidden";
	end
	--
	local node_obj;
	if node then
		node_obj = self.nodes[node];
		if not node_obj then
			return false, "item-not-found";
		end
	end
	local ret = {};
	if jid == nil then
		for _, subs in pairs(self.subscriptions) do
			flatten_subscriptions(ret, self, subs, node, node_obj)
		end
		return true, ret;
	end
	local normal_jid = self.config.normalize_jid(jid);
	local subs = self.subscriptions[normal_jid];
	-- We return the subscription object from the node to save
	-- a get_subscription() call for each node.
	if subs then
		flatten_subscriptions(ret, self, subs, node, node_obj)
	end
	return true, ret;
end

-- Access models only affect 'none' affiliation caps, service/default access level...
function service:set_node_capabilities(node, actor, capabilities)
	-- Access checking
	if not self:may(node, actor, "configure") then
		return false, "forbidden";
	end
	--
	local node_obj = self.nodes[node];
	if not node_obj then
		return false, "item-not-found";
	end
	node_obj.capabilities = capabilities;
	return true;
end

function service:set_node_config(node, actor, new_config)
	if not self:may(node, actor, "configure") then
		return false, "forbidden";
	end

	local node_obj = self.nodes[node];
	if not node_obj then
		return false, "item-not-found";
	end

	setmetatable(new_config, {__index=self.node_defaults})

	if self.config.check_node_config then
		local ok = self.config.check_node_config(node, actor, new_config);
		if not ok then
			return false, "not-acceptable";
		end
	end

	local old_config = node_obj.config;
	node_obj.config = new_config;

	if self.config.nodestore then
		local ok, err = save_node_to_store(self, node_obj);
		if not ok then
			node_obj.config = old_config;
			return ok, "internal-server-error";
		end
	end

	if old_config["persist_items"] ~= node_obj.config["persist_items"] then
		self.data[node] = self.config.itemstore(self.nodes[node].config, node);
	elseif old_config["max_items"] ~= node_obj.config["max_items"] then
		self.data[node]:resize(self.nodes[node].config["max_items"]);
	end

	return true;
end

function service:get_node_config(node, actor)
	if not self:may(node, actor, "get_configuration") then
		return false, "forbidden";
	end

	local node_obj = self.nodes[node];
	if not node_obj then
		return false, "item-not-found";
	end

	local config_table = {};
	for k, v in pairs(default_node_config) do
		config_table[k] = v;
	end
	for k, v in pairs(node_obj.config) do
		config_table[k] = v;
	end

	return true, config_table;
end

return {
	new = new;
};