Skip to content

bug in ONNX export of aten_full due to incorrect dtype default argument #2925

@bas-aarts

Description

@bas-aarts

exporting the linked .pt2 model with the following script:

import onnx
import torch

torch.onnx.export(torch.export.load("input.pt2"), (torch.randn(1, 3, 640, 640),), optimize=False).save("output.pt2")
onnx.shape_inference.infer_shapes_path("output.pt2", strict_mode=True)

results in shape inference errors:

(op_type:Expand, node name: node_full_1): [TypeInferenceError] Inferred elem type differs from existing elem type: (1) vs (7)...

The issue is caused by line

. When fill_value is integral, but dtype is not provides, line
fill_value = op.Cast(fill_value, to=dtype)
is changing the fill_vale to be float, which causes the type inference errors.
Using -1 as the default dtype instead of FLOAT.dtype fixes the issue.

input model (40MB):
https://drive.google.com/file/d/1_TaBSSF_F6mW-QruIO-55qW0xY3qI62G/view?usp=drive_web

versions of relevant packages:

onnx                   1.21.0
onnx-ir                0.2.1
onnxscript             0.7.0
torch                  2.12.0
torchvision            0.27.0

Metadata

Metadata

Labels

No labels
No labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions