Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 64 additions & 24 deletions sdk_v2/cpp/src/inferencing/generative/chat/chat_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,27 @@ SessionType ChatSession::Type() const {
return SessionType::kChat;
}

void ChatSession::Cancel() {
Session::Cancel();

std::lock_guard<std::mutex> lock(active_generator_mutex_);
if (active_generator_ != nullptr) {
active_generator_->Cancel();
}
}

void ChatSession::SetActiveGenerator(OnnxChatGenerator* generator) {
std::lock_guard<std::mutex> lock(active_generator_mutex_);
active_generator_ = generator;
}

void ChatSession::ClearActiveGenerator(OnnxChatGenerator* generator) {
std::lock_guard<std::mutex> lock(active_generator_mutex_);
if (active_generator_ == generator) {
active_generator_ = nullptr;
}
}

void ChatSession::SetSessionOptionsImpl(const KeyValuePairs& options) {
session_options_ = SearchOptions::FromParameters(options);
}
Expand Down Expand Up @@ -574,22 +595,32 @@ void ChatSession::ProcessRequestImpl(const Request& request, Response& response)
}
};

while (!cached_generator_->IsDone() && !request.canceled) {
cached_generator_->GenerateNextToken();
std::string token = cached_generator_->Decode();
++output_tokens;

if (!token.empty()) {
text += token;
emit_segments(splitter.Push(token));
}
OnnxChatGenerator* active_generator = cached_generator_.get();
SetActiveGenerator(active_generator);
try {
while (!cached_generator_->IsDone() && !request.canceled && !IsCancellationRequested()) {
cached_generator_->GenerateNextToken();
std::string token = cached_generator_->Decode();
++output_tokens;

if (!token.empty()) {
text += token;
emit_segments(splitter.Push(token));
}

// Enforce max_output_tokens — with use_full_context the OGA max_length
// is the entire context window, so we must cap output ourselves.
if (max_output > 0 && output_tokens >= max_output) {
break;
// Enforce max_output_tokens — with use_full_context the OGA max_length
// is the entire context window, so we must cap output ourselves.
if (max_output > 0 && output_tokens >= max_output) {
break;
}
}
} catch (...) {
ClearActiveGenerator(active_generator);
throw;
}
ClearActiveGenerator(active_generator);

const bool canceled = request.canceled || IsCancellationRequested();

// End-of-stream: drain the reasoning splitter first so any final DEFAULT bytes feed into the tool accumulator,
// then drain the tool accumulator.
Expand All @@ -598,17 +629,17 @@ void ChatSession::ProcessRequestImpl(const Request& request, Response& response)

int total_tokens = cached_generator_->TokenCount();

if (request.canceled) {
if (canceled) {
// Rewind the generator to undo this turn's input. The generator remains valid
// for the next attempt — the caller can re-send the same input.
cached_generator_->RewindTo(pre_turn_token_count);
}

ProcessGeneratedOutput(std::move(text), cached_tool_ctx_, effective_options, request.canceled,
ProcessGeneratedOutput(std::move(text), cached_tool_ctx_, effective_options, canceled,
response, prompt_tokens, total_tokens, std::move(streamed_tool_calls));

// Commit input messages + assistant reply to history only on success (not cancelled)
if (!request.canceled) {
if (!canceled) {
// LARK grammar (tool-call-only mode) is a single-shot finite parse. If generation was truncated while grammar was
// active, the parser is in an unrecoverable state. Additionally, a completed grammar signals EOS — IsDone() would
// return true on the next turn. Invalidate after any grammar-guided generation so the next turn rebuilds.
Expand Down Expand Up @@ -794,15 +825,24 @@ void ChatSession::ProcessChatCompletionsJson(const std::string& request_json, co

// Generate token-by-token
std::string text;
while (!generator->IsDone() && !original_request.canceled) {
generator->GenerateNextToken();
std::string token = generator->Decode();

if (!token.empty()) {
text += token;
process_segments(splitter.Push(token));
SetActiveGenerator(generator.get());
try {
while (!generator->IsDone() && !original_request.canceled && !IsCancellationRequested()) {
generator->GenerateNextToken();
std::string token = generator->Decode();

if (!token.empty()) {
text += token;
process_segments(splitter.Push(token));
}
}
} catch (...) {
ClearActiveGenerator(generator.get());
throw;
}
ClearActiveGenerator(generator.get());

const bool canceled = original_request.canceled || IsCancellationRequested();

// Drain any buffered partial-marker bytes at end-of-stream. Reasoning splitter first so any final DEFAULT bytes
// feed into the tool accumulator; then drain the tool accumulator.
Expand All @@ -818,7 +858,7 @@ void ChatSession::ProcessChatCompletionsJson(const std::string& request_json, co
// Process the generated output into response items (MessageItem, ToolCallItem, etc.)
// This also updates finish_reason, and usage on the response. Streamed-parsed tool calls are reused so call_ids
// stay stable across stream deltas and the final ChatCompletionResponse.
ProcessGeneratedOutput(std::move(text), tool_ctx, options, original_request.canceled,
ProcessGeneratedOutput(std::move(text), tool_ctx, options, canceled,
response, prompt_tokens, total_tokens, std::move(streamed_tool_calls));

// Emit final streaming chunk with finish_reason
Expand Down
8 changes: 8 additions & 0 deletions sdk_v2/cpp/src/inferencing/generative/chat/chat_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "logger.h"

#include <memory>
#include <mutex>
#include <string>
#include <vector>

Expand Down Expand Up @@ -69,6 +70,8 @@ class ChatSession : public Session {
/// @param count Number of turns to undo. Must be <= TurnCount().
void UndoTurns(size_t count) override;

void Cancel() override;

private:
// populate session_options_
void SetSessionOptionsImpl(const KeyValuePairs& options) override;
Expand Down Expand Up @@ -101,6 +104,9 @@ class ChatSession : public Session {
void ProcessChatCompletionsJson(const std::string& request_json, const Request& original_request,
Response& response);

void SetActiveGenerator(OnnxChatGenerator* generator);
void ClearActiveGenerator(OnnxChatGenerator* generator);

/// Commit input messages and assistant reply to history after a successful turn.
void CommitTurn(std::vector<MessageItem>&& new_messages, const Response& response,
int pre_turn_token_count, int post_turn_token_count);
Expand All @@ -120,6 +126,8 @@ class ChatSession : public Session {
// Cached generator for continuous decoding (non-JSON path only).
// Null until first non-JSON ProcessRequestImpl call.
std::unique_ptr<OnnxChatGenerator> cached_generator_;
std::mutex active_generator_mutex_;
OnnxChatGenerator* active_generator_ = nullptr;

// Tool context used when creating the cached generator.
// Reused for subsequent turns to maintain tool definition consistency.
Expand Down
1 change: 1 addition & 0 deletions sdk_v2/cpp/src/inferencing/model_load_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ bool ModelLoadManager::UnloadModel(std::string_view model_id) {

// Erasing destroys the GenAIModelInstance, which destroys OGA objects in reverse order.
loaded_models_.erase(it);
logger_.Log(LogLevel::Information, fmt::format("unloaded model: {}", id_str));
return true;
}

Expand Down
17 changes: 17 additions & 0 deletions sdk_v2/cpp/src/inferencing/session/session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <nlohmann/json.hpp>

#include <memory>
#include <utility>

namespace fl {

Expand All @@ -30,6 +31,18 @@ Session::Session(const fl::Model& catalog_model, ILogger& logger, ITelemetry& te

Session::~Session() = default;

Session::Session(Session&& other) noexcept
: catalog_model_(other.catalog_model_),
logger_(other.logger_),
telemetry_(other.telemetry_),
tool_definitions_(std::move(other.tool_definitions_)),
session_options_(std::move(other.session_options_)),
callback_fn_(std::move(other.callback_fn_)),
callback_user_data_(other.callback_user_data_),
allow_concurrent_requests_(other.allow_concurrent_requests_),
cancel_requested_(other.cancel_requested_.load(std::memory_order_relaxed)) {
}

std::unique_ptr<Session> Session::Create(const fl::Model& model) {
auto& mgr = Manager::Instance();
auto& telemetry = mgr.GetTelemetry();
Expand Down Expand Up @@ -114,4 +127,8 @@ void Session::ProcessRequest(const Request& request, Response& response) {
}
}

void Session::Cancel() {
cancel_requested_.store(true, std::memory_order_relaxed);
}

} // namespace fl
11 changes: 10 additions & 1 deletion sdk_v2/cpp/src/inferencing/session/session.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#pragma once

#include <algorithm>
#include <atomic>
#include <functional>
#include <memory>
#include <mutex>
Expand Down Expand Up @@ -34,7 +35,7 @@ class Session {
public:
virtual ~Session();

Session(Session&&) = default;
Session(Session&& other) noexcept;
Session& operator=(Session&&) = delete;

Session(const Session&) = delete;
Expand All @@ -52,6 +53,9 @@ class Session {
/// in-flight callbacks and ensures the Response is fully populated on return.
void ProcessRequest(const Request& request, Response& response);

/// Signal the active request, if any, to stop as soon as possible.
virtual void Cancel();

/// Add a tool definition to this session.
/// @throws fl::Exception if tool_def.json_schema is not valid JSON.
void AddToolDefinition(ToolDefinition tool_def);
Expand Down Expand Up @@ -108,6 +112,10 @@ class Session {

ILogger& Logger() { return logger_; }

bool IsCancellationRequested() const {
return cancel_requested_.load(std::memory_order_relaxed);
}

virtual void SetSessionOptionsImpl(const KeyValuePairs& /*options*/) {}

/// Merge session-level options with per-request options.
Expand Down Expand Up @@ -152,6 +160,7 @@ class Session {
StreamingCallbackFn callback_fn_;
void* callback_user_data_ = nullptr;
const bool allow_concurrent_requests_;
std::atomic<bool> cancel_requested_{false};
mutable std::unique_ptr<std::mutex> request_mutex_ = std::make_unique<std::mutex>();
};

Expand Down
19 changes: 15 additions & 4 deletions sdk_v2/cpp/src/inferencing/session/session_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

#include "exception.h"
#include "inferencing/generative/chat/chat_session.h"
#include "inferencing/session/session.h"

#include <cassert>
#include <fmt/format.h>
#include <vector>

namespace fl {

Expand All @@ -22,11 +24,11 @@ SessionManager::~SessionManager() {
}

void SessionManager::Register(Session& session) {
std::lock_guard<std::mutex> lock(mutex_);
if (shutting_down_.load()) {
FL_THROW(FOUNDRY_LOCAL_ERROR_INVALID_USAGE, "cannot create session during shutdown");
}

std::lock_guard<std::mutex> lock(mutex_);
sessions_.insert(&session);
}

Expand All @@ -53,11 +55,20 @@ void SessionManager::CancelAll() {
// Clear cache — frees idle cached sessions so they don't block drain.
ClearCache();

std::lock_guard<std::mutex> lock(mutex_);
std::vector<Session*> sessions;
{
std::lock_guard<std::mutex> lock(mutex_);
sessions.assign(sessions_.begin(), sessions_.end());
}

logger_.Log(LogLevel::Information,
fmt::format("SessionManager: cancelling all sessions ({} active)", sessions_.size()));
fmt::format("SessionManager: cancelling all sessions ({} active)", sessions.size()));

// Future (Phase 3): iterate sessions_ and call Cancel() on each
for (auto* session : sessions) {
if (session != nullptr) {
session->Cancel();
}
}
}

void SessionManager::WaitForDrain(std::chrono::milliseconds timeout) {
Expand Down
Loading