util.fsm: New utility lib for finite state machines
authorMatthew Wild <mwild1@gmail.com>
Thu, 17 Mar 2022 17:45:27 +0000
changeset 13023 8a2f75e38eb2
parent 13022 9ed4a8502c54
child 13024 3174308d127e
util.fsm: New utility lib for finite state machines
spec/util_fsm_spec.lua
util/fsm.lua
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/spec/util_fsm_spec.lua	Thu Mar 17 17:45:27 2022 +0000
@@ -0,0 +1,250 @@
+describe("util.fsm", function ()
+	local new_fsm = require "util.fsm".new;
+
+	do
+		local fsm = new_fsm({
+			transitions = {
+				{ name = "melt", from = "solid", to = "liquid" };
+				{ name = "freeze", from = "liquid", to = "solid" };
+			};
+		});
+
+		it("works", function ()
+			local water = fsm:init("liquid");
+			water:freeze();
+			assert.equal("solid", water.state);
+			water:melt();
+			assert.equal("liquid", water.state);
+		end);
+
+		it("does not allow invalid transitions", function ()
+			local water = fsm:init("liquid");
+			assert.has_errors(function ()
+				water:melt();
+			end, "Invalid state transition: liquid cannot melt");
+
+			water:freeze();
+			assert.equal("solid", water.state);
+
+			water:melt();
+			assert.equal("liquid", water.state);
+
+			assert.has_errors(function ()
+				water:melt();
+			end, "Invalid state transition: liquid cannot melt");
+		end);
+	end
+
+	it("notifies observers", function ()
+		local n = 0;
+		local has_become_solid = spy.new(function (transition)
+			assert.is_table(transition);
+			assert.equal("solid", transition.to);
+			assert.is_not_nil(transition.instance);
+			n = n + 1;
+			if n == 1 then
+				assert.is_nil(transition.from);
+				assert.is_nil(transition.from_attr);
+			elseif n == 2 then
+				assert.equal("liquid", transition.from);
+				assert.is_nil(transition.from_attr);
+				assert.equal("freeze", transition.name);
+			end
+		end);
+		local is_melting = spy.new(function (transition)
+			assert.is_table(transition);
+			assert.equal("melt", transition.name);
+			assert.is_not_nil(transition.instance);
+		end);
+		local fsm = new_fsm({
+			transitions = {
+				{ name = "melt", from = "solid", to = "liquid" };
+				{ name = "freeze", from = "liquid", to = "solid" };
+			};
+			state_handlers = {
+				solid = has_become_solid;
+			};
+
+			transition_handlers = {
+				melt = is_melting;
+			};
+		});
+
+		local water = fsm:init("liquid");
+		assert.spy(has_become_solid).was_not_called();
+
+		local ice = fsm:init("solid"); --luacheck: ignore 211/ice
+		assert.spy(has_become_solid).was_called(1);
+
+		water:freeze();
+
+		assert.spy(is_melting).was_not_called();
+		water:melt();
+		assert.spy(is_melting).was_called(1);
+	end);
+
+	local function test_machine(fsm_spec, expected_transitions, test_func)
+		fsm_spec.handlers = fsm_spec.handlers or {};
+		fsm_spec.handlers.transitioned = function (transition)
+			local expected_transition = table.remove(expected_transitions, 1);
+			assert.same(expected_transition, {
+				name = transition.name;
+				to = transition.to;
+				to_attr = transition.to_attr;
+				from = transition.from;
+				from_attr = transition.from_attr;
+			});
+		end;
+		local fsm = new_fsm(fsm_spec);
+		test_func(fsm);
+		assert.equal(0, #expected_transitions);
+	end
+
+
+	it("handles transitions with the same name", function ()
+		local expected_transitions = {
+			{ name = nil   , from = "none", to = "A" };
+			{ name = "step", from = "A", to = "B" };
+			{ name = "step", from = "B", to = "C" };
+			{ name = "step", from = "C", to = "D" };
+		};
+
+		test_machine({
+			default_state = "none";
+			transitions = {
+				{ name = "step", from = "A", to = "B" };
+				{ name = "step", from = "B", to = "C" };
+				{ name = "step", from = "C", to = "D" };
+			};
+		}, expected_transitions, function (fsm)
+			local i = fsm:init("A");
+			i:step(); -- B
+			i:step(); -- C
+			i:step(); -- D
+			assert.has_errors(function ()
+				i:step();
+			end, "Invalid state transition: D cannot step");
+		end);
+	end);
+
+	it("handles supports wildcard transitions", function ()
+		local expected_transitions = {
+			{ name = nil   , from = "none", to = "A" };
+			{ name = "step", from = "A", to = "B" };
+			{ name = "step", from = "B", to = "C" };
+			{ name = "reset", from = "C", to = "A" };
+			{ name = "step", from = "A", to = "B" };
+			{ name = "step", from = "B", to = "C" };
+			{ name = "step", from = "C", to = "D" };
+		};
+
+		test_machine({
+			default_state = "none";
+			transitions = {
+				{ name = "step", from = "A", to = "B" };
+				{ name = "step", from = "B", to = "C" };
+				{ name = "step", from = "C", to = "D" };
+				{ name = "reset", from = "*", to = "A" };
+			};
+		}, expected_transitions, function (fsm)
+			local i = fsm:init("A");
+			i:step(); -- B
+			i:step(); -- C
+			i:reset(); -- A
+			i:step(); -- B
+			i:step(); -- C
+			i:step(); -- D
+			assert.has_errors(function ()
+				i:step();
+			end, "Invalid state transition: D cannot step");
+		end);
+	end);
+
+	it("supports specifying multiple from states", function ()
+		local expected_transitions = {
+			{ name = nil   , from = "none", to = "A" };
+			{ name = "step", from = "A", to = "B" };
+			{ name = "step", from = "B", to = "C" };
+			{ name = "reset", from = "C", to = "A" };
+			{ name = "step", from = "A", to = "B" };
+			{ name = "step", from = "B", to = "C" };
+			{ name = "step", from = "C", to = "D" };
+		};
+
+		test_machine({
+			default_state = "none";
+			transitions = {
+				{ name = "step", from = "A", to = "B" };
+				{ name = "step", from = "B", to = "C" };
+				{ name = "step", from = "C", to = "D" };
+				{ name = "reset", from = {"B", "C", "D"}, to = "A" };
+			};
+		}, expected_transitions, function (fsm)
+			local i = fsm:init("A");
+			i:step(); -- B
+			i:step(); -- C
+			i:reset(); -- A
+			assert.has_errors(function ()
+				i:reset();
+			end, "Invalid state transition: A cannot reset");
+			i:step(); -- B
+			i:step(); -- C
+			i:step(); -- D
+			assert.has_errors(function ()
+				i:step();
+			end, "Invalid state transition: D cannot step");
+		end);
+	end);
+
+	it("handles transitions with the same start/end state", function ()
+		local expected_transitions = {
+			{ name = nil   , from = "none", to = "A" };
+			{ name = "step", from = "A", to = "B" };
+			{ name = "step", from = "B", to = "B" };
+			{ name = "step", from = "B", to = "B" };
+		};
+
+		test_machine({
+			default_state = "none";
+			transitions = {
+				{ name = "step", from = "A", to = "B" };
+				{ name = "step", from = "B", to = "B" };
+			};
+		}, expected_transitions, function (fsm)
+			local i = fsm:init("A");
+			i:step(); -- B
+			i:step(); -- B
+			i:step(); -- B
+		end);
+	end);
+
+	it("can identify instances of a specific fsm", function ()
+		local fsm1 = new_fsm({ default_state = "a" });
+		local fsm2 = new_fsm({ default_state = "a" });
+
+		local i1 = fsm1:init();
+		local i2 = fsm2:init();
+
+		assert.truthy(fsm1:is_instance(i1));
+		assert.truthy(fsm2:is_instance(i2));
+
+		assert.falsy(fsm1:is_instance(i2));
+		assert.falsy(fsm2:is_instance(i1));
+	end);
+
+	it("errors when an invalid initial state is passed", function ()
+		local fsm1 = new_fsm({
+			transitions = {
+				{ name = "", from = "A", to = "B" };
+			};
+		});
+
+		assert.has_no_errors(function ()
+			fsm1:init("A");
+		end);
+
+		assert.has_errors(function ()
+			fsm1:init("C");
+		end);
+	end);
+end);
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/util/fsm.lua	Thu Mar 17 17:45:27 2022 +0000
@@ -0,0 +1,154 @@
+local events = require "util.events";
+
+local fsm_methods = {};
+local fsm_mt = { __index = fsm_methods };
+
+local function is_fsm(o)
+	local mt = getmetatable(o);
+	return mt == fsm_mt;
+end
+
+local function notify_transition(fire_event, transition_event)
+	local ret;
+	ret = fire_event("transition", transition_event);
+	if ret ~= nil then return ret; end
+	if transition_event.from ~= transition_event.to then
+		ret = fire_event("leave/"..transition_event.from, transition_event);
+		if ret ~= nil then return ret; end
+	end
+	ret = fire_event("transition/"..transition_event.name, transition_event);
+	if ret ~= nil then return ret; end
+end
+
+local function notify_transitioned(fire_event, transition_event)
+	if transition_event.to ~= transition_event.from then
+		fire_event("enter/"..transition_event.to, transition_event);
+	end
+	if transition_event.name then
+		fire_event("transitioned/"..transition_event.name, transition_event);
+	end
+	fire_event("transitioned", transition_event);
+end
+
+local function do_transition(name)
+	return function (self, attr)
+		local new_state = self.fsm.states[self.state][name] or self.fsm.states["*"][name];
+		if not new_state then
+			return error(("Invalid state transition: %s cannot %s"):format(self.state, name));
+		end
+
+		local transition_event = {
+			instance = self;
+
+			name = name;
+			to = new_state;
+			to_attr = attr;
+
+			from = self.state;
+			from_attr = self.state_attr;
+		};
+
+		local fire_event = self.fsm.events.fire_event;
+		local ret = notify_transition(fire_event, transition_event);
+		if ret ~= nil then return nil, ret; end
+
+		self.state = new_state;
+		self.state_attr = attr;
+
+		notify_transitioned(fire_event, transition_event);
+		return true;
+	end;
+end
+
+local function new(desc)
+	local self = setmetatable({
+		default_state = desc.default_state;
+		events = events.new();
+	}, fsm_mt);
+
+	-- states[state_name][transition_name] = new_state_name
+	local states = { ["*"] = {} };
+	if desc.default_state then
+		states[desc.default_state] = {};
+	end
+	self.states = states;
+
+	local instance_methods = {};
+	self._instance_mt = { __index = instance_methods };
+
+	for _, transition in ipairs(desc.transitions or {}) do
+		local from_states = transition.from;
+		if type(from_states) ~= "table" then
+			from_states = { from_states };
+		end
+		for _, from in ipairs(from_states) do
+			if not states[from] then
+				states[from] = {};
+			end
+			if not states[transition.to] then
+				states[transition.to] = {};
+			end
+			if states[from][transition.name] then
+				return error(("Duplicate transition in FSM specification: %s from %s"):format(transition.name, from));
+			end
+			states[from][transition.name] = transition.to;
+		end
+
+		-- Add public method to trigger this transition
+		instance_methods[transition.name] = do_transition(transition.name);
+	end
+
+	if desc.state_handlers then
+		for state_name, handler in pairs(desc.state_handlers) do
+			self.events.add_handler("enter/"..state_name, handler);
+		end
+	end
+
+	if desc.transition_handlers then
+		for transition_name, handler in pairs(desc.transition_handlers) do
+			self.events.add_handler("transition/"..transition_name, handler);
+		end
+	end
+
+	if desc.handlers then
+		self.events.add_handlers(desc.handlers);
+	end
+
+	return self;
+end
+
+function fsm_methods:init(state_name, state_attr)
+	local initial_state = assert(state_name or self.default_state, "no initial state specified");
+	if not self.states[initial_state] then
+		return error("Invalid initial state: "..initial_state);
+	end
+	local instance = setmetatable({
+		fsm = self;
+		state = initial_state;
+		state_attr = state_attr;
+	}, self._instance_mt);
+
+	if initial_state ~= self.default_state then
+		local fire_event = self.events.fire_event;
+		notify_transitioned(fire_event, {
+			instance = instance;
+
+			to = initial_state;
+			to_attr = state_attr;
+
+			from = self.default_state;
+		});
+	end
+
+	return instance;
+end
+
+function fsm_methods:is_instance(o)
+	local mt = getmetatable(o);
+	return mt == self._instance_mt;
+end
+
+return {
+	new = new;
+	is_fsm = is_fsm;
+};