From bebc8039bf462ea567cc27a4074a9c2857ae58e1 Mon Sep 17 00:00:00 2001 From: Patel230 Date: Thu, 7 May 2026 19:40:53 +0530 Subject: [PATCH 1/2] feat: auto-continuation, orphaned tool_use sanitization, weighted routing - StreamChatWithContinuation: auto-continues on max_tokens (3x, 32K cap) - StreamChatContinue on EyrieClient for easy integration - SanitizeMessages: detects orphaned tool_use, injects synthetic error results - Integrated sanitization into Anthropic and OpenAI Chat/StreamChat - WeightedProvider: weighted random selection with failover on retriable errors - StopReason now propagated in both Anthropic and OpenAI stream processors --- client/anthropic.go | 2 + client/client.go | 20 +++ client/continuation.go | 93 ++++++++++++++ client/openai.go | 2 + client/sanitize.go | 44 +++++++ client/stream.go | 8 +- client/weighted.go | 228 +++++++++++++++++++++++++++++++++ client/weighted_test.go | 276 ++++++++++++++++++++++++++++++++++++++++ 8 files changed, 671 insertions(+), 2 deletions(-) create mode 100644 client/sanitize.go create mode 100644 client/weighted.go create mode 100644 client/weighted_test.go diff --git a/client/anthropic.go b/client/anthropic.go index fce7d72..7a616c3 100644 --- a/client/anthropic.go +++ b/client/anthropic.go @@ -182,6 +182,7 @@ func parseImageString(img string) (mediaType, data string, isBase64 bool) { // This is not implemented here; opts.ResponseFormat is ignored for Anthropic. // Future work: implement tool-use-based structured output for Anthropic. func (c *AnthropicClient) Chat(ctx context.Context, messages []EyrieMessage, opts ChatOptions) (*EyrieResponse, error) { + messages = SanitizeMessages(messages) if opts.Model == "" { return nil, fmt.Errorf("eyrie: model is required for anthropic") } @@ -280,6 +281,7 @@ func (c *AnthropicClient) Chat(ctx context.Context, messages []EyrieMessage, opt // StreamChat sends a streaming message to Anthropic. func (c *AnthropicClient) StreamChat(ctx context.Context, messages []EyrieMessage, opts ChatOptions) (*StreamResult, error) { + messages = SanitizeMessages(messages) if opts.Model == "" { return nil, fmt.Errorf("eyrie: model is required for anthropic") } diff --git a/client/client.go b/client/client.go index 35dae52..0ee1bd3 100644 --- a/client/client.go +++ b/client/client.go @@ -313,6 +313,26 @@ func (c *EyrieClient) StreamChat(ctx context.Context, messages []EyrieMessage, o return p.StreamChat(ctx, messages, opts) } +// StreamChatContinue is like StreamChat but automatically continues if the response +// hits max_tokens with text-only content. Continuations are transparent to the caller. +func (c *EyrieClient) StreamChatContinue(ctx context.Context, messages []EyrieMessage, opts ChatOptions, cfg ContinuationConfig) (*StreamResult, error) { + if len(messages) == 0 { + return nil, fmt.Errorf("eyrie: messages must not be empty") + } + provider := opts.Provider + if provider == "" { + provider = c.defaultProvider + } + p, err := c.getOrCreateProvider(provider) + if err != nil { + return nil, err + } + if opts.Model == "" { + opts.Model = ResolveDefaultModel(provider) + } + return StreamChatWithContinuation(ctx, p, messages, opts, cfg) +} + // Ping checks connectivity to the specified (or default) provider. func (c *EyrieClient) Ping(ctx context.Context, provider string) error { if provider == "" { diff --git a/client/continuation.go b/client/continuation.go index d68a3fd..b11619e 100644 --- a/client/continuation.go +++ b/client/continuation.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "strings" + "time" ) // ContinuationConfig controls output continuation behavior. @@ -88,3 +89,95 @@ func ChatWithContinuation(ctx context.Context, p Provider, messages []EyrieMessa ToolCalls: finalToolCalls, Usage: finalUsage, }, nil } + +// StreamChatWithContinuation wraps StreamChat with automatic continuation when +// the response stops with "max_tokens" and contains only text (no tool calls). +// It returns a StreamResult whose Events channel transparently continues across +// multiple LLM calls, emitting a "continuation" event at each boundary. +func StreamChatWithContinuation(ctx context.Context, p Provider, messages []EyrieMessage, opts ChatOptions, cfg ContinuationConfig) (*StreamResult, error) { + if cfg.MaxContinuations <= 0 { + cfg.MaxContinuations = 3 + } + if cfg.MaxTotalTokens <= 0 { + cfg.MaxTotalTokens = 32000 + } + + groupID := fmt.Sprintf("cont_%d", time.Now().UnixNano()) + outCh := make(chan EyrieStreamEvent, streamChannelBuffer) + cancelCtx, cancel := context.WithCancel(ctx) + + go func() { + defer close(outCh) + + var accumulated strings.Builder + var totalCompletionTokens int64 + var hadToolCalls bool + msgs := make([]EyrieMessage, len(messages)) + copy(msgs, messages) + + for attempt := 0; attempt <= cfg.MaxContinuations; attempt++ { + stream, err := p.StreamChat(cancelCtx, msgs, opts) + if err != nil { + emit(cancelCtx, outCh, EyrieStreamEvent{Type: "error", Error: err.Error()}) + return + } + + var stopReason string + for evt := range stream.Events { + switch evt.Type { + case "content": + accumulated.WriteString(evt.Content) + emit(cancelCtx, outCh, evt) + case "tool_call": + hadToolCalls = true + emit(cancelCtx, outCh, evt) + case "usage": + if evt.Usage != nil { + totalCompletionTokens += int64(evt.Usage.CompletionTokens) + } + emit(cancelCtx, outCh, evt) + case "done": + stopReason = evt.StopReason + case "error": + emit(cancelCtx, outCh, evt) + return + default: + emit(cancelCtx, outCh, evt) + } + } + + // Don't continue if: not max_tokens, had tool calls, or hit token cap + if stopReason != "max_tokens" && stopReason != "length" { + emit(cancelCtx, outCh, EyrieStreamEvent{Type: "done", StopReason: stopReason}) + return + } + if hadToolCalls { + emit(cancelCtx, outCh, EyrieStreamEvent{Type: "done", StopReason: stopReason}) + return + } + if cfg.MaxTotalTokens > 0 && int(totalCompletionTokens) >= cfg.MaxTotalTokens { + emit(cancelCtx, outCh, EyrieStreamEvent{Type: "done", StopReason: "max_tokens"}) + return + } + if attempt >= cfg.MaxContinuations { + emit(cancelCtx, outCh, EyrieStreamEvent{Type: "done", StopReason: "max_tokens"}) + return + } + + // Emit continuation boundary event + emit(cancelCtx, outCh, EyrieStreamEvent{ + Type: "continuation", + Content: groupID, + StopReason: fmt.Sprintf("%d", attempt+1), + }) + + // Build continuation messages + msgs = append(msgs, EyrieMessage{Role: "assistant", Content: accumulated.String()}) + msgs = append(msgs, EyrieMessage{Role: "user", Content: "Continue."}) + } + + emit(cancelCtx, outCh, EyrieStreamEvent{Type: "done", StopReason: "max_tokens"}) + }() + + return &StreamResult{Events: outCh, cancel: cancel}, nil +} diff --git a/client/openai.go b/client/openai.go index 3803d6e..6954386 100644 --- a/client/openai.go +++ b/client/openai.go @@ -195,6 +195,7 @@ func (c *OpenAIClient) buildRequest(messages []EyrieMessage, opts ChatOptions, s // Chat sends a non-streaming request. func (c *OpenAIClient) Chat(ctx context.Context, messages []EyrieMessage, opts ChatOptions) (*EyrieResponse, error) { + messages = SanitizeMessages(messages) if opts.Model == "" { return nil, fmt.Errorf("eyrie: model is required for %s", c.providerName) } @@ -249,6 +250,7 @@ func (c *OpenAIClient) Chat(ctx context.Context, messages []EyrieMessage, opts C // StreamChat sends a streaming request. func (c *OpenAIClient) StreamChat(ctx context.Context, messages []EyrieMessage, opts ChatOptions) (*StreamResult, error) { + messages = SanitizeMessages(messages) if opts.Model == "" { return nil, fmt.Errorf("eyrie: model is required for %s", c.providerName) } diff --git a/client/sanitize.go b/client/sanitize.go new file mode 100644 index 0000000..f4c6027 --- /dev/null +++ b/client/sanitize.go @@ -0,0 +1,44 @@ +package client + +// SanitizeMessages inspects messages for orphaned tool_use blocks +// (assistant messages with tool calls that lack matching tool_result blocks) +// and injects synthetic error results to prevent 400 errors from providers. +// This is critical for session resume and compaction scenarios. +func SanitizeMessages(messages []EyrieMessage) []EyrieMessage { + if len(messages) == 0 { + return messages + } + + // Collect all tool_result IDs + resultIDs := make(map[string]bool) + for _, msg := range messages { + if msg.Role == "user" && msg.ToolResult != nil && msg.ToolResult.ToolUseID != "" { + resultIDs[msg.ToolResult.ToolUseID] = true + } + } + + // Find orphaned tool_use IDs and inject synthetic results + var result []EyrieMessage + for _, msg := range messages { + result = append(result, msg) + + if msg.Role == "assistant" && len(msg.ToolUse) > 0 { + for _, tc := range msg.ToolUse { + if tc.ID != "" && !resultIDs[tc.ID] { + // Inject synthetic error result + result = append(result, EyrieMessage{ + Role: "user", + ToolResult: &ToolResult{ + ToolUseID: tc.ID, + Content: "Tool execution was interrupted", + IsError: true, + }, + }) + resultIDs[tc.ID] = true + } + } + } + } + + return result +} diff --git a/client/stream.go b/client/stream.go index c5a68e1..2f0afc1 100644 --- a/client/stream.go +++ b/client/stream.go @@ -108,6 +108,7 @@ func processAnthropicStream(ctx context.Context, sseEvents <-chan SSEEvent, logg jsonBuf strings.Builder } var currentTool *toolAccum + var stopReason string for { select { @@ -190,7 +191,7 @@ func processAnthropicStream(ctx context.Context, sseEvents <-chan SSEEvent, logg } case "message_stop": - emit(ctx, ch, EyrieStreamEvent{Type: "done"}) + emit(ctx, ch, EyrieStreamEvent{Type: "done", StopReason: stopReason}) return case "message_delta": @@ -204,6 +205,9 @@ func processAnthropicStream(ctx context.Context, sseEvents <-chan SSEEvent, logg } `json:"usage"` } _ = json.Unmarshal([]byte(data), &md) + if md.Delta != nil && md.Delta.StopReason != "" { + stopReason = md.Delta.StopReason + } if md.Usage != nil && md.Usage.OutputTokens > 0 { emit(ctx, ch, EyrieStreamEvent{ Type: "usage", @@ -370,7 +374,7 @@ func processOpenAIStream(ctx context.Context, sseEvents <-chan SSEEvent, logger if choice.FinishReason != nil { emitTools() - emit(ctx, ch, EyrieStreamEvent{Type: "done"}) + emit(ctx, ch, EyrieStreamEvent{Type: "done", StopReason: *choice.FinishReason}) return } } diff --git a/client/weighted.go b/client/weighted.go new file mode 100644 index 0000000..c2b648e --- /dev/null +++ b/client/weighted.go @@ -0,0 +1,228 @@ +package client + +import ( + "context" + "fmt" + "math/rand" + "sort" + "strings" + "sync" + "sync/atomic" +) + +// WeightedProviderConfig associates a Provider with a selection weight. +type WeightedProviderConfig struct { + Provider Provider + Weight float64 // relative weight (e.g., 0.8 for 80%) +} + +// WeightedProvider selects a provider based on configured weights, +// with automatic failover to remaining providers on retriable errors. +// +// WeightedProvider is safe for concurrent use. +type WeightedProvider struct { + configs []normalizedConfig // sorted by descending weight + mu sync.Mutex + rng *rand.Rand + + // stats tracks how many times each provider served a request. + stats map[string]*atomic.Int64 +} + +// normalizedConfig holds a provider with its normalized (0-1) weight. +type normalizedConfig struct { + provider Provider + weight float64 +} + +// Compile-time check that WeightedProvider implements Provider. +var _ Provider = (*WeightedProvider)(nil) + +// NewWeightedProvider creates a WeightedProvider that selects providers +// based on the configured weights. At least one provider must be supplied. +// Weights are normalized to sum to 1.0. +func NewWeightedProvider(configs []WeightedProviderConfig) *WeightedProvider { + if len(configs) == 0 { + panic("eyrie: WeightedProvider requires at least one provider config") + } + + // Compute total weight for normalization. + var total float64 + for _, c := range configs { + if c.Weight <= 0 { + panic("eyrie: WeightedProvider weights must be positive") + } + total += c.Weight + } + + normalized := make([]normalizedConfig, len(configs)) + for i, c := range configs { + normalized[i] = normalizedConfig{ + provider: c.Provider, + weight: c.Weight / total, + } + } + + // Sort by descending weight for failover ordering. + sort.Slice(normalized, func(i, j int) bool { + return normalized[i].weight > normalized[j].weight + }) + + stats := make(map[string]*atomic.Int64, len(configs)) + for _, c := range normalized { + if _, ok := stats[c.provider.Name()]; !ok { + stats[c.provider.Name()] = &atomic.Int64{} + } + } + + return &WeightedProvider{ + configs: normalized, + rng: rand.New(rand.NewSource(rand.Int63())), + stats: stats, + } +} + +// Name returns a composite name showing providers and their weights. +func (wp *WeightedProvider) Name() string { + parts := make([]string, len(wp.configs)) + for i, c := range wp.configs { + parts[i] = fmt.Sprintf("%s:%.2f", c.provider.Name(), c.weight) + } + return "weighted(" + strings.Join(parts, ",") + ")" +} + +// Ping tries to ping each provider, returning nil on the first success. +func (wp *WeightedProvider) Ping(ctx context.Context) error { + var lastErr error + for _, c := range wp.configs { + if err := ctx.Err(); err != nil { + return err + } + if err := c.provider.Ping(ctx); err != nil { + lastErr = err + continue + } + return nil + } + return fmt.Errorf("eyrie: all weighted providers failed ping: %w", lastErr) +} + +// Chat sends a non-streaming chat request using weighted random selection +// with failover on retriable errors. +func (wp *WeightedProvider) Chat(ctx context.Context, messages []EyrieMessage, opts ChatOptions) (*EyrieResponse, error) { + selected := wp.selectProvider() + var lastErr error + + // Try the selected provider first. + if err := ctx.Err(); err != nil { + return nil, err + } + resp, err := selected.Chat(ctx, messages, opts) + if err == nil { + wp.recordSuccess(selected.Name()) + return resp, nil + } + + // If non-retriable, return immediately. + if !isRetriableError(err) { + return nil, err + } + lastErr = err + + // Failover: try remaining providers in weight-descending order. + for _, c := range wp.configs { + if c.provider == selected { + continue + } + if err := ctx.Err(); err != nil { + return nil, err + } + resp, err := c.provider.Chat(ctx, messages, opts) + if err == nil { + wp.recordSuccess(c.provider.Name()) + return resp, nil + } + if !isRetriableError(err) { + return nil, err + } + lastErr = err + } + + return nil, fmt.Errorf("eyrie: all weighted providers failed: %w", lastErr) +} + +// StreamChat sends a streaming chat request using weighted random selection +// with failover on retriable errors. +func (wp *WeightedProvider) StreamChat(ctx context.Context, messages []EyrieMessage, opts ChatOptions) (*StreamResult, error) { + selected := wp.selectProvider() + var lastErr error + + // Try the selected provider first. + if err := ctx.Err(); err != nil { + return nil, err + } + sr, err := selected.StreamChat(ctx, messages, opts) + if err == nil { + wp.recordSuccess(selected.Name()) + return sr, nil + } + + // If non-retriable, return immediately. + if !isRetriableError(err) { + return nil, err + } + lastErr = err + + // Failover: try remaining providers in weight-descending order. + for _, c := range wp.configs { + if c.provider == selected { + continue + } + if err := ctx.Err(); err != nil { + return nil, err + } + sr, err := c.provider.StreamChat(ctx, messages, opts) + if err == nil { + wp.recordSuccess(c.provider.Name()) + return sr, nil + } + if !isRetriableError(err) { + return nil, err + } + lastErr = err + } + + return nil, fmt.Errorf("eyrie: all weighted providers failed streaming: %w", lastErr) +} + +// Stats returns a snapshot of how many times each provider served a request. +func (wp *WeightedProvider) Stats() map[string]int64 { + result := make(map[string]int64, len(wp.stats)) + for name, counter := range wp.stats { + result[name] = counter.Load() + } + return result +} + +// selectProvider picks a provider based on weighted random selection. +func (wp *WeightedProvider) selectProvider() Provider { + wp.mu.Lock() + r := wp.rng.Float64() + wp.mu.Unlock() + + var cumulative float64 + for _, c := range wp.configs { + cumulative += c.weight + if r < cumulative { + return c.provider + } + } + // Fallback to last provider (handles floating point edge case). + return wp.configs[len(wp.configs)-1].provider +} + +func (wp *WeightedProvider) recordSuccess(name string) { + if counter, ok := wp.stats[name]; ok { + counter.Add(1) + } +} diff --git a/client/weighted_test.go b/client/weighted_test.go new file mode 100644 index 0000000..7412fb3 --- /dev/null +++ b/client/weighted_test.go @@ -0,0 +1,276 @@ +package client + +import ( + "context" + "fmt" + "math" + "testing" +) + +func TestWeightedProviderSingleProvider(t *testing.T) { + p := NewMockProvider(MockModeFixed) + p.Response = "only one" + + wp := NewWeightedProvider([]WeightedProviderConfig{ + {Provider: p, Weight: 1.0}, + }) + + for i := 0; i < 10; i++ { + resp, err := wp.Chat(context.Background(), []EyrieMessage{ + {Role: "user", Content: "hello"}, + }, ChatOptions{Model: "test"}) + if err != nil { + t.Fatalf("call %d: unexpected error: %v", i, err) + } + if resp.Content != "only one" { + t.Errorf("call %d: expected 'only one', got %q", i, resp.Content) + } + } + + if p.CallCount() != 10 { + t.Errorf("expected 10 calls, got %d", p.CallCount()) + } +} + +func TestWeightedProviderDistribution(t *testing.T) { + // Use named providers so stats can distinguish them. + primary := &namedProvider{name: "primary", mock: NewMockProvider(MockModeFixed)} + primary.mock.Response = "from primary" + + secondary := &namedProvider{name: "secondary", mock: NewMockProvider(MockModeFixed)} + secondary.mock.Response = "from secondary" + + wp := NewWeightedProvider([]WeightedProviderConfig{ + {Provider: primary, Weight: 0.8}, + {Provider: secondary, Weight: 0.2}, + }) + + const iterations = 1000 + counts := map[string]int{"primary": 0, "secondary": 0} + + for i := 0; i < iterations; i++ { + resp, err := wp.Chat(context.Background(), []EyrieMessage{ + {Role: "user", Content: "hello"}, + }, ChatOptions{Model: "test"}) + if err != nil { + t.Fatalf("call %d: unexpected error: %v", i, err) + } + if resp.Content == "from primary" { + counts["primary"]++ + } else if resp.Content == "from secondary" { + counts["secondary"]++ + } else { + t.Fatalf("unexpected response: %q", resp.Content) + } + } + + // Check distribution is roughly 80/20 with tolerance of 8%. + primaryRatio := float64(counts["primary"]) / float64(iterations) + secondaryRatio := float64(counts["secondary"]) / float64(iterations) + + if math.Abs(primaryRatio-0.8) > 0.08 { + t.Errorf("primary ratio %.3f is too far from expected 0.80", primaryRatio) + } + if math.Abs(secondaryRatio-0.2) > 0.08 { + t.Errorf("secondary ratio %.3f is too far from expected 0.20", secondaryRatio) + } +} + +func TestWeightedProviderFailoverOnRetriableError(t *testing.T) { + // Primary always returns a retriable error. + primary := &namedProvider{name: "primary", mock: nil, err: fmt.Errorf("HTTP 503 service unavailable")} + // Secondary succeeds. + secondary := &namedProvider{name: "secondary", mock: NewMockProvider(MockModeFixed)} + secondary.mock.Response = "fallback success" + + wp := NewWeightedProvider([]WeightedProviderConfig{ + {Provider: primary, Weight: 1.0}, // will always be selected + {Provider: secondary, Weight: 0.01}, // extremely low weight + }) + + resp, err := wp.Chat(context.Background(), []EyrieMessage{ + {Role: "user", Content: "hello"}, + }, ChatOptions{Model: "test"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.Content != "fallback success" { + t.Errorf("expected 'fallback success', got %q", resp.Content) + } +} + +func TestWeightedProviderNoFailoverOnNonRetriableError(t *testing.T) { + // Primary returns a 400 (non-retriable). + primary := &namedProvider{name: "primary", mock: nil, err: fmt.Errorf("HTTP 400 bad request")} + // Secondary would succeed if reached. + secondary := &namedProvider{name: "secondary", mock: NewMockProvider(MockModeFixed)} + secondary.mock.Response = "should not reach" + + wp := NewWeightedProvider([]WeightedProviderConfig{ + {Provider: primary, Weight: 1.0}, // always selected + {Provider: secondary, Weight: 0.01}, + }) + + _, err := wp.Chat(context.Background(), []EyrieMessage{ + {Role: "user", Content: "hello"}, + }, ChatOptions{Model: "test"}) + if err == nil { + t.Fatal("expected error for non-retriable 400") + } + // Secondary should NOT have been called. + if secondary.mock.CallCount() != 0 { + t.Errorf("secondary was called %d times; should not be called for non-retriable error", secondary.mock.CallCount()) + } +} + +func TestWeightedProviderNoFailoverOn401(t *testing.T) { + // Primary returns a 401 (non-retriable). + primary := &namedProvider{name: "primary", mock: nil, err: fmt.Errorf("HTTP 401 unauthorized")} + secondary := &namedProvider{name: "secondary", mock: NewMockProvider(MockModeFixed)} + secondary.mock.Response = "should not reach" + + wp := NewWeightedProvider([]WeightedProviderConfig{ + {Provider: primary, Weight: 1.0}, + {Provider: secondary, Weight: 0.01}, + }) + + _, err := wp.Chat(context.Background(), []EyrieMessage{ + {Role: "user", Content: "hello"}, + }, ChatOptions{Model: "test"}) + if err == nil { + t.Fatal("expected error for non-retriable 401") + } + if secondary.mock.CallCount() != 0 { + t.Errorf("secondary should not be called for 401 error") + } +} + +func TestWeightedProviderAllFail(t *testing.T) { + p1 := &namedProvider{name: "p1", mock: nil, err: fmt.Errorf("HTTP 503 service unavailable")} + p2 := &namedProvider{name: "p2", mock: nil, err: fmt.Errorf("HTTP 502 bad gateway")} + p3 := &namedProvider{name: "p3", mock: nil, err: fmt.Errorf("HTTP 500 internal error")} + + wp := NewWeightedProvider([]WeightedProviderConfig{ + {Provider: p1, Weight: 0.5}, + {Provider: p2, Weight: 0.3}, + {Provider: p3, Weight: 0.2}, + }) + + _, err := wp.Chat(context.Background(), []EyrieMessage{ + {Role: "user", Content: "hello"}, + }, ChatOptions{Model: "test"}) + if err == nil { + t.Fatal("expected error when all providers fail") + } +} + +func TestWeightedProviderName(t *testing.T) { + p1 := &namedProvider{name: "anthropic", mock: NewMockProvider(MockModeFixed)} + p2 := &namedProvider{name: "openai", mock: NewMockProvider(MockModeFixed)} + + wp := NewWeightedProvider([]WeightedProviderConfig{ + {Provider: p1, Weight: 0.8}, + {Provider: p2, Weight: 0.2}, + }) + + expected := "weighted(anthropic:0.80,openai:0.20)" + if wp.Name() != expected { + t.Errorf("expected name %q, got %q", expected, wp.Name()) + } +} + +func TestWeightedProviderPing(t *testing.T) { + p1 := &namedProvider{name: "failing", mock: nil, err: fmt.Errorf("ping failed")} + p2 := &namedProvider{name: "ok", mock: NewMockProvider(MockModeFixed)} + + wp := NewWeightedProvider([]WeightedProviderConfig{ + {Provider: p1, Weight: 0.8}, + {Provider: p2, Weight: 0.2}, + }) + + // Should succeed because p2 pings ok (even though p1 fails). + if err := wp.Ping(context.Background()); err != nil { + t.Fatalf("expected ping to succeed, got: %v", err) + } +} + +func TestWeightedProviderStreamFailover(t *testing.T) { + primary := &namedProvider{name: "primary", mock: nil, err: fmt.Errorf("HTTP 429 rate limited")} + secondary := &namedProvider{name: "secondary", mock: NewMockProvider(MockModeFixed)} + secondary.mock.Response = "streamed from secondary" + + wp := NewWeightedProvider([]WeightedProviderConfig{ + {Provider: primary, Weight: 1.0}, + {Provider: secondary, Weight: 0.01}, + }) + + sr, err := wp.StreamChat(context.Background(), []EyrieMessage{ + {Role: "user", Content: "hello"}, + }, ChatOptions{Model: "test"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer sr.Close() + + var content string + for evt := range sr.Events { + if evt.Type == "content" { + content += evt.Content + } + } + if content == "" { + t.Error("expected some streamed content") + } +} + +func TestWeightedProviderPanicOnEmpty(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("expected panic with no provider configs") + } + }() + NewWeightedProvider(nil) +} + +func TestWeightedProviderPanicOnZeroWeight(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("expected panic with zero weight") + } + }() + p := NewMockProvider(MockModeFixed) + NewWeightedProvider([]WeightedProviderConfig{ + {Provider: p, Weight: 0}, + }) +} + +// namedProvider wraps a mock provider with a custom name, used to distinguish +// providers in stats and test assertions. +type namedProvider struct { + name string + mock *MockProvider + err error // if set, all calls return this error +} + +func (n *namedProvider) Name() string { return n.name } + +func (n *namedProvider) Ping(_ context.Context) error { + if n.err != nil { + return n.err + } + return nil +} + +func (n *namedProvider) Chat(ctx context.Context, messages []EyrieMessage, opts ChatOptions) (*EyrieResponse, error) { + if n.err != nil { + return nil, n.err + } + return n.mock.Chat(ctx, messages, opts) +} + +func (n *namedProvider) StreamChat(ctx context.Context, messages []EyrieMessage, opts ChatOptions) (*StreamResult, error) { + if n.err != nil { + return nil, n.err + } + return n.mock.StreamChat(ctx, messages, opts) +} From fa9db4e0c2b9f59f5c5da50ccf74cef44ffb6db5 Mon Sep 17 00:00:00 2001 From: Patel230 Date: Fri, 8 May 2026 10:31:07 +0530 Subject: [PATCH 2/2] feat: enhance prompt caching with full provider support - Add cache_control annotations for tool definitions (last tool cached) - Handle tool_use/tool_result blocks in cached request builder - Parse cache_creation_input_tokens and cache_read_input_tokens from Anthropic API - Parse prompt_tokens_details.cached_tokens from OpenAI API - Add CacheCreationTokens/CacheReadTokens to unified EyrieUsage - Bump Go version to 1.26.1 for toolchain consistency --- client/anthropic.go | 54 +++++-------- client/cache.go | 69 +++++++++++------ client/cache_test.go | 176 +++++++++++++++++++++++++++++++++++++++++++ client/client.go | 8 +- client/openai.go | 11 ++- go.mod | 2 +- 6 files changed, 253 insertions(+), 67 deletions(-) create mode 100644 client/cache_test.go diff --git a/client/anthropic.go b/client/anthropic.go index 7a616c3..a555230 100644 --- a/client/anthropic.go +++ b/client/anthropic.go @@ -79,9 +79,11 @@ type anthropicResponse struct { Input json.RawMessage `json:"input,omitempty"` } `json:"content"` StopReason string `json:"stop_reason"` - Usage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` + Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + CacheCreationInputTokens int `json:"cache_creation_input_tokens"` + CacheReadInputTokens int `json:"cache_read_input_tokens"` } `json:"usage"` } @@ -193,25 +195,12 @@ func (c *AnthropicClient) Chat(ctx context.Context, messages []EyrieMessage, opt var body []byte if opts.EnableCaching { - // Use cached request builder for Anthropic prompt caching support - cachedReq := buildAnthropicCachedRequest(messages, opts.Model, maxTokens, opts.Temperature, false) + allMessages := messages if opts.System != "" { - if existing, ok := cachedReq["system"]; ok && existing != nil { - // System already set as cached array; prepend the opts.System - _ = existing // already handled by buildAnthropicCachedRequest - } else { - cachedReq["system"] = []map[string]interface{}{ - { - "type": "text", - "text": opts.System, - "cache_control": map[string]string{"type": "ephemeral"}, - }, - } - } - } - if tools := convertToAnthropicTools(opts.Tools); len(tools) > 0 { - cachedReq["tools"] = tools + allMessages = append([]EyrieMessage{{Role: "system", Content: opts.System}}, allMessages...) } + tools := convertToAnthropicTools(opts.Tools) + cachedReq := buildAnthropicCachedRequest(allMessages, opts.Model, maxTokens, opts.Temperature, false, tools) body, _ = json.Marshal(cachedReq) } else { msgs, system := buildAnthropicMessages(messages) @@ -273,8 +262,11 @@ func (c *AnthropicClient) Chat(ctx context.Context, messages []EyrieMessage, opt Content: content, FinishReason: ar.StopReason, ToolCalls: toolCalls, RequestID: requestID, Usage: &EyrieUsage{ - PromptTokens: ar.Usage.InputTokens, CompletionTokens: ar.Usage.OutputTokens, - TotalTokens: ar.Usage.InputTokens + ar.Usage.OutputTokens, + PromptTokens: ar.Usage.InputTokens, + CompletionTokens: ar.Usage.OutputTokens, + TotalTokens: ar.Usage.InputTokens + ar.Usage.OutputTokens, + CacheCreationTokens: ar.Usage.CacheCreationInputTokens, + CacheReadTokens: ar.Usage.CacheReadInputTokens, }, }, nil } @@ -292,22 +284,12 @@ func (c *AnthropicClient) StreamChat(ctx context.Context, messages []EyrieMessag var body []byte if opts.EnableCaching { - // Use cached request builder for Anthropic prompt caching support - cachedReq := buildAnthropicCachedRequest(messages, opts.Model, maxTokens, opts.Temperature, true) + allMessages := messages if opts.System != "" { - if _, ok := cachedReq["system"]; !ok { - cachedReq["system"] = []map[string]interface{}{ - { - "type": "text", - "text": opts.System, - "cache_control": map[string]string{"type": "ephemeral"}, - }, - } - } - } - if tools := convertToAnthropicTools(opts.Tools); len(tools) > 0 { - cachedReq["tools"] = tools + allMessages = append([]EyrieMessage{{Role: "system", Content: opts.System}}, allMessages...) } + tools := convertToAnthropicTools(opts.Tools) + cachedReq := buildAnthropicCachedRequest(allMessages, opts.Model, maxTokens, opts.Temperature, true, tools) body, _ = json.Marshal(cachedReq) } else { msgs, system := buildAnthropicMessages(messages) diff --git a/client/cache.go b/client/cache.go index cd6bacf..330f843 100644 --- a/client/cache.go +++ b/client/cache.go @@ -65,32 +65,18 @@ type anthropicCachedMessage struct { } // buildAnthropicCachedRequest builds an Anthropic request body with cache_control. -// Use this instead of the standard request builder when prompt caching is desired. -func buildAnthropicCachedRequest(messages []EyrieMessage, model string, maxTokens int, temperature *float64, stream bool) map[string]interface{} { - var system string - var msgs []interface{} +// It reuses buildAnthropicMessages for proper tool_use/tool_result handling, +// then applies cache_control breakpoints following Anthropic's best practices: +// - System prompt gets cache_control (cached for all turns) +// - Second-to-last message gets cache_control (caches conversation prefix) +// - Last tool definition gets cache_control (caches tool schema) +func buildAnthropicCachedRequest(messages []EyrieMessage, model string, maxTokens int, temperature *float64, stream bool, tools []anthropicTool) map[string]interface{} { + msgs, system := buildAnthropicMessages(messages) - for _, m := range messages { - if m.Role == "system" { - system = m.Content - continue - } - msgs = append(msgs, map[string]interface{}{"role": m.Role, "content": m.Content}) - } - - // Apply cache breakpoint to second-to-last message + // Apply cache breakpoint to second-to-last non-system message if len(msgs) >= 2 { idx := len(msgs) - 2 - if msg, ok := msgs[idx].(map[string]interface{}); ok { - content := msg["content"].(string) - msg["content"] = []map[string]interface{}{ - { - "type": "text", - "text": content, - "cache_control": map[string]string{"type": "ephemeral"}, - }, - } - } + applyCacheBreakpointToMessage(msgs[idx]) } req := map[string]interface{}{ @@ -102,14 +88,47 @@ func buildAnthropicCachedRequest(messages []EyrieMessage, model string, maxToken if system != "" { req["system"] = []map[string]interface{}{ { - "type": "text", - "text": system, + "type": "text", + "text": system, "cache_control": map[string]string{"type": "ephemeral"}, }, } } + if len(tools) > 0 { + toolMaps := make([]map[string]interface{}, len(tools)) + for i, t := range tools { + toolMaps[i] = map[string]interface{}{ + "name": t.Name, + "description": t.Description, + "input_schema": t.InputSchema, + } + } + // Annotate last tool with cache_control + toolMaps[len(toolMaps)-1]["cache_control"] = map[string]string{"type": "ephemeral"} + req["tools"] = toolMaps + } if temperature != nil { req["temperature"] = *temperature } return req } + +// applyCacheBreakpointToMessage adds cache_control to a message's content. +// Handles both string content and array content (tool_use/tool_result blocks). +func applyCacheBreakpointToMessage(msg map[string]interface{}) { + content := msg["content"] + switch c := content.(type) { + case string: + msg["content"] = []map[string]interface{}{ + { + "type": "text", + "text": c, + "cache_control": map[string]string{"type": "ephemeral"}, + }, + } + case []map[string]interface{}: + if len(c) > 0 { + c[len(c)-1]["cache_control"] = map[string]string{"type": "ephemeral"} + } + } +} diff --git a/client/cache_test.go b/client/cache_test.go new file mode 100644 index 0000000..40d1b3a --- /dev/null +++ b/client/cache_test.go @@ -0,0 +1,176 @@ +package client + +import ( + "encoding/json" + "testing" +) + +func TestBuildAnthropicCachedRequest_BasicMessages(t *testing.T) { + messages := []EyrieMessage{ + {Role: "system", Content: "You are helpful."}, + {Role: "user", Content: "Hello"}, + {Role: "assistant", Content: "Hi there!"}, + {Role: "user", Content: "How are you?"}, + } + req := buildAnthropicCachedRequest(messages, "claude-sonnet-4-20250514", 4096, nil, false, nil) + + // System should be array with cache_control + system, ok := req["system"].([]map[string]interface{}) + if !ok || len(system) != 1 { + t.Fatal("expected system as array with one element") + } + if system[0]["cache_control"] == nil { + t.Fatal("expected cache_control on system") + } + if system[0]["text"] != "You are helpful." { + t.Fatal("system text mismatch") + } + + // Messages: second-to-last (index 1, assistant) should have cache_control + msgs := req["messages"].([]map[string]interface{}) + if len(msgs) != 3 { // 3 non-system messages + t.Fatalf("expected 3 messages, got %d", len(msgs)) + } + + // Second to last message (index 1 = assistant "Hi there!") should be array with cache_control + assistantContent, ok := msgs[1]["content"].([]map[string]interface{}) + if !ok { + t.Fatal("expected assistant content to be array after cache breakpoint") + } + if assistantContent[0]["cache_control"] == nil { + t.Fatal("expected cache_control on second-to-last message") + } +} + +func TestBuildAnthropicCachedRequest_ToolUsePropagated(t *testing.T) { + messages := []EyrieMessage{ + {Role: "user", Content: "read file.go"}, + {Role: "assistant", Content: "", ToolUse: []ToolCall{ + {ID: "tc1", Name: "read", Arguments: map[string]interface{}{"path": "file.go"}}, + }}, + {Role: "user", Content: "", ToolResult: &ToolResult{ToolUseID: "tc1", Content: "package main"}}, + {Role: "user", Content: "now edit it"}, + } + req := buildAnthropicCachedRequest(messages, "claude-sonnet-4-20250514", 4096, nil, false, nil) + + msgs := req["messages"].([]map[string]interface{}) + if len(msgs) != 4 { + t.Fatalf("expected 4 messages, got %d", len(msgs)) + } + + // Verify tool_use message (index 1) preserved + assistantMsg := msgs[1] + content, ok := assistantMsg["content"].([]map[string]interface{}) + if !ok { + t.Fatal("expected assistant tool_use as array content") + } + found := false + for _, block := range content { + if block["type"] == "tool_use" { + found = true + break + } + } + if !found { + t.Fatal("expected tool_use block in assistant message") + } + + // Verify tool_result message (index 2) is the cached one (second-to-last) + toolResultMsg := msgs[2] + trContent, ok := toolResultMsg["content"].([]map[string]interface{}) + if !ok { + t.Fatal("expected tool_result as array content") + } + if trContent[len(trContent)-1]["cache_control"] == nil { + t.Fatal("expected cache_control on second-to-last message (tool_result)") + } +} + +func TestBuildAnthropicCachedRequest_ToolsAnnotated(t *testing.T) { + messages := []EyrieMessage{ + {Role: "user", Content: "hello"}, + } + tools := []anthropicTool{ + {Name: "read", Description: "Read a file", InputSchema: map[string]interface{}{"type": "object"}}, + {Name: "write", Description: "Write a file", InputSchema: map[string]interface{}{"type": "object"}}, + {Name: "bash", Description: "Run command", InputSchema: map[string]interface{}{"type": "object"}}, + } + req := buildAnthropicCachedRequest(messages, "claude-sonnet-4-20250514", 4096, nil, false, tools) + + toolMaps, ok := req["tools"].([]map[string]interface{}) + if !ok || len(toolMaps) != 3 { + t.Fatalf("expected 3 tools, got %v", req["tools"]) + } + + // Only the LAST tool should have cache_control + if toolMaps[0]["cache_control"] != nil { + t.Fatal("first tool should not have cache_control") + } + if toolMaps[1]["cache_control"] != nil { + t.Fatal("second tool should not have cache_control") + } + if toolMaps[2]["cache_control"] == nil { + t.Fatal("last tool must have cache_control") + } +} + +func TestCacheUsageParsing(t *testing.T) { + responseJSON := `{ + "id": "msg_123", + "content": [{"type": "text", "text": "Hello!"}], + "stop_reason": "end_turn", + "usage": { + "input_tokens": 100, + "output_tokens": 50, + "cache_creation_input_tokens": 1000, + "cache_read_input_tokens": 800 + } + }` + + var ar anthropicResponse + if err := json.Unmarshal([]byte(responseJSON), &ar); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if ar.Usage.CacheCreationInputTokens != 1000 { + t.Fatalf("expected cache_creation=1000, got %d", ar.Usage.CacheCreationInputTokens) + } + if ar.Usage.CacheReadInputTokens != 800 { + t.Fatalf("expected cache_read=800, got %d", ar.Usage.CacheReadInputTokens) + } + + // Verify it propagates to EyrieUsage + usage := &EyrieUsage{ + PromptTokens: ar.Usage.InputTokens, + CompletionTokens: ar.Usage.OutputTokens, + TotalTokens: ar.Usage.InputTokens + ar.Usage.OutputTokens, + CacheCreationTokens: ar.Usage.CacheCreationInputTokens, + CacheReadTokens: ar.Usage.CacheReadInputTokens, + } + if usage.CacheCreationTokens != 1000 || usage.CacheReadTokens != 800 { + t.Fatal("cache tokens not propagated correctly") + } +} + +func TestBuildAnthropicCachedRequest_NoSystem(t *testing.T) { + messages := []EyrieMessage{ + {Role: "user", Content: "Hello"}, + {Role: "assistant", Content: "Hi"}, + {Role: "user", Content: "Bye"}, + } + req := buildAnthropicCachedRequest(messages, "claude-sonnet-4-20250514", 4096, nil, false, nil) + + if _, ok := req["system"]; ok { + t.Fatal("should not have system key when no system message") + } +} + +func TestBuildAnthropicCachedRequest_StreamFlag(t *testing.T) { + messages := []EyrieMessage{ + {Role: "user", Content: "Hello"}, + } + req := buildAnthropicCachedRequest(messages, "claude-sonnet-4-20250514", 4096, nil, true, nil) + if req["stream"] != true { + t.Fatal("expected stream=true") + } +} diff --git a/client/client.go b/client/client.go index 0ee1bd3..facbd36 100644 --- a/client/client.go +++ b/client/client.go @@ -67,9 +67,11 @@ type EyrieTool struct { // EyrieUsage tracks token usage. type EyrieUsage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + CacheCreationTokens int `json:"cache_creation_tokens,omitempty"` + CacheReadTokens int `json:"cache_read_tokens,omitempty"` } // EyrieResponse is the response from a chat call. diff --git a/client/openai.go b/client/openai.go index 6954386..459fa45 100644 --- a/client/openai.go +++ b/client/openai.go @@ -92,6 +92,9 @@ type openaiResponse struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` TotalTokens int `json:"total_tokens"` + PromptTokensDetails *struct { + CachedTokens int `json:"cached_tokens"` + } `json:"prompt_tokens_details,omitempty"` } `json:"usage"` } @@ -241,8 +244,12 @@ func (c *OpenAIClient) Chat(ctx context.Context, messages []EyrieMessage, opts C } if or.Usage != nil { result.Usage = &EyrieUsage{ - PromptTokens: or.Usage.PromptTokens, CompletionTokens: or.Usage.CompletionTokens, - TotalTokens: or.Usage.TotalTokens, + PromptTokens: or.Usage.PromptTokens, + CompletionTokens: or.Usage.CompletionTokens, + TotalTokens: or.Usage.TotalTokens, + } + if or.Usage.PromptTokensDetails != nil { + result.Usage.CacheReadTokens = or.Usage.PromptTokensDetails.CachedTokens } } return result, nil diff --git a/go.mod b/go.mod index 4f165cb..cf92108 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/GrayCodeAI/eyrie -go 1.23 +go 1.26.1