Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from __future__ import annotations

import math
import string
from typing import Any, Optional, Sequence, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -56,6 +57,7 @@
_INT64_MAX = 9223372036854775807
_INT64_MIN = -9223372036854775808
_MATH_PI = math.pi
_EINSUM_SYMBOLS = string.ascii_letters


@torch_op("aten::_local_scalar_dense", trace_only=True)
Expand Down Expand Up @@ -9791,6 +9793,117 @@ def aten_tril_indices(row: int, col: int, offset: int = 0) -> TensorType:
raise NotImplementedError()


def _get_einsum_symbol(dim: int) -> str:
if dim >= len(_EINSUM_SYMBOLS):
raise ValueError("aten::_trilinear only supports up to 52 dimensions")
return _EINSUM_SYMBOLS[dim]


def _validate_trilinear_dims(
total_dim: int, dims: Sequence[int], dims_name: str
) -> None:
seen_dims = set()
for dim in dims:
if dim < 0 or dim >= total_dim:
raise ValueError(
f"aten::_trilinear {dims_name} values must be in [0, {total_dim})"
)
if dim in seen_dims:
raise ValueError(
f"aten::_trilinear {dims_name} values must be unique"
)
seen_dims.add(dim)


def _build_trilinear_subscript(
total_dim: int, expanded_dims: Sequence[int], dims_name: str
) -> str:
_validate_trilinear_dims(total_dim, expanded_dims, dims_name)
expanded_dims_set = set(expanded_dims)
return "".join(
_get_einsum_symbol(dim) for dim in range(total_dim) if dim not in expanded_dims_set
)


def _build_trilinear_equation(
total_dim: int,
expand1: Sequence[int],
expand2: Sequence[int],
expand3: Sequence[int],
sumdim: Sequence[int],
) -> str:
_validate_trilinear_dims(total_dim, sumdim, "sumdim")
sumdim_set = set(sumdim)
output_subscript = "".join(
_get_einsum_symbol(dim) for dim in range(total_dim) if dim not in sumdim_set
)
return (
f"{_build_trilinear_subscript(total_dim, expand1, 'expand1')},"
f"{_build_trilinear_subscript(total_dim, expand2, 'expand2')},"
f"{_build_trilinear_subscript(total_dim, expand3, 'expand3')}"
f"->{output_subscript}"
)


def _trilinear_input_rank(input_value: TensorType) -> int:
input_rank = getattr(input_value, "rank", None)
if input_rank is not None:
return input_rank
return len(input_value.shape)


def _trilinear_operand_total_dim(
input_value: TensorType, expanded_dims: Sequence[int]
) -> int:
return _trilinear_input_rank(input_value) + len(expanded_dims)


def _resolve_trilinear_total_dim(
i1: TensorType,
i2: TensorType,
i3: TensorType,
expand1: Sequence[int],
expand2: Sequence[int],
expand3: Sequence[int],
) -> int:
total_dim = _trilinear_operand_total_dim(i1, expand1)
candidate_total_dims = (
("i2", "expand2", _trilinear_operand_total_dim(i2, expand2)),
("i3", "expand3", _trilinear_operand_total_dim(i3, expand3)),
)
for input_name, expand_name, candidate_total_dim in candidate_total_dims:
if candidate_total_dim != total_dim:
raise ValueError(
"aten::_trilinear input ranks and expand dims must resolve "
"to the same total dimension; "
f"i1+expand1 resolved {total_dim}, but "
f"{input_name}+{expand_name} resolved {candidate_total_dim}"
)
return total_dim


@torch_op("aten::_trilinear", trace_only=True)
def aten__trilinear(
i1: TReal,
i2: TReal,
i3: TReal,
expand1: Sequence[int],
expand2: Sequence[int],
expand3: Sequence[int],
sumdim: Sequence[int],
unroll_dim: int = 1,
) -> TReal:
"""_trilinear(Tensor i1, Tensor i2, Tensor i3, int[] expand1, int[] expand2, int[] expand3, int[] sumdim, int unroll_dim=1) -> Tensor"""

