diff --git a/graphql/server/package.json b/graphql/server/package.json index 19633f5fcb..4060a4c268 100644 --- a/graphql/server/package.json +++ b/graphql/server/package.json @@ -45,6 +45,7 @@ "@constructive-io/csrf": "workspace:^", "@constructive-io/graphql-env": "workspace:^", "@constructive-io/graphql-types": "workspace:^", + "@constructive-io/oauth": "workspace:^", "@constructive-io/s3-utils": "workspace:^", "@constructive-io/upload-names": "workspace:^", "@constructive-io/url-domains": "workspace:^", @@ -57,6 +58,7 @@ "cors": "^2.8.6", "deepmerge": "^4.3.1", "express": "^5.2.1", + "express-rate-limit": "^8.5.1", "gql-ast": "workspace:^", "grafast": "1.0.0", "grafserv": "1.0.0", diff --git a/graphql/server/src/middleware/__tests__/oauth.test.ts b/graphql/server/src/middleware/__tests__/oauth.test.ts new file mode 100644 index 0000000000..8e4f4f1dc3 --- /dev/null +++ b/graphql/server/src/middleware/__tests__/oauth.test.ts @@ -0,0 +1,964 @@ +import { Request, Response } from 'express'; +import crypto from 'crypto'; + +// Mock dependencies before importing the module +jest.mock('@constructive-io/oauth', () => ({ + OAuthClient: jest.fn().mockImplementation(() => ({ + getAuthorizationUrl: jest.fn().mockReturnValue({ + url: 'https://accounts.google.com/o/oauth2/v2/auth?client_id=test', + state: 'mock-state', + }), + handleCallback: jest.fn(), + })), +})); + +jest.mock('pg-cache', () => ({ + getPgPool: jest.fn().mockReturnValue({ + query: jest.fn(), + connect: jest.fn(), + }), +})); + +jest.mock('@pgpmjs/logger', () => ({ + Logger: jest.fn().mockImplementation(() => ({ + info: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + debug: jest.fn(), + })), +})); + +jest.mock('express-rate-limit', () => { + return jest.fn(() => (_req: any, _res: any, next: any) => next()); +}); + +// Import after mocks +import { createOAuthRoutes } from '../oauth'; +import { OAuthClient } from '@constructive-io/oauth'; +import { getPgPool } from 'pg-cache'; + +describe('OAuth Middleware', () => { + const mockOpts = { + oauth: { + baseUrl: 'https://app.example.com', + allowSignup: true, + requireVerifiedEmail: true, + }, + pg: { + database: 'test_db', + }, + }; + + const mockProviderRow = { + slug: 'google', + kind: 'oidc', + display_name: 'Google', + enabled: true, + client_id: 'test-client-id', + client_secret: 'test-client-secret', + authorization_url: null as string | null, + token_url: null as string | null, + userinfo_url: null as string | null, + scopes: ['openid', 'email', 'profile'], + pkce_enabled: true, + }; + + beforeEach(() => { + jest.clearAllMocks(); + process.env.OAUTH_SECRET = 'test-secret-key-for-testing'; + }); + + afterEach(() => { + delete process.env.OAUTH_SECRET; + }); + + const createValidState = () => { + const statePayload = { + redirect_uri: '/dashboard', + provider: 'google', + nonce: crypto.randomBytes(16).toString('hex'), + exp: Date.now() + 10 * 60 * 1000, + }; + const json = JSON.stringify(statePayload); + const sig = crypto.createHmac('sha256', 'test-secret-key-for-testing').update(json).digest('base64url'); + return Buffer.from(json).toString('base64url') + '.' + sig; + }; + + const setupMockQuery = (responses: any[]) => { + const mockQuery = jest.fn(); + responses.forEach((response, index) => { + if (response instanceof Error) { + mockQuery.mockRejectedValueOnce(response); + } else { + mockQuery.mockResolvedValueOnce(response); + } + }); + // Create a mock client for pool.connect() + const mockClient = { + query: mockQuery, + release: jest.fn(), + }; + (getPgPool as jest.Mock).mockReturnValue({ + query: mockQuery, + connect: jest.fn().mockResolvedValue(mockClient), + }); + return mockQuery; + }; + + const mockRequestHelpers = { + get: jest.fn().mockReturnValue('localhost:3000'), + protocol: 'http', + }; + + describe('createOAuthRoutes', () => { + it('always creates routes (providers come from database)', () => { + const router = createOAuthRoutes(mockOpts as any); + // Should have 4 routes: /providers, /error, /:provider, /:provider/callback + expect(router.stack.length).toBe(4); + }); + + it('applies rate limiting to OAuth initiation and callback routes', () => { + const router = createOAuthRoutes(mockOpts as any); + + const initiateRoute = router.stack.find( + (layer: any) => layer.route?.path === '/:provider' && layer.route?.methods?.get + ); + const callbackRoute = router.stack.find( + (layer: any) => layer.route?.path === '/:provider/callback' + ); + const providersRoute = router.stack.find( + (layer: any) => layer.route?.path === '/providers' + ); + + // Initiate and callback routes should have 2 handlers (rate limiter + handler) + expect(initiateRoute!.route.stack.length).toBe(2); + expect(callbackRoute!.route.stack.length).toBe(2); + + // Providers route should have 1 handler (no rate limiting) + expect(providersRoute!.route.stack.length).toBe(1); + }); + }); + + describe('Providers Endpoint', () => { + it('returns empty list when no API context', async () => { + const router = createOAuthRoutes(mockOpts as any); + + const req = {} as Request; + const res = { + json: jest.fn(), + } as unknown as Response; + + const providersRoute = router.stack.find( + (layer: any) => layer.route?.path === '/providers' + ); + const handler = providersRoute!.route.stack.slice(-1)[0].handle; + + await handler(req, res, jest.fn()); + + expect(res.json).toHaveBeenCalledWith({ providers: [] }); + }); + + it('returns providers from database', async () => { + setupMockQuery([{ rows: [{ slug: 'google' }, { slug: 'github' }] }]); + + const router = createOAuthRoutes(mockOpts as any); + + const req = { + ...mockRequestHelpers, + api: { + rlsModule: { privateSchema: { schemaName: 'auth_private' } }, + dbname: 'tenant_db', + }, + } as unknown as Request; + const res = { + json: jest.fn(), + } as unknown as Response; + + const providersRoute = router.stack.find( + (layer: any) => layer.route?.path === '/providers' + ); + const handler = providersRoute!.route.stack.slice(-1)[0].handle; + + await handler(req, res, jest.fn()); + + expect(res.json).toHaveBeenCalledWith({ providers: ['google', 'github'] }); + }); + }); + + describe('OAuth Initiation', () => { + it('rejects when API context is missing', async () => { + const router = createOAuthRoutes(mockOpts as any); + + const req = { + ...mockRequestHelpers, + params: { provider: 'google' }, + query: { redirect_uri: '/dashboard' }, + } as unknown as Request; + + const res = { + redirect: jest.fn(), + } as unknown as Response; + + const initiateRoute = router.stack.find( + (layer: any) => layer.route?.path === '/:provider' && layer.route?.methods?.get + ); + const handler = initiateRoute!.route.stack.slice(-1)[0].handle; + + await handler(req, res, jest.fn()); + + expect(res.redirect).toHaveBeenCalledWith( + expect.stringContaining('error=API_NOT_CONFIGURED') + ); + }); + + it('rejects when provider not found in database', async () => { + // Query 1: getEncryptedSecretsSchema returns schema + // Query 2: getIdentityProvider returns empty (not found) + setupMockQuery([ + { rows: [{ encrypted_schema: 'test_encrypted' }] }, + { rows: [] }, + ]); + + const router = createOAuthRoutes(mockOpts as any); + + const req = { + ...mockRequestHelpers, + params: { provider: 'google' }, + query: { redirect_uri: '/dashboard' }, + api: { + rlsModule: { privateSchema: { schemaName: 'auth_private' } }, + dbname: 'tenant_db', + }, + } as unknown as Request; + + const res = { + redirect: jest.fn(), + } as unknown as Response; + + const initiateRoute = router.stack.find( + (layer: any) => layer.route?.path === '/:provider' && layer.route?.methods?.get + ); + const handler = initiateRoute!.route.stack.slice(-1)[0].handle; + + await handler(req, res, jest.fn()); + + expect(res.redirect).toHaveBeenCalledWith( + expect.stringContaining('error=PROVIDER_NOT_CONFIGURED') + ); + }); + + it('initiates OAuth with signed state cookie when provider found', async () => { + // Query 1: getEncryptedSecretsSchema + // Query 2: getIdentityProvider + setupMockQuery([ + { rows: [{ encrypted_schema: 'test_encrypted' }] }, + { rows: [mockProviderRow] }, + ]); + + const router = createOAuthRoutes(mockOpts as any); + + const req = { + ...mockRequestHelpers, + params: { provider: 'google' }, + query: { redirect_uri: '/dashboard' }, + api: { + rlsModule: { privateSchema: { schemaName: 'auth_private' } }, + dbname: 'tenant_db', + }, + } as unknown as Request; + + const cookies: Record = {}; + const res = { + cookie: jest.fn((name, value, opts) => { + cookies[name] = { value, opts }; + }), + redirect: jest.fn(), + } as unknown as Response; + + const initiateRoute = router.stack.find( + (layer: any) => layer.route?.path === '/:provider' && layer.route?.methods?.get + ); + const handler = initiateRoute!.route.stack.slice(-1)[0].handle; + + await handler(req, res, jest.fn()); + + expect(res.cookie).toHaveBeenCalledWith( + 'oauth_state', + expect.any(String), + expect.objectContaining({ + httpOnly: true, + sameSite: 'lax', + }) + ); + expect(res.redirect).toHaveBeenCalled(); + }); + }); + + describe('OAuth Callback', () => { + const createMockReqRes = (overrides: Partial<{ + params: any; + query: any; + cookies: any; + api: any; + }> = {}) => { + const req = { + ...mockRequestHelpers, + params: { provider: 'google' }, + query: { code: 'auth-code', state: 'valid-state' }, + cookies: { oauth_state: 'valid-state' }, + api: { + rlsModule: { + privateSchema: { schemaName: 'auth_private' }, + }, + dbname: 'tenant_db', + authSettings: {}, + }, + ...overrides, + } as unknown as Request; + + const res = { + clearCookie: jest.fn(), + cookie: jest.fn(), + redirect: jest.fn(), + status: jest.fn().mockReturnThis(), + json: jest.fn(), + } as unknown as Response; + + return { req, res }; + }; + + it('rejects when state does not match', async () => { + const router = createOAuthRoutes(mockOpts as any); + const { req, res } = createMockReqRes({ + query: { code: 'auth-code', state: 'invalid-state' }, + cookies: { oauth_state: 'different-state' }, + }); + + const callbackRoute = router.stack.find( + (layer: any) => layer.route?.path === '/:provider/callback' + ); + const handler = callbackRoute!.route.stack.slice(-1)[0].handle; + + await handler(req, res, jest.fn()); + + expect(res.redirect).toHaveBeenCalledWith( + expect.stringContaining('error=INVALID_STATE') + ); + }); + + it('rejects when API context is missing', async () => { + const router = createOAuthRoutes(mockOpts as any); + const validState = createValidState(); + + const { req, res } = createMockReqRes({ + query: { code: 'auth-code', state: validState }, + cookies: { oauth_state: validState }, + api: undefined, + }); + + const callbackRoute = router.stack.find( + (layer: any) => layer.route?.path === '/:provider/callback' + ); + const handler = callbackRoute!.route.stack.slice(-1)[0].handle; + + await handler(req, res, jest.fn()); + + expect(res.redirect).toHaveBeenCalledWith( + expect.stringContaining('error=API_NOT_CONFIGURED') + ); + }); + + it('rejects when provider not found in database', async () => { + // Query 1: getEncryptedSecretsSchema returns schema + // Query 2: getIdentityProvider returns empty (not found) + setupMockQuery([ + { rows: [{ encrypted_schema: 'test_encrypted' }] }, + { rows: [] }, + ]); + + const router = createOAuthRoutes(mockOpts as any); + const validState = createValidState(); + + const { req, res } = createMockReqRes({ + query: { code: 'auth-code', state: validState }, + cookies: { oauth_state: validState }, + }); + + const callbackRoute = router.stack.find( + (layer: any) => layer.route?.path === '/:provider/callback' + ); + const handler = callbackRoute!.route.stack.slice(-1)[0].handle; + + await handler(req, res, jest.fn()); + + expect(res.redirect).toHaveBeenCalledWith( + expect.stringContaining('error=PROVIDER_NOT_CONFIGURED') + ); + }); + + it('handles OAuth provider errors', async () => { + const router = createOAuthRoutes(mockOpts as any); + const { req, res } = createMockReqRes({ + query: { + error: 'access_denied', + error_description: 'User denied access' + }, + }); + + const callbackRoute = router.stack.find( + (layer: any) => layer.route?.path === '/:provider/callback' + ); + const handler = callbackRoute!.route.stack.slice(-1)[0].handle; + + await handler(req, res, jest.fn()); + + expect(res.redirect).toHaveBeenCalledWith( + expect.stringContaining('error=access_denied') + ); + }); + }); + + describe('Shadow Attack Defense', () => { + it('rejects signup with unverified email when requireVerifiedEmail is true', async () => { + const validState = createValidState(); + + // Mock OAuthClient to return unverified email + const mockOAuthClient = { + getAuthorizationUrl: jest.fn(), + handleCallback: jest.fn().mockResolvedValue({ + provider: 'google', + providerId: '123456', + email: 'attacker@example.com', + emailVerified: false, + name: 'Attacker', + picture: null, + raw: {}, + }), + }; + (OAuthClient as jest.Mock).mockImplementation(() => mockOAuthClient); + + // Query 1: getEncryptedSecretsSchema (pool.query) + // Query 2: getIdentityProvider (pool.query) + // Query 3: set_config for JWT context (dbClient.query) + // Query 4: sign_in_identity throws NOT_FOUND (dbClient.query) + const mockClientQuery = jest.fn() + .mockResolvedValueOnce({ rows: [{}] }) // set_config + .mockRejectedValueOnce(new Error('IDENTITY_ACCOUNT_NOT_FOUND')); + const mockClient = { + query: mockClientQuery, + release: jest.fn(), + }; + const mockPoolQuery = jest.fn() + .mockResolvedValueOnce({ rows: [{ encrypted_schema: 'test_encrypted' }] }) + .mockResolvedValueOnce({ rows: [mockProviderRow] }); + (getPgPool as jest.Mock).mockReturnValue({ + query: mockPoolQuery, + connect: jest.fn().mockResolvedValue(mockClient), + }); + + const router = createOAuthRoutes(mockOpts as any); + + const req = { + ...mockRequestHelpers, + params: { provider: 'google' }, + query: { code: 'auth-code', state: validState }, + cookies: { + oauth_state: validState, + constructive_device_token: 'device-token', + }, + api: { + rlsModule: { privateSchema: { schemaName: 'auth_private' } }, + dbname: 'tenant_db', + authSettings: {}, + }, + } as unknown as Request; + + const res = { + clearCookie: jest.fn(), + cookie: jest.fn(), + redirect: jest.fn(), + } as unknown as Response; + + const callbackRoute = router.stack.find( + (layer: any) => layer.route?.path === '/:provider/callback' + ); + const handler = callbackRoute!.route.stack.slice(-1)[0].handle; + + await handler(req, res, jest.fn()); + + expect(res.redirect).toHaveBeenCalledWith( + expect.stringContaining('error=EMAIL_NOT_VERIFIED') + ); + }); + + it('allows signup with verified email (same-origin - cookie mode)', async () => { + const validState = createValidState(); + + // Mock OAuthClient to return verified email + const mockOAuthClient = { + getAuthorizationUrl: jest.fn(), + handleCallback: jest.fn().mockResolvedValue({ + provider: 'google', + providerId: '123456', + email: 'user@example.com', + emailVerified: true, + name: 'User', + picture: null, + raw: {}, + }), + }; + (OAuthClient as jest.Mock).mockImplementation(() => mockOAuthClient); + + // Query 1: getEncryptedSecretsSchema (pool.query) + // Query 2: getIdentityProvider (pool.query) + // Query 3: set_config for JWT context (dbClient.query) + // Query 4: sign_in_identity throws NOT_FOUND (dbClient.query) + // Query 5: sign_up_identity succeeds (dbClient.query) + // (No query 6 for same-origin - no generateCrossOriginToken) + const mockClientQuery = jest.fn() + .mockResolvedValueOnce({ rows: [{}] }) // set_config + .mockRejectedValueOnce(new Error('IDENTITY_ACCOUNT_NOT_FOUND')) + .mockResolvedValueOnce({ + rows: [{ + access_token: 'new-access-token', + user_id: 'user-123', + out_device_token: null, + mfa_required: false, + }], + }); + const mockClient = { + query: mockClientQuery, + release: jest.fn(), + }; + const mockPoolQuery = jest.fn() + .mockResolvedValueOnce({ rows: [{ encrypted_schema: 'test_encrypted' }] }) + .mockResolvedValueOnce({ rows: [mockProviderRow] }); + (getPgPool as jest.Mock).mockReturnValue({ + query: mockPoolQuery, + connect: jest.fn().mockResolvedValue(mockClient), + }); + + const router = createOAuthRoutes(mockOpts as any); + + const req = { + ...mockRequestHelpers, + params: { provider: 'google' }, + query: { code: 'auth-code', state: validState }, + cookies: { oauth_state: validState }, + api: { + rlsModule: { privateSchema: { schemaName: 'auth_private' } }, + dbname: 'tenant_db', + authSettings: {}, + }, + } as unknown as Request; + + const res = { + clearCookie: jest.fn(), + cookie: jest.fn(), + redirect: jest.fn(), + } as unknown as Response; + + const callbackRoute = router.stack.find( + (layer: any) => layer.route?.path === '/:provider/callback' + ); + const handler = callbackRoute!.route.stack.slice(-1)[0].handle; + + await handler(req, res, jest.fn()); + + // Same-origin: should set session cookie, no token in URL + expect(res.cookie).toHaveBeenCalledWith( + 'constructive_session', + 'new-access-token', + expect.any(Object) + ); + expect(res.redirect).toHaveBeenCalledWith('/dashboard'); + }); + + it('allows signup with verified email (cross-origin - token mode)', async () => { + // Create state with cross-origin redirect_uri + const crossOriginStatePayload = { + redirect_uri: 'http://frontend.example.com/auth/callback', + provider: 'google', + nonce: crypto.randomBytes(16).toString('hex'), + exp: Date.now() + 10 * 60 * 1000, + }; + const json = JSON.stringify(crossOriginStatePayload); + const sig = crypto.createHmac('sha256', 'test-secret-key-for-testing').update(json).digest('base64url'); + const crossOriginState = Buffer.from(json).toString('base64url') + '.' + sig; + + const mockOAuthClient = { + getAuthorizationUrl: jest.fn(), + handleCallback: jest.fn().mockResolvedValue({ + provider: 'google', + providerId: '123456', + email: 'user@example.com', + emailVerified: true, + name: 'User', + picture: null, + raw: {}, + }), + }; + (OAuthClient as jest.Mock).mockImplementation(() => mockOAuthClient); + + // Query 1: getEncryptedSecretsSchema (pool.query) + // Query 2: getIdentityProvider (pool.query) + // Query 6: generateCrossOriginToken UPDATE (pool.query, cross-origin only) + const mockPoolQuery = jest.fn() + .mockResolvedValueOnce({ rows: [{ encrypted_schema: 'test_encrypted' }] }) + .mockResolvedValueOnce({ rows: [mockProviderRow] }) + .mockResolvedValueOnce({ rows: [{ id: 'credential-id' }] }); + // Query 3: set_config for JWT context (dbClient.query) + // Query 4: sign_in_identity throws NOT_FOUND (dbClient.query) + // Query 5: sign_up_identity succeeds (dbClient.query) + const mockClientQuery = jest.fn() + .mockResolvedValueOnce({ rows: [{}] }) // set_config + .mockRejectedValueOnce(new Error('IDENTITY_ACCOUNT_NOT_FOUND')) + .mockResolvedValueOnce({ + rows: [{ + access_token: 'new-access-token', + user_id: 'user-123', + out_device_token: null, + mfa_required: false, + }], + }); + const mockClient = { + query: mockClientQuery, + release: jest.fn(), + }; + (getPgPool as jest.Mock).mockReturnValue({ + query: mockPoolQuery, + connect: jest.fn().mockResolvedValue(mockClient), + }); + + const router = createOAuthRoutes(mockOpts as any); + + const req = { + ...mockRequestHelpers, + params: { provider: 'google' }, + query: { code: 'auth-code', state: crossOriginState }, + cookies: { oauth_state: crossOriginState }, + api: { + rlsModule: { privateSchema: { schemaName: 'auth_private' } }, + dbname: 'tenant_db', + authSettings: {}, + }, + } as unknown as Request; + + const res = { + clearCookie: jest.fn(), + cookie: jest.fn(), + redirect: jest.fn(), + } as unknown as Response; + + const callbackRoute = router.stack.find( + (layer: any) => layer.route?.path === '/:provider/callback' + ); + const handler = callbackRoute!.route.stack.slice(-1)[0].handle; + + await handler(req, res, jest.fn()); + + // Cross-origin: should NOT set session cookie, should redirect with token + expect(res.cookie).not.toHaveBeenCalled(); + expect(res.redirect).toHaveBeenCalledWith( + expect.stringContaining('http://frontend.example.com/auth/callback?token=') + ); + }); + }); + + describe('MFA Flow', () => { + it('redirects to MFA page when mfa_required is true', async () => { + const validState = createValidState(); + + const mockOAuthClient = { + getAuthorizationUrl: jest.fn(), + handleCallback: jest.fn().mockResolvedValue({ + provider: 'google', + providerId: '123456', + email: 'user@example.com', + emailVerified: true, + name: 'User', + picture: null, + raw: {}, + }), + }; + (OAuthClient as jest.Mock).mockImplementation(() => mockOAuthClient); + + // Query 1: getEncryptedSecretsSchema (pool.query) + // Query 2: getIdentityProvider (pool.query) + const mockPoolQuery = jest.fn() + .mockResolvedValueOnce({ rows: [{ encrypted_schema: 'test_encrypted' }] }) + .mockResolvedValueOnce({ rows: [mockProviderRow] }); + // Query 3: set_config for JWT context (dbClient.query) + // Query 4: sign_in_identity returns MFA required (dbClient.query) + const mockClientQuery = jest.fn() + .mockResolvedValueOnce({ rows: [{}] }) // set_config + .mockResolvedValueOnce({ + rows: [{ + mfa_required: true, + mfa_challenge_token: 'mfa-challenge-token-123', + user_id: 'user-123', + }], + }); + const mockClient = { + query: mockClientQuery, + release: jest.fn(), + }; + (getPgPool as jest.Mock).mockReturnValue({ + query: mockPoolQuery, + connect: jest.fn().mockResolvedValue(mockClient), + }); + + const router = createOAuthRoutes(mockOpts as any); + + const req = { + ...mockRequestHelpers, + params: { provider: 'google' }, + query: { code: 'auth-code', state: validState }, + cookies: { oauth_state: validState }, + api: { + rlsModule: { privateSchema: { schemaName: 'auth_private' } }, + dbname: 'tenant_db', + authSettings: {}, + }, + } as unknown as Request; + + const res = { + clearCookie: jest.fn(), + cookie: jest.fn(), + redirect: jest.fn(), + } as unknown as Response; + + const callbackRoute = router.stack.find( + (layer: any) => layer.route?.path === '/:provider/callback' + ); + const handler = callbackRoute!.route.stack.slice(-1)[0].handle; + + await handler(req, res, jest.fn()); + + expect(res.redirect).toHaveBeenCalledWith( + expect.stringContaining('/auth/mfa') + ); + expect(res.redirect).toHaveBeenCalledWith( + expect.stringContaining('token=mfa-challenge-token-123') + ); + }); + }); + + describe('Multi-Tenancy', () => { + it('connects to correct tenant database based on req.api.dbname', async () => { + const validState = createValidState(); + + const mockOAuthClient = { + getAuthorizationUrl: jest.fn(), + handleCallback: jest.fn().mockResolvedValue({ + provider: 'google', + providerId: '123456', + email: 'user@example.com', + emailVerified: true, + name: 'User', + picture: null, + raw: {}, + }), + }; + (OAuthClient as jest.Mock).mockImplementation(() => mockOAuthClient); + + // Query 1: getEncryptedSecretsSchema (pool.query) + // Query 2: getIdentityProvider (pool.query) + const mockPoolQuery = jest.fn() + .mockResolvedValueOnce({ rows: [{ encrypted_schema: 'test_encrypted' }] }) + .mockResolvedValueOnce({ rows: [mockProviderRow] }); + // Query 3: set_config for JWT context (dbClient.query) + // Query 4: sign_in_identity succeeds (dbClient.query) + // (No query 5 for same-origin - no generateCrossOriginToken) + const mockClientQuery = jest.fn() + .mockResolvedValueOnce({ rows: [{}] }) // set_config + .mockResolvedValueOnce({ + rows: [{ + access_token: 'tenant-access-token', + user_id: 'user-123', + out_device_token: null, + mfa_required: false, + }], + }); + const mockClient = { + query: mockClientQuery, + release: jest.fn(), + }; + (getPgPool as jest.Mock).mockReturnValue({ + query: mockPoolQuery, + connect: jest.fn().mockResolvedValue(mockClient), + }); + + const router = createOAuthRoutes(mockOpts as any); + + const req = { + ...mockRequestHelpers, + params: { provider: 'google' }, + query: { code: 'auth-code', state: validState }, + cookies: { oauth_state: validState }, + api: { + rlsModule: { privateSchema: { schemaName: 'auth_private' } }, + dbname: 'tenant_acme_db', + authSettings: {}, + }, + } as unknown as Request; + + const res = { + clearCookie: jest.fn(), + cookie: jest.fn(), + redirect: jest.fn(), + } as unknown as Response; + + const callbackRoute = router.stack.find( + (layer: any) => layer.route?.path === '/:provider/callback' + ); + const handler = callbackRoute!.route.stack.slice(-1)[0].handle; + + await handler(req, res, jest.fn()); + + expect(getPgPool).toHaveBeenCalledWith( + expect.objectContaining({ + database: 'tenant_acme_db', + }) + ); + }); + + it('uses correct private schema for each tenant', async () => { + const validState = createValidState(); + + const mockOAuthClient = { + getAuthorizationUrl: jest.fn(), + handleCallback: jest.fn().mockResolvedValue({ + provider: 'google', + providerId: '123456', + email: 'user@example.com', + emailVerified: true, + name: 'User', + picture: null, + raw: {}, + }), + }; + (OAuthClient as jest.Mock).mockImplementation(() => mockOAuthClient); + + // Query 1: getEncryptedSecretsSchema (pool.query) + // Query 2: getIdentityProvider (pool.query) + const mockPoolQuery = jest.fn() + .mockResolvedValueOnce({ rows: [{ encrypted_schema: 'test_encrypted' }] }) + .mockResolvedValueOnce({ rows: [mockProviderRow] }); + // Query 3: set_config for JWT context (dbClient.query) + // Query 4: sign_in_identity succeeds (dbClient.query) + // (No query 5 for same-origin - no generateCrossOriginToken) + const mockClientQuery = jest.fn() + .mockResolvedValueOnce({ rows: [{}] }) // set_config + .mockResolvedValueOnce({ + rows: [{ + access_token: 'tenant-access-token', + user_id: 'user-123', + out_device_token: null, + mfa_required: false, + }], + }); + const mockClient = { + query: mockClientQuery, + release: jest.fn(), + }; + (getPgPool as jest.Mock).mockReturnValue({ + query: mockPoolQuery, + connect: jest.fn().mockResolvedValue(mockClient), + }); + + const router = createOAuthRoutes(mockOpts as any); + + const req = { + ...mockRequestHelpers, + params: { provider: 'google' }, + query: { code: 'auth-code', state: validState }, + cookies: { oauth_state: validState }, + api: { + rlsModule: { privateSchema: { schemaName: 'custom_auth_schema' } }, + dbname: 'tenant_db', + authSettings: {}, + }, + } as unknown as Request; + + const res = { + clearCookie: jest.fn(), + cookie: jest.fn(), + redirect: jest.fn(), + } as unknown as Response; + + const callbackRoute = router.stack.find( + (layer: any) => layer.route?.path === '/:provider/callback' + ); + const handler = callbackRoute!.route.stack.slice(-1)[0].handle; + + await handler(req, res, jest.fn()); + + // Second query is getIdentityProvider which uses custom_auth_schema + expect(mockPoolQuery).toHaveBeenCalledWith( + expect.stringContaining('"custom_auth_schema".identity_providers'), + expect.any(Array) + ); + }); + + it('reads provider config from tenant database', async () => { + const validState = createValidState(); + + const mockOAuthClient = { + getAuthorizationUrl: jest.fn().mockReturnValue({ + url: 'https://accounts.google.com/o/oauth2/v2/auth', + state: validState, + }), + handleCallback: jest.fn(), + }; + (OAuthClient as jest.Mock).mockImplementation(() => mockOAuthClient); + + const tenantProviderRow = { + ...mockProviderRow, + client_id: 'tenant-specific-client-id', + client_secret: 'tenant-specific-secret', + }; + + const mockQuery = jest.fn() + .mockResolvedValueOnce({ rows: [{ encrypted_schema: 'test_encrypted' }] }) + .mockResolvedValueOnce({ rows: [tenantProviderRow] }); + (getPgPool as jest.Mock).mockReturnValue({ query: mockQuery }); + + const router = createOAuthRoutes(mockOpts as any); + + const req = { + ...mockRequestHelpers, + params: { provider: 'google' }, + query: { redirect_uri: '/dashboard' }, + api: { + rlsModule: { privateSchema: { schemaName: 'auth_private' } }, + dbname: 'tenant_db', + }, + } as unknown as Request; + + const res = { + cookie: jest.fn(), + redirect: jest.fn(), + } as unknown as Response; + + const initiateRoute = router.stack.find( + (layer: any) => layer.route?.path === '/:provider' && layer.route?.methods?.get + ); + const handler = initiateRoute!.route.stack.slice(-1)[0].handle; + + await handler(req, res, jest.fn()); + + // Verify OAuthClient was created with tenant's credentials + expect(OAuthClient).toHaveBeenCalledWith( + expect.objectContaining({ + providers: { + google: { + clientId: 'tenant-specific-client-id', + clientSecret: 'tenant-specific-secret', + }, + }, + }) + ); + }); + }); +}); diff --git a/graphql/server/src/middleware/oauth.ts b/graphql/server/src/middleware/oauth.ts new file mode 100644 index 0000000000..8ae53c07bd --- /dev/null +++ b/graphql/server/src/middleware/oauth.ts @@ -0,0 +1,591 @@ +import crypto from 'crypto'; +import { Router, Request, Response } from 'express'; +import rateLimit from 'express-rate-limit'; +import { OAuthClient, OAuthProfile } from '@constructive-io/oauth'; +import { Logger } from '@pgpmjs/logger'; +import { Pool } from 'pg'; +import { getPgPool } from 'pg-cache'; +import type { ConstructiveOptions } from '@constructive-io/graphql-types'; + +import { + DEVICE_TOKEN_COOKIE_NAME, + getSessionCookieConfig, + getDeviceTokenCookieConfig, + setSessionCookie, + setDeviceTokenCookie, +} from './cookie'; + +const log = new Logger('oauth'); + +const OAUTH_STATE_COOKIE = 'oauth_state'; +const OAUTH_STATE_MAX_AGE = 10 * 60 * 1000; // 10 minutes + +// ============================================================================= +// Signed State Utilities +// ============================================================================= + +interface StatePayload { + redirect_uri: string; + provider: string; + nonce: string; + exp: number; +} + +function getStateSecret(): string { + const secret = process.env.OAUTH_SECRET; + if (!secret) { + throw new Error('OAUTH_SECRET environment variable is required'); + } + return secret; +} + +function createSignedState(payload: { redirect_uri: string; provider: string }): string { + const data: StatePayload = { + ...payload, + nonce: crypto.randomBytes(16).toString('hex'), + exp: Date.now() + OAUTH_STATE_MAX_AGE, + }; + const json = JSON.stringify(data); + const sig = crypto.createHmac('sha256', getStateSecret()).update(json).digest('base64url'); + return Buffer.from(json).toString('base64url') + '.' + sig; +} + +function verifySignedState(state: string): StatePayload | null { + try { + const [payloadB64, sig] = state.split('.'); + if (!payloadB64 || !sig) return null; + + const json = Buffer.from(payloadB64, 'base64url').toString(); + const expectedSig = crypto.createHmac('sha256', getStateSecret()).update(json).digest('base64url'); + + if (!crypto.timingSafeEqual(Buffer.from(sig), Buffer.from(expectedSig))) { + return null; + } + + const data = JSON.parse(json) as StatePayload; + if (data.exp < Date.now()) { + return null; + } + + return data; + } catch { + return null; + } +} + +// ============================================================================= +// Identity Provider Database Functions +// ============================================================================= + +async function getEncryptedSecretsSchema( + pool: Pool, + privateSchema: string +): Promise { + const sql = ` + SELECT es.schema_name as encrypted_schema + FROM metaschema_public.schema ps + JOIN metaschema_modules_public.encrypted_secrets_module esm ON esm.database_id = ps.database_id + JOIN metaschema_public.schema es ON es.id = esm.schema_id + WHERE ps.schema_name = $1 + LIMIT 1 + `; + const result = await pool.query(sql, [privateSchema]); + return result.rows[0]?.encrypted_schema || null; +} + +interface IdentityProviderConfig { + slug: string; + kind: 'oauth2' | 'oidc'; + display_name: string; + enabled: boolean; + client_id: string; + client_secret: string; + authorization_url: string | null; + token_url: string | null; + userinfo_url: string | null; + scopes: string[]; + pkce_enabled: boolean; +} + +async function getEnabledProviders( + pool: Pool, + privateSchema: string +): Promise { + const sql = ` + SELECT slug FROM "${privateSchema}".identity_providers + WHERE enabled = true AND client_id IS NOT NULL AND client_secret_id IS NOT NULL + `; + const result = await pool.query(sql); + return result.rows.map((row: any) => row.slug); +} + +async function getIdentityProvider( + pool: Pool, + privateSchema: string, + encryptedSecretsSchema: string, + providerSlug: string +): Promise { + const sql = ` + SELECT + ip.slug, + ip.kind, + ip.display_name, + ip.enabled, + ip.client_id, + "${encryptedSecretsSchema}".get(ip.client_secret_id, 'oauth_client_secret') as client_secret, + ip.authorization_url, + ip.token_url, + ip.userinfo_url, + ip.scopes, + ip.pkce_enabled + FROM "${privateSchema}".identity_providers ip + WHERE ip.slug = $1 AND ip.enabled = true + `; + + const result = await pool.query(sql, [providerSlug]); + if (result.rows.length === 0) { + return null; + } + + const row = result.rows[0]; + if (!row.client_id || !row.client_secret) { + return null; + } + + return { + slug: row.slug, + kind: row.kind, + display_name: row.display_name, + enabled: row.enabled, + client_id: row.client_id, + client_secret: row.client_secret, + authorization_url: row.authorization_url, + token_url: row.token_url, + userinfo_url: row.userinfo_url, + scopes: row.scopes || [], + pkce_enabled: row.pkce_enabled ?? true, + }; +} + +function createOAuthClientForProvider( + providerConfig: IdentityProviderConfig, + baseUrl: string +): OAuthClient { + return new OAuthClient({ + providers: { + [providerConfig.slug]: { + clientId: providerConfig.client_id, + clientSecret: providerConfig.client_secret, + }, + }, + baseUrl, + callbackPath: '/auth/{provider}/callback', + }); +} + +// ============================================================================= +// Database Functions +// ============================================================================= + +interface SignInIdentityResult { + id?: string; + user_id?: string; + access_token?: string; + access_token_expires_at?: string; + is_verified?: boolean; + totp_enabled?: boolean; + mfa_required?: boolean; + mfa_challenge_token?: string; + out_device_token?: string; +} + +async function generateCrossOriginToken( + pool: Pool, + privateSchema: string, + accessToken: string +): Promise { + const otToken = crypto.randomBytes(32).toString('base64url'); + + const sql = ` + UPDATE "${privateSchema}".session_credentials + SET ot_token = $1 + WHERE secret_hash = digest($2::text, 'sha256') + RETURNING id + `; + + const result = await pool.query(sql, [otToken, accessToken]); + if (result.rows.length === 0) { + throw new Error('Failed to set cross-origin token'); + } + + return otToken; +} + +// ============================================================================= +// OAuth Routes +// ============================================================================= + +function getBaseUrl(req: Request): string { + const protocol = req.protocol || 'http'; + const host = req.get('host') || 'localhost:3000'; + return `${protocol}://${host}`; +} + +export function createOAuthRoutes(opts: ConstructiveOptions): Router { + const router = Router(); + const oauthConfig = opts.oauth; + + const errorRedirectPath = oauthConfig?.errorRedirectPath ?? '/auth/error'; + const allowSignup = oauthConfig?.allowSignup ?? true; + const requireVerifiedEmail = oauthConfig?.requireVerifiedEmail ?? true; + + // Rate limiters for OAuth endpoints (disabled in development/test) + const skipRateLimit = process.env.NODE_ENV === 'development' || process.env.NODE_ENV === 'test'; + + const oauthInitLimiter = rateLimit({ + windowMs: 60 * 1000, // 1 minute + max: 10, // 10 requests per minute per IP + skip: () => skipRateLimit, + message: { error: 'TOO_MANY_REQUESTS', message: 'Too many OAuth requests, please try again later' }, + standardHeaders: true, + legacyHeaders: false, + }); + + const oauthCallbackLimiter = rateLimit({ + windowMs: 60 * 1000, // 1 minute + max: 30, // 30 requests per minute per IP + skip: () => skipRateLimit, + message: { error: 'TOO_MANY_REQUESTS', message: 'Too many OAuth callback requests, please try again later' }, + standardHeaders: true, + legacyHeaders: false, + }); + + // GET /auth/providers - List available providers from database + router.get('/providers', async (req: Request, res: Response) => { + if (!req.api?.rlsModule?.privateSchema?.schemaName) { + return res.json({ providers: [] }); + } + + const privateSchema = req.api.rlsModule.privateSchema.schemaName; + const dbname = req.api.dbname; + + try { + const pool = getPgPool({ ...opts.pg, database: dbname }); + const providers = await getEnabledProviders(pool, privateSchema); + res.json({ providers }); + } catch (error) { + log.error('[oauth] Failed to fetch providers:', error); + res.json({ providers: [] }); + } + }); + + // GET /auth/error - Error page (must be before /:provider to avoid conflict) + // Pass to next middleware stack (outside this router) for frontend to handle + router.get('/error', (req: Request, res: Response, next) => { + next('router'); + }); + + // GET /auth/:provider - Initiate OAuth flow + router.get('/:provider', oauthInitLimiter, async (req: Request, res: Response) => { + const { provider } = req.params; + const redirectUri = (req.query.redirect_uri as string) || '/'; + + // Check if API context is available + if (!req.api?.rlsModule?.privateSchema?.schemaName) { + log.error(`[oauth] No API context available for ${provider} initiation`); + const errorUrl = new URL(errorRedirectPath, getBaseUrl(req)); + errorUrl.searchParams.set('error', 'API_NOT_CONFIGURED'); + errorUrl.searchParams.set('provider', provider); + return res.redirect(errorUrl.toString()); + } + + const privateSchema = req.api.rlsModule.privateSchema.schemaName; + const dbname = req.api.dbname; + + try { + const pool = getPgPool({ ...opts.pg, database: dbname }); + + // Look up encrypted secrets schema from metaschema + const encryptedSchema = await getEncryptedSecretsSchema(pool, privateSchema); + if (!encryptedSchema) { + log.error(`[oauth] Could not resolve encrypted_secrets schema for ${privateSchema}`); + const errorUrl = new URL(errorRedirectPath, getBaseUrl(req)); + errorUrl.searchParams.set('error', 'SCHEMA_NOT_CONFIGURED'); + errorUrl.searchParams.set('provider', provider); + return res.redirect(errorUrl.toString()); + } + + const providerConfig = await getIdentityProvider(pool, privateSchema, encryptedSchema, provider); + + if (!providerConfig) { + log.warn(`[oauth] Provider ${provider} not found or not configured`); + const errorUrl = new URL(errorRedirectPath, getBaseUrl(req)); + errorUrl.searchParams.set('error', 'PROVIDER_NOT_CONFIGURED'); + errorUrl.searchParams.set('provider', provider); + return res.redirect(errorUrl.toString()); + } + + const state = createSignedState({ redirect_uri: redirectUri, provider }); + + res.cookie(OAUTH_STATE_COOKIE, state, { + httpOnly: true, + secure: process.env.NODE_ENV === 'production', + maxAge: OAUTH_STATE_MAX_AGE, + sameSite: 'lax', + }); + + const client = createOAuthClientForProvider(providerConfig, getBaseUrl(req)); + const { url } = client.getAuthorizationUrl({ provider, state }); + log.info(`[oauth] Initiating OAuth flow for provider: ${provider}`); + res.redirect(url); + } catch (error) { + log.error(`[oauth] Failed to initiate OAuth for ${provider}:`, error); + const errorUrl = new URL(errorRedirectPath, getBaseUrl(req)); + errorUrl.searchParams.set('error', 'OAUTH_INIT_FAILED'); + errorUrl.searchParams.set('provider', provider); + res.redirect(errorUrl.toString()); + } + }); + + // GET /auth/:provider/callback - Handle OAuth callback + router.get('/:provider/callback', oauthCallbackLimiter, async (req: Request, res: Response) => { + const { provider } = req.params; + const { code, state, error: oauthError, error_description } = req.query; + + const storedState = req.cookies[OAUTH_STATE_COOKIE]; + res.clearCookie(OAUTH_STATE_COOKIE); + + // Handle OAuth provider errors + if (oauthError) { + log.warn(`[oauth] Provider ${provider} returned error: ${oauthError}`); + const errorUrl = new URL(errorRedirectPath, getBaseUrl(req)); + errorUrl.searchParams.set('error', oauthError as string); + if (error_description) { + errorUrl.searchParams.set('error_description', error_description as string); + } + errorUrl.searchParams.set('provider', provider); + return res.redirect(errorUrl.toString()); + } + + // Verify state + if (state !== storedState) { + log.warn(`[oauth] State mismatch for ${provider}`); + const errorUrl = new URL(errorRedirectPath, getBaseUrl(req)); + errorUrl.searchParams.set('error', 'INVALID_STATE'); + errorUrl.searchParams.set('provider', provider); + return res.redirect(errorUrl.toString()); + } + + const statePayload = verifySignedState(storedState); + if (!statePayload) { + log.warn(`[oauth] Invalid or expired state for ${provider}`); + const errorUrl = new URL(errorRedirectPath, getBaseUrl(req)); + errorUrl.searchParams.set('error', 'INVALID_STATE'); + errorUrl.searchParams.set('provider', provider); + return res.redirect(errorUrl.toString()); + } + + const { redirect_uri: redirectUri } = statePayload; + + // Check if API context is available + if (!req.api?.rlsModule?.privateSchema?.schemaName) { + log.error(`[oauth] No API context available for ${provider} callback`); + const errorUrl = new URL(errorRedirectPath, getBaseUrl(req)); + errorUrl.searchParams.set('error', 'API_NOT_CONFIGURED'); + errorUrl.searchParams.set('provider', provider); + return res.redirect(errorUrl.toString()); + } + + const privateSchema = req.api.rlsModule.privateSchema.schemaName; + const dbname = req.api.dbname; + const authSettings = req.api.authSettings; + + try { + const pool = getPgPool({ ...opts.pg, database: dbname }); + + // Look up encrypted secrets schema from metaschema + const encryptedSchema = await getEncryptedSecretsSchema(pool, privateSchema); + if (!encryptedSchema) { + log.error(`[oauth] Could not resolve encrypted_secrets schema for ${privateSchema}`); + const errorUrl = new URL(errorRedirectPath, getBaseUrl(req)); + errorUrl.searchParams.set('error', 'SCHEMA_NOT_CONFIGURED'); + errorUrl.searchParams.set('provider', provider); + return res.redirect(errorUrl.toString()); + } + + // Get provider config from database + const providerConfig = await getIdentityProvider(pool, privateSchema, encryptedSchema, provider); + if (!providerConfig) { + log.error(`[oauth] Provider ${provider} not found in database`); + const errorUrl = new URL(errorRedirectPath, getBaseUrl(req)); + errorUrl.searchParams.set('error', 'PROVIDER_NOT_CONFIGURED'); + errorUrl.searchParams.set('provider', provider); + return res.redirect(errorUrl.toString()); + } + + // Create OAuth client with provider config from database + const client = createOAuthClientForProvider(providerConfig, getBaseUrl(req)); + + // Exchange code for profile + const profile = await client.handleCallback({ provider, code: code as string }); + log.info(`[oauth] Got profile for ${provider}: ${profile.email}`); + + // Get device token from cookie + const deviceToken = req.cookies[DEVICE_TOKEN_COOKIE_NAME] ?? null; + + // Calculate target origin for cross-origin flow + const currentOrigin = getBaseUrl(req); + let targetOrigin: string; + try { + const redirectUrl = new URL(redirectUri, currentOrigin); + targetOrigin = redirectUrl.origin; + } catch { + targetOrigin = currentOrigin; + } + + // Use a dedicated database client to ensure JWT context is available for sign_in_identity + // - user_agent: from browser request (same browser will call signInCrossOrigin) + // - origin: target origin (where the token will be exchanged) + const userAgent = req.get('user-agent') || ''; + const dbClient = await pool.connect(); + + let result: SignInIdentityResult; + + try { + // Set JWT context on this connection (false = session-level, persists across queries) + await dbClient.query(` + SELECT set_config('jwt.claims.user_agent', $1, false), + set_config('jwt.claims.origin', $2, false) + `, [userAgent, targetOrigin]); + + // Try sign_in_identity first (using same client) + const details = { + provider: profile.provider, + sub: profile.providerId, + email: profile.email, + email_verified: profile.emailVerified, + name: profile.name, + picture: profile.picture, + raw_userinfo: profile.raw, + }; + + const signInSql = ` + SELECT * FROM "${privateSchema}".sign_in_identity( + $1::text, $2::text, $3::jsonb, $4::text, 'access_token'::text, $5::boolean, $6::text + ) + `; + + try { + const signInResult = await dbClient.query(signInSql, [ + profile.provider, + profile.providerId, + JSON.stringify(details), + profile.email, + true, + deviceToken, + ]); + + result = signInResult.rows[0] || {}; + } catch (err: any) { + const errorMessage = err.message || ''; + + // Handle IDENTITY_ACCOUNT_NOT_FOUND - try signup + if (!errorMessage.includes('IDENTITY_ACCOUNT_NOT_FOUND')) { + throw err; + } + + log.info(`[oauth] Account not found for ${profile.email}, attempting signup`); + + if (!allowSignup) { + log.warn(`[oauth] Signup disabled, rejecting ${profile.email}`); + const errorUrl = new URL(errorRedirectPath, getBaseUrl(req)); + errorUrl.searchParams.set('error', 'SIGNUP_DISABLED'); + errorUrl.searchParams.set('provider', provider); + return res.redirect(errorUrl.toString()); + } + + // Shadow attack defense: require verified email for auto-signup + if (requireVerifiedEmail && !profile.emailVerified) { + log.warn(`[oauth] Rejecting unverified email for signup: ${profile.email}`); + const errorUrl = new URL(errorRedirectPath, getBaseUrl(req)); + errorUrl.searchParams.set('error', 'EMAIL_NOT_VERIFIED'); + errorUrl.searchParams.set('provider', provider); + return res.redirect(errorUrl.toString()); + } + + // Call sign_up_identity (using same client with JWT context) + const signUpSql = ` + SELECT * FROM "${privateSchema}".sign_up_identity( + $1::text, $2::text, $3::text, $4::jsonb, 'access_token'::text, $5::boolean, $6::text + ) + `; + + const signUpResult = await dbClient.query(signUpSql, [ + profile.provider, + profile.providerId, + profile.email, + JSON.stringify(details), + true, + deviceToken, + ]); + + result = signUpResult.rows[0] || {}; + } + } finally { + dbClient.release(); + } + + // Handle MFA required + if (result.mfa_required && result.mfa_challenge_token) { + log.info(`[oauth] MFA required for ${profile.email}`); + const mfaUrl = new URL('/auth/mfa', getBaseUrl(req)); + mfaUrl.searchParams.set('token', result.mfa_challenge_token); + mfaUrl.searchParams.set('redirect_uri', redirectUri); + return res.redirect(mfaUrl.toString()); + } + + // Success + if (!result.access_token) { + throw new Error('No access token returned from sign_in_identity'); + } + + // Determine if this is a cross-origin request + // Cookie mode and Token mode are mutually exclusive (Better Auth design) + const isCrossOrigin = targetOrigin !== currentOrigin; + + if (isCrossOrigin) { + // Cross-origin: Token mode only + // Generate one-time token for frontend to exchange via signInCrossOrigin + // Frontend stores access_token in localStorage + const otToken = await generateCrossOriginToken(pool, privateSchema, result.access_token); + const redirectUrl = new URL(redirectUri, currentOrigin); + redirectUrl.searchParams.set('token', otToken); + log.info(`[oauth] OAuth success for ${profile.email}, cross-origin redirect with one-time token`); + return res.redirect(redirectUrl.toString()); + } else { + // Same-origin: Cookie mode only + // Set httpOnly cookies, no token in URL + const sessionConfig = getSessionCookieConfig(authSettings, true); + setSessionCookie(res, result.access_token, sessionConfig); + + if (result.out_device_token) { + const deviceConfig = getDeviceTokenCookieConfig(authSettings); + setDeviceTokenCookie(res, result.out_device_token, deviceConfig); + } + + log.info(`[oauth] OAuth success for ${profile.email}, same-origin redirect with cookie`); + return res.redirect(redirectUri); + } + + } catch (error: any) { + log.error(`[oauth] Callback failed for ${provider}:`, error); + + const errorUrl = new URL(errorRedirectPath, getBaseUrl(req)); + errorUrl.searchParams.set('error', 'CALLBACK_FAILED'); + errorUrl.searchParams.set('provider', provider); + return res.redirect(errorUrl.toString()); + } + }); + + return router; +} diff --git a/graphql/server/src/server.ts b/graphql/server/src/server.ts index 8235c6c123..ae94bf676b 100644 --- a/graphql/server/src/server.ts +++ b/graphql/server/src/server.ts @@ -37,6 +37,7 @@ import { createRequestLogger } from './middleware/observability/request-logger'; // Auth cookie handling is done via AuthCookiePlugin in grafserv import { createCaptchaMiddleware } from './middleware/captcha'; import { parseCookieValue, SESSION_COOKIE_NAME } from './middleware/cookie'; +import { createOAuthRoutes } from './middleware/oauth'; import { createUploadAuthenticateMiddleware, uploadRoute } from './middleware/upload'; import { createLlmApiRouter } from './middleware/llm-api'; import { startDebugSampler } from './diagnostics/debug-sampler'; @@ -163,6 +164,7 @@ class Server { app.use(requestIp.mw()); app.use(requestLogger); app.use(api); + app.use('/auth', createOAuthRoutes(effectiveOpts)); app.post('/upload', uploadAuthenticate, ...uploadRoute); app.use(authenticate); app.use(createCaptchaMiddleware()); diff --git a/graphql/types/src/constructive.ts b/graphql/types/src/constructive.ts index fc64c547c0..fdbe2f0155 100644 --- a/graphql/types/src/constructive.ts +++ b/graphql/types/src/constructive.ts @@ -20,6 +20,20 @@ import { } from './graphile'; import { LlmOptions } from './llm'; +/** + * OAuth routes configuration + */ +export interface OAuthRoutesConfig { + /** @deprecated baseUrl is now derived from request headers for multi-tenant support */ + baseUrl?: string; + /** Path to redirect on error (default: /auth/error) */ + errorRedirectPath?: string; + /** Allow signup via OAuth (default: true) */ + allowSignup?: boolean; + /** Require verified email for OAuth signup (default: true) */ + requireVerifiedEmail?: boolean; +} + /** * GraphQL-specific options for Constructive */ @@ -59,6 +73,8 @@ export interface ConstructiveOptions extends PgpmOptions, ConstructiveGraphQLOpt jobs?: JobsConfig; /** LLM provider configuration (embeddings, chat, RAG) */ llm?: LlmOptions; + /** OAuth routes configuration */ + oauth?: OAuthRoutesConfig; } /** diff --git a/graphql/types/src/index.ts b/graphql/types/src/index.ts index a66eb0bf63..da9ada2dd5 100644 --- a/graphql/types/src/index.ts +++ b/graphql/types/src/index.ts @@ -12,6 +12,7 @@ export { export { ConstructiveGraphQLOptions, ConstructiveOptions, + OAuthRoutesConfig, constructiveGraphqlDefaults, constructiveDefaults } from './constructive'; diff --git a/packages/oauth/src/oauth-client.ts b/packages/oauth/src/oauth-client.ts index 3f53bdf292..da1184a1c5 100644 --- a/packages/oauth/src/oauth-client.ts +++ b/packages/oauth/src/oauth-client.ts @@ -182,9 +182,9 @@ export class OAuthClient { if (response.ok) { const emails = await response.json(); - const email = extractPrimaryEmail(emails); - if (email) { - return { ...profile, email }; + const extracted = extractPrimaryEmail(emails); + if (extracted) { + return { ...profile, email: extracted.email, emailVerified: extracted.verified }; } } } catch { diff --git a/packages/oauth/src/providers/facebook.ts b/packages/oauth/src/providers/facebook.ts index 56c146dd1c..a17f9963bc 100644 --- a/packages/oauth/src/providers/facebook.ts +++ b/packages/oauth/src/providers/facebook.ts @@ -27,6 +27,7 @@ export const facebookProvider: OAuthProviderConfig = { provider: 'facebook', providerId: profile.id, email: profile.email || null, + emailVerified: !!profile.email, // Facebook emails are verified by platform name: profile.name || null, picture: profile.picture?.data?.url || null, raw: data, diff --git a/packages/oauth/src/providers/github.ts b/packages/oauth/src/providers/github.ts index 04ba3a5290..2fdad65340 100644 --- a/packages/oauth/src/providers/github.ts +++ b/packages/oauth/src/providers/github.ts @@ -28,6 +28,7 @@ export const githubProvider: OAuthProviderConfig = { provider: 'github', providerId: String(profile.id), email: profile.email || null, + emailVerified: false, // GitHub requires separate /user/emails call for verification status name: profile.name || profile.login || null, picture: profile.avatar_url || null, raw: data, @@ -37,10 +38,16 @@ export const githubProvider: OAuthProviderConfig = { export const GITHUB_EMAILS_URL = 'https://api.github.com/user/emails'; -export function extractPrimaryEmail(emails: GitHubEmail[]): string | null { +export interface ExtractedEmail { + email: string; + verified: boolean; +} + +export function extractPrimaryEmail(emails: GitHubEmail[]): ExtractedEmail | null { const primary = emails.find((e) => e.primary && e.verified); - if (primary) return primary.email; + if (primary) return { email: primary.email, verified: true }; const verified = emails.find((e) => e.verified); - if (verified) return verified.email; - return emails[0]?.email || null; + if (verified) return { email: verified.email, verified: true }; + if (emails[0]) return { email: emails[0].email, verified: emails[0].verified }; + return null; } diff --git a/packages/oauth/src/providers/google.ts b/packages/oauth/src/providers/google.ts index eaeac0a11d..8c86cbb11d 100644 --- a/packages/oauth/src/providers/google.ts +++ b/packages/oauth/src/providers/google.ts @@ -24,6 +24,7 @@ export const googleProvider: OAuthProviderConfig = { provider: 'google', providerId: profile.sub, email: profile.email || null, + emailVerified: profile.email_verified ?? false, name: profile.name || null, picture: profile.picture || null, raw: data, diff --git a/packages/oauth/src/providers/linkedin.ts b/packages/oauth/src/providers/linkedin.ts index a7658c859b..a92fd0bfde 100644 --- a/packages/oauth/src/providers/linkedin.ts +++ b/packages/oauth/src/providers/linkedin.ts @@ -24,6 +24,7 @@ export const linkedinProvider: OAuthProviderConfig = { provider: 'linkedin', providerId: profile.sub, email: profile.email || null, + emailVerified: profile.email_verified ?? false, name: profile.name || null, picture: profile.picture || null, raw: data, diff --git a/packages/oauth/src/types.ts b/packages/oauth/src/types.ts index db4ef07437..9665bddab2 100644 --- a/packages/oauth/src/types.ts +++ b/packages/oauth/src/types.ts @@ -14,6 +14,7 @@ export interface OAuthProfile { provider: string; providerId: string; email: string | null; + emailVerified: boolean; name: string | null; picture: string | null; raw: unknown; diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index b782c2687a..771b621d28 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -1470,6 +1470,9 @@ importers: '@constructive-io/graphql-types': specifier: workspace:^ version: link:../types/dist + '@constructive-io/oauth': + specifier: workspace:^ + version: link:../../packages/oauth/dist '@constructive-io/s3-utils': specifier: workspace:^ version: link:../../uploads/s3-utils/dist @@ -1506,6 +1509,9 @@ importers: express: specifier: ^5.2.1 version: 5.2.1 + express-rate-limit: + specifier: ^8.5.1 + version: 8.5.2(express@5.2.1) gql-ast: specifier: workspace:^ version: link:../gql-ast/dist @@ -4518,24 +4524,28 @@ packages: engines: {node: '>= 10'} cpu: [arm64] os: [linux] + libc: [glibc] '@nx/nx-linux-arm64-musl@20.8.3': resolution: {integrity: sha512-LTTGzI8YVPlF1v0YlVf+exM+1q7rpsiUbjTTHJcfHFRU5t4BsiZD54K19Y1UBg1XFx5cwhEaIomSmJ88RwPPVQ==} engines: {node: '>= 10'} cpu: [arm64] os: [linux] + libc: [musl] '@nx/nx-linux-x64-gnu@20.8.3': resolution: {integrity: sha512-SlA4GtXvQbSzSIWLgiIiLBOjdINPOUR/im+TUbaEMZ8wiGrOY8cnk0PVt95TIQJVBeXBCeb5HnoY0lHJpMOODg==} engines: {node: '>= 10'} cpu: [x64] os: [linux] + libc: [glibc] '@nx/nx-linux-x64-musl@20.8.3': resolution: {integrity: sha512-MNzkEwPktp5SQH9dJDH2wP9hgG9LsBDhKJXJfKw6sUI/6qz5+/aAjFziKy+zBnhU4AO1yXt5qEWzR8lDcIriVQ==} engines: {node: '>= 10'} cpu: [x64] os: [linux] + libc: [musl] '@nx/nx-win32-arm64-msvc@20.8.3': resolution: {integrity: sha512-qUV7CyXKwRCM/lkvyS6Xa1MqgAuK5da6w27RAehh7LATBUKn1I4/M7DGn6L7ERCxpZuh1TrDz9pUzEy0R+Ekkg==} @@ -4654,48 +4664,56 @@ packages: engines: {node: ^20.19.0 || >=22.12.0} cpu: [arm64] os: [linux] + libc: [glibc] '@oxfmt/binding-linux-arm64-musl@0.42.0': resolution: {integrity: sha512-+JA0YMlSdDqmacygGi2REp57c3fN+tzARD8nwsukx9pkCHK+6DkbAA9ojS4lNKsiBjIW8WWa0pBrBWhdZEqfuw==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [arm64] os: [linux] + libc: [musl] '@oxfmt/binding-linux-ppc64-gnu@0.42.0': resolution: {integrity: sha512-VfnET0j4Y5mdfCzh5gBt0NK28lgn5DKx+8WgSMLYYeSooHhohdbzwAStLki9pNuGy51y4I7IoW8bqwAaCMiJQg==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [ppc64] os: [linux] + libc: [glibc] '@oxfmt/binding-linux-riscv64-gnu@0.42.0': resolution: {integrity: sha512-gVlCbmBkB0fxBWbhBj9rcxezPydsQHf4MFKeHoTSPicOQ+8oGeTQgQ8EeesSybWeiFPVRx3bgdt4IJnH6nOjAA==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [riscv64] os: [linux] + libc: [glibc] '@oxfmt/binding-linux-riscv64-musl@0.42.0': resolution: {integrity: sha512-zN5OfstL0avgt/IgvRu0zjQzVh/EPkcLzs33E9LMAzpqlLWiPWeMDZyMGFlSRGOdDjuNmlZBCgj0pFnK5u32TQ==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [riscv64] os: [linux] + libc: [musl] '@oxfmt/binding-linux-s390x-gnu@0.42.0': resolution: {integrity: sha512-9X6+H2L0qMc2sCAgO9HS03bkGLMKvOFjmEdchaFlany3vNZOjnVui//D8k/xZAtQv2vaCs1reD5KAgPoIU4msA==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [s390x] os: [linux] + libc: [glibc] '@oxfmt/binding-linux-x64-gnu@0.42.0': resolution: {integrity: sha512-BajxJ6KQvMMdpXGPWhBGyjb2Jvx4uec0w+wi6TJZ6Tv7+MzPwe0pO8g5h1U0jyFgoaF7mDl6yKPW3ykWcbUJRw==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [x64] os: [linux] + libc: [glibc] '@oxfmt/binding-linux-x64-musl@0.42.0': resolution: {integrity: sha512-0wV284I6vc5f0AqAhgAbHU2935B4bVpncPoe5n/WzVZY/KnHgqxC8iSFGeSyLWEgstFboIcWkOPck7tqbdHkzA==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [x64] os: [linux] + libc: [musl] '@oxfmt/binding-openharmony-arm64@0.42.0': resolution: {integrity: sha512-p4BG6HpGnhfgHk1rzZfyR6zcWkE7iLrWxyehHfXUy4Qa5j3e0roglFOdP/Nj5cJJ58MA3isQ5dlfkW2nNEpolw==} @@ -5229,66 +5247,79 @@ packages: resolution: {integrity: sha512-F8sWbhZ7tyuEfsmOxwc2giKDQzN3+kuBLPwwZGyVkLlKGdV1nvnNwYD0fKQ8+XS6hp9nY7B+ZeK01EBUE7aHaw==} cpu: [arm] os: [linux] + libc: [glibc] '@rollup/rollup-linux-arm-musleabihf@4.57.1': resolution: {integrity: sha512-rGfNUfn0GIeXtBP1wL5MnzSj98+PZe/AXaGBCRmT0ts80lU5CATYGxXukeTX39XBKsxzFpEeK+Mrp9faXOlmrw==} cpu: [arm] os: [linux] + libc: [musl] '@rollup/rollup-linux-arm64-gnu@4.57.1': resolution: {integrity: sha512-MMtej3YHWeg/0klK2Qodf3yrNzz6CGjo2UntLvk2RSPlhzgLvYEB3frRvbEF2wRKh1Z2fDIg9KRPe1fawv7C+g==} cpu: [arm64] os: [linux] + libc: [glibc] '@rollup/rollup-linux-arm64-musl@4.57.1': resolution: {integrity: sha512-1a/qhaaOXhqXGpMFMET9VqwZakkljWHLmZOX48R0I/YLbhdxr1m4gtG1Hq7++VhVUmf+L3sTAf9op4JlhQ5u1Q==} cpu: [arm64] os: [linux] + libc: [musl] '@rollup/rollup-linux-loong64-gnu@4.57.1': resolution: {integrity: sha512-QWO6RQTZ/cqYtJMtxhkRkidoNGXc7ERPbZN7dVW5SdURuLeVU7lwKMpo18XdcmpWYd0qsP1bwKPf7DNSUinhvA==} cpu: [loong64] os: [linux] + libc: [glibc] '@rollup/rollup-linux-loong64-musl@4.57.1': resolution: {integrity: sha512-xpObYIf+8gprgWaPP32xiN5RVTi/s5FCR+XMXSKmhfoJjrpRAjCuuqQXyxUa/eJTdAE6eJ+KDKaoEqjZQxh3Gw==} cpu: [loong64] os: [linux] + libc: [musl] '@rollup/rollup-linux-ppc64-gnu@4.57.1': resolution: {integrity: sha512-4BrCgrpZo4hvzMDKRqEaW1zeecScDCR+2nZ86ATLhAoJ5FQ+lbHVD3ttKe74/c7tNT9c6F2viwB3ufwp01Oh2w==} cpu: [ppc64] os: [linux] + libc: [glibc] '@rollup/rollup-linux-ppc64-musl@4.57.1': resolution: {integrity: sha512-NOlUuzesGauESAyEYFSe3QTUguL+lvrN1HtwEEsU2rOwdUDeTMJdO5dUYl/2hKf9jWydJrO9OL/XSSf65R5+Xw==} cpu: [ppc64] os: [linux] + libc: [musl] '@rollup/rollup-linux-riscv64-gnu@4.57.1': resolution: {integrity: sha512-ptA88htVp0AwUUqhVghwDIKlvJMD/fmL/wrQj99PRHFRAG6Z5nbWoWG4o81Nt9FT+IuqUQi+L31ZKAFeJ5Is+A==} cpu: [riscv64] os: [linux] + libc: [glibc] '@rollup/rollup-linux-riscv64-musl@4.57.1': resolution: {integrity: sha512-S51t7aMMTNdmAMPpBg7OOsTdn4tySRQvklmL3RpDRyknk87+Sp3xaumlatU+ppQ+5raY7sSTcC2beGgvhENfuw==} cpu: [riscv64] os: [linux] + libc: [musl] '@rollup/rollup-linux-s390x-gnu@4.57.1': resolution: {integrity: sha512-Bl00OFnVFkL82FHbEqy3k5CUCKH6OEJL54KCyx2oqsmZnFTR8IoNqBF+mjQVcRCT5sB6yOvK8A37LNm/kPJiZg==} cpu: [s390x] os: [linux] + libc: [glibc] '@rollup/rollup-linux-x64-gnu@4.57.1': resolution: {integrity: sha512-ABca4ceT4N+Tv/GtotnWAeXZUZuM/9AQyCyKYyKnpk4yoA7QIAuBt6Hkgpw8kActYlew2mvckXkvx0FfoInnLg==} cpu: [x64] os: [linux] + libc: [glibc] '@rollup/rollup-linux-x64-musl@4.57.1': resolution: {integrity: sha512-HFps0JeGtuOR2convgRRkHCekD7j+gdAuXM+/i6kGzQtFhlCtQkpwtNzkNj6QhCDp7DRJ7+qC/1Vg2jt5iSOFw==} cpu: [x64] os: [linux] + libc: [musl] '@rollup/rollup-openbsd-x64@4.57.1': resolution: {integrity: sha512-H+hXEv9gdVQuDTgnqD+SQffoWoc0Of59AStSzTEj/feWTBAnSfSD3+Dql1ZruJQxmykT/JVY0dE8Ka7z0DH1hw==} @@ -5969,41 +6000,49 @@ packages: resolution: {integrity: sha512-34gw7PjDGB9JgePJEmhEqBhWvCiiWCuXsL9hYphDF7crW7UgI05gyBAi6MF58uGcMOiOqSJ2ybEeCvHcq0BCmQ==} cpu: [arm64] os: [linux] + libc: [glibc] '@unrs/resolver-binding-linux-arm64-musl@1.11.1': resolution: {integrity: sha512-RyMIx6Uf53hhOtJDIamSbTskA99sPHS96wxVE/bJtePJJtpdKGXO1wY90oRdXuYOGOTuqjT8ACccMc4K6QmT3w==} cpu: [arm64] os: [linux] + libc: [musl] '@unrs/resolver-binding-linux-ppc64-gnu@1.11.1': resolution: {integrity: sha512-D8Vae74A4/a+mZH0FbOkFJL9DSK2R6TFPC9M+jCWYia/q2einCubX10pecpDiTmkJVUH+y8K3BZClycD8nCShA==} cpu: [ppc64] os: [linux] + libc: [glibc] '@unrs/resolver-binding-linux-riscv64-gnu@1.11.1': resolution: {integrity: sha512-frxL4OrzOWVVsOc96+V3aqTIQl1O2TjgExV4EKgRY09AJ9leZpEg8Ak9phadbuX0BA4k8U5qtvMSQQGGmaJqcQ==} cpu: [riscv64] os: [linux] + libc: [glibc] '@unrs/resolver-binding-linux-riscv64-musl@1.11.1': resolution: {integrity: sha512-mJ5vuDaIZ+l/acv01sHoXfpnyrNKOk/3aDoEdLO/Xtn9HuZlDD6jKxHlkN8ZhWyLJsRBxfv9GYM2utQ1SChKew==} cpu: [riscv64] os: [linux] + libc: [musl] '@unrs/resolver-binding-linux-s390x-gnu@1.11.1': resolution: {integrity: sha512-kELo8ebBVtb9sA7rMe1Cph4QHreByhaZ2QEADd9NzIQsYNQpt9UkM9iqr2lhGr5afh885d/cB5QeTXSbZHTYPg==} cpu: [s390x] os: [linux] + libc: [glibc] '@unrs/resolver-binding-linux-x64-gnu@1.11.1': resolution: {integrity: sha512-C3ZAHugKgovV5YvAMsxhq0gtXuwESUKc5MhEtjBpLoHPLYM+iuwSj3lflFwK3DPm68660rZ7G8BMcwSro7hD5w==} cpu: [x64] os: [linux] + libc: [glibc] '@unrs/resolver-binding-linux-x64-musl@1.11.1': resolution: {integrity: sha512-rV0YSoyhK2nZ4vEswT/QwqzqQXw5I6CjoaYMOX0TqBlWhojUf8P94mvI7nuJTeaCkkds3QE4+zS8Ko+GdXuZtA==} cpu: [x64] os: [linux] + libc: [musl] '@unrs/resolver-binding-wasm32-wasi@1.11.1': resolution: {integrity: sha512-5u4RkfxJm+Ng7IWgkzi3qrFOvLvQYnPBmjmZQ8+szTK/b31fQCnleNl1GgEt7nIsZRIf5PLhPwT0WM+q45x/UQ==} @@ -7180,6 +7219,12 @@ packages: exponential-backoff@3.1.3: resolution: {integrity: sha512-ZgEeZXj30q+I0EN+CbSSpIyPaJ5HVQD18Z1m+u1FXbAeT94mr1zw50q4q6jiiC447Nl/YTcIYSAftiGqetwXCA==} + express-rate-limit@8.5.2: + resolution: {integrity: sha512-5Kb34ipNX694DH48vN9irak1Qx30nb0PLYHXfJgw4YEjiC3ZEmZJhwOp+VfiCYwFzvFTdB9QkArYS5kXa2cx2A==} + engines: {node: '>= 16'} + peerDependencies: + express: '>= 4.11' + express@5.2.1: resolution: {integrity: sha512-hIS4idWWai69NezIdRt2xFVofaF4j+6INOpJlVOLDO8zXGpUVEVzIYk12UUi2JzjEzWL3IOAxcTubgz9Po0yXw==} engines: {node: '>= 18'} @@ -7786,6 +7831,10 @@ packages: resolution: {integrity: sha512-XXADHxXmvT9+CRxhXg56LJovE+bmWnEWB78LB83VZTprKTmaC5QfruXocxzTZ2Kl0DNwKuBdlIhjL8LeY8Sf8Q==} engines: {node: '>= 12'} + ip-address@10.2.0: + resolution: {integrity: sha512-/+S6j4E9AHvW9SWMSEY9Xfy66O5PWvVEJ08O0y5JGyEKQpojb0K0GKpz/v5HJ/G0vi3D2sjGK78119oXZeE0qA==} + engines: {node: '>= 12'} + ipaddr.js@1.9.1: resolution: {integrity: sha512-0KI/607xoxSToH7GjN1FfSbLoU0+btTicjsQSWQlh/hZykN8KpmMf7uYwPW3R+akZ6R/w18ZlXSHBYXiYUPO3g==} engines: {node: '>= 0.10'} @@ -15106,6 +15155,11 @@ snapshots: exponential-backoff@3.1.3: {} + express-rate-limit@8.5.2(express@5.2.1): + dependencies: + express: 5.2.1 + ip-address: 10.2.0 + express@5.2.1: dependencies: accepts: 2.0.0 @@ -15907,6 +15961,8 @@ snapshots: ip-address@10.1.0: {} + ip-address@10.2.0: {} + ipaddr.js@1.9.1: {} ipv6-normalize@1.0.1: {}