Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 9 additions & 28 deletions sdk_v2/cpp/src/ep_detection/cuda_ep_bootstrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
// Licensed under the MIT License.
#include "ep_detection/cuda_ep_bootstrapper.h"

#include "ep_detection/ep_utils.h"
#include "logger.h"
#include "util/file_lock.h"
#include "http/http_download.h"
#include "util/sha256.h"
#include "util/zip_extract.h"

#include <fmt/format.h>
Expand Down Expand Up @@ -61,31 +61,6 @@ constexpr ExpectedBinary kExpectedBinaries[] = {
constexpr const char* kRegistrationName = "Foundry.CUDA";
constexpr const char* kCudaProviderDll = "onnxruntime_providers_cuda.dll";

/// Verify all expected binaries exist and have correct SHA256 hashes.
bool VerifyPackage(const std::filesystem::path& dir, fl::ILogger& logger) {
for (const auto& expected : kExpectedBinaries) {
auto file_path = dir / expected.filename;

if (!std::filesystem::exists(file_path)) {
return false;
}

auto hash = fl::Sha256File(file_path);

// Case-insensitive comparison
std::string expected_hash(expected.sha256);
if (!std::equal(hash.begin(), hash.end(), expected_hash.begin(), expected_hash.end(),
[](char a, char b) { return std::toupper(a) == std::toupper(b); })) {
logger.Log(fl::LogLevel::Warning,
fmt::format("CUDA EP: hash mismatch for {}: got {}, expected {}",
expected.filename, hash, expected.sha256));
return false;
}
}

return true;
}

} // anonymous namespace

namespace fl {
Expand Down Expand Up @@ -127,7 +102,10 @@ bool CudaEpBootstrapper::DownloadAndRegister(bool force,
FileLock lock(lock_path);

// Check if package already exists and is valid
if (VerifyPackage(ep_dir, logger)) {
if (fl::VerifyEpPackage(ep_dir,
{{kExpectedBinaries[0].filename, kExpectedBinaries[0].sha256},
{kExpectedBinaries[1].filename, kExpectedBinaries[1].sha256}},
"CUDA EP", logger)) {
logger.Log(LogLevel::Information, "CUDA EP: package already valid, skipping download");
} else {
// Clean up any partial install
Expand Down Expand Up @@ -170,7 +148,10 @@ bool CudaEpBootstrapper::DownloadAndRegister(bool force,
std::filesystem::remove(zip_path);

// Verify
if (!VerifyPackage(ep_dir, logger)) {
if (!fl::VerifyEpPackage(ep_dir,
{{kExpectedBinaries[0].filename, kExpectedBinaries[0].sha256},
{kExpectedBinaries[1].filename, kExpectedBinaries[1].sha256}},
"CUDA EP", logger)) {
logger.Log(LogLevel::Warning, "CUDA EP: verification failed after download");
return false;
}
Expand Down
54 changes: 54 additions & 0 deletions sdk_v2/cpp/src/ep_detection/ep_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once

#include "logger.h"
#include "util/sha256.h"

#include <fmt/format.h>

#include <algorithm>
#include <cctype>
#include <filesystem>
#include <initializer_list>
#include <string>
#include <string_view>
#include <utility>

namespace fl {

/// Verify a set of binaries in @p dir all exist and match their expected SHA-256 hashes.
///
/// @param dir Directory containing the extracted EP binaries.
/// @param expected List of (filename, expected_sha256_hex) pairs.
/// @param ep_name EP name used in warning log messages (e.g. "CUDA EP").
/// @param logger Logger for diagnostic output.
/// @return true if every file exists and its hash matches; false otherwise.
inline bool VerifyEpPackage(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be inlined?

const std::filesystem::path& dir,
std::initializer_list<std::pair<std::string_view, std::string_view>> expected,
std::string_view ep_name,
ILogger& logger) {
for (const auto& [filename, expected_hash] : expected) {
auto file_path = dir / filename;

if (!std::filesystem::exists(file_path)) {
return false;
}

auto hash = Sha256File(file_path);

// Case-insensitive hex comparison
if (!std::equal(hash.begin(), hash.end(), expected_hash.begin(), expected_hash.end(),
[](char a, char b) { return std::toupper(a) == std::toupper(b); })) {
logger.Log(LogLevel::Warning,
fmt::format("{}: hash mismatch for {}: got {}, expected {}",
ep_name, filename, hash, expected_hash));
return false;
}
}

return true;
}

} // namespace fl
Loading