diff --git a/sdk_v2/cpp/CMakeLists.txt b/sdk_v2/cpp/CMakeLists.txt index c203deec6..4fabe806d 100644 --- a/sdk_v2/cpp/CMakeLists.txt +++ b/sdk_v2/cpp/CMakeLists.txt @@ -149,8 +149,11 @@ set(FOUNDRY_LOCAL_SOURCES src/inferencing/generative/chat/chat_session.cc src/inferencing/generative/chat/chat_template.cc src/configuration.cc + src/download/blob_download_state.cc src/download/blob_downloader.cc + src/download/cross_process_file_lock.cc src/download/download_manager.cc + src/download/file_writer.cc src/download/inference_model_writer.cc src/download/model_registry_client.cc src/ep_detection/cuda_ep_bootstrapper.cc diff --git a/sdk_v2/cpp/src/download/blob_download_state.cc b/sdk_v2/cpp/src/download/blob_download_state.cc new file mode 100644 index 000000000..1cf8ae9b3 --- /dev/null +++ b/sdk_v2/cpp/src/download/blob_download_state.cc @@ -0,0 +1,381 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "download/blob_download_state.h" +#include "logger.h" + +#include +#include +#include +#include + +namespace fl { + +namespace { + +constexpr const char* kStateFileExtension = ".dlstate"; + +// On-disk format (little-endian throughout): +// bytes | field +// -------|-------------------------------------------------------- +// 0..3 | magic "FLDS" +// 4 | version (currently 1) +// 5..12 | blob_size (int64) +// 13..16 | chunk_size (int32) +// 17..20 | total_chunks (int32) +// 21..24 | bitmap_byte_aligned_start (int32) +// 25..28 | highest_completed_chunk (int32) +// 29..32 | completed_count (int32) +// 33..40 | last_modified_unix_ms (int64) +// 41..44 | trunc_bitmap_byte_len (uint32) +// 45.. | trunc_bitmap_byte_len bytes of bitmap data, copied directly out of +// full_completion_bitmap starting at the byte offset implied by +// bitmap_byte_aligned_start. +constexpr char kMagic[4] = {'F', 'L', 'D', 'S'}; +constexpr uint8_t kVersion = 1; + +constexpr int32_t kBitsPerWord = 64; + +template +void WriteLE(std::ostream& out, T value) { + static_assert(std::is_trivially_copyable_v); + unsigned char buf[sizeof(T)]; + std::memcpy(buf, &value, sizeof(T)); + out.write(reinterpret_cast(buf), sizeof(T)); +} + +template +bool ReadLE(std::istream& in, T& out_value) { + static_assert(std::is_trivially_copyable_v); + unsigned char buf[sizeof(T)]; + in.read(reinterpret_cast(buf), sizeof(T)); + if (!in) { + return false; + } + std::memcpy(&out_value, buf, sizeof(T)); + return true; +} + +int64_t NowUnixMs() { + return std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(); +} + +} // namespace + +std::filesystem::path BlobDownloadState::GetStateFilePath(const std::filesystem::path& local_file_path) { + auto p = local_file_path; + p += kStateFileExtension; + return p; +} + +std::unique_ptr BlobDownloadState::CreateNew(std::string blob_name, + std::filesystem::path local_file_path, + int64_t blob_size, + int32_t chunk_size, + int32_t total_chunks) { + auto state = std::make_unique(); + state->blob_name = std::move(blob_name); + state->local_file_path = local_file_path.string(); + state->blob_size = blob_size; + state->chunk_size = chunk_size; + state->total_chunks = total_chunks; + state->bitmap_byte_aligned_start = 0; + state->highest_completed_chunk = -1; + state->completed_count = 0; + state->last_modified_unix_ms = NowUnixMs(); + auto words = static_cast((total_chunks + kBitsPerWord - 1) / kBitsPerWord); + state->full_completion_bitmap.assign(words, 0); + return state; +} + +std::unique_ptr BlobDownloadState::LoadState(std::string blob_name, + std::filesystem::path local_file_path, + int64_t expected_blob_size, + int32_t expected_chunk_size, + int32_t expected_total_chunks, + ILogger* logger) { + auto state_path = GetStateFilePath(local_file_path); + std::error_code ec; + if (!std::filesystem::exists(state_path, ec)) { + return nullptr; + } + + std::ifstream in(state_path, std::ios::binary); + if (!in) { + if (logger) { + logger->Log(LogLevel::Warning, "Could not open download state file: " + state_path.string()); + } + return nullptr; + } + + char magic[4]{}; + in.read(magic, 4); + uint8_t version = 0; + if (!in || std::memcmp(magic, kMagic, 4) != 0 || !ReadLE(in, version) || version != kVersion) { + if (logger) { + logger->Log(LogLevel::Warning, + "Download state file " + state_path.string() + " has unexpected magic/version; ignoring"); + } + return nullptr; + } + + int64_t blob_size = 0; + int32_t chunk_size = 0; + int32_t total_chunks = 0; + int32_t bitmap_byte_aligned_start = 0; + int32_t highest_completed_chunk = 0; + int32_t completed_count = 0; + int64_t last_modified_unix_ms = 0; + uint32_t trunc_len = 0; + if (!ReadLE(in, blob_size) || !ReadLE(in, chunk_size) || !ReadLE(in, total_chunks) || + !ReadLE(in, bitmap_byte_aligned_start) || !ReadLE(in, highest_completed_chunk) || + !ReadLE(in, completed_count) || !ReadLE(in, last_modified_unix_ms) || !ReadLE(in, trunc_len)) { + if (logger) { + logger->Log(LogLevel::Warning, "Download state header truncated: " + state_path.string()); + } + return nullptr; + } + + // Sanity / compatibility checks. + if (blob_size != expected_blob_size || chunk_size != expected_chunk_size || + total_chunks != expected_total_chunks) { + if (logger) { + logger->Log(LogLevel::Information, + "Download state for " + state_path.string() + + " is incompatible with current blob layout; starting fresh"); + } + return nullptr; + } + if (bitmap_byte_aligned_start < 0 || bitmap_byte_aligned_start % 8 != 0 || + bitmap_byte_aligned_start > total_chunks || completed_count < 0 || + completed_count > total_chunks || highest_completed_chunk < -1 || + highest_completed_chunk >= total_chunks) { + if (logger) { + logger->Log(LogLevel::Warning, "Download state header values out of range: " + state_path.string()); + } + return nullptr; + } + + auto words_total = static_cast((total_chunks + kBitsPerWord - 1) / kBitsPerWord); + std::vector bitmap(words_total, 0); + + // The prefix of fully-completed chunks below bitmap_byte_aligned_start is + // implied — fill those bits. + size_t implicit_full_words = static_cast(bitmap_byte_aligned_start) / kBitsPerWord; + for (size_t i = 0; i < implicit_full_words && i < bitmap.size(); ++i) { + bitmap[i] = ~uint64_t{0}; + } + // Any remaining "implicit" bits inside a partial word (between + // implicit_full_words*64 and bitmap_byte_aligned_start). + if (size_t partial_bits = static_cast(bitmap_byte_aligned_start) % kBitsPerWord; + partial_bits > 0 && implicit_full_words < bitmap.size()) { + bitmap[implicit_full_words] |= (uint64_t{1} << partial_bits) - 1; + } + + if (trunc_len > 0) { + // Copy serialized bytes directly into the bitmap starting at the byte + // position implied by bitmap_byte_aligned_start. + size_t byte_offset = static_cast(bitmap_byte_aligned_start) / 8; + auto* dest = reinterpret_cast(bitmap.data()) + byte_offset; + auto dest_capacity = bitmap.size() * sizeof(uint64_t) - byte_offset; + if (trunc_len > dest_capacity) { + if (logger) { + logger->Log(LogLevel::Warning, + "Download state bitmap length exceeds expected capacity: " + state_path.string()); + } + return nullptr; + } + in.read(reinterpret_cast(dest), trunc_len); + if (!in) { + if (logger) { + logger->Log(LogLevel::Warning, + "Download state bitmap payload truncated: " + state_path.string()); + } + return nullptr; + } + } + + auto state = std::make_unique(); + state->blob_name = std::move(blob_name); + state->local_file_path = local_file_path.string(); + state->blob_size = blob_size; + state->chunk_size = chunk_size; + state->total_chunks = total_chunks; + state->bitmap_byte_aligned_start = bitmap_byte_aligned_start; + state->highest_completed_chunk = highest_completed_chunk; + state->completed_count = completed_count; + state->last_modified_unix_ms = last_modified_unix_ms; + state->full_completion_bitmap = std::move(bitmap); + + if (logger) { + logger->Log(LogLevel::Information, + "Loaded download state " + state_path.string() + ": " + + std::to_string(completed_count) + "/" + std::to_string(total_chunks) + + " chunks already done"); + } + return state; +} + +int64_t BlobDownloadState::CalculateDownloadedSize() const noexcept { + int64_t bytes = static_cast(completed_count) * chunk_size; + // If the final chunk is partial and was completed, adjust the overcount. + if (highest_completed_chunk == total_chunks - 1 && chunk_size > 0) { + auto remainder = blob_size % chunk_size; + if (remainder != 0) { + bytes -= (chunk_size - remainder); + } + } + return bytes; +} + +bool BlobDownloadState::IsChunkComplete(int32_t chunk_idx) const noexcept { + if (chunk_idx < 0 || chunk_idx >= total_chunks) { + return false; + } + if (chunk_idx < bitmap_byte_aligned_start) { + // Below the truncation point — implicitly complete. + return true; + } + auto word_idx = static_cast(chunk_idx) / kBitsPerWord; + auto bit_idx = static_cast(chunk_idx) % kBitsPerWord; + if (word_idx >= full_completion_bitmap.size()) { + return false; + } + return (full_completion_bitmap[word_idx] & (uint64_t{1} << bit_idx)) != 0; +} + +void BlobDownloadState::MarkChunkComplete(int32_t chunk_idx) { + if (chunk_idx < 0 || chunk_idx >= total_chunks) { + return; + } + if (IsChunkComplete(chunk_idx)) { + return; + } + if (chunk_idx > highest_completed_chunk) { + highest_completed_chunk = chunk_idx; + } + auto word_idx = static_cast(chunk_idx) / kBitsPerWord; + auto bit_idx = static_cast(chunk_idx) % kBitsPerWord; + full_completion_bitmap[word_idx] |= (uint64_t{1} << bit_idx); + ++completed_count; +} + +std::vector BlobDownloadState::GetPendingChunks() const { + std::vector pending; + pending.reserve(static_cast(total_chunks - completed_count)); + for (int32_t i = bitmap_byte_aligned_start; i < total_chunks; ++i) { + if (!IsChunkComplete(i)) { + pending.push_back(i); + } + } + return pending; +} + +void BlobDownloadState::SaveState(ILogger* logger) { + // Advance bitmap_byte_aligned_start past any words that are now all 1s, so + // the next save serializes only the unfinished tail. + int32_t new_start = bitmap_byte_aligned_start; + size_t word_idx = static_cast(new_start) / kBitsPerWord; + while (word_idx < full_completion_bitmap.size() && + full_completion_bitmap[word_idx] == ~uint64_t{0}) { + new_start += kBitsPerWord; + ++word_idx; + } + // Within the first not-fully-set word, advance to the lowest 0 bit and round + // down to a byte boundary (8 bits) so reload-then-resume re-reads on a clean + // alignment. + if (word_idx < full_completion_bitmap.size()) { + uint64_t inverted = ~full_completion_bitmap[word_idx]; + int trailing_zero = 0; + while (trailing_zero < kBitsPerWord && ((inverted >> trailing_zero) & 1) == 0) { + ++trailing_zero; + } + new_start += trailing_zero; + } + new_start = (new_start / 8) * 8; + if (new_start > total_chunks) { + new_start = (total_chunks / 8) * 8; + } + if (new_start > bitmap_byte_aligned_start) { + bitmap_byte_aligned_start = new_start; + } + + last_modified_unix_ms = NowUnixMs(); + + auto state_path = GetStateFilePath(local_file_path); + auto tmp_path = state_path; + tmp_path += ".tmp"; + + // Compute the serialized bitmap payload: bytes from bitmap_byte_aligned_start + // up to (highest_completed_chunk + 1), rounded up to the nearest byte. + uint32_t trunc_len = 0; + if (highest_completed_chunk >= bitmap_byte_aligned_start) { + int32_t bit_count = highest_completed_chunk - bitmap_byte_aligned_start + 1; + trunc_len = static_cast((bit_count + 7) / 8); + } + size_t byte_offset = static_cast(bitmap_byte_aligned_start) / 8; + + { + std::ofstream out(tmp_path, std::ios::binary | std::ios::trunc); + if (!out) { + if (logger) { + logger->Log(LogLevel::Warning, "Failed to open download state tmp file: " + tmp_path.string()); + } + return; + } + out.write(kMagic, 4); + WriteLE(out, kVersion); + WriteLE(out, blob_size); + WriteLE(out, chunk_size); + WriteLE(out, total_chunks); + WriteLE(out, bitmap_byte_aligned_start); + WriteLE(out, highest_completed_chunk); + WriteLE(out, completed_count); + WriteLE(out, last_modified_unix_ms); + WriteLE(out, trunc_len); + if (trunc_len > 0) { + auto* src = reinterpret_cast(full_completion_bitmap.data()) + byte_offset; + out.write(reinterpret_cast(src), trunc_len); + } + if (!out) { + if (logger) { + logger->Log(LogLevel::Warning, "Failed to write download state tmp file: " + tmp_path.string()); + } + return; + } + } + + std::error_code ec; + std::filesystem::rename(tmp_path, state_path, ec); + if (ec) { + // std::filesystem::rename atomically replaces the destination on every + // platform we target (POSIX rename(2); Windows MoveFileExW with + // MOVEFILE_REPLACE_EXISTING). If it still fails, the cause is transient + // (e.g. a brief sharing violation on Windows or a flaky network FS) — + // do NOT delete state_path as a fallback; that loses the only intact + // copy of the resume bitmap. Instead, drop the tmp file and let the + // next SaveState call retry from the up-to-date in-memory state. + std::error_code rm_ec; + std::filesystem::remove(tmp_path, rm_ec); + if (logger) { + logger->Log(LogLevel::Warning, + "Failed to commit download state file: " + tmp_path.string() + " -> " + + state_path.string() + " (" + ec.message() + + "); previous state retained, will retry on next save"); + } + } +} + +void BlobDownloadState::DeleteState(const std::filesystem::path& local_file_path, ILogger* logger) { + auto state_path = GetStateFilePath(local_file_path); + std::error_code ec; + std::filesystem::remove(state_path, ec); + if (ec && logger) { + logger->Log(LogLevel::Warning, + "Failed to delete download state file: " + state_path.string() + " (" + + ec.message() + ")"); + } +} + +} // namespace fl diff --git a/sdk_v2/cpp/src/download/blob_download_state.h b/sdk_v2/cpp/src/download/blob_download_state.h new file mode 100644 index 000000000..66cc69dbf --- /dev/null +++ b/sdk_v2/cpp/src/download/blob_download_state.h @@ -0,0 +1,116 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace fl { + +class ILogger; + +/// Per-blob download progress, persisted next to the data file as `.dlstate`. +/// +/// Each chunk completion flips a bit in `full_completion_bitmap`. On resume, +/// `GetPendingChunks` enumerates only the chunks whose bits are still 0. +/// +/// The serialized form stores only the bitmap suffix starting at +/// `bitmap_byte_aligned_start` — the prefix of fully-completed chunks is +/// implied. This keeps the on-disk state proportional to the *unfinished* +/// range, not the total file size. +/// +/// On-disk layout is a small fixed-width little-endian binary header followed +/// by the truncated bitmap bytes; see `blob_download_state.cc` for the exact +/// field order. Chosen over JSON for speed and compactness; the file is purely +/// internal cache state, never inspected by users. +class BlobDownloadState { + public: + /// Identity of the blob (populated by caller; not serialized). + std::string blob_name; + std::string local_file_path; + + /// Fixed at first save; serialized for resume integrity checks. + int64_t blob_size = 0; + int32_t chunk_size = 0; + int32_t total_chunks = 0; + + /// Bit 0 of `full_completion_bitmap` represents chunk `bitmap_byte_aligned_start`. + /// Always a multiple of 8 — the prefix of completed chunks below this index + /// is implied complete and is not serialized. + int32_t bitmap_byte_aligned_start = 0; + + /// Highest chunk index completed so far. -1 if no chunks are done yet. + int32_t highest_completed_chunk = -1; + + /// Cached count for O(1) `IsComplete()`. + int32_t completed_count = 0; + + /// Unix epoch milliseconds; refreshed on every save. + int64_t last_modified_unix_ms = 0; + + /// Bit set: bit at `(chunk_idx - bitmap_byte_aligned_start) / 64` shifted by + /// `(chunk_idx - bitmap_byte_aligned_start) % 64`. Lazily grown by + /// `MarkChunkComplete` to cover up to `highest_completed_chunk`. + std::vector full_completion_bitmap; + + /// Sidecar path for `local_file_path`. + static std::filesystem::path GetStateFilePath(const std::filesystem::path& local_file_path); + + /// Construct a fresh state for a new download. Bitmap sized for `total_chunks`. + static std::unique_ptr CreateNew(std::string blob_name, + std::filesystem::path local_file_path, + int64_t blob_size, + int32_t chunk_size, + int32_t total_chunks); + + /// Load existing state from `.dlstate`. Returns nullptr if + /// the file does not exist, is corrupted, or has incompatible + /// `blob_size` / `chunk_size` / `total_chunks` (caller-provided values are + /// authoritative — a mismatch means the blob has been reconfigured upstream + /// and the partial download is no longer valid). + static std::unique_ptr LoadState(std::string blob_name, + std::filesystem::path local_file_path, + int64_t expected_blob_size, + int32_t expected_chunk_size, + int32_t expected_total_chunks, + ILogger* logger = nullptr); + + /// All chunks downloaded. + bool IsComplete() const noexcept { return completed_count == total_chunks; } + + /// Sum of bytes already written. Accounts for the final chunk being smaller + /// than `chunk_size` when blob_size is not chunk-aligned. + int64_t CalculateDownloadedSize() const noexcept; + + /// Whether `chunk_idx` is already marked complete. + bool IsChunkComplete(int32_t chunk_idx) const noexcept; + + /// Mark `chunk_idx` complete. Caller must hold the mutex when called from + /// concurrent worker tasks (use `mutex()` for that). Idempotent. + void MarkChunkComplete(int32_t chunk_idx); + + /// Enumerate chunks in [0, total_chunks) that are not yet complete. + std::vector GetPendingChunks() const; + + /// Atomically write current state to `.dlstate`. Best-effort: + /// I/O errors are logged but not thrown — the next save will retry, and a + /// failed save just means the next resume will replay a few chunks. + void SaveState(ILogger* logger = nullptr); + + /// Remove the sidecar; called on successful completion. + static void DeleteState(const std::filesystem::path& local_file_path, + ILogger* logger = nullptr); + + /// Mutex protecting concurrent `MarkChunkComplete` / `SaveState` calls from + /// the chunk worker pool. + std::mutex& mutex() noexcept { return mutex_; } + + private: + mutable std::mutex mutex_; +}; + +} // namespace fl diff --git a/sdk_v2/cpp/src/download/blob_downloader.cc b/sdk_v2/cpp/src/download/blob_downloader.cc index 73b41b173..a2a0da61d 100644 --- a/sdk_v2/cpp/src/download/blob_downloader.cc +++ b/sdk_v2/cpp/src/download/blob_downloader.cc @@ -1,15 +1,20 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #include "download/blob_downloader.h" +#include "download/blob_download_state.h" +#include "download/file_writer.h" #include "exception.h" +#include "logger.h" #include "util/path_safety.h" #include #include #include +#include #include #include #include +#include #include #include #include @@ -19,10 +24,32 @@ namespace fl { +namespace { + +/// Streaming buffer size used by the production chunk downloader. Matches the +/// 64 KB-ish granularity Stream.CopyTo uses in .NET, capping per-worker peak +/// memory at this many bytes regardless of chunk size. +constexpr size_t kStreamingBufferBytes = 64 * 1024; + +} // namespace + // ======================================================================== // AzureBlobDownloader — real Azure Storage SDK implementation // ======================================================================== +/// Per-blob shared state passed to the protected virtuals. The production +/// virtuals dereference `blob_client` / `azure_ctx`; tests can ignore them. +/// `cancel_flag` is flipped by the orchestrator on the first chunk failure so +/// workers exit promptly without waiting for Azure SDK timeouts. +struct AzureBlobDownloader::ChunkContext { + Azure::Storage::Blobs::BlobClient* blob_client; + Azure::Core::Context* azure_ctx; + std::atomic* cancel_flag; +}; + +AzureBlobDownloader::AzureBlobDownloader(ILogger* logger, FileWriterKind writer_kind) + : logger_(logger), writer_kind_(writer_kind) {} + std::vector AzureBlobDownloader::ListBlobs(const std::string& sas_uri) { try { auto container_client = Azure::Storage::Blobs::BlobContainerClient(sas_uri); @@ -44,6 +71,62 @@ std::vector AzureBlobDownloader::ListBlobs(const std::string& sas_ } } +int64_t AzureBlobDownloader::GetBlobSize(ChunkContext& ctx) { + auto props = ctx.blob_client->GetProperties({}, *ctx.azure_ctx).Value; + return props.BlobSize; +} + +std::atomic* AzureBlobDownloader::GetCancelFlag(ChunkContext& ctx) { + return ctx.cancel_flag; +} + +void AzureBlobDownloader::DownloadChunkStreaming( + ChunkContext& ctx, int64_t offset, int64_t size, std::vector& scratch, + const std::function& sink) { + Azure::Storage::Blobs::DownloadBlobOptions range_opts; + range_opts.Range = Azure::Core::Http::HttpRange{offset, size}; + auto result = ctx.blob_client->Download(range_opts, *ctx.azure_ctx); + auto& body_stream = *result.Value.BodyStream; + + if (scratch.size() < kStreamingBufferBytes) { + scratch.resize(kStreamingBufferBytes); + } + + int64_t remaining = size; + while (remaining > 0) { + size_t to_read = + static_cast(std::min(remaining, static_cast(scratch.size()))); + size_t got = body_stream.Read(scratch.data(), to_read, *ctx.azure_ctx); + if (got == 0) { + // Zero-byte read before reaching `size` means the server closed early. + // Treat as a hard error rather than silently writing a truncated chunk. + FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, + "short read from blob stream at offset " + std::to_string(offset) + ": got " + + std::to_string(size - remaining) + " of " + std::to_string(size) + " bytes"); + } + sink(scratch.data(), got); + remaining -= static_cast(got); + } +} + +namespace { + +/// Pre-allocate `local_path` to `blob_size` bytes if it does not already exist +/// at the expected size. Allows concurrent chunk writes to seek without races +/// and avoids re-zeroing a file we're resuming. +/// +/// Used only for the empty-blob case below; the writers' `Open` method handles +/// pre-allocation for the streaming chunked path. +void EnsureEmptyBlobFile(const std::string& local_path) { + std::ofstream f(local_path, std::ios::binary); + if (!f.is_open()) { + FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, + "failed to create empty blob file: " + local_path); + } +} + +} // namespace + void AzureBlobDownloader::DownloadBlob(const std::string& sas_uri, const std::string& blob_name, const std::string& local_path, @@ -64,155 +147,188 @@ void AzureBlobDownloader::DownloadBlob(const std::string& sas_uri, auto container_client = Azure::Storage::Blobs::BlobContainerClient(sas_uri, client_options); auto blob_client = container_client.GetBlobClient(blob_name); - // Context provides cooperative cancellation across all SDK operations. - Azure::Core::Context ctx; + // Single shared Azure context for the whole blob; calling Cancel() on it + // propagates into every in-flight chunk read. + Azure::Core::Context azure_ctx; + // Internal cancel flag flipped by the orchestrator on first chunk failure + // or by external cancellation; checked by workers between iterations. + std::atomic internal_cancel{false}; - // Get blob size - auto props = blob_client.GetProperties({}, ctx).Value; - int64_t blob_size = props.BlobSize; + ChunkContext chunk_ctx{&blob_client, &azure_ctx, &internal_cancel}; - if (blob_size == 0) { - // Empty blob — just create the file - std::ofstream f(local_path, std::ios::binary); - if (!f.is_open()) { - FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, - "failed to create empty blob file: " + local_path); - } + int64_t blob_size = GetBlobSize(chunk_ctx); + if (blob_size == 0) { + EnsureEmptyBlobFile(local_path); + BlobDownloadState::DeleteState(local_path, logger_); return; } // 2MB chunk size matching C# constexpr int64_t kChunkSize = 2 * 1024 * 1024; - int64_t num_chunks = (blob_size + kChunkSize - 1) / kChunkSize; - - // Pre-allocate the file to the full blob size. - // This lets concurrent chunk writes seek to their offset without a resize race. - { - std::ofstream f(local_path, std::ios::binary); - if (!f.is_open()) { - FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, - "failed to open blob file for pre-allocation: " + local_path); - } + int32_t num_chunks = static_cast((blob_size + kChunkSize - 1) / kChunkSize); + + // Resume from existing sidecar if it matches the current blob layout. + auto state = BlobDownloadState::LoadState(blob_name, local_path, blob_size, + static_cast(kChunkSize), + num_chunks, logger_); + if (!state) { + state = BlobDownloadState::CreateNew(blob_name, local_path, blob_size, + static_cast(kChunkSize), num_chunks); + } + + // Track cumulative bytes for progress reporting; seed with bytes already + // present on disk so percent stays monotonic across resume. + std::atomic bytes_completed{state->CalculateDownloadedSize()}; + if (bytes_written_cb && bytes_completed.load() > 0) { + bytes_written_cb(bytes_completed.load()); + } - f.seekp(blob_size - 1); - f.put('\0'); - f.close(); - if (f.fail()) { - FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, - "failed to pre-allocate blob file: " + local_path + - " (size=" + std::to_string(blob_size) + ")"); + auto pending = state->GetPendingChunks(); + if (pending.empty()) { + // Already complete on disk — drop the sidecar. + BlobDownloadState::DeleteState(local_path, logger_); + if (bytes_written_cb) { + bytes_written_cb(blob_size); } + return; + } + + // Open the file writer once for the whole download. Open() pre-allocates + // the file to blob_size if needed, preserving any existing bytes from a + // resume. Concurrent WriteAt calls to disjoint ranges are thread-safe + // (lock-free for Positional, mutex-guarded for MutexFstream). + std::unique_ptr writer = (writer_kind_ == FileWriterKind::MutexFstream) + ? MakeMutexFstreamFileWriter() + : MakePositionalFileWriter(); + writer->Open(local_path, blob_size); + + // Save the sidecar roughly every 2% of chunks, with a floor of 10. + const int32_t save_interval = std::max(10, num_chunks / 50); + std::atomic chunks_since_save{0}; + + std::mutex error_mutex; + std::exception_ptr first_error; + + // Worker pool: workers race to claim from `pending` via atomic fetch_add. + // On any failure, the first worker to fail records the error, sets + // internal_cancel, and calls azure_ctx.Cancel(); other workers see the + // signal and exit fast. + std::atomic next_pending_idx{0}; + int worker_count = std::min(max_concurrency, static_cast(pending.size())); + if (worker_count < 1) { + worker_count = 1; } + std::vector> workers; + workers.reserve(static_cast(worker_count)); + + auto worker_body = [&]() { + // Per-worker scratch buffer reused across every chunk this worker + // handles. Streaming downloads fill the scratch in 64 KB pieces and + // forward each piece to the sink, so total transient memory is bounded + // by `worker_count * kStreamingBufferBytes` regardless of chunk size. + std::vector scratch(kStreamingBufferBytes); + + while (true) { + // External cancellation drains the pool as fast as the SDK can unwind. + if (cancelled && cancelled->load(std::memory_order_relaxed)) { + if (!internal_cancel.exchange(true)) { + azure_ctx.Cancel(); + } + return; + } + if (internal_cancel.load(std::memory_order_relaxed)) { + return; + } - // Track cumulative bytes for progress reporting - std::atomic bytes_completed{0}; + size_t i = next_pending_idx.fetch_add(1, std::memory_order_relaxed); + if (i >= pending.size()) { + return; + } + int32_t chunk_idx = pending[i]; + int64_t offset = static_cast(chunk_idx) * kChunkSize; + int64_t size = std::min(kChunkSize, blob_size - offset); + + // Sink advances a per-chunk write cursor and forwards each piece to + // the file writer. The writer is responsible for any synchronization + // needed across concurrent workers; we don't take a mutex here. + int64_t written = 0; + auto sink = [&](const uint8_t* data, size_t len) { + writer->WriteAt(offset + written, data, len); + written += static_cast(len); + }; + + try { + DownloadChunkStreaming(chunk_ctx, offset, size, scratch, sink); + } catch (...) { + std::lock_guard lock(error_mutex); + if (!first_error) { + first_error = std::current_exception(); + } + if (!internal_cancel.exchange(true)) { + azure_ctx.Cancel(); + } + return; + } - // Mutex protects concurrent writes to different offsets in the same file. - // Each chunk opens the file, seeks, and writes — the mutex prevents interleaved I/O. - std::mutex file_mutex; + int64_t new_total = bytes_completed.fetch_add(size, std::memory_order_relaxed) + size; + if (bytes_written_cb) { + bytes_written_cb(new_total); + } - // Download chunks concurrently using a bounded pool of async tasks. - // We launch up to max_concurrency tasks at a time, then wait for the batch to complete. - for (int64_t batch_start = 0; batch_start < num_chunks; batch_start += max_concurrency) { - // Check cancellation between batches - if (cancelled && cancelled->load(std::memory_order_relaxed)) { - ctx.Cancel(); - FL_THROW(FOUNDRY_LOCAL_ERROR_OPERATION_CANCELLED, "download cancelled"); + bool should_save = false; + { + std::lock_guard lock(state->mutex()); + state->MarkChunkComplete(chunk_idx); + int32_t inc = chunks_since_save.fetch_add(1, std::memory_order_relaxed) + 1; + if (inc >= save_interval) { + chunks_since_save.store(0, std::memory_order_relaxed); + should_save = true; + } + } + if (should_save) { + std::lock_guard lock(state->mutex()); + state->SaveState(logger_); + } } + }; - int64_t batch_end = std::min(batch_start + max_concurrency, num_chunks); - std::vector> futures; - futures.reserve(static_cast(batch_end - batch_start)); - - for (int64_t chunk_idx = batch_start; chunk_idx < batch_end; ++chunk_idx) { - int64_t offset = chunk_idx * kChunkSize; - int64_t size = std::min(kChunkSize, blob_size - offset); - - futures.push_back(std::async(std::launch::async, - [&blob_client, &local_path, &file_mutex, &bytes_completed, &bytes_written_cb, - &ctx, offset, size]() { - // Download this range from the blob. - // Retry and backoff are handled by the SDK's retry policy. - Azure::Storage::Blobs::DownloadBlobOptions range_opts; - range_opts.Range = Azure::Core::Http::HttpRange{offset, size}; - auto result = blob_client.Download(range_opts, ctx); - auto& body_stream = *result.Value.BodyStream; - - // Read the body into a local buffer - std::vector buffer(static_cast(size)); - size_t total_read = 0; - while (total_read < static_cast(size)) { - size_t bytes_read = body_stream.Read( - buffer.data() + total_read, - static_cast(size) - total_read, - ctx); - - if (bytes_read == 0) { - break; - } - - total_read += bytes_read; - } - - // a zero-byte read before reaching `size` indicates the server closed early. - // Treat as a hard error rather than silently writing a truncated chunk. - if (total_read < static_cast(size)) { - FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, - "short read from blob stream: got " + - std::to_string(total_read) + " of " + - std::to_string(size) + " bytes at offset " + - std::to_string(offset)); - } - - // Write the chunk to the file at the correct offset - { - std::lock_guard lock(file_mutex); - std::ofstream f(local_path, - std::ios::binary | std::ios::in | std::ios::out); - if (!f.is_open()) { - FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, - "failed to open blob file for write: " + local_path); - } - - f.seekp(offset); - f.write(reinterpret_cast(buffer.data()), - static_cast(total_read)); - if (f.fail()) { - FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, - "failed to write blob chunk to " + local_path + - " at offset " + std::to_string(offset) + - " (" + std::to_string(total_read) + " bytes)"); - } - } - - // Report progress - bytes_completed += static_cast(total_read); - if (bytes_written_cb) { - bytes_written_cb(bytes_completed.load()); - } - })); - } + for (int w = 0; w < worker_count; ++w) { + workers.push_back(std::async(std::launch::async, worker_body)); + } - // Wait for all tasks in this batch, cancelling context on failure + for (auto& f : workers) { try { - for (auto& f : futures) { - f.get(); - } + f.get(); } catch (...) { - // Cancel remaining in-flight downloads so futures complete quickly - ctx.Cancel(); - for (auto& f : futures) { - try { - if (f.valid()) { - f.get(); - } - } catch (...) { - } + // Worker bodies should already have routed exceptions through + // first_error, but stay defensive in case std::async signals one. + std::lock_guard lock(error_mutex); + if (!first_error) { + first_error = std::current_exception(); } - throw; + internal_cancel.store(true, std::memory_order_relaxed); } } + + // Release the OS handle before persisting / deleting the sidecar so any + // observer that watches the data file sees a fully-closed handle. + writer->Close(); + + if (first_error || (cancelled && cancelled->load(std::memory_order_relaxed))) { + // Persist what we have so the next attempt resumes from here. + { + std::lock_guard lock(state->mutex()); + state->SaveState(logger_); + } + if (cancelled && cancelled->load(std::memory_order_relaxed)) { + FL_THROW(FOUNDRY_LOCAL_ERROR_OPERATION_CANCELLED, "download cancelled"); + } + std::rethrow_exception(first_error); + } + + // All chunks done — sidecar is no longer needed. + BlobDownloadState::DeleteState(local_path, logger_); } catch (const Azure::Core::OperationCancelledException&) { FL_THROW(FOUNDRY_LOCAL_ERROR_OPERATION_CANCELLED, "download cancelled"); } catch (const Azure::Core::RequestFailedException& e) { @@ -258,6 +374,34 @@ bool EndsWith(const std::string& str, const std::string& suffix) { }); } +/// Returns false if a file at `local_path` already matches the blob's expected +/// `content_length` exactly AND has no `.dlstate` sidecar — in which case the +/// caller can skip the download. Returns true (download needed) for any of: +/// missing file, size mismatch, sidecar present (file may be pre-allocated +/// with holes), or filesystem-stat errors (treat as "redownload to be safe"). +bool IsDownloadNeeded(const BlobItemInfo& blob, const std::string& local_path) { + std::error_code ec; + auto status = std::filesystem::status(local_path, ec); + if (ec || !std::filesystem::exists(status) || !std::filesystem::is_regular_file(status)) { + return true; + } + auto size = std::filesystem::file_size(local_path, ec); + if (ec) { + return true; + } + if (static_cast(size) != blob.content_length) { + return true; + } + // The data file is at the expected size, but a sidecar means a previous run + // pre-allocated then aborted mid-download. The file has holes; let + // AzureBlobDownloader resume from the sidecar. + auto sidecar = BlobDownloadState::GetStateFilePath(local_path); + if (std::filesystem::exists(sidecar, ec)) { + return true; + } + return false; +} + } // anonymous namespace void DownloadBlobsToDirectory(IBlobDownloader& downloader, @@ -315,15 +459,43 @@ void DownloadBlobsToDirectory(IBlobDownloader& downloader, return a.first.content_length < b.first.content_length; }); - // Step 4: Calculate total size for progress + // Step 4: Calculate total size across every in-scope blob, including those + // already present on disk — so 100% always means "every byte is local". int64_t total_size = 0; for (const auto& [blob, _] : blobs_to_download) { total_size += blob.content_length; } - // Step 4.5: Emit 0% so callers know the download has started + // Step 4.25: Skip blobs already present at the expected size. Their bytes + // count toward "downloaded" so the percentage stays accurate when this is a + // resume of a partially-completed download. + int64_t skipped_bytes = 0; + blobs_to_download.erase( + std::remove_if(blobs_to_download.begin(), blobs_to_download.end(), + [&skipped_bytes](const auto& pair) { + if (IsDownloadNeeded(pair.first, pair.second)) { + return false; + } + skipped_bytes += pair.first.content_length; + return true; + }), + blobs_to_download.end()); + + // Step 4.5: Emit initial progress reflecting any already-on-disk bytes. + // If everything was skipped, emit 100% directly and return. + if (blobs_to_download.empty()) { + if (options.progress) { + options.progress(100.0f); + } + return; + } + if (options.progress) { - int result = options.progress(0.0f); + float initial_percent = total_size > 0 + ? static_cast(skipped_bytes) / + static_cast(total_size) * 100.0f + : 0.0f; + int result = options.progress(initial_percent); if (result != 0) { FL_THROW(FOUNDRY_LOCAL_ERROR_OPERATION_CANCELLED, "download cancelled by user callback return value"); } @@ -333,7 +505,9 @@ void DownloadBlobsToDirectory(IBlobDownloader& downloader, // The cancellation flag is set when the progress callback returns non-zero. // It is shared with chunk download threads so they can exit promptly. std::atomic cancelled{false}; - std::atomic total_downloaded_bytes{0}; + // Seed with skipped bytes so per-chunk progress callbacks compute the right + // overall percentage. + std::atomic total_downloaded_bytes{skipped_bytes}; for (const auto& [blob, local_path] : blobs_to_download) { // Check cancellation between blobs diff --git a/sdk_v2/cpp/src/download/blob_downloader.h b/sdk_v2/cpp/src/download/blob_downloader.h index f43774a16..4fc7412f2 100644 --- a/sdk_v2/cpp/src/download/blob_downloader.h +++ b/sdk_v2/cpp/src/download/blob_downloader.h @@ -11,6 +11,8 @@ namespace fl { +class ILogger; + /// Progress callback: percent is 0.0 to 100.0. Return 0 to continue, non-zero to cancel. using DownloadProgressFn = std::function; @@ -56,9 +58,38 @@ class IBlobDownloader { std::atomic* cancelled = nullptr) = 0; }; +/// Strategy for writing downloaded blob chunks to the local file. Both +/// strategies are thread-safe across concurrent calls to disjoint ranges. +/// +/// - `Positional`: lock-free `pwrite` / `WriteFile`+`OVERLAPPED`. Default and +/// recommended; lets the OS arbitrate concurrent writes to disjoint ranges +/// instead of taking a user-space mutex. +/// - `MutexFstream`: single shared `std::fstream` guarded by an internal +/// mutex. Provided for benchmarking and as a portable fallback. +enum class FileWriterKind { + Positional, + MutexFstream, +}; + /// Azure Storage Blobs SDK-based implementation of IBlobDownloader. +/// +/// Implements resumable downloads: a `.dlstate` sidecar tracks which 2 MB +/// chunks have completed, and DownloadBlob picks up where a prior aborted run +/// left off. A linked cancellation token cascades the first chunk-level +/// failure to every other in-flight chunk so the worker pool drains quickly. +/// +/// Chunks stream from the blob client into the local file in ~64 KB pieces +/// via a sink callback, so each worker holds a single 64 KB scratch buffer +/// instead of allocating a full chunk's worth of bytes per request. This +/// caps peak memory at roughly `max_concurrency * 64 KB` regardless of how +/// large the blob or the chunk size is. class AzureBlobDownloader : public IBlobDownloader { public: + /// `logger` is used for diagnostics only (state file save/load events). May be null. + /// `writer_kind` chooses the on-disk write strategy; see `FileWriterKind`. + explicit AzureBlobDownloader(ILogger* logger = nullptr, + FileWriterKind writer_kind = FileWriterKind::Positional); + std::vector ListBlobs(const std::string& sas_uri) override; void DownloadBlob(const std::string& sas_uri, @@ -67,6 +98,45 @@ class AzureBlobDownloader : public IBlobDownloader { int max_concurrency, BlobBytesWrittenFn bytes_written_cb = nullptr, std::atomic* cancelled = nullptr) override; + + protected: + /// Opaque per-blob context. Defined in `blob_downloader.cc`; holds the Azure + /// SDK BlobClient + Context pointers used by the production virtuals. Test + /// subclasses can ignore this argument and use only the explicit parameters. + struct ChunkContext; + + /// Return the blob size in bytes. Production calls `BlobClient::GetProperties`. + /// Test subclasses can override to return a constant without touching Azure. + virtual int64_t GetBlobSize(ChunkContext& ctx); + + /// Read `size` bytes starting at `offset` from the blob and forward them + /// piecewise to `sink`. The production implementation pulls from the blob + /// client referenced by `ctx`; test subclasses can override to inject + /// chunk-level failures or slow reads. + /// + /// `scratch` is a per-worker reusable buffer (default 64 KB) — implementers + /// may resize it but should avoid allocating one-buffer-per-chunk. `sink` + /// must be invoked with strictly contiguous ranges; the cumulative byte + /// count delivered to `sink` must equal `size` on success. + /// + /// Must throw on failure. Implementations should observe the cancellation + /// flag accessible via `ctx` and exit promptly when cancellation is requested. + virtual void DownloadChunkStreaming(ChunkContext& ctx, + int64_t offset, + int64_t size, + std::vector& scratch, + const std::function& sink); + + /// Accessor for test subclasses overriding `DownloadChunkStreaming`. Returns + /// the shared cancellation flag — when set by the orchestrator (e.g. after + /// another chunk fails), in-flight chunk simulations should observe it and + /// exit promptly. Production code doesn't need this directly: cancellation + /// is routed through `Azure::Core::Context::Cancel()`. + std::atomic* GetCancelFlag(ChunkContext& ctx); + + private: + ILogger* logger_ = nullptr; + FileWriterKind writer_kind_ = FileWriterKind::Positional; }; /// High-level download function: enumerate, filter, and download all blobs from a SAS URI. diff --git a/sdk_v2/cpp/src/download/cross_process_file_lock.cc b/sdk_v2/cpp/src/download/cross_process_file_lock.cc new file mode 100644 index 000000000..33eeb2150 --- /dev/null +++ b/sdk_v2/cpp/src/download/cross_process_file_lock.cc @@ -0,0 +1,198 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "download/cross_process_file_lock.h" +#include "exception.h" +#include "logger.h" + +#include + +#include +#include +#include +#include +#include +#include + +#ifdef _WIN32 +#define WIN32_LEAN_AND_MEAN +#define NOMINMAX +#include +#include +#else +#include +#include +#include +#include +#endif + +namespace fl { + +namespace { + +constexpr const char* kLockFileName = ".download.lock"; + +/// `PID:,Time:\n` — mirrors what C# writes +/// (CrossProcessFileLock.cs:68) so the lock file is recognizable across SDKs. +std::string FormatProcessInfo() { +#ifdef _WIN32 + auto pid = static_cast(_getpid()); +#else + auto pid = static_cast(getpid()); +#endif + auto t = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now()); + std::tm tm{}; +#ifdef _WIN32 + gmtime_s(&tm, &t); +#else + gmtime_r(&t, &tm); +#endif + std::ostringstream oss; + oss << "PID:" << pid << ",Time:" << std::put_time(&tm, "%Y-%m-%dT%H:%M:%SZ") << '\n'; + return oss.str(); +} + +} // namespace + +// Platform-specific resource handle. The destructor here is the only thing +// that releases the lock; CrossProcessFileLock's destructor is defaulted. +#ifdef _WIN32 +struct CrossProcessFileLock::State { + HANDLE handle; + ~State() { + if (handle != INVALID_HANDLE_VALUE) { + // FILE_FLAG_DELETE_ON_CLOSE removes the file when the last handle closes. + CloseHandle(handle); + } + } +}; +#else +struct CrossProcessFileLock::State { + int fd; + std::filesystem::path path; + ~State() { + if (fd >= 0) { + // Unlink before close so the file disappears at the same instant the + // lock releases; a concurrent acquirer simply recreates it. + ::unlink(path.c_str()); + ::close(fd); + } + } +}; +#endif + +CrossProcessFileLock::CrossProcessFileLock(std::filesystem::path path, + std::unique_ptr state, + ILogger* logger) + : path_(std::move(path)), state_(std::move(state)), logger_(logger) {} + +CrossProcessFileLock::~CrossProcessFileLock() { + // Release the OS handle first so the "released" log message is accurate. + state_.reset(); + if (logger_) { + logger_->Log(LogLevel::Debug, "CrossProcessFileLock released: " + path_.string()); + } +} + +std::unique_ptr CrossProcessFileLock::TryAcquireForDirectory( + const std::filesystem::path& directory, ILogger* logger) { + std::error_code ec; + std::filesystem::create_directories(directory, ec); + // Best-effort: if create_directories failed, the platform open below will + // surface a clearer error message. + + auto lock_path = directory / kLockFileName; + std::unique_ptr state; + +#ifdef _WIN32 + // dwShareMode=0 blocks any other open (cross- and in-process) until this + // handle closes. FILE_FLAG_DELETE_ON_CLOSE pairs OPEN_ALWAYS into a + // self-cleaning lock that doesn't require unlink-then-close races. + auto wide = lock_path.wstring(); + HANDLE handle = CreateFileW(wide.c_str(), + GENERIC_READ | GENERIC_WRITE, + 0, + nullptr, + OPEN_ALWAYS, + FILE_ATTRIBUTE_NORMAL | FILE_FLAG_DELETE_ON_CLOSE, + nullptr); + if (handle == INVALID_HANDLE_VALUE) { + DWORD err = GetLastError(); + if (err == ERROR_SHARING_VIOLATION || err == ERROR_LOCK_VIOLATION || err == ERROR_ACCESS_DENIED) { + // ACCESS_DENIED can surface on FILE_SHARE_NONE collisions when the + // existing handle has narrower access rights — treat as contention. + return nullptr; + } + FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, + "CreateFileW failed for lock '" + lock_path.string() + + "' (GetLastError=" + std::to_string(err) + ")"); + } + + auto info = FormatProcessInfo(); + DWORD written = 0; + WriteFile(handle, info.data(), static_cast(info.size()), &written, nullptr); + FlushFileBuffers(handle); + + state = std::unique_ptr(new State{handle}); +#else + int fd = ::open(lock_path.c_str(), O_CREAT | O_RDWR | O_CLOEXEC, 0644); + if (fd < 0) { + FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, + "open failed for lock '" + lock_path.string() + "' (errno=" + std::to_string(errno) + ")"); + } + if (::flock(fd, LOCK_EX | LOCK_NB) != 0) { + int err = errno; + ::close(fd); + if (err == EWOULDBLOCK || err == EAGAIN) { + return nullptr; + } + FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, + "flock failed for '" + lock_path.string() + "' (errno=" + std::to_string(err) + ")"); + } + + (void)::ftruncate(fd, 0); + auto info = FormatProcessInfo(); + (void)::write(fd, info.data(), info.size()); + + state = std::unique_ptr(new State{fd, lock_path}); +#endif + + if (logger) { + logger->Log(LogLevel::Debug, "CrossProcessFileLock acquired: " + lock_path.string()); + } + return std::unique_ptr( + new CrossProcessFileLock(std::move(lock_path), std::move(state), logger)); +} + +std::unique_ptr WaitForLockForDirectory( + const std::filesystem::path& directory, + const CancellationPredicate& is_cancelled, + ILogger* logger, + std::chrono::milliseconds poll_interval, + std::chrono::milliseconds timeout) { + auto deadline = std::chrono::steady_clock::now() + timeout; + // Poll cancellation in slices of at most 100 ms so a long poll interval + // (1.25 s default) doesn't keep a cancelling caller waiting. + constexpr std::chrono::milliseconds kCancelSlice{100}; + while (true) { + if (is_cancelled && is_cancelled()) { + FL_THROW(FOUNDRY_LOCAL_ERROR_OPERATION_CANCELLED, "lock acquisition cancelled"); + } + auto lock = CrossProcessFileLock::TryAcquireForDirectory(directory, logger); + if (lock) { + return lock; + } + if (std::chrono::steady_clock::now() >= deadline) { + FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, + "timed out waiting for cross-process download lock on '" + directory.string() + "'"); + } + auto slice_end = std::chrono::steady_clock::now() + poll_interval; + while (std::chrono::steady_clock::now() < slice_end) { + if (is_cancelled && is_cancelled()) { + FL_THROW(FOUNDRY_LOCAL_ERROR_OPERATION_CANCELLED, "lock acquisition cancelled"); + } + std::this_thread::sleep_for(std::min(kCancelSlice, poll_interval)); + } + } +} + +} // namespace fl diff --git a/sdk_v2/cpp/src/download/cross_process_file_lock.h b/sdk_v2/cpp/src/download/cross_process_file_lock.h new file mode 100644 index 000000000..2c771b9c8 --- /dev/null +++ b/sdk_v2/cpp/src/download/cross_process_file_lock.h @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include + +namespace fl { + +class ILogger; + +/// RAII exclusive lock backed by an OS-level file lock on +/// `/.download.lock`. Serializes model downloads across processes +/// that share a cache directory. A crash while holding the lock may leave a +/// zero-byte file behind; the next acquirer reopens and re-locks, so the leak +/// is harmless. +class CrossProcessFileLock { + public: + /// Non-blocking acquisition. Returns nullptr if another process currently + /// holds the lock. Creates `directory` if missing. Throws fl::Exception on + /// unexpected errors (permission denied, etc.). + static std::unique_ptr TryAcquireForDirectory( + const std::filesystem::path& directory, + ILogger* logger = nullptr); + + ~CrossProcessFileLock(); + + CrossProcessFileLock(const CrossProcessFileLock&) = delete; + CrossProcessFileLock& operator=(const CrossProcessFileLock&) = delete; + CrossProcessFileLock(CrossProcessFileLock&&) = delete; + CrossProcessFileLock& operator=(CrossProcessFileLock&&) = delete; + + /// Path to the lock file (for diagnostics / tests). + const std::filesystem::path& path() const noexcept { return path_; } + + private: + struct State; // Platform-specific; defined in the .cc. + + CrossProcessFileLock(std::filesystem::path path, std::unique_ptr state, ILogger* logger); + + std::filesystem::path path_; + std::unique_ptr state_; + ILogger* logger_; +}; + +/// Returning true aborts WaitForLockForDirectory with FOUNDRY_LOCAL_ERROR_OPERATION_CANCELLED. +using CancellationPredicate = std::function; + +/// Polls TryAcquireForDirectory until the lock is acquired, `is_cancelled()` +/// returns true, or `timeout` elapses. +/// Throws FOUNDRY_LOCAL_ERROR_OPERATION_CANCELLED on cancellation, or +/// FOUNDRY_LOCAL_ERROR_INTERNAL on timeout. +std::unique_ptr WaitForLockForDirectory( + const std::filesystem::path& directory, + const CancellationPredicate& is_cancelled, + ILogger* logger = nullptr, + std::chrono::milliseconds poll_interval = std::chrono::milliseconds{1250}, + std::chrono::milliseconds timeout = std::chrono::hours{3}); + +} // namespace fl diff --git a/sdk_v2/cpp/src/download/download_manager.cc b/sdk_v2/cpp/src/download/download_manager.cc index df5059a35..93a764e4b 100644 --- a/sdk_v2/cpp/src/download/download_manager.cc +++ b/sdk_v2/cpp/src/download/download_manager.cc @@ -1,8 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #include "download/download_manager.h" +#include "download/cross_process_file_lock.h" #include "download/inference_model_writer.h" #include "exception.h" +#include "log_level.h" +#include "logger.h" #include "util/path_safety.h" #include @@ -154,8 +157,9 @@ DownloadManager::DownloadManager(std::string cache_directory, std::string catalo ILogger& logger) : cache_directory_(std::move(cache_directory)), max_concurrency_(max_concurrency), + logger_(logger), registry_client_(std::make_unique(std::move(catalog_region), logger)), - blob_downloader_(std::make_unique()) {} + blob_downloader_(std::make_unique(&logger)) {} DownloadManager::~DownloadManager() = default; @@ -218,7 +222,7 @@ std::string DownloadManager::DownloadModel(const ModelInfo& info, auto model_path = ComputeModelPath(info); - // Check if already downloaded (before validating URI — cached models don't need one). + // Fast path: serve the cache without taking the cross-process lock. // A valid cache hit requires: directory exists, no in-progress signal file, and // inference_model.json is present (written by DownloadModel on successful completion). auto signal_path = std::filesystem::path(model_path) / kDownloadSignalFileName; @@ -237,9 +241,38 @@ std::string DownloadManager::DownloadModel(const ModelInfo& info, FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, "cannot download model: empty URI (asset_id)"); } - // Create output directory + // Create output directory before taking the cross-process lock, since the lock + // file lives inside it. std::filesystem::create_directories(model_path); + // Serialize across processes that share this cache directory. Inside the + // running process `download_mutex_` already prevents reentry; the file lock + // protects against a second SDK instance (e.g. another service or CLI) racing + // on the same model directory. + auto cancel_pred = [&progress_cb]() -> bool { + // progress_cb returning non-zero is the SDK's cancellation signal. Reusing + // it here also acts as a periodic heartbeat (0%) while we wait for the + // other process to finish. + return progress_cb && progress_cb(0.0f) != 0; + }; + auto lock = CrossProcessFileLock::TryAcquireForDirectory(model_path, &logger_); + if (!lock) { + logger_.Log(LogLevel::Information, + "Model download is being performed by another process. Waiting on lock at '" + + model_path + "'..."); + lock = WaitForLockForDirectory(model_path, cancel_pred, &logger_); + } + + // Another process may have just completed the download we were waiting on. + // Re-check the cache now that we hold the lock. + if (std::filesystem::exists(model_path) && !std::filesystem::exists(signal_path) && + HasInferenceModelJson(model_path)) { + if (progress_cb) { + progress_cb(100.0f); + } + return ResolveEffectiveModelPath(model_path); + } + // Create download signal file { std::ofstream signal(signal_path); diff --git a/sdk_v2/cpp/src/download/download_manager.h b/sdk_v2/cpp/src/download/download_manager.h index 42a4e69b7..91bd22b56 100644 --- a/sdk_v2/cpp/src/download/download_manager.h +++ b/sdk_v2/cpp/src/download/download_manager.h @@ -3,9 +3,11 @@ #pragma once #include "download/blob_downloader.h" +#include "download/cross_process_file_lock.h" #include "download/model_registry_client.h" #include "model_info.h" +#include #include #include #include @@ -65,6 +67,7 @@ class DownloadManager { std::string cache_directory_; int max_concurrency_; + ILogger& logger_; std::unique_ptr registry_client_; std::unique_ptr blob_downloader_; diff --git a/sdk_v2/cpp/src/download/file_writer.cc b/sdk_v2/cpp/src/download/file_writer.cc new file mode 100644 index 000000000..a1936a62c --- /dev/null +++ b/sdk_v2/cpp/src/download/file_writer.cc @@ -0,0 +1,222 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "download/file_writer.h" +#include "exception.h" + +#include + +#include +#include +#include +#include + +#ifdef _WIN32 +#ifndef NOMINMAX +#define NOMINMAX +#endif +#include +#else +#include +#include +#include +#include +#endif + +namespace fl { + +namespace { + +namespace fs = std::filesystem; + +/// Ensure the data file exists at exactly `expected_size`. Skips truncation +/// if the file is already at that size — the resume path relies on this. +void EnsureFileExistsAtSize(const fs::path& path, int64_t expected_size) { + std::error_code ec; + auto cur_size = fs::file_size(path, ec); + if (!ec) { + if (cur_size == static_cast(expected_size)) { + return; + } + // File exists but is the wrong size — fall through to recreate. + } else if (ec != std::errc::no_such_file_or_directory) { + // Some other stat error (permission, transient NFS hiccup, AV scanner + // holding a handle, etc.). Don't blow away a potentially-intact file + // just because we couldn't read its size; surface the error instead so + // the caller can retry and the existing on-disk progress is preserved. + FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, + "failed to stat blob file: " + path.string() + " (" + ec.message() + ")"); + } + + std::ofstream f(path, std::ios::binary); + if (!f.is_open()) { + FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, + "failed to open blob file for pre-allocation: " + path.string()); + } + if (expected_size > 0) { + f.seekp(expected_size - 1); + f.put('\0'); + } + f.close(); + if (f.fail()) { + FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, + "failed to pre-allocate blob file: " + path.string() + + " (size=" + std::to_string(expected_size) + ")"); + } +} + +#ifdef _WIN32 + +class WindowsPositionalFileWriter : public IFileWriter { + public: + ~WindowsPositionalFileWriter() override { Close(); } + + void Open(const fs::path& path, int64_t expected_size) override { + EnsureFileExistsAtSize(path, expected_size); + // FILE_SHARE_READ | FILE_SHARE_WRITE so the lock file / other tools can + // peek at the partial file without us erroring; positional WriteFile is + // safe regardless of share mode. + handle_ = ::CreateFileW(path.wstring().c_str(), GENERIC_READ | GENERIC_WRITE, + FILE_SHARE_READ | FILE_SHARE_WRITE, nullptr, OPEN_EXISTING, + FILE_ATTRIBUTE_NORMAL, nullptr); + if (handle_ == INVALID_HANDLE_VALUE) { + FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, + "PositionalFileWriter open failed for " + path.string() + + " (Win32 err " + std::to_string(::GetLastError()) + ")"); + } + } + + void WriteAt(int64_t offset, const uint8_t* data, size_t len) override { + // Concurrent WriteFile calls with distinct OVERLAPPED offsets on the same + // handle are safe for non-overlapping ranges; the kernel orders them. + while (len > 0) { + OVERLAPPED ov{}; + ov.Offset = static_cast(static_cast(offset) & 0xFFFFFFFFULL); + ov.OffsetHigh = static_cast((static_cast(offset) >> 32) & 0xFFFFFFFFULL); + DWORD to_write = static_cast(len > 0x7FFFFFFFu ? 0x7FFFFFFFu : len); + DWORD written = 0; + if (!::WriteFile(handle_, data, to_write, &written, &ov)) { + FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, + "PositionalFileWriter write failed at offset " + std::to_string(offset) + + " (Win32 err " + std::to_string(::GetLastError()) + ")"); + } + if (written == 0) { + FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, + "PositionalFileWriter short write at offset " + std::to_string(offset)); + } + offset += static_cast(written); + data += written; + len -= written; + } + } + + void Close() override { + if (handle_ != INVALID_HANDLE_VALUE) { + ::CloseHandle(handle_); + handle_ = INVALID_HANDLE_VALUE; + } + } + + private: + HANDLE handle_ = INVALID_HANDLE_VALUE; +}; + +#else // POSIX + +class PosixPositionalFileWriter : public IFileWriter { + public: + ~PosixPositionalFileWriter() override { Close(); } + + void Open(const fs::path& path, int64_t expected_size) override { + EnsureFileExistsAtSize(path, expected_size); + fd_ = ::open(path.c_str(), O_RDWR | O_CLOEXEC); + if (fd_ < 0) { + FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, + "PositionalFileWriter open failed for " + path.string() + + " (errno " + std::to_string(errno) + ")"); + } + } + + void WriteAt(int64_t offset, const uint8_t* data, size_t len) override { + while (len > 0) { + ssize_t n = ::pwrite(fd_, data, len, static_cast(offset)); + if (n < 0) { + if (errno == EINTR) continue; + FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, + "PositionalFileWriter pwrite failed at offset " + std::to_string(offset) + + " (errno " + std::to_string(errno) + ")"); + } + if (n == 0) { + FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, + "PositionalFileWriter short pwrite at offset " + std::to_string(offset)); + } + offset += n; + data += n; + len -= static_cast(n); + } + } + + void Close() override { + if (fd_ >= 0) { + ::close(fd_); + fd_ = -1; + } + } + + private: + int fd_ = -1; +}; + +#endif + +class MutexFstreamFileWriter : public IFileWriter { + public: + ~MutexFstreamFileWriter() override { Close(); } + + void Open(const fs::path& path, int64_t expected_size) override { + EnsureFileExistsAtSize(path, expected_size); + file_.open(path, std::ios::binary | std::ios::in | std::ios::out); + if (!file_.is_open()) { + FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, + "MutexFstreamFileWriter open failed for " + path.string()); + } + } + + void WriteAt(int64_t offset, const uint8_t* data, size_t len) override { + std::lock_guard lock(mutex_); + // Clear any sticky failbit from a prior call so this write's diagnostic + // reflects what actually went wrong here, not a stale earlier failure. + file_.clear(); + file_.seekp(offset); + file_.write(reinterpret_cast(data), static_cast(len)); + if (file_.fail()) { + FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, + "MutexFstreamFileWriter write failed at offset " + std::to_string(offset)); + } + } + + void Close() override { + if (file_.is_open()) { + file_.close(); + } + } + + private: + std::fstream file_; + std::mutex mutex_; +}; + +} // namespace + +std::unique_ptr MakePositionalFileWriter() { +#ifdef _WIN32 + return std::make_unique(); +#else + return std::make_unique(); +#endif +} + +std::unique_ptr MakeMutexFstreamFileWriter() { + return std::make_unique(); +} + +} // namespace fl diff --git a/sdk_v2/cpp/src/download/file_writer.h b/sdk_v2/cpp/src/download/file_writer.h new file mode 100644 index 000000000..eacc498e2 --- /dev/null +++ b/sdk_v2/cpp/src/download/file_writer.h @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include + +namespace fl { + +/// Thread-safe positional writer for blob downloads. +/// +/// Workers in a single download claim disjoint chunks, so concurrent +/// `WriteAt` calls always target non-overlapping byte ranges. An +/// implementation may serialize internally (e.g. via a mutex) or rely on the +/// OS to allow lock-free concurrent positional writes — the contract is the +/// same either way. +class IFileWriter { + public: + virtual ~IFileWriter() = default; + + /// Make `path` exist at exactly `expected_size` bytes. If the file already + /// exists at that size, leave its contents intact (so the resume path can + /// pick up where it left off). Called once before the first `WriteAt`. + virtual void Open(const std::filesystem::path& path, int64_t expected_size) = 0; + + /// Write `len` bytes from `data` starting at byte offset `offset`. + /// Thread-safe across overlapping or disjoint ranges — concurrent calls to + /// disjoint ranges complete without coordination from the caller. + virtual void WriteAt(int64_t offset, const uint8_t* data, size_t len) = 0; + + /// Release the underlying OS handle. Implicitly called by the destructor. + virtual void Close() = 0; +}; + +/// Backed by `pwrite` (POSIX) or `WriteFile`+`OVERLAPPED` (Windows). Concurrent +/// `WriteAt` calls to disjoint ranges proceed in parallel — no internal +/// mutex. The recommended default. +std::unique_ptr MakePositionalFileWriter(); + +/// Backed by a single `std::fstream` guarded by an internal mutex. Provided +/// for comparison with `MakePositionalFileWriter` and as a portable fallback +/// if a platform's positional-write semantics ever change. +std::unique_ptr MakeMutexFstreamFileWriter(); + +} // namespace fl diff --git a/sdk_v2/cpp/test/CMakeLists.txt b/sdk_v2/cpp/test/CMakeLists.txt index e14a9fc4d..1503fa0f5 100644 --- a/sdk_v2/cpp/test/CMakeLists.txt +++ b/sdk_v2/cpp/test/CMakeLists.txt @@ -11,6 +11,7 @@ add_executable(foundry_local_tests internal_api/audio/audio_transcription_contract_test.cc internal_api/audio/pcm_utils_test.cc internal_api/base_model_catalog_test.cc + internal_api/blob_download_state_test.cc internal_api/c_api_test.cc internal_api/callback_handler_test.cc internal_api/catalog_cache_test.cc @@ -21,6 +22,7 @@ add_executable(foundry_local_tests internal_api/chat_completions_test.cc internal_api/chat_completions_converter_test.cc internal_api/configuration_test.cc + internal_api/cross_process_file_lock_test.cc internal_api/download_test.cc internal_api/embeddings/contracts_embeddings_test.cc internal_api/embeddings/fp16_test.cc @@ -28,6 +30,7 @@ add_executable(foundry_local_tests internal_api/exception_test.cc internal_api/execution_provider_test.cc internal_api/file_uri_test.cc + internal_api/file_writer_test.cc internal_api/genai_config_test.cc internal_api/http_retry_test.cc internal_api/item_test.cc diff --git a/sdk_v2/cpp/test/internal_api/blob_download_state_test.cc b/sdk_v2/cpp/test/internal_api/blob_download_state_test.cc new file mode 100644 index 000000000..9e4770120 --- /dev/null +++ b/sdk_v2/cpp/test/internal_api/blob_download_state_test.cc @@ -0,0 +1,251 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "download/blob_download_state.h" + +#include + +#include +#include +#include +#include +#include +#include + +namespace fs = std::filesystem; + +using namespace fl; + +namespace { + +class TempDir { + public: + TempDir() { + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution dist; + path_ = fs::temp_directory_path() / ("fl_dlstate_test_" + std::to_string(dist(gen))); + fs::create_directories(path_); + } + + ~TempDir() { + std::error_code ec; + fs::remove_all(path_, ec); + } + + const fs::path& path() const { return path_; } + + private: + fs::path path_; +}; + +constexpr int64_t kBlobSize = 20 * 1024 * 1024; // 20 MiB +constexpr int32_t kChunkSize = 2 * 1024 * 1024; // 2 MiB +constexpr int32_t kNumChunks = 10; + +} // namespace + +TEST(BlobDownloadStateTest, GetStateFilePathAppendsDlstate) { + fs::path p = "C:/some/file.bin"; + EXPECT_EQ(BlobDownloadState::GetStateFilePath(p).string(), + (fs::path("C:/some/file.bin.dlstate")).string()); +} + +TEST(BlobDownloadStateTest, CreateNewInitializesEmptyBitmap) { + TempDir d; + auto local = d.path() / "blob.bin"; + auto s = BlobDownloadState::CreateNew("blob", local, kBlobSize, kChunkSize, kNumChunks); + ASSERT_NE(s, nullptr); + EXPECT_EQ(s->blob_size, kBlobSize); + EXPECT_EQ(s->chunk_size, kChunkSize); + EXPECT_EQ(s->total_chunks, kNumChunks); + EXPECT_EQ(s->completed_count, 0); + EXPECT_EQ(s->highest_completed_chunk, -1); + EXPECT_EQ(s->bitmap_byte_aligned_start, 0); + EXPECT_FALSE(s->IsComplete()); + EXPECT_EQ(s->CalculateDownloadedSize(), 0); + EXPECT_EQ(s->GetPendingChunks().size(), static_cast(kNumChunks)); +} + +TEST(BlobDownloadStateTest, MarkChunkCompleteUpdatesBitmapAndCounter) { + TempDir d; + auto local = d.path() / "blob.bin"; + auto s = BlobDownloadState::CreateNew("blob", local, kBlobSize, kChunkSize, kNumChunks); + s->MarkChunkComplete(3); + EXPECT_TRUE(s->IsChunkComplete(3)); + EXPECT_FALSE(s->IsChunkComplete(2)); + EXPECT_EQ(s->completed_count, 1); + EXPECT_EQ(s->highest_completed_chunk, 3); + EXPECT_EQ(s->CalculateDownloadedSize(), kChunkSize); +} + +TEST(BlobDownloadStateTest, MarkChunkCompleteIsIdempotent) { + TempDir d; + auto local = d.path() / "blob.bin"; + auto s = BlobDownloadState::CreateNew("blob", local, kBlobSize, kChunkSize, kNumChunks); + s->MarkChunkComplete(5); + s->MarkChunkComplete(5); + s->MarkChunkComplete(5); + EXPECT_EQ(s->completed_count, 1); +} + +TEST(BlobDownloadStateTest, CalculateDownloadedSizeAccountsForPartialFinalChunk) { + TempDir d; + auto local = d.path() / "blob.bin"; + constexpr int64_t kOddBlobSize = 5 * 1024 * 1024 + 17; // last chunk is 17 bytes + constexpr int32_t kOddNumChunks = 3; + auto s = BlobDownloadState::CreateNew("blob", local, kOddBlobSize, kChunkSize, kOddNumChunks); + for (int32_t i = 0; i < kOddNumChunks; ++i) { + s->MarkChunkComplete(i); + } + EXPECT_TRUE(s->IsComplete()); + EXPECT_EQ(s->CalculateDownloadedSize(), kOddBlobSize); +} + +TEST(BlobDownloadStateTest, GetPendingChunksReturnsGaps) { + TempDir d; + auto local = d.path() / "blob.bin"; + auto s = BlobDownloadState::CreateNew("blob", local, kBlobSize, kChunkSize, kNumChunks); + for (int32_t i : {0, 1, 2, 5, 7}) { + s->MarkChunkComplete(i); + } + auto pending = s->GetPendingChunks(); + std::vector expected{3, 4, 6, 8, 9}; + EXPECT_EQ(pending, expected); +} + +TEST(BlobDownloadStateTest, SaveAndLoadRoundTrip) { + TempDir d; + auto local = d.path() / "blob.bin"; + { + auto s = BlobDownloadState::CreateNew("blob", local, kBlobSize, kChunkSize, kNumChunks); + for (int32_t i : {0, 2, 4, 6, 8}) { + s->MarkChunkComplete(i); + } + s->SaveState(); + } + auto loaded = BlobDownloadState::LoadState("blob", local, kBlobSize, kChunkSize, kNumChunks); + ASSERT_NE(loaded, nullptr); + EXPECT_EQ(loaded->completed_count, 5); + EXPECT_EQ(loaded->highest_completed_chunk, 8); + for (int32_t i : {0, 2, 4, 6, 8}) { + EXPECT_TRUE(loaded->IsChunkComplete(i)) << "chunk " << i; + } + for (int32_t i : {1, 3, 5, 7, 9}) { + EXPECT_FALSE(loaded->IsChunkComplete(i)) << "chunk " << i; + } + std::vector expected{1, 3, 5, 7, 9}; + EXPECT_EQ(loaded->GetPendingChunks(), expected); +} + +TEST(BlobDownloadStateTest, SaveStateAdvancesBitmapByteAlignedStart) { + TempDir d; + auto local = d.path() / "blob.bin"; + // Use a large enough total that whole-word advance is meaningful. + constexpr int32_t kBigNumChunks = 200; + constexpr int64_t kBigBlobSize = static_cast(kBigNumChunks) * kChunkSize; + auto s = BlobDownloadState::CreateNew("blob", local, kBigBlobSize, kChunkSize, kBigNumChunks); + // Complete the first 80 chunks (10 full bytes worth). + for (int32_t i = 0; i < 80; ++i) { + s->MarkChunkComplete(i); + } + s->SaveState(); + // 64 bits = 1 full word; next 16 bits in word 1. Aligned start lands on + // 80 (multiple of 8). + EXPECT_EQ(s->bitmap_byte_aligned_start, 80); + + // Reload and verify the implicit prefix is still considered complete. + auto loaded = BlobDownloadState::LoadState("blob", local, kBigBlobSize, kChunkSize, kBigNumChunks); + ASSERT_NE(loaded, nullptr); + for (int32_t i = 0; i < 80; ++i) { + EXPECT_TRUE(loaded->IsChunkComplete(i)); + } + for (int32_t i = 80; i < kBigNumChunks; ++i) { + EXPECT_FALSE(loaded->IsChunkComplete(i)); + } + EXPECT_EQ(loaded->completed_count, 80); +} + +TEST(BlobDownloadStateTest, LoadStateReturnsNullWhenFileMissing) { + TempDir d; + auto local = d.path() / "blob.bin"; + auto s = BlobDownloadState::LoadState("blob", local, kBlobSize, kChunkSize, kNumChunks); + EXPECT_EQ(s, nullptr); +} + +TEST(BlobDownloadStateTest, LoadStateRejectsBadMagic) { + TempDir d; + auto local = d.path() / "blob.bin"; + auto sidecar = BlobDownloadState::GetStateFilePath(local); + { + std::ofstream f(sidecar, std::ios::binary); + f << "ZZZZ"; // wrong magic + f.put(static_cast(0)); // version + for (int i = 0; i < 64; ++i) f.put(0); // padding + } + auto s = BlobDownloadState::LoadState("blob", local, kBlobSize, kChunkSize, kNumChunks); + EXPECT_EQ(s, nullptr); +} + +TEST(BlobDownloadStateTest, LoadStateRejectsBlobSizeMismatch) { + TempDir d; + auto local = d.path() / "blob.bin"; + { + auto s = BlobDownloadState::CreateNew("blob", local, kBlobSize, kChunkSize, kNumChunks); + s->MarkChunkComplete(0); + s->SaveState(); + } + // Reload with a *different* expected blob_size — should be rejected. + auto s = BlobDownloadState::LoadState("blob", local, kBlobSize + 1, kChunkSize, kNumChunks); + EXPECT_EQ(s, nullptr); +} + +TEST(BlobDownloadStateTest, LoadStateRejectsChunkSizeMismatch) { + TempDir d; + auto local = d.path() / "blob.bin"; + { + auto s = BlobDownloadState::CreateNew("blob", local, kBlobSize, kChunkSize, kNumChunks); + s->MarkChunkComplete(0); + s->SaveState(); + } + auto s = BlobDownloadState::LoadState("blob", local, kBlobSize, kChunkSize + 1, kNumChunks); + EXPECT_EQ(s, nullptr); +} + +TEST(BlobDownloadStateTest, LoadStateRejectsTotalChunksMismatch) { + TempDir d; + auto local = d.path() / "blob.bin"; + { + auto s = BlobDownloadState::CreateNew("blob", local, kBlobSize, kChunkSize, kNumChunks); + s->MarkChunkComplete(0); + s->SaveState(); + } + auto s = BlobDownloadState::LoadState("blob", local, kBlobSize, kChunkSize, kNumChunks + 1); + EXPECT_EQ(s, nullptr); +} + +TEST(BlobDownloadStateTest, DeleteStateRemovesSidecar) { + TempDir d; + auto local = d.path() / "blob.bin"; + { + auto s = BlobDownloadState::CreateNew("blob", local, kBlobSize, kChunkSize, kNumChunks); + s->MarkChunkComplete(0); + s->SaveState(); + } + EXPECT_TRUE(fs::exists(BlobDownloadState::GetStateFilePath(local))); + BlobDownloadState::DeleteState(local); + EXPECT_FALSE(fs::exists(BlobDownloadState::GetStateFilePath(local))); + // Re-deletion when the file is already absent is a no-op (best-effort). + BlobDownloadState::DeleteState(local); +} + +TEST(BlobDownloadStateTest, IsCompleteFlipsTrueWhenAllChunksMarked) { + TempDir d; + auto local = d.path() / "blob.bin"; + auto s = BlobDownloadState::CreateNew("blob", local, kBlobSize, kChunkSize, kNumChunks); + for (int32_t i = 0; i < kNumChunks; ++i) { + EXPECT_FALSE(s->IsComplete()); + s->MarkChunkComplete(i); + } + EXPECT_TRUE(s->IsComplete()); + EXPECT_EQ(s->GetPendingChunks().size(), 0u); +} diff --git a/sdk_v2/cpp/test/internal_api/cross_process_file_lock_test.cc b/sdk_v2/cpp/test/internal_api/cross_process_file_lock_test.cc new file mode 100644 index 000000000..a6e38fdfc --- /dev/null +++ b/sdk_v2/cpp/test/internal_api/cross_process_file_lock_test.cc @@ -0,0 +1,188 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "download/cross_process_file_lock.h" + +#include "exception.h" + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace fs = std::filesystem; + +using namespace fl; + +namespace { + +/// Per-test temp directory. Auto-cleans on destruction so a flaky test never +/// leaks lock files into the system temp dir. +class TempDir { + public: + TempDir() { + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution dist; + path_ = fs::temp_directory_path() / ("fl_lock_test_" + std::to_string(dist(gen))); + fs::create_directories(path_); + } + + ~TempDir() { + std::error_code ec; + fs::remove_all(path_, ec); + } + + const fs::path& path() const { return path_; } + + private: + fs::path path_; +}; + +} // namespace + +TEST(CrossProcessFileLockTest, TryAcquireSucceedsForFreshDirectory) { + TempDir dir; + + auto lock = CrossProcessFileLock::TryAcquireForDirectory(dir.path()); + + ASSERT_NE(lock, nullptr); + EXPECT_TRUE(fs::exists(lock->path())); + EXPECT_EQ(lock->path().parent_path(), dir.path()); + EXPECT_EQ(lock->path().filename(), ".download.lock"); +} + +TEST(CrossProcessFileLockTest, ReleaseOnDestructionRemovesLockFile) { + TempDir dir; + fs::path lock_file; + + { + auto lock = CrossProcessFileLock::TryAcquireForDirectory(dir.path()); + ASSERT_NE(lock, nullptr); + lock_file = lock->path(); + EXPECT_TRUE(fs::exists(lock_file)); + } + + // After RAII release the lock file should be gone (Win FILE_FLAG_DELETE_ON_CLOSE, + // POSIX explicit unlink in destructor). + EXPECT_FALSE(fs::exists(lock_file)); +} + +TEST(CrossProcessFileLockTest, SecondAcquireReturnsNullWhileFirstIsHeld) { + TempDir dir; + auto first = CrossProcessFileLock::TryAcquireForDirectory(dir.path()); + ASSERT_NE(first, nullptr); + + auto second = CrossProcessFileLock::TryAcquireForDirectory(dir.path()); + EXPECT_EQ(second, nullptr); +} + +TEST(CrossProcessFileLockTest, ReacquireSucceedsAfterRelease) { + TempDir dir; + { + auto first = CrossProcessFileLock::TryAcquireForDirectory(dir.path()); + ASSERT_NE(first, nullptr); + } + auto reacquired = CrossProcessFileLock::TryAcquireForDirectory(dir.path()); + EXPECT_NE(reacquired, nullptr); +} + +TEST(CrossProcessFileLockTest, CreatesDirectoryIfMissing) { + TempDir parent; + auto missing = parent.path() / "nested" / "model"; + + ASSERT_FALSE(fs::exists(missing)); + + auto lock = CrossProcessFileLock::TryAcquireForDirectory(missing); + + ASSERT_NE(lock, nullptr); + EXPECT_TRUE(fs::is_directory(missing)); + EXPECT_TRUE(fs::exists(missing / ".download.lock")); +} + +TEST(CrossProcessFileLockTest, WaitForLockReturnsImmediatelyWhenAvailable) { + TempDir dir; + + auto start = std::chrono::steady_clock::now(); + auto lock = WaitForLockForDirectory(dir.path(), []() { return false; }); + auto elapsed = std::chrono::steady_clock::now() - start; + + ASSERT_NE(lock, nullptr); + // Fast-path acquisition should be well under 100 ms. + EXPECT_LT(elapsed, std::chrono::milliseconds(500)); +} + +TEST(CrossProcessFileLockTest, WaitForLockAcquiresAfterHolderReleases) { + TempDir dir; + auto holder = CrossProcessFileLock::TryAcquireForDirectory(dir.path()); + ASSERT_NE(holder, nullptr); + + // Release the holder after a short delay on another thread. + std::thread releaser([&] { + std::this_thread::sleep_for(std::chrono::milliseconds(300)); + holder.reset(); + }); + + auto start = std::chrono::steady_clock::now(); + auto lock = WaitForLockForDirectory(dir.path(), + []() { return false; }, + /*logger=*/nullptr, + /*poll_interval=*/std::chrono::milliseconds(100), + /*timeout=*/std::chrono::seconds(10)); + auto elapsed = std::chrono::steady_clock::now() - start; + + releaser.join(); + ASSERT_NE(lock, nullptr); + EXPECT_GE(elapsed, std::chrono::milliseconds(200)); + EXPECT_LT(elapsed, std::chrono::seconds(5)); +} + +TEST(CrossProcessFileLockTest, WaitForLockThrowsOnCancellation) { + TempDir dir; + auto holder = CrossProcessFileLock::TryAcquireForDirectory(dir.path()); + ASSERT_NE(holder, nullptr); + + std::atomic cancel{false}; + std::thread canceller([&] { + std::this_thread::sleep_for(std::chrono::milliseconds(150)); + cancel.store(true); + }); + + try { + (void)WaitForLockForDirectory(dir.path(), + [&cancel]() { return cancel.load(); }, + /*logger=*/nullptr, + /*poll_interval=*/std::chrono::milliseconds(100), + /*timeout=*/std::chrono::seconds(10)); + canceller.join(); + FAIL() << "expected fl::Exception(FOUNDRY_LOCAL_ERROR_OPERATION_CANCELLED)"; + } catch (const Exception& ex) { + canceller.join(); + EXPECT_EQ(ex.code(), FOUNDRY_LOCAL_ERROR_OPERATION_CANCELLED); + } +} + +TEST(CrossProcessFileLockTest, WaitForLockThrowsOnTimeout) { + TempDir dir; + auto holder = CrossProcessFileLock::TryAcquireForDirectory(dir.path()); + ASSERT_NE(holder, nullptr); + + try { + (void)WaitForLockForDirectory(dir.path(), + []() { return false; }, + /*logger=*/nullptr, + /*poll_interval=*/std::chrono::milliseconds(50), + /*timeout=*/std::chrono::milliseconds(200)); + FAIL() << "expected fl::Exception(FOUNDRY_LOCAL_ERROR_INTERNAL)"; + } catch (const Exception& ex) { + EXPECT_EQ(ex.code(), FOUNDRY_LOCAL_ERROR_INTERNAL); + std::string what = ex.what(); + EXPECT_NE(what.find("timed out"), std::string::npos); + } +} diff --git a/sdk_v2/cpp/test/internal_api/download_test.cc b/sdk_v2/cpp/test/internal_api/download_test.cc index 66f189a98..246eb0cf5 100644 --- a/sdk_v2/cpp/test/internal_api/download_test.cc +++ b/sdk_v2/cpp/test/internal_api/download_test.cc @@ -10,6 +10,7 @@ #include "catalog/azure_catalog_client.h" #endif #include "catalog/azure_catalog_models.h" +#include "download/blob_download_state.h" #include "download/blob_downloader.h" #include "download/download_manager.h" #include "download/inference_model_writer.h" @@ -24,9 +25,12 @@ #include #include +#include #include #include #include +#include +#include #include #include #include @@ -390,6 +394,99 @@ TEST(BlobDownloadTest, HandlesEmptyBlobList) { EXPECT_TRUE(mock.downloaded_blobs.empty()); } +// ======================================================================== +// Skip-existing (Increment 1: resumable downloads) +// ======================================================================== + +TEST(BlobDownloadTest, SkipsExistingFilesWithCorrectSize) { + TempDir tmpdir; + // Pre-create one of the blobs at the expected size on disk. + std::ofstream(tmpdir.path() / "weights.safetensors") << std::string(1000, 'X'); + + MockBlobDownloader mock; + mock.blobs_to_return = { + {"weights.safetensors", 1000}, + {"config.json", 100}, + }; + + BlobDownloadOptions opts; + DownloadBlobsToDirectory(mock, "https://test.blob/c?sig=x", tmpdir.string(), opts); + + // Only the missing blob should be downloaded. + ASSERT_EQ(mock.downloaded_blobs.size(), 1u); + EXPECT_EQ(mock.downloaded_blobs[0], "config.json"); +} + +TEST(BlobDownloadTest, RedownloadsFilesWithWrongSize) { + TempDir tmpdir; + // Existing file is truncated relative to the expected blob size. + std::ofstream(tmpdir.path() / "weights.safetensors") << std::string(500, 'X'); + + MockBlobDownloader mock; + mock.blobs_to_return = { + {"weights.safetensors", 1000}, + }; + + BlobDownloadOptions opts; + DownloadBlobsToDirectory(mock, "https://test.blob/c?sig=x", tmpdir.string(), opts); + + // Wrong-size files should be redownloaded (the mock overwrites them). + ASSERT_EQ(mock.downloaded_blobs.size(), 1u); + EXPECT_EQ(mock.downloaded_blobs[0], "weights.safetensors"); +} + +TEST(BlobDownloadTest, ReportsSkippedBytesInInitialProgress) { + TempDir tmpdir; + // 500 of 1500 bytes already on disk → initial progress should be ~33%. + std::ofstream(tmpdir.path() / "already.bin") << std::string(500, 'X'); + + MockBlobDownloader mock; + mock.blobs_to_return = { + {"already.bin", 500}, + {"missing.bin", 1000}, + }; + + std::vector progress_values; + BlobDownloadOptions opts; + opts.progress = [&](float pct) { + progress_values.push_back(pct); + return 0; + }; + + DownloadBlobsToDirectory(mock, "https://test.blob/c?sig=x", tmpdir.string(), opts); + + ASSERT_FALSE(progress_values.empty()); + // First emitted progress reflects the already-on-disk bytes (500/1500 ≈ 33.3%). + EXPECT_NEAR(progress_values.front(), 100.0f * 500.0f / 1500.0f, 0.5f); + // Final progress must hit 100%. + EXPECT_FLOAT_EQ(progress_values.back(), 100.0f); +} + +TEST(BlobDownloadTest, EmitsHundredPercentWhenEverythingIsCached) { + TempDir tmpdir; + std::ofstream(tmpdir.path() / "a.bin") << std::string(100, 'A'); + std::ofstream(tmpdir.path() / "b.bin") << std::string(200, 'B'); + + MockBlobDownloader mock; + mock.blobs_to_return = { + {"a.bin", 100}, + {"b.bin", 200}, + }; + + std::vector progress_values; + BlobDownloadOptions opts; + opts.progress = [&](float pct) { + progress_values.push_back(pct); + return 0; + }; + + DownloadBlobsToDirectory(mock, "https://test.blob/c?sig=x", tmpdir.string(), opts); + + EXPECT_TRUE(mock.downloaded_blobs.empty()); + ASSERT_FALSE(progress_values.empty()); + EXPECT_FLOAT_EQ(progress_values.back(), 100.0f); +} + // ======================================================================== // Path-traversal hardening (security) // ======================================================================== @@ -1113,3 +1210,232 @@ TEST(DownloadManagerTest, AcceptsNormalModelIdAndPublisher) { EXPECT_NO_THROW(manager.IsModelCached(info)); EXPECT_FALSE(manager.IsModelCached(info)); } + +// ======================================================================== +// AzureBlobDownloader resume + cancel-cascade tests +// Use a subclass that overrides the protected GetBlobSize / DownloadChunkStreaming +// virtuals to bypass the real Azure SDK and simulate per-chunk behavior. +// ======================================================================== + +namespace { + +/// Test double for AzureBlobDownloader. Overrides the protected virtuals so +/// chunked-download orchestration can be exercised without network I/O. +class FakeChunkAzureDownloader : public AzureBlobDownloader { + public: + int64_t blob_size = 0; + + /// Per-call hook. Receives the chunk offset and size plus a `sink` callback + /// that forwards bytes to the file writer. Allowed to: + /// - call `sink` zero or more times with strictly contiguous, cumulative + /// `size`-byte ranges to simulate a successful chunk + /// - throw to simulate a transient failure (sink calls so far still hit disk) + /// - sleep / poll cancellation + std::function& sink, + std::atomic* cancel_flag)> + chunk_hook; + + std::atomic chunk_call_count{0}; + std::mutex offsets_mutex; + std::vector requested_offsets; + + using AzureBlobDownloader::AzureBlobDownloader; + + protected: + int64_t GetBlobSize(ChunkContext& /*ctx*/) override { return blob_size; } + + void DownloadChunkStreaming(ChunkContext& ctx, int64_t offset, int64_t size, + std::vector& scratch, + const std::function& sink) override { + chunk_call_count.fetch_add(1); + { + std::lock_guard lock(offsets_mutex); + requested_offsets.push_back(offset); + } + if (chunk_hook) { + chunk_hook(offset, size, sink, GetCancelFlag(ctx)); + return; + } + // Default: stream the chunk to the sink in scratch-sized pieces, filled + // with the low byte of the offset for verification. + if (scratch.size() < 64 * 1024) { + scratch.resize(64 * 1024); + } + int64_t remaining = size; + while (remaining > 0) { + size_t to_emit = + static_cast(std::min(remaining, static_cast(scratch.size()))); + std::fill_n(scratch.begin(), to_emit, static_cast(offset & 0xFF)); + sink(scratch.data(), to_emit); + remaining -= static_cast(to_emit); + } + } +}; + +} // namespace + +TEST(AzureBlobDownloaderResumeTest, SkipsChunksAlreadyMarkedCompleteInSidecar) { + TempDir tmpdir; + auto local = tmpdir.path() / "blob.bin"; + + constexpr int32_t kChunkSize = 2 * 1024 * 1024; + constexpr int32_t kNumChunks = 10; + constexpr int64_t kBlobSize = static_cast(kNumChunks) * kChunkSize; + + // Pre-allocate the data file so the downloader takes the resume path. + { + std::ofstream f(local, std::ios::binary); + f.seekp(kBlobSize - 1); + f.put('\0'); + } + // Pre-write a sidecar: chunks 0..4 done, 5..9 pending. + { + auto state = BlobDownloadState::CreateNew("blob", local, kBlobSize, kChunkSize, kNumChunks); + for (int32_t i = 0; i < 5; ++i) { + state->MarkChunkComplete(i); + } + state->SaveState(); + } + + FakeChunkAzureDownloader d; + d.blob_size = kBlobSize; + + d.DownloadBlob(/*sas_uri=*/"", "blob", local.string(), /*max_concurrency=*/2); + + EXPECT_EQ(d.chunk_call_count.load(), 5); + std::sort(d.requested_offsets.begin(), d.requested_offsets.end()); + std::vector expected{5 * int64_t{kChunkSize}, 6 * int64_t{kChunkSize}, + 7 * int64_t{kChunkSize}, 8 * int64_t{kChunkSize}, + 9 * int64_t{kChunkSize}}; + EXPECT_EQ(d.requested_offsets, expected); + + // Sidecar should be gone on full success. + EXPECT_FALSE(fs::exists(BlobDownloadState::GetStateFilePath(local))); +} + +TEST(AzureBlobDownloaderResumeTest, DownloadsAllChunksWhenSidecarMissing) { + TempDir tmpdir; + auto local = tmpdir.path() / "blob.bin"; + + constexpr int32_t kChunkSize = 2 * 1024 * 1024; + constexpr int32_t kNumChunks = 4; + constexpr int64_t kBlobSize = static_cast(kNumChunks) * kChunkSize; + + FakeChunkAzureDownloader d; + d.blob_size = kBlobSize; + + d.DownloadBlob(/*sas_uri=*/"", "blob", local.string(), /*max_concurrency=*/4); + + EXPECT_EQ(d.chunk_call_count.load(), kNumChunks); + EXPECT_FALSE(fs::exists(BlobDownloadState::GetStateFilePath(local))); + // Local file is pre-allocated to blob_size during the first pass. + EXPECT_TRUE(fs::exists(local)); + EXPECT_EQ(fs::file_size(local), static_cast(kBlobSize)); +} + +TEST(AzureBlobDownloaderResumeTest, PersistsSidecarOnChunkFailure) { + TempDir tmpdir; + auto local = tmpdir.path() / "blob.bin"; + + constexpr int32_t kChunkSize = 2 * 1024 * 1024; + constexpr int32_t kNumChunks = 10; + constexpr int64_t kBlobSize = static_cast(kNumChunks) * kChunkSize; + + FakeChunkAzureDownloader d; + d.blob_size = kBlobSize; + // Fail when we see the offset of chunk 4 (specifically chosen so several + // chunks land before the failing one across threads). + constexpr int64_t kFailOffset = 4 * int64_t{kChunkSize}; + d.chunk_hook = [&](int64_t offset, int64_t size, + const std::function& sink, + std::atomic* /*cancel_flag*/) { + if (offset == kFailOffset) { + FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, "simulated chunk failure"); + } + std::vector buf(static_cast(size), static_cast(offset & 0xFF)); + sink(buf.data(), buf.size()); + }; + + EXPECT_THROW( + d.DownloadBlob(/*sas_uri=*/"", "blob", local.string(), /*max_concurrency=*/2), + fl::Exception); + + // The sidecar should be persisted so a subsequent call can resume. + EXPECT_TRUE(fs::exists(BlobDownloadState::GetStateFilePath(local))); + + // On resume with the same offset blocked, we should still hit the failure + // but skip already-completed chunks. Strip the failure and rerun: the + // downloader should only process the chunks that weren't completed. + auto retry_state = BlobDownloadState::LoadState("blob", local, kBlobSize, kChunkSize, kNumChunks); + ASSERT_NE(retry_state, nullptr); + EXPECT_GT(retry_state->completed_count, 0); + EXPECT_LT(retry_state->completed_count, kNumChunks); +} + +TEST(AzureBlobDownloaderResumeTest, CleansUpSidecarOnEmptyBlob) { + TempDir tmpdir; + auto local = tmpdir.path() / "empty.bin"; + // Plant a stale sidecar. + { + std::ofstream f(BlobDownloadState::GetStateFilePath(local), std::ios::binary); + f << "stale"; + } + + FakeChunkAzureDownloader d; + d.blob_size = 0; // empty + + d.DownloadBlob(/*sas_uri=*/"", "empty", local.string(), /*max_concurrency=*/4); + + EXPECT_TRUE(fs::exists(local)); + EXPECT_EQ(fs::file_size(local), 0u); + EXPECT_FALSE(fs::exists(BlobDownloadState::GetStateFilePath(local))); + EXPECT_EQ(d.chunk_call_count.load(), 0); +} + +TEST(AzureBlobDownloaderResumeTest, ChunkFailureCancelsInFlightPeersFast) { + TempDir tmpdir; + auto local = tmpdir.path() / "blob.bin"; + + constexpr int32_t kChunkSize = 2 * 1024 * 1024; + constexpr int32_t kNumChunks = 10; + constexpr int64_t kBlobSize = static_cast(kNumChunks) * kChunkSize; + constexpr int64_t kFailOffset = 4 * int64_t{kChunkSize}; + + FakeChunkAzureDownloader d; + d.blob_size = kBlobSize; + // The failing chunk throws fast. Every other chunk sleeps for up to 5 s in + // 50-ms slices, polling the cancel flag. If linked cancellation works, they + // observe the flag within one slice of the failure and exit promptly. + d.chunk_hook = [](int64_t offset, int64_t size, + const std::function& sink, + std::atomic* cancel_flag) { + if (offset == kFailOffset) { + // Give other workers a moment to enter their sleep loop before we throw, + // so we're meaningfully testing the cancel-while-in-flight path. + std::this_thread::sleep_for(std::chrono::milliseconds(75)); + FL_THROW(FOUNDRY_LOCAL_ERROR_INTERNAL, "simulated chunk failure"); + } + for (int i = 0; i < 100; ++i) { + if (cancel_flag && cancel_flag->load(std::memory_order_relaxed)) { + FL_THROW(FOUNDRY_LOCAL_ERROR_OPERATION_CANCELLED, "cancelled mid-chunk"); + } + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + } + std::vector buf(static_cast(size), 0); + sink(buf.data(), buf.size()); + }; + + auto start = std::chrono::steady_clock::now(); + EXPECT_THROW( + d.DownloadBlob(/*sas_uri=*/"", "blob", local.string(), /*max_concurrency=*/kNumChunks), + fl::Exception); + auto elapsed = std::chrono::steady_clock::now() - start; + auto elapsed_ms = std::chrono::duration_cast(elapsed).count(); + + // Without cancellation, the slow chunks would sleep ~5 s. With it, they + // should all exit within a few hundred ms of the failure (well under 2 s). + EXPECT_LT(elapsed_ms, 2000) + << "Cancel-cascade should drain in-flight peers fast; took " << elapsed_ms << " ms"; +} + diff --git a/sdk_v2/cpp/test/internal_api/file_writer_test.cc b/sdk_v2/cpp/test/internal_api/file_writer_test.cc new file mode 100644 index 000000000..84134d68a --- /dev/null +++ b/sdk_v2/cpp/test/internal_api/file_writer_test.cc @@ -0,0 +1,272 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +// Tests for the IFileWriter abstraction backing AzureBlobDownloader's chunked +// writes. Exercises both implementations (Positional / MutexFstream) through a +// parametrized fixture so every correctness assertion runs against both. +// +// The "PerfComparison" test prints wall-clock numbers for a representative +// download workload (32 threads, 64-way chunked streaming into a 256 MB file) +// so we can eyeball lock contention deltas without adding a separate +// microbenchmark binary. It is informational — its only EXPECT is that both +// runs complete and the file ends up at the right size. + +#include "download/file_writer.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace fs = std::filesystem; +using namespace fl; + +namespace { + +class TempPath { + public: + TempPath() { + auto base = fs::temp_directory_path(); + std::random_device rd; + std::uniform_int_distribution dist; + path_ = base / ("file_writer_test_" + std::to_string(dist(rd)) + ".bin"); + } + ~TempPath() { + std::error_code ec; + fs::remove(path_, ec); + } + const fs::path& path() const { return path_; } + + private: + fs::path path_; +}; + +std::unique_ptr MakeWriter(const std::string& kind) { + if (kind == "Positional") return MakePositionalFileWriter(); + if (kind == "MutexFstream") return MakeMutexFstreamFileWriter(); + ADD_FAILURE() << "unknown writer kind " << kind; + return nullptr; +} + +class FileWriterTest : public ::testing::TestWithParam {}; + +} // namespace + +TEST_P(FileWriterTest, OpenCreatesFileAtRequestedSize) { + TempPath p; + auto w = MakeWriter(GetParam()); + ASSERT_NE(w, nullptr); + w->Open(p.path(), 4096); + w->Close(); + EXPECT_TRUE(fs::exists(p.path())); + EXPECT_EQ(fs::file_size(p.path()), 4096u); +} + +TEST_P(FileWriterTest, OpenPreservesExistingFileAtSameSize) { + TempPath p; + // Pre-write a sentinel byte the writer must NOT overwrite. + { + std::ofstream f(p.path(), std::ios::binary); + f.seekp(1023); + f.put('\0'); + } + // Plant a known byte at offset 100. + { + std::fstream f(p.path(), std::ios::binary | std::ios::in | std::ios::out); + f.seekp(100); + f.put(static_cast(0xAB)); + } + + auto w = MakeWriter(GetParam()); + ASSERT_NE(w, nullptr); + w->Open(p.path(), 1024); // same size -> must not truncate + w->Close(); + + // Sentinel byte should still be there. + std::ifstream f(p.path(), std::ios::binary); + f.seekg(100); + int byte = f.get(); + EXPECT_EQ(byte, 0xAB); +} + +TEST_P(FileWriterTest, OpenTruncatesIfSizeChanged) { + TempPath p; + { + std::ofstream f(p.path(), std::ios::binary); + f.seekp(100); + f.put(static_cast(0xCD)); + } + EXPECT_EQ(fs::file_size(p.path()), 101u); + + auto w = MakeWriter(GetParam()); + ASSERT_NE(w, nullptr); + w->Open(p.path(), 4096); + w->Close(); + EXPECT_EQ(fs::file_size(p.path()), 4096u); +} + +TEST_P(FileWriterTest, SingleThreadWriteAt) { + TempPath p; + auto w = MakeWriter(GetParam()); + ASSERT_NE(w, nullptr); + w->Open(p.path(), 1024); + + std::vector data(256, 0xEF); + w->WriteAt(512, data.data(), data.size()); + w->Close(); + + std::ifstream f(p.path(), std::ios::binary); + std::vector contents((std::istreambuf_iterator(f)), + std::istreambuf_iterator()); + ASSERT_EQ(contents.size(), 1024u); + for (size_t i = 512; i < 768; ++i) { + EXPECT_EQ(contents[i], 0xEF) << "byte " << i; + } +} + +TEST_P(FileWriterTest, ConcurrentDisjointWritesProduceCorrectFile) { + TempPath p; + constexpr int kThreads = 8; + constexpr int kRegionSize = 256 * 1024; // 256 KB per thread + constexpr int kPieceSize = 16 * 1024; // 16 KB per WriteAt + constexpr int64_t kTotalSize = int64_t{kThreads} * kRegionSize; + static_assert(kRegionSize % kPieceSize == 0, ""); + + auto w = MakeWriter(GetParam()); + ASSERT_NE(w, nullptr); + w->Open(p.path(), kTotalSize); + + std::atomic started{0}; + std::vector workers; + workers.reserve(kThreads); + for (int t = 0; t < kThreads; ++t) { + workers.emplace_back([&, t]() { + std::vector piece(kPieceSize, static_cast(t + 1)); + started.fetch_add(1); + while (started.load() < kThreads) { + // tiny spin to encourage concurrent dispatch + } + const int64_t base = int64_t{t} * kRegionSize; + for (int i = 0; i < kRegionSize / kPieceSize; ++i) { + w->WriteAt(base + int64_t{i} * kPieceSize, piece.data(), piece.size()); + } + }); + } + for (auto& th : workers) th.join(); + w->Close(); + + std::ifstream f(p.path(), std::ios::binary); + std::vector contents((std::istreambuf_iterator(f)), + std::istreambuf_iterator()); + ASSERT_EQ(contents.size(), static_cast(kTotalSize)); + for (int t = 0; t < kThreads; ++t) { + const uint8_t expected = static_cast(t + 1); + for (int64_t i = 0; i < kRegionSize; ++i) { + const auto idx = static_cast(int64_t{t} * kRegionSize + i); + if (contents[idx] != expected) { + FAIL() << "mismatch at offset " << idx << " (thread " << t << ", expected " + << static_cast(expected) << ", got " << static_cast(contents[idx]) << ")"; + } + } + } +} + +INSTANTIATE_TEST_SUITE_P(WriterImpls, FileWriterTest, + ::testing::Values("Positional", "MutexFstream"), + [](const ::testing::TestParamInfo& info) { + return info.param; + }); + +// --------------------------------------------------------------------------- +// Perf comparison: print wall-clock for both writer kinds against a workload +// that mirrors AzureBlobDownloader (32 workers each streaming 8 chunks of 2 MB +// in 64 KB sink pieces). Run direct: +// foundry_local_tests --gtest_filter=FileWriterPerfComparison.* +// --------------------------------------------------------------------------- + +namespace { + +struct PerfResult { + std::string kind; + int64_t elapsed_ms; + double mb_per_sec; +}; + +PerfResult RunChunkedWorkload(const std::string& kind) { + constexpr int kThreads = 32; + constexpr int kChunksPerThread = 8; + constexpr int kChunkSize = 2 * 1024 * 1024; // 2 MB chunk like the downloader + constexpr int kPieceSize = 64 * 1024; // 64 KB scratch like the downloader + constexpr int64_t kTotalSize = int64_t{kThreads} * kChunksPerThread * kChunkSize; + static_assert(kChunkSize % kPieceSize == 0, ""); + + TempPath p; + auto w = MakeWriter(kind); + if (!w) { + ADD_FAILURE() << "MakeWriter returned null for " << kind; + return {kind, 0, 0.0}; + } + w->Open(p.path(), kTotalSize); + + std::atomic next_chunk{0}; + const int total_chunks = kThreads * kChunksPerThread; + + auto start = std::chrono::steady_clock::now(); + std::vector workers; + workers.reserve(kThreads); + for (int t = 0; t < kThreads; ++t) { + workers.emplace_back([&, t]() { + std::vector scratch(kPieceSize, static_cast(t & 0xFF)); + while (true) { + int i = next_chunk.fetch_add(1, std::memory_order_relaxed); + if (i >= total_chunks) return; + const int64_t chunk_off = int64_t{i} * kChunkSize; + for (int pos = 0; pos < kChunkSize; pos += kPieceSize) { + w->WriteAt(chunk_off + pos, scratch.data(), kPieceSize); + } + } + }); + } + for (auto& th : workers) th.join(); + w->Close(); + auto elapsed = std::chrono::steady_clock::now() - start; + auto ms = std::chrono::duration_cast(elapsed).count(); + + EXPECT_EQ(fs::file_size(p.path()), static_cast(kTotalSize)); + + double mb_per_sec = + static_cast(kTotalSize) / (1024.0 * 1024.0) / (static_cast(ms) / 1000.0); + return {kind, ms, mb_per_sec}; +} + +} // namespace + +TEST(FileWriterPerfComparison, PositionalVsMutexFstream) { + std::vector results; + results.push_back(RunChunkedWorkload("Positional")); + results.push_back(RunChunkedWorkload("MutexFstream")); + + std::cout << "\n=== IFileWriter perf comparison ===\n"; + std::cout << "Workload: 32 workers, 8 chunks/worker, 2 MB chunks, 64 KB sink pieces (512 MB total)\n"; + for (const auto& r : results) { + std::cout << " " << r.kind << ": " << r.elapsed_ms << " ms (" + << static_cast(r.mb_per_sec) << " MB/s)\n"; + } + std::cout << "===================================\n" << std::endl; + + // Sanity: both should make positive progress; perf is informational. + for (const auto& r : results) { + EXPECT_GT(r.mb_per_sec, 0.0) << r.kind; + } +}