From e5f4b23fc2081a1efec117e6c831dc2eac0051e5 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 22 Apr 2026 09:45:29 +0800 Subject: [PATCH] Reduce WanAnimate TorchAO test input sizes to prevent OOM Shrink dummy inputs to avoid OOM on devices without FlashAttention. Reduce hidden_states spatial from 64x64 to 16x16 and frames from 21 to 5, bringing self-attention sequence length from 21,504 to 320. --- .../transformers/test_models_transformer_wan_animate.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/models/transformers/test_models_transformer_wan_animate.py b/tests/models/transformers/test_models_transformer_wan_animate.py index ac0ef0698c63..569e3507825e 100644 --- a/tests/models/transformers/test_models_transformer_wan_animate.py +++ b/tests/models/transformers/test_models_transformer_wan_animate.py @@ -219,7 +219,7 @@ def get_dummy_inputs(self): """Override to provide inputs matching the tiny Wan Animate model dimensions.""" return { "hidden_states": randn_tensor( - (1, 36, 21, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype + (1, 36, 5, 16, 16), generator=self.generator, device=torch_device, dtype=self.torch_dtype ), "encoder_hidden_states": randn_tensor( (1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype @@ -228,10 +228,10 @@ def get_dummy_inputs(self): (1, 257, 1280), generator=self.generator, device=torch_device, dtype=self.torch_dtype ), "pose_hidden_states": randn_tensor( - (1, 16, 20, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype + (1, 16, 4, 16, 16), generator=self.generator, device=torch_device, dtype=self.torch_dtype ), "face_pixel_values": randn_tensor( - (1, 3, 77, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype + (1, 3, 13, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype ), "timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype), }