Skip to content

Add LTX-Video 2B streaming T2V integration#348

Open
shy982 wants to merge 1 commit into
NVIDIA:mainfrom
shy982:add-ltx-video-integration
Open

Add LTX-Video 2B streaming T2V integration#348
shy982 wants to merge 1 commit into
NVIDIA:mainfrom
shy982:add-ltx-video-integration

Conversation

@shy982

@shy982 shy982 commented Jun 24, 2026

Copy link
Copy Markdown

Summary

Adds a first-party FlashDreams integration for LTX-Video 2B causal streaming text-to-video.

  • New plugin under integrations/ltx_video/ wrapping LTXPipeline from diffusers
  • Three registered runner slugs:
    • ltx-video-t2v-2b — streaming wrapper using native pipe() per chunk
    • ltx-video-t2v-2b-optimized — manual denoise loop with KV-cache, torch.compile, and FlashAttention
    • ltx-video-t2v-2b-taehv — optimized path plus TAEHV fast decoder
  • CPU smoke tests for runner registration and entry-point discovery
  • GPU optimization tests (ci_gpu) validating each optimization layer in isolation
  • Model gallery docs at docs/source/models/ltx_video.rst

Performance note

The optimized runner improves time-to-first-frame after warmup. Steady-state per-chunk throughput vs the streaming runner is still being tuned (KV window size, compile/graph strategy, denoise loop overhead).

Test plan

  • uv sync --project integrations/ltx_video --extra dev
  • uv run --project integrations/ltx_video pytest integrations/ltx_video/tests/test_smoke.py -v (5 passed)
  • pytest integrations/ltx_video/tests/test_optimizations.py -v -m ci_gpu (7 passed on H100, cu128 env)
  • uv run --project integrations/ltx_video flashdreams-run --help lists all three LTX slugs
  • End-to-end: uv run --project integrations/ltx_video flashdreams-run ltx-video-t2v-2b --total-blocks 7 --pixel-height 512 --pixel-width 768 (requires HF weights)

Register streaming, optimized (KV-cache + compile), and TAEHV runner
slugs with smoke tests, GPU optimization tests, and model gallery docs.
@copy-pr-bot

copy-pr-bot Bot commented Jun 24, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@greptile-apps

greptile-apps Bot commented Jun 24, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR introduces a first-party FlashDreams integration for LTX-Video 2B causal streaming text-to-video, adding three runner slugs under integrations/ltx_video/ with support for native diffusers pipe(), a manual denoise loop with KV-cache, torch.compile, and CUDA graphs, plus an optional TAEHV fast decoder.

  • KV-cache path (attention.py + kv_cache.py): The attention processor stores the already-concatenated (past+current) key/value tensor, and then LTXKVCache.update() concatenates it again with the old cache — past tokens are duplicated every AR step, which corrupts the KV representation and causes the cache to grow far faster than expected.
  • Thread-safety (kv_context.py): _CTX is a bare module-level global despite the "thread-local" docstring; concurrent inference calls in a server environment would corrupt each other's KV state.
  • Benchmark script (run_benchmark.py): A fresh cache is created inside the warmup loop, so the timed call receives AR index 5 against a cache that expects index 0, causing an immediate AssertionError.

Confidence Score: 3/5

The KV-cache path — the core optimization for the two featured runner slugs — produces incorrect output due to token duplication in the attention processor, and the shared mutable context makes concurrent use dangerous. These need fixes before the optimized runners can be trusted in any environment beyond single-threaded smoke tests.

Three real defects on changed paths: the KV accumulation bug silently corrupts every AR frame after the first when KV cache is enabled; the module-level context singleton would race in any server deployment; and the benchmark crash prevents performance validation. The non-KV streaming runner (ltx-video-t2v-2b) is unaffected and appears correct, but it is one of three advertised runners.

attention.py and kv_context.py need the most attention — they carry the KV-cache correctness and thread-safety bugs. tests/benchmark/run_benchmark.py needs the cache re-initialization loop fixed before any benchmark numbers can be trusted.

Important Files Changed

Filename Overview
integrations/ltx_video/ltx_video/attention.py KV attention processor stores the full concatenated (past+current) key/value tensors before update() further concatenates them — past tokens are duplicated each AR step, corrupting the KV cache.
integrations/ltx_video/ltx_video/kv_context.py Docstring says 'thread-local' but _CTX is a bare module-level singleton; concurrent requests would corrupt each other's KV state.
integrations/ltx_video/ltx_video/pipeline.py Core streaming pipeline integrating manual denoise loop, compile, and CUDA graph paths; missing guard for the kv_cache+cuda_graphs combination when compile=False.
integrations/ltx_video/ltx_video/kv_cache.py Rolling KV cache with window support; the update() logic is correct in isolation but is fed already-concatenated tensors from the attention processor, making the combined behavior wrong.
integrations/ltx_video/ltx_video/compiler.py torch.compile, CUDA graph capture/replay, and FlashAttention toggle helpers; CUDA graph runner design is correct but a missing guard in pipeline.py makes the kv_cache+cuda_graphs combo unsafe.
integrations/ltx_video/ltx_video/ltx_loader.py Model loader wrapping diffusers LTXPipeline; low_cpu_mem_usage=False doubles peak CPU RAM during weight loading.
integrations/ltx_video/tests/benchmark/run_benchmark.py Benchmark script re-creates a fresh cache inside the warmup loop, causing an AssertionError when generate() receives AR index 5 against a cache that expects index 0.
integrations/ltx_video/tests/test_optimizations.py GPU optimization tests; KV-accumulation test only asserts seq_after_1 > seq_after_0 and would pass even with the token-doubling bug, so the bug goes undetected here.

Sequence Diagram

%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
    participant Caller
    participant Pipeline as LTXVideoStreamingPipeline
    participant Encoder as LTXEncoder
    participant KVCtx as KVContext (global)
    participant Attn as LTXKVAttnProcessor
    participant KVCache as LTXKVCache
    participant Decoder as LTXDecoder

    Caller->>Pipeline: "initialize_cache(text=[prompt])"
    Pipeline->>Encoder: encode(prompt, negative_prompt)
    Encoder-->>Pipeline: LTXConditionings
    Pipeline-->>Caller: LTXPipelineCache

    loop AR step N
        Caller->>Pipeline: generate(N, cache, width, height)
        Pipeline->>KVCtx: "configure_kv_context(past_kv=cache.kv.get())"
        loop Denoise timestep T
            Pipeline->>Attn: __call__(hidden_states, ...)
            Attn->>KVCtx: get_kv_context()
            Attn->>Attn: prepend past_k/v to key/value
            Note over Attn,KVCtx: stores full (past+current) key/value
            Attn->>KVCtx: set_layer_kv(layer_idx, key, value)
        end
        Pipeline->>KVCtx: collected_kv() to last_present_kv
        Pipeline->>KVCtx: reset_kv_context()
        Pipeline->>Decoder: decode_from_denoised(latents)
        Decoder-->>Pipeline: frames
        Pipeline-->>Caller: frames tensor

        Caller->>Pipeline: finalize(N, cache)
        Pipeline->>KVCache: update(pending_kv)
        Note over KVCache: re-concatenates already-accumulated KV
        KVCache-->>Pipeline: updated cache
    end
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
    participant Caller
    participant Pipeline as LTXVideoStreamingPipeline
    participant Encoder as LTXEncoder
    participant KVCtx as KVContext (global)
    participant Attn as LTXKVAttnProcessor
    participant KVCache as LTXKVCache
    participant Decoder as LTXDecoder

    Caller->>Pipeline: "initialize_cache(text=[prompt])"
    Pipeline->>Encoder: encode(prompt, negative_prompt)
    Encoder-->>Pipeline: LTXConditionings
    Pipeline-->>Caller: LTXPipelineCache

    loop AR step N
        Caller->>Pipeline: generate(N, cache, width, height)
        Pipeline->>KVCtx: "configure_kv_context(past_kv=cache.kv.get())"
        loop Denoise timestep T
            Pipeline->>Attn: __call__(hidden_states, ...)
            Attn->>KVCtx: get_kv_context()
            Attn->>Attn: prepend past_k/v to key/value
            Note over Attn,KVCtx: stores full (past+current) key/value
            Attn->>KVCtx: set_layer_kv(layer_idx, key, value)
        end
        Pipeline->>KVCtx: collected_kv() to last_present_kv
        Pipeline->>KVCtx: reset_kv_context()
        Pipeline->>Decoder: decode_from_denoised(latents)
        Decoder-->>Pipeline: frames
        Pipeline-->>Caller: frames tensor

        Caller->>Pipeline: finalize(N, cache)
        Pipeline->>KVCache: update(pending_kv)
        Note over KVCache: re-concatenates already-accumulated KV
        KVCache-->>Pipeline: updated cache
    end
Loading