del unroll_dim

total_dim = _resolve_trilinear_total_dim(
i1, i2, i3, expand1, expand2, expand3
)
equation = _build_trilinear_equation(total_dim, expand1, expand2, expand3, sumdim)
return op.Einsum(i1, i2, i3, equation=equation)


def aten_triplet_margin_loss(
anchor: TensorType,
positive: TensorType,
Expand Down
36 changes: 36 additions & 0 deletions tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,35 @@ def sample_inputs_bilinear(op_info, device, dtype, requires_grad, **kwargs):
yield opinfo_core.SampleInput(input1, args=(input2, weight, None))


def sample_inputs__trilinear(op_info, device, dtype, requires_grad, **kwargs):
"""Sample inputs for aten._trilinear using bilinear's internal call pattern."""
del op_info
del kwargs

make_arg = functools.partial(
torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
)

cases = [
(2, 3, 4, 5),
(1, 2, 2, 1),
(3, 5, 2, 4),
]
expand1 = (1, 3)
expand2 = (0,)
expand3 = (1, 2)
sumdim = (2, 3)
Comment thread
WineChord marked this conversation as resolved.

for batch_size, in1_features, in2_features, out_features in cases:
input1 = make_arg((batch_size, in1_features))
weight = make_arg((out_features, in1_features, in2_features))
input2 = make_arg((batch_size, in2_features))
yield opinfo_core.SampleInput(
input1,
args=(weight, input2, expand1, expand2, expand3, sumdim, 1),
)


def sample_inputs_bernoulli_p(op_info, device, dtype, requires_grad, **kwargs):
del op_info

Expand Down Expand Up @@ -2516,6 +2545,13 @@ def __init__(self):
sample_inputs_func=sample_inputs_bilinear,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten._trilinear.default",
aten_name="_trilinear.default",
dtypes=common_dtype.floating_types(),
sample_inputs_func=sample_inputs__trilinear,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten.bernoulli.p",
aten_name="bernoulli.p",
Expand Down
28 changes: 28 additions & 0 deletions tests/function_libs/torch_lib/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

import onnxscript
from onnxscript._internal import version_utils
from onnxscript.function_libs.torch_lib.ops import core as core_ops
from tests.function_libs.torch_lib import (
error_reproduction,
ops_test_common,
Expand Down Expand Up @@ -110,6 +111,33 @@ def test_script_function_passes_checker(
onnx.checker.check_function(function_proto) # type: ignore[attr-defined]


class _FakeTensor:
__slots__ = ("shape",)

def __init__(self, shape: Sequence[int]):
self.shape = shape
Comment on lines +114 to +118
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may use ir.value() to create these objects.



class TestTrilinearHelpers(unittest.TestCase):
def test_resolve_trilinear_total_dim_validates_operand_dims(self):
i1 = _FakeTensor((2, 3))
i2 = _FakeTensor((4, 5, 6))
i3 = _FakeTensor((7,))

with self.assertRaisesRegex(
ValueError,
"i2\\+expand2 resolved 3",
):
core_ops._resolve_trilinear_total_dim(
i1,
i2,
i3,
(1, 3),
(),
(1, 2),
)


def run_test_output_match(
test_suite: unittest.TestCase,
device: str,
Expand Down
5 changes: 5 additions & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1293,6 +1293,11 @@ def _where_input_wrangler(
dtypes=(torch.int32,),
reason="fixme: ORT does not have an implementation of Trilu for int32.",
),
TorchLibOpInfo(
"ops.aten._trilinear.default",
core_ops.aten__trilinear,
tolerance={torch.float32: (2e-5, 2e-5)},
),
TorchLibOpInfo("triu", core_ops.aten_triu).xfail(
dtypes=(torch.int32,),
reason="fixme: ORT does not have an implementation of Trilu for int32.",
Expand Down