Skip to content

[Pytorch][Common] Hybrid quantization#2817

Open
negvet wants to merge 12 commits intoNVIDIA:mainfrom
negvet:hybrid_quantization
Open

[Pytorch][Common] Hybrid quantization#2817
negvet wants to merge 12 commits intoNVIDIA:mainfrom
negvet:hybrid_quantization

Conversation

@negvet
Copy link
Copy Markdown
Collaborator

@negvet negvet commented Mar 31, 2026

Description

Hybrid (per-direction) quantization. Functional.
C++ optimizations (fusions, etc.) will come in the next PRs.

Ecosystem integration (functional):

  • quantized_model_init
  • FSDP2 (TODO in the next PRs: optimize communication buffers)
  • CPU offloading
  • Activation recomputation
  • TP/SP (TODO in the next PRs: enable quantized AG)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Mar 31, 2026

Greptile Summary

This PR introduces hybrid (per-direction) quantization, allowing a HybridQuantizer to apply different quantization formats rowwise vs. columnwise (e.g., MXFP8 rowwise + NVFP4 columnwise). The implementation covers HybridQuantizedTensor, FSDP2 all-gather protocol, CPU offloading, activation recomputation, TP/SP, and GroupedLinear batched split-quantize paths.

  • P1 — _hybrid_split_quantize TypeError on None quantizers: _is_hybrid_quantizer_list allows None entries in the list (returns True for all-hybrid-or-None), but _hybrid_split_quantize's type guard (all(isinstance(q, HybridQuantizer) for q in quantizers)) fails on None because isinstance(None, HybridQuantizer) is False. Any GroupedLinear quantizer list with a None alongside HybridQuantizer entries will raise TypeError at runtime.
  • P1 — Float8TensorStorage.fsdp_buffer_fields unconditionally returns (\"_data\",): On Hopper/L40, columnwise-only Float8 sub-storages have _data=None and store data in _transpose. FSDP2 will all-gather None instead of the actual weights for this architecture/quantizer combination.

Confidence Score: 4/5

Two P1 defects — a TypeError in GroupedLinear's hybrid quantize path when None quantizers appear, and FSDP2 gathering None on Hopper/L40 for columnwise-only Float8 sub-storages — should be addressed before merging.

Two confirmed P1 bugs: (1) _hybrid_split_quantize raises TypeError for any quantizer list containing None entries because _is_hybrid_quantizer_list allows None but the type guard does not filter them. (2) Float8TensorStorage.fsdp_buffer_fields unconditionally returns ("_data",) even when _data is None on Hopper/L40. Prior concerns around _hybrid_split_quantize attribute errors and make_empty try/finally leaks have been addressed.

transformer_engine/pytorch/module/grouped_linear.py (_hybrid_split_quantize None guard), transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py (fsdp_buffer_fields direction awareness)

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/grouped_linear.py Adds _is_hybrid_quantizer_list and _hybrid_split_quantize helpers; mismatch between _is_hybrid_quantizer_list (allows None entries) and _hybrid_split_quantize's type guard (fails on None) causes TypeError at runtime.
transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py Adds fsdp_buffer_fields returning ("_data",) unconditionally; on Hopper/L40, columnwise-only sub-storages have _data=None, causing FSDP2 to all-gather None instead of _transpose.
transformer_engine/pytorch/tensor/hybrid_tensor.py New HybridQuantizer and HybridQuantizedTensor classes; FSDP2 protocol implemented via fsdp_pre/post_all_gather; silent sub-storage drop risk in fsdp_post_all_gather when orig sub is not a QuantizedTensor.
transformer_engine/pytorch/tensor/storage/hybrid_tensor_storage.py New HybridQuantizedTensorStorage mixin composing two sub-storages; __repr__ correctly guards None sub-storages (fixed from prior review); FSDP2 buffer protocol delegated to sub-storages.
transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py Adds fsdp_buffer_fields, fsdp_extract_buffers, and fsdp_assign_gathered with correct direction-aware field selection and block-scale de-/re-padding logic; handles None directions cleanly.
transformer_engine/pytorch/tensor/float8_tensor.py Guards _data references in clone and aten.split dispatch to handle columnwise-only sub-storages (_data=None) in hybrid quantization; logic is correct.
transformer_engine/pytorch/tensor/mxfp8_tensor.py Refactors aten.split dispatch to handle rowwise-only or columnwise-only sub-tensors; adds aten.clone dispatch; changes are clean and correct.
transformer_engine/pytorch/cpp_extensions/gemm.py Adds _unwrap_hybrid_A/B helpers to extract the direction-appropriate sub-storage from a HybridQuantizedTensorStorage before passing to cuBLAS, with correct layout-to-direction mapping.
transformer_engine/pytorch/quantized_tensor.py Adds FSDP2 buffer protocol (fsdp_buffer_fields, fsdp_extract_buffers, fsdp_assign_gathered) to QuantizedTensorStorage base class with sensible defaults; well-documented contract.
transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu Adds nullptr guards for output_c and tile_scales_inv_c to support columnwise-only quantization; validates that at least one of rowwise/columnwise output is requested; changes are correct.

Class Diagram

%%{init: {'theme': 'neutral'}}%%
classDiagram
    class Quantizer {
        +quantize(tensor)
        +make_empty(shape)
        +update_quantized(src, dst)
    }
    class HybridQuantizer {
        +rowwise_quantizer: Quantizer
        +columnwise_quantizer: Quantizer
        +quantize_impl(tensor)
        +make_empty(shape)
        +supports_only_rowwise_all_gather()
    }
    class QuantizedTensorStorage {
        +fsdp_buffer_fields()
        +fsdp_extract_buffers()
        +fsdp_assign_gathered()
    }
    class HybridQuantizedTensorStorage {
        +_rowwise_storage
        +_columnwise_storage
        +dequantize()
        +prepare_for_saving()
        +view()
    }
    class HybridQuantizedTensor {
        +fsdp_pre_all_gather()
        +fsdp_post_all_gather()
        +detach()
        +__torch_dispatch__()
    }
    class Float8TensorStorage {
        +_data: Tensor
        +_transpose: Tensor
        +fsdp_buffer_fields() returns _data unconditionally
    }
    class MXFP8TensorStorage {
        +_rowwise_data: Tensor
        +_columnwise_data: Tensor
        +fsdp_buffer_fields()
        +fsdp_extract_buffers()
        +fsdp_assign_gathered()
    }
    Quantizer <|-- HybridQuantizer
    QuantizedTensorStorage <|-- HybridQuantizedTensorStorage
    HybridQuantizedTensorStorage <|-- HybridQuantizedTensor
    QuantizedTensorStorage <|-- Float8TensorStorage
    QuantizedTensorStorage <|-- MXFP8TensorStorage
    HybridQuantizer --> Quantizer : rowwise_quantizer
    HybridQuantizer --> Quantizer : columnwise_quantizer
    HybridQuantizedTensorStorage --> QuantizedTensorStorage : _rowwise_storage
    HybridQuantizedTensorStorage --> QuantizedTensorStorage : _columnwise_storage
Loading

Reviews (4): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/module/grouped_linear.py
Comment thread transformer_engine/pytorch/tensor/storage/hybrid_tensor_storage.py
Comment thread transformer_engine/pytorch/tensor/hybrid_tensor.py Outdated
Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall I think this moves us in a good direction. I see some minor bugs, as well as bugs reported by @greptile-apps.

Comment on lines +52 to +53
rowwise_result = self.rowwise_quantizer.quantize(tensor)
columnwise_result = self.columnwise_quantizer.quantize(tensor)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we handle the case where not all usages are needed? I'd expect something like:

Suggested change
rowwise_result = self.rowwise_quantizer.quantize(tensor)
columnwise_result = self.columnwise_quantizer.quantize(tensor)
rowwise_result = self.rowwise_quantizer.quantize(tensor) if self.rowwise_usage else None
columnwise_result = self.columnwise_quantizer.quantize(tensor) if self.columnwise_usage else None

requires_grad: bool = False,
pin_memory: bool = False,
) -> HybridQuantizedTensor:
self.rowwise_quantizer.internal = True
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we just set internal=True in the constructor? I don't think we ever need PyTorch tensor functionality in the per-usage data.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would not work under FSDP2.

