Skip to content

fix: FP16 crash on CUDA 12.x + kernel optimizations#14

Open
zaki699-blip wants to merge 1 commit into
SeanWangJS:mainfrom
zaki699-blip:fix/fp16-cuda12-compat
Open

fix: FP16 crash on CUDA 12.x + kernel optimizations#14
zaki699-blip wants to merge 1 commit into
SeanWangJS:mainfrom
zaki699-blip:fix/fp16-cuda12-compat

Conversation

@zaki699-blip

Copy link
Copy Markdown

Summary

Fix critical FP16 (half-precision) crash (CUDA error 700: illegal memory access) when running on CUDA 12.x, and apply kernel performance optimizations.

Bug Fixes

1. Remove conflicting __half operator overloads (grid_sample_3d.cuh)

CUDA 12.x ships native __half arithmetic operators. The 12 custom operator overloads in this file caused ODR (One Definition Rule) violations at link time, leading to undefined behavior.

One overload was inherently broken:

__half operator+=(const float& a, const half& b)  // Can't modify const ref!

Fix: All custom overloads removed. A minimal to_float<scalar_t>() helper handles __halffloat conversion.

2. Fix compute_index() precision loss for __half (grid_sample_3d.cuh)

compute_index() was templated in scalar_t, so when instantiated with __half (~3.3 decimal digits of precision), unnormalization and coordinate math produced garbage indices → out-of-bounds memory access → CUDA error 700.

Fix: compute_index() now works entirely in float32 internally and returns float. reflect_coordinates() is similarly rewritten as reflect_coordinates_f() in pure float32.

3. Fix nearest kernel channel stride bug (grid_sample_3d.cu)

The nearest interpolation kernel never advanced its input/output channel pointers inside the channel loop:

// MISSING:
input_NC_offset += input_stride_C;
output_NCDHW_offset += output_stride_C;

This caused channel-0 data to be read for all channels.

Fix: Added the missing stride advancement.

Performance Optimizations

Optimization Benefit
__restrict__ on all pointer parameters Enables compiler to generate better code (no aliasing assumptions)
Hoist boundary checks outside channel loops Coordinates are constant across C; avoids redundant comparisons
Precompute spatial offsets outside channel loops Avoid recomputing x * stride_W + y * stride_H + z * stride_D per channel
__ldg() for read-only loads Uses texture cache path for global memory reads
float32 trilinear weights Avoids repeated scalar_t conversions in the hot loop
float32 accumulation in bilinear kernel Better precision and avoids __half addition overhead

Testing

Tested with CUDA 12.6 / TensorRT 10.5 / SM80+ (Ampere):

Precision Max Diff vs PyTorch Status
FP32 0.0 ✅ Exact match
FP16 9.76e-4 ✅ Expected for half

Performance (input [1,32,16,64,64] NCDHW, grid [1,16,64,64,3]):

  • FP32: 117.5 μs
  • FP16: 93.9 μs → 1.25× speedup (20% latency reduction)

Notes

The original FP16 test was commented out in main() with a TODO comment — this PR fixes the underlying issues and makes FP16 fully functional on modern CUDA toolkits.

Fix critical FP16 (half-precision) crash (CUDA error 700: illegal memory
access) when running on CUDA 12.x, and apply kernel performance optimizations.

Bug Fixes
---------

1. Remove conflicting __half operator overloads (grid_sample_3d.cuh)

   CUDA 12.x ships native __half arithmetic operators. The 12 custom
   operator overloads caused ODR violations at link time, leading to
   undefined behavior. One overload was inherently broken:
     __half operator+=(const float&, const half&)
   attempts to modify a const reference.

   All custom overloads removed. A minimal to_float<scalar_t>() helper
   handles __half to float conversion.

2. Fix compute_index() precision loss for __half (grid_sample_3d.cuh)

   compute_index() was templated in scalar_t, so when instantiated with
   __half (~3.3 decimal digits), unnormalization and coordinate math
   produced garbage indices, causing out-of-bounds memory access.

   compute_index() now works entirely in float32 and returns float.
   reflect_coordinates() rewritten as reflect_coordinates_f() in pure
   float32.

3. Fix nearest kernel channel stride bug (grid_sample_3d.cu)

   The nearest interpolation kernel never advanced input/output channel
   pointers inside the channel loop (missing += stride_C). This caused
   channel-0 data to be read for all channels.

Performance Optimizations
-------------------------

- __restrict__ on all kernel and launcher pointer parameters
- Hoist boundary checks outside channel loops (constant across C)
- Precompute spatial offsets outside channel loops
- Use __ldg() for read-only global memory loads (texture cache path)
- Compute trilinear weights in float32 (avoids repeated conversions)
- Accumulate bilinear interpolation in float32, cast on output

Tested: CUDA 12.6 / TensorRT 10.5 / SM80+ (Ampere)
  FP32: max_diff vs PyTorch = 0.0 (exact match)
  FP16: max_diff vs PyTorch = 9.76e-4 (expected for half)
  FP16 latency: 93.9us vs FP32 117.5us = 1.25x speedup
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant