util/jwt.lua
changeset 12700 27a72982e331
parent 11565 d2f33b8fdc96
child 12703 b3d0c1457584
equal deleted inserted replaced
12699:6aaa604fdfd5 12700:27a72982e331
     1 local s_gsub = string.gsub;
     1 local s_gsub = string.gsub;
       
     2 local crypto = require "util.crypto";
     2 local json = require "util.json";
     3 local json = require "util.json";
     3 local hashes = require "util.hashes";
     4 local hashes = require "util.hashes";
     4 local base64_encode = require "util.encodings".base64.encode;
     5 local base64_encode = require "util.encodings".base64.encode;
     5 local base64_decode = require "util.encodings".base64.decode;
     6 local base64_decode = require "util.encodings".base64.decode;
     6 local secure_equals = require "util.hashes".equals;
     7 local secure_equals = require "util.hashes".equals;
    11 end
    12 end
    12 local function unb64url(data)
    13 local function unb64url(data)
    13 	return base64_decode(s_gsub(data, "[-_]", b64url_rep).."==");
    14 	return base64_decode(s_gsub(data, "[-_]", b64url_rep).."==");
    14 end
    15 end
    15 
    16 
    16 local static_header = b64url('{"alg":"HS256","typ":"JWT"}') .. '.';
       
    17 
       
    18 local function sign(key, payload)
       
    19 	local encoded_payload = json.encode(payload);
       
    20 	local signed = static_header .. b64url(encoded_payload);
       
    21 	local signature = hashes.hmac_sha256(key, signed);
       
    22 	return signed .. "." .. b64url(signature);
       
    23 end
       
    24 
       
    25 local jwt_pattern = "^(([A-Za-z0-9-_]+)%.([A-Za-z0-9-_]+))%.([A-Za-z0-9-_]+)$"
    17 local jwt_pattern = "^(([A-Za-z0-9-_]+)%.([A-Za-z0-9-_]+))%.([A-Za-z0-9-_]+)$"
    26 local function verify(key, blob)
    18 local function decode_jwt(blob, expected_alg)
    27 	local signed, bheader, bpayload, signature = string.match(blob, jwt_pattern);
    19 	local signed, bheader, bpayload, signature = string.match(blob, jwt_pattern);
    28 	if not signed then
    20 	if not signed then
    29 		return nil, "invalid-encoding";
    21 		return nil, "invalid-encoding";
    30 	end
    22 	end
    31 	local header = json.decode(unb64url(bheader));
    23 	local header = json.decode(unb64url(bheader));
    32 	if not header or type(header) ~= "table" then
    24 	if not header or type(header) ~= "table" then
    33 		return nil, "invalid-header";
    25 		return nil, "invalid-header";
    34 	elseif header.alg ~= "HS256" then
    26 	elseif header.alg ~= expected_alg then
    35 		return nil, "unsupported-algorithm";
    27 		return nil, "unsupported-algorithm";
    36 	end
    28 	end
    37 	if not secure_equals(b64url(hashes.hmac_sha256(key, signed)), signature) then
    29 	return signed, signature, bpayload;
    38 		return false, "signature-mismatch";
    30 end
       
    31 
       
    32 local function new_static_header(algorithm_name)
       
    33 	return b64url('{"alg":"'..algorithm_name..'","typ":"JWT"}') .. '.';
       
    34 end
       
    35 
       
    36 -- HS*** family
       
    37 local function new_hmac_algorithm(name, hmac)
       
    38 	local static_header = new_static_header(name);
       
    39 
       
    40 	local function sign(key, payload)
       
    41 		local encoded_payload = json.encode(payload);
       
    42 		local signed = static_header .. b64url(encoded_payload);
       
    43 		local signature = hmac(key, signed);
       
    44 		return signed .. "." .. b64url(signature);
    39 	end
    45 	end
    40 	local payload, err = json.decode(unb64url(bpayload));
    46 
    41 	if err ~= nil then
    47 	local function verify(key, blob)
    42 		return nil, "json-decode-error";
    48 		local signed, signature, raw_payload = decode_jwt(blob, name);
       
    49 		if not signed then return nil, signature; end -- nil, err
       
    50 
       
    51 		if not secure_equals(b64url(hmac(key, signed)), signature) then
       
    52 			return false, "signature-mismatch";
       
    53 		end
       
    54 		local payload, err = json.decode(unb64url(raw_payload));
       
    55 		if err ~= nil then
       
    56 			return nil, "json-decode-error";
       
    57 		end
       
    58 		return true, payload;
    43 	end
    59 	end
    44 	return true, payload;
    60 
       
    61 	local function load_key(key)
       
    62 		assert(type(key) == "string", "key must be string (long, random, secure)");
       
    63 		return key;
       
    64 	end
       
    65 
       
    66 	return { sign = sign, verify = verify, load_key = load_key };
       
    67 end
       
    68 
       
    69 -- ES*** family
       
    70 local function new_ecdsa_algorithm(name, c_sign, c_verify)
       
    71 	local static_header = new_static_header(name);
       
    72 
       
    73 	return {
       
    74 		sign = function (private_key, payload)
       
    75 			local encoded_payload = json.encode(payload);
       
    76 			local signed = static_header .. b64url(encoded_payload);
       
    77 
       
    78 			local der_sig = c_sign(private_key, signed);
       
    79 
       
    80 			local r, s = crypto.parse_ecdsa_signature(der_sig);
       
    81 
       
    82 			return signed.."."..b64url(r..s);
       
    83 		end;
       
    84 
       
    85 	verify = function (public_key, blob)
       
    86 			local signed, signature, raw_payload = decode_jwt(blob, name);
       
    87 			if not signed then return nil, signature; end -- nil, err
       
    88 
       
    89 			local raw_signature = unb64url(signature);
       
    90 
       
    91 			local der_sig = crypto.build_ecdsa_signature(raw_signature:sub(1, 32), raw_signature:sub(33, 64));
       
    92 			if not der_sig then
       
    93 				return false, "signature-mismatch";
       
    94 			end
       
    95 
       
    96 			local verify_ok = c_verify(public_key, signed, der_sig);
       
    97 			if not verify_ok then
       
    98 				return false, "signature-mismatch";
       
    99 			end
       
   100 
       
   101 			local payload, err = json.decode(unb64url(raw_payload));
       
   102 			if err ~= nil then
       
   103 				return nil, "json-decode-error";
       
   104 			end
       
   105 
       
   106 			return true, payload;
       
   107 		end;
       
   108 
       
   109 		load_public_key = function (public_key_pem)
       
   110 			local key = assert(crypto.import_public_pem(public_key_pem));
       
   111 			assert(key:get_type() == "id-ecPublicKey", "incorrect key type");
       
   112 			return key;
       
   113 		end;
       
   114 
       
   115 		load_private_key = function (private_key_pem)
       
   116 			local key = assert(crypto.import_private_pem(private_key_pem));
       
   117 			assert(key:get_type() == "id-ecPublicKey", "incorrect key type");
       
   118 			return key;
       
   119 		end;
       
   120 	};
       
   121 end
       
   122 
       
   123 local algorithms = {
       
   124 	HS256 = new_hmac_algorithm("HS256", hashes.hmac_sha256);
       
   125 	ES256 = new_ecdsa_algorithm("ES256", crypto.ecdsa_sha256_sign, crypto.ecdsa_sha256_verify);
       
   126 };
       
   127 
       
   128 local function new_signer(algorithm, key_input)
       
   129 	local impl = assert(algorithms[algorithm], "Unknown JWT algorithm: "..algorithm);
       
   130 	local key = (impl.load_private_key or impl.load_key)(key_input);
       
   131 	local sign = impl.sign;
       
   132 	return function (payload)
       
   133 		return sign(key, payload);
       
   134 	end
       
   135 end
       
   136 
       
   137 local function new_verifier(algorithm, key_input)
       
   138 	local impl = assert(algorithms[algorithm], "Unknown JWT algorithm: "..algorithm);
       
   139 	local key = (impl.load_public_key or impl.load_key)(key_input);
       
   140 	local verify = impl.verify;
       
   141 	return function (token)
       
   142 		return verify(key, token);
       
   143 	end
    45 end
   144 end
    46 
   145 
    47 return {
   146 return {
    48 	sign = sign;
   147 	new_signer = new_signer;
    49 	verify = verify;
   148 	new_verifier = new_verifier;
       
   149 	-- Deprecated
       
   150 	sign = algorithms.HS256.sign;
       
   151 	verify = algorithms.HS256.verify;
    50 };
   152 };
    51 
   153