From 9e82471ff15e0e02995cf10bb84a392e98ef5647 Mon Sep 17 00:00:00 2001 From: Joseph Loftin Date: Tue, 2 Jun 2026 00:26:25 +0000 Subject: [PATCH 1/5] All To All --- .../conversion/custom_ops_converters.py | 20 ++++ .../dynamo/conversion/impl/nccl_ops.py | 83 ++++++++++++++ .../lowering/passes/fuse_distributed_ops.py | 29 +++++ .../py/dynamo/distributed/test_native_nccl.py | 102 ++++++++++++++++++ 4 files changed, 234 insertions(+) diff --git a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py index cdb8cb4647..e956ae2c0f 100644 --- a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py @@ -16,6 +16,7 @@ tensorrt_fused_nccl_all_gather_op, tensorrt_fused_nccl_all_reduce_op, tensorrt_fused_nccl_reduce_scatter_op, + tensorrt_fused_nccl_all_to_all_op ) _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -86,6 +87,25 @@ def fused_nccl_all_reduce( reduce_op=reduce_op, ) + @dynamo_tensorrt_converter( + tensorrt_fused_nccl_all_to_all_op, requires_native_multidevice=True + ) + def fused_nccl_all_to_all( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, + ) -> Union[trt.ITensor, Sequence[trt.ITensor]]: + """All-to-all using native TensorRT DistCollective API.""" + return impl.nccl_ops.nccl_all_to_all_native( + ctx, + target, + SourceIR.ATEN, + name, + [args[0]], + ) + # Conditionally register NCCL converters only if TensorRT-LLM plugin is available. # We use an `if` statement instead of @needs_trtllm_for_nccl decorator because diff --git a/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py b/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py index 26e67cab67..68fd16845a 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py @@ -487,3 +487,86 @@ def nccl_all_reduce_native( except Exception as e: logger.error(f"Native ALL_REDUCE failed: {e} (type: {type(e).__name__})") raise + +def nccl_all_to_all_native( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + plug_inputs: Tuple[Argument, ...], +) -> trt.ITensor: + """ + Implement all_to_all using native TensorRT DistCollective API. + + This operation sends a chunk of data from each rank. The i-th rank will receive the data + at the i-th position of every other rank. + + Returns: + Output tensor after all_to_all operation + + Example: + Input on rank 0: [1, 2] shape=(2,) + Input on rank 1: [3, 4] shape=(2,) + Output on rank 0: [1, 3] shape=(2,) + Output on rank 1: [2, 4] shape=(2,) + """ + rank, world_size = _get_distributed_rank_and_world_size() + + # TRT add_dist_collective crashes with world_size=1; all_to_all of a single rank + # is an identity op. + if world_size == 1: + return plug_inputs[0] + logger.debug( + f"Adding native all_gather: name={name}, rank={rank}, world_size={world_size}" + ) + + # Get the input tensor + input_tensor = plug_inputs[0] + + try: + # Use native TensorRT DistCollective API for ALL_TO_ALL + # For ALL_TO_ALL, the reduce operation and root rank parameters are ignored + # The last parameter (group) can be None to include all ranks + import numpy as np + + # Create array of all participating rank IDs [0, 1, 2, ..., world_size-1] + groups = np.arange(world_size, dtype=np.int64) + + logger.debug( + f"Creating ALL_TO_ALL layer: groups={groups.tolist()}, groupSize={world_size}" + ) + layer = ctx.net.add_dist_collective( + input_tensor, + trt.CollectiveOperation.ALL_TO_ALL, + trt.ReduceOperation.NONE, # Ignored for ALL_TO_ALL + -1, # Root rank - ignored for ALL_TO_ALL + groups, # None means all ranks participate (world_size ranks) + ) + + logger.debug(f"Successfully created native ALL_TO_ALL layer: {name}") + logger.debug( + f"Calling add_dist_collective: input_shape={input_tensor.shape}, " + f"groups={groups.tolist()}, groupSize={len(groups)} (inferred from array)" + ) + + set_layer_name(layer, target, name, source_ir) + + output = layer.get_output(0) + layer.num_ranks = world_size + + return output + + except AttributeError as e: + error_msg = ( + f"Native ALL_TO_ALL failed: {e}. " + "This usually means TensorRT doesn't support native distributed collectives. " + f"Your TensorRT version: {trt.__version__}. " + "Native collectives require TensorRT 11 or later. " + "Consider using TensorRT-LLM plugins instead by setting USE_NATIVE_TRT_COLLECTIVES=0" + ) + logger.error(error_msg) + raise RuntimeError(error_msg) from e + + except Exception as e: + logger.error(f"Native ALL_TO_ALL failed: {e} (type: {type(e).__name__})") + raise diff --git a/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py b/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py index 2772ba7d9f..b02e3f154d 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py @@ -68,6 +68,23 @@ def _(inp: torch.Tensor, reduce_op: str, group_name: str) -> torch.Tensor: return torch.empty_like(inp) +@torch.library.custom_op("tensorrt::fused_nccl_all_to_all", mutates_args=()) +def _fused_nccl_all_to_all_impl( + inp: torch.Tensor, output_splits: list[int] | None, input_splits: list[int] | None, group_name: str +) -> torch.Tensor: + out_shape = inp.shape + return inp.new_empty(out_shape) + + +@_fused_nccl_all_to_all_impl.register_fake +def _( + inp: torch.Tensor, output_splits: list[int] | None, input_splits: list[int] | None, group_name: str +) -> torch.Tensor: + return torch.ops._c10d_functional.wait_tensor.default( + torch.ops._c10d_functional.all_to_all_single.default(inp, output_splits, input_splits, group_name) + ) + + # Public aliases — used as FX node targets in the fuse pass, as converter keys # in custom_ops_converters.py, and in test equality checks. Each is the # torch._ops.OpOverload created by the custom_op decoration above. @@ -76,6 +93,7 @@ def _(inp: torch.Tensor, reduce_op: str, group_name: str) -> torch.Tensor: torch.ops.tensorrt.fused_nccl_reduce_scatter.default ) tensorrt_fused_nccl_all_reduce_op = torch.ops.tensorrt.fused_nccl_all_reduce.default +tensorrt_fused_nccl_all_to_all_op = torch.ops.tensorrt.fused_nccl_all_to_all.default def fuse_distributed_ops( @@ -89,6 +107,7 @@ def fuse_distributed_ops( torch.ops._c10d_functional.all_gather_into_tensor.default, torch.ops._c10d_functional.reduce_scatter_tensor.default, torch.ops._c10d_functional.all_reduce.default, + torch.ops._c10d_functional.all_to_all_single.default, ) and len(node.users) == 1 and list(node.users)[0].target @@ -111,6 +130,16 @@ def fuse_distributed_ops( target=tensorrt_fused_nccl_reduce_scatter_op, args=(node.args[0], node.args[1], node.args[2], node.args[3]), ) + elif ( + node.target == torch.ops._c10d_functional.all_to_all_single.default + ): + with gm.graph.inserting_after(wait_tensor_node): + fused_node = gm.graph.create_node( + op="call_function", + target=tensorrt_fused_nccl_all_to_all_op, + # Drop input and output splits, since TRT doesn't use them. + args=(node.args[0], node.args[1], node.args[2], node.args[3]), + ) else: with gm.graph.inserting_after(wait_tensor_node): fused_node = gm.graph.create_node( diff --git a/tests/py/dynamo/distributed/test_native_nccl.py b/tests/py/dynamo/distributed/test_native_nccl.py index e4f5b18582..1ada3fbe1e 100644 --- a/tests/py/dynamo/distributed/test_native_nccl.py +++ b/tests/py/dynamo/distributed/test_native_nccl.py @@ -949,6 +949,46 @@ def test_fuse_all_reduce_no_group_size_arg(self) -> None: "3-arg all_reduce node." ) + # -- all_to_all --------------------------------------------------------- + + def test_fuse_all_to_all_replaces_pair(self) -> None: + """all_to_all_single + wait_tensor → tensorrt_fused_nccl_all_to_all_op.""" + from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import ( + tensorrt_fused_nccl_all_to_all_op, + ) + + gm = _build_graph( + torch.ops._c10d_functional.all_to_all_single.default, + args_without_input=([], [], "test_group"), + ) + gm = self._run_pass(gm) + targets = _node_targets(gm) + self.assertNotIn( + torch.ops._c10d_functional.all_to_all_single.default, targets + ) + self.assertNotIn(torch.ops._c10d_functional.wait_tensor.default, targets) + self.assertIn(tensorrt_fused_nccl_all_to_all_op, targets) + + def test_fuse_all_to_all_args(self) -> None: + """Fused all_to_all node carries (inp, input splits, output splits, group_name).""" + from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import ( + tensorrt_fused_nccl_all_to_all_op, + ) + + gm = _build_graph( + torch.ops._c10d_functional.all_to_all_single.default, + # input splits, output splits, group name + args_without_input=([1], [2], "grp"), + ) + gm = self._run_pass(gm) + fused = next( + n for n in gm.graph.nodes if n.target == tensorrt_fused_nccl_all_to_all_op + ) + # args: (inp_placeholder, [], [], "grp") + self.assertEqual(fused.args[1], [1]) + self.assertEqual(fused.args[2], [2]) + self.assertEqual(fused.args[3], "grp") + # -- no-fuse when wait_tensor has multiple users ----------------------- def test_fuse_when_wait_tensor_result_has_multiple_uses(self) -> None: @@ -1078,6 +1118,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) return torch.ops._c10d_functional.wait_tensor.default(out) +class _AllToAllModel(nn.Module): + def __init__(self, dim: int, world_size: int, group_name: str) -> None: + super().__init__() + self.fc = nn.Linear(dim, dim) + self.world_size = world_size + self.group_name = group_name + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.fc(x) + split_sizes = [x.shape[0] // self.world_size] * self.world_size + out = torch.ops._c10d_functional.all_to_all_single.default( + x, split_sizes, split_sizes, self.group_name + ) + return torch.ops._c10d_functional.wait_tensor.default(out) + @unittest.skipIf( not is_nccl_available(), @@ -1158,6 +1213,14 @@ def test_reduce_scatter_single_rank(self) -> None: [torch.randn(1, dim)], ) + def test_all_to_all_single_rank(self) -> None: + """all_to_all compiles and produces correct output on a single rank.""" + dim = 8 + self._run( + _AllToAllModel(dim, self.world_size, self.group_name), + [torch.randn(1, dim)], + ) + def test_distributed_mode_with_single_rank_subgroup(self) -> None: """distributed_context() selects the subgroup as NCCL communicator source.""" import torch_tensorrt @@ -1483,6 +1546,37 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) +def _multirank_all_to_all_correctness( + rank: int, world_size: int, device: torch.device +) -> None: + """all_to_all sends a chunk from each rank to every other ank.""" + import torch_tensorrt + from torch_tensorrt.distributed._nccl_utils import setup_nccl_for_torch_tensorrt + + setup_nccl_for_torch_tensorrt() + group = dist.group.WORLD + group_name = group.group_name if hasattr(group, "group_name") else "" + + class AllToAll(nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + split_sizes = [x.shape[0] // world_size] * world_size + out = torch.ops._c10d_functional.all_to_all_single.default( + x, split_sizes, split_sizes, group_name + ) + return torch.ops._c10d_functional.wait_tensor.default(out) + + model = AllToAll().to(device).eval() + # Input: [0...world_size) + inp = torch.arange(world_size, device=device) + + with torch.no_grad(): + out = model(inp) + + # Result for rank r: [r] * world_size + expected = torch.tensor([rank] * world_size, device=device) + _check_close(out, expected, f"all_to_all rank={rank}") + + def _multirank_distributed_mode_tp_model( rank: int, world_size: int, device: torch.device ) -> None: @@ -1864,6 +1958,13 @@ def test_reduce_scatter_all_reduce_ops(self) -> None: device = self._init_dist() _multirank_reduce_scatter_all_reduce_ops(self.rank, self.world_size, device) + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_all_to_all_correctness(self) -> None: + """all_to_all sends a chunk from each rank to every other ank.""" + device = self._init_dist() + _multirank_all_to_all_correctness(self.rank, self.world_size, device) + @unittest.skipIf(not has_nccl_collectives(), "No NCCL collective support available") @requires_nccl() @skip_if_lt_x_gpu(2) @@ -1919,6 +2020,7 @@ def run_multirank_tests() -> None: _multirank_all_reduce_correctness, _multirank_all_gather_correctness, _multirank_reduce_scatter_all_reduce_ops, + _multirank_all_to_all_correctness, _multirank_distributed_mode_tp_model, _multirank_distributed_mode_subgroup, _multirank_cpp_runtime_bind_nccl, From 33fc48ab0432ec4184114a7557106470411e8e7f Mon Sep 17 00:00:00 2001 From: Joseph Loftin Date: Wed, 3 Jun 2026 22:42:48 +0000 Subject: [PATCH 2/5] Scatter --- .../conversion/custom_ops_converters.py | 24 ++++- .../dynamo/conversion/impl/nccl_ops.py | 87 ++++++++++++++++++- .../lowering/passes/fuse_distributed_ops.py | 27 +++++- .../py/dynamo/distributed/test_native_nccl.py | 82 ++++++++++++++++- 4 files changed, 215 insertions(+), 5 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py index e956ae2c0f..48ad5e816a 100644 --- a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py @@ -16,7 +16,8 @@ tensorrt_fused_nccl_all_gather_op, tensorrt_fused_nccl_all_reduce_op, tensorrt_fused_nccl_reduce_scatter_op, - tensorrt_fused_nccl_all_to_all_op + tensorrt_fused_nccl_all_to_all_op, + tensorrt_fused_nccl_scatter_op ) _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -106,6 +107,27 @@ def fused_nccl_all_to_all( [args[0]], ) + @dynamo_tensorrt_converter( + tensorrt_fused_nccl_scatter_op, requires_native_multidevice=True + ) + def fused_nccl_scatter( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, + ) -> Union[trt.ITensor, Sequence[trt.ITensor]]: + """Scatter using native TensorRT DistCollective API.""" + root = args[1] if len(args) > 1 else 0 + return impl.nccl_ops.nccl_scatter_native( + ctx, + target, + SourceIR.ATEN, + name, + [args[0]], + root=root + ) + # Conditionally register NCCL converters only if TensorRT-LLM plugin is available. # We use an `if` statement instead of @needs_trtllm_for_nccl decorator because diff --git a/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py b/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py index 68fd16845a..e82a35b681 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py @@ -488,6 +488,7 @@ def nccl_all_reduce_native( logger.error(f"Native ALL_REDUCE failed: {e} (type: {type(e).__name__})") raise + def nccl_all_to_all_native( ctx: ConversionContext, target: Union[Target, str], @@ -517,7 +518,7 @@ def nccl_all_to_all_native( if world_size == 1: return plug_inputs[0] logger.debug( - f"Adding native all_gather: name={name}, rank={rank}, world_size={world_size}" + f"Adding native all_to_all: name={name}, rank={rank}, world_size={world_size}" ) # Get the input tensor @@ -570,3 +571,87 @@ def nccl_all_to_all_native( except Exception as e: logger.error(f"Native ALL_TO_ALL failed: {e} (type: {type(e).__name__})") raise + + +def nccl_scatter_native( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + plug_inputs: Tuple[Argument, ...], + root: int = 0 +) -> trt.ITensor: + """ + Implement scatter using native TensorRT DistCollective API. + + This operation has the root rank send a chunk of its data to every other rank. + + Returns: + Output tensor after scatter operation + + Example: + Input on rank 0: [1, 2] shape=(2,) + Input on rank 1: None shape=(2,) + Output on rank 0: [1] shape=(1,) + Output on rank 1: [2] shape=(1,) + """ + rank, world_size = _get_distributed_rank_and_world_size() + + # TRT add_dist_collective crashes with world_size=1; scatter of a single rank + # is an identity op. + if world_size == 1: + return plug_inputs[0] + logger.debug( + f"Adding native scatter: name={name}, rank={rank}, world_size={world_size}" + ) + + # Get the input tensor + input_tensor = plug_inputs[0] + + try: + # Use native TensorRT DistCollective API for SCATTER + # For SCATTER, the reduce operation parameter is ignored + # The last parameter (group) can be None to include all ranks + import numpy as np + + # Create array of all participating rank IDs [0, 1, 2, ..., world_size-1] + groups = np.arange(world_size, dtype=np.int64) + + logger.debug( + f"Creating scatter layer: groups={groups.tolist()}, groupSize={world_size}" + ) + layer = ctx.net.add_dist_collective( + input_tensor, + trt.CollectiveOperation.SCATTER, + trt.ReduceOperation.NONE, # Ignored for SCATTER + root, + groups, # None means all ranks participate (world_size ranks) + ) + + logger.debug(f"Successfully created native SCATTER layer: {name}") + logger.debug( + f"Calling add_dist_collective: input_shape={input_tensor.shape}, " + f"root={root}, groups={groups.tolist()}, groupSize={len(groups)} (inferred from array)" + ) + + set_layer_name(layer, target, name, source_ir) + + output = layer.get_output(0) + layer.num_ranks = world_size + + return output + + except AttributeError as e: + error_msg = ( + f"Native SCATTER failed: {e}. " + "This usually means TensorRT doesn't support native distributed collectives. " + f"Your TensorRT version: {trt.__version__}. " + "Native collectives require TensorRT 11 or later. " + "Consider using TensorRT-LLM plugins instead by setting USE_NATIVE_TRT_COLLECTIVES=0" + ) + logger.error(error_msg) + raise RuntimeError(error_msg) from e + + except Exception as e: + logger.error(f"Native SCATTER failed: {e} (type: {type(e).__name__})") + raise diff --git a/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py b/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py index b02e3f154d..4c8511a662 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py @@ -85,6 +85,29 @@ def _( ) +@torch.library.custom_op("tensorrt::fused_nccl_scatter", mutates_args=()) +def _fused_nccl_scatter_impl( + inp: torch.Tensor, src: int, group_name: str +) -> torch.Tensor: + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + + out = torch.ops._c10d_functional.broadcast.default(inp, src, group_name) + out = torch.ops._c10d_functional.wait_tensor.default(out) + + chunk = out.shape[0] // world_size + return out[rank * chunk : (rank + 1) * chunk] + + +@_fused_nccl_scatter_impl.register_fake +def _( + inp: torch.Tensor, src: int, group_name: str +) -> torch.Tensor: + world_size = torch.distributed.get_world_size() + out_shape = (inp.shape[0] // world_size,) + tuple(inp.shape[1:]) + return inp.new_empty(out_shape) + + # Public aliases — used as FX node targets in the fuse pass, as converter keys # in custom_ops_converters.py, and in test equality checks. Each is the # torch._ops.OpOverload created by the custom_op decoration above. @@ -94,6 +117,7 @@ def _( ) tensorrt_fused_nccl_all_reduce_op = torch.ops.tensorrt.fused_nccl_all_reduce.default tensorrt_fused_nccl_all_to_all_op = torch.ops.tensorrt.fused_nccl_all_to_all.default +tensorrt_fused_nccl_scatter_op = torch.ops.tensorrt.fused_nccl_scatter.default def fuse_distributed_ops( @@ -107,7 +131,7 @@ def fuse_distributed_ops( torch.ops._c10d_functional.all_gather_into_tensor.default, torch.ops._c10d_functional.reduce_scatter_tensor.default, torch.ops._c10d_functional.all_reduce.default, - torch.ops._c10d_functional.all_to_all_single.default, + torch.ops._c10d_functional.all_to_all_single.default ) and len(node.users) == 1 and list(node.users)[0].target @@ -137,7 +161,6 @@ def fuse_distributed_ops( fused_node = gm.graph.create_node( op="call_function", target=tensorrt_fused_nccl_all_to_all_op, - # Drop input and output splits, since TRT doesn't use them. args=(node.args[0], node.args[1], node.args[2], node.args[3]), ) else: diff --git a/tests/py/dynamo/distributed/test_native_nccl.py b/tests/py/dynamo/distributed/test_native_nccl.py index 1ada3fbe1e..ec01255626 100644 --- a/tests/py/dynamo/distributed/test_native_nccl.py +++ b/tests/py/dynamo/distributed/test_native_nccl.py @@ -60,6 +60,8 @@ ) from torch.testing._internal.common_utils import run_tests +import pytest + # --------------------------------------------------------------------------- # helpers # --------------------------------------------------------------------------- @@ -785,6 +787,25 @@ def _build_graph(collective_op, args_without_input): return torch.fx.GraphModule({}, g) +def _build_graph_scatter(args_without_input): + """Build a minimal FX graph: input → scatter → wait_tensor → slice → output.""" + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + + g = torch.fx.Graph() + inp = g.placeholder("inp") + coll = g.call_function(torch.ops._c10d_functional.broadcast.default, args=(inp, *args_without_input)) + wait = g.call_function(torch.ops._c10d_functional.wait_tensor.default, args=(coll,)) + # To avoid tracing tensor sizes for test + # use dummy dimension to represent inp's + # 0th axis shape + dummy_shape_dim = 10 + chunk = dummy_shape_dim // world_size + chunk = g.call_function(torch.ops.aten.slice.Tensor, args=(wait, 0, rank * chunk, (rank+1) * chunk)) + g.output(chunk) + return torch.fx.GraphModule({}, g) + + def _node_targets(gm: torch.fx.GraphModule) -> list: return [n.target for n in gm.graph.nodes if n.op == "call_function"] @@ -797,6 +818,7 @@ class TestFuseDistributedOps(unittest.TestCase): handled correctly. """ + def _settings(self): from torch_tensorrt.dynamo._settings import CompilationSettings @@ -1133,6 +1155,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) return torch.ops._c10d_functional.wait_tensor.default(out) +class _ScatterModel(nn.Module): + def __init__(self, dim: int, root : int, group_name: str) -> None: + super().__init__() + self.fc = nn.Linear(dim, dim) + self.group_name = group_name + self.root = root + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.fc(x) + return torch.ops.tensorrt.fused_nccl_scatter(x, self.root, self.group_name) + + @unittest.skipIf( not is_nccl_available(), @@ -1159,6 +1193,7 @@ def setUpClass(cls) -> None: cls.group = dist.new_group(ranks=[0]) cls.group_name = cls.group.group_name cls.world_size = 1 + cls.root = 0 @classmethod def tearDownClass(cls) -> None: @@ -1221,6 +1256,14 @@ def test_all_to_all_single_rank(self) -> None: [torch.randn(1, dim)], ) + def test_scatter_single_rank(self) -> None: + """scatter compiles and produces correct output on a single rank.""" + dim = 8 + self._run( + _ScatterModel(dim, self.root, self.group_name), + [torch.randn(1, dim)], + ) + def test_distributed_mode_with_single_rank_subgroup(self) -> None: """distributed_context() selects the subgroup as NCCL communicator source.""" import torch_tensorrt @@ -1464,10 +1507,39 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # After gather: shape is (world_size, 4), row i == float(i) assert out.shape == torch.Size([world_size, 4]), f"Shape mismatch: {out.shape}" for r in range(world_size): - expected_row = torch.full((4,), float(r), device=device) + expected_row = torch.full((4,), float(r), device=device, dtype=out.dtype) _check_close(out[r], expected_row, f"all_gather row {r} rank={rank}") +def _multirank_scatter_correctness( + root: int, rank: int, world_size: int, device: torch.device +) -> None: + """all_gather concatenates tensors from all ranks in order.""" + group = dist.group.WORLD + group_name = group.group_name if hasattr(group, "group_name") else "" + + class Scatter(nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.tensorrt.fused_nccl_scatter(x, root, group_name) + + model = Scatter().to(device).eval() + # input is of shape (world_size, 4) laid out like + # [0, 0, 0, 0] + # ... + # [world_size - 1, world_size - 1, world_size - 1, world_size - 1] + + # Supply root only with valid input + inp = torch.arange(world_size, device=device).unsqueeze(1).repeat(1, 4) if rank == root else torch.full((world_size, 4), -1, device=device) + + with torch.no_grad(): + out = model(inp) + + # After scatter: shape is (1, 4), row[0] == [rank, rank, rank, rank] + assert out.shape == torch.Size([1, 4]), f"Shape mismatch: {out.shape}" + expected_row = torch.full((1, 4), int(rank), device=device) + _check_close(out, expected_row, f"scatter row {out} rank={rank}") + + def _multirank_reduce_scatter_all_reduce_ops( rank: int, world_size: int, device: torch.device ) -> None: @@ -1965,6 +2037,14 @@ def test_all_to_all_correctness(self) -> None: device = self._init_dist() _multirank_all_to_all_correctness(self.rank, self.world_size, device) + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_scatter_correctness(self) -> None: + """all_to_all sends a chunk from each rank to every other ank.""" + device = self._init_dist() + for i in range(self.world_size): + _multirank_scatter_correctness(i, self.rank, self.world_size, device) + @unittest.skipIf(not has_nccl_collectives(), "No NCCL collective support available") @requires_nccl() @skip_if_lt_x_gpu(2) From ef8f488340d13e28c758223b59bcfa4c09513928 Mon Sep 17 00:00:00 2001 From: Joseph Loftin Date: Wed, 3 Jun 2026 23:59:27 +0000 Subject: [PATCH 3/5] Gather --- .../conversion/custom_ops_converters.py | 36 ++++--- .../dynamo/conversion/impl/nccl_ops.py | 92 +++++++++++++++++- .../lowering/passes/fuse_distributed_ops.py | 47 +++++++-- .../py/dynamo/distributed/test_native_nccl.py | 97 +++++++++++++------ 4 files changed, 217 insertions(+), 55 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py index 48ad5e816a..7bd70705a0 100644 --- a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py @@ -15,9 +15,10 @@ from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import ( tensorrt_fused_nccl_all_gather_op, tensorrt_fused_nccl_all_reduce_op, - tensorrt_fused_nccl_reduce_scatter_op, tensorrt_fused_nccl_all_to_all_op, - tensorrt_fused_nccl_scatter_op + tensorrt_fused_nccl_gather_op, + tensorrt_fused_nccl_reduce_scatter_op, + tensorrt_fused_nccl_scatter_op, ) _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -29,7 +30,7 @@ @dynamo_tensorrt_converter( tensorrt_fused_nccl_all_gather_op, requires_native_multidevice=True ) - def fused_nccl_gather( + def fused_nccl_all_gather( ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], @@ -37,7 +38,7 @@ def fused_nccl_gather( name: str, ) -> Union[trt.ITensor, Sequence[trt.ITensor]]: """All-gather using native TensorRT DistCollective API""" - return impl.nccl_ops.nccl_gather_native( + return impl.nccl_ops.nccl_all_gather_native( ctx, target, SourceIR.ATEN, @@ -120,12 +121,23 @@ def fused_nccl_scatter( """Scatter using native TensorRT DistCollective API.""" root = args[1] if len(args) > 1 else 0 return impl.nccl_ops.nccl_scatter_native( - ctx, - target, - SourceIR.ATEN, - name, - [args[0]], - root=root + ctx, target, SourceIR.ATEN, name, [args[0]], root=root + ) + + @dynamo_tensorrt_converter( + tensorrt_fused_nccl_gather_op, requires_native_multidevice=True + ) + def fused_nccl_gather( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, + ) -> Union[trt.ITensor, Sequence[trt.ITensor]]: + """Gather using native TensorRT DistCollective API.""" + root = args[1] if len(args) > 1 else 0 + return impl.nccl_ops.nccl_gather_native( + ctx, target, SourceIR.ATEN, name, [args[0]], root=root ) @@ -143,14 +155,14 @@ def fused_nccl_scatter( ) @dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op) - def fused_nccl_gather( + def fused_nccl_all_gather( ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[trt.ITensor, Sequence[trt.ITensor]]: - return impl.nccl_ops.nccl_gather( + return impl.nccl_ops.nccl_all_gather( ctx, target, SourceIR.ATEN, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py b/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py index e82a35b681..f26879235f 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py @@ -79,7 +79,7 @@ def _get_distributed_rank_and_world_size() -> Tuple[int, int]: return rank, world_size -def nccl_gather( +def nccl_all_gather( ctx: ConversionContext, target: Union[Target, str], source_ir: Optional[SourceIR], @@ -220,7 +220,7 @@ def nccl_reduce_scatter( return layer.get_output(0) -def nccl_gather_native( +def nccl_all_gather_native( ctx: ConversionContext, target: Union[Target, str], source_ir: Optional[SourceIR], @@ -579,7 +579,7 @@ def nccl_scatter_native( source_ir: Optional[SourceIR], name: str, plug_inputs: Tuple[Argument, ...], - root: int = 0 + root: int = 0, ) -> trt.ITensor: """ Implement scatter using native TensorRT DistCollective API. @@ -624,7 +624,7 @@ def nccl_scatter_native( input_tensor, trt.CollectiveOperation.SCATTER, trt.ReduceOperation.NONE, # Ignored for SCATTER - root, + root, groups, # None means all ranks participate (world_size ranks) ) @@ -655,3 +655,87 @@ def nccl_scatter_native( except Exception as e: logger.error(f"Native SCATTER failed: {e} (type: {type(e).__name__})") raise + + +def nccl_gather_native( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + plug_inputs: Tuple[Argument, ...], + root: int = 0, +) -> trt.ITensor: + """ + Implement gather using native TensorRT DistCollective API. + + This operation has the root rank receive a chunk of data from every other rank + + Returns: + Output tensor after scatter operation + + Example: + Input on rank 0: [1] shape=(1,) + Input on rank 1: [2] shape=(1,) + Output on rank 0: [1, 2] shape=(2,) + Output on rank 1: [undefined, undefined] shape=(2,) + """ + rank, world_size = _get_distributed_rank_and_world_size() + + # TRT add_dist_collective crashes with world_size=1; scatter of a single rank + # is an identity op. + if world_size == 1: + return plug_inputs[0] + logger.debug( + f"Adding native scatter: name={name}, rank={rank}, world_size={world_size}" + ) + + # Get the input tensor + input_tensor = plug_inputs[0] + + try: + # Use native TensorRT DistCollective API for GATHER + # For GATHER, the reduce operation parameter is ignored + # The last parameter (group) can be None to include all ranks + import numpy as np + + # Create array of all participating rank IDs [0, 1, 2, ..., world_size-1] + groups = np.arange(world_size, dtype=np.int64) + + logger.debug( + f"Creating scatter layer: groups={groups.tolist()}, groupSize={world_size}" + ) + layer = ctx.net.add_dist_collective( + input_tensor, + trt.CollectiveOperation.GATHER, + trt.ReduceOperation.NONE, # Ignored for GATHER + root, + groups, # None means all ranks participate (world_size ranks) + ) + + logger.debug(f"Successfully created native GATHER layer: {name}") + logger.debug( + f"Calling add_dist_collective: input_shape={input_tensor.shape}, " + f"root={root}, groups={groups.tolist()}, groupSize={len(groups)} (inferred from array)" + ) + + set_layer_name(layer, target, name, source_ir) + + output = layer.get_output(0) + layer.num_ranks = world_size + + return output + + except AttributeError as e: + error_msg = ( + f"Native GATHER failed: {e}. " + "This usually means TensorRT doesn't support native distributed collectives. " + f"Your TensorRT version: {trt.__version__}. " + "Native collectives require TensorRT 11 or later. " + "Consider using TensorRT-LLM plugins instead by setting USE_NATIVE_TRT_COLLECTIVES=0" + ) + logger.error(error_msg) + raise RuntimeError(error_msg) from e + + except Exception as e: + logger.error(f"Native GATHER failed: {e} (type: {type(e).__name__})") + raise diff --git a/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py b/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py index 4c8511a662..930ad4d7b6 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py @@ -70,7 +70,10 @@ def _(inp: torch.Tensor, reduce_op: str, group_name: str) -> torch.Tensor: @torch.library.custom_op("tensorrt::fused_nccl_all_to_all", mutates_args=()) def _fused_nccl_all_to_all_impl( - inp: torch.Tensor, output_splits: list[int] | None, input_splits: list[int] | None, group_name: str + inp: torch.Tensor, + output_splits: list[int] | None, + input_splits: list[int] | None, + group_name: str, ) -> torch.Tensor: out_shape = inp.shape return inp.new_empty(out_shape) @@ -78,10 +81,15 @@ def _fused_nccl_all_to_all_impl( @_fused_nccl_all_to_all_impl.register_fake def _( - inp: torch.Tensor, output_splits: list[int] | None, input_splits: list[int] | None, group_name: str + inp: torch.Tensor, + output_splits: list[int] | None, + input_splits: list[int] | None, + group_name: str, ) -> torch.Tensor: return torch.ops._c10d_functional.wait_tensor.default( - torch.ops._c10d_functional.all_to_all_single.default(inp, output_splits, input_splits, group_name) + torch.ops._c10d_functional.all_to_all_single.default( + inp, output_splits, input_splits, group_name + ) ) @@ -100,13 +108,33 @@ def _fused_nccl_scatter_impl( @_fused_nccl_scatter_impl.register_fake -def _( +def _(inp: torch.Tensor, src: int, group_name: str) -> torch.Tensor: + world_size = torch.distributed.get_world_size() + out_shape = (inp.shape[0] // world_size,) + tuple(inp.shape[1:]) + return inp.new_empty(out_shape) + + +@torch.library.custom_op("tensorrt::fused_nccl_gather", mutates_args=()) +def _fused_nccl_gather_impl( inp: torch.Tensor, src: int, group_name: str ) -> torch.Tensor: + + # Perform all_gather world_size = torch.distributed.get_world_size() - out_shape = (inp.shape[0] // world_size,) + tuple(inp.shape[1:]) + out = _fused_nccl_all_gather_impl(inp, world_size, group_name) + + # TRT leads to undefined data after gather on non-root ranks + # so maintain that here for parity's sake + rank = torch.distributed.get_rank() + return out if rank == src else torch.empty_like(out) + + +@_fused_nccl_gather_impl.register_fake +def _(inp: torch.Tensor, src: int, group_name: str) -> torch.Tensor: + world_size = torch.distributed.get_world_size() + out_shape = (inp.shape[0] * world_size,) + tuple(inp.shape[1:]) return inp.new_empty(out_shape) - + # Public aliases — used as FX node targets in the fuse pass, as converter keys # in custom_ops_converters.py, and in test equality checks. Each is the @@ -118,6 +146,7 @@ def _( tensorrt_fused_nccl_all_reduce_op = torch.ops.tensorrt.fused_nccl_all_reduce.default tensorrt_fused_nccl_all_to_all_op = torch.ops.tensorrt.fused_nccl_all_to_all.default tensorrt_fused_nccl_scatter_op = torch.ops.tensorrt.fused_nccl_scatter.default +tensorrt_fused_nccl_gather_op = torch.ops.tensorrt.fused_nccl_gather.default def fuse_distributed_ops( @@ -131,7 +160,7 @@ def fuse_distributed_ops( torch.ops._c10d_functional.all_gather_into_tensor.default, torch.ops._c10d_functional.reduce_scatter_tensor.default, torch.ops._c10d_functional.all_reduce.default, - torch.ops._c10d_functional.all_to_all_single.default + torch.ops._c10d_functional.all_to_all_single.default, ) and len(node.users) == 1 and list(node.users)[0].target @@ -154,9 +183,7 @@ def fuse_distributed_ops( target=tensorrt_fused_nccl_reduce_scatter_op, args=(node.args[0], node.args[1], node.args[2], node.args[3]), ) - elif ( - node.target == torch.ops._c10d_functional.all_to_all_single.default - ): + elif node.target == torch.ops._c10d_functional.all_to_all_single.default: with gm.graph.inserting_after(wait_tensor_node): fused_node = gm.graph.create_node( op="call_function", diff --git a/tests/py/dynamo/distributed/test_native_nccl.py b/tests/py/dynamo/distributed/test_native_nccl.py index ec01255626..f8fa813055 100644 --- a/tests/py/dynamo/distributed/test_native_nccl.py +++ b/tests/py/dynamo/distributed/test_native_nccl.py @@ -60,8 +60,6 @@ ) from torch.testing._internal.common_utils import run_tests -import pytest - # --------------------------------------------------------------------------- # helpers # --------------------------------------------------------------------------- @@ -787,25 +785,6 @@ def _build_graph(collective_op, args_without_input): return torch.fx.GraphModule({}, g) -def _build_graph_scatter(args_without_input): - """Build a minimal FX graph: input → scatter → wait_tensor → slice → output.""" - world_size = torch.distributed.get_world_size() - rank = torch.distributed.get_rank() - - g = torch.fx.Graph() - inp = g.placeholder("inp") - coll = g.call_function(torch.ops._c10d_functional.broadcast.default, args=(inp, *args_without_input)) - wait = g.call_function(torch.ops._c10d_functional.wait_tensor.default, args=(coll,)) - # To avoid tracing tensor sizes for test - # use dummy dimension to represent inp's - # 0th axis shape - dummy_shape_dim = 10 - chunk = dummy_shape_dim // world_size - chunk = g.call_function(torch.ops.aten.slice.Tensor, args=(wait, 0, rank * chunk, (rank+1) * chunk)) - g.output(chunk) - return torch.fx.GraphModule({}, g) - - def _node_targets(gm: torch.fx.GraphModule) -> list: return [n.target for n in gm.graph.nodes if n.op == "call_function"] @@ -818,7 +797,6 @@ class TestFuseDistributedOps(unittest.TestCase): handled correctly. """ - def _settings(self): from torch_tensorrt.dynamo._settings import CompilationSettings @@ -985,9 +963,7 @@ def test_fuse_all_to_all_replaces_pair(self) -> None: ) gm = self._run_pass(gm) targets = _node_targets(gm) - self.assertNotIn( - torch.ops._c10d_functional.all_to_all_single.default, targets - ) + self.assertNotIn(torch.ops._c10d_functional.all_to_all_single.default, targets) self.assertNotIn(torch.ops._c10d_functional.wait_tensor.default, targets) self.assertIn(tensorrt_fused_nccl_all_to_all_op, targets) @@ -1140,6 +1116,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) return torch.ops._c10d_functional.wait_tensor.default(out) + class _AllToAllModel(nn.Module): def __init__(self, dim: int, world_size: int, group_name: str) -> None: super().__init__() @@ -1155,8 +1132,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) return torch.ops._c10d_functional.wait_tensor.default(out) + class _ScatterModel(nn.Module): - def __init__(self, dim: int, root : int, group_name: str) -> None: + def __init__(self, dim: int, root: int, group_name: str) -> None: super().__init__() self.fc = nn.Linear(dim, dim) self.group_name = group_name @@ -1167,6 +1145,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.ops.tensorrt.fused_nccl_scatter(x, self.root, self.group_name) +class _GatherModel(nn.Module): + def __init__(self, dim: int, root: int, group_name: str) -> None: + super().__init__() + self.fc = nn.Linear(dim, dim) + self.group_name = group_name + self.root = root + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.fc(x) + return torch.ops.tensorrt.fused_nccl_gather(x, self.root, self.group_name) + @unittest.skipIf( not is_nccl_available(), @@ -1264,6 +1253,14 @@ def test_scatter_single_rank(self) -> None: [torch.randn(1, dim)], ) + def test_gather_single_rank(self) -> None: + """gather compiles and produces correct output on a single rank.""" + dim = 8 + self._run( + _GatherModel(dim, self.root, self.group_name), + [torch.randn(1, dim)], + ) + def test_distributed_mode_with_single_rank_subgroup(self) -> None: """distributed_context() selects the subgroup as NCCL communicator source.""" import torch_tensorrt @@ -1523,13 +1520,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.ops.tensorrt.fused_nccl_scatter(x, root, group_name) model = Scatter().to(device).eval() - # input is of shape (world_size, 4) laid out like + # input is of shape (world_size, 4) laid out like # [0, 0, 0, 0] # ... # [world_size - 1, world_size - 1, world_size - 1, world_size - 1] # Supply root only with valid input - inp = torch.arange(world_size, device=device).unsqueeze(1).repeat(1, 4) if rank == root else torch.full((world_size, 4), -1, device=device) + inp = ( + torch.arange(world_size, device=device).unsqueeze(1).repeat(1, 4) + if rank == root + else torch.full((world_size, 4), -1, device=device) + ) with torch.no_grad(): out = model(inp) @@ -1540,6 +1541,36 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: _check_close(out, expected_row, f"scatter row {out} rank={rank}") +def _multirank_gather_correctness( + root: int, rank: int, world_size: int, device: torch.device +) -> None: + """gather has the root rank receive a chunk of data from all other ranks.""" + group = dist.group.WORLD + group_name = group.group_name if hasattr(group, "group_name") else "" + + class Gather(nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.tensorrt.fused_nccl_gather(x, root, group_name) + + model = Gather().to(device).eval() + # input is of shape (1, 4) where all data is rank + inp = torch.full((1, 4), int(rank), device=device) + + with torch.no_grad(): + out = model(inp) + + # After gather: shape is (world_size, 4), row[i] == [i, i, i, i], only validate data on root + assert out.shape == torch.Size([world_size, 4]), f"Shape mismatch: {out.shape}" + + if rank == root: + # output is of shape (world_size, 4) laid out like + # [0, 0, 0, 0] + # ... + # [world_size - 1, world_size - 1, world_size - 1, world_size - 1] + expected = torch.arange(world_size, device=device).unsqueeze(1).repeat(1, 4) + _check_close(out, expected, f"gather received {out} expected {expected} ") + + def _multirank_reduce_scatter_all_reduce_ops( rank: int, world_size: int, device: torch.device ) -> None: @@ -2040,11 +2071,19 @@ def test_all_to_all_correctness(self) -> None: @requires_nccl() @skip_if_lt_x_gpu(2) def test_scatter_correctness(self) -> None: - """all_to_all sends a chunk from each rank to every other ank.""" + """scatter sends a chunk of data from the root rank to every other rank.""" device = self._init_dist() for i in range(self.world_size): _multirank_scatter_correctness(i, self.rank, self.world_size, device) + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_gather_correctness(self) -> None: + """gather receives a chunk of data from every other rank on the root rank.""" + device = self._init_dist() + for i in range(self.world_size): + _multirank_gather_correctness(i, self.rank, self.world_size, device) + @unittest.skipIf(not has_nccl_collectives(), "No NCCL collective support available") @requires_nccl() @skip_if_lt_x_gpu(2) From b49910fb87518d5de3888a00e077ae695c81be48 Mon Sep 17 00:00:00 2001 From: Joseph Loftin Date: Thu, 4 Jun 2026 00:22:42 +0000 Subject: [PATCH 4/5] Add Tests To Runner --- tests/py/dynamo/distributed/test_native_nccl.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/py/dynamo/distributed/test_native_nccl.py b/tests/py/dynamo/distributed/test_native_nccl.py index f8fa813055..f0d6ed607b 100644 --- a/tests/py/dynamo/distributed/test_native_nccl.py +++ b/tests/py/dynamo/distributed/test_native_nccl.py @@ -2140,6 +2140,8 @@ def run_multirank_tests() -> None: _multirank_all_gather_correctness, _multirank_reduce_scatter_all_reduce_ops, _multirank_all_to_all_correctness, + _multirank_scatter_correctness, + _multirank_gather_correctness, _multirank_distributed_mode_tp_model, _multirank_distributed_mode_subgroup, _multirank_cpp_runtime_bind_nccl, From 9336ca3022b73f88dd4d3b111c91655afa19d095 Mon Sep 17 00:00:00 2001 From: Joseph Loftin Date: Thu, 4 Jun 2026 00:30:14 +0000 Subject: [PATCH 5/5] Gather Typo --- py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py b/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py index f26879235f..6a4afd5c91 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py @@ -681,12 +681,12 @@ def nccl_gather_native( """ rank, world_size = _get_distributed_rank_and_world_size() - # TRT add_dist_collective crashes with world_size=1; scatter of a single rank + # TRT add_dist_collective crashes with world_size=1; gather of a single rank # is an identity op. if world_size == 1: return plug_inputs[0] logger.debug( - f"Adding native scatter: name={name}, rank={rank}, world_size={world_size}" + f"Adding native gather: name={name}, rank={rank}, world_size={world_size}" ) # Get the input tensor @@ -702,7 +702,7 @@ def nccl_gather_native( groups = np.arange(world_size, dtype=np.int64) logger.debug( - f"Creating scatter layer: groups={groups.tolist()}, groupSize={world_size}" + f"Creating gather layer: groups={groups.tolist()}, groupSize={world_size}" ) layer = ctx.net.add_dist_collective( input_tensor,