Skip to content

Fix ErnieImagePipeline pre-computed prompt_embeds + num_images_per_prompt shape mismatch#13532

Open
Ricardo-M-L wants to merge 1 commit intohuggingface:mainfrom
Ricardo-M-L:fix-ernie-image-prompt-embeds-num-images
Open

Fix ErnieImagePipeline pre-computed prompt_embeds + num_images_per_prompt shape mismatch#13532
Ricardo-M-L wants to merge 1 commit intohuggingface:mainfrom
Ricardo-M-L:fix-ernie-image-prompt-embeds-num-images

Conversation

@Ricardo-M-L
Copy link
Copy Markdown
Contributor

What this PR does

When a user passes pre-computed prompt_embeds (or negative_prompt_embeds) alongside num_images_per_prompt > 1, ErnieImagePipeline.__call__ did not replicate the provided embeddings — the embeds list kept its original length (one per prompt) while the latents were allocated with total_batch_size = batch_size * num_images_per_prompt:

https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py#L287-L298

# [Phase 2] Text encoding
if prompt_embeds is not None:
    text_hiddens = prompt_embeds            # length = batch_size (NOT replicated)
else:
    text_hiddens = self.encode_prompt(prompt, device, num_images_per_prompt)

if self.do_classifier_free_guidance:
    if negative_prompt_embeds is not None:
        uncond_text_hiddens = negative_prompt_embeds   # same issue
    ...

Why this is a real bug

In the denoise loop, latents are expanded to total_batch_size:

https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py#L329-L343

if self.do_classifier_free_guidance:
    latent_model_input = torch.cat([latents, latents], dim=0)   # (batch*N*2, ...)

pred = self.transformer(
    hidden_states=latent_model_input,
    text_bth=text_bth,                                          # (batch*2, ...)
    ...
)

text_bth.shape[0] is derived from len(cfg_text_hiddens), which is len(prompt_embeds) * 2 under CFG — i.e. batch_size * 2, not batch_size * N * 2. This produces a shape mismatch inside the transformer's text-conditioning attention, and the call raises a RuntimeError. The standard "pre-compute embeds once, generate N variants" usage pattern is broken.

Minimal repro

pipe = ErnieImagePipeline.from_pretrained(...)
embeds = pipe.encode_prompt(["a cat"], pipe.device, num_images_per_prompt=1)
pipe(prompt_embeds=embeds, num_images_per_prompt=4, ...)
# RuntimeError during transformer forward

Fix

encode_prompt already performs this replication internally:

https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py#L158-L160

# Repeat for num_images_per_prompt
for _ in range(num_images_per_prompt):
    text_hiddens.append(hidden)

Mirror the same replication in the pre-embed branches so both paths yield a text_hiddens list of length batch_size * num_images_per_prompt:

 if prompt_embeds is not None:
-    text_hiddens = prompt_embeds
+    text_hiddens = [h for h in prompt_embeds for _ in range(num_images_per_prompt)]
 else:
     text_hiddens = self.encode_prompt(prompt, device, num_images_per_prompt)

 if self.do_classifier_free_guidance:
     if negative_prompt_embeds is not None:
-        uncond_text_hiddens = negative_prompt_embeds
+        uncond_text_hiddens = [h for h in negative_prompt_embeds for _ in range(num_images_per_prompt)]
     else:
         uncond_text_hiddens = self.encode_prompt(negative_prompt, device, num_images_per_prompt)

The non-embed path is unaffected since encode_prompt already replicates.

Before submitting

  • Did you read the contributor guideline?
  • Was this discussed/approved via a Github issue or the forum? N/A — standard API consistency fix.
  • Did you make sure to update the documentation with your changes? N/A — no public API change.
  • Did you write any new necessary tests? N/A — restores documented prompt_embeds + num_images_per_prompt combo.

Who can review?

@yiyixuxu @sayakpaul

…ompt

When a user passes pre-computed `prompt_embeds` (or `negative_prompt_embeds`)
alongside `num_images_per_prompt > 1`, `ErnieImagePipeline.__call__` did
not replicate the provided embeddings — the embeds list kept its original
length (one per prompt) while the latents were allocated with
`total_batch_size = batch_size * num_images_per_prompt`:

    text_hiddens = prompt_embeds            # length = batch_size (NOT replicated)
    ...
    latents = randn_tensor((total_batch_size, ...))   # batch * N in shape

In the denoise loop `text_bth.shape[0]` then mismatches
`latent_model_input.shape[0]`, so the transformer call:

    pred = self.transformer(
        hidden_states=latent_model_input,   # (batch*N*2, ...) under CFG
        text_bth=text_bth,                  # (batch*2, ...)
        ...
    )

fails with a shape mismatch inside the attention block. The standard
"pre-compute embeds once, generate N variants" usage pattern is broken.

`encode_prompt` already performs this replication internally
(`for _ in range(num_images_per_prompt): text_hiddens.append(hidden)`
at lines 158-160), so the non-embed path is unaffected — this only
impacts callers of the documented `prompt_embeds` / `negative_prompt_embeds`
arguments.

Mirror the replication logic in the pre-embed branches so both paths
yield a `text_hiddens` list of length `batch_size * num_images_per_prompt`.
@github-actions github-actions Bot added pipelines size/S PR with diff < 50 LOC labels Apr 21, 2026
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!
would you be willing to add a pipeline test for ernie-image? a new PR is ok

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants