diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 67de7076fa..61808a8f43 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -12,6 +12,7 @@ from __future__ import annotations import math +import string from typing import Any, Optional, Sequence, Tuple, Union import numpy as np @@ -56,6 +57,7 @@ _INT64_MAX = 9223372036854775807 _INT64_MIN = -9223372036854775808 _MATH_PI = math.pi +_EINSUM_SYMBOLS = string.ascii_letters @torch_op("aten::_local_scalar_dense", trace_only=True) @@ -9791,6 +9793,117 @@ def aten_tril_indices(row: int, col: int, offset: int = 0) -> TensorType: raise NotImplementedError() +def _get_einsum_symbol(dim: int) -> str: + if dim >= len(_EINSUM_SYMBOLS): + raise ValueError("aten::_trilinear only supports up to 52 dimensions") + return _EINSUM_SYMBOLS[dim] + + +def _validate_trilinear_dims( + total_dim: int, dims: Sequence[int], dims_name: str +) -> None: + seen_dims = set() + for dim in dims: + if dim < 0 or dim >= total_dim: + raise ValueError( + f"aten::_trilinear {dims_name} values must be in [0, {total_dim})" + ) + if dim in seen_dims: + raise ValueError( + f"aten::_trilinear {dims_name} values must be unique" + ) + seen_dims.add(dim) + + +def _build_trilinear_subscript( + total_dim: int, expanded_dims: Sequence[int], dims_name: str +) -> str: + _validate_trilinear_dims(total_dim, expanded_dims, dims_name) + expanded_dims_set = set(expanded_dims) + return "".join( + _get_einsum_symbol(dim) for dim in range(total_dim) if dim not in expanded_dims_set + ) + + +def _build_trilinear_equation( + total_dim: int, + expand1: Sequence[int], + expand2: Sequence[int], + expand3: Sequence[int], + sumdim: Sequence[int], +) -> str: + _validate_trilinear_dims(total_dim, sumdim, "sumdim") + sumdim_set = set(sumdim) + output_subscript = "".join( + _get_einsum_symbol(dim) for dim in range(total_dim) if dim not in sumdim_set + ) + return ( + f"{_build_trilinear_subscript(total_dim, expand1, 'expand1')}," + f"{_build_trilinear_subscript(total_dim, expand2, 'expand2')}," + f"{_build_trilinear_subscript(total_dim, expand3, 'expand3')}" + f"->{output_subscript}" + ) + + +def _trilinear_input_rank(input_value: TensorType) -> int: + input_rank = getattr(input_value, "rank", None) + if input_rank is not None: + return input_rank + return len(input_value.shape) + + +def _trilinear_operand_total_dim( + input_value: TensorType, expanded_dims: Sequence[int] +) -> int: + return _trilinear_input_rank(input_value) + len(expanded_dims) + + +def _resolve_trilinear_total_dim( + i1: TensorType, + i2: TensorType, + i3: TensorType, + expand1: Sequence[int], + expand2: Sequence[int], + expand3: Sequence[int], +) -> int: + total_dim = _trilinear_operand_total_dim(i1, expand1) + candidate_total_dims = ( + ("i2", "expand2", _trilinear_operand_total_dim(i2, expand2)), + ("i3", "expand3", _trilinear_operand_total_dim(i3, expand3)), + ) + for input_name, expand_name, candidate_total_dim in candidate_total_dims: + if candidate_total_dim != total_dim: + raise ValueError( + "aten::_trilinear input ranks and expand dims must resolve " + "to the same total dimension; " + f"i1+expand1 resolved {total_dim}, but " + f"{input_name}+{expand_name} resolved {candidate_total_dim}" + ) + return total_dim + + +@torch_op("aten::_trilinear", trace_only=True) +def aten__trilinear( + i1: TReal, + i2: TReal, + i3: TReal, + expand1: Sequence[int], + expand2: Sequence[int], + expand3: Sequence[int], + sumdim: Sequence[int], + unroll_dim: int = 1, +) -> TReal: + """_trilinear(Tensor i1, Tensor i2, Tensor i3, int[] expand1, int[] expand2, int[] expand3, int[] sumdim, int unroll_dim=1) -> Tensor""" + + del unroll_dim + + total_dim = _resolve_trilinear_total_dim( + i1, i2, i3, expand1, expand2, expand3 + ) + equation = _build_trilinear_equation(total_dim, expand1, expand2, expand3, sumdim) + return op.Einsum(i1, i2, i3, equation=equation) + + def aten_triplet_margin_loss( anchor: TensorType, positive: TensorType, diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index a28a6c9cd9..6c3f50412a 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -68,6 +68,35 @@ def sample_inputs_bilinear(op_info, device, dtype, requires_grad, **kwargs): yield opinfo_core.SampleInput(input1, args=(input2, weight, None)) +def sample_inputs__trilinear(op_info, device, dtype, requires_grad, **kwargs): + """Sample inputs for aten._trilinear using bilinear's internal call pattern.""" + del op_info + del kwargs + + make_arg = functools.partial( + torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + + cases = [ + (2, 3, 4, 5), + (1, 2, 2, 1), + (3, 5, 2, 4), + ] + expand1 = (1, 3) + expand2 = (0,) + expand3 = (1, 2) + sumdim = (2, 3) + + for batch_size, in1_features, in2_features, out_features in cases: + input1 = make_arg((batch_size, in1_features)) + weight = make_arg((out_features, in1_features, in2_features)) + input2 = make_arg((batch_size, in2_features)) + yield opinfo_core.SampleInput( + input1, + args=(weight, input2, expand1, expand2, expand3, sumdim, 1), + ) + + def sample_inputs_bernoulli_p(op_info, device, dtype, requires_grad, **kwargs): del op_info @@ -2516,6 +2545,13 @@ def __init__(self): sample_inputs_func=sample_inputs_bilinear, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten._trilinear.default", + aten_name="_trilinear.default", + dtypes=common_dtype.floating_types(), + sample_inputs_func=sample_inputs__trilinear, + supports_out=False, + ), opinfo_core.OpInfo( "ops.aten.bernoulli.p", aten_name="bernoulli.p", diff --git a/tests/function_libs/torch_lib/ops_test.py b/tests/function_libs/torch_lib/ops_test.py index beb74b5462..f6b5c2e2ac 100644 --- a/tests/function_libs/torch_lib/ops_test.py +++ b/tests/function_libs/torch_lib/ops_test.py @@ -40,6 +40,7 @@ import onnxscript from onnxscript._internal import version_utils +from onnxscript.function_libs.torch_lib.ops import core as core_ops from tests.function_libs.torch_lib import ( error_reproduction, ops_test_common, @@ -110,6 +111,33 @@ def test_script_function_passes_checker( onnx.checker.check_function(function_proto) # type: ignore[attr-defined] +class _FakeTensor: + __slots__ = ("shape",) + + def __init__(self, shape: Sequence[int]): + self.shape = shape + + +class TestTrilinearHelpers(unittest.TestCase): + def test_resolve_trilinear_total_dim_validates_operand_dims(self): + i1 = _FakeTensor((2, 3)) + i2 = _FakeTensor((4, 5, 6)) + i3 = _FakeTensor((7,)) + + with self.assertRaisesRegex( + ValueError, + "i2\\+expand2 resolved 3", + ): + core_ops._resolve_trilinear_total_dim( + i1, + i2, + i3, + (1, 3), + (), + (1, 2), + ) + + def run_test_output_match( test_suite: unittest.TestCase, device: str, diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index a40535f4ba..16e365af5a 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1293,6 +1293,11 @@ def _where_input_wrangler( dtypes=(torch.int32,), reason="fixme: ORT does not have an implementation of Trilu for int32.", ), + TorchLibOpInfo( + "ops.aten._trilinear.default", + core_ops.aten__trilinear, + tolerance={torch.float32: (2e-5, 2e-5)}, + ), TorchLibOpInfo("triu", core_ops.aten_triu).xfail( dtypes=(torch.int32,), reason="fixme: ORT does not have an implementation of Trilu for int32.",