Add LTX-Video 2B streaming T2V integration#348
Conversation
Register streaming, optimized (KV-cache + compile), and TAEHV runner slugs with smoke tests, GPU optimization tests, and model gallery docs.
Greptile SummaryThis PR introduces a first-party FlashDreams integration for LTX-Video 2B causal streaming text-to-video, adding three runner slugs under
Confidence Score: 3/5The KV-cache path — the core optimization for the two featured runner slugs — produces incorrect output due to token duplication in the attention processor, and the shared mutable context makes concurrent use dangerous. These need fixes before the optimized runners can be trusted in any environment beyond single-threaded smoke tests. Three real defects on changed paths: the KV accumulation bug silently corrupts every AR frame after the first when KV cache is enabled; the module-level context singleton would race in any server deployment; and the benchmark crash prevents performance validation. The non-KV streaming runner (
Important Files Changed
Sequence Diagram%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
participant Caller
participant Pipeline as LTXVideoStreamingPipeline
participant Encoder as LTXEncoder
participant KVCtx as KVContext (global)
participant Attn as LTXKVAttnProcessor
participant KVCache as LTXKVCache
participant Decoder as LTXDecoder
Caller->>Pipeline: "initialize_cache(text=[prompt])"
Pipeline->>Encoder: encode(prompt, negative_prompt)
Encoder-->>Pipeline: LTXConditionings
Pipeline-->>Caller: LTXPipelineCache
loop AR step N
Caller->>Pipeline: generate(N, cache, width, height)
Pipeline->>KVCtx: "configure_kv_context(past_kv=cache.kv.get())"
loop Denoise timestep T
Pipeline->>Attn: __call__(hidden_states, ...)
Attn->>KVCtx: get_kv_context()
Attn->>Attn: prepend past_k/v to key/value
Note over Attn,KVCtx: stores full (past+current) key/value
Attn->>KVCtx: set_layer_kv(layer_idx, key, value)
end
Pipeline->>KVCtx: collected_kv() to last_present_kv
Pipeline->>KVCtx: reset_kv_context()
Pipeline->>Decoder: decode_from_denoised(latents)
Decoder-->>Pipeline: frames
Pipeline-->>Caller: frames tensor
Caller->>Pipeline: finalize(N, cache)
Pipeline->>KVCache: update(pending_kv)
Note over KVCache: re-concatenates already-accumulated KV
KVCache-->>Pipeline: updated cache
end
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
participant Caller
participant Pipeline as LTXVideoStreamingPipeline
participant Encoder as LTXEncoder
participant KVCtx as KVContext (global)
participant Attn as LTXKVAttnProcessor
participant KVCache as LTXKVCache
participant Decoder as LTXDecoder
Caller->>Pipeline: "initialize_cache(text=[prompt])"
Pipeline->>Encoder: encode(prompt, negative_prompt)
Encoder-->>Pipeline: LTXConditionings
Pipeline-->>Caller: LTXPipelineCache
loop AR step N
Caller->>Pipeline: generate(N, cache, width, height)
Pipeline->>KVCtx: "configure_kv_context(past_kv=cache.kv.get())"
loop Denoise timestep T
Pipeline->>Attn: __call__(hidden_states, ...)
Attn->>KVCtx: get_kv_context()
Attn->>Attn: prepend past_k/v to key/value
Note over Attn,KVCtx: stores full (past+current) key/value
Attn->>KVCtx: set_layer_kv(layer_idx, key, value)
end
Pipeline->>KVCtx: collected_kv() to last_present_kv
Pipeline->>KVCtx: reset_kv_context()
Pipeline->>Decoder: decode_from_denoised(latents)
Decoder-->>Pipeline: frames
Pipeline-->>Caller: frames tensor
Caller->>Pipeline: finalize(N, cache)
Pipeline->>KVCache: update(pending_kv)
Note over KVCache: re-concatenates already-accumulated KV
KVCache-->>Pipeline: updated cache
end
Reviews (1): Last reviewed commit: "Add LTX-Video 2B streaming T2V integrati..." | Re-trigger Greptile |
| if past_layer_kv is not None: | ||
| past_k, past_v = past_layer_kv | ||
| past_len = past_k.shape[1] | ||
| key = torch.cat([past_k, key], dim=1) | ||
| value = torch.cat([past_v, value], dim=1) | ||
|
|
||
| if ctx.collect: | ||
| ctx.set_layer_kv( | ||
| self.layer_idx, | ||
| (key.detach().clone(), value.detach().clone()), | ||
| ) |
There was a problem hiding this comment.
KV cache stores accumulated (past+current) tokens — causes exponential growth
When collecting the present KV, key and value have already been prepended with past_k/past_v from the previous AR step. Storing the full concatenated tensor and then having LTXKVCache.update() additionally concatenate with k_old means each step double-counts the past tokens. After step 1 the cache holds [key_step0, key_step0, key_step1] instead of [key_step0, key_step1], and the duplication compounds every subsequent step.
The fix is to capture only the current chunk's keys and values before the past-prepend, and let the cache accumulate them. Concretely, save the new-chunk slices before the torch.cat lines and store those, rather than the already-extended key/value.
| _CTX = KVContext() | ||
|
|
||
|
|
||
| def get_kv_context() -> KVContext: | ||
| return _CTX |
There was a problem hiding this comment.
Module-level global
_CTX is not thread-safe, contrary to the docstring
The module docstring says "Thread-local KV state" but _CTX is a plain module-level singleton, not a threading.local(). Any server environment dispatching two concurrent generation requests will have both forward passes mutating the same KVContext object — interleaving configure_kv_context, begin_forward, set_layer_kv, and reset_kv_context calls across threads, corrupting collected KV and poisoning the cache for every concurrent request. The fix is to replace _CTX = KVContext() with _CTX = threading.local() and lazily initialize a KVContext per thread.
| for s in range(STEP_TO_MEASURE): | ||
| pipe.generate(s, cache, width=768, height=512) | ||
| pipe.finalize(s, cache) | ||
| cache = pipe.initialize_cache(text=[PROMPT]) | ||
|
|
||
| def step() -> None: | ||
| pipe.generate(STEP_TO_MEASURE, cache, width=768, height=512) | ||
|
|
||
| return {"mode": mode, "times_ms": measure_dit_time(step)} |
There was a problem hiding this comment.
run_flashdreams() will always raise AssertionError for non-baseline modes
Inside the warmup loop the cache is re-initialized on every iteration (cache = pipe.initialize_cache(text=[PROMPT]) at line 87), so after the loop cache is a fresh object with autoregressive_index = None. generate() then asserts that autoregressive_index == 0, but it receives STEP_TO_MEASURE = 5, immediately crashing with AR step out of order: previous=None, expected=0, got=5. The re-initialization inside the loop also defeats the purpose of the warmup: accumulated KV state is discarded before the timed call, so the measurement never reflects a warmed-up KV cache.
| pipe = pipeline_cls.from_pretrained( | ||
| checkpoint, | ||
| torch_dtype=dtype, | ||
| low_cpu_mem_usage=False, | ||
| ) |
There was a problem hiding this comment.
low_cpu_mem_usage=False makes the loader hold a full second copy of the model weights in CPU RAM during the initial parameter transfer, effectively doubling peak host-memory usage for a 2B-parameter model. Setting it to True (the diffusers default) streams weights into the model shard-by-shard and avoids the spike.
| pipe = pipeline_cls.from_pretrained( | |
| checkpoint, | |
| torch_dtype=dtype, | |
| low_cpu_mem_usage=False, | |
| ) | |
| pipe = pipeline_cls.from_pretrained( | |
| checkpoint, | |
| torch_dtype=dtype, | |
| low_cpu_mem_usage=True, | |
| ) |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| if self.use_cuda_graphs and self.use_compile: | ||
| print( | ||
| "[LTX compiler] CUDA graphs disabled with torch.compile " | ||
| "(incompatible graph capture on compiled DiT; graphs still apply when compile=False)" | ||
| ) | ||
| self._cuda_graphs_enabled = False |
There was a problem hiding this comment.
CUDA graphs silently capture stale KV state when
kv_cache=True, compile=False
The existing guard disables CUDA graphs when compile=True (lines 85-90), but there is no guard for the cuda_graphs=True, kv_cache=True, compile=False combination. CUDA graphs replay recorded GPU kernels without re-executing Python: the torch.cat([past_k, key], dim=1) call inside LTXKVAttnProcessor is a Python-level operation, so during replay the attention still operates on the KV tensors captured during the first AR step. AR step 2+ will silently produce frames conditioned on stale (step-0) KV context. A note in the constructor alongside the existing guard — or an explicit self._cuda_graphs_enabled = False when use_kv_cache is True — would prevent this from being triggered by derived configs.
Summary
Adds a first-party FlashDreams integration for LTX-Video 2B causal streaming text-to-video.
integrations/ltx_video/wrappingLTXPipelinefrom diffusersltx-video-t2v-2b— streaming wrapper using nativepipe()per chunkltx-video-t2v-2b-optimized— manual denoise loop with KV-cache,torch.compile, and FlashAttentionltx-video-t2v-2b-taehv— optimized path plus TAEHV fast decoderci_gpu) validating each optimization layer in isolationdocs/source/models/ltx_video.rstPerformance note
The optimized runner improves time-to-first-frame after warmup. Steady-state per-chunk throughput vs the streaming runner is still being tuned (KV window size, compile/graph strategy, denoise loop overhead).
Test plan
uv sync --project integrations/ltx_video --extra devuv run --project integrations/ltx_video pytest integrations/ltx_video/tests/test_smoke.py -v(5 passed)pytest integrations/ltx_video/tests/test_optimizations.py -v -m ci_gpu(7 passed on H100, cu128 env)uv run --project integrations/ltx_video flashdreams-run --helplists all three LTX slugsuv run --project integrations/ltx_video flashdreams-run ltx-video-t2v-2b --total-blocks 7 --pixel-height 512 --pixel-width 768(requires HF weights)