fix: preserve MDP integrity in PPO mini-batching#98
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces trajectory-aware PPO mini-batching helpers to preserve whole trajectories across mini-batches. It updates the Ray trainer, DP actor, DP critic, and loss computation functions to support planned mini-batches and propagate global batch metadata (such as DP size, global batch size, and token counts) for proper loss aggregation. The review feedback highlights two critical improvements in trajectory_batching.py: first, handling empty valid indices defensively to prevent pipeline crashes, and second, vectorizing token sum calculations on the GPU to avoid performance bottlenecks caused by row-by-row GPU-CPU synchronizations.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| valid_indices = _valid_indices(data) | ||
| if not valid_indices: | ||
| raise ValueError("trajectory mini-batching requires at least one valid row") |
There was a problem hiding this comment.
If valid_indices is empty (e.g., when all samples in the batch are masked out or invalid), raising a ValueError will crash the entire training pipeline. It is safer to handle this case defensively by returning a dummy mini-batch of all padding, which avoids crashes while correctly marking all samples as invalid.
| valid_indices = _valid_indices(data) | |
| if not valid_indices: | |
| raise ValueError("trajectory mini-batching requires at least one valid row") | |
| valid_indices = _valid_indices(data) | |
| if not valid_indices: | |
| device = _batch_device(data.batch) | |
| prepared = data.select_idxs(list(range(len(data)))) | |
| prepared.batch["mini_batch_id"] = torch.zeros(len(data), dtype=torch.long, device=device) | |
| prepared.batch["sample_mask"] = torch.zeros(len(data), dtype=torch.bool, device=device) | |
| prepared.batch["mini_batch_global_size"] = torch.zeros(len(data), dtype=torch.long, device=device) | |
| prepared.batch["mini_batch_global_token_num"] = torch.zeros((len(data), 1), dtype=torch.long, device=device) | |
| prepared.batch["mini_batch_global_response_token_num"] = torch.zeros(len(data), dtype=torch.long, device=device) | |
| prepared.meta_info = dict(getattr(prepared, "meta_info", {})) | |
| prepared.meta_info["num_mini_batch"] = 1 | |
| return prepared |
There was a problem hiding this comment.
I don't think returning a dummy mini-batch is safe here. valid_indices is derived from sample_mask, which in this flow only marks world-size padding rows. A non-empty training batch should always contain at least one real row; if all rows are invalid, that is an upstream invariant violation and failing fast is preferable.
The proposed dummy batch is not a no-op: planned mini-batches are still executed by actor/critic/engine workers, and the loss path does not use sample_mask as the skip condition. The suggestion also does not clear response_mask/loss_mask, so invalid rows may still affect loss/metrics depending on the path. If we ever need to support an all-invalid batch, the trainer should skip actor/critic updates explicitly, not fabricate mini-batch metadata here.
| def _assign_global_mini_batch_info( | ||
| prepared: Any, | ||
| source_batch: Any, | ||
| mini_batches: list[list[list[int]]], | ||
| mini_batch_ids: torch.Tensor, | ||
| device: torch.device, | ||
| ) -> None: | ||
| row_counts = [sum(len(group) for group in mini_batch) for mini_batch in mini_batches] | ||
| max_rows = max(row_counts) | ||
| global_sizes = torch.tensor(row_counts, dtype=torch.long, device=device) | ||
| prepared.batch["mini_batch_global_size"] = global_sizes[mini_batch_ids] | ||
|
|
||
| token_num_table = torch.zeros((len(mini_batches), max_rows), dtype=torch.long, device=device) | ||
| response_token_nums = torch.zeros(len(mini_batches), dtype=torch.long, device=device) | ||
| attention_mask = source_batch.get("attention_mask", None) | ||
| response_mask = source_batch.get("response_mask", None) | ||
|
|
||
| for mini_batch_id, mini_batch in enumerate(mini_batches): | ||
| row_indices = [row_idx for group in mini_batch for row_idx in group] | ||
| if attention_mask is not None: | ||
| source_token_nums = attention_mask.new_tensor( | ||
| [_sum_row_tokens(attention_mask, row_idx) for row_idx in row_indices], dtype=torch.long | ||
| ) | ||
| token_num_table[mini_batch_id, : len(row_indices)] = source_token_nums.to(device) | ||
| if response_mask is not None: | ||
| response_token_nums[mini_batch_id] = int( | ||
| sum(_sum_row_tokens(response_mask, row_idx) for row_idx in row_indices) | ||
| ) | ||
|
|
||
| prepared.batch["mini_batch_global_token_num"] = token_num_table[mini_batch_ids] | ||
| prepared.batch["mini_batch_global_response_token_num"] = response_token_nums[mini_batch_ids] | ||
|
|
||
|
|
||
| def _sum_row_tokens(tensor: torch.Tensor, row_idx: int) -> int: | ||
| return int(tensor[row_idx].sum().detach().cpu().item()) |
There was a problem hiding this comment.
The current implementation of _assign_global_mini_batch_info uses a loop over all rows in the mini-batch and calls _sum_row_tokens, which performs .detach().cpu().item() on each row. This causes multiple GPU-CPU synchronizations per step, creating a massive performance bottleneck. We can completely avoid this by precomputing the sum of tokens for all rows in a single vectorized operation on the GPU.
def _assign_global_mini_batch_info(
prepared: Any,
source_batch: Any,
mini_batches: list[list[list[int]]],
mini_batch_ids: torch.Tensor,
device: torch.device,
) -> None:
row_counts = [sum(len(group) for group in mini_batch) for mini_batch in mini_batches]
max_rows = max(row_counts)
global_sizes = torch.tensor(row_counts, dtype=torch.long, device=device)
prepared.batch["mini_batch_global_size"] = global_sizes[mini_batch_ids]
token_num_table = torch.zeros((len(mini_batches), max_rows), dtype=torch.long, device=device)
response_token_nums = torch.zeros(len(mini_batches), dtype=torch.long, device=device)
attention_mask = source_batch.get("attention_mask", None)
response_mask = source_batch.get("response_mask", None)
attention_token_counts = attention_mask.sum(dim=-1) if attention_mask is not None else None
response_token_counts = response_mask.sum(dim=-1) if response_mask is not None else None
for mini_batch_id, mini_batch in enumerate(mini_batches):
row_indices = [row_idx for group in mini_batch for row_idx in group]
if attention_token_counts is not None:
source_token_nums = attention_token_counts[row_indices]
token_num_table[mini_batch_id, : len(row_indices)] = source_token_nums
if response_token_counts is not None:
response_token_nums[mini_batch_id] = response_token_counts[row_indices].sum()
prepared.batch["mini_batch_global_token_num"] = token_num_table[mini_batch_ids]
prepared.batch["mini_batch_global_response_token_num"] = response_token_nums[mini_batch_ids]
Summary
Preserve MDP integrity during PPO updates by ensuring all reasoning steps from the same trajectory are assigned to the same mini-batch. This prevents a single trajectory from being split across different actor or critic update batches, which would break the consistency of trajectory-level MDP optimization.
Changes