diff --git a/agent_r1/trainer/ppo/core_algos.py b/agent_r1/trainer/ppo/core_algos.py index 8f74861..15a75fc 100644 --- a/agent_r1/trainer/ppo/core_algos.py +++ b/agent_r1/trainer/ppo/core_algos.py @@ -383,13 +383,25 @@ def compute_value_loss( response_mask: torch.Tensor, cliprange_value: float, loss_agg_mode: str = "token-mean", + dp_size: int = 1, + batch_num_tokens: Optional[int] = None, + global_batch_size: Optional[int] = None, + loss_scale_factor: Optional[int] = None, ): """Local value loss that uses pad-aware `agg_loss`.""" vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value) vf_losses1 = (vpreds - returns) ** 2 vf_losses2 = (vpredclipped - returns) ** 2 clipped_vf_losses = torch.max(vf_losses1, vf_losses2) - vf_loss = 0.5 * agg_loss(loss_mat=clipped_vf_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + vf_loss = 0.5 * agg_loss( + loss_mat=clipped_vf_losses, + loss_mask=response_mask, + loss_agg_mode=loss_agg_mode, + dp_size=dp_size, + batch_num_tokens=batch_num_tokens, + global_batch_size=global_batch_size, + loss_scale_factor=loss_scale_factor, + ) vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), response_mask) return vf_loss, vf_clipfrac diff --git a/agent_r1/trainer/ppo/ray_trainer.py b/agent_r1/trainer/ppo/ray_trainer.py index f1dc480..3acd78f 100644 --- a/agent_r1/trainer/ppo/ray_trainer.py +++ b/agent_r1/trainer/ppo/ray_trainer.py @@ -35,6 +35,7 @@ from tqdm import tqdm from agent_r1.trainer.ppo.metric_utils import compute_data_metrics +from agent_r1.trainer.ppo.trajectory_batching import prepare_trajectory_mini_batch from verl import DataProto from verl.experimental.dataset.sampler import AbstractCurriculumSampler from verl.protocol import pad_dataproto_to_divisor @@ -286,17 +287,17 @@ def _update_actor(self, batch: DataProto) -> DataProto: rollout_config = self.config.actor_rollout_ref.rollout batch.meta_info["multi_turn"] = rollout_config.multi_turn.enable batch.meta_info["temperature"] = rollout_config.temperature + ppo_mini_batch_size = self.config.actor_rollout_ref.actor.ppo_mini_batch_size + ppo_mini_batch_size = ppo_mini_batch_size * self.config.actor_rollout_ref.rollout.n if self.use_legacy_worker_impl == "disable": from verl.utils import tensordict_utils as tu from verl.utils.py_functional import rename_dict from verl.workers.utils.padding import left_right_2_no_padding calculate_entropy = self.config.actor_rollout_ref.actor.entropy_coeff != 0.0 - ppo_mini_batch_size = self.config.actor_rollout_ref.actor.ppo_mini_batch_size - ppo_mini_batch_size = ppo_mini_batch_size * self.config.actor_rollout_ref.rollout.n dp_size = self._get_dp_size(self.actor_rollout_wg, "actor") - assign_global_mini_batch_ids(batch, mini_batch_size=ppo_mini_batch_size, dp_size=dp_size) - batch_td = batch.to_tensordict() + update_batch = prepare_trajectory_mini_batch(batch, mini_batch_size=ppo_mini_batch_size, dp_size=dp_size) + batch_td = update_batch.to_tensordict() batch_td = left_right_2_no_padding(batch_td) ppo_epochs = self.config.actor_rollout_ref.actor.ppo_epochs seed = self.config.actor_rollout_ref.actor.data_loader_seed @@ -305,6 +306,7 @@ def _update_actor(self, batch: DataProto) -> DataProto: batch_td, calculate_entropy=calculate_entropy, mini_batch_size=ppo_mini_batch_size, + num_mini_batch=update_batch.meta_info["num_mini_batch"], epochs=ppo_epochs, seed=seed, dataloader_kwargs={"shuffle": shuffle}, @@ -316,20 +318,22 @@ def _update_actor(self, batch: DataProto) -> DataProto: actor_output["perf/mfu/actor"] = actor_output.pop("actor/mfu") actor_output = DataProto.from_single_dict(data={}, meta_info={"metrics": actor_output}) else: - actor_output = self.actor_rollout_wg.update_actor(batch) + dp_size = self._get_worker_group_dp_size(self.actor_rollout_wg, ("actor",)) + update_batch = prepare_trajectory_mini_batch(batch, mini_batch_size=ppo_mini_batch_size, dp_size=dp_size) + actor_output = self.actor_rollout_wg.update_actor(update_batch) return actor_output def _update_critic(self, batch: DataProto) -> DataProto: + ppo_mini_batch_size = self.config.critic.ppo_mini_batch_size + ppo_mini_batch_size = ppo_mini_batch_size * self.config.actor_rollout_ref.rollout.n if self.use_legacy_worker_impl == "disable": from verl.utils import tensordict_utils as tu from verl.utils.py_functional import rename_dict from verl.workers.utils.padding import left_right_2_no_padding - ppo_mini_batch_size = self.config.critic.ppo_mini_batch_size - ppo_mini_batch_size = ppo_mini_batch_size * self.config.actor_rollout_ref.rollout.n dp_size = self._get_worker_group_dp_size(self.critic_wg, ("train", "critic")) - assign_global_mini_batch_ids(batch, mini_batch_size=ppo_mini_batch_size, dp_size=dp_size) - batch_td = batch.to_tensordict() + update_batch = prepare_trajectory_mini_batch(batch, mini_batch_size=ppo_mini_batch_size, dp_size=dp_size) + batch_td = update_batch.to_tensordict() batch_td = left_right_2_no_padding(batch_td) ppo_epochs = self.config.critic.ppo_epochs seed = self.config.critic.data_loader_seed @@ -337,6 +341,7 @@ def _update_critic(self, batch: DataProto) -> DataProto: tu.assign_non_tensor( batch_td, mini_batch_size=ppo_mini_batch_size, + num_mini_batch=update_batch.meta_info["num_mini_batch"], epochs=ppo_epochs, seed=seed, dataloader_kwargs={"shuffle": shuffle}, @@ -349,7 +354,9 @@ def _update_critic(self, batch: DataProto) -> DataProto: output["perf/mfu/critic"] = output.pop("critic/mfu") output = DataProto.from_single_dict(data={}, meta_info={"metrics": output}) else: - output = self.critic_wg.update_critic(batch) + dp_size = self._get_worker_group_dp_size(self.critic_wg, ("critic",)) + update_batch = prepare_trajectory_mini_batch(batch, mini_batch_size=ppo_mini_batch_size, dp_size=dp_size) + output = self.critic_wg.update_critic(update_batch) return output def _get_worker_group_dp_size(self, worker_group, roles: Sequence[str]) -> int: diff --git a/agent_r1/trainer/ppo/trajectory_batching.py b/agent_r1/trainer/ppo/trajectory_batching.py new file mode 100644 index 0000000..0774a53 --- /dev/null +++ b/agent_r1/trainer/ppo/trajectory_batching.py @@ -0,0 +1,209 @@ +"""Trajectory-aware PPO mini-batch helpers.""" + +from __future__ import annotations + +from collections import OrderedDict +from dataclasses import dataclass +from typing import Any + +import torch + + +@dataclass(frozen=True) +class _Entry: + source_idx: int + mini_batch_id: int + is_padding: bool + + +def prepare_trajectory_mini_batch(data: Any, mini_batch_size: int, dp_size: int) -> Any: + """Build an update batch whose PPO mini-batches preserve whole trajectories. + + Args: + data (Any): A DataProto-like object with `batch`, `non_tensor_batch`, and `select_idxs`. + mini_batch_size (int): Target number of trajectories per PPO mini-batch. + dp_size (int): Data parallel size used by the training worker dispatch. + + Returns: + Any: A DataProto-like object with mini-batch metadata and update-only padding rows. + """ + if mini_batch_size <= 0: + raise ValueError(f"mini_batch_size must be positive, got {mini_batch_size}") + if dp_size <= 0: + raise ValueError(f"dp_size must be positive, got {dp_size}") + if len(data) == 0: + return data + + valid_indices = _valid_indices(data) + if not valid_indices: + raise ValueError("trajectory mini-batching requires at least one valid row") + + mini_batches = _build_trajectory_batches(data, valid_indices, mini_batch_size) + entries = _build_rank_ordered_entries(mini_batches, dp_size) + source_indices = [entry.source_idx for entry in entries] + prepared = data.select_idxs(source_indices) + + device = _batch_device(prepared.batch) + mini_batch_ids = torch.tensor([entry.mini_batch_id for entry in entries], dtype=torch.long, device=device) + padding_mask = torch.tensor([entry.is_padding for entry in entries], dtype=torch.bool, device=device) + + prepared.batch["mini_batch_id"] = mini_batch_ids + prepared.batch["sample_mask"] = ~padding_mask + _zero_padding_loss_masks(prepared.batch, padding_mask) + _assign_global_mini_batch_info(prepared, data.batch, mini_batches, mini_batch_ids, device) + + prepared.meta_info = dict(getattr(prepared, "meta_info", {})) + prepared.meta_info["num_mini_batch"] = len(mini_batches) + return prepared + + +def split_data_proto_by_mini_batch_id(data: Any, *, shuffle: bool = False, seed: int = 42) -> list[Any]: + """Split a local worker batch using precomputed `mini_batch_id` values. + + Args: + data (Any): A DataProto-like object containing `batch["mini_batch_id"]`. + shuffle (bool): Whether to shuffle mini-batch id order. + seed (int): Deterministic seed for shuffling. + + Returns: + list[Any]: DataProto-like mini-batches, each containing exactly one mini-batch id. + """ + if "mini_batch_id" not in data.batch: + raise KeyError("mini_batch_id is required for trajectory-aware mini-batch splitting") + + mini_batch_ids = data.batch["mini_batch_id"].detach().cpu() + num_mini_batch = int(getattr(data, "meta_info", {}).get("num_mini_batch", mini_batch_ids.max().item() + 1)) + ordered_ids = list(range(num_mini_batch)) + if shuffle: + generator = torch.Generator() + generator.manual_seed(seed) + permutation = torch.randperm(num_mini_batch, generator=generator).tolist() + ordered_ids = [ordered_ids[idx] for idx in permutation] + + mini_batches = [] + for mini_batch_id in ordered_ids: + indices = torch.nonzero(mini_batch_ids == mini_batch_id, as_tuple=False).flatten() + if indices.numel() == 0: + continue + mini_batches.append(data.select_idxs(indices)) + return mini_batches + + +def get_mini_batch_global_info(mini_batch: Any) -> dict[str, Any]: + """Return loss normalization metadata for one planned mini-batch. + + Args: + mini_batch (Any): A DataProto-like object containing mini-batch metadata fields. + + Returns: + dict[str, Any]: Global mini-batch size and token counts. + """ + first_idx = 0 + global_size = int(mini_batch.batch["mini_batch_global_size"][first_idx].item()) + token_nums = mini_batch.batch["mini_batch_global_token_num"][first_idx] + response_token_num = int(mini_batch.batch["mini_batch_global_response_token_num"][first_idx].item()) + return { + "global_batch_size": global_size, + "batch_num_tokens": response_token_num, + "global_token_num": token_nums[token_nums > 0].tolist(), + } + + +def _valid_indices(data: Any) -> list[int]: + sample_mask = data.batch.get("sample_mask", None) + if sample_mask is None: + return list(range(len(data))) + mask = sample_mask.detach().cpu().to(dtype=torch.bool).tolist() + return [idx for idx, is_valid in enumerate(mask) if is_valid] + + +def _build_trajectory_batches(data: Any, valid_indices: list[int], mini_batch_size: int) -> list[list[list[int]]]: + trajectory_uids = data.non_tensor_batch.get("trajectory_uids") + if trajectory_uids is None: + row_groups = [[idx] for idx in valid_indices] + return _chunk_groups(row_groups, mini_batch_size) + + groups: OrderedDict[Any, list[int]] = OrderedDict() + for idx in valid_indices: + groups.setdefault(trajectory_uids[idx], []).append(idx) + return _chunk_groups(list(groups.values()), mini_batch_size) + + +def _chunk_groups(groups: list[list[int]], chunk_size: int) -> list[list[list[int]]]: + return [groups[idx : idx + chunk_size] for idx in range(0, len(groups), chunk_size)] + + +def _build_rank_ordered_entries(mini_batches: list[list[list[int]]], dp_size: int) -> list[_Entry]: + per_rank_entries: list[list[_Entry]] = [[] for _ in range(dp_size)] + + for mini_batch_id, trajectory_groups in enumerate(mini_batches): + per_rank_for_mini_batch: list[list[_Entry]] = [[] for _ in range(dp_size)] + pad_source_idx = trajectory_groups[0][0] + + for group_idx, row_indices in enumerate(trajectory_groups): + rank = group_idx % dp_size + per_rank_for_mini_batch[rank].extend( + _Entry(source_idx=row_idx, mini_batch_id=mini_batch_id, is_padding=False) for row_idx in row_indices + ) + + max_local_rows = max(1, *(len(entries) for entries in per_rank_for_mini_batch)) + for rank_entries in per_rank_for_mini_batch: + while len(rank_entries) < max_local_rows: + rank_entries.append(_Entry(source_idx=pad_source_idx, mini_batch_id=mini_batch_id, is_padding=True)) + + for rank, rank_entries in enumerate(per_rank_for_mini_batch): + per_rank_entries[rank].extend(rank_entries) + + return [entry for rank_entries in per_rank_entries for entry in rank_entries] + + +def _batch_device(batch: Any) -> torch.device: + for value in batch.values(): + if torch.is_tensor(value): + return value.device + return torch.device("cpu") + + +def _zero_padding_loss_masks(batch: Any, padding_mask: torch.Tensor) -> None: + if not padding_mask.any(): + return + for key in ("response_mask", "loss_mask"): + if key in batch: + batch[key][padding_mask] = 0 + + +def _assign_global_mini_batch_info( + prepared: Any, + source_batch: Any, + mini_batches: list[list[list[int]]], + mini_batch_ids: torch.Tensor, + device: torch.device, +) -> None: + row_counts = [sum(len(group) for group in mini_batch) for mini_batch in mini_batches] + max_rows = max(row_counts) + global_sizes = torch.tensor(row_counts, dtype=torch.long, device=device) + prepared.batch["mini_batch_global_size"] = global_sizes[mini_batch_ids] + + token_num_table = torch.zeros((len(mini_batches), max_rows), dtype=torch.long, device=device) + response_token_nums = torch.zeros(len(mini_batches), dtype=torch.long, device=device) + attention_mask = source_batch.get("attention_mask", None) + response_mask = source_batch.get("response_mask", None) + + for mini_batch_id, mini_batch in enumerate(mini_batches): + row_indices = [row_idx for group in mini_batch for row_idx in group] + if attention_mask is not None: + source_token_nums = attention_mask.new_tensor( + [_sum_row_tokens(attention_mask, row_idx) for row_idx in row_indices], dtype=torch.long + ) + token_num_table[mini_batch_id, : len(row_indices)] = source_token_nums.to(device) + if response_mask is not None: + response_token_nums[mini_batch_id] = int( + sum(_sum_row_tokens(response_mask, row_idx) for row_idx in row_indices) + ) + + prepared.batch["mini_batch_global_token_num"] = token_num_table[mini_batch_ids] + prepared.batch["mini_batch_global_response_token_num"] = response_token_nums[mini_batch_ids] + + +def _sum_row_tokens(tensor: torch.Tensor, row_idx: int) -> int: + return int(tensor[row_idx].sum().detach().cpu().item()) diff --git a/agent_r1/workers/actor/dp_actor.py b/agent_r1/workers/actor/dp_actor.py index a7c0578..86946f8 100644 --- a/agent_r1/workers/actor/dp_actor.py +++ b/agent_r1/workers/actor/dp_actor.py @@ -18,7 +18,10 @@ import logging import os +import torch + from agent_r1.trainer.ppo.core_algos import agg_loss, get_policy_loss_fn +from agent_r1.trainer.ppo.trajectory_batching import get_mini_batch_global_info, split_data_proto_by_mini_batch_id from verl import DataProto from verl.trainer.ppo.core_algos import kl_penalty from verl.utils.device import get_device_id @@ -53,12 +56,26 @@ def update_policy(self, data: DataProto): select_keys.append("rollout_is_weights") if "rollout_log_probs" in data.batch.keys(): select_keys.append("rollout_log_probs") + has_planned_mini_batches = "mini_batch_id" in data.batch.keys() + if has_planned_mini_batches: + select_keys.extend( + [ + "mini_batch_id", + "mini_batch_global_size", + "mini_batch_global_token_num", + "mini_batch_global_response_token_num", + "sample_mask", + ] + ) has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else [] data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) - mini_batches = data.split(self.config.ppo_mini_batch_size) + if has_planned_mini_batches: + mini_batches = split_data_proto_by_mini_batch_id(data) + else: + mini_batches = data.split(self.config.ppo_mini_batch_size) on_policy = len(mini_batches) == 1 and self.config.ppo_epochs == 1 metrics = { @@ -67,13 +84,35 @@ def update_policy(self, data: DataProto): } for _ in range(self.config.ppo_epochs): for mini_batch in mini_batches: + use_global_mini_batch_info = "mini_batch_global_size" in mini_batch.batch.keys() + if use_global_mini_batch_info: + global_info = get_mini_batch_global_info(mini_batch) + if not hasattr(self.config, "global_batch_info") or self.config.global_batch_info is None: + self.config.global_batch_info = {} + dp_size = ( + torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size + if torch.distributed.is_initialized() + else 1 + ) + self.config.global_batch_info.update( + { + "dp_size": dp_size, + "batch_num_tokens": global_info["batch_num_tokens"], + "global_batch_size": global_info["global_batch_size"], + "loss_scale_factor": self.config.loss_scale_factor, + } + ) + elif hasattr(self.config, "global_batch_info"): + self.config.global_batch_info.clear() + if self.config.use_dynamic_bsz: max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size micro_batches, _ = prepare_dynamic_batch(mini_batch, max_token_len=max_token_len) else: micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu) - micro_batches = [mb for mb in micro_batches if bool(mb.batch["response_mask"].any().item())] + if not has_planned_mini_batches: + micro_batches = [mb for mb in micro_batches if bool(mb.batch["response_mask"].any().item())] if not micro_batches: append_to_dict(metrics, {"actor/grad_norm": 0.0}) continue @@ -94,7 +133,9 @@ def update_policy(self, data: DataProto): loss_agg_mode = self.config.loss_agg_mode calculate_entropy = self.config.calculate_entropy or (entropy_coeff != 0) - if self.config.use_dynamic_bsz: + if use_global_mini_batch_info: + loss_scale_factor = 1.0 + elif self.config.use_dynamic_bsz: loss_scale_factor = response_mask.shape[0] / self.config.ppo_mini_batch_size else: loss_scale_factor = 1 / self.gradient_accumulation @@ -135,7 +176,12 @@ def update_policy(self, data: DataProto): policy_loss = pg_loss if calculate_entropy and entropy is not None: - entropy_agg = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + entropy_agg = agg_loss( + loss_mat=entropy, + loss_mask=response_mask, + loss_agg_mode=loss_agg_mode, + **getattr(self.config, "global_batch_info", {}), + ) micro_batch_metrics["actor/entropy"] = entropy_agg.detach().item() if entropy_coeff != 0: policy_loss -= entropy_agg * entropy_coeff @@ -145,7 +191,12 @@ def update_policy(self, data: DataProto): kld = kl_penalty( logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type ) - kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + kl_loss = agg_loss( + loss_mat=kld, + loss_mask=response_mask, + loss_agg_mode=loss_agg_mode, + **getattr(self.config, "global_batch_info", {}), + ) policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef metrics["actor/kl_loss"] += kl_loss.detach().item() * loss_scale_factor micro_batch_metrics["actor/kl_coef"] = self.config.kl_loss_coef diff --git a/agent_r1/workers/critic/dp_critic.py b/agent_r1/workers/critic/dp_critic.py index e77ed78..8c4294c 100644 --- a/agent_r1/workers/critic/dp_critic.py +++ b/agent_r1/workers/critic/dp_critic.py @@ -18,7 +18,10 @@ import logging import os +import torch + from agent_r1.trainer.ppo.core_algos import compute_value_loss +from agent_r1.trainer.ppo.trajectory_batching import get_mini_batch_global_info, split_data_proto_by_mini_batch_id from verl import DataProto from verl.utils.device import get_device_id from verl.utils.profiler import GPUMemoryLogger @@ -40,21 +43,52 @@ def update_critic(self, data: DataProto): } select_keys = ["input_ids", "responses", "response_mask", "attention_mask", "position_ids", "values", "returns"] + has_planned_mini_batches = "mini_batch_id" in data.batch.keys() + if has_planned_mini_batches: + select_keys.extend( + [ + "mini_batch_id", + "mini_batch_global_size", + "mini_batch_global_token_num", + "mini_batch_global_response_token_num", + "sample_mask", + ] + ) + has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else [] data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) - mini_batches = data.split(self.config.ppo_mini_batch_size) + if has_planned_mini_batches: + mini_batches = split_data_proto_by_mini_batch_id(data) + else: + mini_batches = data.split(self.config.ppo_mini_batch_size) for _ in range(self.config.ppo_epochs): for mini_batch in mini_batches: + use_global_mini_batch_info = "mini_batch_global_size" in mini_batch.batch.keys() + global_batch_info = {} + if use_global_mini_batch_info: + global_info = get_mini_batch_global_info(mini_batch) + dp_size = ( + torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size + if torch.distributed.is_initialized() + else 1 + ) + global_batch_info = { + "dp_size": dp_size, + "batch_num_tokens": global_info["batch_num_tokens"], + "global_batch_size": global_info["global_batch_size"], + } + if self.config.use_dynamic_bsz: max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size micro_batches, _ = prepare_dynamic_batch(mini_batch, max_token_len=max_token_len) else: micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu) - micro_batches = [mb for mb in micro_batches if bool(mb.batch["response_mask"].any().item())] + if not has_planned_mini_batches: + micro_batches = [mb for mb in micro_batches if bool(mb.batch["response_mask"].any().item())] if not micro_batches: append_to_dict(metrics, {"critic/grad_norm": 0.0}) continue @@ -79,8 +113,11 @@ def update_critic(self, data: DataProto): response_mask=response_mask, cliprange_value=self.config.cliprange_value, loss_agg_mode=self.config.loss_agg_mode, + **global_batch_info, ) - if self.config.use_dynamic_bsz: + if use_global_mini_batch_info: + loss_scale_factor = 1.0 + elif self.config.use_dynamic_bsz: loss_scale_factor = response_mask.shape[0] / self.config.ppo_mini_batch_size else: loss_scale_factor = 1 / self.gradient_accumulation diff --git a/agent_r1/workers/engine_workers.py b/agent_r1/workers/engine_workers.py index 248ad62..817f3b3 100644 --- a/agent_r1/workers/engine_workers.py +++ b/agent_r1/workers/engine_workers.py @@ -120,7 +120,11 @@ def train_mini_batch(self, data: TensorDict) -> TensorDict: assert mini_batch_size is not None or num_mini_batch is not None assert dataloader_kwargs.keys() <= {"shuffle"}, f"Unsupported dataloader_kwargs: {dataloader_kwargs.keys()}" - unique_mini_batch_ids = torch.unique(mini_batch_ids, sorted=True).cpu() + if num_mini_batch is not None: + num_mini_batch = int(num_mini_batch) + unique_mini_batch_ids = torch.arange(num_mini_batch, dtype=torch.long) + else: + unique_mini_batch_ids = torch.unique(mini_batch_ids, sorted=True).cpu() total_num_iterations = len(unique_mini_batch_ids) * epochs shuffle = dataloader_kwargs.get("shuffle", False) diff --git a/agent_r1/workers/utils/losses.py b/agent_r1/workers/utils/losses.py index fa7a0aa..ab062b4 100644 --- a/agent_r1/workers/utils/losses.py +++ b/agent_r1/workers/utils/losses.py @@ -130,6 +130,10 @@ def value_loss(config: CriticConfig, model_output, data: TensorDict, dp_group=No values = data["values"] returns = data["returns"] response_mask = data["response_mask"].to(bool) + global_batch_info = {} + for key in ("dp_size", "batch_num_tokens", "global_batch_size"): + if key in data.keys(): + global_batch_info[key] = data[key] vf_loss, vf_clipfrac = compute_value_loss( vpreds=vpreds, @@ -138,6 +142,7 @@ def value_loss(config: CriticConfig, model_output, data: TensorDict, dp_group=No response_mask=response_mask, cliprange_value=config.cliprange_value, loss_agg_mode=config.loss_agg_mode, + **global_batch_info, ) metrics = {