Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion agent_r1/trainer/ppo/core_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,13 +383,25 @@ def compute_value_loss(
response_mask: torch.Tensor,
cliprange_value: float,
loss_agg_mode: str = "token-mean",
dp_size: int = 1,
batch_num_tokens: Optional[int] = None,
global_batch_size: Optional[int] = None,
loss_scale_factor: Optional[int] = None,
):
"""Local value loss that uses pad-aware `agg_loss`."""
vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value)
vf_losses1 = (vpreds - returns) ** 2
vf_losses2 = (vpredclipped - returns) ** 2
clipped_vf_losses = torch.max(vf_losses1, vf_losses2)
vf_loss = 0.5 * agg_loss(loss_mat=clipped_vf_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
vf_loss = 0.5 * agg_loss(
loss_mat=clipped_vf_losses,
loss_mask=response_mask,
loss_agg_mode=loss_agg_mode,
dp_size=dp_size,
batch_num_tokens=batch_num_tokens,
global_batch_size=global_batch_size,
loss_scale_factor=loss_scale_factor,
)
vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), response_mask)
return vf_loss, vf_clipfrac

Expand Down
27 changes: 17 additions & 10 deletions agent_r1/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from tqdm import tqdm

from agent_r1.trainer.ppo.metric_utils import compute_data_metrics
from agent_r1.trainer.ppo.trajectory_batching import prepare_trajectory_mini_batch
from verl import DataProto
from verl.experimental.dataset.sampler import AbstractCurriculumSampler
from verl.protocol import pad_dataproto_to_divisor
Expand Down Expand Up @@ -286,17 +287,17 @@ def _update_actor(self, batch: DataProto) -> DataProto:
rollout_config = self.config.actor_rollout_ref.rollout
batch.meta_info["multi_turn"] = rollout_config.multi_turn.enable
batch.meta_info["temperature"] = rollout_config.temperature
ppo_mini_batch_size = self.config.actor_rollout_ref.actor.ppo_mini_batch_size
ppo_mini_batch_size = ppo_mini_batch_size * self.config.actor_rollout_ref.rollout.n
if self.use_legacy_worker_impl == "disable":
from verl.utils import tensordict_utils as tu
from verl.utils.py_functional import rename_dict
from verl.workers.utils.padding import left_right_2_no_padding

calculate_entropy = self.config.actor_rollout_ref.actor.entropy_coeff != 0.0
ppo_mini_batch_size = self.config.actor_rollout_ref.actor.ppo_mini_batch_size
ppo_mini_batch_size = ppo_mini_batch_size * self.config.actor_rollout_ref.rollout.n
dp_size = self._get_dp_size(self.actor_rollout_wg, "actor")
assign_global_mini_batch_ids(batch, mini_batch_size=ppo_mini_batch_size, dp_size=dp_size)
batch_td = batch.to_tensordict()
update_batch = prepare_trajectory_mini_batch(batch, mini_batch_size=ppo_mini_batch_size, dp_size=dp_size)
batch_td = update_batch.to_tensordict()
batch_td = left_right_2_no_padding(batch_td)
ppo_epochs = self.config.actor_rollout_ref.actor.ppo_epochs
seed = self.config.actor_rollout_ref.actor.data_loader_seed
Expand All @@ -305,6 +306,7 @@ def _update_actor(self, batch: DataProto) -> DataProto:
batch_td,
calculate_entropy=calculate_entropy,
mini_batch_size=ppo_mini_batch_size,
num_mini_batch=update_batch.meta_info["num_mini_batch"],
epochs=ppo_epochs,
seed=seed,
dataloader_kwargs={"shuffle": shuffle},
Expand All @@ -316,27 +318,30 @@ def _update_actor(self, batch: DataProto) -> DataProto:
actor_output["perf/mfu/actor"] = actor_output.pop("actor/mfu")
actor_output = DataProto.from_single_dict(data={}, meta_info={"metrics": actor_output})
else:
actor_output = self.actor_rollout_wg.update_actor(batch)
dp_size = self._get_worker_group_dp_size(self.actor_rollout_wg, ("actor",))
update_batch = prepare_trajectory_mini_batch(batch, mini_batch_size=ppo_mini_batch_size, dp_size=dp_size)
actor_output = self.actor_rollout_wg.update_actor(update_batch)
return actor_output

def _update_critic(self, batch: DataProto) -> DataProto:
ppo_mini_batch_size = self.config.critic.ppo_mini_batch_size
ppo_mini_batch_size = ppo_mini_batch_size * self.config.actor_rollout_ref.rollout.n
if self.use_legacy_worker_impl == "disable":
from verl.utils import tensordict_utils as tu
from verl.utils.py_functional import rename_dict
from verl.workers.utils.padding import left_right_2_no_padding

ppo_mini_batch_size = self.config.critic.ppo_mini_batch_size
ppo_mini_batch_size = ppo_mini_batch_size * self.config.actor_rollout_ref.rollout.n
dp_size = self._get_worker_group_dp_size(self.critic_wg, ("train", "critic"))
assign_global_mini_batch_ids(batch, mini_batch_size=ppo_mini_batch_size, dp_size=dp_size)
batch_td = batch.to_tensordict()
update_batch = prepare_trajectory_mini_batch(batch, mini_batch_size=ppo_mini_batch_size, dp_size=dp_size)
batch_td = update_batch.to_tensordict()
batch_td = left_right_2_no_padding(batch_td)
ppo_epochs = self.config.critic.ppo_epochs
seed = self.config.critic.data_loader_seed
shuffle = self.config.critic.shuffle
tu.assign_non_tensor(
batch_td,
mini_batch_size=ppo_mini_batch_size,
num_mini_batch=update_batch.meta_info["num_mini_batch"],
epochs=ppo_epochs,
seed=seed,
dataloader_kwargs={"shuffle": shuffle},
Expand All @@ -349,7 +354,9 @@ def _update_critic(self, batch: DataProto) -> DataProto:
output["perf/mfu/critic"] = output.pop("critic/mfu")
output = DataProto.from_single_dict(data={}, meta_info={"metrics": output})
else:
output = self.critic_wg.update_critic(batch)
dp_size = self._get_worker_group_dp_size(self.critic_wg, ("critic",))
update_batch = prepare_trajectory_mini_batch(batch, mini_batch_size=ppo_mini_batch_size, dp_size=dp_size)
output = self.critic_wg.update_critic(update_batch)
return output

def _get_worker_group_dp_size(self, worker_group, roles: Sequence[str]) -> int:
Expand Down
209 changes: 209 additions & 0 deletions agent_r1/trainer/ppo/trajectory_batching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
"""Trajectory-aware PPO mini-batch helpers."""

from __future__ import annotations

from collections import OrderedDict
from dataclasses import dataclass
from typing import Any

import torch


@dataclass(frozen=True)
class _Entry:
source_idx: int
mini_batch_id: int
is_padding: bool


def prepare_trajectory_mini_batch(data: Any, mini_batch_size: int, dp_size: int) -> Any:
"""Build an update batch whose PPO mini-batches preserve whole trajectories.

Args:
data (Any): A DataProto-like object with `batch`, `non_tensor_batch`, and `select_idxs`.
mini_batch_size (int): Target number of trajectories per PPO mini-batch.
dp_size (int): Data parallel size used by the training worker dispatch.

Returns:
Any: A DataProto-like object with mini-batch metadata and update-only padding rows.
"""
if mini_batch_size <= 0:
raise ValueError(f"mini_batch_size must be positive, got {mini_batch_size}")
if dp_size <= 0:
raise ValueError(f"dp_size must be positive, got {dp_size}")
if len(data) == 0:
return data

valid_indices = _valid_indices(data)
if not valid_indices:
raise ValueError("trajectory mini-batching requires at least one valid row")
Comment on lines +37 to +39
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If valid_indices is empty (e.g., when all samples in the batch are masked out or invalid), raising a ValueError will crash the entire training pipeline. It is safer to handle this case defensively by returning a dummy mini-batch of all padding, which avoids crashes while correctly marking all samples as invalid.

Suggested change
valid_indices = _valid_indices(data)
if not valid_indices:
raise ValueError("trajectory mini-batching requires at least one valid row")
valid_indices = _valid_indices(data)
if not valid_indices:
device = _batch_device(data.batch)
prepared = data.select_idxs(list(range(len(data))))
prepared.batch["mini_batch_id"] = torch.zeros(len(data), dtype=torch.long, device=device)
prepared.batch["sample_mask"] = torch.zeros(len(data), dtype=torch.bool, device=device)
prepared.batch["mini_batch_global_size"] = torch.zeros(len(data), dtype=torch.long, device=device)
prepared.batch["mini_batch_global_token_num"] = torch.zeros((len(data), 1), dtype=torch.long, device=device)
prepared.batch["mini_batch_global_response_token_num"] = torch.zeros(len(data), dtype=torch.long, device=device)
prepared.meta_info = dict(getattr(prepared, "meta_info", {}))
prepared.meta_info["num_mini_batch"] = 1
return prepared

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think returning a dummy mini-batch is safe here. valid_indices is derived from sample_mask, which in this flow only marks world-size padding rows. A non-empty training batch should always contain at least one real row; if all rows are invalid, that is an upstream invariant violation and failing fast is preferable.

The proposed dummy batch is not a no-op: planned mini-batches are still executed by actor/critic/engine workers, and the loss path does not use sample_mask as the skip condition. The suggestion also does not clear response_mask/loss_mask, so invalid rows may still affect loss/metrics depending on the path. If we ever need to support an all-invalid batch, the trainer should skip actor/critic updates explicitly, not fabricate mini-batch metadata here.


mini_batches = _build_trajectory_batches(data, valid_indices, mini_batch_size)
entries = _build_rank_ordered_entries(mini_batches, dp_size)
source_indices = [entry.source_idx for entry in entries]
prepared = data.select_idxs(source_indices)

device = _batch_device(prepared.batch)
mini_batch_ids = torch.tensor([entry.mini_batch_id for entry in entries], dtype=torch.long, device=device)
padding_mask = torch.tensor([entry.is_padding for entry in entries], dtype=torch.bool, device=device)

prepared.batch["mini_batch_id"] = mini_batch_ids
prepared.batch["sample_mask"] = ~padding_mask
_zero_padding_loss_masks(prepared.batch, padding_mask)
_assign_global_mini_batch_info(prepared, data.batch, mini_batches, mini_batch_ids, device)

prepared.meta_info = dict(getattr(prepared, "meta_info", {}))
prepared.meta_info["num_mini_batch"] = len(mini_batches)
return prepared


def split_data_proto_by_mini_batch_id(data: Any, *, shuffle: bool = False, seed: int = 42) -> list[Any]:
"""Split a local worker batch using precomputed `mini_batch_id` values.

Args:
data (Any): A DataProto-like object containing `batch["mini_batch_id"]`.
shuffle (bool): Whether to shuffle mini-batch id order.
seed (int): Deterministic seed for shuffling.

Returns:
list[Any]: DataProto-like mini-batches, each containing exactly one mini-batch id.
"""
if "mini_batch_id" not in data.batch:
raise KeyError("mini_batch_id is required for trajectory-aware mini-batch splitting")

mini_batch_ids = data.batch["mini_batch_id"].detach().cpu()
num_mini_batch = int(getattr(data, "meta_info", {}).get("num_mini_batch", mini_batch_ids.max().item() + 1))
ordered_ids = list(range(num_mini_batch))
if shuffle:
generator = torch.Generator()
generator.manual_seed(seed)
permutation = torch.randperm(num_mini_batch, generator=generator).tolist()
ordered_ids = [ordered_ids[idx] for idx in permutation]

mini_batches = []
for mini_batch_id in ordered_ids:
indices = torch.nonzero(mini_batch_ids == mini_batch_id, as_tuple=False).flatten()
if indices.numel() == 0:
continue
mini_batches.append(data.select_idxs(indices))
return mini_batches


def get_mini_batch_global_info(mini_batch: Any) -> dict[str, Any]:
"""Return loss normalization metadata for one planned mini-batch.

Args:
mini_batch (Any): A DataProto-like object containing mini-batch metadata fields.

Returns:
dict[str, Any]: Global mini-batch size and token counts.
"""
first_idx = 0
global_size = int(mini_batch.batch["mini_batch_global_size"][first_idx].item())
token_nums = mini_batch.batch["mini_batch_global_token_num"][first_idx]
response_token_num = int(mini_batch.batch["mini_batch_global_response_token_num"][first_idx].item())
return {
"global_batch_size": global_size,
"batch_num_tokens": response_token_num,
"global_token_num": token_nums[token_nums > 0].tolist(),
}


def _valid_indices(data: Any) -> list[int]:
sample_mask = data.batch.get("sample_mask", None)
if sample_mask is None:
return list(range(len(data)))
mask = sample_mask.detach().cpu().to(dtype=torch.bool).tolist()
return [idx for idx, is_valid in enumerate(mask) if is_valid]


def _build_trajectory_batches(data: Any, valid_indices: list[int], mini_batch_size: int) -> list[list[list[int]]]:
trajectory_uids = data.non_tensor_batch.get("trajectory_uids")
if trajectory_uids is None:
row_groups = [[idx] for idx in valid_indices]
return _chunk_groups(row_groups, mini_batch_size)

groups: OrderedDict[Any, list[int]] = OrderedDict()
for idx in valid_indices:
groups.setdefault(trajectory_uids[idx], []).append(idx)
return _chunk_groups(list(groups.values()), mini_batch_size)


def _chunk_groups(groups: list[list[int]], chunk_size: int) -> list[list[list[int]]]:
return [groups[idx : idx + chunk_size] for idx in range(0, len(groups), chunk_size)]


def _build_rank_ordered_entries(mini_batches: list[list[list[int]]], dp_size: int) -> list[_Entry]:
per_rank_entries: list[list[_Entry]] = [[] for _ in range(dp_size)]

for mini_batch_id, trajectory_groups in enumerate(mini_batches):
per_rank_for_mini_batch: list[list[_Entry]] = [[] for _ in range(dp_size)]
pad_source_idx = trajectory_groups[0][0]

for group_idx, row_indices in enumerate(trajectory_groups):
rank = group_idx % dp_size
per_rank_for_mini_batch[rank].extend(
_Entry(source_idx=row_idx, mini_batch_id=mini_batch_id, is_padding=False) for row_idx in row_indices
)

max_local_rows = max(1, *(len(entries) for entries in per_rank_for_mini_batch))
for rank_entries in per_rank_for_mini_batch:
while len(rank_entries) < max_local_rows:
rank_entries.append(_Entry(source_idx=pad_source_idx, mini_batch_id=mini_batch_id, is_padding=True))

for rank, rank_entries in enumerate(per_rank_for_mini_batch):
per_rank_entries[rank].extend(rank_entries)

return [entry for rank_entries in per_rank_entries for entry in rank_entries]


def _batch_device(batch: Any) -> torch.device:
for value in batch.values():
if torch.is_tensor(value):
return value.device
return torch.device("cpu")


def _zero_padding_loss_masks(batch: Any, padding_mask: torch.Tensor) -> None:
if not padding_mask.any():
return
for key in ("response_mask", "loss_mask"):
if key in batch:
batch[key][padding_mask] = 0


def _assign_global_mini_batch_info(
prepared: Any,
source_batch: Any,
mini_batches: list[list[list[int]]],
mini_batch_ids: torch.Tensor,
device: torch.device,
) -> None:
row_counts = [sum(len(group) for group in mini_batch) for mini_batch in mini_batches]
max_rows = max(row_counts)
global_sizes = torch.tensor(row_counts, dtype=torch.long, device=device)
prepared.batch["mini_batch_global_size"] = global_sizes[mini_batch_ids]

token_num_table = torch.zeros((len(mini_batches), max_rows), dtype=torch.long, device=device)
response_token_nums = torch.zeros(len(mini_batches), dtype=torch.long, device=device)
attention_mask = source_batch.get("attention_mask", None)
response_mask = source_batch.get("response_mask", None)

for mini_batch_id, mini_batch in enumerate(mini_batches):
row_indices = [row_idx for group in mini_batch for row_idx in group]
if attention_mask is not None:
source_token_nums = attention_mask.new_tensor(
[_sum_row_tokens(attention_mask, row_idx) for row_idx in row_indices], dtype=torch.long
)
token_num_table[mini_batch_id, : len(row_indices)] = source_token_nums.to(device)
if response_mask is not None:
response_token_nums[mini_batch_id] = int(
sum(_sum_row_tokens(response_mask, row_idx) for row_idx in row_indices)
)

prepared.batch["mini_batch_global_token_num"] = token_num_table[mini_batch_ids]
prepared.batch["mini_batch_global_response_token_num"] = response_token_nums[mini_batch_ids]


def _sum_row_tokens(tensor: torch.Tensor, row_idx: int) -> int:
return int(tensor[row_idx].sum().detach().cpu().item())
Comment on lines +175 to +209
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of _assign_global_mini_batch_info uses a loop over all rows in the mini-batch and calls _sum_row_tokens, which performs .detach().cpu().item() on each row. This causes multiple GPU-CPU synchronizations per step, creating a massive performance bottleneck. We can completely avoid this by precomputing the sum of tokens for all rows in a single vectorized operation on the GPU.

def _assign_global_mini_batch_info(
    prepared: Any,
    source_batch: Any,
    mini_batches: list[list[list[int]]],
    mini_batch_ids: torch.Tensor,
    device: torch.device,
) -> None:
    row_counts = [sum(len(group) for group in mini_batch) for mini_batch in mini_batches]
    max_rows = max(row_counts)
    global_sizes = torch.tensor(row_counts, dtype=torch.long, device=device)
    prepared.batch["mini_batch_global_size"] = global_sizes[mini_batch_ids]

    token_num_table = torch.zeros((len(mini_batches), max_rows), dtype=torch.long, device=device)
    response_token_nums = torch.zeros(len(mini_batches), dtype=torch.long, device=device)
    attention_mask = source_batch.get("attention_mask", None)
    response_mask = source_batch.get("response_mask", None)

    attention_token_counts = attention_mask.sum(dim=-1) if attention_mask is not None else None
    response_token_counts = response_mask.sum(dim=-1) if response_mask is not None else None

    for mini_batch_id, mini_batch in enumerate(mini_batches):
        row_indices = [row_idx for group in mini_batch for row_idx in group]
        if attention_token_counts is not None:
            source_token_nums = attention_token_counts[row_indices]
            token_num_table[mini_batch_id, : len(row_indices)] = source_token_nums
        if response_token_counts is not None:
            response_token_nums[mini_batch_id] = response_token_counts[row_indices].sum()

    prepared.batch["mini_batch_global_token_num"] = token_num_table[mini_batch_ids]
    prepared.batch["mini_batch_global_response_token_num"] = response_token_nums[mini_batch_ids]

Loading