Skip to content

Refactor(linear): split LinearBackward kernel into 3 independent kernels#142

Open
chen2021673 wants to merge 4 commits intomasterfrom
split_linear_backward
Open

Refactor(linear): split LinearBackward kernel into 3 independent kernels#142
chen2021673 wants to merge 4 commits intomasterfrom
split_linear_backward

Conversation

@chen2021673
Copy link
Copy Markdown
Contributor

@chen2021673 chen2021673 commented Apr 10, 2026

Summary

Architecture refactoring of Linear/Matmul/Outer kernels.

The core idea is separation of concerns — moving the decision of whether a gradient should be computed from the kernel layer up to the autograd layer, making kernels pure compute functions. At the same time, unified GEMM/SGEMV primitives are abstracted at the bottom layer to eliminate duplicated cuBLAS boilerplate.

Changes

  • Autograd layer: LinearBackward and MatmulBackward are each decomposed into multiple independent Dispatcher calls. The needs_input_grad checks happen at the autograd layer, invoking only the kernels actually needed.
  • Kernel layer: The monolithic LinearBackward is split into LinearBackwardInput / LinearBackwardWeight / LinearBackwardBias; MatmulBackward is split into MatmulBackwardInput / MatmulBackwardOther, with naming aligned to MatmulForward(input, other).
  • File split: Matmul kernels are extracted from linear.cc / linear.cu into dedicated cpu/matmul.cc and cuda/matmul.cu, giving each file a single responsibility.
  • GEMM primitive: New gemm.cuh / gemm.cu define the GemmParams struct and GemmCuda(), providing a unified wrapper over cublasGemmEx and cublasGemmStridedBatchedEx branching logic. GetCublasHandle() / GetCudaStream() are centrally defined and shared across linear.cu, matmul.cu, and outer.cu, eliminating duplicate definitions.
  • SGEMV primitive: New SgemvParams struct and SgemvCuda() wrap the cublasSgemv call. LinearForward and LinearBackwardInput in linear.cu take the SGEMV path when bs==1 and fp32 (more efficient for matrix-vector shapes); bf16 falls back to GemmCuda since cublasSgemv does not support it. The fp32 backward path in outer.cu is migrated to SgemvCuda as well, eliminating inline cublasSgemv calls.

@chen2021673 chen2021673 force-pushed the split_linear_backward branch 3 times, most recently from 283d083 to 23d301b Compare April 15, 2026 01:58
@chen2021673 chen2021673 requested a review from kilinchange April 15, 2026 02:08
Move grad_flags logic from kernel to autograd layer. The
monolithic LinearBackward kernel is replaced by LinearBackwardInput,
LinearBackwardWeight, and LinearBackwardBias — each a pure compute
operation with no autograd-related parameters.
Move needs_input_grad logic from kernel to autograd layer. The monolithic MatmulBackward kernel
is replaced by MatmulBackwardInput1 and MatmulBackwardInput2.
…ls; rename MatmulBackwardInput1/2

- Add gemm.cuh / gemm.cu: GemmParams struct + GemmCuda() dispatch (cublasGemmEx or
  cublasGemmStridedBatchedEx based on batch_count), GetCublasHandle(), GetCudaStream()
  shared across all GEMM-using kernels
- Split matmul kernels (CPU + CUDA) out of linear.cc / linear.cu into dedicated
  matmul.cc / matmul.cu; linear.* now only contains the four Linear kernels
- Rename MatmulBackwardInput1 → MatmulBackwardInput, MatmulBackwardInput2 → MatmulBackwardOther
  for semantic clarity matching MatmulForward(input, other) parameter names
- Rewrite outer.cu to use GemmCuda() (OuterForward + bf16 backward paths);
  keep cublasSgemv for the fp32 backward path (more efficient, bf16 unsupported)
@chen2021673 chen2021673 force-pushed the split_linear_backward branch from 23d301b to 97dabe4 Compare April 27, 2026 02:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant