From 159ab6e6c15c3d981b8118ab797a2c0d7ea9c41e Mon Sep 17 00:00:00 2001 From: rich Date: Mon, 27 Apr 2026 21:11:55 +0800 Subject: [PATCH] refactor(ai-adapters): keep provider failures typed --- src/crates/ai-adapters/src/client/sse.rs | 30 +- src/crates/ai-adapters/src/lib.rs | 2 + src/crates/ai-adapters/src/provider_error.rs | 386 ++++++++++++++++++ .../src/stream/stream_handler/anthropic.rs | 53 +-- .../src/stream/stream_handler/gemini.rs | 43 +- .../src/stream/stream_handler/openai.rs | 57 +-- .../src/stream/stream_handler/responses.rs | 54 +-- .../src/agentic/execution/round_executor.rs | 50 ++- .../src/agentic/execution/stream_processor.rs | 5 +- src/crates/core/src/util/errors.rs | 84 ++++ .../openai/provider_error_with_code.sse | 3 + .../core/tests/stream_processor_openai.rs | 33 +- 12 files changed, 681 insertions(+), 119 deletions(-) create mode 100644 src/crates/ai-adapters/src/provider_error.rs create mode 100644 src/crates/core/tests/fixtures/stream/openai/provider_error_with_code.sse diff --git a/src/crates/ai-adapters/src/client/sse.rs b/src/crates/ai-adapters/src/client/sse.rs index d2445f415..14783a581 100644 --- a/src/crates/ai-adapters/src/client/sse.rs +++ b/src/crates/ai-adapters/src/client/sse.rs @@ -1,7 +1,8 @@ use crate::client::utils::elapsed_ms_u64; use crate::client::StreamResponse; +use crate::provider_error::ProviderError; use crate::stream::UnifiedResponse; -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, Error, Result}; use chrono::{DateTime, Utc}; use log::{debug, error, warn}; use reqwest::{ @@ -81,8 +82,10 @@ where .text() .await .unwrap_or_else(|e| format!("Failed to read error response: {}", e)); - error!("{} client error {}: {}", label, status, error_text); - return Err(anyhow!("{} client error {}: {}", label, status, error_text)); + let provider_error = + ProviderError::from_http_error(label, status.as_u16(), &error_text); + error!("{} client error {}: {}", label, status, provider_error); + return Err(Error::new(provider_error)); } if status.is_success() { @@ -100,7 +103,9 @@ where .text() .await .unwrap_or_else(|e| format!("Failed to read error response: {}", e)); - let error = anyhow!("{} error {}: {}", label, status, error_text); + let provider_error = + ProviderError::from_http_error(label, status.as_u16(), &error_text); + let error = Error::new(provider_error); warn!( "{} request failed: {}ms, attempt {}/{}, error: {}", label, @@ -162,14 +167,17 @@ where }); } - let error_msg = format!( - "{} failed after {} attempts: {}", - label, - max_tries, - last_error.unwrap_or_else(|| anyhow!("Unknown error")) - ); + let error_msg = match &last_error { + Some(error) => format!("{} failed after {} attempts: {}", label, max_tries, error), + None => format!( + "{} failed after {} attempts: Unknown error", + label, max_tries + ), + }; error!("{}", error_msg); - Err(anyhow!(error_msg)) + Err(last_error + .unwrap_or_else(|| anyhow!("Unknown error")) + .context(format!("{} failed after {} attempts", label, max_tries))) } #[cfg(test)] diff --git a/src/crates/ai-adapters/src/lib.rs b/src/crates/ai-adapters/src/lib.rs index 31cf28e0b..04af92a5d 100644 --- a/src/crates/ai-adapters/src/lib.rs +++ b/src/crates/ai-adapters/src/lib.rs @@ -1,12 +1,14 @@ #![doc = include_str!("../README.md")] pub mod client; +pub mod provider_error; pub mod providers; pub mod stream; pub mod tool_call_accumulator; pub mod types; pub use client::{AIClient, StreamOptions, StreamResponse}; +pub use provider_error::{ProviderError, ProviderErrorKind}; pub use stream::{UnifiedResponse, UnifiedTokenUsage, UnifiedToolCall}; pub use types::{ resolve_request_url, AIConfig, ConnectionTestMessageCode, ConnectionTestResult, GeminiResponse, diff --git a/src/crates/ai-adapters/src/provider_error.rs b/src/crates/ai-adapters/src/provider_error.rs new file mode 100644 index 000000000..0f3ab49b5 --- /dev/null +++ b/src/crates/ai-adapters/src/provider_error.rs @@ -0,0 +1,386 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::error::Error; +use std::fmt; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum ProviderErrorKind { + Network, + Auth, + RateLimit, + ContextOverflow, + Timeout, + ProviderQuota, + ProviderBilling, + ProviderUnavailable, + Permission, + InvalidRequest, + ContentPolicy, + ModelError, + Unknown, +} + +impl ProviderErrorKind { + pub fn is_retryable(self) -> bool { + matches!( + self, + ProviderErrorKind::Network + | ProviderErrorKind::RateLimit + | ProviderErrorKind::Timeout + | ProviderErrorKind::ProviderUnavailable + ) + } + + fn as_str(self) -> &'static str { + match self { + ProviderErrorKind::Network => "network", + ProviderErrorKind::Auth => "auth", + ProviderErrorKind::RateLimit => "rate_limit", + ProviderErrorKind::ContextOverflow => "context_overflow", + ProviderErrorKind::Timeout => "timeout", + ProviderErrorKind::ProviderQuota => "provider_quota", + ProviderErrorKind::ProviderBilling => "provider_billing", + ProviderErrorKind::ProviderUnavailable => "provider_unavailable", + ProviderErrorKind::Permission => "permission", + ProviderErrorKind::InvalidRequest => "invalid_request", + ProviderErrorKind::ContentPolicy => "content_policy", + ProviderErrorKind::ModelError => "model_error", + ProviderErrorKind::Unknown => "unknown", + } + } +} + +impl fmt::Display for ProviderErrorKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct ProviderError { + provider: Option, + kind: ProviderErrorKind, + code: Option, + message: String, + request_id: Option, + http_status: Option, +} + +impl ProviderError { + pub fn builder(message: impl Into) -> ProviderErrorBuilder { + ProviderErrorBuilder { + error: ProviderError { + provider: None, + kind: ProviderErrorKind::Unknown, + code: None, + message: message.into(), + request_id: None, + http_status: None, + }, + } + } + + pub fn from_error_payload(provider: &str, payload: &Value) -> Option { + let error = payload.get("error")?; + let request_id = payload + .get("request_id") + .or_else(|| payload.get("requestId")) + .and_then(json_scalar_to_string); + + if let Some(message) = error.as_str() { + return Some( + Self::builder(message) + .provider(provider) + .kind(classify_provider_error(None, message, None)) + .maybe_request_id(request_id) + .build(), + ); + } + + let error_object = error.as_object()?; + let code = error_object.get("code").and_then(json_scalar_to_string); + let message = error_object + .get("message") + .and_then(|value| value.as_str()) + .or_else(|| error_object.get("error").and_then(|value| value.as_str())) + .unwrap_or("Provider returned an error"); + let http_status = error_object + .get("status") + .and_then(|value| value.as_u64()) + .and_then(|status| u16::try_from(status).ok()) + .or_else(|| { + payload + .get("status") + .and_then(|value| value.as_u64()) + .and_then(|status| u16::try_from(status).ok()) + }); + + Some( + Self::builder(message) + .provider(provider) + .kind(classify_provider_error( + code.as_deref(), + message, + http_status, + )) + .maybe_code(code) + .maybe_request_id(request_id) + .maybe_http_status(http_status) + .build(), + ) + } + + pub fn from_http_error(provider: &str, status: u16, body: &str) -> Self { + let parsed = serde_json::from_str::(body) + .ok() + .and_then(|value| Self::from_error_payload(provider, &value)); + + if let Some(error) = parsed { + return error.with_http_status(status); + } + + Self::builder(body.trim()) + .provider(provider) + .kind(classify_provider_error(None, body, Some(status))) + .http_status(status) + .build() + } + + pub fn provider(&self) -> Option<&str> { + self.provider.as_deref() + } + + pub fn kind(&self) -> ProviderErrorKind { + self.kind + } + + pub fn code(&self) -> Option<&str> { + self.code.as_deref() + } + + pub fn message(&self) -> &str { + &self.message + } + + pub fn request_id(&self) -> Option<&str> { + self.request_id.as_deref() + } + + pub fn http_status(&self) -> Option { + self.http_status + } + + pub fn is_retryable(&self) -> bool { + self.kind.is_retryable() + } + + fn with_http_status(mut self, status: u16) -> Self { + self.http_status = Some(status); + let kind_with_status = + classify_provider_error(self.code.as_deref(), &self.message, Some(status)); + if matches!( + self.kind, + ProviderErrorKind::Unknown | ProviderErrorKind::ModelError + ) && !matches!( + kind_with_status, + ProviderErrorKind::Unknown | ProviderErrorKind::ModelError + ) { + self.kind = kind_with_status; + } + self + } +} + +impl fmt::Display for ProviderError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("Provider error")?; + if let Some(provider) = &self.provider { + write!(f, ": provider={provider}")?; + } + write!(f, ", kind={}", self.kind)?; + if let Some(code) = &self.code { + write!(f, ", code={code}")?; + } + if let Some(request_id) = &self.request_id { + write!(f, ", request_id={request_id}")?; + } + if let Some(http_status) = self.http_status { + write!(f, ", http_status={http_status}")?; + } + write!( + f, + ", retryable={}, message={}", + self.is_retryable(), + self.message + ) + } +} + +impl Error for ProviderError {} + +pub struct ProviderErrorBuilder { + error: ProviderError, +} + +impl ProviderErrorBuilder { + pub fn provider(mut self, provider: impl Into) -> Self { + self.error.provider = Some(provider.into()); + self + } + + pub fn kind(mut self, kind: ProviderErrorKind) -> Self { + self.error.kind = kind; + self + } + + pub fn code(mut self, code: impl Into) -> Self { + self.error.code = Some(code.into()); + self + } + + pub fn request_id(mut self, request_id: impl Into) -> Self { + self.error.request_id = Some(request_id.into()); + self + } + + pub fn http_status(mut self, http_status: u16) -> Self { + self.error.http_status = Some(http_status); + self + } + + pub fn build(self) -> ProviderError { + self.error + } + + fn maybe_code(self, code: Option) -> Self { + if let Some(code) = code { + self.code(code) + } else { + self + } + } + + fn maybe_request_id(self, request_id: Option) -> Self { + if let Some(request_id) = request_id { + self.request_id(request_id) + } else { + self + } + } + + fn maybe_http_status(self, http_status: Option) -> Self { + if let Some(http_status) = http_status { + self.http_status(http_status) + } else { + self + } + } +} + +fn json_scalar_to_string(value: &Value) -> Option { + if let Some(value) = value.as_str() { + return Some(value.to_string()); + } + if let Some(value) = value.as_i64() { + return Some(value.to_string()); + } + if let Some(value) = value.as_u64() { + return Some(value.to_string()); + } + if let Some(value) = value.as_bool() { + return Some(value.to_string()); + } + None +} + +fn classify_provider_error( + code: Option<&str>, + message: &str, + http_status: Option, +) -> ProviderErrorKind { + let message = message.to_lowercase(); + let code = code.unwrap_or_default().to_lowercase(); + + if matches!(http_status, Some(401)) || matches!(code.as_str(), "1000" | "1002") { + ProviderErrorKind::Auth + } else if matches!(http_status, Some(403)) || code == "1220" { + ProviderErrorKind::Permission + } else if matches!(http_status, Some(429)) || code == "1302" || message.contains("rate limit") { + ProviderErrorKind::RateLimit + } else if matches!(http_status, Some(402)) + || matches!(code.as_str(), "1113" | "insufficient_quota") + { + ProviderErrorKind::ProviderQuota + } else if code == "1309" || message.contains("billing") || message.contains("subscription") { + ProviderErrorKind::ProviderBilling + } else if matches!(http_status, Some(500..=599)) + || code == "1305" + || message.contains("overloaded") + || message.contains("temporarily unavailable") + || message.contains("service unavailable") + { + ProviderErrorKind::ProviderUnavailable + } else if code == "1301" + || message.contains("content policy") + || message.contains("content_filter") + { + ProviderErrorKind::ContentPolicy + } else if matches!(http_status, Some(400 | 413 | 422)) + || matches!(code.as_str(), "1210" | "1211" | "435") + || message.contains("invalid request") + || message.contains("invalid parameter") + || message.contains("model not found") + { + ProviderErrorKind::InvalidRequest + } else if message.contains("context window") + || message.contains("context length") + || message.contains("token limit") + { + ProviderErrorKind::ContextOverflow + } else if message.contains("timeout") || message.contains("timed out") { + ProviderErrorKind::Timeout + } else if message.contains("connection reset") + || message.contains("broken pipe") + || message.contains("stream closed") + { + ProviderErrorKind::Network + } else if code.is_empty() && http_status.is_none() { + ProviderErrorKind::Unknown + } else { + ProviderErrorKind::ModelError + } +} + +#[cfg(test)] +mod tests { + use super::{ProviderError, ProviderErrorKind}; + + #[test] + fn parses_json_http_error_body_into_provider_error() { + let error = ProviderError::from_http_error( + "OpenAI Streaming API", + 429, + r#"{"error":{"code":"rate_limit_exceeded","message":"too many requests"},"request_id":"req_http_429"}"#, + ); + + assert_eq!(error.provider(), Some("OpenAI Streaming API")); + assert_eq!(error.kind(), ProviderErrorKind::RateLimit); + assert_eq!(error.code(), Some("rate_limit_exceeded")); + assert_eq!(error.message(), "too many requests"); + assert_eq!(error.request_id(), Some("req_http_429")); + assert_eq!(error.http_status(), Some(429)); + assert!(error.is_retryable()); + } + + #[test] + fn classifies_plain_http_client_error_without_json() { + let error = ProviderError::from_http_error("Responses API", 401, "unauthorized"); + + assert_eq!(error.provider(), Some("Responses API")); + assert_eq!(error.kind(), ProviderErrorKind::Auth); + assert_eq!(error.message(), "unauthorized"); + assert_eq!(error.http_status(), Some(401)); + assert!(!error.is_retryable()); + } +} diff --git a/src/crates/ai-adapters/src/stream/stream_handler/anthropic.rs b/src/crates/ai-adapters/src/stream/stream_handler/anthropic.rs index 788b4772d..34925d697 100644 --- a/src/crates/ai-adapters/src/stream/stream_handler/anthropic.rs +++ b/src/crates/ai-adapters/src/stream/stream_handler/anthropic.rs @@ -1,12 +1,13 @@ use super::inline_think::InlineThinkParser; use super::stream_stats::StreamStats; -use super::{TimedStreamItem, next_stream_item}; +use super::{next_stream_item, TimedStreamItem}; +use crate::provider_error::ProviderError; use crate::stream::types::anthropic::{ AnthropicSSEError, ContentBlock, ContentBlockDelta, ContentBlockStart, MessageDelta, MessageStart, Usage, }; use crate::stream::types::unified::UnifiedResponse; -use anyhow::{Result, anyhow}; +use anyhow::{anyhow, Error, Result}; use eventsource_stream::Eventsource; use log::{error, trace}; use reqwest::Response; @@ -72,11 +73,11 @@ pub async fn handle_anthropic_stream( let _ = tx.send(format!("[{}] {}", event_type, data)); } - if let Some(error_msg) = format_provider_error_from_sse_message(&event_type, &data) { + if let Some(provider_error) = format_provider_error_from_sse_message(&event_type, &data) { stats.increment("error:provider_message"); stats.log_summary("provider_error_message_received"); - error!("{}", error_msg); - let _ = tx_event.send(Err(anyhow!(error_msg))); + error!("{}", provider_error); + let _ = tx_event.send(Err(Error::new(provider_error))); return; } @@ -205,36 +206,13 @@ pub async fn handle_anthropic_stream( } } -fn format_provider_error_from_sse_message(event_type: &str, data: &str) -> Option { +fn format_provider_error_from_sse_message(event_type: &str, data: &str) -> Option { if event_type != "message" { return None; } let value: serde_json::Value = serde_json::from_str(data).ok()?; - let error = value.get("error")?.as_object()?; - let code = error - .get("code") - .and_then(|value| value.as_str()) - .map(str::to_string) - .or_else(|| error.get("code").map(|value| value.to_string()))?; - let message = error - .get("message") - .and_then(|value| value.as_str()) - .unwrap_or("Provider returned an error"); - let request_id = value - .get("request_id") - .or_else(|| value.get("requestId")) - .and_then(|value| value.as_str()); - - let mut formatted = format!( - "Provider error: provider=anthropic_compatible, code={}, message={}", - code, message - ); - if let Some(request_id) = request_id { - formatted.push_str(&format!(", request_id={}", request_id)); - } - - Some(formatted) + ProviderError::from_error_payload("anthropic_compatible", &value) } fn emit_normalized_response( @@ -257,17 +235,20 @@ fn emit_normalized_response( #[cfg(test)] mod tests { use super::format_provider_error_from_sse_message; + use crate::provider_error::ProviderErrorKind; #[test] - fn extracts_glm_business_error_from_message_event() { + fn extracts_structured_glm_business_error_from_message_event() { let raw = r#"{"error":{"code":"1113","message":"余额不足或无可用资源包,请充值。"},"request_id":"20260425142416"}"#; - let formatted = format_provider_error_from_sse_message("message", raw).unwrap(); + let error = format_provider_error_from_sse_message("message", raw).unwrap(); - assert!(formatted.contains("Provider error")); - assert!(formatted.contains("code=1113")); - assert!(formatted.contains("余额不足或无可用资源包")); - assert!(formatted.contains("request_id=20260425142416")); + assert_eq!(error.provider(), Some("anthropic_compatible")); + assert_eq!(error.kind(), ProviderErrorKind::ProviderQuota); + assert_eq!(error.code(), Some("1113")); + assert_eq!(error.message(), "余额不足或无可用资源包,请充值。"); + assert_eq!(error.request_id(), Some("20260425142416")); + assert!(!error.is_retryable()); } #[test] diff --git a/src/crates/ai-adapters/src/stream/stream_handler/gemini.rs b/src/crates/ai-adapters/src/stream/stream_handler/gemini.rs index 4889e7243..5bf42b926 100644 --- a/src/crates/ai-adapters/src/stream/stream_handler/gemini.rs +++ b/src/crates/ai-adapters/src/stream/stream_handler/gemini.rs @@ -1,8 +1,9 @@ use super::stream_stats::StreamStats; use super::{next_stream_item, TimedStreamItem}; +use crate::provider_error::ProviderError; use crate::stream::types::gemini::GeminiSSEData; use crate::stream::types::unified::UnifiedResponse; -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, Error, Result}; use eventsource_stream::Eventsource; use log::{error, trace}; use reqwest::Response; @@ -61,15 +62,8 @@ impl GeminiToolCallState { } } -fn extract_api_error_message(event_json: &Value) -> Option { - let error = event_json.get("error")?; - if let Some(message) = error.get("message").and_then(Value::as_str) { - return Some(message.to_string()); - } - if let Some(message) = error.as_str() { - return Some(message.to_string()); - } - Some("Gemini streaming request failed".to_string()) +fn extract_api_error(event_json: &Value) -> Option { + ProviderError::from_error_payload("gemini", event_json) } pub async fn handle_gemini_stream( @@ -140,12 +134,12 @@ pub async fn handle_gemini_stream( } }; - if let Some(message) = extract_api_error_message(&event_json) { - let error_msg = format!("Gemini SSE API error: {}, data: {}", message, raw); + if let Some(api_error) = extract_api_error(&event_json) { + let error_msg = format!("Gemini SSE API error: {}, data: {}", api_error, raw); stats.increment("error:api"); stats.log_summary("sse_api_error"); error!("{}", error_msg); - let _ = tx_event.send(Err(anyhow!(error_msg))); + let _ = tx_event.send(Err(Error::new(api_error))); return; } @@ -193,7 +187,8 @@ pub async fn handle_gemini_stream( #[cfg(test)] mod tests { - use super::GeminiToolCallState; + use super::{extract_api_error, GeminiToolCallState}; + use crate::provider_error::ProviderErrorKind; use crate::stream::types::unified::UnifiedToolCall; #[test] @@ -280,4 +275,24 @@ mod tests { assert_ne!(first.id, second.id); } + + #[test] + fn extracts_structured_gemini_api_error() { + let event = serde_json::json!({ + "error": { + "code": 429, + "message": "rate limit exceeded" + }, + "request_id": "gemini_req_1" + }); + + let error = extract_api_error(&event).expect("provider error"); + + assert_eq!(error.provider(), Some("gemini")); + assert_eq!(error.kind(), ProviderErrorKind::RateLimit); + assert_eq!(error.code(), Some("429")); + assert_eq!(error.message(), "rate limit exceeded"); + assert_eq!(error.request_id(), Some("gemini_req_1")); + assert!(error.is_retryable()); + } } diff --git a/src/crates/ai-adapters/src/stream/stream_handler/openai.rs b/src/crates/ai-adapters/src/stream/stream_handler/openai.rs index ef670581e..a4abe5061 100644 --- a/src/crates/ai-adapters/src/stream/stream_handler/openai.rs +++ b/src/crates/ai-adapters/src/stream/stream_handler/openai.rs @@ -1,9 +1,10 @@ use super::inline_think::InlineThinkParser; use super::stream_stats::StreamStats; use super::{next_stream_item, TimedStreamItem}; +use crate::provider_error::ProviderError; use crate::stream::types::openai::{OpenAISSEData, OpenAIToolCallArgumentsNormalizer}; use crate::stream::types::unified::UnifiedResponse; -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, Error, Result}; use eventsource_stream::Eventsource; use log::{error, trace, warn}; use reqwest::Response; @@ -48,15 +49,8 @@ fn is_valid_chat_completion_chunk_weak(event_json: &Value) -> bool { ) } -fn extract_sse_api_error_message(event_json: &Value) -> Option { - let error = event_json.get("error")?; - if let Some(message) = error.get("message").and_then(|value| value.as_str()) { - return Some(message.to_string()); - } - if let Some(message) = error.as_str() { - return Some(message.to_string()); - } - Some("An error occurred during streaming".to_string()) +fn extract_sse_api_error(event_json: &Value) -> Option { + ProviderError::from_error_payload("openai_compatible", event_json) } /// Convert a byte stream into a structured response stream @@ -144,12 +138,12 @@ pub async fn handle_openai_stream( } }; - if let Some(api_error_message) = extract_sse_api_error_message(&event_json) { - let error_msg = format!("SSE API error: {}, data: {}", api_error_message, raw); + if let Some(api_error) = extract_sse_api_error(&event_json) { + let error_msg = format!("SSE API error: {}, data: {}", api_error, raw); stats.increment("error:api"); stats.log_summary("sse_api_error"); error!("{}", error_msg); - let _ = tx_event.send(Err(anyhow!(error_msg))); + let _ = tx_event.send(Err(Error::new(api_error))); return; } @@ -235,7 +229,8 @@ pub async fn handle_openai_stream( #[cfg(test)] mod tests { - use super::{extract_sse_api_error_message, is_valid_chat_completion_chunk_weak}; + use super::{extract_sse_api_error, is_valid_chat_completion_chunk_weak}; + use crate::provider_error::ProviderErrorKind; #[test] fn weak_filter_accepts_chat_completion_chunk() { @@ -262,27 +257,37 @@ mod tests { } #[test] - fn extracts_api_error_message_from_object_shape() { + fn extracts_structured_provider_error_from_object_shape() { let event = serde_json::json!({ "error": { + "code": "1305", "message": "provider error" - } + }, + "request_id": "req_1305" }); - assert_eq!( - extract_sse_api_error_message(&event).as_deref(), - Some("provider error") - ); + + let error = extract_sse_api_error(&event).expect("provider error"); + + assert_eq!(error.provider(), Some("openai_compatible")); + assert_eq!(error.kind(), ProviderErrorKind::ProviderUnavailable); + assert_eq!(error.code(), Some("1305")); + assert_eq!(error.message(), "provider error"); + assert_eq!(error.request_id(), Some("req_1305")); + assert!(error.is_retryable()); } #[test] - fn extracts_api_error_message_from_string_shape() { + fn extracts_structured_provider_error_from_string_shape() { let event = serde_json::json!({ "error": "provider error" }); - assert_eq!( - extract_sse_api_error_message(&event).as_deref(), - Some("provider error") - ); + + let error = extract_sse_api_error(&event).expect("provider error"); + + assert_eq!(error.provider(), Some("openai_compatible")); + assert_eq!(error.kind(), ProviderErrorKind::Unknown); + assert_eq!(error.message(), "provider error"); + assert_eq!(error.code(), None); } #[test] @@ -290,6 +295,6 @@ mod tests { let event = serde_json::json!({ "object": "chat.completion.chunk" }); - assert!(extract_sse_api_error_message(&event).is_none()); + assert!(extract_sse_api_error(&event).is_none()); } } diff --git a/src/crates/ai-adapters/src/stream/stream_handler/responses.rs b/src/crates/ai-adapters/src/stream/stream_handler/responses.rs index 3913574d4..71a5b1ec8 100644 --- a/src/crates/ai-adapters/src/stream/stream_handler/responses.rs +++ b/src/crates/ai-adapters/src/stream/stream_handler/responses.rs @@ -1,10 +1,11 @@ use super::stream_stats::StreamStats; use super::{next_stream_item, TimedStreamItem}; +use crate::provider_error::ProviderError; use crate::stream::types::responses::{ parse_responses_output_item, ResponsesCompleted, ResponsesDone, ResponsesStreamEvent, }; use crate::stream::types::unified::UnifiedResponse; -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, Error, Result}; use eventsource_stream::Eventsource; use log::{error, trace}; use reqwest::Response; @@ -157,7 +158,7 @@ fn handle_function_call_output_item_done( ); } -fn extract_api_error_message(event_json: &Value) -> Option { +fn extract_api_error(event_json: &Value) -> Option { let response = event_json.get("response")?; let error = response.get("error")?; @@ -165,14 +166,16 @@ fn extract_api_error_message(event_json: &Value) -> Option { return None; } - if let Some(message) = error.get("message").and_then(Value::as_str) { - return Some(message.to_string()); - } - if let Some(message) = error.as_str() { - return Some(message.to_string()); + let mut provider_payload = serde_json::Map::new(); + provider_payload.insert("error".to_string(), error.clone()); + if let Some(request_id) = event_json + .get("request_id") + .or_else(|| event_json.get("requestId")) + { + provider_payload.insert("request_id".to_string(), request_id.clone()); } - Some("An error occurred during responses streaming".to_string()) + ProviderError::from_error_payload("responses", &Value::Object(provider_payload)) } pub async fn handle_responses_stream( @@ -244,15 +247,12 @@ pub async fn handle_responses_stream( } }; - if let Some(api_error_message) = extract_api_error_message(&event_json) { - let error_msg = format!( - "Responses SSE API error: {}, data: {}", - api_error_message, raw - ); + if let Some(api_error) = extract_api_error(&event_json) { + let error_msg = format!("Responses SSE API error: {}, data: {}", api_error, raw); stats.increment("error:api"); stats.log_summary("sse_api_error"); error!("{}", error_msg); - let _ = tx_event.send(Err(anyhow!(error_msg))); + let _ = tx_event.send(Err(Error::new(api_error))); return; } @@ -539,28 +539,36 @@ pub async fn handle_responses_stream( #[cfg(test)] mod tests { use super::{ - super::stream_stats::StreamStats, extract_api_error_message, - handle_function_call_output_item_done, InProgressToolCall, + super::stream_stats::StreamStats, extract_api_error, handle_function_call_output_item_done, + InProgressToolCall, }; + use crate::provider_error::ProviderErrorKind; use serde_json::json; use std::collections::HashMap; use tokio::sync::mpsc; #[test] - fn extracts_api_error_message_from_response_error() { + fn extracts_structured_api_error_from_response_error() { let event = json!({ "type": "response.failed", + "request_id": "req_resp_1", "response": { + "id": "resp_1", "error": { + "code": "insufficient_quota", "message": "provider error" } } }); - assert_eq!( - extract_api_error_message(&event).as_deref(), - Some("provider error") - ); + let error = extract_api_error(&event).expect("provider error"); + + assert_eq!(error.provider(), Some("responses")); + assert_eq!(error.kind(), ProviderErrorKind::ProviderQuota); + assert_eq!(error.code(), Some("insufficient_quota")); + assert_eq!(error.message(), "provider error"); + assert_eq!(error.request_id(), Some("req_resp_1")); + assert!(!error.is_retryable()); } #[test] @@ -572,7 +580,7 @@ mod tests { } }); - assert!(extract_api_error_message(&event).is_none()); + assert!(extract_api_error(&event).is_none()); } #[test] @@ -585,7 +593,7 @@ mod tests { } }); - assert!(extract_api_error_message(&event).is_none()); + assert!(extract_api_error(&event).is_none()); } #[test] diff --git a/src/crates/core/src/agentic/execution/round_executor.rs b/src/crates/core/src/agentic/execution/round_executor.rs index 03d26e6e5..c00426f94 100644 --- a/src/crates/core/src/agentic/execution/round_executor.rs +++ b/src/crates/core/src/agentic/execution/round_executor.rs @@ -114,10 +114,11 @@ impl RoundExecutor { { Ok(response) => response, Err(e) => { - error!("AI request failed: {}", e); - let err_msg = e.to_string(); - let can_retry = attempt_index < max_attempts - 1 - && Self::is_transient_network_error(&err_msg); + let ai_error = BitFunError::from_ai_adapter_error(e); + error!("AI request failed: {}", ai_error); + let err_msg = ai_error.to_string(); + let can_retry = + attempt_index < max_attempts - 1 && Self::is_retryable_ai_error(&ai_error); if can_retry { let delay_ms = Self::retry_delay_ms(attempt_index); warn!( @@ -133,7 +134,7 @@ impl RoundExecutor { attempt_index += 1; continue; } - return Err(BitFunError::AIClient(err_msg)); + return Err(ai_error); } }; @@ -247,7 +248,7 @@ impl RoundExecutor { let err_msg = stream_err.error.to_string(); let can_retry = !stream_err.has_effective_output && attempt_index < max_attempts - 1 - && Self::is_transient_network_error(&err_msg); + && Self::is_retryable_ai_error(&stream_err.error); if can_retry { let delay_ms = Self::retry_delay_ms(attempt_index); warn!( @@ -668,6 +669,15 @@ impl RoundExecutor { Self::RETRY_BASE_DELAY_MS * (1u64 << attempt_index.min(3)) } + fn is_retryable_ai_error(error: &BitFunError) -> bool { + match error { + BitFunError::AIProvider(provider_error) => provider_error.is_retryable(), + BitFunError::AIClient(error_message) => Self::is_transient_network_error(error_message), + BitFunError::Timeout(_) => true, + _ => false, + } + } + fn is_transient_network_error(error_message: &str) -> bool { let msg = error_message.to_lowercase(); @@ -745,6 +755,8 @@ impl RoundExecutor { #[cfg(test)] mod tests { use super::RoundExecutor; + use crate::util::errors::BitFunError; + use bitfun_ai_adapters::{ProviderError, ProviderErrorKind}; #[test] fn detects_transient_stream_transport_error() { @@ -780,6 +792,32 @@ mod tests { assert!(!RoundExecutor::is_transient_network_error(billing)); } + #[test] + fn retries_structured_provider_unavailable_error() { + let err = BitFunError::AIProvider( + ProviderError::builder("provider temporarily overloaded") + .provider("openai_compatible") + .kind(ProviderErrorKind::ProviderUnavailable) + .code("1305") + .build(), + ); + + assert!(RoundExecutor::is_retryable_ai_error(&err)); + } + + #[test] + fn rejects_structured_provider_quota_error() { + let err = BitFunError::AIProvider( + ProviderError::builder("insufficient balance") + .provider("anthropic_compatible") + .kind(ProviderErrorKind::ProviderQuota) + .code("1113") + .build(), + ); + + assert!(!RoundExecutor::is_retryable_ai_error(&err)); + } + #[test] fn detects_interrupted_invalid_tool_only_recovery() { let result = crate::agentic::execution::stream_processor::StreamResult { diff --git a/src/crates/core/src/agentic/execution/stream_processor.rs b/src/crates/core/src/agentic/execution/stream_processor.rs index 10dd00052..be6d7ba5c 100644 --- a/src/crates/core/src/agentic/execution/stream_processor.rs +++ b/src/crates/core/src/agentic/execution/stream_processor.rs @@ -728,7 +728,8 @@ impl StreamProcessor { break; } TimedStreamItem::Item(Err(e)) => { - let error_msg = format!("Stream processing error: {}", e); + let stream_error = BitFunError::from_ai_adapter_error(e); + let error_msg = format!("Stream processing error: {}", stream_error); error!("{}", error_msg); if ctx.can_recover_as_partial_result() { flush_sse_on_error(&sse_collector, &error_msg).await; @@ -742,7 +743,7 @@ impl StreamProcessor { flush_sse_on_error(&sse_collector, &error_msg).await; self.graceful_shutdown_from_ctx(&mut ctx, error_msg.clone()).await; return Err(StreamProcessError::new( - BitFunError::AIClient(error_msg), + stream_error, ctx.has_effective_output, )); } diff --git a/src/crates/core/src/util/errors.rs b/src/crates/core/src/util/errors.rs index 5428521c1..f278be56e 100644 --- a/src/crates/core/src/util/errors.rs +++ b/src/crates/core/src/util/errors.rs @@ -2,6 +2,7 @@ //! //! Provide unified error types and handling for the whole application +use bitfun_ai_adapters::{ProviderError, ProviderErrorKind}; use bitfun_events::agentic::{AiErrorDetail, ErrorCategory}; use serde::Serialize; use thiserror::Error; @@ -21,6 +22,9 @@ pub enum BitFunError { #[error("AI client error: {0}")] AIClient(String), + #[error("AI provider error: {0}")] + AIProvider(ProviderError), + #[error("Session error: {0}")] Session(String), @@ -154,10 +158,19 @@ impl BitFunError { Self::Cancelled(msg.into()) } + pub fn from_ai_adapter_error(error: anyhow::Error) -> Self { + if let Some(provider_error) = error.downcast_ref::() { + return Self::AIProvider(provider_error.clone()); + } + + Self::AIClient(error.to_string()) + } + /// Infer an error category from this error for frontend-friendly classification. pub fn error_category(&self) -> ErrorCategory { match self { BitFunError::AIClient(msg) => classify_ai_error(msg), + BitFunError::AIProvider(err) => provider_error_kind_to_category(err.kind()), BitFunError::Timeout(_) => ErrorCategory::Timeout, BitFunError::Cancelled(_) => ErrorCategory::Unknown, _ => ErrorCategory::Unknown, @@ -168,6 +181,19 @@ impl BitFunError { pub fn error_detail(&self) -> AiErrorDetail { let category = self.error_category(); let message = self.to_string(); + if let BitFunError::AIProvider(err) = self { + return AiErrorDetail { + category: category.clone(), + provider: err.provider().map(str::to_string), + provider_code: err.code().map(str::to_string), + provider_message: Some(err.message().to_string()), + request_id: err.request_id().map(str::to_string), + http_status: err.http_status(), + retryable: Some(err.is_retryable()), + action_hints: action_hints_for_category(&category), + }; + } + AiErrorDetail { category: category.clone(), provider: extract_error_field(&message, "provider"), @@ -181,6 +207,24 @@ impl BitFunError { } } +fn provider_error_kind_to_category(kind: ProviderErrorKind) -> ErrorCategory { + match kind { + ProviderErrorKind::Network => ErrorCategory::Network, + ProviderErrorKind::Auth => ErrorCategory::Auth, + ProviderErrorKind::RateLimit => ErrorCategory::RateLimit, + ProviderErrorKind::ContextOverflow => ErrorCategory::ContextOverflow, + ProviderErrorKind::Timeout => ErrorCategory::Timeout, + ProviderErrorKind::ProviderQuota => ErrorCategory::ProviderQuota, + ProviderErrorKind::ProviderBilling => ErrorCategory::ProviderBilling, + ProviderErrorKind::ProviderUnavailable => ErrorCategory::ProviderUnavailable, + ProviderErrorKind::Permission => ErrorCategory::Permission, + ProviderErrorKind::InvalidRequest => ErrorCategory::InvalidRequest, + ProviderErrorKind::ContentPolicy => ErrorCategory::ContentPolicy, + ProviderErrorKind::ModelError => ErrorCategory::ModelError, + ProviderErrorKind::Unknown => ErrorCategory::Unknown, + } +} + /// Classify an AI client error message into a structured category. fn classify_ai_error(msg: &str) -> ErrorCategory { let m = msg.to_lowercase(); @@ -442,6 +486,7 @@ impl From for BitFunError { #[cfg(test)] mod tests { use super::BitFunError; + use bitfun_ai_adapters::{ProviderError, ProviderErrorKind}; use bitfun_events::agentic::ErrorCategory; #[test] @@ -472,4 +517,43 @@ mod tests { assert_eq!(err.error_category(), ErrorCategory::ProviderUnavailable); } + + #[test] + fn builds_error_detail_directly_from_structured_provider_error() { + let err = BitFunError::AIProvider( + ProviderError::builder("OpenAI-compatible provider is overloaded") + .provider("openai_compatible") + .kind(ProviderErrorKind::ProviderUnavailable) + .code("1305") + .request_id("req_1305") + .build(), + ); + + let detail = err.error_detail(); + + assert_eq!(detail.category, ErrorCategory::ProviderUnavailable); + assert_eq!(detail.provider.as_deref(), Some("openai_compatible")); + assert_eq!(detail.provider_code.as_deref(), Some("1305")); + assert_eq!( + detail.provider_message.as_deref(), + Some("OpenAI-compatible provider is overloaded") + ); + assert_eq!(detail.request_id.as_deref(), Some("req_1305")); + assert_eq!(detail.retryable, Some(true)); + } + + #[test] + fn preserves_structured_provider_error_through_anyhow_context() { + let provider_error = ProviderError::builder("provider temporarily overloaded") + .provider("openai_compatible") + .kind(ProviderErrorKind::ProviderUnavailable) + .code("1305") + .build(); + let anyhow_error = anyhow::Error::new(provider_error).context("request failed after retry"); + + let err = BitFunError::from_ai_adapter_error(anyhow_error); + + assert_eq!(err.error_category(), ErrorCategory::ProviderUnavailable); + assert!(matches!(err, BitFunError::AIProvider(_))); + } } diff --git a/src/crates/core/tests/fixtures/stream/openai/provider_error_with_code.sse b/src/crates/core/tests/fixtures/stream/openai/provider_error_with_code.sse new file mode 100644 index 000000000..e0cb2e437 --- /dev/null +++ b/src/crates/core/tests/fixtures/stream/openai/provider_error_with_code.sse @@ -0,0 +1,3 @@ +data: {"error":{"code":"1305","message":"provider temporarily overloaded"},"request_id":"req_1305"} + +data: [DONE] diff --git a/src/crates/core/tests/stream_processor_openai.rs b/src/crates/core/tests/stream_processor_openai.rs index 3cb6b5c83..9f064dbe1 100644 --- a/src/crates/core/tests/stream_processor_openai.rs +++ b/src/crates/core/tests/stream_processor_openai.rs @@ -1,6 +1,6 @@ mod common; -use bitfun_core::agentic::events::{AgenticEvent, ToolEventData}; +use bitfun_core::agentic::events::{AgenticEvent, ErrorCategory, ToolEventData}; use common::sse_fixture_server::FixtureSseServerOptions; use common::stream_test_harness::{ run_stream_fixture, run_stream_fixture_with_options, StreamFixtureProvider, @@ -8,6 +8,37 @@ use common::stream_test_harness::{ }; use serde_json::json; +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn openai_fixture_preserves_structured_provider_error_detail() { + let output = run_stream_fixture( + StreamFixtureProvider::OpenAi, + "stream/openai/provider_error_with_code.sse", + FixtureSseServerOptions::default(), + ) + .await; + + let err = output + .result + .expect_err("provider error should fail stream"); + let detail = err.error.error_detail(); + + assert_eq!( + detail.category, + ErrorCategory::ProviderUnavailable, + "error={:?}, detail={:?}", + err.error, + detail + ); + assert_eq!(detail.provider.as_deref(), Some("openai_compatible")); + assert_eq!(detail.provider_code.as_deref(), Some("1305")); + assert_eq!( + detail.provider_message.as_deref(), + Some("provider temporarily overloaded") + ); + assert_eq!(detail.request_id.as_deref(), Some("req_1305")); + assert_eq!(detail.retryable, Some(true)); +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn openai_fixture_keeps_collecting_tool_args_across_usage_chunks() { let output = run_stream_fixture(