diff --git a/src/core/verify-credentials.test.ts b/src/core/verify-credentials.test.ts index 5f950ab..6790368 100644 --- a/src/core/verify-credentials.test.ts +++ b/src/core/verify-credentials.test.ts @@ -1,4 +1,4 @@ -import { exportJWK, generateKeyPair, SignJWT } from 'jose' +import { exportJWK, generateKeyPair, generateSecret, SignJWT } from 'jose' import { afterEach, beforeAll, @@ -277,39 +277,63 @@ describe('verifyCredentials', () => { describe('user mode', () => { let jwks: JsonWebKeySet - let validToken: string + let validTokens: string[] beforeAll(async () => { + // Asymmetric JWK const { privateKey, publicKey } = await generateKeyPair('RS256') const publicJwk = await exportJWK(publicKey) publicJwk.alg = 'RS256' publicJwk.use = 'sig' - jwks = { keys: [publicJwk] } + publicJwk.kid = 'asymmetric-key-id' + + // Symmetric Shared Secret JWK + const jwtSecret = await generateSecret('HS256', { + extractable: true, + }) + const symmetricJwk = await exportJWK(jwtSecret) + symmetricJwk.alg = 'HS256' + symmetricJwk.kid = 'symmetric-shared-secret-key-id' + + jwks = { keys: [publicJwk, symmetricJwk] } + + const getValidToken = async ( + key: CryptoKey | Uint8Array, + alg: string, + kid: string, + ) => + await new SignJWT({ + sub: 'user-123', + role: 'authenticated', + email: 'test@example.com', + }) + .setProtectedHeader({ alg, kid }) + .setIssuedAt() + .setExpirationTime('1h') + .sign(key) - validToken = await new SignJWT({ - sub: 'user-123', - role: 'authenticated', - email: 'test@example.com', - }) - .setProtectedHeader({ alg: 'RS256' }) - .setIssuedAt() - .setExpirationTime('1h') - .sign(privateKey) + validTokens = [ + await getValidToken(privateKey, publicJwk.alg, publicJwk.kid), + await getValidToken(jwtSecret, symmetricJwk.alg, symmetricJwk.kid), + ] }) it('succeeds with valid JWT', async () => { - const creds: Credentials = { token: validToken, apikey: null } - const result = await verifyCredentials(creds, { - auth: 'user', - env: makeEnv({ jwks }), - }) - expect(result.error).toBeNull() - expect(result.data!.authMode).toBe('user') - expect(result.data!.keyName).toBeNull() - expect(result.data!.userClaims!.id).toBe('user-123') - expect(result.data!.userClaims!.email).toBe('test@example.com') - expect(result.data!.jwtClaims!.sub).toBe('user-123') - expect(result.data!.token).toBe(validToken) + for (let index = 0; index < validTokens.length; index++) { + const token = validTokens[index] + const creds: Credentials = { token, apikey: null } + const result = await verifyCredentials(creds, { + auth: 'user', + env: makeEnv({ jwks }), + }) + expect(result.error).toBeNull() + expect(result.data!.authMode).toBe('user') + expect(result.data!.keyName).toBeNull() + expect(result.data!.userClaims!.id).toBe('user-123') + expect(result.data!.userClaims!.email).toBe('test@example.com') + expect(result.data!.jwtClaims!.sub).toBe('user-123') + expect(result.data!.token).toBe(token) + } }) it('fails with invalid JWT', async () => { @@ -356,28 +380,47 @@ describe('verifyCredentials', () => { }) describe('user mode with remote JWKS URL', () => { - let privateKey: CryptoKey let jwks: JsonWebKeySet - let validToken: string + let validTokens: string[] let fetchMock: ReturnType beforeAll(async () => { - const keyPair = await generateKeyPair('RS256') - privateKey = keyPair.privateKey - const publicJwk = await exportJWK(keyPair.publicKey) + // Asymmetric JWK + const { privateKey, publicKey } = await generateKeyPair('RS256') + const publicJwk = await exportJWK(publicKey) publicJwk.alg = 'RS256' publicJwk.use = 'sig' publicJwk.kid = 'remote-key-1' - jwks = { keys: [publicJwk] } - validToken = await new SignJWT({ - sub: 'user-remote', - role: 'authenticated', - }) - .setProtectedHeader({ alg: 'RS256', kid: 'remote-key-1' }) - .setIssuedAt() - .setExpirationTime('1h') - .sign(privateKey) + // Symmetric Shared Secret JWK + const jwtSecret = await generateSecret('HS256', { + extractable: true, + }) + const symmetricJwk = await exportJWK(jwtSecret) + symmetricJwk.alg = 'HS256' + symmetricJwk.kid = 'remote-key-2' + + jwks = { keys: [publicJwk, symmetricJwk] } + + const getValidToken = async ( + key: CryptoKey | Uint8Array, + alg: string, + kid: string, + ) => + await new SignJWT({ + sub: 'user-123', + role: 'authenticated', + email: 'test@example.com', + }) + .setProtectedHeader({ alg, kid }) + .setIssuedAt() + .setExpirationTime('1h') + .sign(key) + + validTokens = [ + await getValidToken(privateKey, publicJwk.alg, publicJwk.kid), + await getValidToken(jwtSecret, symmetricJwk.alg, symmetricJwk.kid), + ] }) beforeEach(() => { @@ -396,40 +439,47 @@ describe('verifyCredentials', () => { }) it('fetches keys from the URL and verifies a valid JWT', async () => { - const creds: Credentials = { token: validToken, apikey: null } - const result = await verifyCredentials(creds, { - auth: 'user', - env: makeEnv({ - jwks: new URL( - 'https://jwks-fetch-success.example/auth/v1/.well-known/jwks.json', - ), - }), - }) - expect(result.error).toBeNull() - expect(result.data!.userClaims!.id).toBe('user-remote') - expect(fetchMock).toHaveBeenCalledTimes(1) + for (let index = 0; index < validTokens.length; index++) { + const token = validTokens[index] + const creds: Credentials = { token, apikey: null } + const result = await verifyCredentials(creds, { + auth: 'user', + env: makeEnv({ + jwks: new URL( + 'https://jwks-fetch-success.example/auth/v1/.well-known/jwks.json', + ), + }), + }) + expect(result.error).toBeNull() + expect(result.data!.userClaims!.id).toBe('user-remote') + expect(fetchMock).toHaveBeenCalledTimes(1) + } }) it('reuses the cached resolver for the same URL across requests', async () => { // Distinct URL so jose's per-resolver cooldown is fresh for this test const jwksUrl = new URL('https://jwks-cache.example/jwks.json') - const creds: Credentials = { token: validToken, apikey: null } + for (let index = 0; index < validTokens.length; index++) { + const token = validTokens[index] - const first = await verifyCredentials(creds, { - auth: 'user', - env: makeEnv({ jwks: jwksUrl }), - }) - const second = await verifyCredentials(creds, { - auth: 'user', - env: makeEnv({ jwks: jwksUrl }), - }) + const creds: Credentials = { token, apikey: null } - expect(first.error).toBeNull() - expect(second.error).toBeNull() - // jose's cooldownDuration (default 30s) keeps the second call from re-fetching. - // What we're guarding against is *re-creating* the resolver on every request, - // which would re-fetch every time. - expect(fetchMock).toHaveBeenCalledTimes(1) + const first = await verifyCredentials(creds, { + auth: 'user', + env: makeEnv({ jwks: jwksUrl }), + }) + const second = await verifyCredentials(creds, { + auth: 'user', + env: makeEnv({ jwks: jwksUrl }), + }) + + expect(first.error).toBeNull() + expect(second.error).toBeNull() + // jose's cooldownDuration (default 30s) keeps the second call from re-fetching. + // What we're guarding against is *re-creating* the resolver on every request, + // which would re-fetch every time. + expect(fetchMock).toHaveBeenCalledTimes(1) + } }) it('rejects an invalid JWT verified against the remote JWKS', async () => { @@ -446,7 +496,7 @@ describe('verifyCredentials', () => { it('rejects when the remote JWKS endpoint fails', async () => { fetchMock.mockResolvedValueOnce(new Response('boom', { status: 500 })) - const creds: Credentials = { token: validToken, apikey: null } + const creds: Credentials = { token: validTokens.at(0)!, apikey: null } const result = await verifyCredentials(creds, { auth: 'user', env: makeEnv({ @@ -488,8 +538,9 @@ describe('verifyCredentials', () => { }) }) + const tokenA = validTokens.at(0)! // matches 'RS256' token const a = await verifyCredentials( - { token: validToken, apikey: null }, + { token: tokenA, apikey: null }, { auth: 'user', env: makeEnv({ jwks: urlA }) }, ) const b = await verifyCredentials( diff --git a/src/core/verify-credentials.ts b/src/core/verify-credentials.ts index 13344cb..4c5b5df 100644 --- a/src/core/verify-credentials.ts +++ b/src/core/verify-credentials.ts @@ -1,6 +1,10 @@ import { createLocalJWKSet, createRemoteJWKSet, + decodeProtectedHeader, + importJWK, + JSONWebKeySet, + JWTPayload, jwtVerify, type JWTVerifyGetKey, } from 'jose' @@ -91,7 +95,10 @@ function jwtClaimsToUserClaims(jwtClaims: JWTClaims): UserClaims { const INVALID = Symbol('invalid') -let remoteJwksResolver: { url: string; resolver: JWTVerifyGetKey } | undefined = +export type JwksResolver = JWTVerifyGetKey & { + jwks: () => JSONWebKeySet | undefined +} +let remoteJwksResolver: { url: string; resolver: JwksResolver } | undefined = undefined /** @@ -104,7 +111,7 @@ let remoteJwksResolver: { url: string; resolver: JWTVerifyGetKey } | undefined = * * @internal */ -function getJwksResolver(jwks: JsonWebKeySet | URL): JWTVerifyGetKey { +function getJwksResolver(jwks: JsonWebKeySet | URL): JwksResolver { if (jwks instanceof URL) { const url = jwks.toString() if (remoteJwksResolver?.url !== url) { @@ -112,7 +119,17 @@ function getJwksResolver(jwks: JsonWebKeySet | URL): JWTVerifyGetKey { } return remoteJwksResolver.resolver } - return createLocalJWKSet(jwks) + + const localJwkSet = createLocalJWKSet(jwks) + function localJwtVerifyGetKey(...args: Parameters) { + return localJwkSet(...args) + } + + const localJwksResolver: JwksResolver = Object.assign(localJwtVerifyGetKey, { + jwks: () => jwks, + }) + + return localJwksResolver } /** @@ -215,8 +232,31 @@ async function tryMode( if (credentials.token.startsWith('sb_')) return null if (!env.jwks) return null try { - const jwkSet = getJwksResolver(env.jwks) - const { payload } = await jwtVerify(credentials.token, jwkSet) + const jwkResolver = getJwksResolver(env.jwks) + const { alg, kid } = decodeProtectedHeader(credentials.token) + if (!alg || !kid) { + return INVALID + } + + let payload: JWTPayload | null = null + + // Symmetric algorithm requires importing the shared secret + if (alg === 'HS256') { + const jwk = jwkResolver + .jwks() + ?.keys.find((key) => key.alg === alg && key.kid === kid) + if (!jwk) { + return INVALID + } + const sharedSecret = await importJWK(jwk, 'HS256') + + const verify = await jwtVerify(credentials.token, sharedSecret) + payload = verify.payload + } else { + const verify = await jwtVerify(credentials.token, jwkResolver) + payload = verify.payload + } + if (typeof payload.sub !== 'string') { return INVALID }