From 5b38569be279b225ea83e4501084dd7472ff441b Mon Sep 17 00:00:00 2001 From: Human <5217366+BrainSlugs83@users.noreply.github.com> Date: Thu, 4 Jun 2026 14:27:29 -0700 Subject: [PATCH] Fix cancel teardown shutdown crash Avoid the 0xC0000005 teardown AV by clearing the OGA log callback before model/ORT teardown, draining canceled sessions before unload, and leaving EP/env globals alive for the process lifetime. Wire cancellation through SessionManager, Session, ChatSession, and web service request tracking so shutdown stops in-flight work before unloading models. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../generative/chat/chat_session.cc | 88 ++++++++++++++----- .../generative/chat/chat_session.h | 8 ++ .../cpp/src/inferencing/model_load_manager.cc | 1 + sdk_v2/cpp/src/inferencing/session/session.cc | 17 ++++ sdk_v2/cpp/src/inferencing/session/session.h | 11 ++- .../inferencing/session/session_manager.cc | 19 +++- sdk_v2/cpp/src/manager.cc | 57 ++++++++---- sdk_v2/cpp/src/service/web_service.h | 17 ++-- 8 files changed, 158 insertions(+), 60 deletions(-) diff --git a/sdk_v2/cpp/src/inferencing/generative/chat/chat_session.cc b/sdk_v2/cpp/src/inferencing/generative/chat/chat_session.cc index eb76ef94d..97f490228 100644 --- a/sdk_v2/cpp/src/inferencing/generative/chat/chat_session.cc +++ b/sdk_v2/cpp/src/inferencing/generative/chat/chat_session.cc @@ -77,6 +77,27 @@ SessionType ChatSession::Type() const { return SessionType::kChat; } +void ChatSession::Cancel() { + Session::Cancel(); + + std::lock_guard lock(active_generator_mutex_); + if (active_generator_ != nullptr) { + active_generator_->Cancel(); + } +} + +void ChatSession::SetActiveGenerator(OnnxChatGenerator* generator) { + std::lock_guard lock(active_generator_mutex_); + active_generator_ = generator; +} + +void ChatSession::ClearActiveGenerator(OnnxChatGenerator* generator) { + std::lock_guard lock(active_generator_mutex_); + if (active_generator_ == generator) { + active_generator_ = nullptr; + } +} + void ChatSession::SetSessionOptionsImpl(const KeyValuePairs& options) { session_options_ = SearchOptions::FromParameters(options); } @@ -574,22 +595,32 @@ void ChatSession::ProcessRequestImpl(const Request& request, Response& response) } }; - while (!cached_generator_->IsDone() && !request.canceled) { - cached_generator_->GenerateNextToken(); - std::string token = cached_generator_->Decode(); - ++output_tokens; - - if (!token.empty()) { - text += token; - emit_segments(splitter.Push(token)); - } + OnnxChatGenerator* active_generator = cached_generator_.get(); + SetActiveGenerator(active_generator); + try { + while (!cached_generator_->IsDone() && !request.canceled && !IsCancellationRequested()) { + cached_generator_->GenerateNextToken(); + std::string token = cached_generator_->Decode(); + ++output_tokens; + + if (!token.empty()) { + text += token; + emit_segments(splitter.Push(token)); + } - // Enforce max_output_tokens — with use_full_context the OGA max_length - // is the entire context window, so we must cap output ourselves. - if (max_output > 0 && output_tokens >= max_output) { - break; + // Enforce max_output_tokens — with use_full_context the OGA max_length + // is the entire context window, so we must cap output ourselves. + if (max_output > 0 && output_tokens >= max_output) { + break; + } } + } catch (...) { + ClearActiveGenerator(active_generator); + throw; } + ClearActiveGenerator(active_generator); + + const bool canceled = request.canceled || IsCancellationRequested(); // End-of-stream: drain the reasoning splitter first so any final DEFAULT bytes feed into the tool accumulator, // then drain the tool accumulator. @@ -598,17 +629,17 @@ void ChatSession::ProcessRequestImpl(const Request& request, Response& response) int total_tokens = cached_generator_->TokenCount(); - if (request.canceled) { + if (canceled) { // Rewind the generator to undo this turn's input. The generator remains valid // for the next attempt — the caller can re-send the same input. cached_generator_->RewindTo(pre_turn_token_count); } - ProcessGeneratedOutput(std::move(text), cached_tool_ctx_, effective_options, request.canceled, + ProcessGeneratedOutput(std::move(text), cached_tool_ctx_, effective_options, canceled, response, prompt_tokens, total_tokens, std::move(streamed_tool_calls)); // Commit input messages + assistant reply to history only on success (not cancelled) - if (!request.canceled) { + if (!canceled) { // LARK grammar (tool-call-only mode) is a single-shot finite parse. If generation was truncated while grammar was // active, the parser is in an unrecoverable state. Additionally, a completed grammar signals EOS — IsDone() would // return true on the next turn. Invalidate after any grammar-guided generation so the next turn rebuilds. @@ -794,15 +825,24 @@ void ChatSession::ProcessChatCompletionsJson(const std::string& request_json, co // Generate token-by-token std::string text; - while (!generator->IsDone() && !original_request.canceled) { - generator->GenerateNextToken(); - std::string token = generator->Decode(); - - if (!token.empty()) { - text += token; - process_segments(splitter.Push(token)); + SetActiveGenerator(generator.get()); + try { + while (!generator->IsDone() && !original_request.canceled && !IsCancellationRequested()) { + generator->GenerateNextToken(); + std::string token = generator->Decode(); + + if (!token.empty()) { + text += token; + process_segments(splitter.Push(token)); + } } + } catch (...) { + ClearActiveGenerator(generator.get()); + throw; } + ClearActiveGenerator(generator.get()); + + const bool canceled = original_request.canceled || IsCancellationRequested(); // Drain any buffered partial-marker bytes at end-of-stream. Reasoning splitter first so any final DEFAULT bytes // feed into the tool accumulator; then drain the tool accumulator. @@ -818,7 +858,7 @@ void ChatSession::ProcessChatCompletionsJson(const std::string& request_json, co // Process the generated output into response items (MessageItem, ToolCallItem, etc.) // This also updates finish_reason, and usage on the response. Streamed-parsed tool calls are reused so call_ids // stay stable across stream deltas and the final ChatCompletionResponse. - ProcessGeneratedOutput(std::move(text), tool_ctx, options, original_request.canceled, + ProcessGeneratedOutput(std::move(text), tool_ctx, options, canceled, response, prompt_tokens, total_tokens, std::move(streamed_tool_calls)); // Emit final streaming chunk with finish_reason diff --git a/sdk_v2/cpp/src/inferencing/generative/chat/chat_session.h b/sdk_v2/cpp/src/inferencing/generative/chat/chat_session.h index 4bd761bc3..ef275d2a5 100644 --- a/sdk_v2/cpp/src/inferencing/generative/chat/chat_session.h +++ b/sdk_v2/cpp/src/inferencing/generative/chat/chat_session.h @@ -10,6 +10,7 @@ #include "logger.h" #include +#include #include #include @@ -69,6 +70,8 @@ class ChatSession : public Session { /// @param count Number of turns to undo. Must be <= TurnCount(). void UndoTurns(size_t count) override; + void Cancel() override; + private: // populate session_options_ void SetSessionOptionsImpl(const KeyValuePairs& options) override; @@ -101,6 +104,9 @@ class ChatSession : public Session { void ProcessChatCompletionsJson(const std::string& request_json, const Request& original_request, Response& response); + void SetActiveGenerator(OnnxChatGenerator* generator); + void ClearActiveGenerator(OnnxChatGenerator* generator); + /// Commit input messages and assistant reply to history after a successful turn. void CommitTurn(std::vector&& new_messages, const Response& response, int pre_turn_token_count, int post_turn_token_count); @@ -120,6 +126,8 @@ class ChatSession : public Session { // Cached generator for continuous decoding (non-JSON path only). // Null until first non-JSON ProcessRequestImpl call. std::unique_ptr cached_generator_; + std::mutex active_generator_mutex_; + OnnxChatGenerator* active_generator_ = nullptr; // Tool context used when creating the cached generator. // Reused for subsequent turns to maintain tool definition consistency. diff --git a/sdk_v2/cpp/src/inferencing/model_load_manager.cc b/sdk_v2/cpp/src/inferencing/model_load_manager.cc index 0bc321b9b..9dcacb978 100644 --- a/sdk_v2/cpp/src/inferencing/model_load_manager.cc +++ b/sdk_v2/cpp/src/inferencing/model_load_manager.cc @@ -194,6 +194,7 @@ bool ModelLoadManager::UnloadModel(std::string_view model_id) { // Erasing destroys the GenAIModelInstance, which destroys OGA objects in reverse order. loaded_models_.erase(it); + logger_.Log(LogLevel::Information, fmt::format("unloaded model: {}", id_str)); return true; } diff --git a/sdk_v2/cpp/src/inferencing/session/session.cc b/sdk_v2/cpp/src/inferencing/session/session.cc index e8f90ee69..a7765ae6a 100644 --- a/sdk_v2/cpp/src/inferencing/session/session.cc +++ b/sdk_v2/cpp/src/inferencing/session/session.cc @@ -17,6 +17,7 @@ #include #include +#include namespace fl { @@ -30,6 +31,18 @@ Session::Session(const fl::Model& catalog_model, ILogger& logger, ITelemetry& te Session::~Session() = default; +Session::Session(Session&& other) noexcept + : catalog_model_(other.catalog_model_), + logger_(other.logger_), + telemetry_(other.telemetry_), + tool_definitions_(std::move(other.tool_definitions_)), + session_options_(std::move(other.session_options_)), + callback_fn_(std::move(other.callback_fn_)), + callback_user_data_(other.callback_user_data_), + allow_concurrent_requests_(other.allow_concurrent_requests_), + cancel_requested_(other.cancel_requested_.load(std::memory_order_relaxed)) { +} + std::unique_ptr Session::Create(const fl::Model& model) { auto& mgr = Manager::Instance(); auto& telemetry = mgr.GetTelemetry(); @@ -114,4 +127,8 @@ void Session::ProcessRequest(const Request& request, Response& response) { } } +void Session::Cancel() { + cancel_requested_.store(true, std::memory_order_relaxed); +} + } // namespace fl diff --git a/sdk_v2/cpp/src/inferencing/session/session.h b/sdk_v2/cpp/src/inferencing/session/session.h index f3ec0e5f1..941958556 100644 --- a/sdk_v2/cpp/src/inferencing/session/session.h +++ b/sdk_v2/cpp/src/inferencing/session/session.h @@ -3,6 +3,7 @@ #pragma once #include +#include #include #include #include @@ -34,7 +35,7 @@ class Session { public: virtual ~Session(); - Session(Session&&) = default; + Session(Session&& other) noexcept; Session& operator=(Session&&) = delete; Session(const Session&) = delete; @@ -52,6 +53,9 @@ class Session { /// in-flight callbacks and ensures the Response is fully populated on return. void ProcessRequest(const Request& request, Response& response); + /// Signal the active request, if any, to stop as soon as possible. + virtual void Cancel(); + /// Add a tool definition to this session. /// @throws fl::Exception if tool_def.json_schema is not valid JSON. void AddToolDefinition(ToolDefinition tool_def); @@ -108,6 +112,10 @@ class Session { ILogger& Logger() { return logger_; } + bool IsCancellationRequested() const { + return cancel_requested_.load(std::memory_order_relaxed); + } + virtual void SetSessionOptionsImpl(const KeyValuePairs& /*options*/) {} /// Merge session-level options with per-request options. @@ -152,6 +160,7 @@ class Session { StreamingCallbackFn callback_fn_; void* callback_user_data_ = nullptr; const bool allow_concurrent_requests_; + std::atomic cancel_requested_{false}; mutable std::unique_ptr request_mutex_ = std::make_unique(); }; diff --git a/sdk_v2/cpp/src/inferencing/session/session_manager.cc b/sdk_v2/cpp/src/inferencing/session/session_manager.cc index a48e81bb6..8d2e9091c 100644 --- a/sdk_v2/cpp/src/inferencing/session/session_manager.cc +++ b/sdk_v2/cpp/src/inferencing/session/session_manager.cc @@ -4,9 +4,11 @@ #include "exception.h" #include "inferencing/generative/chat/chat_session.h" +#include "inferencing/session/session.h" #include #include +#include namespace fl { @@ -22,11 +24,11 @@ SessionManager::~SessionManager() { } void SessionManager::Register(Session& session) { + std::lock_guard lock(mutex_); if (shutting_down_.load()) { FL_THROW(FOUNDRY_LOCAL_ERROR_INVALID_USAGE, "cannot create session during shutdown"); } - std::lock_guard lock(mutex_); sessions_.insert(&session); } @@ -53,11 +55,20 @@ void SessionManager::CancelAll() { // Clear cache — frees idle cached sessions so they don't block drain. ClearCache(); - std::lock_guard lock(mutex_); + std::vector sessions; + { + std::lock_guard lock(mutex_); + sessions.assign(sessions_.begin(), sessions_.end()); + } + logger_.Log(LogLevel::Information, - fmt::format("SessionManager: cancelling all sessions ({} active)", sessions_.size())); + fmt::format("SessionManager: cancelling all sessions ({} active)", sessions.size())); - // Future (Phase 3): iterate sessions_ and call Cancel() on each + for (auto* session : sessions) { + if (session != nullptr) { + session->Cancel(); + } + } } void SessionManager::WaitForDrain(std::chrono::milliseconds timeout) { diff --git a/sdk_v2/cpp/src/manager.cc b/sdk_v2/cpp/src/manager.cc index f7c1e4bb1..a5e9ad79f 100644 --- a/sdk_v2/cpp/src/manager.cc +++ b/sdk_v2/cpp/src/manager.cc @@ -312,38 +312,51 @@ Manager::~Manager() { // we unregister EPs and release the env. C++ would destroy these in reverse // declaration order after this function returns, but the env release below // requires they be gone *now*. + if (s_oga_logger.load(std::memory_order_acquire) != nullptr) { + logger_->Log(LogLevel::Information, "Manager teardown: clearing OGA log callback (before ORT/EP teardown)"); + SetOgaLogCallback(nullptr); + logger_->Log(LogLevel::Information, "Manager teardown: cleared OGA log callback"); + } + #ifdef FOUNDRY_LOCAL_HAS_WEB_SERVICE + logger_->Log(LogLevel::Information, "Manager teardown: resetting web_service_"); web_service_.reset(); + logger_->Log(LogLevel::Information, "Manager teardown: reset web_service_"); #endif + logger_->Log(LogLevel::Information, "Manager teardown: resetting session_manager_"); session_manager_.reset(); + logger_->Log(LogLevel::Information, "Manager teardown: reset session_manager_"); + logger_->Log(LogLevel::Information, "Manager teardown: resetting model_load_manager_"); model_load_manager_.reset(); + logger_->Log(LogLevel::Information, "Manager teardown: reset model_load_manager_"); + logger_->Log(LogLevel::Information, "Manager teardown: resetting download_manager_"); download_manager_.reset(); + logger_->Log(LogLevel::Information, "Manager teardown: reset download_manager_"); + logger_->Log(LogLevel::Information, "Manager teardown: resetting catalog_"); catalog_.reset(); + logger_->Log(LogLevel::Information, "Manager teardown: reset catalog_"); + logger_->Log(LogLevel::Information, "Manager teardown: resetting telemetry_"); telemetry_.reset(); + logger_->Log(LogLevel::Information, "Manager teardown: reset telemetry_"); + logger_->Log(LogLevel::Information, "Manager teardown: resetting ep_detector_"); ep_detector_.reset(); + logger_->Log(LogLevel::Information, "Manager teardown: reset ep_detector_"); - // Unregister EPs we registered, then drop our OrtEnv refcount. Best-effort: - // log failures but don't throw from a destructor. + // Keep ORT global EP/env state alive for process lifetime. ORT GenAI currently + // has process-exit teardown races after in-process model/service disposal. + // The loaded models and sessions are already gone above; avoiding global + // EP/env teardown prevents a late AV in embedded hosts. if (ort_api_ != nullptr && ort_env_ != nullptr) { for (const auto& name : registered_ep_libraries_) { - OrtStatus* status = ort_api_->UnregisterExecutionProviderLibrary(ort_env_, name.c_str()); - if (status != nullptr) { - const char* msg = ort_api_->GetErrorMessage(status); - logger_->Log(LogLevel::Warning, - std::string("EP unregister: UnregisterExecutionProviderLibrary failed for '") + - name + "': " + (msg ? msg : "unknown")); - ort_api_->ReleaseStatus(status); - } + logger_->Log(LogLevel::Information, + std::string("Manager teardown: leaving EP library registered '") + name + "'"); } - ort_api_->ReleaseEnv(ort_env_); + logger_->Log(LogLevel::Information, + "Manager teardown: leaving OrtEnv alive to avoid ORT GenAI teardown AV after EP unregister"); ort_env_ = nullptr; } - if (s_oga_logger.load(std::memory_order_acquire) != nullptr) { - SetOgaLogCallback(nullptr); - } - logger_->Log(LogLevel::Information, "Manager is being disposed."); } @@ -496,10 +509,6 @@ void Manager::Shutdown() { logger_->Log(LogLevel::Information, "Shutdown requested"); - if (web_service_running_) { - StopWebService(); - } - // Order matters: // 1. Reject new loads so callers gated on IsShutdownRequested can stop early. // 2. Cancel + drain HTTP-tracked sessions (web service path). @@ -508,8 +517,18 @@ void Manager::Shutdown() { // caller can't block process shutdown indefinitely. model_load_manager_->RejectNewLoads(); session_manager_->CancelAll(); + if (web_service_running_) { + StopWebService(); + } session_manager_->WaitForDrain(); + if (s_oga_logger.load(std::memory_order_acquire) != nullptr) { + logger_->Log(LogLevel::Information, "Shutdown: clearing OGA log callback before model unload"); + SetOgaLogCallback(nullptr); + logger_->Log(LogLevel::Information, "Shutdown: cleared OGA log callback before model unload"); + } + logger_->Log(LogLevel::Information, "Shutdown: unloading all models"); model_load_manager_->UnloadAll(); + logger_->Log(LogLevel::Information, "Shutdown: unloaded all models"); } bool Manager::IsShutdownRequested() const { diff --git a/sdk_v2/cpp/src/service/web_service.h b/sdk_v2/cpp/src/service/web_service.h index b7d72bd17..73252855a 100644 --- a/sdk_v2/cpp/src/service/web_service.h +++ b/sdk_v2/cpp/src/service/web_service.h @@ -21,7 +21,7 @@ class ResponseStore; /// Tracks streaming threads so they can be joined on shutdown. /// Handlers call Track() instead of std::thread::detach(). -/// Threads call Remove() when done to clean up immediately. +/// Threads call Remove() when done; ownership stays with the tracker so Stop() can join. class StreamingThreadTracker { public: /// Take ownership of a streaming thread. @@ -30,20 +30,13 @@ class StreamingThreadTracker { threads_.push_back(std::move(t)); } - /// Called from within a thread to untrack itself after work is done. - /// Detaches the thread (can't join itself) and removes the entry. + /// Called from within a thread after work is done. Do not detach/erase here: + /// JoinAll() owns reaping so Remove() cannot race it and hide a joinable thread. void Remove(std::thread::id id) { - std::lock_guard lock(mutex_); - for (auto it = threads_.begin(); it != threads_.end(); ++it) { - if (it->get_id() == id) { - it->detach(); - threads_.erase(it); - return; - } - } + (void)id; } - /// Join all remaining threads. Called by WebService::Stop(). + /// Join all tracked threads. Called by WebService::Stop(). /// Moves entries out before joining to avoid deadlock with Remove(). void JoinAll() { std::vector local;