plugins/mod_admin_shell.lua
changeset 11889 197642f9972f
parent 11854 bfa85965106e
child 11890 b0b258e092da
--- a/plugins/mod_admin_shell.lua	Sat Nov 06 18:45:44 2021 +0100
+++ b/plugins/mod_admin_shell.lua	Wed Nov 10 17:59:35 2021 +0100
@@ -42,6 +42,10 @@
 local format_number = require "util.human.units".format;
 local format_table = require "util.human.io".table;
 
+local function capitalize(s)
+	return (s:gsub("^%a", string.upper):gsub("_", " "));
+end
+
 local commands = module:shared("commands")
 local def_env = module:shared("env");
 local default_env_mt = { __index = def_env };
@@ -205,15 +209,13 @@
 		print [[config - Reloading the configuration, etc.]]
 		print [[console - Help regarding the console itself]]
 	elseif section == "c2s" then
-		print [[c2s:show(jid) - Show all client sessions with the specified JID (or all if no JID given)]]
-		print [[c2s:show_insecure() - Show all unencrypted client connections]]
-		print [[c2s:show_secure() - Show all encrypted client connections]]
-		print [[c2s:show_tls() - Show TLS cipher info for encrypted sessions]]
+		print [[c2s:show(jid, columns) - Show all client sessions with the specified JID (or all if no JID given)]]
+		print [[c2s:show_tls(jid) - Show TLS cipher info for encrypted sessions]]
 		print [[c2s:count() - Count sessions without listing them]]
 		print [[c2s:close(jid) - Close all sessions for the specified JID]]
 		print [[c2s:closeall() - Close all active c2s connections ]]
 	elseif section == "s2s" then
-		print [[s2s:show(domain) - Show all s2s connections for the given domain (or all if no domain given)]]
+		print [[s2s:show(domain, columns) - Show all s2s connections for the given domain (or all if no domain given)]]
 		print [[s2s:show_tls(domain) - Show TLS cipher info for encrypted sessions]]
 		print [[s2s:close(from, to) - Close a connection from one domain to another]]
 		print [[s2s:closeall(host) - Close all the incoming/outgoing s2s sessions to specified host]]
@@ -582,101 +584,6 @@
 	return ok, (ok and "Config reloaded (you may need to reload modules to take effect)") or tostring(err);
 end
 
-local function common_info(session, line)
-	if session.id then
-		line[#line+1] = "["..session.id.."]"
-	else
-		line[#line+1] = "["..session.type..(tostring(session):match("%x*$")).."]"
-	end
-end
-
-local function session_flags(session, line)
-	line = line or {};
-	common_info(session, line);
-	if session.type == "c2s" then
-		local status, priority = "unavailable", tostring(session.priority or "-");
-		if session.presence then
-			status = session.presence:get_child_text("show") or "available";
-		end
-		line[#line+1] = status.."("..priority..")";
-	end
-	if session.cert_identity_status == "valid" then
-		line[#line+1] = "(authenticated)";
-	end
-	if session.dialback_key then
-		line[#line+1] = "(dialback)";
-	end
-	if session.external_auth then
-		line[#line+1] = "(SASL)";
-	end
-	if session.secure then
-		line[#line+1] = "(encrypted)";
-	end
-	if session.compressed then
-		line[#line+1] = "(compressed)";
-	end
-	if session.smacks then
-		line[#line+1] = "(sm)";
-	end
-	if session.state then
-		if type(session.csi_counter) == "number" then
-			line[#line+1] = string.format("(csi:%s queue #%d)", session.state, session.csi_counter);
-		else
-			line[#line+1] = string.format("(csi:%s)", session.state);
-		end
-	end
-	if session.ip and session.ip:match(":") then
-		line[#line+1] = "(IPv6)";
-	end
-	if session.remote then
-		line[#line+1] = "(remote)";
-	end
-	if session.incoming and session.outgoing then
-		line[#line+1] = "(bidi)";
-	elseif session.is_bidi or session.bidi_session then
-		line[#line+1] = "(bidi)";
-	end
-	if session.bosh_version then
-		line[#line+1] = "(bosh)";
-	end
-	if session.websocket_request then
-		line[#line+1] = "(websocket)";
-	end
-	return table.concat(line, " ");
-end
-
-local function tls_info(session, line)
-	line = line or {};
-	common_info(session, line);
-	if session.secure then
-		local sock = session.conn and session.conn.socket and session.conn:socket();
-		if sock then
-			local info = sock.info and sock:info();
-			if info then
-				line[#line+1] = ("(%s with %s)"):format(info.protocol, info.cipher);
-			else
-				-- TLS session might not be ready yet
-				line[#line+1] = "(cipher info unavailable)";
-			end
-			if sock.getsniname then
-				local name = sock:getsniname();
-				if name then
-					line[#line+1] = ("(SNI:%q)"):format(name);
-				end
-			end
-			if sock.getalpn then
-				local proto = sock:getalpn();
-				if proto then
-					line[#line+1] = ("(ALPN:%q)"):format(proto);
-				end
-			end
-		end
-	else
-		line[#line+1] = "(insecure)";
-	end
-	return table.concat(line, " ");
-end
-
 def_env.c2s = {};
 
 local function get_jid(session)
@@ -700,16 +607,16 @@
 	return c2s;
 end
 
+local function _sort_by_jid(a, b)
+	if a.host == b.host then
+		if a.username == b.username then return (a.resource or "") > (b.resource or ""); end
+		return (a.username or "") > (b.username or "");
+	end
+	return _sort_hosts(a.host or "", b.host or "");
+end
+
 local function show_c2s(callback)
-	get_c2s():sort(function(a, b)
-		if a.host == b.host then
-			if a.username == b.username then
-				return (a.resource or "") > (b.resource or "");
-			end
-			return (a.username or "") > (b.username or "");
-		end
-		return _sort_hosts(a.host or "", b.host or "");
-	end):map(function (session)
+	get_c2s():sort(_sort_by_jid):map(function (session)
 		callback(get_jid(session), session)
 	end);
 end
@@ -719,47 +626,223 @@
 	return true, "Total: "..  #c2s .." clients";
 end
 
-function def_env.c2s:show(match_jid, annotate)
-	local print, count = self.session.print, 0;
-	annotate = annotate or session_flags;
-	local curr_host = false;
-	show_c2s(function (jid, session)
-		if curr_host ~= session.host then
-			curr_host = session.host;
-			print(curr_host or "(not connected to any host yet)");
-		end
-		if (not match_jid) or jid:match(match_jid) then
-			count = count + 1;
-			print(annotate(session, { "  ", jid }));
-		end
-	end);
-	return true, "Total: "..count.." clients";
+local function get_s2s_hosts(session) --> local,remote
+	if session.direction == "outgoing" then
+		return session.host or session.from_host, session.to_host;
+	elseif session.direction == "incoming" then
+		return session.host or session.to_host, session.from_host;
+	end
 end
 
-function def_env.c2s:show_insecure(match_jid)
-	local print, count = self.session.print, 0;
-	show_c2s(function (jid, session)
-		if ((not match_jid) or jid:match(match_jid)) and not session.secure then
-			count = count + 1;
-			print(jid);
+local available_columns = {
+	jid = {
+		title = "JID";
+		width = 32;
+		key = "full_jid";
+		mapper = function(full_jid, session) return full_jid or get_jid(session) end;
+	};
+	host = {
+		title = "Host";
+		key = "host";
+		width = 22;
+		mapper = function(host, session)
+			if host ~= "" then return host; end
+			return get_s2s_hosts(session) or "?";
+		end;
+	};
+	remote = {
+		title = "Remote";
+		width = 22;
+		mapper = function(_, session)
+			return select(2, get_s2s_hosts(session));
+		end;
+	};
+	dir = {
+		title = "Dir";
+		width = 3;
+		key = "direction";
+		mapper = function (dir)
+			if dir == "outgoing" then return "-->"; end
+			if dir == "incoming" then return "<--"; end
+			return ""
+		end;
+	};
+	id = { title = "Session ID"; width = 20; key = "id" };
+	type = { title = "Type"; width = #"c2s_unauthed"; key = "type" };
+	method = {
+		title = "Method";
+		width = 10;
+		mapper = function(_, session)
+			if session.bosh_version then
+				return "BOSH";
+			elseif session.websocket_request then
+				return "WebSocket";
+			else
+				return "TCP";
+			end
+		end;
+	};
+	ipv = {
+		title = "IPv";
+		width = 4;
+		key = "ip";
+		mapper = function(ip) return ip:find(":") and "IPv6" or "IPv4"; end;
+	};
+	ip = { title = "IP address"; width = 40; key = "ip" };
+	status = {
+		title = "Status";
+		width = 11;
+		key = "presence";
+		mapper = function(p)
+			if not p or p == "" then return "unavailable"; end
+			return p:get_child_text("show") or "available";
+		end;
+	};
+	secure = {
+		title = "Security";
+		key = "conn";
+		width = 11;
+		mapper = function(conn, session)
+			if not session.secure then return "insecure"; end
+			if conn == "" or not conn:ssl() then return "secure" end
+			local sock = conn ~= "" and conn:socket();
+			if not sock then return "unknown TLS"; end
+			local tls_info = sock.info and sock:info();
+			return tls_info and tls_info.protocol or "unknown TLS";
+		end;
+	};
+	encryption = {
+		title = "Encryption";
+		width = 30;
+		key = "conn";
+		mapper = function(conn)
+			local sock = conn ~= "" and conn:socket();
+			local info = sock and sock.info and sock:info();
+			if info then return info.cipher end
+			return ""
+		end;
+	};
+	cert = {
+		title = "Certificate";
+		key = "cert_identity_status";
+		mapper = function(cert_status, session)
+			if cert_status ~= "" then return capitalize(cert_status); end
+			if session.cert_chain_status == "Invalid" then
+				local cert_errors = set.new(session.cert_chain_errors[1]);
+				if cert_errors:contains("certificate has expired") then
+					return "Expired";
+				elseif cert_errors:contains("self signed certificate") then
+					return "Self-signed";
+				end
+				return "Untrusted";
+			elseif session.cert_identity_status == "invalid" then
+				return "Mismatched";
+			end
+			return "Not validated";
+		end;
+	};
+	sni = {
+		title = "SNI";
+		width = 22;
+		mapper = function(_, session)
+			if not session.conn then return "" end
+			local sock = session.conn:socket();
+			return sock and sock.getsniname and sock:getsniname() or "";
+		end;
+	};
+	alpn = {
+		title = "ALPN";
+		width = 11;
+		mapper = function(_, session)
+			if not session.conn then return "" end
+			local sock = session.conn:socket();
+			return sock and sock.getalpn and sock:getalpn() or "";
+		end;
+	};
+	smacks = {
+		title = "SM";
+		key = "smacks";
+		width = 11;
+		mapper = function(smacks_xmlns, session)
+			if smacks_xmlns == "" then return "no"; end
+			if session.hibernating then return "hibernating"; end
+			return "yes";
+		end;
+	};
+	smacks_queue = {
+		title = "SM Queue";
+		key = "outgoing_stanza_queue";
+		width = 8;
+		align = "right";
+		mapper = function (queue)
+			return tostring(#queue);
 		end
-	end);
-	return true, "Total: "..count.." insecure client connections";
+	};
+	csi = {
+		title = "CSI State";
+		key = "state";
+		-- TODO include counter
+	};
+	s2s_sasl = {
+		title = "SASL";
+		key = "external_auth";
+		width = 10;
+		mapper = capitalize
+	};
+	dialback = {
+		title = "Dialback";
+		key = "dialback_key";
+		width = 13;
+		mapper = function (dialback_key, session)
+			if dialback_key == "" then
+				if session.type == "s2sin" or session.type == "s2sout" then
+					return "Not used";
+				end
+				return "Not initiated";
+			elseif session.type == "s2sin_unauthed" or session.type == "s2sout_unauthed" then
+				return "Initiated";
+			else
+				return "Completed";
+			end
+		end
+	};
+};
+
+local function get_colspec(colspec, default)
+	local columns = {};
+	for i, col in pairs(colspec or default) do
+		if type(col) == "string" then
+			columns[i] = available_columns[col] or { title = capitalize(col); width = 20; key = col };
+		elseif type(col) ~= "table" then
+			return false, ("argument %d: expected string|table but got %s"):format(i, type(col));
+		else
+			columns[i] = col;
+		end
+	end
+
+	return columns;
 end
 
-function def_env.c2s:show_secure(match_jid)
-	local print, count = self.session.print, 0;
-	show_c2s(function (jid, session)
-		if ((not match_jid) or jid:match(match_jid)) and session.secure then
-			count = count + 1;
-			print(jid);
-		end
-	end);
-	return true, "Total: "..count.." secure client connections";
+function def_env.c2s:show(match_jid, colspec)
+	local print = self.session.print;
+	local columns = get_colspec(colspec, { "id"; "jid"; "ipv"; "status"; "secure"; "smacks"; "csi" });
+	local row = format_table(columns, 120);
+
+	local function match(session)
+		local jid = get_jid(session)
+		return (not match_jid) or jid:match(match_jid)
+	end
+
+	print(row());
+
+	for _, session in ipairs(get_c2s():filter(match):sort(_sort_by_jid)) do
+		print(row(session));
+	end
+	return true;
 end
 
 function def_env.c2s:show_tls(match_jid)
-	return self:show(match_jid, tls_info);
+	return self:show(match_jid, { "jid"; "id"; "secure"; "encryption" });
 end
 
 local function build_reason(text, condition)
@@ -794,71 +877,35 @@
 
 
 def_env.s2s = {};
-function def_env.s2s:show(match_jid, annotate)
-	local print = self.session.print;
-	annotate = annotate or session_flags;
-
-	local count_in, count_out = 0,0;
-	local s2s_list = { };
-
-	local s2s_sessions = module:shared"/*/s2s/sessions";
-	for _, session in pairs(s2s_sessions) do
-		local remotehost, localhost, direction;
-		if session.direction == "outgoing" then
-			direction = "->";
-			count_out = count_out + 1;
-			remotehost, localhost = session.to_host or "?", session.from_host or "?";
-		else
-			direction = "<-";
-			count_in = count_in + 1;
-			remotehost, localhost = session.from_host or "?", session.to_host or "?";
-		end
-		local sess_lines = { l = localhost, r = remotehost,
-			annotate(session, { "", direction, remotehost or "?" })};
+local function _sort_s2s(a, b)
+	local a_local, a_remote = get_s2s_hosts(a);
+	local b_local, b_remote = get_s2s_hosts(b);
+	if (a_local or "") == (b_local or "") then return _sort_hosts(a_remote or "", b_remote or ""); end
+	return _sort_hosts(a_local or "", b_local or "");
+end
 
-		if (not match_jid) or remotehost:match(match_jid) or localhost:match(match_jid) then
-			table.insert(s2s_list, sess_lines);
-			-- luacheck: ignore 421/print
-			local print = function (s) table.insert(sess_lines, "        "..s); end
-			if session.sendq then
-				print("There are "..#session.sendq.." queued outgoing stanzas for this connection");
-			end
-			if session.type == "s2sout_unauthed" then
-				if session.notopen then
-					print("The <stream> has not yet been opened");
-				elseif not session.dialback_key then
-					print("Dialback has not been initiated yet");
-				elseif session.dialback_key then
-					print("Dialback has been requested, but no result received");
-				end
-			end
-			if session.type == "s2sin_unauthed" then
-				print("Connection not yet authenticated");
-			elseif session.type == "s2sin" then
-				for name in pairs(session.hosts) do
-					if name ~= session.from_host then
-						print("also hosts "..tostring(name));
-					end
-				end
-			end
-		end
+function def_env.s2s:show(match_jid, colspec)
+	local print = self.session.print;
+	local columns = get_colspec(colspec, { "id"; "host"; "dir"; "remote"; "ipv"; "secure"; "s2s_sasl"; "dialback" });
+	local row = format_table(columns, 132);
+
+	local function match(session)
+		local host, remote = get_s2s_hosts(session);
+		return not match_jid or (host or ""):match(match_jid) or (remote or ""):match(match_jid);
 	end
 
-	-- Sort by local host, then remote host
-	table.sort(s2s_list, function(a,b)
-		if a.l == b.l then return _sort_hosts(a.r, b.r); end
-		return _sort_hosts(a.l, b.l);
-	end);
-	local lasthost;
-	for _, sess_lines in ipairs(s2s_list) do
-		if sess_lines.l ~= lasthost then print(sess_lines.l); lasthost=sess_lines.l end
-		for _, line in ipairs(sess_lines) do print(line); end
+	local s2s_sessions = array(iterators.values(module:shared"/*/s2s/sessions")):filter(match):sort(_sort_s2s);
+
+	print(row());
+
+	for _, session in ipairs(s2s_sessions) do
+		print(row(session));
 	end
-	return true, "Total: "..count_out.." outgoing, "..count_in.." incoming connections";
+	return true; -- TODO counts
 end
 
 function def_env.s2s:show_tls(match_jid)
-	return self:show(match_jid, tls_info);
+	return self:show(match_jid, { "id"; "host"; "dir"; "remote"; "secure"; "encryption"; "cert" });
 end
 
 local function print_subject(print, subject)