Reviews (1): Last reviewed commit: "Add LTX-Video 2B streaming T2V integrati..." | Re-trigger Greptile

Comment on lines +141 to +151
if past_layer_kv is not None:
past_k, past_v = past_layer_kv
past_len = past_k.shape[1]
key = torch.cat([past_k, key], dim=1)
value = torch.cat([past_v, value], dim=1)

if ctx.collect:
ctx.set_layer_kv(
self.layer_idx,
(key.detach().clone(), value.detach().clone()),
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 KV cache stores accumulated (past+current) tokens — causes exponential growth

When collecting the present KV, key and value have already been prepended with past_k/past_v from the previous AR step. Storing the full concatenated tensor and then having LTXKVCache.update() additionally concatenate with k_old means each step double-counts the past tokens. After step 1 the cache holds [key_step0, key_step0, key_step1] instead of [key_step0, key_step1], and the duplication compounds every subsequent step.

The fix is to capture only the current chunk's keys and values before the past-prepend, and let the cache accumulate them. Concretely, save the new-chunk slices before the torch.cat lines and store those, rather than the already-extended key/value.

Comment on lines +38 to +42
_CTX = KVContext()


def get_kv_context() -> KVContext:
return _CTX

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Module-level global _CTX is not thread-safe, contrary to the docstring

The module docstring says "Thread-local KV state" but _CTX is a plain module-level singleton, not a threading.local(). Any server environment dispatching two concurrent generation requests will have both forward passes mutating the same KVContext object — interleaving configure_kv_context, begin_forward, set_layer_kv, and reset_kv_context calls across threads, corrupting collected KV and poisoning the cache for every concurrent request. The fix is to replace _CTX = KVContext() with _CTX = threading.local() and lazily initialize a KVContext per thread.

Comment on lines +84 to +92
for s in range(STEP_TO_MEASURE):
pipe.generate(s, cache, width=768, height=512)
pipe.finalize(s, cache)
cache = pipe.initialize_cache(text=[PROMPT])

def step() -> None:
pipe.generate(STEP_TO_MEASURE, cache, width=768, height=512)

return {"mode": mode, "times_ms": measure_dit_time(step)}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 run_flashdreams() will always raise AssertionError for non-baseline modes

Inside the warmup loop the cache is re-initialized on every iteration (cache = pipe.initialize_cache(text=[PROMPT]) at line 87), so after the loop cache is a fresh object with autoregressive_index = None. generate() then asserts that autoregressive_index == 0, but it receives STEP_TO_MEASURE = 5, immediately crashing with AR step out of order: previous=None, expected=0, got=5. The re-initialization inside the loop also defeats the purpose of the warmup: accumulated KV state is discarded before the timed call, so the measurement never reflects a warmed-up KV cache.

Comment on lines +29 to +33
pipe = pipeline_cls.from_pretrained(
checkpoint,
torch_dtype=dtype,
low_cpu_mem_usage=False,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 low_cpu_mem_usage=False makes the loader hold a full second copy of the model weights in CPU RAM during the initial parameter transfer, effectively doubling peak host-memory usage for a 2B-parameter model. Setting it to True (the diffusers default) streams weights into the model shard-by-shard and avoids the spike.

Suggested change
pipe = pipeline_cls.from_pretrained(
checkpoint,
torch_dtype=dtype,
low_cpu_mem_usage=False,
)
pipe = pipeline_cls.from_pretrained(
checkpoint,
torch_dtype=dtype,
low_cpu_mem_usage=True,
)

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +85 to +90
if self.use_cuda_graphs and self.use_compile:
print(
"[LTX compiler] CUDA graphs disabled with torch.compile "
"(incompatible graph capture on compiled DiT; graphs still apply when compile=False)"
)
self._cuda_graphs_enabled = False

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 CUDA graphs silently capture stale KV state when kv_cache=True, compile=False

The existing guard disables CUDA graphs when compile=True (lines 85-90), but there is no guard for the cuda_graphs=True, kv_cache=True, compile=False combination. CUDA graphs replay recorded GPU kernels without re-executing Python: the torch.cat([past_k, key], dim=1) call inside LTXKVAttnProcessor is a Python-level operation, so during replay the attention still operates on the KV tensors captured during the first AR step. AR step 2+ will silently produce frames conditioned on stale (step-0) KV context. A note in the constructor alongside the existing guard — or an explicit self._cuda_graphs_enabled = False when use_kv_cache is True — would prevent this from being triggered by derived configs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant