diff --git a/tests/models/transformers/test_models_transformer_wan_animate.py b/tests/models/transformers/test_models_transformer_wan_animate.py index df67e55c9b5d..94dab90dc20a 100644 --- a/tests/models/transformers/test_models_transformer_wan_animate.py +++ b/tests/models/transformers/test_models_transformer_wan_animate.py @@ -224,7 +224,7 @@ def get_dummy_inputs(self): """Override to provide inputs matching the tiny Wan Animate model dimensions.""" return { "hidden_states": randn_tensor( - (1, 36, 21, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype + (1, 36, 5, 16, 16), generator=self.generator, device=torch_device, dtype=self.torch_dtype ), "encoder_hidden_states": randn_tensor( (1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype @@ -233,10 +233,10 @@ def get_dummy_inputs(self): (1, 257, 1280), generator=self.generator, device=torch_device, dtype=self.torch_dtype ), "pose_hidden_states": randn_tensor( - (1, 16, 20, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype + (1, 16, 4, 16, 16), generator=self.generator, device=torch_device, dtype=self.torch_dtype ), "face_pixel_values": randn_tensor( - (1, 3, 77, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype + (1, 3, 13, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype ), "timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype), }