From d16e32385e056ce4bb8479a5380e78eadc0f82d3 Mon Sep 17 00:00:00 2001 From: Dan Lynch Date: Tue, 19 May 2026 20:44:38 +0000 Subject: [PATCH 1/2] feat(graphile-llm): add inference usage logging to metering plugin Adds inline INSERT into usage_log_inference table after billing record_usage calls in both meteredEmbed and meteredChat functions. Changes: - config-cache.ts: Add InferenceLogConfig type and resolution from inference_log_module metaschema table (cached alongside billing config) - metering.ts: Add InferenceLogEntry type, logInferenceUsage helper, and calls after billing in meteredEmbed/meteredChat (including quota_exceeded events). Add databaseId, actorId, inferenceLog to MeteringContext. Add embeddingModel, chatModel, provider to MeteringOptions. - metering-plugin.ts: Wire databaseId, actorId, inferenceLog into MeteringContext. Pass model names to MeteringOptions. - index.ts: Export new types (InferenceLogEntry, InferenceLogConfig) and logInferenceUsage function. Gracefully skips if inference_log_module is not provisioned. TODO: dual-write to child (generated) database for platform aggregation. --- graphile/graphile-llm/src/config-cache.ts | 52 +++++- graphile/graphile-llm/src/index.ts | 6 +- graphile/graphile-llm/src/metering.ts | 165 +++++++++++++++++- .../src/plugins/metering-plugin.ts | 8 + 4 files changed, 225 insertions(+), 6 deletions(-) diff --git a/graphile/graphile-llm/src/config-cache.ts b/graphile/graphile-llm/src/config-cache.ts index 10ab45c42..a022d4d8c 100644 --- a/graphile/graphile-llm/src/config-cache.ts +++ b/graphile/graphile-llm/src/config-cache.ts @@ -43,12 +43,24 @@ export interface BillingConfig { publicSchema: string; } +/** + * Inference log table metadata resolved from the inference_log_module. + */ +export interface InferenceLogConfig { + /** Schema containing the usage_log_inference table */ + schema: string; + /** Name of the inference log table */ + tableName: string; +} + /** * Per-database cached configuration for the LLM billing integration. */ export interface LlmBillingCacheEntry { /** Billing function references (null if billing_module not provisioned) */ billing: BillingConfig | null; + /** Inference log table references (null if inference_log_module not provisioned) */ + inferenceLog: InferenceLogConfig | null; } // ─── SQL Queries ──────────────────────────────────────────────────────────── @@ -71,6 +83,18 @@ const BILLING_MODULE_SQL = ` LIMIT 1 `; +/** + * Resolve the inference log module's schema and table name. + */ +const INFERENCE_LOG_MODULE_SQL = ` + SELECT + s.schema_name AS schema, + ilm.inference_log_table_name AS table_name + FROM metaschema_modules_public.inference_log_module ilm + JOIN metaschema_public.schema s ON ilm.schema_id = s.id + WHERE ilm.database_id = $1 + LIMIT 1 +`; // ─── Cache ────────────────────────────────────────────────────────────────── const billingCache = new ModuleConfigCache({ @@ -89,6 +113,27 @@ const SCHEMA_EXISTS_SQL = ` SELECT 1 FROM information_schema.schemata WHERE schema_name = $1 LIMIT 1 `; +async function resolveInferenceLogConfig( + pgClient: PgClient, + databaseId: string, +): Promise { + try { + const schemaCheck = await pgClient.query(SCHEMA_EXISTS_SQL, ['metaschema_modules_public']); + if (schemaCheck.rows.length === 0) return null; + + const result = await pgClient.query(INFERENCE_LOG_MODULE_SQL, [databaseId]); + const row = result.rows[0]; + if (!row?.schema || !row?.table_name) return null; + + return { + schema: row.schema as string, + tableName: row.table_name as string, + }; + } catch { + return null; + } +} + async function resolveBillingConfig( pgClient: PgClient, databaseId: string, @@ -133,9 +178,12 @@ export async function getLlmBillingConfig( const cached = billingCache.get(databaseId); if (cached) return cached; - const billing = await resolveBillingConfig(pgClient, databaseId); + const [billing, inferenceLog] = await Promise.all([ + resolveBillingConfig(pgClient, databaseId), + resolveInferenceLogConfig(pgClient, databaseId), + ]); - const entry: LlmBillingCacheEntry = { billing }; + const entry: LlmBillingCacheEntry = { billing, inferenceLog }; billingCache.set(databaseId, entry); return entry; } diff --git a/graphile/graphile-llm/src/index.ts b/graphile/graphile-llm/src/index.ts index 8d6fe0bf6..110d05787 100644 --- a/graphile/graphile-llm/src/index.ts +++ b/graphile/graphile-llm/src/index.ts @@ -65,8 +65,8 @@ export { } from './chat'; // Metering utilities (for custom integration) -export { meteredEmbed, meteredChat, QuotaExceededError } from './metering'; -export type { MeteringContext, MeteringOptions, MeterResult, WithPgClient } from './metering'; +export { meteredEmbed, meteredChat, logInferenceUsage, QuotaExceededError } from './metering'; +export type { MeteringContext, MeteringOptions, MeterResult, WithPgClient, InferenceLogEntry } from './metering'; // Config cache (for custom integration) export { @@ -74,7 +74,7 @@ export { invalidateLlmBillingConfig, getLlmBillingCacheStats, } from './config-cache'; -export type { BillingConfig, LlmBillingCacheEntry, PgClient } from './config-cache'; +export type { BillingConfig, LlmBillingCacheEntry, InferenceLogConfig, PgClient } from './config-cache'; // Types export type { diff --git a/graphile/graphile-llm/src/metering.ts b/graphile/graphile-llm/src/metering.ts index 520c20f0f..e46a42470 100644 --- a/graphile/graphile-llm/src/metering.ts +++ b/graphile/graphile-llm/src/metering.ts @@ -18,7 +18,7 @@ * resolved from `billing_module` metaschema and cached by `config-cache.ts`. */ -import type { PgClient, BillingConfig } from './config-cache'; +import type { PgClient, BillingConfig, InferenceLogConfig } from './config-cache'; import type { EmbedderFunction, ChatFunction, ChatMessage, ChatOptions } from './types'; // ─── Types ────────────────────────────────────────────────────────────────── @@ -43,6 +43,12 @@ export interface MeteringContext { entityId: string; /** Per-request correlation ID (from request.id pgSetting) */ requestId: string | null; + /** Database UUID from JWT claims */ + databaseId: string; + /** Actor (user) ID from JWT claims */ + actorId: string | null; + /** Inference log table config (null if inference_log_module not provisioned) */ + inferenceLog: InferenceLogConfig | null; } export interface MeteringOptions { @@ -52,6 +58,12 @@ export interface MeteringOptions { chatMeterSlug?: string; /** Whether to skip metering entirely (e.g. for local dev). Default: false */ skipMetering?: boolean; + /** Embedding model name (for inference log) */ + embeddingModel?: string; + /** Chat model name (for inference log) */ + chatModel?: string; + /** Provider name (for inference log) */ + provider?: string; } export interface MeterResult { @@ -113,6 +125,73 @@ async function recordUsage( } } +// ─── Inference Usage Log ──────────────────────────────────────────────────── + +export interface InferenceLogEntry { + databaseId: string; + entityId: string; + actorId: string | null; + model: string; + provider: string | null; + requestType: 'embedding' | 'chat' | 'rag'; + inputTokens: number; + outputTokens: number; + totalTokens: number; + latencyMs: number; + ragEnabled: boolean; + chunksRetrieved: number | null; + embeddingModel: string | null; + embeddingLatencyMs: number | null; + status: 'success' | 'quota_exceeded' | 'provider_error' | 'timeout'; + errorType: string | null; +} + +/** + * Write a row to the usage_log_inference table. + * Gracefully skips if the inference_log_module is not provisioned. + * + * TODO: Also write to child (generated) database when dual-write is needed. + */ +export async function logInferenceUsage( + ctx: MeteringContext, + entry: InferenceLogEntry, +): Promise { + if (!ctx.inferenceLog) return; + + const { schema, tableName } = ctx.inferenceLog; + const sql = `INSERT INTO "${schema}"."${tableName}" ( + database_id, entity_id, actor_id, + model, provider, request_type, + input_tokens, output_tokens, total_tokens, + latency_ms, rag_enabled, chunks_retrieved, + embedding_model, embedding_latency_ms, + status, error_type + ) VALUES ( + $1, $2, $3, + $4, $5, $6, + $7, $8, $9, + $10, $11, $12, + $13, $14, + $15, $16 + )`; + + try { + await ctx.withPgClient(ctx.pgSettings, async (pgClient) => { + await pgClient.query(sql, [ + entry.databaseId, entry.entityId, entry.actorId, + entry.model, entry.provider, entry.requestType, + entry.inputTokens, entry.outputTokens, entry.totalTokens, + entry.latencyMs, entry.ragEnabled, entry.chunksRetrieved, + entry.embeddingModel, entry.embeddingLatencyMs, + entry.status, entry.errorType, + ]); + }); + } catch (e: unknown) { + const message = e instanceof Error ? e.message : String(e); + console.warn(`[graphile-llm] inference log INSERT failed (non-fatal): ${message}`); + } +} + // ─── Metered Embedder ─────────────────────────────────────────────────────── /** @@ -172,6 +251,26 @@ export async function meteredEmbed( } if (!allowed) { + const estimatedTokens = Math.ceil(text.length / 4); + logInferenceUsage(ctx, { + databaseId: ctx.databaseId, + entityId: ctx.entityId, + actorId: ctx.actorId, + model: options.embeddingModel ?? meterSlug, + provider: options.provider ?? null, + requestType: 'embedding', + inputTokens: estimatedTokens, + outputTokens: 0, + totalTokens: estimatedTokens, + latencyMs: Date.now() - startTime, + ragEnabled: false, + chunksRetrieved: null, + embeddingModel: options.embeddingModel ?? null, + embeddingLatencyMs: null, + status: 'quota_exceeded', + errorType: null, + }).catch(() => {}); + return { result: null, metered: true, @@ -185,6 +284,7 @@ export async function meteredEmbed( const latencyMs = Date.now() - startTime; // Record actual usage (input_chars as the metered amount) + const actualTokens = Math.ceil(text.length / 4); ctx.withPgClient(ctx.pgSettings, async (pgClient) => { await recordUsage(pgClient, ctx.billing, ctx.entityId, meterSlug, text.length, { request_id: ctx.requestId, @@ -194,6 +294,26 @@ export async function meteredEmbed( }); }).catch(() => {}); + // Log to inference usage table + logInferenceUsage(ctx, { + databaseId: ctx.databaseId, + entityId: ctx.entityId, + actorId: ctx.actorId, + model: options.embeddingModel ?? meterSlug, + provider: options.provider ?? null, + requestType: 'embedding', + inputTokens: actualTokens, + outputTokens: 0, + totalTokens: actualTokens, + latencyMs, + ragEnabled: false, + chunksRetrieved: null, + embeddingModel: options.embeddingModel ?? null, + embeddingLatencyMs: latencyMs, + status: 'success', + errorType: null, + }).catch(() => {}); + return { result, metered: true, @@ -258,6 +378,26 @@ export async function meteredChat( } if (!allowed) { + const estimatedInputTokens = Math.ceil(messages.reduce((sum, m) => sum + m.content.length, 0) / 4); + logInferenceUsage(ctx, { + databaseId: ctx.databaseId, + entityId: ctx.entityId, + actorId: ctx.actorId, + model: meteringOptions.chatModel ?? meterSlug, + provider: meteringOptions.provider ?? null, + requestType: 'chat', + inputTokens: estimatedInputTokens, + outputTokens: 0, + totalTokens: estimatedInputTokens, + latencyMs: Date.now() - startTime, + ragEnabled: false, + chunksRetrieved: null, + embeddingModel: null, + embeddingLatencyMs: null, + status: 'quota_exceeded', + errorType: null, + }).catch(() => {}); + return { result: null, metered: true, @@ -272,6 +412,9 @@ export async function meteredChat( // Record actual usage (input + output chars as the metered amount) const inputChars = messages.reduce((sum, m) => sum + m.content.length, 0); + const estimatedInputTokens = Math.ceil(inputChars / 4); + const actualOutputTokens = Math.ceil(result.length / 4); + const actualTotal = estimatedInputTokens + actualOutputTokens; ctx.withPgClient(ctx.pgSettings, async (pgClient) => { await recordUsage(pgClient, ctx.billing, ctx.entityId, meterSlug, inputChars + result.length, { request_id: ctx.requestId, @@ -282,6 +425,26 @@ export async function meteredChat( }); }).catch(() => {}); + // Log to inference usage table + logInferenceUsage(ctx, { + databaseId: ctx.databaseId, + entityId: ctx.entityId, + actorId: ctx.actorId, + model: meteringOptions.chatModel ?? meterSlug, + provider: meteringOptions.provider ?? null, + requestType: 'chat', + inputTokens: estimatedInputTokens, + outputTokens: actualOutputTokens, + totalTokens: actualTotal, + latencyMs, + ragEnabled: false, + chunksRetrieved: null, + embeddingModel: null, + embeddingLatencyMs: null, + status: 'success', + errorType: null, + }).catch(() => {}); + return { result, metered: true, diff --git a/graphile/graphile-llm/src/plugins/metering-plugin.ts b/graphile/graphile-llm/src/plugins/metering-plugin.ts index 1c332fa71..61c76be80 100644 --- a/graphile/graphile-llm/src/plugins/metering-plugin.ts +++ b/graphile/graphile-llm/src/plugins/metering-plugin.ts @@ -67,16 +67,19 @@ async function buildMeteringContext( const entityId = resolveEntityId(pgSettings); const databaseId = pgSettings['jwt.claims.database_id'] ?? null; const requestId = pgSettings['request.id'] ?? null; + const actorId = pgSettings['jwt.claims.user_id'] ?? null; if (!entityId || !databaseId) return null; const withPgClient: WithPgClient | undefined = graphqlContext?.withPgClient; if (!withPgClient) return null; let billingConfig = null; + let inferenceLogConfig = null; try { await withPgClient(pgSettings, async (pgClient: PgClient) => { const entry = await getLlmBillingConfig(pgClient, databaseId); billingConfig = entry.billing; + inferenceLogConfig = entry.inferenceLog; }); } catch { return null; @@ -90,6 +93,9 @@ async function buildMeteringContext( billing: billingConfig, entityId, requestId, + databaseId, + actorId, + inferenceLog: inferenceLogConfig, }; } @@ -173,6 +179,8 @@ export function createLlmMeteringPlugin( embeddingMeterSlug: embeddingSlug, chatMeterSlug: chatSlug, skipMetering, + embeddingModel: embeddingModel ?? undefined, + chatModel: chatModel ?? undefined, }; // Replace the embedder with a metered version. From fb532f3709edd507bf6d7e496365d80d612c647e Mon Sep 17 00:00:00 2001 From: Dan Lynch Date: Thu, 21 May 2026 03:02:14 +0000 Subject: [PATCH 2/2] refactor: rename estimatedTokens to placeholderAmountTokens with comment --- graphile/graphile-llm/src/metering.ts | 36 +++++++++++++----------- graphql/server/src/middleware/llm-api.ts | 14 ++++----- 2 files changed, 26 insertions(+), 24 deletions(-) diff --git a/graphile/graphile-llm/src/metering.ts b/graphile/graphile-llm/src/metering.ts index e46a42470..15e82800b 100644 --- a/graphile/graphile-llm/src/metering.ts +++ b/graphile/graphile-llm/src/metering.ts @@ -251,7 +251,8 @@ export async function meteredEmbed( } if (!allowed) { - const estimatedTokens = Math.ceil(text.length / 4); + // Placeholder: replace with actual provider token counts once generateWithUsage() is approved + const placeholderAmountTokens = Math.ceil(text.length / 4); logInferenceUsage(ctx, { databaseId: ctx.databaseId, entityId: ctx.entityId, @@ -259,9 +260,9 @@ export async function meteredEmbed( model: options.embeddingModel ?? meterSlug, provider: options.provider ?? null, requestType: 'embedding', - inputTokens: estimatedTokens, + inputTokens: placeholderAmountTokens, outputTokens: 0, - totalTokens: estimatedTokens, + totalTokens: placeholderAmountTokens, latencyMs: Date.now() - startTime, ragEnabled: false, chunksRetrieved: null, @@ -283,8 +284,8 @@ export async function meteredEmbed( const result = await embedder(text); const latencyMs = Date.now() - startTime; - // Record actual usage (input_chars as the metered amount) - const actualTokens = Math.ceil(text.length / 4); + // Placeholder: replace with actual provider token counts once generateWithUsage() is approved + const placeholderAmountTokens = Math.ceil(text.length / 4); ctx.withPgClient(ctx.pgSettings, async (pgClient) => { await recordUsage(pgClient, ctx.billing, ctx.entityId, meterSlug, text.length, { request_id: ctx.requestId, @@ -302,9 +303,9 @@ export async function meteredEmbed( model: options.embeddingModel ?? meterSlug, provider: options.provider ?? null, requestType: 'embedding', - inputTokens: actualTokens, + inputTokens: placeholderAmountTokens, outputTokens: 0, - totalTokens: actualTokens, + totalTokens: placeholderAmountTokens, latencyMs, ragEnabled: false, chunksRetrieved: null, @@ -378,7 +379,8 @@ export async function meteredChat( } if (!allowed) { - const estimatedInputTokens = Math.ceil(messages.reduce((sum, m) => sum + m.content.length, 0) / 4); + // Placeholder: replace with actual provider token counts once generateWithUsage() is approved + const placeholderInputTokens = Math.ceil(messages.reduce((sum, m) => sum + m.content.length, 0) / 4); logInferenceUsage(ctx, { databaseId: ctx.databaseId, entityId: ctx.entityId, @@ -386,9 +388,9 @@ export async function meteredChat( model: meteringOptions.chatModel ?? meterSlug, provider: meteringOptions.provider ?? null, requestType: 'chat', - inputTokens: estimatedInputTokens, + inputTokens: placeholderInputTokens, outputTokens: 0, - totalTokens: estimatedInputTokens, + totalTokens: placeholderInputTokens, latencyMs: Date.now() - startTime, ragEnabled: false, chunksRetrieved: null, @@ -410,11 +412,11 @@ export async function meteredChat( const result = await chat(messages, chatOptions); const latencyMs = Date.now() - startTime; - // Record actual usage (input + output chars as the metered amount) + // Placeholder: replace with actual provider token counts once generateWithUsage() is approved const inputChars = messages.reduce((sum, m) => sum + m.content.length, 0); - const estimatedInputTokens = Math.ceil(inputChars / 4); - const actualOutputTokens = Math.ceil(result.length / 4); - const actualTotal = estimatedInputTokens + actualOutputTokens; + const placeholderInputTokens = Math.ceil(inputChars / 4); + const placeholderOutputTokens = Math.ceil(result.length / 4); + const placeholderTotalTokens = placeholderInputTokens + placeholderOutputTokens; ctx.withPgClient(ctx.pgSettings, async (pgClient) => { await recordUsage(pgClient, ctx.billing, ctx.entityId, meterSlug, inputChars + result.length, { request_id: ctx.requestId, @@ -433,9 +435,9 @@ export async function meteredChat( model: meteringOptions.chatModel ?? meterSlug, provider: meteringOptions.provider ?? null, requestType: 'chat', - inputTokens: estimatedInputTokens, - outputTokens: actualOutputTokens, - totalTokens: actualTotal, + inputTokens: placeholderInputTokens, + outputTokens: placeholderOutputTokens, + totalTokens: placeholderTotalTokens, latencyMs, ragEnabled: false, chunksRetrieved: null, diff --git a/graphql/server/src/middleware/llm-api.ts b/graphql/server/src/middleware/llm-api.ts index 531664d84..355a8f2b7 100644 --- a/graphql/server/src/middleware/llm-api.ts +++ b/graphql/server/src/middleware/llm-api.ts @@ -68,10 +68,10 @@ interface SendMessageBody { } /** - * Estimate token count from text length (~4 chars per token for English). - * Used as fallback when the provider doesn't return actual counts. + * Placeholder: replace with actual provider token counts once generateWithUsage() is approved. + * Estimates ~4 chars per token for English text. */ -function estimateTokens(text: string): number { +function placeholderAmountTokens(text: string): number { return Math.ceil(text.length / 4); } @@ -94,8 +94,8 @@ async function callWithUsage( content = await client.generate(input); } - const inputTokens = estimateTokens(promptText); - const outputTokens = estimateTokens(content); + const inputTokens = placeholderAmountTokens(promptText); + const outputTokens = placeholderAmountTokens(content); return { content, usage: { input: inputTokens, output: outputTokens, totalTokens: inputTokens + outputTokens }, @@ -526,8 +526,8 @@ async function handleSendMessage( const content = streamedContent; const promptText = llmMessages.map(m => m.content).join(' '); const latencyMs = Date.now() - startTime; - const inputTokens = estimateTokens(promptText); - const outputTokens = estimateTokens(content); + const inputTokens = placeholderAmountTokens(promptText); + const outputTokens = placeholderAmountTokens(content); const totalTokens = inputTokens + outputTokens; // Send [DONE] marker