diff --git a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py index cdb8cb4647..7bd70705a0 100644 --- a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py @@ -15,7 +15,10 @@ from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import ( tensorrt_fused_nccl_all_gather_op, tensorrt_fused_nccl_all_reduce_op, + tensorrt_fused_nccl_all_to_all_op, + tensorrt_fused_nccl_gather_op, tensorrt_fused_nccl_reduce_scatter_op, + tensorrt_fused_nccl_scatter_op, ) _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -27,7 +30,7 @@ @dynamo_tensorrt_converter( tensorrt_fused_nccl_all_gather_op, requires_native_multidevice=True ) - def fused_nccl_gather( + def fused_nccl_all_gather( ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], @@ -35,7 +38,7 @@ def fused_nccl_gather( name: str, ) -> Union[trt.ITensor, Sequence[trt.ITensor]]: """All-gather using native TensorRT DistCollective API""" - return impl.nccl_ops.nccl_gather_native( + return impl.nccl_ops.nccl_all_gather_native( ctx, target, SourceIR.ATEN, @@ -86,6 +89,57 @@ def fused_nccl_all_reduce( reduce_op=reduce_op, ) + @dynamo_tensorrt_converter( + tensorrt_fused_nccl_all_to_all_op, requires_native_multidevice=True + ) + def fused_nccl_all_to_all( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, + ) -> Union[trt.ITensor, Sequence[trt.ITensor]]: + """All-to-all using native TensorRT DistCollective API.""" + return impl.nccl_ops.nccl_all_to_all_native( + ctx, + target, + SourceIR.ATEN, + name, + [args[0]], + ) + + @dynamo_tensorrt_converter( + tensorrt_fused_nccl_scatter_op, requires_native_multidevice=True + ) + def fused_nccl_scatter( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, + ) -> Union[trt.ITensor, Sequence[trt.ITensor]]: + """Scatter using native TensorRT DistCollective API.""" + root = args[1] if len(args) > 1 else 0 + return impl.nccl_ops.nccl_scatter_native( + ctx, target, SourceIR.ATEN, name, [args[0]], root=root + ) + + @dynamo_tensorrt_converter( + tensorrt_fused_nccl_gather_op, requires_native_multidevice=True + ) + def fused_nccl_gather( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, + ) -> Union[trt.ITensor, Sequence[trt.ITensor]]: + """Gather using native TensorRT DistCollective API.""" + root = args[1] if len(args) > 1 else 0 + return impl.nccl_ops.nccl_gather_native( + ctx, target, SourceIR.ATEN, name, [args[0]], root=root + ) + # Conditionally register NCCL converters only if TensorRT-LLM plugin is available. # We use an `if` statement instead of @needs_trtllm_for_nccl decorator because @@ -101,14 +155,14 @@ def fused_nccl_all_reduce( ) @dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op) - def fused_nccl_gather( + def fused_nccl_all_gather( ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[trt.ITensor, Sequence[trt.ITensor]]: - return impl.nccl_ops.nccl_gather( + return impl.nccl_ops.nccl_all_gather( ctx, target, SourceIR.ATEN, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py b/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py index 26e67cab67..6a4afd5c91 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py @@ -79,7 +79,7 @@ def _get_distributed_rank_and_world_size() -> Tuple[int, int]: return rank, world_size -def nccl_gather( +def nccl_all_gather( ctx: ConversionContext, target: Union[Target, str], source_ir: Optional[SourceIR], @@ -220,7 +220,7 @@ def nccl_reduce_scatter( return layer.get_output(0) -def nccl_gather_native( +def nccl_all_gather_native( ctx: ConversionContext, target: Union[Target, str], source_ir: Optional[SourceIR], @@ -487,3 +487,255 @@ def nccl_all_reduce_native( except Exception as e: logger.error(f"Native ALL_REDUCE failed: {e} (type: {type(e).__name__})") raise + + +def nccl_all_to_all_native( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + plug_inputs: Tuple[Argument, ...], +) -> trt.ITensor: + """ + Implement all_to_all using native TensorRT DistCollective API. + + This operation sends a chunk of data from each rank. The i-th rank will receive the data + at the i-th position of every other rank. + + Returns: + Output tensor after all_to_all operation + + Example: + Input on rank 0: [1, 2] shape=(2,) + Input on rank 1: [3, 4] shape=(2,) + Output on rank 0: [1, 3] shape=(2,) + Output on rank 1: [2, 4] shape=(2,) + """ + rank, world_size = _get_distributed_rank_and_world_size() + + # TRT add_dist_collective crashes with world_size=1; all_to_all of a single rank + # is an identity op. + if world_size == 1: + return plug_inputs[0] + logger.debug( + f"Adding native all_to_all: name={name}, rank={rank}, world_size={world_size}" + ) + + # Get the input tensor + input_tensor = plug_inputs[0] + + try: + # Use native TensorRT DistCollective API for ALL_TO_ALL + # For ALL_TO_ALL, the reduce operation and root rank parameters are ignored + # The last parameter (group) can be None to include all ranks + import numpy as np + + # Create array of all participating rank IDs [0, 1, 2, ..., world_size-1] + groups = np.arange(world_size, dtype=np.int64) + + logger.debug( + f"Creating ALL_TO_ALL layer: groups={groups.tolist()}, groupSize={world_size}" + ) + layer = ctx.net.add_dist_collective( + input_tensor, + trt.CollectiveOperation.ALL_TO_ALL, + trt.ReduceOperation.NONE, # Ignored for ALL_TO_ALL + -1, # Root rank - ignored for ALL_TO_ALL + groups, # None means all ranks participate (world_size ranks) + ) + + logger.debug(f"Successfully created native ALL_TO_ALL layer: {name}") + logger.debug( + f"Calling add_dist_collective: input_shape={input_tensor.shape}, " + f"groups={groups.tolist()}, groupSize={len(groups)} (inferred from array)" + ) + + set_layer_name(layer, target, name, source_ir) + + output = layer.get_output(0) + layer.num_ranks = world_size + + return output + + except AttributeError as e: + error_msg = ( + f"Native ALL_TO_ALL failed: {e}. " + "This usually means TensorRT doesn't support native distributed collectives. " + f"Your TensorRT version: {trt.__version__}. " + "Native collectives require TensorRT 11 or later. " + "Consider using TensorRT-LLM plugins instead by setting USE_NATIVE_TRT_COLLECTIVES=0" + ) + logger.error(error_msg) + raise RuntimeError(error_msg) from e + + except Exception as e: + logger.error(f"Native ALL_TO_ALL failed: {e} (type: {type(e).__name__})") + raise + + +def nccl_scatter_native( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + plug_inputs: Tuple[Argument, ...], + root: int = 0, +) -> trt.ITensor: + """ + Implement scatter using native TensorRT DistCollective API. + + This operation has the root rank send a chunk of its data to every other rank. + + Returns: + Output tensor after scatter operation + + Example: + Input on rank 0: [1, 2] shape=(2,) + Input on rank 1: None shape=(2,) + Output on rank 0: [1] shape=(1,) + Output on rank 1: [2] shape=(1,) + """ + rank, world_size = _get_distributed_rank_and_world_size() + + # TRT add_dist_collective crashes with world_size=1; scatter of a single rank + # is an identity op. + if world_size == 1: + return plug_inputs[0] + logger.debug( + f"Adding native scatter: name={name}, rank={rank}, world_size={world_size}" + ) + + # Get the input tensor + input_tensor = plug_inputs[0] + + try: + # Use native TensorRT DistCollective API for SCATTER + # For SCATTER, the reduce operation parameter is ignored + # The last parameter (group) can be None to include all ranks + import numpy as np + + # Create array of all participating rank IDs [0, 1, 2, ..., world_size-1] + groups = np.arange(world_size, dtype=np.int64) + + logger.debug( + f"Creating scatter layer: groups={groups.tolist()}, groupSize={world_size}" + ) + layer = ctx.net.add_dist_collective( + input_tensor, + trt.CollectiveOperation.SCATTER, + trt.ReduceOperation.NONE, # Ignored for SCATTER + root, + groups, # None means all ranks participate (world_size ranks) + ) + + logger.debug(f"Successfully created native SCATTER layer: {name}") + logger.debug( + f"Calling add_dist_collective: input_shape={input_tensor.shape}, " + f"root={root}, groups={groups.tolist()}, groupSize={len(groups)} (inferred from array)" + ) + + set_layer_name(layer, target, name, source_ir) + + output = layer.get_output(0) + layer.num_ranks = world_size + + return output + + except AttributeError as e: + error_msg = ( + f"Native SCATTER failed: {e}. " + "This usually means TensorRT doesn't support native distributed collectives. " + f"Your TensorRT version: {trt.__version__}. " + "Native collectives require TensorRT 11 or later. " + "Consider using TensorRT-LLM plugins instead by setting USE_NATIVE_TRT_COLLECTIVES=0" + ) + logger.error(error_msg) + raise RuntimeError(error_msg) from e + + except Exception as e: + logger.error(f"Native SCATTER failed: {e} (type: {type(e).__name__})") + raise + + +def nccl_gather_native( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + plug_inputs: Tuple[Argument, ...], + root: int = 0, +) -> trt.ITensor: + """ + Implement gather using native TensorRT DistCollective API. + + This operation has the root rank receive a chunk of data from every other rank + + Returns: + Output tensor after scatter operation + + Example: + Input on rank 0: [1] shape=(1,) + Input on rank 1: [2] shape=(1,) + Output on rank 0: [1, 2] shape=(2,) + Output on rank 1: [undefined, undefined] shape=(2,) + """ + rank, world_size = _get_distributed_rank_and_world_size() + + # TRT add_dist_collective crashes with world_size=1; gather of a single rank + # is an identity op. + if world_size == 1: + return plug_inputs[0] + logger.debug( + f"Adding native gather: name={name}, rank={rank}, world_size={world_size}" + ) + + # Get the input tensor + input_tensor = plug_inputs[0] + + try: + # Use native TensorRT DistCollective API for GATHER + # For GATHER, the reduce operation parameter is ignored + # The last parameter (group) can be None to include all ranks + import numpy as np + + # Create array of all participating rank IDs [0, 1, 2, ..., world_size-1] + groups = np.arange(world_size, dtype=np.int64) + + logger.debug( + f"Creating gather layer: groups={groups.tolist()}, groupSize={world_size}" + ) + layer = ctx.net.add_dist_collective( + input_tensor, + trt.CollectiveOperation.GATHER, + trt.ReduceOperation.NONE, # Ignored for GATHER + root, + groups, # None means all ranks participate (world_size ranks) + ) + + logger.debug(f"Successfully created native GATHER layer: {name}") + logger.debug( + f"Calling add_dist_collective: input_shape={input_tensor.shape}, " + f"root={root}, groups={groups.tolist()}, groupSize={len(groups)} (inferred from array)" + ) + + set_layer_name(layer, target, name, source_ir) + + output = layer.get_output(0) + layer.num_ranks = world_size + + return output + + except AttributeError as e: + error_msg = ( + f"Native GATHER failed: {e}. " + "This usually means TensorRT doesn't support native distributed collectives. " + f"Your TensorRT version: {trt.__version__}. " + "Native collectives require TensorRT 11 or later. " + "Consider using TensorRT-LLM plugins instead by setting USE_NATIVE_TRT_COLLECTIVES=0" + ) + logger.error(error_msg) + raise RuntimeError(error_msg) from e + + except Exception as e: + logger.error(f"Native GATHER failed: {e} (type: {type(e).__name__})") + raise diff --git a/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py b/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py index 2772ba7d9f..930ad4d7b6 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py @@ -68,6 +68,74 @@ def _(inp: torch.Tensor, reduce_op: str, group_name: str) -> torch.Tensor: return torch.empty_like(inp) +@torch.library.custom_op("tensorrt::fused_nccl_all_to_all", mutates_args=()) +def _fused_nccl_all_to_all_impl( + inp: torch.Tensor, + output_splits: list[int] | None, + input_splits: list[int] | None, + group_name: str, +) -> torch.Tensor: + out_shape = inp.shape + return inp.new_empty(out_shape) + + +@_fused_nccl_all_to_all_impl.register_fake +def _( + inp: torch.Tensor, + output_splits: list[int] | None, + input_splits: list[int] | None, + group_name: str, +) -> torch.Tensor: + return torch.ops._c10d_functional.wait_tensor.default( + torch.ops._c10d_functional.all_to_all_single.default( + inp, output_splits, input_splits, group_name + ) + ) + + +@torch.library.custom_op("tensorrt::fused_nccl_scatter", mutates_args=()) +def _fused_nccl_scatter_impl( + inp: torch.Tensor, src: int, group_name: str +) -> torch.Tensor: + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + + out = torch.ops._c10d_functional.broadcast.default(inp, src, group_name) + out = torch.ops._c10d_functional.wait_tensor.default(out) + + chunk = out.shape[0] // world_size + return out[rank * chunk : (rank + 1) * chunk] + + +@_fused_nccl_scatter_impl.register_fake +def _(inp: torch.Tensor, src: int, group_name: str) -> torch.Tensor: + world_size = torch.distributed.get_world_size() + out_shape = (inp.shape[0] // world_size,) + tuple(inp.shape[1:]) + return inp.new_empty(out_shape) + + +@torch.library.custom_op("tensorrt::fused_nccl_gather", mutates_args=()) +def _fused_nccl_gather_impl( + inp: torch.Tensor, src: int, group_name: str +) -> torch.Tensor: + + # Perform all_gather + world_size = torch.distributed.get_world_size() + out = _fused_nccl_all_gather_impl(inp, world_size, group_name) + + # TRT leads to undefined data after gather on non-root ranks + # so maintain that here for parity's sake + rank = torch.distributed.get_rank() + return out if rank == src else torch.empty_like(out) + + +@_fused_nccl_gather_impl.register_fake +def _(inp: torch.Tensor, src: int, group_name: str) -> torch.Tensor: + world_size = torch.distributed.get_world_size() + out_shape = (inp.shape[0] * world_size,) + tuple(inp.shape[1:]) + return inp.new_empty(out_shape) + + # Public aliases — used as FX node targets in the fuse pass, as converter keys # in custom_ops_converters.py, and in test equality checks. Each is the # torch._ops.OpOverload created by the custom_op decoration above. @@ -76,6 +144,9 @@ def _(inp: torch.Tensor, reduce_op: str, group_name: str) -> torch.Tensor: torch.ops.tensorrt.fused_nccl_reduce_scatter.default ) tensorrt_fused_nccl_all_reduce_op = torch.ops.tensorrt.fused_nccl_all_reduce.default +tensorrt_fused_nccl_all_to_all_op = torch.ops.tensorrt.fused_nccl_all_to_all.default +tensorrt_fused_nccl_scatter_op = torch.ops.tensorrt.fused_nccl_scatter.default +tensorrt_fused_nccl_gather_op = torch.ops.tensorrt.fused_nccl_gather.default def fuse_distributed_ops( @@ -89,6 +160,7 @@ def fuse_distributed_ops( torch.ops._c10d_functional.all_gather_into_tensor.default, torch.ops._c10d_functional.reduce_scatter_tensor.default, torch.ops._c10d_functional.all_reduce.default, + torch.ops._c10d_functional.all_to_all_single.default, ) and len(node.users) == 1 and list(node.users)[0].target @@ -111,6 +183,13 @@ def fuse_distributed_ops( target=tensorrt_fused_nccl_reduce_scatter_op, args=(node.args[0], node.args[1], node.args[2], node.args[3]), ) + elif node.target == torch.ops._c10d_functional.all_to_all_single.default: + with gm.graph.inserting_after(wait_tensor_node): + fused_node = gm.graph.create_node( + op="call_function", + target=tensorrt_fused_nccl_all_to_all_op, + args=(node.args[0], node.args[1], node.args[2], node.args[3]), + ) else: with gm.graph.inserting_after(wait_tensor_node): fused_node = gm.graph.create_node( diff --git a/tests/py/dynamo/distributed/test_native_nccl.py b/tests/py/dynamo/distributed/test_native_nccl.py index e4f5b18582..f0d6ed607b 100644 --- a/tests/py/dynamo/distributed/test_native_nccl.py +++ b/tests/py/dynamo/distributed/test_native_nccl.py @@ -949,6 +949,44 @@ def test_fuse_all_reduce_no_group_size_arg(self) -> None: "3-arg all_reduce node." ) + # -- all_to_all --------------------------------------------------------- + + def test_fuse_all_to_all_replaces_pair(self) -> None: + """all_to_all_single + wait_tensor → tensorrt_fused_nccl_all_to_all_op.""" + from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import ( + tensorrt_fused_nccl_all_to_all_op, + ) + + gm = _build_graph( + torch.ops._c10d_functional.all_to_all_single.default, + args_without_input=([], [], "test_group"), + ) + gm = self._run_pass(gm) + targets = _node_targets(gm) + self.assertNotIn(torch.ops._c10d_functional.all_to_all_single.default, targets) + self.assertNotIn(torch.ops._c10d_functional.wait_tensor.default, targets) + self.assertIn(tensorrt_fused_nccl_all_to_all_op, targets) + + def test_fuse_all_to_all_args(self) -> None: + """Fused all_to_all node carries (inp, input splits, output splits, group_name).""" + from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import ( + tensorrt_fused_nccl_all_to_all_op, + ) + + gm = _build_graph( + torch.ops._c10d_functional.all_to_all_single.default, + # input splits, output splits, group name + args_without_input=([1], [2], "grp"), + ) + gm = self._run_pass(gm) + fused = next( + n for n in gm.graph.nodes if n.target == tensorrt_fused_nccl_all_to_all_op + ) + # args: (inp_placeholder, [], [], "grp") + self.assertEqual(fused.args[1], [1]) + self.assertEqual(fused.args[2], [2]) + self.assertEqual(fused.args[3], "grp") + # -- no-fuse when wait_tensor has multiple users ----------------------- def test_fuse_when_wait_tensor_result_has_multiple_uses(self) -> None: @@ -1079,6 +1117,46 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.ops._c10d_functional.wait_tensor.default(out) +class _AllToAllModel(nn.Module): + def __init__(self, dim: int, world_size: int, group_name: str) -> None: + super().__init__() + self.fc = nn.Linear(dim, dim) + self.world_size = world_size + self.group_name = group_name + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.fc(x) + split_sizes = [x.shape[0] // self.world_size] * self.world_size + out = torch.ops._c10d_functional.all_to_all_single.default( + x, split_sizes, split_sizes, self.group_name + ) + return torch.ops._c10d_functional.wait_tensor.default(out) + + +class _ScatterModel(nn.Module): + def __init__(self, dim: int, root: int, group_name: str) -> None: + super().__init__() + self.fc = nn.Linear(dim, dim) + self.group_name = group_name + self.root = root + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.fc(x) + return torch.ops.tensorrt.fused_nccl_scatter(x, self.root, self.group_name) + + +class _GatherModel(nn.Module): + def __init__(self, dim: int, root: int, group_name: str) -> None: + super().__init__() + self.fc = nn.Linear(dim, dim) + self.group_name = group_name + self.root = root + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.fc(x) + return torch.ops.tensorrt.fused_nccl_gather(x, self.root, self.group_name) + + @unittest.skipIf( not is_nccl_available(), "Skipped: NCCL backend not available.", @@ -1104,6 +1182,7 @@ def setUpClass(cls) -> None: cls.group = dist.new_group(ranks=[0]) cls.group_name = cls.group.group_name cls.world_size = 1 + cls.root = 0 @classmethod def tearDownClass(cls) -> None: @@ -1158,6 +1237,30 @@ def test_reduce_scatter_single_rank(self) -> None: [torch.randn(1, dim)], ) + def test_all_to_all_single_rank(self) -> None: + """all_to_all compiles and produces correct output on a single rank.""" + dim = 8 + self._run( + _AllToAllModel(dim, self.world_size, self.group_name), + [torch.randn(1, dim)], + ) + + def test_scatter_single_rank(self) -> None: + """scatter compiles and produces correct output on a single rank.""" + dim = 8 + self._run( + _ScatterModel(dim, self.root, self.group_name), + [torch.randn(1, dim)], + ) + + def test_gather_single_rank(self) -> None: + """gather compiles and produces correct output on a single rank.""" + dim = 8 + self._run( + _GatherModel(dim, self.root, self.group_name), + [torch.randn(1, dim)], + ) + def test_distributed_mode_with_single_rank_subgroup(self) -> None: """distributed_context() selects the subgroup as NCCL communicator source.""" import torch_tensorrt @@ -1401,10 +1504,73 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # After gather: shape is (world_size, 4), row i == float(i) assert out.shape == torch.Size([world_size, 4]), f"Shape mismatch: {out.shape}" for r in range(world_size): - expected_row = torch.full((4,), float(r), device=device) + expected_row = torch.full((4,), float(r), device=device, dtype=out.dtype) _check_close(out[r], expected_row, f"all_gather row {r} rank={rank}") +def _multirank_scatter_correctness( + root: int, rank: int, world_size: int, device: torch.device +) -> None: + """all_gather concatenates tensors from all ranks in order.""" + group = dist.group.WORLD + group_name = group.group_name if hasattr(group, "group_name") else "" + + class Scatter(nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.tensorrt.fused_nccl_scatter(x, root, group_name) + + model = Scatter().to(device).eval() + # input is of shape (world_size, 4) laid out like + # [0, 0, 0, 0] + # ... + # [world_size - 1, world_size - 1, world_size - 1, world_size - 1] + + # Supply root only with valid input + inp = ( + torch.arange(world_size, device=device).unsqueeze(1).repeat(1, 4) + if rank == root + else torch.full((world_size, 4), -1, device=device) + ) + + with torch.no_grad(): + out = model(inp) + + # After scatter: shape is (1, 4), row[0] == [rank, rank, rank, rank] + assert out.shape == torch.Size([1, 4]), f"Shape mismatch: {out.shape}" + expected_row = torch.full((1, 4), int(rank), device=device) + _check_close(out, expected_row, f"scatter row {out} rank={rank}") + + +def _multirank_gather_correctness( + root: int, rank: int, world_size: int, device: torch.device +) -> None: + """gather has the root rank receive a chunk of data from all other ranks.""" + group = dist.group.WORLD + group_name = group.group_name if hasattr(group, "group_name") else "" + + class Gather(nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.tensorrt.fused_nccl_gather(x, root, group_name) + + model = Gather().to(device).eval() + # input is of shape (1, 4) where all data is rank + inp = torch.full((1, 4), int(rank), device=device) + + with torch.no_grad(): + out = model(inp) + + # After gather: shape is (world_size, 4), row[i] == [i, i, i, i], only validate data on root + assert out.shape == torch.Size([world_size, 4]), f"Shape mismatch: {out.shape}" + + if rank == root: + # output is of shape (world_size, 4) laid out like + # [0, 0, 0, 0] + # ... + # [world_size - 1, world_size - 1, world_size - 1, world_size - 1] + expected = torch.arange(world_size, device=device).unsqueeze(1).repeat(1, 4) + _check_close(out, expected, f"gather received {out} expected {expected} ") + + def _multirank_reduce_scatter_all_reduce_ops( rank: int, world_size: int, device: torch.device ) -> None: @@ -1483,6 +1649,37 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) +def _multirank_all_to_all_correctness( + rank: int, world_size: int, device: torch.device +) -> None: + """all_to_all sends a chunk from each rank to every other ank.""" + import torch_tensorrt + from torch_tensorrt.distributed._nccl_utils import setup_nccl_for_torch_tensorrt + + setup_nccl_for_torch_tensorrt() + group = dist.group.WORLD + group_name = group.group_name if hasattr(group, "group_name") else "" + + class AllToAll(nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + split_sizes = [x.shape[0] // world_size] * world_size + out = torch.ops._c10d_functional.all_to_all_single.default( + x, split_sizes, split_sizes, group_name + ) + return torch.ops._c10d_functional.wait_tensor.default(out) + + model = AllToAll().to(device).eval() + # Input: [0...world_size) + inp = torch.arange(world_size, device=device) + + with torch.no_grad(): + out = model(inp) + + # Result for rank r: [r] * world_size + expected = torch.tensor([rank] * world_size, device=device) + _check_close(out, expected, f"all_to_all rank={rank}") + + def _multirank_distributed_mode_tp_model( rank: int, world_size: int, device: torch.device ) -> None: @@ -1864,6 +2061,29 @@ def test_reduce_scatter_all_reduce_ops(self) -> None: device = self._init_dist() _multirank_reduce_scatter_all_reduce_ops(self.rank, self.world_size, device) + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_all_to_all_correctness(self) -> None: + """all_to_all sends a chunk from each rank to every other ank.""" + device = self._init_dist() + _multirank_all_to_all_correctness(self.rank, self.world_size, device) + + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_scatter_correctness(self) -> None: + """scatter sends a chunk of data from the root rank to every other rank.""" + device = self._init_dist() + for i in range(self.world_size): + _multirank_scatter_correctness(i, self.rank, self.world_size, device) + + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_gather_correctness(self) -> None: + """gather receives a chunk of data from every other rank on the root rank.""" + device = self._init_dist() + for i in range(self.world_size): + _multirank_gather_correctness(i, self.rank, self.world_size, device) + @unittest.skipIf(not has_nccl_collectives(), "No NCCL collective support available") @requires_nccl() @skip_if_lt_x_gpu(2) @@ -1919,6 +2139,9 @@ def run_multirank_tests() -> None: _multirank_all_reduce_correctness, _multirank_all_gather_correctness, _multirank_reduce_scatter_all_reduce_ops, + _multirank_all_to_all_correctness, + _multirank_scatter_correctness, + _multirank_gather_correctness, _multirank_distributed_mode_tp_model, _multirank_distributed_mode_subgroup, _multirank_cpp_runtime_bind_nccl,