Comment thread transformer_engine/pytorch/tensor/hybrid_tensor.py Outdated
Comment on lines +1339 to +1355
def factory(role):
if role == "linear_weight":
return HybridQuantizer(
rowwise_quantizer=_make_fp8_quantizer(),
columnwise_quantizer=_make_mxfp8_quantizer(),
)
if role == "linear_input":
return HybridQuantizer(
rowwise_quantizer=_make_fp8_quantizer(),
columnwise_quantizer=_make_nvfp4_quantizer(),
)
if role in ("linear_grad_output", "linear_grad_input"):
return HybridQuantizer(
rowwise_quantizer=_make_mxfp8_quantizer(),
columnwise_quantizer=_make_nvfp4_quantizer(),
)
return None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is horrifying. Good test.

negvet and others added 10 commits April 6, 2026 10:26
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Comment on lines +90 to +97
raise ValueError(
"GroupedLinear quantizer list mixes HybridQuantizer and non-hybrid"
f" quantizers ({hybrid_count} hybrid, {len(non_none) - hybrid_count}"
" non-hybrid). This combination is not supported: neither"
" `tex.split_quantize` nor `_hybrid_split_quantize` can consume a"
" heterogeneous list. Make the CustomRecipe `qfactory` return a"
" consistent type (all HybridQuantizer or all non-hybrid) across"
" every GEMM for the same role."
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 _hybrid_split_quantize TypeError on None-containing quantizer lists

_is_hybrid_quantizer_list explicitly allows None entries in the quantizer list (it filters them with non_none = [q for q in quantizers if q is not None]) and returns True when all non-None entries are HybridQuantizer. However, _hybrid_split_quantize's type guard is:

if not all(isinstance(q, HybridQuantizer) for q in quantizers):
    raise TypeError(...)

isinstance(None, HybridQuantizer) is False, so this raises TypeError when any None entry exists — before even reaching [q.rowwise_quantizer for q in quantizers]. Since tex.split_quantize in the non-hybrid path accepts None quantizers, None entries are a valid input to the GroupedLinear quantizer list, making this a real runtime error whenever a list like [HybridQuantizer, None, HybridQuantizer] is passed.

Fix option A — tighten _is_hybrid_quantizer_list to treat None+Hybrid as unsupported (simplest):

if hybrid_count > 0 and len(non_none) < len(quantizers):
    raise ValueError("None quantizers are not supported alongside HybridQuantizer.")

Fix option B — filter None inside _hybrid_split_quantize and propagate None entries as pass-through:

if not all(isinstance(q, HybridQuantizer) for q in quantizers if q is not None):
    raise TypeError(...)
row_quantizers = [q.rowwise_quantizer if q is not None else None for q in quantizers]
col_quantizers = [q.columnwise_quantizer if q is not None else None for q in quantizers]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants