mod_pubsub_mqtt/mqtt.lib.lua
changeset 5837 58df53eefa28
parent 5118 d2a84e6aed2b
--- a/mod_pubsub_mqtt/mqtt.lib.lua	Tue Jan 30 14:26:14 2024 +0000
+++ b/mod_pubsub_mqtt/mqtt.lib.lua	Wed Feb 07 11:57:30 2024 +0000
@@ -1,4 +1,4 @@
-local bit = require "bit";
+local bit = require "util.bitcompat";
 
 local stream_mt = {};
 stream_mt.__index = stream_mt;
@@ -29,10 +29,25 @@
 	return self:read_bytes(len), len+2;
 end
 
+function stream_mt:read_word()
+	local len1, len2 = self:read_bytes(2):byte(1,2);
+	local result = bit.lshift(len1, 8) + len2;
+	module:log("debug", "read_word(%02x, %02x) = %04x (%d)", len1, len2, result, result);
+	return result;
+end
+
+local function hasbit(byte, n_bit)
+	return bit.band(byte, 2^n_bit) ~= 0;
+end
+
+local function encode_string(str)
+	return string.char(bit.band(#str, 0xff00), bit.band(#str, 0x00ff))..str;
+end
+
 local packet_type_codes = {
 	"connect", "connack",
 	"publish", "puback", "pubrec", "pubrel", "pubcomp",
-	"subscribe", "subak", "unsubscribe", "unsuback",
+	"subscribe", "suback", "unsubscribe", "unsuback",
 	"pingreq", "pingresp",
 	"disconnect"
 };
@@ -59,9 +74,46 @@
 			packet.type = nil; -- Invalid packet
 		else
 			packet.version = self:read_bytes(1):byte();
-			packet.connect_flags = self:read_bytes(1):byte();
-			packet.keepalive_timer = self:read_bytes(1):byte();
+			module:log("debug", "ver: %02x", packet.version);
+			if packet.version ~= 0x04 then
+				module:log("warn", "MQTT version mismatch (got %02x, we support %02x", packet.version, 0x04);
+			end
+			local flags = self:read_bytes(1):byte();
+			module:log("debug", "flags: %02x", flags);
+			packet.keepalive_timer = self:read_bytes(2):byte();
+			module:log("debug", "keepalive: %d", packet.keepalive_timer);
+			packet.connect_flags = {};
 			length = length - 11;
+			packet.connect_flags = {
+				clean_session = hasbit(flags, 1);
+				will = hasbit(flags, 2);
+				will_qos = bit.band(bit.rshift(flags, 2), 0x02);
+				will_retain = hasbit(flags, 5);
+				user_name = hasbit(flags, 7);
+				password = hasbit(flags, 6);
+			};
+			module:log("debug", "%s", require "util.serialization".serialize(packet.connect_flags, "debug"));
+			module:log("debug", "Reading client_id...");
+			packet.client_id = self:read_string();
+			if packet.connect_flags.will then
+				module:log("debug", "Reading will...");
+				packet.will = {
+					topic = self:read_string();
+					message = self:read_string();
+					qos = packet.connect_flags.will_qos;
+					retain = packet.connect_flags.will_retain;
+				};
+			end
+			if packet.connect_flags.user_name then
+				module:log("debug", "Reading username...");
+				packet.username = self:read_string();
+			end
+			if packet.connect_flags.password then
+				module:log("debug", "Reading password...");
+				packet.password = self:read_string();
+			end
+			module:log("debug", "Done parsing connect!");
+			length = 0; -- No payload left
 		end
 	elseif packet.type == "publish" then
 		packet.topic = self:read_string();
@@ -87,6 +139,7 @@
 	if length > 0 then
 		packet.data = self:read_bytes(length);
 	end
+	module:log("debug", "MQTT packet complete!");
 	return packet;
 end
 
@@ -102,7 +155,6 @@
 end
 
 function stream_mt:feed(data)
-	module:log("debug", "Feeding %d bytes", #data);
 	local packets = {};
 	local packet = self.parser(data);
 	while packet do
@@ -135,10 +187,10 @@
 		packet.data = string.char(bit.band(#topic, 0xff00), bit.band(#topic, 0x00ff))..topic..packet.data;
 	elseif packet.type == "suback" then
 		local t = {};
-		for _, topic in ipairs(packet.topics) do
-			table.insert(t, string.char(bit.band(#topic, 0xff00), bit.band(#topic, 0x00ff))..topic.."\000");
+		for i, result_code in ipairs(packet.results) do
+			table.insert(t, string.char(result_code));
 		end
-		packet.data = table.concat(t);
+		packet.data = packet.id..table.concat(t);
 	end
 
 	-- Get length