[torchlib] Reimplement as_strided without an ONNX loop#2928
Conversation
|
@copilot the inputs marked as SymInt can be dynamic. Can you update the logic to account for this situation? |
Added a dynamic path in 93d3247: when any of |
|
@copilot can you add good coverage on the new logic in tests/function_libs/torch_lib/e2e_ops_tests.py? |
|
@copilot also add more detailed comment to explain what each line of the implementation is doing in relation to the torch op, as it is a more complicated op, for readability |
Added e2e coverage in |
Added detailed line-by-line comments to |
| empty_shape = op.Constant(value=ir.tensor(np.array([], dtype=np.int64))) | ||
| # Start the running index from storage_offset, cast to an INT64 scalar so all | ||
| # the arithmetic below has a consistent dtype regardless of how the SymInt | ||
| # runtime values are typed (e.g. int32 SymInts). |
There was a problem hiding this comment.
@copilot I think we can assume all symints are int64.
There was a problem hiding this comment.
Removed the INT64 Cast ops from the dynamic path and now reshape the SymInt values directly, assuming they are already INT64.
There was a problem hiding this comment.
Pull request overview
This PR refactors the torchlib lowering for aten::as_strided to avoid generating an ONNX Loop/Scan when constructing gather indices, enabling simpler downstream constant-folding. It introduces a static fast path that precomputes the index tensor with NumPy when all size/stride/storage_offset values are known at trace time, and a dynamic path that builds the same index math using ONNX ops without loops.
Changes:
- Reimplemented
aten_as_stridedinops/core.pyas aReshape([-1]) + Gatherwith (1) a NumPy-constant index fast path and (2) an ONNX-op dynamic index path (noLoop/Scan). - Removed the now-unused private
_aten_as_strided_onnxlowering and unblocked type-constraint deduction by removing it from the “skip loop/scan” list. - Added new E2E tests covering several
as_stridedscenarios (static and dynamic shapes/offsets).
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
tests/function_libs/torch_lib/e2e_ops_tests.py |
Adds E2E export coverage for torch.as_strided across static and dynamic cases. |
onnxscript/function_libs/torch_lib/ops/core.py |
Replaces loop-based index construction with static NumPy-constant and dynamic ONNX-op paths. |
onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py |
Removes _aten_as_strided_onnx from the loop/scan skip list since it no longer exists. |
| # Start the running index from storage_offset as an INT64 scalar; SymInt | ||
| # runtime values are assumed to be INT64. | ||
| indices = op.Reshape(storage_offset, empty_shape) | ||
| for dim in range(rank): | ||
| # Reshape this dimension's size and stride to INT64 scalars. | ||
| dim_size = op.Reshape(size[dim], empty_shape) | ||
| dim_stride = op.Reshape(stride[dim], empty_shape) |
aten_as_stridedwas lowered to a private ONNX function (_aten_as_strided_onnx) that built gather indices via an unrolled loop ofExpand/Range/SequenceInsert/ConcatFromSequenceops. This graph is hard for downstream passes to constant-fold.Since
aten_as_stridedis alreadytrace_only, whensize,stride, andstorage_offsetare concrete at trace time the indices can be computed once with NumPy and emitted as a constant. TheSymIntinputs can also be dynamic (runtime values), so a second path builds the indices with ONNX ops while still avoiding anyLoop/Scan.Changes
ops/core.py: Replace the loop implementation with two paths sharing the same index math, where for each output position the storage index isstorage_offset + Σ_d i_d · stride[d]and the result isReshape(self, [-1])+Gather:size/stride/storage_offsetare concrete ints): fold the indices into a single constant index tensor.SymIntis a runtime value): build the indices with ONNX ops (Range/Mul/Unsqueeze/Add). The per-dimension contributions are unrolled at trace time since the rank is always static, so noLoop/Scanis emitted. RuntimeSymIntvalues are assumed to be INT64 and reshaped to scalars directly, and mixed static/dynamic dimensions are supported.storage_offset=Noneis normalized to0so the dynamic path does not emit an invalidReshapeof a missing input.ops/core.py: Remove the now-unused private_aten_as_strided_onnxfunction.deduce_type_constraints_test.py: Drop_aten_as_strided_onnxfrom_SKIP_FUNCTIONS_WITH_LOOP_OR_SCAN, since no loop/scan remains.tests/function_libs/torch_lib/e2e_ops_tests.py: Add end-to-end coverage for both paths — static (multi-dimensional with non-zerostorage_offset, single dimension, overlapping strides, scalar/emptysize) and dynamic (sizederived from the input shape, with and withoutstorage_offset).Implementation
The empty-
sizecase naturally yields a 0-d index tensor, producing a scalar output. Both paths were checked againsttorch.as_stridedfor multi-dimensional, non-zerostorage_offset, single-dimension, scalar/empty-size, and mixed static/dynamic inputs.