util/sql.lua
changeset 5494 9916f0a2d178
child 5739 2e9ad9cb206f
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/util/sql.lua	Thu Apr 18 11:13:40 2013 +0100
@@ -0,0 +1,340 @@
+
+local setmetatable, getmetatable = setmetatable, getmetatable;
+local ipairs, unpack, select = ipairs, unpack, select;
+local tonumber, tostring = tonumber, tostring;
+local assert, xpcall, debug_traceback = assert, xpcall, debug.traceback;
+local t_concat = table.concat;
+local s_char = string.char;
+local log = require "util.logger".init("sql");
+
+local DBI = require "DBI";
+-- This loads all available drivers while globals are unlocked
+-- LuaDBI should be fixed to not set globals.
+DBI.Drivers();
+local build_url = require "socket.url".build;
+
+module("sql")
+
+local column_mt = {};
+local table_mt = {};
+local query_mt = {};
+--local op_mt = {};
+local index_mt = {};
+
+function is_column(x) return getmetatable(x)==column_mt; end
+function is_index(x) return getmetatable(x)==index_mt; end
+function is_table(x) return getmetatable(x)==table_mt; end
+function is_query(x) return getmetatable(x)==query_mt; end
+--function is_op(x) return getmetatable(x)==op_mt; end
+--function expr(...) return setmetatable({...}, op_mt); end
+function Integer(n) return "Integer()" end
+function String(n) return "String()" end
+
+--[[local ops = {
+	__add = function(a, b) return "("..a.."+"..b..")" end;
+	__sub = function(a, b) return "("..a.."-"..b..")" end;
+	__mul = function(a, b) return "("..a.."*"..b..")" end;
+	__div = function(a, b) return "("..a.."/"..b..")" end;
+	__mod = function(a, b) return "("..a.."%"..b..")" end;
+	__pow = function(a, b) return "POW("..a..","..b..")" end;
+	__unm = function(a) return "NOT("..a..")" end;
+	__len = function(a) return "COUNT("..a..")" end;
+	__eq = function(a, b) return "("..a.."=="..b..")" end;
+	__lt = function(a, b) return "("..a.."<"..b..")" end;
+	__le = function(a, b) return "("..a.."<="..b..")" end;
+};
+
+local functions = {
+	
+};
+
+local cmap = {
+	[Integer] = Integer();
+	[String] = String();
+};]]
+
+function Column(definition)
+	return setmetatable(definition, column_mt);
+end
+function Table(definition)
+	local c = {}
+	for i,col in ipairs(definition) do
+		if is_column(col) then
+			c[i], c[col.name] = col, col;
+		elseif is_index(col) then
+			col.table = definition.name;
+		end
+	end
+	return setmetatable({ __table__ = definition, c = c, name = definition.name }, table_mt);
+end
+function Index(definition)
+	return setmetatable(definition, index_mt);
+end
+
+function table_mt:__tostring()
+	local s = { 'name="'..self.__table__.name..'"' }
+	for i,col in ipairs(self.__table__) do
+		s[#s+1] = tostring(col);
+	end
+	return 'Table{ '..t_concat(s, ", ")..' }'
+end
+table_mt.__index = {};
+function table_mt.__index:create(engine)
+	return engine:_create_table(self);
+end
+function table_mt:__call(...)
+	-- TODO
+end
+function column_mt:__tostring()
+	return 'Column{ name="'..self.name..'", type="'..self.type..'" }'
+end
+function index_mt:__tostring()
+	local s = 'Index{ name="'..self.name..'"';
+	for i=1,#self do s = s..', "'..self[i]:gsub("[\\\"]", "\\%1")..'"'; end
+	return s..' }';
+--	return 'Index{ name="'..self.name..'", type="'..self.type..'" }'
+end
+--
+
+local function urldecode(s) return s and (s:gsub("%%(%x%x)", function (c) return s_char(tonumber(c,16)); end)); end
+local function parse_url(url)
+	local scheme, secondpart, database = url:match("^([%w%+]+)://([^/]*)/?(.*)");
+	assert(scheme, "Invalid URL format");
+	local username, password, host, port;
+	local authpart, hostpart = secondpart:match("([^@]+)@([^@+])");
+	if not authpart then hostpart = secondpart; end
+	if authpart then
+		username, password = authpart:match("([^:]*):(.*)");
+		username = username or authpart;
+		password = password and urldecode(password);
+	end
+	if hostpart then
+		host, port = hostpart:match("([^:]*):(.*)");
+		host = host or hostpart;
+		port = port and assert(tonumber(port), "Invalid URL format");
+	end
+	return {
+		scheme = scheme:lower();
+		username = username; password = password;
+		host = host; port = port;
+		database = #database > 0 and database or nil;
+	};
+end
+
+--[[local session = {};
+
+function session.query(...)
+	local rets = {...};
+	local query = setmetatable({ __rets = rets, __filters }, query_mt);
+	return query;
+end
+--
+
+local function db2uri(params)
+	return build_url{
+		scheme = params.driver,
+		user = params.username,
+		password = params.password,
+		host = params.host,
+		port = params.port,
+		path = params.database,
+	};
+end]]
+
+local engine = {};
+function engine:connect()
+	if self.conn then return true; end
+
+	local params = self.params;
+	assert(params.driver, "no driver")
+	local dbh, err = DBI.Connect(
+		params.driver, params.database,
+		params.username, params.password,
+		params.host, params.port
+	);
+	if not dbh then return nil, err; end
+	dbh:autocommit(false); -- don't commit automatically
+	self.conn = dbh;
+	self.prepared = {};
+	return true;
+end
+function engine:execute(sql, ...)
+	local success, err = self:connect();
+	if not success then return success, err; end
+	local prepared = self.prepared;
+
+	local stmt = prepared[sql];
+	if not stmt then
+		local err;
+		stmt, err = self.conn:prepare(sql);
+		if not stmt then return stmt, err; end
+		prepared[sql] = stmt;
+	end
+
+	local success, err = stmt:execute(...);
+	if not success then return success, err; end
+	return stmt;
+end
+
+local result_mt = { __index = {
+	affected = function(self) return self.__affected; end;
+	rowcount = function(self) return self.__rowcount; end;
+} };
+
+function engine:execute_query(sql, ...)
+	if self.params.driver == "PostgreSQL" then
+		sql = sql:gsub("`", "\"");
+	end
+	local stmt = assert(self.conn:prepare(sql));
+	assert(stmt:execute(...));
+	return stmt:rows();
+end
+function engine:execute_update(sql, ...)
+	if self.params.driver == "PostgreSQL" then
+		sql = sql:gsub("`", "\"");
+	end
+	local prepared = self.prepared;
+	local stmt = prepared[sql];
+	if not stmt then
+		stmt = assert(self.conn:prepare(sql));
+		prepared[sql] = stmt;
+	end
+	assert(stmt:execute(...));
+	return setmetatable({ __affected = stmt:affected(), __rowcount = stmt:rowcount() }, result_mt);
+end
+engine.insert = engine.execute_update;
+engine.select = engine.execute_query;
+engine.delete = engine.execute_update;
+engine.update = engine.execute_update;
+function engine:_transaction(func, ...)
+	if not self.conn then
+		local a,b = self:connect();
+		if not a then return a,b; end
+	end
+	--assert(not self.__transaction, "Recursive transactions not allowed");
+	local args, n_args = {...}, select("#", ...);
+	local function f() return func(unpack(args, 1, n_args)); end
+	self.__transaction = true;
+	local success, a, b, c = xpcall(f, debug_traceback);
+	self.__transaction = nil;
+	if success then
+		log("debug", "SQL transaction success [%s]", tostring(func));
+		local ok, err = self.conn:commit();
+		if not ok then return ok, err; end -- commit failed
+		return success, a, b, c;
+	else
+		log("debug", "SQL transaction failure [%s]: %s", tostring(func), a);
+		if self.conn then self.conn:rollback(); end
+		return success, a;
+	end
+end
+function engine:transaction(...)
+	local a,b = self:_transaction(...);
+	if not a then
+		local conn = self.conn;
+		if not conn or not conn:ping() then
+			self.conn = nil;
+			a,b = self:_transaction(...);
+		end
+	end
+	return a,b;
+end
+function engine:_create_index(index)
+	local sql = "CREATE INDEX `"..index.name.."` ON `"..index.table.."` (";
+	for i=1,#index do
+		sql = sql.."`"..index[i].."`";
+		if i ~= #index then sql = sql..", "; end
+	end
+	sql = sql..");"
+	if self.params.driver == "PostgreSQL" then
+		sql = sql:gsub("`", "\"");
+	elseif self.params.driver == "MySQL" then
+		sql = sql:gsub("`([,)])", "`(20)%1");
+	end
+	--print(sql);
+	return self:execute(sql);
+end
+function engine:_create_table(table)
+	local sql = "CREATE TABLE `"..table.name.."` (";
+	for i,col in ipairs(table.c) do
+		sql = sql.."`"..col.name.."` "..col.type;
+		if col.nullable == false then sql = sql.." NOT NULL"; end
+		if i ~= #table.c then sql = sql..", "; end
+	end
+	sql = sql.. ");"
+	if self.params.driver == "PostgreSQL" then
+		sql = sql:gsub("`", "\"");
+	end
+	local success,err = self:execute(sql);
+	if not success then return success,err; end
+	for i,v in ipairs(table.__table__) do
+		if is_index(v) then
+			self:_create_index(v);
+		end
+	end
+	return success;
+end
+local engine_mt = { __index = engine };
+
+local function db2uri(params)
+	return build_url{
+		scheme = params.driver,
+		user = params.username,
+		password = params.password,
+		host = params.host,
+		port = params.port,
+		path = params.database,
+	};
+end
+local engine_cache = {}; -- TODO make weak valued
+function create_engine(self, params)
+	local url = db2uri(params);
+	if not engine_cache[url] then
+		local engine = setmetatable({ url = url, params = params }, engine_mt);
+		engine_cache[url] = engine;
+	end
+	return engine_cache[url];
+end
+
+
+--[[Users = Table {
+	name="users";
+	Column { name="user_id", type=String(), primary_key=true };
+};
+print(Users)
+print(Users.c.user_id)]]
+
+--local engine = create_engine('postgresql://scott:tiger@localhost:5432/mydatabase');
+--[[local engine = create_engine{ driver = "SQLite3", database = "./alchemy.sqlite" };
+
+local i = 0;
+for row in assert(engine:execute("select * from sqlite_master")):rows(true) do
+	i = i+1;
+	print(i);
+	for k,v in pairs(row) do
+		print("",k,v);
+	end
+end
+print("---")
+
+Prosody = Table {
+	name="prosody";
+	Column { name="host", type="TEXT", nullable=false };
+	Column { name="user", type="TEXT", nullable=false };
+	Column { name="store", type="TEXT", nullable=false };
+	Column { name="key", type="TEXT", nullable=false };
+	Column { name="type", type="TEXT", nullable=false };
+	Column { name="value", type="TEXT", nullable=false };
+	Index { name="prosody_index", "host", "user", "store", "key" };
+};
+--print(Prosody);
+assert(engine:transaction(function()
+	assert(Prosody:create(engine));
+end));
+
+for row in assert(engine:execute("select user from prosody")):rows(true) do
+	print("username:", row['username'])
+end
+--result.close();]]
+
+return _M;