util/sqlite3.lua
changeset 13150 771eb453e03a
parent 13149 af251471d5ae
child 13151 e560f7c691ce
equal deleted inserted replaced
13149:af251471d5ae 13150:771eb453e03a
     7 local error = error
     7 local error = error
     8 local type = type
     8 local type = type
     9 local t_concat = table.concat;
     9 local t_concat = table.concat;
    10 local t_insert = table.insert;
    10 local t_insert = table.insert;
    11 local s_char = string.char;
    11 local s_char = string.char;
       
    12 local array = require "prosody.util.array";
    12 local log = require "prosody.util.logger".init("sql");
    13 local log = require "prosody.util.logger".init("sql");
    13 
    14 
    14 local lsqlite3 = require "lsqlite3";
    15 local lsqlite3 = require "lsqlite3";
    15 local build_url = require "socket.url".build;
    16 local build_url = require "socket.url".build;
    16 local ROW, DONE = lsqlite3.ROW, lsqlite3.DONE;
    17 local ROW, DONE = lsqlite3.ROW, lsqlite3.DONE;
   192 	local ret = stmt:bind_values(...);
   193 	local ret = stmt:bind_values(...);
   193 	if ret ~= lsqlite3.OK then return nil, sqlite_errors.new(ret, { message = self.conn:errmsg() }); end
   194 	if ret ~= lsqlite3.OK then return nil, sqlite_errors.new(ret, { message = self.conn:errmsg() }); end
   194 	return stmt;
   195 	return stmt;
   195 end
   196 end
   196 
   197 
   197 local result_mt = {
       
   198 	__index = {
       
   199 	affected = function(self) return self.__affected; end;
       
   200 	rowcount = function(self) return self.__rowcount; end;
       
   201 	},
       
   202 };
       
   203 
       
   204 local function iterator(table)
   198 local function iterator(table)
   205 	local i = 0;
   199 	local i = 0;
   206 	return function()
   200 	return function()
   207 		i = i + 1;
   201 		i = i + 1;
   208 		local item = table[i];
   202 		local item = table[i];
   209 		if item ~= nil then
   203 		if item ~= nil then
   210 			return item;
   204 			return item;
   211 		end
   205 		end
   212 	end
   206 	end
   213 end
   207 end
       
   208 
       
   209 local result_mt = {
       
   210 	__len = function(self)
       
   211 		return self.__rowcount;
       
   212 	end;
       
   213 	__index = {
       
   214 		affected = function(self)
       
   215 			return self.__affected;
       
   216 		end;
       
   217 		rowcount = function(self)
       
   218 			return self.__rowcount;
       
   219 		end;
       
   220 	};
       
   221 	__call = function(self)
       
   222 		return iterator(self.__data);
       
   223 	end;
       
   224 };
   214 
   225 
   215 local function debugquery(where, sql, ...)
   226 local function debugquery(where, sql, ...)
   216 	local i = 0; local a = {...}
   227 	local i = 0; local a = {...}
   217 	sql = sql:gsub("\n?\t+", " ");
   228 	sql = sql:gsub("\n?\t+", " ");
   218 	log("debug", "[%s] %s", where, (sql:gsub("%?", function ()
   229 	log("debug", "[%s] %s", where, (sql:gsub("%?", function ()
   223 		end
   234 		end
   224 		return tostring(v);
   235 		return tostring(v);
   225 	end)));
   236 	end)));
   226 end
   237 end
   227 
   238 
   228 function engine:execute_query(sql, ...)
   239 function engine:execute_update(sql, ...)
   229 	local prepared = self.prepared;
   240 	local prepared = self.prepared;
   230 	local stmt = prepared[sql];
   241 	local stmt = prepared[sql];
   231 	if stmt and stmt:isopen() then
   242 	if stmt and stmt:isopen() then
   232 		prepared[sql] = nil; -- Can't be used concurrently
   243 		prepared[sql] = nil; -- Can't be used concurrently
   233 	else
   244 	else
   234 		stmt = assert(self.conn:prepare(sql));
   245 		stmt = assert(self.conn:prepare(sql));
   235 	end
   246 	end
   236 	local ret = stmt:bind_values(...);
   247 	local ret = stmt:bind_values(...);
   237 	if ret ~= lsqlite3.OK then error(self.conn:errmsg()); end
   248 	if ret ~= lsqlite3.OK then error(self.conn:errmsg()); end
   238 	local data, ret = {}
   249 	local data = array();
   239 	while stmt:step() == ROW do
   250 	for row in stmt:rows() do
   240 		t_insert(data, stmt:get_values());
   251 		data:push(array(row));
   241 	end
   252 	end
   242 	-- FIXME Error handling, BUSY, ERROR, MISUSE
   253 	-- FIXME Error handling, BUSY, ERROR, MISUSE
   243 	if stmt:reset() == lsqlite3.OK then
   254 	if stmt:reset() == lsqlite3.OK then
   244 		prepared[sql] = stmt;
   255 		prepared[sql] = stmt;
   245 	end
   256 	end
   246 	return setmetatable({ __data = data }, { __index = result_mt.__index, __call = iterator(data) });
       
   247 end
       
   248 function engine:execute_update(sql, ...)
       
   249 	local prepared = self.prepared;
       
   250 	local stmt = prepared[sql];
       
   251 	if not stmt or not stmt:isopen() then
       
   252 		stmt = assert(self.conn:prepare(sql));
       
   253 	else
       
   254 		prepared[sql] = nil;
       
   255 	end
       
   256 	local ret = stmt:bind_values(...);
       
   257 	if ret ~= lsqlite3.OK then error(self.conn:errmsg()); end
       
   258 	local rowcount = 0;
       
   259 	repeat
       
   260 		ret = stmt:step();
       
   261 		if ret == lsqlite3.ROW then
       
   262 			rowcount = rowcount + 1;
       
   263 		end
       
   264 	until ret ~= lsqlite3.ROW;
       
   265 	local affected = self.conn:changes();
   257 	local affected = self.conn:changes();
   266 	if stmt:reset() == lsqlite3.OK then
   258 	return setmetatable({ __affected = affected; __rowcount = #data; __data = data }, result_mt);
   267 		prepared[sql] = stmt;
   259 end
   268 	end
   260 
   269 	return setmetatable({ __affected = affected, __rowcount = rowcount }, result_mt);
   261 function engine:execute_query(sql, ...)
       
   262 	return self:execute_update(sql, ...)()
   270 end
   263 end
   271 
   264 
   272 engine.insert = engine.execute_update;
   265 engine.insert = engine.execute_update;
   273 engine.select = engine.execute_query;
   266 engine.select = engine.execute_query;
   274 engine.delete = engine.execute_update;
   267 engine.delete = engine.execute_update;