Skip to content
Open
115 changes: 98 additions & 17 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1254,12 +1254,41 @@ def aten_binary_cross_entropy_with_logits(
raise NotImplementedError()


@torch_op("aten::bincount", trace_only=True)
def aten_bincount(
self: TensorType, weights: Optional[TensorType] = None, minlength: int = 0
self: IntType, weights: Optional[TensorType] = None, minlength: int = 0
) -> TensorType:
"""bincount(Tensor self, Tensor? weights=None, int minlength=0) -> Tensor"""
"""bincount(Tensor self, Tensor? weights=None, int minlength=0) -> Tensor

raise NotImplementedError()
``weights`` is not supported. Negative inputs are rejected by torch and are not
handled here (ONNX integer ops would wrap them around).
"""
if weights is not None:
raise NotImplementedError("aten::bincount with weights is not supported.")

self = op.Cast(self, to=INT64.dtype)
axis_0 = op.Constant(value_ints=[0])
# Append a 0 so ReduceMax is defined even when ``self`` is empty. It only sizes the
# output and never contributes to the counts (the scatter below uses ``self``).
data_max = op.Unsqueeze(
op.ReduceMax(op.Concat(self, op.Constant(value_ints=[0]), axis=0), keepdims=0),
axis_0,
)
# An empty input yields depth 0, so the output is empty unless ``minlength`` applies.
non_empty = op.Unsqueeze(
op.Cast(op.Greater(op.Size(self), op.Constant(value_int=0)), to=INT64.dtype),
axis_0,
)
depth = op.Mul(op.Add(data_max, op.Constant(value_ints=[1])), non_empty)
if minlength > 0:
depth = op.Max(depth, op.Constant(value_ints=[minlength]))

# Scatter-add 1 for each value into a zero vector of length ``depth``. This uses
# O(N + depth) memory instead of the dense O(N * depth) one-hot, and behaves
# correctly for empty inputs.
zeros = op.Expand(op.Constant(value_int=0), depth)
ones = op.Expand(op.Constant(value_int=1), op.Shape(self))
return op.ScatterElements(zeros, self, ones, axis=0, reduction="add")


def aten_binomial(
Expand Down Expand Up @@ -4976,8 +5005,14 @@ def is_advanced_index(index):
# will invalidate equality-based check.
first_shape = indices[advanced_indices[0]].shape

def same_shape(other_shape: ir.Shape) -> bool:
return (not any(d is None for d in other_shape)) and other_shape == first_shape
def same_shape(other_shape: Optional[ir.Shape]) -> bool:
return (
first_shape is not None
and other_shape is not None
and not any(d is None for d in first_shape)
and not any(d is None for d in other_shape)
and other_shape == first_shape
)

all_same_shape = all(same_shape(indices[i].shape) for i in advanced_indices)
if not all_same_shape:
Expand Down Expand Up @@ -5071,24 +5106,70 @@ def same_shape(other_shape: ir.Shape) -> bool:

def _aten_index_put_bool(
self: TReal,
indices: Sequence[BOOL],
indices: Sequence[Optional[Union[INT64, BOOL]]],
values: TReal,
accumulate: bool = False,
) -> TReal:
"""index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor"""

# TODO: Support indices with more than 1 elements
index = indices[0]
# accumulate should be always False, True does not make sense but an assert would be great
# Reshape indices so it can be properly broadcasted
bool_mask = indices[0]
if len(indices) > 1:
if any(index is None for index in indices):
raise NotImplementedError(
"Boolean index_put with multiple indices does not support None indices."
)

advanced_indices = []
selected_positions = []
minus_one = op.Constant(value_ints=[-1])
for index in indices:
if index.dtype != BOOL.dtype or len(index.shape) != 1:
raise NotImplementedError(
"Boolean index_put with multiple indices supports only 1-D boolean masks."
)
positions = op.Reshape(op.Transpose(op.NonZero(index), perm=[1, 0]), minus_one)
selected_positions.append(positions)
advanced_indices.append(op.Unsqueeze(positions, minus_one))
onnx_index = op.Concat(*advanced_indices, axis=-1)
target_shape = op.Concat(
op.Shape(selected_positions[0]),
op.Slice(op.Shape(self), starts=[len(indices)], ends=[len(self.shape)], axes=[0]),
axis=0,
)
expanded_values = op.Expand(values, target_shape)
return op.ScatterND(
self, onnx_index, expanded_values, reduction="add" if accumulate else None
)

if bool_mask is None or bool_mask.dtype != BOOL.dtype:
raise NotImplementedError(
"Boolean index_put expects a boolean mask as the first index."
)

neg_1 = op.Constant(value_ints=[-1])
self_rank = len(self.shape)
index_rank = len(index.shape)
if self_rank > index_rank:
index_shape = op.Shape(index)
padding = op.Constant(value_ints=[1 for _ in range(self_rank - index_rank)])
padded_shape = op.Concat(index_shape, padding, axis=0)
index = op.Reshape(index, padded_shape)
return op.Where(index, values, self)
mask_rank = len(bool_mask.shape)

# Expand a lower-rank mask (e.g. a row mask) across the trailing dimensions of self
# so it selects whole slices, then collect the coordinates of every selected element.
# NonZero returns them in row-major order.
expanded_mask = bool_mask
for _ in range(self_rank - mask_rank):
expanded_mask = op.Unsqueeze(expanded_mask, neg_1)
expanded_mask = op.Expand(expanded_mask, op.Shape(self))
selected_indices = op.Transpose(op.NonZero(expanded_mask), perm=[1, 0])

# Broadcast ``values`` to the selection shape ``[num_true, *self.shape[mask_rank:]]``
# and flatten it to one update per selected element. This keeps scalar and
# broadcastable ``values`` working, matching ``self[mask] = values`` semantics.
num_true = op.ReduceSum(op.Cast(op.Reshape(bool_mask, neg_1), to=INT64.dtype), keepdims=1)
trailing_shape = op.Slice(op.Shape(self), starts=[mask_rank], ends=[self_rank], axes=[0])
selection_shape = op.Concat(num_true, trailing_shape, axis=0)
flat_values = op.Reshape(op.Expand(values, selection_shape), neg_1)

return op.ScatterND(
self, selected_indices, flat_values, reduction="add" if accumulate else None
)


def aten_index_reduce(
Expand Down
112 changes: 112 additions & 0 deletions tests/function_libs/torch_lib/e2e_ops_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,45 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
)
_testing.assert_onnx_program(onnx_program)

def test_bincount(self):
class Model(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.bincount(x, minlength=6)

onnx_program = torch.onnx.export(
Model(),
(torch.tensor([0, 1, 1, 3, 5], dtype=torch.int64),),
dynamo=True,
optimize=False,
)
_testing.assert_onnx_program(onnx_program)

Comment thread
justinchuby marked this conversation as resolved.
def test_bincount_default_minlength(self):
class Model(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.bincount(x)

onnx_program = torch.onnx.export(
Model(),
(torch.tensor([2, 2, 2], dtype=torch.int64),),
dynamo=True,
optimize=False,
)
_testing.assert_onnx_program(onnx_program)

def test_bincount_empty_input(self):
class Model(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.bincount(x, minlength=4)

onnx_program = torch.onnx.export(
Model(),
(torch.tensor([], dtype=torch.int64),),
dynamo=True,
optimize=False,
)
_testing.assert_onnx_program(onnx_program)

def test_repeat_interleave_integer_1(self):
class Model(torch.nn.Module):
def forward(self, x):
Expand Down Expand Up @@ -902,6 +941,79 @@ def forward(self, x, index, update):
)
_testing.assert_onnx_program(onnx_program)

def test_index_put_bool_mask(self):
class Model(torch.nn.Module):
def forward(self, x, mask, update):
return torch.ops.aten.index_put(x, [mask], update)

x = torch.zeros((2, 3), dtype=torch.float32)
mask = torch.tensor([[True, False, True], [False, True, False]], dtype=torch.bool)
update = torch.tensor([10.0, 20.0, 30.0], dtype=torch.float32)
onnx_program = torch.onnx.export(
Model(),
(x, mask, update),
input_names=["x", "mask", "update"],
output_names=["output"],
opset_version=18,
dynamo=True,
)
_testing.assert_onnx_program(onnx_program)

def test_index_put_bool_mask_scalar_value(self):
class Model(torch.nn.Module):
def forward(self, x, mask, update):
return torch.ops.aten.index_put(x, [mask], update)

x = torch.arange(6, dtype=torch.float32).reshape((2, 3))
mask = torch.tensor([[True, False, True], [False, True, False]], dtype=torch.bool)
update = torch.tensor(5.0, dtype=torch.float32)
onnx_program = torch.onnx.export(
Model(),
(x, mask, update),
input_names=["x", "mask", "update"],
output_names=["output"],
opset_version=18,
dynamo=True,
)
_testing.assert_onnx_program(onnx_program)

def test_index_put_bool_row_mask_scalar_value(self):
class Model(torch.nn.Module):
def forward(self, x, mask, update):
return torch.ops.aten.index_put(x, [mask], update)

x = torch.arange(6, dtype=torch.float32).reshape((2, 3))
mask = torch.tensor([True, False], dtype=torch.bool)
update = torch.tensor(7.0, dtype=torch.float32)
onnx_program = torch.onnx.export(
Model(),
(x, mask, update),
input_names=["x", "mask", "update"],
output_names=["output"],
opset_version=18,
dynamo=True,
)
_testing.assert_onnx_program(onnx_program)

def test_index_put_bool_multi_mask(self):
class Model(torch.nn.Module):
def forward(self, x, mask0, mask1, update):
return torch.ops.aten.index_put(x, [mask0, mask1], update)

x = torch.zeros((3, 4), dtype=torch.float32)
mask0 = torch.tensor([True, False, True], dtype=torch.bool)
mask1 = torch.tensor([True, False, True, False], dtype=torch.bool)
update = torch.tensor([10.0, 20.0], dtype=torch.float32)
onnx_program = torch.onnx.export(
Model(),
(x, mask0, mask1, update),
input_names=["x", "mask0", "mask1", "update"],
output_names=["output"],
opset_version=18,
dynamo=True,
)
_testing.assert_onnx_program(onnx_program)

def test_std_mean(self):
"""Test torch.std_mean which will be decomposed into prims.sum."""

Expand Down
Loading