diff --git a/benchmark_kmeans.py b/benchmark_kmeans.py new file mode 100644 index 0000000000..cfabe7570a --- /dev/null +++ b/benchmark_kmeans.py @@ -0,0 +1,410 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. +# SPDX-License-Identifier: Apache-2.0 +r"""KMeans fit+predict benchmark: baseline / cuTile / flash-kmeans. + +Single impl (activate the target conda env first): + python benchmark_kmeans.py --impl baseline|cutile|flash --n N --d D --k K \\ + --max-iter 5 --tol 1e-4 --seed 42 \\ + --warmup-fit 1 --iters-fit 3 --warmup-pred 1 --iters-pred 3 + +Compare (subprocess per impl; export env vars, then --compare): + export BENCH_CONDA=/path/to/miniforge3 + export BENCH_ENV_BASE=cuvs_2608_base + export BENCH_ENV_CUTILE=cuvs_2608 + export BENCH_ENV_FLASH=cuvs_2608_base + python benchmark_kmeans.py --compare --n 33554432 --d 32 --k 64 \\ + --max-iter 5 --tol 1e-4 --seed 42 \\ + --warmup-fit 1 --iters-fit 3 --warmup-pred 1 --iters-pred 3 + +Smoke test (small shape, single impl): + conda activate cuvs_2608 + python benchmark_kmeans.py --impl cutile --n 10000 --d 32 --k 8 \\ + --max-iter 2 --tol 1e-4 --seed 42 \\ + --warmup-fit 0 --iters-fit 1 --warmup-pred 0 --iters-pred 1 + +Required for --compare (no defaults): + BENCH_CONDA path to miniforge/conda root + BENCH_ENV_BASE conda env name for baseline libcuvs + BENCH_ENV_CUTILE conda env name for cuTile libcuvs + BENCH_ENV_FLASH conda env name for flash-kmeans +""" + +from __future__ import annotations + +import argparse +import os +import re +import subprocess +import sys +import time +from dataclasses import dataclass +from pathlib import Path + +ROOT = Path(__file__).resolve().parent +IMPLS = ("baseline", "cutile", "flash") + + +def _require_env(name: str) -> str: + val = os.environ.get(name) + if not val: + raise SystemExit(f"required environment variable {name} is not set") + return val + + +def _impl_config() -> dict[str, dict]: + conda = Path(_require_env("BENCH_CONDA")) + return { + "baseline": { + "bench_mode": "cuvs_base", + "conda": conda, + "conda_env": _require_env("BENCH_ENV_BASE"), + }, + "cutile": { + "bench_mode": "cuvs_cutile", + "conda": conda, + "conda_env": _require_env("BENCH_ENV_CUTILE"), + }, + "flash": { + "bench_mode": "flash", + "conda": conda, + "conda_env": _require_env("BENCH_ENV_FLASH"), + }, + } + + +@dataclass +class BenchResult: + impl: str + fit_median_ms: float | None = None + predict_median_ms: float | None = None + n_iter: int | None = None + inertia: float | None = None + error: str | None = None + + +def median(xs: list[float]) -> float: + import numpy as np + + return float(np.median(xs)) + + +def run_benchmark( + bench_mode: str, + n: int, + d: int, + k: int, + *, + max_iter: int, + tol: float, + seed: int, + warmup_fit: int, + iters_fit: int, + warmup_pred: int, + iters_pred: int, +) -> BenchResult: + import numpy as np + + rng = np.random.default_rng(seed) + init_centroids_host = rng.standard_normal((k, d), dtype=np.float32) + x_host = rng.standard_normal((n, d), dtype=np.float32) + input_gib = n * d * 4 / (1024**3) + + label = { + "cuvs_base": "baseline", + "cuvs_cutile": "cutile", + "flash": "flash", + }[bench_mode] + print( + f"=== N={n:,} D={d} K={k:,} iters={max_iter} input={input_gib:.2f} GiB ===", + flush=True, + ) + + if bench_mode in ("cuvs_base", "cuvs_cutile"): + from cuda.bindings import runtime as cudart + from pylibraft.common import device_ndarray + + from cuvs.cluster.kmeans import KMeansParams, fit, predict + + def sync(): + cudart.cudaDeviceSynchronize() + + x = device_ndarray(x_host) + params = KMeansParams( + n_clusters=k, + max_iter=max_iter, + tol=tol, + metric="sqeuclidean", + hierarchical=False, + init_method="Array", + n_init=1, + ) + + for _ in range(warmup_fit): + fit( + params, x, centroids=device_ndarray(init_centroids_host.copy()) + ) + sync() + + fit_times: list[float] = [] + n_iter = 0 + inertia = 0.0 + for _ in range(iters_fit): + t0 = time.perf_counter() + _, inertia, n_iter = fit( + params, x, centroids=device_ndarray(init_centroids_host.copy()) + ) + sync() + fit_times.append((time.perf_counter() - t0) * 1e3) + + centroids, _, _ = fit( + params, x, centroids=device_ndarray(init_centroids_host.copy()) + ) + sync() + + for _ in range(warmup_pred): + predict(params, x, centroids) + sync() + + pred_times: list[float] = [] + for _ in range(iters_pred): + t0 = time.perf_counter() + predict(params, x, centroids) + sync() + pred_times.append((time.perf_counter() - t0) * 1e3) + + print(f"impl={label} init=Array", flush=True) + print(f"fit_median_ms={median(fit_times):.2f}", flush=True) + print(f"predict_median_ms={median(pred_times):.2f}", flush=True) + print(f"n_iter={n_iter} inertia={inertia:.6g}", flush=True) + return BenchResult( + impl=label, + fit_median_ms=median(fit_times), + predict_median_ms=median(pred_times), + n_iter=n_iter, + inertia=inertia, + ) + + if bench_mode == "flash": + import torch + from flash_kmeans.assign_euclid_triton import euclid_assign_triton + from flash_kmeans.kmeans_triton_impl import batch_kmeans_Euclid + + def sync(): + torch.cuda.synchronize() + + x = torch.from_numpy(x_host).cuda() + init_c = ( + torch.from_numpy(init_centroids_host.copy()).cuda().unsqueeze(0) + ) + + def run_fit(init): + x_b = x.unsqueeze(0) + _, centroids_b, _ = batch_kmeans_Euclid( + x_b, + k, + max_iters=max_iter, + tol=tol, + init_centroids=init, + verbose=False, + ) + return centroids_b + + for _ in range(warmup_fit): + run_fit(init_c.clone()) + sync() + + fit_times = [] + for _ in range(iters_fit): + t0 = time.perf_counter() + run_fit(init_c.clone()) + sync() + fit_times.append((time.perf_counter() - t0) * 1e3) + + centroids_b = run_fit(init_c.clone()) + sync() + + x_b = x.unsqueeze(0) + x_sq = (x_b**2).sum(dim=-1) + + for _ in range(warmup_pred): + euclid_assign_triton(x_b, centroids_b, x_sq) + sync() + + pred_times = [] + for _ in range(iters_pred): + t0 = time.perf_counter() + euclid_assign_triton(x_b, centroids_b, x_sq) + sync() + pred_times.append((time.perf_counter() - t0) * 1e3) + + print("impl=flash-kmeans init=Array", flush=True) + print(f"fit_median_ms={median(fit_times):.2f}", flush=True) + print(f"predict_median_ms={median(pred_times):.2f}", flush=True) + return BenchResult( + impl="flash", + fit_median_ms=median(fit_times), + predict_median_ms=median(pred_times), + ) + + raise ValueError(f"unknown bench_mode={bench_mode!r}") + + +def _parse_output(text: str, impl: str) -> BenchResult: + fit_m = re.search(r"^fit_median_ms=([0-9.]+)", text, re.M) + pred_m = re.search(r"^predict_median_ms=([0-9.]+)", text, re.M) + if not fit_m or not pred_m: + return BenchResult(impl=impl, error=text.strip() or "no output") + n_iter_m = re.search(r"^n_iter=([0-9]+)", text, re.M) + inertia_m = re.search(r"^inertia=([0-9.eE+-]+)", text, re.M) + return BenchResult( + impl=impl, + fit_median_ms=float(fit_m.group(1)), + predict_median_ms=float(pred_m.group(1)), + n_iter=int(n_iter_m.group(1)) if n_iter_m else None, + inertia=float(inertia_m.group(1)) if inertia_m else None, + ) + + +def _run_subprocess( + impl: str, + n: int, + d: int, + k: int, + args: argparse.Namespace, +) -> BenchResult: + cfg = _impl_config()[impl] + conda = cfg["conda"] + env_exports = " ".join( + f'export {key}="{val}"' + for key, val in ( + ( + "CUDA_VISIBLE_DEVICES", + os.environ.get("CUDA_VISIBLE_DEVICES", ""), + ), + ("MAX_ITER", args.max_iter), + ("TOL", args.tol), + ("SEED", args.seed), + ("WARMUP_FIT", args.warmup_fit), + ("ITERS_FIT", args.iters_fit), + ("WARMUP_PRED", args.warmup_pred), + ("ITERS_PRED", args.iters_pred), + ) + if val != "" + ) + cmd = f""" +set -eo pipefail +source "{conda}/etc/profile.d/conda.sh" +conda activate "{cfg["conda_env"]}" +{env_exports} +python3 "{ROOT / "benchmark_kmeans.py"}" --impl {impl} --n {n} --d {d} --k {k} \\ + --max-iter {args.max_iter} --tol {args.tol} --seed {args.seed} \\ + --warmup-fit {args.warmup_fit} --iters-fit {args.iters_fit} \\ + --warmup-pred {args.warmup_pred} --iters-pred {args.iters_pred} +""" + proc = subprocess.run(["bash", "-lc", cmd], capture_output=True, text=True) + out = proc.stdout + proc.stderr + if proc.returncode != 0: + return BenchResult( + impl=impl, error=out.strip() or f"exit {proc.returncode}" + ) + return _parse_output(out, impl) + + +def _speedup(base: float, other: float) -> str: + if other <= 0: + return "n/a" + return f"{base / other:.2f}x" + + +def print_compare_table( + results: list[BenchResult], n: int, d: int, k: int +) -> None: + print(f"\n######## compare N={n} D={d} K={k} ########") + print(f"{'impl':<10} {'fit_ms':>10} {'pred_ms':>10} {'notes'}") + print("-" * 50) + by_impl = {r.impl: r for r in results} + for impl in IMPLS: + r = by_impl.get(impl) + if r is None: + print(f"{impl:<10} {'—':>10} {'—':>10} missing") + continue + if r.error: + print( + f"{impl:<10} {'FAIL':>10} {'FAIL':>10} {r.error.splitlines()[-1][:40]}" + ) + continue + print( + f"{impl:<10} {r.fit_median_ms:10.2f} {r.predict_median_ms:10.2f}" + ) + + flash = by_impl.get("flash") + cutile = by_impl.get("cutile") + if flash and cutile and flash.fit_median_ms and cutile.fit_median_ms: + if not flash.error and not cutile.error: + print( + f"\nflash vs cutile fit: {_speedup(cutile.fit_median_ms, flash.fit_median_ms)}" + f" predict: {_speedup(cutile.predict_median_ms, flash.predict_median_ms)}" + ) + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--compare", action="store_true", help="run baseline, cutile, flash" + ) + parser.add_argument("--impl", choices=IMPLS, help="single impl") + parser.add_argument("--n", type=int, required=True) + parser.add_argument("--d", type=int, required=True) + parser.add_argument("--k", type=int, required=True) + parser.add_argument("--max-iter", type=int, required=True) + parser.add_argument("--tol", type=float, required=True) + parser.add_argument("--seed", type=int, required=True) + parser.add_argument("--warmup-fit", type=int, required=True) + parser.add_argument("--iters-fit", type=int, required=True) + parser.add_argument("--warmup-pred", type=int, required=True) + parser.add_argument("--iters-pred", type=int, required=True) + args = parser.parse_args() + + if args.compare: + if args.impl: + parser.error("--compare and --impl are mutually exclusive") + _impl_config() # validate required env before launching subprocesses + results = [ + _run_subprocess(impl, args.n, args.d, args.k, args) + for impl in IMPLS + ] + print_compare_table(results, args.n, args.d, args.k) + return 0 if all(r.error is None for r in results) else 1 + + if not args.impl: + parser.error("set --impl for single-run mode, or use --compare") + + bench_mode = { + "baseline": "cuvs_base", + "cutile": "cuvs_cutile", + "flash": "flash", + }[args.impl] + + try: + run_benchmark( + bench_mode, + args.n, + args.d, + args.k, + max_iter=args.max_iter, + tol=args.tol, + seed=args.seed, + warmup_fit=args.warmup_fit, + iters_fit=args.iters_fit, + warmup_pred=args.warmup_pred, + iters_pred=args.iters_pred, + ) + except Exception as exc: + print(f"ERROR: {exc}", file=sys.stderr) + return 1 + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/conda/recipes/libcuvs/recipe.yaml b/conda/recipes/libcuvs/recipe.yaml index aa7a37db44..93f31f8cf2 100644 --- a/conda/recipes/libcuvs/recipe.yaml +++ b/conda/recipes/libcuvs/recipe.yaml @@ -80,6 +80,7 @@ cache: - cuda-cudart-dev - cuda-nvrtc-dev - cuda-profiler-api + - cutile-python - libcublas-dev - libcurand-dev - libcusolver-dev @@ -117,6 +118,7 @@ outputs: - cuda-cudart-dev - cuda-nvrtc-dev - cuda-profiler-api + - cutile-python - libcublas-dev - libcurand-dev - libcusolver-dev @@ -179,6 +181,7 @@ outputs: - cuda-cudart-dev - cuda-nvrtc-dev - cuda-profiler-api + - cutile-python - libcublas-dev - libcurand-dev - libcusolver-dev @@ -240,6 +243,7 @@ outputs: - cuda-cudart-dev - cuda-nvrtc-dev - cuda-profiler-api + - cutile-python - libcublas-dev - libcurand-dev - libcusolver-dev @@ -299,6 +303,7 @@ outputs: - openblas # required by some CPU algos in benchmarks - cuda-cudart-dev - cuda-profiler-api + - cutile-python - libcublas-dev - libcurand-dev - libcusolver-dev diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 227c2906cc..84979d05c1 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -957,6 +957,41 @@ if(NOT BUILD_CPU_ONLY) OUTPUT_FILE_FORMAT "${CMAKE_CURRENT_BINARY_DIR}/src/distance/detail/pairwise_matrix/dispatch_rbf_inst_data_@data_abbrev@_acc_@acc_abbrev@_out_@out_abbrev@_index_@index_abbrev@_op_@op_abbrev@.cu" ) + + include(cmake/modules/generate_cutile_kernels.cmake) + set(fused_1nn_cutile_dir + "${CMAKE_CURRENT_SOURCE_DIR}/src/distance/detail/fused_distance_nn/cutile" + ) + set(cutile_fused_1nn_generated_dir + "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/distance/fused_1nn/cutile" + ) + generate_cutile_kernels( + cutile_fused_1nn_files + KERNEL_DIR + "${fused_1nn_cutile_dir}" + KERNEL_BASENAME + "fused_1nn" + KERNEL_PYTHON + "fused_1nn_kernel.py" + EXPORT_SCRIPT + "export_fused_1nn.py" + OUTPUT_DIRECTORY + "${cutile_fused_1nn_generated_dir}" + MATRIX_JSON_FILE + "${fused_1nn_cutile_dir}/fused_1nn_cutile_matrix.json" + FRAGMENT_TAG_FORMAT_CUBIN + "cuvs::distance::detail::fragment_tag_fused_1nn_cubin, cuvs::detail::jit_lto::@arch_tag@>" + FRAGMENT_TAG_FORMAT_TILEIR + "cuvs::distance::detail::fragment_tag_fused_1nn_tileir>" + FRAGMENT_TAG_HEADER_FILES + "" + "" + "" + ) + if(NOT DEFINED CUVS_CUTILE_ENABLED) + set(CUVS_CUTILE_ENABLED 0) + endif() + target_compile_definitions(cuvs_cpp_headers INTERFACE CUVS_CUTILE_ENABLED=${CUVS_CUTILE_ENABLED}) generate_inst_matrix( cagra_build_inst_files MATRIX_JSON_FILE "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/cagra_build_matrix.json" @@ -1147,6 +1182,8 @@ if(NOT BUILD_CPU_ONLY) src/util/host_memory.cpp src/detail/jit_lto/AlgorithmLauncher.cpp src/detail/jit_lto/AlgorithmPlanner.cpp + src/detail/jit_lto/LTOAlgorithmPlanner.cpp + src/detail/jit_lto/TileAlgorithmPlanner.cpp src/detail/jit_lto/FragmentEntry.cpp src/detail/jit_lto/nvjitlink_checker.cpp src/detail/jit_lto/NVRTCLTOFragmentCompiler.cpp @@ -1234,6 +1271,8 @@ if(NOT BUILD_CPU_ONLY) src/stats/trustworthiness_score.cu ${CUVS_MG_ALGOS} ${jit_lto_files} + ${cutile_fused_1nn_files} + $<$:src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.cu> ) set_target_properties( @@ -1255,8 +1294,9 @@ if(NOT BUILD_CPU_ONLY) ) target_compile_definitions( - cuvs_objs PRIVATE $<$:CUVS_BUILD_CAGRA_HNSWLIB> - $<$:NVTX_ENABLED> + cuvs_objs + PRIVATE $<$:CUVS_BUILD_CAGRA_HNSWLIB> + $<$:NVTX_ENABLED> CUVS_CUTILE_ENABLED=${CUVS_CUTILE_ENABLED} ) target_link_libraries( @@ -1275,6 +1315,7 @@ if(NOT BUILD_CPU_ONLY) "$" INTERFACE "$" PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/src" "${CMAKE_CURRENT_BINARY_DIR}/src" + "${cutile_fused_1nn_generated_dir}" ) # Endian detection diff --git a/cpp/cmake/modules/generate_cutile_kernels.cmake b/cpp/cmake/modules/generate_cutile_kernels.cmake new file mode 100644 index 0000000000..9cd8a207c8 --- /dev/null +++ b/cpp/cmake/modules/generate_cutile_kernels.cmake @@ -0,0 +1,272 @@ +# ============================================================================= +# cmake-format: off +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. +# SPDX-License-Identifier: Apache-2.0 +# cmake-format: on +# ============================================================================= + +include_guard(GLOBAL) + +include(${CMAKE_CURRENT_LIST_DIR}/compute_matrix_product.cmake) + +function(generate_cutile_kernels_stub) + set(CUVS_CUTILE_ENABLED + 0 + PARENT_SCOPE + ) +endfunction() + +function(_cutile_fragment_tag_header_files output_var) + set(${output_var} "") + foreach(_header IN LISTS ARGN) + if(NOT _header MATCHES "^(\".*\"|<.*>)$") + set(_header "\"${_header}\"") + endif() + string(APPEND ${output_var} "#include ${_header}\n") + endforeach() + set(${output_var} + "${${output_var}}" + PARENT_SCOPE + ) +endfunction() + +function(_cutile_kernels_setup) + set(options) + set(one_value MATRIX_JSON_FILE OUTPUT_DIRECTORY) + set(multi_value) + cmake_parse_arguments(_CUTILE "${options}" "${one_value}" "${multi_value}" ${ARGN}) + + find_package(Python3 REQUIRED COMPONENTS Interpreter) + find_package(CUDAToolkit REQUIRED) + + if(CUDAToolkit_VERSION VERSION_LESS 13.0) + message( + STATUS + "cuTile embedded kernels require CUDA 13.0+; skipping cuTile generation (found ${CUDAToolkit_VERSION})." + ) + set(_CUTILE_SETUP_OK + FALSE + PARENT_SCOPE + ) + return() + endif() + + find_program( + CUTILE_BIN2C + NAMES bin2c + PATHS ${CUDAToolkit_BIN_DIR} REQUIRED + ) + + execute_process( + COMMAND "${Python3_EXECUTABLE}" -c "import cuda.tile" + RESULT_VARIABLE _cutile_import_result + OUTPUT_QUIET ERROR_QUIET + ) + if(NOT _cutile_import_result EQUAL 0) + message( + FATAL_ERROR + "cuda.tile (cuTile Python) is required to build cuTile embedded kernels. " + "Install it in the active Python environment, e.g. pip install cuda-tile[tileiras]." + ) + endif() + + set_property( + DIRECTORY + PROPERTY CMAKE_CONFIGURE_DEPENDS "${_CUTILE_MATRIX_JSON_FILE}" + APPEND + ) + + file(MAKE_DIRECTORY "${_CUTILE_OUTPUT_DIRECTORY}") + + set(Python3_EXECUTABLE + "${Python3_EXECUTABLE}" + PARENT_SCOPE + ) + set(CUTILE_BIN2C + "${CUTILE_BIN2C}" + PARENT_SCOPE + ) + set(_CUTILE_SETUP_OK + TRUE + PARENT_SCOPE + ) +endfunction() + +function(_cutile_generate_matrix_tiles_header header_path matrix_json_file) + file(READ "${matrix_json_file}" _matrix_json) + string(JSON _tile0 GET "${_matrix_json}" 0 "_tile" 0) + string(JSON _tile_m GET "${_tile0}" "tile_m") + string(JSON _tile_n GET "${_tile0}" "tile_n") + string(JSON _tile_k GET "${_tile0}" "tile_k") + file( + WRITE "${header_path}" + "/* + * Generated from ${matrix_json_file} by generate_cutile_kernels.cmake — do not edit. + */ +#pragma once + +#include + +namespace cuvs::distance::detail { + +using fused_1nn_matrix_tile = cutile_tile_config<${_tile_m}, ${_tile_n}, ${_tile_k}>; + +} // namespace cuvs::distance::detail +" + ) +endfunction() + +function(process_cutile_matrix_entry source_list_var) + set(options) + set(one_value KERNEL_DIR KERNEL_BASENAME KERNEL_PYTHON EXPORT_SCRIPT OUTPUT_DIRECTORY + FRAGMENT_TAG_FORMAT_CUBIN FRAGMENT_TAG_FORMAT_TILEIR MATRIX_JSON_ENTRY + ) + set(multi_value FRAGMENT_TAG_HEADER_FILES) + cmake_parse_arguments(_CUTILE "${options}" "${one_value}" "${multi_value}" ${ARGN}) + + find_package(Python3 REQUIRED COMPONENTS Interpreter) + + populate_matrix_variables("${_CUTILE_MATRIX_JSON_ENTRY}") + + if(register STREQUAL "cubin") + string(CONFIGURE "${_CUTILE_FRAGMENT_TAG_FORMAT_CUBIN}" fragment_tag @ONLY) + set(bin2c_symbol embedded_cubin) + set(fragment_entry_type "StaticCubinFragmentEntry") + elseif(register STREQUAL "tileir") + string(CONFIGURE "${_CUTILE_FRAGMENT_TAG_FORMAT_TILEIR}" fragment_tag @ONLY) + set(bin2c_symbol embedded_tileir) + set(fragment_entry_type "StaticTileIrBytecodeFragmentEntry") + else() + message(FATAL_ERROR "Unknown cuTile register kind '${register}'") + endif() + + _cutile_fragment_tag_header_files(fragment_tag_header_files ${_CUTILE_FRAGMENT_TAG_HEADER_FILES}) + + string(CONFIGURE "${artifact_basename}" _artifact_basename @ONLY) + set(_artifact_stem "${_CUTILE_KERNEL_BASENAME}_${_artifact_basename}") + set(_artifact_file "${_CUTILE_OUTPUT_DIRECTORY}/${_artifact_stem}.${artifact_ext}") + set(_embedded_header "${_CUTILE_OUTPUT_DIRECTORY}/${_artifact_stem}_${register}.h") + set(_fragment_cpp "${_CUTILE_OUTPUT_DIRECTORY}/${_artifact_stem}_${register}.cpp") + set(embedded_header_file "${_artifact_stem}_${register}.h") + + set(_python_args + --format + "${output_format}" + --data-type + "${data_type}" + --metric + "${metric}" + --index-type + "${index_type}" + --tile-m + "${tile_m}" + --tile-n + "${tile_n}" + --tile-k + "${tile_k}" + --gpu-code + "${gpu_code}" + ) + if(DEFINED bytecode_version AND NOT "${bytecode_version}" STREQUAL "") + list(APPEND _python_args --bytecode-version "${bytecode_version}") + endif() + + add_custom_command( + OUTPUT "${_artifact_file}" + COMMAND "${Python3_EXECUTABLE}" "${_CUTILE_KERNEL_DIR}/${_CUTILE_EXPORT_SCRIPT}" + "${_artifact_file}" ${_python_args} + WORKING_DIRECTORY "${_CUTILE_KERNEL_DIR}" + DEPENDS "${_CUTILE_KERNEL_DIR}/${_CUTILE_EXPORT_SCRIPT}" + "${_CUTILE_KERNEL_DIR}/${_CUTILE_KERNEL_PYTHON}" + COMMENT "Exporting cuTile ${_CUTILE_KERNEL_BASENAME} ${output_format} ${data_type}" + VERBATIM + ) + + add_custom_command( + OUTPUT "${_embedded_header}" + COMMAND "${CUTILE_BIN2C}" --const --name ${bin2c_symbol} --static "${_artifact_file}" > + "${_embedded_header}" + DEPENDS "${_artifact_file}" + VERBATIM + ) + + configure_file( + "${CMAKE_CURRENT_FUNCTION_LIST_DIR}/register_cutile_fragment.cpp.in" "${_fragment_cpp}" @ONLY + ) + list(APPEND ${source_list_var} "${_embedded_header}" "${_fragment_cpp}") + set(${source_list_var} + "${${source_list_var}}" + PARENT_SCOPE + ) +endfunction() + +function(generate_cutile_kernels source_list_var) + set(options) + set(one_value KERNEL_DIR KERNEL_BASENAME KERNEL_PYTHON EXPORT_SCRIPT OUTPUT_DIRECTORY + MATRIX_JSON_FILE FRAGMENT_TAG_FORMAT_CUBIN FRAGMENT_TAG_FORMAT_TILEIR + ) + set(multi_value FRAGMENT_TAG_HEADER_FILES) + cmake_parse_arguments(_CUTILE "${options}" "${one_value}" "${multi_value}" ${ARGN}) + + if(NOT _CUTILE_KERNEL_BASENAME) + message(FATAL_ERROR "generate_cutile_kernels: KERNEL_BASENAME is required") + endif() + if(NOT _CUTILE_KERNEL_PYTHON) + set(_CUTILE_KERNEL_PYTHON "fused_1nn_kernel.py") + endif() + + _cutile_kernels_setup( + MATRIX_JSON_FILE "${_CUTILE_MATRIX_JSON_FILE}" OUTPUT_DIRECTORY "${_CUTILE_OUTPUT_DIRECTORY}" + ) + if(NOT _CUTILE_SETUP_OK) + generate_cutile_kernels_stub() + set(${source_list_var} + "" + PARENT_SCOPE + ) + return() + endif() + + compute_matrix_product(matrix_product MATRIX_JSON_FILE "${_CUTILE_MATRIX_JSON_FILE}") + + set(_matrix_tiles_header "${_CUTILE_OUTPUT_DIRECTORY}/fused_1nn_cutile_tiles.hpp") + _cutile_generate_matrix_tiles_header("${_matrix_tiles_header}" "${_CUTILE_MATRIX_JSON_FILE}") + + string(JSON len LENGTH "${matrix_product}") + math(EXPR last "${len} - 1") + + # cmake-lint: disable=C0103,E1120 + foreach(i RANGE "${last}") + string(JSON matrix_json_entry GET "${matrix_product}" "${i}") + process_cutile_matrix_entry( + "${source_list_var}" + KERNEL_DIR + "${_CUTILE_KERNEL_DIR}" + KERNEL_BASENAME + "${_CUTILE_KERNEL_BASENAME}" + KERNEL_PYTHON + "${_CUTILE_KERNEL_PYTHON}" + EXPORT_SCRIPT + "${_CUTILE_EXPORT_SCRIPT}" + OUTPUT_DIRECTORY + "${_CUTILE_OUTPUT_DIRECTORY}" + FRAGMENT_TAG_FORMAT_CUBIN + "${_CUTILE_FRAGMENT_TAG_FORMAT_CUBIN}" + FRAGMENT_TAG_FORMAT_TILEIR + "${_CUTILE_FRAGMENT_TAG_FORMAT_TILEIR}" + FRAGMENT_TAG_HEADER_FILES + ${_CUTILE_FRAGMENT_TAG_HEADER_FILES} + MATRIX_JSON_ENTRY + "${matrix_json_entry}" + ) + endforeach() + + set(CUVS_CUTILE_ENABLED + 1 + PARENT_SCOPE + ) + set(${source_list_var} + "${${source_list_var}}" + PARENT_SCOPE + ) +endfunction() diff --git a/cpp/cmake/modules/register_cutile_fragment.cpp.in b/cpp/cmake/modules/register_cutile_fragment.cpp.in new file mode 100644 index 0000000000..de0472a779 --- /dev/null +++ b/cpp/cmake/modules/register_cutile_fragment.cpp.in @@ -0,0 +1,31 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "@embedded_header_file@" +#include + +@fragment_tag_header_files@ + + namespace +{ + using fragment_tag = @fragment_tag@; + using fragment_entry = @fragment_entry_type@; + +} // namespace + +template <> +const uint8_t* const fragment_entry::data = @bin2c_symbol@; + +template <> +const size_t fragment_entry::length = sizeof(@bin2c_symbol@); + +template <> +const int fragment_entry::tile_m = @tile_m@; + +template <> +const int fragment_entry::tile_n = @tile_n@; + +template <> +const int fragment_entry::tile_k = @tile_k@; diff --git a/cpp/include/cuvs/detail/jit_lto/AlgorithmPlanner.hpp b/cpp/include/cuvs/detail/jit_lto/AlgorithmPlanner.hpp index 7f275b1285..7ff8487d20 100644 --- a/cpp/include/cuvs/detail/jit_lto/AlgorithmPlanner.hpp +++ b/cpp/include/cuvs/detail/jit_lto/AlgorithmPlanner.hpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -19,6 +20,7 @@ struct LauncherJitCache { std::shared_mutex mutex; std::unordered_map> launchers; + std::unordered_set build_failed; }; struct AlgorithmPlanner { @@ -27,9 +29,32 @@ struct AlgorithmPlanner { { } + virtual ~AlgorithmPlanner() = default; + std::shared_ptr get_launcher(); + /** Returns nullptr when no module can be loaded for the current device (does not RAFT_FAIL). */ + std::shared_ptr try_get_launcher(); + std::string entrypoint; + + protected: + virtual std::shared_ptr build() = 0; + + virtual std::string get_planner_key() const = 0; + + std::shared_ptr read_cache(std::string const& launch_key) const; + + LauncherJitCache& jit_cache_; +}; + +/** Links embedded LTO fatbin fragments at runtime via nvJitLink. */ +struct LTOAlgorithmPlanner : AlgorithmPlanner { + LTOAlgorithmPlanner(std::string entrypoint, LauncherJitCache& jit_cache) + : AlgorithmPlanner(std::move(entrypoint), jit_cache) + { + } + std::vector> fragments; template >> @@ -45,16 +70,41 @@ struct AlgorithmPlanner { } protected: - /** Extra link-time option strings passed to nvJitLink. Base build() - * always passes "-lto" and "-arch=sm_XX" first; derived planners may append here in their - * constructor body. */ + /** Extra link-time option strings passed to nvJitLink. */ std::vector linktime_extra_options; - private: - std::string get_fragments_key() const; - std::shared_ptr build(); + std::string get_planner_key() const override; - std::shared_ptr read_cache(std::string const& launch_key) const; + std::shared_ptr build() override; +}; - LauncherJitCache& jit_cache_; +/** Loads prebuilt cubins or TileIR bytecode via cudaLibraryLoadData. */ +struct TileAlgorithmPlanner : AlgorithmPlanner { + TileAlgorithmPlanner(std::string entrypoint, LauncherJitCache& jit_cache) + : AlgorithmPlanner(std::move(entrypoint), jit_cache) + { + } + + template + void add_static_fragment() + { + cubin_fragments_.push_back(std::make_unique>()); + } + + template + void add_static_tileir_fragment() + { + tileir_fragment_ = std::make_unique>(); + } + + /** Tile geometry from the cubin or TileIR fragment that would load on this device. */ + CutileTileConfig tile_config() const; + + protected: + std::vector> cubin_fragments_; + std::unique_ptr tileir_fragment_; + + std::string get_planner_key() const override; + + std::shared_ptr build() override; }; diff --git a/cpp/include/cuvs/detail/jit_lto/FragmentEntry.hpp b/cpp/include/cuvs/detail/jit_lto/FragmentEntry.hpp index 35aa46633c..0961595f8d 100644 --- a/cpp/include/cuvs/detail/jit_lto/FragmentEntry.hpp +++ b/cpp/include/cuvs/detail/jit_lto/FragmentEntry.hpp @@ -62,3 +62,108 @@ struct UDFFatbinFragment final : FatbinFragmentEntry { std::string key_; std::vector bytes_; }; + +/** cuTile GEMM-style block geometry embedded in generated Static*FragmentEntry specializations. */ +struct CutileTileConfig { + int tile_m; + int tile_n; + int tile_k; +}; + +/** Embedded CUDA binary module (cubin), loaded directly via cudaLibraryLoadData. */ +struct CubinFragmentEntry { + virtual ~CubinFragmentEntry() = default; + + virtual const uint8_t* get_data() const = 0; + + virtual size_t get_length() const = 0; + + virtual const char* get_key() const = 0; + + virtual int get_cc_major() const = 0; + + virtual int get_cc_minor() const = 0; + + virtual int get_tile_m() const { return 0; } + + virtual int get_tile_n() const { return 0; } + + virtual int get_tile_k() const { return 0; } +}; + +template +struct StaticCubinFragmentEntry final : CubinFragmentEntry { + const uint8_t* get_data() const override { return StaticCubinFragmentEntry::data; } + + size_t get_length() const override { return StaticCubinFragmentEntry::length; } + + const char* get_key() const override + { + return typeid(StaticCubinFragmentEntry).name(); + } + + int get_cc_major() const override { return FragmentTag::cc_major; } + + int get_cc_minor() const override { return FragmentTag::cc_minor; } + + int get_tile_m() const override { return tile_m; } + + int get_tile_n() const override { return tile_n; } + + int get_tile_k() const override { return tile_k; } + + static const int tile_m; + static const int tile_n; + static const int tile_k; + + static const uint8_t* const data; + static const size_t length; +}; + +/** Embedded TileIR bytecode, JIT-compiled by the driver when no matching cubin exists. */ +struct TileIrBytecodeFragmentEntry { + virtual ~TileIrBytecodeFragmentEntry() = default; + + virtual const uint8_t* get_data() const = 0; + + virtual size_t get_length() const = 0; + + virtual const char* get_key() const = 0; + + virtual int get_tile_m() const { return 0; } + + virtual int get_tile_n() const { return 0; } + + virtual int get_tile_k() const { return 0; } +}; + +template +struct StaticTileIrBytecodeFragmentEntry final : TileIrBytecodeFragmentEntry { + const uint8_t* get_data() const override + { + return StaticTileIrBytecodeFragmentEntry::data; + } + + size_t get_length() const override + { + return StaticTileIrBytecodeFragmentEntry::length; + } + + const char* get_key() const override + { + return typeid(StaticTileIrBytecodeFragmentEntry).name(); + } + + int get_tile_m() const override { return tile_m; } + + int get_tile_n() const override { return tile_n; } + + int get_tile_k() const override { return tile_k; } + + static const int tile_m; + static const int tile_n; + static const int tile_k; + + static const uint8_t* const data; + static const size_t length; +}; diff --git a/cpp/include/cuvs/detail/jit_lto/cutile_arch_tags.hpp b/cpp/include/cuvs/detail/jit_lto/cutile_arch_tags.hpp new file mode 100644 index 0000000000..2c915a278b --- /dev/null +++ b/cpp/include/cuvs/detail/jit_lto/cutile_arch_tags.hpp @@ -0,0 +1,52 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#ifndef CUVS_CUTILE_ENABLED +#define CUVS_CUTILE_ENABLED 0 +#endif + +namespace cuvs::detail::jit_lto { + +#if CUVS_CUTILE_ENABLED + +/** Must stay in sync with cuTile matrix _arch entries and planner add_static_fragment calls. */ +struct cutile_arch_8_0 { + static constexpr int cc_major = 8; + static constexpr int cc_minor = 0; +}; + +struct cutile_arch_8_6 { + static constexpr int cc_major = 8; + static constexpr int cc_minor = 6; +}; + +struct cutile_arch_9_0 { + static constexpr int cc_major = 9; + static constexpr int cc_minor = 0; +}; + +struct cutile_arch_12_0 { + static constexpr int cc_major = 12; + static constexpr int cc_minor = 0; +}; + +inline bool is_embedded_cubin_arch(int cc_major, int cc_minor) +{ + if (cc_major == 8 && cc_minor == 0) { return true; } + if (cc_major == 8 && cc_minor == 6) { return true; } + if (cc_major == 9 && cc_minor == 0) { return true; } + if (cc_major == 12 && cc_minor == 0) { return true; } + return false; +} + +#else + +inline bool is_embedded_cubin_arch(int, int) { return false; } + +#endif + +} // namespace cuvs::detail::jit_lto diff --git a/cpp/include/cuvs/detail/jit_lto/cutile_module.hpp b/cpp/include/cuvs/detail/jit_lto/cutile_module.hpp new file mode 100644 index 0000000000..dff0f472a7 --- /dev/null +++ b/cpp/include/cuvs/detail/jit_lto/cutile_module.hpp @@ -0,0 +1,75 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +#include + +namespace cuvs::detail::jit_lto { + +struct CutileModuleImage { + const uint8_t* data; + size_t size; +}; + +inline bool get_device_compute_capability(int& cc_major, int& cc_minor) +{ + int device = 0; + if (cudaGetDevice(&device) != cudaSuccess) { return false; } + if (cudaDeviceGetAttribute(&cc_major, cudaDevAttrComputeCapabilityMajor, device) != cudaSuccess) { + return false; + } + if (cudaDeviceGetAttribute(&cc_minor, cudaDevAttrComputeCapabilityMinor, device) != cudaSuccess) { + return false; + } + return true; +} + +/** Selects a prebuilt cubin for the device CC, or embedded TileIR when the driver can JIT it. */ +inline std::optional resolve_cutile_module_image( + int cc_major, + int cc_minor, + int driver_version, + const std::vector>& cubin_fragments, + const TileIrBytecodeFragmentEntry* tileir_fragment) +{ + for (const auto& fragment : cubin_fragments) { + if (fragment->get_cc_major() == cc_major && fragment->get_cc_minor() == cc_minor) { + return CutileModuleImage{fragment->get_data(), fragment->get_length()}; + } + } + if (tileir_fragment != nullptr && tileir_fallback_available(driver_version)) { + return CutileModuleImage{tileir_fragment->get_data(), tileir_fragment->get_length()}; + } + return std::nullopt; +} + +inline std::shared_ptr load_cutile_launcher(const CutileModuleImage& image, + const std::string& kernel_symbol) +{ + cudaLibrary_t library{}; + RAFT_CUDA_TRY( + cudaLibraryLoadData(&library, image.data, nullptr, nullptr, 0, nullptr, nullptr, 0)); + + cudaKernel_t kernel{}; + RAFT_CUDA_TRY(cudaLibraryGetKernel(&kernel, library, kernel_symbol.c_str())); + + return std::make_shared(kernel, library); +} + +} // namespace cuvs::detail::jit_lto diff --git a/cpp/include/cuvs/detail/jit_lto/fused_distance_nn/fused_1nn_fragments.hpp b/cpp/include/cuvs/detail/jit_lto/fused_distance_nn/fused_1nn_fragments.hpp new file mode 100644 index 0000000000..c6afe16b5c --- /dev/null +++ b/cpp/include/cuvs/detail/jit_lto/fused_distance_nn/fused_1nn_fragments.hpp @@ -0,0 +1,109 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include + +#include +#include + +namespace cuvs::distance::detail { + +struct metric_tag_ip {}; +struct metric_tag_l2 {}; +struct metric_tag_cos {}; + +template +struct cutile_tile_config { + static constexpr int tile_m = TileM; + static constexpr int tile_n = TileN; + static constexpr int tile_k = TileK; +}; + +template +struct fused_1nn_metric_tag; + +template <> +struct fused_1nn_metric_tag { + using type = metric_tag_ip; +}; + +template <> +struct fused_1nn_metric_tag { + using type = metric_tag_l2; +}; + +template <> +struct fused_1nn_metric_tag { + using type = metric_tag_l2; +}; + +template <> +struct fused_1nn_metric_tag { + using type = metric_tag_cos; +}; + +/** Whether sqrt is applied when packing distance into KVP output. */ +template +constexpr bool fused_1nn_apply_sqrt_at_pack(bool is_sqrt) +{ + if constexpr (Metric == cuvs::distance::DistanceType::L2Expanded || + Metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + return is_sqrt; + } else { + return false; + } +} + +template +using fused_1nn_metric_tag_t = typename fused_1nn_metric_tag::type; + +template +struct fused_1nn_data_tag; + +template <> +struct fused_1nn_data_tag { + using type = cuvs::neighbors::detail::tag_f; +}; + +template <> +struct fused_1nn_data_tag { + using type = cuvs::neighbors::detail::tag_h; +}; + +template +using fused_1nn_data_tag_t = typename fused_1nn_data_tag::type; + +template +struct fused_1nn_index_tag; + +template <> +struct fused_1nn_index_tag { + using type = cuvs::neighbors::detail::tag_index_i32; +}; + +template <> +struct fused_1nn_index_tag { + using type = cuvs::neighbors::detail::tag_index_i64; +}; + +template +using fused_1nn_index_tag_t = typename fused_1nn_index_tag::type; + +template +struct fragment_tag_fused_1nn_cubin { + static constexpr int cc_major = ArchTag::cc_major; + static constexpr int cc_minor = ArchTag::cc_minor; +}; + +template +struct fragment_tag_fused_1nn_tileir {}; + +} // namespace cuvs::distance::detail diff --git a/cpp/include/cuvs/detail/jit_lto/tileir_compat.hpp b/cpp/include/cuvs/detail/jit_lto/tileir_compat.hpp new file mode 100644 index 0000000000..f114233179 --- /dev/null +++ b/cpp/include/cuvs/detail/jit_lto/tileir_compat.hpp @@ -0,0 +1,108 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#ifndef CUVS_CUTILE_ENABLED +#define CUVS_CUTILE_ENABLED 0 +#endif + +#include +#include + +#include + +namespace cuvs::detail::jit_lto { + +/** Minimum CUDA driver version (from cudaDriverGetVersion) for TileIR JIT of embedded bytecode. */ +inline constexpr int kMinTileIrJitDriverVersion = 13010; // CUDA 13.1 / driver >= 590.44 + +/** Minimum CUDA runtime version (from cudaRuntimeGetVersion) for cuTile integration. */ +inline constexpr int kMinCutileRuntimeVersion = 13000; + +inline constexpr bool library_built_with_cutile() +{ +#if CUVS_CUTILE_ENABLED + return true; +#else + return false; +#endif +} + +inline bool runtime_cuda13_or_newer() +{ + int runtime_version = 0; + if (cudaRuntimeGetVersion(&runtime_version) != cudaSuccess) { return false; } + return runtime_version >= kMinCutileRuntimeVersion; +} + +/** True when this build embeds cuTile artifacts and the runtime is CUDA 13+. */ +inline bool cutile_integration_enabled() +{ + return library_built_with_cutile() && runtime_cuda13_or_newer(); +} + +/** True when this build embeds a prebuilt cubin for the given compute capability. */ +inline bool has_embedded_cubin_for_arch(int cc_major, int cc_minor) +{ + return is_embedded_cubin_arch(cc_major, cc_minor); +} + +/** True when the driver can JIT-compile embedded TileIR bytecode at load time. */ +inline bool tileir_fallback_available(int driver_version) +{ + return driver_version >= kMinTileIrJitDriverVersion; +} + +/** + * True when a cuTile launch may be attempted for the given device: cuTile is enabled, the runtime + * is CUDA 13+, and either a matching embedded cubin exists (no driver JIT required) or the driver + * can JIT the embedded TileIR bytecode fallback. + */ +#if CUVS_CUTILE_ENABLED +inline bool cutile_launch_available_for_arch(int cc_major, int cc_minor, int driver_version) +{ + if (!runtime_cuda13_or_newer()) { return false; } + if (has_embedded_cubin_for_arch(cc_major, cc_minor)) { return true; } + return tileir_fallback_available(driver_version); +} +#else +inline constexpr bool cutile_launch_available_for_arch(int, int, int) { return false; } +#endif + +inline bool query_driver_version(int& driver_version) +{ + return cudaDriverGetVersion(&driver_version) == cudaSuccess; +} + +inline bool query_current_device_arch(int& cc_major, int& cc_minor) +{ + int device = 0; + if (cudaGetDevice(&device) != cudaSuccess) { return false; } + if (cudaDeviceGetAttribute(&cc_major, cudaDevAttrComputeCapabilityMajor, device) != cudaSuccess) { + return false; + } + if (cudaDeviceGetAttribute(&cc_minor, cudaDevAttrComputeCapabilityMinor, device) != cudaSuccess) { + return false; + } + return true; +} + +#if CUVS_CUTILE_ENABLED +inline bool cutile_launch_available_on_current_device() +{ + int cc_major = 0; + int cc_minor = 0; + int driver_version = 0; + if (!query_current_device_arch(cc_major, cc_minor)) { return false; } + if (!query_driver_version(driver_version)) { return false; } + return cutile_launch_available_for_arch(cc_major, cc_minor, driver_version); +} +#else +/** Compile-time false when cuTile is not built; use in if constexpr to skip cuTile-only paths. */ +inline constexpr bool cutile_launch_available_on_current_device() { return false; } +#endif + +} // namespace cuvs::detail::jit_lto diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index 635e8813bd..757108312f 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -682,8 +682,8 @@ void kmeans_fit( DataT* cur_centroids_ptr = cur_centroids_buf.data(); DataT* new_centroids_ptr = new_centroids_buf.data(); - auto minClusterAndDistance = raft::make_device_vector, IndexT>( - handle, streaming_batch_size); + auto nearest_idx = raft::make_device_vector(handle, streaming_batch_size); + auto nearest_dist = raft::make_device_vector(handle, streaming_batch_size); auto L2NormBatch = raft::make_device_vector(handle, streaming_batch_size); auto batch_weights_buf = raft::make_device_vector(handle, streaming_batch_size); rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); @@ -853,8 +853,10 @@ void kmeans_fit( auto batch_weights_view = cur_batch_weights(static_cast(data_batch.offset()), wt_data, cur_batch_size); - auto minCAD_view = raft::make_device_vector_view, IndexT>( - minClusterAndDistance.data_handle(), cur_batch_size); + auto nearest_idx_view = + raft::make_device_vector_view(nearest_idx.data_handle(), cur_batch_size); + auto nearest_dist_view = + raft::make_device_vector_view(nearest_dist.data_handle(), cur_batch_size); if constexpr (!data_on_device) { if (need_compute_norms) { @@ -883,7 +885,8 @@ void kmeans_fit( metric, iter_params.batch_samples, iter_params.batch_centroids, - minCAD_view, + nearest_idx_view, + nearest_dist_view, l2_const_view, L2NormBuf_OR_DistBuf, ws, @@ -1071,8 +1074,7 @@ void kmeans_predict(raft::resources const& handle, raft::make_const_mdspan(weight.view())); } - auto minClusterAndDistance = - raft::make_device_vector, IndexT>(handle, n_samples); + auto nearest_dist = raft::make_device_vector(handle, n_samples); rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); // L2 norm of X: ||x||^2 @@ -1082,50 +1084,35 @@ void kmeans_predict(raft::resources const& handle, raft::linalg::norm(handle, X, L2NormX.view()); } - // computes minClusterAndDistance[0:n_samples) where minClusterAndDistance[i] - // is a pair where - // 'key' is index to a sample in 'centroids' (index of the nearest - // centroid) and 'value' is the distance between the sample 'X[i]' and the - // 'centroid[key]' auto l2normx_view = raft::make_device_vector_view(L2NormX.data_handle(), n_samples); - cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( - handle, - X, - centroids, - minClusterAndDistance.view(), - l2normx_view, - L2NormBuf_OR_DistBuf, - pams.metric, - pams.batch_samples, - pams.batch_centroids, - workspace); + cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute(handle, + X, + centroids, + labels, + nearest_dist.view(), + l2normx_view, + L2NormBuf_OR_DistBuf, + pams.metric, + pams.batch_samples, + pams.batch_centroids, + workspace); - // calculate cluster cost phi_x(C) rmm::device_scalar clusterCostD(stream); - raft::linalg::map( - handle, - minClusterAndDistance.view(), - [=] __device__(const raft::KeyValuePair kvp, DataT wt) { - raft::KeyValuePair res; - res.value = kvp.value * wt; - res.key = kvp.key; - return res; - }, - raft::make_const_mdspan(minClusterAndDistance.view()), - raft::make_const_mdspan(weight.view())); + raft::linalg::map(handle, + nearest_dist.view(), + raft::mul_op{}, + raft::make_const_mdspan(nearest_dist.view()), + raft::make_const_mdspan(weight.view())); cuvs::cluster::kmeans::detail::computeClusterCost( handle, - minClusterAndDistance.view(), + nearest_dist.view(), workspace, raft::make_device_scalar_view(clusterCostD.data()), - raft::value_op{}, + raft::identity_op{}, raft::add_op{}); - raft::linalg::map( - handle, labels, raft::key_op{}, raft::make_const_mdspan(minClusterAndDistance.view())); - inertia[0] = clusterCostD.value(stream); } diff --git a/cpp/src/cluster/detail/kmeans_balanced.cuh b/cpp/src/cluster/detail/kmeans_balanced.cuh index 7fac255810..007c462247 100644 --- a/cpp/src/cluster/detail/kmeans_balanced.cuh +++ b/cpp/src/cluster/detail/kmeans_balanced.cuh @@ -98,55 +98,118 @@ inline std::enable_if_t> predict_core( raft::make_device_matrix_view(centers, n_clusters, dim); auto X_norm_view = raft::make_device_vector_view(dataset_norm, n_rows); - auto minClusterAndDistance = raft::make_device_mdarray, IdxT>( - handle, mr, raft::make_extents(n_rows)); - - cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( - handle, - X_view, - centroids_view, - minClusterAndDistance.view(), - X_norm_view, - L2NormBuf_OR_DistBuf, - params.metric, - 0, // batch_samples (unused for fused reduction) - 0, // batch_centroids (unused for fused reduction) - workspace); - - // Copy keys to output labels - raft::linalg::map(handle, - raft::make_const_mdspan(minClusterAndDistance.view()), - raft::make_device_vector_view(labels, n_rows), - raft::compose_op, raft::key_op>()); + auto nearest_dist = + raft::make_device_mdarray(handle, mr, raft::make_extents(n_rows)); + + if constexpr (std::is_same_v) { + auto labels_view = raft::make_device_vector_view(labels, n_rows); + cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( + handle, + X_view, + centroids_view, + labels_view, + nearest_dist.view(), + X_norm_view, + L2NormBuf_OR_DistBuf, + params.metric, + 0, // batch_samples (unused for fused reduction) + 0, // batch_centroids (unused for fused reduction) + workspace); + } else { + auto nearest_idx = + raft::make_device_mdarray(handle, mr, raft::make_extents(n_rows)); + cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( + handle, + X_view, + centroids_view, + nearest_idx.view(), + nearest_dist.view(), + X_norm_view, + L2NormBuf_OR_DistBuf, + params.metric, + 0, + 0, + workspace); + raft::copy( + handle, raft::make_device_vector_view(labels, n_rows), nearest_idx.view()); + } break; } case cuvs::distance::DistanceType::InnerProduct: { - // TODO: pass buffer - rmm::device_uvector distances(n_rows * n_clusters, stream, mr); + if (uses_fused_distance_nn( + use_fused(handle, n_rows, n_clusters, dim, params.metric))) { + rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream, mr); + rmm::device_uvector workspace(0, stream, mr); + + auto X_view = raft::make_device_matrix_view(dataset, n_rows, dim); + auto centroids_view = + raft::make_device_matrix_view(centers, n_clusters, dim); + auto X_norm_view = raft::make_device_vector_view(dataset_norm, n_rows); + + auto nearest_dist = + raft::make_device_mdarray(handle, mr, raft::make_extents(n_rows)); + + if constexpr (std::is_same_v) { + auto labels_view = raft::make_device_vector_view(labels, n_rows); + cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( + handle, + X_view, + centroids_view, + labels_view, + nearest_dist.view(), + X_norm_view, + L2NormBuf_OR_DistBuf, + params.metric, + 0, + 0, + workspace); + } else { + auto nearest_idx = + raft::make_device_mdarray(handle, mr, raft::make_extents(n_rows)); + cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( + handle, + X_view, + centroids_view, + nearest_idx.view(), + nearest_dist.view(), + X_norm_view, + L2NormBuf_OR_DistBuf, + params.metric, + 0, + 0, + workspace); + raft::copy(handle, + raft::make_device_vector_view(labels, n_rows), + nearest_idx.view()); + } + } else { + rmm::device_uvector distances(n_rows * n_clusters, stream, mr); - MathT alpha = -1.0; - MathT beta = 0.0; + MathT alpha = -1.0; + MathT beta = 0.0; - raft::linalg::gemm(handle, - true, - false, - n_clusters, - n_rows, - dim, - &alpha, - centers, - dim, - dataset, - dim, - &beta, - distances.data(), - n_clusters, - stream); + raft::linalg::gemm(handle, + true, + false, + n_clusters, + n_rows, + dim, + &alpha, + centers, + dim, + dataset, + dim, + &beta, + distances.data(), + n_clusters, + stream); - auto distances_const_view = raft::make_device_matrix_view( - distances.data(), n_rows, n_clusters); - auto labels_view = raft::make_device_vector_view(labels, n_rows); - raft::matrix::argmin(handle, distances_const_view, labels_view); + auto distances_const_view = + raft::make_device_matrix_view( + distances.data(), n_rows, n_clusters); + auto labels_view = raft::make_device_vector_view(labels, n_rows); + raft::matrix::argmin(handle, distances_const_view, labels_view); + } break; } default: { @@ -185,14 +248,19 @@ auto calc_minibatch_size(const raft::resources& handle, size_t mem_per_row = 0; switch (metric) { case distance::DistanceType::L2Expanded: - case distance::DistanceType::L2SqrtExpanded: { - if (use_fused(handle, n_rows, n_clusters, dim)) { - // fusedL2NN needs a mutex and a key-value pair for each row. - mem_per_row += sizeof(int); - mem_per_row += sizeof(raft::KeyValuePair); - } else { - // unfused path needs a full GEMM output (distance matrix row). - mem_per_row += sizeof(MathT) * n_clusters; + case distance::DistanceType::L2SqrtExpanded: + case distance::DistanceType::InnerProduct: { + switch (use_fused(handle, n_rows, n_clusters, dim, metric)) { + case FusedDistancePath::FusedCutile: break; + case FusedDistancePath::FusedCutlass: + // fusedDistanceNNMinReduce CUTLASS fallback: mutex workspace + scratch KVP per row. + mem_per_row += sizeof(int); + mem_per_row += sizeof(raft::KeyValuePair); + break; + case FusedDistancePath::Unfused: + // unfused / GEMM+argmin path needs a full distance matrix row. + mem_per_row += sizeof(MathT) * n_clusters; + break; } } break; // Other metrics require storing a distance matrix. diff --git a/cpp/src/cluster/detail/kmeans_common.cuh b/cpp/src/cluster/detail/kmeans_common.cuh index ba98dadca6..f0d9bde801 100644 --- a/cpp/src/cluster/detail/kmeans_common.cuh +++ b/cpp/src/cluster/detail/kmeans_common.cuh @@ -7,6 +7,7 @@ #include "../../distance/distance.cuh" #include #include +#include #include #include @@ -57,29 +58,67 @@ namespace cuvs::cluster::kmeans::detail { +template +inline constexpr bool is_cutile_fused_data_type_v = + std::is_same_v || std::is_same_v; + +/** Which fused-distance implementation minCluster* will use (or Unfused). */ +enum class FusedDistancePath : std::uint8_t { + /** unfusedDistanceNNMinReduce or batched pairwise distance. */ + Unfused = 0, + /** fusedDistanceNNMinReduce via cuTile; no CUTLASS mutex / KVP scratch. */ + FusedCutile, + /** fusedDistanceNNMinReduce via legacy CUTLASS; needs mutex workspace + KVP scratch. */ + FusedCutlass, +}; + +inline constexpr bool uses_fused_distance_nn(FusedDistancePath path) +{ + return path != FusedDistancePath::Unfused; +} + +inline constexpr bool needs_cutlass_kvp_scratch(FusedDistancePath path) +{ + return path == FusedDistancePath::FusedCutlass; +} + +inline constexpr bool needs_fused_mutex_workspace(FusedDistancePath path) +{ + return path == FusedDistancePath::FusedCutlass; +} + /** - * @brief Returns true if the fused distance NN implementation should be used. + * @brief Selects the fused-distance assignment path for KMeans. * - * On Ampere (SM <= 8.x) always use fused. - * On Hopper (SM 9.x) use fused when m or n >= 4096. - * On Blackwell (SM >= 10.x) use unfused. + * Float/half: cuTile when the build and device support it. Otherwise L2/L2Sqrt/Cosine may use + * legacy CUTLASS fused on Ampere/Hopper (large enough problems). InnerProduct without cuTile uses + * Unfused. Double never uses cuTile; keeps historical CUTLASS/unfused heuristics on pre-Blackwell + * GPUs. */ template -bool use_fused(const raft::resources& handle, IdxT m, IdxT n, IdxT k) +FusedDistancePath use_fused( + const raft::resources& handle, IdxT m, IdxT n, IdxT k, cuvs::distance::DistanceType metric) { + (void)k; cudaDeviceProp prop; prop = raft::resource::get_device_properties(handle); - if (prop.major <= 8) { - // Use fused for Ampere or before - return true; - } else if (prop.major == 9 && (m >= 4096 || n >= 4096)) { - // On Hopper if m, n are bigger than 4096, use fused - return true; - } else if (prop.major >= 10) { - // On Blackwell onwards, use unfused - return false; + + if constexpr (is_cutile_fused_data_type_v) { + if constexpr (cuvs::detail::jit_lto::library_built_with_cutile()) { + if (cuvs::detail::jit_lto::cutile_launch_available_on_current_device()) { + return FusedDistancePath::FusedCutile; + } + } + if (metric == cuvs::distance::DistanceType::InnerProduct) { return FusedDistancePath::Unfused; } + if (prop.major <= 8) { return FusedDistancePath::FusedCutlass; } + if (prop.major == 9 && (m >= 4096 || n >= 4096)) { return FusedDistancePath::FusedCutlass; } + return FusedDistancePath::Unfused; } - return false; + + if (prop.major >= 10) { return FusedDistancePath::Unfused; } + if (prop.major <= 8) { return FusedDistancePath::FusedCutlass; } + if (prop.major == 9 && (m >= 4096 || n >= 4096)) { return FusedDistancePath::FusedCutlass; } + return FusedDistancePath::Unfused; } template @@ -370,33 +409,32 @@ void shuffleAndGather(raft::resources const& handle, stream); } -// Calculates a pair for every sample in input 'X' where key is an -// index to an sample in 'centroids' (index of the nearest centroid) and 'value' -// is the distance between the sample and the 'centroid[key]' +// Calculates nearest centroid index and distance for every sample in input 'X'. template -void minClusterAndDistanceCompute( - raft::resources const& handle, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view, IndexT> minClusterAndDistance, - raft::device_vector_view L2NormX, - rmm::device_uvector& L2NormBuf_OR_DistBuf, - cuvs::distance::DistanceType metric, - int batch_samples, - int batch_centroids, - rmm::device_uvector& workspace); - -#define EXTERN_TEMPLATE_MIN_CLUSTER_AND_DISTANCE(DataT, IndexT) \ - extern template void minClusterAndDistanceCompute( \ - raft::resources const& handle, \ - raft::device_matrix_view X, \ - raft::device_matrix_view centroids, \ - raft::device_vector_view, IndexT> minClusterAndDistance, \ - raft::device_vector_view L2NormX, \ - rmm::device_uvector& L2NormBuf_OR_DistBuf, \ - cuvs::distance::DistanceType metric, \ - int batch_samples, \ - int batch_centroids, \ +void minClusterAndDistanceCompute(raft::resources const& handle, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::device_vector_view nearest_idx, + raft::device_vector_view nearest_dist, + raft::device_vector_view L2NormX, + rmm::device_uvector& L2NormBuf_OR_DistBuf, + cuvs::distance::DistanceType metric, + int batch_samples, + int batch_centroids, + rmm::device_uvector& workspace); + +#define EXTERN_TEMPLATE_MIN_CLUSTER_AND_DISTANCE(DataT, IndexT) \ + extern template void minClusterAndDistanceCompute( \ + raft::resources const& handle, \ + raft::device_matrix_view X, \ + raft::device_matrix_view centroids, \ + raft::device_vector_view nearest_idx, \ + raft::device_vector_view nearest_dist, \ + raft::device_vector_view L2NormX, \ + rmm::device_uvector& L2NormBuf_OR_DistBuf, \ + cuvs::distance::DistanceType metric, \ + int batch_samples, \ + int batch_centroids, \ rmm::device_uvector& workspace); EXTERN_TEMPLATE_MIN_CLUSTER_AND_DISTANCE(float, int64_t) @@ -455,22 +493,16 @@ void countSamplesInCluster(raft::resources const& handle, // stores (key, value) pair corresponding to each sample where // - key is the index of nearest cluster // - value is the distance to the nearest cluster - auto minClusterAndDistance = - raft::make_device_vector, IndexT>(handle, n_samples); - - // temporary buffer to store distance matrix, destructor releases the resource + auto nearest_idx = raft::make_device_vector(handle, n_samples); + auto nearest_dist = raft::make_device_vector(handle, n_samples); rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); - // computes minClusterAndDistance[0:n_samples) where minClusterAndDistance[i] - // is a pair where - // 'key' is index to an sample in 'centroids' (index of the nearest - // centroid) and 'value' is the distance between the sample 'X[i]' and the - // 'centroid[key]' cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( handle, X, (raft::device_matrix_view)centroids, - minClusterAndDistance.view(), + nearest_idx.view(), + nearest_dist.view(), L2NormX, L2NormBuf_OR_DistBuf, params.metric, @@ -478,12 +510,8 @@ void countSamplesInCluster(raft::resources const& handle, params.batch_centroids, workspace); - cuda::transform_iterator itr(minClusterAndDistance.data_handle(), - cuvs::cluster::kmeans::detail::KeyValueIndexOp{}); - - // count # of samples in each cluster countLabels(handle, - itr, + nearest_idx.data_handle(), sampleCountInCluster.data_handle(), (IndexT)n_samples, (IndexT)n_clusters, @@ -668,7 +696,8 @@ __device__ void check_convergence(raft::device_scalar_view clusteri * @param[in] batch_samples_param Batch-samples param forwarded to minClusterAndDistanceCompute * @param[in] batch_centroids_param Batch-centroids param forwarded to * minClusterAndDistanceCompute - * @param[inout] minClusterAndDistance Work buffer [batch_size] + * @param[inout] nearest_idx Nearest cluster index per sample [batch_size] + * @param[inout] nearest_dist Nearest distance per sample [batch_size] * @param[in] L2NormBatch Precomputed data norms [batch_size] * @param[inout] L2NormBuf_OR_DistBuf Resizable scratch * @param[inout] workspace Resizable scratch @@ -677,29 +706,30 @@ __device__ void check_convergence(raft::device_scalar_view clusteri * @param[inout] clustering_cost Running cost scalar (device) (added into) */ template -void process_batch( - raft::resources const& handle, - raft::device_matrix_view batch_data, - raft::device_vector_view batch_weights, - raft::device_matrix_view centroids, - cuvs::distance::DistanceType metric, - int batch_samples_param, - int batch_centroids_param, - raft::device_vector_view, IndexT> minClusterAndDistance, - raft::device_vector_view L2NormBatch, - rmm::device_uvector& L2NormBuf_OR_DistBuf, - rmm::device_uvector& workspace, - raft::device_matrix_view centroid_sums, - raft::device_vector_view weight_per_cluster, - raft::device_scalar_view clustering_cost, - rmm::device_uvector& batch_workspace) +void process_batch(raft::resources const& handle, + raft::device_matrix_view batch_data, + raft::device_vector_view batch_weights, + raft::device_matrix_view centroids, + cuvs::distance::DistanceType metric, + int batch_samples_param, + int batch_centroids_param, + raft::device_vector_view nearest_idx, + raft::device_vector_view nearest_dist, + raft::device_vector_view L2NormBatch, + rmm::device_uvector& L2NormBuf_OR_DistBuf, + rmm::device_uvector& workspace, + raft::device_matrix_view centroid_sums, + raft::device_vector_view weight_per_cluster, + raft::device_scalar_view clustering_cost, + rmm::device_uvector& batch_workspace) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); minClusterAndDistanceCompute(handle, batch_data, centroids, - minClusterAndDistance, + nearest_idx, + nearest_dist, L2NormBatch, L2NormBuf_OR_DistBuf, metric, @@ -707,36 +737,30 @@ void process_batch( batch_centroids_param, workspace); - KeyValueIndexOp conversion_op; - thrust::transform_iterator, - const raft::KeyValuePair*> - labels_itr(minClusterAndDistance.data_handle(), conversion_op); - compute_centroid_adjustments(handle, batch_data, batch_weights, - labels_itr, + nearest_idx.data_handle(), static_cast(centroid_sums.extent(0)), centroid_sums, weight_per_cluster, batch_workspace, /*reset_sums=*/false); - raft::linalg::map( - handle, - minClusterAndDistance, - [=] __device__(const raft::KeyValuePair kvp, DataT wt) { - raft::KeyValuePair res; - res.value = kvp.value * wt; - res.key = kvp.key; - return res; - }, - raft::make_const_mdspan(minClusterAndDistance), - batch_weights); + auto weighted_dist = raft::make_device_vector(handle, nearest_dist.extent(0)); + raft::linalg::map(handle, + weighted_dist.view(), + raft::mul_op{}, + raft::make_const_mdspan(nearest_dist), + raft::make_const_mdspan(batch_weights)); auto batch_cost = raft::make_device_scalar(handle, DataT{0}); - computeClusterCost( - handle, minClusterAndDistance, workspace, batch_cost.view(), raft::value_op{}, raft::add_op{}); + computeClusterCost(handle, + weighted_dist.view(), + workspace, + batch_cost.view(), + raft::identity_op{}, + raft::add_op{}); raft::linalg::add(clustering_cost.data_handle(), clustering_cost.data_handle(), batch_cost.data_handle(), diff --git a/cpp/src/cluster/detail/kmeans_mg.cuh b/cpp/src/cluster/detail/kmeans_mg.cuh index ce3ca5a1fe..955d51c2a9 100644 --- a/cpp/src/cluster/detail/kmeans_mg.cuh +++ b/cpp/src/cluster/detail/kmeans_mg.cuh @@ -539,11 +539,8 @@ void fit(const raft::resources& handle, THROW("unknown initialization method to select initial centers"); } - // stores (key, value) pair corresponding to each sample where - // - key is the index of nearest cluster - // - value is the distance to the nearest cluster - auto minClusterAndDistance = - raft::make_device_vector, IndexT>(handle, n_samples); + auto nearest_idx = raft::make_device_vector(handle, n_samples); + auto nearest_dist = raft::make_device_vector(handle, n_samples); // temporary buffer to store L2 norm of centroids or distance matrix, // destructor releases the resource @@ -577,15 +574,11 @@ void fit(const raft::resources& handle, auto const_centroids = raft::make_device_matrix_view( centroids.data_handle(), centroids.extent(0), centroids.extent(1)); - // computes minClusterAndDistance[0:n_samples) where - // minClusterAndDistance[i] is a pair where - // 'key' is index to an sample in 'centroids' (index of the nearest - // centroid) and 'value' is the distance between the sample 'X[i]' and the - // 'centroid[key]' cuvs::cluster::kmeans::min_cluster_and_distance(handle, X, const_centroids, - minClusterAndDistance.view(), + nearest_idx.view(), + nearest_dist.view(), L2NormX.view(), L2NormBuf_OR_DistBuf, params.metric, @@ -595,9 +588,7 @@ void fit(const raft::resources& handle, workspace.resize(n_samples, stream); - cuda::transform_iterator keys_itr( - minClusterAndDistance.data_handle(), - cuvs::cluster::kmeans::detail::KeyValueIndexOp{}); + const IndexT* keys_itr = nearest_idx.data_handle(); raft::linalg::reduce_rows_by_key((DataT*)X.data_handle(), X.extent(1), keys_itr, @@ -696,35 +687,24 @@ void fit(const raft::resources& handle, raft::make_device_vector_view(centroids.data_handle(), newCentroids.size()), raft::make_device_vector_view(newCentroids.data_handle(), newCentroids.size())); - bool done = false; - rmm::device_scalar> clusterCostD(stream); + bool done = false; + auto clusterCostD = raft::make_device_scalar(handle, DataT{0}); // calculate cluster cost phi_x(C) cuvs::cluster::kmeans::cluster_cost( handle, - minClusterAndDistance.view(), + nearest_dist.view(), workspace, - raft::make_device_scalar_view(clusterCostD.data()), - cuda::proclaim_return_type>( - [] __device__(const raft::KeyValuePair& a, - const raft::KeyValuePair& b) { - raft::KeyValuePair res; - res.key = 0; - res.value = a.value + b.value; - return res; - })); + clusterCostD.view(), + cuda::proclaim_return_type( + [] __device__(const DataT& a, const DataT& b) { return a + b; })); // Cluster cost phi_x(C) from all ranks - comm.allreduce(&(clusterCostD.data()->value), - &(clusterCostD.data()->value), - 1, - raft::comms::op_t::SUM, - stream); + comm.allreduce( + clusterCostD.data_handle(), clusterCostD.data_handle(), 1, raft::comms::op_t::SUM, stream); DataT curClusteringCost = 0; - raft::copy(handle, - raft::make_host_scalar_view(&curClusteringCost), - raft::make_device_scalar_view(&(clusterCostD.data()->value))); + raft::copy(handle, raft::make_host_scalar_view(&curClusteringCost), clusterCostD.view()); ASSERT(comm.sync_stream(stream) == raft::comms::status_t::SUCCESS, "An error occurred in the distributed operation. This can result " diff --git a/cpp/src/cluster/detail/kmeans_mg_batched.cuh b/cpp/src/cluster/detail/kmeans_mg_batched.cuh index 98fed41636..ccc89991a6 100644 --- a/cpp/src/cluster/detail/kmeans_mg_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_mg_batched.cuh @@ -157,9 +157,9 @@ void mnmg_fit(const raft::resources& handle, auto sqrd_norm_error_dev = raft::make_device_scalar(dev_res, DataT{0}); IndexT alloc_batch_size = has_data ? streaming_batch_size : IndexT{1}; auto batch_weights = raft::make_device_vector(dev_res, alloc_batch_size); - auto minClusterAndDistance = - raft::make_device_vector, IndexT>(dev_res, alloc_batch_size); - auto L2NormBatch = raft::make_device_vector(dev_res, alloc_batch_size); + auto nearest_idx = raft::make_device_vector(dev_res, alloc_batch_size); + auto nearest_dist = raft::make_device_vector(dev_res, alloc_batch_size); + auto L2NormBatch = raft::make_device_vector(dev_res, alloc_batch_size); rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); rmm::device_uvector workspace(0, stream); rmm::device_uvector batch_workspace(0, stream); @@ -353,9 +353,10 @@ void mnmg_fit(const raft::resources& handle, auto L2NormBatch_const = raft::make_const_mdspan(L2NormBatch_view); - auto minClusterAndDistance_view = - raft::make_device_vector_view, IndexT>( - minClusterAndDistance.data_handle(), current_batch_size); + auto nearest_idx_view = raft::make_device_vector_view( + nearest_idx.data_handle(), current_batch_size); + auto nearest_dist_view = raft::make_device_vector_view( + nearest_dist.data_handle(), current_batch_size); cuvs::cluster::kmeans::detail::process_batch( dev_res, @@ -365,7 +366,8 @@ void mnmg_fit(const raft::resources& handle, metric, params.batch_samples, params.batch_centroids, - minClusterAndDistance_view, + nearest_idx_view, + nearest_dist_view, L2NormBatch_const, L2NormBuf_OR_DistBuf, workspace, diff --git a/cpp/src/cluster/detail/minClusterDistanceCompute.cu b/cpp/src/cluster/detail/minClusterDistanceCompute.cu index b15119599e..d01f48fc0a 100644 --- a/cpp/src/cluster/detail/minClusterDistanceCompute.cu +++ b/cpp/src/cluster/detail/minClusterDistanceCompute.cu @@ -11,31 +11,105 @@ namespace cuvs::cluster::kmeans::detail { -// Calculates a pair for every sample in input 'X' where key is an -// index to an sample in 'centroids' (index of the nearest centroid) and 'value' -// is the distance between the sample and the 'centroids[key]'. +namespace { + +template +__global__ void unpack_kvp_to_soa(IndexT* nearest_idx, + DataT* nearest_dist, + const raft::KeyValuePair* kvp, + IndexT n) +{ + IndexT i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n) { + if (nearest_idx != nullptr) { nearest_idx[i] = kvp[i].key; } + if (nearest_dist != nullptr) { nearest_dist[i] = kvp[i].value; } + } +} + +template +void unpack_kvp(raft::resources const& handle, + raft::device_vector_view nearest_idx, + raft::device_vector_view nearest_dist, + raft::device_vector_view, IndexT> kvp) +{ + auto stream = raft::resource::get_cuda_stream(handle); + auto n = static_cast(kvp.extent(0)); + int blks = static_cast((n + 255) / 256); + unpack_kvp_to_soa<<>>( + nearest_idx.data_handle(), nearest_dist.data_handle(), kvp.data_handle(), n); + RAFT_CUDA_TRY(cudaGetLastError()); +} + +} // namespace + template -void minClusterAndDistanceCompute( - raft::resources const& handle, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view, IndexT> minClusterAndDistance, - raft::device_vector_view L2NormX, - rmm::device_uvector& L2NormBuf_OR_DistBuf, - cuvs::distance::DistanceType metric, - int batch_samples, - int batch_centroids, - rmm::device_uvector& workspace) +void minClusterAndDistanceCompute(raft::resources const& handle, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::device_vector_view nearest_idx, + raft::device_vector_view nearest_dist, + raft::device_vector_view L2NormX, + rmm::device_uvector& L2NormBuf_OR_DistBuf, + cuvs::distance::DistanceType metric, + int batch_samples, + int batch_centroids, + rmm::device_uvector& workspace) { - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_clusters = centroids.extent(0); - bool is_fused = metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded || - metric == cuvs::distance::DistanceType::CosineExpanded; + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto n_samples = X.extent(0); + auto n_features = X.extent(1); + auto n_clusters = centroids.extent(0); + const bool is_l2_cos = metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::L2SqrtExpanded || + metric == cuvs::distance::DistanceType::CosineExpanded; + const FusedDistancePath fused_path = + use_fused(handle, n_samples, n_clusters, n_features, metric); + + if (uses_fused_distance_nn(fused_path)) { + L2NormBuf_OR_DistBuf.resize(n_clusters, stream); + auto centroidsNorm = + raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); - if (is_fused) { + if (is_l2_cos) { + if (metric == cuvs::distance::DistanceType::CosineExpanded) { + raft::linalg::norm( + handle, centroids, centroidsNorm, raft::sqrt_op{}); + } else { + raft::linalg::norm( + handle, centroids, centroidsNorm); + } + } + + auto centroidsNormConst = + raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); + + raft::KeyValuePair* cutlass_kvp_scratch = nullptr; + rmm::device_uvector> temp_kvp(0, stream); + if (needs_cutlass_kvp_scratch(fused_path)) { + temp_kvp.resize(n_samples, stream); + cutlass_kvp_scratch = temp_kvp.data(); + workspace.resize(sizeof(int) * n_samples, stream); + } + + cuvs::distance::fusedDistanceNNMinReduce( + nearest_idx.data_handle(), + nearest_dist.data_handle(), + X.data_handle(), + centroids.data_handle(), + L2NormX.data_handle(), + centroidsNormConst.data_handle(), + n_samples, + n_clusters, + n_features, + needs_fused_mutex_workspace(fused_path) ? (void*)workspace.data() : nullptr, + metric != cuvs::distance::DistanceType::L2Expanded, + true, + true, + metric, + 0.0f, + cutlass_kvp_scratch, + stream); + } else if (is_l2_cos) { L2NormBuf_OR_DistBuf.resize(n_clusters, stream); auto centroidsNorm = raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); @@ -48,21 +122,23 @@ void minClusterAndDistanceCompute( handle, centroids, centroidsNorm); } - raft::KeyValuePair initial_value(0, std::numeric_limits::max()); - raft::matrix::fill(handle, minClusterAndDistance, initial_value); - - bool should_use_fused = - use_fused(handle, n_samples, n_clusters, n_features); + auto centroidsNormConst = + raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); - if (should_use_fused) { - workspace.resize((sizeof(int)) * n_samples, stream); + workspace.resize(sizeof(DataT) * n_samples * n_clusters, stream); + auto temp_kvp = + raft::make_device_vector, IndexT>(handle, n_samples); + raft::KeyValuePair initial_value(0, std::numeric_limits::max()); + raft::matrix::fill(handle, temp_kvp.view(), initial_value); - cuvs::distance::fusedDistanceNNMinReduce, IndexT>( - minClusterAndDistance.data_handle(), + cuvs::distance:: + unfusedDistanceNNMinReduce, IndexT>( + handle, + temp_kvp.data_handle(), X.data_handle(), centroids.data_handle(), L2NormX.data_handle(), - centroidsNorm.data_handle(), + centroidsNormConst.data_handle(), n_samples, n_clusters, n_features, @@ -73,84 +149,44 @@ void minClusterAndDistanceCompute( metric, 0.0f, stream); - } else { - workspace.resize(sizeof(DataT) * n_samples * n_clusters, stream); - - cuvs::distance:: - unfusedDistanceNNMinReduce, IndexT>( - handle, - minClusterAndDistance.data_handle(), - X.data_handle(), - centroids.data_handle(), - L2NormX.data_handle(), - centroidsNorm.data_handle(), - n_samples, - n_clusters, - n_features, - (void*)workspace.data(), - metric != cuvs::distance::DistanceType::L2Expanded, - false, - true, - metric, - 0.0f, - stream); - } + unpack_kvp(handle, nearest_idx, nearest_dist, raft::make_const_mdspan(temp_kvp.view())); } else { auto dataBatchSize = getDataBatchSize(batch_samples, n_samples); auto centroidsBatchSize = getCentroidsBatchSize(batch_centroids, n_clusters); - // TODO: Unless pool allocator is used, passing in a workspace for this - // isn't really increasing performance because this needs to do a re-allocation - // anyways. ref https://github.com/rapidsai/raft/issues/930 L2NormBuf_OR_DistBuf.resize(dataBatchSize * centroidsBatchSize, stream); - // pairwiseDistance[ns x nc] - tensor wrapper around the distance buffer auto pairwiseDistance = raft::make_device_matrix_view( L2NormBuf_OR_DistBuf.data(), dataBatchSize, centroidsBatchSize); + auto temp_kvp = + raft::make_device_vector, IndexT>(handle, n_samples); raft::KeyValuePair initial_value(0, std::numeric_limits::max()); - raft::matrix::fill(handle, minClusterAndDistance, initial_value); + raft::matrix::fill(handle, temp_kvp.view(), initial_value); - // tile over the input dataset for (IndexT dIdx = 0; dIdx < n_samples; dIdx += dataBatchSize) { - // # of samples for the current batch auto ns = std::min((IndexT)dataBatchSize, n_samples - dIdx); - // datasetView [ns x n_features] - view representing the current batch of - // input dataset auto datasetView = raft::make_device_matrix_view( X.data_handle() + (dIdx * n_features), ns, n_features); - // minClusterAndDistanceView [ns x n_clusters] - auto minClusterAndDistanceView = - raft::make_device_vector_view, IndexT>( - minClusterAndDistance.data_handle() + dIdx, ns); + auto temp_kvp_view = raft::make_device_vector_view, IndexT>( + temp_kvp.data_handle() + dIdx, ns); - // tile over the centroids for (IndexT cIdx = 0; cIdx < n_clusters; cIdx += centroidsBatchSize) { - // # of centroids for the current batch auto nc = std::min((IndexT)centroidsBatchSize, n_clusters - cIdx); - // centroidsView [nc x n_features] - view representing the current batch - // of centroids auto centroidsView = raft::make_device_matrix_view( centroids.data_handle() + (cIdx * n_features), nc, n_features); - // pairwiseDistanceView [ns x nc] - view representing the pairwise - // distance for current batch auto pairwiseDistanceView = raft::make_device_matrix_view(pairwiseDistance.data_handle(), ns, nc); - // calculate pairwise distance between current tile of cluster centroids - // and input dataset pairwise_distance_kmeans( handle, datasetView, centroidsView, pairwiseDistanceView, metric); - // argmin reduction returning pair - // calculates the closest centroid and the distance to the closest - // centroid raft::linalg::coalescedReduction( - minClusterAndDistanceView.data_handle(), + temp_kvp_view.data_handle(), pairwiseDistanceView.data_handle(), pairwiseDistanceView.extent(1), pairwiseDistanceView.extent(0), @@ -167,20 +203,23 @@ void minClusterAndDistanceCompute( raft::identity_op{}); } } + + unpack_kvp(handle, nearest_idx, nearest_dist, raft::make_const_mdspan(temp_kvp.view())); } } -#define INSTANTIATE_MIN_CLUSTER_AND_DISTANCE(DataT, IndexT) \ - template void minClusterAndDistanceCompute( \ - raft::resources const& handle, \ - raft::device_matrix_view X, \ - raft::device_matrix_view centroids, \ - raft::device_vector_view, IndexT> minClusterAndDistance, \ - raft::device_vector_view L2NormX, \ - rmm::device_uvector& L2NormBuf_OR_DistBuf, \ - cuvs::distance::DistanceType metric, \ - int batch_samples, \ - int batch_centroids, \ +#define INSTANTIATE_MIN_CLUSTER_AND_DISTANCE(DataT, IndexT) \ + template void minClusterAndDistanceCompute( \ + raft::resources const& handle, \ + raft::device_matrix_view X, \ + raft::device_matrix_view centroids, \ + raft::device_vector_view nearest_idx, \ + raft::device_vector_view nearest_dist, \ + raft::device_vector_view L2NormX, \ + rmm::device_uvector& L2NormBuf_OR_DistBuf, \ + cuvs::distance::DistanceType metric, \ + int batch_samples, \ + int batch_centroids, \ rmm::device_uvector& workspace); INSTANTIATE_MIN_CLUSTER_AND_DISTANCE(float, int64_t) @@ -207,13 +246,17 @@ void minClusterDistanceCompute(raft::resources const& handle, auto n_features = X.extent(1); auto n_clusters = centroids.extent(0); - bool is_fused = metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded || - metric == cuvs::distance::DistanceType::CosineExpanded; + const bool is_l2_cos = metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::L2SqrtExpanded || + metric == cuvs::distance::DistanceType::CosineExpanded; raft::matrix::fill(handle, minClusterDistance, std::numeric_limits::max()); - if (is_fused) { + const FusedDistancePath fused_path = + is_l2_cos ? use_fused(handle, n_samples, n_clusters, n_features, metric) + : FusedDistancePath::Unfused; + + if (uses_fused_distance_nn(fused_path)) { L2NormBuf_OR_DistBuf.resize(n_clusters, stream); auto centroidsNorm = raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); @@ -233,9 +276,16 @@ void minClusterDistanceCompute(raft::resources const& handle, centroidsNorm); } - workspace.resize(sizeof(int) * n_samples, stream); + raft::KeyValuePair* cutlass_kvp_scratch = nullptr; + rmm::device_uvector> temp_kvp(0, stream); + if (needs_cutlass_kvp_scratch(fused_path)) { + temp_kvp.resize(n_samples, stream); + cutlass_kvp_scratch = temp_kvp.data(); + workspace.resize(sizeof(int) * n_samples, stream); + } - cuvs::distance::fusedDistanceNNMinReduce( + cuvs::distance::fusedDistanceNNMinReduce( + nullptr, minClusterDistance.data_handle(), X.data_handle(), centroids.data_handle(), @@ -244,12 +294,13 @@ void minClusterDistanceCompute(raft::resources const& handle, n_samples, n_clusters, n_features, - (void*)workspace.data(), + needs_fused_mutex_workspace(fused_path) ? (void*)workspace.data() : nullptr, metric != cuvs::distance::DistanceType::L2Expanded, - false, + true, true, metric, 0.0f, + cutlass_kvp_scratch, stream); } else { auto dataBatchSize = getDataBatchSize(batch_samples, n_samples); @@ -260,8 +311,6 @@ void minClusterDistanceCompute(raft::resources const& handle, auto pairwiseDistance = raft::make_device_matrix_view( L2NormBuf_OR_DistBuf.data(), dataBatchSize, centroidsBatchSize); - // tile over the input data and calculate distance matrix [n_samples x - // n_clusters] for (IndexT dIdx = 0; dIdx < n_samples; dIdx += dataBatchSize) { auto ns = std::min((IndexT)dataBatchSize, n_samples - dIdx); @@ -271,7 +320,6 @@ void minClusterDistanceCompute(raft::resources const& handle, auto minClusterDistanceView = raft::make_device_vector_view(minClusterDistance.data_handle() + dIdx, ns); - // tile over the centroids for (IndexT cIdx = 0; cIdx < n_clusters; cIdx += centroidsBatchSize) { auto nc = std::min((IndexT)centroidsBatchSize, n_clusters - cIdx); diff --git a/cpp/src/cluster/kmeans.cuh b/cpp/src/cluster/kmeans.cuh index 06da1fc1de..003604769b 100644 --- a/cpp/src/cluster/kmeans.cuh +++ b/cpp/src/cluster/kmeans.cuh @@ -486,22 +486,23 @@ void cluster_cost( * */ template -void min_cluster_and_distance( - raft::resources const& handle, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view, IndexT> minClusterAndDistance, - raft::device_vector_view L2NormX, - rmm::device_uvector& L2NormBuf_OR_DistBuf, - cuvs::distance::DistanceType metric, - int batch_samples, - int batch_centroids, - rmm::device_uvector& workspace) +void min_cluster_and_distance(raft::resources const& handle, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::device_vector_view nearest_idx, + raft::device_vector_view nearest_dist, + raft::device_vector_view L2NormX, + rmm::device_uvector& L2NormBuf_OR_DistBuf, + cuvs::distance::DistanceType metric, + int batch_samples, + int batch_centroids, + rmm::device_uvector& workspace) { cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute(handle, X, centroids, - minClusterAndDistance, + nearest_idx, + nearest_dist, L2NormX, L2NormBuf_OR_DistBuf, metric, diff --git a/cpp/src/detail/jit_lto/AlgorithmPlanner.cpp b/cpp/src/detail/jit_lto/AlgorithmPlanner.cpp index 7416ea396d..486d6f1aa5 100644 --- a/cpp/src/detail/jit_lto/AlgorithmPlanner.cpp +++ b/cpp/src/detail/jit_lto/AlgorithmPlanner.cpp @@ -3,33 +3,16 @@ * SPDX-License-Identifier: Apache-2.0 */ -#include -#include #include #include -#include #include #include -#include #include -#include - -#include "cuda_runtime.h" -#include "nvJitLink.h" #include #include -std::string AlgorithmPlanner::get_fragments_key() const -{ - std::string key = ""; - for (const auto& fragment : this->fragments) { - key += fragment->get_key(); - } - return key; -} - std::shared_ptr AlgorithmPlanner::read_cache(std::string const& launch_key) const { auto& launchers = jit_cache_.launchers; @@ -38,79 +21,37 @@ std::shared_ptr AlgorithmPlanner::read_cache(std::string cons return nullptr; } -std::shared_ptr AlgorithmPlanner::get_launcher() +std::shared_ptr AlgorithmPlanner::try_get_launcher() { - auto& launchers = jit_cache_.launchers; - auto launch_key = this->get_fragments_key(); + auto launch_key = this->get_planner_key(); - if (auto hit = read_cache(launch_key)) { return hit; } + { + std::shared_lock read_lock(jit_cache_.mutex); + if (jit_cache_.build_failed.count(launch_key)) { return nullptr; } + if (auto hit = read_cache(launch_key)) { return hit; } + } std::unique_lock write_lock(jit_cache_.mutex); - if (auto it = launchers.find(launch_key); it != launchers.end()) { return it->second; } + if (jit_cache_.build_failed.count(launch_key)) { return nullptr; } + if (auto it = jit_cache_.launchers.find(launch_key); it != jit_cache_.launchers.end()) { + return it->second; + } - std::string log_message = - "JIT compiling launcher for kernel: " + this->entrypoint + " and device functions: "; - for (const auto& fragment : this->fragments) { - log_message += std::string{fragment->get_key()} + ","; + RAFT_LOG_DEBUG("Building launcher for kernel entrypoint: %s", this->entrypoint.c_str()); + auto launcher = this->build(); + if (!launcher) { + jit_cache_.build_failed.insert(launch_key); + return nullptr; } - log_message.pop_back(); - RAFT_LOG_DEBUG("%s", log_message.c_str()); - auto launcher = this->build(); - launchers[launch_key] = launcher; + jit_cache_.launchers[launch_key] = launcher; return launcher; } -std::shared_ptr AlgorithmPlanner::build() +std::shared_ptr AlgorithmPlanner::get_launcher() { - int device = 0; - int major = 0; - int minor = 0; - RAFT_CUDA_TRY(cudaGetDevice(&device)); - RAFT_CUDA_TRY(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device)); - RAFT_CUDA_TRY(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device)); - - std::string archs = "-arch=sm_" + std::to_string((major * 10 + minor)); - - // Load the generated LTO IR and link them together - nvJitLinkHandle handle; - std::vector lopts; - lopts.reserve(2 + linktime_extra_options.size()); - lopts.push_back("-lto"); - lopts.push_back(archs.c_str()); - for (auto const& opt : linktime_extra_options) { - lopts.push_back(opt.c_str()); - } - auto result = nvJitLinkCreate(&handle, static_cast(lopts.size()), lopts.data()); - check_nvjitlink_result(handle, result); - - for (const auto& frag : this->fragments) { - frag->add_to(handle); + auto launcher = try_get_launcher(); + if (!launcher) { + RAFT_FAIL("Failed to build launcher for kernel entrypoint: %s", this->entrypoint.c_str()); } - - // Call to nvJitLinkComplete causes linker to link together all the LTO-IR - // modules perform any optimizations and generate cubin from it. - result = nvJitLinkComplete(handle); - check_nvjitlink_result(handle, result); - - // get cubin from nvJitLink - size_t cubin_size; - result = nvJitLinkGetLinkedCubinSize(handle, &cubin_size); - check_nvjitlink_result(handle, result); - - std::unique_ptr cubin{new char[cubin_size]}; - result = nvJitLinkGetLinkedCubin(handle, cubin.get()); - check_nvjitlink_result(handle, result); - - result = nvJitLinkDestroy(&handle); - RAFT_EXPECTS(result == NVJITLINK_SUCCESS, "nvJitLinkDestroy failed"); - - // cubin is linked, so now load it - cudaLibrary_t library; - RAFT_CUDA_TRY( - cudaLibraryLoadData(&library, cubin.get(), nullptr, nullptr, 0, nullptr, nullptr, 0)); - - cudaKernel_t kernel; - RAFT_CUDA_TRY(cudaLibraryGetKernel(&kernel, library, this->entrypoint.c_str())); - - return std::make_shared(kernel, library); + return launcher; } diff --git a/cpp/src/detail/jit_lto/LTOAlgorithmPlanner.cpp b/cpp/src/detail/jit_lto/LTOAlgorithmPlanner.cpp new file mode 100644 index 0000000000..da7c0408b4 --- /dev/null +++ b/cpp/src/detail/jit_lto/LTOAlgorithmPlanner.cpp @@ -0,0 +1,76 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include +#include +#include + +#include +#include + +#include "cuda_runtime.h" +#include "nvJitLink.h" + +#include + +std::string LTOAlgorithmPlanner::get_planner_key() const +{ + std::string key; + for (const auto& fragment : this->fragments) { + key += fragment->get_key(); + } + return key; +} + +std::shared_ptr LTOAlgorithmPlanner::build() +{ + int device = 0; + int major = 0; + int minor = 0; + RAFT_CUDA_TRY(cudaGetDevice(&device)); + RAFT_CUDA_TRY(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device)); + RAFT_CUDA_TRY(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device)); + + std::string archs = "-arch=sm_" + std::to_string((major * 10 + minor)); + + nvJitLinkHandle handle; + std::vector lopts; + lopts.reserve(2 + linktime_extra_options.size()); + lopts.push_back("-lto"); + lopts.push_back(archs.c_str()); + for (auto const& opt : linktime_extra_options) { + lopts.push_back(opt.c_str()); + } + auto result = nvJitLinkCreate(&handle, static_cast(lopts.size()), lopts.data()); + check_nvjitlink_result(handle, result); + + for (const auto& frag : this->fragments) { + frag->add_to(handle); + } + + result = nvJitLinkComplete(handle); + check_nvjitlink_result(handle, result); + + size_t cubin_size; + result = nvJitLinkGetLinkedCubinSize(handle, &cubin_size); + check_nvjitlink_result(handle, result); + + std::unique_ptr cubin{new char[cubin_size]}; + result = nvJitLinkGetLinkedCubin(handle, cubin.get()); + check_nvjitlink_result(handle, result); + + result = nvJitLinkDestroy(&handle); + RAFT_EXPECTS(result == NVJITLINK_SUCCESS, "nvJitLinkDestroy failed"); + + cudaLibrary_t library; + RAFT_CUDA_TRY( + cudaLibraryLoadData(&library, cubin.get(), nullptr, nullptr, 0, nullptr, nullptr, 0)); + + cudaKernel_t kernel; + RAFT_CUDA_TRY(cudaLibraryGetKernel(&kernel, library, this->entrypoint.c_str())); + + return std::make_shared(kernel, library); +} diff --git a/cpp/src/detail/jit_lto/TileAlgorithmPlanner.cpp b/cpp/src/detail/jit_lto/TileAlgorithmPlanner.cpp new file mode 100644 index 0000000000..1487abb239 --- /dev/null +++ b/cpp/src/detail/jit_lto/TileAlgorithmPlanner.cpp @@ -0,0 +1,82 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include + +#include +#include + +#include +#include + +namespace { + +template +CutileTileConfig tile_config_from_fragment(const FragmentT* fragment, const std::string& entrypoint) +{ + if (fragment == nullptr) { + RAFT_FAIL("cuTile planner '%s' has no registered fragments", entrypoint.c_str()); + } + const int tile_m = fragment->get_tile_m(); + const int tile_n = fragment->get_tile_n(); + const int tile_k = fragment->get_tile_k(); + if (tile_m <= 0 || tile_n <= 0 || tile_k <= 0) { + RAFT_FAIL( + "cuTile planner '%s' is missing tile geometry in its static fragment (check " + "register_cutile_fragment.cpp generation)", + entrypoint.c_str()); + } + return CutileTileConfig{tile_m, tile_n, tile_k}; +} + +} // namespace + +std::string TileAlgorithmPlanner::get_planner_key() const +{ + std::string key = this->entrypoint; + for (const auto& fragment : cubin_fragments_) { + key += fragment->get_key(); + } + if (tileir_fragment_) { key += tileir_fragment_->get_key(); } + return key; +} + +CutileTileConfig TileAlgorithmPlanner::tile_config() const +{ + int cc_major = 0; + int cc_minor = 0; + if (cuvs::detail::jit_lto::get_device_compute_capability(cc_major, cc_minor)) { + for (const auto& fragment : cubin_fragments_) { + if (fragment->get_cc_major() == cc_major && fragment->get_cc_minor() == cc_minor) { + return tile_config_from_fragment(fragment.get(), entrypoint); + } + } + } + + if (tileir_fragment_) { return tile_config_from_fragment(tileir_fragment_.get(), entrypoint); } + + if (!cubin_fragments_.empty()) { + return tile_config_from_fragment(cubin_fragments_.front().get(), entrypoint); + } + + RAFT_FAIL("cuTile planner '%s' has no registered fragments", entrypoint.c_str()); +} + +std::shared_ptr TileAlgorithmPlanner::build() +{ + int cc_major = 0; + int cc_minor = 0; + if (!cuvs::detail::jit_lto::get_device_compute_capability(cc_major, cc_minor)) { return nullptr; } + + int driver_version = 0; + if (cudaDriverGetVersion(&driver_version) != cudaSuccess) { return nullptr; } + + auto image = cuvs::detail::jit_lto::resolve_cutile_module_image( + cc_major, cc_minor, driver_version, cubin_fragments_, tileir_fragment_.get()); + if (!image) { return nullptr; } + + return cuvs::detail::jit_lto::load_cutile_launcher(*image, this->entrypoint); +} diff --git a/cpp/src/distance/detail/fused_distance_nn.cuh b/cpp/src/distance/detail/fused_distance_nn.cuh index f9dbd968ec..a2ec5422dd 100644 --- a/cpp/src/distance/detail/fused_distance_nn.cuh +++ b/cpp/src/distance/detail/fused_distance_nn.cuh @@ -1,11 +1,12 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once #include "distance_ops/l2_exp.cuh" // ops::l2_exp_distance_op +#include "fused_distance_nn/cutile/fused_1nn_tile.hpp" #include "fused_distance_nn/cutlass_base.cuh" #include "fused_distance_nn/fused_cosine_nn.cuh" #include "fused_distance_nn/fused_l2_nn.cuh" @@ -13,6 +14,7 @@ #include "fused_distance_nn/simt_kernel.cuh" #include "pairwise_distance_base.cuh" // PairwiseDistances #include +#include #include // raft::KeyValuePair #include // raft::identity_op #include // Policy @@ -27,13 +29,9 @@ namespace distance { namespace detail { -template -void fusedDistanceNNImpl(OutT* min, +template +void fusedDistanceNNImpl(IdxT* nearest_idx, + DataT* nearest_dist, const DataT* x, const DataT* y, const DataT* xn, @@ -49,36 +47,77 @@ void fusedDistanceNNImpl(OutT* min, bool isRowMajor, cuvs::distance::DistanceType metric, float metric_arg, + raft::KeyValuePair* cutlass_kvp_scratch, cudaStream_t stream) { - // The kernel policy is determined by fusedDistanceNN. typedef Policy P; - - dim3 blk(P::Nthreads); - auto nblks = raft::ceildiv(m, P::Nthreads); + typedef raft::KeyValuePair KVP; constexpr auto maxVal = std::numeric_limits::max(); - typedef raft::KeyValuePair KVPair; - RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); + if constexpr (is_fused_1nn_cutile_data_v) { + if constexpr (cuvs::detail::jit_lto::library_built_with_cutile()) { + if (try_fused_1nn_tile( + nearest_idx, nearest_dist, x, y, xn, yn, m, n, k, metric, sqrt, stream)) { + return; + } + } + } + + RAFT_EXPECTS(cutlass_kvp_scratch != nullptr, "CUTLASS fused 1-NN requires a scratch KVP buffer"); + if (initOutBuffer) { - initKernel - <<>>(min, m, maxVal, redOp); - RAFT_CUDA_TRY(cudaGetLastError()); + initFused1nnOutput(nearest_idx, nearest_dist, m, std::numeric_limits::max(), stream); } + MinAndDistanceReduceOpImpl cutlass_redOp; + cutlass_redOp.out_kvp = cutlass_kvp_scratch; + initialize( + cutlass_kvp_scratch, m, maxVal, cutlass_redOp, stream); + + RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); + switch (metric) { case cuvs::distance::DistanceType::CosineExpanded: - fusedCosineNN( - min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, stream); + fusedCosineNN(nearest_idx, + nearest_dist, + x, + y, + xn, + yn, + m, + n, + k, + workspace, + cutlass_redOp, + pairRedOp, + sqrt, + cutlass_kvp_scratch, + stream); break; case cuvs::distance::DistanceType::L2SqrtExpanded: case cuvs::distance::DistanceType::L2Expanded: - // initOutBuffer is take care by fusedDistanceNNImpl() so we set it false to fusedL2NNImpl. - fusedL2NNImpl( - min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, false, stream); + fusedL2NNImpl(nearest_idx, + nearest_dist, + x, + y, + xn, + yn, + m, + n, + k, + workspace, + cutlass_redOp, + pairRedOp, + sqrt, + false, + cutlass_kvp_scratch, + stream); break; + case cuvs::distance::DistanceType::InnerProduct: break; default: assert("only cosine/l2 metric is supported with fusedDistanceNN\n"); break; } + + unpackFused1nnKvpToSoa(nearest_idx, nearest_dist, cutlass_kvp_scratch, m, stream); } } // namespace detail diff --git a/cpp/src/distance/detail/fused_distance_nn/cutile/export_fused_1nn.py b/cpp/src/distance/detail/fused_distance_nn/cutile/export_fused_1nn.py new file mode 100644 index 0000000000..1211456c9e --- /dev/null +++ b/cpp/src/distance/detail/fused_distance_nn/cutile/export_fused_1nn.py @@ -0,0 +1,226 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. +# SPDX-License-Identifier: Apache-2.0 +"""Export fused 1-NN cuTile kernels to cubin or TileIR bytecode.""" + +from __future__ import annotations + +import argparse +import sys +from pathlib import Path +from typing import Literal + +import cuda.tile as ct +from cuda.tile.compilation import ( + ArrayConstraint, + CallingConvention, + ConstantConstraint, + KernelSignature, + ScalarConstraint, + export_kernel, +) + +from fused_1nn_kernel import ( + INDEX_TYPES, + METRICS, + index_abbrev, + kernel_symbol, + make_kernel, + metric_abbrev, +) + +DEFAULT_TILEIR_BYTECODE_VERSION = "13.1" +# cuTile requires a gpu_code even for TileIR bytecode export: it selects the compilation +# target / feature set for lowering, not the runtime architecture (the driver JITs at load). +DEFAULT_TILEIR_EXPORT_GPU_CODE = "sm_80" + + +def _dtype_for(data_type: str): + if data_type == "half": + return ct.float16 + if data_type == "float": + return ct.float32 + raise ValueError(f"Unsupported data_type {data_type!r}") + + +def _data_abbrev(data_type: str) -> str: + return {"half": "h", "float": "f"}[data_type] + + +def _elem_stride_divisible_for_tma(elem_dtype) -> tuple[int, int]: + """Row stride (dim 0) divisible enough for 16-byte TMA access; last dim stride 1.""" + bytes_per_elem = 2 if elem_dtype == ct.float16 else 4 + return (16 // bytes_per_elem, 1) + + +def _cuvs_matrix_constraint(elem_dtype): + """Row-major device matrices for cuVS KMeans benchmarks. + + Assumes raft/cupy-style contiguous layout: stride[-1]==1, stride[0]==D, + 16-byte base alignment, and row pitch 16-byte aligned (float32 D%4==0, + float16 D%8==0). Applies to both points and centroids matrices. + + shape_divisible_by is (1, 1); tail tiles are masked in the kernel. + Odd D or general layouts need a separate relaxed export profile. + """ + return ArrayConstraint( + elem_dtype, + ndim=2, + index_dtype=ct.int32, + stride_lower_bound_incl=(0, None), + alias_groups=(), + may_alias_internally=False, + stride_constant=(None, 1), + stride_divisible_by=_elem_stride_divisible_for_tma(elem_dtype), + shape_divisible_by=(1, 1), + base_addr_divisible_by=16, + ) + + +def _cuvs_vector_constraint(elem_dtype): + """1-D device vectors: contiguous, 16-byte base. Length need not be divisible by 16.""" + return ArrayConstraint( + elem_dtype, + ndim=1, + index_dtype=ct.int32, + stride_lower_bound_incl=(None,), + alias_groups=(), + may_alias_internally=False, + stride_constant=(1,), + stride_divisible_by=(1,), + shape_divisible_by=(1,), + base_addr_divisible_by=16, + ) + + +def _relaxed_matrix_constraint(elem_dtype): + """Deprecated alias; use _cuvs_matrix_constraint.""" + return _cuvs_matrix_constraint(elem_dtype) + + +def _relaxed_vector_constraint(elem_dtype, *, tma_friendly: bool = False): + """Deprecated alias; use _cuvs_vector_constraint.""" + del tma_friendly + return _cuvs_vector_constraint(elem_dtype) + + +def _kernel_signature( + data_type: str, + metric: str, + index_type: str, + tile_m: int, + tile_n: int, + tile_k: int, +) -> KernelSignature: + elem = _dtype_for(data_type) + matrix = _cuvs_matrix_constraint(elem) + norm_array = _cuvs_vector_constraint(elem) + idx_elem = ct.int32 if index_type == "int32" else ct.int64 + idx_array = _cuvs_vector_constraint(idx_elem) + dist_array = _cuvs_vector_constraint(elem) + + abbrev = _data_abbrev(data_type) + symbol = kernel_symbol( + abbrev, metric_abbrev(metric), index_abbrev(index_type) + ) + + return KernelSignature( + parameters=[ + matrix, + matrix, + norm_array, + norm_array, + idx_array, + dist_array, + ScalarConstraint(ct.int64), + ScalarConstraint(ct.int64), + ScalarConstraint(ct.int64), + ScalarConstraint(ct.int64), + ScalarConstraint(ct.int64), + ConstantConstraint(tile_m), + ConstantConstraint(tile_n), + ConstantConstraint(tile_k), + ], + calling_convention=CallingConvention.cutile_python_v1(), + ).with_symbol(symbol) + + +def export_binary( + output_file: Path, + *, + output_format: Literal["cubin", "tileir_bytecode"], + data_type: str, + metric: str, + index_type: str, + tile_m: int, + tile_n: int, + tile_k: int, + gpu_code: str, + bytecode_version: str | None = None, +) -> str: + kernel = make_kernel( + data_type, metric, tile_m, tile_n, tile_k, index_type=index_type + ) + signature = _kernel_signature( + data_type, metric, index_type, tile_m, tile_n, tile_k + ) + + export_kwargs = { + "kernel": kernel, + "signatures": [signature], + "output_file": str(output_file), + "gpu_code": gpu_code, + "output_format": output_format, + } + if output_format == "tileir_bytecode": + export_kwargs["bytecode_version"] = ( + bytecode_version or DEFAULT_TILEIR_BYTECODE_VERSION + ) + + export_kernel(**export_kwargs) + + return signature.symbol + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("output_file", type=Path) + parser.add_argument( + "--format", choices=("cubin", "tileir_bytecode"), default="cubin" + ) + parser.add_argument( + "--data-type", choices=("half", "float"), required=True + ) + parser.add_argument("--metric", choices=METRICS, required=True) + parser.add_argument("--index-type", choices=INDEX_TYPES, required=True) + parser.add_argument("--tile-m", type=int, required=True) + parser.add_argument("--tile-n", type=int, required=True) + parser.add_argument("--tile-k", type=int, required=True) + parser.add_argument( + "--gpu-code", + default=DEFAULT_TILEIR_EXPORT_GPU_CODE, + help="Target SM for cubin export, or compile hint for TileIR bytecode export", + ) + parser.add_argument( + "--bytecode-version", default=DEFAULT_TILEIR_BYTECODE_VERSION + ) + args = parser.parse_args() + + print( + export_binary( + args.output_file, + output_format=args.format, + data_type=args.data_type, + metric=args.metric, + index_type=args.index_type, + tile_m=args.tile_m, + tile_n=args.tile_n, + tile_k=args.tile_k, + gpu_code=args.gpu_code, + bytecode_version=args.bytecode_version, + ) + ) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_cutile_matrix.json b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_cutile_matrix.json new file mode 100644 index 0000000000..7d9b723b39 --- /dev/null +++ b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_cutile_matrix.json @@ -0,0 +1,95 @@ +[ + { + "_data": [ + { + "data_type": "half", + "data_abbrev": "h" + }, + { + "data_type": "float", + "data_abbrev": "f" + } + ], + "_metric": [ + { + "metric": "inner_product", + "metric_abbrev": "ip" + }, + { + "metric": "l2_expanded", + "metric_abbrev": "l2" + }, + { + "metric": "cosine_expanded", + "metric_abbrev": "cos" + } + ], + "_index": [ + { + "index_type": "int32", + "index_abbrev": "i32" + }, + { + "index_type": "int64", + "index_abbrev": "i64" + } + ], + "_tile": [ + { + "tile_m": 256, + "tile_n": 64, + "tile_k": 32 + } + ], + "_export": [ + { + "output_format": "cubin", + "artifact_ext": "cubin", + "artifact_basename": "@data_type@_@metric_abbrev@_@index_abbrev@_@gpu_code@", + "register": "cubin", + "gpu_code": "sm_80", + "cc_major": 8, + "cc_minor": 0, + "arch_tag": "cutile_arch_8_0" + }, + { + "output_format": "cubin", + "artifact_ext": "cubin", + "artifact_basename": "@data_type@_@metric_abbrev@_@index_abbrev@_@gpu_code@", + "register": "cubin", + "gpu_code": "sm_86", + "cc_major": 8, + "cc_minor": 6, + "arch_tag": "cutile_arch_8_6" + }, + { + "output_format": "cubin", + "artifact_ext": "cubin", + "artifact_basename": "@data_type@_@metric_abbrev@_@index_abbrev@_@gpu_code@", + "register": "cubin", + "gpu_code": "sm_90", + "cc_major": 9, + "cc_minor": 0, + "arch_tag": "cutile_arch_9_0" + }, + { + "output_format": "cubin", + "artifact_ext": "cubin", + "artifact_basename": "@data_type@_@metric_abbrev@_@index_abbrev@_@gpu_code@", + "register": "cubin", + "gpu_code": "sm_120", + "cc_major": 12, + "cc_minor": 0, + "arch_tag": "cutile_arch_12_0" + }, + { + "output_format": "tileir_bytecode", + "artifact_ext": "tilebc", + "artifact_basename": "@data_type@_@metric_abbrev@_@index_abbrev@", + "register": "tileir", + "gpu_code": "sm_80", + "bytecode_version": "13.1" + } + ] + } +] diff --git a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_kernel.py b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_kernel.py new file mode 100644 index 0000000000..b2ff25555b --- /dev/null +++ b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_kernel.py @@ -0,0 +1,166 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. +# SPDX-License-Identifier: Apache-2.0 +"""cuTile fused GEMM + 1-NN kernels (InnerProduct, L2Expanded, CosineExpanded).""" + +from __future__ import annotations + +import cuda.tile as ct + +ConstInt = ct.Constant[int] + +# Default tile geometry; overridden per export via make_kernel(..., tile_m, tile_n, tile_k). +DEFAULT_TILE_M = 256 +DEFAULT_TILE_N = 64 +DEFAULT_TILE_K = 32 + +METRICS = ("inner_product", "l2_expanded", "cosine_expanded") +INDEX_TYPES = ("int32", "int64") + + +def _idx_dtype(index_type: str): + if index_type == "int32": + return ct.int32 + if index_type == "int64": + return ct.int64 + raise ValueError(f"Unsupported index_type {index_type!r}") + + +def make_kernel( + data_type: str, + metric: str, + tile_m: int = DEFAULT_TILE_M, + tile_n: int = DEFAULT_TILE_N, + tile_k: int = DEFAULT_TILE_K, + *, + index_type: str = "int32", +): + """Build a cuTile kernel with metric, index width, and tile sizes baked in at compile time.""" + if data_type not in ("half", "float"): + raise ValueError(f"Unsupported data_type {data_type!r}") + if metric not in METRICS: + raise ValueError(f"Unsupported metric {metric!r}") + if index_type not in INDEX_TYPES: + raise ValueError(f"Unsupported index_type {index_type!r}") + + acc_dtype = ct.float32 + idx_dtype = _idx_dtype(index_type) + out_dist_dtype = ct.float16 if data_type == "half" else ct.float32 + is_ip = metric == "inner_product" + is_l2 = metric == "l2_expanded" + is_cos = metric == "cosine_expanded" + + @ct.kernel + def fused_1nn_kernel( + A, + B, + A_norm, + B_norm, + OutIdx, + OutDist, + M, + N, + K, + apply_sqrt, + store_idx, + tm: ConstInt, + tn: ConstInt, + tk: ConstInt, + ): + bidm = ct.bid(0) + + if is_ip: + best_dist = ct.full((tm,), -3.4e38, acc_dtype) + else: + best_dist = ct.full((tm,), 3.4e38, acc_dtype) + best_idx = ct.zeros((tm,), idx_dtype) + + num_tiles_k = ct.num_tiles(A, axis=1, shape=(tm, tk)) + num_tiles_n = ct.num_tiles(B, axis=0, shape=(tn, tk)) + zero_pad = ct.PaddingMode.ZERO + + for n in range(num_tiles_n): + accumulator = ct.full((tm, tn), 0, dtype=acc_dtype) + + for k in range(num_tiles_k): + dtype = ct.tfloat32 if A.dtype == ct.float32 else A.dtype + + a = ct.load( + A, index=(bidm, k), shape=(tm, tk), padding_mode=zero_pad + ).astype(dtype) + b_T = ct.load( + B, index=(n, k), shape=(tn, tk), padding_mode=zero_pad + ).astype(dtype) + + accumulator = ct.mma(a, ct.transpose(b_T), accumulator) + + if is_ip: + score = accumulator + elif is_l2 or is_cos: + a_norm = ct.load( + A_norm, index=(bidm,), shape=(tm,), padding_mode=zero_pad + ) + b_norm = ct.load( + B_norm, index=(n,), shape=(tn,), padding_mode=zero_pad + ) + if is_l2: + # L2 expanded: ||x||^2 + ||y||^2 - 2 * dot(x, y); norms are squared. + score = ( + a_norm[:, None] + b_norm[None, :] - (2.0 * accumulator) + ) + elif is_cos: + # Cosine expanded distance: 1 - dot / (||x|| * ||y||); norms are L2 (not squared). + denom = a_norm[:, None] * b_norm[None, :] + score = 1.0 - (accumulator / denom) + + # Only the final N-tile can include zero-padded centroid columns. + if n == num_tiles_n - 1: + col = ct.arange(tn, dtype=idx_dtype) + global_col = (n * tn + col).astype(idx_dtype) + valid = global_col < N + if is_ip: + score = ct.where(valid[None, :], score, -3.4e38) + else: + score = ct.where(valid[None, :], score, 3.4e38) + + if is_ip: + curr_best = ct.max(score, axis=1) + curr_idx = ct.argmax(score, axis=1) + update = curr_best > best_dist + best_dist = ct.where(update, curr_best, best_dist) + else: + curr_best = ct.min(score, axis=1) + curr_idx = ct.argmin(score, axis=1) + update = curr_best < best_dist + best_dist = ct.where(update, curr_best, best_dist) + + best_idx = ct.where( + update, (n * tn + curr_idx).astype(idx_dtype), best_idx + ) + + out_dist = best_dist + if is_l2: + out_dist = ct.where(apply_sqrt != 0, ct.sqrt(best_dist), best_dist) + if store_idx != 0: + ct.store(OutIdx, index=(bidm,), tile=best_idx) + ct.store(OutDist, index=(bidm,), tile=out_dist.astype(out_dist_dtype)) + + return fused_1nn_kernel + + +def kernel_symbol( + data_abbrev: str, metric_abbrev: str, index_abbrev: str +) -> str: + """Must stay in sync with fused_1nn_kernel_entrypoint() in fused_1nn_planner.hpp.""" + return f"fused_1nn_{data_abbrev}_{metric_abbrev}_{index_abbrev}" + + +def metric_abbrev(metric: str) -> str: + return { + "inner_product": "ip", + "l2_expanded": "l2", + "cosine_expanded": "cos", + }[metric] + + +def index_abbrev(index_type: str) -> str: + return {"int32": "i32", "int64": "i64"}[index_type] diff --git a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_planner.hpp b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_planner.hpp new file mode 100644 index 0000000000..017fb72d48 --- /dev/null +++ b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_planner.hpp @@ -0,0 +1,103 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include + +#include +#include +#include +#include + +#include "fused_1nn_cutile_tiles.hpp" + +namespace cuvs::distance::detail { + +/** Must match kernel_symbol() in fused_1nn_kernel.py (export uses with_symbol). */ +template +inline const char* fused_1nn_kernel_entrypoint() +{ + constexpr bool is_i32 = std::is_same_v; + constexpr bool is_i64 = std::is_same_v; + static_assert(is_i32 || is_i64, "unsupported fused 1-NN cuTile index width"); + + if constexpr (std::is_same_v && + std::is_same_v) { + return is_i32 ? "fused_1nn_f_ip_i32" : "fused_1nn_f_ip_i64"; + } else if constexpr (std::is_same_v && + std::is_same_v) { + return is_i32 ? "fused_1nn_f_l2_i32" : "fused_1nn_f_l2_i64"; + } else if constexpr (std::is_same_v && + std::is_same_v) { + return is_i32 ? "fused_1nn_f_cos_i32" : "fused_1nn_f_cos_i64"; + } else if constexpr (std::is_same_v && + std::is_same_v) { + return is_i32 ? "fused_1nn_h_ip_i32" : "fused_1nn_h_ip_i64"; + } else if constexpr (std::is_same_v && + std::is_same_v) { + return is_i32 ? "fused_1nn_h_l2_i32" : "fused_1nn_h_l2_i64"; + } else if constexpr (std::is_same_v && + std::is_same_v) { + return is_i32 ? "fused_1nn_h_cos_i32" : "fused_1nn_h_cos_i64"; + } else { + static_assert(sizeof(DataTag) == 0, "unsupported fused 1-NN cuTile data/metric combination"); + return ""; + } +} + +template +struct Fused1nnTilePlanner : TileAlgorithmPlanner { + using DataTag = fused_1nn_data_tag_t; + using MetricTag = fused_1nn_metric_tag_t; + using IndexTag = fused_1nn_index_tag_t; + + inline static LauncherJitCache launcher_jit_cache{}; + + Fused1nnTilePlanner() + : TileAlgorithmPlanner(fused_1nn_kernel_entrypoint(), + launcher_jit_cache) + { + } + + /** Registers embedded cubin modules (one per SM); see register_cutile_fragment.cpp object files. + */ + void add_entrypoint() + { + using cuvs::detail::jit_lto::cutile_arch_12_0; + using cuvs::detail::jit_lto::cutile_arch_8_0; + using cuvs::detail::jit_lto::cutile_arch_8_6; + using cuvs::detail::jit_lto::cutile_arch_9_0; + + this->add_static_fragment>(); + this->add_static_fragment>(); + this->add_static_fragment>(); + this->add_static_fragment>(); + } + + void add_tileir_fallback() + { + this->add_static_tileir_fragment< + fragment_tag_fused_1nn_tileir>(); + } +}; + +} // namespace cuvs::distance::detail diff --git a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.cu b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.cu new file mode 100644 index 0000000000..cf01c12ce1 --- /dev/null +++ b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.cu @@ -0,0 +1,214 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "fused_1nn_tile.hpp" + +#include "fused_1nn_planner.hpp" + +#include +#include +#include + +namespace cuvs { +namespace distance { +namespace detail { + +namespace { + +template +bool launch_fused_1nn_tile(IdxT* nearest_idx, + DataT* nearest_dist, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + bool is_sqrt, + cudaStream_t stream) +{ + if constexpr (!std::is_same_v && !std::is_same_v) { return false; } + + if (nearest_dist == nullptr) { return false; } + + Fused1nnTilePlanner planner; + planner.add_entrypoint(); + planner.add_tileir_fallback(); + const CutileTileConfig tile_cfg = planner.tile_config(); + auto launcher = planner.try_get_launcher(); + if (!launcher) { return false; } + + const bool apply_sqrt = fused_1nn_apply_sqrt_at_pack(is_sqrt); + + int64_t shape_x[2] = {m, k}; + int64_t stride_x[2] = {k, 1}; + int64_t shape_y[2] = {n, k}; + int64_t stride_y[2] = {k, 1}; + int64_t shape_xn = m; + int64_t stride_xn = 1; + int64_t shape_yn = n; + int64_t stride_yn = 1; + int64_t shape_idx = m; + int64_t stride_idx = 1; + int64_t shape_dist = m; + int64_t stride_dist = 1; + + int64_t M = m, N = n, K = k; + + void* x_ptr = const_cast(x); + void* y_ptr = const_cast(y); + void* xn_ptr = const_cast(xn); + void* yn_ptr = const_cast(yn); + // OutIdx must be a valid device pointer for the launch ABI; when store_idx is 0 the kernel + // does not write it (dist-only callers pass nearest_dist as a stand-in). + const int64_t store_idx = nearest_idx != nullptr ? 1 : 0; + void* idx_ptr = + nearest_idx != nullptr ? static_cast(nearest_idx) : static_cast(nearest_dist); + void* dist_ptr = nearest_dist; + + const int64_t tile_m = tile_cfg.tile_m; + dim3 grid((m + tile_m - 1) / tile_m, 1, 1); + dim3 block(1, 1, 1); + + using fused_1nn_cutile_kernel_t = void(void*, + int64_t, + int64_t, + int64_t, + int64_t, + void*, + int64_t, + int64_t, + int64_t, + int64_t, + void*, + int64_t, + int64_t, + void*, + int64_t, + int64_t, + void*, + int64_t, + int64_t, + void*, + int64_t, + int64_t, + int64_t, + int64_t, + int64_t, + int64_t, + int64_t); + launcher->template dispatch(stream, + grid, + block, + 0, + x_ptr, + shape_x[0], + shape_x[1], + stride_x[0], + stride_x[1], + y_ptr, + shape_y[0], + shape_y[1], + stride_y[0], + stride_y[1], + xn_ptr, + shape_xn, + stride_xn, + yn_ptr, + shape_yn, + stride_yn, + idx_ptr, + shape_idx, + stride_idx, + dist_ptr, + shape_dist, + stride_dist, + M, + N, + K, + static_cast(apply_sqrt), + store_idx); + RAFT_CUDA_TRY(cudaGetLastError()); + return true; +} + +template +bool try_fused_1nn_tile_dispatch(IdxT* nearest_idx, + DataT* nearest_dist, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + cuvs::distance::DistanceType metric, + bool is_sqrt, + cudaStream_t stream) +{ + switch (metric) { + case cuvs::distance::DistanceType::InnerProduct: + return launch_fused_1nn_tile( + nearest_idx, nearest_dist, x, y, xn, yn, m, n, k, is_sqrt, stream); + case cuvs::distance::DistanceType::L2Expanded: + return launch_fused_1nn_tile( + nearest_idx, nearest_dist, x, y, xn, yn, m, n, k, is_sqrt, stream); + case cuvs::distance::DistanceType::L2SqrtExpanded: + return launch_fused_1nn_tile( + nearest_idx, nearest_dist, x, y, xn, yn, m, n, k, is_sqrt, stream); + case cuvs::distance::DistanceType::CosineExpanded: + return launch_fused_1nn_tile( + nearest_idx, nearest_dist, x, y, xn, yn, m, n, k, is_sqrt, stream); + default: return false; + } +} + +} // namespace + +template + requires is_fused_1nn_cutile_data_v +bool try_fused_1nn_tile(IdxT* nearest_idx, + DataT* nearest_dist, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + cuvs::distance::DistanceType metric, + bool is_sqrt, + cudaStream_t stream) +{ + if (!cuvs::detail::jit_lto::cutile_launch_available_on_current_device()) { return false; } + return try_fused_1nn_tile_dispatch( + nearest_idx, nearest_dist, x, y, xn, yn, m, n, k, metric, is_sqrt, stream); +} + +#define CUVS_INST_TRY_FUSED_1NN_TILE(DataT, IdxT) \ + template CUVS_EXPORT bool try_fused_1nn_tile(IdxT*, \ + DataT*, \ + const DataT*, \ + const DataT*, \ + const DataT*, \ + const DataT*, \ + IdxT, \ + IdxT, \ + IdxT, \ + cuvs::distance::DistanceType, \ + bool, \ + cudaStream_t) + +CUVS_INST_TRY_FUSED_1NN_TILE(float, int); +CUVS_INST_TRY_FUSED_1NN_TILE(float, int64_t); +CUVS_INST_TRY_FUSED_1NN_TILE(half, int); +CUVS_INST_TRY_FUSED_1NN_TILE(half, int64_t); + +#undef CUVS_INST_TRY_FUSED_1NN_TILE + +} // namespace detail +} // namespace distance +} // namespace cuvs diff --git a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.hpp b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.hpp new file mode 100644 index 0000000000..4c0964631c --- /dev/null +++ b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.hpp @@ -0,0 +1,63 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include + +#include + +#include +#include + +#ifndef CUVS_CUTILE_ENABLED +#define CUVS_CUTILE_ENABLED 0 +#endif + +namespace cuvs { +namespace distance { +namespace detail { + +template +inline constexpr bool is_fused_1nn_cutile_data_v = + std::is_same_v || std::is_same_v; + +#if CUVS_CUTILE_ENABLED +template + requires is_fused_1nn_cutile_data_v +bool try_fused_1nn_tile(IdxT* nearest_idx, + DataT* nearest_dist, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + cuvs::distance::DistanceType metric, + bool is_sqrt, + cudaStream_t stream); +#else +template +bool try_fused_1nn_tile(IdxT*, + DataT*, + const DataT*, + const DataT*, + const DataT*, + const DataT*, + IdxT, + IdxT, + IdxT, + cuvs::distance::DistanceType, + bool, + cudaStream_t) +{ + return false; +} +#endif + +} // namespace detail +} // namespace distance +} // namespace cuvs diff --git a/cpp/src/distance/detail/fused_distance_nn/fused_cosine_nn.cuh b/cpp/src/distance/detail/fused_distance_nn/fused_cosine_nn.cuh index 12f4f17cac..cc16d8a2e1 100644 --- a/cpp/src/distance/detail/fused_distance_nn/fused_cosine_nn.cuh +++ b/cpp/src/distance/detail/fused_distance_nn/fused_cosine_nn.cuh @@ -1,11 +1,11 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once -#include "../distance_ops/cosine.cuh" // ops::l2_exp_distance_op +#include "../distance_ops/cosine.cuh" // ops::cosine_distance_op #include "../pairwise_distance_base.cuh" // PairwiseDistances #include "cutlass_base.cuh" #include "helper_structs.cuh" @@ -24,13 +24,9 @@ namespace distance { namespace detail { -template -void fusedCosineNN(OutT* min, +template +void fusedCosineNN(IdxT* nearest_idx, + DataT* nearest_dist, const DataT* x, const DataT* y, const DataT* xn, @@ -42,15 +38,20 @@ void fusedCosineNN(OutT* min, ReduceOpT redOp, KVPReduceOpT pairRedOp, bool sqrt, + raft::KeyValuePair* cutlass_out, cudaStream_t stream) { - // The kernel policy is determined by fusedL2NN. typedef Policy P; dim3 blk(P::Nthreads); constexpr auto maxVal = std::numeric_limits::max(); typedef raft::KeyValuePair KVPair; + if (cutlass_out == nullptr) { + initFused1nnOutput(nearest_idx, nearest_dist, m, maxVal, stream); + RAFT_CUDA_TRY(cudaGetLastError()); + } + namespace arch = raft::util::arch; using AccT = DataT; ops::cosine_distance_op distance_op{}; @@ -58,7 +59,7 @@ void fusedCosineNN(OutT* min, raft::identity_op fin_op{}; auto kernel = fusedDistanceNNkernel; - // Get pointer to fp32 SIMT kernel to determine the runtime architecture of the - // current system. Other methods to determine the architecture (that do not - // require a pointer) can be error prone. See: - // https://github.com/NVIDIA/cub/issues/545 void* kernel_ptr = reinterpret_cast(kernel); auto runtime_arch = arch::kernel_virtual_arch(kernel_ptr); auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future()); if (cutlass_range.contains(runtime_arch)) { - // If device is SM_80 or later, use CUTLASS-based kernel. using cosineOp = cuvs::distance::detail::ops::cosine_cutlass_op; - using kvp_cg_min_reduce_op_ = kvp_cg_min_reduce_op; + using kvp_cg_min_reduce_op_ = kvp_cg_min_reduce_op; kvp_cg_min_reduce_op_ cg_reduce_op; cosineOp cosine_dist_op; @@ -86,7 +82,7 @@ void fusedCosineNN(OutT* min, cutlassFusedDistanceNN(m, n, shmemSize, kernel); kernel<<>>( - min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, distance_op, fin_op); + cutlass_out, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, distance_op, fin_op); RAFT_CUDA_TRY(cudaGetLastError()); } } diff --git a/cpp/src/distance/detail/fused_distance_nn/fused_l2_nn.cuh b/cpp/src/distance/detail/fused_distance_nn/fused_l2_nn.cuh index f1aad72110..8c532e2932 100644 --- a/cpp/src/distance/detail/fused_distance_nn/fused_l2_nn.cuh +++ b/cpp/src/distance/detail/fused_distance_nn/fused_l2_nn.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2021-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -24,13 +24,9 @@ namespace distance { namespace detail { -template -void fusedL2NNImpl(OutT* min, +template +void fusedL2NNImpl(IdxT* nearest_idx, + DataT* nearest_dist, const DataT* x, const DataT* y, const DataT* xn, @@ -43,19 +39,17 @@ void fusedL2NNImpl(OutT* min, KVPReduceOpT pairRedOp, bool sqrt, bool initOutBuffer, + raft::KeyValuePair* cutlass_out, cudaStream_t stream) { - // The kernel policy is determined by fusedL2NN. typedef Policy P; dim3 blk(P::Nthreads); - auto nblks = raft::ceildiv(m, P::Nthreads); constexpr auto maxVal = std::numeric_limits::max(); typedef raft::KeyValuePair KVPair; - if (initOutBuffer) { - initKernel - <<>>(min, m, maxVal, redOp); + if (initOutBuffer && cutlass_out == nullptr) { + initFused1nnOutput(nearest_idx, nearest_dist, m, maxVal, stream); RAFT_CUDA_TRY(cudaGetLastError()); } @@ -66,7 +60,7 @@ void fusedL2NNImpl(OutT* min, raft::identity_op fin_op{}; auto kernel = fusedDistanceNNkernel; - // Get pointer to fp32 SIMT kernel to determine the best compute architecture - // out of all for which the kernel was compiled for that matches closely - // to the current device. Other methods to determine the architecture (that do not - // require a pointer) can be error prone. See: - // https://github.com/NVIDIA/cub/issues/545 void* kernel_ptr = reinterpret_cast(kernel); auto runtime_arch = arch::kernel_virtual_arch(kernel_ptr); auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future()); if (cutlass_range.contains(runtime_arch)) { - // If device is SM_80 or later, use CUTLASS-based kernel. using L2Op = cuvs::distance::detail::ops::l2_exp_cutlass_op; - using kvp_cg_min_reduce_op_ = kvp_cg_min_reduce_op; + using kvp_cg_min_reduce_op_ = kvp_cg_min_reduce_op; kvp_cg_min_reduce_op_ cg_reduce_op; L2Op L2_dist_op(sqrt); @@ -95,7 +83,7 @@ void fusedL2NNImpl(OutT* min, cutlassFusedDistanceNN(m, n, shmemSize, kernel); kernel<<>>( - min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, distance_op, fin_op); + cutlass_out, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, distance_op, fin_op); RAFT_CUDA_TRY(cudaGetLastError()); } } diff --git a/cpp/src/distance/detail/fused_distance_nn/helper_structs.cuh b/cpp/src/distance/detail/fused_distance_nn/helper_structs.cuh index 762c720568..3bd78ba5ab 100644 --- a/cpp/src/distance/detail/fused_distance_nn/helper_structs.cuh +++ b/cpp/src/distance/detail/fused_distance_nn/helper_structs.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2021-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -32,20 +32,43 @@ struct KVPMinReduceImpl { }; // KVPMinReduce +/** Writes fused 1-NN results to separate idx/dist arrays (dist may be null). */ template struct MinAndDistanceReduceOpImpl { typedef typename raft::KeyValuePair KVP; + LabelT* out_idx{nullptr}; + DataT* out_dist{nullptr}; + /** When set, CUTLASS/SIMT global merge writes here instead of SoA (caller unpacks). */ + KVP* out_kvp{nullptr}; + + DI void merge(LabelT rid, const KVP& other) const + { + if (out_kvp != nullptr) { + if (other.value < out_kvp[rid].value) { out_kvp[rid] = other; } + } else if (out_dist != nullptr) { + if (other.value < out_dist[rid]) { + out_dist[rid] = other.value; + if (out_idx != nullptr) { out_idx[rid] = other.key; } + } + } else if (out_idx != nullptr) { + // Idx-only output: dist must still be tracked for multi-tile merge; caller must provide + // out_dist or use a single-pass backend (cuTile). KMeans always passes both buffers. + out_idx[rid] = other.key; + } + } + DI void operator()(LabelT rid, KVP* out, const KVP& other) const { - if (other.value < out->value) { + if (out != nullptr && other.value < out->value) { out->key = other.key; out->value = other.value; } } + DI void operator()(LabelT rid, volatile KVP* out, const KVP& other) const { - if (other.value < out->value) { + if (out != nullptr && other.value < out->value) { out->key = other.key; out->value = other.value; } @@ -53,35 +76,41 @@ struct MinAndDistanceReduceOpImpl { DI void operator()(LabelT rid, DataT* out, const KVP& other) const { - if (other.value < *out) { *out = other.value; } + if (out != nullptr && other.value < *out) { *out = other.value; } } DI void operator()(LabelT rid, volatile DataT* out, const KVP& other) const { - if (other.value < *out) { *out = other.value; } + if (out != nullptr && other.value < *out) { *out = other.value; } } DI void operator()(LabelT rid, DataT* out, const DataT& other) const { - if (other < *out) { *out = other; } + if (out != nullptr && other < *out) { *out = other; } } DI void operator()(LabelT rid, volatile DataT* out, const DataT& other) const { - if (other < *out) { *out = other; } + if (out != nullptr && other < *out) { *out = other; } + } + + DI void init(DataT* out, DataT maxVal) const + { + if (out != nullptr) { *out = maxVal; } } - DI void init(DataT* out, DataT maxVal) const { *out = maxVal; } DI void init(KVP* out, DataT maxVal) const { out->value = maxVal; - out->key = 0xfffffff0; + out->key = LabelT(0); } - DI void init_key(DataT& out, LabelT idx) const { return; } + DI void init_key(DataT& /*out*/, LabelT /*idx*/) const {} + DI void init_key(KVP& out, LabelT idx) const { out.key = idx; } DI DataT get_value(KVP& out) const { return out.value; } + DI DataT get_value(DataT& out) const { return out; } }; @@ -96,6 +125,53 @@ struct MinReduceOpImpl { DI void init(DataT* out, DataT maxVal) { *out = maxVal; } }; +template +RAFT_KERNEL initFused1nnOutputKernel(IdxT* nearest_idx, DataT* nearest_dist, IdxT m, DataT maxVal) +{ + IdxT tid = IdxT(blockIdx.x) * blockDim.x + threadIdx.x; + if (tid < m) { + if (nearest_idx != nullptr) { nearest_idx[tid] = IdxT(0); } + if (nearest_dist != nullptr) { nearest_dist[tid] = maxVal; } + } +} + +template +void initFused1nnOutput( + IdxT* nearest_idx, DataT* nearest_dist, IdxT m, DataT maxVal, cudaStream_t stream) +{ + if (nearest_idx == nullptr && nearest_dist == nullptr) { return; } + auto blks = raft::ceildiv(m, 256); + initFused1nnOutputKernel + <<>>(nearest_idx, nearest_dist, m, maxVal); +} + +template +RAFT_KERNEL unpackFused1nnKvpToSoaKernel(IdxT* nearest_idx, + DataT* nearest_dist, + const raft::KeyValuePair* kvp, + IdxT n) +{ + IdxT i = IdxT(blockIdx.x) * blockDim.x + threadIdx.x; + if (i < n) { + if (nearest_idx != nullptr) { nearest_idx[i] = kvp[i].key; } + if (nearest_dist != nullptr) { nearest_dist[i] = kvp[i].value; } + } +} + +template +void unpackFused1nnKvpToSoa(IdxT* nearest_idx, + DataT* nearest_dist, + const raft::KeyValuePair* kvp, + IdxT m, + cudaStream_t stream) +{ + if (nearest_idx == nullptr && nearest_dist == nullptr) { return; } + auto blks = raft::ceildiv(m, 256); + unpackFused1nnKvpToSoaKernel + <<>>(nearest_idx, nearest_dist, kvp, m); + RAFT_CUDA_TRY(cudaGetLastError()); +} + template RAFT_KERNEL initKernel(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) { @@ -106,15 +182,13 @@ RAFT_KERNEL initKernel(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) template void initialize(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp, cudaStream_t stream) { - auto blks = raft::ceildiv(m, 256); - initKernel<<>>(min, m, maxVal, redOp); + auto blks = raft::ceildiv(m, 256); + initKernel<<>>(min, m, maxVal, redOp); } // cg::reduce functor for FusedDistanceNN used in its cutlass version // to output the min distance value & key(loc id). -// This is used in fused_distance_nn/predicated_tile_iterator_reduced_vec.h -// store_with_byte_offset() passed to cg::reduce() & select_reduce. -template +template struct kvp_cg_min_reduce_op { typedef typename raft::KeyValuePair KVP; @@ -122,7 +196,6 @@ struct kvp_cg_min_reduce_op { using AccTypeT = AccType; using IndexT = Index; - // functor signature. __host__ __device__ KVP operator()(KVP a, KVP b) const { return a.value < b.value ? a : b; } __host__ __device__ AccType operator()(AccType a, AccType b) const { return min(a, b); } diff --git a/cpp/src/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h b/cpp/src/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h index caa6a36d53..0d9f5333af 100644 --- a/cpp/src/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h +++ b/cpp/src/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h @@ -437,10 +437,8 @@ class PredicatedTileIteratorReducedVec { __syncthreads(); if (row < total_rows) { - volatile Element* gmem_ptr = reinterpret_cast(first_tile_byte_pointer_); - if ((block_start_row_first_tile_ + row) < extent_row_) { - user_params.red_op_(block_start_row_first_tile_ + row, (gmem_ptr + row), row_local_min); + user_params.red_op_.merge(block_start_row_first_tile_ + row, row_local_min); } } diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_planner.hpp b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_planner.hpp index 0d00b3eca6..f89a383596 100644 --- a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_planner.hpp +++ b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_planner.hpp @@ -20,7 +20,7 @@ template -struct PairwiseMatrixPlanner : AlgorithmPlanner { +struct PairwiseMatrixPlanner : LTOAlgorithmPlanner { using DistanceTag = DistanceTag_; using DataTag = DataTag_; using AccTag = AccTag_; @@ -33,7 +33,7 @@ struct PairwiseMatrixPlanner : AlgorithmPlanner { inline static LauncherJitCache launcher_jit_cache{}; - PairwiseMatrixPlanner() : AlgorithmPlanner(kPairwiseMatrixJitEntrypoint, launcher_jit_cache) {} + PairwiseMatrixPlanner() : LTOAlgorithmPlanner(kPairwiseMatrixJitEntrypoint, launcher_jit_cache) {} void add_entrypoint() { diff --git a/cpp/src/distance/fused_distance_nn-inl.cuh b/cpp/src/distance/fused_distance_nn-inl.cuh index 3fa80a9b60..13c4faa472 100644 --- a/cpp/src/distance/fused_distance_nn-inl.cuh +++ b/cpp/src/distance/fused_distance_nn-inl.cuh @@ -28,48 +28,10 @@ namespace distance { * \ingroup fused_l2_nn * @{ */ -/** - * @brief Fused L2 distance and 1-nearest-neighbor computation in a single call. - * - * The benefits of such a call are 2-fold: 1) eliminate the need for an - * intermediate buffer to store the output of gemm 2) reduce the memory read - * traffic on this intermediate buffer, otherwise needed during the reduction - * phase for 1-NN. - * - * @tparam DataT data type - * @tparam OutT output type to either store 1-NN indices and their minimum - * distances or store only the min distances. Accordingly, one - * has to pass an appropriate `ReduceOpT` - * @tparam IdxT indexing arithmetic type - * @tparam ReduceOpT A struct to perform the final needed reduction operation - * and also to initialize the output array elements with the - * appropriate initial value needed for reduction. - * @tparam KVPReduceOpT A struct providing functions for key-value pair comparison. - * - * @param[out] min will contain the reduced output (Length = `m`) - * (on device) - * @param[in] x first matrix. Row major. Dim = `m x k`. - * (on device). - * @param[in] y second matrix. Row major. Dim = `n x k`. - * (on device). - * @param[in] xn L2 squared norm of `x`. Length = `m`. (on device). - * @param[in] yn L2 squared norm of `y`. Length = `n`. (on device) - * @param[in] m gemm m - * @param[in] n gemm n - * @param[in] k gemm k - * @param[in] workspace temp workspace. Size = sizeof(int)*m. (on device) - * @param[in] redOp reduction operator in the epilogue - * @param[in] pairRedOp reduction operation on key value pairs - * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt - * @param[in] initOutBuffer whether to initialize the output buffer before the - * main kernel launch - * @param[in] isRowMajor whether the input/output is row or column major. - * @param[in] metric Distance metric to be used (supports L2, cosine) - * @param[in] metric_arg power argument for distances like Minkowski (not supported for now) - * @param[in] stream cuda stream - */ -template -void fusedDistanceNN(OutT* min, + +template +void fusedDistanceNN(IdxT* nearest_idx, + DataT* nearest_dist, const DataT* x, const DataT* y, const DataT* xn, @@ -85,12 +47,10 @@ void fusedDistanceNN(OutT* min, bool isRowMajor, cuvs::distance::DistanceType metric, float metric_arg, + raft::KeyValuePair* cutlass_kvp_scratch, cudaStream_t stream) { ASSERT(isRowMajor, "fusedDistanceNN only supports row major inputs"); - // When k is smaller than 32, the Policy4x4 results in redundant calculations - // as it uses tiles that have k=32. Therefore, use a "skinny" policy instead - // that uses tiles with a smaller value of k. bool is_skinny = k < 32; size_t bytes = sizeof(DataT) * k; @@ -100,10 +60,10 @@ void fusedDistanceNN(OutT* min, if (is_skinny) { detail::fusedDistanceNNImpl< DataT, - OutT, IdxT, typename raft::linalg::Policy4x4Skinny::Policy, - ReduceOpT>(min, + ReduceOpT>(nearest_idx, + nearest_dist, x, y, xn, @@ -119,14 +79,15 @@ void fusedDistanceNN(OutT* min, isRowMajor, metric, metric_arg, + cutlass_kvp_scratch, stream); } else { detail::fusedDistanceNNImpl< DataT, - OutT, IdxT, typename raft::linalg::Policy4x4::Policy, - ReduceOpT>(min, + ReduceOpT>(nearest_idx, + nearest_dist, x, y, xn, @@ -142,16 +103,17 @@ void fusedDistanceNN(OutT* min, isRowMajor, metric, metric_arg, + cutlass_kvp_scratch, stream); } } else if (8 % sizeof(DataT) == 0 && bytes % 8 == 0 && px % 8 == 0 && py % 8 == 0) { if (is_skinny) { detail::fusedDistanceNNImpl< DataT, - OutT, IdxT, typename raft::linalg::Policy4x4Skinny::Policy, - ReduceOpT>(min, + ReduceOpT>(nearest_idx, + nearest_dist, x, y, xn, @@ -167,14 +129,15 @@ void fusedDistanceNN(OutT* min, isRowMajor, metric, metric_arg, + cutlass_kvp_scratch, stream); } else { detail::fusedDistanceNNImpl< DataT, - OutT, IdxT, typename raft::linalg::Policy4x4::Policy, - ReduceOpT>(min, + ReduceOpT>(nearest_idx, + nearest_dist, x, y, xn, @@ -190,15 +153,16 @@ void fusedDistanceNN(OutT* min, isRowMajor, metric, metric_arg, + cutlass_kvp_scratch, stream); } } else { if (is_skinny) { detail::fusedDistanceNNImpl::Policy, - ReduceOpT>(min, + ReduceOpT>(nearest_idx, + nearest_dist, x, y, xn, @@ -214,13 +178,14 @@ void fusedDistanceNN(OutT* min, isRowMajor, metric, metric_arg, + cutlass_kvp_scratch, stream); } else { detail::fusedDistanceNNImpl::Policy, - ReduceOpT>(min, + ReduceOpT>(nearest_idx, + nearest_dist, x, y, xn, @@ -236,44 +201,23 @@ void fusedDistanceNN(OutT* min, isRowMajor, metric, metric_arg, + cutlass_kvp_scratch, stream); } } } /** - * @brief Wrapper around fusedDistanceNN with minimum reduction operators. - * - * fusedDistanceNN cannot be compiled in the distance library due to the lambda - * operators, so this wrapper covers the most common case (minimum). + * @brief Fused GEMM + 1-NN minimum reduction. * - * @tparam DataT data type - * @tparam OutT output type to either store 1-NN indices and their minimum - * distances (e.g. raft::KeyValuePair) or store only the min - * distances. - * @tparam IdxT indexing arithmetic type - * @param[out] min will contain the reduced output (Length = `m`) - * (on device) - * @param[in] x first matrix. Row major. Dim = `m x k`. - * (on device). - * @param[in] y second matrix. Row major. Dim = `n x k`. - * (on device). - * @param[in] xn L2 squared norm of `x`. Length = `m`. (on device). - * @param[in] yn L2 squared norm of `y`. Length = `n`. (on device) - * @param[in] m gemm m - * @param[in] n gemm n - * @param[in] k gemm k - * @param[in] workspace temp workspace. Size = sizeof(int)*m. (on device) - * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt - * @param[in] initOutBuffer whether to initialize the output buffer before the - * main kernel launch - * @param[in] isRowMajor whether the input/output is row or column major. - * @param[in] metric Distance metric to be used (supports L2, cosine) - * @param[in] metric_arg power argument for distances like Minkowski (not supported for now) - * @param[in] stream cuda stream + * @param[out] nearest_idx Nearest neighbor index per row, length `m` (required). + * @param[out] nearest_dist Minimum distance per row, length `m` (optional, may be null). + * @param[in] cutlass_kvp_scratch Temp KVP buffer, length `m`; required when CUTLASS/SIMT runs. + * Unused when cuTile handles the launch. */ -template -void fusedDistanceNNMinReduce(OutT* min, +template +void fusedDistanceNNMinReduce(IdxT* nearest_idx, + DataT* nearest_dist, const DataT* x, const DataT* y, const DataT* xn, @@ -287,28 +231,33 @@ void fusedDistanceNNMinReduce(OutT* min, bool isRowMajor, cuvs::distance::DistanceType metric, float metric_arg, + raft::KeyValuePair* cutlass_kvp_scratch, cudaStream_t stream) { MinAndDistanceReduceOp redOp; + redOp.out_idx = nearest_idx; + redOp.out_dist = nearest_dist; KVPMinReduce pairRedOp; - fusedDistanceNN(min, - x, - y, - xn, - yn, - m, - n, - k, - workspace, - redOp, - pairRedOp, - sqrt, - initOutBuffer, - isRowMajor, - metric, - metric_arg, - stream); + fusedDistanceNN(nearest_idx, + nearest_dist, + x, + y, + xn, + yn, + m, + n, + k, + workspace, + redOp, + pairRedOp, + sqrt, + initOutBuffer, + isRowMajor, + metric, + metric_arg, + cutlass_kvp_scratch, + stream); } /** @} */ diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp index 0c3ed64d13..b44a7f044e 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp @@ -25,7 +25,7 @@ template -struct CagraPlannerBase : AlgorithmPlanner { +struct CagraPlannerBase : LTOAlgorithmPlanner { using DataTag = DataTag_; using IndexTag = IndexTag_; using DistanceTag = DistanceTag_; @@ -34,7 +34,7 @@ struct CagraPlannerBase : AlgorithmPlanner { using SampleFilterJitTag = SampleFilterJitTag_; explicit CagraPlannerBase(std::string entrypoint, LauncherJitCache& jit_cache) - : AlgorithmPlanner(std::move(entrypoint), jit_cache) + : LTOAlgorithmPlanner(std::move(entrypoint), jit_cache) { } diff --git a/cpp/src/neighbors/ivf_flat/detail/jit_lto_kernels/interleaved_scan_planner.hpp b/cpp/src/neighbors/ivf_flat/detail/jit_lto_kernels/interleaved_scan_planner.hpp index ed8191016b..7899d970ab 100644 --- a/cpp/src/neighbors/ivf_flat/detail/jit_lto_kernels/interleaved_scan_planner.hpp +++ b/cpp/src/neighbors/ivf_flat/detail/jit_lto_kernels/interleaved_scan_planner.hpp @@ -14,10 +14,10 @@ namespace cuvs::neighbors::ivf_flat::detail { -struct InterleavedScanPlanner : AlgorithmPlanner { +struct InterleavedScanPlanner : LTOAlgorithmPlanner { inline static LauncherJitCache launcher_jit_cache{}; - InterleavedScanPlanner() : AlgorithmPlanner("interleaved_scan", launcher_jit_cache) {} + InterleavedScanPlanner() : LTOAlgorithmPlanner("interleaved_scan", launcher_jit_cache) {} template void add_entrypoint() diff --git a/cpp/src/neighbors/ivf_pq/detail/jit_lto_kernels/compute_similarity_planner.hpp b/cpp/src/neighbors/ivf_pq/detail/jit_lto_kernels/compute_similarity_planner.hpp index 0621966cad..7152aaeebd 100644 --- a/cpp/src/neighbors/ivf_pq/detail/jit_lto_kernels/compute_similarity_planner.hpp +++ b/cpp/src/neighbors/ivf_pq/detail/jit_lto_kernels/compute_similarity_planner.hpp @@ -12,10 +12,10 @@ namespace cuvs::neighbors::ivf_pq::detail { -struct ComputeSimilarityPlanner : AlgorithmPlanner { +struct ComputeSimilarityPlanner : LTOAlgorithmPlanner { inline static LauncherJitCache launcher_jit_cache{}; - ComputeSimilarityPlanner() : AlgorithmPlanner("compute_similarity", launcher_jit_cache) {} + ComputeSimilarityPlanner() : LTOAlgorithmPlanner("compute_similarity", launcher_jit_cache) {} template void add_entrypoint() diff --git a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_planner.hpp b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_planner.hpp index 05ea34532e..5dc47dc612 100644 --- a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_planner.hpp +++ b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_planner.hpp @@ -13,10 +13,10 @@ namespace cuvs::neighbors::ivf_sq::detail { -struct IvfSqScanPlanner : AlgorithmPlanner { +struct IvfSqScanPlanner : LTOAlgorithmPlanner { inline static LauncherJitCache launcher_jit_cache{}; - IvfSqScanPlanner() : AlgorithmPlanner("ivf_sq_scan", launcher_jit_cache) {} + IvfSqScanPlanner() : LTOAlgorithmPlanner("ivf_sq_scan", launcher_jit_cache) {} template void add_entrypoint() diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 9b96f94bf0..d4e0099035 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -120,7 +120,7 @@ ConfigureTest( ConfigureTest( NAME CLUSTER_TEST PATH cluster/kmeans.cu cluster/kmeans_balanced.cu cluster/kmeans_find_k.cu cluster/linkage.cu - cluster/connect_knn.cu cluster/spectral.cu + cluster/connect_knn.cu cluster/spectral.cu cluster/soa_unpack_trace.cu GPUS 1 PERCENT 100 ) diff --git a/cpp/tests/neighbors/distance_nn.cu b/cpp/tests/neighbors/distance_nn.cu index f31f3ebacf..6b17fc646b 100644 --- a/cpp/tests/neighbors/distance_nn.cu +++ b/cpp/tests/neighbors/distance_nn.cu @@ -42,7 +42,7 @@ __global__ void fill_int8(int8_t* buff, int len, int seed_offset) template class NNTest : public ::testing::TestWithParam> { public: - using OutT = raft::KeyValuePair; + using RefOutT = raft::KeyValuePair; NNTest() : params_{::testing::TestWithParam>::GetParam()}, m{params_.m}, @@ -55,8 +55,10 @@ class NNTest : public ::testing::TestWithParam> { y{raft::make_device_matrix(handle, n, k)}, x_norm{raft::make_device_vector(handle, m)}, y_norm{raft::make_device_vector(handle, n)}, - out{raft::make_device_vector(handle, m)}, - ref_out{raft::make_device_vector(handle, m)} + out_idx{raft::make_device_vector(handle, m)}, + out_dist{raft::make_device_vector(handle, m)}, + out_kvp{raft::make_device_vector(handle, m)}, + ref_out{raft::make_device_vector(handle, m)} { } @@ -92,15 +94,11 @@ class NNTest : public ::testing::TestWithParam> { workspace_size = m * n * sizeof(AccT); } - // Reset buffer - if constexpr (std::is_same_v>) { - // OutT is a RAFT KeyValuePair - raft::matrix::fill( - handle, raft::make_device_matrix_view(out.data_handle(), m, 1), OutT{0, 0}); - } else { - // OutT is a scalar type - raft::matrix::fill(handle, raft::make_device_matrix_view(out.data_handle(), m, 1), OutT{0}); - } + raft::matrix::fill(handle, raft::make_device_matrix_view(out_idx.data_handle(), m, 1), IdxT{0}); + raft::matrix::fill( + handle, raft::make_device_matrix_view(out_dist.data_handle(), m, 1), AccT{0}); + raft::matrix::fill( + handle, raft::make_device_matrix_view(ref_out.data_handle(), m, 1), RefOutT{0, 0}); raft::resource::sync_stream(handle, stream); } @@ -109,34 +107,36 @@ class NNTest : public ::testing::TestWithParam> { raft::device_vector workspace = raft::make_device_vector(handle, workspace_size); - ref_nn( + ref_nn( ref_out.data_handle(), x.data_handle(), y.data_handle(), m, n, k, sqrt, metric, stream); if constexpr (impl == ImplType::fused) { if constexpr (std::is_same_v) { - cuvs::distance::fusedDistanceNNMinReduce(out.data_handle(), - x.data_handle(), - y.data_handle(), - x_norm.data_handle(), - y_norm.data_handle(), - m, - n, - k, - (void*)workspace.data_handle(), - sqrt, - true, - true, - metric, - 0.0, - stream); + cuvs::distance::fusedDistanceNNMinReduce(out_idx.data_handle(), + out_dist.data_handle(), + x.data_handle(), + y.data_handle(), + x_norm.data_handle(), + y_norm.data_handle(), + m, + n, + k, + (void*)workspace.data_handle(), + sqrt, + true, + true, + metric, + 0.0, + out_kvp.data_handle(), + stream); } else { static_assert(sizeof(DataT) == 0, "fusedDistanceNNMinReduce is not implemented for datatype other than float"); } } else if constexpr (impl == ImplType::unfused) { - cuvs::distance::unfusedDistanceNNMinReduce( + cuvs::distance::unfusedDistanceNNMinReduce( handle, - out.data_handle(), + out_kvp.data_handle(), x.data_handle(), y.data_handle(), x_norm.data_handle(), @@ -156,7 +156,12 @@ class NNTest : public ::testing::TestWithParam> { void compare() { - vector_compare(handle, ref_out.data_handle(), out.data_handle(), m, summary); + if constexpr (impl == ImplType::fused) { + vector_compare_soa( + handle, ref_out.data_handle(), out_idx.data_handle(), out_dist.data_handle(), m, summary); + } else { + vector_compare(handle, ref_out.data_handle(), out_kvp.data_handle(), m, summary); + } ASSERT_TRUE(summary.max_diff < params_.tol) << summary; } @@ -174,8 +179,10 @@ class NNTest : public ::testing::TestWithParam> { raft::device_matrix y; raft::device_vector x_norm; raft::device_vector y_norm; - raft::device_vector out; - raft::device_vector ref_out; + raft::device_vector out_idx; + raft::device_vector out_dist; + raft::device_vector out_kvp; + raft::device_vector ref_out; size_t workspace_size; }; @@ -187,6 +194,7 @@ const std::vector> input_fp32 = { {4096, 16384, 128, DistanceType::L2Expanded, true, uint64_t(31415926), 0.1}, {4096, 4096, 64, DistanceType::L2SqrtExpanded, false, uint64_t(31415926), 0.1}, {4096, 16384, 128, DistanceType::L2SqrtExpanded, false, uint64_t(31415926), 0.1}, + {512, 1024, 64, DistanceType::InnerProduct, false, uint64_t(31415926), 0.1}, {4096, 4096, 64, DistanceType::CosineExpanded, false, uint64_t(31415926), 0.1}, {8192, 4096, 64, DistanceType::CosineExpanded, false, uint64_t(31415926), 0.1}, // Fused implementation for cosine distance ignores the sqrt parameter, therefore diff --git a/cpp/tests/neighbors/distance_nn_helper.cuh b/cpp/tests/neighbors/distance_nn_helper.cuh index fda7b76573..dfa71f71a4 100644 --- a/cpp/tests/neighbors/distance_nn_helper.cuh +++ b/cpp/tests/neighbors/distance_nn_helper.cuh @@ -66,6 +66,16 @@ __device__ AccT cosine_distance(const DataT* v1, const DataT* v2, IdxT K) } // This is a naive implementation of 1-NN computation +template +__device__ AccT inner_product_score(const DataT* v1, const DataT* v2, IdxT K) +{ + AccT score = AccT(0.0); + for (IdxT i = 0; i < K; i++) { + score += AccT(v1[i]) * AccT(v2[i]); + } + return score; +} + template RAFT_KERNEL ref_nn_kernel( OutT* out, const DataT* A, const DataT* B, IdxT M, IdxT N, IdxT K, bool sqrt, DistanceType metric) @@ -73,22 +83,47 @@ RAFT_KERNEL ref_nn_kernel( IdxT tid = threadIdx.x + blockIdx.x * IdxT(blockDim.x); for (IdxT m = tid; m < M; m += (blockDim.x * gridDim.x)) { - IdxT min_index = N + 1; - AccT min_dist = max_val(); + IdxT best_index = N + 1; + AccT best_score = min_val(); + AccT best_dist = max_val(); for (IdxT n = 0; n < N; n++) { + if (metric == DistanceType::InnerProduct) { + AccT score = inner_product_score(&A[m * K], &B[n * K], K); + if (score > best_score) { + best_score = score; + best_index = n; + } + continue; + } + AccT dist; if (metric == DistanceType::L2SqrtExpanded || metric == DistanceType::L2Expanded) { dist = l2_distance(&A[m * K], &B[n * K], K); } else if (metric == DistanceType::CosineExpanded) { dist = cosine_distance(&A[m * K], &B[n * K], K); + } else { + continue; + } + if (dist < best_dist) { + best_dist = dist; + best_index = n; } - if (dist < min_dist) { - min_dist = dist; - min_index = n; + } + + if (metric == DistanceType::InnerProduct) { + if constexpr (std::is_fundamental::value) { + out[m] = AccT(best_score); + } else { + out[m].key = IdxT(best_index); + out[m].value = AccT(best_score); } + continue; } + IdxT min_index = best_index; + AccT min_dist = best_dist; + if constexpr (std::is_fundamental::value) { static_assert(std::is_same::value, "OutT and AccT are not same type"); out[m] = AccT(min_dist); @@ -174,6 +209,34 @@ class ComparisonSummary { } }; +template +void vector_compare_soa(raft::resources const& handle, + const raft::KeyValuePair* ref, + const IdxT* out_idx, + const AccT* out_dist, + IdxT n, + ComparisonSummary& summary) +{ + auto ref_h = raft::make_host_vector, IdxT>(n); + auto idx_h = raft::make_host_vector(n); + auto dist_h = raft::make_host_vector(n); + + raft::copy(ref_h.data_handle(), ref, n, raft::resource::get_cuda_stream(handle)); + raft::copy(idx_h.data_handle(), out_idx, n, raft::resource::get_cuda_stream(handle)); + raft::copy(dist_h.data_handle(), out_dist, n, raft::resource::get_cuda_stream(handle)); + raft::resource::sync_stream(handle, raft::resource::get_cuda_stream(handle)); + + summary.init(); + + for (IdxT i = 0; i < n; i++) { + const double a_val = double(dist_h(i)); + const double b_val = double(ref_h(i).value); + const bool missed = idx_h(i) != ref_h(i).key; + const double diff = std::abs(a_val - b_val); + summary.update(diff, i, a_val, b_val, missed); + } +} + template void vector_compare( raft::resources const& handle, const OutT* a, const OutT* b, IdxT n, ComparisonSummary& summary) diff --git a/dependencies.yaml b/dependencies.yaml index 744e4d9227..756041e60c 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -395,6 +395,7 @@ dependencies: - cuda-nvrtc-dev - cuda-nvtx-dev - cuda-profiler-api + - cutile-python - libcublas-dev - libcurand-dev - libcusolver-dev @@ -430,12 +431,14 @@ dependencies: packages: - &ctk_cu13 cuda-toolkit[cublas,curand,cusolver,cusparse,nvrtc]==13.* - &nvjitlink_cu13 nvidia-nvjitlink>=13.0,<14 + - &cutile_cu13 cuda-tile[tileiras] # if no matching matrix selectors passed, list the CUDA 13 requirement # (just as a source of documentation, as this populates pyproject.toml in source control) - matrix: packages: - *ctk_cu13 - *nvjitlink_cu13 + - *cutile_cu13 depends_on_cudart: common: - output_types: conda diff --git a/python/libcuvs/pyproject.toml b/python/libcuvs/pyproject.toml index 5025daa66d..b4e848304f 100644 --- a/python/libcuvs/pyproject.toml +++ b/python/libcuvs/pyproject.toml @@ -19,6 +19,7 @@ authors = [ license = "Apache-2.0" requires-python = ">=3.11" dependencies = [ + "cuda-tile[tileiras]", "cuda-toolkit[cublas,curand,cusolver,cusparse,nvrtc]==13.*", "libraft==26.8.*,>=0.0.0a0", "librmm==26.8.*,>=0.0.0a0", diff --git a/run_benchmark_kmeans.sh b/run_benchmark_kmeans.sh new file mode 100755 index 0000000000..37291b5518 --- /dev/null +++ b/run_benchmark_kmeans.sh @@ -0,0 +1,47 @@ +#!/usr/bin/env bash +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# Compare baseline / cuTile / flash-kmeans for one shape. +# +# Usage: +# export BENCH_CONDA=/path/to/miniforge3 +# export BENCH_ENV_BASE=... +# export BENCH_ENV_CUTILE=... +# export BENCH_ENV_FLASH=... +# export MAX_ITER=5 TOL=1e-4 SEED=42 +# export WARMUP_FIT=1 ITERS_FIT=3 WARMUP_PRED=1 ITERS_PRED=3 +# ./run_benchmark_kmeans.sh N D K +# + +set -u + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +if [[ $# -ne 3 ]]; then + echo "usage: $0 N D K" >&2 + echo "See script header for required env vars and examples." >&2 + exit 2 +fi + +: "${BENCH_CONDA:?set BENCH_CONDA to conda/miniforge root}" +: "${BENCH_ENV_BASE:?set BENCH_ENV_BASE}" +: "${BENCH_ENV_CUTILE:?set BENCH_ENV_CUTILE}" +: "${BENCH_ENV_FLASH:?set BENCH_ENV_FLASH}" +: "${MAX_ITER:?set MAX_ITER}" +: "${SEED:?set SEED}" +: "${WARMUP_FIT:?set WARMUP_FIT}" +: "${ITERS_FIT:?set ITERS_FIT}" +: "${WARMUP_PRED:?set WARMUP_PRED}" +: "${ITERS_PRED:?set ITERS_PRED}" +: "${TOL:?set TOL}" + +N=$1 +D=$2 +K=$3 + +exec python3 "$SCRIPT_DIR/benchmark_kmeans.py" --compare \ + --n "$N" --d "$D" --k "$K" \ + --max-iter "$MAX_ITER" --tol "$TOL" --seed "$SEED" \ + --warmup-fit "$WARMUP_FIT" --iters-fit "$ITERS_FIT" \ + --warmup-pred "$WARMUP_PRED" --iters-pred "$ITERS_PRED"