diff --git a/apps/api/src/controls/controls.service.spec.ts b/apps/api/src/controls/controls.service.spec.ts new file mode 100644 index 000000000..c30abc154 --- /dev/null +++ b/apps/api/src/controls/controls.service.spec.ts @@ -0,0 +1,270 @@ +import { NotFoundException } from '@nestjs/common'; +import { Test, TestingModule } from '@nestjs/testing'; +import { ControlsService } from './controls.service'; + +jest.mock('../auth/auth.server', () => ({ + auth: { api: { getSession: jest.fn() } }, +})); + +jest.mock('@trycompai/auth', () => ({ + statement: { control: ['create', 'read', 'update', 'delete'] }, + BUILT_IN_ROLE_PERMISSIONS: {}, +})); + +jest.mock('./sync-custom-framework-links', () => ({ + syncDirectLinksToCustomFrameworks: jest.fn().mockResolvedValue(undefined), +})); + +const mockDb = { + frameworkInstance: { findUnique: jest.fn() }, + control: { findUnique: jest.fn(), update: jest.fn() }, + policy: { findMany: jest.fn() }, + task: { findMany: jest.fn() }, + evidenceFormSetting: { findMany: jest.fn() }, + evidenceSubmission: { groupBy: jest.fn() }, + frameworkControlPolicyLink: { createMany: jest.fn() }, + frameworkControlTaskLink: { createMany: jest.fn() }, +}; + +jest.mock('@db', () => ({ + db: new Proxy( + {}, + { + get(_target, prop) { + return mockDb[prop] ?? {}; + }, + }, + ), + EvidenceFormType: {}, + Prisma: { SortOrder: { asc: 'asc', desc: 'desc' } }, +})); + +describe('ControlsService', () => { + let service: ControlsService; + + beforeEach(async () => { + const module: TestingModule = await Test.createTestingModule({ + providers: [ControlsService], + }).compile(); + + service = module.get(ControlsService); + jest.clearAllMocks(); + }); + + describe('findOne with frameworkInstanceId', () => { + const orgId = 'org_1'; + const controlId = 'ctrl_1'; + + const policyA = { + id: 'pol_a', + name: 'Policy A', + status: 'published', + archivedAt: null, + }; + const policyB = { + id: 'pol_b', + name: 'Policy B', + status: 'draft', + archivedAt: null, + }; + const taskA = { + id: 'task_a', + title: 'Task A', + status: 'done', + archivedAt: null, + }; + const taskB = { + id: 'task_b', + title: 'Task B', + status: 'todo', + archivedAt: null, + }; + + beforeEach(() => { + mockDb.evidenceFormSetting.findMany.mockResolvedValue([]); + mockDb.evidenceSubmission.groupBy.mockResolvedValue([]); + }); + + describe('custom framework', () => { + const frameworkInstanceId = 'fi_custom_1'; + + beforeEach(() => { + mockDb.frameworkInstance.findUnique.mockResolvedValue({ + id: frameworkInstanceId, + customFrameworkId: 'cf_1', + }); + }); + + it('should include directly-linked policies/tasks when no framework-scoped links exist', async () => { + mockDb.control.findUnique.mockResolvedValue({ + id: controlId, + organizationId: orgId, + policies: [policyA, policyB], + tasks: [taskA, taskB], + controlDocumentTypes: [], + frameworkPolicyLinks: [], + frameworkTaskLinks: [], + frameworkDocumentLinks: [], + requirementsMapped: [], + }); + + const result = await service.findOne( + controlId, + orgId, + frameworkInstanceId, + ); + + expect(result.policies).toEqual([policyA, policyB]); + expect(result.tasks).toEqual([taskA, taskB]); + expect(result.progress.total).toBe(4); + }); + + it('should deduplicate when policies exist in both direct and framework-scoped links', async () => { + mockDb.control.findUnique.mockResolvedValue({ + id: controlId, + organizationId: orgId, + policies: [policyA, policyB], + tasks: [taskA], + controlDocumentTypes: [], + frameworkPolicyLinks: [{ policy: policyA }], + frameworkTaskLinks: [{ task: taskA }, { task: taskB }], + frameworkDocumentLinks: [], + requirementsMapped: [], + }); + + const result = await service.findOne( + controlId, + orgId, + frameworkInstanceId, + ); + + expect(result.policies).toHaveLength(2); + expect(result.tasks).toHaveLength(2); + }); + + it('should include direct document types', async () => { + mockDb.control.findUnique.mockResolvedValue({ + id: controlId, + organizationId: orgId, + policies: [], + tasks: [], + controlDocumentTypes: [{ formType: 'SOC2_TYPE2' }], + frameworkPolicyLinks: [], + frameworkTaskLinks: [], + frameworkDocumentLinks: [], + requirementsMapped: [], + }); + + const result = await service.findOne( + controlId, + orgId, + frameworkInstanceId, + ); + + expect(result.controlDocumentTypes).toHaveLength(1); + expect(result.controlDocumentTypes[0].formType).toBe('SOC2_TYPE2'); + }); + }); + + describe('built-in framework', () => { + const frameworkInstanceId = 'fi_builtin_1'; + + beforeEach(() => { + mockDb.frameworkInstance.findUnique.mockResolvedValue({ + id: frameworkInstanceId, + customFrameworkId: null, + }); + }); + + it('should only show framework-scoped links, not direct links', async () => { + mockDb.control.findUnique.mockResolvedValue({ + id: controlId, + organizationId: orgId, + policies: [policyA, policyB], + tasks: [taskA, taskB], + controlDocumentTypes: [{ formType: 'SOC2_TYPE2' }], + frameworkPolicyLinks: [{ policy: policyA }], + frameworkTaskLinks: [{ task: taskA }], + frameworkDocumentLinks: [], + requirementsMapped: [], + }); + + const result = await service.findOne( + controlId, + orgId, + frameworkInstanceId, + ); + + expect(result.policies).toEqual([policyA]); + expect(result.tasks).toEqual([taskA]); + expect(result.controlDocumentTypes).toHaveLength(0); + }); + }); + + it('should throw NotFoundException when control does not exist', async () => { + mockDb.frameworkInstance.findUnique.mockResolvedValue({ + id: 'fi_1', + customFrameworkId: null, + }); + mockDb.control.findUnique.mockResolvedValue(null); + + await expect( + service.findOne(controlId, orgId, 'fi_1'), + ).rejects.toThrow(NotFoundException); + }); + }); + + describe('linkPolicies', () => { + const { syncDirectLinksToCustomFrameworks } = jest.requireMock( + './sync-custom-framework-links', + ); + + it('should sync to custom frameworks when linking without frameworkInstanceId', async () => { + mockDb.control.findUnique.mockResolvedValue({ id: 'ctrl_1' }); + mockDb.policy.findMany.mockResolvedValue([{ id: 'pol_1' }]); + mockDb.control.update.mockResolvedValue({}); + + await service.linkPolicies('ctrl_1', 'org_1', ['pol_1']); + + expect(syncDirectLinksToCustomFrameworks).toHaveBeenCalledWith({ + controlId: 'ctrl_1', + organizationId: 'org_1', + }); + }); + + it('should NOT sync when linking with frameworkInstanceId', async () => { + mockDb.control.findUnique.mockResolvedValue({ id: 'ctrl_1' }); + mockDb.policy.findMany.mockResolvedValue([{ id: 'pol_1' }]); + mockDb.frameworkInstance.findUnique.mockResolvedValue({ + id: 'fi_1', + customFrameworkId: null, + }); + mockDb.frameworkControlPolicyLink.createMany.mockResolvedValue({ + count: 1, + }); + + await service.linkPolicies('ctrl_1', 'org_1', ['pol_1'], 'fi_1'); + + expect(syncDirectLinksToCustomFrameworks).not.toHaveBeenCalled(); + }); + }); + + describe('linkTasks', () => { + const { syncDirectLinksToCustomFrameworks } = jest.requireMock( + './sync-custom-framework-links', + ); + + it('should sync to custom frameworks when linking without frameworkInstanceId', async () => { + mockDb.control.findUnique.mockResolvedValue({ id: 'ctrl_1' }); + mockDb.task.findMany.mockResolvedValue([{ id: 'task_1' }]); + mockDb.control.update.mockResolvedValue({}); + + await service.linkTasks('ctrl_1', 'org_1', ['task_1']); + + expect(syncDirectLinksToCustomFrameworks).toHaveBeenCalledWith({ + controlId: 'ctrl_1', + organizationId: 'org_1', + }); + }); + }); +}); diff --git a/apps/api/src/controls/controls.service.ts b/apps/api/src/controls/controls.service.ts index 9476b28ba..ace5ca48c 100644 --- a/apps/api/src/controls/controls.service.ts +++ b/apps/api/src/controls/controls.service.ts @@ -5,6 +5,8 @@ import { } from '@nestjs/common'; import { db, EvidenceFormType, Prisma } from '@db'; import { CreateControlDto } from './dto/create-control.dto'; +import { deduplicateById, deduplicateByFormType } from '../utils/deduplicate'; +import { syncDirectLinksToCustomFrameworks } from './sync-custom-framework-links'; // A CustomRequirement is valid for a given FrameworkInstance when its parent // matches: either it lives on the FI's CustomFramework, or it was attached @@ -205,10 +207,14 @@ export class ControlsService { organizationId: string, frameworkInstanceId: string, ) { - await this.ensureFrameworkInstance(frameworkInstanceId, organizationId); + const fi = await this.ensureFrameworkInstance(frameworkInstanceId, organizationId); + const isCustomFramework = fi.customFrameworkId !== null; const control = await db.control.findUnique({ where: { id: controlId, organizationId }, include: { + policies: { where: { archivedAt: null } }, + tasks: { where: { archivedAt: null } }, + controlDocumentTypes: true, frameworkPolicyLinks: { where: { frameworkInstanceId, @@ -243,9 +249,17 @@ export class ControlsService { throw new NotFoundException('Control not found'); } - const policies = control.frameworkPolicyLinks.map((link) => link.policy); - const tasks = control.frameworkTaskLinks.map((link) => link.task); - const controlDocumentTypes = control.frameworkDocumentLinks; + const frameworkPolicies = control.frameworkPolicyLinks.map((link) => link.policy); + const frameworkTasks = control.frameworkTaskLinks.map((link) => link.task); + const directPolicies = isCustomFramework ? (control.policies ?? []) : []; + const directTasks = isCustomFramework ? (control.tasks ?? []) : []; + const policies = deduplicateById([...frameworkPolicies, ...directPolicies]); + const tasks = deduplicateById([...frameworkTasks, ...directTasks]); + const directDocTypes = isCustomFramework ? control.controlDocumentTypes : []; + const controlDocumentTypes = deduplicateByFormType([ + ...control.frameworkDocumentLinks, + ...directDocTypes, + ]); const formTypes = controlDocumentTypes.map((d) => d.formType); const notRelevantSettings = formTypes.length > 0 @@ -287,6 +301,9 @@ export class ControlsService { frameworkPolicyLinks, frameworkTaskLinks, frameworkDocumentLinks, + policies: _policies, + tasks: _tasks, + controlDocumentTypes: _controlDocumentTypes, ...controlData } = control; @@ -460,6 +477,14 @@ export class ControlsService { }); } + if (scopedRequirementMappings.length > 0) { + await syncDirectLinksToCustomFrameworks({ + controlId: control.id, + organizationId, + client: tx, + }); + } + return control; }); } @@ -610,7 +635,7 @@ export class ControlsService { ) { const frameworkInstance = await db.frameworkInstance.findUnique({ where: { id: frameworkInstanceId, organizationId }, - select: { id: true }, + select: { id: true, customFrameworkId: true }, }); if (!frameworkInstance) { throw new NotFoundException('Framework instance not found'); @@ -649,6 +674,7 @@ export class ControlsService { where: { id: controlId }, data: { policies: { connect: policies.map((p) => ({ id: p.id })) } }, }); + await syncDirectLinksToCustomFrameworks({ controlId, organizationId }); } return { count: policies.length }; @@ -685,6 +711,7 @@ export class ControlsService { where: { id: controlId }, data: { tasks: { connect: tasks.map((t) => ({ id: t.id })) } }, }); + await syncDirectLinksToCustomFrameworks({ controlId, organizationId }); } return { count: tasks.length }; @@ -810,6 +837,7 @@ export class ControlsService { data: formTypes.map((formType) => ({ controlId, formType })), skipDuplicates: true, }); + await syncDirectLinksToCustomFrameworks({ controlId, organizationId }); return { count: result.count }; } @@ -827,10 +855,36 @@ export class ControlsService { }); return { success: true }; } - await db.controlDocumentType.deleteMany({ - where: { controlId, formType }, + return db.$transaction(async (tx) => { + const deleted = await tx.controlDocumentType.deleteMany({ + where: { controlId, formType }, + }); + if (deleted.count === 0) return { success: true }; + const customFiIds = await tx.requirementMap.findMany({ + where: { + controlId, + archivedAt: null, + frameworkInstance: { + organizationId, + customFrameworkId: { not: null }, + }, + }, + select: { frameworkInstanceId: true }, + distinct: ['frameworkInstanceId'], + }); + if (customFiIds.length > 0) { + await tx.frameworkControlDocumentTypeLink.deleteMany({ + where: { + controlId, + formType, + frameworkInstanceId: { + in: customFiIds.map((r) => r.frameworkInstanceId), + }, + }, + }); + } + return { success: true }; }); - return { success: true }; } async delete(controlId: string, organizationId: string) { diff --git a/apps/api/src/controls/sync-custom-framework-links.spec.ts b/apps/api/src/controls/sync-custom-framework-links.spec.ts new file mode 100644 index 000000000..20b732906 --- /dev/null +++ b/apps/api/src/controls/sync-custom-framework-links.spec.ts @@ -0,0 +1,152 @@ +import { syncDirectLinksToCustomFrameworks } from './sync-custom-framework-links'; + +const mockDb = { + frameworkInstance: { count: jest.fn() }, + requirementMap: { findMany: jest.fn() }, + control: { findUnique: jest.fn() }, + frameworkControlPolicyLink: { createMany: jest.fn() }, + frameworkControlTaskLink: { createMany: jest.fn() }, + frameworkControlDocumentTypeLink: { createMany: jest.fn() }, +}; + +jest.mock('@db', () => ({ + db: new Proxy( + {}, + { + get(_target, prop) { + return mockDb[prop] ?? {}; + }, + }, + ), + Prisma: {}, +})); + +describe('syncDirectLinksToCustomFrameworks', () => { + beforeEach(() => jest.clearAllMocks()); + + it('should skip entirely when org has no custom frameworks', async () => { + mockDb.frameworkInstance.count.mockResolvedValue(0); + + await syncDirectLinksToCustomFrameworks({ + controlId: 'ctrl_1', + organizationId: 'org_1', + }); + + expect(mockDb.requirementMap.findMany).not.toHaveBeenCalled(); + expect(mockDb.control.findUnique).not.toHaveBeenCalled(); + }); + + it('should do nothing when control is not mapped to any custom framework', async () => { + mockDb.frameworkInstance.count.mockResolvedValue(1); + mockDb.requirementMap.findMany.mockResolvedValue([]); + + await syncDirectLinksToCustomFrameworks({ + controlId: 'ctrl_1', + organizationId: 'org_1', + }); + + expect(mockDb.control.findUnique).not.toHaveBeenCalled(); + }); + + it('should create framework-scoped links for all custom FIs', async () => { + mockDb.frameworkInstance.count.mockResolvedValue(2); + mockDb.requirementMap.findMany.mockResolvedValue([ + { frameworkInstanceId: 'fi_1' }, + { frameworkInstanceId: 'fi_2' }, + ]); + mockDb.control.findUnique.mockResolvedValue({ + id: 'ctrl_1', + policies: [{ id: 'pol_a' }, { id: 'pol_b' }], + tasks: [{ id: 'task_a' }], + controlDocumentTypes: [{ formType: 'SOC2_TYPE2' }], + }); + mockDb.frameworkControlPolicyLink.createMany.mockResolvedValue({ + count: 4, + }); + mockDb.frameworkControlTaskLink.createMany.mockResolvedValue({ count: 2 }); + mockDb.frameworkControlDocumentTypeLink.createMany.mockResolvedValue({ + count: 2, + }); + + await syncDirectLinksToCustomFrameworks({ + controlId: 'ctrl_1', + organizationId: 'org_1', + }); + + expect(mockDb.frameworkControlPolicyLink.createMany).toHaveBeenCalledWith({ + data: [ + { + frameworkInstanceId: 'fi_1', + controlId: 'ctrl_1', + policyId: 'pol_a', + }, + { + frameworkInstanceId: 'fi_1', + controlId: 'ctrl_1', + policyId: 'pol_b', + }, + { + frameworkInstanceId: 'fi_2', + controlId: 'ctrl_1', + policyId: 'pol_a', + }, + { + frameworkInstanceId: 'fi_2', + controlId: 'ctrl_1', + policyId: 'pol_b', + }, + ], + skipDuplicates: true, + }); + + expect(mockDb.frameworkControlTaskLink.createMany).toHaveBeenCalledWith({ + data: [ + { frameworkInstanceId: 'fi_1', controlId: 'ctrl_1', taskId: 'task_a' }, + { frameworkInstanceId: 'fi_2', controlId: 'ctrl_1', taskId: 'task_a' }, + ], + skipDuplicates: true, + }); + + expect( + mockDb.frameworkControlDocumentTypeLink.createMany, + ).toHaveBeenCalledWith({ + data: [ + { + frameworkInstanceId: 'fi_1', + controlId: 'ctrl_1', + formType: 'SOC2_TYPE2', + }, + { + frameworkInstanceId: 'fi_2', + controlId: 'ctrl_1', + formType: 'SOC2_TYPE2', + }, + ], + skipDuplicates: true, + }); + }); + + it('should skip empty direct relationships', async () => { + mockDb.frameworkInstance.count.mockResolvedValue(1); + mockDb.requirementMap.findMany.mockResolvedValue([ + { frameworkInstanceId: 'fi_1' }, + ]); + mockDb.control.findUnique.mockResolvedValue({ + id: 'ctrl_1', + policies: [], + tasks: [], + controlDocumentTypes: [], + }); + + await syncDirectLinksToCustomFrameworks({ + controlId: 'ctrl_1', + organizationId: 'org_1', + }); + + expect(mockDb.frameworkControlPolicyLink.createMany).not.toHaveBeenCalled(); + expect(mockDb.frameworkControlTaskLink.createMany).not.toHaveBeenCalled(); + expect( + mockDb.frameworkControlDocumentTypeLink.createMany, + ).not.toHaveBeenCalled(); + }); +}); diff --git a/apps/api/src/controls/sync-custom-framework-links.ts b/apps/api/src/controls/sync-custom-framework-links.ts new file mode 100644 index 000000000..bd6c679d4 --- /dev/null +++ b/apps/api/src/controls/sync-custom-framework-links.ts @@ -0,0 +1,92 @@ +import { db, Prisma } from '@db'; + +type DbClient = Prisma.TransactionClient | typeof db; + +export async function syncDirectLinksToCustomFrameworks({ + controlId, + organizationId, + client, +}: { + controlId: string; + organizationId: string; + client?: DbClient; +}) { + const prisma = client ?? db; + + const hasCustomFrameworks = await prisma.frameworkInstance.count({ + where: { organizationId, customFrameworkId: { not: null } }, + }); + if (hasCustomFrameworks === 0) return; + + const customFiIds = await prisma.requirementMap.findMany({ + where: { + controlId, + archivedAt: null, + frameworkInstance: { + organizationId, + customFrameworkId: { not: null }, + }, + }, + select: { frameworkInstanceId: true }, + distinct: ['frameworkInstanceId'], + }); + + if (customFiIds.length === 0) return; + + const control = await prisma.control.findUnique({ + where: { id: controlId, organizationId }, + include: { + policies: { + where: { archivedAt: null }, + select: { id: true }, + }, + tasks: { + where: { archivedAt: null }, + select: { id: true }, + }, + controlDocumentTypes: { + select: { formType: true }, + }, + }, + }); + + if (!control) return; + + const fiIds = customFiIds.map((r) => r.frameworkInstanceId); + + await Promise.all([ + control.policies.length > 0 && + prisma.frameworkControlPolicyLink.createMany({ + data: fiIds.flatMap((frameworkInstanceId) => + control.policies.map((p) => ({ + frameworkInstanceId, + controlId, + policyId: p.id, + })), + ), + skipDuplicates: true, + }), + control.tasks.length > 0 && + prisma.frameworkControlTaskLink.createMany({ + data: fiIds.flatMap((frameworkInstanceId) => + control.tasks.map((t) => ({ + frameworkInstanceId, + controlId, + taskId: t.id, + })), + ), + skipDuplicates: true, + }), + control.controlDocumentTypes.length > 0 && + prisma.frameworkControlDocumentTypeLink.createMany({ + data: fiIds.flatMap((frameworkInstanceId) => + control.controlDocumentTypes.map((d) => ({ + frameworkInstanceId, + controlId, + formType: d.formType, + })), + ), + skipDuplicates: true, + }), + ]); +} diff --git a/apps/api/src/frameworks/frameworks.service.ts b/apps/api/src/frameworks/frameworks.service.ts index 4806a850e..66217295e 100644 --- a/apps/api/src/frameworks/frameworks.service.ts +++ b/apps/api/src/frameworks/frameworks.service.ts @@ -5,6 +5,8 @@ import { NotFoundException, } from '@nestjs/common'; import { db, type EvidenceFormType } from '@db'; +import { deduplicateById, deduplicateByFormType } from '../utils/deduplicate'; +import { syncDirectLinksToCustomFrameworks } from '../controls/sync-custom-framework-links'; import { tasks } from '@trigger.dev/sdk'; import { @@ -32,6 +34,50 @@ type RequirementDef = { kind: 'platform' | 'custom'; }; +function mergeControlLinks( + control: { + id: string; + frameworkPolicyLinks: { policy: { id: string; name: string; status: string } }[]; + frameworkDocumentLinks: { formType: EvidenceFormType }[]; + policies: { id: string; name: string; status: string }[]; + controlDocumentTypes: { formType: EvidenceFormType }[]; + [key: string]: unknown; + }, + opts: { + isCustomFramework: boolean; + frameworkInstanceId: string; + notRelevantFormTypes: Set; + }, +) { + const { + frameworkPolicyLinks, + frameworkDocumentLinks, + policies: directPolicies, + controlDocumentTypes: directDocTypes, + ...rest + } = control; + const frameworkPolicies = frameworkPolicyLinks.map((link) => link.policy); + const extraPolicies = opts.isCustomFramework ? directPolicies : []; + const extraDocTypes = opts.isCustomFramework + ? directDocTypes.map((d) => ({ + ...d, + frameworkInstanceId: opts.frameworkInstanceId, + controlId: control.id, + })) + : []; + return { + ...rest, + policies: deduplicateById([...frameworkPolicies, ...extraPolicies]), + controlDocumentTypes: deduplicateByFormType([ + ...(frameworkDocumentLinks || []), + ...extraDocTypes, + ]).map((documentType) => ({ + ...documentType, + isNotRelevant: opts.notRelevantFormTypes.has(documentType.formType), + })), + }; +} + @Injectable() export class FrameworksService { private readonly logger = new Logger(FrameworksService.name); @@ -156,6 +202,11 @@ export class FrameworksService { include: { control: { include: { + policies: { + where: { archivedAt: null }, + select: { id: true, name: true, status: true }, + }, + controlDocumentTypes: true, frameworkPolicyLinks: { where: { policy: { archivedAt: null } }, include: { @@ -185,41 +236,38 @@ export class FrameworksService { await this.getNotRelevantFormTypes(organizationId); const frameworksWithControls = frameworkInstances.map((fi: any) => { + const isCustomFramework = fi.customFrameworkId !== null; const controlsMap = new Map(); for (const rm of fi.requirementsMapped || []) { if (rm.control && !controlsMap.has(rm.control.id)) { const { - requirementsMapped: _, - frameworkPolicyLinks, - frameworkDocumentLinks, + requirementsMapped: _reqs, frameworkControlFamilies, - ...controlData + ...controlForMerge } = rm.control; - const policyLinks = rm.control.frameworkPolicyLinks.filter( - (link: { frameworkInstanceId: string }) => - link.frameworkInstanceId === fi.id, - ); - const documentLinks = rm.control.frameworkDocumentLinks.filter( - (link: { frameworkInstanceId: string }) => - link.frameworkInstanceId === fi.id, - ); + const scopedControl = { + ...controlForMerge, + frameworkPolicyLinks: controlForMerge.frameworkPolicyLinks.filter( + (link: { frameworkInstanceId: string }) => + link.frameworkInstanceId === fi.id, + ), + frameworkDocumentLinks: controlForMerge.frameworkDocumentLinks.filter( + (link: { frameworkInstanceId: string }) => + link.frameworkInstanceId === fi.id, + ), + }; + const merged = mergeControlLinks(scopedControl, { + isCustomFramework, + frameworkInstanceId: fi.id, + notRelevantFormTypes, + }); const familyEntry = (frameworkControlFamilies ?? []).find( (f: { frameworkInstanceId: string }) => f.frameworkInstanceId === fi.id, ); controlsMap.set(rm.control.id, { - ...controlData, + ...merged, controlFamily: familyEntry?.controlFamily ?? null, - policies: policyLinks.map( - (link: { policy: { id: string; name: string; status: string } }) => - link.policy, - ), - controlDocumentTypes: documentLinks.map( - (documentType: { formType: EvidenceFormType }) => ({ - ...documentType, - isNotRelevant: notRelevantFormTypes.has(documentType.formType), - }), - ), requirementsMapped: rm.control.requirementsMapped || [], }); } @@ -232,41 +280,90 @@ export class FrameworksService { return frameworksWithControls; } - const [tasks, evidenceSubmissions] = await Promise.all([ - db.task.findMany({ - where: { - organizationId, - archivedAt: null, - frameworkControlLinks: { - some: { frameworkInstance: { organizationId } }, + const hasCustomFrameworks = frameworkInstances.some( + (fi: any) => fi.customFrameworkId !== null, + ); + const allControlIds = hasCustomFrameworks + ? [ + ...new Set( + frameworksWithControls.flatMap((fw: any) => + fw.controls.map((c: any) => c.id), + ), + ), + ] + : []; + + const [frameworkTasks, directTasks, evidenceSubmissions] = await Promise.all( + [ + db.task.findMany({ + where: { + organizationId, + archivedAt: null, + frameworkControlLinks: { + some: { frameworkInstance: { organizationId } }, + }, }, - }, - include: { - frameworkControlLinks: { - where: { frameworkInstance: { organizationId } }, - include: { control: true }, + include: { + frameworkControlLinks: { + where: { frameworkInstance: { organizationId } }, + include: { control: true }, + }, }, - }, - }), - db.evidenceSubmission.findMany({ - where: { organizationId }, - select: { formType: true, submittedAt: true }, - }), - ]); + }), + hasCustomFrameworks && allControlIds.length > 0 + ? db.task.findMany({ + where: { + organizationId, + archivedAt: null, + controls: { + some: { id: { in: allControlIds as string[] } }, + }, + }, + include: { + controls: { + where: { id: { in: allControlIds as string[] } }, + }, + }, + }) + : Promise.resolve([]), + db.evidenceSubmission.findMany({ + where: { organizationId }, + select: { formType: true, submittedAt: true }, + }), + ], + ); - return frameworksWithControls.map((fw: any) => ({ - ...fw, - complianceScore: computeFrameworkComplianceScore( - fw, - tasks.map(({ frameworkControlLinks, ...task }) => ({ + return frameworksWithControls.map((fw: any) => { + const isCustomFw = fw.customFrameworkId !== null; + const fwControlIds = new Set(fw.controls.map((c: any) => c.id)); + const mappedFrameworkTasks = frameworkTasks.map( + ({ frameworkControlLinks, ...task }) => ({ ...task, controls: frameworkControlLinks .filter((link) => link.frameworkInstanceId === fw.id) .map((link) => link.control), - })), - evidenceSubmissions, - ), - })); + }), + ); + const mappedDirectTasks = isCustomFw + ? directTasks.map(({ controls, ...task }: (typeof directTasks)[number]) => ({ + ...task, + controls: (controls as any[]).filter((c) => fwControlIds.has(c.id)), + })) + : []; + const allTasks = deduplicateById([ + ...mappedDirectTasks, + ...mappedFrameworkTasks, + ].filter((t) => t.controls.length > 0)); + + return { + ...fw, + complianceScore: computeFrameworkComplianceScore( + fw, + allTasks, + evidenceSubmissions, + ), + }; + }); } async findOne(frameworkInstanceId: string, organizationId: string) { @@ -280,6 +377,11 @@ export class FrameworksService { include: { control: { include: { + policies: { + where: { archivedAt: null }, + select: { id: true, name: true, status: true }, + }, + controlDocumentTypes: true, frameworkPolicyLinks: { where: { frameworkInstanceId, @@ -311,31 +413,24 @@ export class FrameworksService { throw new NotFoundException('Framework instance not found'); } + const isCustomFramework = fi.customFrameworkId !== null; const notRelevantFormTypes = await this.getNotRelevantFormTypes(organizationId); + const mergeOpts = { isCustomFramework, frameworkInstanceId, notRelevantFormTypes }; const controlsMap = new Map(); for (const rm of fi.requirementsMapped) { if (rm.control && !controlsMap.has(rm.control.id)) { const { - requirementsMapped: _, - frameworkPolicyLinks, - frameworkDocumentLinks, + requirementsMapped: _reqs, frameworkControlFamilies, - ...controlData + ...controlForMerge } = rm.control; + const merged = mergeControlLinks(controlForMerge, mergeOpts); controlsMap.set(rm.control.id, { - ...controlData, + ...merged, controlFamily: frameworkControlFamilies?.[0]?.controlFamily ?? null, - policies: - rm.control.frameworkPolicyLinks?.map((link) => link.policy) || [], requirementsMapped: rm.control.requirementsMapped || [], - controlDocumentTypes: (rm.control.frameworkDocumentLinks || []).map( - (documentType) => ({ - ...documentType, - isNotRelevant: notRelevantFormTypes.has(documentType.formType), - }), - ), }); } } @@ -349,9 +444,11 @@ export class FrameworksService { } } + const controlIds = Array.from(controlsMap.keys()); const [ requirementDefinitions, - tasks, + frameworkTasks, + directTasks, requirementMaps, evidenceSubmissions, ] = await Promise.all([ @@ -369,6 +466,20 @@ export class FrameworksService { }, }, }), + isCustomFramework && controlIds.length > 0 + ? db.task.findMany({ + where: { + organizationId, + archivedAt: null, + controls: { some: { id: { in: controlIds } } }, + }, + include: { + controls: { + where: { id: { in: controlIds } }, + }, + }, + }) + : Promise.resolve([]), db.requirementMap.findMany({ where: { frameworkInstanceId, archivedAt: null }, include: { control: true }, @@ -385,14 +496,28 @@ export class FrameworksService { : Promise.resolve([]), ]); + const mappedFrameworkTasks = frameworkTasks.map( + ({ frameworkControlLinks, ...task }) => ({ + ...task, + controls: frameworkControlLinks.map((link) => link.control), + }), + ); + const mappedDirectTasks = directTasks.map( + ({ controls, ...task }: (typeof directTasks)[number]) => ({ + ...task, + controls, + }), + ); + const allTasks = deduplicateById([ + ...mappedFrameworkTasks, + ...mappedDirectTasks, + ]); + return { ...rest, controls: Array.from(controlsMap.values()), requirementDefinitions, - tasks: tasks.map(({ frameworkControlLinks, ...task }) => ({ - ...task, - controls: frameworkControlLinks.map((link) => link.control), - })), + tasks: allTasks, requirementMaps, evidenceSubmissions, }; @@ -601,6 +726,17 @@ export class FrameworksService { skipDuplicates: true, }); + if (fi.customFrameworkId) { + await Promise.all( + controls.map((c) => + syncDirectLinksToCustomFrameworks({ + controlId: c.id, + organizationId, + }), + ), + ); + } + return { count: result.count }; } @@ -781,67 +917,103 @@ export class FrameworksService { throw new NotFoundException('Framework instance not found'); } + const isCustomFramework = fi.customFrameworkId !== null; const allReqDefs = await this.loadRequirementDefinitions(fi); const requirement = allReqDefs.find((r) => r.id === requirementKey); if (!requirement) { throw new NotFoundException('Requirement not found'); } - const [relatedControls, tasks, notRelevantFormTypes] = await Promise.all([ - db.requirementMap.findMany({ - where: { - frameworkInstanceId, - archivedAt: null, - ...(requirement.kind === 'custom' - ? { customRequirementId: requirementKey } - : { requirementId: requirementKey }), - }, - include: { - control: { - include: { - frameworkPolicyLinks: { - where: { - frameworkInstanceId, - policy: { archivedAt: null }, + const [relatedControls, frameworkTasks, notRelevantFormTypes] = + await Promise.all([ + db.requirementMap.findMany({ + where: { + frameworkInstanceId, + archivedAt: null, + ...(requirement.kind === 'custom' + ? { customRequirementId: requirementKey } + : { requirementId: requirementKey }), + }, + include: { + control: { + include: { + policies: { + where: { archivedAt: null }, + select: { id: true, name: true, status: true }, }, - include: { - policy: { - select: { id: true, name: true, status: true }, + controlDocumentTypes: true, + frameworkPolicyLinks: { + where: { + frameworkInstanceId, + policy: { archivedAt: null }, + }, + include: { + policy: { + select: { id: true, name: true, status: true }, + }, }, }, - }, - frameworkDocumentLinks: { - where: { frameworkInstanceId }, - }, - frameworkControlFamilies: { - where: { frameworkInstanceId }, - select: { controlFamily: true }, - take: 1, + frameworkDocumentLinks: { + where: { frameworkInstanceId }, + }, + frameworkControlFamilies: { + where: { frameworkInstanceId }, + select: { controlFamily: true }, + take: 1, + }, }, }, }, - }, - }), - db.task.findMany({ - where: { - organizationId, - archivedAt: null, - frameworkControlLinks: { some: { frameworkInstanceId } }, - }, - include: { - frameworkControlLinks: { - where: { frameworkInstanceId }, - include: { control: true }, + }), + db.task.findMany({ + where: { + organizationId, + archivedAt: null, + frameworkControlLinks: { some: { frameworkInstanceId } }, }, + include: { + frameworkControlLinks: { + where: { frameworkInstanceId }, + include: { control: true }, + }, + }, + }), + this.getNotRelevantFormTypes(organizationId), + ]); + + const controlIds = relatedControls.map((rc) => rc.control.id); + const directTasks = + isCustomFramework && controlIds.length > 0 + ? await db.task.findMany({ + where: { + organizationId, + archivedAt: null, + controls: { some: { id: { in: controlIds } } }, + }, + include: { + controls: { where: { id: { in: controlIds } } }, + }, + }) + : []; + + const mergeOpts = { isCustomFramework, frameworkInstanceId, notRelevantFormTypes }; + const mappedRelatedControls = relatedControls.map((relatedControl) => { + const { frameworkControlFamilies, ...controlForMerge } = + relatedControl.control; + return { + ...relatedControl, + control: { + ...mergeControlLinks(controlForMerge, mergeOpts), + controlFamily: + frameworkControlFamilies?.[0]?.controlFamily ?? null, }, - }), - this.getNotRelevantFormTypes(organizationId), - ]); + }; + }); const formTypes = new Set(); - for (const rc of relatedControls) { - for (const dt of rc.control.frameworkDocumentLinks || []) { - if (notRelevantFormTypes.has(dt.formType)) continue; + for (const rc of mappedRelatedControls) { + for (const dt of rc.control.controlDocumentTypes || []) { + if (dt.isNotRelevant) continue; formTypes.add(dt.formType); } } @@ -862,34 +1034,21 @@ export class FrameworksService { .filter((r) => r.id !== requirementKey) .map((r) => ({ id: r.id, name: r.name })); - return { - requirement, - relatedControls: relatedControls.map((relatedControl) => ({ - ...relatedControl, - control: (() => { - const { - frameworkPolicyLinks, - frameworkDocumentLinks, - frameworkControlFamilies, - ...control - } = relatedControl.control; - return { - ...control, - controlFamily: frameworkControlFamilies?.[0]?.controlFamily ?? null, - policies: frameworkPolicyLinks.map((link) => link.policy), - controlDocumentTypes: frameworkDocumentLinks.map( - (documentType) => ({ - ...documentType, - isNotRelevant: notRelevantFormTypes.has(documentType.formType), - }), - ), - }; - })(), - })), - tasks: tasks.map(({ frameworkControlLinks, ...task }) => ({ + const mappedFrameworkTasks = frameworkTasks.map( + ({ frameworkControlLinks, ...task }) => ({ ...task, controls: frameworkControlLinks.map((link) => link.control), - })), + }), + ); + const mappedDirectTasks = directTasks.map(({ controls, ...task }) => ({ + ...task, + controls, + })); + + return { + requirement, + relatedControls: mappedRelatedControls, + tasks: deduplicateById([...mappedFrameworkTasks, ...mappedDirectTasks]), evidenceSubmissions, siblingRequirements, }; diff --git a/apps/api/src/policies/policies.controller.ts b/apps/api/src/policies/policies.controller.ts index 0eab46dc3..553dba274 100644 --- a/apps/api/src/policies/policies.controller.ts +++ b/apps/api/src/policies/policies.controller.ts @@ -780,13 +780,29 @@ export class PoliciesController { @OrganizationId() organizationId: string, @AuthContext() authContext: AuthContextType, ) { - await db.policy.update({ - where: { id, organizationId }, - data: { - controls: { - disconnect: { id: controlId }, + await db.$transaction(async (tx) => { + const before = await tx.policy.findUnique({ + where: { id, organizationId }, + select: { + controls: { where: { id: controlId }, select: { id: true } }, }, - }, + }); + await tx.policy.update({ + where: { id, organizationId }, + data: { controls: { disconnect: { id: controlId } } }, + }); + if (before?.controls.length) { + await tx.frameworkControlPolicyLink.deleteMany({ + where: { + controlId, + policyId: id, + frameworkInstance: { + organizationId, + customFrameworkId: { not: null }, + }, + }, + }); + } }); return { diff --git a/apps/api/src/utils/deduplicate.ts b/apps/api/src/utils/deduplicate.ts new file mode 100644 index 000000000..f98f53336 --- /dev/null +++ b/apps/api/src/utils/deduplicate.ts @@ -0,0 +1,22 @@ +export function deduplicateBy( + items: T[], + key: (item: T) => string, +): T[] { + const seen = new Set(); + return items.filter((item) => { + const k = key(item); + if (seen.has(k)) return false; + seen.add(k); + return true; + }); +} + +export function deduplicateById(items: T[]): T[] { + return deduplicateBy(items, (item) => item.id); +} + +export function deduplicateByFormType( + items: T[], +): T[] { + return deduplicateBy(items, (item) => item.formType); +}