Skip to content
157 changes: 155 additions & 2 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import os
import platform
import warnings
from typing import Any, Collection, List, Optional, Sequence, Union
from typing import Any, Collection, Dict, List, Optional, Sequence, Tuple, Union

import sympy
import torch
from torch.export import ExportedProgram
from torch.fx.node import Target
from torch.utils._sympy.numbers import int_oo
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import EngineCapability, dtype
from torch_tensorrt._features import ENABLED_FEATURES, needs_cross_compile
Expand Down Expand Up @@ -874,6 +876,149 @@ def _insert_complex_io_adapters(
partitioned_module.recompile()


def _build_user_symbol_bounds(
gm: torch.fx.GraphModule,
sample_arg_inputs: Sequence[Input],
sample_kwarg_inputs: dict[Any, Any],
) -> Dict[sympy.Symbol, Tuple[int, int]]:
"""Map ``sympy.Symbol -> (min, max)`` from dynamic ``Input``s, used to
fill ``Dim.DYNAMIC`` upper bounds without mutating ``ShapeEnv``.

Validates against finite exporter bounds: ``user_max > exp_max`` and
``user_min < exp_min`` raise (TRT would reject those shapes at runtime);
a strict subset narrows the engine profile to the user's bounds (info
log only); the ``user_min=1, exp_min=2`` case warns -- it's PyTorch's
0/1 specialization artifact, not a user error.
"""
placeholders = [n for n in gm.graph.nodes if n.op == "placeholder"]

sample_by_name: dict[str, Input] = {}
for i, node in enumerate(placeholders):
if i < len(sample_arg_inputs):
inp = sample_arg_inputs[i]
if isinstance(inp, Input) and inp.shape_mode == Input._ShapeMode.DYNAMIC:
sample_by_name[node.target] = inp
for name, inp in sample_kwarg_inputs.items():
if isinstance(inp, Input) and inp.shape_mode == Input._ShapeMode.DYNAMIC:
sample_by_name[name] = inp

user_symbol_bounds: Dict[sympy.Symbol, Tuple[int, int]] = {}
if not sample_by_name:
return user_symbol_bounds

for node in placeholders:
if node.target not in sample_by_name:
continue
sample_input = sample_by_name[node.target]
fake_val = node.meta.get("val")
if not isinstance(fake_val, torch.Tensor):
continue

min_shape = sample_input.shape["min_shape"]
max_shape = sample_input.shape["max_shape"]

for d, dim in enumerate(fake_val.size()):
if not isinstance(dim, torch.SymInt) or d >= len(min_shape):
continue
expr = dim.node.expr
# Composite exprs (e.g. ``2*s0``) are recomputed by
# ``ShapeEnv.bound_sympy``; overriding them directly would lie.
if not isinstance(expr, sympy.Symbol):
continue
if expr in user_symbol_bounds:
continue
user_min = int(min_shape[d])
user_max = int(max_shape[d])
user_symbol_bounds[expr] = (user_min, user_max)
logger.debug(
"Recorded user-supplied bounds for %s: [%d, %d]",
expr,
user_min,
user_max,
)

# The exported program may already bound this symbol to a finite
# range (e.g. Dim("batch", min=10, max=20)). The compiled TRT
# engine's optimization profile follows that range; any shape
# outside it is rejected by TensorRT at runtime
# (IExecutionContext::setInputShape "satisfyProfile" check).
# Validate the user's Input range against it here -- at compile
# time -- before they hit that opaque runtime error on a shape
# they explicitly declared in Input.min_shape / Input.max_shape.
shape_env = getattr(dim.node, "shape_env", None)
if shape_env is None:
continue
exp_range = shape_env.var_to_range.get(expr)
if exp_range is None:
continue
exp_lower = exp_range.lower
exp_upper = exp_range.upper
exp_max_unbounded = exp_upper is int_oo or exp_upper == sympy.oo
if exp_max_unbounded:
# Dim.DYNAMIC: user fills the gap (intended use).
continue
try:
exp_min = int(exp_lower)
exp_max = int(exp_upper)
except (TypeError, ValueError):
continue
if user_min == exp_min and user_max == exp_max:
continue
Comment on lines +957 to +966
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain what the logic here is?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block runs only when the exported program already bounds the symbol to a finite range (e.g. exported with Dim("batch", min=10, max=20)). The compiled engine's TRT optimization profile follows that exported range, so we compare the user's Input range against it here, at compile time, and handle four cases:

  • exported upper is unbounded (Dim.DYNAMIC) → user's Input fills the gap, nothing to validate (the continue just above);
  • Input range == exported range → nothing to do;
  • Input extends beyond the exported range (user_max > exp_max or user_min < exp_min) → those shapes can never run, so we raise now;
  • Input strictly inside the exported range → safe to honor, we narrow the engine profile to the Input.
    The user_min == exp_min and user_max == exp_max check on this line is just the "exact match → no-op" early-out.


mismatch = (
f"Dynamic dimension '{expr}': "
f"Input range [{user_min}, {user_max}] vs "
f"exported program range [{exp_min}, {exp_max}]."
)

if user_max > exp_max:
raise ValueError(
f"{mismatch} Input.max_shape ({user_max}) exceeds the "
f"exported program's max ({exp_max}). The program was "
f"exported with this dimension bounded to "
f"[{exp_min}, {exp_max}], so the compiled TensorRT engine "
f"cannot accept shapes above {exp_max}. Either re-export "
f"with Dim('{expr}', max={user_max}) or set "
f"Input.max_shape <= {exp_max}."
)

if user_min < exp_min:
# 1->2 is the 0/1 specialization artifact, not a user error.
if user_min == 1 and exp_min == 2:
logger.warning(
"%s Input.min_shape=1 but the exported program's min "
"is 2 (PyTorch 0/1 specialization -- Dim(min=1) is "
"recorded as min=2). The compiled engine's min will "
"be 2.",
mismatch,
)
continue
raise ValueError(
f"{mismatch} Input.min_shape ({user_min}) is below the "
f"exported program's min ({exp_min}). The program was "
f"exported with this dimension bounded to "
f"[{exp_min}, {exp_max}], so the compiled TensorRT engine "
f"cannot accept shapes below {exp_min}. Either re-export "
f"with Dim('{expr}', min={user_min}) or set "
f"Input.min_shape >= {exp_min}."
)

# Strict subset: engine profile narrows to the user's bounds
# (applied in ``extract_var_range_info``). Not a warning -- the
# user got exactly what they asked for.
logger.info(
"%s Narrowing engine profile to user bounds [%d, %d] "
"(exported program range was [%d, %d]).",
mismatch,
user_min,
user_max,
exp_min,
exp_max,
)

return user_symbol_bounds


@fn_supports_debugger # type: ignore[misc]
def compile_module(
gm: torch.fx.GraphModule,
Expand Down Expand Up @@ -905,6 +1050,12 @@ def compile_module(
if sample_kwarg_inputs is None:
sample_kwarg_inputs = {}

# Forwarded to the partitioner to fill Dim.DYNAMIC upper bounds.
# Read-only w.r.t. ShapeEnv so range_constraints survive save/re-export.
user_symbol_bounds = _build_user_symbol_bounds(
gm, sample_arg_inputs, sample_kwarg_inputs
)

# Configure user compilation settings to converters.
CONVERTERS.set_compilation_settings(settings)

Expand Down Expand Up @@ -1086,7 +1237,9 @@ def preserve_module_specs(
)

# Get the submodule inputs for min, opt, max shapes of the graph inputs
submodule_inputs = partitioning.construct_submodule_inputs(submodule)
submodule_inputs = partitioning.construct_submodule_inputs(
submodule, user_symbol_bounds=user_symbol_bounds
)

assert submodule_inputs is not None

Expand Down
38 changes: 32 additions & 6 deletions py/torch_tensorrt/dynamo/partitioning/common.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import logging
from typing import Any, Dict, Optional, Sequence, Set, Tuple

import sympy
import torch
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily

from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo.utils import (
COMPLEX_TO_REAL_DTYPE,
Expand All @@ -20,11 +20,14 @@ def construct_dynamic_input(
input_dtype: torch.dtype,
name: str = "",
is_shape_tensor: bool = False,
user_symbol_bounds: Optional[Dict[sympy.Symbol, Tuple[int, int]]] = None,
) -> Input:
"""
Constructs a torch_tensorrt.Input based on a symbolic input
Args:
input_shape: A symbolic shape / regular shape of a tensor (which can have a mix of SymInt nodes and static values)
user_symbol_bounds: Optional ``{sym: (min, max)}`` map; forwarded to
:func:`extract_var_range_info` to fill unbounded exporter uppers.
Returns:
A dynamic shaped torch_tensorrt.Input which has the properties of the symbolic shaped input.
"""
Expand All @@ -33,7 +36,9 @@ def construct_dynamic_input(
max_shape = []
for d, dim in enumerate(input_shape):
if isinstance(dim, torch.SymInt):
min_max_opt = extract_var_range_info(dim)
min_max_opt = extract_var_range_info(
dim, user_symbol_bounds=user_symbol_bounds
)
unwrapped_min_max_opt: Dict[str, int] = {}
if "min" not in min_max_opt or min_max_opt["min"] is None:
logger.warning(
Expand Down Expand Up @@ -85,9 +90,12 @@ def get_input(
dtype: torch.dtype,
name: str = "",
is_shape_tensor: bool = False,
user_symbol_bounds: Optional[Dict[sympy.Symbol, Tuple[int, int]]] = None,
) -> Input:
"""
Based on type of dimensions in the input_shape, construct regular or dynamic shaped inputs
Based on type of dimensions in the input_shape, construct regular or dynamic shaped inputs.

``user_symbol_bounds`` is forwarded to :func:`construct_dynamic_input`.
"""
if dtype in COMPLEX_TO_REAL_DTYPE:
real_dtype = COMPLEX_TO_REAL_DTYPE[dtype]
Expand All @@ -106,19 +114,25 @@ def get_input(
dtype,
name=name,
is_shape_tensor=is_shape_tensor,
user_symbol_bounds=user_symbol_bounds,
)
else:
return Input(
shape=input_shape, dtype=dtype, name=name, is_shape_tensor=is_shape_tensor
)


def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]:
def construct_submodule_inputs(
module: torch.fx.GraphModule,
user_symbol_bounds: Optional[Dict[sympy.Symbol, Tuple[int, int]]] = None,
) -> Sequence[Input]:
"""
Construct torch_tensorrt Inputs based on the module inputs.
The module inputs will have meta data which has the shape and dtype info
Args:
module: Input FX GraphModule
user_symbol_bounds: Optional ``{sym: (min, max)}`` map; forwarded to
:func:`get_input` to fill unbounded exporter uppers.
Returns:
Sequence of torch_tensorrt.Input's representing inputs to given module
"""
Expand All @@ -134,7 +148,12 @@ def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]:
if isinstance(input_meta, (FakeTensor, torch.Tensor)):
input_shape = input_meta.size()
torchtrt_inputs.append(
get_input(input_shape, input_meta.dtype, name=input.name)
get_input(
input_shape,
input_meta.dtype,
name=input.name,
user_symbol_bounds=user_symbol_bounds,
)
)
elif isinstance(input_meta, torch.SymInt):
# Assuming sym_integers | shape inputs always have torch.int64 dtype
Expand All @@ -144,6 +163,7 @@ def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]:
torch.int64,
name=input.name,
is_shape_tensor=True,
user_symbol_bounds=user_symbol_bounds,
)
)
elif isinstance(input_meta, torch.SymFloat):
Expand All @@ -153,6 +173,7 @@ def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]:
torch.float32,
name=input.name,
is_shape_tensor=False, # Only SymInt inputs are treated as shape tensors
user_symbol_bounds=user_symbol_bounds,
)
)
else:
Expand All @@ -164,7 +185,12 @@ def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]:
input_meta = input.meta["tensor_meta"]
input_shape = input_meta.shape
torchtrt_inputs.append(
get_input(input_shape, input_meta.dtype, name=input.name)
get_input(
input_shape,
input_meta.dtype,
name=input.name,
user_symbol_bounds=user_symbol_bounds,
)
)
else:
raise AssertionError(
Expand Down
50 changes: 43 additions & 7 deletions py/torch_tensorrt/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,9 +406,16 @@ def contains_sym_int(tensor: torch.Tensor) -> bool:
return any(isinstance(dim, torch.SymInt) for dim in tensor)


def extract_var_range_info(symbolic_integer: torch.SymInt) -> Dict[str, Optional[int]]:
"""
This function returns the min, max, opt values of a symbolic integer.
def extract_var_range_info(
symbolic_integer: torch.SymInt,
user_symbol_bounds: Optional[Dict[sympy.Symbol, Tuple[int, int]]] = None,
) -> Dict[str, Optional[int]]:
"""Return ``{min, max, opt}`` for a symbolic integer.

``user_symbol_bounds`` (read-only ``{sym: (min, max)}``) is consulted only
when the exporter's upper is unbounded; finite exporter bounds always win.
The lower is intersected with the exporter's so the 0/1 specialization
survives even if the user passes ``min_shape=0``.
"""
node = symbolic_integer.node
expr = node.expr
Expand All @@ -435,13 +442,42 @@ def extract_var_range_info(symbolic_integer: torch.SymInt) -> Dict[str, Optional
or expr.xreplace(var_to_val_map)
)
assert var_range, var_val
min_val, max_val = (
int(var_range.lower),
int(var_range.upper) if var_range.upper != int_oo else None,
)

# ``var_to_range`` returns ``int_oo`` for unbounded; ``bound_sympy`` (used
# for composite exprs like ``s0+s1``) returns ``sympy.oo`` instead. They
# are distinct objects -- check both, else ``int(sympy.oo)`` raises.
def _bound_to_int_or_none(value: Any) -> Optional[int]:
if value is int_oo or value is -int_oo:
return None
if value == sympy.oo or value == -sympy.oo:
return None
try:
return int(value)
except (TypeError, OverflowError, AttributeError):
return None

min_val_opt = _bound_to_int_or_none(var_range.lower)
max_val = _bound_to_int_or_none(var_range.upper)
# Unbounded lower shouldn't happen for tensor dims; fall back to 1.
min_val = min_val_opt if min_val_opt is not None else 1

# Torchdynamo 0/1 specialization outlier
min_val = 1 if min_val == 2 else min_val

# Apply user bounds whenever present. ``_build_user_symbol_bounds`` already
# rejects user ranges that exceed the exporter, so the only cases reaching
# here are: Dim.DYNAMIC (max_val is None), strict subset, or the 1->2
# specialization. Clamp defensively in case validation was skipped (no
# ShapeEnv access path).
if (
user_symbol_bounds
and isinstance(expr, sympy.Symbol)
and expr in user_symbol_bounds
):
user_min, user_max = user_symbol_bounds[expr]
min_val = max(min_val, int(user_min))
max_val = int(user_max) if max_val is None else min(max_val, int(user_max))

min_max_opt: Dict[str, Optional[int]] = {}
min_max_opt["min"] = min_val
min_max_opt["max"] = max_val
Expand Down
Loading
Loading