Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 58 additions & 4 deletions py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import (
tensorrt_fused_nccl_all_gather_op,
tensorrt_fused_nccl_all_reduce_op,
tensorrt_fused_nccl_all_to_all_op,
tensorrt_fused_nccl_gather_op,
tensorrt_fused_nccl_reduce_scatter_op,
tensorrt_fused_nccl_scatter_op,
)

_LOGGER: logging.Logger = logging.getLogger(__name__)
Expand All @@ -27,15 +30,15 @@
@dynamo_tensorrt_converter(
tensorrt_fused_nccl_all_gather_op, requires_native_multidevice=True
)
def fused_nccl_gather(
def fused_nccl_all_gather(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
"""All-gather using native TensorRT DistCollective API"""
return impl.nccl_ops.nccl_gather_native(
return impl.nccl_ops.nccl_all_gather_native(
ctx,
target,
SourceIR.ATEN,
Expand Down Expand Up @@ -86,6 +89,57 @@ def fused_nccl_all_reduce(
reduce_op=reduce_op,
)

@dynamo_tensorrt_converter(
tensorrt_fused_nccl_all_to_all_op, requires_native_multidevice=True
)
def fused_nccl_all_to_all(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
"""All-to-all using native TensorRT DistCollective API."""
return impl.nccl_ops.nccl_all_to_all_native(
ctx,
target,
SourceIR.ATEN,
name,
[args[0]],
)

@dynamo_tensorrt_converter(
tensorrt_fused_nccl_scatter_op, requires_native_multidevice=True
)
def fused_nccl_scatter(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
"""Scatter using native TensorRT DistCollective API."""
root = args[1] if len(args) > 1 else 0
return impl.nccl_ops.nccl_scatter_native(
ctx, target, SourceIR.ATEN, name, [args[0]], root=root
)

@dynamo_tensorrt_converter(
tensorrt_fused_nccl_gather_op, requires_native_multidevice=True
)
def fused_nccl_gather(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
"""Gather using native TensorRT DistCollective API."""
root = args[1] if len(args) > 1 else 0
return impl.nccl_ops.nccl_gather_native(
ctx, target, SourceIR.ATEN, name, [args[0]], root=root
)


# Conditionally register NCCL converters only if TensorRT-LLM plugin is available.
# We use an `if` statement instead of @needs_trtllm_for_nccl decorator because
Expand All @@ -101,14 +155,14 @@ def fused_nccl_all_reduce(
)

@dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op)
def fused_nccl_gather(
def fused_nccl_all_gather(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
return impl.nccl_ops.nccl_gather(
return impl.nccl_ops.nccl_all_gather(
ctx,
target,
SourceIR.ATEN,
Expand Down
256 changes: 254 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def _get_distributed_rank_and_world_size() -> Tuple[int, int]:
return rank, world_size


def nccl_gather(
def nccl_all_gather(
ctx: ConversionContext,
target: Union[Target, str],
source_ir: Optional[SourceIR],
Expand Down Expand Up @@ -220,7 +220,7 @@ def nccl_reduce_scatter(
return layer.get_output(0)


def nccl_gather_native(
def nccl_all_gather_native(
ctx: ConversionContext,
target: Union[Target, str],
source_ir: Optional[SourceIR],
Expand Down Expand Up @@ -487,3 +487,255 @@ def nccl_all_reduce_native(
except Exception as e:
logger.error(f"Native ALL_REDUCE failed: {e} (type: {type(e).__name__})")
raise


def nccl_all_to_all_native(
ctx: ConversionContext,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
plug_inputs: Tuple[Argument, ...],
) -> trt.ITensor:
"""
Implement all_to_all using native TensorRT DistCollective API.

This operation sends a chunk of data from each rank. The i-th rank will receive the data
at the i-th position of every other rank.

Returns:
Output tensor after all_to_all operation

Example:
Input on rank 0: [1, 2] shape=(2,)
Input on rank 1: [3, 4] shape=(2,)
Output on rank 0: [1, 3] shape=(2,)
Output on rank 1: [2, 4] shape=(2,)
"""
rank, world_size = _get_distributed_rank_and_world_size()

# TRT add_dist_collective crashes with world_size=1; all_to_all of a single rank
# is an identity op.
if world_size == 1:
return plug_inputs[0]
logger.debug(
f"Adding native all_to_all: name={name}, rank={rank}, world_size={world_size}"
)

# Get the input tensor
input_tensor = plug_inputs[0]

try:
# Use native TensorRT DistCollective API for ALL_TO_ALL
# For ALL_TO_ALL, the reduce operation and root rank parameters are ignored
# The last parameter (group) can be None to include all ranks
import numpy as np

# Create array of all participating rank IDs [0, 1, 2, ..., world_size-1]
groups = np.arange(world_size, dtype=np.int64)

logger.debug(
f"Creating ALL_TO_ALL layer: groups={groups.tolist()}, groupSize={world_size}"
)
layer = ctx.net.add_dist_collective(
input_tensor,
trt.CollectiveOperation.ALL_TO_ALL,
trt.ReduceOperation.NONE, # Ignored for ALL_TO_ALL
-1, # Root rank - ignored for ALL_TO_ALL
groups, # None means all ranks participate (world_size ranks)
)

logger.debug(f"Successfully created native ALL_TO_ALL layer: {name}")
logger.debug(
f"Calling add_dist_collective: input_shape={input_tensor.shape}, "
f"groups={groups.tolist()}, groupSize={len(groups)} (inferred from array)"
)

set_layer_name(layer, target, name, source_ir)

output = layer.get_output(0)
layer.num_ranks = world_size

return output

except AttributeError as e:
error_msg = (
f"Native ALL_TO_ALL failed: {e}. "
"This usually means TensorRT doesn't support native distributed collectives. "
f"Your TensorRT version: {trt.__version__}. "
"Native collectives require TensorRT 11 or later. "
"Consider using TensorRT-LLM plugins instead by setting USE_NATIVE_TRT_COLLECTIVES=0"
)
logger.error(error_msg)
raise RuntimeError(error_msg) from e

except Exception as e:
logger.error(f"Native ALL_TO_ALL failed: {e} (type: {type(e).__name__})")
raise


def nccl_scatter_native(
ctx: ConversionContext,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
plug_inputs: Tuple[Argument, ...],
root: int = 0,
) -> trt.ITensor:
"""
Implement scatter using native TensorRT DistCollective API.

This operation has the root rank send a chunk of its data to every other rank.

Returns:
Output tensor after scatter operation

Example:
Input on rank 0: [1, 2] shape=(2,)
Input on rank 1: None shape=(2,)
Output on rank 0: [1] shape=(1,)
Output on rank 1: [2] shape=(1,)
"""
rank, world_size = _get_distributed_rank_and_world_size()

# TRT add_dist_collective crashes with world_size=1; scatter of a single rank
# is an identity op.
if world_size == 1:
return plug_inputs[0]
logger.debug(
f"Adding native scatter: name={name}, rank={rank}, world_size={world_size}"
)

# Get the input tensor
input_tensor = plug_inputs[0]

try:
# Use native TensorRT DistCollective API for SCATTER
# For SCATTER, the reduce operation parameter is ignored
# The last parameter (group) can be None to include all ranks
import numpy as np

# Create array of all participating rank IDs [0, 1, 2, ..., world_size-1]
groups = np.arange(world_size, dtype=np.int64)

logger.debug(
f"Creating scatter layer: groups={groups.tolist()}, groupSize={world_size}"
)
layer = ctx.net.add_dist_collective(
input_tensor,
trt.CollectiveOperation.SCATTER,
trt.ReduceOperation.NONE, # Ignored for SCATTER
root,
groups, # None means all ranks participate (world_size ranks)
)

logger.debug(f"Successfully created native SCATTER layer: {name}")
logger.debug(
f"Calling add_dist_collective: input_shape={input_tensor.shape}, "
f"root={root}, groups={groups.tolist()}, groupSize={len(groups)} (inferred from array)"
)

set_layer_name(layer, target, name, source_ir)

output = layer.get_output(0)
layer.num_ranks = world_size

return output

except AttributeError as e:
error_msg = (
f"Native SCATTER failed: {e}. "
"This usually means TensorRT doesn't support native distributed collectives. "
f"Your TensorRT version: {trt.__version__}. "
"Native collectives require TensorRT 11 or later. "
"Consider using TensorRT-LLM plugins instead by setting USE_NATIVE_TRT_COLLECTIVES=0"
)
logger.error(error_msg)
raise RuntimeError(error_msg) from e

except Exception as e:
logger.error(f"Native SCATTER failed: {e} (type: {type(e).__name__})")
raise


def nccl_gather_native(
ctx: ConversionContext,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
plug_inputs: Tuple[Argument, ...],
root: int = 0,
) -> trt.ITensor:
"""
Implement gather using native TensorRT DistCollective API.

This operation has the root rank receive a chunk of data from every other rank

Returns:
Output tensor after scatter operation

Example:
Input on rank 0: [1] shape=(1,)
Input on rank 1: [2] shape=(1,)
Output on rank 0: [1, 2] shape=(2,)
Output on rank 1: [undefined, undefined] shape=(2,)
"""
rank, world_size = _get_distributed_rank_and_world_size()

# TRT add_dist_collective crashes with world_size=1; gather of a single rank
# is an identity op.
if world_size == 1:
return plug_inputs[0]
logger.debug(
f"Adding native gather: name={name}, rank={rank}, world_size={world_size}"
)

# Get the input tensor
input_tensor = plug_inputs[0]

try:
# Use native TensorRT DistCollective API for GATHER
# For GATHER, the reduce operation parameter is ignored
# The last parameter (group) can be None to include all ranks
import numpy as np

# Create array of all participating rank IDs [0, 1, 2, ..., world_size-1]
groups = np.arange(world_size, dtype=np.int64)

logger.debug(
f"Creating gather layer: groups={groups.tolist()}, groupSize={world_size}"
)
layer = ctx.net.add_dist_collective(
input_tensor,
trt.CollectiveOperation.GATHER,
trt.ReduceOperation.NONE, # Ignored for GATHER
root,
groups, # None means all ranks participate (world_size ranks)
)

logger.debug(f"Successfully created native GATHER layer: {name}")
logger.debug(
f"Calling add_dist_collective: input_shape={input_tensor.shape}, "
f"root={root}, groups={groups.tolist()}, groupSize={len(groups)} (inferred from array)"
)

set_layer_name(layer, target, name, source_ir)

output = layer.get_output(0)
layer.num_ranks = world_size

return output

except AttributeError as e:
error_msg = (
f"Native GATHER failed: {e}. "
"This usually means TensorRT doesn't support native distributed collectives. "
f"Your TensorRT version: {trt.__version__}. "
"Native collectives require TensorRT 11 or later. "
"Consider using TensorRT-LLM plugins instead by setting USE_NATIVE_TRT_COLLECTIVES=0"
)
logger.error(error_msg)
raise RuntimeError(error_msg) from e

except Exception as e:
logger.error(f"Native GATHER failed: {e} (type: {type(e).__name__})")
raise
Loading
Loading