Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 118 additions & 67 deletions src/core/verify-credentials.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { exportJWK, generateKeyPair, SignJWT } from 'jose'
import { exportJWK, generateKeyPair, generateSecret, SignJWT } from 'jose'
import {
afterEach,
beforeAll,
Expand Down Expand Up @@ -277,39 +277,63 @@ describe('verifyCredentials', () => {

describe('user mode', () => {
let jwks: JsonWebKeySet
let validToken: string
let validTokens: string[]

beforeAll(async () => {
// Asymmetric JWK
const { privateKey, publicKey } = await generateKeyPair('RS256')
const publicJwk = await exportJWK(publicKey)
publicJwk.alg = 'RS256'
publicJwk.use = 'sig'
jwks = { keys: [publicJwk] }
publicJwk.kid = 'asymmetric-key-id'

// Symmetric Shared Secret JWK
const jwtSecret = await generateSecret('HS256', {
extractable: true,
})
const symmetricJwk = await exportJWK(jwtSecret)
symmetricJwk.alg = 'HS256'
symmetricJwk.kid = 'symmetric-shared-secret-key-id'

jwks = { keys: [publicJwk, symmetricJwk] }

const getValidToken = async (
key: CryptoKey | Uint8Array<ArrayBufferLike>,
alg: string,
kid: string,
) =>
await new SignJWT({
sub: 'user-123',
role: 'authenticated',
email: 'test@example.com',
})
.setProtectedHeader({ alg, kid })
.setIssuedAt()
.setExpirationTime('1h')
.sign(key)

validToken = await new SignJWT({
sub: 'user-123',
role: 'authenticated',
email: 'test@example.com',
})
.setProtectedHeader({ alg: 'RS256' })
.setIssuedAt()
.setExpirationTime('1h')
.sign(privateKey)
validTokens = [
await getValidToken(privateKey, publicJwk.alg, publicJwk.kid),
await getValidToken(jwtSecret, symmetricJwk.alg, symmetricJwk.kid),
]
})

it('succeeds with valid JWT', async () => {
const creds: Credentials = { token: validToken, apikey: null }
const result = await verifyCredentials(creds, {
auth: 'user',
env: makeEnv({ jwks }),
})
expect(result.error).toBeNull()
expect(result.data!.authMode).toBe('user')
expect(result.data!.keyName).toBeNull()
expect(result.data!.userClaims!.id).toBe('user-123')
expect(result.data!.userClaims!.email).toBe('test@example.com')
expect(result.data!.jwtClaims!.sub).toBe('user-123')
expect(result.data!.token).toBe(validToken)
for (let index = 0; index < validTokens.length; index++) {
const token = validTokens[index]
const creds: Credentials = { token, apikey: null }
const result = await verifyCredentials(creds, {
auth: 'user',
env: makeEnv({ jwks }),
})
expect(result.error).toBeNull()
expect(result.data!.authMode).toBe('user')
expect(result.data!.keyName).toBeNull()
expect(result.data!.userClaims!.id).toBe('user-123')
expect(result.data!.userClaims!.email).toBe('test@example.com')
expect(result.data!.jwtClaims!.sub).toBe('user-123')
expect(result.data!.token).toBe(token)
}
})

it('fails with invalid JWT', async () => {
Expand Down Expand Up @@ -356,28 +380,47 @@ describe('verifyCredentials', () => {
})

describe('user mode with remote JWKS URL', () => {
let privateKey: CryptoKey
let jwks: JsonWebKeySet
let validToken: string
let validTokens: string[]
let fetchMock: ReturnType<typeof vi.fn>

beforeAll(async () => {
const keyPair = await generateKeyPair('RS256')
privateKey = keyPair.privateKey
const publicJwk = await exportJWK(keyPair.publicKey)
// Asymmetric JWK
const { privateKey, publicKey } = await generateKeyPair('RS256')
const publicJwk = await exportJWK(publicKey)
publicJwk.alg = 'RS256'
publicJwk.use = 'sig'
publicJwk.kid = 'remote-key-1'
jwks = { keys: [publicJwk] }

validToken = await new SignJWT({
sub: 'user-remote',
role: 'authenticated',
})
.setProtectedHeader({ alg: 'RS256', kid: 'remote-key-1' })
.setIssuedAt()
.setExpirationTime('1h')
.sign(privateKey)
// Symmetric Shared Secret JWK
const jwtSecret = await generateSecret('HS256', {
extractable: true,
})
const symmetricJwk = await exportJWK(jwtSecret)
symmetricJwk.alg = 'HS256'
symmetricJwk.kid = 'remote-key-2'

jwks = { keys: [publicJwk, symmetricJwk] }

const getValidToken = async (
key: CryptoKey | Uint8Array<ArrayBufferLike>,
alg: string,
kid: string,
) =>
await new SignJWT({
sub: 'user-123',
role: 'authenticated',
email: 'test@example.com',
})
.setProtectedHeader({ alg, kid })
.setIssuedAt()
.setExpirationTime('1h')
.sign(key)

validTokens = [
await getValidToken(privateKey, publicJwk.alg, publicJwk.kid),
await getValidToken(jwtSecret, symmetricJwk.alg, symmetricJwk.kid),
]
})

beforeEach(() => {
Expand All @@ -396,40 +439,47 @@ describe('verifyCredentials', () => {
})

it('fetches keys from the URL and verifies a valid JWT', async () => {
const creds: Credentials = { token: validToken, apikey: null }
const result = await verifyCredentials(creds, {
auth: 'user',
env: makeEnv({
jwks: new URL(
'https://jwks-fetch-success.example/auth/v1/.well-known/jwks.json',
),
}),
})
expect(result.error).toBeNull()
expect(result.data!.userClaims!.id).toBe('user-remote')
expect(fetchMock).toHaveBeenCalledTimes(1)
for (let index = 0; index < validTokens.length; index++) {
const token = validTokens[index]
const creds: Credentials = { token, apikey: null }
const result = await verifyCredentials(creds, {
auth: 'user',
env: makeEnv({
jwks: new URL(
'https://jwks-fetch-success.example/auth/v1/.well-known/jwks.json',
),
}),
})
expect(result.error).toBeNull()
expect(result.data!.userClaims!.id).toBe('user-remote')
expect(fetchMock).toHaveBeenCalledTimes(1)
}
})

it('reuses the cached resolver for the same URL across requests', async () => {
// Distinct URL so jose's per-resolver cooldown is fresh for this test
const jwksUrl = new URL('https://jwks-cache.example/jwks.json')
const creds: Credentials = { token: validToken, apikey: null }
for (let index = 0; index < validTokens.length; index++) {
const token = validTokens[index]

const first = await verifyCredentials(creds, {
auth: 'user',
env: makeEnv({ jwks: jwksUrl }),
})
const second = await verifyCredentials(creds, {
auth: 'user',
env: makeEnv({ jwks: jwksUrl }),
})
const creds: Credentials = { token, apikey: null }

expect(first.error).toBeNull()
expect(second.error).toBeNull()
// jose's cooldownDuration (default 30s) keeps the second call from re-fetching.
// What we're guarding against is *re-creating* the resolver on every request,
// which would re-fetch every time.
expect(fetchMock).toHaveBeenCalledTimes(1)
const first = await verifyCredentials(creds, {
auth: 'user',
env: makeEnv({ jwks: jwksUrl }),
})
const second = await verifyCredentials(creds, {
auth: 'user',
env: makeEnv({ jwks: jwksUrl }),
})

expect(first.error).toBeNull()
expect(second.error).toBeNull()
// jose's cooldownDuration (default 30s) keeps the second call from re-fetching.
// What we're guarding against is *re-creating* the resolver on every request,
// which would re-fetch every time.
expect(fetchMock).toHaveBeenCalledTimes(1)
}
})

it('rejects an invalid JWT verified against the remote JWKS', async () => {
Expand All @@ -446,7 +496,7 @@ describe('verifyCredentials', () => {

it('rejects when the remote JWKS endpoint fails', async () => {
fetchMock.mockResolvedValueOnce(new Response('boom', { status: 500 }))
const creds: Credentials = { token: validToken, apikey: null }
const creds: Credentials = { token: validTokens.at(0)!, apikey: null }
const result = await verifyCredentials(creds, {
auth: 'user',
env: makeEnv({
Expand Down Expand Up @@ -488,8 +538,9 @@ describe('verifyCredentials', () => {
})
})

const tokenA = validTokens.at(0)! // matches 'RS256' token
const a = await verifyCredentials(
{ token: validToken, apikey: null },
{ token: tokenA, apikey: null },
{ auth: 'user', env: makeEnv({ jwks: urlA }) },
)
const b = await verifyCredentials(
Expand Down
50 changes: 45 additions & 5 deletions src/core/verify-credentials.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import {
createLocalJWKSet,
createRemoteJWKSet,
decodeProtectedHeader,
importJWK,
JSONWebKeySet,
JWTPayload,
jwtVerify,
type JWTVerifyGetKey,
} from 'jose'
Expand Down Expand Up @@ -91,7 +95,10 @@ function jwtClaimsToUserClaims(jwtClaims: JWTClaims): UserClaims {

const INVALID = Symbol('invalid')

let remoteJwksResolver: { url: string; resolver: JWTVerifyGetKey } | undefined =
export type JwksResolver = JWTVerifyGetKey & {
jwks: () => JSONWebKeySet | undefined
}
let remoteJwksResolver: { url: string; resolver: JwksResolver } | undefined =
undefined

/**
Expand All @@ -104,15 +111,25 @@ let remoteJwksResolver: { url: string; resolver: JWTVerifyGetKey } | undefined =
*
* @internal
*/
function getJwksResolver(jwks: JsonWebKeySet | URL): JWTVerifyGetKey {
function getJwksResolver(jwks: JsonWebKeySet | URL): JwksResolver {
if (jwks instanceof URL) {
const url = jwks.toString()
if (remoteJwksResolver?.url !== url) {
remoteJwksResolver = { url, resolver: createRemoteJWKSet(jwks) }
}
return remoteJwksResolver.resolver
}
return createLocalJWKSet(jwks)

const localJwkSet = createLocalJWKSet(jwks)
function localJwtVerifyGetKey(...args: Parameters<typeof localJwkSet>) {
return localJwkSet(...args)
}

const localJwksResolver: JwksResolver = Object.assign(localJwtVerifyGetKey, {
jwks: () => jwks,
})

return localJwksResolver
}

/**
Expand Down Expand Up @@ -215,8 +232,31 @@ async function tryMode(
if (credentials.token.startsWith('sb_')) return null
if (!env.jwks) return null
try {
const jwkSet = getJwksResolver(env.jwks)
const { payload } = await jwtVerify(credentials.token, jwkSet)
const jwkResolver = getJwksResolver(env.jwks)
const { alg, kid } = decodeProtectedHeader(credentials.token)
if (!alg || !kid) {
return INVALID
}

let payload: JWTPayload | null = null

// Symmetric algorithm requires importing the shared secret
if (alg === 'HS256') {
const jwk = jwkResolver
.jwks()
?.keys.find((key) => key.alg === alg && key.kid === kid)
if (!jwk) {
return INVALID
}
const sharedSecret = await importJWK(jwk, 'HS256')

const verify = await jwtVerify(credentials.token, sharedSecret)
payload = verify.payload
} else {
const verify = await jwtVerify(credentials.token, jwkResolver)
payload = verify.payload
}

if (typeof payload.sub !== 'string') {
return INVALID
}
Expand Down
Loading