util.sasl.scram: Factor out SHA-1 specific getAuthenticationDatabaseSHA1
authorKim Alvefur <zash@zash.se>
Sun, 13 Jan 2019 14:01:31 +0100
changeset 10220 a51d017e6173
parent 10219 82abf88db13f
child 10221 60b445183d84
util.sasl.scram: Factor out SHA-1 specific getAuthenticationDatabaseSHA1 This makes the code more generic, allowing SHA-1 to be replaced
util/sasl/scram.lua
--- a/util/sasl/scram.lua	Thu Aug 22 22:23:04 2019 +0200
+++ b/util/sasl/scram.lua	Sun Jan 13 14:01:31 2019 +0100
@@ -14,9 +14,7 @@
 local s_match = string.match;
 local type = type
 local base64 = require "util.encodings".base64;
-local hmac_sha1 = require "util.hashes".hmac_sha1;
-local sha1 = require "util.hashes".sha1;
-local Hi = require "util.hashes".scram_Hi_sha1;
+local hashes = require "util.hashes";
 local generate_uuid = require "util.uuid".generate;
 local saslprep = require "util.encodings".stringprep.saslprep;
 local nodeprep = require "util.encodings".stringprep.nodeprep;
@@ -99,20 +97,22 @@
 	return hashname:lower():gsub("-", "_");
 end
 
-local function getAuthenticationDatabaseSHA1(password, salt, iteration_count)
-	if type(password) ~= "string" or type(salt) ~= "string" or type(iteration_count) ~= "number" then
-		return false, "inappropriate argument types"
+local function get_scram_hasher(H, HMAC, Hi)
+	return function (password, salt, iteration_count)
+		if type(password) ~= "string" or type(salt) ~= "string" or type(iteration_count) ~= "number" then
+			return false, "inappropriate argument types"
+		end
+		if iteration_count < 4096 then
+			log("warn", "Iteration count < 4096 which is the suggested minimum according to RFC 5802.")
+		end
+		local salted_password = Hi(password, salt, iteration_count);
+		local stored_key = H(HMAC(salted_password, "Client Key"))
+		local server_key = HMAC(salted_password, "Server Key");
+		return true, stored_key, server_key
 	end
-	if iteration_count < 4096 then
-		log("warn", "Iteration count < 4096 which is the suggested minimum according to RFC 5802.")
-	end
-	local salted_password = Hi(password, salt, iteration_count);
-	local stored_key = sha1(hmac_sha1(salted_password, "Client Key"))
-	local server_key = hmac_sha1(salted_password, "Server Key");
-	return true, stored_key, server_key
 end
 
-local function scram_gen(hash_name, H_f, HMAC_f)
+local function scram_gen(hash_name, H_f, HMAC_f, get_auth_db)
 	local profile_name = "scram_" .. hashprep(hash_name);
 	local function scram_hash(self, message)
 		local support_channel_binding = false;
@@ -177,7 +177,7 @@
 				iteration_count = default_i;
 
 				local succ;
-				succ, stored_key, server_key = getAuthenticationDatabaseSHA1(password, salt, iteration_count);
+				succ, stored_key, server_key = get_auth_db(password, salt, iteration_count);
 				if not succ then
 					log("error", "Generating authentication database failed. Reason: %s", stored_key);
 					return "failure", "temporary-auth-failure";
@@ -247,22 +247,27 @@
 	return scram_hash;
 end
 
+local auth_db_getters = {}
 local function init(registerMechanism)
-	local function registerSCRAMMechanism(hash_name, hash, hmac_hash)
+	local function registerSCRAMMechanism(hash_name, hash, hmac_hash, pbkdf2)
+		local get_auth_db = get_scram_hasher(hash, hmac_hash, pbkdf2);
+		auth_db_getters[hash_name] = get_auth_db;
 		registerMechanism("SCRAM-"..hash_name,
 			{"plain", "scram_"..(hashprep(hash_name))},
-			scram_gen(hash_name:lower(), hash, hmac_hash));
+			scram_gen(hash_name:lower(), hash, hmac_hash, get_auth_db));
 
 		-- register channel binding equivalent
 		registerMechanism("SCRAM-"..hash_name.."-PLUS",
 			{"plain", "scram_"..(hashprep(hash_name))},
-			scram_gen(hash_name:lower(), hash, hmac_hash), {"tls-unique"});
+			scram_gen(hash_name:lower(), hash, hmac_hash, get_auth_db), {"tls-unique"});
 	end
 
-	registerSCRAMMechanism("SHA-1", sha1, hmac_sha1);
+	registerSCRAMMechanism("SHA-1", hashes.sha1, hashes.hmac_sha1, hashes.pbkdf2_hmac_sha1);
 end
 
 return {
-	getAuthenticationDatabaseSHA1 = getAuthenticationDatabaseSHA1;
+	get_hash = get_scram_hasher;
+	hashers = auth_db_getters;
+	getAuthenticationDatabaseSHA1 = get_scram_hasher(hashes.sha1, hashes.sha256, hashes.pbkdf2_hmac_sha1);
 	init = init;
 }