From 12dff7fb97fd0ea521c1283a981d8c8606c98a92 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Thu, 16 Apr 2026 22:11:43 +0800 Subject: [PATCH 001/122] Align tracking env with the intended training and playback behavior Remove unused self-collision plumbing, update domain-randomization terms to match the active General-Tracking setup, and keep play-mode evaluation deterministic. The rigid-body perturbation now uses physics-consistent pseudo-inertia randomization instead of mass-only scaling. Constraint: Playback and benchmark paths use play mode and must stay deterministic Constraint: Domain randomization must stay within current mjlab APIs and repo conventions Rejected: Keep dr.body_mass randomization | leaves inertia inconsistent with mass changes Rejected: Preserve startup DR in play mode | makes playback and benchmark nondeterministic Confidence: high Scope-risk: narrow Reversibility: clean Directive: If training-only events change again, update _TRAIN_ONLY_EVENTS and the task registry/domain-randomization tests together Tested: pytest tests/test_task_registry.py tests/test_domain_randomization.py -q Tested: git diff --check Not-tested: pytest tests/ -q | collection currently fails in tests/test_math_utils.py because train_mimic.pose is missing in this tree --- .gitignore | 1 + tests/test_domain_randomization.py | 82 +++++++++++++++++++ tests/test_task_registry.py | 3 + train_mimic/tasks/tracking/config/env.py | 30 +++---- train_mimic/tasks/tracking/mdp/rewards.py | 17 ---- .../tasks/tracking/tracking_env_cfg.py | 20 +++-- 6 files changed, 112 insertions(+), 41 deletions(-) create mode 100644 tests/test_domain_randomization.py diff --git a/.gitignore b/.gitignore index da8602fb..00126fc3 100644 --- a/.gitignore +++ b/.gitignore @@ -100,3 +100,4 @@ data/modelscope_upload/ # Codex local workspace state .codex +.omx/ diff --git a/tests/test_domain_randomization.py b/tests/test_domain_randomization.py new file mode 100644 index 00000000..96432d6e --- /dev/null +++ b/tests/test_domain_randomization.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +from train_mimic.app import DEFAULT_TASK +from train_mimic.tasks.tracking import mdp + + +def test_general_tracking_domain_randomization_matches_gr00t_active_set() -> None: + import mjlab.tasks # noqa: F401 + import train_mimic.tasks # noqa: F401 + from mjlab.envs.mdp import dr + from mjlab.tasks.registry import load_env_cfg + + env_cfg = load_env_cfg(DEFAULT_TASK) + events = env_cfg.events + + assert set(events) == { + "push_robot", + "base_com", + "encoder_bias", + "physics_material", + "randomize_rigid_body_mass", + } + + push_robot = events["push_robot"] + assert push_robot.func is mdp.push_by_setting_velocity + assert push_robot.mode == "interval" + assert push_robot.interval_range_s == (4.0, 6.0) + assert push_robot.params["velocity_range"] == { + "x": (-0.5, 0.5), + "y": (-0.5, 0.5), + "z": (-0.2, 0.2), + "roll": (-0.52, 0.52), + "pitch": (-0.52, 0.52), + "yaw": (-0.78, 0.78), + } + + base_com = events["base_com"] + assert base_com.func is dr.body_com_offset + assert base_com.mode == "startup" + assert base_com.params["asset_cfg"].body_names == ("torso_link",) + assert base_com.params["operation"] == "add" + assert base_com.params["ranges"] == { + 0: (-0.025, 0.025), + 1: (-0.05, 0.05), + 2: (-0.05, 0.05), + } + + encoder_bias = events["encoder_bias"] + assert encoder_bias.func is dr.encoder_bias + assert encoder_bias.mode == "startup" + assert encoder_bias.params["bias_range"] == (-0.01, 0.01) + + physics_material = events["physics_material"] + assert physics_material.func is dr.geom_friction + assert physics_material.mode == "startup" + assert physics_material.params["asset_cfg"].geom_names == r".*_collision$" + assert physics_material.params["operation"] == "abs" + assert physics_material.params["ranges"] == (0.3, 1.6) + + mass = events["randomize_rigid_body_mass"] + assert mass.func is dr.pseudo_inertia + assert mass.mode == "startup" + assert mass.params["asset_cfg"].body_names == r".*wrist_yaw.*|torso_link" + assert mass.params["alpha_range"] == ( + -0.11157177565710488, + 0.4581453659370775, + ) + + +def test_play_env_disables_training_only_domain_randomization() -> None: + import mjlab.tasks # noqa: F401 + import train_mimic.tasks # noqa: F401 + from mjlab.tasks.registry import load_env_cfg + + play_cfg = load_env_cfg(DEFAULT_TASK, play=True) + + assert "push_robot" not in play_cfg.events + assert "base_com" not in play_cfg.events + assert "encoder_bias" not in play_cfg.events + assert "physics_material" not in play_cfg.events + assert "randomize_rigid_body_mass" not in play_cfg.events + assert play_cfg.events == {} diff --git a/tests/test_task_registry.py b/tests/test_task_registry.py index 43c1cac1..ff5b3096 100644 --- a/tests/test_task_registry.py +++ b/tests/test_task_registry.py @@ -33,6 +33,9 @@ def test_general_tracking_task_is_registered() -> None: assert "critic_history" in env_cfg.observations assert env_cfg.commands["motion"].sampling_mode == "uniform" assert env_cfg.commands["motion"].window_steps == (0,) + assert "self_collisions" not in env_cfg.rewards + assert "undesired_contacts" not in env_cfg.rewards + assert not getattr(env_cfg.scene, "sensors", ()) rl_cfg = load_rl_cfg(DEFAULT_TASK) assert rl_cfg.experiment_name == GENERAL_TRACKING_EXPERIMENT_NAME assert rl_cfg.actor.hidden_dims == (1024, 512, 256, 256, 128) diff --git a/train_mimic/tasks/tracking/config/env.py b/train_mimic/tasks/tracking/config/env.py index d8aa5268..dbd2c25e 100644 --- a/train_mimic/tasks/tracking/config/env.py +++ b/train_mimic/tasks/tracking/config/env.py @@ -9,7 +9,6 @@ from mjlab.envs.mdp.actions import JointPositionActionCfg from mjlab.managers.observation_manager import ObservationGroupCfg, ObservationTermCfg from mjlab.managers.scene_entity_config import SceneEntityCfg -from mjlab.sensor import ContactMatch, ContactSensorCfg from mjlab.utils.noise import UniformNoiseCfg as Unoise from train_mimic.tasks.tracking import mdp @@ -34,6 +33,14 @@ "right_wrist_yaw_link", ) +_TRAIN_ONLY_EVENTS = ( + "push_robot", + "base_com", + "encoder_bias", + "physics_material", + "randomize_rigid_body_mass", +) + def _apply_play_mode_overrides(cfg: ManagerBasedRlEnvCfg) -> None: motion_cmd = cfg.commands["motion"] @@ -41,7 +48,8 @@ def _apply_play_mode_overrides(cfg: ManagerBasedRlEnvCfg) -> None: cfg.episode_length_s = int(1e9) cfg.observations["actor"].enable_corruption = False - cfg.events.pop("push_robot", None) + for event_name in _TRAIN_ONLY_EVENTS: + cfg.events.pop(event_name, None) motion_cmd.pose_range = {} motion_cmd.velocity_range = {} motion_cmd.sampling_mode = "start" @@ -109,17 +117,6 @@ def make_general_tracking_env_cfg( cfg = make_tracking_env_cfg() cfg.scene.entities = {"robot": get_g1_robot_cfg()} - cfg.scene.sensors = ( - ContactSensorCfg( - name="self_collision", - primary=ContactMatch(mode="subtree", pattern="pelvis", entity="robot"), - secondary=ContactMatch(mode="subtree", pattern="pelvis", entity="robot"), - fields=("found", "force"), - reduce="none", - num_slots=1, - history_length=4, - ), - ) joint_pos_action = cfg.actions["joint_pos"] assert isinstance(joint_pos_action, JointPositionActionCfg) @@ -133,10 +130,13 @@ def make_general_tracking_env_cfg( motion_cmd.sampling_mode = "uniform" motion_cmd.window_steps = (0,) - cfg.events["foot_friction"].params[ + cfg.events["physics_material"].params[ "asset_cfg" - ].geom_names = r"^(left|right)_foot[1-7]_collision$" + ].geom_names = r".*_collision$" cfg.events["base_com"].params["asset_cfg"].body_names = ("torso_link",) + cfg.events["randomize_rigid_body_mass"].params[ + "asset_cfg" + ].body_names = r".*wrist_yaw.*|torso_link" cfg.terminations["ee_body_pos"].params["body_names"] = ( "left_ankle_roll_link", "right_ankle_roll_link", diff --git a/train_mimic/tasks/tracking/mdp/rewards.py b/train_mimic/tasks/tracking/mdp/rewards.py index 108bf46b..dee3318a 100644 --- a/train_mimic/tasks/tracking/mdp/rewards.py +++ b/train_mimic/tasks/tracking/mdp/rewards.py @@ -7,7 +7,6 @@ from mjlab.entity import Entity from mjlab.managers.reward_manager import RewardTermCfg from mjlab.managers.scene_entity_config import SceneEntityCfg -from mjlab.sensor import ContactSensor from mjlab.utils.lab_api.math import ( quat_error_magnitude, ) @@ -144,22 +143,6 @@ def motion_global_body_angular_velocity_error_exp( return torch.exp(-error.mean(-1) / std**2) -def self_collision_cost( - env: ManagerBasedRlEnv, - sensor_name: str, - force_threshold: float = 10.0, -) -> torch.Tensor: - """Penalize self-collisions.""" - sensor: ContactSensor = env.scene[sensor_name] - data = sensor.data - if data.force_history is not None: - force_mag = torch.norm(data.force_history, dim=-1) - hit = (force_mag > force_threshold).any(dim=1) - return hit.sum(dim=-1).float() - assert data.found is not None - return data.found.squeeze(-1) - - class joint_torque_limits: """Penalize actuator-force limit violations with a configurable soft margin.""" diff --git a/train_mimic/tasks/tracking/tracking_env_cfg.py b/train_mimic/tasks/tracking/tracking_env_cfg.py index ce2e252b..d79b0e9b 100644 --- a/train_mimic/tasks/tracking/tracking_env_cfg.py +++ b/train_mimic/tasks/tracking/tracking_env_cfg.py @@ -163,7 +163,7 @@ def make_tracking_env_cfg() -> ManagerBasedRlEnvCfg: "push_robot": EventTermCfg( func=mdp.push_by_setting_velocity, mode="interval", - interval_range_s=(1.0, 3.0), + interval_range_s=(4.0, 6.0), params={"velocity_range": VELOCITY_RANGE}, ), "base_com": EventTermCfg( @@ -187,14 +187,21 @@ def make_tracking_env_cfg() -> ManagerBasedRlEnvCfg: "bias_range": (-0.01, 0.01), }, ), - "foot_friction": EventTermCfg( + "physics_material": EventTermCfg( mode="startup", func=dr.geom_friction, params={ "asset_cfg": SceneEntityCfg("robot", geom_names=()), # Set per-robot. "operation": "abs", - "ranges": (0.3, 1.2), - "shared_random": True, # All foot geoms share the same friction. + "ranges": (0.3, 1.6), + }, + ), + "randomize_rigid_body_mass": EventTermCfg( + mode="startup", + func=dr.pseudo_inertia, + params={ + "asset_cfg": SceneEntityCfg("robot", body_names=()), # Set per-robot. + "alpha_range": (-0.11157177565710488, 0.4581453659370775), }, ), } @@ -240,11 +247,6 @@ def make_tracking_env_cfg() -> ManagerBasedRlEnvCfg: weight=-10.0, params={"asset_cfg": SceneEntityCfg("robot", joint_names=(".*",))}, ), - "self_collisions": RewardTermCfg( - func=mdp.self_collision_cost, - weight=-10.0, - params={"sensor_name": "self_collision", "force_threshold": 10.0}, - ), } ## From b7a73e52427669a3ad1498a5a4f7b8bf1e41a9a7 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Fri, 17 Apr 2026 00:00:03 +0800 Subject: [PATCH 002/122] Avoid premature tracking resets on low-height motions General-Tracking-G1 now uses adaptive Z-only anchor and end-effector termination thresholds, while separating foot XYZ checks into an explicit termination term. Benchmark video mode also disables the new foot termination so long clip capture keeps the previous behavior. Constraint: Low-height reference poses need looser failure thresholds without weakening normal upright tracking Rejected: Raise all tracking thresholds globally | too permissive for upright motions and hides real failures Confidence: high Scope-risk: moderate Reversibility: clean Directive: Keep benchmark/video termination overrides in sync whenever new tracking failure terms are added Tested: pytest tests/test_termination_config.py tests/test_task_registry.py -q Not-tested: Full training run and benchmark rollout on real motion clips --- tests/test_termination_config.py | 108 ++++++++++++++++++ train_mimic/scripts/benchmark.py | 1 + train_mimic/tasks/tracking/config/env.py | 6 +- .../tasks/tracking/mdp/terminations.py | 34 ++++++ .../tasks/tracking/tracking_env_cfg.py | 23 +++- 5 files changed, 166 insertions(+), 6 deletions(-) create mode 100644 tests/test_termination_config.py diff --git a/tests/test_termination_config.py b/tests/test_termination_config.py new file mode 100644 index 00000000..1336070c --- /dev/null +++ b/tests/test_termination_config.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import torch + +from train_mimic.app import DEFAULT_TASK +from train_mimic.tasks.tracking import mdp + + +def test_general_tracking_termination_config_matches_requested_policy() -> None: + import mjlab.tasks # noqa: F401 + import train_mimic.tasks # noqa: F401 + from mjlab.tasks.registry import load_env_cfg + + env_cfg = load_env_cfg(DEFAULT_TASK) + terminations = env_cfg.terminations + + assert set(terminations) == { + "time_out", + "anchor_pos", + "anchor_ori", + "ee_body_pos", + "foot_pos_xyz", + } + + anchor_pos = terminations["anchor_pos"] + assert anchor_pos.func is mdp.bad_anchor_pos_z_only_adaptive + assert anchor_pos.params == { + "command_name": "motion", + "threshold": 0.15, + "down_threshold": 0.4, + "root_height_threshold": 0.5, + } + + anchor_ori = terminations["anchor_ori"] + assert anchor_ori.func is mdp.bad_anchor_ori + assert anchor_ori.params["threshold"] == 1.0 + + ee_body_pos = terminations["ee_body_pos"] + assert ee_body_pos.func is mdp.bad_motion_body_pos_z_only_adaptive + assert ee_body_pos.params == { + "command_name": "motion", + "threshold": 0.15, + "down_threshold": 0.4, + "root_height_threshold": 0.5, + "body_names": ( + "left_ankle_roll_link", + "right_ankle_roll_link", + "left_wrist_yaw_link", + "right_wrist_yaw_link", + ), + } + + foot_pos_xyz = terminations["foot_pos_xyz"] + assert foot_pos_xyz.func is mdp.bad_motion_body_pos + assert foot_pos_xyz.params == { + "command_name": "motion", + "threshold": 0.2, + "body_names": ( + "left_ankle_roll_link", + "right_ankle_roll_link", + ), + } + + +def test_adaptive_height_termination_uses_relaxed_threshold_for_low_reference() -> None: + command = SimpleNamespace( + cfg=SimpleNamespace(body_names=("left_ankle_roll_link",)), + anchor_pos_w=torch.tensor([[0.0, 0.0, 0.3], [0.0, 0.0, 0.8]], dtype=torch.float32), + robot_anchor_pos_w=torch.tensor([[0.0, 0.0, -0.09], [0.0, 0.0, 0.55]], dtype=torch.float32), + body_pos_relative_w=torch.tensor( + [ + [[0.0, 0.0, 0.30]], + [[0.0, 0.0, 0.80]], + ], + dtype=torch.float32, + ), + robot_body_pos_w=torch.tensor( + [ + [[0.0, 0.0, -0.09]], + [[0.0, 0.0, 0.55]], + ], + dtype=torch.float32, + ), + ) + env = SimpleNamespace( + command_manager=SimpleNamespace(get_term=lambda _name: command), + ) + + anchor_done = mdp.bad_anchor_pos_z_only_adaptive( + env, + "motion", + threshold=0.15, + down_threshold=0.4, + root_height_threshold=0.5, + ) + ee_done = mdp.bad_motion_body_pos_z_only_adaptive( + env, + "motion", + threshold=0.15, + down_threshold=0.4, + root_height_threshold=0.5, + body_names=("left_ankle_roll_link",), + ) + + assert anchor_done.tolist() == [False, True] + assert ee_done.tolist() == [False, True] diff --git a/train_mimic/scripts/benchmark.py b/train_mimic/scripts/benchmark.py index f9748e84..68a940d4 100644 --- a/train_mimic/scripts/benchmark.py +++ b/train_mimic/scripts/benchmark.py @@ -288,6 +288,7 @@ def main() -> int: env_cfg.terminations.pop("anchor_pos", None) env_cfg.terminations.pop("anchor_ori", None) env_cfg.terminations.pop("ee_body_pos", None) + env_cfg.terminations.pop("foot_pos_xyz", None) env_cfg.terminations.pop("body_z_tracking_failure", None) env_cfg.terminations.pop("gravity_tracking_failure", None) diff --git a/train_mimic/tasks/tracking/config/env.py b/train_mimic/tasks/tracking/config/env.py index dbd2c25e..0ae53539 100644 --- a/train_mimic/tasks/tracking/config/env.py +++ b/train_mimic/tasks/tracking/config/env.py @@ -143,9 +143,11 @@ def make_general_tracking_env_cfg( "left_wrist_yaw_link", "right_wrist_yaw_link", ) - cfg.terminations["anchor_pos"].params["threshold"] = 0.4 + cfg.terminations["foot_pos_xyz"].params["body_names"] = ( + "left_ankle_roll_link", + "right_ankle_roll_link", + ) cfg.terminations["anchor_ori"].params["threshold"] = 1.0 - cfg.terminations["ee_body_pos"].params["threshold"] = 0.4 cfg.viewer.body_name = "torso_link" cfg.episode_length_s = 10.0 if cfg.sim.njmax < 500: diff --git a/train_mimic/tasks/tracking/mdp/terminations.py b/train_mimic/tasks/tracking/mdp/terminations.py index f5d64e29..5c62bb0a 100644 --- a/train_mimic/tasks/tracking/mdp/terminations.py +++ b/train_mimic/tasks/tracking/mdp/terminations.py @@ -34,6 +34,20 @@ def bad_anchor_pos_z_only( ) +def bad_anchor_pos_z_only_adaptive( + env: ManagerBasedRlEnv, + command_name: str, + threshold: float, + down_threshold: float, + root_height_threshold: float, +) -> torch.Tensor: + command = cast(MotionCommand, env.command_manager.get_term(command_name)) + height_err = torch.abs(command.anchor_pos_w[:, -1] - command.robot_anchor_pos_w[:, -1]) + threshold_tensor = torch.full_like(height_err, threshold) + threshold_tensor[command.anchor_pos_w[:, -1] < root_height_threshold] = down_threshold + return height_err > threshold_tensor + + def bad_anchor_ori( env: ManagerBasedRlEnv, asset_cfg: SceneEntityCfg, command_name: str, threshold: float ) -> torch.Tensor: @@ -84,3 +98,23 @@ def bad_motion_body_pos_z_only( - command.robot_body_pos_w[:, body_indexes, -1] ) return torch.any(error > threshold, dim=-1) + + +def bad_motion_body_pos_z_only_adaptive( + env: ManagerBasedRlEnv, + command_name: str, + threshold: float, + down_threshold: float, + root_height_threshold: float, + body_names: tuple[str, ...] | None = None, +) -> torch.Tensor: + command = cast(MotionCommand, env.command_manager.get_term(command_name)) + + body_indexes = _get_body_indexes(command, body_names) + error = torch.abs( + command.body_pos_relative_w[:, body_indexes, -1] + - command.robot_body_pos_w[:, body_indexes, -1] + ) + threshold_tensor = torch.full_like(error, threshold) + threshold_tensor[command.anchor_pos_w[:, -1] < root_height_threshold] = down_threshold + return torch.any(error > threshold_tensor, dim=-1) diff --git a/train_mimic/tasks/tracking/tracking_env_cfg.py b/train_mimic/tasks/tracking/tracking_env_cfg.py index d79b0e9b..685eab62 100644 --- a/train_mimic/tasks/tracking/tracking_env_cfg.py +++ b/train_mimic/tasks/tracking/tracking_env_cfg.py @@ -256,8 +256,13 @@ def make_tracking_env_cfg() -> ManagerBasedRlEnvCfg: terminations: dict[str, TerminationTermCfg] = { "time_out": TerminationTermCfg(func=mdp.time_out, time_out=True), "anchor_pos": TerminationTermCfg( - func=mdp.bad_anchor_pos_z_only, - params={"command_name": "motion", "threshold": 0.25}, + func=mdp.bad_anchor_pos_z_only_adaptive, + params={ + "command_name": "motion", + "threshold": 0.15, + "down_threshold": 0.4, + "root_height_threshold": 0.5, + }, ), "anchor_ori": TerminationTermCfg( func=mdp.bad_anchor_ori, @@ -268,10 +273,20 @@ def make_tracking_env_cfg() -> ManagerBasedRlEnvCfg: }, ), "ee_body_pos": TerminationTermCfg( - func=mdp.bad_motion_body_pos_z_only, + func=mdp.bad_motion_body_pos_z_only_adaptive, + params={ + "command_name": "motion", + "threshold": 0.15, + "down_threshold": 0.4, + "root_height_threshold": 0.5, + "body_names": (), # Set per-robot. + }, + ), + "foot_pos_xyz": TerminationTermCfg( + func=mdp.bad_motion_body_pos, params={ "command_name": "motion", - "threshold": 0.25, + "threshold": 0.2, "body_names": (), # Set per-robot. }, ), From efb2a0bc5a85153c4bb7005a4763aaafe649ee85 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Mon, 20 Apr 2026 22:29:27 +0800 Subject: [PATCH 003/122] Keep General-Tracking-G1 resets aligned with the active training policy The training env no longer keeps a separate foot XYZ termination term. This commit removes that term from the assembled General-Tracking-G1 config and updates the regression test to assert the reduced termination surface directly. Constraint: Training resets should match the currently intended termination policy without reintroducing premature foot-specific failures Rejected: Keep foot_pos_xyz configured but unused in tests | leaves the env behavior and regression contract out of sync Confidence: high Scope-risk: narrow Reversibility: clean Directive: If foot-specific reset logic is reintroduced, update both the env assembly and termination regression test together Tested: pytest tests/test_termination_config.py tests/test_task_registry.py -q Not-tested: Full training run or benchmark rollout with real motion clips --- tests/test_termination_config.py | 13 ------------- train_mimic/tasks/tracking/config/env.py | 5 +---- 2 files changed, 1 insertion(+), 17 deletions(-) diff --git a/tests/test_termination_config.py b/tests/test_termination_config.py index 1336070c..bfda66f4 100644 --- a/tests/test_termination_config.py +++ b/tests/test_termination_config.py @@ -21,7 +21,6 @@ def test_general_tracking_termination_config_matches_requested_policy() -> None: "anchor_pos", "anchor_ori", "ee_body_pos", - "foot_pos_xyz", } anchor_pos = terminations["anchor_pos"] @@ -52,18 +51,6 @@ def test_general_tracking_termination_config_matches_requested_policy() -> None: ), } - foot_pos_xyz = terminations["foot_pos_xyz"] - assert foot_pos_xyz.func is mdp.bad_motion_body_pos - assert foot_pos_xyz.params == { - "command_name": "motion", - "threshold": 0.2, - "body_names": ( - "left_ankle_roll_link", - "right_ankle_roll_link", - ), - } - - def test_adaptive_height_termination_uses_relaxed_threshold_for_low_reference() -> None: command = SimpleNamespace( cfg=SimpleNamespace(body_names=("left_ankle_roll_link",)), diff --git a/train_mimic/tasks/tracking/config/env.py b/train_mimic/tasks/tracking/config/env.py index 0ae53539..bf1a897f 100644 --- a/train_mimic/tasks/tracking/config/env.py +++ b/train_mimic/tasks/tracking/config/env.py @@ -143,10 +143,7 @@ def make_general_tracking_env_cfg( "left_wrist_yaw_link", "right_wrist_yaw_link", ) - cfg.terminations["foot_pos_xyz"].params["body_names"] = ( - "left_ankle_roll_link", - "right_ankle_roll_link", - ) + cfg.terminations.pop("foot_pos_xyz", None) cfg.terminations["anchor_ori"].params["threshold"] = 1.0 cfg.viewer.body_name = "torso_link" cfg.episode_length_s = 10.0 From 89066121b5fc830ac9596446d947b4fed6c82cc9 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Tue, 21 Apr 2026 17:26:42 +0800 Subject: [PATCH 004/122] Reduce dataset-config drift across training and review workflows Standardize the public dataset spec names to twist2/seed/lafan1, update default motion-file references, and add a seed_filter_preset path that records filtering diagnostics in build reports. This keeps docs, scripts, and tests aligned with the renamed configs while making SEED curation reproducible and observable. Constraint: Keep dataset builder behavior fail-fast and avoid adding new dependencies Rejected: Preserve versioned spec filenames as the primary defaults | keeps user-facing commands and docs inconsistent Confidence: high Scope-risk: moderate Reversibility: clean Directive: If external automation still depends on the old YAML filenames, add explicit compatibility aliases before changing defaults again Tested: pytest tests/test_dataset_v2.py tests/test_review_pipeline.py tests/test_train_script.py -q Not-tested: End-to-end dataset build against real SEED/twist2 assets --- .gitignore | 1 + AGENTS.md | 4 +- docs/docs/reference/dataset.md | 20 +- .../current/reference/dataset.md | 20 +- scripts/render/render_motion_npz.py | 2 +- scripts/review/build_dataset_from_review.py | 6 +- scripts/review/export_reviewed_manifest.py | 4 +- scripts/review/init_review_manifest.py | 6 +- scripts/review/review_dataset.py | 4 +- tests/test_dataset_v2.py | 175 +++++++++++++- tests/test_review_pipeline.py | 2 +- tests/test_train_script.py | 2 +- .../datasets/{lafan1_v1.yaml => lafan1.yaml} | 4 +- .../datasets/{seed_v1.yaml => seed.yaml} | 5 +- .../configs/datasets/seed_v1_smoke.yaml | 13 -- train_mimic/configs/datasets/seed_v2_3h.yaml | 16 -- .../{twist2_full.yaml => twist2.yaml} | 2 +- train_mimic/data/dataset_builder.py | 214 ++++++++++++++++-- train_mimic/scripts/benchmark.py | 2 +- train_mimic/scripts/data/split_shards.py | 4 +- train_mimic/scripts/play.py | 6 +- train_mimic/scripts/train.py | 8 +- .../tasks/tracking/config/constants.py | 2 +- 23 files changed, 419 insertions(+), 103 deletions(-) rename train_mimic/configs/datasets/{lafan1_v1.yaml => lafan1.yaml} (83%) rename train_mimic/configs/datasets/{seed_v1.yaml => seed.yaml} (82%) delete mode 100644 train_mimic/configs/datasets/seed_v1_smoke.yaml delete mode 100644 train_mimic/configs/datasets/seed_v2_3h.yaml rename train_mimic/configs/datasets/{twist2_full.yaml => twist2.yaml} (96%) diff --git a/.gitignore b/.gitignore index da8602fb..00126fc3 100644 --- a/.gitignore +++ b/.gitignore @@ -100,3 +100,4 @@ data/modelscope_upload/ # Codex local workspace state .codex +.omx/ diff --git a/AGENTS.md b/AGENTS.md index 77d2fc92..d0f173b9 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -180,8 +180,8 @@ The single supported training task is `General-Tracking-G1` (experiment name: `g Quick reference: ```bash -python train_mimic/scripts/data/build_dataset.py --spec train_mimic/configs/datasets/twist2_full.yaml -python train_mimic/scripts/train.py --motion_file data/datasets/twist2_full/train +python train_mimic/scripts/data/build_dataset.py --spec train_mimic/configs/datasets/twist2.yaml +python train_mimic/scripts/train.py --motion_file data/datasets/twist2/train python train_mimic/scripts/save_onnx.py --checkpoint logs/rsl_rl/g1_general_tracking//model_30000.pt --output policy.onnx --history_length 10 ``` diff --git a/docs/docs/reference/dataset.md b/docs/docs/reference/dataset.md index 348db65a..53bd41d2 100644 --- a/docs/docs/reference/dataset.md +++ b/docs/docs/reference/dataset.md @@ -26,7 +26,7 @@ Data pipeline: `typed source YAML -> preprocess/filter -> shard-only training da ```bash python train_mimic/scripts/data/build_dataset.py \ - --spec train_mimic/configs/datasets/twist2_full.yaml + --spec train_mimic/configs/datasets/twist2.yaml ``` ## Output Structure @@ -48,10 +48,10 @@ data/datasets// ## YAML Spec Format -Example (`train_mimic/configs/datasets/twist2_full.yaml`): +Example (`train_mimic/configs/datasets/twist2.yaml`): ```yaml -name: twist2_full +name: twist2 target_fps: 30 val_percent: 5 hash_salt: "" @@ -62,7 +62,7 @@ sources: - name: OMOMO_g1_GMR type: pkl input: data/twist2_retarget_pkl/OMOMO_g1_GMR - - name: lafan1_v1 + - name: lafan1 type: bvh input: data/lafan1_bvh bvh_format: lafan1 @@ -105,20 +105,20 @@ Each shard contains: `clip_starts`, `clip_lengths`, `clip_fps`, `clip_weights`. ```bash # Force rebuild python train_mimic/scripts/data/build_dataset.py \ - --spec train_mimic/configs/datasets/twist2_full.yaml --force + --spec train_mimic/configs/datasets/twist2.yaml --force # Parallel processing python train_mimic/scripts/data/build_dataset.py \ - --spec train_mimic/configs/datasets/twist2_full.yaml --jobs 8 + --spec train_mimic/configs/datasets/twist2.yaml --jobs 8 # Custom output root python train_mimic/scripts/data/build_dataset.py \ - --spec train_mimic/configs/datasets/twist2_full.yaml \ + --spec train_mimic/configs/datasets/twist2.yaml \ --output_root /tmp/my_datasets # Print build report python train_mimic/scripts/data/build_dataset.py \ - --spec train_mimic/configs/datasets/twist2_full.yaml --json + --spec train_mimic/configs/datasets/twist2.yaml --json ``` ## Batch Ingest to NPZ Clips @@ -128,8 +128,8 @@ Convert raw data to standard NPZ clips without merging: ```bash python train_mimic/scripts/data/ingest_motion.py \ --type bvh --input data/lafan1_bvh \ - --output data/datasets/lafan1_v1/clips/lafan1_v1 \ - --source lafan1_v1 --bvh_format lafan1 --jobs 8 + --output data/datasets/lafan1/clips/lafan1 \ + --source lafan1 --bvh_format lafan1 --jobs 8 ``` ## Check Clip FK Consistency diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/dataset.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/dataset.md index b7b17831..71802832 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/dataset.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/dataset.md @@ -26,7 +26,7 @@ python train_mimic/scripts/train.py --motion_file data/datasets/seed/train ```bash python train_mimic/scripts/data/build_dataset.py \ - --spec train_mimic/configs/datasets/twist2_full.yaml + --spec train_mimic/configs/datasets/twist2.yaml ``` ## 输出目录结构 @@ -48,10 +48,10 @@ data/datasets// ## YAML spec -示例(`train_mimic/configs/datasets/twist2_full.yaml`): +示例(`train_mimic/configs/datasets/twist2.yaml`): ```yaml -name: twist2_full +name: twist2 target_fps: 30 val_percent: 5 hash_salt: "" @@ -62,7 +62,7 @@ sources: - name: OMOMO_g1_GMR type: pkl input: data/twist2_retarget_pkl/OMOMO_g1_GMR - - name: lafan1_v1 + - name: lafan1 type: bvh input: data/lafan1_bvh bvh_format: lafan1 @@ -93,20 +93,20 @@ sources: ```bash # 强制重建 python train_mimic/scripts/data/build_dataset.py \ - --spec train_mimic/configs/datasets/twist2_full.yaml --force + --spec train_mimic/configs/datasets/twist2.yaml --force # 多进程并行 python train_mimic/scripts/data/build_dataset.py \ - --spec train_mimic/configs/datasets/twist2_full.yaml --jobs 8 + --spec train_mimic/configs/datasets/twist2.yaml --jobs 8 # 自定义输出根目录 python train_mimic/scripts/data/build_dataset.py \ - --spec train_mimic/configs/datasets/twist2_full.yaml \ + --spec train_mimic/configs/datasets/twist2.yaml \ --output_root /tmp/my_datasets # 打印 build report python train_mimic/scripts/data/build_dataset.py \ - --spec train_mimic/configs/datasets/twist2_full.yaml --json + --spec train_mimic/configs/datasets/twist2.yaml --json ``` ## 批量转换为 NPZ clips @@ -116,8 +116,8 @@ python train_mimic/scripts/data/build_dataset.py \ ```bash python train_mimic/scripts/data/ingest_motion.py \ --type bvh --input data/lafan1_bvh \ - --output data/datasets/lafan1_v1/clips/lafan1_v1 \ - --source lafan1_v1 --bvh_format lafan1 --jobs 8 + --output data/datasets/lafan1/clips/lafan1 \ + --source lafan1 --bvh_format lafan1 --jobs 8 ``` ## FK 一致性检查 diff --git a/scripts/render/render_motion_npz.py b/scripts/render/render_motion_npz.py index 40b1c8d2..12cb306c 100644 --- a/scripts/render/render_motion_npz.py +++ b/scripts/render/render_motion_npz.py @@ -27,7 +27,7 @@ def parse_args() -> argparse.Namespace: "--robot", type=str, default="unitree_g1", - help="Robot viewer type. For lafan1_v1 clips this should usually stay as unitree_g1.", + help="Robot viewer type. For lafan1 clips this should usually stay as unitree_g1.", ) parser.add_argument("--start", type=int, default=0, help="Start frame index") parser.add_argument("--end", type=int, default=-1, help="Exclusive end frame index; -1 means full clip") diff --git a/scripts/review/build_dataset_from_review.py b/scripts/review/build_dataset_from_review.py index 9018f313..15f8d9a4 100644 --- a/scripts/review/build_dataset_from_review.py +++ b/scripts/review/build_dataset_from_review.py @@ -6,8 +6,8 @@ Usage: python scripts/data/build_dataset_from_review.py \ - --filtered_manifest data/datasets/review/twist2_full/filtered_manifest.csv \ - --output_dir data/datasets/builds/twist2_full_cleaned + --filtered_manifest data/datasets/review/twist2/filtered_manifest.csv \ + --output_dir data/datasets/builds/twist2_cleaned """ from __future__ import annotations @@ -41,7 +41,7 @@ def main() -> None: ) parser.add_argument( "--output_dir", type=str, required=True, - help="Output directory, e.g. data/datasets/builds/twist2_full_cleaned", + help="Output directory, e.g. data/datasets/builds/twist2_cleaned", ) parser.add_argument( "--target_fps", type=int, default=None, diff --git a/scripts/review/export_reviewed_manifest.py b/scripts/review/export_reviewed_manifest.py index b2a893f5..1f4f3c32 100644 --- a/scripts/review/export_reviewed_manifest.py +++ b/scripts/review/export_reviewed_manifest.py @@ -6,8 +6,8 @@ Usage: python scripts/data/export_reviewed_manifest.py \ - --review data/datasets/review/twist2_full/review_state.csv \ - --output data/datasets/review/twist2_full/filtered_manifest.csv + --review data/datasets/review/twist2/review_state.csv \ + --output data/datasets/review/twist2/filtered_manifest.csv """ from __future__ import annotations diff --git a/scripts/review/init_review_manifest.py b/scripts/review/init_review_manifest.py index afd5c01b..8c81bfe8 100644 --- a/scripts/review/init_review_manifest.py +++ b/scripts/review/init_review_manifest.py @@ -3,8 +3,8 @@ Usage: python scripts/data/init_review_manifest.py \ - --dataset twist2_full \ - --manifest data/datasets/builds/twist2_full/manifest_resolved.csv + --dataset twist2 \ + --manifest data/datasets/builds/twist2/manifest_resolved.csv """ from __future__ import annotations @@ -22,7 +22,7 @@ def main() -> None: parser = argparse.ArgumentParser(description="Initialize review state from manifest") - parser.add_argument("--dataset", type=str, required=True, help="Dataset name, e.g. twist2_full") + parser.add_argument("--dataset", type=str, required=True, help="Dataset name, e.g. twist2") parser.add_argument( "--manifest", type=str, required=True, help="Path to manifest_resolved.csv" ) diff --git a/scripts/review/review_dataset.py b/scripts/review/review_dataset.py index c0dfa8a2..f87568a9 100644 --- a/scripts/review/review_dataset.py +++ b/scripts/review/review_dataset.py @@ -6,8 +6,8 @@ Usage: python scripts/review/review_dataset.py \ - --dataset lafan1_v1 \ - --review data/datasets/review/lafan1_v1/review_state.csv + --dataset lafan1 \ + --review data/datasets/review/lafan1/review_state.csv """ from __future__ import annotations diff --git a/tests/test_dataset_v2.py b/tests/test_dataset_v2.py index 07f5d894..2625ddfd 100644 --- a/tests/test_dataset_v2.py +++ b/tests/test_dataset_v2.py @@ -1,5 +1,6 @@ from __future__ import annotations +import csv import pickle from pathlib import Path @@ -146,6 +147,77 @@ def test_load_dataset_spec_parses_preprocess(tmp_path: Path) -> None: assert spec.preprocess.min_frames == 10 +def test_load_dataset_spec_parses_seed_filter_preset(tmp_path: Path) -> None: + metadata_csv = tmp_path / "seed_metadata.csv" + metadata_csv.write_text("move_g1_path,is_mirror\n", encoding="utf-8") + spec_path = tmp_path / "seed.yaml" + spec_path.write_text( + f"""name: seed_demo +target_fps: 30 +val_percent: 5 +hash_salt: "" +sources: + - name: seed + type: seed_csv + input: {tmp_path / 'seed_source'} + metadata_csv: {metadata_csv} + seed_filter_preset: groot_strict + filters: + is_mirror: [false] +""", + encoding="utf-8", + ) + + spec = load_dataset_spec(spec_path) + assert spec.sources[0].seed_filter_preset == "groot_strict" + + +def test_load_dataset_spec_rejects_seed_filter_preset_on_non_seed_source(tmp_path: Path) -> None: + metadata_csv = tmp_path / "seed_metadata.csv" + metadata_csv.write_text("move_g1_path,is_mirror\n", encoding="utf-8") + spec_path = tmp_path / "bad_seed_preset.yaml" + spec_path.write_text( + f"""name: demo +target_fps: 30 +val_percent: 5 +hash_salt: "" +sources: + - name: clips + type: npz + input: {tmp_path / 'npz_source'} + metadata_csv: {metadata_csv} + seed_filter_preset: groot_strict +""", + encoding="utf-8", + ) + + with pytest.raises(ValueError, match="supported only for seed_csv sources"): + load_dataset_spec(spec_path) + + +def test_load_dataset_spec_rejects_unknown_seed_filter_preset(tmp_path: Path) -> None: + metadata_csv = tmp_path / "seed_metadata.csv" + metadata_csv.write_text("move_g1_path,is_mirror\n", encoding="utf-8") + spec_path = tmp_path / "bad_seed_preset.yaml" + spec_path.write_text( + f"""name: demo +target_fps: 30 +val_percent: 5 +hash_salt: "" +sources: + - name: seed + type: seed_csv + input: {tmp_path / 'seed_source'} + metadata_csv: {metadata_csv} + seed_filter_preset: unknown_preset +""", + encoding="utf-8", + ) + + with pytest.raises(ValueError, match="unknown seed_filter_preset"): + load_dataset_spec(spec_path) + + def test_load_dataset_spec_rejects_bvh_without_format(tmp_path: Path) -> None: spec_path = tmp_path / "bad.yaml" spec_path.write_text( @@ -273,6 +345,89 @@ def test_convert_source_to_npz_clips_rejects_dataset_root_for_npz_source(tmp_pat convert_source_to_npz_clips(source, tmp_path / "dataset" / "clips" / "npz_src", jobs=1) +def test_collect_source_files_with_report_applies_seed_filter_preset(tmp_path: Path) -> None: + metadata_csv = tmp_path / "seed_metadata.csv" + with metadata_csv.open("w", encoding="utf-8", newline="") as handle: + writer = csv.DictWriter( + handle, + fieldnames=[ + "move_g1_path", + "is_mirror", + "content_body_position", + "content_type_of_movement", + "content_props", + "filename", + "move_name", + ], + ) + writer.writeheader() + writer.writerow( + { + "move_g1_path": "g1/csv/240101/walk_forward.csv", + "is_mirror": "False", + "content_body_position": "standing", + "content_type_of_movement": "walking", + "content_props": "0", + "filename": "walk_forward", + "move_name": "walk_forward", + } + ) + writer.writerow( + { + "move_g1_path": "g1/csv/240101/sit_pose.csv", + "is_mirror": "False", + "content_body_position": "sitting", + "content_type_of_movement": "sitting", + "content_props": "chair", + "filename": "sit_pose", + "move_name": "sit_pose", + } + ) + + source = DatasetSourceSpec( + name="seed", + type="seed_csv", + input=str(tmp_path / "seed_source" / "g1" / "csv"), + metadata_csv=str(metadata_csv), + filters={"is_mirror": [False]}, + seed_filter_preset="groot_strict", + ) + + input_root = tmp_path / "seed_source" / "g1" / "csv" + input_root.mkdir(parents=True, exist_ok=True) + (input_root / "240101").mkdir(parents=True, exist_ok=True) + (input_root / "240101" / "walk_forward.csv").write_text("placeholder", encoding="utf-8") + (input_root / "240101" / "sit_pose.csv").write_text("placeholder", encoding="utf-8") + + items, _scan_root, report = dataset_builder._collect_source_files_with_report(source, quiet=True) + + assert [item.rel_no_suffix.as_posix() for item in items] == ["240101/walk_forward"] + assert report["scanned_files"] == 2 + assert report["metadata_rows_matched"] == 2 + assert report["preset_rejected_rows"] == 1 + assert report["kept_files"] == 1 + assert report["filtered_files"] == 1 + assert report["preset_reject_reasons"]["content_body_position:sitting"] == 1 + + +def test_collect_source_files_with_report_handles_single_file_source(tmp_path: Path) -> None: + npz_path = tmp_path / "clip_a.npz" + _write_npz_from_pkl(npz_path) + source = DatasetSourceSpec(name="clip", type="npz", input=str(npz_path)) + + items, scan_root, report = dataset_builder._collect_source_files_with_report( + source, quiet=True + ) + legacy_items, legacy_scan_root = dataset_builder._collect_source_files(source) + + assert scan_root == tmp_path + assert [item.rel_no_suffix.as_posix() for item in items] == ["clip_a"] + assert report["scanned_files"] == 1 + assert report["kept_files"] == 1 + assert legacy_scan_root == scan_root + assert [item.rel_no_suffix.as_posix() for item in legacy_items] == ["clip_a"] + + def test_build_dataset_from_spec_writes_shard_directories(tmp_path: Path) -> None: npz_input = tmp_path / "npz_source" _write_npz_from_pkl(npz_input / "clip_a.npz") @@ -506,12 +661,24 @@ def test_build_dataset_batch_manifest_skips_filtered_entries( ) dataset_dir = tmp_path / "datasets" / spec.name - def _collect(_source): + def _collect_with_report(_source, *, quiet=False): + _ = quiet return ([ SourceInputFile(path=keep_train, rel_no_suffix=Path("keep_train")), SourceInputFile(path=drop_train, rel_no_suffix=Path("drop_train")), SourceInputFile(path=keep_val, rel_no_suffix=Path("keep_val")), - ], source_dir) + ], source_dir, { + "source": "seed", + "type": "seed_csv", + "metadata_csv": None, + "seed_filter_preset": "groot_strict", + "scanned_files": 3, + "metadata_rows_matched": 3, + "preset_rejected_rows": 1, + "kept_files": 2, + "filtered_files": 1, + "preset_reject_reasons": {"content_body_position:sitting": 1}, + }) def _hash_split(clip_id: str, _val_percent: int, _salt: str = "") -> str: return "val" if clip_id.endswith("keep_val") else "train" @@ -582,7 +749,7 @@ def _batch_convert_split(clips, target_fps, output_dir, jobs, split_name, prepro "kept_file_paths": [str(keep_val)], }]) - monkeypatch.setattr(dataset_builder, "_collect_source_files", _collect) + monkeypatch.setattr(dataset_builder, "_collect_source_files_with_report", _collect_with_report) monkeypatch.setattr(dataset_builder, "hash_split", _hash_split) monkeypatch.setattr(dataset_builder, "_batch_convert_split", _batch_convert_split) @@ -601,3 +768,5 @@ def _batch_convert_split(clips, target_fps, output_dir, jobs, split_name, prepro assert "seed:drop_train" not in manifest assert report["clip_counts"] == {"total": 2, "train": 1, "val": 1} assert report["input_clip_counts"] == {"total": 3, "train": 2, "val": 1} + assert report["source_filters"][0]["seed_filter_preset"] == "groot_strict" + assert report["source_filters"][0]["preset_reject_reasons"] == {"content_body_position:sitting": 1} diff --git a/tests/test_review_pipeline.py b/tests/test_review_pipeline.py index 7c9715cc..a2e4e5b8 100644 --- a/tests/test_review_pipeline.py +++ b/tests/test_review_pipeline.py @@ -229,7 +229,7 @@ def test_build_dataset_from_review_resamples_mixed_fps_and_preserves_weights( ] ) - output_dir = tmp_path / "twist2_full_cleaned" + output_dir = tmp_path / "twist2_cleaned" monkeypatch.setattr( sys, "argv", diff --git a/tests/test_train_script.py b/tests/test_train_script.py index 1fdfe9cb..f1dd315c 100644 --- a/tests/test_train_script.py +++ b/tests/test_train_script.py @@ -29,7 +29,7 @@ def _args(**overrides: object) -> argparse.Namespace: "seed": 42, "wandb_project": None, "experiment_name": None, - "motion_file": "data/datasets/twist2_full/train", + "motion_file": "data/datasets/twist2/train", "resume": None, "device": None, "gpu_ids": None, diff --git a/train_mimic/configs/datasets/lafan1_v1.yaml b/train_mimic/configs/datasets/lafan1.yaml similarity index 83% rename from train_mimic/configs/datasets/lafan1_v1.yaml rename to train_mimic/configs/datasets/lafan1.yaml index 2c4ffac3..bc58527c 100644 --- a/train_mimic/configs/datasets/lafan1_v1.yaml +++ b/train_mimic/configs/datasets/lafan1.yaml @@ -1,4 +1,4 @@ -name: lafan1_v1 +name: lafan1 target_fps: 30 val_percent: 5 hash_salt: "" @@ -6,7 +6,7 @@ preprocess: normalize_root_xy: true ground_align: clip_min_foot sources: - - name: lafan1_v1 + - name: lafan1 type: bvh input: data/lafan1_bvh bvh_format: lafan1 diff --git a/train_mimic/configs/datasets/seed_v1.yaml b/train_mimic/configs/datasets/seed.yaml similarity index 82% rename from train_mimic/configs/datasets/seed_v1.yaml rename to train_mimic/configs/datasets/seed.yaml index 882a39d8..17736f0a 100644 --- a/train_mimic/configs/datasets/seed_v1.yaml +++ b/train_mimic/configs/datasets/seed.yaml @@ -1,4 +1,4 @@ -name: seed_v1 +name: seed target_fps: 30 val_percent: 5 hash_salt: "" @@ -7,10 +7,11 @@ preprocess: ground_align: clip_min_foot min_frames: 22 sources: - - name: seed_full + - name: seed type: seed_csv input: data/SEED/g1/csv metadata_csv: data/SEED/seed_metadata_v003.csv + seed_filter_preset: groot_strict weight: 1.0 filters: is_mirror: [false] diff --git a/train_mimic/configs/datasets/seed_v1_smoke.yaml b/train_mimic/configs/datasets/seed_v1_smoke.yaml deleted file mode 100644 index 7b9bbd39..00000000 --- a/train_mimic/configs/datasets/seed_v1_smoke.yaml +++ /dev/null @@ -1,13 +0,0 @@ -name: seed_v1_smoke -target_fps: 30 -val_percent: 5 -hash_salt: "" -preprocess: - normalize_root_xy: true - ground_align: clip_min_foot - min_frames: 22 -sources: - - name: seed_smoke - type: seed_csv - input: data/SEED/g1/csv/221118 - weight: 1.0 diff --git a/train_mimic/configs/datasets/seed_v2_3h.yaml b/train_mimic/configs/datasets/seed_v2_3h.yaml deleted file mode 100644 index 2f275413..00000000 --- a/train_mimic/configs/datasets/seed_v2_3h.yaml +++ /dev/null @@ -1,16 +0,0 @@ -name: seed_v2_3h -target_fps: 30 -val_percent: 5 -hash_salt: "" -preprocess: - normalize_root_xy: true - ground_align: clip_min_foot - min_frames: 22 -sources: - - name: seed_3h_sampled - type: seed_csv - input: data/SEED/g1/csv - metadata_csv: train_mimic/data/seed/seed_metadata_v003_3h.csv - weight: 1.0 - filters: - is_mirror: [false] diff --git a/train_mimic/configs/datasets/twist2_full.yaml b/train_mimic/configs/datasets/twist2.yaml similarity index 96% rename from train_mimic/configs/datasets/twist2_full.yaml rename to train_mimic/configs/datasets/twist2.yaml index 48c14d78..0c819271 100644 --- a/train_mimic/configs/datasets/twist2_full.yaml +++ b/train_mimic/configs/datasets/twist2.yaml @@ -1,4 +1,4 @@ -name: twist2_full +name: twist2 target_fps: 30 val_percent: 5 hash_salt: "" diff --git a/train_mimic/data/dataset_builder.py b/train_mimic/data/dataset_builder.py index 7ef9a52f..3385ccaa 100644 --- a/train_mimic/data/dataset_builder.py +++ b/train_mimic/data/dataset_builder.py @@ -5,6 +5,7 @@ import shutil import multiprocessing from concurrent.futures import ProcessPoolExecutor, as_completed +from collections import Counter from dataclasses import asdict, dataclass, field from pathlib import Path from tempfile import TemporaryDirectory @@ -38,7 +39,7 @@ from teleopit.retargeting.export_pkl import convert_bvh_to_retarget_pkl, mocap_xml_path PROJECT_ROOT = Path(__file__).resolve().parents[2] -DEFAULT_SPEC_PATH = PROJECT_ROOT / "train_mimic" / "configs" / "datasets" / "twist2_full.yaml" +DEFAULT_SPEC_PATH = PROJECT_ROOT / "train_mimic" / "configs" / "datasets" / "twist2.yaml" DEFAULT_DATASETS_ROOT = PROJECT_ROOT / "data" / "datasets" DEFAULT_FK_SAMPLE_CLIPS = 2 DEFAULT_FK_SAMPLE_FRAMES = 16 @@ -86,6 +87,7 @@ class DatasetSourceSpec: max_frames: int = 0 metadata_csv: str | None = None filters: dict[str, list] | None = None + seed_filter_preset: str | None = None @dataclass(frozen=True) @@ -148,6 +150,39 @@ class ConversionTask: preprocess: DatasetPreprocessSpec = field(default_factory=DatasetPreprocessSpec) +@dataclass(frozen=True) +class SeedFilterRule: + columns: tuple[str, ...] + patterns: tuple[str, ...] + label: str + + +_SEED_FILTER_PRESETS: dict[str, tuple[SeedFilterRule, ...]] = { + "groot_strict": ( + SeedFilterRule( + columns=("content_body_position",), + patterns=("sitting", "on all fours", "handstand"), + label="content_body_position", + ), + SeedFilterRule( + columns=("content_type_of_movement",), + patterns=("crawling", "on hands and knees", "rolling", "flipping", "climbing"), + label="content_type_of_movement", + ), + SeedFilterRule( + columns=("content_props",), + patterns=("chair", "crutch", "crutches", "ladder", "box", "table", "bike", "scooter", "bed"), + label="content_props", + ), + SeedFilterRule( + columns=("filename", "move_name"), + patterns=("safety_roll", "cartwheel", "box_jump", "monkey_jump", "walking_on_edge"), + label="filename_or_move_name", + ), + ) +} + + def _display_path(path: Path) -> str: try: return path.relative_to(PROJECT_ROOT).as_posix() @@ -165,6 +200,34 @@ def _validate_source_type(raw_type: object, spec_path: Path, source_name: str) - return source_type +def _validate_seed_filter_preset( + source_type: str, + seed_filter_preset: str | None, + metadata_csv: str | None, + *, + spec_path: Path, + source_name: str, +) -> str | None: + if seed_filter_preset is None: + return None + if source_type != "seed_csv": + raise ValueError( + f"source {source_name!r} uses seed_filter_preset={seed_filter_preset!r}, " + "but seed_filter_preset is supported only for seed_csv sources" + ) + if metadata_csv is None: + raise ValueError( + f"source {source_name!r} uses seed_filter_preset={seed_filter_preset!r} " + f"without metadata_csv: {spec_path}" + ) + if seed_filter_preset not in _SEED_FILTER_PRESETS: + raise ValueError( + f"source {source_name!r} has unknown seed_filter_preset {seed_filter_preset!r} " + f"in {spec_path}. Expected one of {sorted(_SEED_FILTER_PRESETS)}." + ) + return seed_filter_preset + + def _load_preprocess_spec(raw: object, spec_path: Path) -> DatasetPreprocessSpec: if raw is None: return DatasetPreprocessSpec() @@ -266,6 +329,16 @@ def load_dataset_spec(path: str | Path) -> DatasetSpec: filters = raw.get("filters") if filters is not None and not isinstance(filters, dict): raise ValueError(f"source {source_name!r} filters must be a mapping: {spec_path}") + seed_filter_preset = raw.get("seed_filter_preset") + if seed_filter_preset is not None: + seed_filter_preset = str(seed_filter_preset).strip() or None + seed_filter_preset = _validate_seed_filter_preset( + source_type, + seed_filter_preset, + metadata_csv, + spec_path=spec_path, + source_name=source_name, + ) sources.append( DatasetSourceSpec( @@ -278,6 +351,7 @@ def load_dataset_spec(path: str | Path) -> DatasetSpec: max_frames=max_frames, metadata_csv=metadata_csv, filters=filters, + seed_filter_preset=seed_filter_preset, ) ) @@ -346,10 +420,24 @@ def _filter_seed_csv_by_metadata( source: DatasetSourceSpec, all_files: list[SourceInputFile], input_dir: Path, -) -> list[SourceInputFile]: + *, + quiet: bool = False, +) -> tuple[list[SourceInputFile], dict[str, Any]]: """Filter seed_csv files using metadata_csv + filters from the source spec.""" - if source.metadata_csv is None or source.filters is None: - return all_files + report: dict[str, Any] = { + "source": source.name, + "type": source.type, + "metadata_csv": source.metadata_csv, + "seed_filter_preset": source.seed_filter_preset, + "scanned_files": len(all_files), + "metadata_rows_matched": len(all_files), + "preset_rejected_rows": 0, + "kept_files": len(all_files), + "filtered_files": 0, + "preset_reject_reasons": {}, + } + if source.metadata_csv is None or (source.filters is None and source.seed_filter_preset is None): + return all_files, report meta_path = Path(source.metadata_csv).expanduser() if not meta_path.is_absolute(): @@ -360,24 +448,53 @@ def _filter_seed_csv_by_metadata( with meta_path.open("r", encoding="utf-8") as f: reader = csv.DictReader(f) fieldnames = reader.fieldnames or [] - for col in source.filters: - if col not in fieldnames: - raise ValueError( - f"filter column {col!r} not found in metadata CSV for {source.name}. " - f"Available: {sorted(fieldnames)}" - ) + if source.filters is not None: + for col in source.filters: + if col not in fieldnames: + raise ValueError( + f"filter column {col!r} not found in metadata CSV for {source.name}. " + f"Available: {sorted(fieldnames)}" + ) if "move_g1_path" not in fieldnames: raise ValueError(f"metadata CSV missing move_g1_path column for {source.name}") + if source.seed_filter_preset is not None: + for rule in _SEED_FILTER_PRESETS[source.seed_filter_preset]: + for col in rule.columns: + if col not in fieldnames: + raise ValueError( + f"seed_filter_preset column {col!r} not found in metadata CSV for " + f"{source.name}. Available: {sorted(fieldnames)}" + ) # Normalize filter values to strings for comparison str_filters: dict[str, set[str]] = {} - for col, allowed_values in source.filters.items(): + for col, allowed_values in (source.filters or {}).items(): str_filters[col] = {str(v) for v in allowed_values} rows = [] for row in reader: if all(row.get(col, "") in vals for col, vals in str_filters.items()): rows.append(row) + report["metadata_csv"] = str(meta_path) + report["metadata_rows_matched"] = len(rows) + + if source.seed_filter_preset is not None: + reject_counts: Counter[str] = Counter() + kept_rows = [] + for row in rows: + reasons: list[str] = [] + for rule in _SEED_FILTER_PRESETS[source.seed_filter_preset]: + for pattern in rule.patterns: + if any(pattern in row.get(col, "").lower() for col in rule.columns): + reasons.append(f"{rule.label}:{pattern}") + break + if reasons: + reject_counts.update(reasons) + continue + kept_rows.append(row) + report["preset_rejected_rows"] = len(rows) - len(kept_rows) + report["preset_reject_reasons"] = dict(sorted(reject_counts.items())) + rows = kept_rows # Build set of allowed relative paths (without .csv suffix) allowed_rels: set[str] = set() @@ -395,14 +512,27 @@ def _filter_seed_csv_by_metadata( allowed_rels.add(Path(g1_path).stem) filtered = [f for f in all_files if f.rel_no_suffix.as_posix() in allowed_rels] - print( - f"[FILTER] source={source.name}: {len(filtered)}/{len(all_files)} files " - f"after metadata filtering" - ) - return filtered + report["kept_files"] = len(filtered) + report["filtered_files"] = len(all_files) - len(filtered) + if not quiet: + print( + f"[FILTER] source={source.name}: {len(filtered)}/{len(all_files)} files " + f"after metadata filtering" + ) + if source.seed_filter_preset is not None and report["preset_rejected_rows"] > 0: + print( + f"[FILTER] source={source.name}: preset={source.seed_filter_preset} " + f"rejected={report['preset_rejected_rows']} " + f"reasons={report['preset_reject_reasons']}" + ) + return filtered, report -def _collect_source_files(source: DatasetSourceSpec) -> tuple[list[SourceInputFile], Path]: +def _collect_source_files_with_report( + source: DatasetSourceSpec, + *, + quiet: bool = False, +) -> tuple[list[SourceInputFile], Path, dict[str, Any]]: input_path = resolve_source_input_path(source) _ensure_not_dataset_root_npz_input(source, input_path) suffix = _SOURCE_SUFFIXES[source.type] @@ -412,7 +542,20 @@ def _collect_source_files(source: DatasetSourceSpec) -> tuple[list[SourceInputFi raise ValueError( f"source {source.name} expected {suffix} input, got file {input_path.name}" ) - return [SourceInputFile(path=input_path, rel_no_suffix=Path(input_path.stem))], input_path.parent + items = [SourceInputFile(path=input_path, rel_no_suffix=Path(input_path.stem))] + report: dict[str, Any] = { + "source": source.name, + "type": source.type, + "metadata_csv": source.metadata_csv, + "seed_filter_preset": source.seed_filter_preset, + "scanned_files": len(items), + "metadata_rows_matched": len(items), + "preset_rejected_rows": 0, + "kept_files": len(items), + "filtered_files": 0, + "preset_reject_reasons": {}, + } + return items, input_path.parent, report if not input_path.is_dir(): raise FileNotFoundError(f"source input is neither file nor directory: {input_path}") @@ -430,14 +573,32 @@ def _collect_source_files(source: DatasetSourceSpec) -> tuple[list[SourceInputFi for path in files ] + report: dict[str, Any] = { + "source": source.name, + "type": source.type, + "metadata_csv": source.metadata_csv, + "seed_filter_preset": source.seed_filter_preset, + "scanned_files": len(items), + "metadata_rows_matched": len(items), + "preset_rejected_rows": 0, + "kept_files": len(items), + "filtered_files": 0, + "preset_reject_reasons": {}, + } + # Apply metadata filtering for seed_csv sources if source.type == "seed_csv" and source.metadata_csv is not None: - items = _filter_seed_csv_by_metadata(source, items, input_path) + items, report = _filter_seed_csv_by_metadata(source, items, input_path, quiet=quiet) if not items: raise ValueError( f"no files remain after metadata filtering for source {source.name}: {input_path}" ) + return items, input_path, report + + +def _collect_source_files(source: DatasetSourceSpec) -> tuple[list[SourceInputFile], Path]: + items, input_path, _report = _collect_source_files_with_report(source, quiet=False) return items, input_path @@ -488,6 +649,14 @@ def _pending_tasks(tasks: list[ConversionTask]) -> list[ConversionTask]: return [task for task in tasks if not Path(task.output_path).is_file()] +def _build_source_filter_reports(spec: DatasetSpec) -> list[dict[str, Any]]: + reports: list[dict[str, Any]] = [] + for source in spec.sources: + _, _, report = _collect_source_files_with_report(source, quiet=True) + reports.append(report) + return reports + + def _get_fk_extractor() -> MotionFkExtractor: global _PROCESS_FK_EXTRACTOR if _PROCESS_FK_EXTRACTOR is None: @@ -1132,8 +1301,10 @@ def _build_dataset_batch( # 1. Enumerate all source files and pre-compute splits clip_entries: list[tuple[str, str, str, float, str]] = [] + source_filter_reports: list[dict[str, Any]] = [] for source in spec.sources: - items, _ = _collect_source_files(source) + items, _, filter_report = _collect_source_files_with_report(source, quiet=False) + source_filter_reports.append(filter_report) for item in items: clip_id = f"{source.name}:{item.rel_no_suffix.as_posix()}" split = hash_split(clip_id, spec.val_percent, spec.hash_salt) @@ -1249,6 +1420,7 @@ def _build_rows_for_shards( "jobs": int(jobs), "preprocess": spec.preprocess.to_dict(), "sources": [asdict(source) for source in spec.sources], + "source_filters": source_filter_reports, "splits": { "train": train_stats, "val": val_stats, @@ -1293,6 +1465,7 @@ def build_dataset_from_spec( ) # Legacy per-file mode for bvh/npz sources + source_filter_reports = _build_source_filter_reports(spec) convert_sources_to_npz(spec, paths=paths, force=force, jobs=jobs) rows = collect_clip_rows(spec, paths=paths) @@ -1385,6 +1558,7 @@ def build_dataset_from_spec( "jobs": int(jobs), "preprocess": spec.preprocess.to_dict(), "sources": [asdict(source) for source in spec.sources], + "source_filters": source_filter_reports, "splits": { "train": train_stats, "val": val_stats, diff --git a/train_mimic/scripts/benchmark.py b/train_mimic/scripts/benchmark.py index f9748e84..4dab5b72 100644 --- a/train_mimic/scripts/benchmark.py +++ b/train_mimic/scripts/benchmark.py @@ -10,7 +10,7 @@ # Benchmark only (no video) python train_mimic/scripts/benchmark.py \ --checkpoint logs/rsl_rl/g1_tracking/.../model_30000.pt \ - --motion_file data/datasets/twist2_full/val \ + --motion_file data/datasets/twist2/val \ --num_envs 1 # Single video (one continuous clip) diff --git a/train_mimic/scripts/data/split_shards.py b/train_mimic/scripts/data/split_shards.py index 69072263..0b8ea602 100644 --- a/train_mimic/scripts/data/split_shards.py +++ b/train_mimic/scripts/data/split_shards.py @@ -7,8 +7,8 @@ Usage: python train_mimic/scripts/data/split_shards.py \ - --input data/datasets/seed_v1/train \ - --output data/datasets/seed_v1/train_small_shards \ + --input data/datasets/seed/train \ + --output data/datasets/seed/train_small_shards \ --max_size_gb 2 """ diff --git a/train_mimic/scripts/play.py b/train_mimic/scripts/play.py index 6d2a4b5a..33069710 100644 --- a/train_mimic/scripts/play.py +++ b/train_mimic/scripts/play.py @@ -9,18 +9,18 @@ # Native window python train_mimic/scripts/play.py \ --checkpoint logs/rsl_rl/g1_tracking/2026-.../model_30000.pt \ - --motion_file data/datasets/twist2_full/val + --motion_file data/datasets/twist2/val # Browser viewer (no display required) python train_mimic/scripts/play.py \ --checkpoint logs/rsl_rl/g1_tracking/2026-.../model_30000.pt \ - --motion_file data/datasets/twist2_full/val \ + --motion_file data/datasets/twist2/val \ --viewer viser # Record video instead of interactive viewer python train_mimic/scripts/play.py \ --checkpoint logs/rsl_rl/g1_tracking/2026-.../model_30000.pt \ - --motion_file data/datasets/twist2_full/val \ + --motion_file data/datasets/twist2/val \ --video """ diff --git a/train_mimic/scripts/train.py b/train_mimic/scripts/train.py index f45d5afc..e8e2d30e 100644 --- a/train_mimic/scripts/train.py +++ b/train_mimic/scripts/train.py @@ -4,24 +4,24 @@ Usage: python train_mimic/scripts/train.py \ --num_envs 4096 --max_iterations 18000 \ - --motion_file data/datasets/twist2_full/train + --motion_file data/datasets/twist2/train # Quick verification python train_mimic/scripts/train.py \ --num_envs 64 --max_iterations 100 \ - --motion_file data/datasets/twist2_full/train + --motion_file data/datasets/twist2/train # With wandb logging python train_mimic/scripts/train.py \ --num_envs 4096 --max_iterations 30000 \ - --motion_file data/datasets/twist2_full/train \ + --motion_file data/datasets/twist2/train \ --wandb_project teleopit # Resume for additional iterations python train_mimic/scripts/train.py \ --resume logs/rsl_rl/g1_general_tracking//model_12000.pt \ --max_iterations 18000 \ - --motion_file data/datasets/twist2_full/train + --motion_file data/datasets/twist2/train """ from __future__ import annotations diff --git a/train_mimic/tasks/tracking/config/constants.py b/train_mimic/tasks/tracking/config/constants.py index f534929e..a96cf8b1 100644 --- a/train_mimic/tasks/tracking/config/constants.py +++ b/train_mimic/tasks/tracking/config/constants.py @@ -1,6 +1,6 @@ """Public constants for supported tracking tasks.""" -DEFAULT_TRAIN_MOTION_FILE = "data/datasets/twist2_full/train" +DEFAULT_TRAIN_MOTION_FILE = "data/datasets/twist2/train" GENERAL_TRACKING_TASK = "General-Tracking-G1" GENERAL_TRACKING_EXPERIMENT_NAME = "g1_general_tracking" From 512f6101c38316cc9e3287d86b96c1ee3f9ce245 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Tue, 21 Apr 2026 19:36:55 +0800 Subject: [PATCH 005/122] Preserve both raw and curated SEED dataset entry points Keep the canonical `seed.yaml` spec pointed at the full SEED source while moving the strict Groot-filtered variant into a dedicated `seed_clean.yaml` spec. This makes the public config names match their actual coverage and keeps the curated path explicit for reproducible review or training flows. Constraint: Keep existing dataset-builder semantics and avoid hidden filtering in the default SEED spec Rejected: Leave `groot_strict` inside `seed.yaml` | makes the canonical SEED entry point silently exclude source data Confidence: high Scope-risk: narrow Reversibility: clean Directive: Use `seed_clean.yaml` for strict curated SEED runs; keep `seed.yaml` as the unfiltered baseline unless downstream defaults are updated intentionally Tested: `pytest tests/test_dataset_v2.py -q` Not-tested: End-to-end dataset build against real SEED assets --- train_mimic/configs/datasets/seed.yaml | 3 +-- train_mimic/configs/datasets/seed_clean.yaml | 17 +++++++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) create mode 100644 train_mimic/configs/datasets/seed_clean.yaml diff --git a/train_mimic/configs/datasets/seed.yaml b/train_mimic/configs/datasets/seed.yaml index 17736f0a..45d0c40e 100644 --- a/train_mimic/configs/datasets/seed.yaml +++ b/train_mimic/configs/datasets/seed.yaml @@ -7,11 +7,10 @@ preprocess: ground_align: clip_min_foot min_frames: 22 sources: - - name: seed + - name: seed_full type: seed_csv input: data/SEED/g1/csv metadata_csv: data/SEED/seed_metadata_v003.csv - seed_filter_preset: groot_strict weight: 1.0 filters: is_mirror: [false] diff --git a/train_mimic/configs/datasets/seed_clean.yaml b/train_mimic/configs/datasets/seed_clean.yaml new file mode 100644 index 00000000..8a05a26c --- /dev/null +++ b/train_mimic/configs/datasets/seed_clean.yaml @@ -0,0 +1,17 @@ +name: seed_clean +target_fps: 30 +val_percent: 5 +hash_salt: "" +preprocess: + normalize_root_xy: true + ground_align: clip_min_foot + min_frames: 22 +sources: + - name: seed_full + type: seed_csv + input: data/SEED/g1/csv + metadata_csv: data/SEED/seed_metadata_v003.csv + seed_filter_preset: groot_strict + weight: 1.0 + filters: + is_mirror: [false] From 61b507a748c8361ff2d16a5d941a4212329cec96 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Fri, 8 May 2026 07:13:41 +0000 Subject: [PATCH 006/122] Simplify tracking termination configuration --- tests/test_domain_randomization.py | 2 +- tests/test_termination_config.py | 61 ++----------------- train_mimic/scripts/benchmark.py | 1 - train_mimic/tasks/tracking/config/env.py | 3 +- .../tasks/tracking/mdp/terminations.py | 34 ----------- .../tasks/tracking/tracking_env_cfg.py | 25 ++------ 6 files changed, 13 insertions(+), 113 deletions(-) diff --git a/tests/test_domain_randomization.py b/tests/test_domain_randomization.py index 96432d6e..dad68089 100644 --- a/tests/test_domain_randomization.py +++ b/tests/test_domain_randomization.py @@ -24,7 +24,7 @@ def test_general_tracking_domain_randomization_matches_gr00t_active_set() -> Non push_robot = events["push_robot"] assert push_robot.func is mdp.push_by_setting_velocity assert push_robot.mode == "interval" - assert push_robot.interval_range_s == (4.0, 6.0) + assert push_robot.interval_range_s == (1.0, 3.0) assert push_robot.params["velocity_range"] == { "x": (-0.5, 0.5), "y": (-0.5, 0.5), diff --git a/tests/test_termination_config.py b/tests/test_termination_config.py index bfda66f4..3c9303b5 100644 --- a/tests/test_termination_config.py +++ b/tests/test_termination_config.py @@ -1,14 +1,10 @@ from __future__ import annotations -from types import SimpleNamespace - -import torch - from train_mimic.app import DEFAULT_TASK from train_mimic.tasks.tracking import mdp -def test_general_tracking_termination_config_matches_requested_policy() -> None: +def test_general_tracking_termination_config_matches_baseline_policy() -> None: import mjlab.tasks # noqa: F401 import train_mimic.tasks # noqa: F401 from mjlab.tasks.registry import load_env_cfg @@ -24,12 +20,10 @@ def test_general_tracking_termination_config_matches_requested_policy() -> None: } anchor_pos = terminations["anchor_pos"] - assert anchor_pos.func is mdp.bad_anchor_pos_z_only_adaptive + assert anchor_pos.func is mdp.bad_anchor_pos_z_only assert anchor_pos.params == { "command_name": "motion", - "threshold": 0.15, - "down_threshold": 0.4, - "root_height_threshold": 0.5, + "threshold": 0.4, } anchor_ori = terminations["anchor_ori"] @@ -37,12 +31,10 @@ def test_general_tracking_termination_config_matches_requested_policy() -> None: assert anchor_ori.params["threshold"] == 1.0 ee_body_pos = terminations["ee_body_pos"] - assert ee_body_pos.func is mdp.bad_motion_body_pos_z_only_adaptive + assert ee_body_pos.func is mdp.bad_motion_body_pos_z_only assert ee_body_pos.params == { "command_name": "motion", - "threshold": 0.15, - "down_threshold": 0.4, - "root_height_threshold": 0.5, + "threshold": 0.4, "body_names": ( "left_ankle_roll_link", "right_ankle_roll_link", @@ -50,46 +42,3 @@ def test_general_tracking_termination_config_matches_requested_policy() -> None: "right_wrist_yaw_link", ), } - -def test_adaptive_height_termination_uses_relaxed_threshold_for_low_reference() -> None: - command = SimpleNamespace( - cfg=SimpleNamespace(body_names=("left_ankle_roll_link",)), - anchor_pos_w=torch.tensor([[0.0, 0.0, 0.3], [0.0, 0.0, 0.8]], dtype=torch.float32), - robot_anchor_pos_w=torch.tensor([[0.0, 0.0, -0.09], [0.0, 0.0, 0.55]], dtype=torch.float32), - body_pos_relative_w=torch.tensor( - [ - [[0.0, 0.0, 0.30]], - [[0.0, 0.0, 0.80]], - ], - dtype=torch.float32, - ), - robot_body_pos_w=torch.tensor( - [ - [[0.0, 0.0, -0.09]], - [[0.0, 0.0, 0.55]], - ], - dtype=torch.float32, - ), - ) - env = SimpleNamespace( - command_manager=SimpleNamespace(get_term=lambda _name: command), - ) - - anchor_done = mdp.bad_anchor_pos_z_only_adaptive( - env, - "motion", - threshold=0.15, - down_threshold=0.4, - root_height_threshold=0.5, - ) - ee_done = mdp.bad_motion_body_pos_z_only_adaptive( - env, - "motion", - threshold=0.15, - down_threshold=0.4, - root_height_threshold=0.5, - body_names=("left_ankle_roll_link",), - ) - - assert anchor_done.tolist() == [False, True] - assert ee_done.tolist() == [False, True] diff --git a/train_mimic/scripts/benchmark.py b/train_mimic/scripts/benchmark.py index 68a940d4..f9748e84 100644 --- a/train_mimic/scripts/benchmark.py +++ b/train_mimic/scripts/benchmark.py @@ -288,7 +288,6 @@ def main() -> int: env_cfg.terminations.pop("anchor_pos", None) env_cfg.terminations.pop("anchor_ori", None) env_cfg.terminations.pop("ee_body_pos", None) - env_cfg.terminations.pop("foot_pos_xyz", None) env_cfg.terminations.pop("body_z_tracking_failure", None) env_cfg.terminations.pop("gravity_tracking_failure", None) diff --git a/train_mimic/tasks/tracking/config/env.py b/train_mimic/tasks/tracking/config/env.py index bf1a897f..dbd2c25e 100644 --- a/train_mimic/tasks/tracking/config/env.py +++ b/train_mimic/tasks/tracking/config/env.py @@ -143,8 +143,9 @@ def make_general_tracking_env_cfg( "left_wrist_yaw_link", "right_wrist_yaw_link", ) - cfg.terminations.pop("foot_pos_xyz", None) + cfg.terminations["anchor_pos"].params["threshold"] = 0.4 cfg.terminations["anchor_ori"].params["threshold"] = 1.0 + cfg.terminations["ee_body_pos"].params["threshold"] = 0.4 cfg.viewer.body_name = "torso_link" cfg.episode_length_s = 10.0 if cfg.sim.njmax < 500: diff --git a/train_mimic/tasks/tracking/mdp/terminations.py b/train_mimic/tasks/tracking/mdp/terminations.py index 5c62bb0a..f5d64e29 100644 --- a/train_mimic/tasks/tracking/mdp/terminations.py +++ b/train_mimic/tasks/tracking/mdp/terminations.py @@ -34,20 +34,6 @@ def bad_anchor_pos_z_only( ) -def bad_anchor_pos_z_only_adaptive( - env: ManagerBasedRlEnv, - command_name: str, - threshold: float, - down_threshold: float, - root_height_threshold: float, -) -> torch.Tensor: - command = cast(MotionCommand, env.command_manager.get_term(command_name)) - height_err = torch.abs(command.anchor_pos_w[:, -1] - command.robot_anchor_pos_w[:, -1]) - threshold_tensor = torch.full_like(height_err, threshold) - threshold_tensor[command.anchor_pos_w[:, -1] < root_height_threshold] = down_threshold - return height_err > threshold_tensor - - def bad_anchor_ori( env: ManagerBasedRlEnv, asset_cfg: SceneEntityCfg, command_name: str, threshold: float ) -> torch.Tensor: @@ -98,23 +84,3 @@ def bad_motion_body_pos_z_only( - command.robot_body_pos_w[:, body_indexes, -1] ) return torch.any(error > threshold, dim=-1) - - -def bad_motion_body_pos_z_only_adaptive( - env: ManagerBasedRlEnv, - command_name: str, - threshold: float, - down_threshold: float, - root_height_threshold: float, - body_names: tuple[str, ...] | None = None, -) -> torch.Tensor: - command = cast(MotionCommand, env.command_manager.get_term(command_name)) - - body_indexes = _get_body_indexes(command, body_names) - error = torch.abs( - command.body_pos_relative_w[:, body_indexes, -1] - - command.robot_body_pos_w[:, body_indexes, -1] - ) - threshold_tensor = torch.full_like(error, threshold) - threshold_tensor[command.anchor_pos_w[:, -1] < root_height_threshold] = down_threshold - return torch.any(error > threshold_tensor, dim=-1) diff --git a/train_mimic/tasks/tracking/tracking_env_cfg.py b/train_mimic/tasks/tracking/tracking_env_cfg.py index 685eab62..c09fd060 100644 --- a/train_mimic/tasks/tracking/tracking_env_cfg.py +++ b/train_mimic/tasks/tracking/tracking_env_cfg.py @@ -163,7 +163,7 @@ def make_tracking_env_cfg() -> ManagerBasedRlEnvCfg: "push_robot": EventTermCfg( func=mdp.push_by_setting_velocity, mode="interval", - interval_range_s=(4.0, 6.0), + interval_range_s=(1.0, 3.0), params={"velocity_range": VELOCITY_RANGE}, ), "base_com": EventTermCfg( @@ -256,13 +256,8 @@ def make_tracking_env_cfg() -> ManagerBasedRlEnvCfg: terminations: dict[str, TerminationTermCfg] = { "time_out": TerminationTermCfg(func=mdp.time_out, time_out=True), "anchor_pos": TerminationTermCfg( - func=mdp.bad_anchor_pos_z_only_adaptive, - params={ - "command_name": "motion", - "threshold": 0.15, - "down_threshold": 0.4, - "root_height_threshold": 0.5, - }, + func=mdp.bad_anchor_pos_z_only, + params={"command_name": "motion", "threshold": 0.25}, ), "anchor_ori": TerminationTermCfg( func=mdp.bad_anchor_ori, @@ -273,20 +268,10 @@ def make_tracking_env_cfg() -> ManagerBasedRlEnvCfg: }, ), "ee_body_pos": TerminationTermCfg( - func=mdp.bad_motion_body_pos_z_only_adaptive, - params={ - "command_name": "motion", - "threshold": 0.15, - "down_threshold": 0.4, - "root_height_threshold": 0.5, - "body_names": (), # Set per-robot. - }, - ), - "foot_pos_xyz": TerminationTermCfg( - func=mdp.bad_motion_body_pos, + func=mdp.bad_motion_body_pos_z_only, params={ "command_name": "motion", - "threshold": 0.2, + "threshold": 0.25, "body_names": (), # Set per-robot. }, ), From a09aa9e7e4672c726ac6a759f35d0d03dcc306ac Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Sat, 9 May 2026 15:46:47 +0000 Subject: [PATCH 007/122] Add self-collision tracking reward --- tests/test_task_registry.py | 21 +++++++++++-- tests/test_tracking_rewards.py | 31 ++++++++++++++++++ train_mimic/tasks/tracking/config/env.py | 38 +++++++++++++++++++++++ train_mimic/tasks/tracking/mdp/rewards.py | 38 +++++++++++++++++++++++ 4 files changed, 126 insertions(+), 2 deletions(-) create mode 100644 tests/test_tracking_rewards.py diff --git a/tests/test_task_registry.py b/tests/test_task_registry.py index ff5b3096..1f7aaf6b 100644 --- a/tests/test_task_registry.py +++ b/tests/test_task_registry.py @@ -33,9 +33,26 @@ def test_general_tracking_task_is_registered() -> None: assert "critic_history" in env_cfg.observations assert env_cfg.commands["motion"].sampling_mode == "uniform" assert env_cfg.commands["motion"].window_steps == (0,) - assert "self_collisions" not in env_cfg.rewards + reward = env_cfg.rewards["self_collisions"] + assert reward.weight == -0.1 + assert reward.params == { + "sensor_name": "self_collision", + "force_threshold": 1.0, + } assert "undesired_contacts" not in env_cfg.rewards - assert not getattr(env_cfg.scene, "sensors", ()) + sensors = {sensor.name: sensor for sensor in env_cfg.scene.sensors} + assert set(sensors) == {"self_collision"} + assert sensors["self_collision"].primary.mode == "body" + assert sensors["self_collision"].primary.pattern == r".*" + assert sensors["self_collision"].primary.exclude == ( + "left_ankle_roll_link", + "right_ankle_roll_link", + "left_wrist_yaw_link", + "right_wrist_yaw_link", + ) + assert sensors["self_collision"].secondary.mode == "subtree" + assert sensors["self_collision"].secondary.pattern == "pelvis" + assert sensors["self_collision"].reduce == "maxforce" rl_cfg = load_rl_cfg(DEFAULT_TASK) assert rl_cfg.experiment_name == GENERAL_TRACKING_EXPERIMENT_NAME assert rl_cfg.actor.hidden_dims == (1024, 512, 256, 256, 128) diff --git a/tests/test_tracking_rewards.py b/tests/test_tracking_rewards.py new file mode 100644 index 00000000..7736cc1d --- /dev/null +++ b/tests/test_tracking_rewards.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import torch + +from train_mimic.tasks.tracking.mdp.rewards import self_collision_cost + + +def _env_with_force_history(force_history: torch.Tensor) -> SimpleNamespace: + sensor = SimpleNamespace( + data=SimpleNamespace(force_history=force_history, found=None) + ) + return SimpleNamespace(scene={"self_collision": sensor}) + + +def test_self_collision_cost_counts_history_frames_not_contacts() -> None: + force_history = torch.zeros((2, 3, 4, 3), dtype=torch.float32) + force_history[0, 0, 0, 0] = 2.0 + force_history[0, 1, 0, 0] = 3.0 + force_history[0, 2, 2, 2] = 2.0 + force_history[1, 0, 1, 0] = 0.5 + force_history[1, 0, 3, 0] = 1.0 + + penalty = self_collision_cost( + _env_with_force_history(force_history), + sensor_name="self_collision", + force_threshold=1.0, + ) + + torch.testing.assert_close(penalty, torch.tensor([2.0, 0.0])) diff --git a/train_mimic/tasks/tracking/config/env.py b/train_mimic/tasks/tracking/config/env.py index dbd2c25e..313fc282 100644 --- a/train_mimic/tasks/tracking/config/env.py +++ b/train_mimic/tasks/tracking/config/env.py @@ -8,7 +8,9 @@ from mjlab.envs import ManagerBasedRlEnvCfg from mjlab.envs.mdp.actions import JointPositionActionCfg from mjlab.managers.observation_manager import ObservationGroupCfg, ObservationTermCfg +from mjlab.managers.reward_manager import RewardTermCfg from mjlab.managers.scene_entity_config import SceneEntityCfg +from mjlab.sensor import ContactMatch, ContactSensorCfg from mjlab.utils.noise import UniformNoiseCfg as Unoise from train_mimic.tasks.tracking import mdp @@ -110,6 +112,41 @@ def _add_history_obs_groups( } +def _configure_self_collision_reward(cfg: ManagerBasedRlEnvCfg) -> None: + excluded_body_names = ( + "left_ankle_roll_link", + "right_ankle_roll_link", + "left_wrist_yaw_link", + "right_wrist_yaw_link", + ) + cfg.scene.sensors = ( + *tuple(getattr(cfg.scene, "sensors", ()) or ()), + ContactSensorCfg( + name="self_collision", + # Exclude only primary bodies: wrist/ankle vs torso is still caught by torso. + primary=ContactMatch( + mode="body", + pattern=r".*", + entity="robot", + exclude=excluded_body_names, + ), + secondary=ContactMatch(mode="subtree", pattern="pelvis", entity="robot"), + fields=("found", "force"), + reduce="maxforce", + num_slots=1, + history_length=4, + ), + ) + cfg.rewards["self_collisions"] = RewardTermCfg( + func=mdp.self_collision_cost, + weight=-0.1, + params={ + "sensor_name": "self_collision", + "force_threshold": 1.0, + }, + ) + + def make_general_tracking_env_cfg( *, play: bool = False, ) -> ManagerBasedRlEnvCfg: @@ -137,6 +174,7 @@ def make_general_tracking_env_cfg( cfg.events["randomize_rigid_body_mass"].params[ "asset_cfg" ].body_names = r".*wrist_yaw.*|torso_link" + _configure_self_collision_reward(cfg) cfg.terminations["ee_body_pos"].params["body_names"] = ( "left_ankle_roll_link", "right_ankle_roll_link", diff --git a/train_mimic/tasks/tracking/mdp/rewards.py b/train_mimic/tasks/tracking/mdp/rewards.py index dee3318a..db36632e 100644 --- a/train_mimic/tasks/tracking/mdp/rewards.py +++ b/train_mimic/tasks/tracking/mdp/rewards.py @@ -143,6 +143,44 @@ def motion_global_body_angular_velocity_error_exp( return torch.exp(-error.mean(-1) / std**2) +def self_collision_cost( + env: ManagerBasedRlEnv, + sensor_name: str | tuple[str, ...], + force_threshold: float = 10.0, +) -> torch.Tensor: + """Penalize self-collision history frames above the configured force threshold.""" + hit = _self_collision_hits(env, sensor_name, force_threshold) + return hit.sum(dim=-1).float() + + +def _self_collision_hits( + env: ManagerBasedRlEnv, + sensor_name: str | tuple[str, ...], + force_threshold: float, +) -> torch.Tensor: + sensor_names = (sensor_name,) if isinstance(sensor_name, str) else sensor_name + force_histories = [] + found_values = [] + for name in sensor_names: + data = env.scene[name].data + if data.force_history is not None: + force_histories.append(data.force_history) + else: + assert data.found is not None + found = data.found + if found.ndim == 1: + found = found.unsqueeze(-1) + found_values.append(found) + + if force_histories: + force_history = torch.cat(force_histories, dim=1) + force_mag = torch.norm(force_history, dim=-1) + return (force_mag > force_threshold).any(dim=1) + + found = torch.cat(found_values, dim=1) + return (found > 0).any(dim=1, keepdim=True) + + class joint_torque_limits: """Penalize actuator-force limit violations with a configurable soft margin.""" From 3fcdc48cc6fe2a8eef76aac612c5682d7d755361 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Mon, 11 May 2026 05:56:13 +0000 Subject: [PATCH 008/122] Remove adaptive motion sampling --- docs/docs/tutorials/training.md | 2 +- .../current/tutorials/training.md | 2 +- tests/test_adaptive_sampling.py | 284 -------------- train_mimic/scripts/benchmark.py | 3 - train_mimic/scripts/train.py | 2 +- train_mimic/tasks/tracking/mdp/commands.py | 357 +----------------- 6 files changed, 8 insertions(+), 642 deletions(-) delete mode 100644 tests/test_adaptive_sampling.py diff --git a/docs/docs/tutorials/training.md b/docs/docs/tutorials/training.md index ead24ecb..03c9ab70 100644 --- a/docs/docs/tutorials/training.md +++ b/docs/docs/tutorials/training.md @@ -113,4 +113,4 @@ Key files: - `train_mimic/app.py` - Shared entry point for train/play/benchmark - `train_mimic/tasks/tracking/config/env.py` - General-Tracking-G1 env builder - `train_mimic/tasks/tracking/config/rl.py` - TemporalCNN PPO config -- `train_mimic/tasks/tracking/mdp/commands.py` - Supports `adaptive` / `uniform` / `start` sampling modes +- `train_mimic/tasks/tracking/mdp/commands.py` - Supports `uniform` / `start` sampling modes diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md index e8f23f80..16e3f59d 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md @@ -113,4 +113,4 @@ train_mimic/scripts - `train_mimic/app.py` - 训练/播放/评估的统一入口 - `train_mimic/tasks/tracking/config/env.py` - General-Tracking-G1 环境构建器 - `train_mimic/tasks/tracking/config/rl.py` - TemporalCNN PPO 配置 -- `train_mimic/tasks/tracking/mdp/commands.py` - 支持 `adaptive` / `uniform` / `start` 三种采样模式 +- `train_mimic/tasks/tracking/mdp/commands.py` - 支持 `uniform` / `start` 两种采样模式 diff --git a/tests/test_adaptive_sampling.py b/tests/test_adaptive_sampling.py deleted file mode 100644 index 49f932f5..00000000 --- a/tests/test_adaptive_sampling.py +++ /dev/null @@ -1,284 +0,0 @@ -from __future__ import annotations - -import pytest -import torch - -from train_mimic.tasks.tracking.mdp.commands import ( - _cap_failure_rates, - _compute_clip_counts, - _compute_clip_failure_rate, - _normalize_sampling_probabilities, - _validate_legacy_adaptive_config, -) - - -# --------------------------------------------------------------------------- -# Existing tests (refactored helper, same semantics) -# --------------------------------------------------------------------------- - - -def test_compute_clip_failure_rate_is_invariant_to_parallel_env_count() -> None: - motion_ids = torch.tensor([0, 0, 0, 1, 1, 1], dtype=torch.long) - episode_failed = torch.tensor([True, False, True, False, False, True]) - - small_rate = _compute_clip_failure_rate(motion_ids, episode_failed, bin_count=3) - large_rate = _compute_clip_failure_rate( - motion_ids.repeat(16), episode_failed.repeat(16), bin_count=3 - ) - - expected = torch.tensor([2.0 / 3.0, 1.0 / 3.0, 0.0], dtype=torch.float32) - assert torch.allclose(small_rate, expected) - assert torch.allclose(large_rate, expected) - - -def test_compute_clip_failure_rate_ignores_unseen_clips() -> None: - motion_ids = torch.tensor([1, 1, 3, 3], dtype=torch.long) - episode_failed = torch.tensor([True, False, False, False]) - - rate = _compute_clip_failure_rate(motion_ids, episode_failed, bin_count=5) - - expected = torch.tensor([0.0, 0.5, 0.0, 0.0, 0.0], dtype=torch.float32) - assert torch.allclose(rate, expected) - - -# --------------------------------------------------------------------------- -# New tests for _compute_clip_counts -# --------------------------------------------------------------------------- - - -def test_compute_clip_counts_returns_correct_counts() -> None: - motion_ids = torch.tensor([0, 0, 0, 1, 1, 1], dtype=torch.long) - episode_failed = torch.tensor([True, False, True, False, False, True]) - - exposure, failure = _compute_clip_counts(motion_ids, episode_failed, bin_count=3) - - assert torch.allclose(exposure, torch.tensor([3.0, 3.0, 0.0])) - assert torch.allclose(failure, torch.tensor([2.0, 1.0, 0.0])) - - -# --------------------------------------------------------------------------- -# EMA no-decay on empty steps -# --------------------------------------------------------------------------- - - -def test_ema_skips_update_when_no_data() -> None: - """bin_failed_rate must NOT decay when accumulators are empty.""" - bin_count = 4 - alpha = 0.01 - bin_failed_rate = torch.tensor([0.5, 0.3, 0.0, 0.2]) - accum_exposure = torch.zeros(bin_count) - accum_failure = torch.zeros(bin_count) - - # Simulate the guarded EMA logic from _update_command - if accum_exposure.sum() > 0: - valid = accum_exposure > 0 - global_rate = torch.zeros_like(bin_failed_rate) - global_rate[valid] = accum_failure[valid] / accum_exposure[valid] - bin_failed_rate = alpha * global_rate + (1 - alpha) * bin_failed_rate - - expected = torch.tensor([0.5, 0.3, 0.0, 0.2]) - assert torch.allclose(bin_failed_rate, expected) - - -# --------------------------------------------------------------------------- -# Fail-fast validation for invalid adaptive distributions -# --------------------------------------------------------------------------- - - -def test_normalize_sampling_probabilities_raises_on_zero_mass() -> None: - sampling_probabilities = torch.zeros(5) - - with pytest.raises(ValueError, match="invalid probability mass"): - _normalize_sampling_probabilities( - sampling_probabilities, - adaptive_uniform_ratio=0.0, - bin_count=5, - ) - - -def test_validate_legacy_adaptive_config_rejects_nondefault_kernel_size() -> None: - with pytest.raises(ValueError, match="adaptive_kernel_size"): - _validate_legacy_adaptive_config(adaptive_kernel_size=3, adaptive_lambda=0.8) - - -def test_validate_legacy_adaptive_config_rejects_nondefault_lambda() -> None: - with pytest.raises(ValueError, match="adaptive_lambda"): - _validate_legacy_adaptive_config(adaptive_kernel_size=1, adaptive_lambda=0.5) - - -# --------------------------------------------------------------------------- -# Accumulation sums across multiple resamples -# --------------------------------------------------------------------------- - - -def test_accumulation_sums_across_resamples() -> None: - """Two accumulate calls before EMA update should sum counts correctly.""" - bin_count = 3 - accum_exposure = torch.zeros(bin_count) - accum_failure = torch.zeros(bin_count) - - # First resample batch - ids1 = torch.tensor([0, 0, 1], dtype=torch.long) - failed1 = torch.tensor([True, False, True]) - e1, f1 = _compute_clip_counts(ids1, failed1, bin_count) - accum_exposure += e1 - accum_failure += f1 - - # Second resample batch - ids2 = torch.tensor([1, 2, 2], dtype=torch.long) - failed2 = torch.tensor([False, True, False]) - e2, f2 = _compute_clip_counts(ids2, failed2, bin_count) - accum_exposure += e2 - accum_failure += f2 - - assert torch.allclose(accum_exposure, torch.tensor([2.0, 2.0, 2.0])) - assert torch.allclose(accum_failure, torch.tensor([1.0, 1.0, 1.0])) - - -# --------------------------------------------------------------------------- -# adaptive_bin: _cap_failure_rates -# --------------------------------------------------------------------------- - - -def test_cap_failure_rates_clamps_outliers() -> None: - rates = torch.tensor([0.1, 0.9, 0.2]) - # mean = 0.4, beta=2.0 -> cap = 0.8 - capped = _cap_failure_rates(rates, beta=2.0) - expected = torch.tensor([0.1, 0.8, 0.2]) - assert torch.allclose(capped, expected) - - -def test_cap_failure_rates_all_zero() -> None: - rates = torch.zeros(5) - capped = _cap_failure_rates(rates, beta=5.0) - assert torch.allclose(capped, torch.zeros(5)) - - -def test_cap_failure_rates_uniform() -> None: - """When all rates are equal, capping changes nothing.""" - rates = torch.tensor([0.3, 0.3, 0.3]) - capped = _cap_failure_rates(rates, beta=2.0) - assert torch.allclose(capped, rates) - - -# --------------------------------------------------------------------------- -# adaptive_bin: build_time_bins (via MotionLib mock) -# --------------------------------------------------------------------------- - - -def _make_mock_motion_lib( - clip_sample_start_s: list[float], - clip_sample_end_s: list[float], - clip_weights: list[float], -): - """Create a lightweight mock with just the fields build_time_bins needs.""" - import math as _math - from types import SimpleNamespace - - from train_mimic.tasks.tracking.mdp.commands import MotionLib - - num_clips = len(clip_weights) - device = "cpu" - - mock = SimpleNamespace() - mock.num_clips = num_clips - mock._device = device - mock.clip_weights = torch.tensor(clip_weights, dtype=torch.float32, device=device) - mock.clip_sample_start_s = torch.tensor( - clip_sample_start_s, dtype=torch.float32, device=device - ) - mock.clip_sample_end_s = torch.tensor( - clip_sample_end_s, dtype=torch.float32, device=device - ) - # Bind the real method - mock.build_time_bins = MotionLib.build_time_bins.__get__(mock, type(mock)) - return mock - - -def test_build_time_bins_single_clip() -> None: - ml = _make_mock_motion_lib( - clip_sample_start_s=[0.0], - clip_sample_end_s=[12.0], - clip_weights=[1.0], - ) - result = ml.build_time_bins(bin_duration_s=5.0) - - assert result["num_bins"] == 3 # [0,5), [5,10), [10,12) - assert torch.equal(result["bin_clip_id"], torch.tensor([0, 0, 0])) - assert torch.allclose(result["bin_start_s"], torch.tensor([0.0, 5.0, 10.0])) - assert torch.allclose(result["bin_end_s"], torch.tensor([5.0, 10.0, 12.0])) - assert torch.allclose(result["bin_duration"], torch.tensor([5.0, 5.0, 2.0])) - assert result["clip_bin_offset"][0].item() == 0 - - -def test_build_time_bins_multiple_clips() -> None: - ml = _make_mock_motion_lib( - clip_sample_start_s=[0.0, 0.0, 0.0], - clip_sample_end_s=[3.0, 7.0, 10.0], - clip_weights=[1.0, 1.0, 1.0], - ) - result = ml.build_time_bins(bin_duration_s=5.0) - - # clip 0: 3s -> 1 bin; clip 1: 7s -> 2 bins; clip 2: 10s -> 2 bins - assert result["num_bins"] == 5 - expected_clip_ids = torch.tensor([0, 1, 1, 2, 2]) - assert torch.equal(result["bin_clip_id"], expected_clip_ids) - - -def test_build_time_bins_skips_zero_weight() -> None: - ml = _make_mock_motion_lib( - clip_sample_start_s=[0.0, 0.0], - clip_sample_end_s=[10.0, 10.0], - clip_weights=[1.0, 0.0], - ) - result = ml.build_time_bins(bin_duration_s=5.0) - - assert result["num_bins"] == 2 # only clip 0 - assert torch.equal(result["bin_clip_id"], torch.tensor([0, 0])) - assert result["clip_bin_offset"][1].item() == -1 - - -def test_build_time_bins_short_clip() -> None: - """A clip shorter than bin_duration produces exactly 1 bin.""" - ml = _make_mock_motion_lib( - clip_sample_start_s=[0.0], - clip_sample_end_s=[2.0], - clip_weights=[1.0], - ) - result = ml.build_time_bins(bin_duration_s=5.0) - - assert result["num_bins"] == 1 - assert torch.allclose(result["bin_duration"], torch.tensor([2.0])) - - -# --------------------------------------------------------------------------- -# adaptive_bin: probability formula -# --------------------------------------------------------------------------- - - -def test_adaptive_bin_probability_formula() -> None: - """Verify the full capped + blended probability computation.""" - rates = torch.tensor([0.1, 0.9, 0.2, 0.0]) - beta = 2.0 - alpha = 0.8 - N = len(rates) - - # Step 1: cap - capped = _cap_failure_rates(rates, beta) - # mean=0.3, cap=0.6 -> [0.1, 0.6, 0.2, 0.0] - expected_capped = torch.tensor([0.1, 0.6, 0.2, 0.0]) - assert torch.allclose(capped, expected_capped) - - # Step 2: normalize to p_hat - capped_sum = capped.sum() - p_hat = capped / capped_sum # [0.1/0.9, 0.6/0.9, 0.2/0.9, 0] - - # Step 3: blend - p_final = alpha * p_hat + (1.0 - alpha) / N - p_final = p_final / p_final.sum() - - # Verify it's a valid distribution - assert torch.allclose(p_final.sum(), torch.tensor(1.0)) - assert (p_final > 0).all(), "All bins should have positive probability" - # Bin 1 (highest failure) should have highest probability - assert p_final[1] == p_final.max() diff --git a/train_mimic/scripts/benchmark.py b/train_mimic/scripts/benchmark.py index 4dab5b72..d826dd8a 100644 --- a/train_mimic/scripts/benchmark.py +++ b/train_mimic/scripts/benchmark.py @@ -505,9 +505,6 @@ def _open_clip_writer(idx: int): "error_body_ang_vel", "error_joint_pos", "error_joint_vel", - "sampling_entropy", - "sampling_top1_prob", - "sampling_top1_bin", ): if key not in metric_stats: continue diff --git a/train_mimic/scripts/train.py b/train_mimic/scripts/train.py index e8e2d30e..7126fa35 100644 --- a/train_mimic/scripts/train.py +++ b/train_mimic/scripts/train.py @@ -74,7 +74,7 @@ def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace: ), ) parser.add_argument("--sampling_mode", type=str, default=None, - choices=["uniform", "adaptive", "adaptive_bin", "start"], + choices=["uniform", "start"], help="Motion sampling mode (default: from task config)") parser.add_argument("--device", type=str, default=None) parser.add_argument( diff --git a/train_mimic/tasks/tracking/mdp/commands.py b/train_mimic/tasks/tracking/mdp/commands.py index f24e964d..916c45b4 100644 --- a/train_mimic/tasks/tracking/mdp/commands.py +++ b/train_mimic/tasks/tracking/mdp/commands.py @@ -2,7 +2,6 @@ import copy import logging -import math from dataclasses import dataclass, field from pathlib import Path from typing import TYPE_CHECKING, Any, Literal @@ -58,80 +57,6 @@ def _batched_quat_slerp( return result / result.norm(dim=-1, keepdim=True) -def _compute_clip_counts( - motion_ids: torch.Tensor, episode_failed: torch.Tensor, bin_count: int -) -> tuple[torch.Tensor, torch.Tensor]: - """Return (exposure_count, failure_count) per clip bin.""" - exposure = torch.bincount(motion_ids, minlength=bin_count).float() - failure = torch.bincount( - motion_ids[episode_failed], minlength=bin_count - ).float() - return exposure, failure - - -def _compute_clip_failure_rate( - motion_ids: torch.Tensor, episode_failed: torch.Tensor, bin_count: int -) -> torch.Tensor: - """Compute per-clip failure rate for the current adaptive-sampling window.""" - exposure, failure = _compute_clip_counts(motion_ids, episode_failed, bin_count) - rate = torch.zeros(bin_count, dtype=torch.float32, device=motion_ids.device) - valid = exposure > 0 - rate[valid] = failure[valid] / exposure[valid] - return rate - - -def _is_distributed() -> bool: - """Return True when running inside an initialized torch.distributed group.""" - return torch.distributed.is_available() and torch.distributed.is_initialized() - - -def _normalize_sampling_probabilities( - sampling_probabilities: torch.Tensor, - *, - adaptive_uniform_ratio: float, - bin_count: int, -) -> torch.Tensor: - """Normalize adaptive sampling weights and fail fast on invalid mass.""" - prob_sum = sampling_probabilities.sum() - if not torch.isfinite(prob_sum) or prob_sum <= 0: - raise ValueError( - "Adaptive sampling produced an invalid probability mass. " - f"sum={prob_sum.item() if torch.isfinite(prob_sum) else prob_sum}, " - f"adaptive_uniform_ratio={adaptive_uniform_ratio}, bin_count={bin_count}. " - "Increase adaptive_uniform_ratio or accumulate failure statistics before " - "using pure adaptive sampling." - ) - - sampling_probabilities = sampling_probabilities / prob_sum - if not torch.isfinite(sampling_probabilities).all(): - raise ValueError("Adaptive sampling probabilities contain NaN or Inf values.") - return sampling_probabilities - - -def _cap_failure_rates(rates: torch.Tensor, beta: float) -> torch.Tensor: - """Clamp per-bin failure rates to ``beta * mean(rates)``. - - When all rates are zero the cap is zero and the returned tensor is all-zero, - which is safe — callers fall back to uniform sampling via the blend term. - """ - f_mean = rates.mean() - return torch.clamp(rates, max=beta * f_mean) - - -def _validate_legacy_adaptive_config(*, adaptive_kernel_size: int, adaptive_lambda: float) -> None: - """Reject legacy adaptive knobs that are no longer implemented.""" - if adaptive_kernel_size != 1: - raise ValueError( - "adaptive_kernel_size is not implemented in the restored adaptive sampler. " - f"Expected 1, got {adaptive_kernel_size}." - ) - if adaptive_lambda != 0.8: - raise ValueError( - "adaptive_lambda is not implemented in the restored adaptive sampler. " - f"Expected 0.8, got {adaptive_lambda}." - ) - - def _load_shard_dir(shard_dir: Path) -> dict[str, Any]: """Load and merge all shard NPZ files from a directory. @@ -386,69 +311,6 @@ def sample_start_times(self, motion_ids: torch.Tensor) -> torch.Tensor: """Return the earliest valid center time for each motion id.""" return self.clip_sample_start_s[motion_ids] - def build_time_bins( - self, bin_duration_s: float - ) -> dict[str, torch.Tensor | int]: - """Partition clips into fixed-duration time bins. - - Each clip is independently split into bins of ``bin_duration_s`` seconds. - The last bin in a clip may be shorter than ``bin_duration_s``. - - Returns a dict with: - - ``num_bins``: total number of bins across all clips. - - ``bin_clip_id``: (num_bins,) which clip each bin belongs to. - - ``bin_start_s``: (num_bins,) start time (seconds) within the clip. - - ``bin_end_s``: (num_bins,) end time (seconds) within the clip. - - ``bin_duration``: (num_bins,) actual duration of each bin. - - ``clip_bin_offset``: (num_clips,) index of the first bin for each clip - (-1 for zero-weight clips). - """ - clip_ids: list[int] = [] - starts: list[float] = [] - ends: list[float] = [] - clip_bin_offsets = torch.full( - (self.num_clips,), -1, dtype=torch.long, device=self._device - ) - - for i in range(self.num_clips): - w = self.clip_weights[i].item() - if w <= 0: - continue - t0 = self.clip_sample_start_s[i].item() - t1 = self.clip_sample_end_s[i].item() - sampleable = t1 - t0 - if sampleable <= 0: - continue - - clip_bin_offsets[i] = len(clip_ids) - n_bins = math.ceil(sampleable / bin_duration_s) - for j in range(n_bins): - b_start = t0 + j * bin_duration_s - b_end = min(t0 + (j + 1) * bin_duration_s, t1) - clip_ids.append(i) - starts.append(b_start) - ends.append(b_end) - - if not clip_ids: - raise ValueError( - "No valid time bins could be created — all clips have zero " - "weight or zero sampleable range." - ) - - bin_clip_id = torch.tensor(clip_ids, dtype=torch.long, device=self._device) - bin_start_s = torch.tensor(starts, dtype=torch.float32, device=self._device) - bin_end_s = torch.tensor(ends, dtype=torch.float32, device=self._device) - bin_duration = bin_end_s - bin_start_s - - return { - "num_bins": len(clip_ids), - "bin_clip_id": bin_clip_id, - "bin_start_s": bin_start_s, - "bin_end_s": bin_end_s, - "bin_duration": bin_duration, - "clip_bin_offset": clip_bin_offsets, - } - def _compute_interpolation_state( self, motion_ids: torch.Tensor, @@ -585,10 +447,6 @@ class MotionCommand(CommandTerm): def __init__(self, cfg: MotionCommandCfg, env: ManagerBasedRlEnv): super().__init__(cfg, env) - _validate_legacy_adaptive_config( - adaptive_kernel_size=cfg.adaptive_kernel_size, - adaptive_lambda=cfg.adaptive_lambda, - ) self.robot: Entity = env.scene[cfg.entity_name] self.robot_anchor_body_index = self.robot.body_names.index( @@ -632,50 +490,6 @@ def __init__(self, cfg: MotionCommandCfg, env: ManagerBasedRlEnv): self.body_lin_vel_b = torch.zeros(self.num_envs, nb, 3, device=self.device) self.body_ang_vel_b = torch.zeros(self.num_envs, nb, 3, device=self.device) - self.bin_count = max(self.motion.num_clips, 1) - self.bin_failed_rate = torch.zeros( - self.bin_count, dtype=torch.float, device=self.device - ) - self._accum_exposure_count = torch.zeros( - self.bin_count, dtype=torch.float, device=self.device - ) - self._accum_failure_count = torch.zeros( - self.bin_count, dtype=torch.float, device=self.device - ) - self.kernel = torch.ones(1, device=self.device) - - # --- adaptive_bin mode: time-bin data structures --- - if self.cfg.sampling_mode == "adaptive_bin": - bin_data = self.motion.build_time_bins(self.cfg.adaptive_bin_duration_s) - self.num_time_bins: int = bin_data["num_bins"] - self.tb_clip_id: torch.Tensor = bin_data["bin_clip_id"] - self.tb_start_s: torch.Tensor = bin_data["bin_start_s"] - self.tb_end_s: torch.Tensor = bin_data["bin_end_s"] - self.tb_duration: torch.Tensor = bin_data["bin_duration"] - self.tb_clip_bin_offset: torch.Tensor = bin_data["clip_bin_offset"] - - self.tb_failed_rate = torch.zeros( - self.num_time_bins, dtype=torch.float, device=self.device - ) - self._tb_accum_exposure = torch.zeros( - self.num_time_bins, dtype=torch.float, device=self.device - ) - self._tb_accum_failure = torch.zeros( - self.num_time_bins, dtype=torch.float, device=self.device - ) - self._env_bin_ids = torch.zeros( - self.num_envs, dtype=torch.long, device=self.device - ) - # Override bin_count so metrics (entropy, top1) use time bins - self.bin_count = self.num_time_bins - _LOG.info( - "adaptive_bin: %d time bins (duration=%.1fs, beta=%.1f, alpha=%.2f)", - self.num_time_bins, - self.cfg.adaptive_bin_duration_s, - self.cfg.adaptive_bin_beta, - self.cfg.adaptive_bin_alpha, - ) - self.metrics["error_anchor_pos"] = torch.zeros(self.num_envs, device=self.device) self.metrics["error_anchor_rot"] = torch.zeros(self.num_envs, device=self.device) self.metrics["error_anchor_lin_vel"] = torch.zeros( @@ -688,9 +502,6 @@ def __init__(self, cfg: MotionCommandCfg, env: ManagerBasedRlEnv): self.metrics["error_body_rot"] = torch.zeros(self.num_envs, device=self.device) self.metrics["error_joint_pos"] = torch.zeros(self.num_envs, device=self.device) self.metrics["error_joint_vel"] = torch.zeros(self.num_envs, device=self.device) - self.metrics["sampling_entropy"] = torch.zeros(self.num_envs, device=self.device) - self.metrics["sampling_top1_prob"] = torch.zeros(self.num_envs, device=self.device) - self.metrics["sampling_top1_bin"] = torch.zeros(self.num_envs, device=self.device) # Feet standing state (for feet_air_time_ref rewards) if self.cfg.feet_body_names: @@ -932,110 +743,9 @@ def _update_metrics(self): # Sampling # ------------------------------------------------------------------ - def _adaptive_sampling(self, env_ids: torch.Tensor): - current_motion_ids = self.motion_ids[env_ids] - episode_failed = self._env.termination_manager.terminated[env_ids] - exposure, failure = _compute_clip_counts( - current_motion_ids, episode_failed, self.bin_count - ) - self._accum_exposure_count += exposure - self._accum_failure_count += failure - - sampling_probabilities = ( - self.bin_failed_rate - + self.cfg.adaptive_uniform_ratio / float(self.bin_count) - ) - sampling_probabilities = _normalize_sampling_probabilities( - sampling_probabilities, - adaptive_uniform_ratio=self.cfg.adaptive_uniform_ratio, - bin_count=self.bin_count, - ) - - sampled_clips = torch.multinomial( - sampling_probabilities, len(env_ids), replacement=True - ) - self.motion_ids[env_ids] = sampled_clips - self.motion_times[env_ids] = self.motion.sample_times(sampled_clips) - - entropy = -(sampling_probabilities * (sampling_probabilities + 1e-12).log()).sum() - entropy_norm = entropy / math.log(self.bin_count) if self.bin_count > 1 else 1.0 - top1_prob, top1_bin = sampling_probabilities.max(dim=0) - self.metrics["sampling_entropy"][:] = entropy_norm - self.metrics["sampling_top1_prob"][:] = top1_prob - self.metrics["sampling_top1_bin"][:] = top1_bin.float() / self.bin_count - def _uniform_sampling(self, env_ids: torch.Tensor): self.motion_ids[env_ids] = self.motion.sample_motion_ids(len(env_ids)) self.motion_times[env_ids] = self.motion.sample_times(self.motion_ids[env_ids]) - self.metrics["sampling_entropy"][:] = 1.0 - self.metrics["sampling_top1_prob"][:] = 1.0 / max(self.bin_count, 1) - self.metrics["sampling_top1_bin"][:] = 0.5 - - def _time_to_bin_index( - self, motion_ids: torch.Tensor, motion_times: torch.Tensor - ) -> torch.Tensor: - """Map (clip_id, time_in_clip) pairs to time-bin indices.""" - offsets = self.tb_clip_bin_offset[motion_ids] # first bin of each clip - local_time = motion_times - self.tb_start_s[offsets] - bin_within_clip = (local_time / self.cfg.adaptive_bin_duration_s).long() - # Clamp to valid range: last bin of each clip - # Total bins per clip = number of consecutive bins with same clip_id - # Simple approach: clamp so offset + bin_within_clip stays < num_time_bins - # and the clip_id matches - bin_within_clip = torch.clamp(bin_within_clip, min=0) - result = offsets + bin_within_clip - result = torch.clamp(result, max=self.num_time_bins - 1) - return result - - def _adaptive_bin_sampling(self, env_ids: torch.Tensor): - """SONIC-inspired bin-based adaptive sampling.""" - # 1. Accumulate failure statistics per time bin - current_bin_ids = self._env_bin_ids[env_ids] - episode_failed = self._env.termination_manager.terminated[env_ids] - exposure = torch.bincount( - current_bin_ids, minlength=self.num_time_bins - ).float() - failure = torch.bincount( - current_bin_ids[episode_failed], minlength=self.num_time_bins - ).float() - self._tb_accum_exposure += exposure - self._tb_accum_failure += failure - - # 2. Compute capped, blended probabilities - capped = _cap_failure_rates(self.tb_failed_rate, self.cfg.adaptive_bin_beta) - capped_sum = capped.sum() - if capped_sum > 0: - p_hat = capped / capped_sum - else: - p_hat = torch.ones_like(capped) / self.num_time_bins - - alpha = self.cfg.adaptive_bin_alpha - p_final = alpha * p_hat + (1.0 - alpha) / self.num_time_bins - # Normalize for numerical safety - p_final = p_final / p_final.sum() - - # 3. Sample bins, then sample frames within bins - sampled_bins = torch.multinomial(p_final, len(env_ids), replacement=True) - sampled_clip_ids = self.tb_clip_id[sampled_bins] - bin_starts = self.tb_start_s[sampled_bins] - bin_durs = self.tb_duration[sampled_bins] - sampled_times = bin_starts + torch.rand( - len(env_ids), device=self.device - ) * bin_durs - - self.motion_ids[env_ids] = sampled_clip_ids - self.motion_times[env_ids] = sampled_times - self._env_bin_ids[env_ids] = sampled_bins - - # 4. Update metrics - entropy = -(p_final * (p_final + 1e-12).log()).sum() - entropy_norm = ( - entropy / math.log(self.num_time_bins) if self.num_time_bins > 1 else 1.0 - ) - top1_prob, top1_bin = p_final.max(dim=0) - self.metrics["sampling_entropy"][:] = entropy_norm - self.metrics["sampling_top1_prob"][:] = top1_prob - self.metrics["sampling_top1_bin"][:] = top1_bin.float() / self.num_time_bins def _resample_command(self, env_ids: torch.Tensor): if self.cfg.sampling_mode == "start": @@ -1043,11 +753,11 @@ def _resample_command(self, env_ids: torch.Tensor): self.motion_times[env_ids] = self.motion.sample_start_times(self.motion_ids[env_ids]) elif self.cfg.sampling_mode == "uniform": self._uniform_sampling(env_ids) - elif self.cfg.sampling_mode == "adaptive_bin": - self._adaptive_bin_sampling(env_ids) else: - assert self.cfg.sampling_mode == "adaptive" - self._adaptive_sampling(env_ids) + raise ValueError( + f"Unsupported motion sampling_mode={self.cfg.sampling_mode!r}. " + "Supported modes are 'uniform' and 'start'." + ) if env_ids.numel() == 0: return @@ -1174,54 +884,6 @@ def _update_command(self): self._refresh_body_local_cache() self._update_feet_standing() - if self.cfg.sampling_mode == "adaptive": - # Sync raw counts across ranks for unified statistics - if _is_distributed(): - torch.distributed.all_reduce(self._accum_exposure_count) - torch.distributed.all_reduce(self._accum_failure_count) - - # Only update EMA when new data exists (fixes decay-on-empty-step) - if self._accum_exposure_count.sum() > 0: - valid = self._accum_exposure_count > 0 - global_rate = torch.zeros_like(self.bin_failed_rate) - global_rate[valid] = ( - self._accum_failure_count[valid] - / self._accum_exposure_count[valid] - ) - self.bin_failed_rate = ( - self.cfg.adaptive_alpha * global_rate - + (1 - self.cfg.adaptive_alpha) * self.bin_failed_rate - ) - - self._accum_exposure_count.zero_() - self._accum_failure_count.zero_() - - elif self.cfg.sampling_mode == "adaptive_bin": - # Keep bin ids in sync with current motion_times so that - # failures are attributed to the bin the agent is *in*, not the - # bin it was initially sampled from. - self._env_bin_ids = self._time_to_bin_index( - self.motion_ids, self.motion_times - ) - - if _is_distributed(): - torch.distributed.all_reduce(self._tb_accum_exposure) - torch.distributed.all_reduce(self._tb_accum_failure) - - if self._tb_accum_exposure.sum() > 0: - valid = self._tb_accum_exposure > 0 - global_rate = torch.zeros_like(self.tb_failed_rate) - global_rate[valid] = ( - self._tb_accum_failure[valid] / self._tb_accum_exposure[valid] - ) - ema = self.cfg.adaptive_bin_alpha_ema - self.tb_failed_rate = ( - ema * global_rate + (1 - ema) * self.tb_failed_rate - ) - - self._tb_accum_exposure.zero_() - self._tb_accum_failure.zero_() - # ------------------------------------------------------------------ # Visualization # ------------------------------------------------------------------ @@ -1306,16 +968,7 @@ class MotionCommandCfg(CommandTermCfg): pose_range: dict[str, tuple[float, float]] = field(default_factory=dict) velocity_range: dict[str, tuple[float, float]] = field(default_factory=dict) joint_position_range: tuple[float, float] = (-0.52, 0.52) - adaptive_kernel_size: int = 1 - adaptive_lambda: float = 0.8 - adaptive_uniform_ratio: float = 0.1 - adaptive_alpha: float = 0.001 - sampling_mode: Literal["adaptive", "adaptive_bin", "uniform", "start"] = "uniform" - # --- adaptive_bin mode params --- - adaptive_bin_duration_s: float = 5.0 - adaptive_bin_beta: float = 5.0 - adaptive_bin_alpha: float = 0.8 - adaptive_bin_alpha_ema: float = 0.001 + sampling_mode: Literal["uniform", "start"] = "uniform" window_steps: tuple[int, ...] = (0,) feet_body_names: tuple[str, ...] = () feet_standing_z_threshold: float = 0.18 From d3b0e02ca93676b78a840dc9090a8b200938c196 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Wed, 13 May 2026 21:18:42 +0800 Subject: [PATCH 009/122] Add optional LinkerHand sim2real control --- .gitmodules | 3 + AGENTS.md | 7 + README.md | 4 + docs/docs/configuration/config-reference.md | 17 + docs/docs/getting-started/installation.md | 9 + docs/docs/tutorials/pico-sim2real.md | 37 ++ .../current/configuration/config-reference.md | 16 + .../current/getting-started/installation.md | 9 + .../current/tutorials/pico-sim2real.md | 36 ++ teleopit/configs/pico4_sim2real.yaml | 19 + teleopit/configs/sim2real.yaml | 19 + teleopit/inputs/pico4_provider.py | 50 +++ teleopit/sim2real/controller.py | 26 ++ teleopit/sim2real/dexterous_hand.py | 380 ++++++++++++++++++ tests/test_dexterous_hand.py | 233 +++++++++++ tests/test_mocap_mujoco.py | 28 ++ tests/test_pico4_provider.py | 15 + tests/test_sim2real_runtime.py | 42 ++ third_party/linkerhand-python-sdk | 1 + 19 files changed, 951 insertions(+) create mode 100644 teleopit/sim2real/dexterous_hand.py create mode 100644 tests/test_dexterous_hand.py create mode 160000 third_party/linkerhand-python-sdk diff --git a/.gitmodules b/.gitmodules index a5085287..6ddf149b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "third_party/unitree_sdk2_python"] path = third_party/unitree_sdk2_python url = https://github.com/unitreerobotics/unitree_sdk2_python.git +[submodule "third_party/linkerhand-python-sdk"] + path = third_party/linkerhand-python-sdk + url = https://github.com/BotRunner64/linkerhand-python-sdk.git diff --git a/AGENTS.md b/AGENTS.md index 8d62dc4f..ed756a36 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -54,6 +54,9 @@ teleopit/ # Core inference package │ └── mujoco_robot.py # MuJoCoRobot — MuJoCo sim wrapper ├── sim/ │ └── loop.py # SimulationLoop — PD control at 1000Hz, policy at 50Hz +├── sim2real/ +│ ├── controller.py # G1 state machine and hardware control loop +│ └── dexterous_hand.py # Optional Pico controller → LinkerHand L6 runtime └── recording/ # HDF5Recorder scripts/ ├── run_sim.py # Offline sim2sim pipeline @@ -137,6 +140,9 @@ target_dof_pos = clip(action, -10, 10) × action_scale + default_dof_pos - Default Pico sim2sim keyboard mappings are `Y` → `MOCAP`, `A` → pause/resume mocap, `X` → back to `STANDING`, `Q` → quit - Pico4 sim2real pause/resume is handled as a mocap-session control event (`toggle_pause`), not as a mode switch to `STANDING` - Default Pico pause button is `A`; restore tracking by rebuilding the realtime buffer and yaw/XY root-offset alignment, then accept the current live mocap retarget reference directly +- Optional LinkerHand L6 control uses `third_party/linkerhand-python-sdk` and `dexterous_hand.enabled=true` +- LinkerHand control reuses `Pico4InputProvider.get_controller_snapshot()`; do not start a second `PicoBridge` for hand control +- LinkerHand L6 control is active only in sim2real `MOCAP`; `STANDING`, `DAMPING`, mocap pause, frame timeout, and shutdown must send the configured open pose ### SimulationLoop Runtime Behavior - `realtime=true` enforces wall-clock pacing even without a viewer @@ -204,6 +210,7 @@ python train_mimic/scripts/save_onnx.py --checkpoint logs/rsl_rl/g1_general_trac - Do not commit robot meshes, datasets, checkpoints, or demo media to Git; use `scripts/setup/download_assets.py` - `teleopit/retargeting/gmr/assets/` is gitignored; downloaded at runtime - `train_mimic/assets/` is no longer tracked; FK tooling reuses `teleopit/retargeting/gmr/assets/unitree_g1/g1_mjlab.xml` +- `third_party/linkerhand-python-sdk` is a git submodule for optional LinkerHand L6 sim2real control - Run `python scripts/check_large_tracked_files.py` before pushing Assets are split across two ModelScope repos by type: diff --git a/README.md b/README.md index baba966a..03acbf5f 100644 --- a/README.md +++ b/README.md @@ -60,6 +60,10 @@ Full docs at **[BotRunner64.github.io/Teleopit](https://BotRunner64.github.io/Te ## Changelog +### Unreleased + +- Added optional Pico controller control for LinkerHand L6 in sim2real, backed by the LinkerHand SDK submodule. + ### v0.3.0 (2026-05-12) - Consolidated realtime input around pico-bridge 0.2.0 and removed the old ZMQ/onboard Pico path. diff --git a/docs/docs/configuration/config-reference.md b/docs/docs/configuration/config-reference.md index 96b86f21..76ae83f6 100644 --- a/docs/docs/configuration/config-reference.md +++ b/docs/docs/configuration/config-reference.md @@ -123,6 +123,23 @@ Fields used by sim2real configs (`sim2real.yaml`, `pico4_sim2real.yaml`). Realtime Pico resume re-centers heading and ground-plane position before tracking continues. Operators should keep still and stay as close as practical to the paused pose to reduce sudden reference changes. +### Dexterous Hand (Pico sim2real) + +`dexterous_hand.enabled=true` requires `input.provider=pico4` and the optional +LinkerHand SDK submodule. Control is active only in `MOCAP`; inactive modes and +timeouts send the open pose. + +| Field | Description | Default | +|-------|-------------|---------| +| `dexterous_hand.enabled` | Enable Pico controller control for LinkerHand L6 | `false` | +| `dexterous_hand.hand_type` | Controlled side: `left`, `right`, or `both` | `both` | +| `dexterous_hand.left_can` / `right_can` | CAN channels for each hand | `can0` / `can1` | +| `dexterous_hand.rate` | Maximum command rate in Hz | `30.0` | +| `dexterous_hand.frame_timeout` | Missing-controller timeout before opening hands | `0.3` | +| `dexterous_hand.deadman_threshold` | Minimum grip value required to enable a side | `0.5` | +| `dexterous_hand.trigger_deadzone` | Trigger deadzone at both ends | `0.05` | +| `dexterous_hand.open_pose` / `close_pose` | Six-value L6 open/closed poses | see config | + ### Realtime Catch-up (Pico sim2real) | Field | Description | Default | diff --git a/docs/docs/getting-started/installation.md b/docs/docs/getting-started/installation.md index e16332b4..9c6cc634 100644 --- a/docs/docs/getting-started/installation.md +++ b/docs/docs/getting-started/installation.md @@ -60,6 +60,15 @@ The receiver can run on a workstation PC or the robot onboard computer. See [Pico Sim2Sim](../tutorials/pico-sim2sim) and [Pico Sim2Real](../tutorials/pico-sim2real) for the full setup guides. +Optional LinkerHand L6 control for Pico sim2real uses a submodule SDK: + +```bash +git submodule update --init third_party/linkerhand-python-sdk +pip install -e third_party/linkerhand-python-sdk +``` + +This SDK is only required when `dexterous_hand.enabled=true`. + ## Verify Installation ```bash diff --git a/docs/docs/tutorials/pico-sim2real.md b/docs/docs/tutorials/pico-sim2real.md index 868f1aa9..fb9c0866 100644 --- a/docs/docs/tutorials/pico-sim2real.md +++ b/docs/docs/tutorials/pico-sim2real.md @@ -138,6 +138,39 @@ Resume while standing still and close to the paused pose. This reduces sudden reference changes when live tracking resumes. ::: +## Optional LinkerHand L6 Control + +Pico sim2real can drive LinkerHand L6 hands from the Pico controllers. Hold the +matching side grip as a deadman switch; the matching trigger closes that hand. +Hand control is active only in `MOCAP`. It sends the open pose in `STANDING`, +`DAMPING`, paused mocap, frame timeout, and shutdown. + +Install the SDK submodule first: + +```bash +git submodule update --init third_party/linkerhand-python-sdk +pip install -e third_party/linkerhand-python-sdk +``` + +Dry-run before connecting CAN: + +```bash +python scripts/run/run_sim2real.py \ + --config-name pico4_sim2real \ + controller.policy_path=track.onnx \ + dexterous_hand.enabled=true \ + dexterous_hand.dry_run=true +``` + +For real hands, configure the CAN channels and disable dry-run: + +```bash +dexterous_hand.enabled=true +dexterous_hand.dry_run=false +dexterous_hand.left_can=can0 +dexterous_hand.right_can=can1 +``` + ## Optional RealSense Preview Stream the G1 RealSense color camera back to the Pico headset: @@ -177,6 +210,9 @@ pause_resume_warmup_steps=2 # Change Pico pause button input.pause_button=right_axis_click +# Enable LinkerHand L6 dry-run +dexterous_hand.enabled=true dexterous_hand.dry_run=true + # Enable headset video preview input.video.enabled=true ``` @@ -190,4 +226,5 @@ input.video.enabled=true | Cannot enter debug mode | Unitree mode release failed | Stop other robot modes and press `Start` again | | Robot enters `STANDING` but not `MOCAP` | Mocap validation failed | Keep tracking active and stable; check `mocap_switch.check_frames` logs | | Pico pause does not return to `STANDING` | Expected behavior | Pico pause freezes mocap; press remote `X` for `STANDING` | +| LinkerHand does not move | Not in `MOCAP`, deadman grip released, SDK not installed, or CAN channel wrong | Enter `MOCAP`, hold the matching side grip, verify `pip install -e third_party/linkerhand-python-sdk`, and check `dexterous_hand.left_can` / `right_can` | | Video preview is unavailable | RealSense or video source failed | Check camera permissions, `input.video.source`, and logs | diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md index 27641896..a84a57ba 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md @@ -142,6 +142,22 @@ target = clip(action, clip_range) * action_scale + default_dof_pos 实时 Pico 恢复追踪时会先重新居中航向和地面平面位置。操作者应保持静止,并尽量贴近暂停时的姿态,以减少参考突变。 +### 灵巧手(Pico sim2real) + +`dexterous_hand.enabled=true` 要求 `input.provider=pico4`,并安装可选的 +LinkerHand SDK submodule。控制只在 `MOCAP` 中生效;非活动模式和超时会发送张开姿态。 + +| 字段 | 说明 | 默认值 | +|---|---|---| +| `dexterous_hand.enabled` | 启用 Pico 手柄控制 LinkerHand L6 | `false` | +| `dexterous_hand.hand_type` | 控制侧:`left`、`right` 或 `both` | `both` | +| `dexterous_hand.left_can` / `right_can` | 左右手 CAN 通道 | `can0` / `can1` | +| `dexterous_hand.rate` | 最大命令频率(Hz) | `30.0` | +| `dexterous_hand.frame_timeout` | 手柄超时后张开手的时间 | `0.3` | +| `dexterous_hand.deadman_threshold` | 启用单侧控制所需的最小 grip 值 | `0.5` | +| `dexterous_hand.trigger_deadzone` | trigger 两端死区 | `0.05` | +| `dexterous_hand.open_pose` / `close_pose` | L6 的 6 维张开/闭合姿态 | 见配置 | + ### 实时追赶(Pico sim2real) | 字段 | 说明 | 默认值 | diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md index 8f95acb5..41f2c6fd 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md @@ -60,6 +60,15 @@ receiver 可以运行在工作站 PC,也可以运行在机器人 onboard 计 完整设置流程详见 [Pico Sim2Sim](../tutorials/pico-sim2sim) 和 [Pico Sim2Real](../tutorials/pico-sim2real)。 +Pico sim2real 可选的 LinkerHand L6 控制使用一个 submodule SDK: + +```bash +git submodule update --init third_party/linkerhand-python-sdk +pip install -e third_party/linkerhand-python-sdk +``` + +只有在 `dexterous_hand.enabled=true` 时才需要安装该 SDK。 + ## 验证安装 ```bash diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md index bb7f5052..b25daff6 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md @@ -130,6 +130,38 @@ Pico 暂停/恢复是 mocap-session control event。 恢复时请保持静止,并尽量接近暂停时的姿态。这样可以减少实时追踪恢复时的参考突变。 ::: +## 可选 LinkerHand L6 控制 + +Pico sim2real 可以用 Pico 手柄控制 LinkerHand L6。按住同侧 grip 作为 deadman, +同侧 trigger 控制对应手闭合。手控只在 `MOCAP` 中生效;在 `STANDING`、`DAMPING`、 +mocap 暂停、帧超时和退出时都会发送张开姿态。 + +先安装 SDK submodule: + +```bash +git submodule update --init third_party/linkerhand-python-sdk +pip install -e third_party/linkerhand-python-sdk +``` + +连接 CAN 前先 dry-run: + +```bash +python scripts/run/run_sim2real.py \ + --config-name pico4_sim2real \ + controller.policy_path=track.onnx \ + dexterous_hand.enabled=true \ + dexterous_hand.dry_run=true +``` + +使用真实手时,配置 CAN 通道并关闭 dry-run: + +```bash +dexterous_hand.enabled=true +dexterous_hand.dry_run=false +dexterous_hand.left_can=can0 +dexterous_hand.right_can=can1 +``` + ## 可选 RealSense 预览 将 G1 RealSense 彩色相机推送回 Pico 头显: @@ -169,6 +201,9 @@ pause_resume_warmup_steps=2 # 更换 Pico 暂停键 input.pause_button=right_axis_click +# 开启 LinkerHand L6 dry-run +dexterous_hand.enabled=true dexterous_hand.dry_run=true + # 开启头显视频预览 input.video.enabled=true ``` @@ -182,4 +217,5 @@ input.video.enabled=true | 无法进入 debug mode | Unitree mode 释放失败 | 停止其他机器人模式后再次按 `Start` | | 机器人进入 `STANDING` 但不进入 `MOCAP` | 动捕验证失败 | 保持追踪稳定,查看 `mocap_switch.check_frames` 日志 | | Pico 暂停没有返回 `STANDING` | 这是预期行为 | Pico 暂停只冻结 mocap;按遥控器 `X` 返回 `STANDING` | +| LinkerHand 不动 | 不在 `MOCAP`、deadman grip 未按住、SDK 未安装,或 CAN 通道错误 | 进入 `MOCAP`,按住同侧 grip,确认已执行 `pip install -e third_party/linkerhand-python-sdk`,并检查 `dexterous_hand.left_can` / `right_can` | | 视频预览不可用 | RealSense 或视频源失败 | 检查相机权限、`input.video.source` 和日志 | diff --git a/teleopit/configs/pico4_sim2real.yaml b/teleopit/configs/pico4_sim2real.yaml index 3c55d1c4..081fc192 100644 --- a/teleopit/configs/pico4_sim2real.yaml +++ b/teleopit/configs/pico4_sim2real.yaml @@ -32,6 +32,25 @@ startup_ramp_duration: 2.0 # Joint velocity safety limit (rad/s) -- trigger emergency damping if exceeded joint_vel_limit: 10.0 +# Optional LinkerHand L6 control from Pico controller grip/trigger. +dexterous_hand: + enabled: false + hand_joint: L6 + hand_type: both + left_can: can0 + right_can: can1 + modbus: "None" + rate: 30.0 + frame_timeout: 0.3 + trigger_deadzone: 0.05 + deadman_threshold: 0.5 + thumb_yaw_center: 10 + speed: [50, 50, 50, 50, 50, 50] + open_pose: [250, 10, 250, 250, 250, 250] + close_pose: [79, 10, 0, 0, 0, 0] + dry_run: false + print_input: false + # Physical robot SDK configuration real_robot: network_interface: "eth0" diff --git a/teleopit/configs/sim2real.yaml b/teleopit/configs/sim2real.yaml index a677d3a9..ef31bc6d 100644 --- a/teleopit/configs/sim2real.yaml +++ b/teleopit/configs/sim2real.yaml @@ -27,6 +27,25 @@ startup_ramp_duration: 2.0 # Joint velocity safety limit (rad/s) -- trigger emergency damping if exceeded joint_vel_limit: 10.0 +# Optional LinkerHand L6 control. Enable only with input.provider=pico4. +dexterous_hand: + enabled: false + hand_joint: L6 + hand_type: both + left_can: can0 + right_can: can1 + modbus: "None" + rate: 30.0 + frame_timeout: 0.3 + trigger_deadzone: 0.05 + deadman_threshold: 0.5 + thumb_yaw_center: 10 + speed: [50, 50, 50, 50, 50, 50] + open_pose: [250, 10, 250, 250, 250, 250] + close_pose: [79, 10, 0, 0, 0, 0] + dry_run: false + print_input: false + # Physical robot SDK configuration real_robot: network_interface: "eth0" diff --git a/teleopit/inputs/pico4_provider.py b/teleopit/inputs/pico4_provider.py index b6bc021e..df6e8517 100644 --- a/teleopit/inputs/pico4_provider.py +++ b/teleopit/inputs/pico4_provider.py @@ -8,6 +8,7 @@ from __future__ import annotations from collections import deque +from dataclasses import dataclass import inspect import logging import threading @@ -52,6 +53,26 @@ dtype=np.int32, ) + +@dataclass(frozen=True) +class PicoControllerState: + """Latest per-controller input state exposed by pico_bridge.""" + + raw: bool + grip: float + trigger: float + present: bool = True + + +@dataclass(frozen=True) +class PicoControllerSnapshot: + """Immutable snapshot of Pico controller inputs for auxiliary runtimes.""" + + left: PicoControllerState + right: PicoControllerState + timestamp_s: float + seq: int + _PAUSE_BUTTON_MAP: dict[str, tuple[str, str]] = { "A": ("right", "primaryButton"), "B": ("right", "secondaryButton"), @@ -143,6 +164,7 @@ def __init__( self._last_raw_body_joints: NDArray[np.float64] | None = None self._last_frame_timestamp: float | None = None self._last_source_seq: int | None = None + self._controller_snapshot: PicoControllerSnapshot | None = None self._bridge = bridge_cls( host=bridge_host, port=int(bridge_port), @@ -223,6 +245,11 @@ def pop_control_events(self) -> tuple[ControlEvent, ...]: self._pending_control_events.clear() return control_events + def get_controller_snapshot(self) -> PicoControllerSnapshot | None: + """Return the latest Pico controller-axis snapshot, if one has arrived.""" + with self._lock: + return self._controller_snapshot + def push_video_frame(self, frame: NDArray[np.uint8]) -> int: """Push one RGB camera frame to pico-bridge 0.2.0 video output.""" push_video_frame = getattr(self._bridge, "push_video_frame", None) @@ -287,6 +314,7 @@ def _poll_loop(self) -> None: def _accept_pico_frame(self, frame: Any) -> bool: timestamp = float(getattr(frame, "receive_time_s", time.monotonic())) + self._accept_controller_snapshot(frame, timestamp=timestamp) self._poll_control_events(frame, timestamp=timestamp) body = getattr(frame, "body", None) @@ -327,6 +355,18 @@ def _accept_pico_frame(self, frame: Any) -> bool: self._frame_ready.set() return True + def _accept_controller_snapshot(self, frame: Any, *, timestamp: float) -> None: + seq = int(getattr(frame, "seq", self._last_source_seq or -1)) + controllers = getattr(frame, "controllers", None) + snapshot = PicoControllerSnapshot( + left=self._read_controller_state(None if controllers is None else getattr(controllers, "left", None)), + right=self._read_controller_state(None if controllers is None else getattr(controllers, "right", None)), + timestamp_s=float(timestamp), + seq=seq, + ) + with self._lock: + self._controller_snapshot = snapshot + def _poll_control_events(self, frame: Any, *, timestamp: float) -> bool: if self._pause_button_path is None: return False @@ -361,6 +401,16 @@ def _resolve_button_path(pause_button: str | None) -> tuple[str, str] | None: return None return _PAUSE_BUTTON_MAP.get(pause_button) + @staticmethod + def _read_controller_state(controller: Any) -> PicoControllerState: + axis = {} if controller is None else getattr(controller, "axis", {}) or {} + return PicoControllerState( + raw=bool(False if controller is None else getattr(controller, "raw", False)), + grip=float(axis.get("grip", 0.0)), + trigger=float(axis.get("trigger", 0.0)), + present=controller is not None, + ) + @staticmethod def _convert_body_joints_to_frame(body_joints: NDArray[np.float64]) -> HumanFrame: body_joints = Pico4InputProvider._normalize_pico_bridge_body_joints(body_joints) diff --git a/teleopit/sim2real/controller.py b/teleopit/sim2real/controller.py index 8324743b..c29ac0eb 100644 --- a/teleopit/sim2real/controller.py +++ b/teleopit/sim2real/controller.py @@ -45,6 +45,7 @@ obs_builder_requires_reference_window, ) from teleopit.sim.realtime_utils import RealtimeReferenceManager +from teleopit.sim2real.dexterous_hand import build_linkerhand_runtime from teleopit.sim2real.reference_processor import Sim2RealReferenceProcessor from teleopit.sim2real.remote import UnitreeRemote from teleopit.sim2real.safety import Sim2RealSafetyManager @@ -120,6 +121,7 @@ def _init_components(self, cfg: Any) -> None: config=parse_pico_video_config(cfg_get(cfg, "input", {})), mode="sim2real", ) + self._hand_runtime = build_linkerhand_runtime(cfg, self.input_provider) self._offline_reference: OfflineReferenceMotion | None = None self._offline_playback: OfflinePlaybackController | None = None if hasattr(self.input_provider, "__len__") and hasattr(self.input_provider, "get_frame_by_index"): @@ -223,6 +225,7 @@ def run(self) -> None: try: self._video_runtime.start() + self._hand_runtime.start() while True: t0 = time.monotonic() self._video_runtime.tick() @@ -236,6 +239,7 @@ def run(self) -> None: if self.mode != RobotMode.DAMPING: logger.warning("EMERGENCY STOP (L1+R1)") self._enter_damping() + self._tick_dexterous_hand() self._sleep_until(t0, dt) continue @@ -248,6 +252,8 @@ def run(self) -> None: elif self.mode == RobotMode.MOCAP: self._mocap_step() + self._tick_dexterous_hand() + # 6. Rate control self._sleep_until(t0, dt) @@ -563,6 +569,7 @@ def _enter_standing(self) -> None: self._mocap_reentry_armed = prev_mode == RobotMode.MOCAP self.mode = RobotMode.STANDING + self._deactivate_dexterous_hand() logger.info("Mode -> STANDING (RL policy maintaining balance at default pose)") # ------------------------------------------------------------------ @@ -666,6 +673,7 @@ def _enter_damping(self) -> None: self.robot.exit_debug_mode() self.mode = RobotMode.DAMPING + self._deactivate_dexterous_hand() self._ref_proc.last_reference_qpos = None if self._reference_timeline is not None: self._reference_timeline.clear() @@ -878,6 +886,20 @@ def _run_static_mocap_step(self, hold_qpos: Float64Array) -> None: self._ref_proc.last_reference_qpos = qpos.copy() self._last_commanded_motion_qpos = qpos.copy() + def _tick_dexterous_hand(self) -> None: + active = self.mode == RobotMode.MOCAP and self._mocap_session.state == MocapSessionState.ACTIVE + try: + self._hand_runtime.tick(active=active) + except Exception: + logger.exception("Dexterous hand runtime failed -- entering damping") + self._enter_damping() + + def _deactivate_dexterous_hand(self) -> None: + try: + self._hand_runtime.tick(active=False) + except Exception: + logger.exception("Failed to deactivate dexterous hand runtime") + @staticmethod def _sleep_until(t0: float, dt: float) -> None: """Sleep to maintain control frequency.""" @@ -907,6 +929,10 @@ def shutdown(self) -> None: self._video_runtime.stop() except Exception: pass + try: + self._hand_runtime.close() + except Exception: + pass try: self.input_provider.close() except Exception: diff --git a/teleopit/sim2real/dexterous_hand.py b/teleopit/sim2real/dexterous_hand.py new file mode 100644 index 00000000..9899caf0 --- /dev/null +++ b/teleopit/sim2real/dexterous_hand.py @@ -0,0 +1,380 @@ +"""Optional LinkerHand L6 control for Pico sim2real.""" + +from __future__ import annotations + +from dataclasses import dataclass +import logging +import time +from typing import Any, Protocol, Sequence + +from teleopit.inputs.pico4_provider import PicoControllerSnapshot, PicoControllerState +from teleopit.runtime.common import cfg_get + +logger = logging.getLogger(__name__) + +THUMB_YAW_DEFAULT = 10 +OPEN_POSE = [250, THUMB_YAW_DEFAULT, 250, 250, 250, 250] +CLOSE_POSE = [79, THUMB_YAW_DEFAULT, 0, 0, 0, 0] +DEFAULT_SPEED = [50, 50, 50, 50, 50, 50] +HAND_TYPES = ("left", "right") + + +class ControllerSnapshotProvider(Protocol): + def get_controller_snapshot(self) -> PicoControllerSnapshot | None: + ... + + +@dataclass(frozen=True) +class LinkerHandConfig: + enabled: bool = False + hand_joint: str = "L6" + hand_type: str = "both" + left_can: str = "can0" + right_can: str = "can1" + modbus: str = "None" + rate: float = 30.0 + frame_timeout: float = 0.3 + trigger_deadzone: float = 0.05 + deadman_threshold: float = 0.5 + thumb_yaw_center: int = THUMB_YAW_DEFAULT + speed: tuple[int, ...] = tuple(DEFAULT_SPEED) + open_pose: tuple[int, ...] = tuple(OPEN_POSE) + close_pose: tuple[int, ...] = tuple(CLOSE_POSE) + dry_run: bool = False + print_input: bool = False + + @property + def selected_hand_types(self) -> tuple[str, ...]: + if self.hand_type == "both": + return HAND_TYPES + return (self.hand_type,) + + +def clamp_unit(value: float) -> float: + return max(0.0, min(1.0, float(value))) + + +def normalize_trigger(value: float, deadzone: float) -> float: + value = clamp_unit(value) + deadzone = clamp_unit(deadzone) + if deadzone >= 0.5: + raise ValueError(f"trigger_deadzone must be < 0.5, got {deadzone}") + if value <= deadzone: + return 0.0 + upper = 1.0 - deadzone + if value >= upper: + return 1.0 + return (value - deadzone) / (upper - deadzone) + + +def trigger_to_pose( + trigger: float, + *, + open_pose: Sequence[int], + close_pose: Sequence[int], + deadzone: float, + thumb_yaw_default: int, +) -> list[int]: + if len(open_pose) != 6 or len(close_pose) != 6: + raise ValueError("LinkerHand L6 open_pose and close_pose must each contain 6 values") + alpha = normalize_trigger(trigger, deadzone) + pose = [ + int(round(float(open_value) + alpha * (float(close_value) - float(open_value)))) + for open_value, close_value in zip(open_pose, close_pose) + ] + pose[1] = int(thumb_yaw_default) + return pose + + +def parse_linkerhand_config(cfg: Any) -> LinkerHandConfig: + hand_cfg = cfg_get(cfg, "dexterous_hand", {}) or {} + thumb_yaw = _uint8(cfg_get(hand_cfg, "thumb_yaw_center", THUMB_YAW_DEFAULT), "thumb_yaw_center") + open_pose = _pose_values(cfg_get(hand_cfg, "open_pose", OPEN_POSE), "open_pose") + close_pose = _pose_values(cfg_get(hand_cfg, "close_pose", CLOSE_POSE), "close_pose") + open_pose[1] = thumb_yaw + close_pose[1] = thumb_yaw + + config = LinkerHandConfig( + enabled=bool(cfg_get(hand_cfg, "enabled", False)), + hand_joint=str(cfg_get(hand_cfg, "hand_joint", "L6")).upper(), + hand_type=str(cfg_get(hand_cfg, "hand_type", "both")).lower(), + left_can=str(cfg_get(hand_cfg, "left_can", "can0")), + right_can=str(cfg_get(hand_cfg, "right_can", "can1")), + modbus=str(cfg_get(hand_cfg, "modbus", "None")), + rate=_positive_float(cfg_get(hand_cfg, "rate", 30.0), "rate"), + frame_timeout=_positive_float(cfg_get(hand_cfg, "frame_timeout", 0.3), "frame_timeout"), + trigger_deadzone=_trigger_deadzone(cfg_get(hand_cfg, "trigger_deadzone", 0.05)), + deadman_threshold=_deadman_threshold(cfg_get(hand_cfg, "deadman_threshold", 0.5)), + thumb_yaw_center=thumb_yaw, + speed=tuple(_pose_values(cfg_get(hand_cfg, "speed", DEFAULT_SPEED), "speed")), + open_pose=tuple(open_pose), + close_pose=tuple(close_pose), + dry_run=bool(cfg_get(hand_cfg, "dry_run", False)), + print_input=bool(cfg_get(hand_cfg, "print_input", False)), + ) + if config.hand_joint != "L6": + raise ValueError(f"dexterous_hand.hand_joint must be 'L6', got {config.hand_joint!r}") + if config.hand_type not in ("left", "right", "both"): + raise ValueError("dexterous_hand.hand_type must be left, right, or both") + return config + + +class L6PoseSender: + """Thin adapter around LinkerHandApi with duplicate-command suppression.""" + + def __init__(self, config: LinkerHandConfig): + self._config = config + self._hand_types = config.selected_hand_types + self._can_channels = {"left": config.left_can, "right": config.right_can} + self._hands: dict[str, Any] = {} + self._last_pose: dict[str, list[int] | None] = { + hand_type: None for hand_type in self._hand_types + } + self._started = False + + @property + def started(self) -> bool: + return self._started + + def start(self) -> None: + if self._started: + return + if self._config.dry_run: + logger.info("LinkerHand L6 dry-run enabled | hands=%s", ",".join(self._hand_types)) + self._started = True + return + try: + from LinkerHand.linker_hand_api import LinkerHandApi + except ImportError as exc: + raise ImportError( + "LinkerHand SDK is required when dexterous_hand.enabled=true. " + "Run: pip install -e third_party/linkerhand-python-sdk" + ) from exc + + try: + for hand_type in self._hand_types: + hand = LinkerHandApi( + hand_joint="L6", + hand_type=hand_type, + modbus=self._config.modbus, + can=self._can_channels[hand_type], + ) + hand.set_speed(speed=list(self._config.speed)) + self._hands[hand_type] = hand + self._started = True + except SystemExit as exc: + self._close_hands() + self._started = False + raise RuntimeError( + "LinkerHand SDK exited during startup. Check CAN interface configuration " + f"({', '.join(self._can_channels[hand_type] for hand_type in self._hand_types)}) " + "or use dexterous_hand.dry_run=true before connecting hardware." + ) from exc + except Exception: + self._close_hands() + self._started = False + raise + logger.info("LinkerHand L6 runtime started | hands=%s", ",".join(self._hand_types)) + + def send(self, hand_type: str, pose: Sequence[int], *, force: bool = False, reason: str = "") -> None: + if not self._started: + return + next_pose = [int(value) for value in pose] + if not force and self._last_pose.get(hand_type) == next_pose: + return + if self._config.dry_run: + suffix = f" ({reason})" if reason else "" + logger.info("dry-run LinkerHand %s pose%s: %s", hand_type, suffix, next_pose) + else: + hand = self._hands.get(hand_type) + if hand is None: + raise RuntimeError("L6PoseSender has not been started") + hand.finger_move(pose=next_pose) + self._last_pose[hand_type] = next_pose + + def send_all(self, pose: Sequence[int], *, force: bool = False, reason: str = "") -> None: + for hand_type in self._hand_types: + self.send(hand_type, pose, force=force, reason=reason) + + def close(self) -> None: + if not self._started and not self._hands: + return + try: + if self._started: + self.send_all(self._config.open_pose, force=True, reason="exit") + time.sleep(0.2) + except Exception: + logger.exception("Failed to send LinkerHand open pose on exit") + self._close_hands() + self._started = False + + def _close_hands(self) -> None: + for hand in self._hands.values(): + close_can = getattr(hand, "close_can", None) + if callable(close_can): + close_can() + inner_hand = getattr(hand, "hand", None) + close = getattr(inner_hand, "close", None) + if callable(close): + close() + self._hands.clear() + + +class LinkerHandRuntime: + """Drive LinkerHand L6 from Pico controller grip/trigger snapshots.""" + + def __init__(self, config: LinkerHandConfig, provider: ControllerSnapshotProvider): + self.config = config + self._provider = provider + self._sender = L6PoseSender(config) + self._interval_s = 1.0 / config.rate + self._next_tick_s = 0.0 + self._active = False + self._last_status: dict[str, str] = {hand_type: "" for hand_type in config.selected_hand_types} + + @property + def enabled(self) -> bool: + return self.config.enabled + + def start(self) -> None: + if not self.enabled: + return + self._sender.start() + self._sender.send_all(self.config.open_pose, force=True, reason="startup") + + def tick(self, *, active: bool, now_s: float | None = None) -> None: + if not self.enabled: + return + now = time.monotonic() if now_s is None else float(now_s) + if not active: + self._deactivate(reason="inactive") + return + if not self._active: + self._active = True + self._next_tick_s = 0.0 + if now < self._next_tick_s: + return + self._next_tick_s = now + self._interval_s + + snapshot = self._provider.get_controller_snapshot() + if snapshot is None or now - snapshot.timestamp_s > self.config.frame_timeout: + self._open_all(reason="timeout") + return + + for hand_type in self.config.selected_hand_types: + state = getattr(snapshot, hand_type) + self._tick_hand(hand_type, state, snapshot.seq) + + def close(self) -> None: + self._deactivate(reason="shutdown") + self._sender.close() + + def _tick_hand(self, hand_type: str, state: PicoControllerState, seq: int) -> None: + if not state.present: + self._set_status(hand_type, "missing", f"{hand_type} controller missing; opening hand") + self._sender.send(hand_type, self.config.open_pose, reason="missing-controller") + return + + grip = clamp_unit(state.grip) + trigger = clamp_unit(state.trigger) + if grip < self.config.deadman_threshold: + self._set_status(hand_type, "deadman", f"{hand_type} deadman released; opening hand") + self._sender.send(hand_type, self.config.open_pose, reason="deadman-released") + return + + self._set_status(hand_type, "enabled", f"{hand_type} controller active") + if self.config.print_input: + logger.info( + "LinkerHand input | seq=%d hand=%s grip=%.3f trigger=%.3f", + seq, + hand_type, + grip, + trigger, + ) + pose = trigger_to_pose( + trigger, + open_pose=self.config.open_pose, + close_pose=self.config.close_pose, + deadzone=self.config.trigger_deadzone, + thumb_yaw_default=self.config.thumb_yaw_center, + ) + self._sender.send(hand_type, pose, reason="controller") + + def _deactivate(self, *, reason: str) -> None: + if self._active: + self._open_all(reason=reason, force=True) + self._active = False + + def _open_all(self, *, reason: str, force: bool = False) -> None: + self._sender.send_all(self.config.open_pose, force=force, reason=reason) + + def _set_status(self, hand_type: str, status: str, message: str) -> None: + if self._last_status.get(hand_type) == status: + return + self._last_status[hand_type] = status + logger.info("LinkerHand L6: %s", message) + + +class DisabledLinkerHandRuntime: + enabled = False + + def start(self) -> None: + pass + + def tick(self, *, active: bool, now_s: float | None = None) -> None: + del active, now_s + + def close(self) -> None: + pass + + +def build_linkerhand_runtime(cfg: Any, input_provider: Any) -> LinkerHandRuntime | DisabledLinkerHandRuntime: + config = parse_linkerhand_config(cfg) + if not config.enabled: + return DisabledLinkerHandRuntime() + + input_cfg = cfg_get(cfg, "input", {}) or {} + provider_kind = str(cfg_get(input_cfg, "provider", "")).lower() + if provider_kind != "pico4": + raise ValueError("dexterous_hand.enabled=true requires input.provider=pico4") + if not callable(getattr(input_provider, "get_controller_snapshot", None)): + raise ValueError("dexterous_hand.enabled=true requires a Pico input provider with controller snapshots") + return LinkerHandRuntime(config, input_provider) + + +def _positive_float(value: object, field_name: str) -> float: + parsed = float(value) + if parsed <= 0.0: + raise ValueError(f"dexterous_hand.{field_name} must be > 0, got {value!r}") + return parsed + + +def _uint8(value: object, field_name: str) -> int: + parsed = int(value) + if parsed < 0 or parsed > 255: + raise ValueError(f"dexterous_hand.{field_name} must be in range 0-255, got {value!r}") + return parsed + + +def _pose_values(value: object, field_name: str) -> list[int]: + try: + parsed = [_uint8(item, field_name) for item in value] # type: ignore[union-attr] + except TypeError as exc: + raise ValueError(f"dexterous_hand.{field_name} must be a sequence of 6 uint8 values") from exc + if len(parsed) != 6: + raise ValueError(f"dexterous_hand.{field_name} must contain 6 values, got {len(parsed)}") + return parsed + + +def _trigger_deadzone(value: object) -> float: + parsed = float(value) + if parsed < 0.0 or parsed >= 0.5: + raise ValueError(f"dexterous_hand.trigger_deadzone must be in [0, 0.5), got {value!r}") + return parsed + + +def _deadman_threshold(value: object) -> float: + parsed = float(value) + if parsed <= 0.0 or parsed >= 1.0: + raise ValueError(f"dexterous_hand.deadman_threshold must be in (0, 1), got {value!r}") + return parsed diff --git a/tests/test_dexterous_hand.py b/tests/test_dexterous_hand.py new file mode 100644 index 00000000..6878fc02 --- /dev/null +++ b/tests/test_dexterous_hand.py @@ -0,0 +1,233 @@ +from __future__ import annotations + +import sys +from types import SimpleNamespace + +import pytest + +from teleopit.inputs.pico4_provider import PicoControllerSnapshot, PicoControllerState +from teleopit.sim2real.dexterous_hand import ( + L6PoseSender, + LinkerHandRuntime, + parse_linkerhand_config, + trigger_to_pose, +) + + +class SnapshotProvider: + def __init__(self) -> None: + self.snapshot: PicoControllerSnapshot | None = None + + def get_controller_snapshot(self) -> PicoControllerSnapshot | None: + return self.snapshot + + +def _snapshot( + *, + left: PicoControllerState | None = None, + right: PicoControllerState | None = None, + timestamp_s: float = 10.0, + seq: int = 1, +) -> PicoControllerSnapshot: + missing = PicoControllerState(raw=False, grip=0.0, trigger=0.0) + return PicoControllerSnapshot( + left=left or missing, + right=right or missing, + timestamp_s=timestamp_s, + seq=seq, + ) + + +def _runtime(provider: SnapshotProvider) -> LinkerHandRuntime: + cfg = parse_linkerhand_config( + { + "dexterous_hand": { + "enabled": True, + "dry_run": True, + "hand_type": "both", + } + } + ) + runtime = LinkerHandRuntime(cfg, provider) + runtime.start() + return runtime + + +def test_trigger_to_pose_applies_deadzone_and_fixed_thumb_yaw() -> None: + pose = trigger_to_pose( + 0.5, + open_pose=[250, 10, 250, 250, 250, 250], + close_pose=[79, 10, 0, 0, 0, 0], + deadzone=0.05, + thumb_yaw_default=10, + ) + + assert pose == [164, 10, 125, 125, 125, 125] + + +def test_runtime_opens_when_deadman_released() -> None: + provider = SnapshotProvider() + runtime = _runtime(provider) + provider.snapshot = _snapshot( + left=PicoControllerState(raw=True, grip=0.1, trigger=1.0), + right=PicoControllerState(raw=True, grip=0.1, trigger=1.0), + ) + + runtime.tick(active=True, now_s=10.0) + + assert runtime._sender._last_pose["left"] == list(runtime.config.open_pose) + assert runtime._sender._last_pose["right"] == list(runtime.config.open_pose) + + +def test_runtime_maps_present_controller_even_without_raw_flag() -> None: + provider = SnapshotProvider() + runtime = _runtime(provider) + provider.snapshot = _snapshot( + left=PicoControllerState(raw=False, grip=1.0, trigger=1.0, present=True), + right=PicoControllerState(raw=False, grip=1.0, trigger=0.0, present=True), + ) + + runtime.tick(active=True, now_s=10.0) + + assert runtime._sender._last_pose["left"] == list(runtime.config.close_pose) + assert runtime._sender._last_pose["right"] == list(runtime.config.open_pose) + + +def test_runtime_maps_trigger_when_deadman_active() -> None: + provider = SnapshotProvider() + runtime = _runtime(provider) + provider.snapshot = _snapshot( + left=PicoControllerState(raw=True, grip=1.0, trigger=1.0), + right=PicoControllerState(raw=True, grip=1.0, trigger=0.0), + ) + + runtime.tick(active=True, now_s=10.0) + + assert runtime._sender._last_pose["left"] == list(runtime.config.close_pose) + assert runtime._sender._last_pose["right"] == list(runtime.config.open_pose) + + +def test_runtime_opens_on_timeout_and_inactive_mode() -> None: + provider = SnapshotProvider() + runtime = _runtime(provider) + provider.snapshot = _snapshot( + left=PicoControllerState(raw=True, grip=1.0, trigger=1.0), + right=PicoControllerState(raw=True, grip=1.0, trigger=1.0), + timestamp_s=10.0, + ) + + runtime.tick(active=True, now_s=10.0) + assert runtime._sender._last_pose["left"] == list(runtime.config.close_pose) + + provider.snapshot = SimpleNamespace(timestamp_s=9.0, seq=2, left=None, right=None) + runtime.tick(active=True, now_s=20.0) + assert runtime._sender._last_pose["left"] == list(runtime.config.open_pose) + + provider.snapshot = _snapshot( + left=PicoControllerState(raw=True, grip=1.0, trigger=1.0), + right=PicoControllerState(raw=True, grip=1.0, trigger=1.0), + timestamp_s=20.1, + ) + runtime.tick(active=True, now_s=20.1) + assert runtime._sender._last_pose["left"] == list(runtime.config.close_pose) + + runtime.tick(active=False, now_s=20.2) + assert runtime._sender._last_pose["left"] == list(runtime.config.open_pose) + + +def test_pose_sender_cleans_up_partial_start_failure(monkeypatch) -> None: + created_hands = [] + + class FakeInnerHand: + def __init__(self) -> None: + self.close_calls = 0 + + def close(self) -> None: + self.close_calls += 1 + + class FakeLinkerHandApi: + def __init__(self, *, hand_joint: str, hand_type: str, modbus: str, can: str) -> None: + del hand_joint, modbus, can + if hand_type == "right": + raise RuntimeError("right hand failed") + self.hand = FakeInnerHand() + self.close_can_calls = 0 + created_hands.append(self) + + def set_speed(self, speed: list[int]) -> None: + self.speed = speed + + def close_can(self) -> None: + self.close_can_calls += 1 + + fake_module = SimpleNamespace(LinkerHandApi=FakeLinkerHandApi) + monkeypatch.setitem(sys.modules, "LinkerHand.linker_hand_api", fake_module) + + cfg = parse_linkerhand_config( + { + "dexterous_hand": { + "enabled": True, + "dry_run": False, + "hand_type": "both", + } + } + ) + sender = L6PoseSender(cfg) + + with pytest.raises(RuntimeError, match="right hand failed"): + sender.start() + + assert sender.started is False + assert sender._hands == {} + assert len(created_hands) == 1 + assert created_hands[0].close_can_calls == 1 + assert created_hands[0].hand.close_calls == 1 + + +def test_pose_sender_wraps_sdk_system_exit_and_cleans_up(monkeypatch) -> None: + created_hands = [] + + class FakeInnerHand: + def __init__(self) -> None: + self.close_calls = 0 + + def close(self) -> None: + self.close_calls += 1 + + class FakeLinkerHandApi: + def __init__(self, *, hand_joint: str, hand_type: str, modbus: str, can: str) -> None: + del hand_joint, modbus, can + if hand_type == "right": + raise SystemExit(1) + self.hand = FakeInnerHand() + self.close_can_calls = 0 + created_hands.append(self) + + def set_speed(self, speed: list[int]) -> None: + self.speed = speed + + def close_can(self) -> None: + self.close_can_calls += 1 + + fake_module = SimpleNamespace(LinkerHandApi=FakeLinkerHandApi) + monkeypatch.setitem(sys.modules, "LinkerHand.linker_hand_api", fake_module) + + cfg = parse_linkerhand_config( + { + "dexterous_hand": { + "enabled": True, + "dry_run": False, + "hand_type": "both", + } + } + ) + sender = L6PoseSender(cfg) + + with pytest.raises(RuntimeError, match="LinkerHand SDK exited during startup"): + sender.start() + + assert sender.started is False + assert sender._hands == {} + assert len(created_hands) == 1 + assert created_hands[0].close_can_calls == 1 + assert created_hands[0].hand.close_calls == 1 diff --git a/tests/test_mocap_mujoco.py b/tests/test_mocap_mujoco.py index e0a077c7..8d128641 100644 --- a/tests/test_mocap_mujoco.py +++ b/tests/test_mocap_mujoco.py @@ -1,5 +1,8 @@ from __future__ import annotations +import threading +from types import SimpleNamespace + import mujoco import numpy as np import pytest @@ -157,3 +160,28 @@ def test_pico4_provider_exposes_mocap_skeleton_metadata() -> None: assert pico4.bone_names == list(BODY_JOINT_NAMES) np.testing.assert_array_equal(pico4.bone_parents, BODY_JOINT_PARENTS) + + +def test_pico4_provider_exposes_controller_snapshot() -> None: + pico4 = object.__new__(Pico4InputProvider) + pico4._lock = threading.Lock() + pico4._controller_snapshot = None + pico4._last_source_seq = None + + frame = SimpleNamespace( + seq=42, + controllers=SimpleNamespace( + left=SimpleNamespace(raw=True, axis={"grip": 0.75, "trigger": 0.25}), + right=SimpleNamespace(raw=False, axis={}), + ), + ) + pico4._accept_controller_snapshot(frame, timestamp=12.5) + + snapshot = pico4.get_controller_snapshot() + assert snapshot is not None + assert snapshot.seq == 42 + assert snapshot.timestamp_s == pytest.approx(12.5) + assert snapshot.left.raw is True + assert snapshot.left.grip == pytest.approx(0.75) + assert snapshot.left.trigger == pytest.approx(0.25) + assert snapshot.right.raw is False diff --git a/tests/test_pico4_provider.py b/tests/test_pico4_provider.py index 4dcde8e5..7830eded 100644 --- a/tests/test_pico4_provider.py +++ b/tests/test_pico4_provider.py @@ -201,6 +201,21 @@ def test_pico4_provider_exposes_pause_control_events_once() -> None: assert packet.control_events == () +def test_pico4_provider_marks_controller_present_without_raw_field() -> None: + provider = _make_provider() + frame = _pico_frame(_body_poses(1.0), seq=1, timestamp=1.0) + frame.controllers.left.axis = {"grip": 0.8, "trigger": 0.4} + + assert provider._accept_pico_frame(frame) is True + + snapshot = provider.get_controller_snapshot() + assert snapshot is not None + assert snapshot.left.present is True + assert snapshot.left.raw is False + assert snapshot.left.grip == pytest.approx(0.8) + assert snapshot.left.trigger == pytest.approx(0.4) + + def test_pico4_provider_reads_pause_control_events_when_body_inactive() -> None: provider = _make_provider() diff --git a/tests/test_sim2real_runtime.py b/tests/test_sim2real_runtime.py index 80963eb1..f38cf37a 100644 --- a/tests/test_sim2real_runtime.py +++ b/tests/test_sim2real_runtime.py @@ -138,6 +138,21 @@ def build( return np.zeros(self.total_obs_size, dtype=np.float32) +class DummyHandRuntime: + def __init__(self) -> None: + self.active_flags: list[bool] = [] + self.close_calls = 0 + + def start(self) -> None: + pass + + def tick(self, *, active: bool) -> None: + self.active_flags.append(active) + + def close(self) -> None: + self.close_calls += 1 + + def _make_cfg(transition_duration: float = 1.0) -> dict[str, object]: return { "policy_hz": 50.0, @@ -270,6 +285,33 @@ def test_state_machine_allows_mocap_reentry_after_returning_to_standing(monkeypa assert ctrl.mode == RobotMode.MOCAP +def test_dexterous_hand_ticks_only_during_active_mocap(monkeypatch) -> None: + import teleopit.sim2real.controller as controller_mod + from teleopit.runtime.mocap_session import MocapSessionState + from teleopit.sim2real.controller import RobotMode, Sim2RealController + + policy = DummyPolicy() + obs_builder = DummyVelCmdObservationBuilder() + hand_runtime = DummyHandRuntime() + _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=np.zeros(36, dtype=np.float64)) + monkeypatch.setattr(controller_mod, "build_linkerhand_runtime", lambda _cfg, _provider: hand_runtime) + + ctrl = Sim2RealController(_make_cfg()) + + ctrl.mode = RobotMode.STANDING + ctrl._tick_dexterous_hand() + + ctrl.mode = RobotMode.MOCAP + ctrl._mocap_session.reset() + assert ctrl._mocap_session.state == MocapSessionState.ACTIVE + ctrl._tick_dexterous_hand() + + ctrl._mocap_session.pause(np.zeros(36, dtype=np.float64)) + ctrl._tick_dexterous_hand() + + assert hand_runtime.active_flags == [False, True, False] + + def test_can_switch_to_mocap_returns_false_without_blocking_when_realtime_has_no_frame(monkeypatch) -> None: from teleopit.sim2real.controller import Sim2RealController diff --git a/third_party/linkerhand-python-sdk b/third_party/linkerhand-python-sdk new file mode 160000 index 00000000..d884a720 --- /dev/null +++ b/third_party/linkerhand-python-sdk @@ -0,0 +1 @@ +Subproject commit d884a72081539bb159855f54945e497eedefd31a From 5b22280042f3a37446510491c88a83cde0565afd Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Wed, 13 May 2026 21:42:35 +0800 Subject: [PATCH 010/122] Remove LinkerHand dry run mode --- docs/docs/tutorials/pico-sim2real.md | 21 ++- .../current/tutorials/pico-sim2real.md | 20 +-- scripts/dev/test_linkerhand_l6.py | 163 ++++++++++++++++++ teleopit/configs/pico4_sim2real.yaml | 1 - teleopit/configs/sim2real.yaml | 1 - teleopit/sim2real/dexterous_hand.py | 23 +-- tests/test_dexterous_hand.py | 67 ++++--- 7 files changed, 234 insertions(+), 62 deletions(-) create mode 100644 scripts/dev/test_linkerhand_l6.py diff --git a/docs/docs/tutorials/pico-sim2real.md b/docs/docs/tutorials/pico-sim2real.md index fb9c0866..fb7cb9a3 100644 --- a/docs/docs/tutorials/pico-sim2real.md +++ b/docs/docs/tutorials/pico-sim2real.md @@ -152,21 +152,20 @@ git submodule update --init third_party/linkerhand-python-sdk pip install -e third_party/linkerhand-python-sdk ``` -Dry-run before connecting CAN: +Before enabling full sim2real, verify the hand connection with a standalone +open/close test: ```bash -python scripts/run/run_sim2real.py \ - --config-name pico4_sim2real \ - controller.policy_path=track.onnx \ - dexterous_hand.enabled=true \ - dexterous_hand.dry_run=true +python scripts/dev/test_linkerhand_l6.py \ + --hand-type both \ + --left-can can0 \ + --right-can can1 ``` -For real hands, configure the CAN channels and disable dry-run: +Then enable L6 control in Pico sim2real: ```bash dexterous_hand.enabled=true -dexterous_hand.dry_run=false dexterous_hand.left_can=can0 dexterous_hand.right_can=can1 ``` @@ -210,8 +209,8 @@ pause_resume_warmup_steps=2 # Change Pico pause button input.pause_button=right_axis_click -# Enable LinkerHand L6 dry-run -dexterous_hand.enabled=true dexterous_hand.dry_run=true +# Enable LinkerHand L6 control +dexterous_hand.enabled=true # Enable headset video preview input.video.enabled=true @@ -226,5 +225,5 @@ input.video.enabled=true | Cannot enter debug mode | Unitree mode release failed | Stop other robot modes and press `Start` again | | Robot enters `STANDING` but not `MOCAP` | Mocap validation failed | Keep tracking active and stable; check `mocap_switch.check_frames` logs | | Pico pause does not return to `STANDING` | Expected behavior | Pico pause freezes mocap; press remote `X` for `STANDING` | -| LinkerHand does not move | Not in `MOCAP`, deadman grip released, SDK not installed, or CAN channel wrong | Enter `MOCAP`, hold the matching side grip, verify `pip install -e third_party/linkerhand-python-sdk`, and check `dexterous_hand.left_can` / `right_can` | +| LinkerHand does not move | Not in `MOCAP`, deadman grip released, SDK not installed, or CAN channel wrong | Enter `MOCAP`, hold the matching side grip, run `scripts/dev/test_linkerhand_l6.py`, and check `dexterous_hand.left_can` / `right_can` | | Video preview is unavailable | RealSense or video source failed | Check camera permissions, `input.video.source`, and logs | diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md index b25daff6..5f8a23cd 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md @@ -143,21 +143,19 @@ git submodule update --init third_party/linkerhand-python-sdk pip install -e third_party/linkerhand-python-sdk ``` -连接 CAN 前先 dry-run: +启用完整 sim2real 前,先用独立开合测试验证灵巧手连接: ```bash -python scripts/run/run_sim2real.py \ - --config-name pico4_sim2real \ - controller.policy_path=track.onnx \ - dexterous_hand.enabled=true \ - dexterous_hand.dry_run=true +python scripts/dev/test_linkerhand_l6.py \ + --hand-type both \ + --left-can can0 \ + --right-can can1 ``` -使用真实手时,配置 CAN 通道并关闭 dry-run: +然后在 Pico sim2real 中启用 L6 控制: ```bash dexterous_hand.enabled=true -dexterous_hand.dry_run=false dexterous_hand.left_can=can0 dexterous_hand.right_can=can1 ``` @@ -201,8 +199,8 @@ pause_resume_warmup_steps=2 # 更换 Pico 暂停键 input.pause_button=right_axis_click -# 开启 LinkerHand L6 dry-run -dexterous_hand.enabled=true dexterous_hand.dry_run=true +# 开启 LinkerHand L6 控制 +dexterous_hand.enabled=true # 开启头显视频预览 input.video.enabled=true @@ -217,5 +215,5 @@ input.video.enabled=true | 无法进入 debug mode | Unitree mode 释放失败 | 停止其他机器人模式后再次按 `Start` | | 机器人进入 `STANDING` 但不进入 `MOCAP` | 动捕验证失败 | 保持追踪稳定,查看 `mocap_switch.check_frames` 日志 | | Pico 暂停没有返回 `STANDING` | 这是预期行为 | Pico 暂停只冻结 mocap;按遥控器 `X` 返回 `STANDING` | -| LinkerHand 不动 | 不在 `MOCAP`、deadman grip 未按住、SDK 未安装,或 CAN 通道错误 | 进入 `MOCAP`,按住同侧 grip,确认已执行 `pip install -e third_party/linkerhand-python-sdk`,并检查 `dexterous_hand.left_can` / `right_can` | +| LinkerHand 不动 | 不在 `MOCAP`、deadman grip 未按住、SDK 未安装,或 CAN 通道错误 | 进入 `MOCAP`,按住同侧 grip,运行 `scripts/dev/test_linkerhand_l6.py`,并检查 `dexterous_hand.left_can` / `right_can` | | 视频预览不可用 | RealSense 或视频源失败 | 检查相机权限、`input.video.source` 和日志 | diff --git a/scripts/dev/test_linkerhand_l6.py b/scripts/dev/test_linkerhand_l6.py new file mode 100644 index 00000000..2f0e8a4c --- /dev/null +++ b/scripts/dev/test_linkerhand_l6.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python3 +"""Exercise LinkerHand L6 open/close motion to verify hardware connectivity.""" + +from __future__ import annotations + +import argparse +from pathlib import Path +import sys +import time +from typing import Sequence + + +REPO_ROOT = Path(__file__).resolve().parents[2] +SDK_PATH = REPO_ROOT / "third_party" / "linkerhand-python-sdk" +if SDK_PATH.exists(): + sys.path.insert(0, str(SDK_PATH)) + + +THUMB_YAW_DEFAULT = 10 +OPEN_POSE = [250, THUMB_YAW_DEFAULT, 250, 250, 250, 250] +CLOSE_POSE = [79, THUMB_YAW_DEFAULT, 0, 0, 0, 0] +DEFAULT_SPEED = [50, 50, 50, 50, 50, 50] + + +def uint8(value: str) -> int: + parsed = int(value) + if parsed < 0 or parsed > 255: + raise argparse.ArgumentTypeError("value must be in range 0-255") + return parsed + + +def positive_float(value: str) -> float: + parsed = float(value) + if parsed <= 0.0: + raise argparse.ArgumentTypeError("value must be greater than 0") + return parsed + + +def positive_int(value: str) -> int: + parsed = int(value) + if parsed <= 0: + raise argparse.ArgumentTypeError("value must be greater than 0") + return parsed + + +def selected_hand_types(hand_type: str) -> tuple[str, ...]: + if hand_type == "both": + return ("left", "right") + return (hand_type,) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Test LinkerHand L6 open/close motion") + parser.add_argument("--hand-type", choices=["left", "right", "both"], default="both") + parser.add_argument("--left-can", default="can0") + parser.add_argument("--right-can", default="can1") + parser.add_argument( + "--modbus", + default="None", + help='RS485 serial port such as /dev/ttyUSB0; "None" uses CAN', + ) + parser.add_argument("--cycles", type=positive_int, default=3) + parser.add_argument("--hold-s", type=positive_float, default=1.0) + parser.add_argument("--thumb-yaw-center", type=uint8, default=THUMB_YAW_DEFAULT) + parser.add_argument( + "--speed", + type=uint8, + nargs=6, + default=DEFAULT_SPEED, + metavar=("THUMB_PITCH", "THUMB_YAW", "INDEX", "MIDDLE", "RING", "LITTLE"), + ) + parser.add_argument( + "--open-pose", + type=uint8, + nargs=6, + default=OPEN_POSE, + metavar=("THUMB_PITCH", "THUMB_YAW", "INDEX", "MIDDLE", "RING", "LITTLE"), + ) + parser.add_argument( + "--close-pose", + type=uint8, + nargs=6, + default=CLOSE_POSE, + metavar=("THUMB_PITCH", "THUMB_YAW", "INDEX", "MIDDLE", "RING", "LITTLE"), + ) + args = parser.parse_args() + args.open_pose[1] = args.thumb_yaw_center + args.close_pose[1] = args.thumb_yaw_center + return args + + +def send_all(hands: dict[str, object], pose: Sequence[int], *, label: str) -> None: + print(f"{label}: {list(pose)}", flush=True) + for hand_type, hand in hands.items(): + print(f" {hand_type}", flush=True) + hand.finger_move(pose=list(pose)) + + +def close_all(hands: dict[str, object]) -> None: + for hand in hands.values(): + close_can = getattr(hand, "close_can", None) + if callable(close_can): + close_can() + inner_hand = getattr(hand, "hand", None) + close = getattr(inner_hand, "close", None) + if callable(close): + close() + + +def main() -> None: + args = parse_args() + try: + from LinkerHand.linker_hand_api import LinkerHandApi + except ImportError as exc: + raise SystemExit( + "LinkerHand SDK import failed. Run: " + "git submodule update --init third_party/linkerhand-python-sdk && " + "pip install -e third_party/linkerhand-python-sdk" + ) from exc + + hand_types = selected_hand_types(args.hand_type) + can_channels = {"left": args.left_can, "right": args.right_can} + hands: dict[str, object] = {} + + print( + "Testing LinkerHand L6 | " + f"hands={','.join(hand_types)} | " + f"can={','.join(f'{hand}:{can_channels[hand]}' for hand in hand_types)} | " + f"modbus={args.modbus}", + flush=True, + ) + try: + for hand_type in hand_types: + hand = LinkerHandApi( + hand_joint="L6", + hand_type=hand_type, + modbus=args.modbus, + can=can_channels[hand_type], + ) + hand.set_speed(speed=list(args.speed)) + hands[hand_type] = hand + + send_all(hands, args.open_pose, label="startup open") + time.sleep(args.hold_s) + for cycle in range(args.cycles): + print(f"cycle {cycle + 1}/{args.cycles}", flush=True) + send_all(hands, args.close_pose, label="close") + time.sleep(args.hold_s) + send_all(hands, args.open_pose, label="open") + time.sleep(args.hold_s) + except KeyboardInterrupt: + print("Interrupted; opening hands before exit", flush=True) + finally: + if hands: + try: + send_all(hands, args.open_pose, label="exit open") + time.sleep(0.2) + finally: + close_all(hands) + + +if __name__ == "__main__": + main() diff --git a/teleopit/configs/pico4_sim2real.yaml b/teleopit/configs/pico4_sim2real.yaml index 081fc192..f85900e2 100644 --- a/teleopit/configs/pico4_sim2real.yaml +++ b/teleopit/configs/pico4_sim2real.yaml @@ -48,7 +48,6 @@ dexterous_hand: speed: [50, 50, 50, 50, 50, 50] open_pose: [250, 10, 250, 250, 250, 250] close_pose: [79, 10, 0, 0, 0, 0] - dry_run: false print_input: false # Physical robot SDK configuration diff --git a/teleopit/configs/sim2real.yaml b/teleopit/configs/sim2real.yaml index ef31bc6d..27c99e39 100644 --- a/teleopit/configs/sim2real.yaml +++ b/teleopit/configs/sim2real.yaml @@ -43,7 +43,6 @@ dexterous_hand: speed: [50, 50, 50, 50, 50, 50] open_pose: [250, 10, 250, 250, 250, 250] close_pose: [79, 10, 0, 0, 0, 0] - dry_run: false print_input: false # Physical robot SDK configuration diff --git a/teleopit/sim2real/dexterous_hand.py b/teleopit/sim2real/dexterous_hand.py index 9899caf0..658a32cc 100644 --- a/teleopit/sim2real/dexterous_hand.py +++ b/teleopit/sim2real/dexterous_hand.py @@ -40,7 +40,6 @@ class LinkerHandConfig: speed: tuple[int, ...] = tuple(DEFAULT_SPEED) open_pose: tuple[int, ...] = tuple(OPEN_POSE) close_pose: tuple[int, ...] = tuple(CLOSE_POSE) - dry_run: bool = False print_input: bool = False @property @@ -109,7 +108,6 @@ def parse_linkerhand_config(cfg: Any) -> LinkerHandConfig: speed=tuple(_pose_values(cfg_get(hand_cfg, "speed", DEFAULT_SPEED), "speed")), open_pose=tuple(open_pose), close_pose=tuple(close_pose), - dry_run=bool(cfg_get(hand_cfg, "dry_run", False)), print_input=bool(cfg_get(hand_cfg, "print_input", False)), ) if config.hand_joint != "L6": @@ -139,10 +137,6 @@ def started(self) -> bool: def start(self) -> None: if self._started: return - if self._config.dry_run: - logger.info("LinkerHand L6 dry-run enabled | hands=%s", ",".join(self._hand_types)) - self._started = True - return try: from LinkerHand.linker_hand_api import LinkerHandApi except ImportError as exc: @@ -167,8 +161,8 @@ def start(self) -> None: self._started = False raise RuntimeError( "LinkerHand SDK exited during startup. Check CAN interface configuration " - f"({', '.join(self._can_channels[hand_type] for hand_type in self._hand_types)}) " - "or use dexterous_hand.dry_run=true before connecting hardware." + f"({', '.join(self._can_channels[hand_type] for hand_type in self._hand_types)}). " + "Run scripts/dev/test_linkerhand_l6.py to verify the hand connection." ) from exc except Exception: self._close_hands() @@ -182,14 +176,11 @@ def send(self, hand_type: str, pose: Sequence[int], *, force: bool = False, reas next_pose = [int(value) for value in pose] if not force and self._last_pose.get(hand_type) == next_pose: return - if self._config.dry_run: - suffix = f" ({reason})" if reason else "" - logger.info("dry-run LinkerHand %s pose%s: %s", hand_type, suffix, next_pose) - else: - hand = self._hands.get(hand_type) - if hand is None: - raise RuntimeError("L6PoseSender has not been started") - hand.finger_move(pose=next_pose) + del reason + hand = self._hands.get(hand_type) + if hand is None: + raise RuntimeError("L6PoseSender has not been started") + hand.finger_move(pose=next_pose) self._last_pose[hand_type] = next_pose def send_all(self, pose: Sequence[int], *, force: bool = False, reason: str = "") -> None: diff --git a/tests/test_dexterous_hand.py b/tests/test_dexterous_hand.py index 6878fc02..42e30ad5 100644 --- a/tests/test_dexterous_hand.py +++ b/tests/test_dexterous_hand.py @@ -14,6 +14,46 @@ ) +class FakeInnerHand: + def __init__(self) -> None: + self.close_calls = 0 + + def close(self) -> None: + self.close_calls += 1 + + +class FakeLinkerHandApi: + instances: list["FakeLinkerHandApi"] = [] + + def __init__(self, *, hand_joint: str, hand_type: str, modbus: str, can: str) -> None: + self.hand_joint = hand_joint + self.hand_type = hand_type + self.modbus = modbus + self.can = can + self.hand = FakeInnerHand() + self.close_can_calls = 0 + self.speed: list[int] | None = None + self.poses: list[list[int]] = [] + FakeLinkerHandApi.instances.append(self) + + def set_speed(self, speed: list[int]) -> None: + self.speed = speed + + def finger_move(self, pose: list[int]) -> None: + self.poses.append(list(pose)) + + def close_can(self) -> None: + self.close_can_calls += 1 + + +@pytest.fixture(autouse=True) +def fake_linkerhand_sdk(monkeypatch): + FakeLinkerHandApi.instances = [] + fake_module = SimpleNamespace(LinkerHandApi=FakeLinkerHandApi) + monkeypatch.setitem(sys.modules, "LinkerHand.linker_hand_api", fake_module) + yield + + class SnapshotProvider: def __init__(self) -> None: self.snapshot: PicoControllerSnapshot | None = None @@ -29,7 +69,7 @@ def _snapshot( timestamp_s: float = 10.0, seq: int = 1, ) -> PicoControllerSnapshot: - missing = PicoControllerState(raw=False, grip=0.0, trigger=0.0) + missing = PicoControllerState(raw=False, grip=0.0, trigger=0.0, present=False) return PicoControllerSnapshot( left=left or missing, right=right or missing, @@ -43,7 +83,6 @@ def _runtime(provider: SnapshotProvider) -> LinkerHandRuntime: { "dexterous_hand": { "enabled": True, - "dry_run": True, "hand_type": "both", } } @@ -138,14 +177,7 @@ def test_runtime_opens_on_timeout_and_inactive_mode() -> None: def test_pose_sender_cleans_up_partial_start_failure(monkeypatch) -> None: created_hands = [] - class FakeInnerHand: - def __init__(self) -> None: - self.close_calls = 0 - - def close(self) -> None: - self.close_calls += 1 - - class FakeLinkerHandApi: + class FailingLinkerHandApi: def __init__(self, *, hand_joint: str, hand_type: str, modbus: str, can: str) -> None: del hand_joint, modbus, can if hand_type == "right": @@ -160,14 +192,13 @@ def set_speed(self, speed: list[int]) -> None: def close_can(self) -> None: self.close_can_calls += 1 - fake_module = SimpleNamespace(LinkerHandApi=FakeLinkerHandApi) + fake_module = SimpleNamespace(LinkerHandApi=FailingLinkerHandApi) monkeypatch.setitem(sys.modules, "LinkerHand.linker_hand_api", fake_module) cfg = parse_linkerhand_config( { "dexterous_hand": { "enabled": True, - "dry_run": False, "hand_type": "both", } } @@ -187,14 +218,7 @@ def close_can(self) -> None: def test_pose_sender_wraps_sdk_system_exit_and_cleans_up(monkeypatch) -> None: created_hands = [] - class FakeInnerHand: - def __init__(self) -> None: - self.close_calls = 0 - - def close(self) -> None: - self.close_calls += 1 - - class FakeLinkerHandApi: + class ExitingLinkerHandApi: def __init__(self, *, hand_joint: str, hand_type: str, modbus: str, can: str) -> None: del hand_joint, modbus, can if hand_type == "right": @@ -209,14 +233,13 @@ def set_speed(self, speed: list[int]) -> None: def close_can(self) -> None: self.close_can_calls += 1 - fake_module = SimpleNamespace(LinkerHandApi=FakeLinkerHandApi) + fake_module = SimpleNamespace(LinkerHandApi=ExitingLinkerHandApi) monkeypatch.setitem(sys.modules, "LinkerHand.linker_hand_api", fake_module) cfg = parse_linkerhand_config( { "dexterous_hand": { "enabled": True, - "dry_run": False, "hand_type": "both", } } From 9e3507c839803d0c4a21d92e7a7f2672be6017ef Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Wed, 13 May 2026 21:55:55 +0800 Subject: [PATCH 011/122] Adjust LinkerHand L6 test cleanup --- scripts/dev/test_linkerhand_l6.py | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/scripts/dev/test_linkerhand_l6.py b/scripts/dev/test_linkerhand_l6.py index 2f0e8a4c..643e8bb7 100644 --- a/scripts/dev/test_linkerhand_l6.py +++ b/scripts/dev/test_linkerhand_l6.py @@ -96,17 +96,6 @@ def send_all(hands: dict[str, object], pose: Sequence[int], *, label: str) -> No hand.finger_move(pose=list(pose)) -def close_all(hands: dict[str, object]) -> None: - for hand in hands.values(): - close_can = getattr(hand, "close_can", None) - if callable(close_can): - close_can() - inner_hand = getattr(hand, "hand", None) - close = getattr(inner_hand, "close", None) - if callable(close): - close() - - def main() -> None: args = parse_args() try: @@ -152,11 +141,12 @@ def main() -> None: print("Interrupted; opening hands before exit", flush=True) finally: if hands: - try: - send_all(hands, args.open_pose, label="exit open") - time.sleep(0.2) - finally: - close_all(hands) + send_all(hands, args.open_pose, label="exit open") + time.sleep(0.2) + print( + "Exit cleanup intentionally leaves CAN interfaces up to avoid SDK network-down noise.", + flush=True, + ) if __name__ == "__main__": From 07d9d0bcf14b158a12752249dd27dd0b1cd15373 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Wed, 13 May 2026 22:27:39 +0800 Subject: [PATCH 012/122] Keep LinkerHand CAN interfaces open on close --- teleopit/sim2real/dexterous_hand.py | 3 --- tests/test_dexterous_hand.py | 22 ++++++++++++++++++++-- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/teleopit/sim2real/dexterous_hand.py b/teleopit/sim2real/dexterous_hand.py index 658a32cc..753dded2 100644 --- a/teleopit/sim2real/dexterous_hand.py +++ b/teleopit/sim2real/dexterous_hand.py @@ -201,9 +201,6 @@ def close(self) -> None: def _close_hands(self) -> None: for hand in self._hands.values(): - close_can = getattr(hand, "close_can", None) - if callable(close_can): - close_can() inner_hand = getattr(hand, "hand", None) close = getattr(inner_hand, "close", None) if callable(close): diff --git a/tests/test_dexterous_hand.py b/tests/test_dexterous_hand.py index 42e30ad5..2d6dfb92 100644 --- a/tests/test_dexterous_hand.py +++ b/tests/test_dexterous_hand.py @@ -174,6 +174,24 @@ def test_runtime_opens_on_timeout_and_inactive_mode() -> None: assert runtime._sender._last_pose["left"] == list(runtime.config.open_pose) +def test_pose_sender_close_leaves_can_interfaces_up() -> None: + cfg = parse_linkerhand_config( + { + "dexterous_hand": { + "enabled": True, + "hand_type": "both", + } + } + ) + sender = L6PoseSender(cfg) + sender.start() + + sender.close() + + assert [hand.close_can_calls for hand in FakeLinkerHandApi.instances] == [0, 0] + assert [hand.hand.close_calls for hand in FakeLinkerHandApi.instances] == [1, 1] + + def test_pose_sender_cleans_up_partial_start_failure(monkeypatch) -> None: created_hands = [] @@ -211,7 +229,7 @@ def close_can(self) -> None: assert sender.started is False assert sender._hands == {} assert len(created_hands) == 1 - assert created_hands[0].close_can_calls == 1 + assert created_hands[0].close_can_calls == 0 assert created_hands[0].hand.close_calls == 1 @@ -252,5 +270,5 @@ def close_can(self) -> None: assert sender.started is False assert sender._hands == {} assert len(created_hands) == 1 - assert created_hands[0].close_can_calls == 1 + assert created_hands[0].close_can_calls == 0 assert created_hands[0].hand.close_calls == 1 From 7c74e84baaddb83dbce3ab4544650ff206304a98 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Wed, 13 May 2026 23:33:09 +0800 Subject: [PATCH 013/122] Tune Pico4 sim2real realtime buffering --- teleopit/configs/pico4_sim2real.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/teleopit/configs/pico4_sim2real.yaml b/teleopit/configs/pico4_sim2real.yaml index f85900e2..ce164da3 100644 --- a/teleopit/configs/pico4_sim2real.yaml +++ b/teleopit/configs/pico4_sim2real.yaml @@ -7,21 +7,21 @@ defaults: policy_hz: 50.0 transition_duration: 2.0 # seconds to interpolate from default pose to motion command (0.0 to disable) pause_resume_transition_duration: 1.0 # offline playback resume blend; realtime mocap resumes from re-centered live tracking -pause_resume_warmup_steps: 2 # realtime mocap frames to accumulate before resuming live tracking +pause_resume_warmup_steps: 0 # realtime mocap frames to accumulate before resuming live tracking input: video: source: realsense retarget_buffer_enabled: true retarget_buffer_window_s: 0.5 -retarget_buffer_delay_s: 0.05 +retarget_buffer_delay_s: null # null = auto use one input-frame delay for timeline sampling realtime_buffer_low_watermark_steps: 2 realtime_buffer_high_watermark_steps: 4 realtime_buffer_warmup_steps: 2 -realtime_catchup_enabled: true +realtime_catchup_enabled: false realtime_catchup_trigger_steps: 6 realtime_catchup_release_steps: 3 realtime_catchup_target_delay_s: 0.04 -reference_qpos_smoothing_alpha: 0.4 +reference_qpos_smoothing_alpha: 1.0 reference_velocity_smoothing_alpha: 0.35 reference_anchor_velocity_smoothing_alpha: 0.25 reference_steps: [0] From bba22ddbdf02a5a5f6322daad857f91992d4b599 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Thu, 14 May 2026 12:22:29 +0800 Subject: [PATCH 014/122] Decouple LinkerHand control from body loop --- teleopit/sim2real/controller.py | 3 +- teleopit/sim2real/dexterous_hand.py | 129 +++++++++++++++++++++++++++- tests/test_dexterous_hand.py | 11 +++ tests/test_sim2real_runtime.py | 26 ++++++ 4 files changed, 166 insertions(+), 3 deletions(-) diff --git a/teleopit/sim2real/controller.py b/teleopit/sim2real/controller.py index c29ac0eb..0301d65b 100644 --- a/teleopit/sim2real/controller.py +++ b/teleopit/sim2real/controller.py @@ -891,8 +891,7 @@ def _tick_dexterous_hand(self) -> None: try: self._hand_runtime.tick(active=active) except Exception: - logger.exception("Dexterous hand runtime failed -- entering damping") - self._enter_damping() + logger.exception("Dexterous hand runtime failed; body control continues") def _deactivate_dexterous_hand(self) -> None: try: diff --git a/teleopit/sim2real/dexterous_hand.py b/teleopit/sim2real/dexterous_hand.py index 753dded2..98949d3b 100644 --- a/teleopit/sim2real/dexterous_hand.py +++ b/teleopit/sim2real/dexterous_hand.py @@ -4,6 +4,7 @@ from dataclasses import dataclass import logging +import threading import time from typing import Any, Protocol, Sequence @@ -208,13 +209,139 @@ def _close_hands(self) -> None: self._hands.clear() +class AsyncL6PoseSender: + """Run blocking LinkerHand SDK calls outside the robot control loop.""" + + def __init__(self, config: LinkerHandConfig): + self._config = config + self._sync_sender = L6PoseSender(config) + self._condition = threading.Condition() + self._pending: dict[str, tuple[list[int], bool, str]] = {} + self._thread: threading.Thread | None = None + self._running = False + self._stopping = False + self._busy = False + self._failed = False + + @property + def started(self) -> bool: + return self._running and not self._failed + + @property + def _last_pose(self) -> dict[str, list[int] | None]: + return self._sync_sender._last_pose + + def start(self) -> None: + with self._condition: + if self._running: + return + self._running = True + self._stopping = False + self._failed = False + self._busy = True + self._thread = threading.Thread( + target=self._run, + name="linkerhand-l6-sender", + daemon=True, + ) + self._thread.start() + + def send(self, hand_type: str, pose: Sequence[int], *, force: bool = False, reason: str = "") -> None: + next_pose = [int(value) for value in pose] + if not force and self._sync_sender._last_pose.get(hand_type) == next_pose: + return + with self._condition: + if not self._running or self._failed or self._stopping: + return + self._pending[hand_type] = (next_pose, force, reason) + self._condition.notify_all() + + def send_all(self, pose: Sequence[int], *, force: bool = False, reason: str = "") -> None: + for hand_type in self._config.selected_hand_types: + self.send(hand_type, pose, force=force, reason=reason) + + def close(self) -> None: + thread: threading.Thread | None + with self._condition: + if not self._running: + return + if not self._failed: + for hand_type in self._config.selected_hand_types: + self._pending[hand_type] = (list(self._config.open_pose), True, "exit") + self._stopping = True + self._condition.notify_all() + thread = self._thread + if thread is not None: + thread.join(timeout=3.0) + if thread.is_alive(): + logger.warning("LinkerHand L6 worker did not stop within timeout") + + def wait_idle(self, timeout_s: float = 1.0) -> bool: + deadline = time.monotonic() + timeout_s + with self._condition: + while self._busy or self._pending: + remaining = deadline - time.monotonic() + if remaining <= 0.0: + return False + self._condition.wait(timeout=remaining) + return True + + def _run(self) -> None: + try: + self._sync_sender.start() + self._sync_sender.send_all(self._config.open_pose, force=True, reason="startup") + while True: + commands = self._take_commands() + if not commands: + break + try: + for hand_type, pose, force, reason in commands: + self._sync_sender.send(hand_type, pose, force=force, reason=reason) + finally: + with self._condition: + self._busy = False + self._condition.notify_all() + except Exception: + logger.exception("LinkerHand L6 worker failed; hand control is disabled") + with self._condition: + self._failed = True + self._pending.clear() + self._busy = False + self._condition.notify_all() + finally: + try: + self._sync_sender.close() + except Exception: + logger.exception("Failed to close LinkerHand L6 worker cleanly") + with self._condition: + self._running = False + self._busy = False + self._condition.notify_all() + + def _take_commands(self) -> list[tuple[str, list[int], bool, str]]: + with self._condition: + while not self._pending and not self._stopping: + self._busy = False + self._condition.notify_all() + self._condition.wait() + if not self._pending and self._stopping: + return [] + self._busy = True + commands = [ + (hand_type, pose, force, reason) + for hand_type, (pose, force, reason) in self._pending.items() + ] + self._pending.clear() + return commands + + class LinkerHandRuntime: """Drive LinkerHand L6 from Pico controller grip/trigger snapshots.""" def __init__(self, config: LinkerHandConfig, provider: ControllerSnapshotProvider): self.config = config self._provider = provider - self._sender = L6PoseSender(config) + self._sender = AsyncL6PoseSender(config) self._interval_s = 1.0 / config.rate self._next_tick_s = 0.0 self._active = False diff --git a/tests/test_dexterous_hand.py b/tests/test_dexterous_hand.py index 2d6dfb92..de39e01c 100644 --- a/tests/test_dexterous_hand.py +++ b/tests/test_dexterous_hand.py @@ -92,6 +92,10 @@ def _runtime(provider: SnapshotProvider) -> LinkerHandRuntime: return runtime +def _wait_runtime_idle(runtime: LinkerHandRuntime) -> None: + assert runtime._sender.wait_idle(timeout_s=1.0) + + def test_trigger_to_pose_applies_deadzone_and_fixed_thumb_yaw() -> None: pose = trigger_to_pose( 0.5, @@ -113,6 +117,7 @@ def test_runtime_opens_when_deadman_released() -> None: ) runtime.tick(active=True, now_s=10.0) + _wait_runtime_idle(runtime) assert runtime._sender._last_pose["left"] == list(runtime.config.open_pose) assert runtime._sender._last_pose["right"] == list(runtime.config.open_pose) @@ -127,6 +132,7 @@ def test_runtime_maps_present_controller_even_without_raw_flag() -> None: ) runtime.tick(active=True, now_s=10.0) + _wait_runtime_idle(runtime) assert runtime._sender._last_pose["left"] == list(runtime.config.close_pose) assert runtime._sender._last_pose["right"] == list(runtime.config.open_pose) @@ -141,6 +147,7 @@ def test_runtime_maps_trigger_when_deadman_active() -> None: ) runtime.tick(active=True, now_s=10.0) + _wait_runtime_idle(runtime) assert runtime._sender._last_pose["left"] == list(runtime.config.close_pose) assert runtime._sender._last_pose["right"] == list(runtime.config.open_pose) @@ -156,10 +163,12 @@ def test_runtime_opens_on_timeout_and_inactive_mode() -> None: ) runtime.tick(active=True, now_s=10.0) + _wait_runtime_idle(runtime) assert runtime._sender._last_pose["left"] == list(runtime.config.close_pose) provider.snapshot = SimpleNamespace(timestamp_s=9.0, seq=2, left=None, right=None) runtime.tick(active=True, now_s=20.0) + _wait_runtime_idle(runtime) assert runtime._sender._last_pose["left"] == list(runtime.config.open_pose) provider.snapshot = _snapshot( @@ -168,9 +177,11 @@ def test_runtime_opens_on_timeout_and_inactive_mode() -> None: timestamp_s=20.1, ) runtime.tick(active=True, now_s=20.1) + _wait_runtime_idle(runtime) assert runtime._sender._last_pose["left"] == list(runtime.config.close_pose) runtime.tick(active=False, now_s=20.2) + _wait_runtime_idle(runtime) assert runtime._sender._last_pose["left"] == list(runtime.config.open_pose) diff --git a/tests/test_sim2real_runtime.py b/tests/test_sim2real_runtime.py index f38cf37a..8ec74241 100644 --- a/tests/test_sim2real_runtime.py +++ b/tests/test_sim2real_runtime.py @@ -153,6 +153,12 @@ def close(self) -> None: self.close_calls += 1 +class FailingHandRuntime(DummyHandRuntime): + def tick(self, *, active: bool) -> None: + super().tick(active=active) + raise RuntimeError("hand send failed") + + def _make_cfg(transition_duration: float = 1.0) -> dict[str, object]: return { "policy_hz": 50.0, @@ -312,6 +318,26 @@ def test_dexterous_hand_ticks_only_during_active_mocap(monkeypatch) -> None: assert hand_runtime.active_flags == [False, True, False] +def test_dexterous_hand_failure_does_not_enter_damping(monkeypatch) -> None: + import teleopit.sim2real.controller as controller_mod + from teleopit.sim2real.controller import RobotMode, Sim2RealController + + policy = DummyPolicy() + obs_builder = DummyVelCmdObservationBuilder() + hand_runtime = FailingHandRuntime() + _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=np.zeros(36, dtype=np.float64)) + monkeypatch.setattr(controller_mod, "build_linkerhand_runtime", lambda _cfg, _provider: hand_runtime) + + ctrl = Sim2RealController(_make_cfg()) + ctrl.mode = RobotMode.MOCAP + ctrl._mocap_session.reset() + + ctrl._tick_dexterous_hand() + + assert ctrl.mode == RobotMode.MOCAP + assert hand_runtime.active_flags == [True] + + def test_can_switch_to_mocap_returns_false_without_blocking_when_realtime_has_no_frame(monkeypatch) -> None: from teleopit.sim2real.controller import Sim2RealController From 11c12b91d49a1687448d7352f8164da77dea6dac Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Thu, 14 May 2026 14:56:32 +0800 Subject: [PATCH 015/122] Remove realtime resume smoothing knobs --- AGENTS.md | 10 +- docs/docs/configuration/config-reference.md | 17 -- docs/docs/tutorials/bvh-sim2real.md | 3 - docs/docs/tutorials/pico-sim2real.md | 3 - .../current/configuration/config-reference.md | 17 -- .../current/tutorials/bvh-sim2real.md | 3 - .../current/tutorials/pico-sim2real.md | 3 - teleopit/configs/online.yaml | 5 - teleopit/configs/pico4_sim.yaml | 4 - teleopit/configs/pico4_sim2real.yaml | 9 - teleopit/configs/sim2real.yaml | 5 - teleopit/controllers/qpos_interpolator.py | 28 --- teleopit/runtime/factory.py | 11 -- teleopit/runtime/reference_config.py | 56 ------ teleopit/sim/loop.py | 23 +-- teleopit/sim/realtime_utils.py | 165 +----------------- teleopit/sim/runtime_components.py | 7 +- teleopit/sim/session.py | 25 +-- teleopit/sim2real/controller.py | 40 +---- teleopit/sim2real/reference_processor.py | 7 - tests/test_pipeline.py | 6 - tests/test_realtime_utils.py | 76 +------- tests/test_sim2real_runtime.py | 15 +- tests/test_sim_loop.py | 8 +- 24 files changed, 30 insertions(+), 516 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index ed756a36..1dfcbccb 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -122,7 +122,7 @@ target_dof_pos = clip(action, -10, 10) × action_scale + default_dof_pos ### Offline Playback - Offline sim2sim and default sim2real both read `input.bvh_file` directly; no UDP relay path remains - Offline sim2sim playback can be keyboard-controlled: `Space/P` pause/resume, `R` replay from frame 0, `Q` stop -- Pause/resume now includes a short hold window; users should stay still during resume and pause again if visible distortion appears +- Offline pause holds the commanded pose; resume resets policy/reference state and uses `transition_duration` for the playback blend - sim2sim keyboard playback is optional via `playback.keyboard.enabled=true` - sim2real reuses the Unitree remote: `Start` → `STANDING`, `Y` → playback, `X` → back to `STANDING`, `L1+R1` → `DAMPING` - `playback.pause_on_end=true` keeps the final pose and waits for manual replay @@ -139,7 +139,7 @@ target_dof_pos = clip(action, -10, 10) × action_scale + default_dof_pos - Pico sim2sim supports a keyboard-driven top-level mode state machine: `STANDING → MOCAP → STANDING` - Default Pico sim2sim keyboard mappings are `Y` → `MOCAP`, `A` → pause/resume mocap, `X` → back to `STANDING`, `Q` → quit - Pico4 sim2real pause/resume is handled as a mocap-session control event (`toggle_pause`), not as a mode switch to `STANDING` -- Default Pico pause button is `A`; restore tracking by rebuilding the realtime buffer and yaw/XY root-offset alignment, then accept the current live mocap retarget reference directly +- Default Pico pause button is `A`; resume rebuilds the realtime buffer and yaw/XY root-offset alignment, then waits for the configured realtime warmup before tracking continues - Optional LinkerHand L6 control uses `third_party/linkerhand-python-sdk` and `dexterous_hand.enabled=true` - LinkerHand control reuses `Pico4InputProvider.get_controller_snapshot()`; do not start a second `PicoBridge` for hand control - LinkerHand L6 control is active only in sim2real `MOCAP`; `STANDING`, `DAMPING`, mocap pause, frame timeout, and shutdown must send the configured open pose @@ -149,10 +149,10 @@ target_dof_pos = clip(action, -10, 10) × action_scale + default_dof_pos - `num_steps=0` means infinite loop (`max_steps = 2**63`) - `KeyboardInterrupt` is handled for clean shutdown - BVH frame alignment is time-based: `bvh_idx = int(policy_time × input_fps)` -- Realtime reference buffering is controlled by `retarget_buffer_enabled`, `retarget_buffer_window_s`, `retarget_buffer_delay_s`, `reference_steps`, `realtime_buffer_warmup_steps`, and the low/high watermark knobs +- Realtime reference buffering is controlled by `retarget_buffer_enabled`, `retarget_buffer_window_s`, `retarget_buffer_delay_s`, `reference_steps`, and `realtime_buffer_warmup_steps` - Realtime inferred `motion_joint_vel`, anchor linear velocity, and anchor angular velocity can be EMA-smoothed via `reference_velocity_smoothing_alpha` and `reference_anchor_velocity_smoothing_alpha` -- Sim2real Pico pause/resume uses mocap-session states `ACTIVE ↔ PAUSED`; resume clears policy/reference state, warms the realtime buffer, rebuilds yaw/XY root alignment, and does not interpolate retarget qpos from the paused pose -- Realtime sim2sim with Pico control events uses the same mocap-session pause/resume semantics and rebuilds the realtime reference path on resume +- Sim2real Pico pause/resume uses mocap-session states `ACTIVE ↔ PAUSED`; resume clears policy/reference state, rebuilds yaw/XY root alignment, warms the realtime buffer, and does not interpolate retarget qpos from the paused pose +- Realtime sim2sim with Pico control events uses the same mocap-session pause/resume semantics and rebuilds the realtime reference path on resume, including the configured warmup - Realtime Pico sim2sim can start directly in `STANDING` with keyboard mode control enabled via top-level `keyboard.enabled` ### Inference Observation diff --git a/docs/docs/configuration/config-reference.md b/docs/docs/configuration/config-reference.md index 76ae83f6..5c7bfc95 100644 --- a/docs/docs/configuration/config-reference.md +++ b/docs/docs/configuration/config-reference.md @@ -84,8 +84,6 @@ Complete reference for all configurable fields. | `retarget_buffer_delay_s` | Buffer delay | | `reference_steps` | Reference window steps | | `realtime_buffer_warmup_steps` | Warmup before playback | -| `realtime_buffer_low_watermark_steps` | Low watermark | -| `realtime_buffer_high_watermark_steps` | High watermark | | `reference_velocity_smoothing_alpha` | Velocity smoothing | | `reference_anchor_velocity_smoothing_alpha` | Anchor velocity smoothing | @@ -116,11 +114,6 @@ Fields used by sim2real configs (`sim2real.yaml`, `pico4_sim2real.yaml`). ### Pause/Resume (Pico sim2real) -| Field | Description | Default | -|-------|-------------|---------| -| `pause_resume_transition_duration` | Resume blend duration for offline playback; realtime Pico resume uses live tracking re-centering | `1.0` | -| `pause_resume_warmup_steps` | Realtime mocap frames to collect before tracking resumes | `2` | - Realtime Pico resume re-centers heading and ground-plane position before tracking continues. Operators should keep still and stay as close as practical to the paused pose to reduce sudden reference changes. ### Dexterous Hand (Pico sim2real) @@ -140,16 +133,6 @@ timeouts send the open pose. | `dexterous_hand.trigger_deadzone` | Trigger deadzone at both ends | `0.05` | | `dexterous_hand.open_pose` / `close_pose` | Six-value L6 open/closed poses | see config | -### Realtime Catch-up (Pico sim2real) - -| Field | Description | Default | -|-------|-------------|---------| -| `realtime_catchup_enabled` | Enable catch-up when buffer grows too large | `true` | -| `realtime_catchup_trigger_steps` | Buffer depth that triggers catch-up | `6` | -| `realtime_catchup_release_steps` | Buffer depth to release catch-up | `3` | -| `realtime_catchup_target_delay_s` | Target delay for catch-up | `0.04` | -| `reference_qpos_smoothing_alpha` | Joint position smoothing (1.0 = no smoothing) | `0.4` | - ## Critical: `default_dof_pos` The RL policy outputs action **offsets** relative to the default standing pose, not absolute joint angles: diff --git a/docs/docs/tutorials/bvh-sim2real.md b/docs/docs/tutorials/bvh-sim2real.md index 18741b2a..884c4d23 100644 --- a/docs/docs/tutorials/bvh-sim2real.md +++ b/docs/docs/tutorials/bvh-sim2real.md @@ -79,9 +79,6 @@ playback.pause_on_end=true # Smooth transition from standing/current robot state into playback transition_duration=2.0 -# Resume blend for offline playback -pause_resume_transition_duration=1.0 - # Control loop rate policy_hz=50 ``` diff --git a/docs/docs/tutorials/pico-sim2real.md b/docs/docs/tutorials/pico-sim2real.md index fb7cb9a3..77a80a02 100644 --- a/docs/docs/tutorials/pico-sim2real.md +++ b/docs/docs/tutorials/pico-sim2real.md @@ -203,9 +203,6 @@ mocap_switch.check_frames=10 # Smooth transition into mocap reference transition_duration=2.0 -# Realtime frames to collect before resume -pause_resume_warmup_steps=2 - # Change Pico pause button input.pause_button=right_axis_click diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md index a84a57ba..6bb02e7a 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md @@ -103,8 +103,6 @@ target = clip(action, clip_range) * action_scale + default_dof_pos | `retarget_buffer_delay_s` | 缓冲延迟 | | `reference_steps` | 参考轨迹窗口步数 | | `realtime_buffer_warmup_steps` | 播放前预热帧数 | -| `realtime_buffer_low_watermark_steps` | 低水位线 | -| `realtime_buffer_high_watermark_steps` | 高水位线 | | `reference_velocity_smoothing_alpha` | 速度平滑系数 | | `reference_anchor_velocity_smoothing_alpha` | 锚点速度平滑系数 | @@ -135,11 +133,6 @@ target = clip(action, clip_range) * action_scale + default_dof_pos ### 暂停/恢复(Pico sim2real) -| 字段 | 说明 | 默认值 | -|---|---|---| -| `pause_resume_transition_duration` | 离线回放恢复时的混合时长;实时 Pico 恢复使用实时追踪重新居中 | `1.0` | -| `pause_resume_warmup_steps` | 恢复追踪前采集的实时动捕帧数 | `2` | - 实时 Pico 恢复追踪时会先重新居中航向和地面平面位置。操作者应保持静止,并尽量贴近暂停时的姿态,以减少参考突变。 ### 灵巧手(Pico sim2real) @@ -157,13 +150,3 @@ LinkerHand SDK submodule。控制只在 `MOCAP` 中生效;非活动模式和 | `dexterous_hand.deadman_threshold` | 启用单侧控制所需的最小 grip 值 | `0.5` | | `dexterous_hand.trigger_deadzone` | trigger 两端死区 | `0.05` | | `dexterous_hand.open_pose` / `close_pose` | L6 的 6 维张开/闭合姿态 | 见配置 | - -### 实时追赶(Pico sim2real) - -| 字段 | 说明 | 默认值 | -|---|---|---| -| `realtime_catchup_enabled` | 缓冲区过大时启用追赶 | `true` | -| `realtime_catchup_trigger_steps` | 触发追赶的缓冲区深度 | `6` | -| `realtime_catchup_release_steps` | 释放追赶的缓冲区深度 | `3` | -| `realtime_catchup_target_delay_s` | 追赶目标延迟 | `0.04` | -| `reference_qpos_smoothing_alpha` | 关节位置平滑系数(1.0 = 无平滑) | `0.4` | diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/bvh-sim2real.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/bvh-sim2real.md index 80644d14..ab57700f 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/bvh-sim2real.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/bvh-sim2real.md @@ -76,9 +76,6 @@ playback.pause_on_end=true # 从 standing/当前机器人状态平滑进入回放 transition_duration=2.0 -# 离线回放恢复混合时长 -pause_resume_transition_duration=1.0 - # 控制循环频率 policy_hz=50 ``` diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md index 5f8a23cd..0a91d69c 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md @@ -193,9 +193,6 @@ mocap_switch.check_frames=10 # 平滑过渡到 mocap 参考 transition_duration=2.0 -# 恢复前采集的实时帧数 -pause_resume_warmup_steps=2 - # 更换 Pico 暂停键 input.pause_button=right_axis_click diff --git a/teleopit/configs/online.yaml b/teleopit/configs/online.yaml index 7fa0a2db..f9a5533e 100644 --- a/teleopit/configs/online.yaml +++ b/teleopit/configs/online.yaml @@ -6,15 +6,10 @@ defaults: policy_hz: 50.0 pd_hz: 200.0 -pause_resume_transition_duration: 1.0 # offline playback resume blend; realtime mocap resumes from re-centered live tracking -pause_resume_warmup_steps: 2 retarget_buffer_enabled: true retarget_buffer_window_s: 0.5 retarget_buffer_delay_s: null -realtime_buffer_low_watermark_steps: 2 -realtime_buffer_high_watermark_steps: 4 realtime_buffer_warmup_steps: 2 -reference_qpos_smoothing_alpha: 0.4 reference_velocity_smoothing_alpha: 0.35 reference_anchor_velocity_smoothing_alpha: 0.25 reference_steps: [0] diff --git a/teleopit/configs/pico4_sim.yaml b/teleopit/configs/pico4_sim.yaml index a5313ae0..c36be166 100644 --- a/teleopit/configs/pico4_sim.yaml +++ b/teleopit/configs/pico4_sim.yaml @@ -6,8 +6,6 @@ defaults: policy_hz: 50.0 pd_hz: 200.0 -pause_resume_transition_duration: 1.0 # offline playback resume blend; realtime mocap resumes from re-centered live tracking -pause_resume_warmup_steps: 0 # realtime mocap frames to accumulate before resuming live tracking in sim2sim keyboard: enabled: true # starts in STANDING; Y=enter mocap, A=pause/resume, X=return standing, Q=quit input: @@ -16,8 +14,6 @@ input: retarget_buffer_enabled: true retarget_buffer_window_s: 0.5 retarget_buffer_delay_s: null # null = auto use one input-frame delay for timeline sampling -realtime_buffer_low_watermark_steps: 2 -realtime_buffer_high_watermark_steps: 4 realtime_buffer_warmup_steps: 2 reference_velocity_smoothing_alpha: 0.35 reference_anchor_velocity_smoothing_alpha: 0.25 diff --git a/teleopit/configs/pico4_sim2real.yaml b/teleopit/configs/pico4_sim2real.yaml index ce164da3..8e270c4f 100644 --- a/teleopit/configs/pico4_sim2real.yaml +++ b/teleopit/configs/pico4_sim2real.yaml @@ -6,22 +6,13 @@ defaults: policy_hz: 50.0 transition_duration: 2.0 # seconds to interpolate from default pose to motion command (0.0 to disable) -pause_resume_transition_duration: 1.0 # offline playback resume blend; realtime mocap resumes from re-centered live tracking -pause_resume_warmup_steps: 0 # realtime mocap frames to accumulate before resuming live tracking input: video: source: realsense retarget_buffer_enabled: true retarget_buffer_window_s: 0.5 retarget_buffer_delay_s: null # null = auto use one input-frame delay for timeline sampling -realtime_buffer_low_watermark_steps: 2 -realtime_buffer_high_watermark_steps: 4 realtime_buffer_warmup_steps: 2 -realtime_catchup_enabled: false -realtime_catchup_trigger_steps: 6 -realtime_catchup_release_steps: 3 -realtime_catchup_target_delay_s: 0.04 -reference_qpos_smoothing_alpha: 1.0 reference_velocity_smoothing_alpha: 0.35 reference_anchor_velocity_smoothing_alpha: 0.25 reference_steps: [0] diff --git a/teleopit/configs/sim2real.yaml b/teleopit/configs/sim2real.yaml index 27c99e39..14a7500b 100644 --- a/teleopit/configs/sim2real.yaml +++ b/teleopit/configs/sim2real.yaml @@ -6,15 +6,10 @@ defaults: policy_hz: 50.0 transition_duration: 2.0 # seconds to interpolate from default pose to motion command (0.0 to disable) -pause_resume_transition_duration: 1.0 # offline playback resume blend; realtime mocap resumes from re-centered live tracking -pause_resume_warmup_steps: 2 # realtime mocap frames to accumulate before resuming live tracking retarget_buffer_enabled: true retarget_buffer_window_s: 0.5 retarget_buffer_delay_s: null # null = auto use one input-frame delay for timeline sampling -realtime_buffer_low_watermark_steps: 2 -realtime_buffer_high_watermark_steps: 4 realtime_buffer_warmup_steps: 2 -reference_qpos_smoothing_alpha: 1.0 reference_velocity_smoothing_alpha: 0.35 reference_anchor_velocity_smoothing_alpha: 0.25 reference_steps: [0] diff --git a/teleopit/controllers/qpos_interpolator.py b/teleopit/controllers/qpos_interpolator.py index da1af66e..15c8e79d 100644 --- a/teleopit/controllers/qpos_interpolator.py +++ b/teleopit/controllers/qpos_interpolator.py @@ -109,31 +109,3 @@ def apply(self, target_qpos: NDArray) -> NDArray: # Joints: lerp result[7:] = (1.0 - alpha) * self._start_qpos[7:] + alpha * target_qpos[7:] return result - - -class QposLowPassFilter: - """Low-pass filter for retargeted qpos.""" - - def __init__(self, alpha: float) -> None: - alpha_f = float(alpha) - if not np.isfinite(alpha_f) or alpha_f <= 0.0 or alpha_f > 1.0: - raise ValueError(f"alpha must be finite and in (0, 1], got {alpha}") - self._alpha = alpha_f - self._state: NDArray | None = None - - def reset(self) -> None: - self._state = None - - def apply(self, target_qpos: NDArray) -> NDArray: - target = np.asarray(target_qpos, dtype=np.float64).reshape(-1) - if self._state is None or self._state.shape != target.shape or self._alpha >= 1.0 - 1e-6: - self._state = target.copy() - return self._state.copy() - - alpha = float(self._alpha) - filtered = np.empty_like(target) - filtered[0:3] = (1.0 - alpha) * self._state[0:3] + alpha * target[0:3] - filtered[3:7] = _slerp(self._state[3:7], target[3:7], alpha) - filtered[7:] = (1.0 - alpha) * self._state[7:] + alpha * target[7:] - self._state = filtered - return filtered.copy() diff --git a/teleopit/runtime/factory.py b/teleopit/runtime/factory.py index 970b8a13..3f960afd 100644 --- a/teleopit/runtime/factory.py +++ b/teleopit/runtime/factory.py @@ -38,24 +38,13 @@ def build_simulation_cfg(cfg: Any) -> dict[str, object]: "policy_hz": float(cfg_get(cfg, "policy_hz", 50.0)), "pd_hz": float(cfg_get(cfg, "pd_hz", 1000.0)), "transition_duration": float(cfg_get(cfg, "transition_duration", 0.0) or 0.0), - "pause_resume_transition_duration": float( - cfg_get(cfg, "pause_resume_transition_duration", cfg_get(cfg, "transition_duration", 0.0)) or 0.0 - ), - "pause_resume_warmup_steps": cfg_get(cfg, "pause_resume_warmup_steps", None), "retarget_buffer_enabled": bool(cfg_get(cfg, "retarget_buffer_enabled", True)), "retarget_buffer_window_s": float(cfg_get(cfg, "retarget_buffer_window_s", 0.5)), "retarget_buffer_delay_s": cfg_get(cfg, "retarget_buffer_delay_s", None), "reference_steps": cfg_get(cfg, "reference_steps", [0]), "reference_debug_log": bool(cfg_get(cfg, "reference_debug_log", False)), "realtime_input_delay_s": cfg_get(cfg, "realtime_input_delay_s", None), - "realtime_buffer_low_watermark_steps": cfg_get(cfg, "realtime_buffer_low_watermark_steps", None), - "realtime_buffer_high_watermark_steps": cfg_get(cfg, "realtime_buffer_high_watermark_steps", None), "realtime_buffer_warmup_steps": cfg_get(cfg, "realtime_buffer_warmup_steps", None), - "realtime_catchup_enabled": bool(cfg_get(cfg, "realtime_catchup_enabled", False)), - "realtime_catchup_trigger_steps": cfg_get(cfg, "realtime_catchup_trigger_steps", None), - "realtime_catchup_release_steps": cfg_get(cfg, "realtime_catchup_release_steps", None), - "realtime_catchup_target_delay_s": cfg_get(cfg, "realtime_catchup_target_delay_s", None), - "reference_qpos_smoothing_alpha": float(cfg_get(cfg, "reference_qpos_smoothing_alpha", 1.0)), "reference_velocity_smoothing_alpha": float(cfg_get(cfg, "reference_velocity_smoothing_alpha", 1.0)), "reference_anchor_velocity_smoothing_alpha": float( cfg_get(cfg, "reference_anchor_velocity_smoothing_alpha", 1.0) diff --git a/teleopit/runtime/reference_config.py b/teleopit/runtime/reference_config.py index 468e29c6..a84f7f23 100644 --- a/teleopit/runtime/reference_config.py +++ b/teleopit/runtime/reference_config.py @@ -13,7 +13,6 @@ cfg_get, parse_alpha, parse_nonnegative_int, - parse_optional_nonnegative_int, ) @@ -23,17 +22,9 @@ class ReferenceConfig: retarget_buffer_window_s: float reference_delay_s: float | None reference_debug_log: bool - realtime_buffer_low_watermark_steps: int - realtime_buffer_high_watermark_steps: int | None realtime_buffer_warmup_steps: int - pause_resume_warmup_steps: int - realtime_catchup_enabled: bool - realtime_catchup_trigger_steps: int | None - realtime_catchup_release_steps: int | None - realtime_catchup_target_delay_s: float | None reference_velocity_smoothing_alpha: float reference_anchor_velocity_smoothing_alpha: float - reference_qpos_smoothing_alpha: float def _resolve_delay(cfg: Any, *, provider_fps: float | None) -> float | None: @@ -48,13 +39,6 @@ def _resolve_delay(cfg: Any, *, provider_fps: float | None) -> float | None: return None -def _resolve_catchup_target_delay(cfg: Any) -> float | None: - raw = cfg_get(cfg, "realtime_catchup_target_delay_s", None) - if raw in (None, "", "null"): - return None - return float(raw) - - def parse_reference_config( cfg: Any, *, @@ -80,38 +64,11 @@ def parse_reference_config( reference_debug_log = bool(cfg_get(cfg, "reference_debug_log", False)) reference_delay_s = _resolve_delay(cfg, provider_fps=provider_fps) - low = parse_nonnegative_int( - cfg_get(cfg, "realtime_buffer_low_watermark_steps", 0), - field_name="realtime_buffer_low_watermark_steps", - default=0, - ) - high = parse_optional_nonnegative_int( - cfg_get(cfg, "realtime_buffer_high_watermark_steps", None), - field_name="realtime_buffer_high_watermark_steps", - ) - if high is not None and high < low: - raise ValueError("realtime_buffer_high_watermark_steps must be >= realtime_buffer_low_watermark_steps") - warmup = parse_nonnegative_int( cfg_get(cfg, "realtime_buffer_warmup_steps", 0), field_name="realtime_buffer_warmup_steps", default=0, ) - pause_resume_warmup = parse_nonnegative_int( - cfg_get(cfg, "pause_resume_warmup_steps", warmup), - field_name="pause_resume_warmup_steps", - default=warmup, - ) - catchup_enabled = bool(cfg_get(cfg, "realtime_catchup_enabled", False)) - catchup_trigger = parse_optional_nonnegative_int( - cfg_get(cfg, "realtime_catchup_trigger_steps", None), - field_name="realtime_catchup_trigger_steps", - ) - catchup_release = parse_optional_nonnegative_int( - cfg_get(cfg, "realtime_catchup_release_steps", None), - field_name="realtime_catchup_release_steps", - ) - catchup_target_delay = _resolve_catchup_target_delay(cfg) vel_alpha = parse_alpha( cfg_get(cfg, "reference_velocity_smoothing_alpha", 1.0), @@ -123,26 +80,13 @@ def parse_reference_config( field_name="reference_anchor_velocity_smoothing_alpha", default=1.0, ) - qpos_alpha = parse_alpha( - cfg_get(cfg, "reference_qpos_smoothing_alpha", 1.0), - field_name="reference_qpos_smoothing_alpha", - default=1.0, - ) return ReferenceConfig( retarget_buffer_enabled=retarget_buffer_enabled, retarget_buffer_window_s=retarget_buffer_window_s, reference_delay_s=reference_delay_s, reference_debug_log=reference_debug_log, - realtime_buffer_low_watermark_steps=low, - realtime_buffer_high_watermark_steps=high, realtime_buffer_warmup_steps=warmup, - pause_resume_warmup_steps=pause_resume_warmup, - realtime_catchup_enabled=catchup_enabled, - realtime_catchup_trigger_steps=catchup_trigger, - realtime_catchup_release_steps=catchup_release, - realtime_catchup_target_delay_s=catchup_target_delay, reference_velocity_smoothing_alpha=vel_alpha, reference_anchor_velocity_smoothing_alpha=anchor_vel_alpha, - reference_qpos_smoothing_alpha=qpos_alpha, ) diff --git a/teleopit/sim/loop.py b/teleopit/sim/loop.py index a726c281..9958b110 100644 --- a/teleopit/sim/loop.py +++ b/teleopit/sim/loop.py @@ -84,9 +84,6 @@ def __init__( # Motion command transition smoothing transition_dur = float(self._try_get_cfg("transition_duration") or 0.0) self._mocap_transition_duration = transition_dur - self._pause_resume_transition_duration = float( - self._try_get_cfg("pause_resume_transition_duration") or transition_dur - ) self._qpos_interpolator = QposInterpolator(transition_dur, self.policy_hz) self._init_reference_config() @@ -131,7 +128,6 @@ def _init_components(self, viewers: set[str] | None) -> None: qpos_interpolator=self._qpos_interpolator, reference_velocity_smoothing_alpha=self._ref_cfg.reference_velocity_smoothing_alpha, reference_anchor_velocity_smoothing_alpha=self._ref_cfg.reference_anchor_velocity_smoothing_alpha, - reference_qpos_smoothing_alpha=self._ref_cfg.reference_qpos_smoothing_alpha, ) self._publisher = RuntimePublisher(self.bus) self._recorder_helper = RunRecorder() @@ -306,7 +302,7 @@ def _resume_offline_playback( self._step_runner.last_retarget_qpos = resume_qpos.copy() self._step_runner.arm_motion_transition( resume_qpos, - duration_s=self._pause_resume_transition_duration, + duration_s=self._mocap_transition_duration, ) def _build_observation( @@ -444,8 +440,6 @@ def _write_debug_trace( reference_future_horizon_steps=(None if realtime_reference_diag is None else np.asarray(getattr(realtime_reference_diag, "future_horizon_steps"), dtype=np.int64)), reference_real_frame_count=(None if realtime_reference_diag is None else np.asarray(getattr(realtime_reference_diag, "real_frame_count"), dtype=np.int64)), reference_warmup_done=(None if realtime_reference_diag is None else np.asarray(getattr(realtime_reference_diag, "warmup_done"), dtype=np.bool_)), - reference_used_repeat_padding=(None if realtime_reference_diag is None else np.asarray(getattr(realtime_reference_diag, "used_repeat_padding"), dtype=np.bool_)), - reference_padding_active=(None if realtime_reference_diag is None else np.asarray(getattr(realtime_reference_diag, "padding_active"), dtype=np.bool_)), ) def _log_reference_window(self, reference_window: ReferenceWindow, buffer_len: int) -> None: @@ -458,18 +452,3 @@ def _log_reference_window(self, reference_window: ReferenceWindow, buffer_len: i list(reference_window.reference_steps), list(reference_window.modes()), ) - - def _log_repeat_padding( - self, - reference_window: ReferenceWindow, - diagnostics: RealtimeReferenceDiagnostics, - buffer_len: int, - ) -> None: - import logging - - logging.getLogger(__name__).warning( - "Reference timeline repeat padding | buffer_len=%d | future_horizon_steps=%d | steps=%s", - buffer_len, - diagnostics.future_horizon_steps, - list(reference_window.reference_steps), - ) diff --git a/teleopit/sim/realtime_utils.py b/teleopit/sim/realtime_utils.py index 7a51283a..589e1e2f 100644 --- a/teleopit/sim/realtime_utils.py +++ b/teleopit/sim/realtime_utils.py @@ -5,12 +5,7 @@ import numpy as np from numpy.typing import NDArray -from teleopit.sim.reference_timeline import ( - ReferenceSample, - ReferenceTimeline, - ReferenceWindow, - ReferenceWindowBuilder, -) +from teleopit.sim.reference_timeline import ReferenceTimeline, ReferenceWindow, ReferenceWindowBuilder Float32Array = NDArray[np.float32] @@ -44,13 +39,9 @@ class RealtimeReferenceDiagnostics: future_horizon_steps: int real_frame_count: int warmup_done: bool - used_repeat_padding: bool - padding_active: bool requested_base_time_s: float effective_base_time_s: float latest_timestamp_s: float | None - used_catchup: bool - catchup_active: bool class RealtimeReferenceManager: @@ -58,62 +49,13 @@ def __init__( self, *, reference_window_builder: ReferenceWindowBuilder, - low_watermark_steps: int = 0, - high_watermark_steps: int | None = None, warmup_steps: int = 0, - catchup_enabled: bool = False, - catchup_trigger_steps: int | None = None, - catchup_release_steps: int | None = None, - catchup_target_delay_s: float | None = None, ) -> None: self._builder = reference_window_builder - self._low_watermark_steps = max(int(low_watermark_steps), self._builder.max_future_step) - if high_watermark_steps is None: - self._high_watermark_steps = self._low_watermark_steps - else: - self._high_watermark_steps = int(high_watermark_steps) - if self._low_watermark_steps < 0: - raise ValueError("low_watermark_steps must be >= 0") - if self._high_watermark_steps < self._low_watermark_steps: - raise ValueError( - "high_watermark_steps must be >= low_watermark_steps, " - f"got {self._high_watermark_steps} < {self._low_watermark_steps}" - ) self._warmup_steps = max(int(warmup_steps), 0) - self._catchup_enabled = bool(catchup_enabled) - self._catchup_trigger_steps = ( - max(int(catchup_trigger_steps), 0) - if catchup_trigger_steps is not None - else max(self._high_watermark_steps, self._low_watermark_steps + 2) - ) - self._catchup_release_steps = ( - max(int(catchup_release_steps), 0) - if catchup_release_steps is not None - else max(self._low_watermark_steps, self._builder.max_future_step) - ) - if self._catchup_release_steps > self._catchup_trigger_steps: - raise ValueError( - "catchup_release_steps must be <= catchup_trigger_steps, " - f"got {self._catchup_release_steps} > {self._catchup_trigger_steps}" - ) - self._catchup_target_delay_s = ( - float(catchup_target_delay_s) - if catchup_target_delay_s is not None - else float(self._builder.policy_dt_s * max(self._builder.max_future_step, 1)) - ) - if ( - not np.isfinite(self._catchup_target_delay_s) - or self._catchup_target_delay_s < 0.0 - ): - raise ValueError( - "catchup_target_delay_s must be finite and >= 0, " - f"got {catchup_target_delay_s}" - ) self.reset() def reset(self) -> None: - self._padding_active = False - self._catchup_active = False self._real_frame_count = 0 def set_warmup_steps(self, warmup_steps: int) -> None: @@ -148,115 +90,14 @@ def sample( ) -> tuple[ReferenceWindow, RealtimeReferenceDiagnostics]: requested_base_time_s = float(base_time_s) latest_timestamp = timeline.latest_timestamp() - requested_future_horizon_steps = self.future_horizon_steps(timeline, requested_base_time_s) - effective_base_time_s, used_catchup = self._apply_catchup( - requested_base_time_s, - latest_timestamp, - requested_future_horizon_steps, - ) + effective_base_time_s = requested_base_time_s window = self._builder.sample(timeline, effective_base_time_s) future_horizon_steps = self.future_horizon_steps(timeline, effective_base_time_s) - self._update_padding_state(future_horizon_steps) - target_padding_horizon = ( - self._high_watermark_steps - if self._padding_active - else self._builder.max_future_step - ) - padded_window, used_repeat_padding = self._pad_future_window( - timeline, - window, - effective_base_time_s, - target_padding_horizon, - ) - return padded_window, RealtimeReferenceDiagnostics( + return window, RealtimeReferenceDiagnostics( future_horizon_steps=future_horizon_steps, real_frame_count=self._real_frame_count, warmup_done=self.warmup_done, - used_repeat_padding=used_repeat_padding, - padding_active=self._padding_active, requested_base_time_s=requested_base_time_s, effective_base_time_s=effective_base_time_s, latest_timestamp_s=None if latest_timestamp is None else float(latest_timestamp), - used_catchup=used_catchup, - catchup_active=self._catchup_active, - ) - - def _apply_catchup( - self, - requested_base_time_s: float, - latest_timestamp: float | None, - requested_future_horizon_steps: int, - ) -> tuple[float, bool]: - if not self._catchup_enabled or latest_timestamp is None: - self._catchup_active = False - return requested_base_time_s, False - - if self._catchup_active: - if requested_future_horizon_steps <= self._catchup_release_steps: - self._catchup_active = False - elif requested_future_horizon_steps >= self._catchup_trigger_steps: - self._catchup_active = True - - if not self._catchup_active: - return requested_base_time_s, False - - target_base_time_s = max( - requested_base_time_s, - float(latest_timestamp) - self._catchup_target_delay_s, ) - used_catchup = target_base_time_s > requested_base_time_s + 1e-9 - return target_base_time_s, used_catchup - - def _update_padding_state(self, future_horizon_steps: int) -> None: - if self._high_watermark_steps <= 0: - self._padding_active = False - return - if future_horizon_steps <= self._low_watermark_steps: - self._padding_active = True - elif future_horizon_steps >= self._high_watermark_steps: - self._padding_active = False - - def _pad_future_window( - self, - timeline: ReferenceTimeline, - window: ReferenceWindow, - base_time_s: float, - target_padding_horizon: int, - ) -> tuple[ReferenceWindow, bool]: - latest_timestamp = timeline.latest_timestamp() - if latest_timestamp is None or target_padding_horizon <= 0: - return window, False - - latest_sample = timeline.sample_at(latest_timestamp) - latest_ts = float(latest_timestamp) - policy_dt_s = self._builder.policy_dt_s - samples: list[ReferenceSample] = [] - used_repeat_padding = False - - for step, sample in zip(window.reference_steps, window.samples): - target_time_s = float(base_time_s) + float(step) * policy_dt_s - if step > 0 and step <= target_padding_horizon and target_time_s > latest_ts + 1e-9: - used_repeat_padding = True - samples.append( - ReferenceSample( - qpos=np.asarray(latest_sample.qpos, dtype=np.float64).copy(), - timestamp_s=target_time_s, - mode="repeat_latest", - used_fallback=False, - older_timestamp_s=latest_ts, - newer_timestamp_s=latest_ts, - alpha=None, - ) - ) - else: - samples.append(sample) - - if not used_repeat_padding: - return window, False - - return ReferenceWindow( - base_time_s=float(window.base_time_s), - policy_dt_s=float(window.policy_dt_s), - reference_steps=tuple(window.reference_steps), - samples=tuple(samples), - ), True diff --git a/teleopit/sim/runtime_components.py b/teleopit/sim/runtime_components.py index 91e83ff4..094ac873 100644 --- a/teleopit/sim/runtime_components.py +++ b/teleopit/sim/runtime_components.py @@ -14,7 +14,7 @@ from teleopit.constants import FULL_QPOS_DIM, NUM_JOINTS, ROOT_DIM from teleopit.bus.topics import TOPIC_ACTION, TOPIC_MIMIC_OBS, TOPIC_ROBOT_STATE from teleopit.controllers.observation import VelCmdObservationBuilder -from teleopit.controllers.qpos_interpolator import QposInterpolator, QposLowPassFilter +from teleopit.controllers.qpos_interpolator import QposInterpolator from teleopit.controllers import reference_processing as ref_proc from teleopit.interfaces import MessageBus, ObservationBuilder, Recorder, Robot, RobotState from teleopit.retargeting.core import extract_mimic_obs @@ -97,7 +97,6 @@ def __init__( qpos_interpolator: QposInterpolator, reference_velocity_smoothing_alpha: float = 1.0, reference_anchor_velocity_smoothing_alpha: float = 1.0, - reference_qpos_smoothing_alpha: float = 1.0, ) -> None: self.robot = robot self.controller = controller @@ -113,7 +112,6 @@ def __init__( self._motion_joint_vel_smoother = ExponentialVecSmoother(reference_velocity_smoothing_alpha) self._motion_anchor_lin_vel_smoother = ExponentialVecSmoother(reference_anchor_velocity_smoothing_alpha) self._motion_anchor_ang_vel_smoother = ExponentialVecSmoother(reference_anchor_velocity_smoothing_alpha) - self._reference_qpos_smoother = QposLowPassFilter(reference_qpos_smoothing_alpha) self.last_action: Float32Array = np.zeros((self.num_actions,), dtype=np.float32) self.last_retarget_qpos: Float64Array | None = None self.last_reference_qpos: Float64Array | None = None @@ -135,7 +133,6 @@ def reset(self) -> None: self._motion_joint_vel_smoother.reset() self._motion_anchor_lin_vel_smoother.reset() self._motion_anchor_ang_vel_smoother.reset() - self._reference_qpos_smoother.reset() self.qpos_interpolator.reset() def soft_reset_reference_state(self, *, reset_alignment: bool = True) -> None: @@ -149,7 +146,6 @@ def soft_reset_reference_state(self, *, reset_alignment: bool = True) -> None: self._motion_joint_vel_smoother.reset() self._motion_anchor_lin_vel_smoother.reset() self._motion_anchor_ang_vel_smoother.reset() - self._reference_qpos_smoother.reset() def reset_reference_alignment(self, target_qpos: Float64Array | None = None) -> None: self._fixed_reference_yaw_quat = None @@ -191,7 +187,6 @@ def prepare_static_motion_command(self, qpos: Float64Array) -> MotionPreparation def prepare_motion_command(self, retargeted: object, state: object) -> MotionPreparation: reference_qpos = self._retarget_to_qpos(retargeted) reference_qpos = self._align_velcmd_reference_yaw(reference_qpos, state) - reference_qpos = self._reference_qpos_smoother.apply(reference_qpos) self._pending_reference_qpos = reference_qpos.copy() if self.last_retarget_qpos is None and self.qpos_interpolator.duration > 0: diff --git a/teleopit/sim/session.py b/teleopit/sim/session.py index 703430fd..f963f308 100644 --- a/teleopit/sim/session.py +++ b/teleopit/sim/session.py @@ -148,13 +148,7 @@ def __init__( if self.reference_timeline is not None: self.realtime_reference_manager = RealtimeReferenceManager( reference_window_builder=loop._reference_window_builder, - low_watermark_steps=ref_cfg.realtime_buffer_low_watermark_steps, - high_watermark_steps=ref_cfg.realtime_buffer_high_watermark_steps, warmup_steps=ref_cfg.realtime_buffer_warmup_steps, - catchup_enabled=ref_cfg.realtime_catchup_enabled, - catchup_trigger_steps=ref_cfg.realtime_catchup_trigger_steps, - catchup_release_steps=ref_cfg.realtime_catchup_release_steps, - catchup_target_delay_s=ref_cfg.realtime_catchup_target_delay_s, ) # Realtime live-frame tracking @@ -217,14 +211,12 @@ def __init__( # State-management methods (formerly closures with nonlocal) # ------------------------------------------------------------------ - def reset_runtime_tracking(self, *, warmup_steps: int | None = None) -> None: + def reset_runtime_tracking(self) -> None: ref_cfg = self._loop._ref_cfg if self.reference_timeline is not None: self.reference_timeline.clear() if self.realtime_reference_manager is not None: - self.realtime_reference_manager.set_warmup_steps( - ref_cfg.realtime_buffer_warmup_steps if warmup_steps is None else warmup_steps - ) + self.realtime_reference_manager.set_warmup_steps(ref_cfg.realtime_buffer_warmup_steps) self.realtime_reference_manager.reset() self.previous_live_human_frame = None self.previous_live_retargeted = None @@ -236,14 +228,14 @@ def reset_runtime_tracking(self, *, warmup_steps: int | None = None) -> None: self.cached_human_frame = None self.cached_retargeted = None - def full_policy_reset(self, *, warmup_steps: int | None = None) -> None: + def full_policy_reset(self) -> None: self._step_runner.reset() self._loop.controller.reset() self._loop.obs_builder.reset() self._retargeter.reset() self.mocap_session.reset() self.last_commanded_motion_qpos = None - self.reset_runtime_tracking(warmup_steps=warmup_steps) + self.reset_runtime_tracking() def enter_standing_mode(self) -> None: from teleopit.sim.loop import SimulationMode @@ -274,7 +266,7 @@ def toggle_realtime_mocap_pause(self) -> None: if hold_qpos is None: raise RuntimeError("Cannot resume mocap without a paused hold qpos") resume_qpos = loop._build_resume_alignment_qpos(hold_qpos, loop.robot.get_state()) - self.full_policy_reset(warmup_steps=loop._ref_cfg.pause_resume_warmup_steps) + self.full_policy_reset() self._step_runner.reset_reference_alignment(resume_qpos) self.last_commanded_motion_qpos = resume_qpos.copy() return @@ -485,11 +477,8 @@ def _fetch_realtime_input(self) -> tuple[bool, ReferenceWindow | None, RealtimeR self.reference_timeline, target_base_time, ) - if loop._ref_cfg.reference_debug_log: - if any(reference_window.fallback_mask()): - loop._log_reference_window(reference_window, len(self.reference_timeline)) - if realtime_reference_diag.used_repeat_padding: - loop._log_repeat_padding(reference_window, realtime_reference_diag, len(self.reference_timeline)) + if loop._ref_cfg.reference_debug_log and any(reference_window.fallback_mask()): + loop._log_reference_window(reference_window, len(self.reference_timeline)) self.cached_retargeted = reference_window.current_sample().qpos else: if self.latest_live_retargeted is None: diff --git a/teleopit/sim2real/controller.py b/teleopit/sim2real/controller.py index 0301d65b..5435e914 100644 --- a/teleopit/sim2real/controller.py +++ b/teleopit/sim2real/controller.py @@ -82,9 +82,6 @@ def __init__(self, cfg: Any) -> None: # Motion command transition smoothing transition_dur = float(cfg_get(cfg, "transition_duration", 0.0) or 0.0) self._mocap_transition_duration = transition_dur - self._pause_resume_transition_duration = float( - cfg_get(cfg, "pause_resume_transition_duration", transition_dur) or 0.0 - ) self._qpos_interpolator = QposInterpolator(transition_dur, self.policy_hz) self._init_components(cfg) @@ -184,13 +181,7 @@ def _init_reference_config(self, cfg: Any) -> None: self._reference_manager: RealtimeReferenceManager | None = ( RealtimeReferenceManager( reference_window_builder=self._reference_window_builder, - low_watermark_steps=rc.realtime_buffer_low_watermark_steps, - high_watermark_steps=rc.realtime_buffer_high_watermark_steps, warmup_steps=rc.realtime_buffer_warmup_steps, - catchup_enabled=rc.realtime_catchup_enabled, - catchup_trigger_steps=rc.realtime_catchup_trigger_steps, - catchup_release_steps=rc.realtime_catchup_release_steps, - catchup_target_delay_s=rc.realtime_catchup_target_delay_s, ) if self._reference_timeline is not None else None @@ -203,7 +194,6 @@ def _init_reference_config(self, cfg: Any) -> None: num_actions=self.num_actions, reference_velocity_smoothing_alpha=rc.reference_velocity_smoothing_alpha, reference_anchor_velocity_smoothing_alpha=rc.reference_anchor_velocity_smoothing_alpha, - reference_qpos_smoothing_alpha=rc.reference_qpos_smoothing_alpha, max_pos_value=float(cfg_get(mocap_sw, "max_position_value", 5.0)), ) self._last_live_packet_seq = -1 @@ -356,21 +346,6 @@ def _mocap_step(self) -> None: list(reference_window.reference_steps), list(reference_window.modes()), ) - if self._ref_cfg.reference_debug_log and reference_diag.used_repeat_padding: - logger.warning( - "Reference timeline repeat padding | buffer_len=%d | future_horizon_steps=%d | steps=%s", - len(self._reference_timeline), - reference_diag.future_horizon_steps, - list(reference_window.reference_steps), - ) - if self._ref_cfg.reference_debug_log and reference_diag.used_catchup: - logger.warning( - "Reference timeline catch-up | requested_base=%.6f | effective_base=%.6f | latest=%.6f | future_horizon_steps=%d", - reference_diag.requested_base_time_s, - reference_diag.effective_base_time_s, - -1.0 if reference_diag.latest_timestamp_s is None else reference_diag.latest_timestamp_s, - reference_diag.future_horizon_steps, - ) reference_qpos = reference_window.current_sample().qpos else: retargeted = self.retargeter.retarget(human_frame) @@ -416,9 +391,8 @@ def _execute_mocap_pipeline( robot_state: object, reference_window: ReferenceWindow | None, ) -> None: - """Shared mocap control pipeline: align → smooth → interpolate → infer → send.""" + """Shared mocap control pipeline: align → interpolate → infer → send.""" reference_qpos = self._ref_proc.align_reference_yaw(reference_qpos, robot_state=robot_state) - reference_qpos = self._ref_proc.apply_qpos_smoothing(reference_qpos) qpos = self._qpos_interpolator.apply(reference_qpos) # Compute joint velocities via finite difference @@ -700,7 +674,7 @@ def _reset_policy_state(self) -> None: self.obs_builder.reset() self.retargeter.reset() - def _reset_mocap_reference_state(self, *, warmup_steps: int | None = None) -> None: + def _reset_mocap_reference_state(self) -> None: """Reset mocap-specific reference state without disrupting policy observation continuity. Unlike ``_reset_policy_state``, this preserves ``_last_action``, the @@ -710,9 +684,7 @@ def _reset_mocap_reference_state(self, *, warmup_steps: int | None = None) -> No if self._reference_timeline is not None: self._reference_timeline.clear() if self._reference_manager is not None: - self._reference_manager.set_warmup_steps( - self._ref_cfg.realtime_buffer_warmup_steps if warmup_steps is None else warmup_steps - ) + self._reference_manager.set_warmup_steps(self._ref_cfg.realtime_buffer_warmup_steps) self._reference_manager.reset() self._ref_proc.reset_smoothers() self._last_live_packet_seq = -1 @@ -827,14 +799,10 @@ def _resume_paused_mocap(self) -> None: self._last_commanded_motion_qpos = resume_qpos.copy() # Override warmup steps for the resume-specific buffer warmup. - if self._reference_manager is not None: - self._reference_manager.set_warmup_steps(self._ref_cfg.pause_resume_warmup_steps) - self._reference_manager.reset() - self._ref_proc.reset_alignment(target_qpos=resume_qpos) if self._offline_playback is not None: self._last_retarget_qpos = resume_qpos.copy() - self._arm_qpos_transition(resume_qpos, duration_s=self._pause_resume_transition_duration) + self._arm_qpos_transition(resume_qpos, duration_s=self._mocap_transition_duration) self._offline_playback.resume() logger.info("Mocap session -> ACTIVE (episode-reset + reference realignment)") diff --git a/teleopit/sim2real/reference_processor.py b/teleopit/sim2real/reference_processor.py index e0dc4841..bfbf72c8 100644 --- a/teleopit/sim2real/reference_processor.py +++ b/teleopit/sim2real/reference_processor.py @@ -15,7 +15,6 @@ from teleopit.controllers import reference_processing as ref_proc from teleopit.controllers.observation import VelCmdObservationBuilder -from teleopit.controllers.qpos_interpolator import QposLowPassFilter from teleopit.sim.realtime_utils import ExponentialVecSmoother from teleopit.sim.reference_timeline import ReferenceWindow @@ -35,7 +34,6 @@ def __init__( num_actions: int, reference_velocity_smoothing_alpha: float, reference_anchor_velocity_smoothing_alpha: float, - reference_qpos_smoothing_alpha: float, max_pos_value: float, ) -> None: self._obs_builder = obs_builder @@ -54,7 +52,6 @@ def __init__( self._motion_joint_vel_smoother = ExponentialVecSmoother(reference_velocity_smoothing_alpha) self._motion_anchor_lin_vel_smoother = ExponentialVecSmoother(reference_anchor_velocity_smoothing_alpha) self._motion_anchor_ang_vel_smoother = ExponentialVecSmoother(reference_anchor_velocity_smoothing_alpha) - self._reference_qpos_smoother = QposLowPassFilter(reference_qpos_smoothing_alpha) # Last reference qpos for velocity computation self._last_reference_qpos: Float64Array | None = None @@ -201,9 +198,6 @@ def validate_observation(self, obs: Float32Array) -> Float32Array: # Smoothing # ------------------------------------------------------------------ - def apply_qpos_smoothing(self, qpos: Float64Array) -> Float64Array: - return self._reference_qpos_smoother.apply(qpos) - def apply_joint_vel_smoothing(self, vel: Float32Array) -> Float32Array: return self._motion_joint_vel_smoother.apply(vel) @@ -219,5 +213,4 @@ def reset_smoothers(self) -> None: self._motion_joint_vel_smoother.reset() self._motion_anchor_lin_vel_smoother.reset() self._motion_anchor_ang_vel_smoother.reset() - self._reference_qpos_smoother.reset() self._last_reference_qpos = None diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 31487c52..c586484b 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -82,13 +82,10 @@ def __init__(self, *args: object, **kwargs: object) -> None: "policy_hz": 50, "pd_hz": 1000, "transition_duration": 1.5, - "pause_resume_transition_duration": 0.75, - "pause_resume_warmup_steps": 3, "keyboard": {"enabled": True}, "retarget_buffer_enabled": True, "retarget_buffer_window_s": 0.75, "retarget_buffer_delay_s": 0.02, - "reference_qpos_smoothing_alpha": 0.4, "reference_steps": [0, 1, -1], "realtime": True, "viewers": ["retarget", "sim2sim"], @@ -109,13 +106,10 @@ def __init__(self, *args: object, **kwargs: object) -> None: assert list(controller_cfg.default_dof_pos) == robot_default_angles assert list(controller_cfg.action_scale) == robot_action_scale assert loop_cfg["transition_duration"] == pytest.approx(1.5) - assert loop_cfg["pause_resume_transition_duration"] == pytest.approx(0.75) - assert loop_cfg["pause_resume_warmup_steps"] == 3 assert loop_cfg["keyboard"]["enabled"] is True assert loop_cfg["retarget_buffer_enabled"] is True assert loop_cfg["retarget_buffer_window_s"] == pytest.approx(0.75) assert loop_cfg["retarget_buffer_delay_s"] == pytest.approx(0.02) - assert loop_cfg["reference_qpos_smoothing_alpha"] == pytest.approx(0.4) assert list(loop_cfg["reference_steps"]) == [0, 1, -1] assert loop_cfg["realtime"] is True assert captured["loop_kwargs"]["viewers"] == {"retarget", "sim2sim"} diff --git a/tests/test_realtime_utils.py b/tests/test_realtime_utils.py index 99a10c72..29649436 100644 --- a/tests/test_realtime_utils.py +++ b/tests/test_realtime_utils.py @@ -30,8 +30,6 @@ def test_exponential_vec_smoother_blends_and_resets() -> None: def test_realtime_reference_manager_warmup_counts_real_frames() -> None: manager = RealtimeReferenceManager( reference_window_builder=ReferenceWindowBuilder(policy_dt_s=0.02, reference_steps=[0, 1, 2]), - low_watermark_steps=1, - high_watermark_steps=3, warmup_steps=2, ) @@ -42,15 +40,13 @@ def test_realtime_reference_manager_warmup_counts_real_frames() -> None: assert manager.warmup_done is True -def test_realtime_reference_manager_repeat_pads_future_window() -> None: +def test_realtime_reference_manager_samples_without_padding_or_catchup() -> None: timeline = ReferenceTimeline(window_s=1.0) timeline.append(_qpos(0.0), 0.0) timeline.append(_qpos(1.0), 0.02) manager = RealtimeReferenceManager( reference_window_builder=ReferenceWindowBuilder(policy_dt_s=0.02, reference_steps=[0, 1, 2]), - low_watermark_steps=1, - high_watermark_steps=2, warmup_steps=0, ) manager.note_realtime_frame() @@ -59,71 +55,7 @@ def test_realtime_reference_manager_repeat_pads_future_window() -> None: window, diagnostics = manager.sample(timeline, 0.02) assert diagnostics.future_horizon_steps == 0 - assert diagnostics.used_repeat_padding is True - assert diagnostics.padding_active is True np.testing.assert_allclose(window.current_sample().qpos[0], 1.0, atol=1e-6) - assert window.samples[1].mode == 'repeat_latest' - assert window.samples[2].mode == 'repeat_latest' - np.testing.assert_allclose(window.samples[1].timestamp_s, 0.04, atol=1e-6) - np.testing.assert_allclose(window.samples[2].timestamp_s, 0.06, atol=1e-6) - - -def test_realtime_reference_manager_defaults_high_watermark_to_effective_low() -> None: - manager = RealtimeReferenceManager( - reference_window_builder=ReferenceWindowBuilder(policy_dt_s=0.02, reference_steps=[0, 1, 2, 3, 4]), - low_watermark_steps=0, - high_watermark_steps=None, - warmup_steps=0, - ) - - assert manager._low_watermark_steps == 4 - assert manager._high_watermark_steps == 4 - - -def test_realtime_reference_manager_catchup_advances_base_time() -> None: - timeline = ReferenceTimeline(window_s=1.0) - for idx in range(11): - timeline.append(_qpos(float(idx)), idx * 0.02) - - manager = RealtimeReferenceManager( - reference_window_builder=ReferenceWindowBuilder(policy_dt_s=0.02, reference_steps=[0]), - low_watermark_steps=2, - high_watermark_steps=4, - warmup_steps=0, - catchup_enabled=True, - catchup_trigger_steps=6, - catchup_release_steps=3, - catchup_target_delay_s=0.04, - ) - - window, diagnostics = manager.sample(timeline, 0.00) - - assert diagnostics.used_catchup is True - assert diagnostics.catchup_active is True - np.testing.assert_allclose(diagnostics.effective_base_time_s, 0.16, atol=1e-6) - np.testing.assert_allclose(window.current_sample().qpos[0], 8.0, atol=1e-6) - - -def test_realtime_reference_manager_catchup_releases_with_hysteresis() -> None: - timeline = ReferenceTimeline(window_s=1.0) - for idx in range(11): - timeline.append(_qpos(float(idx)), idx * 0.02) - - manager = RealtimeReferenceManager( - reference_window_builder=ReferenceWindowBuilder(policy_dt_s=0.02, reference_steps=[0]), - low_watermark_steps=2, - high_watermark_steps=4, - warmup_steps=0, - catchup_enabled=True, - catchup_trigger_steps=6, - catchup_release_steps=3, - catchup_target_delay_s=0.04, - ) - - _, first_diag = manager.sample(timeline, 0.00) - _, second_diag = manager.sample(timeline, 0.18) - - assert first_diag.catchup_active is True - assert second_diag.used_catchup is False - assert second_diag.catchup_active is False - np.testing.assert_allclose(second_diag.effective_base_time_s, 0.18, atol=1e-6) + assert window.samples[1].mode == "fallback_latest" + assert window.samples[2].mode == "fallback_latest" + np.testing.assert_allclose(diagnostics.effective_base_time_s, 0.02, atol=1e-6) diff --git a/tests/test_sim2real_runtime.py b/tests/test_sim2real_runtime.py index 8ec74241..8e2b92a9 100644 --- a/tests/test_sim2real_runtime.py +++ b/tests/test_sim2real_runtime.py @@ -493,7 +493,7 @@ def test_mocap_step_waits_for_realtime_warmup_before_running_policy(monkeypatch) assert len(ctrl.robot.sent_positions) == 1 -def test_sim2real_allows_future_reference_steps_without_explicit_high_watermark(monkeypatch) -> None: +def test_sim2real_allows_future_reference_steps(monkeypatch) -> None: from teleopit.sim2real.controller import Sim2RealController policy = DummyPolicy() @@ -504,15 +504,11 @@ def test_sim2real_allows_future_reference_steps_without_explicit_high_watermark( cfg["reference_steps"] = [0, 1, 2, 3, 4] cfg["retarget_buffer_delay_s"] = 0.08 cfg["retarget_buffer_window_s"] = 0.5 - cfg["realtime_buffer_low_watermark_steps"] = 0 - ctrl = Sim2RealController(cfg) - - assert ctrl._ref_cfg.realtime_buffer_low_watermark_steps == 0 - assert ctrl._ref_cfg.realtime_buffer_high_watermark_steps is None + Sim2RealController(cfg) -def test_mocap_step_reference_qpos_smoothing_filters_motion_change(monkeypatch) -> None: +def test_mocap_step_uses_current_reference_qpos(monkeypatch) -> None: from teleopit.sim2real.controller import Sim2RealController policy = DummyPolicy() @@ -523,7 +519,6 @@ def test_mocap_step_reference_qpos_smoothing_filters_motion_change(monkeypatch) cfg = _make_cfg(transition_duration=0.0) cfg["retarget_buffer_enabled"] = False - cfg["reference_qpos_smoothing_alpha"] = 0.5 ctrl = Sim2RealController(cfg) monkeypatch.setattr( ctrl._ref_proc, @@ -540,7 +535,7 @@ def test_mocap_step_reference_qpos_smoothing_filters_motion_change(monkeypatch) assert len(obs_builder.build_calls) == 2 np.testing.assert_allclose(obs_builder.build_calls[0]["motion_qpos"][0], 0.0, atol=1e-6) - np.testing.assert_allclose(obs_builder.build_calls[1]["motion_qpos"][0], 0.5, atol=1e-6) + np.testing.assert_allclose(obs_builder.build_calls[1]["motion_qpos"][0], 1.0, atol=1e-6) def test_mocap_pause_freezes_reference_and_zeroes_velocities(monkeypatch) -> None: @@ -606,8 +601,6 @@ def test_mocap_resume_uses_episode_reset_semantics(monkeypatch) -> None: cfg = _make_cfg(transition_duration=0.0) cfg["retarget_buffer_enabled"] = False - cfg["pause_resume_transition_duration"] = 1.0 - cfg["pause_resume_warmup_steps"] = 0 ctrl = Sim2RealController(cfg) monkeypatch.setattr( ctrl._ref_proc, diff --git a/tests/test_sim_loop.py b/tests/test_sim_loop.py index e3ca6d31..56408951 100644 --- a/tests/test_sim_loop.py +++ b/tests/test_sim_loop.py @@ -326,7 +326,7 @@ def get_frame_packet(self): @requires_mujoco -def test_simulation_loop_allows_future_reference_steps_without_explicit_high_watermark() -> None: +def test_simulation_loop_allows_future_reference_steps() -> None: from teleopit.sim.loop import SimulationLoop bus = InProcessBus() @@ -344,14 +344,10 @@ def test_simulation_loop_allows_future_reference_steps_without_explicit_high_wat "reference_steps": [0, 1, 2, 3, 4], "retarget_buffer_delay_s": 0.08, "retarget_buffer_window_s": 0.5, - "realtime_buffer_low_watermark_steps": 0, }, viewers=set(), ) - assert loop._ref_cfg.realtime_buffer_low_watermark_steps == 0 - assert loop._ref_cfg.realtime_buffer_high_watermark_steps is None - class _RealtimeInputProvider: fps = 50 @@ -467,8 +463,6 @@ def get_realtime_input_packet(self): "transition_duration": 0.0, "retarget_buffer_enabled": False, "realtime_input_delay_s": 0.0, - "pause_resume_transition_duration": 1.0, - "pause_resume_warmup_steps": 0, }, viewers=set(), ) From 7a99b6c30787f301485be3e76df04148880a7f3a Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Fri, 15 May 2026 16:13:40 +0800 Subject: [PATCH 016/122] Update tracking domain randomization --- tests/test_domain_randomization.py | 54 +++++++++++++++++-- train_mimic/tasks/tracking/config/env.py | 4 ++ .../tasks/tracking/tracking_env_cfg.py | 49 ++++++++++++++++- 3 files changed, 100 insertions(+), 7 deletions(-) diff --git a/tests/test_domain_randomization.py b/tests/test_domain_randomization.py index dad68089..16a4d080 100644 --- a/tests/test_domain_randomization.py +++ b/tests/test_domain_randomization.py @@ -17,6 +17,10 @@ def test_general_tracking_domain_randomization_matches_gr00t_active_set() -> Non "push_robot", "base_com", "encoder_bias", + "add_joint_default_pos", + "motor_params_implicit_upper_body_pd", + "motor_params_implicit_lower_body_pd", + "motor_params_implicit_armature", "physics_material", "randomize_rigid_body_mass", } @@ -24,7 +28,7 @@ def test_general_tracking_domain_randomization_matches_gr00t_active_set() -> Non push_robot = events["push_robot"] assert push_robot.func is mdp.push_by_setting_velocity assert push_robot.mode == "interval" - assert push_robot.interval_range_s == (1.0, 3.0) + assert push_robot.interval_range_s == (4.0, 6.0) assert push_robot.params["velocity_range"] == { "x": (-0.5, 0.5), "y": (-0.5, 0.5), @@ -50,6 +54,45 @@ def test_general_tracking_domain_randomization_matches_gr00t_active_set() -> Non assert encoder_bias.mode == "startup" assert encoder_bias.params["bias_range"] == (-0.01, 0.01) + add_joint_default_pos = events["add_joint_default_pos"] + assert add_joint_default_pos.func is dr.joint_default_pos + assert add_joint_default_pos.mode == "startup" + assert add_joint_default_pos.params["asset_cfg"].joint_names == ".*" + assert add_joint_default_pos.params["operation"] == "add" + assert add_joint_default_pos.params["ranges"] == (-0.01, 0.01) + + upper_motor_pd = events["motor_params_implicit_upper_body_pd"] + assert upper_motor_pd.func is dr.pd_gains + assert upper_motor_pd.mode == "reset" + assert ( + upper_motor_pd.params["asset_cfg"].actuator_names + == r".*(shoulder|elbow|wrist).*" + ) + assert upper_motor_pd.params["kp_range"] == (0.9, 1.1) + assert upper_motor_pd.params["kd_range"] == (0.9, 1.1) + assert upper_motor_pd.params["distribution"] == "log_uniform" + assert upper_motor_pd.params["operation"] == "scale" + + lower_motor_pd = events["motor_params_implicit_lower_body_pd"] + assert lower_motor_pd.func is dr.pd_gains + assert lower_motor_pd.mode == "reset" + assert ( + lower_motor_pd.params["asset_cfg"].actuator_names + == r".*(waist|hip|knee|ankle).*" + ) + assert lower_motor_pd.params["kp_range"] == (0.5, 2.0) + assert lower_motor_pd.params["kd_range"] == (0.5, 2.0) + assert lower_motor_pd.params["distribution"] == "log_uniform" + assert lower_motor_pd.params["operation"] == "scale" + + motor_armature = events["motor_params_implicit_armature"] + assert motor_armature.func is dr.joint_armature + assert motor_armature.mode == "startup" + assert motor_armature.params["asset_cfg"].joint_names == ".*" + assert motor_armature.params["ranges"] == (0.75, 1.25) + assert motor_armature.params["distribution"] == "log_uniform" + assert motor_armature.params["operation"] == "scale" + physics_material = events["physics_material"] assert physics_material.func is dr.geom_friction assert physics_material.mode == "startup" @@ -61,10 +104,7 @@ def test_general_tracking_domain_randomization_matches_gr00t_active_set() -> Non assert mass.func is dr.pseudo_inertia assert mass.mode == "startup" assert mass.params["asset_cfg"].body_names == r".*wrist_yaw.*|torso_link" - assert mass.params["alpha_range"] == ( - -0.11157177565710488, - 0.4581453659370775, - ) + assert mass.params["alpha_range"] == (-0.1, 0.45) def test_play_env_disables_training_only_domain_randomization() -> None: @@ -77,6 +117,10 @@ def test_play_env_disables_training_only_domain_randomization() -> None: assert "push_robot" not in play_cfg.events assert "base_com" not in play_cfg.events assert "encoder_bias" not in play_cfg.events + assert "add_joint_default_pos" not in play_cfg.events + assert "motor_params_implicit_upper_body_pd" not in play_cfg.events + assert "motor_params_implicit_lower_body_pd" not in play_cfg.events + assert "motor_params_implicit_armature" not in play_cfg.events assert "physics_material" not in play_cfg.events assert "randomize_rigid_body_mass" not in play_cfg.events assert play_cfg.events == {} diff --git a/train_mimic/tasks/tracking/config/env.py b/train_mimic/tasks/tracking/config/env.py index 313fc282..addc0167 100644 --- a/train_mimic/tasks/tracking/config/env.py +++ b/train_mimic/tasks/tracking/config/env.py @@ -39,6 +39,10 @@ "push_robot", "base_com", "encoder_bias", + "add_joint_default_pos", + "motor_params_implicit_upper_body_pd", + "motor_params_implicit_lower_body_pd", + "motor_params_implicit_armature", "physics_material", "randomize_rigid_body_mass", ) diff --git a/train_mimic/tasks/tracking/tracking_env_cfg.py b/train_mimic/tasks/tracking/tracking_env_cfg.py index c09fd060..9a56e107 100644 --- a/train_mimic/tasks/tracking/tracking_env_cfg.py +++ b/train_mimic/tasks/tracking/tracking_env_cfg.py @@ -163,7 +163,7 @@ def make_tracking_env_cfg() -> ManagerBasedRlEnvCfg: "push_robot": EventTermCfg( func=mdp.push_by_setting_velocity, mode="interval", - interval_range_s=(1.0, 3.0), + interval_range_s=(4.0, 6.0), params={"velocity_range": VELOCITY_RANGE}, ), "base_com": EventTermCfg( @@ -187,6 +187,51 @@ def make_tracking_env_cfg() -> ManagerBasedRlEnvCfg: "bias_range": (-0.01, 0.01), }, ), + "add_joint_default_pos": EventTermCfg( + mode="startup", + func=dr.joint_default_pos, + params={ + "asset_cfg": SceneEntityCfg("robot", joint_names=".*"), + "operation": "add", + "ranges": (-0.01, 0.01), + }, + ), + "motor_params_implicit_upper_body_pd": EventTermCfg( + mode="reset", + func=dr.pd_gains, + params={ + "asset_cfg": SceneEntityCfg( + "robot", actuator_names=r".*(shoulder|elbow|wrist).*" + ), + "kp_range": (0.9, 1.1), + "kd_range": (0.9, 1.1), + "distribution": "log_uniform", + "operation": "scale", + }, + ), + "motor_params_implicit_lower_body_pd": EventTermCfg( + mode="reset", + func=dr.pd_gains, + params={ + "asset_cfg": SceneEntityCfg( + "robot", actuator_names=r".*(waist|hip|knee|ankle).*" + ), + "kp_range": (0.5, 2.0), + "kd_range": (0.5, 2.0), + "distribution": "log_uniform", + "operation": "scale", + }, + ), + "motor_params_implicit_armature": EventTermCfg( + mode="startup", + func=dr.joint_armature, + params={ + "asset_cfg": SceneEntityCfg("robot", joint_names=".*"), + "ranges": (0.75, 1.25), + "distribution": "log_uniform", + "operation": "scale", + }, + ), "physics_material": EventTermCfg( mode="startup", func=dr.geom_friction, @@ -201,7 +246,7 @@ def make_tracking_env_cfg() -> ManagerBasedRlEnvCfg: func=dr.pseudo_inertia, params={ "asset_cfg": SceneEntityCfg("robot", body_names=()), # Set per-robot. - "alpha_range": (-0.11157177565710488, 0.4581453659370775), + "alpha_range": (-0.1, 0.45), }, ), } From 6a33e66256d861bd08a5c66a09801a7b2d924109 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Fri, 15 May 2026 17:33:43 +0800 Subject: [PATCH 017/122] Update tracking contact rewards --- tests/test_task_registry.py | 30 ++++++++++++------- tests/test_tracking_rewards.py | 14 ++++----- train_mimic/tasks/tracking/config/env.py | 36 +++++++++++++++-------- train_mimic/tasks/tracking/mdp/rewards.py | 33 ++++++++++----------- 4 files changed, 65 insertions(+), 48 deletions(-) diff --git a/tests/test_task_registry.py b/tests/test_task_registry.py index 1f7aaf6b..f3aeb039 100644 --- a/tests/test_task_registry.py +++ b/tests/test_task_registry.py @@ -33,26 +33,34 @@ def test_general_tracking_task_is_registered() -> None: assert "critic_history" in env_cfg.observations assert env_cfg.commands["motion"].sampling_mode == "uniform" assert env_cfg.commands["motion"].window_steps == (0,) - reward = env_cfg.rewards["self_collisions"] + reward = env_cfg.rewards["undesired_contacts"] assert reward.weight == -0.1 assert reward.params == { - "sensor_name": "self_collision", - "force_threshold": 1.0, + "sensor_name": "undesired_contacts", + "threshold": 1.0, } - assert "undesired_contacts" not in env_cfg.rewards + assert "self_collisions" not in env_cfg.rewards sensors = {sensor.name: sensor for sensor in env_cfg.scene.sensors} - assert set(sensors) == {"self_collision"} - assert sensors["self_collision"].primary.mode == "body" - assert sensors["self_collision"].primary.pattern == r".*" - assert sensors["self_collision"].primary.exclude == ( + assert set(sensors) == {"undesired_contacts"} + assert sensors["undesired_contacts"].primary.mode == "body" + assert sensors["undesired_contacts"].primary.pattern == r".*" + assert sensors["undesired_contacts"].primary.exclude == ( "left_ankle_roll_link", "right_ankle_roll_link", "left_wrist_yaw_link", "right_wrist_yaw_link", + "left_elbow_link", + "right_elbow_link", ) - assert sensors["self_collision"].secondary.mode == "subtree" - assert sensors["self_collision"].secondary.pattern == "pelvis" - assert sensors["self_collision"].reduce == "maxforce" + assert sensors["undesired_contacts"].secondary is None + assert sensors["undesired_contacts"].fields == ("force",) + assert sensors["undesired_contacts"].reduce == "netforce" + assert sensors["undesired_contacts"].history_length == 3 + feet_acc = env_cfg.rewards["feet_acc"] + assert feet_acc.weight == -2.5e-6 + assert feet_acc.params["asset_cfg"].name == "robot" + assert feet_acc.params["asset_cfg"].joint_names == r".*ankle.*" + assert "anti_shake_ang_vel" not in env_cfg.rewards rl_cfg = load_rl_cfg(DEFAULT_TASK) assert rl_cfg.experiment_name == GENERAL_TRACKING_EXPERIMENT_NAME assert rl_cfg.actor.hidden_dims == (1024, 512, 256, 256, 128) diff --git a/tests/test_tracking_rewards.py b/tests/test_tracking_rewards.py index 7736cc1d..e6cc02fd 100644 --- a/tests/test_tracking_rewards.py +++ b/tests/test_tracking_rewards.py @@ -4,17 +4,17 @@ import torch -from train_mimic.tasks.tracking.mdp.rewards import self_collision_cost +from train_mimic.tasks.tracking.mdp.rewards import undesired_contacts def _env_with_force_history(force_history: torch.Tensor) -> SimpleNamespace: sensor = SimpleNamespace( data=SimpleNamespace(force_history=force_history, found=None) ) - return SimpleNamespace(scene={"self_collision": sensor}) + return SimpleNamespace(scene={"undesired_contacts": sensor}) -def test_self_collision_cost_counts_history_frames_not_contacts() -> None: +def test_undesired_contacts_counts_bodies_not_history_frames() -> None: force_history = torch.zeros((2, 3, 4, 3), dtype=torch.float32) force_history[0, 0, 0, 0] = 2.0 force_history[0, 1, 0, 0] = 3.0 @@ -22,10 +22,10 @@ def test_self_collision_cost_counts_history_frames_not_contacts() -> None: force_history[1, 0, 1, 0] = 0.5 force_history[1, 0, 3, 0] = 1.0 - penalty = self_collision_cost( + penalty = undesired_contacts( _env_with_force_history(force_history), - sensor_name="self_collision", - force_threshold=1.0, + sensor_name="undesired_contacts", + threshold=1.0, ) - torch.testing.assert_close(penalty, torch.tensor([2.0, 0.0])) + torch.testing.assert_close(penalty, torch.tensor([3.0, 0.0])) diff --git a/train_mimic/tasks/tracking/config/env.py b/train_mimic/tasks/tracking/config/env.py index addc0167..a4906bce 100644 --- a/train_mimic/tasks/tracking/config/env.py +++ b/train_mimic/tasks/tracking/config/env.py @@ -116,37 +116,48 @@ def _add_history_obs_groups( } -def _configure_self_collision_reward(cfg: ManagerBasedRlEnvCfg) -> None: +def _configure_undesired_contacts_reward(cfg: ManagerBasedRlEnvCfg) -> None: excluded_body_names = ( "left_ankle_roll_link", "right_ankle_roll_link", "left_wrist_yaw_link", "right_wrist_yaw_link", + "left_elbow_link", + "right_elbow_link", ) cfg.scene.sensors = ( *tuple(getattr(cfg.scene, "sensors", ()) or ()), ContactSensorCfg( - name="self_collision", - # Exclude only primary bodies: wrist/ankle vs torso is still caught by torso. + name="undesired_contacts", primary=ContactMatch( mode="body", pattern=r".*", entity="robot", exclude=excluded_body_names, ), - secondary=ContactMatch(mode="subtree", pattern="pelvis", entity="robot"), - fields=("found", "force"), - reduce="maxforce", + secondary=None, + fields=("force",), + reduce="netforce", num_slots=1, - history_length=4, + history_length=3, ), ) - cfg.rewards["self_collisions"] = RewardTermCfg( - func=mdp.self_collision_cost, + cfg.rewards["undesired_contacts"] = RewardTermCfg( + func=mdp.undesired_contacts, weight=-0.1, params={ - "sensor_name": "self_collision", - "force_threshold": 1.0, + "sensor_name": "undesired_contacts", + "threshold": 1.0, + }, + ) + + +def _configure_feet_acc_reward(cfg: ManagerBasedRlEnvCfg) -> None: + cfg.rewards["feet_acc"] = RewardTermCfg( + func=mdp.joint_acc_l2, + weight=-2.5e-6, + params={ + "asset_cfg": SceneEntityCfg("robot", joint_names=r".*ankle.*"), }, ) @@ -178,7 +189,8 @@ def make_general_tracking_env_cfg( cfg.events["randomize_rigid_body_mass"].params[ "asset_cfg" ].body_names = r".*wrist_yaw.*|torso_link" - _configure_self_collision_reward(cfg) + _configure_undesired_contacts_reward(cfg) + _configure_feet_acc_reward(cfg) cfg.terminations["ee_body_pos"].params["body_names"] = ( "left_ankle_roll_link", "right_ankle_roll_link", diff --git a/train_mimic/tasks/tracking/mdp/rewards.py b/train_mimic/tasks/tracking/mdp/rewards.py index db36632e..3af37bbb 100644 --- a/train_mimic/tasks/tracking/mdp/rewards.py +++ b/train_mimic/tasks/tracking/mdp/rewards.py @@ -143,42 +143,39 @@ def motion_global_body_angular_velocity_error_exp( return torch.exp(-error.mean(-1) / std**2) -def self_collision_cost( +def undesired_contacts( env: ManagerBasedRlEnv, sensor_name: str | tuple[str, ...], - force_threshold: float = 10.0, + threshold: float = 10.0, ) -> torch.Tensor: - """Penalize self-collision history frames above the configured force threshold.""" - hit = _self_collision_hits(env, sensor_name, force_threshold) - return hit.sum(dim=-1).float() + """Penalize bodies whose contact force exceeds the configured threshold.""" + hits = _undesired_contact_hits(env, sensor_name, threshold) + return hits.sum(dim=-1).float() -def _self_collision_hits( +def _undesired_contact_hits( env: ManagerBasedRlEnv, sensor_name: str | tuple[str, ...], - force_threshold: float, + threshold: float, ) -> torch.Tensor: sensor_names = (sensor_name,) if isinstance(sensor_name, str) else sensor_name - force_histories = [] - found_values = [] + hit_values = [] for name in sensor_names: data = env.scene[name].data if data.force_history is not None: - force_histories.append(data.force_history) + force_mag = torch.norm(data.force_history, dim=-1) + hit_values.append(force_mag.amax(dim=2) > threshold) + elif getattr(data, "force", None) is not None: + force_mag = torch.norm(data.force, dim=-1) + hit_values.append(force_mag > threshold) else: assert data.found is not None found = data.found if found.ndim == 1: found = found.unsqueeze(-1) - found_values.append(found) + hit_values.append(found > 0) - if force_histories: - force_history = torch.cat(force_histories, dim=1) - force_mag = torch.norm(force_history, dim=-1) - return (force_mag > force_threshold).any(dim=1) - - found = torch.cat(found_values, dim=1) - return (found > 0).any(dim=1, keepdim=True) + return torch.cat(hit_values, dim=1) class joint_torque_limits: From 6e398239b1777caf8fde02665d6f5035cc559102 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Fri, 15 May 2026 17:55:47 +0800 Subject: [PATCH 018/122] Restore self collision reward semantics --- tests/test_task_registry.py | 26 +++++++++--------- tests/test_tracking_rewards.py | 12 ++++----- train_mimic/tasks/tracking/config/env.py | 25 +++++++++-------- train_mimic/tasks/tracking/mdp/rewards.py | 33 ++++++++++++----------- 4 files changed, 48 insertions(+), 48 deletions(-) diff --git a/tests/test_task_registry.py b/tests/test_task_registry.py index f3aeb039..973feffa 100644 --- a/tests/test_task_registry.py +++ b/tests/test_task_registry.py @@ -33,29 +33,27 @@ def test_general_tracking_task_is_registered() -> None: assert "critic_history" in env_cfg.observations assert env_cfg.commands["motion"].sampling_mode == "uniform" assert env_cfg.commands["motion"].window_steps == (0,) - reward = env_cfg.rewards["undesired_contacts"] + reward = env_cfg.rewards["self_collisions"] assert reward.weight == -0.1 assert reward.params == { - "sensor_name": "undesired_contacts", - "threshold": 1.0, + "sensor_name": "self_collision", + "force_threshold": 1.0, } - assert "self_collisions" not in env_cfg.rewards + assert "undesired_contacts" not in env_cfg.rewards sensors = {sensor.name: sensor for sensor in env_cfg.scene.sensors} - assert set(sensors) == {"undesired_contacts"} - assert sensors["undesired_contacts"].primary.mode == "body" - assert sensors["undesired_contacts"].primary.pattern == r".*" - assert sensors["undesired_contacts"].primary.exclude == ( + assert set(sensors) == {"self_collision"} + assert sensors["self_collision"].primary.mode == "body" + assert sensors["self_collision"].primary.pattern == r".*" + assert sensors["self_collision"].primary.exclude == ( "left_ankle_roll_link", "right_ankle_roll_link", "left_wrist_yaw_link", "right_wrist_yaw_link", - "left_elbow_link", - "right_elbow_link", ) - assert sensors["undesired_contacts"].secondary is None - assert sensors["undesired_contacts"].fields == ("force",) - assert sensors["undesired_contacts"].reduce == "netforce" - assert sensors["undesired_contacts"].history_length == 3 + assert sensors["self_collision"].secondary.mode == "subtree" + assert sensors["self_collision"].secondary.pattern == "pelvis" + assert sensors["self_collision"].reduce == "maxforce" + assert sensors["self_collision"].history_length == 4 feet_acc = env_cfg.rewards["feet_acc"] assert feet_acc.weight == -2.5e-6 assert feet_acc.params["asset_cfg"].name == "robot" diff --git a/tests/test_tracking_rewards.py b/tests/test_tracking_rewards.py index e6cc02fd..1550cf3d 100644 --- a/tests/test_tracking_rewards.py +++ b/tests/test_tracking_rewards.py @@ -4,17 +4,17 @@ import torch -from train_mimic.tasks.tracking.mdp.rewards import undesired_contacts +from train_mimic.tasks.tracking.mdp.rewards import self_collision_cost def _env_with_force_history(force_history: torch.Tensor) -> SimpleNamespace: sensor = SimpleNamespace( data=SimpleNamespace(force_history=force_history, found=None) ) - return SimpleNamespace(scene={"undesired_contacts": sensor}) + return SimpleNamespace(scene={"self_collision": sensor}) -def test_undesired_contacts_counts_bodies_not_history_frames() -> None: +def test_self_collision_cost_counts_contact_slots_not_history_frames() -> None: force_history = torch.zeros((2, 3, 4, 3), dtype=torch.float32) force_history[0, 0, 0, 0] = 2.0 force_history[0, 1, 0, 0] = 3.0 @@ -22,10 +22,10 @@ def test_undesired_contacts_counts_bodies_not_history_frames() -> None: force_history[1, 0, 1, 0] = 0.5 force_history[1, 0, 3, 0] = 1.0 - penalty = undesired_contacts( + penalty = self_collision_cost( _env_with_force_history(force_history), - sensor_name="undesired_contacts", - threshold=1.0, + sensor_name="self_collision", + force_threshold=1.0, ) torch.testing.assert_close(penalty, torch.tensor([3.0, 0.0])) diff --git a/train_mimic/tasks/tracking/config/env.py b/train_mimic/tasks/tracking/config/env.py index a4906bce..b55c4d34 100644 --- a/train_mimic/tasks/tracking/config/env.py +++ b/train_mimic/tasks/tracking/config/env.py @@ -116,38 +116,37 @@ def _add_history_obs_groups( } -def _configure_undesired_contacts_reward(cfg: ManagerBasedRlEnvCfg) -> None: +def _configure_self_collision_reward(cfg: ManagerBasedRlEnvCfg) -> None: excluded_body_names = ( "left_ankle_roll_link", "right_ankle_roll_link", "left_wrist_yaw_link", "right_wrist_yaw_link", - "left_elbow_link", - "right_elbow_link", ) cfg.scene.sensors = ( *tuple(getattr(cfg.scene, "sensors", ()) or ()), ContactSensorCfg( - name="undesired_contacts", + name="self_collision", + # Exclude only primary bodies: wrist/ankle vs torso is still caught by torso. primary=ContactMatch( mode="body", pattern=r".*", entity="robot", exclude=excluded_body_names, ), - secondary=None, - fields=("force",), - reduce="netforce", + secondary=ContactMatch(mode="subtree", pattern="pelvis", entity="robot"), + fields=("found", "force"), + reduce="maxforce", num_slots=1, - history_length=3, + history_length=4, ), ) - cfg.rewards["undesired_contacts"] = RewardTermCfg( - func=mdp.undesired_contacts, + cfg.rewards["self_collisions"] = RewardTermCfg( + func=mdp.self_collision_cost, weight=-0.1, params={ - "sensor_name": "undesired_contacts", - "threshold": 1.0, + "sensor_name": "self_collision", + "force_threshold": 1.0, }, ) @@ -189,7 +188,7 @@ def make_general_tracking_env_cfg( cfg.events["randomize_rigid_body_mass"].params[ "asset_cfg" ].body_names = r".*wrist_yaw.*|torso_link" - _configure_undesired_contacts_reward(cfg) + _configure_self_collision_reward(cfg) _configure_feet_acc_reward(cfg) cfg.terminations["ee_body_pos"].params["body_names"] = ( "left_ankle_roll_link", diff --git a/train_mimic/tasks/tracking/mdp/rewards.py b/train_mimic/tasks/tracking/mdp/rewards.py index 3af37bbb..ec032e05 100644 --- a/train_mimic/tasks/tracking/mdp/rewards.py +++ b/train_mimic/tasks/tracking/mdp/rewards.py @@ -143,39 +143,42 @@ def motion_global_body_angular_velocity_error_exp( return torch.exp(-error.mean(-1) / std**2) -def undesired_contacts( +def self_collision_cost( env: ManagerBasedRlEnv, sensor_name: str | tuple[str, ...], - threshold: float = 10.0, + force_threshold: float = 10.0, ) -> torch.Tensor: - """Penalize bodies whose contact force exceeds the configured threshold.""" - hits = _undesired_contact_hits(env, sensor_name, threshold) - return hits.sum(dim=-1).float() + """Penalize self-collision slots above the configured force threshold.""" + hit = _self_collision_hits(env, sensor_name, force_threshold) + return hit.sum(dim=-1).float() -def _undesired_contact_hits( +def _self_collision_hits( env: ManagerBasedRlEnv, sensor_name: str | tuple[str, ...], - threshold: float, + force_threshold: float, ) -> torch.Tensor: sensor_names = (sensor_name,) if isinstance(sensor_name, str) else sensor_name - hit_values = [] + force_histories = [] + found_values = [] for name in sensor_names: data = env.scene[name].data if data.force_history is not None: - force_mag = torch.norm(data.force_history, dim=-1) - hit_values.append(force_mag.amax(dim=2) > threshold) - elif getattr(data, "force", None) is not None: - force_mag = torch.norm(data.force, dim=-1) - hit_values.append(force_mag > threshold) + force_histories.append(data.force_history) else: assert data.found is not None found = data.found if found.ndim == 1: found = found.unsqueeze(-1) - hit_values.append(found > 0) + found_values.append(found) - return torch.cat(hit_values, dim=1) + if force_histories: + force_history = torch.cat(force_histories, dim=1) + force_mag = torch.norm(force_history, dim=-1) + return (force_mag > force_threshold).any(dim=2) + + found = torch.cat(found_values, dim=1) + return (found > 0).any(dim=1, keepdim=True) class joint_torque_limits: From 7ed7303512f4c8e8d92b67356209412edc923c3c Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Fri, 15 May 2026 18:47:45 +0800 Subject: [PATCH 019/122] Fix G1 PD gain actuator group selection --- tests/test_domain_randomization.py | 12 ++++-------- tests/test_task_registry.py | 12 ++++++++++++ train_mimic/tasks/tracking/tracking_env_cfg.py | 9 +++++++-- 3 files changed, 23 insertions(+), 10 deletions(-) diff --git a/tests/test_domain_randomization.py b/tests/test_domain_randomization.py index 16a4d080..d06e188d 100644 --- a/tests/test_domain_randomization.py +++ b/tests/test_domain_randomization.py @@ -64,10 +64,8 @@ def test_general_tracking_domain_randomization_matches_gr00t_active_set() -> Non upper_motor_pd = events["motor_params_implicit_upper_body_pd"] assert upper_motor_pd.func is dr.pd_gains assert upper_motor_pd.mode == "reset" - assert ( - upper_motor_pd.params["asset_cfg"].actuator_names - == r".*(shoulder|elbow|wrist).*" - ) + assert upper_motor_pd.params["asset_cfg"].actuator_names is None + assert upper_motor_pd.params["asset_cfg"].actuator_ids == [0, 3] assert upper_motor_pd.params["kp_range"] == (0.9, 1.1) assert upper_motor_pd.params["kd_range"] == (0.9, 1.1) assert upper_motor_pd.params["distribution"] == "log_uniform" @@ -76,10 +74,8 @@ def test_general_tracking_domain_randomization_matches_gr00t_active_set() -> Non lower_motor_pd = events["motor_params_implicit_lower_body_pd"] assert lower_motor_pd.func is dr.pd_gains assert lower_motor_pd.mode == "reset" - assert ( - lower_motor_pd.params["asset_cfg"].actuator_names - == r".*(waist|hip|knee|ankle).*" - ) + assert lower_motor_pd.params["asset_cfg"].actuator_names is None + assert lower_motor_pd.params["asset_cfg"].actuator_ids == [1, 2, 4, 5] assert lower_motor_pd.params["kp_range"] == (0.5, 2.0) assert lower_motor_pd.params["kd_range"] == (0.5, 2.0) assert lower_motor_pd.params["distribution"] == "log_uniform" diff --git a/tests/test_task_registry.py b/tests/test_task_registry.py index 973feffa..f2d54d92 100644 --- a/tests/test_task_registry.py +++ b/tests/test_task_registry.py @@ -59,6 +59,18 @@ def test_general_tracking_task_is_registered() -> None: assert feet_acc.params["asset_cfg"].name == "robot" assert feet_acc.params["asset_cfg"].joint_names == r".*ankle.*" assert "anti_shake_ang_vel" not in env_cfg.rewards + upper_pd_asset = env_cfg.events["motor_params_implicit_upper_body_pd"].params[ + "asset_cfg" + ] + lower_pd_asset = env_cfg.events["motor_params_implicit_lower_body_pd"].params[ + "asset_cfg" + ] + assert upper_pd_asset.name == "robot" + assert upper_pd_asset.actuator_names is None + assert upper_pd_asset.actuator_ids == [0, 3] + assert lower_pd_asset.name == "robot" + assert lower_pd_asset.actuator_names is None + assert lower_pd_asset.actuator_ids == [1, 2, 4, 5] rl_cfg = load_rl_cfg(DEFAULT_TASK) assert rl_cfg.experiment_name == GENERAL_TRACKING_EXPERIMENT_NAME assert rl_cfg.actor.hidden_dims == (1024, 512, 256, 256, 128) diff --git a/train_mimic/tasks/tracking/tracking_env_cfg.py b/train_mimic/tasks/tracking/tracking_env_cfg.py index 9a56e107..be39a6d7 100644 --- a/train_mimic/tasks/tracking/tracking_env_cfg.py +++ b/train_mimic/tasks/tracking/tracking_env_cfg.py @@ -33,6 +33,11 @@ "yaw": (-0.78, 0.78), } +# G1 uses six mjlab actuator groups. dr.pd_gains indexes asset.actuators +# by group id, not the expanded XML per-joint actuator ids. +G1_UPPER_BODY_ACTUATOR_GROUP_IDS = (0, 3) +G1_LOWER_BODY_ACTUATOR_GROUP_IDS = (1, 2, 4, 5) + def make_tracking_env_cfg() -> ManagerBasedRlEnvCfg: """Create base tracking task configuration.""" @@ -201,7 +206,7 @@ def make_tracking_env_cfg() -> ManagerBasedRlEnvCfg: func=dr.pd_gains, params={ "asset_cfg": SceneEntityCfg( - "robot", actuator_names=r".*(shoulder|elbow|wrist).*" + "robot", actuator_ids=list(G1_UPPER_BODY_ACTUATOR_GROUP_IDS) ), "kp_range": (0.9, 1.1), "kd_range": (0.9, 1.1), @@ -214,7 +219,7 @@ def make_tracking_env_cfg() -> ManagerBasedRlEnvCfg: func=dr.pd_gains, params={ "asset_cfg": SceneEntityCfg( - "robot", actuator_names=r".*(waist|hip|knee|ankle).*" + "robot", actuator_ids=list(G1_LOWER_BODY_ACTUATOR_GROUP_IDS) ), "kp_range": (0.5, 2.0), "kd_range": (0.5, 2.0), From 4f6f97527edb8e17c5b9bf9f6b8e738541b360ba Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Fri, 15 May 2026 20:36:46 +0800 Subject: [PATCH 020/122] Fix Pico4 ground lift reset --- teleopit/inputs/pico4_provider.py | 48 +++++++++++++++++++++++++++++ tests/test_pico4_provider.py | 50 +++++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+) diff --git a/teleopit/inputs/pico4_provider.py b/teleopit/inputs/pico4_provider.py index df6e8517..a1020f29 100644 --- a/teleopit/inputs/pico4_provider.py +++ b/teleopit/inputs/pico4_provider.py @@ -85,6 +85,32 @@ class PicoControllerSnapshot: } +def _has_non_degenerate_positions(positions: NDArray[np.float64]) -> bool: + pos = np.asarray(positions, dtype=np.float64).reshape(-1, 3) + if pos.size == 0: + return False + finite_mask = np.all(np.isfinite(pos), axis=1) + valid_pos = pos[finite_mask] + if valid_pos.shape[0] < 2: + return False + nonzero_pos = valid_pos[np.linalg.norm(valid_pos, axis=1) > 1e-9] + if nonzero_pos.shape[0] < 2: + return False + extent = float(np.max(np.ptp(nonzero_pos, axis=0))) + return extent > 1e-6 + + +def _compute_ground_lift_offset(positions: NDArray[np.float64]) -> float: + pos = np.asarray(positions, dtype=np.float64).reshape(-1, 3) + if pos.size == 0: + return 0.0 + finite_mask = np.all(np.isfinite(pos), axis=1) + if not np.any(finite_mask): + return 0.0 + min_z = float(np.min(pos[finite_mask, 2])) + return max(-min_z, 0.0) + + def _bridge_accepts_video_enabled(bridge_cls: type[Any]) -> bool: try: signature = inspect.signature(bridge_cls) @@ -165,6 +191,7 @@ def __init__( self._last_frame_timestamp: float | None = None self._last_source_seq: int | None = None self._controller_snapshot: PicoControllerSnapshot | None = None + self._ground_lift_offset: float | None = None self._bridge = bridge_cls( host=bridge_host, port=int(bridge_port), @@ -340,6 +367,7 @@ def _accept_pico_frame(self, frame: Any) -> bool: and timestamp - self._last_frame_timestamp > self._timestamp_gap_reset_s ): self._frame_cache.clear() + self._ground_lift_offset = None logger.warning( "Pico4InputProvider timestamp-gap reset | gap=%.4fs", timestamp - self._last_frame_timestamp, @@ -347,6 +375,7 @@ def _accept_pico_frame(self, frame: Any) -> bool: if self._last_frame_timestamp is not None and timestamp <= self._last_frame_timestamp + 1e-9: timestamp = self._last_frame_timestamp + 1e-6 + human_frame = self._apply_ground_lift(human_frame) self._frame_cache.append(human_frame, timestamp, fps_timestamp=timestamp) self._last_raw_body_joints = body_joints.copy() self._last_frame_timestamp = timestamp @@ -428,6 +457,25 @@ def _convert_body_joints_to_frame(body_joints: NDArray[np.float64]) -> HumanFram result[name] = (np.asarray(pos, dtype=np.float64), np.asarray(quat, dtype=np.float64)) return result + def _apply_ground_lift(self, human_frame: HumanFrame) -> HumanFrame: + """Apply one fixed Z lift so the initial Pico skeleton sits on the floor.""" + if self._ground_lift_offset is None: + positions = np.asarray([value[0] for value in human_frame.values()], dtype=np.float64) + if _has_non_degenerate_positions(positions): + self._ground_lift_offset = _compute_ground_lift_offset(positions) + else: + return human_frame + + offset = float(self._ground_lift_offset) + if offset <= 0.0: + return human_frame + + z_offset = np.array([0.0, 0.0, offset], dtype=np.float64) + lifted: HumanFrame = {} + for name, (pos, quat) in human_frame.items(): + lifted[name] = (np.asarray(pos, dtype=np.float64) + z_offset, np.asarray(quat, dtype=np.float64)) + return lifted + @staticmethod def _normalize_pico_bridge_body_joints(body_joints: NDArray[np.float64]) -> NDArray[np.float64]: """Match Teleopit's calibrated Pico body-pose convention.""" diff --git a/tests/test_pico4_provider.py b/tests/test_pico4_provider.py index 7830eded..b64edd8f 100644 --- a/tests/test_pico4_provider.py +++ b/tests/test_pico4_provider.py @@ -55,6 +55,7 @@ def _make_provider() -> Pico4InputProvider: provider._last_raw_body_joints = None provider._last_frame_timestamp = None provider._last_source_seq = None + provider._ground_lift_offset = None provider._closed = False return provider @@ -163,6 +164,55 @@ def test_pico4_provider_normalizes_pico_bridge_body_pose_convention() -> None: np.testing.assert_allclose(frame["Pelvis"][0], [1.0, 3.0, 2.0], atol=1e-6) +def test_pico4_provider_applies_fixed_ground_lift_from_first_real_frame() -> None: + provider = _make_provider() + body_poses = np.zeros((len(BODY_JOINT_NAMES), 7), dtype=np.float64) + pelvis_idx = BODY_JOINT_NAMES.index("Pelvis") + left_ankle_idx = BODY_JOINT_NAMES.index("Left_Ankle") + right_ankle_idx = BODY_JOINT_NAMES.index("Right_Ankle") + body_poses[pelvis_idx, 0:3] = [0.0, 0.8, 0.0] + body_poses[left_ankle_idx, 0:3] = [0.1, -0.2, 0.0] + body_poses[right_ankle_idx, 0:3] = [-0.1, 0.1, 0.0] + body_poses[:, 6] = 1.0 + + assert provider._accept_pico_frame(_pico_frame(body_poses, seq=1, timestamp=1.0)) is True + first_frame, _, _ = provider._frame_cache.latest_packet() + np.testing.assert_allclose(first_frame["Pelvis"][0][2], 0.8 + 0.2, atol=1e-6) + np.testing.assert_allclose(first_frame["Left_Ankle"][0][2], 0.0, atol=1e-6) + assert provider._ground_lift_offset == pytest.approx(0.2) + + body_poses[:, 1] += 0.3 + assert provider._accept_pico_frame(_pico_frame(body_poses, seq=2, timestamp=1.1)) is True + second_frame, _, _ = provider._frame_cache.latest_packet() + np.testing.assert_allclose(second_frame["Pelvis"][0][2], first_frame["Pelvis"][0][2] + 0.3, atol=1e-6) + np.testing.assert_allclose(second_frame["Left_Ankle"][0][2], 0.3, atol=1e-6) + + +def test_pico4_provider_recomputes_ground_lift_after_timestamp_gap_reset() -> None: + provider = _make_provider() + body_poses = np.zeros((len(BODY_JOINT_NAMES), 7), dtype=np.float64) + pelvis_idx = BODY_JOINT_NAMES.index("Pelvis") + left_ankle_idx = BODY_JOINT_NAMES.index("Left_Ankle") + right_ankle_idx = BODY_JOINT_NAMES.index("Right_Ankle") + body_poses[pelvis_idx, 0:3] = [0.0, 0.8, 0.0] + body_poses[left_ankle_idx, 0:3] = [0.1, -0.2, 0.0] + body_poses[right_ankle_idx, 0:3] = [-0.1, 0.1, 0.0] + body_poses[:, 6] = 1.0 + + assert provider._accept_pico_frame(_pico_frame(body_poses, seq=1, timestamp=1.0)) is True + assert provider._ground_lift_offset == pytest.approx(0.2) + + body_poses[pelvis_idx, 1] = 0.7 + body_poses[left_ankle_idx, 1] = -0.5 + body_poses[right_ankle_idx, 1] = 0.2 + assert provider._accept_pico_frame(_pico_frame(body_poses, seq=2, timestamp=1.3)) is True + latest_frame, _, _ = provider._frame_cache.latest_packet() + np.testing.assert_allclose(latest_frame["Left_Ankle"][0][2], 0.0, atol=1e-6) + np.testing.assert_allclose(latest_frame["Pelvis"][0][2], 1.2, atol=1e-6) + assert provider._ground_lift_offset == pytest.approx(0.5) + assert len(provider._frame_cache) == 1 + + def test_pico4_provider_drops_duplicate_raw_body_pose() -> None: provider = _make_provider() body_poses = _body_poses(1.0) From 50b0c87ca508e49bb9e98aa74295691518912a33 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Mon, 18 May 2026 11:35:58 +0800 Subject: [PATCH 021/122] Tighten tracking position terminations --- train_mimic/tasks/tracking/config/env.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train_mimic/tasks/tracking/config/env.py b/train_mimic/tasks/tracking/config/env.py index b55c4d34..87de1083 100644 --- a/train_mimic/tasks/tracking/config/env.py +++ b/train_mimic/tasks/tracking/config/env.py @@ -196,9 +196,9 @@ def make_general_tracking_env_cfg( "left_wrist_yaw_link", "right_wrist_yaw_link", ) - cfg.terminations["anchor_pos"].params["threshold"] = 0.4 + cfg.terminations["anchor_pos"].params["threshold"] = 0.25 cfg.terminations["anchor_ori"].params["threshold"] = 1.0 - cfg.terminations["ee_body_pos"].params["threshold"] = 0.4 + cfg.terminations["ee_body_pos"].params["threshold"] = 0.25 cfg.viewer.body_name = "torso_link" cfg.episode_length_s = 10.0 if cfg.sim.njmax < 500: From 8032747c582ed0f8049e51439117b2af8f6eef64 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Tue, 19 May 2026 06:40:38 +0000 Subject: [PATCH 022/122] Fix training ETA duration formatting --- tests/test_runner_iteration_numbering.py | 5 +++++ train_mimic/tasks/tracking/rl/runner.py | 12 ++++++++++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/tests/test_runner_iteration_numbering.py b/tests/test_runner_iteration_numbering.py index adad80e4..9b61ed5f 100644 --- a/tests/test_runner_iteration_numbering.py +++ b/tests/test_runner_iteration_numbering.py @@ -3,6 +3,7 @@ import pytest from train_mimic.tasks.tracking.rl.runner import ( + _format_duration, _one_based_iteration_range, _resolve_total_iterations, ) @@ -36,3 +37,7 @@ def test_resolve_total_iterations_adds_requested_iterations_on_resume() -> None: def test_resolve_total_iterations_rejects_negative_requested_iterations() -> None: with pytest.raises(ValueError, match='non-negative'): _resolve_total_iterations(10, -1) + + +def test_format_duration_keeps_hours_above_one_day() -> None: + assert _format_duration(33 * 3600 + 16 * 60 + 25) == "33:16:25" diff --git a/train_mimic/tasks/tracking/rl/runner.py b/train_mimic/tasks/tracking/rl/runner.py index 9bd143b8..058a1360 100644 --- a/train_mimic/tasks/tracking/rl/runner.py +++ b/train_mimic/tasks/tracking/rl/runner.py @@ -34,6 +34,14 @@ def _resolve_total_iterations(start_iteration: int, num_learning_iterations: int return start_iteration + num_learning_iterations +def _format_duration(seconds: float) -> str: + """Format elapsed/remaining seconds without wrapping after 24 hours.""" + total_seconds = max(0, int(seconds)) + hours, remainder = divmod(total_seconds, 3600) + minutes, secs = divmod(remainder, 60) + return f"{hours:02d}:{minutes:02d}:{secs:02d}" + + class _OnnxMotionModel(nn.Module): """ONNX-exportable model that wraps the policy and bundles motion reference data.""" @@ -264,9 +272,9 @@ def _log_one_based_iteration( """ f"""{'Iteration time:':>{pad}} {iteration_time:.2f}s """ - f"""{'Time elapsed:':>{pad}} {time.strftime('%H:%M:%S', time.gmtime(logger.tot_time))} + f"""{'Time elapsed:':>{pad}} {_format_duration(logger.tot_time)} """ - f"""{'ETA:':>{pad}} {time.strftime('%H:%M:%S', time.gmtime(eta))} + f"""{'ETA:':>{pad}} {_format_duration(eta)} """ ) print(log_string) From 4309a79c8d85f28390ebd7a35b95d390a6f9cac6 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Tue, 19 May 2026 15:19:39 +0800 Subject: [PATCH 023/122] Add ONNX policy benchmark script --- scripts/dev/bench_policy_onnx.py | 495 +++++++++++++++++++++++++++++++ 1 file changed, 495 insertions(+) create mode 100644 scripts/dev/bench_policy_onnx.py diff --git a/scripts/dev/bench_policy_onnx.py b/scripts/dev/bench_policy_onnx.py new file mode 100644 index 00000000..b40b7c2a --- /dev/null +++ b/scripts/dev/bench_policy_onnx.py @@ -0,0 +1,495 @@ +"""Benchmark Teleopit ONNX policy inference latency. + +This is a policy-only micro-benchmark intended for onboard model-size checks. +It does not require MuJoCo, robot hardware, GMR assets, or Pico input. + +Examples: + python scripts/dev/bench_policy_onnx.py --policy track.onnx + python scripts/dev/bench_policy_onnx.py --policy track.onnx --runs 20000 --device cpu + python scripts/dev/bench_policy_onnx.py --policy track.onnx --mode direct +""" + +from __future__ import annotations + +import argparse +import json +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Sequence + +import numpy as np + + +@dataclass(frozen=True) +class InputSpec: + name: str + shape: tuple[Any, ...] + dtype: str + + +@dataclass(frozen=True) +class PolicySignature: + obs_name: str + obs_dim: int + history_name: str | None + history_length: int + history_obs_dim: int + output_name: str + + @property + def is_dual_input(self) -> bool: + return self.history_name is not None + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Benchmark ONNX policy latency for Teleopit runtime sizing.", + ) + parser.add_argument("--policy", required=True, help="Path to exported policy.onnx") + parser.add_argument( + "--device", + choices=["cpu", "cuda", "auto"], + default="cpu", + help="ONNX Runtime execution provider preference (default: cpu)", + ) + parser.add_argument( + "--mode", + choices=["controller", "direct"], + default="controller", + help=( + "controller simulates RLPolicyController obs_history stacking; " + "direct reuses prebuilt feed tensors and measures session.run only" + ), + ) + parser.add_argument("--runs", type=int, default=5000, help="Measured iterations") + parser.add_argument("--warmup", type=int, default=200, help="Warmup iterations") + parser.add_argument("--policy-hz", type=float, default=50.0, help="Runtime policy frequency") + parser.add_argument( + "--input-mode", + choices=["random", "zeros"], + default="random", + help="Synthetic observation contents; generated outside the timed region", + ) + parser.add_argument( + "--obs-dim", + type=int, + default=None, + help="Required only when the ONNX obs feature dimension is dynamic", + ) + parser.add_argument( + "--history-length", + type=int, + default=None, + help="Required only when the ONNX obs_history length dimension is dynamic", + ) + parser.add_argument( + "--intra-op-threads", + type=int, + default=0, + help="ONNX Runtime intra-op threads; 0 keeps ORT default", + ) + parser.add_argument( + "--inter-op-threads", + type=int, + default=0, + help="ONNX Runtime inter-op threads; 0 keeps ORT default", + ) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument( + "--max-p99-ms", + type=float, + default=None, + help="Exit non-zero if p99 latency exceeds this value", + ) + parser.add_argument( + "--json", + type=str, + default=None, + help="Optional path to write benchmark summary JSON", + ) + return parser.parse_args() + + +def _dim_to_int(dim: Any) -> int | None: + if isinstance(dim, int): + return dim + if isinstance(dim, np.integer): + return int(dim) + if isinstance(dim, float) and dim.is_integer(): + return int(dim) + if isinstance(dim, str): + try: + return int(dim) + except ValueError: + return None + return None + + +def _feature_dim(shape: Sequence[Any], fallback: int | None, label: str) -> int: + if not shape: + raise ValueError(f"{label} input has empty shape; cannot infer feature dimension") + dim = _dim_to_int(shape[-1]) + if dim is not None: + return dim + if fallback is not None: + return int(fallback) + raise ValueError( + f"{label} feature dimension is dynamic ({shape[-1]!r}). " + f"Pass --obs-dim to make the benchmark input explicit." + ) + + +def _history_len(shape: Sequence[Any], fallback: int | None) -> int: + if len(shape) < 3: + raise ValueError(f"obs_history must be rank 3 [batch, history, obs_dim], got shape={shape}") + dim = _dim_to_int(shape[1]) + if dim is not None: + return dim + if fallback is not None: + return int(fallback) + raise ValueError( + f"obs_history length dimension is dynamic ({shape[1]!r}). " + f"Pass --history-length to make the benchmark input explicit." + ) + + +def _select_providers(ort: Any, device: str) -> list[str]: + available = set(ort.get_available_providers()) + providers: list[str] = [] + if (device == "cuda" or device == "auto") and "CUDAExecutionProvider" in available: + providers.append("CUDAExecutionProvider") + if device == "cuda" and not providers: + raise RuntimeError( + "CUDAExecutionProvider was requested but is not available. " + f"Available providers: {sorted(available)}" + ) + providers.append("CPUExecutionProvider") + return providers + + +def _make_session(policy_path: Path, args: argparse.Namespace) -> tuple[Any, list[str], list[str]]: + try: + import onnxruntime as ort + except ModuleNotFoundError as exc: + raise RuntimeError("onnxruntime is required; install the Teleopit inference dependencies") from exc + + options = ort.SessionOptions() + options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL + if args.intra_op_threads > 0: + options.intra_op_num_threads = int(args.intra_op_threads) + if args.inter_op_threads > 0: + options.inter_op_num_threads = int(args.inter_op_threads) + + providers = _select_providers(ort, str(args.device)) + session = ort.InferenceSession(str(policy_path), sess_options=options, providers=providers) + return session, providers, list(ort.get_available_providers()) + + +def _inspect_signature(session: Any, args: argparse.Namespace) -> tuple[PolicySignature, list[InputSpec]]: + inputs = session.get_inputs() + outputs = session.get_outputs() + if len(outputs) < 1: + raise ValueError("ONNX policy has no outputs") + if len(inputs) not in (1, 2): + names = [inp.name for inp in inputs] + raise ValueError( + "Unsupported ONNX policy input signature. Expected one obs input or " + f"dual inputs ('obs', 'obs_history'), got {names}." + ) + + specs = [ + InputSpec(name=inp.name, shape=tuple(inp.shape), dtype=str(getattr(inp, "type", ""))) + for inp in inputs + ] + obs_dim = _feature_dim(inputs[0].shape, args.obs_dim, inputs[0].name) + history_name: str | None = None + history_length = 0 + history_obs_dim = 0 + + if len(inputs) == 2: + if inputs[1].name != "obs_history": + names = [inp.name for inp in inputs] + raise ValueError( + "Unsupported dual-input policy. Expected second input named " + f"'obs_history', got {names}." + ) + history_name = inputs[1].name + history_length = _history_len(inputs[1].shape, args.history_length) + history_obs_dim = _feature_dim(inputs[1].shape, args.obs_dim, "obs_history") + if history_obs_dim != obs_dim: + raise ValueError( + f"obs_dim mismatch: obs has {obs_dim}, obs_history has {history_obs_dim}" + ) + + signature = PolicySignature( + obs_name=inputs[0].name, + obs_dim=obs_dim, + history_name=history_name, + history_length=history_length, + history_obs_dim=history_obs_dim, + output_name=outputs[0].name, + ) + return signature, specs + + +def _make_obs(total: int, obs_dim: int, input_mode: str, seed: int) -> np.ndarray: + if input_mode == "zeros": + return np.zeros((total, obs_dim), dtype=np.float32) + rng = np.random.default_rng(seed) + return rng.standard_normal((total, obs_dim), dtype=np.float32) + + +def _stats_ms(samples_ms: np.ndarray) -> dict[str, float]: + return { + "mean": float(np.mean(samples_ms)), + "std": float(np.std(samples_ms)), + "min": float(np.min(samples_ms)), + "p50": float(np.percentile(samples_ms, 50)), + "p90": float(np.percentile(samples_ms, 90)), + "p95": float(np.percentile(samples_ms, 95)), + "p99": float(np.percentile(samples_ms, 99)), + "max": float(np.max(samples_ms)), + } + + +def _run_direct( + session: Any, + signature: PolicySignature, + obs_samples: np.ndarray, + runs: int, + warmup: int, +) -> tuple[np.ndarray, tuple[int, ...]]: + obs = obs_samples[0:1] + feed = {signature.obs_name: obs} + if signature.history_name is not None: + feed[signature.history_name] = np.repeat( + obs[:, np.newaxis, :], + signature.history_length, + axis=1, + ).astype(np.float32, copy=False) + + for _ in range(warmup): + session.run([signature.output_name], feed) + + timings = np.empty(runs, dtype=np.float64) + output_shape: tuple[int, ...] = () + for idx in range(runs): + t0 = time.perf_counter_ns() + output = session.run([signature.output_name], feed)[0] + t1 = time.perf_counter_ns() + timings[idx] = (t1 - t0) / 1_000_000.0 + if idx == 0: + output_shape = tuple(np.asarray(output).shape) + return timings, output_shape + + +def _run_controller_like( + session: Any, + signature: PolicySignature, + obs_samples: np.ndarray, + runs: int, + warmup: int, +) -> tuple[np.ndarray, tuple[int, ...]]: + from collections import deque + + history_buf: deque[np.ndarray] = deque(maxlen=max(signature.history_length, 1)) + total = warmup + runs + timings = np.empty(runs, dtype=np.float64) + output_shape: tuple[int, ...] = () + + for idx in range(total): + measured = idx >= warmup + t0 = time.perf_counter_ns() if measured else 0 + obs_flat = obs_samples[idx] + obs = obs_flat[np.newaxis, :] + if signature.history_name is not None: + if len(history_buf) == 0: + for _ in range(signature.history_length): + history_buf.append(obs_flat.copy()) + else: + history_buf.append(obs_flat.copy()) + obs_history = np.stack(list(history_buf), axis=0)[np.newaxis].astype(np.float32) + feed = { + signature.obs_name: obs, + signature.history_name: obs_history, + } + else: + feed = {signature.obs_name: obs} + + output = session.run([signature.output_name], feed)[0] + if measured: + t1 = time.perf_counter_ns() + out_idx = idx - warmup + timings[out_idx] = (t1 - t0) / 1_000_000.0 + if out_idx == 0: + output_shape = tuple(np.asarray(output).shape) + return timings, output_shape + + +def _print_summary( + policy_path: Path, + args: argparse.Namespace, + providers: list[str], + available_providers: list[str], + signature: PolicySignature, + input_specs: list[InputSpec], + stats: dict[str, float], + output_shape: tuple[int, ...], + over_budget: int, +) -> None: + budget_ms = 1000.0 / float(args.policy_hz) + measured_fps_mean = 1000.0 / stats["mean"] if stats["mean"] > 0 else float("inf") + p95_margin_ms = budget_ms - stats["p95"] + p99_margin_ms = budget_ms - stats["p99"] + + print("=" * 72) + print("ONNX Policy Benchmark") + print("=" * 72) + print(f"Policy: {policy_path}") + print(f"Mode: {args.mode}") + print(f"Runs / warmup: {args.runs} / {args.warmup}") + print(f"Policy rate budget: {args.policy_hz:.1f} Hz = {budget_ms:.2f} ms/step") + print(f"Providers selected: {providers}") + print(f"Providers available: {available_providers}") + print() + print("Input signature:") + for spec in input_specs: + print(f" {spec.name}: shape={spec.shape}, type={spec.dtype}") + print(f"Output: {signature.output_name}, measured shape={output_shape}") + print() + print("Latency (ms):") + print(f" mean: {stats['mean']:.4f} std: {stats['std']:.4f}") + print(f" min: {stats['min']:.4f} max: {stats['max']:.4f}") + print(f" p50: {stats['p50']:.4f}") + print(f" p90: {stats['p90']:.4f}") + print(f" p95: {stats['p95']:.4f} margin vs budget: {p95_margin_ms:.4f} ms") + print(f" p99: {stats['p99']:.4f} margin vs budget: {p99_margin_ms:.4f} ms") + print() + print(f"Mean throughput: {measured_fps_mean:.1f} policy calls/s") + print(f"Over budget: {over_budget} / {args.runs}") + print("=" * 72) + + if p99_margin_ms < 0: + print("WARNING: p99 latency exceeds the policy-rate budget.") + elif p99_margin_ms < 0.25 * budget_ms: + print("NOTE: p99 latency fits but leaves less than 25% budget margin.") + else: + print("OK: p99 latency leaves at least 25% policy-rate budget margin.") + + +def _write_json( + path: Path, + policy_path: Path, + args: argparse.Namespace, + providers: list[str], + available_providers: list[str], + signature: PolicySignature, + stats: dict[str, float], + output_shape: tuple[int, ...], + over_budget: int, +) -> None: + budget_ms = 1000.0 / float(args.policy_hz) + payload = { + "policy": str(policy_path), + "mode": args.mode, + "runs": int(args.runs), + "warmup": int(args.warmup), + "policy_hz": float(args.policy_hz), + "budget_ms": budget_ms, + "providers": providers, + "available_providers": available_providers, + "signature": { + "obs_name": signature.obs_name, + "obs_dim": signature.obs_dim, + "history_name": signature.history_name, + "history_length": signature.history_length, + "history_obs_dim": signature.history_obs_dim, + "output_name": signature.output_name, + "output_shape": list(output_shape), + }, + "latency_ms": stats, + "mean_policy_calls_per_s": 1000.0 / stats["mean"] if stats["mean"] > 0 else None, + "over_budget": int(over_budget), + "p95_margin_ms": budget_ms - stats["p95"], + "p99_margin_ms": budget_ms - stats["p99"], + } + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8") + + +def main() -> int: + args = _parse_args() + if args.runs <= 0: + raise ValueError("--runs must be > 0") + if args.warmup < 0: + raise ValueError("--warmup must be >= 0") + if args.policy_hz <= 0: + raise ValueError("--policy-hz must be > 0") + + policy_path = Path(args.policy).expanduser().resolve() + if not policy_path.is_file(): + raise FileNotFoundError(f"ONNX policy not found: {policy_path}") + + session, providers, available_providers = _make_session(policy_path, args) + signature, input_specs = _inspect_signature(session, args) + total_samples = max(1, args.warmup + args.runs) + obs_samples = _make_obs(total_samples, signature.obs_dim, args.input_mode, args.seed) + + if args.mode == "direct": + timings_ms, output_shape = _run_direct( + session=session, + signature=signature, + obs_samples=obs_samples, + runs=args.runs, + warmup=args.warmup, + ) + else: + timings_ms, output_shape = _run_controller_like( + session=session, + signature=signature, + obs_samples=obs_samples, + runs=args.runs, + warmup=args.warmup, + ) + + stats = _stats_ms(timings_ms) + budget_ms = 1000.0 / float(args.policy_hz) + over_budget = int(np.count_nonzero(timings_ms > budget_ms)) + _print_summary( + policy_path=policy_path, + args=args, + providers=providers, + available_providers=available_providers, + signature=signature, + input_specs=input_specs, + stats=stats, + output_shape=output_shape, + over_budget=over_budget, + ) + + if args.json is not None: + json_path = Path(args.json).expanduser() + _write_json( + path=json_path, + policy_path=policy_path, + args=args, + providers=providers, + available_providers=available_providers, + signature=signature, + stats=stats, + output_shape=output_shape, + over_budget=over_budget, + ) + print(f"Wrote JSON summary: {json_path}") + + if args.max_p99_ms is not None and stats["p99"] > args.max_p99_ms: + print( + f"FAIL: p99={stats['p99']:.4f}ms exceeds --max-p99-ms={args.max_p99_ms:.4f}ms" + ) + return 2 + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) From def441814c403fb4d37b641bbf90ac7a0484c1ef Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Wed, 20 May 2026 10:02:19 +0800 Subject: [PATCH 024/122] Add payload randomization to tracking task --- teleopit/runtime/assets.py | 1 + tests/test_domain_randomization.py | 50 ++++++++++++++++++- tests/test_motion_sampling.py | 22 ++++++++ tests/test_task_registry.py | 9 ++++ train_mimic/tasks/tracking/config/env.py | 40 +++++++++++++-- train_mimic/tasks/tracking/mdp/commands.py | 24 ++++++++- .../tasks/tracking/tracking_env_cfg.py | 44 ++++++++++++++++ 7 files changed, 184 insertions(+), 6 deletions(-) diff --git a/teleopit/runtime/assets.py b/teleopit/runtime/assets.py index e9be0638..f86ca79a 100644 --- a/teleopit/runtime/assets.py +++ b/teleopit/runtime/assets.py @@ -6,6 +6,7 @@ PROJECT_ROOT = Path(__file__).resolve().parents[2] GMR_ASSETS_ROOT = PROJECT_ROOT / "teleopit" / "retargeting" / "gmr" / "assets" UNITREE_G1_MJLAB_XML = GMR_ASSETS_ROOT / "unitree_g1" / "g1_mjlab.xml" +UNITREE_G1_MJLAB_PAYLOAD_XML = GMR_ASSETS_ROOT / "unitree_g1" / "g1_mjlab_payload.xml" def missing_gmr_assets_message(path: str | Path, *, label: str = "Required asset") -> str: diff --git a/tests/test_domain_randomization.py b/tests/test_domain_randomization.py index d06e188d..d58d64e9 100644 --- a/tests/test_domain_randomization.py +++ b/tests/test_domain_randomization.py @@ -23,6 +23,10 @@ def test_general_tracking_domain_randomization_matches_gr00t_active_set() -> Non "motor_params_implicit_armature", "physics_material", "randomize_rigid_body_mass", + "randomize_dexhand_payload_mass", + "randomize_gimbal_payload_mass", + "randomize_dexhand_payload_pos", + "randomize_gimbal_payload_pos", } push_robot = events["push_robot"] @@ -99,9 +103,49 @@ def test_general_tracking_domain_randomization_matches_gr00t_active_set() -> Non mass = events["randomize_rigid_body_mass"] assert mass.func is dr.pseudo_inertia assert mass.mode == "startup" - assert mass.params["asset_cfg"].body_names == r".*wrist_yaw.*|torso_link" + assert mass.params["asset_cfg"].body_names == "torso_link" assert mass.params["alpha_range"] == (-0.1, 0.45) + dexhand_mass = events["randomize_dexhand_payload_mass"] + assert dexhand_mass.func is dr.pseudo_inertia + assert dexhand_mass.mode == "startup" + assert dexhand_mass.params["asset_cfg"].body_names == ( + "left_dexhand_payload", + "right_dexhand_payload", + ) + assert dexhand_mass.params["alpha_range"] == (-8.0, 0.34657359027997264) + + gimbal_mass = events["randomize_gimbal_payload_mass"] + assert gimbal_mass.func is dr.pseudo_inertia + assert gimbal_mass.mode == "startup" + assert gimbal_mass.params["asset_cfg"].body_names == ("head_gimbal_payload",) + assert gimbal_mass.params["alpha_range"] == (-8.0, 0.34657359027997264) + + dexhand_pos = events["randomize_dexhand_payload_pos"] + assert dexhand_pos.func is dr.body_pos + assert dexhand_pos.mode == "startup" + assert dexhand_pos.params["asset_cfg"].body_names == ( + "left_dexhand_payload", + "right_dexhand_payload", + ) + assert dexhand_pos.params["operation"] == "abs" + assert dexhand_pos.params["ranges"] == { + 0: (0.04, 0.12), + 1: (-0.03, 0.03), + 2: (-0.03, 0.03), + } + + gimbal_pos = events["randomize_gimbal_payload_pos"] + assert gimbal_pos.func is dr.body_pos + assert gimbal_pos.mode == "startup" + assert gimbal_pos.params["asset_cfg"].body_names == ("head_gimbal_payload",) + assert gimbal_pos.params["operation"] == "abs" + assert gimbal_pos.params["ranges"] == { + 0: (0.03, 0.12), + 1: (-0.03, 0.03), + 2: (0.40, 0.50), + } + def test_play_env_disables_training_only_domain_randomization() -> None: import mjlab.tasks # noqa: F401 @@ -119,4 +163,8 @@ def test_play_env_disables_training_only_domain_randomization() -> None: assert "motor_params_implicit_armature" not in play_cfg.events assert "physics_material" not in play_cfg.events assert "randomize_rigid_body_mass" not in play_cfg.events + assert "randomize_dexhand_payload_mass" not in play_cfg.events + assert "randomize_gimbal_payload_mass" not in play_cfg.events + assert "randomize_dexhand_payload_pos" not in play_cfg.events + assert "randomize_gimbal_payload_pos" not in play_cfg.events assert play_cfg.events == {} diff --git a/tests/test_motion_sampling.py b/tests/test_motion_sampling.py index d9b6b5de..ca68090b 100644 --- a/tests/test_motion_sampling.py +++ b/tests/test_motion_sampling.py @@ -18,6 +18,7 @@ def _clip_dict(num_frames: int = 6, fps: int = 1) -> dict[str, object]: body_pos_w = np.zeros((num_frames, 3, 3), dtype=np.float32) body_pos_w[:, :, 0] = time[:, None] + body_pos_w[:, :, 1] = np.arange(3, dtype=np.float32)[None, :] body_quat_w = np.zeros((num_frames, 3, 4), dtype=np.float32) body_quat_w[:, :, 0] = 1.0 body_lin_vel_w = np.zeros((num_frames, 3, 3), dtype=np.float32) @@ -109,6 +110,27 @@ def test_motion_lib_get_window_frames_returns_requested_offsets(tmp_path: Path) assert torch.allclose(current["joint_pos"][0, :1], torch.tensor([2.0], dtype=torch.float32)) +def test_motion_lib_selects_bodies_by_dataset_names(tmp_path: Path) -> None: + motion_path = _write_shard_dir(tmp_path / "motion_named_bodies", [_clip_dict()]) + + motion = MotionLib( + str(motion_path), + body_indexes=torch.tensor([99, 0], dtype=torch.long), + body_names=["right_ankle_roll_link", "pelvis"], + window_steps=(0,), + ) + frames = motion.get_frames( + torch.tensor([0], dtype=torch.long), + torch.tensor([2.0], dtype=torch.float32), + ) + + assert frames["body_pos_w"].shape == (1, 2, 3) + assert torch.allclose( + frames["body_pos_w"][0, :, 1], + torch.tensor([2.0, 0.0], dtype=torch.float32), + ) + + def test_motion_lib_window_start_and_end_times_follow_valid_center_range(tmp_path: Path) -> None: motion_path = _write_shard_dir(tmp_path / "motion_windowed", [_clip_dict()]) diff --git a/tests/test_task_registry.py b/tests/test_task_registry.py index f2d54d92..7cb02df1 100644 --- a/tests/test_task_registry.py +++ b/tests/test_task_registry.py @@ -20,8 +20,17 @@ def test_general_tracking_task_is_registered() -> None: env_cfg = load_env_cfg(DEFAULT_TASK) actor_terms = env_cfg.observations["actor"].terms critic_terms = env_cfg.observations["critic"].terms + robot_model = env_cfg.scene.entities["robot"].spec_fn().compile() assert DEFAULT_TASK == GENERAL_TRACKING_TASK + for body_name in ( + "left_dexhand_payload", + "right_dexhand_payload", + "head_gimbal_payload", + ): + assert body_name in { + robot_model.body(i).name for i in range(robot_model.nbody) + } for terms in (actor_terms, critic_terms): assert "projected_gravity" in terms assert "ref_base_lin_vel_b" in terms diff --git a/train_mimic/tasks/tracking/config/env.py b/train_mimic/tasks/tracking/config/env.py index 87de1083..338af2c3 100644 --- a/train_mimic/tasks/tracking/config/env.py +++ b/train_mimic/tasks/tracking/config/env.py @@ -4,6 +4,7 @@ from copy import deepcopy +import mujoco from mjlab.asset_zoo.robots import G1_ACTION_SCALE, get_g1_robot_cfg from mjlab.envs import ManagerBasedRlEnvCfg from mjlab.envs.mdp.actions import JointPositionActionCfg @@ -13,6 +14,7 @@ from mjlab.sensor import ContactMatch, ContactSensorCfg from mjlab.utils.noise import UniformNoiseCfg as Unoise +from teleopit.runtime.assets import UNITREE_G1_MJLAB_PAYLOAD_XML from train_mimic.tasks.tracking import mdp from train_mimic.tasks.tracking.config.constants import DEFAULT_TRAIN_MOTION_FILE from train_mimic.tasks.tracking.mdp import MotionCommandCfg @@ -45,6 +47,10 @@ "motor_params_implicit_armature", "physics_material", "randomize_rigid_body_mass", + "randomize_dexhand_payload_mass", + "randomize_gimbal_payload_mass", + "randomize_dexhand_payload_pos", + "randomize_gimbal_payload_pos", ) @@ -116,6 +122,16 @@ def _add_history_obs_groups( } +def _payload_g1_spec() -> mujoco.MjSpec: + return mujoco.MjSpec.from_file(str(UNITREE_G1_MJLAB_PAYLOAD_XML)) + + +def _payload_g1_robot_cfg(): + robot_cfg = get_g1_robot_cfg() + robot_cfg.spec_fn = _payload_g1_spec + return robot_cfg + + def _configure_self_collision_reward(cfg: ManagerBasedRlEnvCfg) -> None: excluded_body_names = ( "left_ankle_roll_link", @@ -161,13 +177,31 @@ def _configure_feet_acc_reward(cfg: ManagerBasedRlEnvCfg) -> None: ) +def _configure_payload_randomization(cfg: ManagerBasedRlEnvCfg) -> None: + cfg.events["randomize_rigid_body_mass"].params[ + "asset_cfg" + ].body_names = "torso_link" + cfg.events["randomize_dexhand_payload_mass"].params[ + "asset_cfg" + ].body_names = ("left_dexhand_payload", "right_dexhand_payload") + cfg.events["randomize_gimbal_payload_mass"].params[ + "asset_cfg" + ].body_names = ("head_gimbal_payload",) + cfg.events["randomize_dexhand_payload_pos"].params[ + "asset_cfg" + ].body_names = ("left_dexhand_payload", "right_dexhand_payload") + cfg.events["randomize_gimbal_payload_pos"].params[ + "asset_cfg" + ].body_names = ("head_gimbal_payload",) + + def make_general_tracking_env_cfg( *, play: bool = False, ) -> ManagerBasedRlEnvCfg: """Create the General-Tracking-G1 training env.""" cfg = make_tracking_env_cfg() - cfg.scene.entities = {"robot": get_g1_robot_cfg()} + cfg.scene.entities = {"robot": _payload_g1_robot_cfg()} joint_pos_action = cfg.actions["joint_pos"] assert isinstance(joint_pos_action, JointPositionActionCfg) @@ -185,9 +219,7 @@ def make_general_tracking_env_cfg( "asset_cfg" ].geom_names = r".*_collision$" cfg.events["base_com"].params["asset_cfg"].body_names = ("torso_link",) - cfg.events["randomize_rigid_body_mass"].params[ - "asset_cfg" - ].body_names = r".*wrist_yaw.*|torso_link" + _configure_payload_randomization(cfg) _configure_self_collision_reward(cfg) _configure_feet_acc_reward(cfg) cfg.terminations["ee_body_pos"].params["body_names"] = ( diff --git a/train_mimic/tasks/tracking/mdp/commands.py b/train_mimic/tasks/tracking/mdp/commands.py index 916c45b4..753b1163 100644 --- a/train_mimic/tasks/tracking/mdp/commands.py +++ b/train_mimic/tasks/tracking/mdp/commands.py @@ -170,6 +170,7 @@ def __init__( self, motion_file: str, body_indexes: torch.Tensor, + body_names: tuple[str, ...] | list[str] | None = None, device: str = "cpu", window_steps: tuple[int, ...] | list[int] | None = None, ) -> None: @@ -182,7 +183,27 @@ def __init__( f"motion_file must be a shard directory, got: {motion_file}" ) data = _load_shard_dir(motion_path) - body_idx_np = body_indexes.cpu().numpy() + if body_names is None: + body_idx_np = body_indexes.cpu().numpy() + else: + dataset_body_names = [str(name) for name in np.asarray(data["body_names"])] + dataset_body_index_by_name = { + name: index for index, name in enumerate(dataset_body_names) + } + missing_body_names = [ + name for name in body_names if name not in dataset_body_index_by_name + ] + if missing_body_names: + raise ValueError( + "Motion dataset body_names do not contain all requested tracking " + f"bodies. Missing: {missing_body_names}. " + "Rebuild the dataset with the current G1 body metadata or update " + "motion command body_names." + ) + body_idx_np = np.asarray( + [dataset_body_index_by_name[name] for name in body_names], + dtype=np.int64, + ) self._joint_pos = np.asarray(data["joint_pos"], dtype=np.float32) # (T, 29) self._joint_vel = np.asarray(data["joint_vel"], dtype=np.float32) # (T, 29) @@ -462,6 +483,7 @@ def __init__(self, cfg: MotionCommandCfg, env: ManagerBasedRlEnv): self.motion = MotionLib( self.cfg.motion_file, self.body_indexes, + body_names=self.cfg.body_names, device=self.device, window_steps=self.cfg.window_steps, ) diff --git a/train_mimic/tasks/tracking/tracking_env_cfg.py b/train_mimic/tasks/tracking/tracking_env_cfg.py index be39a6d7..4ff88d25 100644 --- a/train_mimic/tasks/tracking/tracking_env_cfg.py +++ b/train_mimic/tasks/tracking/tracking_env_cfg.py @@ -254,6 +254,50 @@ def make_tracking_env_cfg() -> ManagerBasedRlEnvCfg: "alpha_range": (-0.1, 0.45), }, ), + "randomize_dexhand_payload_mass": EventTermCfg( + mode="startup", + func=dr.pseudo_inertia, + params={ + "asset_cfg": SceneEntityCfg("robot", body_names=()), # Set per-robot. + # Nominal is 0.5 kg per hand. Scale covers 0-1.0 kg. + "alpha_range": (-8.0, 0.34657359027997264), + }, + ), + "randomize_gimbal_payload_mass": EventTermCfg( + mode="startup", + func=dr.pseudo_inertia, + params={ + "asset_cfg": SceneEntityCfg("robot", body_names=()), # Set per-robot. + # Nominal is 0.25 kg. Scale covers 0-0.5 kg. + "alpha_range": (-8.0, 0.34657359027997264), + }, + ), + "randomize_dexhand_payload_pos": EventTermCfg( + mode="startup", + func=dr.body_pos, + params={ + "asset_cfg": SceneEntityCfg("robot", body_names=()), # Set per-robot. + "operation": "abs", + "ranges": { + 0: (0.04, 0.12), + 1: (-0.03, 0.03), + 2: (-0.03, 0.03), + }, + }, + ), + "randomize_gimbal_payload_pos": EventTermCfg( + mode="startup", + func=dr.body_pos, + params={ + "asset_cfg": SceneEntityCfg("robot", body_names=()), # Set per-robot. + "operation": "abs", + "ranges": { + 0: (0.03, 0.12), + 1: (-0.03, 0.03), + 2: (0.40, 0.50), + }, + }, + ), } ## From 5a4377d3c846f3c0456c4ee072bada6178527d80 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Wed, 20 May 2026 15:04:54 +0800 Subject: [PATCH 025/122] Add dexhand optional dependency --- docs/docs/configuration/config-reference.md | 4 ++-- docs/docs/getting-started/installation.md | 10 ++++++---- docs/docs/tutorials/pico-sim2real.md | 6 +++--- .../current/configuration/config-reference.md | 2 +- .../current/getting-started/installation.md | 9 +++++---- .../current/tutorials/pico-sim2real.md | 5 ++--- pyproject.toml | 3 +++ 7 files changed, 22 insertions(+), 17 deletions(-) diff --git a/docs/docs/configuration/config-reference.md b/docs/docs/configuration/config-reference.md index 5c7bfc95..5e2a16b9 100644 --- a/docs/docs/configuration/config-reference.md +++ b/docs/docs/configuration/config-reference.md @@ -119,8 +119,8 @@ Realtime Pico resume re-centers heading and ground-plane position before trackin ### Dexterous Hand (Pico sim2real) `dexterous_hand.enabled=true` requires `input.provider=pico4` and the optional -LinkerHand SDK submodule. Control is active only in `MOCAP`; inactive modes and -timeouts send the open pose. +`dexhand` extra. Control is active only in `MOCAP`; inactive modes and timeouts +send the open pose. | Field | Description | Default | |-------|-------------|---------| diff --git a/docs/docs/getting-started/installation.md b/docs/docs/getting-started/installation.md index 9c6cc634..9a9bfa6d 100644 --- a/docs/docs/getting-started/installation.md +++ b/docs/docs/getting-started/installation.md @@ -60,14 +60,16 @@ The receiver can run on a workstation PC or the robot onboard computer. See [Pico Sim2Sim](../tutorials/pico-sim2sim) and [Pico Sim2Real](../tutorials/pico-sim2real) for the full setup guides. -Optional LinkerHand L6 control for Pico sim2real uses a submodule SDK: +Optional LinkerHand L6 control for Pico sim2real is installed through the +`dexhand` extra. The SDK itself is provided by the repository submodule, so make +sure submodules are initialized first: ```bash -git submodule update --init third_party/linkerhand-python-sdk -pip install -e third_party/linkerhand-python-sdk +git submodule update --init --recursive +pip install -e '.[dexhand]' ``` -This SDK is only required when `dexterous_hand.enabled=true`. +This extra is only required when `dexterous_hand.enabled=true`. ## Verify Installation diff --git a/docs/docs/tutorials/pico-sim2real.md b/docs/docs/tutorials/pico-sim2real.md index 77a80a02..a3ee541f 100644 --- a/docs/docs/tutorials/pico-sim2real.md +++ b/docs/docs/tutorials/pico-sim2real.md @@ -145,11 +145,11 @@ matching side grip as a deadman switch; the matching trigger closes that hand. Hand control is active only in `MOCAP`. It sends the open pose in `STANDING`, `DAMPING`, paused mocap, frame timeout, and shutdown. -Install the SDK submodule first: +Install the dexhand extra first if it was not installed with the main Pico +profile: ```bash -git submodule update --init third_party/linkerhand-python-sdk -pip install -e third_party/linkerhand-python-sdk +pip install -e '.[dexhand]' ``` Before enabling full sim2real, verify the hand connection with a standalone diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md index 6bb02e7a..e0bfa1c4 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md @@ -138,7 +138,7 @@ target = clip(action, clip_range) * action_scale + default_dof_pos ### 灵巧手(Pico sim2real) `dexterous_hand.enabled=true` 要求 `input.provider=pico4`,并安装可选的 -LinkerHand SDK submodule。控制只在 `MOCAP` 中生效;非活动模式和超时会发送张开姿态。 +`dexhand` extra。控制只在 `MOCAP` 中生效;非活动模式和超时会发送张开姿态。 | 字段 | 说明 | 默认值 | |---|---|---| diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md index 41f2c6fd..748b5bc2 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md @@ -60,14 +60,15 @@ receiver 可以运行在工作站 PC,也可以运行在机器人 onboard 计 完整设置流程详见 [Pico Sim2Sim](../tutorials/pico-sim2sim) 和 [Pico Sim2Real](../tutorials/pico-sim2real)。 -Pico sim2real 可选的 LinkerHand L6 控制使用一个 submodule SDK: +Pico sim2real 可选的 LinkerHand L6 控制通过 `dexhand` extra 安装。SDK +本身由仓库 submodule 提供,因此需要先初始化 submodule: ```bash -git submodule update --init third_party/linkerhand-python-sdk -pip install -e third_party/linkerhand-python-sdk +git submodule update --init --recursive +pip install -e '.[dexhand]' ``` -只有在 `dexterous_hand.enabled=true` 时才需要安装该 SDK。 +只有在 `dexterous_hand.enabled=true` 时才需要安装这个 extra。 ## 验证安装 diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md index 0a91d69c..027e888e 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md @@ -136,11 +136,10 @@ Pico sim2real 可以用 Pico 手柄控制 LinkerHand L6。按住同侧 grip 作 同侧 trigger 控制对应手闭合。手控只在 `MOCAP` 中生效;在 `STANDING`、`DAMPING`、 mocap 暂停、帧超时和退出时都会发送张开姿态。 -先安装 SDK submodule: +如果主 Pico profile 没有包含手控支持,先安装 dexhand extra: ```bash -git submodule update --init third_party/linkerhand-python-sdk -pip install -e third_party/linkerhand-python-sdk +pip install -e '.[dexhand]' ``` 启用完整 sim2real 前,先用独立开合测试验证灵巧手连接: diff --git a/pyproject.toml b/pyproject.toml index a17b4714..e7d43c97 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,9 @@ pico4 = [ "pico-bridge[camera] @ https://github.com/BotRunner64/pico-bridge/releases/download/v0.2.0/pico_bridge-0.2.0-py3-none-any.whl", "teleopit[sim2real]", ] +dexhand = [ + "linkerhand-python-sdk @ file:third_party/linkerhand-python-sdk", +] [tool.setuptools.packages.find] where = ["."] From 8714b58c394f196b99364d7ed1abab560b27369b Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Wed, 20 May 2026 19:31:25 +0800 Subject: [PATCH 026/122] Scale up tracking policy model --- AGENTS.md | 4 ++-- docs/docs/reference/architecture.md | 2 +- train_mimic/tasks/tracking/config/rl.py | 8 ++++---- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 24256b2e..7c289809 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -20,7 +20,7 @@ Module-internal isolation: all modules run in-process and communicate via `InPro - Training task: `General-Tracking-G1` - Inference observation: `velcmd_history` (166D, dual-input ONNX with `obs` + `obs_history`) -- TemporalCNN actor/critic with larger dims (1024,512,256,256,128) +- TemporalCNN actor/critic with scaled dims (2048,1024,512,256,128) - Realtime inference uses a retargeted-reference timeline before observation build; `reference_steps=[0]` is the default production path ## Directory Structure @@ -179,7 +179,7 @@ Runtime constraints: ### Training Task The single supported training task is `General-Tracking-G1` (experiment name: `g1_general_tracking`). -- Uses TemporalCNN actor/critic with larger dims (1024,512,256,256,128) +- Uses TemporalCNN actor/critic with scaled dims (2048,1024,512,256,128) - 166D `velcmd_history` observation, dual-input ONNX export - Training env uses `sampling_mode="uniform"` - Playback/benchmark use `play=True`, which switches motion sampling to `start` diff --git a/docs/docs/reference/architecture.md b/docs/docs/reference/architecture.md index fb0915e0..ab1fdae9 100644 --- a/docs/docs/reference/architecture.md +++ b/docs/docs/reference/architecture.md @@ -57,7 +57,7 @@ train_mimic/scripts/data | Training task | `General-Tracking-G1` | | Inference observation | `velcmd_history` (166D) | | ONNX signature | Dual-input `obs` (166D) + `obs_history` | -| Actor/Critic | TemporalCNN (1024, 512, 256, 256, 128) | +| Actor/Critic | TemporalCNN (2048, 1024, 512, 256, 128) | | Training sampling | `uniform`; playback/benchmark use `start` | | Training `window_steps` | `[0]` | | Data format | Shard directories only (`shard_*.npz`) | diff --git a/train_mimic/tasks/tracking/config/rl.py b/train_mimic/tasks/tracking/config/rl.py index 2c674a36..c35715cc 100644 --- a/train_mimic/tasks/tracking/config/rl.py +++ b/train_mimic/tasks/tracking/config/rl.py @@ -10,7 +10,7 @@ "train_mimic.tasks.tracking.rl.temporal_cnn_model:TemporalCNNModel" ) _CNN_CFG: dict = { - "output_channels": (128, 64, 32), + "output_channels": (256, 128, 64), "kernel_size": 3, "activation": "elu", "global_pool": "avg", @@ -24,7 +24,7 @@ def make_general_tracking_ppo_runner_cfg( return RslRlOnPolicyRunnerCfg( actor=RslRlModelCfg( class_name=_TEMPORAL_CNN_MODEL_CLASS, - hidden_dims=(1024, 512, 256, 256, 128), + hidden_dims=(2048, 1024, 512, 256, 128), activation="elu", obs_normalization=True, cnn_cfg=_CNN_CFG, @@ -36,7 +36,7 @@ def make_general_tracking_ppo_runner_cfg( ), critic=RslRlModelCfg( class_name=_TEMPORAL_CNN_MODEL_CLASS, - hidden_dims=(1024, 512, 256, 256, 128), + hidden_dims=(2048, 1024, 512, 256, 128), activation="elu", obs_normalization=True, cnn_cfg=_CNN_CFG, @@ -48,7 +48,7 @@ def make_general_tracking_ppo_runner_cfg( entropy_coef=0.005, num_learning_epochs=5, num_mini_batches=4, - learning_rate=1.0e-3, + learning_rate=5.0e-4, schedule="adaptive", gamma=0.99, lam=0.95, From 6a01fa9966327ca9a7a79afe94693dd38d4edcaf Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Wed, 20 May 2026 22:23:42 +0800 Subject: [PATCH 027/122] test: verify payload collision filtering --- tests/test_task_registry.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/test_task_registry.py b/tests/test_task_registry.py index 7cb02df1..df22ce20 100644 --- a/tests/test_task_registry.py +++ b/tests/test_task_registry.py @@ -31,6 +31,14 @@ def test_general_tracking_task_is_registered() -> None: assert body_name in { robot_model.body(i).name for i in range(robot_model.nbody) } + for geom_name in ( + "left_dexhand_payload_collision", + "right_dexhand_payload_collision", + "head_gimbal_payload_collision", + ): + geom = robot_model.geom(geom_name) + assert int(geom.contype[0]) == 0 + assert int(geom.conaffinity[0]) == 0 for terms in (actor_terms, critic_terms): assert "projected_gravity" in terms assert "ref_base_lin_vel_b" in terms @@ -82,7 +90,7 @@ def test_general_tracking_task_is_registered() -> None: assert lower_pd_asset.actuator_ids == [1, 2, 4, 5] rl_cfg = load_rl_cfg(DEFAULT_TASK) assert rl_cfg.experiment_name == GENERAL_TRACKING_EXPERIMENT_NAME - assert rl_cfg.actor.hidden_dims == (1024, 512, 256, 256, 128) + assert rl_cfg.actor.hidden_dims == (2048, 1024, 512, 256, 128) assert load_runner_cls(DEFAULT_TASK) is MotionTrackingOnPolicyRunner From 2dcfc0b66a352ef2ddd83f27c0abedba0e3e0a22 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Thu, 21 May 2026 19:07:44 +0800 Subject: [PATCH 028/122] Offset train worker seeds by rank --- tests/test_train_script.py | 10 ++++++++++ train_mimic/scripts/train.py | 10 +++++++++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/tests/test_train_script.py b/tests/test_train_script.py index f1dd315c..fba33b7b 100644 --- a/tests/test_train_script.py +++ b/tests/test_train_script.py @@ -104,6 +104,16 @@ def test_resolve_device_rejects_wrong_distributed_device(self, monkeypatch: pyte with pytest.raises(ValueError, match="LOCAL_RANK=1"): train._resolve_device(_args(device="cuda:0"), _TorchStub()) + def test_resolve_worker_seed_offsets_by_global_rank(self) -> None: + assert train._resolve_worker_seed(42, env={"WORLD_SIZE": "4", "RANK": "0"}) == 42 + assert train._resolve_worker_seed(42, env={"WORLD_SIZE": "4", "RANK": "3"}) == 300051 + + def test_resolve_worker_seed_defaults_to_base_seed_without_rank(self) -> None: + assert train._resolve_worker_seed(123, env={}) == 123 + + def test_resolve_worker_seed_ignores_rank_outside_distributed_mode(self) -> None: + assert train._resolve_worker_seed(42, env={"WORLD_SIZE": "1", "RANK": "3"}) == 42 + def test_main_uses_launcher_branch(self, monkeypatch: pytest.MonkeyPatch) -> None: called: dict[str, object] = {} diff --git a/train_mimic/scripts/train.py b/train_mimic/scripts/train.py index 7126fa35..1e38c892 100644 --- a/train_mimic/scripts/train.py +++ b/train_mimic/scripts/train.py @@ -219,6 +219,14 @@ def _resolve_device(args: argparse.Namespace, torch_module: object) -> str: return "cuda:0" if torch_module.cuda.is_available() else "cpu" +def _resolve_worker_seed(base_seed: int, env: dict[str, str] | None = None) -> int: + runtime_env = os.environ if env is None else env + if not _is_distributed_env(runtime_env): + return base_seed + global_rank = int(runtime_env.get("RANK", "0")) + return base_seed + global_rank * 100003 + + def _launch_multi_gpu(args: argparse.Namespace, argv: Sequence[str]) -> None: _validate_multi_gpu_args(args) command = _build_torchrun_command(args, argv) @@ -291,7 +299,7 @@ def _handle_shutdown(signum: int, _frame: Any) -> None: agent_cfg.logger = "tensorboard" # CLI overrides - env_cfg.seed = args.seed + env_cfg.seed = _resolve_worker_seed(args.seed) if args.num_envs is not None: env_cfg.scene.num_envs = args.num_envs if args.motion_file is not None: From ac75cb165c6f5d72b3d1f5c6724e60ce81f3fea3 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Fri, 22 May 2026 16:47:13 +0800 Subject: [PATCH 029/122] Update self-collision wrist exclusions --- tests/test_task_registry.py | 2 -- train_mimic/tasks/tracking/config/env.py | 4 +--- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/test_task_registry.py b/tests/test_task_registry.py index df22ce20..71ecd80e 100644 --- a/tests/test_task_registry.py +++ b/tests/test_task_registry.py @@ -62,8 +62,6 @@ def test_general_tracking_task_is_registered() -> None: assert sensors["self_collision"].primary.mode == "body" assert sensors["self_collision"].primary.pattern == r".*" assert sensors["self_collision"].primary.exclude == ( - "left_ankle_roll_link", - "right_ankle_roll_link", "left_wrist_yaw_link", "right_wrist_yaw_link", ) diff --git a/train_mimic/tasks/tracking/config/env.py b/train_mimic/tasks/tracking/config/env.py index 338af2c3..025e09f3 100644 --- a/train_mimic/tasks/tracking/config/env.py +++ b/train_mimic/tasks/tracking/config/env.py @@ -134,8 +134,6 @@ def _payload_g1_robot_cfg(): def _configure_self_collision_reward(cfg: ManagerBasedRlEnvCfg) -> None: excluded_body_names = ( - "left_ankle_roll_link", - "right_ankle_roll_link", "left_wrist_yaw_link", "right_wrist_yaw_link", ) @@ -143,7 +141,7 @@ def _configure_self_collision_reward(cfg: ManagerBasedRlEnvCfg) -> None: *tuple(getattr(cfg.scene, "sensors", ()) or ()), ContactSensorCfg( name="self_collision", - # Exclude only primary bodies: wrist/ankle vs torso is still caught by torso. + # Exclude only primary wrist bodies; wrist vs torso is still caught by torso. primary=ContactMatch( mode="body", pattern=r".*", From d9cdc243edc10a1357a64a3d7eb345d67f163777 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Fri, 22 May 2026 21:43:48 +0800 Subject: [PATCH 030/122] Adapt Pico input to pico-bridge 0.2.1 --- AGENTS.md | 2 +- CHANGELOG.md | 4 ++ README.md | 1 + docs/docs/configuration/config-reference.md | 2 +- docs/docs/getting-started/installation.md | 4 +- docs/docs/tutorials/pico-sim2real.md | 4 +- docs/docs/tutorials/pico-sim2sim.md | 8 ++-- .../current/configuration/config-reference.md | 2 +- .../current/getting-started/installation.md | 3 +- .../current/tutorials/pico-sim2real.md | 3 +- .../current/tutorials/pico-sim2sim.md | 7 ++- pyproject.toml | 2 +- teleopit/inputs/pico4_provider.py | 43 ++++++++++++------- teleopit/inputs/pico_video.py | 2 +- tests/test_pico4_provider.py | 11 ++--- 15 files changed, 51 insertions(+), 47 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 7c289809..c7dd15b3 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -130,7 +130,7 @@ target_dof_pos = clip(action, -10, 10) × action_scale + default_dof_pos ### Pico4 Realtime Input - `Pico4InputProvider` reads realtime body tracking from the in-process `pico_bridge.PicoBridge` - The pico-bridge receiver runs on the Teleopit host, which can be a workstation PC or robot onboard computer; do not maintain a separate onboard Pico input mode -- pico-bridge 0.2.0 is the supported runtime; camera preview uses `PicoBridge(video="frames").push_video_frame(rgb_uint8)` +- pico-bridge 0.2.1 is the supported runtime; camera preview uses `PicoBridge(video="frames").push_video_frame(rgb_uint8)` - Pico video preview is optional and disabled by default; sim2sim uses the MuJoCo `d435i_rgb` camera and sim2real uses RealSense when `input.video.enabled=true` - Bone naming follows `pico_bridge_to_g1.json` - The provider applies an input-space transform to match the current retarget config diff --git a/CHANGELOG.md b/CHANGELOG.md index 56573946..124d5579 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Changelog +## [Unreleased] + +- 支持 pico-bridge 0.2.1,并适配其修正后的 tracking pose 语义。 + ## [0.3.0] - 2026-05-12 - 重构实时输入栈,Pico 4 统一使用 pico-bridge 0.2.0 in-process receiver,并移除旧 ZMQ/onboard Pico 路径。 diff --git a/README.md b/README.md index 03acbf5f..26f05c39 100644 --- a/README.md +++ b/README.md @@ -62,6 +62,7 @@ Full docs at **[BotRunner64.github.io/Teleopit](https://BotRunner64.github.io/Te ### Unreleased +- Bumped Pico input support to pico-bridge 0.2.1 and its corrected tracking pose semantics. - Added optional Pico controller control for LinkerHand L6 in sim2real, backed by the LinkerHand SDK submodule. ### v0.3.0 (2026-05-12) diff --git a/docs/docs/configuration/config-reference.md b/docs/docs/configuration/config-reference.md index 5e2a16b9..94ca097f 100644 --- a/docs/docs/configuration/config-reference.md +++ b/docs/docs/configuration/config-reference.md @@ -69,7 +69,7 @@ Complete reference for all configurable fields. | `input.bridge_advertise_ip` | Optional advertised host IP override | `null` | | `input.bridge_start_timeout` | Timeout while starting the bridge | `10.0` | | `input.bridge_history_size` | Pico frame history retained by the bridge | `120` | -| `input.video.enabled` | Stream host camera preview back to Pico through pico-bridge 0.2.0 | `false` | +| `input.video.enabled` | Stream host camera preview back to Pico through pico-bridge 0.2.1 | `false` | | `input.video.source` | Video source: `mujoco`, `realsense`, or `test-pattern` | `null` | | `input.video.width` / `height` / `fps` | Video capture/render settings | `1280` / `720` / `30` | | `input.video.device` | Optional RealSense serial | `null` | diff --git a/docs/docs/getting-started/installation.md b/docs/docs/getting-started/installation.md index 072be0ee..1771c6c8 100644 --- a/docs/docs/getting-started/installation.md +++ b/docs/docs/getting-started/installation.md @@ -56,9 +56,7 @@ pip install -e '.[pico4]' ``` Teleopit uses the in-process `pico_bridge.PicoBridge` receiver for Pico tracking. -For Teleopit 0.3.0, use pico-bridge 0.2.0. Do not upgrade the receiver to -pico-bridge 0.2.1, because 0.2.1 changes interface semantics that Teleopit -0.3.0 does not target. +Teleopit targets pico-bridge 0.2.1 and its `pico_native` tracking semantics. The receiver can run on a workstation PC or the robot onboard computer. See [Pico Sim2Sim](../tutorials/pico-sim2sim) and [Pico Sim2Real](../tutorials/pico-sim2real) for the full setup guides. diff --git a/docs/docs/tutorials/pico-sim2real.md b/docs/docs/tutorials/pico-sim2real.md index a0c6957b..e33bff46 100644 --- a/docs/docs/tutorials/pico-sim2real.md +++ b/docs/docs/tutorials/pico-sim2real.md @@ -21,9 +21,7 @@ There are two deployment styles: Both styles use `Pico4InputProvider` and the in-process pico-bridge receiver. There is no separate onboard Pico input mode. -For Teleopit 0.3.0, keep the host receiver on pico-bridge 0.2.0. pico-bridge -0.2.1 changes interface semantics and is not the supported receiver version for -this Teleopit release. +Teleopit targets pico-bridge 0.2.1 and its `pico_native` tracking semantics. ## 1. Install Runtime Dependencies diff --git a/docs/docs/tutorials/pico-sim2sim.md b/docs/docs/tutorials/pico-sim2sim.md index 06bccf2c..b72bfc36 100644 --- a/docs/docs/tutorials/pico-sim2sim.md +++ b/docs/docs/tutorials/pico-sim2sim.md @@ -47,9 +47,7 @@ Teleopit starts `pico_bridge.PicoBridge` in-process through `Pico4InputProvider`. The same Pico input path is used later for wired and onboard sim2real deployment. -For Teleopit 0.3.0, keep the host receiver on pico-bridge 0.2.0. pico-bridge -0.2.1 changes interface semantics and is not the supported receiver version for -this Teleopit release. +Teleopit targets pico-bridge 0.2.1 and its `pico_native` tracking semantics. ## 3. Download Assets @@ -96,7 +94,7 @@ The default Pico pause button is `A`. Supported overrides include `B`, `X`, `Y`, ## Optional Headset Video Preview -pico-bridge 0.2.0 can show a host-side camera stream in the headset. In +pico-bridge 0.2.1 can show a host-side camera stream in the headset. In simulation, Teleopit can stream the MuJoCo `d435i_rgb` camera: ```bash @@ -141,7 +139,7 @@ input.video.enabled=true | Symptom | Likely Cause | Fix | |---------|--------------|-----| | `ImportError: pico_bridge` | Pico extra not installed | Run `pip install -e '.[pico4]'` | -| Startup says pico-bridge is too old | Installed receiver does not support video args | Reinstall the Pico extra so pico-bridge 0.2.0 is used | +| Startup says pico-bridge is too old | Installed receiver does not support the required API or tracking semantics | Reinstall the Pico extra so pico-bridge 0.2.1 is used | | `TimeoutError: No Pico4 body data` | Headset is not connected or body tracking is inactive | Check the headset app, network, and `input.pico4_timeout` | | Discovery cannot find the host | Wrong advertised IP or blocked UDP | Set `input.bridge_advertise_ip=` and confirm UDP port `63901` is reachable | | Sim robot does not follow | Loop is still in `STANDING` | Press `Y` after tracking is ready | diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md index e0bfa1c4..5a3a3b65 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md @@ -86,7 +86,7 @@ target = clip(action, clip_range) * action_scale + default_dof_pos | `bridge_advertise_ip` | str/null | `null` | 可选的 host 广播 IP 覆盖 | | `bridge_start_timeout` | float | `10.0` | 启动 bridge 的超时时间 | | `bridge_history_size` | int | `120` | bridge 保留的 Pico 帧历史长度 | -| `video.enabled` | bool | `false` | 通过 pico-bridge 0.2.0 将 host 相机预览发送回 Pico | +| `video.enabled` | bool | `false` | 通过 pico-bridge 0.2.1 将 host 相机预览发送回 Pico | | `video.source` | str/null | `null` | 视频源:`mujoco`、`realsense` 或 `test-pattern` | | `video.width` / `height` / `fps` | int | `1280` / `720` / `30` | 视频采集/渲染设置 | | `video.device` | str/null | `null` | 可选的 RealSense 序列号 | diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md index 086a6d5d..5fc8cd81 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md @@ -56,8 +56,7 @@ pip install -e '.[pico4]' ``` Teleopit 使用进程内的 `pico_bridge.PicoBridge` receiver 接收 Pico 追踪数据。 -Teleopit 0.3.0 请使用 pico-bridge 0.2.0,不要升级到 pico-bridge 0.2.1; -0.2.1 修改了接口语义,Teleopit 0.3.0 未按该语义适配。 +Teleopit 面向 pico-bridge 0.2.1 及其 `pico_native` tracking 语义。 receiver 可以运行在工作站 PC,也可以运行在机器人 onboard 计算机。 完整设置流程详见 [Pico Sim2Sim](../tutorials/pico-sim2sim) 和 [Pico Sim2Real](../tutorials/pico-sim2real)。 diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md index 74083e19..7d17be7d 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md @@ -21,8 +21,7 @@ Pico 头显 -> Teleopit host -> retarget -> RL policy -> g1_bridge_sdk -> G1 两种方式都使用 `Pico4InputProvider` 和进程内 pico-bridge receiver。不存在单独的 onboard Pico 输入模式。 -Teleopit 0.3.0 请保持 host receiver 使用 pico-bridge 0.2.0。pico-bridge -0.2.1 修改了接口语义,不是该 Teleopit 版本支持的 receiver 版本。 +Teleopit 面向 pico-bridge 0.2.1 及其 `pico_native` tracking 语义。 ## 1. 安装运行时依赖 diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2sim.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2sim.md index 867f78f7..3601a9f1 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2sim.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2sim.md @@ -46,8 +46,7 @@ python -c "from pico_bridge import PicoBridge; print('OK')" Teleopit 会通过 `Pico4InputProvider` 在进程内启动 `pico_bridge.PicoBridge`。 后续 wired 和 onboard sim2real 部署也使用同一条 Pico 输入路径。 -Teleopit 0.3.0 请保持 host receiver 使用 pico-bridge 0.2.0。pico-bridge -0.2.1 修改了接口语义,不是该 Teleopit 版本支持的 receiver 版本。 +Teleopit 面向 pico-bridge 0.2.1 及其 `pico_native` tracking 语义。 ## 3. 下载资源 @@ -88,7 +87,7 @@ Pico 暂停/恢复会冻结 mocap session;它不是切回 `STANDING`。 ## 可选头显视频预览 -pico-bridge 0.2.0 可以在头显中显示 host 侧视频流。在仿真中,Teleopit 可以推送 +pico-bridge 0.2.1 可以在头显中显示 host 侧视频流。在仿真中,Teleopit 可以推送 MuJoCo `d435i_rgb` 相机: ```bash @@ -132,7 +131,7 @@ input.video.enabled=true | 现象 | 可能原因 | 解决方法 | |------|----------|----------| | `ImportError: pico_bridge` | 未安装 Pico extra | 执行 `pip install -e '.[pico4]'` | -| 启动提示 pico-bridge 太旧 | 已安装 receiver 不支持视频参数 | 重新安装 Pico extra,确保使用 pico-bridge 0.2.0 | +| 启动提示 pico-bridge 太旧 | 已安装 receiver 不支持所需 API 或 tracking 语义 | 重新安装 Pico extra,确保使用 pico-bridge 0.2.1 | | `TimeoutError: No Pico4 body data` | 头显未连接或 body tracking 未激活 | 检查头显 app、网络和 `input.pico4_timeout` | | discovery 找不到 host | 广播 IP 不对或 UDP 被阻断 | 设置 `input.bridge_advertise_ip=`,确认 UDP 端口 `63901` 可达 | | 仿真机器人不跟随 | 循环仍在 `STANDING` | 追踪准备好后按 `Y` | diff --git a/pyproject.toml b/pyproject.toml index e7d43c97..59de8612 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,7 @@ train = [ "tqdm>=4.65.0", ] pico4 = [ - "pico-bridge[camera] @ https://github.com/BotRunner64/pico-bridge/releases/download/v0.2.0/pico_bridge-0.2.0-py3-none-any.whl", + "pico-bridge[camera] @ https://github.com/BotRunner64/pico-bridge/releases/download/v0.2.1/pico_bridge-0.2.1-py3-none-any.whl", "teleopit[sim2real]", ] dexhand = [ diff --git a/teleopit/inputs/pico4_provider.py b/teleopit/inputs/pico4_provider.py index a1020f29..663d2ce5 100644 --- a/teleopit/inputs/pico4_provider.py +++ b/teleopit/inputs/pico4_provider.py @@ -1,7 +1,7 @@ """Pico4 VR full-body motion capture input provider. Uses the in-process ``pico_bridge`` receiver to collect PICO tracking frames. -The provider converts native PICO/Unity poses (meters, xyzw quaternions) into +The provider converts native PICO poses (meters, xyzw quaternions) into Teleopit's realtime ``HumanFrame`` format. """ @@ -10,6 +10,7 @@ from collections import deque from dataclasses import dataclass import inspect +from importlib.metadata import PackageNotFoundError, version import logging import threading import time @@ -32,7 +33,7 @@ logger = logging.getLogger(__name__) -# PICO/Unity -> Teleopit retarget input space. +# PICO native -> Teleopit retarget input space. _INPUT_TO_TELEOPIT_MATRIX = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]], dtype=np.float64) _INPUT_TO_TELEOPIT_QUAT = R.from_matrix(_INPUT_TO_TELEOPIT_MATRIX).as_quat(scalar_first=True) @@ -122,6 +123,20 @@ def _bridge_accepts_video_enabled(bridge_cls: type[Any]) -> bool: return any(param.kind == inspect.Parameter.VAR_KEYWORD for param in parameters.values()) +def _installed_pico_bridge_version() -> tuple[int, ...] | None: + try: + raw_version = version("pico-bridge") + except PackageNotFoundError: + return None + release = raw_version.split("+", 1)[0].split("-", 1)[0] + parts: list[int] = [] + for part in release.split("."): + if not part.isdigit(): + break + parts.append(int(part)) + return tuple(parts) if parts else None + + def _coordinate_transform_input(body_pose_dict: dict[str, list]) -> dict[str, list]: """Transform provider-space poses into Teleopit's expected coordinates.""" for body_name, value in body_pose_dict.items(): @@ -167,10 +182,16 @@ def __init__( "pico_bridge is required for Pico4 input. Install the receiver package, " "for example: pip install -e '.[pico4]'" ) from exc + installed_version = _installed_pico_bridge_version() + if installed_version is None or installed_version < (0, 2, 1): + raise RuntimeError( + "pico_bridge >= 0.2.1 is required for Pico4 input. Reinstall the Pico extra with " + "pip install -e '.[pico4]' so Teleopit receives pico_native tracking semantics." + ) bridge_cls = PicoBridge if not _bridge_accepts_video_enabled(bridge_cls): raise RuntimeError( - "pico_bridge >= 0.2.0 is required for Pico4 input. Reinstall the Pico extra with " + "pico_bridge >= 0.2.1 is required for Pico4 input. Reinstall the Pico extra with " "pip install -e '.[pico4]' so PicoBridge accepts video_enabled and push_video_frame()." ) @@ -278,10 +299,10 @@ def get_controller_snapshot(self) -> PicoControllerSnapshot | None: return self._controller_snapshot def push_video_frame(self, frame: NDArray[np.uint8]) -> int: - """Push one RGB camera frame to pico-bridge 0.2.0 video output.""" + """Push one RGB camera frame to pico-bridge 0.2.1 video output.""" push_video_frame = getattr(self._bridge, "push_video_frame", None) if not callable(push_video_frame): - raise RuntimeError("Installed pico_bridge does not expose push_video_frame(); use pico-bridge 0.2.0") + raise RuntimeError("Installed pico_bridge does not expose push_video_frame(); use pico-bridge 0.2.1") return int(push_video_frame(frame)) def has_frame(self) -> bool: @@ -442,11 +463,10 @@ def _read_controller_state(controller: Any) -> PicoControllerState: @staticmethod def _convert_body_joints_to_frame(body_joints: NDArray[np.float64]) -> HumanFrame: - body_joints = Pico4InputProvider._normalize_pico_bridge_body_joints(body_joints) body_pose_dict: dict[str, list] = {} for i, joint_name in enumerate(BODY_JOINT_NAMES): pos = [body_joints[i][0], body_joints[i][1], body_joints[i][2]] - # pico_bridge returns [x, y, z, qx, qy, qz, qw]. + # pico_bridge 0.2.1 returns pico_native [x, y, z, qx, qy, qz, qw]. rot = [body_joints[i][6], body_joints[i][3], body_joints[i][4], body_joints[i][5]] body_pose_dict[joint_name] = [pos, rot] @@ -475,12 +495,3 @@ def _apply_ground_lift(self, human_frame: HumanFrame) -> HumanFrame: for name, (pos, quat) in human_frame.items(): lifted[name] = (np.asarray(pos, dtype=np.float64) + z_offset, np.asarray(quat, dtype=np.float64)) return lifted - - @staticmethod - def _normalize_pico_bridge_body_joints(body_joints: NDArray[np.float64]) -> NDArray[np.float64]: - """Match Teleopit's calibrated Pico body-pose convention.""" - converted = np.array(body_joints, dtype=np.float64, copy=True) - converted[:, 2] *= -1.0 - converted[:, 5] *= -1.0 - converted[:, 6] *= -1.0 - return converted diff --git a/teleopit/inputs/pico_video.py b/teleopit/inputs/pico_video.py index 500b3104..7f01f43f 100644 --- a/teleopit/inputs/pico_video.py +++ b/teleopit/inputs/pico_video.py @@ -1,4 +1,4 @@ -"""Optional camera-to-Pico video streaming for pico-bridge 0.2.0.""" +"""Optional camera-to-Pico video streaming for pico-bridge 0.2.1.""" from __future__ import annotations diff --git a/tests/test_pico4_provider.py b/tests/test_pico4_provider.py index b64edd8f..c6f6cabe 100644 --- a/tests/test_pico4_provider.py +++ b/tests/test_pico4_provider.py @@ -130,8 +130,8 @@ def test_pico4_provider_starts_pico_bridge_receiver_with_config() -> None: assert bridge.closed is True -def test_pico4_provider_requires_pico_bridge_0_2_signature() -> None: - with pytest.raises(RuntimeError, match=r"pico_bridge >= 0\.2\.0"): +def test_pico4_provider_requires_pico_bridge_0_2_1_signature() -> None: + with pytest.raises(RuntimeError, match=r"pico_bridge >= 0\.2\.1"): Pico4InputProvider(timeout=0.01, bridge_cls=_LegacyBridge) @@ -153,15 +153,12 @@ def push_video_frame(self: _FakeBridge, frame: np.ndarray) -> int: delattr(_FakeBridge, "push_video_frame") -def test_pico4_provider_normalizes_pico_bridge_body_pose_convention() -> None: +def test_pico4_provider_converts_pico_native_body_pose_convention() -> None: body_poses = _body_poses(0.0) body_poses[0] = [1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9] - converted = Pico4InputProvider._normalize_pico_bridge_body_joints(body_poses) - np.testing.assert_allclose(converted[0], [1.0, 2.0, -3.0, 0.1, 0.2, -0.3, -0.9]) - frame = Pico4InputProvider._convert_body_joints_to_frame(body_poses) - np.testing.assert_allclose(frame["Pelvis"][0], [1.0, 3.0, 2.0], atol=1e-6) + np.testing.assert_allclose(frame["Pelvis"][0], [1.0, -3.0, 2.0], atol=1e-6) def test_pico4_provider_applies_fixed_ground_lift_from_first_real_frame() -> None: From 47517b388d2e8b51d1f6d7746a32dae121ff603c Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Mon, 25 May 2026 15:38:33 +0800 Subject: [PATCH 031/122] Fix sim standing retarget roll --- teleopit/sim/loop.py | 4 +++- tests/test_sim_loop.py | 50 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/teleopit/sim/loop.py b/teleopit/sim/loop.py index 9958b110..b69ad238 100644 --- a/teleopit/sim/loop.py +++ b/teleopit/sim/loop.py @@ -8,6 +8,7 @@ from numpy.typing import NDArray from teleopit.constants import FULL_QPOS_DIM, ROOT_DIM +from teleopit.controllers.observation import align_motion_qpos_yaw from teleopit.controllers.qpos_interpolator import QposInterpolator from teleopit.runtime.reference_config import parse_reference_config from teleopit.inputs.realtime_packet import RealtimeInputPacket @@ -167,7 +168,8 @@ def _build_standing_qpos(self, state: RobotState) -> Float64Array: standing_qpos = np.zeros(FULL_QPOS_DIM, dtype=np.float64) if state.base_pos is not None: standing_qpos[0:3] = np.asarray(state.base_pos, dtype=np.float64)[:3] - standing_qpos[3:7] = np.asarray(state.quat, dtype=np.float64)[:4] + standing_qpos[3] = 1.0 + align_motion_qpos_yaw(np.asarray(state.quat, dtype=np.float32), standing_qpos) standing_qpos[7:7 + self._num_actions] = self._default_dof_pos.astype(np.float64)[: self._num_actions] return standing_qpos diff --git a/tests/test_sim_loop.py b/tests/test_sim_loop.py index 56408951..e79e0bba 100644 --- a/tests/test_sim_loop.py +++ b/tests/test_sim_loop.py @@ -130,6 +130,56 @@ def add_frame(self, data: dict[str, object]) -> None: self.frames.append(data) +def _quat_mul(a: np.ndarray, b: np.ndarray) -> np.ndarray: + aw, ax, ay, az = a + bw, bx, by, bz = b + return np.array( + [ + aw * bw - ax * bx - ay * by - az * bz, + aw * bx + ax * bw + ay * bz - az * by, + aw * by - ax * bz + ay * bw + az * bx, + aw * bz + ax * by - ay * bx + az * bw, + ], + dtype=np.float32, + ) + + +def test_standing_qpos_keeps_yaw_but_drops_roll() -> None: + from teleopit.sim.loop import SimulationLoop + + bus = InProcessBus() + robot = _DummyRobot() + loop = SimulationLoop( + robot=robot, + controller=_DummyController(), + obs_builder=_DummyObsBuilder(), + bus=bus, + cfg={"policy_hz": 50.0, "pd_hz": 50.0, "realtime": False, "transition_duration": 0.0}, + viewers=set(), + ) + + yaw = np.float32(np.pi / 2.0) + roll = np.float32(np.pi / 6.0) + yaw_quat = np.array([np.cos(yaw / 2.0), 0.0, 0.0, np.sin(yaw / 2.0)], dtype=np.float32) + roll_quat = np.array([np.cos(roll / 2.0), np.sin(roll / 2.0), 0.0, 0.0], dtype=np.float32) + tilted_quat = _quat_mul(yaw_quat, roll_quat) + state = RobotState( + qpos=np.array([0.2, -0.1], dtype=np.float32), + qvel=np.zeros(2, dtype=np.float32), + quat=tilted_quat, + ang_vel=np.zeros(3, dtype=np.float32), + timestamp=0.0, + base_pos=np.array([1.0, 2.0, 0.9], dtype=np.float32), + base_lin_vel=np.zeros(3, dtype=np.float32), + ) + + standing_qpos = loop._build_standing_qpos(state) + + np.testing.assert_allclose(standing_qpos[0:3], state.base_pos, atol=1e-6) + np.testing.assert_allclose(standing_qpos[3:7], yaw_quat, atol=1e-6) + np.testing.assert_allclose(standing_qpos[7:9], robot.default_dof_pos, atol=1e-6) + + @requires_mujoco def test_simulation_loop_runs_and_records_without_viewers() -> None: from teleopit.sim.loop import SimulationLoop From 58a3e58ee3ae88b5354d501c99707a912858b27e Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Mon, 25 May 2026 19:42:20 +0800 Subject: [PATCH 032/122] Sync TemporalCNN config and docs --- AGENTS.md | 4 ++-- docs/docs/reference/architecture.md | 2 +- tests/test_task_registry.py | 2 +- train_mimic/tasks/tracking/config/rl.py | 8 ++++---- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index c7dd15b3..e84903af 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -20,7 +20,7 @@ Module-internal isolation: all modules run in-process and communicate via `InPro - Training task: `General-Tracking-G1` - Inference observation: `velcmd_history` (166D, dual-input ONNX with `obs` + `obs_history`) -- TemporalCNN actor/critic with scaled dims (2048,1024,512,256,128) +- TemporalCNN actor/critic with scaled dims (1024,512,256,256,128) - Realtime inference uses a retargeted-reference timeline before observation build; `reference_steps=[0]` is the default production path ## Directory Structure @@ -179,7 +179,7 @@ Runtime constraints: ### Training Task The single supported training task is `General-Tracking-G1` (experiment name: `g1_general_tracking`). -- Uses TemporalCNN actor/critic with scaled dims (2048,1024,512,256,128) +- Uses TemporalCNN actor/critic with scaled dims (1024,512,256,256,128) - 166D `velcmd_history` observation, dual-input ONNX export - Training env uses `sampling_mode="uniform"` - Playback/benchmark use `play=True`, which switches motion sampling to `start` diff --git a/docs/docs/reference/architecture.md b/docs/docs/reference/architecture.md index ab1fdae9..fb0915e0 100644 --- a/docs/docs/reference/architecture.md +++ b/docs/docs/reference/architecture.md @@ -57,7 +57,7 @@ train_mimic/scripts/data | Training task | `General-Tracking-G1` | | Inference observation | `velcmd_history` (166D) | | ONNX signature | Dual-input `obs` (166D) + `obs_history` | -| Actor/Critic | TemporalCNN (2048, 1024, 512, 256, 128) | +| Actor/Critic | TemporalCNN (1024, 512, 256, 256, 128) | | Training sampling | `uniform`; playback/benchmark use `start` | | Training `window_steps` | `[0]` | | Data format | Shard directories only (`shard_*.npz`) | diff --git a/tests/test_task_registry.py b/tests/test_task_registry.py index 71ecd80e..f3974a3b 100644 --- a/tests/test_task_registry.py +++ b/tests/test_task_registry.py @@ -88,7 +88,7 @@ def test_general_tracking_task_is_registered() -> None: assert lower_pd_asset.actuator_ids == [1, 2, 4, 5] rl_cfg = load_rl_cfg(DEFAULT_TASK) assert rl_cfg.experiment_name == GENERAL_TRACKING_EXPERIMENT_NAME - assert rl_cfg.actor.hidden_dims == (2048, 1024, 512, 256, 128) + assert rl_cfg.actor.hidden_dims == (1024, 512, 256, 256, 128) assert load_runner_cls(DEFAULT_TASK) is MotionTrackingOnPolicyRunner diff --git a/train_mimic/tasks/tracking/config/rl.py b/train_mimic/tasks/tracking/config/rl.py index c35715cc..2c674a36 100644 --- a/train_mimic/tasks/tracking/config/rl.py +++ b/train_mimic/tasks/tracking/config/rl.py @@ -10,7 +10,7 @@ "train_mimic.tasks.tracking.rl.temporal_cnn_model:TemporalCNNModel" ) _CNN_CFG: dict = { - "output_channels": (256, 128, 64), + "output_channels": (128, 64, 32), "kernel_size": 3, "activation": "elu", "global_pool": "avg", @@ -24,7 +24,7 @@ def make_general_tracking_ppo_runner_cfg( return RslRlOnPolicyRunnerCfg( actor=RslRlModelCfg( class_name=_TEMPORAL_CNN_MODEL_CLASS, - hidden_dims=(2048, 1024, 512, 256, 128), + hidden_dims=(1024, 512, 256, 256, 128), activation="elu", obs_normalization=True, cnn_cfg=_CNN_CFG, @@ -36,7 +36,7 @@ def make_general_tracking_ppo_runner_cfg( ), critic=RslRlModelCfg( class_name=_TEMPORAL_CNN_MODEL_CLASS, - hidden_dims=(2048, 1024, 512, 256, 128), + hidden_dims=(1024, 512, 256, 256, 128), activation="elu", obs_normalization=True, cnn_cfg=_CNN_CFG, @@ -48,7 +48,7 @@ def make_general_tracking_ppo_runner_cfg( entropy_coef=0.005, num_learning_epochs=5, num_mini_batches=4, - learning_rate=5.0e-4, + learning_rate=1.0e-3, schedule="adaptive", gamma=0.99, lam=0.95, From e9cce31f13ea52a2d431b14d2974e160f92d7e79 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Mon, 25 May 2026 21:21:48 +0800 Subject: [PATCH 033/122] Revert "Sync TemporalCNN config and docs" This reverts commit 58a3e58ee3ae88b5354d501c99707a912858b27e. --- AGENTS.md | 4 ++-- docs/docs/reference/architecture.md | 2 +- tests/test_task_registry.py | 2 +- train_mimic/tasks/tracking/config/rl.py | 8 ++++---- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index e84903af..c7dd15b3 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -20,7 +20,7 @@ Module-internal isolation: all modules run in-process and communicate via `InPro - Training task: `General-Tracking-G1` - Inference observation: `velcmd_history` (166D, dual-input ONNX with `obs` + `obs_history`) -- TemporalCNN actor/critic with scaled dims (1024,512,256,256,128) +- TemporalCNN actor/critic with scaled dims (2048,1024,512,256,128) - Realtime inference uses a retargeted-reference timeline before observation build; `reference_steps=[0]` is the default production path ## Directory Structure @@ -179,7 +179,7 @@ Runtime constraints: ### Training Task The single supported training task is `General-Tracking-G1` (experiment name: `g1_general_tracking`). -- Uses TemporalCNN actor/critic with scaled dims (1024,512,256,256,128) +- Uses TemporalCNN actor/critic with scaled dims (2048,1024,512,256,128) - 166D `velcmd_history` observation, dual-input ONNX export - Training env uses `sampling_mode="uniform"` - Playback/benchmark use `play=True`, which switches motion sampling to `start` diff --git a/docs/docs/reference/architecture.md b/docs/docs/reference/architecture.md index fb0915e0..ab1fdae9 100644 --- a/docs/docs/reference/architecture.md +++ b/docs/docs/reference/architecture.md @@ -57,7 +57,7 @@ train_mimic/scripts/data | Training task | `General-Tracking-G1` | | Inference observation | `velcmd_history` (166D) | | ONNX signature | Dual-input `obs` (166D) + `obs_history` | -| Actor/Critic | TemporalCNN (1024, 512, 256, 256, 128) | +| Actor/Critic | TemporalCNN (2048, 1024, 512, 256, 128) | | Training sampling | `uniform`; playback/benchmark use `start` | | Training `window_steps` | `[0]` | | Data format | Shard directories only (`shard_*.npz`) | diff --git a/tests/test_task_registry.py b/tests/test_task_registry.py index f3974a3b..71ecd80e 100644 --- a/tests/test_task_registry.py +++ b/tests/test_task_registry.py @@ -88,7 +88,7 @@ def test_general_tracking_task_is_registered() -> None: assert lower_pd_asset.actuator_ids == [1, 2, 4, 5] rl_cfg = load_rl_cfg(DEFAULT_TASK) assert rl_cfg.experiment_name == GENERAL_TRACKING_EXPERIMENT_NAME - assert rl_cfg.actor.hidden_dims == (1024, 512, 256, 256, 128) + assert rl_cfg.actor.hidden_dims == (2048, 1024, 512, 256, 128) assert load_runner_cls(DEFAULT_TASK) is MotionTrackingOnPolicyRunner diff --git a/train_mimic/tasks/tracking/config/rl.py b/train_mimic/tasks/tracking/config/rl.py index 2c674a36..c35715cc 100644 --- a/train_mimic/tasks/tracking/config/rl.py +++ b/train_mimic/tasks/tracking/config/rl.py @@ -10,7 +10,7 @@ "train_mimic.tasks.tracking.rl.temporal_cnn_model:TemporalCNNModel" ) _CNN_CFG: dict = { - "output_channels": (128, 64, 32), + "output_channels": (256, 128, 64), "kernel_size": 3, "activation": "elu", "global_pool": "avg", @@ -24,7 +24,7 @@ def make_general_tracking_ppo_runner_cfg( return RslRlOnPolicyRunnerCfg( actor=RslRlModelCfg( class_name=_TEMPORAL_CNN_MODEL_CLASS, - hidden_dims=(1024, 512, 256, 256, 128), + hidden_dims=(2048, 1024, 512, 256, 128), activation="elu", obs_normalization=True, cnn_cfg=_CNN_CFG, @@ -36,7 +36,7 @@ def make_general_tracking_ppo_runner_cfg( ), critic=RslRlModelCfg( class_name=_TEMPORAL_CNN_MODEL_CLASS, - hidden_dims=(1024, 512, 256, 256, 128), + hidden_dims=(2048, 1024, 512, 256, 128), activation="elu", obs_normalization=True, cnn_cfg=_CNN_CFG, @@ -48,7 +48,7 @@ def make_general_tracking_ppo_runner_cfg( entropy_coef=0.005, num_learning_epochs=5, num_mini_batches=4, - learning_rate=1.0e-3, + learning_rate=5.0e-4, schedule="adaptive", gamma=0.99, lam=0.95, From 1e71c58ee74a7049dc511baa0a4489549fc797fe Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Mon, 25 May 2026 22:59:17 +0800 Subject: [PATCH 034/122] Remove mocap qpos interpolation --- AGENTS.md | 2 +- docs/docs/configuration/config-reference.md | 1 - docs/docs/tutorials/bvh-sim2real.md | 3 - docs/docs/tutorials/pico-sim2real.md | 3 - .../current/configuration/config-reference.md | 1 - .../current/tutorials/bvh-sim2real.md | 3 - .../current/tutorials/pico-sim2real.md | 3 - teleopit/configs/default.yaml | 1 - teleopit/configs/pico4_sim2real.yaml | 1 - teleopit/configs/sim2real.yaml | 1 - teleopit/controllers/qpos_interpolator.py | 111 ------------------ teleopit/runtime/factory.py | 1 - teleopit/sim/loop.py | 24 ++-- teleopit/sim/runtime_components.py | 27 +---- teleopit/sim/session.py | 29 +++-- teleopit/sim2real/controller.py | 76 ++++++------ tests/test_pipeline.py | 2 - tests/test_runtime_components.py | 2 - tests/test_sim2real_runtime.py | 21 ++-- tests/test_sim_loop.py | 60 ++++++++-- 20 files changed, 120 insertions(+), 252 deletions(-) delete mode 100644 teleopit/controllers/qpos_interpolator.py diff --git a/AGENTS.md b/AGENTS.md index c7dd15b3..178b7f1c 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -122,7 +122,7 @@ target_dof_pos = clip(action, -10, 10) × action_scale + default_dof_pos ### Offline Playback - Offline sim2sim and default sim2real both read `input.bvh_file` directly; no UDP relay path remains - Offline sim2sim playback can be keyboard-controlled: `Space/P` pause/resume, `R` replay from frame 0, `Q` stop -- Offline pause holds the commanded pose; resume resets policy/reference state and uses `transition_duration` for the playback blend +- Offline pause holds the commanded pose; resume resets policy/reference state and reanchors yaw/XY without qpos interpolation - sim2sim keyboard playback is optional via `playback.keyboard.enabled=true` - sim2real reuses the Unitree remote: `Start` → `STANDING`, `Y` → playback, `X` → back to `STANDING`, `L1+R1` → `DAMPING` - `playback.pause_on_end=true` keeps the final pose and waits for manual replay diff --git a/docs/docs/configuration/config-reference.md b/docs/docs/configuration/config-reference.md index 94ca097f..04e8c4b5 100644 --- a/docs/docs/configuration/config-reference.md +++ b/docs/docs/configuration/config-reference.md @@ -15,7 +15,6 @@ Complete reference for all configurable fields. | `viewers` | Viewer set: `mocap`, `retarget`, `sim2sim`, `camera`, `all`, `none`. `all` opens `mocap`, `retarget`, and `sim2sim`; add `camera` explicitly. | `sim2sim` | | `realtime` | Rate-limit to wall clock | `false` | | `num_steps` | Number of steps; `0` = infinite | `0` | -| `transition_duration` | Smooth transition time (seconds) from current pose to retarget command | - | | `keyboard.enabled` | Enable realtime keyboard mode control for sim2sim | `false` | | `playback.pause_on_end` | Pause at last frame when offline motion ends | `false` | | `playback.keyboard.enabled` | Enable keyboard control for offline playback | `false` | diff --git a/docs/docs/tutorials/bvh-sim2real.md b/docs/docs/tutorials/bvh-sim2real.md index 884c4d23..75c9a83b 100644 --- a/docs/docs/tutorials/bvh-sim2real.md +++ b/docs/docs/tutorials/bvh-sim2real.md @@ -76,9 +76,6 @@ real_robot.network_interface=enp130s0 # Pause at the final BVH frame playback.pause_on_end=true -# Smooth transition from standing/current robot state into playback -transition_duration=2.0 - # Control loop rate policy_hz=50 ``` diff --git a/docs/docs/tutorials/pico-sim2real.md b/docs/docs/tutorials/pico-sim2real.md index e33bff46..5c46f67b 100644 --- a/docs/docs/tutorials/pico-sim2real.md +++ b/docs/docs/tutorials/pico-sim2real.md @@ -202,9 +202,6 @@ input.bridge_advertise_ip=192.168.1.20 # Consecutive valid mocap frames required before MOCAP mocap_switch.check_frames=10 -# Smooth transition into mocap reference -transition_duration=2.0 - # Change Pico pause button input.pause_button=right_axis_click diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md index 5a3a3b65..fab9683c 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md @@ -15,7 +15,6 @@ sidebar_position: 2 | `viewers` | str/list | `sim2sim` | 可视化窗口集合:`mocap`、`retarget`、`sim2sim`、`camera`、`all`、`none`。`all` 打开 `mocap`、`retarget` 和 `sim2sim`;如需相机画面需显式加入 `camera` | | `realtime` | bool | `false` | 是否启用实时模式(实机部署时需开启) | | `num_steps` | int | — | 仿真总步数;设为 `-1` 表示无限运行 | -| `transition_duration` | float | — | 从静止姿态过渡到策略控制的时长(秒) | | `keyboard.enabled` | bool | `false` | 是否启用 sim2sim 实时键盘模式控制 | | `playback.pause_on_end` | bool | `false` | 回放结束后是否暂停(而非退出) | | `playback.keyboard.enabled` | bool | `false` | 是否启用键盘控制回放进度 | diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/bvh-sim2real.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/bvh-sim2real.md index ab57700f..59acadaa 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/bvh-sim2real.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/bvh-sim2real.md @@ -73,9 +73,6 @@ real_robot.network_interface=enp130s0 # 在 BVH 最后一帧暂停 playback.pause_on_end=true -# 从 standing/当前机器人状态平滑进入回放 -transition_duration=2.0 - # 控制循环频率 policy_hz=50 ``` diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md index 7d17be7d..51c0917f 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md @@ -191,9 +191,6 @@ input.bridge_advertise_ip=192.168.1.20 # 进入 MOCAP 前要求的连续有效动捕帧数 mocap_switch.check_frames=10 -# 平滑过渡到 mocap 参考 -transition_duration=2.0 - # 更换 Pico 暂停键 input.pause_button=right_axis_click diff --git a/teleopit/configs/default.yaml b/teleopit/configs/default.yaml index 9e45aefb..17036bb1 100644 --- a/teleopit/configs/default.yaml +++ b/teleopit/configs/default.yaml @@ -6,7 +6,6 @@ defaults: policy_hz: 50.0 pd_hz: 200.0 -transition_duration: 2.0 # seconds to interpolate from default pose to motion command (0.0 to disable) viewers: "sim2sim" realtime: false debug_trace_path: null diff --git a/teleopit/configs/pico4_sim2real.yaml b/teleopit/configs/pico4_sim2real.yaml index 8e270c4f..2f976b62 100644 --- a/teleopit/configs/pico4_sim2real.yaml +++ b/teleopit/configs/pico4_sim2real.yaml @@ -5,7 +5,6 @@ defaults: - _self_ policy_hz: 50.0 -transition_duration: 2.0 # seconds to interpolate from default pose to motion command (0.0 to disable) input: video: source: realsense diff --git a/teleopit/configs/sim2real.yaml b/teleopit/configs/sim2real.yaml index 14a7500b..17befeda 100644 --- a/teleopit/configs/sim2real.yaml +++ b/teleopit/configs/sim2real.yaml @@ -5,7 +5,6 @@ defaults: - _self_ policy_hz: 50.0 -transition_duration: 2.0 # seconds to interpolate from default pose to motion command (0.0 to disable) retarget_buffer_enabled: true retarget_buffer_window_s: 0.5 retarget_buffer_delay_s: null # null = auto use one input-frame delay for timeline sampling diff --git a/teleopit/controllers/qpos_interpolator.py b/teleopit/controllers/qpos_interpolator.py deleted file mode 100644 index 15c8e79d..00000000 --- a/teleopit/controllers/qpos_interpolator.py +++ /dev/null @@ -1,111 +0,0 @@ -"""Smooth transition interpolator for retargeted 36D qpos. - -Prevents violent robot motion when switching from default/idle pose to a new -motion command by gradually blending between the start pose and the live -retargeted target over a configurable duration. -""" - -from __future__ import annotations - -import numpy as np -from numpy.typing import NDArray - - -def _slerp(q0: NDArray, q1: NDArray, t: float) -> NDArray: - """Spherical linear interpolation between two wxyz quaternions.""" - q0 = q0 / max(np.linalg.norm(q0), 1e-8) - q1 = q1 / max(np.linalg.norm(q1), 1e-8) - dot = float(np.dot(q0, q1)) - # Ensure shortest path - if dot < 0.0: - q1 = -q1 - dot = -dot - # Fall back to lerp for nearly identical quaternions - if dot > 0.9995: - result = q0 + t * (q1 - q0) - return result / max(np.linalg.norm(result), 1e-8) - theta = np.arccos(np.clip(dot, -1.0, 1.0)) - sin_theta = np.sin(theta) - a = np.sin((1.0 - t) * theta) / sin_theta - b = np.sin(t * theta) / sin_theta - return a * q0 + b * q1 - - -class QposInterpolator: - """Smoothly interpolates retargeted qpos from a start pose to the live target. - - Operates on 36D qpos: pos(3) + quat_wxyz(4) + joints(29). - Position and joints use linear interpolation; quaternion uses SLERP. - - Parameters - ---------- - duration : float - Transition duration in seconds. 0.0 disables interpolation. - policy_hz : float - Policy frequency (steps per second) for step-based progress. - """ - - def __init__(self, duration: float, policy_hz: float) -> None: - self._policy_hz = policy_hz - self._duration = 0.0 - self._total_steps = 0 - self._step = 0 - self._start_qpos: NDArray | None = None - self._active = False - self._last_alpha = np.float64(1.0) - self.configure(duration) - - @property - def duration(self) -> float: - return self._duration - - @property - def is_active(self) -> bool: - return self._active - - @property - def last_alpha(self) -> float: - return float(self._last_alpha) - - def reset(self) -> None: - self._step = 0 - self._start_qpos = None - self._active = False - self._last_alpha = np.float64(1.0) - - def configure(self, duration: float) -> None: - self._duration = max(float(duration), 0.0) - self._total_steps = int(self._duration * self._policy_hz) - - def start(self, start_qpos: NDArray) -> None: - """Begin interpolation from *start_qpos* toward future targets.""" - if self._total_steps <= 0: - return - self._start_qpos = np.array(start_qpos, dtype=np.float64).ravel() - self._step = 0 - self._active = True - self._last_alpha = np.float64(0.0) - - def apply(self, target_qpos: NDArray) -> NDArray: - """Return interpolated qpos. Passthrough when inactive or finished.""" - if not self._active or self._start_qpos is None: - self._last_alpha = np.float64(1.0) - return target_qpos - - if self._step >= self._total_steps: - self._active = False - self._last_alpha = np.float64(1.0) - return target_qpos - - alpha = self._step / self._total_steps - self._step += 1 - self._last_alpha = np.float64(alpha) - - result = np.empty_like(target_qpos) - # Position: lerp - result[0:3] = (1.0 - alpha) * self._start_qpos[0:3] + alpha * target_qpos[0:3] - # Quaternion: SLERP - result[3:7] = _slerp(self._start_qpos[3:7], target_qpos[3:7], alpha) - # Joints: lerp - result[7:] = (1.0 - alpha) * self._start_qpos[7:] + alpha * target_qpos[7:] - return result diff --git a/teleopit/runtime/factory.py b/teleopit/runtime/factory.py index 3f960afd..3466b613 100644 --- a/teleopit/runtime/factory.py +++ b/teleopit/runtime/factory.py @@ -37,7 +37,6 @@ def build_simulation_cfg(cfg: Any) -> dict[str, object]: return { "policy_hz": float(cfg_get(cfg, "policy_hz", 50.0)), "pd_hz": float(cfg_get(cfg, "pd_hz", 1000.0)), - "transition_duration": float(cfg_get(cfg, "transition_duration", 0.0) or 0.0), "retarget_buffer_enabled": bool(cfg_get(cfg, "retarget_buffer_enabled", True)), "retarget_buffer_window_s": float(cfg_get(cfg, "retarget_buffer_window_s", 0.5)), "retarget_buffer_delay_s": cfg_get(cfg, "retarget_buffer_delay_s", None), diff --git a/teleopit/sim/loop.py b/teleopit/sim/loop.py index b69ad238..21779743 100644 --- a/teleopit/sim/loop.py +++ b/teleopit/sim/loop.py @@ -9,7 +9,6 @@ from teleopit.constants import FULL_QPOS_DIM, ROOT_DIM from teleopit.controllers.observation import align_motion_qpos_yaw -from teleopit.controllers.qpos_interpolator import QposInterpolator from teleopit.runtime.reference_config import parse_reference_config from teleopit.inputs.realtime_packet import RealtimeInputPacket from teleopit.interfaces import Controller, InputProvider, MessageBus, ObservationBuilder, Recorder, Retargeter, Robot, RobotState @@ -76,17 +75,13 @@ def __init__( self._last_action: Float32Array = np.zeros((self._num_actions,), dtype=np.float32) self._last_retarget_qpos: Float64Array | None = None + self._standing_qpos: Float64Array | None = None self._realtime: bool = bool(self._try_get_cfg("realtime") or False) raw_debug_trace_path = self._try_get_cfg("debug_trace_path") self._debug_trace_path: str | None = None if raw_debug_trace_path not in (None, "", "null"): self._debug_trace_path = str(raw_debug_trace_path) - # Motion command transition smoothing - transition_dur = float(self._try_get_cfg("transition_duration") or 0.0) - self._mocap_transition_duration = transition_dur - self._qpos_interpolator = QposInterpolator(transition_dur, self.policy_hz) - self._init_reference_config() self._init_components(viewers) @@ -126,7 +121,6 @@ def _init_components(self, viewers: set[str] | None) -> None: kds=self._kds, torque_limits=self._torque_limits, default_dof_pos=self._default_dof_pos, - qpos_interpolator=self._qpos_interpolator, reference_velocity_smoothing_alpha=self._ref_cfg.reference_velocity_smoothing_alpha, reference_anchor_velocity_smoothing_alpha=self._ref_cfg.reference_anchor_velocity_smoothing_alpha, ) @@ -173,6 +167,11 @@ def _build_standing_qpos(self, state: RobotState) -> Float64Array: standing_qpos[7:7 + self._num_actions] = self._default_dof_pos.astype(np.float64)[: self._num_actions] return standing_qpos + def _set_standing_reference(self, state: RobotState) -> Float64Array: + standing_qpos = self._build_standing_qpos(state) + self._standing_qpos = standing_qpos.copy() + return standing_qpos + @staticmethod def _drain_realtime_control_events(input_provider: InputProvider) -> tuple[object, ...]: pop_control_events = getattr(input_provider, "pop_control_events", None) @@ -273,14 +272,12 @@ def _pause_offline_playback( offline_playback: OfflinePlaybackController, mocap_session: MocapSessionManager, hold_qpos: Float64Array, - retargeter: Retargeter, ) -> None: offline_playback.pause() self._step_runner.reset() self._last_action = np.zeros((self._num_actions,), dtype=np.float32) self.controller.reset() self.obs_builder.reset() - retargeter.reset() mocap_session.pause(hold_qpos) def _resume_offline_playback( @@ -288,24 +285,19 @@ def _resume_offline_playback( *, offline_playback: OfflinePlaybackController, mocap_session: MocapSessionManager, - retargeter: Retargeter, state: RobotState, ) -> None: if mocap_session.hold_qpos is None: raise RuntimeError("Cannot resume offline playback without a paused hold qpos") - resume_qpos = self._build_robot_state_qpos(state) + resume_qpos = self._build_resume_alignment_qpos(mocap_session.hold_qpos, state) offline_playback.resume() mocap_session.reset() self._step_runner.reset() self._last_action = np.zeros((self._num_actions,), dtype=np.float32) self.controller.reset() self.obs_builder.reset() - retargeter.reset() + self._step_runner.reset_reference_alignment(resume_qpos) self._step_runner.last_retarget_qpos = resume_qpos.copy() - self._step_runner.arm_motion_transition( - resume_qpos, - duration_s=self._mocap_transition_duration, - ) def _build_observation( self, diff --git a/teleopit/sim/runtime_components.py b/teleopit/sim/runtime_components.py index 094ac873..3c44db23 100644 --- a/teleopit/sim/runtime_components.py +++ b/teleopit/sim/runtime_components.py @@ -11,10 +11,9 @@ _logger = logging.getLogger(__name__) -from teleopit.constants import FULL_QPOS_DIM, NUM_JOINTS, ROOT_DIM +from teleopit.constants import FULL_QPOS_DIM from teleopit.bus.topics import TOPIC_ACTION, TOPIC_MIMIC_OBS, TOPIC_ROBOT_STATE from teleopit.controllers.observation import VelCmdObservationBuilder -from teleopit.controllers.qpos_interpolator import QposInterpolator from teleopit.controllers import reference_processing as ref_proc from teleopit.interfaces import MessageBus, ObservationBuilder, Recorder, Robot, RobotState from teleopit.retargeting.core import extract_mimic_obs @@ -94,7 +93,6 @@ def __init__( kds: Float32Array, torque_limits: Float32Array, default_dof_pos: Float32Array, - qpos_interpolator: QposInterpolator, reference_velocity_smoothing_alpha: float = 1.0, reference_anchor_velocity_smoothing_alpha: float = 1.0, ) -> None: @@ -108,7 +106,6 @@ def __init__( self.kds = kds self.torque_limits = torque_limits self.default_dof_pos = default_dof_pos - self.qpos_interpolator = qpos_interpolator self._motion_joint_vel_smoother = ExponentialVecSmoother(reference_velocity_smoothing_alpha) self._motion_anchor_lin_vel_smoother = ExponentialVecSmoother(reference_anchor_velocity_smoothing_alpha) self._motion_anchor_ang_vel_smoother = ExponentialVecSmoother(reference_anchor_velocity_smoothing_alpha) @@ -133,7 +130,6 @@ def reset(self) -> None: self._motion_joint_vel_smoother.reset() self._motion_anchor_lin_vel_smoother.reset() self._motion_anchor_ang_vel_smoother.reset() - self.qpos_interpolator.reset() def soft_reset_reference_state(self, *, reset_alignment: bool = True) -> None: self.last_reference_qpos = None @@ -157,11 +153,6 @@ def reset_reference_alignment(self, target_qpos: Float64Array | None = None) -> else np.asarray(target_qpos[0:2], dtype=np.float32).reshape(2).copy() ) - def arm_motion_transition(self, start_qpos: Float64Array, *, duration_s: float) -> None: - self.qpos_interpolator.reset() - self.qpos_interpolator.configure(duration_s) - self.qpos_interpolator.start(start_qpos) - def prepare_static_motion_command(self, qpos: Float64Array) -> MotionPreparation: hold_qpos = self._retarget_to_qpos(qpos) mimic_obs = extract_mimic_obs( @@ -188,14 +179,7 @@ def prepare_motion_command(self, retargeted: object, state: object) -> MotionPre reference_qpos = self._retarget_to_qpos(retargeted) reference_qpos = self._align_velcmd_reference_yaw(reference_qpos, state) self._pending_reference_qpos = reference_qpos.copy() - - if self.last_retarget_qpos is None and self.qpos_interpolator.duration > 0: - start_qpos = np.zeros(FULL_QPOS_DIM, dtype=np.float64) - start_qpos[0:3] = np.asarray(state.base_pos[:3], dtype=np.float64) - start_qpos[3:7] = np.asarray(state.quat[:4], dtype=np.float64) - start_qpos[ROOT_DIM:FULL_QPOS_DIM] = np.asarray(state.qpos[:NUM_JOINTS], dtype=np.float64) - self.qpos_interpolator.start(start_qpos) - qpos = self.qpos_interpolator.apply(reference_qpos) + qpos = reference_qpos.copy() mimic_obs = extract_mimic_obs(qpos=qpos, last_qpos=self.last_retarget_qpos, dt=1.0 / self.policy_hz) retarget_viewer_qpos = qpos.copy() @@ -209,13 +193,6 @@ def prepare_motion_command(self, retargeted: object, state: object) -> MotionPre if obs_builder_requires_reference_window(self.obs_builder): motion_anchor_lin_vel_w = None motion_anchor_ang_vel_w = None - elif self.qpos_interpolator.is_active: - true_lin_vel_w, true_ang_vel_w = self._compute_anchor_velocities(reference_qpos) - blend = np.float32(self.qpos_interpolator.last_alpha) - raw_motion_anchor_lin_vel_w = np.asarray(true_lin_vel_w * blend, dtype=np.float32) - raw_motion_anchor_ang_vel_w = np.asarray(true_ang_vel_w * blend, dtype=np.float32) - motion_anchor_lin_vel_w = self._motion_anchor_lin_vel_smoother.apply(raw_motion_anchor_lin_vel_w) - motion_anchor_ang_vel_w = self._motion_anchor_ang_vel_smoother.apply(raw_motion_anchor_ang_vel_w) else: raw_motion_anchor_lin_vel_w, raw_motion_anchor_ang_vel_w = self._compute_anchor_velocities(reference_qpos) motion_anchor_lin_vel_w = self._motion_anchor_lin_vel_smoother.apply(raw_motion_anchor_lin_vel_w) diff --git a/teleopit/sim/session.py b/teleopit/sim/session.py index f963f308..18287477 100644 --- a/teleopit/sim/session.py +++ b/teleopit/sim/session.py @@ -192,6 +192,8 @@ def __init__( self.simulation_mode: SimulationMode = ( SimulationMode.STANDING if self.realtime_keyboard_mode_enabled else SimulationMode.MOCAP ) + if self.simulation_mode == SimulationMode.STANDING: + loop._set_standing_reference(loop.robot.get_state()) # Debug writer self.debug_writer: RolloutTraceWriter | None = None @@ -228,18 +230,23 @@ def reset_runtime_tracking(self) -> None: self.cached_human_frame = None self.cached_retargeted = None - def full_policy_reset(self) -> None: + def reset_policy_reference_state(self, *, reset_mocap_session: bool = True) -> None: self._step_runner.reset() self._loop.controller.reset() self._loop.obs_builder.reset() - self._retargeter.reset() - self.mocap_session.reset() + if reset_mocap_session: + self.mocap_session.reset() self.last_commanded_motion_qpos = None self.reset_runtime_tracking() + def full_policy_reset(self) -> None: + self.reset_policy_reference_state() + self._retargeter.reset() + def enter_standing_mode(self) -> None: from teleopit.sim.loop import SimulationMode self.full_policy_reset() + self._loop._set_standing_reference(self._loop.robot.get_state()) self.simulation_mode = SimulationMode.STANDING def enter_mocap_mode(self) -> None: @@ -252,10 +259,6 @@ def enter_mocap_mode(self) -> None: start_qpos = loop._resolve_hold_qpos(None, None, None, state) self.full_policy_reset() self._step_runner.last_retarget_qpos = start_qpos.copy() - self._step_runner.arm_motion_transition( - start_qpos, - duration_s=loop._mocap_transition_duration, - ) self.last_commanded_motion_qpos = start_qpos.copy() self.simulation_mode = SimulationMode.MOCAP @@ -266,7 +269,7 @@ def toggle_realtime_mocap_pause(self) -> None: if hold_qpos is None: raise RuntimeError("Cannot resume mocap without a paused hold qpos") resume_qpos = loop._build_resume_alignment_qpos(hold_qpos, loop.robot.get_state()) - self.full_policy_reset() + self.reset_policy_reference_state() self._step_runner.reset_reference_alignment(resume_qpos) self.last_commanded_motion_qpos = resume_qpos.copy() return @@ -276,7 +279,7 @@ def toggle_realtime_mocap_pause(self) -> None: self.latest_live_retargeted, loop.robot.get_state(), ) - self.full_policy_reset() + self.reset_policy_reference_state() self.mocap_session.pause(hold_qpos) self.last_commanded_motion_qpos = hold_qpos.copy() @@ -335,7 +338,6 @@ def _handle_offline_keyboard(self) -> bool: loop._resume_offline_playback( offline_playback=self.offline_playback, mocap_session=self.mocap_session, - retargeter=self._retargeter, state=loop.robot.get_state(), ) self.last_commanded_motion_qpos = None @@ -350,7 +352,6 @@ def _handle_offline_keyboard(self) -> bool: offline_playback=self.offline_playback, mocap_session=self.mocap_session, hold_qpos=hold_qpos, - retargeter=self._retargeter, ) return False @@ -360,9 +361,11 @@ def _handle_offline_keyboard(self) -> bool: def _fetch_standing_input(self) -> tuple[bool, ReferenceWindow | None, RealtimeReferenceDiagnostics | None]: """Fetch input when in STANDING mode (keyboard). Returns (new_bvh_frame, ref_window, diag).""" - state = self._loop.robot.get_state() self.cached_human_frame = None - self.cached_retargeted = self._loop._build_standing_qpos(state) + if self._loop._standing_qpos is None: + self.cached_retargeted = self._loop._set_standing_reference(self._loop.robot.get_state()) + else: + self.cached_retargeted = self._loop._standing_qpos.copy() return False, None, None def _fetch_offline_reference_input( diff --git a/teleopit/sim2real/controller.py b/teleopit/sim2real/controller.py index 5435e914..fcaa41f1 100644 --- a/teleopit/sim2real/controller.py +++ b/teleopit/sim2real/controller.py @@ -25,7 +25,6 @@ VelCmdObservationBuilder, align_motion_qpos_yaw, ) -from teleopit.controllers.qpos_interpolator import QposInterpolator from teleopit.controllers.rl_policy import RLPolicyController from teleopit.inputs.bvh_provider import BVHInputProvider from teleopit.inputs.pico4_provider import Pico4InputProvider @@ -79,11 +78,6 @@ def __init__(self, cfg: Any) -> None: self.policy_hz: float = float(cfg_get(cfg, "policy_hz", 50.0)) self._project_root = Path(__file__).resolve().parent.parent.parent - # Motion command transition smoothing - transition_dur = float(cfg_get(cfg, "transition_duration", 0.0) or 0.0) - self._mocap_transition_duration = transition_dur - self._qpos_interpolator = QposInterpolator(transition_dur, self.policy_hz) - self._init_components(cfg) self._init_reference_config(cfg) self._safety = Sim2RealSafetyManager(cfg, self.robot, self.policy_hz, self.num_actions) @@ -258,11 +252,7 @@ def _standing_step(self) -> None: """Standing mode: feed fixed default-pose reference to RL policy.""" robot_state = self.robot.get_state() - # Build standing reference qpos aligned to robot's current yaw qpos = self._standing_qpos.copy() - align_motion_qpos_yaw( - np.asarray(robot_state.quat, dtype=np.float32), qpos - ) # Standing → zero joint velocity reference motion_joint_vel = np.zeros(self.num_actions, dtype=np.float32) @@ -391,9 +381,9 @@ def _execute_mocap_pipeline( robot_state: object, reference_window: ReferenceWindow | None, ) -> None: - """Shared mocap control pipeline: align → interpolate → infer → send.""" + """Shared mocap control pipeline: align → infer → send.""" reference_qpos = self._ref_proc.align_reference_yaw(reference_qpos, robot_state=robot_state) - qpos = self._qpos_interpolator.apply(reference_qpos) + qpos = reference_qpos.copy() # Compute joint velocities via finite difference if qpos.shape[0] < 7 + self.num_actions: @@ -408,17 +398,11 @@ def _execute_mocap_pipeline( raw_motion_joint_vel = (motion_joint_pos - prev_joint_pos) * np.float32(self.policy_hz) motion_joint_vel = self._ref_proc.apply_joint_vel_smoothing(raw_motion_joint_vel) - # Compute anchor velocities (with interpolator blending if active) + # Compute anchor velocities. anchor_lin_vel_w = np.zeros(3, dtype=np.float32) anchor_ang_vel_w = np.zeros(3, dtype=np.float32) if not obs_builder_requires_reference_window(self.obs_builder): - if self._qpos_interpolator.is_active: - true_lin_vel_w, true_ang_vel_w = self._ref_proc.compute_anchor_velocities(reference_qpos) - blend = np.float32(self._qpos_interpolator.last_alpha) - raw_anchor_lin_vel_w = np.asarray(true_lin_vel_w * blend, dtype=np.float32) - raw_anchor_ang_vel_w = np.asarray(true_ang_vel_w * blend, dtype=np.float32) - else: - raw_anchor_lin_vel_w, raw_anchor_ang_vel_w = self._ref_proc.compute_anchor_velocities(reference_qpos) + raw_anchor_lin_vel_w, raw_anchor_ang_vel_w = self._ref_proc.compute_anchor_velocities(reference_qpos) anchor_lin_vel_w, anchor_ang_vel_w = self._ref_proc.apply_anchor_vel_smoothing( raw_anchor_lin_vel_w, raw_anchor_ang_vel_w, ) @@ -531,6 +515,12 @@ def _enter_standing(self) -> None: self._ref_proc.last_reference_qpos = None self._mocap_session.reset() self._last_commanded_motion_qpos = None + self._standing_qpos[0:3] = 0.0 + if getattr(state, "base_pos", None) is not None: + self._standing_qpos[0:3] = np.asarray(state.base_pos, dtype=np.float64)[:3] + self._standing_qpos[3:7] = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64) + align_motion_qpos_yaw(np.asarray(state.quat, dtype=np.float32), self._standing_qpos) + self._standing_qpos[ROOT_DIM:FULL_QPOS_DIM] = self.default_angles.astype(np.float64) # Always do a full policy reset (episode-reset semantics) to ensure # the TemporalCNN history is clean and action-state causality holds. @@ -606,11 +596,10 @@ def _can_switch_to_mocap(self) -> bool: def _transition_to_mocap(self) -> None: """Switch from STANDING -> MOCAP. - Episode-reset + reference-side interpolation. The policy state is - fully reset (clean history, zero last_action) so the TemporalCNN - starts fresh. A QposInterpolator smoothly blends the *reference* - from the current robot state toward incoming live mocap so the policy - never sees a large instantaneous tracking error. + Episode-reset + reference realignment. The policy state is fully + reset (clean history, zero last_action) so the TemporalCNN starts + fresh. Incoming mocap is aligned by fixed yaw/XY offsets and then + consumed directly; switching does not interpolate reference qpos. """ state = self.robot.get_state() init_qpos = np.zeros(FULL_QPOS_DIM, dtype=np.float64) @@ -622,11 +611,6 @@ def _transition_to_mocap(self) -> None: # Full episode reset: clean policy state, alignment, timeline. self._reset_policy_state() - - # Reference-side interpolation: smoothly blend reference from current - # robot state toward incoming live mocap. This is done AFTER the - # episode reset so the interpolator starts with a clean slate. - self._arm_qpos_transition(init_qpos, duration_s=self._mocap_transition_duration) if self._offline_playback is not None: self._offline_playback.replay() @@ -665,7 +649,6 @@ def _reset_policy_state(self) -> None: """Full episode-reset: clear all policy state so the TemporalCNN sees a clean start identical to training episode reset.""" self._last_action = np.zeros(self.num_actions, dtype=np.float32) - self._qpos_interpolator.reset() self._reset_mocap_reference_state() self._ref_proc.reset_alignment() self._mocap_session.reset() @@ -674,6 +657,16 @@ def _reset_policy_state(self) -> None: self.obs_builder.reset() self.retargeter.reset() + def _reset_policy_reference_state(self) -> None: + """Reset policy/reference state without resetting the retargeter.""" + self._last_action = np.zeros(self.num_actions, dtype=np.float32) + self._reset_mocap_reference_state() + self._ref_proc.reset_alignment() + self._mocap_session.reset() + self._last_commanded_motion_qpos = None + self.policy.reset() + self.obs_builder.reset() + def _reset_mocap_reference_state(self) -> None: """Reset mocap-specific reference state without disrupting policy observation continuity. @@ -713,7 +706,6 @@ def _restart_offline_playback(self) -> None: self._last_commanded_motion_qpos = restart_qpos.copy() self._offline_playback.replay() self._reset_policy_state() - self._arm_qpos_transition(restart_qpos, duration_s=self._mocap_transition_duration) logger.info("Offline playback restarted from frame 0") def _hold_completed_offline_playback(self, hold_qpos: Float64Array) -> None: @@ -723,11 +715,6 @@ def _hold_completed_offline_playback(self, hold_qpos: Float64Array) -> None: self._mocap_session.pause(hold_qpos) logger.info("Offline playback reached the end; press B to replay") - def _arm_qpos_transition(self, start_qpos: Float64Array, *, duration_s: float) -> None: - self._qpos_interpolator.reset() - self._qpos_interpolator.configure(duration_s) - self._qpos_interpolator.start(start_qpos) - def _fetch_realtime_input_packet(self) -> RealtimeInputPacket[object]: get_realtime_input_packet = getattr(self.input_provider, "get_realtime_input_packet", None) if callable(get_realtime_input_packet): @@ -770,10 +757,12 @@ def _pause_active_mocap(self) -> None: self._ref_proc.last_reference_qpos = hold_qpos.copy() self._last_commanded_motion_qpos = hold_qpos.copy() - # Reset policy state (clears last_action, history, smoothers, etc.) + # Reset policy/reference state (clears last_action, history, smoothers, etc.) + # without resetting the retargeter IK warm-start. Pause is a mocap-session + # control event, not a new retargeting source. # Note: _reset_policy_state resets _mocap_session to ACTIVE, so we # must call pause() *after* it to set the correct PAUSED state. - self._reset_policy_state() + self._reset_policy_reference_state() self._mocap_session.pause(hold_qpos) if self._offline_playback is not None: self._offline_playback.pause() @@ -792,9 +781,11 @@ def _resume_paused_mocap(self) -> None: self._last_commanded_motion_qpos = resume_qpos.copy() - # Full policy reset -- clean history, zero last_action, smoothers, - # timeline, alignment. Also resets _mocap_session to ACTIVE. - self._reset_policy_state() + # Policy/reference reset -- clean history, zero last_action, smoothers, + # timeline, alignment. Keep the retargeter IK warm-start so the first + # resumed frame is solved from the current retarget state rather than + # from the model default qpos. Also resets _mocap_session to ACTIVE. + self._reset_policy_reference_state() self._last_retarget_qpos = None self._last_commanded_motion_qpos = resume_qpos.copy() @@ -802,7 +793,6 @@ def _resume_paused_mocap(self) -> None: self._ref_proc.reset_alignment(target_qpos=resume_qpos) if self._offline_playback is not None: self._last_retarget_qpos = resume_qpos.copy() - self._arm_qpos_transition(resume_qpos, duration_s=self._mocap_transition_duration) self._offline_playback.resume() logger.info("Mocap session -> ACTIVE (episode-reset + reference realignment)") diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index c586484b..81419cc4 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -81,7 +81,6 @@ def __init__(self, *args: object, **kwargs: object) -> None: }, "policy_hz": 50, "pd_hz": 1000, - "transition_duration": 1.5, "keyboard": {"enabled": True}, "retarget_buffer_enabled": True, "retarget_buffer_window_s": 0.75, @@ -105,7 +104,6 @@ def __init__(self, *args: object, **kwargs: object) -> None: loop_cfg = captured["loop_args"][4] assert list(controller_cfg.default_dof_pos) == robot_default_angles assert list(controller_cfg.action_scale) == robot_action_scale - assert loop_cfg["transition_duration"] == pytest.approx(1.5) assert loop_cfg["keyboard"]["enabled"] is True assert loop_cfg["retarget_buffer_enabled"] is True assert loop_cfg["retarget_buffer_window_s"] == pytest.approx(0.75) diff --git a/tests/test_runtime_components.py b/tests/test_runtime_components.py index 7e6e776f..8e10c3ef 100644 --- a/tests/test_runtime_components.py +++ b/tests/test_runtime_components.py @@ -6,7 +6,6 @@ import numpy as np from teleopit.controllers.observation import VelCmdObservationBuilder -from teleopit.controllers.qpos_interpolator import QposInterpolator from teleopit.sim.runtime_components import PolicyStepRunner @@ -39,7 +38,6 @@ def _make_runner( kds=np.ones(29, dtype=np.float32), torque_limits=np.ones(29, dtype=np.float32), default_dof_pos=np.zeros(29, dtype=np.float32), - qpos_interpolator=QposInterpolator(0.0, 50.0), reference_velocity_smoothing_alpha=reference_velocity_smoothing_alpha, ) diff --git a/tests/test_sim2real_runtime.py b/tests/test_sim2real_runtime.py index 8e2b92a9..79f54212 100644 --- a/tests/test_sim2real_runtime.py +++ b/tests/test_sim2real_runtime.py @@ -85,12 +85,13 @@ def get_realtime_input_packet(self) -> RealtimeInputPacket[dict[str, tuple[np.nd class DummyRetargeter: def __init__(self, qpos: np.ndarray) -> None: self._qpos = np.asarray(qpos, dtype=np.float64) + self.reset_calls = 0 def retarget(self, _frame: object) -> np.ndarray: return self._qpos.copy() def reset(self) -> None: - pass + self.reset_calls += 1 class DummyPolicy: @@ -159,10 +160,9 @@ def tick(self, *, active: bool) -> None: raise RuntimeError("hand send failed") -def _make_cfg(transition_duration: float = 1.0) -> dict[str, object]: +def _make_cfg() -> dict[str, object]: return { "policy_hz": 50.0, - "transition_duration": transition_duration, "real_robot": {}, "mocap_switch": {"check_frames": 1}, "robot": { @@ -376,7 +376,7 @@ def test_mocap_step_episode_reset_on_transition(monkeypatch) -> None: target_qpos[0] = 0.3 _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=target_qpos) - ctrl = Sim2RealController(_make_cfg(transition_duration=2.0)) + ctrl = Sim2RealController(_make_cfg()) ctrl._transition_to_mocap() monkeypatch.setattr( ctrl._ref_proc, @@ -407,7 +407,7 @@ def test_mocap_step_velcmd_applies_fixed_initial_yaw_alignment(monkeypatch) -> N target_qpos[3:7] = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64) _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=target_qpos) - ctrl = Sim2RealController(_make_cfg(transition_duration=0.0)) + ctrl = Sim2RealController(_make_cfg()) ctrl.robot._state.quat = np.array([0.70710677, 0.0, 0.0, 0.70710677], dtype=np.float32) monkeypatch.setattr( ctrl._ref_proc, @@ -436,7 +436,7 @@ def test_mocap_step_velcmd_keeps_fixed_yaw_after_start(monkeypatch) -> None: target_qpos[3:7] = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64) _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=target_qpos) - ctrl = Sim2RealController(_make_cfg(transition_duration=0.0)) + ctrl = Sim2RealController(_make_cfg()) monkeypatch.setattr( ctrl._ref_proc, "compute_anchor_velocities", @@ -467,7 +467,7 @@ def test_mocap_step_waits_for_realtime_warmup_before_running_policy(monkeypatch) target_qpos[3] = 1.0 _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=target_qpos) - cfg = _make_cfg(transition_duration=0.0) + cfg = _make_cfg() cfg['realtime_buffer_warmup_steps'] = 2 ctrl = Sim2RealController(cfg) monkeypatch.setattr( @@ -517,7 +517,7 @@ def test_mocap_step_uses_current_reference_qpos(monkeypatch) -> None: target_qpos[3] = 1.0 _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=target_qpos) - cfg = _make_cfg(transition_duration=0.0) + cfg = _make_cfg() cfg["retarget_buffer_enabled"] = False ctrl = Sim2RealController(cfg) monkeypatch.setattr( @@ -549,7 +549,7 @@ def test_mocap_pause_freezes_reference_and_zeroes_velocities(monkeypatch) -> Non target_qpos[3] = 1.0 _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=target_qpos) - cfg = _make_cfg(transition_duration=0.0) + cfg = _make_cfg() cfg["retarget_buffer_enabled"] = False ctrl = Sim2RealController(cfg) monkeypatch.setattr( @@ -599,7 +599,7 @@ def test_mocap_resume_uses_episode_reset_semantics(monkeypatch) -> None: target_qpos[3] = 1.0 _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=target_qpos) - cfg = _make_cfg(transition_duration=0.0) + cfg = _make_cfg() cfg["retarget_buffer_enabled"] = False ctrl = Sim2RealController(cfg) monkeypatch.setattr( @@ -643,6 +643,7 @@ def test_mocap_resume_uses_episode_reset_semantics(monkeypatch) -> None: assert ctrl._mocap_session.state == MocapSessionState.ACTIVE # Policy was reset (last_action zeroed, history cleared) assert np.allclose(ctrl._last_action, 0.0) + assert ctrl.retargeter.reset_calls == 0 # Retarget reference jumps to the live mocap pose (joint 0), while root XY # is reanchored to the paused reference because real-robot XY is unobserved. np.testing.assert_allclose(obs_builder.build_calls[-1]["motion_qpos"][0], 0.2, atol=1e-6) diff --git a/tests/test_sim_loop.py b/tests/test_sim_loop.py index e79e0bba..f54e59f5 100644 --- a/tests/test_sim_loop.py +++ b/tests/test_sim_loop.py @@ -110,6 +110,9 @@ def is_available(self) -> bool: class _DummyRetargeter: + def __init__(self) -> None: + self.reset_calls = 0 + def retarget(self, human_data: dict[str, tuple[np.ndarray, np.ndarray]]) -> tuple[np.ndarray, np.ndarray, np.ndarray]: pelvis_x = float(human_data["Pelvis"][0][0]) return ( @@ -119,7 +122,7 @@ def retarget(self, human_data: dict[str, tuple[np.ndarray, np.ndarray]]) -> tupl ) def reset(self) -> None: - pass + self.reset_calls += 1 class _DummyRecorder: @@ -154,7 +157,7 @@ def test_standing_qpos_keeps_yaw_but_drops_roll() -> None: controller=_DummyController(), obs_builder=_DummyObsBuilder(), bus=bus, - cfg={"policy_hz": 50.0, "pd_hz": 50.0, "realtime": False, "transition_duration": 0.0}, + cfg={"policy_hz": 50.0, "pd_hz": 50.0, "realtime": False}, viewers=set(), ) @@ -180,6 +183,46 @@ def test_standing_qpos_keeps_yaw_but_drops_roll() -> None: np.testing.assert_allclose(standing_qpos[7:9], robot.default_dof_pos, atol=1e-6) +def test_standing_reference_is_fixed_after_initialization() -> None: + from teleopit.sim.loop import SimulationLoop + + bus = InProcessBus() + robot = _DummyRobot() + loop = SimulationLoop( + robot=robot, + controller=_DummyController(), + obs_builder=_DummyObsBuilder(), + bus=bus, + cfg={"policy_hz": 50.0, "pd_hz": 50.0, "realtime": False}, + viewers=set(), + ) + + first_state = RobotState( + qpos=np.zeros(2, dtype=np.float32), + qvel=np.zeros(2, dtype=np.float32), + quat=np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), + ang_vel=np.zeros(3, dtype=np.float32), + timestamp=0.0, + base_pos=np.array([1.0, 2.0, 0.9], dtype=np.float32), + base_lin_vel=np.zeros(3, dtype=np.float32), + ) + first = loop._set_standing_reference(first_state) + + drifted_state = RobotState( + qpos=np.zeros(2, dtype=np.float32), + qvel=np.zeros(2, dtype=np.float32), + quat=np.array([0.70710677, 0.0, 0.0, 0.70710677], dtype=np.float32), + ang_vel=np.zeros(3, dtype=np.float32), + timestamp=1.0, + base_pos=np.array([5.0, 6.0, 0.9], dtype=np.float32), + base_lin_vel=np.zeros(3, dtype=np.float32), + ) + live = loop._build_standing_qpos(drifted_state) + + np.testing.assert_allclose(loop._standing_qpos, first, atol=1e-6) + assert not np.allclose(live[0:7], first[0:7]) + + @requires_mujoco def test_simulation_loop_runs_and_records_without_viewers() -> None: from teleopit.sim.loop import SimulationLoop @@ -191,7 +234,7 @@ def test_simulation_loop_runs_and_records_without_viewers() -> None: controller=_DummyController(), obs_builder=_DummyObsBuilder(), bus=bus, - cfg={"policy_hz": 50.0, "pd_hz": 50.0, "realtime": False, "transition_duration": 0.0}, + cfg={"policy_hz": 50.0, "pd_hz": 50.0, "realtime": False}, viewers=set(), ) @@ -263,7 +306,7 @@ def get_frame_packet(self): controller=_DummyController(), obs_builder=obs_builder, bus=bus, - cfg={"policy_hz": 50.0, "pd_hz": 50.0, "realtime": False, "transition_duration": 0.0}, + cfg={"policy_hz": 50.0, "pd_hz": 50.0, "realtime": False}, viewers=set(), ) @@ -293,7 +336,6 @@ def test_simulation_loop_rejects_nonzero_reference_steps_without_realtime_timeli "policy_hz": 50.0, "pd_hz": 50.0, "realtime": False, - "transition_duration": 0.0, "retarget_buffer_enabled": True, "reference_steps": [0, 1], }, @@ -357,7 +399,6 @@ def get_frame_packet(self): 'policy_hz': 50.0, 'pd_hz': 50.0, 'realtime': False, - 'transition_duration': 0.0, 'realtime_buffer_warmup_steps': 2, }, viewers=set(), @@ -390,7 +431,6 @@ def test_simulation_loop_allows_future_reference_steps() -> None: "policy_hz": 50.0, "pd_hz": 50.0, "realtime": False, - "transition_duration": 0.0, "reference_steps": [0, 1, 2, 3, 4], "retarget_buffer_delay_s": 0.08, "retarget_buffer_window_s": 0.5, @@ -510,7 +550,6 @@ def get_realtime_input_packet(self): "policy_hz": 50.0, "pd_hz": 50.0, "realtime": False, - "transition_duration": 0.0, "retarget_buffer_enabled": False, "realtime_input_delay_s": 0.0, }, @@ -528,8 +567,10 @@ def get_realtime_input_packet(self): np.testing.assert_allclose(obs_builder.mimic_obs_calls[1], np.array([0.2], dtype=np.float32), atol=1e-6) np.testing.assert_allclose(obs_builder.mimic_obs_calls[2], np.array([0.0], dtype=np.float32), atol=1e-6) np.testing.assert_allclose(obs_builder.mimic_obs_calls[3], np.array([0.0], dtype=np.float32), atol=1e-6) + assert loop._step_runner.last_retarget_qpos is not None +@requires_mujoco @requires_mujoco def test_simulation_loop_realtime_keyboard_mode_transitions(monkeypatch) -> None: from teleopit.sim.loop import SimulationLoop @@ -606,7 +647,6 @@ def close(self) -> None: "policy_hz": 50.0, "pd_hz": 50.0, "realtime": False, - "transition_duration": 0.0, "retarget_buffer_enabled": False, "realtime_input_delay_s": 0.0, "keyboard": {"enabled": True}, @@ -707,7 +747,6 @@ def close(self) -> None: "policy_hz": 50.0, "pd_hz": 50.0, "realtime": False, - "transition_duration": 0.0, "retarget_buffer_enabled": False, "realtime_input_delay_s": 0.0, "keyboard": {"enabled": True}, @@ -775,7 +814,6 @@ def close(self) -> None: "policy_hz": 50.0, "pd_hz": 50.0, "realtime": False, - "transition_duration": 0.0, "retarget_buffer_enabled": False, "realtime_input_delay_s": 0.0, "keyboard": {"enabled": True}, From dec76fb0d62f8e517c6bf972c3de088be5fa16f8 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Tue, 26 May 2026 12:25:17 +0800 Subject: [PATCH 035/122] Add optional sim2real retarget viewer --- AGENTS.md | 1 + README.md | 3 + docs/docs/configuration/config-reference.md | 6 +- docs/docs/tutorials/pico-sim2real.md | 6 +- .../current/configuration/config-reference.md | 6 +- .../current/tutorials/pico-sim2real.md | 4 +- scripts/run/standalone_standing.py | 50 +----------- teleopit/configs/pico4_sim2real.yaml | 3 +- teleopit/configs/sim2real.yaml | 3 +- teleopit/sim2real/controller.py | 73 +++++++++++++++++- teleopit/sim2real/safety.py | 1 - tests/test_sim2real_runtime.py | 77 +++++++++++++++++++ 12 files changed, 175 insertions(+), 58 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 178b7f1c..be65dca3 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -108,6 +108,7 @@ python scripts/run/run_sim.py controller.policy_path=policy.onnx viewers=none - `viewers=all` opens `mocap`, `retarget`, and `sim2sim`; add `camera` explicitly when needed - All viewers run in separate subprocesses because GLFW/GLX only supports one window per process - Simulation exits when all active viewer windows are closed +- sim2real defaults to `viewers=none`; it supports only optional `viewers=retarget` - `viewers` is the only supported viewer key; legacy `viewer` alias is removed ### default_dof_pos Propagation diff --git a/README.md b/README.md index 26f05c39..00f80b9c 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,9 @@ python scripts/run/run_sim.py \ 'viewers=[sim2sim,camera]' ``` +For sim2real, viewers are disabled by default. Add `viewers=retarget` to show +the retargeted reference in an optional MuJoCo window. + ## Documentation Full docs at **[BotRunner64.github.io/Teleopit](https://BotRunner64.github.io/Teleopit/)**, covering installation profiles, all tutorials, configuration reference, and architecture. diff --git a/docs/docs/configuration/config-reference.md b/docs/docs/configuration/config-reference.md index 04e8c4b5..85da3790 100644 --- a/docs/docs/configuration/config-reference.md +++ b/docs/docs/configuration/config-reference.md @@ -90,11 +90,15 @@ Complete reference for all configurable fields. Fields used by sim2real configs (`sim2real.yaml`, `pico4_sim2real.yaml`). +Sim2real defaults to `viewers=none`. Set `viewers=retarget` to open an optional +MuJoCo window showing the retargeted reference; `sim2sim`, `mocap`, `camera`, +and `all` are simulation-only viewer modes. + ### Safety | Field | Description | Default | |-------|-------------|---------| -| `startup_ramp_duration` | Seconds to smoothly blend from locked to policy positions | `2.0` | +| `startup_ramp_duration` | Kp ramp duration after entering `STANDING`; gradually increases PD gains without changing policy targets | `2.0` | | `joint_vel_limit` | Joint velocity limit (rad/s); triggers emergency damping if exceeded | `10.0` | | `mocap_switch.check_frames` | Consecutive valid frames required before switching to MOCAP | `10` | | `mocap_switch.max_position_value` | Position sanity threshold in meters | `5.0` | diff --git a/docs/docs/tutorials/pico-sim2real.md b/docs/docs/tutorials/pico-sim2real.md index 5c46f67b..e847a488 100644 --- a/docs/docs/tutorials/pico-sim2real.md +++ b/docs/docs/tutorials/pico-sim2real.md @@ -122,10 +122,10 @@ Pico body frames -> retarget -> reference buffer -> observation -> policy -> G1 When entering `STANDING`, Teleopit releases active Unitree modes, enters debug/low-level control, locks the current joints briefly, resets policy state, -and ramps Kp to reduce startup spikes. +and ramps Kp without changing policy targets. -When entering `MOCAP`, Teleopit resets policy/reference state and blends the -reference from the current robot state into the live mocap command. +When entering `MOCAP`, Teleopit resets policy/reference state and starts tracking +the live mocap command through the realtime reference timeline. ## Pause / Resume diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md index fab9683c..60b8e116 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md @@ -109,11 +109,15 @@ target = clip(action, clip_range) * action_scale + default_dof_pos 以下字段用于 sim2real 配置(`sim2real.yaml`、`pico4_sim2real.yaml`)。 +sim2real 默认使用 `viewers=none`。设置 `viewers=retarget` 可打开一个可选的 +MuJoCo 窗口显示重定向参考;`sim2sim`、`mocap`、`camera` 和 `all` +仅用于仿真 viewer。 + ### 安全相关 | 字段 | 说明 | 默认值 | |---|---|---| -| `startup_ramp_duration` | 从锁定位置平滑过渡到策略控制的时长(秒) | `2.0` | +| `startup_ramp_duration` | 进入 `STANDING` 后的 Kp ramp 时长;逐步提高 PD 增益,不改变 policy target | `2.0` | | `joint_vel_limit` | 关节速度限制(rad/s),超过时触发急停 | `10.0` | | `mocap_switch.check_frames` | 切换到 MOCAP 前所需的连续有效帧数 | `10` | | `mocap_switch.max_position_value` | 位置合理性阈值(米) | `5.0` | diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md index 51c0917f..771bf234 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md @@ -115,9 +115,9 @@ Pico body frames -> retarget -> reference buffer -> observation -> policy -> G1 ``` 进入 `STANDING` 时,Teleopit 会释放当前 Unitree 模式,进入 debug/low-level 控制, -短暂锁住当前关节,重置 policy 状态,并通过 Kp ramp 减少启动冲击。 +短暂锁住当前关节,重置 policy 状态,并在不改变 policy target 的情况下执行 Kp ramp。 -进入 `MOCAP` 时,Teleopit 会重置 policy/reference 状态,并将参考从当前机器人状态平滑过渡到 +进入 `MOCAP` 时,Teleopit 会重置 policy/reference 状态,并通过实时参考时间线开始跟踪 实时 mocap 命令。 ## 暂停 / 恢复 diff --git a/scripts/run/standalone_standing.py b/scripts/run/standalone_standing.py index 81bb4b60..6338bbbb 100644 --- a/scripts/run/standalone_standing.py +++ b/scripts/run/standalone_standing.py @@ -17,7 +17,7 @@ 1. Init DDS, subscribe to rt/lowstate 2. Load ONNX policy + MuJoCo model for observation building 3. Enter debug mode (release MotionSwitcher modes) - 4. Lock joints, then run RL policy standing loop with startup ramp + 4. Lock joints, then run RL policy standing loop 5. Hold standing until Ctrl-C or L1+R1 6. On exit: set damping, restore ai mode """ @@ -54,7 +54,6 @@ POS_STOP_F = 2146000000.0 VEL_STOP_F = 16000.0 KD_DAMPING = 8.0 -RAMP_DURATION = 2.0 JOINT_VEL_LIMIT = 10.0 # Default standing pose (from g1_constants.py HOME_KEYFRAME) @@ -357,11 +356,9 @@ class StandingController: """RL-policy-based standing controller matching Sim2RealController.STANDING.""" def __init__(self, network_interface: str, policy_path: str, - ramp_duration: float = RAMP_DURATION, no_policy: bool = False, publish_hz: int = 250) -> None: self._network_interface = network_interface - self._ramp_duration = ramp_duration self._shutdown = False # ---- Load policy and observation builder ---- @@ -391,12 +388,6 @@ def __init__(self, network_interface: str, policy_path: str, self._standing_qpos[3] = 1.0 # identity quaternion w=1 self._standing_qpos[7:36] = DEFAULT_ANGLES.astype(np.float64) - # Startup ramp state - self._ramp_duration_steps = max(1, int(ramp_duration * POLICY_HZ)) - self._ramp_step = 0 - self._ramp_start_positions: np.ndarray | None = None - self._ramp_active = False - # ---- Pipeline state ---- self._inference_thread: threading.Thread | None = None self._inference_running = False @@ -507,22 +498,6 @@ def _check_joint_vel_safety(self, qvel: np.ndarray) -> bool: return True return False - # ---- Startup ramp ---- - - def _apply_startup_ramp(self, target_dof_pos: np.ndarray) -> np.ndarray: - if not self._ramp_active or self._ramp_start_positions is None: - return target_dof_pos - - ramp_factor = min(1.0, self._ramp_step / self._ramp_duration_steps) - ramped = self._ramp_start_positions + ramp_factor * (target_dof_pos - self._ramp_start_positions) - - self._ramp_step += 1 - if self._ramp_step >= self._ramp_duration_steps: - self._ramp_active = False - logger.info("Startup ramp complete (%d steps)", self._ramp_duration_steps) - - return np.asarray(ramped, dtype=np.float32) - # ---- Standing step (matches Sim2RealController._standing_step) ---- def _standing_step(self) -> np.ndarray: @@ -575,9 +550,6 @@ def _standing_step(self) -> np.ndarray: np.array2string(qpos[:6], precision=4, separator=','), ) - # Startup ramp - target_dof_pos = self._apply_startup_ramp(target_dof_pos) - # Joint limits target_dof_pos = np.clip(target_dof_pos, JOINT_POS_LOWER, JOINT_POS_UPPER) @@ -633,8 +605,7 @@ def _inference_loop(self) -> None: # Policy step if self._no_policy: - target = self._apply_startup_ramp(DEFAULT_ANGLES.copy()) - target = np.clip(target, JOINT_POS_LOWER, JOINT_POS_UPPER) + target = np.clip(DEFAULT_ANGLES.copy(), JOINT_POS_LOWER, JOINT_POS_UPPER) else: target = self._standing_step() @@ -667,7 +638,6 @@ def _run_dry(self) -> None: logger.info("=== DRY-RUN MODE: no motor commands will be sent ===") self._last_action = np.zeros(NUM_JOINTS, dtype=np.float32) self._policy.reset() - self._ramp_active = False dt = 1.0 / POLICY_HZ loop_count = 0 @@ -774,17 +744,10 @@ def run(self) -> None: self._policy.reset() logger.info("ONNX warmup complete") - # 5. Initialize policy state and startup ramp - qpos, _, quat, _ = self._get_robot_state() + # 5. Initialize policy state self._last_action = np.zeros(NUM_JOINTS, dtype=np.float32) - self._ramp_start_positions = qpos.copy() - self._ramp_step = 0 - self._ramp_active = True - logger.info( - "Starting RL policy standing (pipelined) | ramp=%d steps (%.1fs)", - self._ramp_duration_steps, self._ramp_duration_steps / POLICY_HZ, - ) + logger.info("Starting RL policy standing (pipelined)") # 6. Start inference thread (~50Hz, soft deadline) self._start_inference() @@ -825,10 +788,6 @@ def main(): "--network-interface", type=str, default="eth0", help="Network interface for DDS (e.g. eth0, enp130s0)", ) - parser.add_argument( - "--ramp-duration", type=float, default=RAMP_DURATION, - help="Seconds for startup ramp (default: 2.0)", - ) parser.add_argument( "--no-policy", action="store_true", help="Skip RL policy, just send fixed DEFAULT_ANGLES (diagnostic mode)", @@ -850,7 +809,6 @@ def main(): controller = StandingController( network_interface=args.network_interface, policy_path=args.policy, - ramp_duration=args.ramp_duration, no_policy=args.no_policy, publish_hz=args.publish_hz, ) diff --git a/teleopit/configs/pico4_sim2real.yaml b/teleopit/configs/pico4_sim2real.yaml index 2f976b62..a0597668 100644 --- a/teleopit/configs/pico4_sim2real.yaml +++ b/teleopit/configs/pico4_sim2real.yaml @@ -5,6 +5,7 @@ defaults: - _self_ policy_hz: 50.0 +viewers: "none" # Optional: set viewers=retarget to show the retargeted reference input: video: source: realsense @@ -17,7 +18,7 @@ reference_anchor_velocity_smoothing_alpha: 0.25 reference_steps: [0] reference_debug_log: false -# Startup ramp duration (seconds) -- smoothly blend from locked to policy positions +# Kp ramp duration (seconds) -- gradually increases PD gains after entering STANDING startup_ramp_duration: 2.0 # Joint velocity safety limit (rad/s) -- trigger emergency damping if exceeded joint_vel_limit: 10.0 diff --git a/teleopit/configs/sim2real.yaml b/teleopit/configs/sim2real.yaml index 17befeda..92f72d7c 100644 --- a/teleopit/configs/sim2real.yaml +++ b/teleopit/configs/sim2real.yaml @@ -5,6 +5,7 @@ defaults: - _self_ policy_hz: 50.0 +viewers: "none" # Optional: set viewers=retarget to show the retargeted reference retarget_buffer_enabled: true retarget_buffer_window_s: 0.5 retarget_buffer_delay_s: null # null = auto use one input-frame delay for timeline sampling @@ -16,7 +17,7 @@ reference_debug_log: false playback: pause_on_end: true -# Startup ramp duration (seconds) -- smoothly blend from locked to policy positions +# Kp ramp duration (seconds) -- gradually increases PD gains after entering STANDING startup_ramp_duration: 2.0 # Joint velocity safety limit (rad/s) -- trigger emergency damping if exceeded joint_vel_limit: 10.0 diff --git a/teleopit/sim2real/controller.py b/teleopit/sim2real/controller.py index fcaa41f1..3798526a 100644 --- a/teleopit/sim2real/controller.py +++ b/teleopit/sim2real/controller.py @@ -31,7 +31,7 @@ from teleopit.inputs.pico_video import PicoVideoRuntime, parse_pico_video_config from teleopit.inputs.realtime_packet import ControlEvent, ControlEventType, RealtimeInputPacket from teleopit.retargeting.core import RetargetingModule -from teleopit.runtime.common import cfg_get +from teleopit.runtime.common import cfg_get, parse_viewers from teleopit.runtime.reference_config import parse_reference_config from teleopit.runtime.factory import build_sim2real_mocap_components from teleopit.runtime.mocap_session import MocapSessionManager, MocapSessionState @@ -44,6 +44,7 @@ obs_builder_requires_reference_window, ) from teleopit.sim.realtime_utils import RealtimeReferenceManager +from teleopit.sim.viewer_subprocess import start_robot_viewer from teleopit.sim2real.dexterous_hand import build_linkerhand_runtime from teleopit.sim2real.reference_processor import Sim2RealReferenceProcessor from teleopit.sim2real.remote import UnitreeRemote @@ -63,6 +64,56 @@ class RobotMode(Enum): DAMPING = "damping" # Emergency stop / recovery +def _parse_sim2real_viewers(cfg: Any) -> set[str]: + viewers = parse_viewers(cfg) + unsupported = viewers.difference({"retarget"}) + if unsupported: + raise ValueError( + f"Sim2real supports only the optional 'retarget' viewer; got unsupported viewers {sorted(unsupported)}. " + "Use viewers=retarget or viewers=none." + ) + return viewers + + +class _Sim2RealRetargetViewer: + def __init__(self, *, xml_path: str | None, enabled: bool) -> None: + self._entry: tuple[Any, Any, Any, Any] | None = None + if not enabled: + return + if not xml_path: + raise ValueError("Sim2real retarget viewer requires robot.xml_path to be set.") + self._entry = start_robot_viewer( + xml_path, + FULL_QPOS_DIM, + True, + "Retarget", + 900, + 50, + ) + + def write(self, qpos: Float64Array) -> None: + if self._entry is None: + return + _, arr, alive, _ = self._entry + if not alive.value: + return + qpos = np.asarray(qpos, dtype=np.float64).reshape(-1) + if qpos.shape[0] < FULL_QPOS_DIM: + return + with arr.get_lock(): + arr[:FULL_QPOS_DIM] = qpos[:FULL_QPOS_DIM].tolist() + + def shutdown(self) -> None: + if self._entry is None: + return + proc, _, _, shutdown = self._entry + shutdown.set() + proc.join(timeout=3) + if proc.is_alive(): + proc.terminate() + self._entry = None + + class Sim2RealController: """G1 real-robot controller -- standing/mocap dual mode with state machine. @@ -145,6 +196,11 @@ def _init_components(self, cfg: Any) -> None: self._mocap_reentry_armed: bool = False self._mocap_session = MocapSessionManager() self._last_commanded_motion_qpos: Float64Array | None = None + self._viewers = _parse_sim2real_viewers(cfg) + self._retarget_viewer = _Sim2RealRetargetViewer( + xml_path=str(cfg_get(robot_cfg, "xml_path", "")) if "retarget" in self._viewers else None, + enabled="retarget" in self._viewers, + ) def _init_reference_config(self, cfg: Any) -> None: """Parse reference-window / realtime-buffer configuration.""" @@ -282,6 +338,7 @@ def _standing_step(self) -> None: self._last_action = np.asarray(action, dtype=np.float32).reshape(-1) self._last_retarget_qpos = qpos.copy() self._last_commanded_motion_qpos = qpos.copy() + self._write_retarget_viewer(qpos) def _mocap_step(self) -> None: """Mocap mode: input provider -> retarget -> policy -> update LowCmd targets.""" @@ -430,6 +487,7 @@ def _execute_mocap_pipeline( self._last_retarget_qpos = qpos.copy() self._ref_proc.last_reference_qpos = reference_qpos.copy() self._last_commanded_motion_qpos = qpos.copy() + self._write_retarget_viewer(qpos) # ------------------------------------------------------------------ # State machine transitions @@ -527,7 +585,7 @@ def _enter_standing(self) -> None: self._reset_policy_state() # Kp ramp: gradually increase PD gains to avoid torque spike. - # Unlike the old position ramp, this does NOT break action-state causality. + # Unlike position ramping, this does not alter policy targets. self._safety.start_kp_ramp() self._mocap_reentry_armed = prev_mode == RobotMode.MOCAP @@ -843,6 +901,7 @@ def _run_static_mocap_step(self, hold_qpos: Float64Array) -> None: self._last_retarget_qpos = qpos.copy() self._ref_proc.last_reference_qpos = qpos.copy() self._last_commanded_motion_qpos = qpos.copy() + self._write_retarget_viewer(qpos) def _tick_dexterous_hand(self) -> None: active = self.mode == RobotMode.MOCAP and self._mocap_session.state == MocapSessionState.ACTIVE @@ -857,6 +916,12 @@ def _deactivate_dexterous_hand(self) -> None: except Exception: logger.exception("Failed to deactivate dexterous hand runtime") + def _write_retarget_viewer(self, qpos: Float64Array) -> None: + try: + self._retarget_viewer.write(qpos) + except Exception: + logger.exception("Sim2real retarget viewer update failed; control continues") + @staticmethod def _sleep_until(t0: float, dt: float) -> None: """Sleep to maintain control frequency.""" @@ -890,6 +955,10 @@ def shutdown(self) -> None: self._hand_runtime.close() except Exception: pass + try: + self._retarget_viewer.shutdown() + except Exception: + pass try: self.input_provider.close() except Exception: diff --git a/teleopit/sim2real/safety.py b/teleopit/sim2real/safety.py index 7cd05423..2b31ee23 100644 --- a/teleopit/sim2real/safety.py +++ b/teleopit/sim2real/safety.py @@ -30,7 +30,6 @@ def __init__( ) -> None: self._robot = robot self._policy_hz = policy_hz - real_cfg = cfg_get(cfg, "real_robot") # KP ramp (gradually increase PD gains after episode-reset) diff --git a/tests/test_sim2real_runtime.py b/tests/test_sim2real_runtime.py index 79f54212..e5687c4a 100644 --- a/tests/test_sim2real_runtime.py +++ b/tests/test_sim2real_runtime.py @@ -168,6 +168,7 @@ def _make_cfg() -> dict[str, object]: "robot": { "default_angles": [0.0] * 29, "num_actions": 29, + "xml_path": "robot.xml", }, "controller": {}, "input": {"provider": "pico4"}, @@ -227,6 +228,82 @@ def test_reset_policy_state_clears_reference_timeline(monkeypatch) -> None: assert ctrl._last_live_packet_seq == -1 +def test_sim2real_retarget_viewer_defaults_off(monkeypatch) -> None: + import teleopit.sim2real.controller as controller_mod + from teleopit.sim2real.controller import Sim2RealController + + policy = DummyPolicy() + obs_builder = DummyVelCmdObservationBuilder() + _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=np.zeros(36, dtype=np.float64)) + + starts: list[tuple[object, ...]] = [] + monkeypatch.setattr(controller_mod, "start_robot_viewer", lambda *args, **kwargs: starts.append(args)) + + Sim2RealController(_make_cfg()) + + assert starts == [] + + +def test_sim2real_retarget_viewer_writes_reference_qpos(monkeypatch) -> None: + import multiprocessing as mp + + import teleopit.sim2real.controller as controller_mod + from teleopit.sim2real.controller import Sim2RealController + + policy = DummyPolicy() + obs_builder = DummyVelCmdObservationBuilder() + target_qpos = np.zeros(36, dtype=np.float64) + target_qpos[0] = 0.25 + target_qpos[3] = 1.0 + target_qpos[7] = 0.5 + _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=target_qpos) + + arr = mp.Array("d", 36) + alive = mp.Value("i", 1) + shutdown = mp.Event() + proc = SimpleNamespace(join=lambda timeout=None: None, is_alive=lambda: False, terminate=lambda: None) + starts: list[tuple[object, ...]] = [] + + def fake_start_robot_viewer(*args: object, **_kwargs: object) -> tuple[object, object, object, object]: + starts.append(args) + return proc, arr, alive, shutdown + + monkeypatch.setattr(controller_mod, "start_robot_viewer", fake_start_robot_viewer) + cfg = _make_cfg() + cfg["retarget_buffer_enabled"] = False + cfg["viewers"] = "retarget" + ctrl = Sim2RealController(cfg) + monkeypatch.setattr( + ctrl._ref_proc, + "compute_anchor_velocities", + lambda _qpos: ( + np.zeros(3, dtype=np.float32), + np.zeros(3, dtype=np.float32), + ), + ) + + ctrl._mocap_step() + + assert starts + with arr.get_lock(): + written = np.asarray(arr[:], dtype=np.float64) + np.testing.assert_allclose(written[[0, 3, 7]], target_qpos[[0, 3, 7]], atol=1e-6) + + +def test_sim2real_retarget_viewer_rejects_sim_viewers(monkeypatch) -> None: + from teleopit.sim2real.controller import Sim2RealController + + policy = DummyPolicy() + obs_builder = DummyVelCmdObservationBuilder() + _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=np.zeros(36, dtype=np.float64)) + + cfg = _make_cfg() + cfg["viewers"] = ["retarget", "sim2sim"] + + with pytest.raises(ValueError, match="supports only the optional 'retarget' viewer"): + Sim2RealController(cfg) + + def test_sim2real_rejects_nonzero_reference_steps_without_buffer(monkeypatch) -> None: from teleopit.sim2real.controller import Sim2RealController From 63e0946f40f52aa6292b296ffb748f82c9bd41ed Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Tue, 26 May 2026 14:19:26 +0800 Subject: [PATCH 036/122] Add Kp ramp to standalone standing --- docs/docs/tutorials/standalone-standing.md | 12 ++++ .../current/tutorials/standalone-standing.md | 11 ++++ scripts/run/standalone_standing.py | 62 ++++++++++++++++++- 3 files changed, 82 insertions(+), 3 deletions(-) diff --git a/docs/docs/tutorials/standalone-standing.md b/docs/docs/tutorials/standalone-standing.md index a40c63bf..214ee6e2 100644 --- a/docs/docs/tutorials/standalone-standing.md +++ b/docs/docs/tutorials/standalone-standing.md @@ -58,6 +58,18 @@ python scripts/run/standalone_standing.py \ --network-interface eth0 ``` +Standalone standing uses the same Kp ramp semantics as sim2real: after locking +the current joints, policy targets are sent immediately while Kp ramps from 10% +to the configured gains over 2 seconds. To tune this startup behavior: + +```bash +python scripts/run/standalone_standing.py \ + --policy track.onnx \ + --network-interface eth0 \ + --kp-ramp-duration 2.0 \ + --kp-ramp-floor-ratio 0.1 +``` + ## What It Checks - `g1_bridge_sdk` imports correctly. diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/standalone-standing.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/standalone-standing.md index 14802910..5201df02 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/standalone-standing.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/standalone-standing.md @@ -57,6 +57,17 @@ python scripts/run/standalone_standing.py \ --network-interface eth0 ``` +standalone standing 使用与 sim2real 相同的 Kp ramp 语义:锁住当前关节后立即发送 +policy target,同时在 2 秒内把 Kp 从 10% 逐步升到配置的增益。可以这样调整启动行为: + +```bash +python scripts/run/standalone_standing.py \ + --policy track.onnx \ + --network-interface eth0 \ + --kp-ramp-duration 2.0 \ + --kp-ramp-floor-ratio 0.1 +``` + ## 它会检查什么 - `g1_bridge_sdk` 能正确导入。 diff --git a/scripts/run/standalone_standing.py b/scripts/run/standalone_standing.py index 6338bbbb..11cd3577 100644 --- a/scripts/run/standalone_standing.py +++ b/scripts/run/standalone_standing.py @@ -55,6 +55,8 @@ VEL_STOP_F = 16000.0 KD_DAMPING = 8.0 JOINT_VEL_LIMIT = 10.0 +DEFAULT_KP_RAMP_DURATION = 2.0 +DEFAULT_KP_RAMP_FLOOR_RATIO = 0.1 # Default standing pose (from g1_constants.py HOME_KEYFRAME) DEFAULT_ANGLES = np.array([ @@ -357,9 +359,15 @@ class StandingController: def __init__(self, network_interface: str, policy_path: str, no_policy: bool = False, - publish_hz: int = 250) -> None: + publish_hz: int = 250, + kp_ramp_duration: float = DEFAULT_KP_RAMP_DURATION, + kp_ramp_floor_ratio: float = DEFAULT_KP_RAMP_FLOOR_RATIO) -> None: self._network_interface = network_interface self._shutdown = False + if kp_ramp_duration < 0.0: + raise ValueError("kp_ramp_duration must be >= 0") + if not 0.0 <= kp_ramp_floor_ratio <= 1.0: + raise ValueError("kp_ramp_floor_ratio must be in [0, 1]") # ---- Load policy and observation builder ---- self._policy = PolicyInference(policy_path) @@ -393,6 +401,10 @@ def __init__(self, network_interface: str, policy_path: str, self._inference_running = False self._publish_hz = publish_hz + self._kp_ramp_duration_steps = max(1, int(kp_ramp_duration * POLICY_HZ)) + self._kp_ramp_floor_ratio = float(kp_ramp_floor_ratio) + self._kp_ramp_step = 0 + self._kp_ramp_active = False self._init_cpp_backend() @@ -477,6 +489,38 @@ def _lock_joints(self) -> None: self._bridge.set_target(qpos, KP, KD) self._bridge.lock_joints() + def _start_kp_ramp(self) -> None: + self._kp_ramp_step = 0 + self._kp_ramp_active = True + logger.info( + "Kp ramp armed: %d steps (%.1fs), floor_ratio=%.2f", + self._kp_ramp_duration_steps, + self._kp_ramp_duration_steps / POLICY_HZ, + self._kp_ramp_floor_ratio, + ) + + def _compute_kp_ramp_gains(self) -> tuple[np.ndarray, np.ndarray] | None: + if not self._kp_ramp_active: + return None + + factor = min(1.0, self._kp_ramp_step / self._kp_ramp_duration_steps) + kp = KP * (self._kp_ramp_floor_ratio + (1.0 - self._kp_ramp_floor_ratio) * factor) + + self._kp_ramp_step += 1 + if self._kp_ramp_step >= self._kp_ramp_duration_steps: + self._kp_ramp_active = False + logger.info("Kp ramp complete (%d steps)", self._kp_ramp_duration_steps) + + return np.asarray(kp, dtype=np.float32), KD.copy() + + def _send_target(self, target: np.ndarray) -> None: + gains = self._compute_kp_ramp_gains() + if gains is None: + self._bridge.set_target(target, KP, KD) + return + kp, kd = gains + self._bridge.set_target(target, kp, kd) + # ================================================================== # Safety checks # ================================================================== @@ -609,8 +653,9 @@ def _inference_loop(self) -> None: else: target = self._standing_step() - # Write target to publish thread - self._bridge.set_target(target, KP, KD) + # Write target to publish thread. Kp ramps after standing entry; + # policy targets stay unchanged, matching sim2real STANDING. + self._send_target(target) # Timing diagnostics (informational only — not a control failure) elapsed = time.monotonic() - t0 @@ -746,6 +791,7 @@ def run(self) -> None: # 5. Initialize policy state self._last_action = np.zeros(NUM_JOINTS, dtype=np.float32) + self._start_kp_ramp() logger.info("Starting RL policy standing (pipelined)") @@ -804,6 +850,14 @@ def main(): "--publish-hz", type=int, default=200, help="C++ publish frequency in Hz (default: 200, matching training pd_hz)", ) + parser.add_argument( + "--kp-ramp-duration", type=float, default=DEFAULT_KP_RAMP_DURATION, + help="Seconds to ramp Kp after entering standing (default: 2.0, matches sim2real)", + ) + parser.add_argument( + "--kp-ramp-floor-ratio", type=float, default=DEFAULT_KP_RAMP_FLOOR_RATIO, + help="Initial Kp ratio for the standing ramp (default: 0.1, matches sim2real)", + ) args = parser.parse_args() controller = StandingController( @@ -811,6 +865,8 @@ def main(): policy_path=args.policy, no_policy=args.no_policy, publish_hz=args.publish_hz, + kp_ramp_duration=args.kp_ramp_duration, + kp_ramp_floor_ratio=args.kp_ramp_floor_ratio, ) controller._state_delay = args.state_delay controller._dry_run = args.dry_run From f3b856844f4ef1a8e10d12d2814e5404ac8f055d Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Tue, 26 May 2026 15:53:33 +0800 Subject: [PATCH 037/122] Add latency simulation to standalone standing --- scripts/run/standalone_standing.py | 174 +++++++++++++++++++++++++++-- 1 file changed, 167 insertions(+), 7 deletions(-) diff --git a/scripts/run/standalone_standing.py b/scripts/run/standalone_standing.py index 11cd3577..8cdaef3e 100644 --- a/scripts/run/standalone_standing.py +++ b/scripts/run/standalone_standing.py @@ -360,10 +360,16 @@ class StandingController: def __init__(self, network_interface: str, policy_path: str, no_policy: bool = False, publish_hz: int = 250, + obs_delay: float = 0.0, + command_delay: float = 0.0, kp_ramp_duration: float = DEFAULT_KP_RAMP_DURATION, kp_ramp_floor_ratio: float = DEFAULT_KP_RAMP_FLOOR_RATIO) -> None: self._network_interface = network_interface self._shutdown = False + if obs_delay < 0.0: + raise ValueError("obs_delay must be >= 0") + if command_delay < 0.0: + raise ValueError("command_delay must be >= 0") if kp_ramp_duration < 0.0: raise ValueError("kp_ramp_duration must be >= 0") if not 0.0 <= kp_ramp_floor_ratio <= 1.0: @@ -387,6 +393,18 @@ def __init__(self, network_interface: str, policy_path: str, self._no_policy = no_policy self._dry_run = False self._state_delay = 0.0 + self._obs_delay = float(obs_delay) + self._command_delay = float(command_delay) + self._state_history: deque[tuple[float, tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]] = deque(maxlen=512) + self._state_history_lock = threading.Lock() + self._state_sampler_thread: threading.Thread | None = None + self._state_sampler_running = False + self._pending_targets: deque[tuple[float, np.ndarray]] = deque() + self._pending_targets_cv = threading.Condition() + self._command_sender_thread: threading.Thread | None = None + self._command_sender_running = False + self._last_obs_age_s = 0.0 + self._last_command_queue_len = 0 # ---- Policy state ---- self._step_count = 0 @@ -427,7 +445,59 @@ def _init_cpp_backend(self) -> None: # ================================================================== def _get_robot_state(self): - return self._bridge.get_state() + return self._read_robot_state() + + def _read_robot_state(self) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + qpos, qvel, quat, ang_vel = self._bridge.get_state() + return ( + np.asarray(qpos, dtype=np.float32).copy(), + np.asarray(qvel, dtype=np.float32).copy(), + np.asarray(quat, dtype=np.float32).copy(), + np.asarray(ang_vel, dtype=np.float32).copy(), + ) + + def _record_robot_state(self) -> tuple[float, tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]: + now = time.monotonic() + state = self._read_robot_state() + with self._state_history_lock: + self._state_history.append((now, state)) + return now, state + + def _get_observation_state(self) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + now, current = self._record_robot_state() + if self._obs_delay <= 0.0: + self._last_obs_age_s = 0.0 + return current + + target_time = now - self._obs_delay + with self._state_history_lock: + history = list(self._state_history) + selected_time, selected_state = history[0] + for sample_time, sample_state in reversed(history): + if sample_time <= target_time: + selected_time, selected_state = sample_time, sample_state + break + self._last_obs_age_s = max(0.0, now - selected_time) + return selected_state + + def _start_state_sampler(self) -> None: + if self._obs_delay <= 0.0 or self._state_sampler_thread is not None: + return + self._state_sampler_running = True + self._state_sampler_thread = threading.Thread(target=self._state_sampler_loop, daemon=True) + self._state_sampler_thread.start() + + def _stop_state_sampler(self) -> None: + self._state_sampler_running = False + if self._state_sampler_thread is not None: + self._state_sampler_thread.join(timeout=1.0) + self._state_sampler_thread = None + + def _state_sampler_loop(self) -> None: + sample_dt = 1.0 / max(float(self._publish_hz), POLICY_HZ) + while self._state_sampler_running and not self._shutdown: + self._record_robot_state() + time.sleep(sample_dt) # ================================================================== # Publish thread @@ -513,7 +583,7 @@ def _compute_kp_ramp_gains(self) -> tuple[np.ndarray, np.ndarray] | None: return np.asarray(kp, dtype=np.float32), KD.copy() - def _send_target(self, target: np.ndarray) -> None: + def _write_target_now(self, target: np.ndarray) -> None: gains = self._compute_kp_ramp_gains() if gains is None: self._bridge.set_target(target, KP, KD) @@ -521,6 +591,69 @@ def _send_target(self, target: np.ndarray) -> None: kp, kd = gains self._bridge.set_target(target, kp, kd) + def _pop_due_targets(self, now: float) -> list[np.ndarray]: + due: list[np.ndarray] = [] + with self._pending_targets_cv: + while self._pending_targets and self._pending_targets[0][0] <= now: + _, target = self._pending_targets.popleft() + due.append(target) + self._last_command_queue_len = len(self._pending_targets) + return due + + def _flush_pending_targets(self, now: float | None = None) -> None: + if now is None: + now = time.monotonic() + due = self._pop_due_targets(now) + for target in due: + self._write_target_now(target) + if len(due) > 1: + logger.warning("Flushed %d delayed targets in one control tick", len(due)) + + def _send_target(self, target: np.ndarray) -> None: + if self._command_delay <= 0.0: + self._write_target_now(target) + self._last_command_queue_len = len(self._pending_targets) + return + with self._pending_targets_cv: + self._pending_targets.append((time.monotonic() + self._command_delay, np.asarray(target, dtype=np.float32).copy())) + self._last_command_queue_len = len(self._pending_targets) + self._pending_targets_cv.notify() + + def _start_command_sender(self) -> None: + if self._command_delay <= 0.0 or self._command_sender_thread is not None: + return + self._command_sender_running = True + self._command_sender_thread = threading.Thread(target=self._command_sender_loop, daemon=True) + self._command_sender_thread.start() + + def _stop_command_sender(self) -> None: + self._command_sender_running = False + with self._pending_targets_cv: + self._pending_targets_cv.notify_all() + if self._command_sender_thread is not None: + self._command_sender_thread.join(timeout=1.0) + self._command_sender_thread = None + with self._pending_targets_cv: + self._pending_targets.clear() + self._last_command_queue_len = 0 + + def _command_sender_loop(self) -> None: + while self._command_sender_running and not self._shutdown: + now = time.monotonic() + due = self._pop_due_targets(now) + if due: + self._write_target_now(due[-1]) + if len(due) > 1: + logger.warning("Dropped %d stale delayed targets", len(due) - 1) + continue + + with self._pending_targets_cv: + if not self._pending_targets: + self._pending_targets_cv.wait(timeout=0.02) + continue + wait_s = max(0.0, self._pending_targets[0][0] - time.monotonic()) + self._pending_targets_cv.wait(timeout=min(wait_s, 0.02)) + # ================================================================== # Safety checks # ================================================================== @@ -547,7 +680,7 @@ def _check_joint_vel_safety(self, qvel: np.ndarray) -> bool: def _standing_step(self) -> np.ndarray: """One step of RL policy standing inference. Returns target joint positions.""" _t0 = time.monotonic() - qpos, qvel, quat, ang_vel = self._get_robot_state() + qpos, qvel, quat, ang_vel = self._get_observation_state() # Build standing reference aligned to robot's current yaw ref_qpos = self._standing_qpos.copy() @@ -583,11 +716,13 @@ def _standing_step(self) -> np.ndarray: tag = "OVERRUN" if step_ms > (1000.0 / POLICY_HZ) else "DIAG" logger.info( "%s step=%d | state=%.2fms obs=%.2fms infer=%.2fms total=%.1fms | " - "qvel_norm=%.4f | action_norm=%.4f | " + "obs_age=%.1fms cmd_q=%d | qvel_norm=%.4f | action_norm=%.4f | " "target[:6]=%s | qpos[:6]=%s", tag, self._step_count, (_t1 - _t0) * 1000, (_t2 - _t1) * 1000, (_t3 - _t2) * 1000, step_ms, + self._last_obs_age_s * 1000, + self._last_command_queue_len, float(np.linalg.norm(qvel)), float(np.linalg.norm(action)), np.array2string(target_dof_pos[:6], precision=4, separator=','), @@ -631,6 +766,8 @@ def _inference_loop(self) -> None: while self._inference_running and not self._shutdown: t0 = time.monotonic() + if self._command_delay <= 0.0: + self._flush_pending_targets(t0) # Emergency stop check if self._check_emergency_stop(): @@ -676,6 +813,9 @@ def _inference_loop(self) -> None: if remain > 0: time.sleep(remain) + if self._command_delay <= 0.0: + self._flush_pending_targets() + # ---- Main loop ---- def _run_dry(self) -> None: @@ -693,11 +833,12 @@ def _run_dry(self) -> None: obs_sum = 0.0 infer_sum = 0.0 + self._start_state_sampler() while not self._shutdown: t0 = time.monotonic() # 1. Read state - qpos, qvel, quat, ang_vel = self._get_robot_state() + qpos, qvel, quat, ang_vel = self._get_observation_state() t1 = time.monotonic() # 2. Build reference @@ -734,11 +875,12 @@ def _run_dry(self) -> None: n = loop_count logger.info( "DRY step=%d | state=%.2fms obs=%.2fms infer=%.2fms total=%.2fms | " - "max=%.2fms overruns=%d/%d | target[:6]=%s", + "max=%.2fms overruns=%d/%d | obs_age=%.1fms | target[:6]=%s", n, (state_sum / n) * 1000, (obs_sum / n) * 1000, (infer_sum / n) * 1000, (elapsed_sum / n) * 1000, max_elapsed * 1000, overrun_count, n, + self._last_obs_age_s * 1000, np.array2string(target[:6], precision=4, separator=','), ) @@ -746,6 +888,7 @@ def _run_dry(self) -> None: if remain > 0: time.sleep(remain) + self._stop_state_sampler() logger.info( "DRY-RUN finished: %d steps, avg total=%.2fms " "(state=%.2f obs=%.2f infer=%.2f) max=%.2fms overruns=%d", @@ -792,6 +935,8 @@ def run(self) -> None: # 5. Initialize policy state self._last_action = np.zeros(NUM_JOINTS, dtype=np.float32) self._start_kp_ramp() + self._start_state_sampler() + self._start_command_sender() logger.info("Starting RL policy standing (pipelined)") @@ -816,6 +961,8 @@ def _signal_handler(self, signum, frame) -> None: def _cleanup(self) -> None: self._inference_running = False + self._stop_state_sampler() + self._stop_command_sender() logger.info("Shutting down: setting damping ...") self._set_damping() time.sleep(0.5) @@ -840,7 +987,18 @@ def main(): ) parser.add_argument( "--state-delay", type=float, default=0.0, - help="Artificial delay (seconds) before reading state, simulates network latency (e.g. 0.005)", + help=( + "Legacy loop delay before the policy step. This consumes timing budget but does not make " + "the observation stale; prefer --obs-delay or --command-delay for latency tests." + ), + ) + parser.add_argument( + "--obs-delay", type=float, default=0.0, + help="Use LowState sampled this many seconds in the past when building the policy observation.", + ) + parser.add_argument( + "--command-delay", type=float, default=0.0, + help="Delay writing each computed target to the C++ publish thread by this many seconds.", ) parser.add_argument( "--dry-run", action="store_true", @@ -865,6 +1023,8 @@ def main(): policy_path=args.policy, no_policy=args.no_policy, publish_hz=args.publish_hz, + obs_delay=args.obs_delay, + command_delay=args.command_delay, kp_ramp_duration=args.kp_ramp_duration, kp_ramp_floor_ratio=args.kp_ramp_floor_ratio, ) From 5da1bdfe4665e0af033537b952e5ffb6b6281c2d Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Tue, 26 May 2026 16:35:35 +0800 Subject: [PATCH 038/122] Adjust tracking randomization tests and config --- tests/test_domain_randomization.py | 57 ++----------- tests/test_task_registry.py | 12 --- train_mimic/tasks/tracking/config/env.py | 4 - .../tasks/tracking/tracking_env_cfg.py | 83 +++++-------------- 4 files changed, 29 insertions(+), 127 deletions(-) diff --git a/tests/test_domain_randomization.py b/tests/test_domain_randomization.py index d58d64e9..0c0941b9 100644 --- a/tests/test_domain_randomization.py +++ b/tests/test_domain_randomization.py @@ -16,11 +16,7 @@ def test_general_tracking_domain_randomization_matches_gr00t_active_set() -> Non assert set(events) == { "push_robot", "base_com", - "encoder_bias", "add_joint_default_pos", - "motor_params_implicit_upper_body_pd", - "motor_params_implicit_lower_body_pd", - "motor_params_implicit_armature", "physics_material", "randomize_rigid_body_mass", "randomize_dexhand_payload_mass", @@ -53,11 +49,6 @@ def test_general_tracking_domain_randomization_matches_gr00t_active_set() -> Non 2: (-0.05, 0.05), } - encoder_bias = events["encoder_bias"] - assert encoder_bias.func is dr.encoder_bias - assert encoder_bias.mode == "startup" - assert encoder_bias.params["bias_range"] == (-0.01, 0.01) - add_joint_default_pos = events["add_joint_default_pos"] assert add_joint_default_pos.func is dr.joint_default_pos assert add_joint_default_pos.mode == "startup" @@ -65,34 +56,6 @@ def test_general_tracking_domain_randomization_matches_gr00t_active_set() -> Non assert add_joint_default_pos.params["operation"] == "add" assert add_joint_default_pos.params["ranges"] == (-0.01, 0.01) - upper_motor_pd = events["motor_params_implicit_upper_body_pd"] - assert upper_motor_pd.func is dr.pd_gains - assert upper_motor_pd.mode == "reset" - assert upper_motor_pd.params["asset_cfg"].actuator_names is None - assert upper_motor_pd.params["asset_cfg"].actuator_ids == [0, 3] - assert upper_motor_pd.params["kp_range"] == (0.9, 1.1) - assert upper_motor_pd.params["kd_range"] == (0.9, 1.1) - assert upper_motor_pd.params["distribution"] == "log_uniform" - assert upper_motor_pd.params["operation"] == "scale" - - lower_motor_pd = events["motor_params_implicit_lower_body_pd"] - assert lower_motor_pd.func is dr.pd_gains - assert lower_motor_pd.mode == "reset" - assert lower_motor_pd.params["asset_cfg"].actuator_names is None - assert lower_motor_pd.params["asset_cfg"].actuator_ids == [1, 2, 4, 5] - assert lower_motor_pd.params["kp_range"] == (0.5, 2.0) - assert lower_motor_pd.params["kd_range"] == (0.5, 2.0) - assert lower_motor_pd.params["distribution"] == "log_uniform" - assert lower_motor_pd.params["operation"] == "scale" - - motor_armature = events["motor_params_implicit_armature"] - assert motor_armature.func is dr.joint_armature - assert motor_armature.mode == "startup" - assert motor_armature.params["asset_cfg"].joint_names == ".*" - assert motor_armature.params["ranges"] == (0.75, 1.25) - assert motor_armature.params["distribution"] == "log_uniform" - assert motor_armature.params["operation"] == "scale" - physics_material = events["physics_material"] assert physics_material.func is dr.geom_friction assert physics_material.mode == "startup" @@ -113,13 +76,13 @@ def test_general_tracking_domain_randomization_matches_gr00t_active_set() -> Non "left_dexhand_payload", "right_dexhand_payload", ) - assert dexhand_mass.params["alpha_range"] == (-8.0, 0.34657359027997264) + assert dexhand_mass.params["alpha_range"] == (-1, 0) gimbal_mass = events["randomize_gimbal_payload_mass"] assert gimbal_mass.func is dr.pseudo_inertia assert gimbal_mass.mode == "startup" assert gimbal_mass.params["asset_cfg"].body_names == ("head_gimbal_payload",) - assert gimbal_mass.params["alpha_range"] == (-8.0, 0.34657359027997264) + assert gimbal_mass.params["alpha_range"] == (-1, 0) dexhand_pos = events["randomize_dexhand_payload_pos"] assert dexhand_pos.func is dr.body_pos @@ -130,9 +93,9 @@ def test_general_tracking_domain_randomization_matches_gr00t_active_set() -> Non ) assert dexhand_pos.params["operation"] == "abs" assert dexhand_pos.params["ranges"] == { - 0: (0.04, 0.12), - 1: (-0.03, 0.03), - 2: (-0.03, 0.03), + 0: (0.055, 0.095), + 1: (-0.02, 0.02), + 2: (-0.02, 0.02), } gimbal_pos = events["randomize_gimbal_payload_pos"] @@ -141,9 +104,9 @@ def test_general_tracking_domain_randomization_matches_gr00t_active_set() -> Non assert gimbal_pos.params["asset_cfg"].body_names == ("head_gimbal_payload",) assert gimbal_pos.params["operation"] == "abs" assert gimbal_pos.params["ranges"] == { - 0: (0.03, 0.12), - 1: (-0.03, 0.03), - 2: (0.40, 0.50), + 0: (0.05, 0.09), + 1: (-0.02, 0.02), + 2: (0.43, 0.47), } @@ -156,11 +119,7 @@ def test_play_env_disables_training_only_domain_randomization() -> None: assert "push_robot" not in play_cfg.events assert "base_com" not in play_cfg.events - assert "encoder_bias" not in play_cfg.events assert "add_joint_default_pos" not in play_cfg.events - assert "motor_params_implicit_upper_body_pd" not in play_cfg.events - assert "motor_params_implicit_lower_body_pd" not in play_cfg.events - assert "motor_params_implicit_armature" not in play_cfg.events assert "physics_material" not in play_cfg.events assert "randomize_rigid_body_mass" not in play_cfg.events assert "randomize_dexhand_payload_mass" not in play_cfg.events diff --git a/tests/test_task_registry.py b/tests/test_task_registry.py index 71ecd80e..7a437b97 100644 --- a/tests/test_task_registry.py +++ b/tests/test_task_registry.py @@ -74,18 +74,6 @@ def test_general_tracking_task_is_registered() -> None: assert feet_acc.params["asset_cfg"].name == "robot" assert feet_acc.params["asset_cfg"].joint_names == r".*ankle.*" assert "anti_shake_ang_vel" not in env_cfg.rewards - upper_pd_asset = env_cfg.events["motor_params_implicit_upper_body_pd"].params[ - "asset_cfg" - ] - lower_pd_asset = env_cfg.events["motor_params_implicit_lower_body_pd"].params[ - "asset_cfg" - ] - assert upper_pd_asset.name == "robot" - assert upper_pd_asset.actuator_names is None - assert upper_pd_asset.actuator_ids == [0, 3] - assert lower_pd_asset.name == "robot" - assert lower_pd_asset.actuator_names is None - assert lower_pd_asset.actuator_ids == [1, 2, 4, 5] rl_cfg = load_rl_cfg(DEFAULT_TASK) assert rl_cfg.experiment_name == GENERAL_TRACKING_EXPERIMENT_NAME assert rl_cfg.actor.hidden_dims == (2048, 1024, 512, 256, 128) diff --git a/train_mimic/tasks/tracking/config/env.py b/train_mimic/tasks/tracking/config/env.py index 025e09f3..78a89e43 100644 --- a/train_mimic/tasks/tracking/config/env.py +++ b/train_mimic/tasks/tracking/config/env.py @@ -40,11 +40,7 @@ _TRAIN_ONLY_EVENTS = ( "push_robot", "base_com", - "encoder_bias", "add_joint_default_pos", - "motor_params_implicit_upper_body_pd", - "motor_params_implicit_lower_body_pd", - "motor_params_implicit_armature", "physics_material", "randomize_rigid_body_mass", "randomize_dexhand_payload_mass", diff --git a/train_mimic/tasks/tracking/tracking_env_cfg.py b/train_mimic/tasks/tracking/tracking_env_cfg.py index 4ff88d25..29b21dff 100644 --- a/train_mimic/tasks/tracking/tracking_env_cfg.py +++ b/train_mimic/tasks/tracking/tracking_env_cfg.py @@ -33,11 +33,22 @@ "yaw": (-0.78, 0.78), } -# G1 uses six mjlab actuator groups. dr.pd_gains indexes asset.actuators -# by group id, not the expanded XML per-joint actuator ids. -G1_UPPER_BODY_ACTUATOR_GROUP_IDS = (0, 3) -G1_LOWER_BODY_ACTUATOR_GROUP_IDS = (1, 2, 4, 5) +_DEXHAND_PAYLOAD_MASS_ALPHA_RANGE = (-1, 0) +_GIMBAL_PAYLOAD_MASS_ALPHA_RANGE = (-1, 0) +_DEXHAND_PAYLOAD_POS_RANGES_MM = { + 0: (55, 95), + 1: (-20, 20), + 2: (-20, 20), +} +_GIMBAL_PAYLOAD_POS_RANGES_MM = { + 0: (50, 90), + 1: (-20, 20), + 2: (430, 470), +} + +def _mm_ranges_to_m(ranges_mm: dict[int, tuple[int, int]]) -> dict[int, tuple[float, float]]: + return {axis: (lower / 1000.0, upper / 1000.0) for axis, (lower, upper) in ranges_mm.items()} def make_tracking_env_cfg() -> ManagerBasedRlEnvCfg: """Create base tracking task configuration.""" @@ -184,14 +195,6 @@ def make_tracking_env_cfg() -> ManagerBasedRlEnvCfg: }, }, ), - "encoder_bias": EventTermCfg( - mode="startup", - func=dr.encoder_bias, - params={ - "asset_cfg": SceneEntityCfg("robot"), - "bias_range": (-0.01, 0.01), - }, - ), "add_joint_default_pos": EventTermCfg( mode="startup", func=dr.joint_default_pos, @@ -201,42 +204,6 @@ def make_tracking_env_cfg() -> ManagerBasedRlEnvCfg: "ranges": (-0.01, 0.01), }, ), - "motor_params_implicit_upper_body_pd": EventTermCfg( - mode="reset", - func=dr.pd_gains, - params={ - "asset_cfg": SceneEntityCfg( - "robot", actuator_ids=list(G1_UPPER_BODY_ACTUATOR_GROUP_IDS) - ), - "kp_range": (0.9, 1.1), - "kd_range": (0.9, 1.1), - "distribution": "log_uniform", - "operation": "scale", - }, - ), - "motor_params_implicit_lower_body_pd": EventTermCfg( - mode="reset", - func=dr.pd_gains, - params={ - "asset_cfg": SceneEntityCfg( - "robot", actuator_ids=list(G1_LOWER_BODY_ACTUATOR_GROUP_IDS) - ), - "kp_range": (0.5, 2.0), - "kd_range": (0.5, 2.0), - "distribution": "log_uniform", - "operation": "scale", - }, - ), - "motor_params_implicit_armature": EventTermCfg( - mode="startup", - func=dr.joint_armature, - params={ - "asset_cfg": SceneEntityCfg("robot", joint_names=".*"), - "ranges": (0.75, 1.25), - "distribution": "log_uniform", - "operation": "scale", - }, - ), "physics_material": EventTermCfg( mode="startup", func=dr.geom_friction, @@ -259,8 +226,8 @@ def make_tracking_env_cfg() -> ManagerBasedRlEnvCfg: func=dr.pseudo_inertia, params={ "asset_cfg": SceneEntityCfg("robot", body_names=()), # Set per-robot. - # Nominal is 0.5 kg per hand. Scale covers 0-1.0 kg. - "alpha_range": (-8.0, 0.34657359027997264), + # Nominal is 0.5 kg per hand. Keep a tighter ~0.37-1.0x band. + "alpha_range": _DEXHAND_PAYLOAD_MASS_ALPHA_RANGE, }, ), "randomize_gimbal_payload_mass": EventTermCfg( @@ -268,8 +235,8 @@ def make_tracking_env_cfg() -> ManagerBasedRlEnvCfg: func=dr.pseudo_inertia, params={ "asset_cfg": SceneEntityCfg("robot", body_names=()), # Set per-robot. - # Nominal is 0.25 kg. Scale covers 0-0.5 kg. - "alpha_range": (-8.0, 0.34657359027997264), + # Nominal is 0.25 kg. Keep a tighter ~0.37-1.0x band. + "alpha_range": _GIMBAL_PAYLOAD_MASS_ALPHA_RANGE, }, ), "randomize_dexhand_payload_pos": EventTermCfg( @@ -278,11 +245,7 @@ def make_tracking_env_cfg() -> ManagerBasedRlEnvCfg: params={ "asset_cfg": SceneEntityCfg("robot", body_names=()), # Set per-robot. "operation": "abs", - "ranges": { - 0: (0.04, 0.12), - 1: (-0.03, 0.03), - 2: (-0.03, 0.03), - }, + "ranges": _mm_ranges_to_m(_DEXHAND_PAYLOAD_POS_RANGES_MM), }, ), "randomize_gimbal_payload_pos": EventTermCfg( @@ -291,11 +254,7 @@ def make_tracking_env_cfg() -> ManagerBasedRlEnvCfg: params={ "asset_cfg": SceneEntityCfg("robot", body_names=()), # Set per-robot. "operation": "abs", - "ranges": { - 0: (0.03, 0.12), - 1: (-0.03, 0.03), - 2: (0.40, 0.50), - }, + "ranges": _mm_ranges_to_m(_GIMBAL_PAYLOAD_POS_RANGES_MM), }, ), } From 7c390cf49268eb75171bbdc84203b39725d79a3d Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Tue, 26 May 2026 21:17:26 +0800 Subject: [PATCH 039/122] docs: add multinode training example --- docs/docs/tutorials/training.md | 18 ++++++++++++++++++ .../current/tutorials/training.md | 18 ++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/docs/docs/tutorials/training.md b/docs/docs/tutorials/training.md index 03c9ab70..c49d52f6 100644 --- a/docs/docs/tutorials/training.md +++ b/docs/docs/tutorials/training.md @@ -53,8 +53,26 @@ python train_mimic/scripts/train.py \ --motion_file data/datasets/seed/train ``` +### Multi-Node Multi-GPU + +Use `torchrun` directly when training across multiple machines: + +```bash +torchrun \ + --nnodes=$PET_NNODES \ + --nproc_per_node=$PET_NPROC_PER_NODE \ + --node_rank=$PET_NODE_RANK \ + --master_addr=$PET_MASTER_ADDR \ + --master_port=$PET_MASTER_PORT \ + train_mimic/scripts/train.py \ + --num_envs 1024 \ + --max_iterations 1000 \ + --motion_file data/datasets/seed/train +``` + **Notes:** - `--num_envs` is per-GPU in multi-GPU mode +- `--num_envs` is also per-process in multi-node mode, so total environments scale with `world_size` - Default logger is TensorBoard; pass `--wandb_project ` to enable W&B - `--motion_file` accepts only shard directories (containing `shard_*.npz` files) - `--max_iterations` means additional iterations; resuming from `model_12000.pt` with `--max_iterations 18000` trains to `model_30000.pt` diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md index 16e3f59d..dc9387ba 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md @@ -53,8 +53,26 @@ python train_mimic/scripts/train.py \ --motion_file data/datasets/seed/train ``` +### 多机多卡训练 + +跨多台机器训练时,直接使用 `torchrun`: + +```bash +torchrun \ + --nnodes=$PET_NNODES \ + --nproc_per_node=$PET_NPROC_PER_NODE \ + --node_rank=$PET_NODE_RANK \ + --master_addr=$PET_MASTER_ADDR \ + --master_port=$PET_MASTER_PORT \ + train_mimic/scripts/train.py \ + --num_envs 1024 \ + --max_iterations 1000 \ + --motion_file data/datasets/seed/train +``` + **注意事项:** - 多卡模式下 `--num_envs` 为每张 GPU 的环境数量 +- 多机模式下 `--num_envs` 也按每个进程计算,因此总环境数会随 `world_size` 线性增长 - 默认日志工具为 TensorBoard;传入 `--wandb_project ` 可启用 W&B - `--motion_file` 仅接受分片目录(包含 `shard_*.npz` 文件的目录) - `--max_iterations` 表示追加迭代次数;例如从 `model_12000.pt` 恢复训练并设置 `--max_iterations 18000`,最终将训练到 `model_30000.pt` From b8a2897d54555245805b46c0f88794af70828331 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Wed, 27 May 2026 19:41:09 +0800 Subject: [PATCH 040/122] Preserve retargeter warm-start on runtime resets --- AGENTS.md | 4 +++- README.md | 1 + teleopit/sim/loop.py | 2 -- teleopit/sim/session.py | 9 ++------- teleopit/sim2real/controller.py | 1 - teleopit/sim2real/safety.py | 7 +++---- tests/test_sim2real_runtime.py | 3 ++- 7 files changed, 11 insertions(+), 16 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index be65dca3..2bb25b66 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -123,7 +123,7 @@ target_dof_pos = clip(action, -10, 10) × action_scale + default_dof_pos ### Offline Playback - Offline sim2sim and default sim2real both read `input.bvh_file` directly; no UDP relay path remains - Offline sim2sim playback can be keyboard-controlled: `Space/P` pause/resume, `R` replay from frame 0, `Q` stop -- Offline pause holds the commanded pose; resume resets policy/reference state and reanchors yaw/XY without qpos interpolation +- Offline pause holds the commanded pose; resume resets policy/reference state and reanchors yaw/XY without qpos interpolation or retargeter IK reset - sim2sim keyboard playback is optional via `playback.keyboard.enabled=true` - sim2real reuses the Unitree remote: `Start` → `STANDING`, `Y` → playback, `X` → back to `STANDING`, `L1+R1` → `DAMPING` - `playback.pause_on_end=true` keeps the final pose and waits for manual replay @@ -141,6 +141,7 @@ target_dof_pos = clip(action, -10, 10) × action_scale + default_dof_pos - Default Pico sim2sim keyboard mappings are `Y` → `MOCAP`, `A` → pause/resume mocap, `X` → back to `STANDING`, `Q` → quit - Pico4 sim2real pause/resume is handled as a mocap-session control event (`toggle_pause`), not as a mode switch to `STANDING` - Default Pico pause button is `A`; resume rebuilds the realtime buffer and yaw/XY root-offset alignment, then waits for the configured realtime warmup before tracking continues +- Realtime mode switches and pause/resume use a retargeter-preserving soft reset: policy/reference history, smoothers, realtime buffers, and reference alignment are reset, while the GMR IK warm-start is retained - Optional LinkerHand L6 control uses `third_party/linkerhand-python-sdk` and `dexterous_hand.enabled=true` - LinkerHand control reuses `Pico4InputProvider.get_controller_snapshot()`; do not start a second `PicoBridge` for hand control - LinkerHand L6 control is active only in sim2real `MOCAP`; `STANDING`, `DAMPING`, mocap pause, frame timeout, and shutdown must send the configured open pose @@ -154,6 +155,7 @@ target_dof_pos = clip(action, -10, 10) × action_scale + default_dof_pos - Realtime inferred `motion_joint_vel`, anchor linear velocity, and anchor angular velocity can be EMA-smoothed via `reference_velocity_smoothing_alpha` and `reference_anchor_velocity_smoothing_alpha` - Sim2real Pico pause/resume uses mocap-session states `ACTIVE ↔ PAUSED`; resume clears policy/reference state, rebuilds yaw/XY root alignment, warms the realtime buffer, and does not interpolate retarget qpos from the paused pose - Realtime sim2sim with Pico control events uses the same mocap-session pause/resume semantics and rebuilds the realtime reference path on resume, including the configured warmup +- Realtime sim2sim/sim2real `STANDING ↔ MOCAP` transitions use the same retargeter-preserving soft reset, rather than cold-starting the retargeter from its default qpos - Realtime Pico sim2sim can start directly in `STANDING` with keyboard mode control enabled via top-level `keyboard.enabled` ### Inference Observation diff --git a/README.md b/README.md index 00f80b9c..85f62270 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,7 @@ Full docs at **[BotRunner64.github.io/Teleopit](https://BotRunner64.github.io/Te - Bumped Pico input support to pico-bridge 0.2.1 and its corrected tracking pose semantics. - Added optional Pico controller control for LinkerHand L6 in sim2real, backed by the LinkerHand SDK submodule. +- Realtime mode switches and pause/resume now preserve GMR IK warm-starts instead of cold-starting the retargeter on each transition. ### v0.3.0 (2026-05-12) diff --git a/teleopit/sim/loop.py b/teleopit/sim/loop.py index 21779743..6de109b1 100644 --- a/teleopit/sim/loop.py +++ b/teleopit/sim/loop.py @@ -255,7 +255,6 @@ def _restart_offline_playback( *, offline_playback: OfflinePlaybackController, mocap_session: MocapSessionManager, - retargeter: Retargeter, ) -> None: offline_playback.replay() mocap_session.reset() @@ -263,7 +262,6 @@ def _restart_offline_playback( self._last_action = np.zeros((self._num_actions,), dtype=np.float32) self.controller.reset() self.obs_builder.reset() - retargeter.reset() self.robot.reset() def _pause_offline_playback( diff --git a/teleopit/sim/session.py b/teleopit/sim/session.py index 18287477..df228c6a 100644 --- a/teleopit/sim/session.py +++ b/teleopit/sim/session.py @@ -239,13 +239,9 @@ def reset_policy_reference_state(self, *, reset_mocap_session: bool = True) -> N self.last_commanded_motion_qpos = None self.reset_runtime_tracking() - def full_policy_reset(self) -> None: - self.reset_policy_reference_state() - self._retargeter.reset() - def enter_standing_mode(self) -> None: from teleopit.sim.loop import SimulationMode - self.full_policy_reset() + self.reset_policy_reference_state() self._loop._set_standing_reference(self._loop.robot.get_state()) self.simulation_mode = SimulationMode.STANDING @@ -257,7 +253,7 @@ def enter_mocap_mode(self) -> None: return state = loop.robot.get_state() start_qpos = loop._resolve_hold_qpos(None, None, None, state) - self.full_policy_reset() + self.reset_policy_reference_state() self._step_runner.last_retarget_qpos = start_qpos.copy() self.last_commanded_motion_qpos = start_qpos.copy() self.simulation_mode = SimulationMode.MOCAP @@ -321,7 +317,6 @@ def _handle_offline_keyboard(self) -> bool: loop._restart_offline_playback( offline_playback=self.offline_playback, mocap_session=self.mocap_session, - retargeter=self._retargeter, ) self.cached_human_frame = None self.cached_retargeted = None diff --git a/teleopit/sim2real/controller.py b/teleopit/sim2real/controller.py index 3798526a..23e50d38 100644 --- a/teleopit/sim2real/controller.py +++ b/teleopit/sim2real/controller.py @@ -713,7 +713,6 @@ def _reset_policy_state(self) -> None: self._last_commanded_motion_qpos = None self.policy.reset() self.obs_builder.reset() - self.retargeter.reset() def _reset_policy_reference_state(self) -> None: """Reset policy/reference state without resetting the retargeter.""" diff --git a/teleopit/sim2real/safety.py b/teleopit/sim2real/safety.py index 2b31ee23..d5b7dfed 100644 --- a/teleopit/sim2real/safety.py +++ b/teleopit/sim2real/safety.py @@ -32,10 +32,9 @@ def __init__( self._policy_hz = policy_hz real_cfg = cfg_get(cfg, "real_robot") - # KP ramp (gradually increase PD gains after episode-reset) - _legacy_ramp_dur = cfg_get(cfg, "startup_ramp_duration", cfg_get(real_cfg, "startup_ramp_duration", 2.0)) - kp_ramp_dur = float(cfg_get(cfg, "kp_ramp_duration", _legacy_ramp_dur)) - self._kp_ramp_duration_steps: int = max(1, int(kp_ramp_dur * policy_hz)) + # Kp ramp gradually increases PD gains after episode-reset. + startup_ramp_duration = float(cfg_get(cfg, "startup_ramp_duration", 2.0)) + self._kp_ramp_duration_steps: int = max(1, int(startup_ramp_duration * policy_hz)) self._kp_ramp_step: int = 0 self._kp_ramp_active: bool = False self._kp_nominal = np.asarray(cfg_get(real_cfg, "kp_real", [100] * num_actions), dtype=np.float32) diff --git a/tests/test_sim2real_runtime.py b/tests/test_sim2real_runtime.py index e5687c4a..1281eb85 100644 --- a/tests/test_sim2real_runtime.py +++ b/tests/test_sim2real_runtime.py @@ -205,9 +205,10 @@ def test_mode_transitions_reset_stateful_policy(monkeypatch) -> None: ctrl._enter_standing() ctrl._transition_to_mocap() - # Both _enter_standing and _transition_to_mocap now do full episode-reset + # Both transitions now do soft episode-reset and preserve retargeter warm-start. assert policy.reset_calls == 2 assert obs_builder.reset_calls == 2 + assert ctrl.retargeter.reset_calls == 0 def test_reset_policy_state_clears_reference_timeline(monkeypatch) -> None: From 925510b61c51541048cb080afee76e553465efff Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Wed, 27 May 2026 20:52:30 +0800 Subject: [PATCH 041/122] Refine sim2real state transitions --- docs/docs/tutorials/pico-sim2real.md | 7 ++ .../current/tutorials/pico-sim2real.md | 7 ++ teleopit/configs/pico4_sim2real.yaml | 4 + teleopit/configs/sim2real.yaml | 4 + teleopit/sim2real/controller.py | 76 +++++++++------ teleopit/sim2real/safety.py | 21 +++- tests/test_sim2real_runtime.py | 95 ++++++++++++++++--- 7 files changed, 173 insertions(+), 41 deletions(-) diff --git a/docs/docs/tutorials/pico-sim2real.md b/docs/docs/tutorials/pico-sim2real.md index e847a488..45559c9b 100644 --- a/docs/docs/tutorials/pico-sim2real.md +++ b/docs/docs/tutorials/pico-sim2real.md @@ -154,6 +154,13 @@ profile: pip install -e '.[dexhand]' ``` +Bring up the CAN interfaces before testing or running hand control: + +```bash +sudo /usr/sbin/ip link set can0 up type can bitrate 1000000 +sudo /usr/sbin/ip link set can1 up type can bitrate 1000000 +``` + Before enabling full sim2real, verify the hand connection with a standalone open/close test: diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md index 771bf234..e659b71a 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md @@ -144,6 +144,13 @@ mocap 暂停、帧超时和退出时都会发送张开姿态。 pip install -e '.[dexhand]' ``` +测试或运行手控前,先开启 CAN 接口: + +```bash +sudo /usr/sbin/ip link set can0 up type can bitrate 1000000 +sudo /usr/sbin/ip link set can1 up type can bitrate 1000000 +``` + 启用完整 sim2real 前,先用独立开合测试验证灵巧手连接: ```bash diff --git a/teleopit/configs/pico4_sim2real.yaml b/teleopit/configs/pico4_sim2real.yaml index a0597668..f3fc588f 100644 --- a/teleopit/configs/pico4_sim2real.yaml +++ b/teleopit/configs/pico4_sim2real.yaml @@ -20,6 +20,10 @@ reference_debug_log: false # Kp ramp duration (seconds) -- gradually increases PD gains after entering STANDING startup_ramp_duration: 2.0 +kp_ramp_floor_ratio: 0.1 +# Faster Kp ramp used when returning from MOCAP to default STANDING with X +standing_return_ramp_duration: 0.5 +standing_return_kp_ramp_floor_ratio: 0.5 # Joint velocity safety limit (rad/s) -- trigger emergency damping if exceeded joint_vel_limit: 10.0 diff --git a/teleopit/configs/sim2real.yaml b/teleopit/configs/sim2real.yaml index 92f72d7c..cb6796a9 100644 --- a/teleopit/configs/sim2real.yaml +++ b/teleopit/configs/sim2real.yaml @@ -19,6 +19,10 @@ playback: # Kp ramp duration (seconds) -- gradually increases PD gains after entering STANDING startup_ramp_duration: 2.0 +kp_ramp_floor_ratio: 0.1 +# Faster Kp ramp used when returning from MOCAP to default STANDING with X +standing_return_ramp_duration: 0.5 +standing_return_kp_ramp_floor_ratio: 0.5 # Joint velocity safety limit (rad/s) -- trigger emergency damping if exceeded joint_vel_limit: 10.0 diff --git a/teleopit/sim2real/controller.py b/teleopit/sim2real/controller.py index 23e50d38..23d78779 100644 --- a/teleopit/sim2real/controller.py +++ b/teleopit/sim2real/controller.py @@ -132,6 +132,10 @@ def __init__(self, cfg: Any) -> None: self._init_components(cfg) self._init_reference_config(cfg) self._safety = Sim2RealSafetyManager(cfg, self.robot, self.policy_hz, self.num_actions) + self._standing_return_ramp_duration = float(cfg_get(cfg, "standing_return_ramp_duration", 0.5)) + self._standing_return_kp_ramp_floor_ratio = float( + cfg_get(cfg, "standing_return_kp_ramp_floor_ratio", 0.5) + ) logger.info( "Sim2RealController ready | mode=IDLE | policy_hz=%.0f", @@ -557,28 +561,22 @@ def _enter_standing(self) -> None: return time.sleep(0.5) - # Lock joints to current position (prevent collapse during init) - logger.info("Locking joints to current position...") - self.robot.lock_all_joints() - time.sleep(0.3) + state = self.robot.get_state() + if prev_mode != RobotMode.MOCAP: + # Lock joints to current position during initial low-level takeover. + logger.info("Locking joints to current position...") + self.robot.lock_all_joints() + time.sleep(0.3) # Episode-reset semantics: reference = current robot state, full policy reset. # This matches training where robot is teleported to reference position at # episode start, so policy sees reference ≈ robot state with clean history. - state = self.robot.get_state() - init_qpos = np.zeros(FULL_QPOS_DIM, dtype=np.float64) - init_qpos[3:7] = state.quat.astype(np.float64) - init_qpos[ROOT_DIM:FULL_QPOS_DIM] = state.qpos.astype(np.float64) + init_qpos = self._build_robot_state_qpos(state) self._last_retarget_qpos = init_qpos self._ref_proc.last_reference_qpos = None self._mocap_session.reset() self._last_commanded_motion_qpos = None - self._standing_qpos[0:3] = 0.0 - if getattr(state, "base_pos", None) is not None: - self._standing_qpos[0:3] = np.asarray(state.base_pos, dtype=np.float64)[:3] - self._standing_qpos[3:7] = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64) - align_motion_qpos_yaw(np.asarray(state.quat, dtype=np.float32), self._standing_qpos) - self._standing_qpos[ROOT_DIM:FULL_QPOS_DIM] = self.default_angles.astype(np.float64) + self._set_default_standing_reference(state) # Always do a full policy reset (episode-reset semantics) to ensure # the TemporalCNN history is clean and action-state causality holds. @@ -586,7 +584,13 @@ def _enter_standing(self) -> None: # Kp ramp: gradually increase PD gains to avoid torque spike. # Unlike position ramping, this does not alter policy targets. - self._safety.start_kp_ramp() + if prev_mode == RobotMode.MOCAP: + self._safety.start_kp_ramp( + duration_s=self._standing_return_ramp_duration, + floor_ratio=self._standing_return_kp_ramp_floor_ratio, + ) + else: + self._safety.start_kp_ramp() self._mocap_reentry_armed = prev_mode == RobotMode.MOCAP @@ -654,22 +658,22 @@ def _can_switch_to_mocap(self) -> bool: def _transition_to_mocap(self) -> None: """Switch from STANDING -> MOCAP. - Episode-reset + reference realignment. The policy state is fully - reset (clean history, zero last_action) so the TemporalCNN starts - fresh. Incoming mocap is aligned by fixed yaw/XY offsets and then - consumed directly; switching does not interpolate reference qpos. + Episode-reset + reference realignment. The policy/reference state is + reset like pause/resume so the first mocap frame is anchored to the + current robot pose and starts with zero inferred reference velocity. """ state = self.robot.get_state() - init_qpos = np.zeros(FULL_QPOS_DIM, dtype=np.float64) - init_qpos[3:7] = state.quat.astype(np.float64) - init_qpos[ROOT_DIM:FULL_QPOS_DIM] = state.qpos.astype(np.float64) - self._last_retarget_qpos = init_qpos - self._last_commanded_motion_qpos = init_qpos.copy() + resume_qpos = self._build_resume_alignment_qpos(self._standing_qpos, state) + self._last_commanded_motion_qpos = resume_qpos.copy() self._mocap_reentry_armed = False # Full episode reset: clean policy state, alignment, timeline. self._reset_policy_state() + self._last_retarget_qpos = None + self._last_commanded_motion_qpos = resume_qpos.copy() + self._ref_proc.reset_alignment(target_qpos=resume_qpos) if self._offline_playback is not None: + self._last_retarget_qpos = resume_qpos.copy() self._offline_playback.replay() self.mode = RobotMode.MOCAP @@ -739,15 +743,33 @@ def _reset_mocap_reference_state(self) -> None: self._ref_proc.reset_smoothers() self._last_live_packet_seq = -1 - def _build_resume_alignment_qpos(self, hold_qpos: Float64Array | None, state: object) -> Float64Array: + def _build_robot_state_qpos(self, state: object) -> Float64Array: qpos = np.zeros(FULL_QPOS_DIM, dtype=np.float64) + base_pos = getattr(state, "base_pos", None) + if base_pos is not None: + qpos[0:3] = np.asarray(base_pos, dtype=np.float64).reshape(-1)[:3] + qpos[3:7] = np.asarray(getattr(state, "quat"), dtype=np.float64).reshape(-1)[:4] + qpos[ROOT_DIM:FULL_QPOS_DIM] = np.asarray(getattr(state, "qpos"), dtype=np.float64).reshape(-1)[ + : self.num_actions + ] + return qpos + + def _set_default_standing_reference(self, state: object) -> None: + self._standing_qpos[:] = 0.0 + base_pos = getattr(state, "base_pos", None) + if base_pos is not None: + self._standing_qpos[0:3] = np.asarray(base_pos, dtype=np.float64).reshape(-1)[:3] + self._standing_qpos[3:7] = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64) + align_motion_qpos_yaw(np.asarray(getattr(state, "quat"), dtype=np.float32), self._standing_qpos) + self._standing_qpos[ROOT_DIM:FULL_QPOS_DIM] = self.default_angles.astype(np.float64) + + def _build_resume_alignment_qpos(self, hold_qpos: Float64Array | None, state: object) -> Float64Array: + qpos = self._build_robot_state_qpos(state) if hold_qpos is not None: qpos[0:2] = np.asarray(hold_qpos, dtype=np.float64).reshape(-1)[0:2] base_pos = getattr(state, "base_pos", None) if base_pos is not None: qpos[0:2] = np.asarray(base_pos, dtype=np.float64).reshape(-1)[0:2] - qpos[3:7] = np.asarray(getattr(state, "quat"), dtype=np.float64) - qpos[ROOT_DIM:FULL_QPOS_DIM] = np.asarray(getattr(state, "qpos"), dtype=np.float64) return qpos def _restart_offline_playback(self) -> None: diff --git a/teleopit/sim2real/safety.py b/teleopit/sim2real/safety.py index d5b7dfed..6ce08ebe 100644 --- a/teleopit/sim2real/safety.py +++ b/teleopit/sim2real/safety.py @@ -34,12 +34,14 @@ def __init__( # Kp ramp gradually increases PD gains after episode-reset. startup_ramp_duration = float(cfg_get(cfg, "startup_ramp_duration", 2.0)) - self._kp_ramp_duration_steps: int = max(1, int(startup_ramp_duration * policy_hz)) + self._default_kp_ramp_duration_steps: int = max(1, int(startup_ramp_duration * policy_hz)) + self._kp_ramp_duration_steps: int = self._default_kp_ramp_duration_steps self._kp_ramp_step: int = 0 self._kp_ramp_active: bool = False self._kp_nominal = np.asarray(cfg_get(real_cfg, "kp_real", [100] * num_actions), dtype=np.float32) self._kd_nominal = np.asarray(cfg_get(real_cfg, "kd_real", [2] * num_actions), dtype=np.float32) - self._kp_ramp_floor_ratio: float = float(cfg_get(cfg, "kp_ramp_floor_ratio", 0.1)) + self._default_kp_ramp_floor_ratio: float = float(cfg_get(cfg, "kp_ramp_floor_ratio", 0.1)) + self._kp_ramp_floor_ratio: float = self._default_kp_ramp_floor_ratio # Joint safety limits self._joint_vel_limit: float = float( @@ -69,8 +71,21 @@ def compute_kp_ramp_gains(self) -> tuple[Float32Array, Float32Array] | None: return np.asarray(kp, dtype=np.float32), self._kd_nominal.copy() - def start_kp_ramp(self) -> None: + def start_kp_ramp( + self, + *, + duration_s: float | None = None, + floor_ratio: float | None = None, + ) -> None: """Arm the Kp ramp for gradual PD gain increase.""" + if duration_s is None: + self._kp_ramp_duration_steps = self._default_kp_ramp_duration_steps + else: + self._kp_ramp_duration_steps = max(1, int(float(duration_s) * self._policy_hz)) + if floor_ratio is None: + self._kp_ramp_floor_ratio = self._default_kp_ramp_floor_ratio + else: + self._kp_ramp_floor_ratio = float(floor_ratio) self._kp_ramp_step = 0 self._kp_ramp_active = True logger.info( diff --git a/tests/test_sim2real_runtime.py b/tests/test_sim2real_runtime.py index 1281eb85..df0469b3 100644 --- a/tests/test_sim2real_runtime.py +++ b/tests/test_sim2real_runtime.py @@ -17,18 +17,24 @@ def __init__(self, _cfg: object) -> None: ang_vel=np.zeros(3, dtype=np.float32), ) self.sent_positions: list[np.ndarray] = [] + self.sent_gains: list[tuple[np.ndarray | None, np.ndarray | None]] = [] + self.lock_calls = 0 def enter_debug_mode(self) -> bool: return True def lock_all_joints(self) -> None: - pass + self.lock_calls += 1 def get_state(self) -> SimpleNamespace: return self._state def send_positions(self, target_dof_pos: np.ndarray, kp: np.ndarray | None = None, kd: np.ndarray | None = None) -> None: self.sent_positions.append(np.asarray(target_dof_pos, dtype=np.float32)) + self.sent_gains.append(( + None if kp is None else np.asarray(kp, dtype=np.float32), + None if kd is None else np.asarray(kd, dtype=np.float32), + )) def set_damping(self) -> None: pass @@ -163,7 +169,12 @@ def tick(self, *, active: bool) -> None: def _make_cfg() -> dict[str, object]: return { "policy_hz": 50.0, - "real_robot": {}, + "real_robot": { + "kp_real": [100.0] * 29, + "kd_real": [2.0] * 29, + }, + "standing_return_ramp_duration": 0.5, + "standing_return_kp_ramp_floor_ratio": 0.5, "mocap_switch": {"check_frames": 1}, "robot": { "default_angles": [0.0] * 29, @@ -369,6 +380,37 @@ def test_state_machine_allows_mocap_reentry_after_returning_to_standing(monkeypa assert ctrl.mode == RobotMode.MOCAP +def test_return_to_standing_uses_default_pose_and_stronger_ramp_without_relock(monkeypatch) -> None: + from teleopit.sim2real.controller import RobotMode, Sim2RealController + + policy = DummyPolicy() + obs_builder = DummyVelCmdObservationBuilder() + _install_controller_mocks( + monkeypatch, + policy=policy, + obs_builder=obs_builder, + qpos=np.zeros(36, dtype=np.float64), + ) + + cfg = _make_cfg() + cfg["robot"]["default_angles"] = [0.2] * 29 + ctrl = Sim2RealController(cfg) + ctrl.mode = RobotMode.MOCAP + ctrl.robot._state.qpos = np.ones(29, dtype=np.float32) + ctrl._last_commanded_motion_qpos = np.ones(36, dtype=np.float64) + + ctrl._enter_standing() + ctrl._standing_step() + + assert ctrl.mode == RobotMode.STANDING + assert ctrl.robot.lock_calls == 0 + np.testing.assert_allclose(ctrl._standing_qpos[7:36], np.full(29, 0.2, dtype=np.float64)) + kp, kd = ctrl.robot.sent_gains[-1] + assert kp is not None + assert kd is not None + np.testing.assert_allclose(kp, np.full(29, 50.0, dtype=np.float32)) + + def test_dexterous_hand_ticks_only_during_active_mocap(monkeypatch) -> None: import teleopit.sim2real.controller as controller_mod from teleopit.runtime.mocap_session import MocapSessionState @@ -444,17 +486,19 @@ def blocking_get_frame() -> dict[str, tuple[np.ndarray, np.ndarray]]: def test_mocap_step_episode_reset_on_transition(monkeypatch) -> None: - """After _transition_to_mocap (episode-reset), the first mocap step should - produce zero anchor velocities because _last_reference_qpos is None.""" + """After _transition_to_mocap, the first mocap step starts with zero joint velocity.""" from teleopit.sim2real.controller import Sim2RealController policy = DummyPolicy() obs_builder = DummyVelCmdObservationBuilder() target_qpos = np.zeros(36, dtype=np.float64) target_qpos[0] = 0.3 + target_qpos[7] = 1.0 _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=target_qpos) - ctrl = Sim2RealController(_make_cfg()) + cfg = _make_cfg() + cfg["retarget_buffer_enabled"] = False + ctrl = Sim2RealController(cfg) ctrl._transition_to_mocap() monkeypatch.setattr( ctrl._ref_proc, @@ -468,12 +512,8 @@ def test_mocap_step_episode_reset_on_transition(monkeypatch) -> None: ctrl._mocap_step() assert len(obs_builder.build_calls) == 1 - # First step after episode-reset: _last_reference_qpos was None on entry, - # so _compute_anchor_velocities returns zeros for the initial call. - # The mock overrides this, but the first call computes joint vel from - # finite diff with _last_retarget_qpos which IS set, so vel ≠ 0. - # Just verify the observation was built. - assert obs_builder.build_calls[0]["motion_qpos"] is not None + np.testing.assert_allclose(obs_builder.build_calls[0]["motion_joint_vel"], np.zeros(29, dtype=np.float32)) + np.testing.assert_allclose(obs_builder.build_calls[0]["motion_qpos"][7], 1.0, atol=1e-6) def test_mocap_step_velcmd_applies_fixed_initial_yaw_alignment(monkeypatch) -> None: @@ -536,6 +576,39 @@ def test_mocap_step_velcmd_keeps_fixed_yaw_after_start(monkeypatch) -> None: ) +def test_transition_to_mocap_uses_resume_style_alignment_and_zero_velocity(monkeypatch) -> None: + from teleopit.sim2real.controller import Sim2RealController + + policy = DummyPolicy() + obs_builder = DummyVelCmdObservationBuilder() + target_qpos = np.zeros(36, dtype=np.float64) + target_qpos[0] = 0.25 + target_qpos[7] = 0.75 + _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=target_qpos) + + cfg = _make_cfg() + cfg["retarget_buffer_enabled"] = False + ctrl = Sim2RealController(cfg) + ctrl.robot._state.base_pos = np.array([1.0, 2.0, 0.0], dtype=np.float32) + ctrl.robot._state.quat = np.array([0.9238795, 0.0, 0.0, 0.38268343], dtype=np.float32) + ctrl._transition_to_mocap() + + assert ctrl._last_retarget_qpos is None + + monkeypatch.setattr( + ctrl._ref_proc, + "compute_anchor_velocities", + lambda _qpos: ( + np.zeros(3, dtype=np.float32), + np.zeros(3, dtype=np.float32), + ), + ) + ctrl._mocap_step() + + np.testing.assert_allclose(obs_builder.build_calls[0]["motion_joint_vel"], np.zeros(29, dtype=np.float32)) + np.testing.assert_allclose(obs_builder.build_calls[0]["motion_qpos"][0:2], np.array([1.0, 2.0], dtype=np.float32), atol=1e-6) + + def test_mocap_step_waits_for_realtime_warmup_before_running_policy(monkeypatch) -> None: from teleopit.sim2real.controller import Sim2RealController From 1b77f04ab9aaecd0b31521427bf9fb4b1d6cfbaf Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Wed, 27 May 2026 21:51:41 +0800 Subject: [PATCH 042/122] Add somehand L6 VR hand pose support --- .gitmodules | 3 + AGENTS.md | 12 +- README.md | 2 +- docs/docs/configuration/config-reference.md | 14 +- docs/docs/getting-started/installation.md | 8 +- docs/docs/tutorials/pico-sim2real.md | 28 +- .../current/configuration/config-reference.md | 13 +- .../current/getting-started/installation.md | 8 +- .../current/tutorials/pico-sim2real.md | 27 +- pyproject.toml | 1 + scripts/dev/test_linkerhand_l6.py | 288 +++++++++++++++++- scripts/run/run_sim2real.py | 1 + scripts/setup/download_somehand_l6_assets.sh | 189 ++++++++++++ teleopit/configs/pico4_sim2real.yaml | 7 +- teleopit/configs/sim2real.yaml | 7 +- teleopit/inputs/pico4_provider.py | 57 ++++ teleopit/sim2real/dexterous_hand.py | 191 +++++++++++- tests/test_dexterous_hand.py | 143 ++++++++- tests/test_pico4_provider.py | 27 ++ third_party/somehand | 1 + 20 files changed, 975 insertions(+), 52 deletions(-) create mode 100755 scripts/setup/download_somehand_l6_assets.sh create mode 160000 third_party/somehand diff --git a/.gitmodules b/.gitmodules index 6ddf149b..1b97b5ef 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,3 +4,6 @@ [submodule "third_party/linkerhand-python-sdk"] path = third_party/linkerhand-python-sdk url = https://github.com/BotRunner64/linkerhand-python-sdk.git +[submodule "third_party/somehand"] + path = third_party/somehand + url = https://github.com/BotRunner64/somehand.git diff --git a/AGENTS.md b/AGENTS.md index 2bb25b66..eec9346d 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -56,7 +56,7 @@ teleopit/ # Core inference package │ └── loop.py # SimulationLoop — PD control at 1000Hz, policy at 50Hz ├── sim2real/ │ ├── controller.py # G1 state machine and hardware control loop -│ └── dexterous_hand.py # Optional Pico controller → LinkerHand L6 runtime +│ └── dexterous_hand.py # Optional Pico gripper / VR hand pose → LinkerHand L6 runtime └── recording/ # HDF5Recorder scripts/ ├── run_sim.py # Offline sim2sim pipeline @@ -142,9 +142,11 @@ target_dof_pos = clip(action, -10, 10) × action_scale + default_dof_pos - Pico4 sim2real pause/resume is handled as a mocap-session control event (`toggle_pause`), not as a mode switch to `STANDING` - Default Pico pause button is `A`; resume rebuilds the realtime buffer and yaw/XY root-offset alignment, then waits for the configured realtime warmup before tracking continues - Realtime mode switches and pause/resume use a retargeter-preserving soft reset: policy/reference history, smoothers, realtime buffers, and reference alignment are reset, while the GMR IK warm-start is retained -- Optional LinkerHand L6 control uses `third_party/linkerhand-python-sdk` and `dexterous_hand.enabled=true` -- LinkerHand control reuses `Pico4InputProvider.get_controller_snapshot()`; do not start a second `PicoBridge` for hand control -- LinkerHand L6 control is active only in sim2real `MOCAP`; `STANDING`, `DAMPING`, mocap pause, frame timeout, and shutdown must send the configured open pose +- Optional LinkerHand L6 control uses `dexterous_hand.mode=off|gripper|vr_hand_pose`; default is `off` +- `gripper` mode reuses `Pico4InputProvider.get_controller_snapshot()` for Pico grip/trigger open-close control +- `vr_hand_pose` mode reuses `Pico4InputProvider.get_hand_snapshot()` and `somehand` for continuous Pico hand-pose retargeting; do not start a second `PicoBridge` for hand control +- LinkerHand L6 control is active only in sim2real `MOCAP`; `STANDING`, `DAMPING`, mocap pause, and shutdown must send the configured open pose +- In `vr_hand_pose` mode, missing/inactive hand pose holds the last commanded pose for that side instead of opening the hand ### SimulationLoop Runtime Behavior - `realtime=true` enforces wall-clock pacing even without a viewer @@ -213,7 +215,7 @@ python train_mimic/scripts/save_onnx.py --checkpoint logs/rsl_rl/g1_general_trac - Do not commit robot meshes, datasets, checkpoints, or demo media to Git; use `scripts/setup/download_assets.py` - `teleopit/retargeting/gmr/assets/` is gitignored; downloaded at runtime - `train_mimic/assets/` is no longer tracked; FK tooling reuses `teleopit/retargeting/gmr/assets/unitree_g1/g1_mjlab.xml` -- `third_party/linkerhand-python-sdk` is a git submodule for optional LinkerHand L6 sim2real control +- `third_party/linkerhand-python-sdk` and `third_party/somehand` support optional LinkerHand L6 sim2real control - Run `python scripts/check_large_tracked_files.py` before pushing Assets are split across two ModelScope repos by type: diff --git a/README.md b/README.md index 85f62270..cc3ffc81 100644 --- a/README.md +++ b/README.md @@ -66,7 +66,7 @@ Full docs at **[BotRunner64.github.io/Teleopit](https://BotRunner64.github.io/Te ### Unreleased - Bumped Pico input support to pico-bridge 0.2.1 and its corrected tracking pose semantics. -- Added optional Pico controller control for LinkerHand L6 in sim2real, backed by the LinkerHand SDK submodule. +- Added optional LinkerHand L6 sim2real modes: `gripper` from Pico grip/trigger and `vr_hand_pose` from Pico hand pose through somehand. - Realtime mode switches and pause/resume now preserve GMR IK warm-starts instead of cold-starting the retargeter on each transition. ### v0.3.0 (2026-05-12) diff --git a/docs/docs/configuration/config-reference.md b/docs/docs/configuration/config-reference.md index 85da3790..07368c08 100644 --- a/docs/docs/configuration/config-reference.md +++ b/docs/docs/configuration/config-reference.md @@ -121,20 +121,22 @@ Realtime Pico resume re-centers heading and ground-plane position before trackin ### Dexterous Hand (Pico sim2real) -`dexterous_hand.enabled=true` requires `input.provider=pico4` and the optional -`dexhand` extra. Control is active only in `MOCAP`; inactive modes and timeouts -send the open pose. +`dexterous_hand.mode=gripper` or `dexterous_hand.mode=vr_hand_pose` requires +`input.provider=pico4` and the optional `dexhand` extra. Control is active only +in `MOCAP`; inactive modes send the open pose. In `vr_hand_pose`, missing hand +pose holds the last command for that side. | Field | Description | Default | |-------|-------------|---------| -| `dexterous_hand.enabled` | Enable Pico controller control for LinkerHand L6 | `false` | -| `dexterous_hand.hand_type` | Controlled side: `left`, `right`, or `both` | `both` | +| `dexterous_hand.mode` | `off`, `gripper`, or `vr_hand_pose` | `off` | +| `dexterous_hand.hand_type` | Controlled side: `left`, `right`, or `both`; `vr_hand_pose` requires `both` | `both` | | `dexterous_hand.left_can` / `right_can` | CAN channels for each hand | `can0` / `can1` | | `dexterous_hand.rate` | Maximum command rate in Hz | `30.0` | -| `dexterous_hand.frame_timeout` | Missing-controller timeout before opening hands | `0.3` | +| `dexterous_hand.frame_timeout` | Gripper controller timeout, or VR hand-pose staleness threshold | `0.3` | | `dexterous_hand.deadman_threshold` | Minimum grip value required to enable a side | `0.5` | | `dexterous_hand.trigger_deadzone` | Trigger deadzone at both ends | `0.05` | | `dexterous_hand.open_pose` / `close_pose` | Six-value L6 open/closed poses | see config | +| `dexterous_hand.somehand.config_path` | somehand bi-hand L6 config used by `vr_hand_pose` | see config | ## Critical: `default_dof_pos` diff --git a/docs/docs/getting-started/installation.md b/docs/docs/getting-started/installation.md index 1771c6c8..6129c681 100644 --- a/docs/docs/getting-started/installation.md +++ b/docs/docs/getting-started/installation.md @@ -62,15 +62,17 @@ See [Pico Sim2Sim](../tutorials/pico-sim2sim) and [Pico Sim2Real](../tutorials/pico-sim2real) for the full setup guides. Optional LinkerHand L6 control for Pico sim2real is installed through the -`dexhand` extra. The SDK itself is provided by the repository submodule, so make -sure submodules are initialized first: +`dexhand` extra. It includes the LinkerHand SDK submodule and the remote +somehand package used by VR hand-pose mode: ```bash git submodule update --init --recursive pip install -e '.[dexhand]' +scripts/setup/download_somehand_l6_assets.sh ``` -This extra is only required when `dexterous_hand.enabled=true`. +This extra is only required when `dexterous_hand.mode=gripper` or +`dexterous_hand.mode=vr_hand_pose`. ## Verify Installation diff --git a/docs/docs/tutorials/pico-sim2real.md b/docs/docs/tutorials/pico-sim2real.md index 45559c9b..49c687d1 100644 --- a/docs/docs/tutorials/pico-sim2real.md +++ b/docs/docs/tutorials/pico-sim2real.md @@ -142,16 +142,23 @@ reference changes when live tracking resumes. ## Optional LinkerHand L6 Control -Pico sim2real can drive LinkerHand L6 hands from the Pico controllers. Hold the -matching side grip as a deadman switch; the matching trigger closes that hand. +Pico sim2real can drive LinkerHand L6 hands in two modes: + +- `gripper`: hold the matching side grip as a deadman switch; the matching + trigger closes that hand. +- `vr_hand_pose`: retarget Pico hand pose through somehand and command the + continuous L6 hand target. If a hand pose disappears, that side keeps its last + commanded pose. This mode currently uses `hand_type=both`. + Hand control is active only in `MOCAP`. It sends the open pose in `STANDING`, -`DAMPING`, paused mocap, frame timeout, and shutdown. +`DAMPING`, paused mocap, and shutdown. Install the dexhand extra first if it was not installed with the main Pico profile: ```bash pip install -e '.[dexhand]' +scripts/setup/download_somehand_l6_assets.sh ``` Bring up the CAN interfaces before testing or running hand control: @@ -174,7 +181,16 @@ python scripts/dev/test_linkerhand_l6.py \ Then enable L6 control in Pico sim2real: ```bash -dexterous_hand.enabled=true +dexterous_hand.mode=gripper +dexterous_hand.left_can=can0 +dexterous_hand.right_can=can1 +``` + +For continuous VR hand-pose control, use: + +```bash +dexterous_hand.mode=vr_hand_pose +dexterous_hand.hand_type=both dexterous_hand.left_can=can0 dexterous_hand.right_can=can1 ``` @@ -213,7 +229,7 @@ mocap_switch.check_frames=10 input.pause_button=right_axis_click # Enable LinkerHand L6 control -dexterous_hand.enabled=true +dexterous_hand.mode=gripper # Enable headset video preview input.video.enabled=true @@ -228,5 +244,5 @@ input.video.enabled=true | Cannot enter debug mode | Unitree mode release failed | Stop other robot modes and press `Start` again | | Robot enters `STANDING` but not `MOCAP` | Mocap validation failed | Keep tracking active and stable; check `mocap_switch.check_frames` logs | | Pico pause does not return to `STANDING` | Expected behavior | Pico pause freezes mocap; press remote `X` for `STANDING` | -| LinkerHand does not move | Not in `MOCAP`, deadman grip released, SDK not installed, or CAN channel wrong | Enter `MOCAP`, hold the matching side grip, run `scripts/dev/test_linkerhand_l6.py`, and check `dexterous_hand.left_can` / `right_can` | +| LinkerHand does not move | Mode is `off`, not in `MOCAP`, gripper deadman released, SDK/assets not installed, or CAN channel wrong | Set `dexterous_hand.mode`, enter `MOCAP`, run `scripts/dev/test_linkerhand_l6.py`, and check `dexterous_hand.left_can` / `right_can` | | Video preview is unavailable | RealSense or video source failed | Check camera permissions, `input.video.source`, and logs | diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md index 60b8e116..478d35eb 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md @@ -140,16 +140,19 @@ MuJoCo 窗口显示重定向参考;`sim2sim`、`mocap`、`camera` 和 `all` ### 灵巧手(Pico sim2real) -`dexterous_hand.enabled=true` 要求 `input.provider=pico4`,并安装可选的 -`dexhand` extra。控制只在 `MOCAP` 中生效;非活动模式和超时会发送张开姿态。 +`dexterous_hand.mode=gripper` 或 `dexterous_hand.mode=vr_hand_pose` 要求 +`input.provider=pico4`,并安装可选的 `dexhand` extra。控制只在 `MOCAP` +中生效;非活动模式会发送张开姿态。在 `vr_hand_pose` 中,手部 pose 消失时, +对应侧会保持上一条命令。 | 字段 | 说明 | 默认值 | |---|---|---| -| `dexterous_hand.enabled` | 启用 Pico 手柄控制 LinkerHand L6 | `false` | -| `dexterous_hand.hand_type` | 控制侧:`left`、`right` 或 `both` | `both` | +| `dexterous_hand.mode` | `off`、`gripper` 或 `vr_hand_pose` | `off` | +| `dexterous_hand.hand_type` | 控制侧:`left`、`right` 或 `both`;`vr_hand_pose` 要求 `both` | `both` | | `dexterous_hand.left_can` / `right_can` | 左右手 CAN 通道 | `can0` / `can1` | | `dexterous_hand.rate` | 最大命令频率(Hz) | `30.0` | -| `dexterous_hand.frame_timeout` | 手柄超时后张开手的时间 | `0.3` | +| `dexterous_hand.frame_timeout` | gripper 手柄超时或 VR 手部 pose 过期阈值 | `0.3` | | `dexterous_hand.deadman_threshold` | 启用单侧控制所需的最小 grip 值 | `0.5` | | `dexterous_hand.trigger_deadzone` | trigger 两端死区 | `0.05` | | `dexterous_hand.open_pose` / `close_pose` | L6 的 6 维张开/闭合姿态 | 见配置 | +| `dexterous_hand.somehand.config_path` | `vr_hand_pose` 使用的 somehand 双手 L6 配置 | 见配置 | diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md index 5fc8cd81..8a038992 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md @@ -61,15 +61,17 @@ receiver 可以运行在工作站 PC,也可以运行在机器人 onboard 计 完整设置流程详见 [Pico Sim2Sim](../tutorials/pico-sim2sim) 和 [Pico Sim2Real](../tutorials/pico-sim2real)。 -Pico sim2real 可选的 LinkerHand L6 控制通过 `dexhand` extra 安装。SDK -本身由仓库 submodule 提供,因此需要先初始化 submodule: +Pico sim2real 可选的 LinkerHand L6 控制通过 `dexhand` extra 安装。它包含 +LinkerHand SDK submodule,以及 `vr_hand_pose` 模式使用的远程 somehand 包: ```bash git submodule update --init --recursive pip install -e '.[dexhand]' +scripts/setup/download_somehand_l6_assets.sh ``` -只有在 `dexterous_hand.enabled=true` 时才需要安装这个 extra。 +只有在 `dexterous_hand.mode=gripper` 或 +`dexterous_hand.mode=vr_hand_pose` 时才需要安装这个 extra。 ## 验证安装 diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md index e659b71a..b4eb1f1a 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md @@ -134,14 +134,20 @@ Pico 暂停/恢复是 mocap-session control event。 ## 可选 LinkerHand L6 控制 -Pico sim2real 可以用 Pico 手柄控制 LinkerHand L6。按住同侧 grip 作为 deadman, -同侧 trigger 控制对应手闭合。手控只在 `MOCAP` 中生效;在 `STANDING`、`DAMPING`、 -mocap 暂停、帧超时和退出时都会发送张开姿态。 +Pico sim2real 可以用两种模式控制 LinkerHand L6: + +- `gripper`:按住同侧 grip 作为 deadman,同侧 trigger 控制对应手闭合。 +- `vr_hand_pose`:通过 somehand 重定向 Pico 手部 pose,并下发连续 L6 手部目标。 + 如果某侧手部 pose 消失,该侧会保持上一条手势命令。这个模式当前使用 + `hand_type=both`。 + +手控只在 `MOCAP` 中生效;在 `STANDING`、`DAMPING`、mocap 暂停和退出时都会发送张开姿态。 如果主 Pico profile 没有包含手控支持,先安装 dexhand extra: ```bash pip install -e '.[dexhand]' +scripts/setup/download_somehand_l6_assets.sh ``` 测试或运行手控前,先开启 CAN 接口: @@ -163,7 +169,16 @@ python scripts/dev/test_linkerhand_l6.py \ 然后在 Pico sim2real 中启用 L6 控制: ```bash -dexterous_hand.enabled=true +dexterous_hand.mode=gripper +dexterous_hand.left_can=can0 +dexterous_hand.right_can=can1 +``` + +连续 VR 手部 pose 控制使用: + +```bash +dexterous_hand.mode=vr_hand_pose +dexterous_hand.hand_type=both dexterous_hand.left_can=can0 dexterous_hand.right_can=can1 ``` @@ -202,7 +217,7 @@ mocap_switch.check_frames=10 input.pause_button=right_axis_click # 开启 LinkerHand L6 控制 -dexterous_hand.enabled=true +dexterous_hand.mode=gripper # 开启头显视频预览 input.video.enabled=true @@ -217,5 +232,5 @@ input.video.enabled=true | 无法进入 debug mode | Unitree mode 释放失败 | 停止其他机器人模式后再次按 `Start` | | 机器人进入 `STANDING` 但不进入 `MOCAP` | 动捕验证失败 | 保持追踪稳定,查看 `mocap_switch.check_frames` 日志 | | Pico 暂停没有返回 `STANDING` | 这是预期行为 | Pico 暂停只冻结 mocap;按遥控器 `X` 返回 `STANDING` | -| LinkerHand 不动 | 不在 `MOCAP`、deadman grip 未按住、SDK 未安装,或 CAN 通道错误 | 进入 `MOCAP`,按住同侧 grip,运行 `scripts/dev/test_linkerhand_l6.py`,并检查 `dexterous_hand.left_can` / `right_can` | +| LinkerHand 不动 | 模式为 `off`、不在 `MOCAP`、gripper deadman 未按住、SDK/资产未安装,或 CAN 通道错误 | 设置 `dexterous_hand.mode`,进入 `MOCAP`,运行 `scripts/dev/test_linkerhand_l6.py`,并检查 `dexterous_hand.left_can` / `right_can` | | 视频预览不可用 | RealSense 或视频源失败 | 检查相机权限、`input.video.source` 和日志 | diff --git a/pyproject.toml b/pyproject.toml index 59de8612..7a858fe7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ pico4 = [ ] dexhand = [ "linkerhand-python-sdk @ file:third_party/linkerhand-python-sdk", + "somehand @ file:third_party/somehand", ] [tool.setuptools.packages.find] diff --git a/scripts/dev/test_linkerhand_l6.py b/scripts/dev/test_linkerhand_l6.py index 643e8bb7..91c9a1ae 100644 --- a/scripts/dev/test_linkerhand_l6.py +++ b/scripts/dev/test_linkerhand_l6.py @@ -1,25 +1,71 @@ #!/usr/bin/env python3 -"""Exercise LinkerHand L6 open/close motion to verify hardware connectivity.""" +"""Exercise LinkerHand L6 dexterous-hand control modes.""" from __future__ import annotations import argparse +import logging from pathlib import Path import sys import time from typing import Sequence +import numpy as np + REPO_ROOT = Path(__file__).resolve().parents[2] +sys.path.insert(0, str(REPO_ROOT)) SDK_PATH = REPO_ROOT / "third_party" / "linkerhand-python-sdk" if SDK_PATH.exists(): sys.path.insert(0, str(SDK_PATH)) +SOMEHAND_SRC_PATH = REPO_ROOT / "third_party" / "somehand" / "src" +if SOMEHAND_SRC_PATH.exists(): + sys.path.insert(0, str(SOMEHAND_SRC_PATH)) + +from teleopit.inputs.pico4_provider import ( # noqa: E402 + PicoControllerSnapshot, + PicoControllerState, + PicoHandSnapshot, + PicoHandState, +) +from teleopit.sim2real.dexterous_hand import ( # noqa: E402 + LinkerHandConfig, + LinkerHandRuntime, + SomeHandPoseRuntime, + trigger_to_pose, +) THUMB_YAW_DEFAULT = 10 OPEN_POSE = [250, THUMB_YAW_DEFAULT, 250, 250, 250, 250] CLOSE_POSE = [79, THUMB_YAW_DEFAULT, 0, 0, 0, 0] DEFAULT_SPEED = [50, 50, 50, 50, 50, 50] +DEFAULT_SOMEHAND_CONFIG_PATH = "third_party/somehand/configs/retargeting/bihand/linkerhand_l6_bihand.yaml" +DEFAULT_LINKERHAND_SDK_ROOT = "third_party/linkerhand-python-sdk" + +PICO_BRIDGE_TO_MEDIAPIPE = [ + 1, 2, 3, 4, 5, 7, 8, 9, 10, 12, 13, 14, 15, 17, 18, 19, 20, 22, 23, 24, 25 +] +PICO_NATIVE_TO_RH = np.array( + [[1.0, 0.0, 0.0], [0.0, 0.0, -1.0], [0.0, 1.0, 0.0]], + dtype=np.float64, +) + + +class ScriptControllerProvider: + def __init__(self) -> None: + self.snapshot: PicoControllerSnapshot | None = None + + def get_controller_snapshot(self) -> PicoControllerSnapshot | None: + return self.snapshot + + +class ScriptHandProvider: + def __init__(self) -> None: + self.snapshot: PicoHandSnapshot | None = None + + def get_hand_snapshot(self) -> PicoHandSnapshot | None: + return self.snapshot def uint8(value: str) -> int: @@ -50,7 +96,17 @@ def selected_hand_types(hand_type: str) -> tuple[str, ...]: def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="Test LinkerHand L6 open/close motion") + parser = argparse.ArgumentParser(description="Test LinkerHand L6 dexterous-hand control modes") + parser.add_argument( + "--mode", + choices=["open_close", "gripper", "vr_hand_pose"], + default="open_close", + help=( + "open_close sends fixed poses directly; gripper exercises the sim2real Pico " + "grip/trigger mapping; vr_hand_pose exercises the somehand Pico hand-pose path " + "with synthetic hand landmarks." + ), + ) parser.add_argument("--hand-type", choices=["left", "right", "both"], default="both") parser.add_argument("--left-can", default="can0") parser.add_argument("--right-can", default="can1") @@ -61,7 +117,12 @@ def parse_args() -> argparse.Namespace: ) parser.add_argument("--cycles", type=positive_int, default=3) parser.add_argument("--hold-s", type=positive_float, default=1.0) + parser.add_argument("--rate", type=positive_float, default=30.0) + parser.add_argument("--frame-timeout", type=positive_float, default=0.3) + parser.add_argument("--trigger-deadzone", type=float, default=0.05) + parser.add_argument("--deadman-threshold", type=float, default=0.5) parser.add_argument("--thumb-yaw-center", type=uint8, default=THUMB_YAW_DEFAULT) + parser.add_argument("--print-input", action="store_true") parser.add_argument( "--speed", type=uint8, @@ -83,12 +144,41 @@ def parse_args() -> argparse.Namespace: default=CLOSE_POSE, metavar=("THUMB_PITCH", "THUMB_YAW", "INDEX", "MIDDLE", "RING", "LITTLE"), ) + parser.add_argument("--somehand-config-path", default=DEFAULT_SOMEHAND_CONFIG_PATH) + parser.add_argument("--somehand-sdk-root", default=DEFAULT_LINKERHAND_SDK_ROOT) args = parser.parse_args() args.open_pose[1] = args.thumb_yaw_center args.close_pose[1] = args.thumb_yaw_center + if args.trigger_deadzone < 0.0 or args.trigger_deadzone >= 0.5: + raise SystemExit("--trigger-deadzone must be in [0, 0.5)") + if args.deadman_threshold <= 0.0 or args.deadman_threshold >= 1.0: + raise SystemExit("--deadman-threshold must be in (0, 1)") return args +def make_config(args: argparse.Namespace, *, mode: str) -> LinkerHandConfig: + return LinkerHandConfig( + mode=mode, + enabled=True, + hand_joint="L6", + hand_type=args.hand_type, + left_can=args.left_can, + right_can=args.right_can, + modbus=args.modbus, + rate=args.rate, + frame_timeout=args.frame_timeout, + trigger_deadzone=args.trigger_deadzone, + deadman_threshold=args.deadman_threshold, + thumb_yaw_center=args.thumb_yaw_center, + speed=tuple(args.speed), + open_pose=tuple(args.open_pose), + close_pose=tuple(args.close_pose), + print_input=args.print_input, + somehand_config_path=args.somehand_config_path, + somehand_sdk_root=args.somehand_sdk_root, + ) + + def send_all(hands: dict[str, object], pose: Sequence[int], *, label: str) -> None: print(f"{label}: {list(pose)}", flush=True) for hand_type, hand in hands.items(): @@ -96,8 +186,20 @@ def send_all(hands: dict[str, object], pose: Sequence[int], *, label: str) -> No hand.finger_move(pose=list(pose)) -def main() -> None: - args = parse_args() +def wait_runtime_idle(runtime: object, *, timeout_s: float = 2.0) -> None: + sender = getattr(runtime, "_sender", None) + wait_idle = getattr(sender, "wait_idle", None) + if callable(wait_idle) and not wait_idle(timeout_s=timeout_s): + raise RuntimeError("Timed out waiting for LinkerHand sender to become idle") + + +def assert_runtime_started(runtime: object) -> None: + sender = getattr(runtime, "_sender", None) + if not bool(getattr(sender, "started", False)): + raise RuntimeError("LinkerHand sender failed to start; check the log above for SDK/CAN errors") + + +def run_open_close(args: argparse.Namespace) -> None: try: from LinkerHand.linker_hand_api import LinkerHandApi except ImportError as exc: @@ -149,5 +251,183 @@ def main() -> None: ) +def controller_snapshot( + *, + timestamp_s: float, + seq: int, + trigger: float, + grip: float, + config: LinkerHandConfig, +) -> PicoControllerSnapshot: + missing = PicoControllerState(raw=False, grip=0.0, trigger=0.0, present=False) + active = PicoControllerState(raw=True, grip=grip, trigger=trigger, present=True) + left = active if "left" in config.selected_hand_types else missing + right = active if "right" in config.selected_hand_types else missing + return PicoControllerSnapshot(left=left, right=right, timestamp_s=timestamp_s, seq=seq) + + +def run_gripper(args: argparse.Namespace) -> None: + config = make_config(args, mode="gripper") + provider = ScriptControllerProvider() + runtime = LinkerHandRuntime(config, provider) + + print("Testing dexterous_hand.mode=gripper with synthetic Pico grip/trigger snapshots", flush=True) + try: + runtime.start() + wait_runtime_idle(runtime) + assert_runtime_started(runtime) + + now_s = time.monotonic() + print("inactive safety open", flush=True) + runtime.tick(active=False, now_s=now_s) + wait_runtime_idle(runtime) + time.sleep(args.hold_s) + + seq = 1 + print("deadman released -> open", flush=True) + now_s = time.monotonic() + provider.snapshot = controller_snapshot( + timestamp_s=now_s, + seq=seq, + trigger=1.0, + grip=0.0, + config=config, + ) + runtime.tick(active=True, now_s=now_s) + wait_runtime_idle(runtime) + time.sleep(args.hold_s) + + sweep = [0.0, 0.25, 0.5, 0.75, 1.0, 0.75, 0.5, 0.25, 0.0] + for cycle in range(args.cycles): + print(f"gripper cycle {cycle + 1}/{args.cycles}", flush=True) + for trigger in sweep: + seq += 1 + now_s = time.monotonic() + pose = trigger_to_pose( + trigger, + open_pose=config.open_pose, + close_pose=config.close_pose, + deadzone=config.trigger_deadzone, + thumb_yaw_default=config.thumb_yaw_center, + ) + print(f" grip=1.00 trigger={trigger:.2f} -> {pose}", flush=True) + provider.snapshot = controller_snapshot( + timestamp_s=now_s, + seq=seq, + trigger=trigger, + grip=1.0, + config=config, + ) + runtime.tick(active=True, now_s=now_s) + wait_runtime_idle(runtime) + time.sleep(args.hold_s) + except KeyboardInterrupt: + print("Interrupted; opening hands before exit", flush=True) + finally: + runtime.tick(active=False) + wait_runtime_idle(runtime) + runtime.close() + + +def rh_to_pico_native(position: Sequence[float]) -> np.ndarray: + return np.asarray(position, dtype=np.float64) @ PICO_NATIVE_TO_RH + + +def synthetic_pico_hand_joints(hand_type: str, *, curl: float) -> np.ndarray: + curl = max(0.0, min(1.0, float(curl))) + side_sign = -1.0 if hand_type == "left" else 1.0 + joints = np.zeros((26, 7), dtype=np.float64) + + mp_landmarks = np.zeros((21, 3), dtype=np.float64) + mp_landmarks[0] = [0.0, 0.0, 0.0] + finger_bases = [ + (1, side_sign * 0.035, 0.035, [0.018, 0.033, 0.046, 0.058]), + (5, side_sign * 0.020, 0.060, [0.040, 0.070, 0.095, 0.120]), + (9, 0.0, 0.065, [0.045, 0.080, 0.110, 0.140]), + (13, -side_sign * 0.020, 0.060, [0.040, 0.070, 0.095, 0.120]), + (17, -side_sign * 0.040, 0.052, [0.035, 0.060, 0.082, 0.102]), + ] + for base_idx, x, base_y, lengths in finger_bases: + for offset, length in enumerate(lengths): + bend = curl * (offset + 1) / len(lengths) + y = base_y + length * (1.0 - 0.65 * bend) + z = -0.055 * bend + if base_idx == 1: + x_pos = x + side_sign * length * 0.65 + y = 0.015 + length * (1.0 - 0.35 * bend) + else: + x_pos = x + mp_landmarks[base_idx + offset] = [x_pos, y, z] + + for mp_idx, pico_idx in enumerate(PICO_BRIDGE_TO_MEDIAPIPE): + joints[pico_idx, :3] = rh_to_pico_native(mp_landmarks[mp_idx]) + joints[0, :3] = rh_to_pico_native([0.0, 0.025, 0.0]) + return joints + + +def hand_snapshot(*, timestamp_s: float, seq: int, curl: float) -> PicoHandSnapshot: + return PicoHandSnapshot( + left=PicoHandState(active=True, joints=synthetic_pico_hand_joints("left", curl=curl), present=True), + right=PicoHandState(active=True, joints=synthetic_pico_hand_joints("right", curl=curl), present=True), + timestamp_s=timestamp_s, + seq=seq, + ) + + +def run_vr_hand_pose(args: argparse.Namespace) -> None: + if args.hand_type != "both": + raise SystemExit("dexterous_hand.mode=vr_hand_pose currently requires --hand-type both") + + config = make_config(args, mode="vr_hand_pose") + provider = ScriptHandProvider() + runtime = SomeHandPoseRuntime(config, provider) + + print( + "Testing dexterous_hand.mode=vr_hand_pose with synthetic Pico hand-pose snapshots. " + "This drives poses produced by somehand; start with the robot clear of contacts.", + flush=True, + ) + try: + runtime.start() + wait_runtime_idle(runtime) + assert_runtime_started(runtime) + + seq = 0 + curl_sweep = [0.0, 0.35, 0.7, 1.0, 0.7, 0.35, 0.0] + for cycle in range(args.cycles): + print(f"vr_hand_pose cycle {cycle + 1}/{args.cycles}", flush=True) + for curl in curl_sweep: + seq += 1 + now_s = time.monotonic() + print(f" synthetic curl={curl:.2f}", flush=True) + provider.snapshot = hand_snapshot(timestamp_s=now_s, seq=seq, curl=curl) + runtime.tick(active=True, now_s=now_s) + wait_runtime_idle(runtime) + time.sleep(args.hold_s) + + print("inactive mode -> configured open pose", flush=True) + runtime.tick(active=False) + wait_runtime_idle(runtime) + except KeyboardInterrupt: + print("Interrupted; opening hands before exit", flush=True) + finally: + runtime.tick(active=False) + wait_runtime_idle(runtime) + runtime.close() + + +def main() -> None: + logging.basicConfig(level=logging.INFO, format="%(levelname)s:%(name)s:%(message)s") + args = parse_args() + if args.mode == "open_close": + run_open_close(args) + elif args.mode == "gripper": + run_gripper(args) + elif args.mode == "vr_hand_pose": + run_vr_hand_pose(args) + else: + raise AssertionError(f"Unhandled mode: {args.mode}") + + if __name__ == "__main__": main() diff --git a/scripts/run/run_sim2real.py b/scripts/run/run_sim2real.py index d546d8a3..84619387 100644 --- a/scripts/run/run_sim2real.py +++ b/scripts/run/run_sim2real.py @@ -18,6 +18,7 @@ def _print_sim2real_controls(cfg: DictConfig) -> None: print(" Remote L1+R1: DAMPING / estop.") if provider == "pico4": print(" Mocap pause/resume: Pico/controller A.") + print(" Dexterous hand: dexterous_hand.mode=off|gripper|vr_hand_pose (default off).") else: print(" Offline playback: A pause/resume, B replay from start.") print(" State flow: IDLE -> STANDING -> MOCAP -> STANDING, Any -> DAMPING.") diff --git a/scripts/setup/download_somehand_l6_assets.sh b/scripts/setup/download_somehand_l6_assets.sh new file mode 100755 index 00000000..447a50b0 --- /dev/null +++ b/scripts/setup/download_somehand_l6_assets.sh @@ -0,0 +1,189 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)" + +SOURCE="modelscope" +REPO_ID="" +DEST="${PROJECT_ROOT}/third_party/somehand/assets/mjcf" +CACHE_DIR="${PROJECT_ROOT}/data/somehand_assets_cache" + +usage() { + cat <<'EOF' +Download somehand LinkerHand L6 bi-hand MJCF assets. + +Usage: + scripts/setup/download_somehand_l6_assets.sh + scripts/setup/download_somehand_l6_assets.sh --source huggingface + scripts/setup/download_somehand_l6_assets.sh --dest third_party/somehand/assets/mjcf + +Options: + --source modelscope|huggingface Download backend (default: modelscope) + --repo-id REPO Override asset repo id + --dest PATH Destination mjcf directory + --cache-dir PATH Download cache directory + -h, --help Show this help +EOF +} + +while [[ $# -gt 0 ]]; do + case "$1" in + --source) + SOURCE="$2" + shift 2 + ;; + --repo-id) + REPO_ID="$2" + shift 2 + ;; + --dest) + DEST="$2" + shift 2 + ;; + --cache-dir) + CACHE_DIR="$2" + shift 2 + ;; + -h|--help) + usage + exit 0 + ;; + *) + echo "Unknown argument: $1" >&2 + usage >&2 + exit 2 + ;; + esac +done + +if [[ "${SOURCE}" != "modelscope" && "${SOURCE}" != "huggingface" ]]; then + echo "--source must be modelscope or huggingface" >&2 + exit 2 +fi + +python - "$PROJECT_ROOT" "$SOURCE" "$REPO_ID" "$DEST" "$CACHE_DIR" <<'PY' +from __future__ import annotations + +import shutil +import subprocess +import sys +import tarfile +from pathlib import Path + +project_root = Path(sys.argv[1]).resolve() +source = sys.argv[2] +repo_id = sys.argv[3] or ("BingqianWu/somehand-assets" if source == "modelscope" else "12e21/somehand-assets") +dest = Path(sys.argv[4]).expanduser() +if not dest.is_absolute(): + dest = (project_root / dest).resolve() +cache_dir = Path(sys.argv[5]).expanduser() +if not cache_dir.is_absolute(): + cache_dir = (project_root / cache_dir).resolve() + +hands = ("linkerhand_l6_left", "linkerhand_l6_right") +direct_patterns = [f"assets/mjcf/{hand}/**" for hand in hands] +archive_pattern = "archives/mjcf_assets.tar.gz" +repo_cache = cache_dir / source / repo_id.split("/")[-1] + + +def ensure_package(name: str): + try: + return __import__(name) + except ImportError: + subprocess.check_call([sys.executable, "-m", "pip", "install", name]) + return __import__(name) + + +def remove_path(path: Path) -> None: + if path.is_dir() and not path.is_symlink(): + shutil.rmtree(path) + elif path.exists() or path.is_symlink(): + path.unlink() + + +def copy_dir(src: Path, dst: Path) -> None: + remove_path(dst) + dst.parent.mkdir(parents=True, exist_ok=True) + shutil.copytree(src, dst) + + +def download(patterns: list[str]) -> None: + repo_cache.mkdir(parents=True, exist_ok=True) + if source == "huggingface": + hub = ensure_package("huggingface_hub") + hub.snapshot_download( + repo_id=repo_id, + repo_type="model", + local_dir=str(repo_cache), + allow_patterns=patterns, + ) + return + modelscope = ensure_package("modelscope") + modelscope.snapshot_download( + repo_id, + repo_type="model", + local_dir=str(repo_cache), + allow_patterns=patterns, + allow_file_pattern=patterns, + ) + + +def place_direct_assets() -> bool: + placed = False + for hand in hands: + src = repo_cache / "assets" / "mjcf" / hand + if not src.exists(): + return False + copy_dir(src, dest / hand) + print(f" {src.relative_to(repo_cache)} -> {dest / hand}") + placed = True + return placed + + +def safe_extract_l6_from_archive(archive_path: Path) -> None: + dest.mkdir(parents=True, exist_ok=True) + wanted_prefixes = {f"assets/mjcf/{hand}/" for hand in hands} + with tarfile.open(archive_path, "r:*") as tar: + members = [] + for member in tar.getmembers(): + path = Path(member.name) + if path.is_absolute() or ".." in path.parts: + raise ValueError(f"Unsafe archive member path: {member.name}") + normalized = member.name.lstrip("./") + if any(normalized.startswith(prefix) for prefix in wanted_prefixes): + members.append(member) + if not members: + raise FileNotFoundError(f"No LinkerHand L6 assets found in {archive_path}") + tmp = dest.parent / ".somehand_l6_extracting" + remove_path(tmp) + tmp.mkdir(parents=True, exist_ok=True) + tar.extractall(tmp, members=members) + for hand in hands: + src = tmp / "assets" / "mjcf" / hand + if not src.exists(): + raise FileNotFoundError(f"Archive missing assets/mjcf/{hand}") + copy_dir(src, dest / hand) + print(f" archive:{hand} -> {dest / hand}") + remove_path(tmp) + + +print(f"Downloading somehand L6 assets from {source}:{repo_id}") +print(f"Destination: {dest}") + +download(direct_patterns) +if not place_direct_assets(): + print("Direct L6 asset paths not found; downloading mjcf archive fallback.") + download([archive_pattern]) + archive = repo_cache / archive_pattern + if not archive.exists(): + raise FileNotFoundError(f"Downloaded repo is missing {archive_pattern}") + safe_extract_l6_from_archive(archive) + +for hand in hands: + model = dest / hand / "model.xml" + if not model.exists(): + raise FileNotFoundError(f"Expected model file not found: {model}") + +print("Done.") +PY diff --git a/teleopit/configs/pico4_sim2real.yaml b/teleopit/configs/pico4_sim2real.yaml index f3fc588f..3d639e94 100644 --- a/teleopit/configs/pico4_sim2real.yaml +++ b/teleopit/configs/pico4_sim2real.yaml @@ -27,9 +27,9 @@ standing_return_kp_ramp_floor_ratio: 0.5 # Joint velocity safety limit (rad/s) -- trigger emergency damping if exceeded joint_vel_limit: 10.0 -# Optional LinkerHand L6 control from Pico controller grip/trigger. +# Optional LinkerHand L6 control from Pico controller grip/trigger or VR hand pose. dexterous_hand: - enabled: false + mode: off # off | gripper | vr_hand_pose hand_joint: L6 hand_type: both left_can: can0 @@ -44,6 +44,9 @@ dexterous_hand: open_pose: [250, 10, 250, 250, 250, 250] close_pose: [79, 10, 0, 0, 0, 0] print_input: false + somehand: + config_path: third_party/somehand/configs/retargeting/bihand/linkerhand_l6_bihand.yaml + sdk_root: third_party/linkerhand-python-sdk # Physical robot SDK configuration real_robot: diff --git a/teleopit/configs/sim2real.yaml b/teleopit/configs/sim2real.yaml index cb6796a9..b483a881 100644 --- a/teleopit/configs/sim2real.yaml +++ b/teleopit/configs/sim2real.yaml @@ -26,9 +26,9 @@ standing_return_kp_ramp_floor_ratio: 0.5 # Joint velocity safety limit (rad/s) -- trigger emergency damping if exceeded joint_vel_limit: 10.0 -# Optional LinkerHand L6 control. Enable only with input.provider=pico4. +# Optional LinkerHand L6 control. Use only with input.provider=pico4. dexterous_hand: - enabled: false + mode: off # off | gripper | vr_hand_pose hand_joint: L6 hand_type: both left_can: can0 @@ -43,6 +43,9 @@ dexterous_hand: open_pose: [250, 10, 250, 250, 250, 250] close_pose: [79, 10, 0, 0, 0, 0] print_input: false + somehand: + config_path: third_party/somehand/configs/retargeting/bihand/linkerhand_l6_bihand.yaml + sdk_root: third_party/linkerhand-python-sdk # Physical robot SDK configuration real_robot: diff --git a/teleopit/inputs/pico4_provider.py b/teleopit/inputs/pico4_provider.py index 663d2ce5..1df18f3e 100644 --- a/teleopit/inputs/pico4_provider.py +++ b/teleopit/inputs/pico4_provider.py @@ -74,6 +74,26 @@ class PicoControllerSnapshot: timestamp_s: float seq: int + +@dataclass(frozen=True) +class PicoHandState: + """Latest per-hand pose state exposed by pico_bridge.""" + + active: bool + joints: NDArray[np.float64] + present: bool = True + + +@dataclass(frozen=True) +class PicoHandSnapshot: + """Immutable snapshot of Pico hand poses for auxiliary runtimes.""" + + left: PicoHandState + right: PicoHandState + timestamp_s: float + seq: int + + _PAUSE_BUTTON_MAP: dict[str, tuple[str, str]] = { "A": ("right", "primaryButton"), "B": ("right", "secondaryButton"), @@ -212,6 +232,7 @@ def __init__( self._last_frame_timestamp: float | None = None self._last_source_seq: int | None = None self._controller_snapshot: PicoControllerSnapshot | None = None + self._hand_snapshot: PicoHandSnapshot | None = None self._ground_lift_offset: float | None = None self._bridge = bridge_cls( host=bridge_host, @@ -298,6 +319,11 @@ def get_controller_snapshot(self) -> PicoControllerSnapshot | None: with self._lock: return self._controller_snapshot + def get_hand_snapshot(self) -> PicoHandSnapshot | None: + """Return the latest Pico hand-pose snapshot, if one has arrived.""" + with self._lock: + return self._hand_snapshot + def push_video_frame(self, frame: NDArray[np.uint8]) -> int: """Push one RGB camera frame to pico-bridge 0.2.1 video output.""" push_video_frame = getattr(self._bridge, "push_video_frame", None) @@ -363,6 +389,7 @@ def _poll_loop(self) -> None: def _accept_pico_frame(self, frame: Any) -> bool: timestamp = float(getattr(frame, "receive_time_s", time.monotonic())) self._accept_controller_snapshot(frame, timestamp=timestamp) + self._accept_hand_snapshot(frame, timestamp=timestamp) self._poll_control_events(frame, timestamp=timestamp) body = getattr(frame, "body", None) @@ -417,6 +444,17 @@ def _accept_controller_snapshot(self, frame: Any, *, timestamp: float) -> None: with self._lock: self._controller_snapshot = snapshot + def _accept_hand_snapshot(self, frame: Any, *, timestamp: float) -> None: + seq = int(getattr(frame, "seq", self._last_source_seq or -1)) + snapshot = PicoHandSnapshot( + left=self._read_hand_state(getattr(frame, "left_hand", None)), + right=self._read_hand_state(getattr(frame, "right_hand", None)), + timestamp_s=float(timestamp), + seq=seq, + ) + with self._lock: + self._hand_snapshot = snapshot + def _poll_control_events(self, frame: Any, *, timestamp: float) -> bool: if self._pause_button_path is None: return False @@ -461,6 +499,25 @@ def _read_controller_state(controller: Any) -> PicoControllerState: present=controller is not None, ) + @staticmethod + def _read_hand_state(hand: Any) -> PicoHandState: + joints = np.zeros((26, 7), dtype=np.float64) + valid_shape = False + if hand is None: + return PicoHandState(active=False, joints=joints, present=False) + try: + raw_joints = np.asarray(getattr(hand, "joints"), dtype=np.float64) + if raw_joints.shape == (26, 7): + joints = raw_joints.copy() + valid_shape = True + except (TypeError, ValueError): + pass + return PicoHandState( + active=bool(getattr(hand, "active", False)) and valid_shape, + joints=joints, + present=True, + ) + @staticmethod def _convert_body_joints_to_frame(body_joints: NDArray[np.float64]) -> HumanFrame: body_pose_dict: dict[str, list] = {} diff --git a/teleopit/sim2real/dexterous_hand.py b/teleopit/sim2real/dexterous_hand.py index 98949d3b..98fbebc0 100644 --- a/teleopit/sim2real/dexterous_hand.py +++ b/teleopit/sim2real/dexterous_hand.py @@ -4,20 +4,30 @@ from dataclasses import dataclass import logging +from pathlib import Path import threading import time from typing import Any, Protocol, Sequence -from teleopit.inputs.pico4_provider import PicoControllerSnapshot, PicoControllerState +from teleopit.inputs.pico4_provider import ( + PicoControllerSnapshot, + PicoControllerState, + PicoHandSnapshot, + PicoHandState, +) from teleopit.runtime.common import cfg_get logger = logging.getLogger(__name__) +PROJECT_ROOT = Path(__file__).resolve().parents[2] THUMB_YAW_DEFAULT = 10 OPEN_POSE = [250, THUMB_YAW_DEFAULT, 250, 250, 250, 250] CLOSE_POSE = [79, THUMB_YAW_DEFAULT, 0, 0, 0, 0] DEFAULT_SPEED = [50, 50, 50, 50, 50, 50] HAND_TYPES = ("left", "right") +HAND_MODES = ("off", "gripper", "vr_hand_pose") +DEFAULT_SOMEHAND_CONFIG_PATH = "third_party/somehand/configs/retargeting/bihand/linkerhand_l6_bihand.yaml" +DEFAULT_LINKERHAND_SDK_ROOT = "third_party/linkerhand-python-sdk" class ControllerSnapshotProvider(Protocol): @@ -25,8 +35,14 @@ def get_controller_snapshot(self) -> PicoControllerSnapshot | None: ... +class HandSnapshotProvider(Protocol): + def get_hand_snapshot(self) -> PicoHandSnapshot | None: + ... + + @dataclass(frozen=True) class LinkerHandConfig: + mode: str = "off" enabled: bool = False hand_joint: str = "L6" hand_type: str = "both" @@ -42,6 +58,8 @@ class LinkerHandConfig: open_pose: tuple[int, ...] = tuple(OPEN_POSE) close_pose: tuple[int, ...] = tuple(CLOSE_POSE) print_input: bool = False + somehand_config_path: str = DEFAULT_SOMEHAND_CONFIG_PATH + somehand_sdk_root: str = DEFAULT_LINKERHAND_SDK_ROOT @property def selected_hand_types(self) -> tuple[str, ...]: @@ -88,6 +106,10 @@ def trigger_to_pose( def parse_linkerhand_config(cfg: Any) -> LinkerHandConfig: hand_cfg = cfg_get(cfg, "dexterous_hand", {}) or {} + raw_mode = cfg_get(hand_cfg, "mode", None) + legacy_enabled = bool(cfg_get(hand_cfg, "enabled", False)) + mode = str(raw_mode if raw_mode is not None else ("gripper" if legacy_enabled else "off")).lower() + somehand_cfg = cfg_get(hand_cfg, "somehand", {}) or {} thumb_yaw = _uint8(cfg_get(hand_cfg, "thumb_yaw_center", THUMB_YAW_DEFAULT), "thumb_yaw_center") open_pose = _pose_values(cfg_get(hand_cfg, "open_pose", OPEN_POSE), "open_pose") close_pose = _pose_values(cfg_get(hand_cfg, "close_pose", CLOSE_POSE), "close_pose") @@ -95,7 +117,8 @@ def parse_linkerhand_config(cfg: Any) -> LinkerHandConfig: close_pose[1] = thumb_yaw config = LinkerHandConfig( - enabled=bool(cfg_get(hand_cfg, "enabled", False)), + mode=mode, + enabled=mode != "off", hand_joint=str(cfg_get(hand_cfg, "hand_joint", "L6")).upper(), hand_type=str(cfg_get(hand_cfg, "hand_type", "both")).lower(), left_can=str(cfg_get(hand_cfg, "left_can", "can0")), @@ -110,7 +133,11 @@ def parse_linkerhand_config(cfg: Any) -> LinkerHandConfig: open_pose=tuple(open_pose), close_pose=tuple(close_pose), print_input=bool(cfg_get(hand_cfg, "print_input", False)), + somehand_config_path=str(cfg_get(somehand_cfg, "config_path", DEFAULT_SOMEHAND_CONFIG_PATH)), + somehand_sdk_root=str(cfg_get(somehand_cfg, "sdk_root", DEFAULT_LINKERHAND_SDK_ROOT)), ) + if config.mode not in HAND_MODES: + raise ValueError(f"dexterous_hand.mode must be one of {', '.join(HAND_MODES)}, got {config.mode!r}") if config.hand_joint != "L6": raise ValueError(f"dexterous_hand.hand_joint must be 'L6', got {config.hand_joint!r}") if config.hand_type not in ("left", "right", "both"): @@ -142,7 +169,7 @@ def start(self) -> None: from LinkerHand.linker_hand_api import LinkerHandApi except ImportError as exc: raise ImportError( - "LinkerHand SDK is required when dexterous_hand.enabled=true. " + "LinkerHand SDK is required when dexterous_hand.mode is gripper or vr_hand_pose. " "Run: pip install -e third_party/linkerhand-python-sdk" ) from exc @@ -430,6 +457,139 @@ def _set_status(self, hand_type: str, status: str, message: str) -> None: logger.info("LinkerHand L6: %s", message) +class SomeHandPoseRuntime: + """Drive LinkerHand L6 from Pico hand-pose snapshots through somehand.""" + + def __init__(self, config: LinkerHandConfig, provider: HandSnapshotProvider): + self.config = config + self._provider = provider + self._sender = AsyncL6PoseSender(config) + self._interval_s = 1.0 / config.rate + self._next_tick_s = 0.0 + self._active = False + self._last_status: dict[str, str] = {hand_type: "" for hand_type in config.selected_hand_types} + self._engine: Any | None = None + self._hand_frame_cls: Any | None = None + self._bihand_frame_cls: Any | None = None + self._pico_hand_to_landmarks: Any | None = None + self._adapters: dict[str, Any] = {} + + @property + def enabled(self) -> bool: + return self.config.enabled + + def start(self) -> None: + if not self.enabled: + return + self._load_somehand() + self._sender.start() + self._sender.send_all(self.config.open_pose, force=True, reason="startup") + + def tick(self, *, active: bool, now_s: float | None = None) -> None: + if not self.enabled: + return + now = time.monotonic() if now_s is None else float(now_s) + if not active: + self._deactivate(reason="inactive") + return + if not self._active: + self._active = True + self._next_tick_s = 0.0 + if now < self._next_tick_s: + return + self._next_tick_s = now + self._interval_s + + snapshot = self._provider.get_hand_snapshot() + if snapshot is None: + self._set_status("both", "missing", "Pico hand pose missing; holding last hand command") + return + if now - snapshot.timestamp_s > self.config.frame_timeout: + self._set_status("both", "timeout", "Pico hand pose timed out; holding last hand command") + return + + self._tick_snapshot(snapshot) + + def close(self) -> None: + self._deactivate(reason="shutdown") + self._sender.close() + + def _tick_snapshot(self, snapshot: PicoHandSnapshot) -> None: + left_frame = self._make_hand_frame("left", snapshot.left) if "left" in self.config.selected_hand_types else None + right_frame = self._make_hand_frame("right", snapshot.right) if "right" in self.config.selected_hand_types else None + if left_frame is None and right_frame is None: + return + + result = self._engine.process(self._bihand_frame_cls(left=left_frame, right=right_frame)) + for hand_type, detected, step in ( + ("left", result.left_detected, result.left), + ("right", result.right_detected, result.right), + ): + if hand_type not in self.config.selected_hand_types or not detected: + continue + pose = self._adapters[hand_type].qpos_to_sdk_range(step.qpos) + self._sender.send(hand_type, pose, reason="vr-hand-pose") + + def _make_hand_frame(self, hand_type: str, state: PicoHandState) -> Any | None: + if not state.present: + self._set_status(hand_type, "missing", f"{hand_type} hand pose missing; holding last hand command") + return None + if not state.active: + self._set_status(hand_type, "inactive", f"{hand_type} hand pose inactive; holding last hand command") + return None + self._set_status(hand_type, "enabled", f"{hand_type} hand pose active") + landmarks = self._pico_hand_to_landmarks(state.joints) + return self._hand_frame_cls(landmarks_3d=landmarks, landmarks_2d=None, hand_side=hand_type) + + def _deactivate(self, *, reason: str) -> None: + if self._active: + self._sender.send_all(self.config.open_pose, force=True, reason=reason) + self._active = False + + def _load_somehand(self) -> None: + try: + from somehand.api import BiHandFrame, BiHandRetargetingEngine, HandFrame + from somehand.infrastructure import HandModel, LinkerHandModelAdapter, infer_linkerhand_model_family + from somehand.pico_input import pico_hand_to_landmarks + except ImportError as exc: + raise ImportError( + "somehand is required when dexterous_hand.mode=vr_hand_pose. " + "Install it with: pip install -e '.[dexhand]'" + ) from exc + + config_path = _resolve_project_path(self.config.somehand_config_path) + if not config_path.exists(): + raise FileNotFoundError( + "somehand bi-hand config not found: " + f"{config_path}. Initialize the submodule and download assets with " + "scripts/setup/download_somehand_l6_assets.sh" + ) + sdk_root = _resolve_project_path(self.config.somehand_sdk_root) + self._engine = BiHandRetargetingEngine.from_config_path(str(config_path)) + self._hand_frame_cls = HandFrame + self._bihand_frame_cls = BiHandFrame + self._pico_hand_to_landmarks = pico_hand_to_landmarks + + self._adapters = {} + for hand_type, engine in (("left", self._engine.left_engine), ("right", self._engine.right_engine)): + if hand_type not in self.config.selected_hand_types: + continue + family = infer_linkerhand_model_family(engine.config.hand.name) + self._adapters[hand_type] = LinkerHandModelAdapter( + HandModel(engine.config.hand.mjcf_path), + family=family, + hand_side=hand_type, + sdk_root=str(sdk_root), + ) + logger.info("somehand LinkerHand L6 runtime started | hands=%s", ",".join(self.config.selected_hand_types)) + + def _set_status(self, hand_type: str, status: str, message: str) -> None: + key = hand_type + if self._last_status.get(key) == status: + return + self._last_status[key] = status + logger.info("somehand LinkerHand L6: %s", message) + + class DisabledLinkerHandRuntime: enabled = False @@ -443,7 +603,7 @@ def close(self) -> None: pass -def build_linkerhand_runtime(cfg: Any, input_provider: Any) -> LinkerHandRuntime | DisabledLinkerHandRuntime: +def build_linkerhand_runtime(cfg: Any, input_provider: Any) -> LinkerHandRuntime | SomeHandPoseRuntime | DisabledLinkerHandRuntime: config = parse_linkerhand_config(cfg) if not config.enabled: return DisabledLinkerHandRuntime() @@ -451,10 +611,25 @@ def build_linkerhand_runtime(cfg: Any, input_provider: Any) -> LinkerHandRuntime input_cfg = cfg_get(cfg, "input", {}) or {} provider_kind = str(cfg_get(input_cfg, "provider", "")).lower() if provider_kind != "pico4": - raise ValueError("dexterous_hand.enabled=true requires input.provider=pico4") - if not callable(getattr(input_provider, "get_controller_snapshot", None)): - raise ValueError("dexterous_hand.enabled=true requires a Pico input provider with controller snapshots") - return LinkerHandRuntime(config, input_provider) + raise ValueError("dexterous_hand.mode requires input.provider=pico4") + if config.mode == "gripper": + if not callable(getattr(input_provider, "get_controller_snapshot", None)): + raise ValueError("dexterous_hand.mode=gripper requires a Pico input provider with controller snapshots") + return LinkerHandRuntime(config, input_provider) + if config.mode == "vr_hand_pose": + if config.hand_type != "both": + raise ValueError("dexterous_hand.mode=vr_hand_pose currently requires dexterous_hand.hand_type=both") + if not callable(getattr(input_provider, "get_hand_snapshot", None)): + raise ValueError("dexterous_hand.mode=vr_hand_pose requires a Pico input provider with hand snapshots") + return SomeHandPoseRuntime(config, input_provider) + raise ValueError(f"Unsupported dexterous_hand.mode={config.mode!r}") + + +def _resolve_project_path(path_value: str) -> Path: + path = Path(path_value).expanduser() + if path.is_absolute(): + return path + return (PROJECT_ROOT / path).resolve() def _positive_float(value: object, field_name: str) -> float: diff --git a/tests/test_dexterous_hand.py b/tests/test_dexterous_hand.py index de39e01c..95171d28 100644 --- a/tests/test_dexterous_hand.py +++ b/tests/test_dexterous_hand.py @@ -3,12 +3,14 @@ import sys from types import SimpleNamespace +import numpy as np import pytest -from teleopit.inputs.pico4_provider import PicoControllerSnapshot, PicoControllerState +from teleopit.inputs.pico4_provider import PicoControllerSnapshot, PicoControllerState, PicoHandSnapshot, PicoHandState from teleopit.sim2real.dexterous_hand import ( L6PoseSender, LinkerHandRuntime, + SomeHandPoseRuntime, parse_linkerhand_config, trigger_to_pose, ) @@ -62,6 +64,14 @@ def get_controller_snapshot(self) -> PicoControllerSnapshot | None: return self.snapshot +class HandSnapshotProvider: + def __init__(self) -> None: + self.snapshot: PicoHandSnapshot | None = None + + def get_hand_snapshot(self) -> PicoHandSnapshot | None: + return self.snapshot + + def _snapshot( *, left: PicoControllerState | None = None, @@ -78,6 +88,28 @@ def _snapshot( ) +def _hand_snapshot( + *, + left: PicoHandState | None = None, + right: PicoHandState | None = None, + timestamp_s: float = 10.0, + seq: int = 1, +) -> PicoHandSnapshot: + missing = PicoHandState(active=False, joints=np.zeros((26, 7), dtype=np.float64), present=False) + return PicoHandSnapshot( + left=left or missing, + right=right or missing, + timestamp_s=timestamp_s, + seq=seq, + ) + + +def _hand_state(*, active: bool = True, value: float = 1.0, present: bool = True) -> PicoHandState: + joints = np.zeros((26, 7), dtype=np.float64) + joints[:, 0] = value + return PicoHandState(active=active, joints=joints, present=present) + + def _runtime(provider: SnapshotProvider) -> LinkerHandRuntime: cfg = parse_linkerhand_config( { @@ -244,6 +276,115 @@ def close_can(self) -> None: assert created_hands[0].hand.close_calls == 1 +def _install_fake_somehand(monkeypatch, *, left_pose: list[int], right_pose: list[int]) -> None: + class FakeHandFrame: + def __init__(self, *, landmarks_3d, landmarks_2d, hand_side): + self.landmarks_3d = landmarks_3d + self.landmarks_2d = landmarks_2d + self.hand_side = hand_side + + class FakeBiHandFrame: + def __init__(self, *, left=None, right=None): + self.left = left + self.right = right + + class FakeEngine: + def __init__(self): + self.left_engine = SimpleNamespace( + config=SimpleNamespace(hand=SimpleNamespace(name="linkerhand_l6_left", mjcf_path="left.xml")) + ) + self.right_engine = SimpleNamespace( + config=SimpleNamespace(hand=SimpleNamespace(name="linkerhand_l6_right", mjcf_path="right.xml")) + ) + + @classmethod + def from_config_path(cls, _path: str): + return cls() + + def process(self, frame): + return SimpleNamespace( + left_detected=frame.left is not None, + right_detected=frame.right is not None, + left=SimpleNamespace(qpos=np.array([1.0], dtype=np.float64)), + right=SimpleNamespace(qpos=np.array([2.0], dtype=np.float64)), + ) + + class FakeHandModel: + def __init__(self, mjcf_path: str): + self.mjcf_path = mjcf_path + + class FakeAdapter: + def __init__(self, _hand_model, *, family: str, hand_side: str, sdk_root: str): + del family, sdk_root + self.hand_side = hand_side + + def qpos_to_sdk_range(self, _qpos): + return left_pose if self.hand_side == "left" else right_pose + + fake_api = SimpleNamespace( + BiHandFrame=FakeBiHandFrame, + BiHandRetargetingEngine=FakeEngine, + HandFrame=FakeHandFrame, + ) + fake_infra = SimpleNamespace( + HandModel=FakeHandModel, + LinkerHandModelAdapter=FakeAdapter, + infer_linkerhand_model_family=lambda _name: "L6", + ) + fake_pico = SimpleNamespace(pico_hand_to_landmarks=lambda joints: np.asarray(joints, dtype=np.float64)[:21, :3]) + monkeypatch.setitem(sys.modules, "somehand.api", fake_api) + monkeypatch.setitem(sys.modules, "somehand.infrastructure", fake_infra) + monkeypatch.setitem(sys.modules, "somehand.pico_input", fake_pico) + + +def test_vr_hand_pose_runtime_holds_last_pose_when_hand_pose_disappears(monkeypatch, tmp_path) -> None: + _install_fake_somehand(monkeypatch, left_pose=[1, 2, 3, 4, 5, 6], right_pose=[6, 5, 4, 3, 2, 1]) + config_path = tmp_path / "linkerhand_l6_bihand.yaml" + config_path.write_text("left: {}\nright: {}\n", encoding="utf-8") + provider = HandSnapshotProvider() + cfg = parse_linkerhand_config( + { + "input": {"provider": "pico4"}, + "dexterous_hand": { + "mode": "vr_hand_pose", + "hand_type": "both", + "somehand": {"config_path": str(config_path), "sdk_root": "third_party/linkerhand-python-sdk"}, + }, + } + ) + runtime = SomeHandPoseRuntime(cfg, provider) + runtime.start() + + provider.snapshot = _hand_snapshot( + left=_hand_state(active=True, value=1.0), + right=_hand_state(active=True, value=2.0), + timestamp_s=10.0, + ) + runtime.tick(active=True, now_s=10.0) + assert runtime._sender.wait_idle(timeout_s=1.0) + + assert runtime._sender._last_pose["left"] == [1, 2, 3, 4, 5, 6] + assert runtime._sender._last_pose["right"] == [6, 5, 4, 3, 2, 1] + + provider.snapshot = _hand_snapshot( + left=_hand_state(active=False, value=9.0), + right=_hand_state(active=False, value=9.0, present=False), + timestamp_s=10.1, + seq=2, + ) + runtime.tick(active=True, now_s=10.1) + assert runtime._sender.wait_idle(timeout_s=1.0) + + assert runtime._sender._last_pose["left"] == [1, 2, 3, 4, 5, 6] + assert runtime._sender._last_pose["right"] == [6, 5, 4, 3, 2, 1] + + runtime.tick(active=False, now_s=10.2) + assert runtime._sender.wait_idle(timeout_s=1.0) + assert runtime._sender._last_pose["left"] == list(runtime.config.open_pose) + assert runtime._sender._last_pose["right"] == list(runtime.config.open_pose) + runtime.close() + + def test_pose_sender_wraps_sdk_system_exit_and_cleans_up(monkeypatch) -> None: created_hands = [] diff --git a/tests/test_pico4_provider.py b/tests/test_pico4_provider.py index c6f6cabe..da82c507 100644 --- a/tests/test_pico4_provider.py +++ b/tests/test_pico4_provider.py @@ -39,6 +39,12 @@ def _pico_frame( ) +def _hand_state(*, active: bool, value: float) -> SimpleNamespace: + joints = np.zeros((26, 7), dtype=np.float64) + joints[:, 0:3] = value + return SimpleNamespace(active=active, joints=joints) + + def _make_provider() -> Pico4InputProvider: provider = object.__new__(Pico4InputProvider) provider._lock = threading.Lock() @@ -56,6 +62,8 @@ def _make_provider() -> Pico4InputProvider: provider._last_frame_timestamp = None provider._last_source_seq = None provider._ground_lift_offset = None + provider._controller_snapshot = None + provider._hand_snapshot = None provider._closed = False return provider @@ -273,3 +281,22 @@ def test_pico4_provider_reads_pause_control_events_when_body_inactive() -> None: events = provider.pop_control_events() assert [event.event_type for event in events] == [ControlEventType.TOGGLE_PAUSE] assert provider._last_source_seq == 1 + + +def test_pico4_provider_exposes_hand_snapshot_when_body_inactive() -> None: + provider = _make_provider() + frame = _pico_frame(_body_poses(1.0), seq=4, timestamp=2.0, body_active=False) + frame.left_hand = _hand_state(active=True, value=1.5) + frame.right_hand = _hand_state(active=False, value=2.5) + + assert provider._accept_pico_frame(frame) is False + + snapshot = provider.get_hand_snapshot() + assert snapshot is not None + assert snapshot.seq == 4 + assert snapshot.timestamp_s == pytest.approx(2.0) + assert snapshot.left.present is True + assert snapshot.left.active is True + assert snapshot.right.present is True + assert snapshot.right.active is False + np.testing.assert_allclose(snapshot.left.joints[:, 0:3], 1.5) diff --git a/third_party/somehand b/third_party/somehand new file mode 160000 index 00000000..3e5cc805 --- /dev/null +++ b/third_party/somehand @@ -0,0 +1 @@ +Subproject commit 3e5cc8052c75f14e6ee6ce400e6762a20bdd42fa From 90a8a21b349668d6acdd08effa240145519b3628 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Wed, 27 May 2026 22:22:52 +0800 Subject: [PATCH 043/122] Fix somehand L6 archive extraction paths --- scripts/setup/download_somehand_l6_assets.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/setup/download_somehand_l6_assets.sh b/scripts/setup/download_somehand_l6_assets.sh index 447a50b0..b60db5e7 100755 --- a/scripts/setup/download_somehand_l6_assets.sh +++ b/scripts/setup/download_somehand_l6_assets.sh @@ -143,7 +143,7 @@ def place_direct_assets() -> bool: def safe_extract_l6_from_archive(archive_path: Path) -> None: dest.mkdir(parents=True, exist_ok=True) - wanted_prefixes = {f"assets/mjcf/{hand}/" for hand in hands} + wanted_prefixes = {f"mjcf/{hand}/" for hand in hands} with tarfile.open(archive_path, "r:*") as tar: members = [] for member in tar.getmembers(): @@ -160,9 +160,9 @@ def safe_extract_l6_from_archive(archive_path: Path) -> None: tmp.mkdir(parents=True, exist_ok=True) tar.extractall(tmp, members=members) for hand in hands: - src = tmp / "assets" / "mjcf" / hand + src = tmp / "mjcf" / hand if not src.exists(): - raise FileNotFoundError(f"Archive missing assets/mjcf/{hand}") + raise FileNotFoundError(f"Archive missing mjcf/{hand}") copy_dir(src, dest / hand) print(f" archive:{hand} -> {dest / hand}") remove_path(tmp) From 509e80f7444e7865ffe86063f752a2bd0183e50a Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Wed, 27 May 2026 22:27:38 +0800 Subject: [PATCH 044/122] Fix somehand L6 asset download --- scripts/setup/download_somehand_l6_assets.sh | 26 ++++---------------- 1 file changed, 5 insertions(+), 21 deletions(-) diff --git a/scripts/setup/download_somehand_l6_assets.sh b/scripts/setup/download_somehand_l6_assets.sh index b60db5e7..61835102 100755 --- a/scripts/setup/download_somehand_l6_assets.sh +++ b/scripts/setup/download_somehand_l6_assets.sh @@ -82,7 +82,6 @@ if not cache_dir.is_absolute(): cache_dir = (project_root / cache_dir).resolve() hands = ("linkerhand_l6_left", "linkerhand_l6_right") -direct_patterns = [f"assets/mjcf/{hand}/**" for hand in hands] archive_pattern = "archives/mjcf_assets.tar.gz" repo_cache = cache_dir / source / repo_id.split("/")[-1] @@ -129,18 +128,6 @@ def download(patterns: list[str]) -> None: ) -def place_direct_assets() -> bool: - placed = False - for hand in hands: - src = repo_cache / "assets" / "mjcf" / hand - if not src.exists(): - return False - copy_dir(src, dest / hand) - print(f" {src.relative_to(repo_cache)} -> {dest / hand}") - placed = True - return placed - - def safe_extract_l6_from_archive(archive_path: Path) -> None: dest.mkdir(parents=True, exist_ok=True) wanted_prefixes = {f"mjcf/{hand}/" for hand in hands} @@ -171,14 +158,11 @@ def safe_extract_l6_from_archive(archive_path: Path) -> None: print(f"Downloading somehand L6 assets from {source}:{repo_id}") print(f"Destination: {dest}") -download(direct_patterns) -if not place_direct_assets(): - print("Direct L6 asset paths not found; downloading mjcf archive fallback.") - download([archive_pattern]) - archive = repo_cache / archive_pattern - if not archive.exists(): - raise FileNotFoundError(f"Downloaded repo is missing {archive_pattern}") - safe_extract_l6_from_archive(archive) +download([archive_pattern]) +archive = repo_cache / archive_pattern +if not archive.exists(): + raise FileNotFoundError(f"Downloaded repo is missing {archive_pattern}") +safe_extract_l6_from_archive(archive) for hand in hands: model = dest / hand / "model.xml" From 4485c1a61460551e8093ccfdac3380ec49e02cbd Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Wed, 27 May 2026 22:41:42 +0800 Subject: [PATCH 045/122] Add live dexterous hand test modes --- scripts/dev/test_linkerhand_l6.py | 228 +++++++++--------------------- 1 file changed, 67 insertions(+), 161 deletions(-) diff --git a/scripts/dev/test_linkerhand_l6.py b/scripts/dev/test_linkerhand_l6.py index 91c9a1ae..e7ae2993 100644 --- a/scripts/dev/test_linkerhand_l6.py +++ b/scripts/dev/test_linkerhand_l6.py @@ -10,8 +10,6 @@ import time from typing import Sequence -import numpy as np - REPO_ROOT = Path(__file__).resolve().parents[2] sys.path.insert(0, str(REPO_ROOT)) @@ -23,16 +21,12 @@ sys.path.insert(0, str(SOMEHAND_SRC_PATH)) from teleopit.inputs.pico4_provider import ( # noqa: E402 - PicoControllerSnapshot, - PicoControllerState, - PicoHandSnapshot, - PicoHandState, + Pico4InputProvider, ) from teleopit.sim2real.dexterous_hand import ( # noqa: E402 LinkerHandConfig, LinkerHandRuntime, SomeHandPoseRuntime, - trigger_to_pose, ) @@ -43,30 +37,6 @@ DEFAULT_SOMEHAND_CONFIG_PATH = "third_party/somehand/configs/retargeting/bihand/linkerhand_l6_bihand.yaml" DEFAULT_LINKERHAND_SDK_ROOT = "third_party/linkerhand-python-sdk" -PICO_BRIDGE_TO_MEDIAPIPE = [ - 1, 2, 3, 4, 5, 7, 8, 9, 10, 12, 13, 14, 15, 17, 18, 19, 20, 22, 23, 24, 25 -] -PICO_NATIVE_TO_RH = np.array( - [[1.0, 0.0, 0.0], [0.0, 0.0, -1.0], [0.0, 1.0, 0.0]], - dtype=np.float64, -) - - -class ScriptControllerProvider: - def __init__(self) -> None: - self.snapshot: PicoControllerSnapshot | None = None - - def get_controller_snapshot(self) -> PicoControllerSnapshot | None: - return self.snapshot - - -class ScriptHandProvider: - def __init__(self) -> None: - self.snapshot: PicoHandSnapshot | None = None - - def get_hand_snapshot(self) -> PicoHandSnapshot | None: - return self.snapshot - def uint8(value: str) -> int: parsed = int(value) @@ -102,9 +72,8 @@ def parse_args() -> argparse.Namespace: choices=["open_close", "gripper", "vr_hand_pose"], default="open_close", help=( - "open_close sends fixed poses directly; gripper exercises the sim2real Pico " - "grip/trigger mapping; vr_hand_pose exercises the somehand Pico hand-pose path " - "with synthetic hand landmarks." + "open_close sends fixed poses directly; gripper reads real Pico controller " + "grip/trigger input; vr_hand_pose reads real Pico hand pose input through somehand." ), ) parser.add_argument("--hand-type", choices=["left", "right", "both"], default="both") @@ -117,6 +86,12 @@ def parse_args() -> argparse.Namespace: ) parser.add_argument("--cycles", type=positive_int, default=3) parser.add_argument("--hold-s", type=positive_float, default=1.0) + parser.add_argument( + "--duration-s", + type=positive_float, + default=30.0, + help="Live Pico test duration for gripper/vr_hand_pose modes.", + ) parser.add_argument("--rate", type=positive_float, default=30.0) parser.add_argument("--frame-timeout", type=positive_float, default=0.3) parser.add_argument("--trigger-deadzone", type=float, default=0.05) @@ -146,6 +121,11 @@ def parse_args() -> argparse.Namespace: ) parser.add_argument("--somehand-config-path", default=DEFAULT_SOMEHAND_CONFIG_PATH) parser.add_argument("--somehand-sdk-root", default=DEFAULT_LINKERHAND_SDK_ROOT) + parser.add_argument("--bridge-host", default="0.0.0.0") + parser.add_argument("--bridge-port", type=positive_int, default=63901) + parser.add_argument("--bridge-advertise-ip", default=None) + parser.add_argument("--bridge-start-timeout", type=positive_float, default=10.0) + parser.add_argument("--no-bridge-discovery", action="store_true") args = parser.parse_args() args.open_pose[1] = args.thumb_yaw_center args.close_pose[1] = args.thumb_yaw_center @@ -199,6 +179,46 @@ def assert_runtime_started(runtime: object) -> None: raise RuntimeError("LinkerHand sender failed to start; check the log above for SDK/CAN errors") +def make_pico_provider(args: argparse.Namespace) -> Pico4InputProvider: + return Pico4InputProvider( + timeout=args.duration_s, + pause_button=None, + bridge_host=args.bridge_host, + bridge_port=args.bridge_port, + bridge_discovery=not args.no_bridge_discovery, + bridge_advertise_ip=args.bridge_advertise_ip, + bridge_start_timeout=args.bridge_start_timeout, + bridge_video=None, + bridge_video_enabled=False, + ) + + +def run_live_until_done( + runtime: LinkerHandRuntime | SomeHandPoseRuntime, + *, + provider: Pico4InputProvider, + duration_s: float, + mode_label: str, +) -> None: + deadline = time.monotonic() + duration_s + last_seq: int | None = None + print(f"Running {mode_label} for {duration_s:.1f}s; press Ctrl-C to stop early.", flush=True) + while time.monotonic() < deadline: + now_s = time.monotonic() + runtime.tick(active=True, now_s=now_s) + snapshot = ( + provider.get_controller_snapshot() + if isinstance(runtime, LinkerHandRuntime) + else provider.get_hand_snapshot() + ) + if snapshot is not None and snapshot.seq != last_seq: + last_seq = snapshot.seq + age_ms = max((now_s - snapshot.timestamp_s) * 1000.0, 0.0) + print(f" pico seq={snapshot.seq} age={age_ms:.1f}ms", flush=True) + wait_runtime_idle(runtime) + time.sleep(max(1.0 / runtime.config.rate, 0.001)) + + def run_open_close(args: argparse.Namespace) -> None: try: from LinkerHand.linker_hand_api import LinkerHandApi @@ -251,127 +271,28 @@ def run_open_close(args: argparse.Namespace) -> None: ) -def controller_snapshot( - *, - timestamp_s: float, - seq: int, - trigger: float, - grip: float, - config: LinkerHandConfig, -) -> PicoControllerSnapshot: - missing = PicoControllerState(raw=False, grip=0.0, trigger=0.0, present=False) - active = PicoControllerState(raw=True, grip=grip, trigger=trigger, present=True) - left = active if "left" in config.selected_hand_types else missing - right = active if "right" in config.selected_hand_types else missing - return PicoControllerSnapshot(left=left, right=right, timestamp_s=timestamp_s, seq=seq) - - def run_gripper(args: argparse.Namespace) -> None: config = make_config(args, mode="gripper") - provider = ScriptControllerProvider() + provider = make_pico_provider(args) runtime = LinkerHandRuntime(config, provider) - print("Testing dexterous_hand.mode=gripper with synthetic Pico grip/trigger snapshots", flush=True) + print( + "Testing dexterous_hand.mode=gripper with real Pico controller input. " + "Hold grip above the deadman threshold, then use trigger to close/open.", + flush=True, + ) try: runtime.start() wait_runtime_idle(runtime) assert_runtime_started(runtime) - - now_s = time.monotonic() - print("inactive safety open", flush=True) - runtime.tick(active=False, now_s=now_s) - wait_runtime_idle(runtime) - time.sleep(args.hold_s) - - seq = 1 - print("deadman released -> open", flush=True) - now_s = time.monotonic() - provider.snapshot = controller_snapshot( - timestamp_s=now_s, - seq=seq, - trigger=1.0, - grip=0.0, - config=config, - ) - runtime.tick(active=True, now_s=now_s) - wait_runtime_idle(runtime) - time.sleep(args.hold_s) - - sweep = [0.0, 0.25, 0.5, 0.75, 1.0, 0.75, 0.5, 0.25, 0.0] - for cycle in range(args.cycles): - print(f"gripper cycle {cycle + 1}/{args.cycles}", flush=True) - for trigger in sweep: - seq += 1 - now_s = time.monotonic() - pose = trigger_to_pose( - trigger, - open_pose=config.open_pose, - close_pose=config.close_pose, - deadzone=config.trigger_deadzone, - thumb_yaw_default=config.thumb_yaw_center, - ) - print(f" grip=1.00 trigger={trigger:.2f} -> {pose}", flush=True) - provider.snapshot = controller_snapshot( - timestamp_s=now_s, - seq=seq, - trigger=trigger, - grip=1.0, - config=config, - ) - runtime.tick(active=True, now_s=now_s) - wait_runtime_idle(runtime) - time.sleep(args.hold_s) + run_live_until_done(runtime, provider=provider, duration_s=args.duration_s, mode_label="gripper") except KeyboardInterrupt: print("Interrupted; opening hands before exit", flush=True) finally: runtime.tick(active=False) wait_runtime_idle(runtime) runtime.close() - - -def rh_to_pico_native(position: Sequence[float]) -> np.ndarray: - return np.asarray(position, dtype=np.float64) @ PICO_NATIVE_TO_RH - - -def synthetic_pico_hand_joints(hand_type: str, *, curl: float) -> np.ndarray: - curl = max(0.0, min(1.0, float(curl))) - side_sign = -1.0 if hand_type == "left" else 1.0 - joints = np.zeros((26, 7), dtype=np.float64) - - mp_landmarks = np.zeros((21, 3), dtype=np.float64) - mp_landmarks[0] = [0.0, 0.0, 0.0] - finger_bases = [ - (1, side_sign * 0.035, 0.035, [0.018, 0.033, 0.046, 0.058]), - (5, side_sign * 0.020, 0.060, [0.040, 0.070, 0.095, 0.120]), - (9, 0.0, 0.065, [0.045, 0.080, 0.110, 0.140]), - (13, -side_sign * 0.020, 0.060, [0.040, 0.070, 0.095, 0.120]), - (17, -side_sign * 0.040, 0.052, [0.035, 0.060, 0.082, 0.102]), - ] - for base_idx, x, base_y, lengths in finger_bases: - for offset, length in enumerate(lengths): - bend = curl * (offset + 1) / len(lengths) - y = base_y + length * (1.0 - 0.65 * bend) - z = -0.055 * bend - if base_idx == 1: - x_pos = x + side_sign * length * 0.65 - y = 0.015 + length * (1.0 - 0.35 * bend) - else: - x_pos = x - mp_landmarks[base_idx + offset] = [x_pos, y, z] - - for mp_idx, pico_idx in enumerate(PICO_BRIDGE_TO_MEDIAPIPE): - joints[pico_idx, :3] = rh_to_pico_native(mp_landmarks[mp_idx]) - joints[0, :3] = rh_to_pico_native([0.0, 0.025, 0.0]) - return joints - - -def hand_snapshot(*, timestamp_s: float, seq: int, curl: float) -> PicoHandSnapshot: - return PicoHandSnapshot( - left=PicoHandState(active=True, joints=synthetic_pico_hand_joints("left", curl=curl), present=True), - right=PicoHandState(active=True, joints=synthetic_pico_hand_joints("right", curl=curl), present=True), - timestamp_s=timestamp_s, - seq=seq, - ) + provider.close() def run_vr_hand_pose(args: argparse.Namespace) -> None: @@ -379,41 +300,26 @@ def run_vr_hand_pose(args: argparse.Namespace) -> None: raise SystemExit("dexterous_hand.mode=vr_hand_pose currently requires --hand-type both") config = make_config(args, mode="vr_hand_pose") - provider = ScriptHandProvider() + provider = make_pico_provider(args) runtime = SomeHandPoseRuntime(config, provider) print( - "Testing dexterous_hand.mode=vr_hand_pose with synthetic Pico hand-pose snapshots. " - "This drives poses produced by somehand; start with the robot clear of contacts.", + "Testing dexterous_hand.mode=vr_hand_pose with real Pico hand-pose input. " + "Enable Pico hand tracking and move both hands; start with the robot clear of contacts.", flush=True, ) try: runtime.start() wait_runtime_idle(runtime) assert_runtime_started(runtime) - - seq = 0 - curl_sweep = [0.0, 0.35, 0.7, 1.0, 0.7, 0.35, 0.0] - for cycle in range(args.cycles): - print(f"vr_hand_pose cycle {cycle + 1}/{args.cycles}", flush=True) - for curl in curl_sweep: - seq += 1 - now_s = time.monotonic() - print(f" synthetic curl={curl:.2f}", flush=True) - provider.snapshot = hand_snapshot(timestamp_s=now_s, seq=seq, curl=curl) - runtime.tick(active=True, now_s=now_s) - wait_runtime_idle(runtime) - time.sleep(args.hold_s) - - print("inactive mode -> configured open pose", flush=True) - runtime.tick(active=False) - wait_runtime_idle(runtime) + run_live_until_done(runtime, provider=provider, duration_s=args.duration_s, mode_label="vr_hand_pose") except KeyboardInterrupt: print("Interrupted; opening hands before exit", flush=True) finally: runtime.tick(active=False) wait_runtime_idle(runtime) runtime.close() + provider.close() def main() -> None: From 0be0fa01e59fcb0517b2c23295dd9c1b609b261b Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Wed, 27 May 2026 23:21:53 +0800 Subject: [PATCH 046/122] Update somehand submodule for Pico import fix --- third_party/somehand | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/somehand b/third_party/somehand index 3e5cc805..a88bfe2f 160000 --- a/third_party/somehand +++ b/third_party/somehand @@ -1 +1 @@ -Subproject commit 3e5cc8052c75f14e6ee6ce400e6762a20bdd42fa +Subproject commit a88bfe2f09eb3821b310774aa166f8506645a35c From 4592efd93a6b97d693356d7eebbe024ad0b061e7 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Wed, 27 May 2026 23:48:24 +0800 Subject: [PATCH 047/122] Fix L6 hand pose mapping --- teleopit/sim2real/dexterous_hand.py | 97 +++++++++++++++++++++++++---- tests/test_dexterous_hand.py | 55 ++++++++-------- 2 files changed, 113 insertions(+), 39 deletions(-) diff --git a/teleopit/sim2real/dexterous_hand.py b/teleopit/sim2real/dexterous_hand.py index 98fbebc0..3a2890f4 100644 --- a/teleopit/sim2real/dexterous_hand.py +++ b/teleopit/sim2real/dexterous_hand.py @@ -9,6 +9,8 @@ import time from typing import Any, Protocol, Sequence +import numpy as np + from teleopit.inputs.pico4_provider import ( PicoControllerSnapshot, PicoControllerState, @@ -28,6 +30,16 @@ HAND_MODES = ("off", "gripper", "vr_hand_pose") DEFAULT_SOMEHAND_CONFIG_PATH = "third_party/somehand/configs/retargeting/bihand/linkerhand_l6_bihand.yaml" DEFAULT_LINKERHAND_SDK_ROOT = "third_party/linkerhand-python-sdk" +L6_QPOS_CHANNELS = ( + "thumb_cmc_pitch", + "thumb_cmc_yaw", + "index_mcp_pitch", + "middle_mcp_pitch", + "ring_mcp_pitch", + "pinky_mcp_pitch", +) +L6_QPOS_MIN = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=np.float64) +L6_QPOS_MAX = np.array([0.99, 1.39, 1.26, 1.26, 1.26, 1.26], dtype=np.float64) class ControllerSnapshotProvider(Protocol): @@ -104,6 +116,55 @@ def trigger_to_pose( return pose +class L6RetargetPoseMapper: + """Map somehand-retargeted L6 qpos into Teleopit's six-channel L6 SDK pose.""" + + def __init__(self, hand_model: Any | None, config: LinkerHandConfig, *, hand_type: str): + self._config = config + self._hand_type = hand_type + self._indices = self._resolve_indices(hand_model, hand_type=hand_type) + + def qpos_to_pose(self, qpos: Any) -> list[int]: + values = np.asarray(qpos, dtype=np.float64).reshape(-1) + if self._indices is None: + if values.shape[0] < len(L6_QPOS_CHANNELS): + raise ValueError( + "somehand L6 retarget qpos is too short: " + f"got {values.shape[0]}, need at least {len(L6_QPOS_CHANNELS)}" + ) + channel_values = values[:len(L6_QPOS_CHANNELS)] + else: + channel_values = values[self._indices] + + normalized = np.clip((channel_values - L6_QPOS_MIN) / (L6_QPOS_MAX - L6_QPOS_MIN), 0.0, 1.0) + pose = [ + int(round(float(open_value) + float(alpha) * (float(close_value) - float(open_value)))) + for open_value, close_value, alpha in zip( + self._config.open_pose, + self._config.close_pose, + normalized, + ) + ] + pose[1] = int(self._config.thumb_yaw_center) + return [_uint8(value, f"somehand.{self._hand_type}.pose") for value in pose] + + @staticmethod + def _resolve_indices(hand_model: Any | None, *, hand_type: str) -> np.ndarray | None: + if hand_model is None: + return None + get_index = getattr(hand_model, "get_joint_name_to_qpos_index", None) + if not callable(get_index): + return None + joint_index = get_index() + indices: list[int] = [] + for channel in L6_QPOS_CHANNELS: + resolved = _resolve_l6_joint_name(joint_index, channel, hand_type=hand_type) + if resolved is None: + return None + indices.append(int(joint_index[resolved])) + return np.asarray(indices, dtype=np.int64) + + def parse_linkerhand_config(cfg: Any) -> LinkerHandConfig: hand_cfg = cfg_get(cfg, "dexterous_hand", {}) or {} raw_mode = cfg_get(hand_cfg, "mode", None) @@ -472,7 +533,7 @@ def __init__(self, config: LinkerHandConfig, provider: HandSnapshotProvider): self._hand_frame_cls: Any | None = None self._bihand_frame_cls: Any | None = None self._pico_hand_to_landmarks: Any | None = None - self._adapters: dict[str, Any] = {} + self._pose_mappers: dict[str, L6RetargetPoseMapper] = {} @property def enabled(self) -> bool: @@ -526,7 +587,7 @@ def _tick_snapshot(self, snapshot: PicoHandSnapshot) -> None: ): if hand_type not in self.config.selected_hand_types or not detected: continue - pose = self._adapters[hand_type].qpos_to_sdk_range(step.qpos) + pose = self._pose_mappers[hand_type].qpos_to_pose(step.qpos) self._sender.send(hand_type, pose, reason="vr-hand-pose") def _make_hand_frame(self, hand_type: str, state: PicoHandState) -> Any | None: @@ -548,7 +609,6 @@ def _deactivate(self, *, reason: str) -> None: def _load_somehand(self) -> None: try: from somehand.api import BiHandFrame, BiHandRetargetingEngine, HandFrame - from somehand.infrastructure import HandModel, LinkerHandModelAdapter, infer_linkerhand_model_family from somehand.pico_input import pico_hand_to_landmarks except ImportError as exc: raise ImportError( @@ -563,22 +623,20 @@ def _load_somehand(self) -> None: f"{config_path}. Initialize the submodule and download assets with " "scripts/setup/download_somehand_l6_assets.sh" ) - sdk_root = _resolve_project_path(self.config.somehand_sdk_root) self._engine = BiHandRetargetingEngine.from_config_path(str(config_path)) self._hand_frame_cls = HandFrame self._bihand_frame_cls = BiHandFrame self._pico_hand_to_landmarks = pico_hand_to_landmarks - self._adapters = {} + # somehand owns hand-pose retargeting; Teleopit owns the LinkerHand L6 command mapping. + self._pose_mappers = {} for hand_type, engine in (("left", self._engine.left_engine), ("right", self._engine.right_engine)): if hand_type not in self.config.selected_hand_types: continue - family = infer_linkerhand_model_family(engine.config.hand.name) - self._adapters[hand_type] = LinkerHandModelAdapter( - HandModel(engine.config.hand.mjcf_path), - family=family, - hand_side=hand_type, - sdk_root=str(sdk_root), + self._pose_mappers[hand_type] = L6RetargetPoseMapper( + getattr(engine, "hand_model", None), + self.config, + hand_type=hand_type, ) logger.info("somehand LinkerHand L6 runtime started | hands=%s", ",".join(self.config.selected_hand_types)) @@ -632,6 +690,23 @@ def _resolve_project_path(path_value: str) -> Path: return (PROJECT_ROOT / path).resolve() +def _resolve_l6_joint_name(joint_index: dict[str, int], semantic_name: str, *, hand_type: str) -> str | None: + candidates = ( + semantic_name, + f"{hand_type}_{semantic_name}", + f"{hand_type[0]}_{semantic_name}", + f"{'lh' if hand_type == 'left' else 'rh'}_{semantic_name}", + ) + for candidate in candidates: + if candidate in joint_index: + return candidate + suffix = f"_{semantic_name}" + for name in joint_index: + if name.endswith(suffix): + return name + return None + + def _positive_float(value: object, field_name: str) -> float: parsed = float(value) if parsed <= 0.0: diff --git a/tests/test_dexterous_hand.py b/tests/test_dexterous_hand.py index 95171d28..71537c41 100644 --- a/tests/test_dexterous_hand.py +++ b/tests/test_dexterous_hand.py @@ -276,7 +276,7 @@ def close_can(self) -> None: assert created_hands[0].hand.close_calls == 1 -def _install_fake_somehand(monkeypatch, *, left_pose: list[int], right_pose: list[int]) -> None: +def _install_fake_somehand(monkeypatch, *, left_qpos: list[float], right_qpos: list[float]) -> None: class FakeHandFrame: def __init__(self, *, landmarks_3d, landmarks_2d, hand_side): self.landmarks_3d = landmarks_3d @@ -288,13 +288,26 @@ def __init__(self, *, left=None, right=None): self.left = left self.right = right + class FakeHandModel: + def get_joint_name_to_qpos_index(self): + return { + "thumb_cmc_pitch": 0, + "thumb_cmc_yaw": 1, + "index_mcp_pitch": 2, + "middle_mcp_pitch": 3, + "ring_mcp_pitch": 4, + "pinky_mcp_pitch": 5, + } + class FakeEngine: def __init__(self): self.left_engine = SimpleNamespace( - config=SimpleNamespace(hand=SimpleNamespace(name="linkerhand_l6_left", mjcf_path="left.xml")) + config=SimpleNamespace(hand=SimpleNamespace(name="linkerhand_l6_left", mjcf_path="left.xml")), + hand_model=FakeHandModel(), ) self.right_engine = SimpleNamespace( - config=SimpleNamespace(hand=SimpleNamespace(name="linkerhand_l6_right", mjcf_path="right.xml")) + config=SimpleNamespace(hand=SimpleNamespace(name="linkerhand_l6_right", mjcf_path="right.xml")), + hand_model=FakeHandModel(), ) @classmethod @@ -305,40 +318,26 @@ def process(self, frame): return SimpleNamespace( left_detected=frame.left is not None, right_detected=frame.right is not None, - left=SimpleNamespace(qpos=np.array([1.0], dtype=np.float64)), - right=SimpleNamespace(qpos=np.array([2.0], dtype=np.float64)), + left=SimpleNamespace(qpos=np.asarray(left_qpos, dtype=np.float64)), + right=SimpleNamespace(qpos=np.asarray(right_qpos, dtype=np.float64)), ) - class FakeHandModel: - def __init__(self, mjcf_path: str): - self.mjcf_path = mjcf_path - - class FakeAdapter: - def __init__(self, _hand_model, *, family: str, hand_side: str, sdk_root: str): - del family, sdk_root - self.hand_side = hand_side - - def qpos_to_sdk_range(self, _qpos): - return left_pose if self.hand_side == "left" else right_pose - fake_api = SimpleNamespace( BiHandFrame=FakeBiHandFrame, BiHandRetargetingEngine=FakeEngine, HandFrame=FakeHandFrame, ) - fake_infra = SimpleNamespace( - HandModel=FakeHandModel, - LinkerHandModelAdapter=FakeAdapter, - infer_linkerhand_model_family=lambda _name: "L6", - ) fake_pico = SimpleNamespace(pico_hand_to_landmarks=lambda joints: np.asarray(joints, dtype=np.float64)[:21, :3]) monkeypatch.setitem(sys.modules, "somehand.api", fake_api) - monkeypatch.setitem(sys.modules, "somehand.infrastructure", fake_infra) monkeypatch.setitem(sys.modules, "somehand.pico_input", fake_pico) def test_vr_hand_pose_runtime_holds_last_pose_when_hand_pose_disappears(monkeypatch, tmp_path) -> None: - _install_fake_somehand(monkeypatch, left_pose=[1, 2, 3, 4, 5, 6], right_pose=[6, 5, 4, 3, 2, 1]) + _install_fake_somehand( + monkeypatch, + left_qpos=[0.99, 0.0, 1.26, 1.26, 1.26, 1.26], + right_qpos=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ) config_path = tmp_path / "linkerhand_l6_bihand.yaml" config_path.write_text("left: {}\nright: {}\n", encoding="utf-8") provider = HandSnapshotProvider() @@ -363,8 +362,8 @@ def test_vr_hand_pose_runtime_holds_last_pose_when_hand_pose_disappears(monkeypa runtime.tick(active=True, now_s=10.0) assert runtime._sender.wait_idle(timeout_s=1.0) - assert runtime._sender._last_pose["left"] == [1, 2, 3, 4, 5, 6] - assert runtime._sender._last_pose["right"] == [6, 5, 4, 3, 2, 1] + assert runtime._sender._last_pose["left"] == list(runtime.config.close_pose) + assert runtime._sender._last_pose["right"] == list(runtime.config.open_pose) provider.snapshot = _hand_snapshot( left=_hand_state(active=False, value=9.0), @@ -375,8 +374,8 @@ def test_vr_hand_pose_runtime_holds_last_pose_when_hand_pose_disappears(monkeypa runtime.tick(active=True, now_s=10.1) assert runtime._sender.wait_idle(timeout_s=1.0) - assert runtime._sender._last_pose["left"] == [1, 2, 3, 4, 5, 6] - assert runtime._sender._last_pose["right"] == [6, 5, 4, 3, 2, 1] + assert runtime._sender._last_pose["left"] == list(runtime.config.close_pose) + assert runtime._sender._last_pose["right"] == list(runtime.config.open_pose) runtime.tick(active=False, now_s=10.2) assert runtime._sender.wait_idle(timeout_s=1.0) From a1f25b9381f689da36a7252c5484b674d5e430d3 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Wed, 27 May 2026 23:56:07 +0800 Subject: [PATCH 048/122] Correct L6 VR hand pose arc mapping --- teleopit/sim2real/dexterous_hand.py | 26 +++++++++----------------- tests/test_dexterous_hand.py | 8 ++++---- 2 files changed, 13 insertions(+), 21 deletions(-) diff --git a/teleopit/sim2real/dexterous_hand.py b/teleopit/sim2real/dexterous_hand.py index 3a2890f4..29c7e50f 100644 --- a/teleopit/sim2real/dexterous_hand.py +++ b/teleopit/sim2real/dexterous_hand.py @@ -38,8 +38,9 @@ "ring_mcp_pitch", "pinky_mcp_pitch", ) -L6_QPOS_MIN = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=np.float64) -L6_QPOS_MAX = np.array([0.99, 1.39, 1.26, 1.26, 1.26, 1.26], dtype=np.float64) +L6_ARC_MIN = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=np.float64) +L6_ARC_MAX = np.array([0.99, 1.39, 1.26, 1.26, 1.26, 1.26], dtype=np.float64) +L6_ARC_DIRECTION = np.array([-1, -1, -1, -1, -1, -1], dtype=np.int8) class ControllerSnapshotProvider(Protocol): @@ -117,10 +118,9 @@ def trigger_to_pose( class L6RetargetPoseMapper: - """Map somehand-retargeted L6 qpos into Teleopit's six-channel L6 SDK pose.""" + """Map somehand-retargeted L6 qpos into six-channel LinkerHand L6 SDK range.""" - def __init__(self, hand_model: Any | None, config: LinkerHandConfig, *, hand_type: str): - self._config = config + def __init__(self, hand_model: Any | None, *, hand_type: str): self._hand_type = hand_type self._indices = self._resolve_indices(hand_model, hand_type=hand_type) @@ -136,17 +136,10 @@ def qpos_to_pose(self, qpos: Any) -> list[int]: else: channel_values = values[self._indices] - normalized = np.clip((channel_values - L6_QPOS_MIN) / (L6_QPOS_MAX - L6_QPOS_MIN), 0.0, 1.0) - pose = [ - int(round(float(open_value) + float(alpha) * (float(close_value) - float(open_value)))) - for open_value, close_value, alpha in zip( - self._config.open_pose, - self._config.close_pose, - normalized, - ) - ] - pose[1] = int(self._config.thumb_yaw_center) - return [_uint8(value, f"somehand.{self._hand_type}.pose") for value in pose] + arc = np.clip(channel_values, L6_ARC_MIN, L6_ARC_MAX) + normalized = (arc - L6_ARC_MIN) / (L6_ARC_MAX - L6_ARC_MIN) + sdk_range = np.where(L6_ARC_DIRECTION < 0, 255.0 - normalized * 255.0, normalized * 255.0) + return [_uint8(round(float(value)), f"somehand.{self._hand_type}.pose") for value in sdk_range] @staticmethod def _resolve_indices(hand_model: Any | None, *, hand_type: str) -> np.ndarray | None: @@ -635,7 +628,6 @@ def _load_somehand(self) -> None: continue self._pose_mappers[hand_type] = L6RetargetPoseMapper( getattr(engine, "hand_model", None), - self.config, hand_type=hand_type, ) logger.info("somehand LinkerHand L6 runtime started | hands=%s", ",".join(self.config.selected_hand_types)) diff --git a/tests/test_dexterous_hand.py b/tests/test_dexterous_hand.py index 71537c41..4a96fd28 100644 --- a/tests/test_dexterous_hand.py +++ b/tests/test_dexterous_hand.py @@ -362,8 +362,8 @@ def test_vr_hand_pose_runtime_holds_last_pose_when_hand_pose_disappears(monkeypa runtime.tick(active=True, now_s=10.0) assert runtime._sender.wait_idle(timeout_s=1.0) - assert runtime._sender._last_pose["left"] == list(runtime.config.close_pose) - assert runtime._sender._last_pose["right"] == list(runtime.config.open_pose) + assert runtime._sender._last_pose["left"] == [0, 255, 0, 0, 0, 0] + assert runtime._sender._last_pose["right"] == [255, 255, 255, 255, 255, 255] provider.snapshot = _hand_snapshot( left=_hand_state(active=False, value=9.0), @@ -374,8 +374,8 @@ def test_vr_hand_pose_runtime_holds_last_pose_when_hand_pose_disappears(monkeypa runtime.tick(active=True, now_s=10.1) assert runtime._sender.wait_idle(timeout_s=1.0) - assert runtime._sender._last_pose["left"] == list(runtime.config.close_pose) - assert runtime._sender._last_pose["right"] == list(runtime.config.open_pose) + assert runtime._sender._last_pose["left"] == [0, 255, 0, 0, 0, 0] + assert runtime._sender._last_pose["right"] == [255, 255, 255, 255, 255, 255] runtime.tick(active=False, now_s=10.2) assert runtime._sender.wait_idle(timeout_s=1.0) From d4908b532c310707bf7055a6eb57714de51250af Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Thu, 28 May 2026 14:29:22 +0800 Subject: [PATCH 049/122] Fix LinkerHand L6 pose mapping --- teleopit/sim2real/dexterous_hand.py | 139 +++++++++++++++++++++------- tests/test_dexterous_hand.py | 49 ++++++++++ 2 files changed, 157 insertions(+), 31 deletions(-) diff --git a/teleopit/sim2real/dexterous_hand.py b/teleopit/sim2real/dexterous_hand.py index 29c7e50f..b184012a 100644 --- a/teleopit/sim2real/dexterous_hand.py +++ b/teleopit/sim2real/dexterous_hand.py @@ -3,6 +3,8 @@ from __future__ import annotations from dataclasses import dataclass +from functools import lru_cache +import importlib.util import logging from pathlib import Path import threading @@ -30,7 +32,7 @@ HAND_MODES = ("off", "gripper", "vr_hand_pose") DEFAULT_SOMEHAND_CONFIG_PATH = "third_party/somehand/configs/retargeting/bihand/linkerhand_l6_bihand.yaml" DEFAULT_LINKERHAND_SDK_ROOT = "third_party/linkerhand-python-sdk" -L6_QPOS_CHANNELS = ( +L6_SDK_JOINT_ORDER = ( "thumb_cmc_pitch", "thumb_cmc_yaw", "index_mcp_pitch", @@ -38,9 +40,6 @@ "ring_mcp_pitch", "pinky_mcp_pitch", ) -L6_ARC_MIN = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=np.float64) -L6_ARC_MAX = np.array([0.99, 1.39, 1.26, 1.26, 1.26, 1.26], dtype=np.float64) -L6_ARC_DIRECTION = np.array([-1, -1, -1, -1, -1, -1], dtype=np.int8) class ControllerSnapshotProvider(Protocol): @@ -120,40 +119,70 @@ def trigger_to_pose( class L6RetargetPoseMapper: """Map somehand-retargeted L6 qpos into six-channel LinkerHand L6 SDK range.""" - def __init__(self, hand_model: Any | None, *, hand_type: str): + def __init__(self, hand_model: Any | None, *, hand_type: str, sdk_root: str): self._hand_type = hand_type self._indices = self._resolve_indices(hand_model, hand_type=hand_type) + self._mapping = _load_linkerhand_mapping_module(sdk_root) + self._arc_min, self._arc_max, self._arc_direction = _sdk_l6_range_params( + self._mapping, + hand_type=hand_type, + ) def qpos_to_pose(self, qpos: Any) -> list[int]: values = np.asarray(qpos, dtype=np.float64).reshape(-1) - if self._indices is None: - if values.shape[0] < len(L6_QPOS_CHANNELS): - raise ValueError( - "somehand L6 retarget qpos is too short: " - f"got {values.shape[0]}, need at least {len(L6_QPOS_CHANNELS)}" + max_index = int(np.max(self._indices)) + if values.shape[0] <= max_index: + raise ValueError( + "somehand L6 retarget qpos is too short for resolved SDK joint mapping: " + f"got {values.shape[0]}, need index {max_index}" + ) + channel_values = values[self._indices] + sdk_range = [] + for index, value in enumerate(channel_values): + arc = self._mapping.is_within_range( + float(value), + float(self._arc_min[index]), + float(self._arc_max[index]), + ) + if int(self._arc_direction[index]) == -1: + sdk_range.append( + self._mapping.scale_value( + arc, + float(self._arc_min[index]), + float(self._arc_max[index]), + 255.0, + 0.0, + ) + ) + else: + sdk_range.append( + self._mapping.scale_value( + arc, + float(self._arc_min[index]), + float(self._arc_max[index]), + 0.0, + 255.0, + ) ) - channel_values = values[:len(L6_QPOS_CHANNELS)] - else: - channel_values = values[self._indices] - - arc = np.clip(channel_values, L6_ARC_MIN, L6_ARC_MAX) - normalized = (arc - L6_ARC_MIN) / (L6_ARC_MAX - L6_ARC_MIN) - sdk_range = np.where(L6_ARC_DIRECTION < 0, 255.0 - normalized * 255.0, normalized * 255.0) return [_uint8(round(float(value)), f"somehand.{self._hand_type}.pose") for value in sdk_range] @staticmethod - def _resolve_indices(hand_model: Any | None, *, hand_type: str) -> np.ndarray | None: + def _resolve_indices(hand_model: Any | None, *, hand_type: str) -> np.ndarray: if hand_model is None: - return None + raise ValueError("somehand L6 hand model is missing; cannot map retarget qpos to SDK joints") get_index = getattr(hand_model, "get_joint_name_to_qpos_index", None) if not callable(get_index): - return None + raise ValueError("somehand L6 hand model does not expose get_joint_name_to_qpos_index()") joint_index = get_index() indices: list[int] = [] - for channel in L6_QPOS_CHANNELS: + for channel in L6_SDK_JOINT_ORDER: resolved = _resolve_l6_joint_name(joint_index, channel, hand_type=hand_type) if resolved is None: - return None + available = ", ".join(sorted(str(name) for name in joint_index)) + raise ValueError( + f"Cannot resolve LinkerHand L6 SDK joint {channel!r} in somehand {hand_type} hand model. " + f"Available joints: {available}" + ) indices.append(int(joint_index[resolved])) return np.asarray(indices, dtype=np.int64) @@ -629,6 +658,7 @@ def _load_somehand(self) -> None: self._pose_mappers[hand_type] = L6RetargetPoseMapper( getattr(engine, "hand_model", None), hand_type=hand_type, + sdk_root=self.config.somehand_sdk_root, ) logger.info("somehand LinkerHand L6 runtime started | hands=%s", ",".join(self.config.selected_hand_types)) @@ -682,23 +712,70 @@ def _resolve_project_path(path_value: str) -> Path: return (PROJECT_ROOT / path).resolve() +@lru_cache(maxsize=4) +def _load_linkerhand_mapping_module(sdk_root: str) -> Any: + mapping_path = _resolve_project_path(sdk_root) / "LinkerHand" / "utils" / "mapping.py" + if not mapping_path.exists(): + raise FileNotFoundError(f"LinkerHand SDK mapping module not found: {mapping_path}") + spec = importlib.util.spec_from_file_location("teleopit_linkerhand_mapping", mapping_path) + if spec is None or spec.loader is None: + raise RuntimeError(f"Cannot load LinkerHand SDK mapping module from: {mapping_path}") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def _sdk_l6_range_params(mapping: Any, *, hand_type: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + side = "l" if hand_type == "left" else "r" + arc_min = np.asarray(getattr(mapping, f"l6_{side}_min"), dtype=np.float64) + arc_max = np.asarray(getattr(mapping, f"l6_{side}_max"), dtype=np.float64) + direction = np.asarray(getattr(mapping, f"l6_{side}_derict"), dtype=np.int8) + expected_shape = (len(L6_SDK_JOINT_ORDER),) + if arc_min.shape != expected_shape or arc_max.shape != expected_shape or direction.shape != expected_shape: + raise ValueError( + "LinkerHand SDK L6 mapping has unexpected shape: " + f"min={arc_min.shape}, max={arc_max.shape}, direction={direction.shape}" + ) + return arc_min, arc_max, direction + + def _resolve_l6_joint_name(joint_index: dict[str, int], semantic_name: str, *, hand_type: str) -> str | None: - candidates = ( - semantic_name, - f"{hand_type}_{semantic_name}", - f"{hand_type[0]}_{semantic_name}", - f"{'lh' if hand_type == 'left' else 'rh'}_{semantic_name}", + aliases = _l6_joint_aliases(semantic_name) + side_prefixes = ( + "", + f"{hand_type}_", + f"{hand_type[0]}_", + f"{hand_type[0].upper()}_", + f"{'lh' if hand_type == 'left' else 'rh'}_", ) + candidates = tuple(f"{prefix}{alias}" for alias in aliases for prefix in side_prefixes) for candidate in candidates: if candidate in joint_index: return candidate - suffix = f"_{semantic_name}" - for name in joint_index: - if name.endswith(suffix): - return name + for alias in aliases: + suffix = f"_{alias}" + for name in joint_index: + if name == alias or name.endswith(suffix): + return name return None +def _l6_joint_aliases(semantic_name: str) -> tuple[str, ...]: + if semantic_name == "thumb_cmc_pitch": + return ("thumb_cmc_pitch", "thumb_pitch") + if semantic_name == "thumb_cmc_yaw": + return ("thumb_cmc_yaw", "thumb_yaw") + aliases = [semantic_name] + if semantic_name.endswith("_mcp_pitch"): + finger = semantic_name[: -len("_mcp_pitch")] + aliases.append(f"{finger}_pitch") + if finger == "pinky": + aliases.extend(("little_mcp_pitch", "little_pitch")) + elif finger == "little": + aliases.extend(("pinky_mcp_pitch", "pinky_pitch")) + return tuple(dict.fromkeys(aliases)) + + def _positive_float(value: object, field_name: str) -> float: parsed = float(value) if parsed <= 0.0: diff --git a/tests/test_dexterous_hand.py b/tests/test_dexterous_hand.py index 4a96fd28..3be96721 100644 --- a/tests/test_dexterous_hand.py +++ b/tests/test_dexterous_hand.py @@ -9,6 +9,7 @@ from teleopit.inputs.pico4_provider import PicoControllerSnapshot, PicoControllerState, PicoHandSnapshot, PicoHandState from teleopit.sim2real.dexterous_hand import ( L6PoseSender, + L6RetargetPoseMapper, LinkerHandRuntime, SomeHandPoseRuntime, parse_linkerhand_config, @@ -384,6 +385,54 @@ def test_vr_hand_pose_runtime_holds_last_pose_when_hand_pose_disappears(monkeypa runtime.close() +def test_l6_retarget_pose_mapper_uses_sdk_order_and_model_joint_names() -> None: + class FakeHandModel: + def get_joint_name_to_qpos_index(self): + return { + "thumb_pitch": 2, + "thumb_yaw": 0, + "index_pitch": 5, + "middle_pitch": 1, + "ring_pitch": 4, + "little_pitch": 3, + } + + qpos = np.zeros(6, dtype=np.float64) + qpos[2] = 0.99 + qpos[0] = 0.0 + qpos[5] = 1.26 + qpos[1] = 0.0 + qpos[4] = 1.26 + qpos[3] = 0.0 + + mapper = L6RetargetPoseMapper( + FakeHandModel(), + hand_type="right", + sdk_root="third_party/linkerhand-python-sdk", + ) + + assert mapper.qpos_to_pose(qpos) == [0, 255, 0, 255, 0, 255] + + +def test_l6_retarget_pose_mapper_fails_when_model_joint_mapping_is_unknown() -> None: + class FakeHandModel: + def get_joint_name_to_qpos_index(self): + return { + "thumb_pitch": 0, + "thumb_yaw": 1, + "index_pitch": 2, + "middle_pitch": 3, + "ring_pitch": 4, + } + + with pytest.raises(ValueError, match="pinky_mcp_pitch"): + L6RetargetPoseMapper( + FakeHandModel(), + hand_type="right", + sdk_root="third_party/linkerhand-python-sdk", + ) + + def test_pose_sender_wraps_sdk_system_exit_and_cleans_up(monkeypatch) -> None: created_hands = [] From 8f1a36e1a92834d495123c303e1acdbe09ef24d8 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Thu, 28 May 2026 16:14:28 +0800 Subject: [PATCH 050/122] Fix LinkerHand L6 somehand joint mapping --- teleopit/sim2real/dexterous_hand.py | 2 +- tests/test_dexterous_hand.py | 34 +++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/teleopit/sim2real/dexterous_hand.py b/teleopit/sim2real/dexterous_hand.py index b184012a..6b65f0dd 100644 --- a/teleopit/sim2real/dexterous_hand.py +++ b/teleopit/sim2real/dexterous_hand.py @@ -764,7 +764,7 @@ def _l6_joint_aliases(semantic_name: str) -> tuple[str, ...]: if semantic_name == "thumb_cmc_pitch": return ("thumb_cmc_pitch", "thumb_pitch") if semantic_name == "thumb_cmc_yaw": - return ("thumb_cmc_yaw", "thumb_yaw") + return ("thumb_cmc_yaw", "thumb_yaw", "thumb_cmc_roll", "thumb_roll") aliases = [semantic_name] if semantic_name.endswith("_mcp_pitch"): finger = semantic_name[: -len("_mcp_pitch")] diff --git a/tests/test_dexterous_hand.py b/tests/test_dexterous_hand.py index 3be96721..48b99697 100644 --- a/tests/test_dexterous_hand.py +++ b/tests/test_dexterous_hand.py @@ -414,6 +414,40 @@ def get_joint_name_to_qpos_index(self): assert mapper.qpos_to_pose(qpos) == [0, 255, 0, 255, 0, 255] +def test_l6_retarget_pose_mapper_supports_somehand_l6_prefixed_roll_joint_names() -> None: + class FakeHandModel: + def get_joint_name_to_qpos_index(self): + return { + "lh_thumb_cmc_pitch": 8, + "lh_thumb_cmc_roll": 9, + "lh_thumb_dip": 10, + "lh_index_mcp_pitch": 1, + "lh_index_dip": 0, + "lh_middle_mcp_pitch": 3, + "lh_middle_dip": 2, + "lh_ring_mcp_pitch": 5, + "lh_ring_dip": 4, + "lh_pinky_mcp_pitch": 7, + "lh_pinky_dip": 6, + } + + qpos = np.zeros(11, dtype=np.float64) + qpos[8] = 0.99 + qpos[9] = 0.0 + qpos[1] = 1.26 + qpos[3] = 0.0 + qpos[5] = 1.26 + qpos[7] = 0.0 + + mapper = L6RetargetPoseMapper( + FakeHandModel(), + hand_type="left", + sdk_root="third_party/linkerhand-python-sdk", + ) + + assert mapper.qpos_to_pose(qpos) == [0, 255, 0, 255, 0, 255] + + def test_l6_retarget_pose_mapper_fails_when_model_joint_mapping_is_unknown() -> None: class FakeHandModel: def get_joint_name_to_qpos_index(self): From 457d384783a28c981f25265eb91bd35247a0b2ca Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Thu, 28 May 2026 16:37:16 +0800 Subject: [PATCH 051/122] Adjust LinkerHand VR hand pose speed --- AGENTS.md | 1 + README.md | 1 + docs/docs/configuration/config-reference.md | 5 +++- docs/docs/tutorials/pico-sim2real.md | 6 +++-- .../current/configuration/config-reference.md | 4 ++- .../current/tutorials/pico-sim2real.md | 3 ++- teleopit/configs/pico4_sim2real.yaml | 2 +- teleopit/configs/sim2real.yaml | 2 +- teleopit/sim2real/dexterous_hand.py | 5 +++- tests/test_dexterous_hand.py | 27 +++++++++++++++++++ 10 files changed, 48 insertions(+), 8 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index eec9346d..ab98ea23 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -145,6 +145,7 @@ target_dof_pos = clip(action, -10, 10) × action_scale + default_dof_pos - Optional LinkerHand L6 control uses `dexterous_hand.mode=off|gripper|vr_hand_pose`; default is `off` - `gripper` mode reuses `Pico4InputProvider.get_controller_snapshot()` for Pico grip/trigger open-close control - `vr_hand_pose` mode reuses `Pico4InputProvider.get_hand_snapshot()` and `somehand` for continuous Pico hand-pose retargeting; do not start a second `PicoBridge` for hand control +- `gripper` mode uses the configured `dexterous_hand.speed` (default `[50]*6`); `vr_hand_pose` always sets LinkerHand L6 speed to `[255]*6` - LinkerHand L6 control is active only in sim2real `MOCAP`; `STANDING`, `DAMPING`, mocap pause, and shutdown must send the configured open pose - In `vr_hand_pose` mode, missing/inactive hand pose holds the last commanded pose for that side instead of opening the hand diff --git a/README.md b/README.md index cc3ffc81..d982cf31 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,7 @@ Full docs at **[BotRunner64.github.io/Teleopit](https://BotRunner64.github.io/Te - Bumped Pico input support to pico-bridge 0.2.1 and its corrected tracking pose semantics. - Added optional LinkerHand L6 sim2real modes: `gripper` from Pico grip/trigger and `vr_hand_pose` from Pico hand pose through somehand. +- Set LinkerHand L6 `vr_hand_pose` control to maximum speed while keeping `gripper` at the configured default speed. - Realtime mode switches and pause/resume now preserve GMR IK warm-starts instead of cold-starting the retargeter on each transition. ### v0.3.0 (2026-05-12) diff --git a/docs/docs/configuration/config-reference.md b/docs/docs/configuration/config-reference.md index 07368c08..15f75a24 100644 --- a/docs/docs/configuration/config-reference.md +++ b/docs/docs/configuration/config-reference.md @@ -124,7 +124,9 @@ Realtime Pico resume re-centers heading and ground-plane position before trackin `dexterous_hand.mode=gripper` or `dexterous_hand.mode=vr_hand_pose` requires `input.provider=pico4` and the optional `dexhand` extra. Control is active only in `MOCAP`; inactive modes send the open pose. In `vr_hand_pose`, missing hand -pose holds the last command for that side. +pose holds the last command for that side. `gripper` uses the configured +`dexterous_hand.speed`; `vr_hand_pose` always sets LinkerHand L6 speed to the +maximum. | Field | Description | Default | |-------|-------------|---------| @@ -133,6 +135,7 @@ pose holds the last command for that side. | `dexterous_hand.left_can` / `right_can` | CAN channels for each hand | `can0` / `can1` | | `dexterous_hand.rate` | Maximum command rate in Hz | `30.0` | | `dexterous_hand.frame_timeout` | Gripper controller timeout, or VR hand-pose staleness threshold | `0.3` | +| `dexterous_hand.speed` | L6 speed used by `gripper`; `vr_hand_pose` overrides this to maximum speed | see config | | `dexterous_hand.deadman_threshold` | Minimum grip value required to enable a side | `0.5` | | `dexterous_hand.trigger_deadzone` | Trigger deadzone at both ends | `0.05` | | `dexterous_hand.open_pose` / `close_pose` | Six-value L6 open/closed poses | see config | diff --git a/docs/docs/tutorials/pico-sim2real.md b/docs/docs/tutorials/pico-sim2real.md index 49c687d1..f87e12c7 100644 --- a/docs/docs/tutorials/pico-sim2real.md +++ b/docs/docs/tutorials/pico-sim2real.md @@ -145,10 +145,12 @@ reference changes when live tracking resumes. Pico sim2real can drive LinkerHand L6 hands in two modes: - `gripper`: hold the matching side grip as a deadman switch; the matching - trigger closes that hand. + trigger closes that hand. This mode uses the configured + `dexterous_hand.speed`, which defaults to 50. - `vr_hand_pose`: retarget Pico hand pose through somehand and command the continuous L6 hand target. If a hand pose disappears, that side keeps its last - commanded pose. This mode currently uses `hand_type=both`. + commanded pose. This mode currently uses `hand_type=both` and always sets L6 + speed to the maximum. Hand control is active only in `MOCAP`. It sends the open pose in `STANDING`, `DAMPING`, paused mocap, and shutdown. diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md index 478d35eb..605e6eb2 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md @@ -143,7 +143,8 @@ MuJoCo 窗口显示重定向参考;`sim2sim`、`mocap`、`camera` 和 `all` `dexterous_hand.mode=gripper` 或 `dexterous_hand.mode=vr_hand_pose` 要求 `input.provider=pico4`,并安装可选的 `dexhand` extra。控制只在 `MOCAP` 中生效;非活动模式会发送张开姿态。在 `vr_hand_pose` 中,手部 pose 消失时, -对应侧会保持上一条命令。 +对应侧会保持上一条命令。`gripper` 使用配置的 `dexterous_hand.speed`; +`vr_hand_pose` 始终将 LinkerHand L6 速度设为最大值。 | 字段 | 说明 | 默认值 | |---|---|---| @@ -152,6 +153,7 @@ MuJoCo 窗口显示重定向参考;`sim2sim`、`mocap`、`camera` 和 `all` | `dexterous_hand.left_can` / `right_can` | 左右手 CAN 通道 | `can0` / `can1` | | `dexterous_hand.rate` | 最大命令频率(Hz) | `30.0` | | `dexterous_hand.frame_timeout` | gripper 手柄超时或 VR 手部 pose 过期阈值 | `0.3` | +| `dexterous_hand.speed` | `gripper` 使用的 L6 速度;`vr_hand_pose` 会覆盖为最大速度 | 见配置 | | `dexterous_hand.deadman_threshold` | 启用单侧控制所需的最小 grip 值 | `0.5` | | `dexterous_hand.trigger_deadzone` | trigger 两端死区 | `0.05` | | `dexterous_hand.open_pose` / `close_pose` | L6 的 6 维张开/闭合姿态 | 见配置 | diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md index b4eb1f1a..42c5713d 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md @@ -137,9 +137,10 @@ Pico 暂停/恢复是 mocap-session control event。 Pico sim2real 可以用两种模式控制 LinkerHand L6: - `gripper`:按住同侧 grip 作为 deadman,同侧 trigger 控制对应手闭合。 + 该模式使用配置的 `dexterous_hand.speed`,默认值为 50。 - `vr_hand_pose`:通过 somehand 重定向 Pico 手部 pose,并下发连续 L6 手部目标。 如果某侧手部 pose 消失,该侧会保持上一条手势命令。这个模式当前使用 - `hand_type=both`。 + `hand_type=both`,并始终将 L6 速度设为最大值。 手控只在 `MOCAP` 中生效;在 `STANDING`、`DAMPING`、mocap 暂停和退出时都会发送张开姿态。 diff --git a/teleopit/configs/pico4_sim2real.yaml b/teleopit/configs/pico4_sim2real.yaml index 3d639e94..88a67a2a 100644 --- a/teleopit/configs/pico4_sim2real.yaml +++ b/teleopit/configs/pico4_sim2real.yaml @@ -40,7 +40,7 @@ dexterous_hand: trigger_deadzone: 0.05 deadman_threshold: 0.5 thumb_yaw_center: 10 - speed: [50, 50, 50, 50, 50, 50] + speed: [50, 50, 50, 50, 50, 50] # gripper mode; vr_hand_pose always uses max speed open_pose: [250, 10, 250, 250, 250, 250] close_pose: [79, 10, 0, 0, 0, 0] print_input: false diff --git a/teleopit/configs/sim2real.yaml b/teleopit/configs/sim2real.yaml index b483a881..09ee6456 100644 --- a/teleopit/configs/sim2real.yaml +++ b/teleopit/configs/sim2real.yaml @@ -39,7 +39,7 @@ dexterous_hand: trigger_deadzone: 0.05 deadman_threshold: 0.5 thumb_yaw_center: 10 - speed: [50, 50, 50, 50, 50, 50] + speed: [50, 50, 50, 50, 50, 50] # gripper mode; vr_hand_pose always uses max speed open_pose: [250, 10, 250, 250, 250, 250] close_pose: [79, 10, 0, 0, 0, 0] print_input: false diff --git a/teleopit/sim2real/dexterous_hand.py b/teleopit/sim2real/dexterous_hand.py index 6b65f0dd..5bb87820 100644 --- a/teleopit/sim2real/dexterous_hand.py +++ b/teleopit/sim2real/dexterous_hand.py @@ -28,6 +28,7 @@ OPEN_POSE = [250, THUMB_YAW_DEFAULT, 250, 250, 250, 250] CLOSE_POSE = [79, THUMB_YAW_DEFAULT, 0, 0, 0, 0] DEFAULT_SPEED = [50, 50, 50, 50, 50, 50] +VR_HAND_POSE_SPEED = [255, 255, 255, 255, 255, 255] HAND_TYPES = ("left", "right") HAND_MODES = ("off", "gripper", "vr_hand_pose") DEFAULT_SOMEHAND_CONFIG_PATH = "third_party/somehand/configs/retargeting/bihand/linkerhand_l6_bihand.yaml" @@ -199,6 +200,8 @@ def parse_linkerhand_config(cfg: Any) -> LinkerHandConfig: open_pose[1] = thumb_yaw close_pose[1] = thumb_yaw + speed = VR_HAND_POSE_SPEED if mode == "vr_hand_pose" else _pose_values(cfg_get(hand_cfg, "speed", DEFAULT_SPEED), "speed") + config = LinkerHandConfig( mode=mode, enabled=mode != "off", @@ -212,7 +215,7 @@ def parse_linkerhand_config(cfg: Any) -> LinkerHandConfig: trigger_deadzone=_trigger_deadzone(cfg_get(hand_cfg, "trigger_deadzone", 0.05)), deadman_threshold=_deadman_threshold(cfg_get(hand_cfg, "deadman_threshold", 0.5)), thumb_yaw_center=thumb_yaw, - speed=tuple(_pose_values(cfg_get(hand_cfg, "speed", DEFAULT_SPEED), "speed")), + speed=tuple(speed), open_pose=tuple(open_pose), close_pose=tuple(close_pose), print_input=bool(cfg_get(hand_cfg, "print_input", False)), diff --git a/tests/test_dexterous_hand.py b/tests/test_dexterous_hand.py index 48b99697..3384e8e1 100644 --- a/tests/test_dexterous_hand.py +++ b/tests/test_dexterous_hand.py @@ -141,6 +141,33 @@ def test_trigger_to_pose_applies_deadzone_and_fixed_thumb_yaw() -> None: assert pose == [164, 10, 125, 125, 125, 125] +def test_parse_config_keeps_gripper_default_speed() -> None: + cfg = parse_linkerhand_config( + { + "dexterous_hand": { + "mode": "gripper", + "hand_type": "both", + } + } + ) + + assert cfg.speed == (50, 50, 50, 50, 50, 50) + + +def test_parse_config_sets_vr_hand_pose_speed_to_max() -> None: + cfg = parse_linkerhand_config( + { + "dexterous_hand": { + "mode": "vr_hand_pose", + "hand_type": "both", + "speed": [50, 50, 50, 50, 50, 50], + } + } + ) + + assert cfg.speed == (255, 255, 255, 255, 255, 255) + + def test_runtime_opens_when_deadman_released() -> None: provider = SnapshotProvider() runtime = _runtime(provider) From bd5af86e482c7291b4ba442ff697df35369ec5d9 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Thu, 28 May 2026 16:42:58 +0800 Subject: [PATCH 052/122] Adjust LinkerHand vr hand pose speed --- scripts/dev/test_linkerhand_l6.py | 5 ++++- tests/test_dexterous_hand.py | 5 +++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/scripts/dev/test_linkerhand_l6.py b/scripts/dev/test_linkerhand_l6.py index e7ae2993..1f427424 100644 --- a/scripts/dev/test_linkerhand_l6.py +++ b/scripts/dev/test_linkerhand_l6.py @@ -27,6 +27,7 @@ LinkerHandConfig, LinkerHandRuntime, SomeHandPoseRuntime, + VR_HAND_POSE_SPEED, ) @@ -103,6 +104,7 @@ def parse_args() -> argparse.Namespace: type=uint8, nargs=6, default=DEFAULT_SPEED, + help="L6 speed for open_close and gripper modes. vr_hand_pose always uses max speed.", metavar=("THUMB_PITCH", "THUMB_YAW", "INDEX", "MIDDLE", "RING", "LITTLE"), ) parser.add_argument( @@ -137,6 +139,7 @@ def parse_args() -> argparse.Namespace: def make_config(args: argparse.Namespace, *, mode: str) -> LinkerHandConfig: + speed = VR_HAND_POSE_SPEED if mode == "vr_hand_pose" else args.speed return LinkerHandConfig( mode=mode, enabled=True, @@ -150,7 +153,7 @@ def make_config(args: argparse.Namespace, *, mode: str) -> LinkerHandConfig: trigger_deadzone=args.trigger_deadzone, deadman_threshold=args.deadman_threshold, thumb_yaw_center=args.thumb_yaw_center, - speed=tuple(args.speed), + speed=tuple(speed), open_pose=tuple(args.open_pose), close_pose=tuple(args.close_pose), print_input=args.print_input, diff --git a/tests/test_dexterous_hand.py b/tests/test_dexterous_hand.py index 3384e8e1..1597a3f1 100644 --- a/tests/test_dexterous_hand.py +++ b/tests/test_dexterous_hand.py @@ -12,6 +12,7 @@ L6RetargetPoseMapper, LinkerHandRuntime, SomeHandPoseRuntime, + VR_HAND_POSE_SPEED, parse_linkerhand_config, trigger_to_pose, ) @@ -168,6 +169,10 @@ def test_parse_config_sets_vr_hand_pose_speed_to_max() -> None: assert cfg.speed == (255, 255, 255, 255, 255, 255) +def test_vr_hand_pose_speed_constant_is_max() -> None: + assert tuple(VR_HAND_POSE_SPEED) == (255, 255, 255, 255, 255, 255) + + def test_runtime_opens_when_deadman_released() -> None: provider = SnapshotProvider() runtime = _runtime(provider) From 45a3ba50f41286fe7398be1054722724b0cecfb0 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Thu, 28 May 2026 17:08:34 +0800 Subject: [PATCH 053/122] Reduce vr hand pose latency --- AGENTS.md | 1 + README.md | 1 + docs/docs/configuration/config-reference.md | 11 +- docs/docs/tutorials/pico-sim2real.md | 4 +- .../current/configuration/config-reference.md | 11 +- .../current/tutorials/pico-sim2real.md | 3 +- scripts/dev/test_linkerhand_l6.py | 6 + teleopit/configs/pico4_sim2real.yaml | 6 + teleopit/configs/sim2real.yaml | 6 + teleopit/sim2real/dexterous_hand.py | 165 +++++++++++++++++- tests/test_dexterous_hand.py | 92 ++++++++++ 11 files changed, 297 insertions(+), 9 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index ab98ea23..b9864522 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -146,6 +146,7 @@ target_dof_pos = clip(action, -10, 10) × action_scale + default_dof_pos - `gripper` mode reuses `Pico4InputProvider.get_controller_snapshot()` for Pico grip/trigger open-close control - `vr_hand_pose` mode reuses `Pico4InputProvider.get_hand_snapshot()` and `somehand` for continuous Pico hand-pose retargeting; do not start a second `PicoBridge` for hand control - `gripper` mode uses the configured `dexterous_hand.speed` (default `[50]*6`); `vr_hand_pose` always sets LinkerHand L6 speed to `[255]*6` +- `vr_hand_pose` defaults to a low-latency somehand path: `dexterous_hand.somehand.rate=60`, `threaded=true`, `max_iterations=12`, `temporal_filter_alpha=1.0`, and `output_alpha=1.0`; this prioritizes response speed over smoothing - LinkerHand L6 control is active only in sim2real `MOCAP`; `STANDING`, `DAMPING`, mocap pause, and shutdown must send the configured open pose - In `vr_hand_pose` mode, missing/inactive hand pose holds the last commanded pose for that side instead of opening the hand diff --git a/README.md b/README.md index d982cf31..099a83d1 100644 --- a/README.md +++ b/README.md @@ -68,6 +68,7 @@ Full docs at **[BotRunner64.github.io/Teleopit](https://BotRunner64.github.io/Te - Bumped Pico input support to pico-bridge 0.2.1 and its corrected tracking pose semantics. - Added optional LinkerHand L6 sim2real modes: `gripper` from Pico grip/trigger and `vr_hand_pose` from Pico hand pose through somehand. - Set LinkerHand L6 `vr_hand_pose` control to maximum speed while keeping `gripper` at the configured default speed. +- Switched default `vr_hand_pose` to a low-latency somehand path with 60 Hz threaded hand retargeting and reduced smoothing. - Realtime mode switches and pause/resume now preserve GMR IK warm-starts instead of cold-starting the retargeter on each transition. ### v0.3.0 (2026-05-12) diff --git a/docs/docs/configuration/config-reference.md b/docs/docs/configuration/config-reference.md index 15f75a24..fd1aad80 100644 --- a/docs/docs/configuration/config-reference.md +++ b/docs/docs/configuration/config-reference.md @@ -126,20 +126,27 @@ Realtime Pico resume re-centers heading and ground-plane position before trackin in `MOCAP`; inactive modes send the open pose. In `vr_hand_pose`, missing hand pose holds the last command for that side. `gripper` uses the configured `dexterous_hand.speed`; `vr_hand_pose` always sets LinkerHand L6 speed to the -maximum. +maximum. The default `vr_hand_pose` path favors low latency: it runs in a +background thread at `dexterous_hand.somehand.rate` and disables most somehand +input/output smoothing, which can make finger motion noisier. | Field | Description | Default | |-------|-------------|---------| | `dexterous_hand.mode` | `off`, `gripper`, or `vr_hand_pose` | `off` | | `dexterous_hand.hand_type` | Controlled side: `left`, `right`, or `both`; `vr_hand_pose` requires `both` | `both` | | `dexterous_hand.left_can` / `right_can` | CAN channels for each hand | `can0` / `can1` | -| `dexterous_hand.rate` | Maximum command rate in Hz | `30.0` | +| `dexterous_hand.rate` | Maximum gripper command rate in Hz | `30.0` | | `dexterous_hand.frame_timeout` | Gripper controller timeout, or VR hand-pose staleness threshold | `0.3` | | `dexterous_hand.speed` | L6 speed used by `gripper`; `vr_hand_pose` overrides this to maximum speed | see config | | `dexterous_hand.deadman_threshold` | Minimum grip value required to enable a side | `0.5` | | `dexterous_hand.trigger_deadzone` | Trigger deadzone at both ends | `0.05` | | `dexterous_hand.open_pose` / `close_pose` | Six-value L6 open/closed poses | see config | | `dexterous_hand.somehand.config_path` | somehand bi-hand L6 config used by `vr_hand_pose` | see config | +| `dexterous_hand.somehand.rate` | Low-latency `vr_hand_pose` command rate in Hz | `60.0` | +| `dexterous_hand.somehand.threaded` | Run `vr_hand_pose` hand retargeting outside the robot control loop | `true` | +| `dexterous_hand.somehand.max_iterations` | somehand solver iteration cap for `vr_hand_pose` | `12` | +| `dexterous_hand.somehand.temporal_filter_alpha` | somehand input landmark smoothing alpha; `1.0` disables smoothing delay | `1.0` | +| `dexterous_hand.somehand.output_alpha` | somehand qpos output smoothing alpha; `1.0` disables smoothing delay | `1.0` | ## Critical: `default_dof_pos` diff --git a/docs/docs/tutorials/pico-sim2real.md b/docs/docs/tutorials/pico-sim2real.md index f87e12c7..1eb8bb10 100644 --- a/docs/docs/tutorials/pico-sim2real.md +++ b/docs/docs/tutorials/pico-sim2real.md @@ -150,7 +150,9 @@ Pico sim2real can drive LinkerHand L6 hands in two modes: - `vr_hand_pose`: retarget Pico hand pose through somehand and command the continuous L6 hand target. If a hand pose disappears, that side keeps its last commanded pose. This mode currently uses `hand_type=both` and always sets L6 - speed to the maximum. + speed to the maximum. The default configuration uses a low-latency somehand + path at 60 Hz with reduced smoothing, so it should feel more responsive but + can be noisier than the standard somehand settings. Hand control is active only in `MOCAP`. It sends the open pose in `STANDING`, `DAMPING`, paused mocap, and shutdown. diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md index 605e6eb2..5aabbddd 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md @@ -144,17 +144,24 @@ MuJoCo 窗口显示重定向参考;`sim2sim`、`mocap`、`camera` 和 `all` `input.provider=pico4`,并安装可选的 `dexhand` extra。控制只在 `MOCAP` 中生效;非活动模式会发送张开姿态。在 `vr_hand_pose` 中,手部 pose 消失时, 对应侧会保持上一条命令。`gripper` 使用配置的 `dexterous_hand.speed`; -`vr_hand_pose` 始终将 LinkerHand L6 速度设为最大值。 +`vr_hand_pose` 始终将 LinkerHand L6 速度设为最大值。默认的 `vr_hand_pose` +路径优先降低延时:它会按 `dexterous_hand.somehand.rate` 在后台线程运行,并关闭 +大部分 somehand 输入/输出平滑,因此手指运动可能更抖。 | 字段 | 说明 | 默认值 | |---|---|---| | `dexterous_hand.mode` | `off`、`gripper` 或 `vr_hand_pose` | `off` | | `dexterous_hand.hand_type` | 控制侧:`left`、`right` 或 `both`;`vr_hand_pose` 要求 `both` | `both` | | `dexterous_hand.left_can` / `right_can` | 左右手 CAN 通道 | `can0` / `can1` | -| `dexterous_hand.rate` | 最大命令频率(Hz) | `30.0` | +| `dexterous_hand.rate` | gripper 最大命令频率(Hz) | `30.0` | | `dexterous_hand.frame_timeout` | gripper 手柄超时或 VR 手部 pose 过期阈值 | `0.3` | | `dexterous_hand.speed` | `gripper` 使用的 L6 速度;`vr_hand_pose` 会覆盖为最大速度 | 见配置 | | `dexterous_hand.deadman_threshold` | 启用单侧控制所需的最小 grip 值 | `0.5` | | `dexterous_hand.trigger_deadzone` | trigger 两端死区 | `0.05` | | `dexterous_hand.open_pose` / `close_pose` | L6 的 6 维张开/闭合姿态 | 见配置 | | `dexterous_hand.somehand.config_path` | `vr_hand_pose` 使用的 somehand 双手 L6 配置 | 见配置 | +| `dexterous_hand.somehand.rate` | 低延时 `vr_hand_pose` 命令频率(Hz) | `60.0` | +| `dexterous_hand.somehand.threaded` | 在机器人控制循环外运行 `vr_hand_pose` 手部重定向 | `true` | +| `dexterous_hand.somehand.max_iterations` | `vr_hand_pose` 的 somehand solver 迭代上限 | `12` | +| `dexterous_hand.somehand.temporal_filter_alpha` | somehand 输入 landmarks 平滑 alpha;`1.0` 表示关闭平滑延时 | `1.0` | +| `dexterous_hand.somehand.output_alpha` | somehand qpos 输出平滑 alpha;`1.0` 表示关闭平滑延时 | `1.0` | diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md index 42c5713d..5169aeaa 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md @@ -140,7 +140,8 @@ Pico sim2real 可以用两种模式控制 LinkerHand L6: 该模式使用配置的 `dexterous_hand.speed`,默认值为 50。 - `vr_hand_pose`:通过 somehand 重定向 Pico 手部 pose,并下发连续 L6 手部目标。 如果某侧手部 pose 消失,该侧会保持上一条手势命令。这个模式当前使用 - `hand_type=both`,并始终将 L6 速度设为最大值。 + `hand_type=both`,并始终将 L6 速度设为最大值。默认配置使用 60 Hz 的低延时 + somehand 路径并减少平滑,所以响应会更快,但可能比标准 somehand 设置更抖。 手控只在 `MOCAP` 中生效;在 `STANDING`、`DAMPING`、mocap 暂停和退出时都会发送张开姿态。 diff --git a/scripts/dev/test_linkerhand_l6.py b/scripts/dev/test_linkerhand_l6.py index 1f427424..95f92deb 100644 --- a/scripts/dev/test_linkerhand_l6.py +++ b/scripts/dev/test_linkerhand_l6.py @@ -140,6 +140,7 @@ def parse_args() -> argparse.Namespace: def make_config(args: argparse.Namespace, *, mode: str) -> LinkerHandConfig: speed = VR_HAND_POSE_SPEED if mode == "vr_hand_pose" else args.speed + vr_hand_pose = mode == "vr_hand_pose" return LinkerHandConfig( mode=mode, enabled=True, @@ -159,6 +160,11 @@ def make_config(args: argparse.Namespace, *, mode: str) -> LinkerHandConfig: print_input=args.print_input, somehand_config_path=args.somehand_config_path, somehand_sdk_root=args.somehand_sdk_root, + somehand_rate=args.rate if vr_hand_pose else None, + somehand_threaded=False, + somehand_max_iterations=12 if vr_hand_pose else None, + somehand_temporal_filter_alpha=1.0 if vr_hand_pose else None, + somehand_output_alpha=1.0 if vr_hand_pose else None, ) diff --git a/teleopit/configs/pico4_sim2real.yaml b/teleopit/configs/pico4_sim2real.yaml index 88a67a2a..562f9a94 100644 --- a/teleopit/configs/pico4_sim2real.yaml +++ b/teleopit/configs/pico4_sim2real.yaml @@ -47,6 +47,12 @@ dexterous_hand: somehand: config_path: third_party/somehand/configs/retargeting/bihand/linkerhand_l6_bihand.yaml sdk_root: third_party/linkerhand-python-sdk + # Low-latency vr_hand_pose path. This favors response speed over smoothing. + rate: 60.0 + threaded: true + max_iterations: 12 + temporal_filter_alpha: 1.0 + output_alpha: 1.0 # Physical robot SDK configuration real_robot: diff --git a/teleopit/configs/sim2real.yaml b/teleopit/configs/sim2real.yaml index 09ee6456..0690b3ce 100644 --- a/teleopit/configs/sim2real.yaml +++ b/teleopit/configs/sim2real.yaml @@ -46,6 +46,12 @@ dexterous_hand: somehand: config_path: third_party/somehand/configs/retargeting/bihand/linkerhand_l6_bihand.yaml sdk_root: third_party/linkerhand-python-sdk + # Low-latency vr_hand_pose path. This favors response speed over smoothing. + rate: 60.0 + threaded: true + max_iterations: 12 + temporal_filter_alpha: 1.0 + output_alpha: 1.0 # Physical robot SDK configuration real_robot: diff --git a/teleopit/sim2real/dexterous_hand.py b/teleopit/sim2real/dexterous_hand.py index 5bb87820..84776792 100644 --- a/teleopit/sim2real/dexterous_hand.py +++ b/teleopit/sim2real/dexterous_hand.py @@ -73,6 +73,11 @@ class LinkerHandConfig: print_input: bool = False somehand_config_path: str = DEFAULT_SOMEHAND_CONFIG_PATH somehand_sdk_root: str = DEFAULT_LINKERHAND_SDK_ROOT + somehand_rate: float | None = None + somehand_threaded: bool = False + somehand_max_iterations: int | None = None + somehand_temporal_filter_alpha: float | None = None + somehand_output_alpha: float | None = None @property def selected_hand_types(self) -> tuple[str, ...]: @@ -80,6 +85,10 @@ def selected_hand_types(self) -> tuple[str, ...]: return HAND_TYPES return (self.hand_type,) + @property + def vr_hand_pose_rate(self) -> float: + return self.somehand_rate if self.somehand_rate is not None else self.rate + def clamp_unit(value: float) -> float: return max(0.0, min(1.0, float(value))) @@ -221,6 +230,20 @@ def parse_linkerhand_config(cfg: Any) -> LinkerHandConfig: print_input=bool(cfg_get(hand_cfg, "print_input", False)), somehand_config_path=str(cfg_get(somehand_cfg, "config_path", DEFAULT_SOMEHAND_CONFIG_PATH)), somehand_sdk_root=str(cfg_get(somehand_cfg, "sdk_root", DEFAULT_LINKERHAND_SDK_ROOT)), + somehand_rate=_optional_positive_float(cfg_get(somehand_cfg, "rate", None), "somehand.rate"), + somehand_threaded=bool(cfg_get(somehand_cfg, "threaded", False)), + somehand_max_iterations=_optional_positive_int( + cfg_get(somehand_cfg, "max_iterations", None), + "somehand.max_iterations", + ), + somehand_temporal_filter_alpha=_optional_unit_interval( + cfg_get(somehand_cfg, "temporal_filter_alpha", None), + "somehand.temporal_filter_alpha", + ), + somehand_output_alpha=_optional_unit_interval( + cfg_get(somehand_cfg, "output_alpha", None), + "somehand.output_alpha", + ), ) if config.mode not in HAND_MODES: raise ValueError(f"dexterous_hand.mode must be one of {', '.join(HAND_MODES)}, got {config.mode!r}") @@ -550,7 +573,7 @@ def __init__(self, config: LinkerHandConfig, provider: HandSnapshotProvider): self.config = config self._provider = provider self._sender = AsyncL6PoseSender(config) - self._interval_s = 1.0 / config.rate + self._interval_s = 1.0 / config.vr_hand_pose_rate self._next_tick_s = 0.0 self._active = False self._last_status: dict[str, str] = {hand_type: "" for hand_type in config.selected_hand_types} @@ -649,6 +672,7 @@ def _load_somehand(self) -> None: "scripts/setup/download_somehand_l6_assets.sh" ) self._engine = BiHandRetargetingEngine.from_config_path(str(config_path)) + self._apply_low_latency_overrides() self._hand_frame_cls = HandFrame self._bihand_frame_cls = BiHandFrame self._pico_hand_to_landmarks = pico_hand_to_landmarks @@ -665,6 +689,28 @@ def _load_somehand(self) -> None: ) logger.info("somehand LinkerHand L6 runtime started | hands=%s", ",".join(self.config.selected_hand_types)) + def _apply_low_latency_overrides(self) -> None: + for hand_type, engine in (("left", self._engine.left_engine), ("right", self._engine.right_engine)): + retargeter = getattr(engine, "retargeter", None) + if retargeter is None: + continue + if self.config.somehand_max_iterations is not None: + setattr(retargeter, "_max_iterations", int(self.config.somehand_max_iterations)) + if self.config.somehand_output_alpha is not None: + setattr(retargeter, "_output_alpha", float(self.config.somehand_output_alpha)) + if self.config.somehand_temporal_filter_alpha is not None: + landmark_filter = getattr(retargeter, "landmark_filter", None) + if landmark_filter is not None: + setattr(landmark_filter, "alpha", float(self.config.somehand_temporal_filter_alpha)) + logger.info( + "somehand low-latency overrides | hand=%s rate=%.1fHz max_iter=%s temporal_alpha=%s output_alpha=%s", + hand_type, + self.config.vr_hand_pose_rate, + self.config.somehand_max_iterations, + self.config.somehand_temporal_filter_alpha, + self.config.somehand_output_alpha, + ) + def _set_status(self, hand_type: str, status: str, message: str) -> None: key = hand_type if self._last_status.get(key) == status: @@ -673,6 +719,89 @@ def _set_status(self, hand_type: str, status: str, message: str) -> None: logger.info("somehand LinkerHand L6: %s", message) +class ThreadedSomeHandPoseRuntime: + """Tick the somehand path independently from the robot control loop.""" + + def __init__(self, runtime: SomeHandPoseRuntime): + self._runtime = runtime + self._condition = threading.Condition() + self._runtime_lock = threading.Lock() + self._thread: threading.Thread | None = None + self._running = False + self._active = False + self._interval_s = 1.0 / runtime.config.vr_hand_pose_rate + + @property + def config(self) -> LinkerHandConfig: + return self._runtime.config + + @property + def enabled(self) -> bool: + return self._runtime.enabled + + def start(self) -> None: + if not self.enabled: + return + self._runtime.start() + with self._condition: + if self._running: + return + self._running = True + self._thread = threading.Thread( + target=self._run, + name="somehand-pose-runtime", + daemon=True, + ) + self._thread.start() + + def tick(self, *, active: bool, now_s: float | None = None) -> None: + del now_s + if not self.enabled: + return + should_deactivate = False + with self._condition: + if self._active != bool(active): + should_deactivate = self._active and not bool(active) + self._active = bool(active) + self._condition.notify_all() + if should_deactivate: + with self._runtime_lock: + self._runtime.tick(active=False) + + def close(self) -> None: + thread: threading.Thread | None + with self._condition: + self._running = False + self._active = False + self._condition.notify_all() + thread = self._thread + if thread is not None: + thread.join(timeout=2.0) + if thread.is_alive(): + logger.warning("somehand pose runtime worker did not stop within timeout") + with self._runtime_lock: + self._runtime.close() + + def _run(self) -> None: + next_tick_s = 0.0 + while True: + with self._condition: + while self._running and not self._active: + self._condition.wait() + next_tick_s = 0.0 + if not self._running: + return + active = self._active + now = time.monotonic() + if now < next_tick_s: + with self._condition: + self._condition.wait(timeout=next_tick_s - now) + continue + with self._runtime_lock: + self._runtime.tick(active=active, now_s=now) + next_tick_s = now + self._interval_s + + class DisabledLinkerHandRuntime: enabled = False @@ -686,7 +815,10 @@ def close(self) -> None: pass -def build_linkerhand_runtime(cfg: Any, input_provider: Any) -> LinkerHandRuntime | SomeHandPoseRuntime | DisabledLinkerHandRuntime: +def build_linkerhand_runtime( + cfg: Any, + input_provider: Any, +) -> LinkerHandRuntime | SomeHandPoseRuntime | ThreadedSomeHandPoseRuntime | DisabledLinkerHandRuntime: config = parse_linkerhand_config(cfg) if not config.enabled: return DisabledLinkerHandRuntime() @@ -704,7 +836,10 @@ def build_linkerhand_runtime(cfg: Any, input_provider: Any) -> LinkerHandRuntime raise ValueError("dexterous_hand.mode=vr_hand_pose currently requires dexterous_hand.hand_type=both") if not callable(getattr(input_provider, "get_hand_snapshot", None)): raise ValueError("dexterous_hand.mode=vr_hand_pose requires a Pico input provider with hand snapshots") - return SomeHandPoseRuntime(config, input_provider) + runtime = SomeHandPoseRuntime(config, input_provider) + if config.somehand_threaded: + return ThreadedSomeHandPoseRuntime(runtime) + return runtime raise ValueError(f"Unsupported dexterous_hand.mode={config.mode!r}") @@ -786,6 +921,30 @@ def _positive_float(value: object, field_name: str) -> float: return parsed +def _optional_positive_float(value: object, field_name: str) -> float | None: + if value is None: + return None + return _positive_float(value, field_name) + + +def _optional_positive_int(value: object, field_name: str) -> int | None: + if value is None: + return None + parsed = int(value) + if parsed <= 0: + raise ValueError(f"dexterous_hand.{field_name} must be > 0, got {value!r}") + return parsed + + +def _optional_unit_interval(value: object, field_name: str) -> float | None: + if value is None: + return None + parsed = float(value) + if parsed <= 0.0 or parsed > 1.0: + raise ValueError(f"dexterous_hand.{field_name} must be in (0, 1], got {value!r}") + return parsed + + def _uint8(value: object, field_name: str) -> int: parsed = int(value) if parsed < 0 or parsed > 255: diff --git a/tests/test_dexterous_hand.py b/tests/test_dexterous_hand.py index 1597a3f1..c5b54780 100644 --- a/tests/test_dexterous_hand.py +++ b/tests/test_dexterous_hand.py @@ -12,7 +12,9 @@ L6RetargetPoseMapper, LinkerHandRuntime, SomeHandPoseRuntime, + ThreadedSomeHandPoseRuntime, VR_HAND_POSE_SPEED, + build_linkerhand_runtime, parse_linkerhand_config, trigger_to_pose, ) @@ -169,6 +171,30 @@ def test_parse_config_sets_vr_hand_pose_speed_to_max() -> None: assert cfg.speed == (255, 255, 255, 255, 255, 255) +def test_parse_config_accepts_somehand_low_latency_overrides() -> None: + cfg = parse_linkerhand_config( + { + "dexterous_hand": { + "mode": "vr_hand_pose", + "hand_type": "both", + "somehand": { + "rate": 60.0, + "threaded": True, + "max_iterations": 12, + "temporal_filter_alpha": 1.0, + "output_alpha": 1.0, + }, + } + } + ) + + assert cfg.vr_hand_pose_rate == 60.0 + assert cfg.somehand_threaded is True + assert cfg.somehand_max_iterations == 12 + assert cfg.somehand_temporal_filter_alpha == 1.0 + assert cfg.somehand_output_alpha == 1.0 + + def test_vr_hand_pose_speed_constant_is_max() -> None: assert tuple(VR_HAND_POSE_SPEED) == (255, 255, 255, 255, 255, 255) @@ -332,15 +358,27 @@ def get_joint_name_to_qpos_index(self): "pinky_mcp_pitch": 5, } + class FakeLandmarkFilter: + def __init__(self): + self.alpha = 0.65 + + class FakeRetargeter: + def __init__(self): + self._max_iterations = 60 + self._output_alpha = 0.92 + self.landmark_filter = FakeLandmarkFilter() + class FakeEngine: def __init__(self): self.left_engine = SimpleNamespace( config=SimpleNamespace(hand=SimpleNamespace(name="linkerhand_l6_left", mjcf_path="left.xml")), hand_model=FakeHandModel(), + retargeter=FakeRetargeter(), ) self.right_engine = SimpleNamespace( config=SimpleNamespace(hand=SimpleNamespace(name="linkerhand_l6_right", mjcf_path="right.xml")), hand_model=FakeHandModel(), + retargeter=FakeRetargeter(), ) @classmethod @@ -417,6 +455,60 @@ def test_vr_hand_pose_runtime_holds_last_pose_when_hand_pose_disappears(monkeypa runtime.close() +def test_vr_hand_pose_runtime_applies_low_latency_overrides(monkeypatch, tmp_path) -> None: + _install_fake_somehand( + monkeypatch, + left_qpos=[0.99, 0.0, 1.26, 1.26, 1.26, 1.26], + right_qpos=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ) + config_path = tmp_path / "linkerhand_l6_bihand.yaml" + config_path.write_text("left: {}\nright: {}\n", encoding="utf-8") + provider = HandSnapshotProvider() + cfg = parse_linkerhand_config( + { + "input": {"provider": "pico4"}, + "dexterous_hand": { + "mode": "vr_hand_pose", + "hand_type": "both", + "somehand": { + "config_path": str(config_path), + "sdk_root": "third_party/linkerhand-python-sdk", + "rate": 60.0, + "max_iterations": 12, + "temporal_filter_alpha": 1.0, + "output_alpha": 1.0, + }, + }, + } + ) + runtime = SomeHandPoseRuntime(cfg, provider) + runtime.start() + + assert runtime._interval_s == pytest.approx(1.0 / 60.0) + assert runtime._engine.left_engine.retargeter._max_iterations == 12 + assert runtime._engine.left_engine.retargeter._output_alpha == 1.0 + assert runtime._engine.left_engine.retargeter.landmark_filter.alpha == 1.0 + assert runtime._engine.right_engine.retargeter._max_iterations == 12 + runtime.close() + + +def test_build_linkerhand_runtime_returns_threaded_vr_hand_pose_runtime() -> None: + provider = HandSnapshotProvider() + runtime = build_linkerhand_runtime( + { + "input": {"provider": "pico4"}, + "dexterous_hand": { + "mode": "vr_hand_pose", + "hand_type": "both", + "somehand": {"threaded": True}, + }, + }, + provider, + ) + + assert isinstance(runtime, ThreadedSomeHandPoseRuntime) + + def test_l6_retarget_pose_mapper_uses_sdk_order_and_model_joint_names() -> None: class FakeHandModel: def get_joint_name_to_qpos_index(self): From 77eb3c361c85ad98a63fa293385e79f41ed3a4e2 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Thu, 28 May 2026 20:19:59 +0800 Subject: [PATCH 054/122] Fix L6 thumb mapping and range --- teleopit/sim2real/dexterous_hand.py | 6 ++--- tests/test_dexterous_hand.py | 34 ++++++++++++++--------------- third_party/linkerhand-python-sdk | 2 +- 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/teleopit/sim2real/dexterous_hand.py b/teleopit/sim2real/dexterous_hand.py index 84776792..6b313671 100644 --- a/teleopit/sim2real/dexterous_hand.py +++ b/teleopit/sim2real/dexterous_hand.py @@ -35,7 +35,7 @@ DEFAULT_LINKERHAND_SDK_ROOT = "third_party/linkerhand-python-sdk" L6_SDK_JOINT_ORDER = ( "thumb_cmc_pitch", - "thumb_cmc_yaw", + "thumb_cmc_roll", "index_mcp_pitch", "middle_mcp_pitch", "ring_mcp_pitch", @@ -901,8 +901,8 @@ def _resolve_l6_joint_name(joint_index: dict[str, int], semantic_name: str, *, h def _l6_joint_aliases(semantic_name: str) -> tuple[str, ...]: if semantic_name == "thumb_cmc_pitch": return ("thumb_cmc_pitch", "thumb_pitch") - if semantic_name == "thumb_cmc_yaw": - return ("thumb_cmc_yaw", "thumb_yaw", "thumb_cmc_roll", "thumb_roll") + if semantic_name == "thumb_cmc_roll": + return ("thumb_cmc_roll", "thumb_roll") aliases = [semantic_name] if semantic_name.endswith("_mcp_pitch"): finger = semantic_name[: -len("_mcp_pitch")] diff --git a/tests/test_dexterous_hand.py b/tests/test_dexterous_hand.py index c5b54780..728b08f7 100644 --- a/tests/test_dexterous_hand.py +++ b/tests/test_dexterous_hand.py @@ -351,7 +351,7 @@ class FakeHandModel: def get_joint_name_to_qpos_index(self): return { "thumb_cmc_pitch": 0, - "thumb_cmc_yaw": 1, + "thumb_cmc_roll": 1, "index_mcp_pitch": 2, "middle_mcp_pitch": 3, "ring_mcp_pitch": 4, @@ -406,7 +406,7 @@ def process(self, frame): def test_vr_hand_pose_runtime_holds_last_pose_when_hand_pose_disappears(monkeypatch, tmp_path) -> None: _install_fake_somehand( monkeypatch, - left_qpos=[0.99, 0.0, 1.26, 1.26, 1.26, 1.26], + left_qpos=[0.837758, 0.0, 1.134464, 1.134464, 1.134464, 1.134464], right_qpos=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ) config_path = tmp_path / "linkerhand_l6_bihand.yaml" @@ -433,8 +433,8 @@ def test_vr_hand_pose_runtime_holds_last_pose_when_hand_pose_disappears(monkeypa runtime.tick(active=True, now_s=10.0) assert runtime._sender.wait_idle(timeout_s=1.0) - assert runtime._sender._last_pose["left"] == [0, 255, 0, 0, 0, 0] - assert runtime._sender._last_pose["right"] == [255, 255, 255, 255, 255, 255] + assert runtime._sender._last_pose["left"] == [0, 238, 0, 0, 0, 0] + assert runtime._sender._last_pose["right"] == [255, 238, 255, 255, 255, 255] provider.snapshot = _hand_snapshot( left=_hand_state(active=False, value=9.0), @@ -445,8 +445,8 @@ def test_vr_hand_pose_runtime_holds_last_pose_when_hand_pose_disappears(monkeypa runtime.tick(active=True, now_s=10.1) assert runtime._sender.wait_idle(timeout_s=1.0) - assert runtime._sender._last_pose["left"] == [0, 255, 0, 0, 0, 0] - assert runtime._sender._last_pose["right"] == [255, 255, 255, 255, 255, 255] + assert runtime._sender._last_pose["left"] == [0, 238, 0, 0, 0, 0] + assert runtime._sender._last_pose["right"] == [255, 238, 255, 255, 255, 255] runtime.tick(active=False, now_s=10.2) assert runtime._sender.wait_idle(timeout_s=1.0) @@ -458,7 +458,7 @@ def test_vr_hand_pose_runtime_holds_last_pose_when_hand_pose_disappears(monkeypa def test_vr_hand_pose_runtime_applies_low_latency_overrides(monkeypatch, tmp_path) -> None: _install_fake_somehand( monkeypatch, - left_qpos=[0.99, 0.0, 1.26, 1.26, 1.26, 1.26], + left_qpos=[0.837758, -0.087266, 1.134464, 1.134464, 1.134464, 1.134464], right_qpos=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ) config_path = tmp_path / "linkerhand_l6_bihand.yaml" @@ -514,7 +514,7 @@ class FakeHandModel: def get_joint_name_to_qpos_index(self): return { "thumb_pitch": 2, - "thumb_yaw": 0, + "thumb_roll": 0, "index_pitch": 5, "middle_pitch": 1, "ring_pitch": 4, @@ -522,11 +522,11 @@ def get_joint_name_to_qpos_index(self): } qpos = np.zeros(6, dtype=np.float64) - qpos[2] = 0.99 - qpos[0] = 0.0 - qpos[5] = 1.26 + qpos[2] = 0.837758 + qpos[0] = -0.087266 + qpos[5] = 1.134464 qpos[1] = 0.0 - qpos[4] = 1.26 + qpos[4] = 1.134464 qpos[3] = 0.0 mapper = L6RetargetPoseMapper( @@ -556,11 +556,11 @@ def get_joint_name_to_qpos_index(self): } qpos = np.zeros(11, dtype=np.float64) - qpos[8] = 0.99 - qpos[9] = 0.0 - qpos[1] = 1.26 + qpos[8] = 0.837758 + qpos[9] = -0.087266 + qpos[1] = 1.134464 qpos[3] = 0.0 - qpos[5] = 1.26 + qpos[5] = 1.134464 qpos[7] = 0.0 mapper = L6RetargetPoseMapper( @@ -577,7 +577,7 @@ class FakeHandModel: def get_joint_name_to_qpos_index(self): return { "thumb_pitch": 0, - "thumb_yaw": 1, + "thumb_roll": 1, "index_pitch": 2, "middle_pitch": 3, "ring_pitch": 4, diff --git a/third_party/linkerhand-python-sdk b/third_party/linkerhand-python-sdk index d884a720..40dbb8f8 160000 --- a/third_party/linkerhand-python-sdk +++ b/third_party/linkerhand-python-sdk @@ -1 +1 @@ -Subproject commit d884a72081539bb159855f54945e497eedefd31a +Subproject commit 40dbb8f85a98d636285fc23f391fa083d0a30724 From 674ce1816ea8348bc34b193fd259eaf399e3853e Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Thu, 28 May 2026 20:54:49 +0800 Subject: [PATCH 055/122] Fix dexterous hand mode parsing --- teleopit/configs/pico4_sim2real.yaml | 2 +- teleopit/configs/sim2real.yaml | 2 +- teleopit/sim2real/dexterous_hand.py | 5 ++++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/teleopit/configs/pico4_sim2real.yaml b/teleopit/configs/pico4_sim2real.yaml index 562f9a94..aaa21fc8 100644 --- a/teleopit/configs/pico4_sim2real.yaml +++ b/teleopit/configs/pico4_sim2real.yaml @@ -29,7 +29,7 @@ joint_vel_limit: 10.0 # Optional LinkerHand L6 control from Pico controller grip/trigger or VR hand pose. dexterous_hand: - mode: off # off | gripper | vr_hand_pose + mode: "off" # off | gripper | vr_hand_pose hand_joint: L6 hand_type: both left_can: can0 diff --git a/teleopit/configs/sim2real.yaml b/teleopit/configs/sim2real.yaml index 0690b3ce..6ef63c33 100644 --- a/teleopit/configs/sim2real.yaml +++ b/teleopit/configs/sim2real.yaml @@ -28,7 +28,7 @@ joint_vel_limit: 10.0 # Optional LinkerHand L6 control. Use only with input.provider=pico4. dexterous_hand: - mode: off # off | gripper | vr_hand_pose + mode: "off" # off | gripper | vr_hand_pose hand_joint: L6 hand_type: both left_can: can0 diff --git a/teleopit/sim2real/dexterous_hand.py b/teleopit/sim2real/dexterous_hand.py index 6b313671..04ae9812 100644 --- a/teleopit/sim2real/dexterous_hand.py +++ b/teleopit/sim2real/dexterous_hand.py @@ -201,7 +201,10 @@ def parse_linkerhand_config(cfg: Any) -> LinkerHandConfig: hand_cfg = cfg_get(cfg, "dexterous_hand", {}) or {} raw_mode = cfg_get(hand_cfg, "mode", None) legacy_enabled = bool(cfg_get(hand_cfg, "enabled", False)) - mode = str(raw_mode if raw_mode is not None else ("gripper" if legacy_enabled else "off")).lower() + if isinstance(raw_mode, bool): + mode = "gripper" if raw_mode else "off" + else: + mode = str(raw_mode if raw_mode is not None else ("gripper" if legacy_enabled else "off")).lower() somehand_cfg = cfg_get(hand_cfg, "somehand", {}) or {} thumb_yaw = _uint8(cfg_get(hand_cfg, "thumb_yaw_center", THUMB_YAW_DEFAULT), "thumb_yaw_center") open_pose = _pose_values(cfg_get(hand_cfg, "open_pose", OPEN_POSE), "open_pose") From ad81cfb6098abb987fbddadd04cc86e05814af1c Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Thu, 28 May 2026 22:02:14 +0800 Subject: [PATCH 056/122] Add sim2real timing diagnostics --- teleopit/sim2real/controller.py | 130 ++++++++++++++++++++++++++++---- tests/test_sim2real_runtime.py | 22 ++++++ 2 files changed, 138 insertions(+), 14 deletions(-) diff --git a/teleopit/sim2real/controller.py b/teleopit/sim2real/controller.py index 23d78779..dfdc5eaa 100644 --- a/teleopit/sim2real/controller.py +++ b/teleopit/sim2real/controller.py @@ -64,6 +64,89 @@ class RobotMode(Enum): DAMPING = "damping" # Emergency stop / recovery +class _LoopTimingReporter: + """Aggregate best-effort control-loop timing stats and emit periodic logs.""" + + def __init__(self, *, target_period_s: float, log_interval_s: float = 1.0) -> None: + self._target_period_s = float(target_period_s) + self._log_interval_s = float(log_interval_s) + self._window_start_s: float | None = None + self._loop_ms: list[float] = [] + self._work_ms: list[float] = [] + self._pico_age_ms: list[float] = [] + self._overrun_count = 0 + + def record( + self, + *, + loop_start_s: float, + work_elapsed_s: float, + cycle_elapsed_s: float, + pico_age_s: float | None, + ) -> None: + if self._window_start_s is None: + self._window_start_s = float(loop_start_s) + + self._loop_ms.append(float(cycle_elapsed_s) * 1000.0) + self._work_ms.append(float(work_elapsed_s) * 1000.0) + if pico_age_s is not None: + self._pico_age_ms.append(float(pico_age_s) * 1000.0) + if cycle_elapsed_s > self._target_period_s + 1e-9: + self._overrun_count += 1 + + if loop_start_s - self._window_start_s >= self._log_interval_s: + self._emit(loop_start_s) + + def _emit(self, end_s: float) -> None: + sample_count = len(self._loop_ms) + if sample_count <= 0: + self._reset(end_s) + return + + loop_summary = self._summarize(self._loop_ms) + work_summary = self._summarize(self._work_ms) + message = ( + "Timing stats | samples=%d window=%.1fs | " + "loop_ms p50=%.2f p95=%.2f p99=%.2f max=%.2f overrun=%d/%d | " + "work_ms p50=%.2f p95=%.2f p99=%.2f max=%.2f" + ) + args: list[object] = [ + sample_count, + end_s - float(self._window_start_s), + loop_summary[0], + loop_summary[1], + loop_summary[2], + loop_summary[3], + self._overrun_count, + sample_count, + work_summary[0], + work_summary[1], + work_summary[2], + work_summary[3], + ] + if self._pico_age_ms: + pico_summary = self._summarize(self._pico_age_ms) + message += " | pico_age_ms p50=%.2f p95=%.2f p99=%.2f max=%.2f" + args.extend([pico_summary[0], pico_summary[1], pico_summary[2], pico_summary[3]]) + logger.info(message, *args) + self._reset(end_s) + + def _reset(self, window_start_s: float) -> None: + self._window_start_s = float(window_start_s) + self._loop_ms.clear() + self._work_ms.clear() + self._pico_age_ms.clear() + self._overrun_count = 0 + + @staticmethod + def _summarize(samples: list[float]) -> tuple[float, float, float, float]: + values = np.asarray(samples, dtype=np.float64) + if values.size <= 0: + return 0.0, 0.0, 0.0, 0.0 + p50, p95, p99 = np.percentile(values, [50.0, 95.0, 99.0]) + return float(p50), float(p95), float(p99), float(np.max(values)) + + def _parse_sim2real_viewers(cfg: Any) -> set[str]: viewers = parse_viewers(cfg) unsupported = viewers.difference({"retarget"}) @@ -266,6 +349,7 @@ def run(self) -> None: "Control loop started | mode=IDLE | press Start to enter STANDING" ) dt = 1.0 / self.policy_hz + timing = _LoopTimingReporter(target_period_s=dt) try: self._video_runtime.start() @@ -284,22 +368,26 @@ def run(self) -> None: logger.warning("EMERGENCY STOP (L1+R1)") self._enter_damping() self._tick_dexterous_hand() - self._sleep_until(t0, dt) - continue - - # 3. Mode transitions - self._handle_transitions() + else: + # 3. Mode transitions + self._handle_transitions() - # 5. Execute current mode - if self.mode == RobotMode.STANDING: - self._standing_step() - elif self.mode == RobotMode.MOCAP: - self._mocap_step() + # 5. Execute current mode + if self.mode == RobotMode.STANDING: + self._standing_step() + elif self.mode == RobotMode.MOCAP: + self._mocap_step() - self._tick_dexterous_hand() + self._tick_dexterous_hand() - # 6. Rate control - self._sleep_until(t0, dt) + work_elapsed_s = time.monotonic() - t0 + cycle_elapsed_s = self._sleep_until(t0, dt) + timing.record( + loop_start_s=t0, + work_elapsed_s=work_elapsed_s, + cycle_elapsed_s=cycle_elapsed_s, + pico_age_s=self._sample_pico_frame_age_s(), + ) except KeyboardInterrupt: logger.info("KeyboardInterrupt -- shutting down") @@ -944,12 +1032,26 @@ def _write_retarget_viewer(self, qpos: Float64Array) -> None: logger.exception("Sim2real retarget viewer update failed; control continues") @staticmethod - def _sleep_until(t0: float, dt: float) -> None: + def _sleep_until(t0: float, dt: float) -> float: """Sleep to maintain control frequency.""" elapsed = time.monotonic() - t0 remaining = dt - elapsed if remaining > 0: time.sleep(remaining) + return time.monotonic() - t0 + + def _sample_pico_frame_age_s(self) -> float | None: + has_frame = getattr(self.input_provider, "has_frame", None) + get_frame_packet = getattr(self.input_provider, "get_frame_packet", None) + if not callable(has_frame) or not callable(get_frame_packet): + return None + try: + if not has_frame(): + return None + _, frame_timestamp_s, _ = get_frame_packet() + except Exception: + return None + return max(0.0, time.monotonic() - float(frame_timestamp_s)) # ------------------------------------------------------------------ # Lifecycle diff --git a/tests/test_sim2real_runtime.py b/tests/test_sim2real_runtime.py index df0469b3..39c38008 100644 --- a/tests/test_sim2real_runtime.py +++ b/tests/test_sim2real_runtime.py @@ -316,6 +316,28 @@ def test_sim2real_retarget_viewer_rejects_sim_viewers(monkeypatch) -> None: Sim2RealController(cfg) +def test_loop_timing_reporter_logs_percentiles_and_overruns(caplog) -> None: + import logging + + from teleopit.sim2real.controller import _LoopTimingReporter + + reporter = _LoopTimingReporter(target_period_s=0.02, log_interval_s=1.0) + with caplog.at_level(logging.INFO, logger="teleopit.sim2real.controller"): + reporter.record(loop_start_s=0.0, work_elapsed_s=0.005, cycle_elapsed_s=0.020, pico_age_s=0.010) + reporter.record(loop_start_s=0.5, work_elapsed_s=0.006, cycle_elapsed_s=0.021, pico_age_s=0.012) + reporter.record(loop_start_s=1.0, work_elapsed_s=0.007, cycle_elapsed_s=0.050, pico_age_s=0.030) + + text = caplog.text + assert "Timing stats" in text + assert "loop_ms p50=" in text + assert "p95=" in text + assert "p99=" in text + assert "max=" in text + assert "overrun=2/3" in text + assert "work_ms p50=" in text + assert "pico_age_ms p50=" in text + + def test_sim2real_rejects_nonzero_reference_steps_without_buffer(monkeypatch) -> None: from teleopit.sim2real.controller import Sim2RealController From 1ce7f3fdafe16b10b380bed2ae33327d9bbf401f Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Fri, 29 May 2026 16:18:15 +0800 Subject: [PATCH 057/122] Fix multiprocess sim2real reset handling --- pyproject.toml | 1 + scripts/run/run_sim2real.py | 13 +- teleopit/configs/pico4_sim2real.yaml | 13 + teleopit/configs/sim2real.yaml | 13 + teleopit/sim2real/mp/__init__.py | 11 + teleopit/sim2real/mp/ipc.py | 122 +++ teleopit/sim2real/mp/messages.py | 91 ++ teleopit/sim2real/mp/runtime.py | 1277 ++++++++++++++++++++++++++ teleopit/sim2real/mp/shm.py | 106 +++ tests/test_sim2real_multiprocess.py | 257 ++++++ 10 files changed, 1903 insertions(+), 1 deletion(-) create mode 100644 teleopit/sim2real/mp/__init__.py create mode 100644 teleopit/sim2real/mp/ipc.py create mode 100644 teleopit/sim2real/mp/messages.py create mode 100644 teleopit/sim2real/mp/runtime.py create mode 100644 teleopit/sim2real/mp/shm.py create mode 100644 tests/test_sim2real_multiprocess.py diff --git a/pyproject.toml b/pyproject.toml index 7a858fe7..924f10e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ dependencies = [ "omegaconf", "h5py", "onnxruntime", + "pyzmq", "rich", "loop-rate-limiters", "imageio", diff --git a/scripts/run/run_sim2real.py b/scripts/run/run_sim2real.py index 84619387..5aa312ed 100644 --- a/scripts/run/run_sim2real.py +++ b/scripts/run/run_sim2real.py @@ -6,6 +6,7 @@ from omegaconf import DictConfig from teleopit.runtime.cli import validate_policy_path +from teleopit.sim2real.mp import MultiprocessSim2RealController, resolve_sim2real_runtime_mode from teleopit.sim2real.controller import Sim2RealController @@ -26,10 +27,20 @@ def _print_sim2real_controls(cfg: DictConfig) -> None: @hydra.main(version_base=None, config_path="../../teleopit/configs", config_name="sim2real") def main(cfg: DictConfig) -> None: + _run_sim2real(cfg) + + +def _run_sim2real(cfg: DictConfig) -> None: validate_policy_path(cfg, "run_sim2real.py") - controller = Sim2RealController(cfg) + runtime_mode = resolve_sim2real_runtime_mode(cfg) + controller = ( + MultiprocessSim2RealController(cfg) + if runtime_mode == "multiprocess" + else Sim2RealController(cfg) + ) if cfg.input.get("provider") == "pico4": print("Waiting for Pico4 body tracking data...") + print(f"Sim2real runtime: {runtime_mode}") _print_sim2real_controls(cfg) try: controller.run() diff --git a/teleopit/configs/pico4_sim2real.yaml b/teleopit/configs/pico4_sim2real.yaml index aaa21fc8..a038de2e 100644 --- a/teleopit/configs/pico4_sim2real.yaml +++ b/teleopit/configs/pico4_sim2real.yaml @@ -5,6 +5,7 @@ defaults: - _self_ policy_hz: 50.0 +sim2real_runtime: multiprocess viewers: "none" # Optional: set viewers=retarget to show the retargeted reference input: video: @@ -18,6 +19,18 @@ reference_anchor_velocity_smoothing_alpha: 0.25 reference_steps: [0] reference_debug_log: false +multiprocess: + host: 127.0.0.1 + base_port: 39700 + start_method: spawn + shutdown_timeout_s: 3.0 + pico_io_hz: 120.0 + hand_worker_hz: 120.0 + retarget_idle_sleep_s: 0.001 + video_slots: 3 + stale_reference_hold_s: 0.08 + max_reference_age_s: 0.25 + # Kp ramp duration (seconds) -- gradually increases PD gains after entering STANDING startup_ramp_duration: 2.0 kp_ramp_floor_ratio: 0.1 diff --git a/teleopit/configs/sim2real.yaml b/teleopit/configs/sim2real.yaml index 6ef63c33..c2ebe287 100644 --- a/teleopit/configs/sim2real.yaml +++ b/teleopit/configs/sim2real.yaml @@ -5,6 +5,7 @@ defaults: - _self_ policy_hz: 50.0 +sim2real_runtime: auto # auto | single_process | multiprocess; auto uses multiprocess for Pico4 viewers: "none" # Optional: set viewers=retarget to show the retargeted reference retarget_buffer_enabled: true retarget_buffer_window_s: 0.5 @@ -17,6 +18,18 @@ reference_debug_log: false playback: pause_on_end: true +multiprocess: + host: 127.0.0.1 + base_port: 39700 + start_method: spawn + shutdown_timeout_s: 3.0 + pico_io_hz: 120.0 + hand_worker_hz: 120.0 + retarget_idle_sleep_s: 0.001 + video_slots: 3 + stale_reference_hold_s: 0.08 + max_reference_age_s: 0.25 + # Kp ramp duration (seconds) -- gradually increases PD gains after entering STANDING startup_ramp_duration: 2.0 kp_ramp_floor_ratio: 0.1 diff --git a/teleopit/sim2real/mp/__init__.py b/teleopit/sim2real/mp/__init__.py new file mode 100644 index 00000000..59adde7c --- /dev/null +++ b/teleopit/sim2real/mp/__init__.py @@ -0,0 +1,11 @@ +"""Multiprocess sim2real runtime.""" + +from teleopit.sim2real.mp.runtime import ( + MultiprocessSim2RealController, + resolve_sim2real_runtime_mode, +) + +__all__ = [ + "MultiprocessSim2RealController", + "resolve_sim2real_runtime_mode", +] diff --git a/teleopit/sim2real/mp/ipc.py b/teleopit/sim2real/mp/ipc.py new file mode 100644 index 00000000..02864c73 --- /dev/null +++ b/teleopit/sim2real/mp/ipc.py @@ -0,0 +1,122 @@ +"""Small-message IPC helpers for multiprocess sim2real.""" + +from __future__ import annotations + +from dataclasses import dataclass +import pickle +import time +from typing import Any + +import zmq + + +BODY_TOPIC = "body" +HAND_TOPIC = "hand" +CONTROLLER_TOPIC = "controller" +CONTROL_EVENTS_TOPIC = "control_events" +REFERENCE_TOPIC = "reference" +MODE_TOPIC = "mode" +REFERENCE_RESET_TOPIC = "reference_reset" +VIDEO_TOPIC = "video" +HEALTH_TOPIC = "health" +COMMAND_TOPIC = "command" + + +@dataclass(frozen=True) +class Sim2RealIpcEndpoints: + body_pub: str + hand_pub: str + controller_pub: str + control_events_pub: str + reference_pub: str + mode_pub: str + video_pub: str + health_pub: str + command_pub: str + + +def default_endpoints(*, host: str = "127.0.0.1", base_port: int = 39700) -> Sim2RealIpcEndpoints: + """Return deterministic localhost TCP endpoints for one sim2real runtime.""" + prefix = f"tcp://{host}:" + return Sim2RealIpcEndpoints( + body_pub=f"{prefix}{base_port}", + hand_pub=f"{prefix}{base_port + 1}", + controller_pub=f"{prefix}{base_port + 2}", + control_events_pub=f"{prefix}{base_port + 3}", + reference_pub=f"{prefix}{base_port + 4}", + mode_pub=f"{prefix}{base_port + 5}", + video_pub=f"{prefix}{base_port + 6}", + health_pub=f"{prefix}{base_port + 7}", + command_pub=f"{prefix}{base_port + 8}", + ) + + +class ZmqPublisher: + """Topic publisher with low watermarks for realtime latest-only streams.""" + + def __init__(self, endpoint: str, *, context: zmq.Context[Any] | None = None) -> None: + self._own_context = context is None + self._context = zmq.Context() if context is None else context + self._socket = self._context.socket(zmq.PUB) + self._socket.setsockopt(zmq.SNDHWM, 1) + self._socket.setsockopt(zmq.LINGER, 0) + self._socket.bind(endpoint) + self._endpoint = endpoint + # Give subscribers a short chance to connect during process startup. + time.sleep(0.05) + + @property + def endpoint(self) -> str: + return self._endpoint + + def publish(self, topic: str, payload: object) -> None: + try: + self._socket.send_multipart( + [topic.encode("utf-8"), pickle.dumps(payload, protocol=pickle.HIGHEST_PROTOCOL)], + flags=zmq.NOBLOCK, + ) + except zmq.Again: + # Realtime streams are latest-only; dropping beats backpressure. + return + + def close(self) -> None: + self._socket.close(linger=0) + if self._own_context: + self._context.term() + + +class LatestSubscriber: + """Subscriber that drains all pending messages and returns only the latest.""" + + def __init__( + self, + endpoint: str, + topic: str, + *, + context: zmq.Context[Any] | None = None, + ) -> None: + self._own_context = context is None + self._context = zmq.Context() if context is None else context + self._socket = self._context.socket(zmq.SUB) + self._socket.setsockopt(zmq.RCVHWM, 1) + self._socket.setsockopt(zmq.LINGER, 0) + self._socket.setsockopt_string(zmq.SUBSCRIBE, topic) + self._socket.connect(endpoint) + self._topic = topic + + def recv_latest(self) -> object | None: + latest: object | None = None + while True: + try: + topic_raw, payload_raw = self._socket.recv_multipart(flags=zmq.NOBLOCK) + except zmq.Again: + return latest + topic = topic_raw.decode("utf-8") + if topic != self._topic: + continue + latest = pickle.loads(payload_raw) + + def close(self) -> None: + self._socket.close(linger=0) + if self._own_context: + self._context.term() diff --git a/teleopit/sim2real/mp/messages.py b/teleopit/sim2real/mp/messages.py new file mode 100644 index 00000000..aabe9a88 --- /dev/null +++ b/teleopit/sim2real/mp/messages.py @@ -0,0 +1,91 @@ +"""Pickle-serializable IPC message contracts for multiprocess sim2real.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +import numpy as np +from numpy.typing import NDArray + +from teleopit.inputs.realtime_packet import ControlEvent, HumanFrame +from teleopit.sim.reference_timeline import ReferenceWindow + + +Float64Array = NDArray[np.float64] + + +@dataclass(frozen=True) +class BodyFramePacket: + frame: HumanFrame + timestamp_s: float + seq: int + + +@dataclass(frozen=True) +class ReferencePacket: + qpos: Float64Array + timestamp_s: float + seq: int + source_timestamp_s: float + source_seq: int + frame_valid: bool = True + reference_reset_seq: int = 0 + reference_window: ReferenceWindow | None = None + retarget_elapsed_s: float = 0.0 + + +@dataclass(frozen=True) +class ControlEventsPacket: + events: tuple[ControlEvent, ...] + timestamp_s: float + seq: int + + +@dataclass(frozen=True) +class SnapshotPacket: + snapshot: Any + timestamp_s: float + seq: int + + +@dataclass(frozen=True) +class ModeStatePacket: + mode: str + mocap_active: bool + mocap_paused: bool + timestamp_s: float + seq: int + + +@dataclass(frozen=True) +class ReferenceResetPacket: + reason: str + timestamp_s: float + seq: int + + +@dataclass(frozen=True) +class HealthPacket: + worker: str + timestamp_s: float + status: str = "ok" + metrics: dict[str, float | int | str] = field(default_factory=dict) + + +@dataclass(frozen=True) +class CommandPacket: + command: str + timestamp_s: float + payload: dict[str, Any] = field(default_factory=dict) + + +@dataclass(frozen=True) +class SharedFrameDescriptor: + shm_name: str + slot: int + seq: int + timestamp_s: float + shape: tuple[int, ...] + dtype: str + slots: int diff --git a/teleopit/sim2real/mp/runtime.py b/teleopit/sim2real/mp/runtime.py new file mode 100644 index 00000000..c492d307 --- /dev/null +++ b/teleopit/sim2real/mp/runtime.py @@ -0,0 +1,1277 @@ +"""Multiprocess sim2real runtime using ZMQ and shared memory.""" + +from __future__ import annotations + +import logging +import multiprocessing as mp +from multiprocessing.synchronize import Event as MpEvent +from pathlib import Path +import time +from typing import Any, Callable + +import numpy as np +from numpy.typing import NDArray + +from teleopit.constants import FULL_QPOS_DIM, NUM_JOINTS, ROOT_DIM +from teleopit.controllers.observation import VelCmdObservationBuilder, align_motion_qpos_yaw +from teleopit.controllers.rl_policy import RLPolicyController +from teleopit.inputs.pico4_provider import Pico4InputProvider +from teleopit.inputs.pico_video import bridge_video_source, parse_pico_video_config +from teleopit.inputs.realtime_packet import ControlEvent, ControlEventType +from teleopit.retargeting.core import RetargetingModule +from teleopit.runtime.common import cfg_get, require_section +from teleopit.runtime.factory import _build_policy_components, build_simulation_cfg +from teleopit.runtime.mocap_session import MocapSessionManager, MocapSessionState +from teleopit.runtime.reference_config import parse_reference_config +from teleopit.sim.reference_timeline import ReferenceTimeline, ReferenceWindow, ReferenceWindowBuilder +from teleopit.sim.reference_utils import build_static_reference_window, obs_builder_requires_reference_window +from teleopit.sim.realtime_utils import RealtimeReferenceManager +from teleopit.sim2real.controller import ( + RobotMode, + _LoopTimingReporter, + _parse_sim2real_viewers, + _Sim2RealRetargetViewer, +) +from teleopit.sim2real.dexterous_hand import build_linkerhand_runtime +from teleopit.sim2real.mp.ipc import ( + BODY_TOPIC, + COMMAND_TOPIC, + CONTROL_EVENTS_TOPIC, + CONTROLLER_TOPIC, + HAND_TOPIC, + HEALTH_TOPIC, + MODE_TOPIC, + REFERENCE_RESET_TOPIC, + REFERENCE_TOPIC, + VIDEO_TOPIC, + LatestSubscriber, + Sim2RealIpcEndpoints, + ZmqPublisher, + default_endpoints, +) +from teleopit.sim2real.mp.messages import ( + BodyFramePacket, + CommandPacket, + ControlEventsPacket, + HealthPacket, + ModeStatePacket, + ReferencePacket, + ReferenceResetPacket, + SharedFrameDescriptor, + SnapshotPacket, +) +from teleopit.sim2real.mp.shm import SharedFrameRingReader, SharedFrameRingWriter +from teleopit.sim2real.reference_processor import Sim2RealReferenceProcessor +from teleopit.sim2real.remote import UnitreeRemote +from teleopit.sim2real.safety import Sim2RealSafetyManager +from teleopit.sim2real.unitree_g1 import UnitreeG1Robot + +try: + from omegaconf import OmegaConf +except ImportError: # pragma: no cover - OmegaConf is a project dependency. + OmegaConf = None # type: ignore[assignment] + + +logger = logging.getLogger(__name__) + +Float32Array = NDArray[np.float32] +Float64Array = NDArray[np.float64] +PROJECT_ROOT = Path(__file__).resolve().parents[3] + + +def resolve_sim2real_runtime_mode(cfg: Any) -> str: + """Resolve ``auto|single_process|multiprocess`` into a concrete runtime.""" + raw = str(cfg_get(cfg, "sim2real_runtime", "auto")).strip().lower() + if raw in ("single", "single_process", "legacy"): + return "single_process" + if raw in ("mp", "multi", "multiprocess"): + provider = str(cfg_get(cfg_get(cfg, "input", {}), "provider", "")).lower() + if provider != "pico4": + raise ValueError("sim2real_runtime=multiprocess currently requires input.provider=pico4") + return "multiprocess" + if raw != "auto": + raise ValueError("sim2real_runtime must be auto, single_process, or multiprocess") + provider = str(cfg_get(cfg_get(cfg, "input", {}), "provider", "")).lower() + return "multiprocess" if provider == "pico4" else "single_process" + + +def _plain_cfg(cfg: Any) -> dict[str, Any]: + if isinstance(cfg, dict): + return dict(cfg) + if OmegaConf is not None and OmegaConf.is_config(cfg): + return OmegaConf.to_container(cfg, resolve=True) # type: ignore[return-value] + raise TypeError(f"Unsupported sim2real cfg type for multiprocessing: {type(cfg)!r}") + + +def _mp_cfg(cfg: Any) -> Any: + return cfg_get(cfg, "multiprocess", {}) or {} + + +def _worker_loop(name: str, fn: Callable[[], None]) -> None: + logging.basicConfig(level=logging.INFO) + try: + fn() + except KeyboardInterrupt: + pass + except BaseException: + logger.exception("%s worker crashed", name) + raise + + +def _human_frame_is_valid(frame: object, *, max_pos_value: float) -> bool: + if not isinstance(frame, dict): + return False + max_pos = float(max_pos_value) + if not np.isfinite(max_pos) or max_pos <= 0.0: + return False + for value in frame.values(): + try: + pos, quat = value + except Exception: + return False + pos_arr = np.asarray(pos, dtype=np.float64).reshape(-1) + quat_arr = np.asarray(quat, dtype=np.float64).reshape(-1) + if np.any(np.isnan(pos_arr)) or np.any(np.isinf(pos_arr)): + return False + if np.any(np.abs(pos_arr) > max_pos): + return False + if np.any(np.isnan(quat_arr)) or np.any(np.isinf(quat_arr)): + return False + return True + + +class MultiprocessSim2RealController: + """Supervisor facade for the multiprocess Pico sim2real runtime.""" + + def __init__(self, cfg: Any) -> None: + self.cfg = _plain_cfg(cfg) + if resolve_sim2real_runtime_mode(self.cfg) != "multiprocess": + raise ValueError("MultiprocessSim2RealController requires sim2real_runtime=multiprocess or auto+pico4") + + mp_cfg = _mp_cfg(self.cfg) + video_cfg = parse_pico_video_config(cfg_get(self.cfg, "input", {})) + if video_cfg.enabled and video_cfg.source not in ("realsense", "test-pattern"): + raise ValueError( + "Multiprocess sim2real only supports input.video.source=realsense or test-pattern" + ) + self._ctx = mp.get_context(str(cfg_get(mp_cfg, "start_method", "spawn"))) + self._stop_event = self._ctx.Event() + self._processes: list[mp.Process] = [] + self._shutdown_timeout_s = float(cfg_get(mp_cfg, "shutdown_timeout_s", 3.0)) + self._endpoints = default_endpoints( + host=str(cfg_get(mp_cfg, "host", "127.0.0.1")), + base_port=int(cfg_get(mp_cfg, "base_port", 39700)), + ) + + def run(self) -> None: + logger.info("Starting multiprocess sim2real runtime") + try: + self._start_processes() + while not self._stop_event.is_set(): + time.sleep(0.2) + critical_dead = [ + process.name + for process in self._processes + if not process.is_alive() + and process.exitcode not in (None, 0) + and process.name in {"robot_control", "pico_io", "retarget_worker"} + ] + if critical_dead: + logger.error("Critical sim2real worker exited: %s", ", ".join(critical_dead)) + self._stop_event.set() + break + except KeyboardInterrupt: + logger.info("KeyboardInterrupt -- shutting down multiprocess sim2real") + self._stop_event.set() + finally: + self.shutdown() + + def shutdown(self) -> None: + self._stop_event.set() + for process in self._processes: + process.join(timeout=self._shutdown_timeout_s) + for process in self._processes: + if process.is_alive(): + logger.warning("Terminating sim2real worker %s", process.name) + process.terminate() + process.join(timeout=1.0) + self._processes.clear() + + def _start_processes(self) -> None: + if self._processes: + return + + specs: list[tuple[str, Callable[..., None]]] = [ + ("pico_io", _run_pico_io_worker), + ("retarget_worker", _run_retarget_worker), + ("robot_control", _run_robot_control_worker), + ] + hand_mode = str(cfg_get(cfg_get(self.cfg, "dexterous_hand", {}) or {}, "mode", "off")).lower() + if hand_mode != "off": + specs.append(("hand_worker", _run_hand_worker)) + video_cfg = parse_pico_video_config(cfg_get(self.cfg, "input", {})) + if video_cfg.enabled and video_cfg.source not in (None, "test-pattern"): + specs.append(("video_worker", _run_video_worker)) + elif video_cfg.enabled and video_cfg.source == "test-pattern": + # pico-bridge can generate test-pattern internally without a camera worker. + logger.info("Pico video test-pattern uses pico_bridge internal source") + + for name, target in specs: + process = self._ctx.Process( + name=name, + target=target, + args=(self.cfg, self._endpoints, self._stop_event), + ) + process.start() + self._processes.append(process) + + +def _run_pico_io_worker( + cfg: dict[str, Any], + endpoints: Sim2RealIpcEndpoints, + stop_event: MpEvent, +) -> None: + def _main() -> None: + input_cfg = cfg_get(cfg, "input", {}) or {} + video_cfg = parse_pico_video_config(input_cfg) + provider = Pico4InputProvider( + human_format=str(cfg_get(input_cfg, "human_format", "pico_bridge")), + timeout=float(cfg_get(input_cfg, "pico4_timeout", 60.0)), + buffer_size=int(cfg_get(input_cfg, "pico4_buffer_size", 60)), + timestamp_gap_reset_s=float(cfg_get(input_cfg, "pico4_timestamp_gap_reset_s", 0.15)), + pause_button=cfg_get(input_cfg, "pause_button", "A"), + pause_debounce_s=float(cfg_get(input_cfg, "pause_debounce_s", 0.25)), + bridge_host=str(cfg_get(input_cfg, "bridge_host", "0.0.0.0")), + bridge_port=int(cfg_get(input_cfg, "bridge_port", 63901)), + bridge_discovery=bool(cfg_get(input_cfg, "bridge_discovery", True)), + bridge_advertise_ip=cfg_get(input_cfg, "bridge_advertise_ip", None), + bridge_video=bridge_video_source(video_cfg), + bridge_video_enabled=video_cfg.enabled, + bridge_start_timeout=float(cfg_get(input_cfg, "bridge_start_timeout", 10.0)), + bridge_history_size=int(cfg_get(input_cfg, "bridge_history_size", 120)), + ) + + body_pub = ZmqPublisher(endpoints.body_pub) + hand_pub = ZmqPublisher(endpoints.hand_pub) + controller_pub = ZmqPublisher(endpoints.controller_pub) + events_pub = ZmqPublisher(endpoints.control_events_pub) + health_pub = ZmqPublisher(endpoints.health_pub) + command_sub = LatestSubscriber(endpoints.command_pub, COMMAND_TOPIC) + video_sub = ( + LatestSubscriber(endpoints.video_pub, VIDEO_TOPIC) + if video_cfg.enabled and video_cfg.source not in (None, "test-pattern") + else None + ) + frame_reader = SharedFrameRingReader() + + hz = float(cfg_get(_mp_cfg(cfg), "pico_io_hz", 120.0)) + sleep_s = 1.0 / max(hz, 1.0) + last_body_seq = -1 + last_hand_seq = -1 + last_controller_seq = -1 + last_video_seq = -1 + last_health_s = 0.0 + try: + while not stop_event.is_set(): + command = command_sub.recv_latest() + if isinstance(command, CommandPacket) and command.command == "shutdown": + stop_event.set() + break + + now = time.monotonic() + if callable(getattr(provider, "has_frame", None)) and provider.has_frame(): + try: + frame, timestamp_s, seq = provider.get_frame_packet() + except Exception: + logger.exception("pico_io failed to read body frame") + else: + if int(seq) != last_body_seq: + body_pub.publish( + BODY_TOPIC, + BodyFramePacket(frame=frame, timestamp_s=float(timestamp_s), seq=int(seq)), + ) + last_body_seq = int(seq) + + events = provider.pop_control_events() + if events: + events_pub.publish( + CONTROL_EVENTS_TOPIC, + ControlEventsPacket(events=tuple(events), timestamp_s=now, seq=last_body_seq), + ) + + controller_snapshot = provider.get_controller_snapshot() + if controller_snapshot is not None and int(controller_snapshot.seq) != last_controller_seq: + controller_pub.publish( + CONTROLLER_TOPIC, + SnapshotPacket( + snapshot=controller_snapshot, + timestamp_s=float(controller_snapshot.timestamp_s), + seq=int(controller_snapshot.seq), + ), + ) + last_controller_seq = int(controller_snapshot.seq) + + hand_snapshot = provider.get_hand_snapshot() + if hand_snapshot is not None and int(hand_snapshot.seq) != last_hand_seq: + hand_pub.publish( + HAND_TOPIC, + SnapshotPacket( + snapshot=hand_snapshot, + timestamp_s=float(hand_snapshot.timestamp_s), + seq=int(hand_snapshot.seq), + ), + ) + last_hand_seq = int(hand_snapshot.seq) + + if video_sub is not None: + descriptor = video_sub.recv_latest() + if isinstance(descriptor, SharedFrameDescriptor): + try: + frame = frame_reader.read(descriptor, copy=False) + provider.push_video_frame(np.asarray(frame, dtype=np.uint8)) + last_video_seq = int(descriptor.seq) + except Exception as exc: + logger.warning("Pico video frame dropped: %s", exc) + + if now - last_health_s >= 1.0: + health_pub.publish( + HEALTH_TOPIC, + HealthPacket( + worker="pico_io", + timestamp_s=now, + metrics={ + "body_seq": last_body_seq, + "body_fps": float(provider.fps), + "hand_seq": last_hand_seq, + "controller_seq": last_controller_seq, + "video_seq": last_video_seq, + }, + ), + ) + last_health_s = now + time.sleep(sleep_s) + finally: + frame_reader.close() + if video_sub is not None: + video_sub.close() + command_sub.close() + for publisher in (body_pub, hand_pub, controller_pub, events_pub, health_pub): + publisher.close() + provider.close() + + _worker_loop("pico_io", _main) + + +def _run_retarget_worker( + cfg: dict[str, Any], + endpoints: Sim2RealIpcEndpoints, + stop_event: MpEvent, +) -> None: + def _main() -> None: + input_cfg = cfg_get(cfg, "input", {}) or {} + policy_hz = float(cfg_get(cfg, "policy_hz", 50.0)) + ref_cfg = parse_reference_config(cfg, provider_fps=None) + reference_window_builder = ReferenceWindowBuilder( + policy_dt_s=1.0 / policy_hz, + reference_steps=cfg_get(cfg, "reference_steps", [0]), + ) + if ref_cfg.retarget_buffer_enabled and ref_cfg.reference_delay_s is not None: + reference_window_builder.validate_runtime_support( + delay_s=float(ref_cfg.reference_delay_s or 0.0), + window_s=ref_cfg.retarget_buffer_window_s, + config_label="Multiprocess sim2real reference timeline", + ) + timeline = ReferenceTimeline(window_s=ref_cfg.retarget_buffer_window_s) if ref_cfg.retarget_buffer_enabled else None + reference_manager = ( + RealtimeReferenceManager( + reference_window_builder=reference_window_builder, + warmup_steps=ref_cfg.realtime_buffer_warmup_steps, + ) + if timeline is not None + else None + ) + + retargeter = RetargetingModule( + robot_name=str(cfg_get(input_cfg, "robot_name", "unitree_g1")), + human_format=str(cfg_get(input_cfg, "human_format", "pico_bridge")), + actual_human_height=float(cfg_get(input_cfg, "human_height", 1.75)), + ) + mocap_sw = cfg_get(cfg, "mocap_switch", {}) or {} + max_position_value = float(cfg_get(mocap_sw, "max_position_value", 5.0)) + body_sub = LatestSubscriber(endpoints.body_pub, BODY_TOPIC) + health_sub = LatestSubscriber(endpoints.health_pub, HEALTH_TOPIC) + command_sub = LatestSubscriber(endpoints.command_pub, COMMAND_TOPIC) + reference_reset_sub = LatestSubscriber(endpoints.mode_pub, REFERENCE_RESET_TOPIC) + ref_pub = ZmqPublisher(endpoints.reference_pub) + idle_sleep_s = float(cfg_get(_mp_cfg(cfg), "retarget_idle_sleep_s", 0.001)) + last_body_seq = -1 + last_body_timestamp_s: float | None = None + body_dt_s_ema: float | None = None + latest_body_fps: float | None = None + resolved_reference_delay_s = ( + float(ref_cfg.reference_delay_s) if ref_cfg.reference_delay_s is not None else None + ) + runtime_support_validated = ref_cfg.reference_delay_s is not None or not reference_window_builder.requires_timeline + last_reference_reset_seq = 0 + last_reference_reset_timestamp_s = 0.0 + last_valid_qpos: Float64Array | None = None + + def _reset_reference_state(packet: ReferenceResetPacket) -> None: + nonlocal last_body_seq + nonlocal last_body_timestamp_s + nonlocal body_dt_s_ema + nonlocal resolved_reference_delay_s + nonlocal runtime_support_validated + nonlocal last_reference_reset_seq + nonlocal last_reference_reset_timestamp_s + nonlocal last_valid_qpos + + last_reference_reset_seq = int(packet.seq) + last_reference_reset_timestamp_s = float(packet.timestamp_s) + last_body_seq = -1 + last_body_timestamp_s = None + body_dt_s_ema = None + last_valid_qpos = None + resolved_reference_delay_s = ( + float(ref_cfg.reference_delay_s) if ref_cfg.reference_delay_s is not None else None + ) + runtime_support_validated = ( + ref_cfg.reference_delay_s is not None or not reference_window_builder.requires_timeline + ) + if timeline is not None: + timeline.clear() + if reference_manager is not None: + reference_manager.set_warmup_steps(ref_cfg.realtime_buffer_warmup_steps) + reference_manager.reset() + logger.info( + "Retarget reference state reset | reason=%s seq=%d", + packet.reason, + last_reference_reset_seq, + ) + + def _publish_invalid_reference(packet: BodyFramePacket, *, elapsed_s: float) -> None: + qpos = np.zeros(FULL_QPOS_DIM, dtype=np.float64) + qpos[3] = 1.0 + if last_valid_qpos is not None: + qpos = np.asarray(last_valid_qpos, dtype=np.float64).copy() + ref_pub.publish( + REFERENCE_TOPIC, + ReferencePacket( + qpos=qpos, + timestamp_s=time.monotonic(), + seq=int(packet.seq), + source_timestamp_s=float(packet.timestamp_s), + source_seq=int(packet.seq), + frame_valid=False, + reference_reset_seq=last_reference_reset_seq, + retarget_elapsed_s=elapsed_s, + ), + ) + + try: + while not stop_event.is_set(): + health_packet = health_sub.recv_latest() + if isinstance(health_packet, HealthPacket) and health_packet.worker == "pico_io": + metric_fps = health_packet.metrics.get("body_fps") + if isinstance(metric_fps, (int, float)) and float(metric_fps) > 0.0: + latest_body_fps = float(metric_fps) + + command = command_sub.recv_latest() + if isinstance(command, CommandPacket) and command.command == "shutdown": + stop_event.set() + break + + reset_packet = reference_reset_sub.recv_latest() + if ( + isinstance(reset_packet, ReferenceResetPacket) + and int(reset_packet.seq) > last_reference_reset_seq + ): + _reset_reference_state(reset_packet) + + packet = body_sub.recv_latest() + if packet is None: + time.sleep(idle_sleep_s) + continue + if not isinstance(packet, BodyFramePacket) or int(packet.seq) == last_body_seq: + continue + if float(packet.timestamp_s) < last_reference_reset_timestamp_s: + continue + + start_s = time.monotonic() + frame_valid = _human_frame_is_valid(packet.frame, max_pos_value=max_position_value) + if not frame_valid: + last_body_seq = int(packet.seq) + last_body_timestamp_s = None + body_dt_s_ema = None + _publish_invalid_reference(packet, elapsed_s=time.monotonic() - start_s) + logger.warning("retarget_worker dropped invalid body frame seq=%s", packet.seq) + continue + + try: + retargeted = retargeter.retarget(packet.frame) + qpos = np.asarray(retargeted, dtype=np.float64).reshape(-1) + reference_window: ReferenceWindow | None = None + if timeline is not None: + timeline.append(qpos, float(packet.timestamp_s)) + if reference_manager is not None: + reference_manager.note_realtime_frame() + if reference_manager is None or not reference_manager.warmup_done: + last_body_timestamp_s = float(packet.timestamp_s) + last_body_seq = int(packet.seq) + continue + if last_body_timestamp_s is not None: + dt_s = float(packet.timestamp_s) - float(last_body_timestamp_s) + if dt_s > 1e-6: + body_dt_s_ema = dt_s if body_dt_s_ema is None else 0.9 * body_dt_s_ema + 0.1 * dt_s + last_body_timestamp_s = float(packet.timestamp_s) + if resolved_reference_delay_s is None: + if latest_body_fps is not None and latest_body_fps > 1e-6: + resolved_reference_delay_s = 1.0 / latest_body_fps + elif body_dt_s_ema is not None and body_dt_s_ema > 1e-6: + resolved_reference_delay_s = float(body_dt_s_ema) + elif reference_window_builder.requires_timeline: + last_body_seq = int(packet.seq) + continue + else: + resolved_reference_delay_s = 0.0 + if not runtime_support_validated: + reference_window_builder.validate_runtime_support( + delay_s=float(resolved_reference_delay_s), + window_s=ref_cfg.retarget_buffer_window_s, + config_label="Multiprocess sim2real reference timeline", + ) + runtime_support_validated = True + reference_window, _diag = reference_manager.sample( + timeline, + time.monotonic() - float(resolved_reference_delay_s), + ) + qpos = reference_window.current_sample().qpos + last_valid_qpos = np.asarray(qpos, dtype=np.float64).copy() + ref_pub.publish( + REFERENCE_TOPIC, + ReferencePacket( + qpos=np.asarray(qpos, dtype=np.float64).copy(), + timestamp_s=time.monotonic(), + seq=int(packet.seq), + source_timestamp_s=float(packet.timestamp_s), + source_seq=int(packet.seq), + frame_valid=True, + reference_reset_seq=last_reference_reset_seq, + reference_window=reference_window, + retarget_elapsed_s=time.monotonic() - start_s, + ), + ) + last_body_seq = int(packet.seq) + except Exception: + logger.exception("retarget_worker failed to retarget body seq=%s", getattr(packet, "seq", None)) + finally: + body_sub.close() + health_sub.close() + command_sub.close() + reference_reset_sub.close() + ref_pub.close() + + _worker_loop("retarget_worker", _main) + + +class _RobotControlWorker: + def __init__( + self, + cfg: dict[str, Any], + endpoints: Sim2RealIpcEndpoints, + stop_event: MpEvent, + ) -> None: + self.cfg = cfg + self.endpoints = endpoints + self.stop_event = stop_event + self.mode = RobotMode.IDLE + self.policy_hz = float(cfg_get(cfg, "policy_hz", 50.0)) + self.dt = 1.0 / self.policy_hz + + self.robot = UnitreeG1Robot(cfg_get(cfg, "real_robot")) + self.remote = UnitreeRemote() + self.policy, self.obs_builder = self._build_policy_and_obs() + + robot_cfg = cfg_get(cfg, "robot") + self.default_angles = np.asarray(cfg_get(robot_cfg, "default_angles"), dtype=np.float32) + self.num_actions = int(cfg_get(robot_cfg, "num_actions", NUM_JOINTS)) + self._safety = Sim2RealSafetyManager(cfg, self.robot, self.policy_hz, self.num_actions) + self._standing_return_ramp_duration = float(cfg_get(cfg, "standing_return_ramp_duration", 0.5)) + self._standing_return_kp_ramp_floor_ratio = float( + cfg_get(cfg, "standing_return_kp_ramp_floor_ratio", 0.5) + ) + + self._ref_cfg = parse_reference_config(cfg, provider_fps=None) + self._reference_window_builder = ReferenceWindowBuilder( + policy_dt_s=self.dt, + reference_steps=cfg_get(cfg, "reference_steps", [0]), + ) + mocap_sw = cfg_get(cfg, "mocap_switch", {}) or {} + self._ref_proc = Sim2RealReferenceProcessor( + obs_builder=self.obs_builder, + policy=self.policy, + policy_hz=self.policy_hz, + num_actions=self.num_actions, + reference_velocity_smoothing_alpha=self._ref_cfg.reference_velocity_smoothing_alpha, + reference_anchor_velocity_smoothing_alpha=self._ref_cfg.reference_anchor_velocity_smoothing_alpha, + max_pos_value=float(cfg_get(mocap_sw, "max_position_value", 5.0)), + ) + + self._standing_qpos = np.zeros(FULL_QPOS_DIM, dtype=np.float64) + self._standing_qpos[3] = 1.0 + self._standing_qpos[ROOT_DIM:FULL_QPOS_DIM] = self.default_angles.astype(np.float64) + self._last_action = np.zeros(self.num_actions, dtype=np.float32) + self._last_retarget_qpos: Float64Array | None = None + self._last_commanded_motion_qpos: Float64Array | None = None + self._mocap_reentry_armed = False + self._mocap_session = MocapSessionManager() + + self._latest_reference: ReferencePacket | None = None + mp_cfg = _mp_cfg(cfg) + self._max_reference_age_s = float(cfg_get(mp_cfg, "max_reference_age_s", 0.25)) + self._stale_reference_hold_s = float(cfg_get(mp_cfg, "stale_reference_hold_s", 0.08)) + mocap_sw = cfg_get(cfg, "mocap_switch", {}) or {} + self._check_frames = int(cfg_get(mocap_sw, "check_frames", 10)) + self._reference_reset_seq = 0 + self._last_reference_seq = -1 + self._consecutive_valid_references = 0 + + self._reference_sub = LatestSubscriber(endpoints.reference_pub, REFERENCE_TOPIC) + self._events_sub = LatestSubscriber(endpoints.control_events_pub, CONTROL_EVENTS_TOPIC) + self._command_sub = LatestSubscriber(endpoints.command_pub, COMMAND_TOPIC) + self._reference_reset_sub = LatestSubscriber(endpoints.mode_pub, REFERENCE_RESET_TOPIC) + self._mode_pub = ZmqPublisher(endpoints.mode_pub) + + viewers = _parse_sim2real_viewers(cfg) + self._retarget_viewer = _Sim2RealRetargetViewer( + xml_path=str(cfg_get(robot_cfg, "xml_path", "")) if "retarget" in viewers else None, + enabled="retarget" in viewers, + ) + self._mode_seq = 0 + + def run(self) -> None: + logger.info("Robot control worker started | mode=IDLE | policy_hz=%.0f", self.policy_hz) + timing = _LoopTimingReporter(target_period_s=self.dt) + try: + while not self.stop_event.is_set(): + t0 = time.monotonic() + self._drain_ipc() + + remote_bytes = self.robot.get_wireless_remote() + self.remote.update(remote_bytes) + if self.remote.LB.pressed and self.remote.RB.pressed: + if self.mode != RobotMode.DAMPING: + logger.warning("EMERGENCY STOP (L1+R1)") + self._enter_damping() + else: + self._handle_transitions() + if self.mode == RobotMode.STANDING: + self._standing_step() + elif self.mode == RobotMode.MOCAP: + self._mocap_step() + + self._publish_mode_state() + work_elapsed_s = time.monotonic() - t0 + cycle_elapsed_s = self._sleep_until(t0, self.dt) + timing.record( + loop_start_s=t0, + work_elapsed_s=work_elapsed_s, + cycle_elapsed_s=cycle_elapsed_s, + pico_age_s=self._reference_age_s(), + ) + finally: + self.shutdown() + + def shutdown(self) -> None: + if self.mode in (RobotMode.STANDING, RobotMode.MOCAP): + try: + self.robot.set_damping() + time.sleep(0.5) + except Exception: + logger.exception("Failed to send damping during robot_control shutdown") + try: + self.robot.exit_debug_mode() + except Exception: + logger.exception("Failed to exit debug mode during robot_control shutdown") + self._retarget_viewer.shutdown() + self._reference_sub.close() + self._events_sub.close() + self._command_sub.close() + self._reference_reset_sub.close() + self._mode_pub.close() + self.robot.close() + + def _build_policy_and_obs(self) -> tuple[Any, Any]: + robot_cfg = require_section(self.cfg, "robot") + controller_cfg = require_section(self.cfg, "controller") + sim_cfg = build_simulation_cfg(self.cfg) + policy, obs_builder = _build_policy_components( + robot_cfg=robot_cfg, + controller_cfg=controller_cfg, + sim_cfg=sim_cfg, + project_root=PROJECT_ROOT, + controller_cls=RLPolicyController, + ) + if not bool(getattr(policy, "_multi_input", False)): + raise ValueError("Sim2real requires an ONNX policy with dual inputs ('obs' and 'obs_history').") + return policy, obs_builder + + def _drain_ipc(self) -> None: + command = self._command_sub.recv_latest() + if isinstance(command, CommandPacket) and command.command == "shutdown": + self.stop_event.set() + return + reference_reset = self._reference_reset_sub.recv_latest() + if isinstance(reference_reset, ReferenceResetPacket): + self._apply_reference_reset(reference_reset.seq) + reference = self._reference_sub.recv_latest() + if isinstance(reference, ReferencePacket): + self._note_reference_packet(reference) + events = self._events_sub.recv_latest() + if isinstance(events, ControlEventsPacket): + self._handle_mocap_control_events(events.events) + + def _handle_transitions(self) -> None: + if self.mode == RobotMode.IDLE: + if self.remote.start.on_pressed: + logger.info("Start pressed (from IDLE)") + self._enter_standing() + elif self.mode == RobotMode.STANDING: + reentry_request = self._mocap_reentry_armed and self.remote.Y.pressed + if self.remote.Y.on_pressed or reentry_request: + if self._can_switch_to_mocap(): + logger.info("Y pressed -> entering MOCAP") + self._transition_to_mocap() + else: + logger.warning("Cannot switch to MOCAP -- no fresh retarget reference") + elif self.mode == RobotMode.MOCAP: + if self.remote.A.on_pressed: + if self._mocap_session.state == MocapSessionState.PAUSED: + logger.info("A pressed -> resuming playback") + self._resume_paused_mocap() + else: + logger.info("A pressed -> pausing playback") + self._pause_active_mocap() + return + if self.remote.X.on_pressed: + logger.info("X pressed -> returning to STANDING") + self._enter_standing() + elif self.mode == RobotMode.DAMPING: + if self.remote.start.on_pressed: + logger.info("Start pressed (from DAMPING)") + self._enter_standing() + + def _standing_step(self) -> None: + robot_state = self.robot.get_state() + qpos = self._standing_qpos.copy() + motion_joint_vel = np.zeros(self.num_actions, dtype=np.float32) + motion_qpos = np.asarray(qpos[:7 + self.num_actions], dtype=np.float32) + reference_window = None + if obs_builder_requires_reference_window(self.obs_builder): + reference_window = build_static_reference_window(qpos, self._reference_window_builder, self.policy_hz) + obs = self._ref_proc.build_observation( + robot_state=robot_state, + motion_qpos=motion_qpos, + motion_joint_vel=motion_joint_vel, + last_action=self._last_action, + anchor_lin_vel_w=np.zeros(3, dtype=np.float32), + anchor_ang_vel_w=np.zeros(3, dtype=np.float32), + reference_window=reference_window, + ) + obs = self._ref_proc.validate_observation(obs) + action = self.policy.compute_action(obs) + target_dof_pos = self._safety.clip_to_joint_limits(self.policy.get_target_dof_pos(action)) + self._safety.send_positions(target_dof_pos) + self._last_action = np.asarray(action, dtype=np.float32).reshape(-1) + self._last_retarget_qpos = qpos.copy() + self._last_commanded_motion_qpos = qpos.copy() + self._write_retarget_viewer(qpos) + + def _mocap_step(self) -> None: + if self._mocap_session.state == MocapSessionState.PAUSED: + self._paused_mocap_step() + return + + reference = self._latest_reference + age_s = self._reference_age_s() + if reference is None or age_s is None: + self._hold_or_damp_stale_reference("no retarget reference") + return + if int(reference.reference_reset_seq) != self._reference_reset_seq: + self._hold_or_damp_stale_reference("stale reset-generation retarget reference") + return + if not reference.frame_valid: + logger.warning("Retarget reference invalid -- holding last command") + self._hold_or_damp_stale_reference("invalid retarget reference") + return + if age_s > self._max_reference_age_s: + logger.warning("Retarget reference stale %.3fs -- entering damping", age_s) + self._enter_damping() + return + if age_s > self._stale_reference_hold_s and self._last_commanded_motion_qpos is not None: + self._run_static_mocap_step(self._last_commanded_motion_qpos) + return + + robot_state = self.robot.get_state() + self._execute_mocap_pipeline(reference.qpos, robot_state, reference.reference_window) + + def _execute_mocap_pipeline( + self, + reference_qpos: Float64Array, + robot_state: object, + reference_window: ReferenceWindow | None, + ) -> None: + reference_qpos = self._ref_proc.align_reference_yaw(reference_qpos, robot_state=robot_state) + qpos = reference_qpos.copy() + if qpos.shape[0] < 7 + self.num_actions: + raise ValueError(f"Retargeted qpos too short: {qpos.shape[0]} (need >= {7 + self.num_actions})") + motion_joint_pos = np.asarray(qpos[7:7 + self.num_actions], dtype=np.float32) + if self._last_retarget_qpos is None: + raw_motion_joint_vel = np.zeros((self.num_actions,), dtype=np.float32) + else: + prev_joint_pos = np.asarray(self._last_retarget_qpos[7:7 + self.num_actions], dtype=np.float32) + raw_motion_joint_vel = (motion_joint_pos - prev_joint_pos) * np.float32(self.policy_hz) + motion_joint_vel = self._ref_proc.apply_joint_vel_smoothing(raw_motion_joint_vel) + + anchor_lin_vel_w = np.zeros(3, dtype=np.float32) + anchor_ang_vel_w = np.zeros(3, dtype=np.float32) + if not obs_builder_requires_reference_window(self.obs_builder): + raw_lin, raw_ang = self._ref_proc.compute_anchor_velocities(reference_qpos) + anchor_lin_vel_w, anchor_ang_vel_w = self._ref_proc.apply_anchor_vel_smoothing(raw_lin, raw_ang) + + motion_qpos = np.asarray(qpos[:7 + self.num_actions], dtype=np.float32) + obs = self._ref_proc.build_observation( + robot_state=robot_state, + motion_qpos=motion_qpos, + motion_joint_vel=motion_joint_vel, + last_action=self._last_action, + anchor_lin_vel_w=anchor_lin_vel_w, + anchor_ang_vel_w=anchor_ang_vel_w, + reference_window=reference_window, + ) + obs = self._ref_proc.validate_observation(obs) + action = self.policy.compute_action(obs) + target_dof_pos = self._safety.clip_to_joint_limits(self.policy.get_target_dof_pos(action)) + self._safety.send_positions(target_dof_pos) + self._last_action = np.asarray(action, dtype=np.float32).reshape(-1) + self._last_retarget_qpos = qpos.copy() + self._ref_proc.last_reference_qpos = reference_qpos.copy() + self._last_commanded_motion_qpos = qpos.copy() + self._write_retarget_viewer(qpos) + + def _enter_standing(self) -> None: + prev_mode = self.mode + already_in_debug = self.mode in (RobotMode.STANDING, RobotMode.MOCAP) + if not already_in_debug: + logger.info("Entering debug mode...") + ok = self.robot.enter_debug_mode() + if not ok: + logger.error("Failed to enter debug mode -- staying in %s", self.mode.value) + return + time.sleep(0.5) + + state = self.robot.get_state() + if prev_mode != RobotMode.MOCAP: + logger.info("Locking joints to current position...") + self.robot.lock_all_joints() + time.sleep(0.3) + + self._publish_reference_reset("enter_standing") + init_qpos = self._build_robot_state_qpos(state) + self._last_retarget_qpos = init_qpos + self._ref_proc.last_reference_qpos = None + self._mocap_session.reset() + self._last_commanded_motion_qpos = None + self._set_default_standing_reference(state) + self._reset_policy_state() + if prev_mode == RobotMode.MOCAP: + self._safety.start_kp_ramp( + duration_s=self._standing_return_ramp_duration, + floor_ratio=self._standing_return_kp_ramp_floor_ratio, + ) + else: + self._safety.start_kp_ramp() + self._mocap_reentry_armed = prev_mode == RobotMode.MOCAP + self.mode = RobotMode.STANDING + logger.info("Mode -> STANDING (multiprocess robot control)") + + def _can_switch_to_mocap(self) -> bool: + age_s = self._reference_age_s() + if self._latest_reference is None or age_s is None: + return False + if not self._latest_reference.frame_valid: + return False + if self._latest_reference.reference_reset_seq != self._reference_reset_seq: + return False + if age_s > self._max_reference_age_s: + return False + if self._consecutive_valid_references < self._check_frames: + logger.warning( + "Mocap check: only %d/%d valid references", + self._consecutive_valid_references, + self._check_frames, + ) + return False + return True + + def _transition_to_mocap(self) -> None: + state = self.robot.get_state() + resume_qpos = self._build_resume_alignment_qpos(self._standing_qpos, state) + self._mocap_reentry_armed = False + self._publish_reference_reset("standing_to_mocap") + self._reset_policy_state() + self._last_retarget_qpos = None + self._last_commanded_motion_qpos = resume_qpos.copy() + self._ref_proc.reset_alignment(target_qpos=resume_qpos) + self.mode = RobotMode.MOCAP + logger.info("Mode -> MOCAP (tracking multiprocess retarget reference)") + + def _enter_damping(self) -> None: + if self.mode in (RobotMode.STANDING, RobotMode.MOCAP): + logger.info("DAMPING: sending LowCmd damping...") + self.robot.set_damping() + time.sleep(0.5) + logger.info("DAMPING: exiting debug mode...") + self.robot.exit_debug_mode() + self._publish_reference_reset("enter_damping") + self.mode = RobotMode.DAMPING + self._ref_proc.last_reference_qpos = None + self._mocap_reentry_armed = False + self._mocap_session.reset() + self._last_commanded_motion_qpos = None + logger.info("Mode -> DAMPING (press Start to re-enter STANDING)") + + def _reset_policy_state(self) -> None: + self._last_action = np.zeros(self.num_actions, dtype=np.float32) + self._ref_proc.reset_smoothers() + self._ref_proc.reset_alignment() + self._mocap_session.reset() + self._last_commanded_motion_qpos = None + self.policy.reset() + self.obs_builder.reset() + + def _reset_policy_reference_state(self) -> None: + self._last_action = np.zeros(self.num_actions, dtype=np.float32) + self._ref_proc.reset_smoothers() + self._ref_proc.reset_alignment() + self._mocap_session.reset() + self._last_commanded_motion_qpos = None + self.policy.reset() + self.obs_builder.reset() + + def _build_robot_state_qpos(self, state: object) -> Float64Array: + qpos = np.zeros(FULL_QPOS_DIM, dtype=np.float64) + base_pos = getattr(state, "base_pos", None) + if base_pos is not None: + qpos[0:3] = np.asarray(base_pos, dtype=np.float64).reshape(-1)[:3] + qpos[3:7] = np.asarray(getattr(state, "quat"), dtype=np.float64).reshape(-1)[:4] + qpos[ROOT_DIM:FULL_QPOS_DIM] = np.asarray(getattr(state, "qpos"), dtype=np.float64).reshape(-1)[ + : self.num_actions + ] + return qpos + + def _set_default_standing_reference(self, state: object) -> None: + self._standing_qpos[:] = 0.0 + base_pos = getattr(state, "base_pos", None) + if base_pos is not None: + self._standing_qpos[0:3] = np.asarray(base_pos, dtype=np.float64).reshape(-1)[:3] + self._standing_qpos[3:7] = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64) + align_motion_qpos_yaw(np.asarray(getattr(state, "quat"), dtype=np.float32), self._standing_qpos) + self._standing_qpos[ROOT_DIM:FULL_QPOS_DIM] = self.default_angles.astype(np.float64) + + def _build_resume_alignment_qpos(self, hold_qpos: Float64Array | None, state: object) -> Float64Array: + qpos = self._build_robot_state_qpos(state) + if hold_qpos is not None: + qpos[0:2] = np.asarray(hold_qpos, dtype=np.float64).reshape(-1)[0:2] + base_pos = getattr(state, "base_pos", None) + if base_pos is not None: + qpos[0:2] = np.asarray(base_pos, dtype=np.float64).reshape(-1)[0:2] + return qpos + + def _handle_mocap_control_events(self, control_events: tuple[ControlEvent, ...]) -> None: + for event in control_events: + if event.event_type != ControlEventType.TOGGLE_PAUSE: + continue + if self.mode != RobotMode.MOCAP: + continue + if self._mocap_session.state == MocapSessionState.PAUSED: + self._resume_paused_mocap() + else: + self._pause_active_mocap() + + def _pause_active_mocap(self) -> None: + hold_qpos = self._resolve_mocap_hold_qpos() + self._last_retarget_qpos = hold_qpos.copy() + self._ref_proc.last_reference_qpos = hold_qpos.copy() + self._last_commanded_motion_qpos = hold_qpos.copy() + self._publish_reference_reset("pause_active_mocap") + self._reset_policy_reference_state() + self._mocap_session.pause(hold_qpos) + logger.info("Mocap session -> PAUSED (multiprocess episode-reset)") + + def _resume_paused_mocap(self) -> None: + hold_qpos = self._mocap_session.hold_qpos + if hold_qpos is None: + raise RuntimeError("Cannot resume mocap without a paused hold qpos") + state = self.robot.get_state() + resume_qpos = self._build_resume_alignment_qpos(hold_qpos, state) + self._last_commanded_motion_qpos = resume_qpos.copy() + self._publish_reference_reset("resume_paused_mocap") + self._reset_policy_reference_state() + self._last_retarget_qpos = None + self._last_commanded_motion_qpos = resume_qpos.copy() + self._ref_proc.reset_alignment(target_qpos=resume_qpos) + logger.info("Mocap session -> ACTIVE (multiprocess episode-reset + reference realignment)") + + def _resolve_mocap_hold_qpos(self) -> Float64Array: + if self._last_commanded_motion_qpos is not None: + return self._last_commanded_motion_qpos.copy() + if self._last_retarget_qpos is not None: + return np.asarray(self._last_retarget_qpos, dtype=np.float64).copy() + state = self.robot.get_state() + hold_qpos = np.zeros(FULL_QPOS_DIM, dtype=np.float64) + hold_qpos[3:7] = np.asarray(state.quat, dtype=np.float64) + hold_qpos[ROOT_DIM:FULL_QPOS_DIM] = np.asarray(state.qpos, dtype=np.float64) + return hold_qpos + + def _paused_mocap_step(self) -> None: + hold_qpos = self._mocap_session.hold_qpos + if hold_qpos is None: + raise RuntimeError("Paused mocap session is missing a hold_qpos") + self._run_static_mocap_step(hold_qpos) + + def _run_static_mocap_step(self, hold_qpos: Float64Array) -> None: + robot_state = self.robot.get_state() + qpos = np.asarray(hold_qpos, dtype=np.float64).copy() + motion_joint_vel = np.zeros(self.num_actions, dtype=np.float32) + motion_qpos = np.asarray(qpos[:7 + self.num_actions], dtype=np.float32) + reference_window = None + if obs_builder_requires_reference_window(self.obs_builder): + reference_window = build_static_reference_window(qpos, self._reference_window_builder, self.policy_hz) + obs = self._ref_proc.build_observation( + robot_state=robot_state, + motion_qpos=motion_qpos, + motion_joint_vel=motion_joint_vel, + last_action=self._last_action, + anchor_lin_vel_w=np.zeros(3, dtype=np.float32), + anchor_ang_vel_w=np.zeros(3, dtype=np.float32), + reference_window=reference_window, + ) + obs = self._ref_proc.validate_observation(obs) + action = self.policy.compute_action(obs) + target_dof_pos = self._safety.clip_to_joint_limits(self.policy.get_target_dof_pos(action)) + self._safety.send_positions(target_dof_pos) + self._last_action = np.asarray(action, dtype=np.float32).reshape(-1) + self._last_retarget_qpos = qpos.copy() + self._ref_proc.last_reference_qpos = qpos.copy() + self._last_commanded_motion_qpos = qpos.copy() + self._write_retarget_viewer(qpos) + + def _hold_or_damp_stale_reference(self, reason: str) -> None: + if self._last_commanded_motion_qpos is not None: + self._run_static_mocap_step(self._last_commanded_motion_qpos) + return + logger.warning("No mocap hold pose available after %s -- entering damping", reason) + self._enter_damping() + + def _publish_mode_state(self) -> None: + self._mode_seq += 1 + active = self.mode == RobotMode.MOCAP and self._mocap_session.state == MocapSessionState.ACTIVE + paused = self.mode == RobotMode.MOCAP and self._mocap_session.state == MocapSessionState.PAUSED + self._mode_pub.publish( + MODE_TOPIC, + ModeStatePacket( + mode=self.mode.value, + mocap_active=active, + mocap_paused=paused, + timestamp_s=time.monotonic(), + seq=self._mode_seq, + ), + ) + + def _write_retarget_viewer(self, qpos: Float64Array) -> None: + try: + self._retarget_viewer.write(qpos) + except Exception: + logger.exception("Sim2real retarget viewer update failed; control continues") + + def _reference_age_s(self) -> float | None: + if self._latest_reference is None: + return None + return max(0.0, time.monotonic() - float(self._latest_reference.timestamp_s)) + + def _apply_reference_reset(self, reference_reset_seq: int) -> None: + reset_seq = int(reference_reset_seq) + if reset_seq <= self._reference_reset_seq: + return + self._reference_reset_seq = reset_seq + self._latest_reference = None + self._last_reference_seq = -1 + self._consecutive_valid_references = 0 + + def _note_reference_packet(self, reference: ReferencePacket) -> None: + if int(reference.reference_reset_seq) < self._reference_reset_seq: + return + if int(reference.reference_reset_seq) > self._reference_reset_seq: + self._apply_reference_reset(int(reference.reference_reset_seq)) + if int(reference.seq) <= self._last_reference_seq: + return + self._last_reference_seq = int(reference.seq) + self._latest_reference = reference + if not reference.frame_valid: + self._consecutive_valid_references = 0 + return + self._consecutive_valid_references += 1 + + def _publish_reference_reset(self, reason: str) -> None: + self._reference_reset_seq += 1 + self._latest_reference = None + self._last_reference_seq = -1 + self._consecutive_valid_references = 0 + self._mode_pub.publish( + REFERENCE_RESET_TOPIC, + ReferenceResetPacket( + reason=reason, + timestamp_s=time.monotonic(), + seq=self._reference_reset_seq, + ), + ) + + @staticmethod + def _sleep_until(t0: float, dt: float) -> float: + elapsed = time.monotonic() - t0 + remaining = dt - elapsed + if remaining > 0: + time.sleep(remaining) + return time.monotonic() - t0 + + +def _run_robot_control_worker( + cfg: dict[str, Any], + endpoints: Sim2RealIpcEndpoints, + stop_event: MpEvent, +) -> None: + def _main() -> None: + worker = _RobotControlWorker(cfg, endpoints, stop_event) + worker.run() + + _worker_loop("robot_control", _main) + + +class _HandSnapshotProxy: + def __init__(self) -> None: + self.hand_snapshot: Any | None = None + self.controller_snapshot: Any | None = None + + def get_hand_snapshot(self) -> Any | None: + return self.hand_snapshot + + def get_controller_snapshot(self) -> Any | None: + return self.controller_snapshot + + +def _run_hand_worker( + cfg: dict[str, Any], + endpoints: Sim2RealIpcEndpoints, + stop_event: MpEvent, +) -> None: + def _main() -> None: + proxy = _HandSnapshotProxy() + runtime = build_linkerhand_runtime(cfg, proxy) + hand_sub = LatestSubscriber(endpoints.hand_pub, HAND_TOPIC) + controller_sub = LatestSubscriber(endpoints.controller_pub, CONTROLLER_TOPIC) + mode_sub = LatestSubscriber(endpoints.mode_pub, MODE_TOPIC) + command_sub = LatestSubscriber(endpoints.command_pub, COMMAND_TOPIC) + active = False + hz = float(cfg_get(_mp_cfg(cfg), "hand_worker_hz", 120.0)) + sleep_s = 1.0 / max(hz, 1.0) + runtime.start() + try: + while not stop_event.is_set(): + command = command_sub.recv_latest() + if isinstance(command, CommandPacket) and command.command == "shutdown": + stop_event.set() + break + hand_packet = hand_sub.recv_latest() + if isinstance(hand_packet, SnapshotPacket): + proxy.hand_snapshot = hand_packet.snapshot + controller_packet = controller_sub.recv_latest() + if isinstance(controller_packet, SnapshotPacket): + proxy.controller_snapshot = controller_packet.snapshot + mode_packet = mode_sub.recv_latest() + if isinstance(mode_packet, ModeStatePacket): + active = bool(mode_packet.mocap_active) + try: + runtime.tick(active=active) + except Exception: + logger.exception("Dexterous hand worker tick failed; hand control continues") + time.sleep(sleep_s) + finally: + try: + runtime.close() + finally: + hand_sub.close() + controller_sub.close() + mode_sub.close() + command_sub.close() + + _worker_loop("hand_worker", _main) + + +def _run_video_worker( + cfg: dict[str, Any], + endpoints: Sim2RealIpcEndpoints, + stop_event: MpEvent, +) -> None: + def _main() -> None: + input_cfg = cfg_get(cfg, "input", {}) or {} + video_cfg = parse_pico_video_config(input_cfg) + if not video_cfg.enabled: + return + if video_cfg.source not in ("realsense",): + logger.warning("Multiprocess video worker supports source=realsense; got %s", video_cfg.source) + return + + writer = SharedFrameRingWriter( + shape=(video_cfg.height, video_cfg.width, 3), + dtype=np.uint8, + slots=int(cfg_get(_mp_cfg(cfg), "video_slots", 3)), + ) + video_pub = ZmqPublisher(endpoints.video_pub) + command_sub = LatestSubscriber(endpoints.command_pub, COMMAND_TOPIC) + try: + import pyrealsense2 as rs + + pipeline = rs.pipeline() + rs_config = rs.config() + if video_cfg.device is not None: + rs_config.enable_device(video_cfg.device) + rs_config.enable_stream( + rs.stream.color, + video_cfg.width, + video_cfg.height, + rs.format.rgb8, + video_cfg.fps, + ) + pipeline.start(rs_config) + try: + while not stop_event.is_set(): + command = command_sub.recv_latest() + if isinstance(command, CommandPacket) and command.command == "shutdown": + stop_event.set() + break + frames = pipeline.wait_for_frames() + color_frame = frames.get_color_frame() + if not color_frame: + continue + rgb = np.ascontiguousarray(np.asanyarray(color_frame.get_data()), dtype=np.uint8) + descriptor = writer.write(rgb, timestamp_s=time.monotonic()) + video_pub.publish(VIDEO_TOPIC, descriptor) + finally: + pipeline.stop() + finally: + command_sub.close() + video_pub.close() + writer.close(unlink=True) + + _worker_loop("video_worker", _main) diff --git a/teleopit/sim2real/mp/shm.py b/teleopit/sim2real/mp/shm.py new file mode 100644 index 00000000..36feaf9a --- /dev/null +++ b/teleopit/sim2real/mp/shm.py @@ -0,0 +1,106 @@ +"""Shared-memory ring buffer for large sim2real video frames.""" + +from __future__ import annotations + +from multiprocessing import shared_memory +from typing import Any + +import numpy as np +from numpy.typing import NDArray + +from teleopit.sim2real.mp.messages import SharedFrameDescriptor + + +class SharedFrameRingWriter: + """Owns a fixed-size ring of frame slots in shared memory.""" + + def __init__( + self, + *, + shape: tuple[int, ...], + dtype: str | np.dtype[Any] = np.uint8, + slots: int = 3, + name: str | None = None, + ) -> None: + if slots <= 0: + raise ValueError(f"slots must be positive, got {slots}") + dtype_np = np.dtype(dtype) + shape_tuple = tuple(int(dim) for dim in shape) + if not shape_tuple or any(dim <= 0 for dim in shape_tuple): + raise ValueError(f"shape must contain positive dimensions, got {shape}") + + self.shape = shape_tuple + self.dtype = dtype_np + self.slots = int(slots) + self._slot_size = int(np.prod(shape_tuple)) * dtype_np.itemsize + self._shm = shared_memory.SharedMemory( + name=name, + create=True, + size=self._slot_size * self.slots, + ) + self._seq = 0 + + @property + def name(self) -> str: + return self._shm.name + + def write(self, frame: NDArray[np.generic], *, timestamp_s: float) -> SharedFrameDescriptor: + frame_arr = np.asarray(frame, dtype=self.dtype) + if tuple(frame_arr.shape) != self.shape: + raise ValueError(f"frame shape {frame_arr.shape} does not match shared ring shape {self.shape}") + slot = self._seq % self.slots + view = self._slot_view(slot) + np.copyto(view, np.ascontiguousarray(frame_arr)) + descriptor = SharedFrameDescriptor( + shm_name=self._shm.name, + slot=slot, + seq=self._seq, + timestamp_s=float(timestamp_s), + shape=self.shape, + dtype=str(self.dtype), + slots=self.slots, + ) + self._seq += 1 + return descriptor + + def close(self, *, unlink: bool = True) -> None: + self._shm.close() + if unlink: + try: + self._shm.unlink() + except FileNotFoundError: + pass + + def _slot_view(self, slot: int) -> NDArray[np.generic]: + if slot < 0 or slot >= self.slots: + raise ValueError(f"slot must be in [0, {self.slots}), got {slot}") + start = slot * self._slot_size + return np.ndarray(self.shape, dtype=self.dtype, buffer=self._shm.buf, offset=start) + + +class SharedFrameRingReader: + """Attaches to shared frame rings lazily and reads descriptor-selected slots.""" + + def __init__(self) -> None: + self._rings: dict[str, shared_memory.SharedMemory] = {} + + def read(self, descriptor: SharedFrameDescriptor, *, copy: bool = False) -> NDArray[np.generic]: + shm = self._rings.get(descriptor.shm_name) + if shm is None: + shm = shared_memory.SharedMemory(name=descriptor.shm_name) + self._rings[descriptor.shm_name] = shm + + dtype = np.dtype(descriptor.dtype) + shape = tuple(int(dim) for dim in descriptor.shape) + slot_size = int(np.prod(shape)) * dtype.itemsize + if descriptor.slot < 0 or descriptor.slot >= descriptor.slots: + raise ValueError(f"descriptor slot {descriptor.slot} out of range for {descriptor.slots} slots") + view = np.ndarray(shape, dtype=dtype, buffer=shm.buf, offset=descriptor.slot * slot_size) + if copy: + return view.copy() + return view + + def close(self) -> None: + for shm in self._rings.values(): + shm.close() + self._rings.clear() diff --git a/tests/test_sim2real_multiprocess.py b/tests/test_sim2real_multiprocess.py new file mode 100644 index 00000000..d7e6b0f1 --- /dev/null +++ b/tests/test_sim2real_multiprocess.py @@ -0,0 +1,257 @@ +from __future__ import annotations + +import importlib.util +from pathlib import Path +from types import SimpleNamespace + +import numpy as np +import pytest + +from teleopit.runtime.mocap_session import MocapSessionState +from teleopit.sim2real.mp.ipc import HEALTH_TOPIC, REFERENCE_RESET_TOPIC, LatestSubscriber, ZmqPublisher +from teleopit.sim2real.mp import resolve_sim2real_runtime_mode +from teleopit.sim2real.mp.messages import ReferencePacket, ReferenceResetPacket, SharedFrameDescriptor +from teleopit.sim2real.mp.runtime import MultiprocessSim2RealController, _RobotControlWorker, _human_frame_is_valid +from teleopit.sim2real.mp.shm import SharedFrameRingReader, SharedFrameRingWriter + + +def test_resolve_runtime_auto_uses_multiprocess_for_pico4() -> None: + cfg = {"sim2real_runtime": "auto", "input": {"provider": "pico4"}} + assert resolve_sim2real_runtime_mode(cfg) == "multiprocess" + + +def test_resolve_runtime_auto_uses_single_process_for_bvh() -> None: + cfg = {"sim2real_runtime": "auto", "input": {"provider": "bvh"}} + assert resolve_sim2real_runtime_mode(cfg) == "single_process" + + +def test_multiprocess_requires_pico4_provider() -> None: + cfg = {"sim2real_runtime": "multiprocess", "input": {"provider": "bvh"}} + with pytest.raises(ValueError, match="requires input.provider=pico4"): + resolve_sim2real_runtime_mode(cfg) + + +def test_shared_frame_ring_roundtrip() -> None: + writer = SharedFrameRingWriter(shape=(2, 3, 1), dtype=np.uint8, slots=2) + reader = SharedFrameRingReader() + try: + frame0 = np.arange(6, dtype=np.uint8).reshape(2, 3, 1) + desc0 = writer.write(frame0, timestamp_s=1.0) + assert isinstance(desc0, SharedFrameDescriptor) + np.testing.assert_array_equal(reader.read(desc0, copy=True), frame0) + + frame1 = np.full((2, 3, 1), 9, dtype=np.uint8) + desc1 = writer.write(frame1, timestamp_s=2.0) + np.testing.assert_array_equal(reader.read(desc1, copy=True), frame1) + assert desc1.slot != desc0.slot + finally: + reader.close() + writer.close(unlink=True) + + +def test_multiprocess_rejects_unsupported_video_source() -> None: + cfg = { + "sim2real_runtime": "multiprocess", + "input": {"provider": "pico4", "video": {"enabled": True, "source": "mujoco"}}, + } + with pytest.raises(ValueError, match="only supports input.video.source=realsense or test-pattern"): + MultiprocessSim2RealController(cfg) + + +def test_zmq_endpoint_allows_one_publisher_and_subscribers() -> None: + endpoint = "inproc://sim2real-health-test" + import zmq + + context = zmq.Context() + publisher = ZmqPublisher(endpoint, context=context) + subscriber = LatestSubscriber(endpoint, HEALTH_TOPIC, context=context) + try: + with pytest.raises(zmq.ZMQError): + ZmqPublisher(endpoint, context=context) + finally: + subscriber.close() + publisher.close() + context.term() + + +def test_run_sim2real_single_process_shutdowns_on_exception(monkeypatch) -> None: + script_path = Path.cwd() / "scripts" / "run" / "run_sim2real.py" + spec = importlib.util.spec_from_file_location("test_run_sim2real", script_path) + assert spec is not None and spec.loader is not None + run_sim2real = importlib.util.module_from_spec(spec) + spec.loader.exec_module(run_sim2real) + + calls: list[str] = [] + + class FailingController: + def __init__(self, _cfg: object) -> None: + calls.append("init") + + def run(self) -> None: + calls.append("run") + raise RuntimeError("boom") + + def shutdown(self) -> None: + calls.append("shutdown") + + cfg = SimpleNamespace( + input={"provider": "bvh"}, + controller=SimpleNamespace(policy_path="policy.onnx"), + ) + monkeypatch.setattr(run_sim2real, "validate_policy_path", lambda *_args, **_kwargs: None) + monkeypatch.setattr(run_sim2real, "resolve_sim2real_runtime_mode", lambda _cfg: "single_process") + monkeypatch.setattr(run_sim2real, "Sim2RealController", FailingController) + + with pytest.raises(RuntimeError, match="boom"): + run_sim2real._run_sim2real(cfg) + + assert calls == ["init", "run", "shutdown"] + + +def test_multiprocess_run_cleans_up_after_start_failure(monkeypatch) -> None: + class FakeProcess: + def __init__(self, *, name: str, fail_start: bool = False) -> None: + self.name = name + self.exitcode = None + self.fail_start = fail_start + self.started = False + self.join_calls: list[float | None] = [] + self.terminated = False + + def start(self) -> None: + if self.fail_start: + raise RuntimeError("start failed") + self.started = True + + def is_alive(self) -> bool: + return self.started and not self.terminated + + def join(self, timeout: float | None = None) -> None: + self.join_calls.append(timeout) + + def terminate(self) -> None: + self.terminated = True + + started_process = FakeProcess(name="pico_io") + + def fake_start_processes(self: MultiprocessSim2RealController) -> None: + started_process.started = True + self._processes.append(started_process) + raise RuntimeError("start failed") + + cfg = { + "sim2real_runtime": "multiprocess", + "input": {"provider": "pico4"}, + "multiprocess": {"shutdown_timeout_s": 0.01}, + } + controller = MultiprocessSim2RealController(cfg) + monkeypatch.setattr(controller, "_start_processes", fake_start_processes.__get__(controller)) + + with pytest.raises(RuntimeError, match="start failed"): + controller.run() + + assert started_process.join_calls == [0.01, 1.0] + assert started_process.terminated is True + assert controller._processes == [] + + +def test_human_frame_validation_rejects_bad_inputs() -> None: + valid_frame = { + "Pelvis": (np.zeros(3, dtype=np.float64), np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64)), + } + assert _human_frame_is_valid(valid_frame, max_pos_value=5.0) + + bad_frame = { + "Pelvis": (np.array([6.0, 0.0, 0.0], dtype=np.float64), np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64)), + } + assert not _human_frame_is_valid(bad_frame, max_pos_value=5.0) + + +def test_robot_worker_requires_consecutive_valid_references_and_reset_generation(monkeypatch) -> None: + worker = object.__new__(_RobotControlWorker) + worker._reference_reset_seq = 0 + worker._latest_reference = None + worker._last_reference_seq = -1 + worker._consecutive_valid_references = 0 + worker._check_frames = 2 + worker._max_reference_age_s = 0.25 + worker._reference_age_s = lambda: 0.0 + worker._mocap_session = SimpleNamespace(state=MocapSessionState.ACTIVE) + worker._last_commanded_motion_qpos = np.zeros(36, dtype=np.float64) + worker._mode_pub_events: list[tuple[str, object]] = [] + worker._mode_pub = SimpleNamespace( + publish=lambda topic, payload: worker._mode_pub_events.append((topic, payload)) + ) + + valid0 = ReferencePacket( + qpos=np.zeros(36, dtype=np.float64), + timestamp_s=1.0, + seq=1, + source_timestamp_s=1.0, + source_seq=1, + frame_valid=True, + reference_reset_seq=0, + ) + worker._note_reference_packet(valid0) + assert worker._can_switch_to_mocap() is False + + valid1 = ReferencePacket( + qpos=np.zeros(36, dtype=np.float64), + timestamp_s=1.1, + seq=2, + source_timestamp_s=1.1, + source_seq=2, + frame_valid=True, + reference_reset_seq=0, + ) + worker._note_reference_packet(valid1) + assert worker._can_switch_to_mocap() is True + + invalid = ReferencePacket( + qpos=np.zeros(36, dtype=np.float64), + timestamp_s=1.2, + seq=3, + source_timestamp_s=1.2, + source_seq=3, + frame_valid=False, + reference_reset_seq=0, + ) + worker._note_reference_packet(invalid) + assert worker._can_switch_to_mocap() is False + + worker._publish_reference_reset("enter_standing") + assert worker._reference_reset_seq == 1 + assert worker._latest_reference is None + assert worker._consecutive_valid_references == 0 + assert worker._mode_pub_events + topic, payload = worker._mode_pub_events[-1] + assert topic == REFERENCE_RESET_TOPIC + assert isinstance(payload, ReferenceResetPacket) + assert payload.seq == 1 + assert payload.reason == "enter_standing" + + old_packet = ReferencePacket( + qpos=np.zeros(36, dtype=np.float64), + timestamp_s=1.3, + seq=4, + source_timestamp_s=1.3, + source_seq=4, + frame_valid=True, + reference_reset_seq=0, + ) + worker._note_reference_packet(old_packet) + assert worker._latest_reference is None + assert worker._consecutive_valid_references == 0 + + fresh_packet = ReferencePacket( + qpos=np.zeros(36, dtype=np.float64), + timestamp_s=1.4, + seq=5, + source_timestamp_s=1.4, + source_seq=5, + frame_valid=True, + reference_reset_seq=1, + ) + worker._note_reference_packet(fresh_packet) + assert worker._latest_reference == fresh_packet + assert worker._consecutive_valid_references == 1 From 461aacce5279cf3131e6a9511bf0f22c511052e0 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Fri, 29 May 2026 17:09:09 +0800 Subject: [PATCH 058/122] Remove multiprocess reference reset path --- teleopit/sim2real/mp/ipc.py | 1 - teleopit/sim2real/mp/messages.py | 8 --- teleopit/sim2real/mp/runtime.py | 94 ----------------------------- tests/test_sim2real_multiprocess.py | 39 ++---------- 4 files changed, 5 insertions(+), 137 deletions(-) diff --git a/teleopit/sim2real/mp/ipc.py b/teleopit/sim2real/mp/ipc.py index 02864c73..d31fc328 100644 --- a/teleopit/sim2real/mp/ipc.py +++ b/teleopit/sim2real/mp/ipc.py @@ -16,7 +16,6 @@ CONTROL_EVENTS_TOPIC = "control_events" REFERENCE_TOPIC = "reference" MODE_TOPIC = "mode" -REFERENCE_RESET_TOPIC = "reference_reset" VIDEO_TOPIC = "video" HEALTH_TOPIC = "health" COMMAND_TOPIC = "command" diff --git a/teleopit/sim2real/mp/messages.py b/teleopit/sim2real/mp/messages.py index aabe9a88..137666e4 100644 --- a/teleopit/sim2real/mp/messages.py +++ b/teleopit/sim2real/mp/messages.py @@ -30,7 +30,6 @@ class ReferencePacket: source_timestamp_s: float source_seq: int frame_valid: bool = True - reference_reset_seq: int = 0 reference_window: ReferenceWindow | None = None retarget_elapsed_s: float = 0.0 @@ -58,13 +57,6 @@ class ModeStatePacket: seq: int -@dataclass(frozen=True) -class ReferenceResetPacket: - reason: str - timestamp_s: float - seq: int - - @dataclass(frozen=True) class HealthPacket: worker: str diff --git a/teleopit/sim2real/mp/runtime.py b/teleopit/sim2real/mp/runtime.py index c492d307..fe955351 100644 --- a/teleopit/sim2real/mp/runtime.py +++ b/teleopit/sim2real/mp/runtime.py @@ -41,7 +41,6 @@ HAND_TOPIC, HEALTH_TOPIC, MODE_TOPIC, - REFERENCE_RESET_TOPIC, REFERENCE_TOPIC, VIDEO_TOPIC, LatestSubscriber, @@ -56,7 +55,6 @@ HealthPacket, ModeStatePacket, ReferencePacket, - ReferenceResetPacket, SharedFrameDescriptor, SnapshotPacket, ) @@ -401,7 +399,6 @@ def _main() -> None: body_sub = LatestSubscriber(endpoints.body_pub, BODY_TOPIC) health_sub = LatestSubscriber(endpoints.health_pub, HEALTH_TOPIC) command_sub = LatestSubscriber(endpoints.command_pub, COMMAND_TOPIC) - reference_reset_sub = LatestSubscriber(endpoints.mode_pub, REFERENCE_RESET_TOPIC) ref_pub = ZmqPublisher(endpoints.reference_pub) idle_sleep_s = float(cfg_get(_mp_cfg(cfg), "retarget_idle_sleep_s", 0.001)) last_body_seq = -1 @@ -412,43 +409,8 @@ def _main() -> None: float(ref_cfg.reference_delay_s) if ref_cfg.reference_delay_s is not None else None ) runtime_support_validated = ref_cfg.reference_delay_s is not None or not reference_window_builder.requires_timeline - last_reference_reset_seq = 0 - last_reference_reset_timestamp_s = 0.0 last_valid_qpos: Float64Array | None = None - def _reset_reference_state(packet: ReferenceResetPacket) -> None: - nonlocal last_body_seq - nonlocal last_body_timestamp_s - nonlocal body_dt_s_ema - nonlocal resolved_reference_delay_s - nonlocal runtime_support_validated - nonlocal last_reference_reset_seq - nonlocal last_reference_reset_timestamp_s - nonlocal last_valid_qpos - - last_reference_reset_seq = int(packet.seq) - last_reference_reset_timestamp_s = float(packet.timestamp_s) - last_body_seq = -1 - last_body_timestamp_s = None - body_dt_s_ema = None - last_valid_qpos = None - resolved_reference_delay_s = ( - float(ref_cfg.reference_delay_s) if ref_cfg.reference_delay_s is not None else None - ) - runtime_support_validated = ( - ref_cfg.reference_delay_s is not None or not reference_window_builder.requires_timeline - ) - if timeline is not None: - timeline.clear() - if reference_manager is not None: - reference_manager.set_warmup_steps(ref_cfg.realtime_buffer_warmup_steps) - reference_manager.reset() - logger.info( - "Retarget reference state reset | reason=%s seq=%d", - packet.reason, - last_reference_reset_seq, - ) - def _publish_invalid_reference(packet: BodyFramePacket, *, elapsed_s: float) -> None: qpos = np.zeros(FULL_QPOS_DIM, dtype=np.float64) qpos[3] = 1.0 @@ -463,7 +425,6 @@ def _publish_invalid_reference(packet: BodyFramePacket, *, elapsed_s: float) -> source_timestamp_s=float(packet.timestamp_s), source_seq=int(packet.seq), frame_valid=False, - reference_reset_seq=last_reference_reset_seq, retarget_elapsed_s=elapsed_s, ), ) @@ -481,22 +442,12 @@ def _publish_invalid_reference(packet: BodyFramePacket, *, elapsed_s: float) -> stop_event.set() break - reset_packet = reference_reset_sub.recv_latest() - if ( - isinstance(reset_packet, ReferenceResetPacket) - and int(reset_packet.seq) > last_reference_reset_seq - ): - _reset_reference_state(reset_packet) - packet = body_sub.recv_latest() if packet is None: time.sleep(idle_sleep_s) continue if not isinstance(packet, BodyFramePacket) or int(packet.seq) == last_body_seq: continue - if float(packet.timestamp_s) < last_reference_reset_timestamp_s: - continue - start_s = time.monotonic() frame_valid = _human_frame_is_valid(packet.frame, max_pos_value=max_position_value) if not frame_valid: @@ -556,7 +507,6 @@ def _publish_invalid_reference(packet: BodyFramePacket, *, elapsed_s: float) -> source_timestamp_s=float(packet.timestamp_s), source_seq=int(packet.seq), frame_valid=True, - reference_reset_seq=last_reference_reset_seq, reference_window=reference_window, retarget_elapsed_s=time.monotonic() - start_s, ), @@ -568,7 +518,6 @@ def _publish_invalid_reference(packet: BodyFramePacket, *, elapsed_s: float) -> body_sub.close() health_sub.close() command_sub.close() - reference_reset_sub.close() ref_pub.close() _worker_loop("retarget_worker", _main) @@ -632,14 +581,12 @@ def __init__( self._stale_reference_hold_s = float(cfg_get(mp_cfg, "stale_reference_hold_s", 0.08)) mocap_sw = cfg_get(cfg, "mocap_switch", {}) or {} self._check_frames = int(cfg_get(mocap_sw, "check_frames", 10)) - self._reference_reset_seq = 0 self._last_reference_seq = -1 self._consecutive_valid_references = 0 self._reference_sub = LatestSubscriber(endpoints.reference_pub, REFERENCE_TOPIC) self._events_sub = LatestSubscriber(endpoints.control_events_pub, CONTROL_EVENTS_TOPIC) self._command_sub = LatestSubscriber(endpoints.command_pub, COMMAND_TOPIC) - self._reference_reset_sub = LatestSubscriber(endpoints.mode_pub, REFERENCE_RESET_TOPIC) self._mode_pub = ZmqPublisher(endpoints.mode_pub) viewers = _parse_sim2real_viewers(cfg) @@ -697,7 +644,6 @@ def shutdown(self) -> None: self._reference_sub.close() self._events_sub.close() self._command_sub.close() - self._reference_reset_sub.close() self._mode_pub.close() self.robot.close() @@ -721,9 +667,6 @@ def _drain_ipc(self) -> None: if isinstance(command, CommandPacket) and command.command == "shutdown": self.stop_event.set() return - reference_reset = self._reference_reset_sub.recv_latest() - if isinstance(reference_reset, ReferenceResetPacket): - self._apply_reference_reset(reference_reset.seq) reference = self._reference_sub.recv_latest() if isinstance(reference, ReferencePacket): self._note_reference_packet(reference) @@ -797,9 +740,6 @@ def _mocap_step(self) -> None: if reference is None or age_s is None: self._hold_or_damp_stale_reference("no retarget reference") return - if int(reference.reference_reset_seq) != self._reference_reset_seq: - self._hold_or_damp_stale_reference("stale reset-generation retarget reference") - return if not reference.frame_valid: logger.warning("Retarget reference invalid -- holding last command") self._hold_or_damp_stale_reference("invalid retarget reference") @@ -876,7 +816,6 @@ def _enter_standing(self) -> None: self.robot.lock_all_joints() time.sleep(0.3) - self._publish_reference_reset("enter_standing") init_qpos = self._build_robot_state_qpos(state) self._last_retarget_qpos = init_qpos self._ref_proc.last_reference_qpos = None @@ -901,8 +840,6 @@ def _can_switch_to_mocap(self) -> bool: return False if not self._latest_reference.frame_valid: return False - if self._latest_reference.reference_reset_seq != self._reference_reset_seq: - return False if age_s > self._max_reference_age_s: return False if self._consecutive_valid_references < self._check_frames: @@ -918,7 +855,6 @@ def _transition_to_mocap(self) -> None: state = self.robot.get_state() resume_qpos = self._build_resume_alignment_qpos(self._standing_qpos, state) self._mocap_reentry_armed = False - self._publish_reference_reset("standing_to_mocap") self._reset_policy_state() self._last_retarget_qpos = None self._last_commanded_motion_qpos = resume_qpos.copy() @@ -933,7 +869,6 @@ def _enter_damping(self) -> None: time.sleep(0.5) logger.info("DAMPING: exiting debug mode...") self.robot.exit_debug_mode() - self._publish_reference_reset("enter_damping") self.mode = RobotMode.DAMPING self._ref_proc.last_reference_qpos = None self._mocap_reentry_armed = False @@ -1004,7 +939,6 @@ def _pause_active_mocap(self) -> None: self._last_retarget_qpos = hold_qpos.copy() self._ref_proc.last_reference_qpos = hold_qpos.copy() self._last_commanded_motion_qpos = hold_qpos.copy() - self._publish_reference_reset("pause_active_mocap") self._reset_policy_reference_state() self._mocap_session.pause(hold_qpos) logger.info("Mocap session -> PAUSED (multiprocess episode-reset)") @@ -1016,7 +950,6 @@ def _resume_paused_mocap(self) -> None: state = self.robot.get_state() resume_qpos = self._build_resume_alignment_qpos(hold_qpos, state) self._last_commanded_motion_qpos = resume_qpos.copy() - self._publish_reference_reset("resume_paused_mocap") self._reset_policy_reference_state() self._last_retarget_qpos = None self._last_commanded_motion_qpos = resume_qpos.copy() @@ -1100,20 +1033,7 @@ def _reference_age_s(self) -> float | None: return None return max(0.0, time.monotonic() - float(self._latest_reference.timestamp_s)) - def _apply_reference_reset(self, reference_reset_seq: int) -> None: - reset_seq = int(reference_reset_seq) - if reset_seq <= self._reference_reset_seq: - return - self._reference_reset_seq = reset_seq - self._latest_reference = None - self._last_reference_seq = -1 - self._consecutive_valid_references = 0 - def _note_reference_packet(self, reference: ReferencePacket) -> None: - if int(reference.reference_reset_seq) < self._reference_reset_seq: - return - if int(reference.reference_reset_seq) > self._reference_reset_seq: - self._apply_reference_reset(int(reference.reference_reset_seq)) if int(reference.seq) <= self._last_reference_seq: return self._last_reference_seq = int(reference.seq) @@ -1123,20 +1043,6 @@ def _note_reference_packet(self, reference: ReferencePacket) -> None: return self._consecutive_valid_references += 1 - def _publish_reference_reset(self, reason: str) -> None: - self._reference_reset_seq += 1 - self._latest_reference = None - self._last_reference_seq = -1 - self._consecutive_valid_references = 0 - self._mode_pub.publish( - REFERENCE_RESET_TOPIC, - ReferenceResetPacket( - reason=reason, - timestamp_s=time.monotonic(), - seq=self._reference_reset_seq, - ), - ) - @staticmethod def _sleep_until(t0: float, dt: float) -> float: elapsed = time.monotonic() - t0 diff --git a/tests/test_sim2real_multiprocess.py b/tests/test_sim2real_multiprocess.py index d7e6b0f1..87117c5f 100644 --- a/tests/test_sim2real_multiprocess.py +++ b/tests/test_sim2real_multiprocess.py @@ -8,9 +8,9 @@ import pytest from teleopit.runtime.mocap_session import MocapSessionState -from teleopit.sim2real.mp.ipc import HEALTH_TOPIC, REFERENCE_RESET_TOPIC, LatestSubscriber, ZmqPublisher +from teleopit.sim2real.mp.ipc import HEALTH_TOPIC, LatestSubscriber, ZmqPublisher from teleopit.sim2real.mp import resolve_sim2real_runtime_mode -from teleopit.sim2real.mp.messages import ReferencePacket, ReferenceResetPacket, SharedFrameDescriptor +from teleopit.sim2real.mp.messages import ReferencePacket, SharedFrameDescriptor from teleopit.sim2real.mp.runtime import MultiprocessSim2RealController, _RobotControlWorker, _human_frame_is_valid from teleopit.sim2real.mp.shm import SharedFrameRingReader, SharedFrameRingWriter @@ -167,9 +167,8 @@ def test_human_frame_validation_rejects_bad_inputs() -> None: assert not _human_frame_is_valid(bad_frame, max_pos_value=5.0) -def test_robot_worker_requires_consecutive_valid_references_and_reset_generation(monkeypatch) -> None: +def test_robot_worker_requires_consecutive_valid_references(monkeypatch) -> None: worker = object.__new__(_RobotControlWorker) - worker._reference_reset_seq = 0 worker._latest_reference = None worker._last_reference_seq = -1 worker._consecutive_valid_references = 0 @@ -190,7 +189,6 @@ def test_robot_worker_requires_consecutive_valid_references_and_reset_generation source_timestamp_s=1.0, source_seq=1, frame_valid=True, - reference_reset_seq=0, ) worker._note_reference_packet(valid0) assert worker._can_switch_to_mocap() is False @@ -202,7 +200,6 @@ def test_robot_worker_requires_consecutive_valid_references_and_reset_generation source_timestamp_s=1.1, source_seq=2, frame_valid=True, - reference_reset_seq=0, ) worker._note_reference_packet(valid1) assert worker._can_switch_to_mocap() is True @@ -214,43 +211,17 @@ def test_robot_worker_requires_consecutive_valid_references_and_reset_generation source_timestamp_s=1.2, source_seq=3, frame_valid=False, - reference_reset_seq=0, ) worker._note_reference_packet(invalid) assert worker._can_switch_to_mocap() is False - worker._publish_reference_reset("enter_standing") - assert worker._reference_reset_seq == 1 - assert worker._latest_reference is None - assert worker._consecutive_valid_references == 0 - assert worker._mode_pub_events - topic, payload = worker._mode_pub_events[-1] - assert topic == REFERENCE_RESET_TOPIC - assert isinstance(payload, ReferenceResetPacket) - assert payload.seq == 1 - assert payload.reason == "enter_standing" - - old_packet = ReferencePacket( - qpos=np.zeros(36, dtype=np.float64), - timestamp_s=1.3, - seq=4, - source_timestamp_s=1.3, - source_seq=4, - frame_valid=True, - reference_reset_seq=0, - ) - worker._note_reference_packet(old_packet) - assert worker._latest_reference is None - assert worker._consecutive_valid_references == 0 - fresh_packet = ReferencePacket( qpos=np.zeros(36, dtype=np.float64), timestamp_s=1.4, - seq=5, + seq=4, source_timestamp_s=1.4, - source_seq=5, + source_seq=4, frame_valid=True, - reference_reset_seq=1, ) worker._note_reference_packet(fresh_packet) assert worker._latest_reference == fresh_packet From 9aa533a8858c311a66bca76e4961614a5c5098ff Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Fri, 29 May 2026 17:32:53 +0800 Subject: [PATCH 059/122] Fix multiprocess Pico video frame streaming --- teleopit/inputs/pico_video.py | 21 +++++++++++++++++-- teleopit/sim2real/mp/runtime.py | 37 ++++++++++++--------------------- 2 files changed, 32 insertions(+), 26 deletions(-) diff --git a/teleopit/inputs/pico_video.py b/teleopit/inputs/pico_video.py index 7f01f43f..a1c3191b 100644 --- a/teleopit/inputs/pico_video.py +++ b/teleopit/inputs/pico_video.py @@ -82,6 +82,13 @@ def __init__( def enabled(self) -> bool: return self._config.enabled + @property + def pushed_frames(self) -> int: + producer = self._producer + if producer is None: + return 0 + return int(getattr(producer, "pushed_frames", 0)) + def start(self) -> None: if not self._config.enabled: return @@ -152,6 +159,11 @@ def __init__(self, provider: Any, config: PicoVideoConfig) -> None: self._ready_event = threading.Event() self._thread = threading.Thread(target=self._run, name="pico_realsense_video", daemon=True) self._error: BaseException | None = None + self._pushed_frames = 0 + + @property + def pushed_frames(self) -> int: + return int(self._pushed_frames) def start(self) -> None: self._thread.start() @@ -194,7 +206,7 @@ def _run(self) -> None: if not color_frame: continue rgb = np.ascontiguousarray(np.asanyarray(color_frame.get_data()), dtype=np.uint8) - self._provider.push_video_frame(rgb) + self._pushed_frames = int(self._provider.push_video_frame(rgb)) finally: pipeline.stop() except BaseException as exc: @@ -211,6 +223,11 @@ def __init__(self, provider: Any, config: PicoVideoConfig, robot: Any | None) -> self._renderer: Any | None = None self._next_frame_time = 0.0 self._camera_name = "d435i_rgb" + self._pushed_frames = 0 + + @property + def pushed_frames(self) -> int: + return int(self._pushed_frames) def start(self) -> None: if self._robot is None: @@ -236,7 +253,7 @@ def tick(self) -> None: raise RuntimeError("MuJoCo Pico video requires robot.data") self._renderer.update_scene(data, camera=self._camera_name) frame = np.ascontiguousarray(self._renderer.render(), dtype=np.uint8) - self._provider.push_video_frame(frame) + self._pushed_frames = int(self._provider.push_video_frame(frame)) self._next_frame_time = now + 1.0 / float(self._config.fps) def stop(self) -> None: diff --git a/teleopit/sim2real/mp/runtime.py b/teleopit/sim2real/mp/runtime.py index fe955351..19f6f72e 100644 --- a/teleopit/sim2real/mp/runtime.py +++ b/teleopit/sim2real/mp/runtime.py @@ -16,7 +16,7 @@ from teleopit.controllers.observation import VelCmdObservationBuilder, align_motion_qpos_yaw from teleopit.controllers.rl_policy import RLPolicyController from teleopit.inputs.pico4_provider import Pico4InputProvider -from teleopit.inputs.pico_video import bridge_video_source, parse_pico_video_config +from teleopit.inputs.pico_video import PicoVideoRuntime, bridge_video_source, parse_pico_video_config from teleopit.inputs.realtime_packet import ControlEvent, ControlEventType from teleopit.retargeting.core import RetargetingModule from teleopit.runtime.common import cfg_get, require_section @@ -58,7 +58,7 @@ SharedFrameDescriptor, SnapshotPacket, ) -from teleopit.sim2real.mp.shm import SharedFrameRingReader, SharedFrameRingWriter +from teleopit.sim2real.mp.shm import SharedFrameRingWriter from teleopit.sim2real.reference_processor import Sim2RealReferenceProcessor from teleopit.sim2real.remote import UnitreeRemote from teleopit.sim2real.safety import Sim2RealSafetyManager @@ -208,11 +208,8 @@ def _start_processes(self) -> None: if hand_mode != "off": specs.append(("hand_worker", _run_hand_worker)) video_cfg = parse_pico_video_config(cfg_get(self.cfg, "input", {})) - if video_cfg.enabled and video_cfg.source not in (None, "test-pattern"): - specs.append(("video_worker", _run_video_worker)) - elif video_cfg.enabled and video_cfg.source == "test-pattern": - # pico-bridge can generate test-pattern internally without a camera worker. - logger.info("Pico video test-pattern uses pico_bridge internal source") + if video_cfg.enabled: + logger.info("Pico video runs inside pico_io so frames are pushed directly to PicoBridge") for name, target in specs: process = self._ctx.Process( @@ -255,12 +252,11 @@ def _main() -> None: events_pub = ZmqPublisher(endpoints.control_events_pub) health_pub = ZmqPublisher(endpoints.health_pub) command_sub = LatestSubscriber(endpoints.command_pub, COMMAND_TOPIC) - video_sub = ( - LatestSubscriber(endpoints.video_pub, VIDEO_TOPIC) - if video_cfg.enabled and video_cfg.source not in (None, "test-pattern") - else None + video_runtime = PicoVideoRuntime( + provider=provider, + config=video_cfg, + mode="sim2real", ) - frame_reader = SharedFrameRingReader() hz = float(cfg_get(_mp_cfg(cfg), "pico_io_hz", 120.0)) sleep_s = 1.0 / max(hz, 1.0) @@ -270,7 +266,9 @@ def _main() -> None: last_video_seq = -1 last_health_s = 0.0 try: + video_runtime.start() while not stop_event.is_set(): + video_runtime.tick() command = command_sub.recv_latest() if isinstance(command, CommandPacket) and command.command == "shutdown": stop_event.set() @@ -321,15 +319,8 @@ def _main() -> None: ) last_hand_seq = int(hand_snapshot.seq) - if video_sub is not None: - descriptor = video_sub.recv_latest() - if isinstance(descriptor, SharedFrameDescriptor): - try: - frame = frame_reader.read(descriptor, copy=False) - provider.push_video_frame(np.asarray(frame, dtype=np.uint8)) - last_video_seq = int(descriptor.seq) - except Exception as exc: - logger.warning("Pico video frame dropped: %s", exc) + if video_cfg.enabled: + last_video_seq = int(video_runtime.pushed_frames) if now - last_health_s >= 1.0: health_pub.publish( @@ -349,9 +340,7 @@ def _main() -> None: last_health_s = now time.sleep(sleep_s) finally: - frame_reader.close() - if video_sub is not None: - video_sub.close() + video_runtime.stop() command_sub.close() for publisher in (body_pub, hand_pub, controller_pub, events_pub, health_pub): publisher.close() From 991c49aeb68f14e363ebcd1f439df368ead9d248 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Fri, 29 May 2026 19:21:14 +0800 Subject: [PATCH 060/122] Add Pico signal diagnostics --- scripts/run/check_pico_signal.py | 264 ++++++++++++++++++++++ teleopit/inputs/human_frame_validation.py | 129 +++++++++++ teleopit/sim2real/mp/runtime.py | 21 +- tests/test_human_frame_validation.py | 44 ++++ 4 files changed, 439 insertions(+), 19 deletions(-) create mode 100644 scripts/run/check_pico_signal.py create mode 100644 teleopit/inputs/human_frame_validation.py create mode 100644 tests/test_human_frame_validation.py diff --git a/scripts/run/check_pico_signal.py b/scripts/run/check_pico_signal.py new file mode 100644 index 00000000..2de7cd56 --- /dev/null +++ b/scripts/run/check_pico_signal.py @@ -0,0 +1,264 @@ +"""Pico mocap/video signal diagnostic entry point.""" + +from __future__ import annotations + +from collections import Counter +import logging +import time +from typing import Any + +import hydra +import numpy as np +from omegaconf import DictConfig + +from teleopit.inputs.human_frame_validation import HumanFrameValidationResult, validate_human_frame +from teleopit.inputs.pico4_provider import Pico4InputProvider +from teleopit.inputs.pico_video import PicoVideoRuntime, bridge_video_source, parse_pico_video_config +from teleopit.runtime.common import cfg_get + + +logger = logging.getLogger("teleopit.tools.check_pico_signal") + + +def _fmt_vec(values: tuple[float, ...] | None) -> str: + if values is None: + return "None" + return "[" + ", ".join(f"{value:.4f}" for value in values) + "]" + + +def _frame_stats(frame: dict[str, Any]) -> dict[str, Any]: + positions = [] + quat_norms = [] + pelvis_pos = None + for name, value in frame.items(): + try: + pos, quat = value + except Exception: + continue + try: + pos_arr = np.asarray(pos, dtype=np.float64).reshape(-1) + quat_arr = np.asarray(quat, dtype=np.float64).reshape(-1) + except Exception: + continue + if pos_arr.shape[0] >= 3 and np.all(np.isfinite(pos_arr[:3])): + positions.append(pos_arr[:3]) + if str(name) == "Pelvis": + pelvis_pos = pos_arr[:3].copy() + if quat_arr.size > 0 and np.all(np.isfinite(quat_arr)): + quat_norms.append(float(np.linalg.norm(quat_arr))) + + if not positions: + return {} + + pos = np.asarray(positions, dtype=np.float64) + return { + "pelvis_pos": pelvis_pos, + "min_pos": np.min(pos, axis=0), + "max_pos": np.max(pos, axis=0), + "extent": np.ptp(pos, axis=0), + "max_abs_pos": float(np.max(np.abs(pos))), + "quat_norm_min": min(quat_norms) if quat_norms else None, + "quat_norm_max": max(quat_norms) if quat_norms else None, + } + + +def _log_invalid(seq: int, age_ms: float, result: HumanFrameValidationResult) -> None: + logger.warning( + "Invalid Pico body frame | seq=%s age_ms=%.1f reason=%s joint=%s " + "max_abs_pos=%s threshold=%s pos=%s quat=%s detail=%s", + seq, + age_ms, + result.reason, + result.joint_name, + f"{result.max_abs_pos:.4f}" if result.max_abs_pos is not None else "None", + f"{result.max_pos_value:.4f}" if result.max_pos_value is not None else "None", + _fmt_vec(result.pos), + _fmt_vec(result.quat), + result.detail, + ) + + +def _log_summary( + *, + window_s: float, + total: int, + valid: int, + invalid_reasons: Counter[str], + provider_fps: float, + last_seq: int | None, + last_age_ms: float | None, + last_stats: dict[str, Any], + pushed_video_frames: int, +) -> None: + if total <= 0: + logger.info( + "Pico signal summary | window=%.1fs samples=0 provider_fps=%.1f " + "last_seq=%s video_frames=%d", + window_s, + provider_fps, + last_seq, + pushed_video_frames, + ) + return + + invalid = total - valid + reason_text = ",".join(f"{reason}:{count}" for reason, count in invalid_reasons.most_common()) or "none" + pelvis = last_stats.get("pelvis_pos") + extent = last_stats.get("extent") + min_pos = last_stats.get("min_pos") + max_pos = last_stats.get("max_pos") + logger.info( + "Pico signal summary | window=%.1fs samples=%d valid=%d invalid=%d reasons=%s " + "provider_fps=%.1f last_seq=%s last_age_ms=%s video_frames=%d " + "max_abs_pos=%s pelvis=%s extent=%s min=%s max=%s quat_norm=[%s,%s]", + window_s, + total, + valid, + invalid, + reason_text, + provider_fps, + last_seq, + f"{last_age_ms:.1f}" if last_age_ms is not None else "None", + pushed_video_frames, + f"{last_stats.get('max_abs_pos'):.4f}" if "max_abs_pos" in last_stats else "None", + _fmt_np_vec(pelvis), + _fmt_np_vec(extent), + _fmt_np_vec(min_pos), + _fmt_np_vec(max_pos), + _fmt_float(last_stats.get("quat_norm_min")), + _fmt_float(last_stats.get("quat_norm_max")), + ) + + +def _fmt_np_vec(values: Any) -> str: + if values is None: + return "None" + arr = np.asarray(values, dtype=np.float64).reshape(-1) + return "[" + ", ".join(f"{float(value):.4f}" for value in arr) + "]" + + +def _fmt_float(value: Any) -> str: + if value is None: + return "None" + return f"{float(value):.4f}" + + +def _build_provider(cfg: DictConfig, video_enabled: bool) -> Pico4InputProvider: + input_cfg = cfg_get(cfg, "input", {}) or {} + video_cfg = parse_pico_video_config(input_cfg) + return Pico4InputProvider( + human_format=str(cfg_get(input_cfg, "human_format", "pico_bridge")), + timeout=float(cfg_get(input_cfg, "pico4_timeout", 60.0)), + buffer_size=int(cfg_get(input_cfg, "pico4_buffer_size", 60)), + timestamp_gap_reset_s=float(cfg_get(input_cfg, "pico4_timestamp_gap_reset_s", 0.15)), + pause_button=cfg_get(input_cfg, "pause_button", "A"), + pause_debounce_s=float(cfg_get(input_cfg, "pause_debounce_s", 0.25)), + bridge_host=str(cfg_get(input_cfg, "bridge_host", "0.0.0.0")), + bridge_port=int(cfg_get(input_cfg, "bridge_port", 63901)), + bridge_discovery=bool(cfg_get(input_cfg, "bridge_discovery", True)), + bridge_advertise_ip=cfg_get(input_cfg, "bridge_advertise_ip", None), + bridge_video=bridge_video_source(video_cfg), + bridge_video_enabled=video_enabled, + bridge_start_timeout=float(cfg_get(input_cfg, "bridge_start_timeout", 10.0)), + bridge_history_size=int(cfg_get(input_cfg, "bridge_history_size", 120)), + ) + + +@hydra.main(version_base=None, config_path="../../teleopit/configs", config_name="pico4_sim2real") +def main(cfg: DictConfig) -> None: + logging.basicConfig(level=logging.INFO, format="%(levelname)s:%(name)s:%(message)s") + input_cfg = cfg_get(cfg, "input", {}) or {} + video_cfg = parse_pico_video_config(input_cfg) + mocap_switch = cfg_get(cfg, "mocap_switch", {}) or {} + max_pos_value = float(cfg_get(mocap_switch, "max_position_value", 5.0)) + diag_cfg = cfg_get(cfg, "diagnostic", {}) or {} + poll_hz = float(cfg_get(diag_cfg, "poll_hz", cfg_get(cfg_get(cfg, "multiprocess", {}) or {}, "pico_io_hz", 120.0))) + summary_interval_s = float(cfg_get(diag_cfg, "summary_interval_s", 1.0)) + duration_s = float(cfg_get(diag_cfg, "duration_s", 0.0)) + + logger.info("Starting Pico signal diagnostic") + logger.info( + "Pico bridge | host=%s port=%s discovery=%s advertise_ip=%s", + cfg_get(input_cfg, "bridge_host", "0.0.0.0"), + cfg_get(input_cfg, "bridge_port", 63901), + cfg_get(input_cfg, "bridge_discovery", True), + cfg_get(input_cfg, "bridge_advertise_ip", None), + ) + logger.info( + "Signal check | max_position_value=%.3fm poll_hz=%.1f summary_interval_s=%.1f " + "duration_s=%s video_enabled=%s video_source=%s", + max_pos_value, + poll_hz, + summary_interval_s, + f"{duration_s:.1f}" if duration_s > 0.0 else "until Ctrl-C", + video_cfg.enabled, + video_cfg.source, + ) + + provider = _build_provider(cfg, video_cfg.enabled) + video_runtime = PicoVideoRuntime(provider=provider, config=video_cfg, mode="sim2real") + total = 0 + valid = 0 + invalid_reasons: Counter[str] = Counter() + last_seq: int | None = None + last_age_ms: float | None = None + last_stats: dict[str, Any] = {} + window_start_s = time.monotonic() + start_s = window_start_s + sleep_s = 1.0 / max(poll_hz, 1.0) + + try: + video_runtime.start() + while True: + now = time.monotonic() + if duration_s > 0.0 and now - start_s >= duration_s: + break + + video_runtime.tick() + if provider.has_frame(): + try: + frame, timestamp_s, seq = provider.get_frame_packet() + except Exception: + logger.exception("Failed to read Pico body frame packet") + else: + seq = int(seq) + if seq != last_seq: + last_seq = seq + last_age_ms = max((time.monotonic() - float(timestamp_s)) * 1000.0, 0.0) + last_stats = _frame_stats(frame) + result = validate_human_frame(frame, max_pos_value=max_pos_value) + total += 1 + if result.valid: + valid += 1 + else: + invalid_reasons[result.reason] += 1 + _log_invalid(seq, last_age_ms, result) + + now = time.monotonic() + if now - window_start_s >= summary_interval_s: + _log_summary( + window_s=now - window_start_s, + total=total, + valid=valid, + invalid_reasons=invalid_reasons, + provider_fps=float(provider.fps), + last_seq=last_seq, + last_age_ms=last_age_ms, + last_stats=last_stats, + pushed_video_frames=video_runtime.pushed_frames, + ) + total = 0 + valid = 0 + invalid_reasons.clear() + window_start_s = now + + time.sleep(sleep_s) + except KeyboardInterrupt: + logger.info("KeyboardInterrupt -- stopping Pico signal diagnostic") + finally: + video_runtime.stop() + provider.close() + + +if __name__ == "__main__": + main() diff --git a/teleopit/inputs/human_frame_validation.py b/teleopit/inputs/human_frame_validation.py new file mode 100644 index 00000000..4ce86052 --- /dev/null +++ b/teleopit/inputs/human_frame_validation.py @@ -0,0 +1,129 @@ +"""HumanFrame sanity validation shared by realtime diagnostics and runtimes.""" + +from __future__ import annotations + +from dataclasses import dataclass + +import numpy as np + + +@dataclass(frozen=True) +class HumanFrameValidationResult: + valid: bool + reason: str = "ok" + joint_name: str | None = None + pos: tuple[float, ...] | None = None + quat: tuple[float, ...] | None = None + max_abs_pos: float | None = None + max_pos_value: float | None = None + detail: str = "" + + +def validate_human_frame(frame: object, *, max_pos_value: float) -> HumanFrameValidationResult: + """Validate the same HumanFrame conditions used before realtime retargeting.""" + if not isinstance(frame, dict): + return HumanFrameValidationResult(False, reason="frame_not_dict", detail=f"type={type(frame)!r}") + + max_pos = float(max_pos_value) + if not np.isfinite(max_pos) or max_pos <= 0.0: + return HumanFrameValidationResult( + False, + reason="invalid_max_position_value", + max_pos_value=max_pos, + ) + + for name, value in frame.items(): + joint_name = str(name) + try: + pos, quat = value + except Exception as exc: + return HumanFrameValidationResult( + False, + reason="joint_unpack_failed", + joint_name=joint_name, + max_pos_value=max_pos, + detail=str(exc), + ) + + try: + pos_arr = np.asarray(pos, dtype=np.float64).reshape(-1) + except Exception as exc: + return HumanFrameValidationResult( + False, + reason="position_cast_failed", + joint_name=joint_name, + max_pos_value=max_pos, + detail=str(exc), + ) + try: + quat_arr = np.asarray(quat, dtype=np.float64).reshape(-1) + except Exception as exc: + return HumanFrameValidationResult( + False, + reason="quaternion_cast_failed", + joint_name=joint_name, + pos=_to_tuple(pos_arr), + max_pos_value=max_pos, + detail=str(exc), + ) + + pos_tuple = _to_tuple(pos_arr) + quat_tuple = _to_tuple(quat_arr) + max_abs_pos = float(np.max(np.abs(pos_arr))) if pos_arr.size > 0 else 0.0 + + if np.any(np.isnan(pos_arr)): + return HumanFrameValidationResult( + False, + reason="position_nan", + joint_name=joint_name, + pos=pos_tuple, + quat=quat_tuple, + max_abs_pos=max_abs_pos, + max_pos_value=max_pos, + ) + if np.any(np.isinf(pos_arr)): + return HumanFrameValidationResult( + False, + reason="position_inf", + joint_name=joint_name, + pos=pos_tuple, + quat=quat_tuple, + max_abs_pos=max_abs_pos, + max_pos_value=max_pos, + ) + if np.any(np.abs(pos_arr) > max_pos): + return HumanFrameValidationResult( + False, + reason="position_out_of_range", + joint_name=joint_name, + pos=pos_tuple, + quat=quat_tuple, + max_abs_pos=max_abs_pos, + max_pos_value=max_pos, + ) + if np.any(np.isnan(quat_arr)): + return HumanFrameValidationResult( + False, + reason="quaternion_nan", + joint_name=joint_name, + pos=pos_tuple, + quat=quat_tuple, + max_abs_pos=max_abs_pos, + max_pos_value=max_pos, + ) + if np.any(np.isinf(quat_arr)): + return HumanFrameValidationResult( + False, + reason="quaternion_inf", + joint_name=joint_name, + pos=pos_tuple, + quat=quat_tuple, + max_abs_pos=max_abs_pos, + max_pos_value=max_pos, + ) + + return HumanFrameValidationResult(True, max_pos_value=max_pos) + + +def _to_tuple(values: np.ndarray) -> tuple[float, ...]: + return tuple(float(value) for value in np.asarray(values, dtype=np.float64).reshape(-1)) diff --git a/teleopit/sim2real/mp/runtime.py b/teleopit/sim2real/mp/runtime.py index 19f6f72e..fbcb7c71 100644 --- a/teleopit/sim2real/mp/runtime.py +++ b/teleopit/sim2real/mp/runtime.py @@ -15,6 +15,7 @@ from teleopit.constants import FULL_QPOS_DIM, NUM_JOINTS, ROOT_DIM from teleopit.controllers.observation import VelCmdObservationBuilder, align_motion_qpos_yaw from teleopit.controllers.rl_policy import RLPolicyController +from teleopit.inputs.human_frame_validation import validate_human_frame from teleopit.inputs.pico4_provider import Pico4InputProvider from teleopit.inputs.pico_video import PicoVideoRuntime, bridge_video_source, parse_pico_video_config from teleopit.inputs.realtime_packet import ControlEvent, ControlEventType @@ -117,25 +118,7 @@ def _worker_loop(name: str, fn: Callable[[], None]) -> None: def _human_frame_is_valid(frame: object, *, max_pos_value: float) -> bool: - if not isinstance(frame, dict): - return False - max_pos = float(max_pos_value) - if not np.isfinite(max_pos) or max_pos <= 0.0: - return False - for value in frame.values(): - try: - pos, quat = value - except Exception: - return False - pos_arr = np.asarray(pos, dtype=np.float64).reshape(-1) - quat_arr = np.asarray(quat, dtype=np.float64).reshape(-1) - if np.any(np.isnan(pos_arr)) or np.any(np.isinf(pos_arr)): - return False - if np.any(np.abs(pos_arr) > max_pos): - return False - if np.any(np.isnan(quat_arr)) or np.any(np.isinf(quat_arr)): - return False - return True + return validate_human_frame(frame, max_pos_value=max_pos_value).valid class MultiprocessSim2RealController: diff --git a/tests/test_human_frame_validation.py b/tests/test_human_frame_validation.py new file mode 100644 index 00000000..25400102 --- /dev/null +++ b/tests/test_human_frame_validation.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +import numpy as np + +from teleopit.inputs.human_frame_validation import validate_human_frame + + +def test_validate_human_frame_reports_out_of_range_joint() -> None: + frame = { + "Pelvis": (np.array([6.0, 0.0, 0.0]), np.array([1.0, 0.0, 0.0, 0.0])), + } + + result = validate_human_frame(frame, max_pos_value=5.0) + + assert not result.valid + assert result.reason == "position_out_of_range" + assert result.joint_name == "Pelvis" + assert result.pos == (6.0, 0.0, 0.0) + assert result.max_abs_pos == 6.0 + assert result.max_pos_value == 5.0 + + +def test_validate_human_frame_reports_quaternion_nan() -> None: + frame = { + "Head": (np.array([0.0, 0.0, 1.5]), np.array([1.0, np.nan, 0.0, 0.0])), + } + + result = validate_human_frame(frame, max_pos_value=5.0) + + assert not result.valid + assert result.reason == "quaternion_nan" + assert result.joint_name == "Head" + + +def test_validate_human_frame_accepts_finite_frame() -> None: + frame = { + "Pelvis": (np.array([0.0, 0.0, 1.0]), np.array([1.0, 0.0, 0.0, 0.0])), + "Head": (np.array([0.0, 0.0, 1.7]), np.array([1.0, 0.0, 0.0, 0.0])), + } + + result = validate_human_frame(frame, max_pos_value=5.0) + + assert result.valid + assert result.reason == "ok" From 041121489a635d378bbe7b28b93c6fa768439c19 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Fri, 29 May 2026 19:32:34 +0800 Subject: [PATCH 061/122] Make Pico signal diagnostics interruptible --- scripts/run/check_pico_signal.py | 42 +++++++++++++++++++++++++++++--- 1 file changed, 39 insertions(+), 3 deletions(-) diff --git a/scripts/run/check_pico_signal.py b/scripts/run/check_pico_signal.py index 2de7cd56..6d1e19ba 100644 --- a/scripts/run/check_pico_signal.py +++ b/scripts/run/check_pico_signal.py @@ -4,7 +4,10 @@ from collections import Counter import logging +import os +import signal import time +import threading from typing import Any import hydra @@ -164,6 +167,31 @@ def _build_provider(cfg: DictConfig, video_enabled: bool) -> Pico4InputProvider: ) +def _install_signal_handlers(stop_event: threading.Event) -> None: + def _handle_signal(signum: int, _frame: Any) -> None: + if stop_event.is_set(): + os._exit(130) + logger.info("Received signal %s -- shutting down", signum) + stop_event.set() + + signal.signal(signal.SIGINT, _handle_signal) + signal.signal(signal.SIGTERM, _handle_signal) + + +def _start_video_runtime_async(video_runtime: PicoVideoRuntime) -> threading.Event: + done = threading.Event() + + def _run() -> None: + try: + video_runtime.start() + finally: + done.set() + + thread = threading.Thread(target=_run, name="pico_video_start", daemon=True) + thread.start() + return done + + @hydra.main(version_base=None, config_path="../../teleopit/configs", config_name="pico4_sim2real") def main(cfg: DictConfig) -> None: logging.basicConfig(level=logging.INFO, format="%(levelname)s:%(name)s:%(message)s") @@ -195,6 +223,8 @@ def main(cfg: DictConfig) -> None: video_cfg.source, ) + stop_event = threading.Event() + _install_signal_handlers(stop_event) provider = _build_provider(cfg, video_cfg.enabled) video_runtime = PicoVideoRuntime(provider=provider, config=video_cfg, mode="sim2real") total = 0 @@ -206,10 +236,13 @@ def main(cfg: DictConfig) -> None: window_start_s = time.monotonic() start_s = window_start_s sleep_s = 1.0 / max(poll_hz, 1.0) + video_start_done: threading.Event | None = None try: - video_runtime.start() - while True: + if video_cfg.enabled: + logger.info("Starting Pico video backend asynchronously") + video_start_done = _start_video_runtime_async(video_runtime) + while not stop_event.is_set(): now = time.monotonic() if duration_s > 0.0 and now - start_s >= duration_s: break @@ -252,7 +285,10 @@ def main(cfg: DictConfig) -> None: invalid_reasons.clear() window_start_s = now - time.sleep(sleep_s) + if video_start_done is not None and not video_start_done.is_set() and now - start_s >= 5.0: + logger.info("Waiting for Pico video backend to become ready in the background") + video_start_done = None + stop_event.wait(timeout=sleep_s) except KeyboardInterrupt: logger.info("KeyboardInterrupt -- stopping Pico signal diagnostic") finally: From 7f213013d7ae160311c944f602a73f62080d8ac2 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Fri, 29 May 2026 19:53:59 +0800 Subject: [PATCH 062/122] Remove Pico position threshold validation --- docs/docs/configuration/config-reference.md | 1 - .../current/configuration/config-reference.md | 1 - scripts/run/check_pico_signal.py | 10 ++--- teleopit/configs/pico4_sim2real.yaml | 1 - teleopit/configs/sim2real.yaml | 1 - teleopit/inputs/human_frame_validation.py | 38 ++++-------------- teleopit/sim2real/controller.py | 2 - teleopit/sim2real/mp/runtime.py | 10 ++--- teleopit/sim2real/reference_processor.py | 12 +----- tests/test_human_frame_validation.py | 39 ++++++++++++++----- tests/test_sim2real_multiprocess.py | 6 +-- 11 files changed, 49 insertions(+), 72 deletions(-) diff --git a/docs/docs/configuration/config-reference.md b/docs/docs/configuration/config-reference.md index fd1aad80..213ce1b7 100644 --- a/docs/docs/configuration/config-reference.md +++ b/docs/docs/configuration/config-reference.md @@ -101,7 +101,6 @@ and `all` are simulation-only viewer modes. | `startup_ramp_duration` | Kp ramp duration after entering `STANDING`; gradually increases PD gains without changing policy targets | `2.0` | | `joint_vel_limit` | Joint velocity limit (rad/s); triggers emergency damping if exceeded | `10.0` | | `mocap_switch.check_frames` | Consecutive valid frames required before switching to MOCAP | `10` | -| `mocap_switch.max_position_value` | Position sanity threshold in meters | `5.0` | ### Real Robot diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md index 5aabbddd..942bbdcb 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md @@ -120,7 +120,6 @@ MuJoCo 窗口显示重定向参考;`sim2sim`、`mocap`、`camera` 和 `all` | `startup_ramp_duration` | 进入 `STANDING` 后的 Kp ramp 时长;逐步提高 PD 增益,不改变 policy target | `2.0` | | `joint_vel_limit` | 关节速度限制(rad/s),超过时触发急停 | `10.0` | | `mocap_switch.check_frames` | 切换到 MOCAP 前所需的连续有效帧数 | `10` | -| `mocap_switch.max_position_value` | 位置合理性阈值(米) | `5.0` | ### 真机 SDK diff --git a/scripts/run/check_pico_signal.py b/scripts/run/check_pico_signal.py index 6d1e19ba..b30cb671 100644 --- a/scripts/run/check_pico_signal.py +++ b/scripts/run/check_pico_signal.py @@ -68,13 +68,12 @@ def _frame_stats(frame: dict[str, Any]) -> dict[str, Any]: def _log_invalid(seq: int, age_ms: float, result: HumanFrameValidationResult) -> None: logger.warning( "Invalid Pico body frame | seq=%s age_ms=%.1f reason=%s joint=%s " - "max_abs_pos=%s threshold=%s pos=%s quat=%s detail=%s", + "max_abs_pos=%s pos=%s quat=%s detail=%s", seq, age_ms, result.reason, result.joint_name, f"{result.max_abs_pos:.4f}" if result.max_abs_pos is not None else "None", - f"{result.max_pos_value:.4f}" if result.max_pos_value is not None else "None", _fmt_vec(result.pos), _fmt_vec(result.quat), result.detail, @@ -197,8 +196,6 @@ def main(cfg: DictConfig) -> None: logging.basicConfig(level=logging.INFO, format="%(levelname)s:%(name)s:%(message)s") input_cfg = cfg_get(cfg, "input", {}) or {} video_cfg = parse_pico_video_config(input_cfg) - mocap_switch = cfg_get(cfg, "mocap_switch", {}) or {} - max_pos_value = float(cfg_get(mocap_switch, "max_position_value", 5.0)) diag_cfg = cfg_get(cfg, "diagnostic", {}) or {} poll_hz = float(cfg_get(diag_cfg, "poll_hz", cfg_get(cfg_get(cfg, "multiprocess", {}) or {}, "pico_io_hz", 120.0))) summary_interval_s = float(cfg_get(diag_cfg, "summary_interval_s", 1.0)) @@ -213,9 +210,8 @@ def main(cfg: DictConfig) -> None: cfg_get(input_cfg, "bridge_advertise_ip", None), ) logger.info( - "Signal check | max_position_value=%.3fm poll_hz=%.1f summary_interval_s=%.1f " + "Signal check | validation=finite_values poll_hz=%.1f summary_interval_s=%.1f " "duration_s=%s video_enabled=%s video_source=%s", - max_pos_value, poll_hz, summary_interval_s, f"{duration_s:.1f}" if duration_s > 0.0 else "until Ctrl-C", @@ -259,7 +255,7 @@ def main(cfg: DictConfig) -> None: last_seq = seq last_age_ms = max((time.monotonic() - float(timestamp_s)) * 1000.0, 0.0) last_stats = _frame_stats(frame) - result = validate_human_frame(frame, max_pos_value=max_pos_value) + result = validate_human_frame(frame) total += 1 if result.valid: valid += 1 diff --git a/teleopit/configs/pico4_sim2real.yaml b/teleopit/configs/pico4_sim2real.yaml index a038de2e..8629a415 100644 --- a/teleopit/configs/pico4_sim2real.yaml +++ b/teleopit/configs/pico4_sim2real.yaml @@ -103,7 +103,6 @@ real_robot: # Mocap mode switching safety checks mocap_switch: check_frames: 10 # Consecutive valid frames required before switching - max_position_value: 5.0 # Position sanity threshold (meters) hydra: run: diff --git a/teleopit/configs/sim2real.yaml b/teleopit/configs/sim2real.yaml index c2ebe287..13b267c5 100644 --- a/teleopit/configs/sim2real.yaml +++ b/teleopit/configs/sim2real.yaml @@ -102,7 +102,6 @@ real_robot: # Mocap mode switching safety checks mocap_switch: check_frames: 10 # Consecutive valid frames required before switching - max_position_value: 5.0 # Position sanity threshold (meters) hydra: run: diff --git a/teleopit/inputs/human_frame_validation.py b/teleopit/inputs/human_frame_validation.py index 4ce86052..ab854505 100644 --- a/teleopit/inputs/human_frame_validation.py +++ b/teleopit/inputs/human_frame_validation.py @@ -1,4 +1,4 @@ -"""HumanFrame sanity validation shared by realtime diagnostics and runtimes.""" +"""HumanFrame finite-value validation shared by realtime diagnostics and runtimes.""" from __future__ import annotations @@ -15,23 +15,15 @@ class HumanFrameValidationResult: pos: tuple[float, ...] | None = None quat: tuple[float, ...] | None = None max_abs_pos: float | None = None - max_pos_value: float | None = None detail: str = "" -def validate_human_frame(frame: object, *, max_pos_value: float) -> HumanFrameValidationResult: - """Validate the same HumanFrame conditions used before realtime retargeting.""" +def validate_human_frame(frame: object) -> HumanFrameValidationResult: + """Validate HumanFrame numeric values before realtime retargeting.""" if not isinstance(frame, dict): return HumanFrameValidationResult(False, reason="frame_not_dict", detail=f"type={type(frame)!r}") - max_pos = float(max_pos_value) - if not np.isfinite(max_pos) or max_pos <= 0.0: - return HumanFrameValidationResult( - False, - reason="invalid_max_position_value", - max_pos_value=max_pos, - ) - + max_frame_abs_pos: float | None = None for name, value in frame.items(): joint_name = str(name) try: @@ -41,7 +33,6 @@ def validate_human_frame(frame: object, *, max_pos_value: float) -> HumanFrameVa False, reason="joint_unpack_failed", joint_name=joint_name, - max_pos_value=max_pos, detail=str(exc), ) @@ -52,7 +43,6 @@ def validate_human_frame(frame: object, *, max_pos_value: float) -> HumanFrameVa False, reason="position_cast_failed", joint_name=joint_name, - max_pos_value=max_pos, detail=str(exc), ) try: @@ -63,13 +53,15 @@ def validate_human_frame(frame: object, *, max_pos_value: float) -> HumanFrameVa reason="quaternion_cast_failed", joint_name=joint_name, pos=_to_tuple(pos_arr), - max_pos_value=max_pos, detail=str(exc), ) pos_tuple = _to_tuple(pos_arr) quat_tuple = _to_tuple(quat_arr) max_abs_pos = float(np.max(np.abs(pos_arr))) if pos_arr.size > 0 else 0.0 + max_frame_abs_pos = ( + max_abs_pos if max_frame_abs_pos is None else max(max_frame_abs_pos, max_abs_pos) + ) if np.any(np.isnan(pos_arr)): return HumanFrameValidationResult( @@ -79,7 +71,6 @@ def validate_human_frame(frame: object, *, max_pos_value: float) -> HumanFrameVa pos=pos_tuple, quat=quat_tuple, max_abs_pos=max_abs_pos, - max_pos_value=max_pos, ) if np.any(np.isinf(pos_arr)): return HumanFrameValidationResult( @@ -89,17 +80,6 @@ def validate_human_frame(frame: object, *, max_pos_value: float) -> HumanFrameVa pos=pos_tuple, quat=quat_tuple, max_abs_pos=max_abs_pos, - max_pos_value=max_pos, - ) - if np.any(np.abs(pos_arr) > max_pos): - return HumanFrameValidationResult( - False, - reason="position_out_of_range", - joint_name=joint_name, - pos=pos_tuple, - quat=quat_tuple, - max_abs_pos=max_abs_pos, - max_pos_value=max_pos, ) if np.any(np.isnan(quat_arr)): return HumanFrameValidationResult( @@ -109,7 +89,6 @@ def validate_human_frame(frame: object, *, max_pos_value: float) -> HumanFrameVa pos=pos_tuple, quat=quat_tuple, max_abs_pos=max_abs_pos, - max_pos_value=max_pos, ) if np.any(np.isinf(quat_arr)): return HumanFrameValidationResult( @@ -119,10 +98,9 @@ def validate_human_frame(frame: object, *, max_pos_value: float) -> HumanFrameVa pos=pos_tuple, quat=quat_tuple, max_abs_pos=max_abs_pos, - max_pos_value=max_pos, ) - return HumanFrameValidationResult(True, max_pos_value=max_pos) + return HumanFrameValidationResult(True, max_abs_pos=max_frame_abs_pos) def _to_tuple(values: np.ndarray) -> tuple[float, ...]: diff --git a/teleopit/sim2real/controller.py b/teleopit/sim2real/controller.py index dfdc5eaa..2592f6d9 100644 --- a/teleopit/sim2real/controller.py +++ b/teleopit/sim2real/controller.py @@ -323,7 +323,6 @@ def _init_reference_config(self, cfg: Any) -> None: if self._reference_timeline is not None else None ) - mocap_sw = cfg_get(cfg, "mocap_switch", {}) self._ref_proc = Sim2RealReferenceProcessor( obs_builder=self.obs_builder, policy=self.policy, @@ -331,7 +330,6 @@ def _init_reference_config(self, cfg: Any) -> None: num_actions=self.num_actions, reference_velocity_smoothing_alpha=rc.reference_velocity_smoothing_alpha, reference_anchor_velocity_smoothing_alpha=rc.reference_anchor_velocity_smoothing_alpha, - max_pos_value=float(cfg_get(mocap_sw, "max_position_value", 5.0)), ) self._last_live_packet_seq = -1 diff --git a/teleopit/sim2real/mp/runtime.py b/teleopit/sim2real/mp/runtime.py index fbcb7c71..01ab17b2 100644 --- a/teleopit/sim2real/mp/runtime.py +++ b/teleopit/sim2real/mp/runtime.py @@ -117,8 +117,8 @@ def _worker_loop(name: str, fn: Callable[[], None]) -> None: raise -def _human_frame_is_valid(frame: object, *, max_pos_value: float) -> bool: - return validate_human_frame(frame, max_pos_value=max_pos_value).valid +def _human_frame_is_valid(frame: object) -> bool: + return validate_human_frame(frame).valid class MultiprocessSim2RealController: @@ -366,8 +366,6 @@ def _main() -> None: human_format=str(cfg_get(input_cfg, "human_format", "pico_bridge")), actual_human_height=float(cfg_get(input_cfg, "human_height", 1.75)), ) - mocap_sw = cfg_get(cfg, "mocap_switch", {}) or {} - max_position_value = float(cfg_get(mocap_sw, "max_position_value", 5.0)) body_sub = LatestSubscriber(endpoints.body_pub, BODY_TOPIC) health_sub = LatestSubscriber(endpoints.health_pub, HEALTH_TOPIC) command_sub = LatestSubscriber(endpoints.command_pub, COMMAND_TOPIC) @@ -421,7 +419,7 @@ def _publish_invalid_reference(packet: BodyFramePacket, *, elapsed_s: float) -> if not isinstance(packet, BodyFramePacket) or int(packet.seq) == last_body_seq: continue start_s = time.monotonic() - frame_valid = _human_frame_is_valid(packet.frame, max_pos_value=max_position_value) + frame_valid = _human_frame_is_valid(packet.frame) if not frame_valid: last_body_seq = int(packet.seq) last_body_timestamp_s = None @@ -527,7 +525,6 @@ def __init__( policy_dt_s=self.dt, reference_steps=cfg_get(cfg, "reference_steps", [0]), ) - mocap_sw = cfg_get(cfg, "mocap_switch", {}) or {} self._ref_proc = Sim2RealReferenceProcessor( obs_builder=self.obs_builder, policy=self.policy, @@ -535,7 +532,6 @@ def __init__( num_actions=self.num_actions, reference_velocity_smoothing_alpha=self._ref_cfg.reference_velocity_smoothing_alpha, reference_anchor_velocity_smoothing_alpha=self._ref_cfg.reference_anchor_velocity_smoothing_alpha, - max_pos_value=float(cfg_get(mocap_sw, "max_position_value", 5.0)), ) self._standing_qpos = np.zeros(FULL_QPOS_DIM, dtype=np.float64) diff --git a/teleopit/sim2real/reference_processor.py b/teleopit/sim2real/reference_processor.py index bfbf72c8..32567182 100644 --- a/teleopit/sim2real/reference_processor.py +++ b/teleopit/sim2real/reference_processor.py @@ -15,6 +15,7 @@ from teleopit.controllers import reference_processing as ref_proc from teleopit.controllers.observation import VelCmdObservationBuilder +from teleopit.inputs.human_frame_validation import validate_human_frame from teleopit.sim.realtime_utils import ExponentialVecSmoother from teleopit.sim.reference_timeline import ReferenceWindow @@ -34,13 +35,11 @@ def __init__( num_actions: int, reference_velocity_smoothing_alpha: float, reference_anchor_velocity_smoothing_alpha: float, - max_pos_value: float, ) -> None: self._obs_builder = obs_builder self._policy = policy self._policy_hz = policy_hz self._num_actions = num_actions - self._max_pos_value = max_pos_value # Yaw alignment state (lazy-init) self._fixed_reference_yaw_quat: Float32Array | None = None @@ -77,14 +76,7 @@ def retarget_to_qpos(self, retargeted: object) -> Float64Array: return ref_proc.retarget_to_qpos(retargeted) def frame_is_valid(self, frame: dict[str, tuple[np.ndarray, np.ndarray]]) -> bool: - for pos, quat in frame.values(): - if np.any(np.isnan(pos)) or np.any(np.isinf(pos)): - return False - if np.any(np.abs(pos) > self._max_pos_value): - return False - if np.any(np.isnan(quat)) or np.any(np.isinf(quat)): - return False - return True + return validate_human_frame(frame).valid # ------------------------------------------------------------------ # Yaw alignment diff --git a/tests/test_human_frame_validation.py b/tests/test_human_frame_validation.py index 25400102..d1c7d358 100644 --- a/tests/test_human_frame_validation.py +++ b/tests/test_human_frame_validation.py @@ -5,19 +5,40 @@ from teleopit.inputs.human_frame_validation import validate_human_frame -def test_validate_human_frame_reports_out_of_range_joint() -> None: +def test_validate_human_frame_accepts_large_finite_positions() -> None: frame = { - "Pelvis": (np.array([6.0, 0.0, 0.0]), np.array([1.0, 0.0, 0.0, 0.0])), + "Pelvis": (np.array([100.0, -100.0, 1.0]), np.array([1.0, 0.0, 0.0, 0.0])), + "Head": (np.array([100.1, -100.1, 1.7]), np.array([1.0, 0.0, 0.0, 0.0])), } - result = validate_human_frame(frame, max_pos_value=5.0) + result = validate_human_frame(frame) + + assert result.valid + assert result.reason == "ok" + + +def test_validate_human_frame_reports_position_inf() -> None: + frame = { + "Pelvis": (np.array([np.inf, 0.0, 1.0]), np.array([1.0, 0.0, 0.0, 0.0])), + } + + result = validate_human_frame(frame) + + assert not result.valid + assert result.reason == "position_inf" + assert result.joint_name == "Pelvis" + + +def test_validate_human_frame_reports_position_nan() -> None: + frame = { + "Pelvis": (np.array([np.nan, 0.0, 1.0]), np.array([1.0, 0.0, 0.0, 0.0])), + } + + result = validate_human_frame(frame) assert not result.valid - assert result.reason == "position_out_of_range" + assert result.reason == "position_nan" assert result.joint_name == "Pelvis" - assert result.pos == (6.0, 0.0, 0.0) - assert result.max_abs_pos == 6.0 - assert result.max_pos_value == 5.0 def test_validate_human_frame_reports_quaternion_nan() -> None: @@ -25,7 +46,7 @@ def test_validate_human_frame_reports_quaternion_nan() -> None: "Head": (np.array([0.0, 0.0, 1.5]), np.array([1.0, np.nan, 0.0, 0.0])), } - result = validate_human_frame(frame, max_pos_value=5.0) + result = validate_human_frame(frame) assert not result.valid assert result.reason == "quaternion_nan" @@ -38,7 +59,7 @@ def test_validate_human_frame_accepts_finite_frame() -> None: "Head": (np.array([0.0, 0.0, 1.7]), np.array([1.0, 0.0, 0.0, 0.0])), } - result = validate_human_frame(frame, max_pos_value=5.0) + result = validate_human_frame(frame) assert result.valid assert result.reason == "ok" diff --git a/tests/test_sim2real_multiprocess.py b/tests/test_sim2real_multiprocess.py index 87117c5f..52c737ce 100644 --- a/tests/test_sim2real_multiprocess.py +++ b/tests/test_sim2real_multiprocess.py @@ -159,12 +159,12 @@ def test_human_frame_validation_rejects_bad_inputs() -> None: valid_frame = { "Pelvis": (np.zeros(3, dtype=np.float64), np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64)), } - assert _human_frame_is_valid(valid_frame, max_pos_value=5.0) + assert _human_frame_is_valid(valid_frame) bad_frame = { - "Pelvis": (np.array([6.0, 0.0, 0.0], dtype=np.float64), np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64)), + "Pelvis": (np.array([np.nan, 0.0, 0.0], dtype=np.float64), np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64)), } - assert not _human_frame_is_valid(bad_frame, max_pos_value=5.0) + assert not _human_frame_is_valid(bad_frame) def test_robot_worker_requires_consecutive_valid_references(monkeypatch) -> None: From 5d95c887fc0dd91a3c564aea1bb54f38c0fb1f3f Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Fri, 29 May 2026 20:17:37 +0800 Subject: [PATCH 063/122] Hold mocap command on delayed references --- teleopit/configs/pico4_sim2real.yaml | 2 ++ teleopit/configs/sim2real.yaml | 2 ++ teleopit/sim2real/controller.py | 25 ++++++++++++++++---- teleopit/sim2real/mp/runtime.py | 34 ++++++++++++++++------------ tests/test_sim2real_multiprocess.py | 29 ++++++++++++++++++++++++ tests/test_sim2real_runtime.py | 24 ++++++++++++++++++++ 6 files changed, 96 insertions(+), 20 deletions(-) diff --git a/teleopit/configs/pico4_sim2real.yaml b/teleopit/configs/pico4_sim2real.yaml index 8629a415..e2c49810 100644 --- a/teleopit/configs/pico4_sim2real.yaml +++ b/teleopit/configs/pico4_sim2real.yaml @@ -28,6 +28,8 @@ multiprocess: hand_worker_hz: 120.0 retarget_idle_sleep_s: 0.001 video_slots: 3 + # In active MOCAP, stale/invalid realtime references hold the last command instead of damping. + # max_reference_age_s only rejects entering MOCAP from STANDING on an old reference. stale_reference_hold_s: 0.08 max_reference_age_s: 0.25 diff --git a/teleopit/configs/sim2real.yaml b/teleopit/configs/sim2real.yaml index 13b267c5..2b4a3277 100644 --- a/teleopit/configs/sim2real.yaml +++ b/teleopit/configs/sim2real.yaml @@ -27,6 +27,8 @@ multiprocess: hand_worker_hz: 120.0 retarget_idle_sleep_s: 0.001 video_slots: 3 + # In active MOCAP, delayed/invalid realtime references hold the last command instead of damping. + # max_reference_age_s only rejects entering MOCAP from STANDING on an old reference. stale_reference_hold_s: 0.08 max_reference_age_s: 0.25 diff --git a/teleopit/sim2real/controller.py b/teleopit/sim2real/controller.py index 2592f6d9..5c853794 100644 --- a/teleopit/sim2real/controller.py +++ b/teleopit/sim2real/controller.py @@ -283,6 +283,7 @@ def _init_components(self, cfg: Any) -> None: self._mocap_reentry_armed: bool = False self._mocap_session = MocapSessionManager() self._last_commanded_motion_qpos: Float64Array | None = None + self._last_mocap_hold_reason: str | None = None self._viewers = _parse_sim2real_viewers(cfg) self._retarget_viewer = _Sim2RealRetargetViewer( xml_path=str(cfg_get(robot_cfg, "xml_path", "")) if "retarget" in self._viewers else None, @@ -437,15 +438,16 @@ def _mocap_step(self) -> None: return if not self.input_provider.is_available(): - logger.warning("Input provider unavailable -- entering damping") - self._enter_damping() + self._hold_mocap_reference("input provider unavailable") return try: packet = self._fetch_realtime_input_packet() - except (TimeoutError, RuntimeError): - logger.warning("Input provider error -- entering damping") - self._enter_damping() + except (TimeoutError, RuntimeError) as exc: + self._hold_mocap_reference( + "input provider error", + detail=f"{type(exc).__name__}: {exc}", + ) return self._handle_mocap_control_events(packet.control_events) @@ -577,6 +579,7 @@ def _execute_mocap_pipeline( self._last_retarget_qpos = qpos.copy() self._ref_proc.last_reference_qpos = reference_qpos.copy() self._last_commanded_motion_qpos = qpos.copy() + self._last_mocap_hold_reason = None self._write_retarget_viewer(qpos) # ------------------------------------------------------------------ @@ -787,6 +790,7 @@ def _enter_damping(self) -> None: self._mocap_reentry_armed = False self._mocap_session.reset() self._last_commanded_motion_qpos = None + self._last_mocap_hold_reason = None logger.info("Mode -> DAMPING (press Start to re-enter STANDING)") # ------------------------------------------------------------------ @@ -801,6 +805,7 @@ def _reset_policy_state(self) -> None: self._ref_proc.reset_alignment() self._mocap_session.reset() self._last_commanded_motion_qpos = None + self._last_mocap_hold_reason = None self.policy.reset() self.obs_builder.reset() @@ -811,6 +816,7 @@ def _reset_policy_reference_state(self) -> None: self._ref_proc.reset_alignment() self._mocap_session.reset() self._last_commanded_motion_qpos = None + self._last_mocap_hold_reason = None self.policy.reset() self.obs_builder.reset() @@ -828,6 +834,7 @@ def _reset_mocap_reference_state(self) -> None: self._reference_manager.reset() self._ref_proc.reset_smoothers() self._last_live_packet_seq = -1 + self._last_mocap_hold_reason = None def _build_robot_state_qpos(self, state: object) -> Float64Array: qpos = np.zeros(FULL_QPOS_DIM, dtype=np.float64) @@ -1010,6 +1017,14 @@ def _run_static_mocap_step(self, hold_qpos: Float64Array) -> None: self._last_commanded_motion_qpos = qpos.copy() self._write_retarget_viewer(qpos) + def _hold_mocap_reference(self, reason: str, *, detail: str | None = None) -> None: + if self._last_mocap_hold_reason != reason: + suffix = f" ({detail})" if detail else "" + logger.warning("Mocap reference not fresh: %s%s -- holding command", reason, suffix) + self._last_mocap_hold_reason = reason + hold_qpos = self._resolve_mocap_hold_qpos() + self._run_static_mocap_step(hold_qpos) + def _tick_dexterous_hand(self) -> None: active = self.mode == RobotMode.MOCAP and self._mocap_session.state == MocapSessionState.ACTIVE try: diff --git a/teleopit/sim2real/mp/runtime.py b/teleopit/sim2real/mp/runtime.py index 01ab17b2..9052c1bf 100644 --- a/teleopit/sim2real/mp/runtime.py +++ b/teleopit/sim2real/mp/runtime.py @@ -540,6 +540,7 @@ def __init__( self._last_action = np.zeros(self.num_actions, dtype=np.float32) self._last_retarget_qpos: Float64Array | None = None self._last_commanded_motion_qpos: Float64Array | None = None + self._last_mocap_hold_reason: str | None = None self._mocap_reentry_armed = False self._mocap_session = MocapSessionManager() @@ -706,18 +707,16 @@ def _mocap_step(self) -> None: reference = self._latest_reference age_s = self._reference_age_s() if reference is None or age_s is None: - self._hold_or_damp_stale_reference("no retarget reference") + self._hold_mocap_reference("no retarget reference") return if not reference.frame_valid: - logger.warning("Retarget reference invalid -- holding last command") - self._hold_or_damp_stale_reference("invalid retarget reference") + self._hold_mocap_reference("invalid retarget reference") return - if age_s > self._max_reference_age_s: - logger.warning("Retarget reference stale %.3fs -- entering damping", age_s) - self._enter_damping() - return - if age_s > self._stale_reference_hold_s and self._last_commanded_motion_qpos is not None: - self._run_static_mocap_step(self._last_commanded_motion_qpos) + if age_s > self._stale_reference_hold_s: + self._hold_mocap_reference( + "delayed retarget reference", + detail=f"age={age_s:.3f}s", + ) return robot_state = self.robot.get_state() @@ -765,6 +764,7 @@ def _execute_mocap_pipeline( self._last_retarget_qpos = qpos.copy() self._ref_proc.last_reference_qpos = reference_qpos.copy() self._last_commanded_motion_qpos = qpos.copy() + self._last_mocap_hold_reason = None self._write_retarget_viewer(qpos) def _enter_standing(self) -> None: @@ -842,6 +842,7 @@ def _enter_damping(self) -> None: self._mocap_reentry_armed = False self._mocap_session.reset() self._last_commanded_motion_qpos = None + self._last_mocap_hold_reason = None logger.info("Mode -> DAMPING (press Start to re-enter STANDING)") def _reset_policy_state(self) -> None: @@ -850,6 +851,7 @@ def _reset_policy_state(self) -> None: self._ref_proc.reset_alignment() self._mocap_session.reset() self._last_commanded_motion_qpos = None + self._last_mocap_hold_reason = None self.policy.reset() self.obs_builder.reset() @@ -859,6 +861,7 @@ def _reset_policy_reference_state(self) -> None: self._ref_proc.reset_alignment() self._mocap_session.reset() self._last_commanded_motion_qpos = None + self._last_mocap_hold_reason = None self.policy.reset() self.obs_builder.reset() @@ -968,12 +971,13 @@ def _run_static_mocap_step(self, hold_qpos: Float64Array) -> None: self._last_commanded_motion_qpos = qpos.copy() self._write_retarget_viewer(qpos) - def _hold_or_damp_stale_reference(self, reason: str) -> None: - if self._last_commanded_motion_qpos is not None: - self._run_static_mocap_step(self._last_commanded_motion_qpos) - return - logger.warning("No mocap hold pose available after %s -- entering damping", reason) - self._enter_damping() + def _hold_mocap_reference(self, reason: str, *, detail: str | None = None) -> None: + if self._last_mocap_hold_reason != reason: + suffix = f" ({detail})" if detail else "" + logger.warning("Mocap reference not fresh: %s%s -- holding command", reason, suffix) + self._last_mocap_hold_reason = reason + hold_qpos = self._resolve_mocap_hold_qpos() + self._run_static_mocap_step(hold_qpos) def _publish_mode_state(self) -> None: self._mode_seq += 1 diff --git a/tests/test_sim2real_multiprocess.py b/tests/test_sim2real_multiprocess.py index 52c737ce..428dc886 100644 --- a/tests/test_sim2real_multiprocess.py +++ b/tests/test_sim2real_multiprocess.py @@ -167,6 +167,35 @@ def test_human_frame_validation_rejects_bad_inputs() -> None: assert not _human_frame_is_valid(bad_frame) +def test_robot_worker_holds_stale_reference_instead_of_damping() -> None: + worker = object.__new__(_RobotControlWorker) + hold_qpos = np.zeros(36, dtype=np.float64) + hold_qpos[3] = 1.0 + hold_qpos[7] = 0.25 + worker._mocap_session = SimpleNamespace(state=MocapSessionState.ACTIVE) + worker._latest_reference = ReferencePacket( + qpos=np.zeros(36, dtype=np.float64), + timestamp_s=1.0, + seq=1, + source_timestamp_s=1.0, + source_seq=1, + frame_valid=True, + ) + worker._reference_age_s = lambda: 0.30 + worker._stale_reference_hold_s = 0.08 + worker._last_mocap_hold_reason = None + worker._last_commanded_motion_qpos = hold_qpos.copy() + worker._resolve_mocap_hold_qpos = lambda: hold_qpos.copy() + held: list[np.ndarray] = [] + worker._run_static_mocap_step = lambda qpos: held.append(np.asarray(qpos, dtype=np.float64).copy()) + worker._enter_damping = lambda: pytest.fail("stale references must not enter damping") + + worker._mocap_step() + + assert len(held) == 1 + np.testing.assert_allclose(held[0], hold_qpos) + + def test_robot_worker_requires_consecutive_valid_references(monkeypatch) -> None: worker = object.__new__(_RobotControlWorker) worker._latest_reference = None diff --git a/tests/test_sim2real_runtime.py b/tests/test_sim2real_runtime.py index 39c38008..1320dc33 100644 --- a/tests/test_sim2real_runtime.py +++ b/tests/test_sim2real_runtime.py @@ -480,6 +480,30 @@ def test_dexterous_hand_failure_does_not_enter_damping(monkeypatch) -> None: assert hand_runtime.active_flags == [True] +def test_realtime_input_timeout_holds_mocap_instead_of_damping(monkeypatch) -> None: + from teleopit.sim2real.controller import RobotMode, Sim2RealController + + policy = DummyPolicy() + obs_builder = DummyVelCmdObservationBuilder() + _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=np.zeros(36, dtype=np.float64)) + + ctrl = Sim2RealController(_make_cfg()) + ctrl.mode = RobotMode.MOCAP + ctrl._mocap_session.reset() + hold_qpos = np.zeros(36, dtype=np.float64) + hold_qpos[3] = 1.0 + hold_qpos[7] = 0.25 + ctrl._last_commanded_motion_qpos = hold_qpos.copy() + ctrl._fetch_realtime_input_packet = lambda: (_ for _ in ()).throw(TimeoutError("stalled")) + ctrl._enter_damping = lambda: pytest.fail("input timeouts must not enter damping") + + ctrl._mocap_step() + + assert ctrl.mode == RobotMode.MOCAP + np.testing.assert_allclose(obs_builder.build_calls[-1]["motion_qpos"], hold_qpos.astype(np.float32)) + assert len(ctrl.robot.sent_positions) == 1 + + def test_can_switch_to_mocap_returns_false_without_blocking_when_realtime_has_no_frame(monkeypatch) -> None: from teleopit.sim2real.controller import Sim2RealController From 95dce87ee662ed28d8b6e1d0301ae26492b511ed Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Tue, 9 Jun 2026 16:20:56 +0800 Subject: [PATCH 064/122] Refactor sim2real hand runtime --- AGENTS.md | 19 +- README.md | 4 +- docs/docs/configuration/config-reference.md | 43 +- docs/docs/getting-started/installation.md | 3 +- docs/docs/reference/architecture.md | 4 +- docs/docs/tutorials/pico-sim2real.md | 29 +- .../current/configuration/config-reference.md | 46 +- .../current/getting-started/installation.md | 3 +- .../current/reference/architecture.md | 2 +- .../current/tutorials/pico-sim2real.md | 29 +- scripts/dev/test_linkerhand_l6.py | 128 +- scripts/run/check_pico_signal.py | 2 +- scripts/run/run_sim2real.py | 14 +- scripts/run/standalone_standing.py | 6 +- scripts/setup/download_somehand_l6_assets.sh | 171 +-- teleopit/configs/pico4_sim2real.yaml | 43 +- teleopit/configs/sim2real.yaml | 43 +- teleopit/runtime/reference_config.py | 2 +- teleopit/sim/reference_utils.py | 2 +- teleopit/sim2real/__init__.py | 4 +- teleopit/sim2real/controller.py | 1105 ----------------- teleopit/sim2real/dexterous_hand.py | 979 --------------- teleopit/sim2real/hands/__init__.py | 3 + teleopit/sim2real/hands/base.py | 33 + teleopit/sim2real/hands/linkerhand_l6.py | 439 +++++++ teleopit/sim2real/hands/pico_landmarks.py | 43 + teleopit/sim2real/hands/worker.py | 79 ++ teleopit/sim2real/mp/__init__.py | 8 +- teleopit/sim2real/mp/ipc.py | 2 + teleopit/sim2real/mp/messages.py | 2 + teleopit/sim2real/mp/runtime.py | 427 ++++++- tests/test_dexterous_hand.py | 659 ++-------- tests/test_sim2real_dim.py | 140 +-- tests/test_sim2real_multiprocess.py | 96 +- tests/test_sim2real_runtime.py | 848 ------------- tests/test_termination_config.py | 4 +- third_party/somehand | 2 +- 37 files changed, 1417 insertions(+), 4049 deletions(-) delete mode 100644 teleopit/sim2real/controller.py delete mode 100644 teleopit/sim2real/dexterous_hand.py create mode 100644 teleopit/sim2real/hands/__init__.py create mode 100644 teleopit/sim2real/hands/base.py create mode 100644 teleopit/sim2real/hands/linkerhand_l6.py create mode 100644 teleopit/sim2real/hands/pico_landmarks.py create mode 100644 teleopit/sim2real/hands/worker.py delete mode 100644 tests/test_sim2real_runtime.py diff --git a/AGENTS.md b/AGENTS.md index b9864522..30e48481 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -55,8 +55,8 @@ teleopit/ # Core inference package ├── sim/ │ └── loop.py # SimulationLoop — PD control at 1000Hz, policy at 50Hz ├── sim2real/ -│ ├── controller.py # G1 state machine and hardware control loop -│ └── dexterous_hand.py # Optional Pico gripper / VR hand pose → LinkerHand L6 runtime +│ ├── mp/ # Process-isolated sim2real runtime and IPC +│ └── hands/ # Optional LinkerHand L6 driver/mapper plugins └── recording/ # HDF5Recorder scripts/ ├── run_sim.py # Offline sim2sim pipeline @@ -140,13 +140,14 @@ target_dof_pos = clip(action, -10, 10) × action_scale + default_dof_pos - Pico sim2sim supports a keyboard-driven top-level mode state machine: `STANDING → MOCAP → STANDING` - Default Pico sim2sim keyboard mappings are `Y` → `MOCAP`, `A` → pause/resume mocap, `X` → back to `STANDING`, `Q` → quit - Pico4 sim2real pause/resume is handled as a mocap-session control event (`toggle_pause`), not as a mode switch to `STANDING` -- Default Pico pause button is `A`; resume rebuilds the realtime buffer and yaw/XY root-offset alignment, then waits for the configured realtime warmup before tracking continues -- Realtime mode switches and pause/resume use a retargeter-preserving soft reset: policy/reference history, smoothers, realtime buffers, and reference alignment are reset, while the GMR IK warm-start is retained -- Optional LinkerHand L6 control uses `dexterous_hand.mode=off|gripper|vr_hand_pose`; default is `off` +- Default Pico pause button is `A`; resume resets policy/reference state and yaw/XY root-offset alignment while the process-isolated realtime reference worker continues its live input timeline +- Realtime mode switches and pause/resume use a retargeter-preserving soft reset: policy/reference state, smoothers, and reference alignment are reset, while the GMR IK warm-start is retained +- Optional LinkerHand L6 control uses `hands.enabled=true` and `hands.mode=gripper|vr_hand_pose`; default is disabled - `gripper` mode reuses `Pico4InputProvider.get_controller_snapshot()` for Pico grip/trigger open-close control -- `vr_hand_pose` mode reuses `Pico4InputProvider.get_hand_snapshot()` and `somehand` for continuous Pico hand-pose retargeting; do not start a second `PicoBridge` for hand control -- `gripper` mode uses the configured `dexterous_hand.speed` (default `[50]*6`); `vr_hand_pose` always sets LinkerHand L6 speed to `[255]*6` -- `vr_hand_pose` defaults to a low-latency somehand path: `dexterous_hand.somehand.rate=60`, `threaded=true`, `max_iterations=12`, `temporal_filter_alpha=1.0`, and `output_alpha=1.0`; this prioritizes response speed over smoothing +- `vr_hand_pose` mode reuses `Pico4InputProvider.get_hand_snapshot()` and somehand 0.2.0 public `somehand.api` for continuous Pico hand-pose retargeting; do not start a second `PicoBridge` for hand control +- Teleopit owns Pico 26-joint hand-state to 21-landmark conversion; do not import `somehand.pico_input` +- `gripper` mode uses the configured `hands.linkerhand_l6.speed` (default `[50]*6`); `vr_hand_pose` always sets LinkerHand L6 speed to `[255]*6` +- `vr_hand_pose` defaults to a low-latency somehand path: `hands.somehand.rate_hz=60`, `max_iterations=12`, `temporal_filter_alpha=1.0`, and `output_alpha=1.0`; this prioritizes response speed over smoothing - LinkerHand L6 control is active only in sim2real `MOCAP`; `STANDING`, `DAMPING`, mocap pause, and shutdown must send the configured open pose - In `vr_hand_pose` mode, missing/inactive hand pose holds the last commanded pose for that side instead of opening the hand @@ -157,7 +158,7 @@ target_dof_pos = clip(action, -10, 10) × action_scale + default_dof_pos - BVH frame alignment is time-based: `bvh_idx = int(policy_time × input_fps)` - Realtime reference buffering is controlled by `retarget_buffer_enabled`, `retarget_buffer_window_s`, `retarget_buffer_delay_s`, `reference_steps`, and `realtime_buffer_warmup_steps` - Realtime inferred `motion_joint_vel`, anchor linear velocity, and anchor angular velocity can be EMA-smoothed via `reference_velocity_smoothing_alpha` and `reference_anchor_velocity_smoothing_alpha` -- Sim2real Pico pause/resume uses mocap-session states `ACTIVE ↔ PAUSED`; resume clears policy/reference state, rebuilds yaw/XY root alignment, warms the realtime buffer, and does not interpolate retarget qpos from the paused pose +- Sim2real Pico pause/resume uses mocap-session states `ACTIVE ↔ PAUSED`; resume clears policy/reference state, rebuilds yaw/XY root alignment, and does not interpolate retarget qpos from the paused pose - Realtime sim2sim with Pico control events uses the same mocap-session pause/resume semantics and rebuilds the realtime reference path on resume, including the configured warmup - Realtime sim2sim/sim2real `STANDING ↔ MOCAP` transitions use the same retargeter-preserving soft reset, rather than cold-starting the retargeter from its default qpos - Realtime Pico sim2sim can start directly in `STANDING` with keyboard mode control enabled via top-level `keyboard.enabled` diff --git a/README.md b/README.md index 099a83d1..8faf086f 100644 --- a/README.md +++ b/README.md @@ -66,9 +66,9 @@ Full docs at **[BotRunner64.github.io/Teleopit](https://BotRunner64.github.io/Te ### Unreleased - Bumped Pico input support to pico-bridge 0.2.1 and its corrected tracking pose semantics. -- Added optional LinkerHand L6 sim2real modes: `gripper` from Pico grip/trigger and `vr_hand_pose` from Pico hand pose through somehand. +- Added optional LinkerHand L6 sim2real modes under `hands.*`: `gripper` from Pico grip/trigger and `vr_hand_pose` from Pico hand pose through somehand 0.2.0 public API. - Set LinkerHand L6 `vr_hand_pose` control to maximum speed while keeping `gripper` at the configured default speed. -- Switched default `vr_hand_pose` to a low-latency somehand path with 60 Hz threaded hand retargeting and reduced smoothing. +- Switched default `vr_hand_pose` to a low-latency somehand path with 60 Hz hand retargeting and reduced smoothing. - Realtime mode switches and pause/resume now preserve GMR IK warm-starts instead of cold-starting the retargeter on each transition. ### v0.3.0 (2026-05-12) diff --git a/docs/docs/configuration/config-reference.md b/docs/docs/configuration/config-reference.md index 213ce1b7..2b922496 100644 --- a/docs/docs/configuration/config-reference.md +++ b/docs/docs/configuration/config-reference.md @@ -120,32 +120,29 @@ Realtime Pico resume re-centers heading and ground-plane position before trackin ### Dexterous Hand (Pico sim2real) -`dexterous_hand.mode=gripper` or `dexterous_hand.mode=vr_hand_pose` requires -`input.provider=pico4` and the optional `dexhand` extra. Control is active only -in `MOCAP`; inactive modes send the open pose. In `vr_hand_pose`, missing hand -pose holds the last command for that side. `gripper` uses the configured -`dexterous_hand.speed`; `vr_hand_pose` always sets LinkerHand L6 speed to the -maximum. The default `vr_hand_pose` path favors low latency: it runs in a -background thread at `dexterous_hand.somehand.rate` and disables most somehand -input/output smoothing, which can make finger motion noisier. +`hands.enabled=true` requires `input.provider=pico4` and the optional `dexhand` +extra. Control is active only in `MOCAP`; inactive modes send the open pose. In +`vr_hand_pose`, missing hand pose holds the last command for that side. +`gripper` uses the configured `hands.linkerhand_l6.speed`; `vr_hand_pose` +always sets LinkerHand L6 speed to the maximum. Teleopit converts Pico hand +state to 21 landmarks and embeds somehand 0.2.0 through `somehand.api` only. | Field | Description | Default | |-------|-------------|---------| -| `dexterous_hand.mode` | `off`, `gripper`, or `vr_hand_pose` | `off` | -| `dexterous_hand.hand_type` | Controlled side: `left`, `right`, or `both`; `vr_hand_pose` requires `both` | `both` | -| `dexterous_hand.left_can` / `right_can` | CAN channels for each hand | `can0` / `can1` | -| `dexterous_hand.rate` | Maximum gripper command rate in Hz | `30.0` | -| `dexterous_hand.frame_timeout` | Gripper controller timeout, or VR hand-pose staleness threshold | `0.3` | -| `dexterous_hand.speed` | L6 speed used by `gripper`; `vr_hand_pose` overrides this to maximum speed | see config | -| `dexterous_hand.deadman_threshold` | Minimum grip value required to enable a side | `0.5` | -| `dexterous_hand.trigger_deadzone` | Trigger deadzone at both ends | `0.05` | -| `dexterous_hand.open_pose` / `close_pose` | Six-value L6 open/closed poses | see config | -| `dexterous_hand.somehand.config_path` | somehand bi-hand L6 config used by `vr_hand_pose` | see config | -| `dexterous_hand.somehand.rate` | Low-latency `vr_hand_pose` command rate in Hz | `60.0` | -| `dexterous_hand.somehand.threaded` | Run `vr_hand_pose` hand retargeting outside the robot control loop | `true` | -| `dexterous_hand.somehand.max_iterations` | somehand solver iteration cap for `vr_hand_pose` | `12` | -| `dexterous_hand.somehand.temporal_filter_alpha` | somehand input landmark smoothing alpha; `1.0` disables smoothing delay | `1.0` | -| `dexterous_hand.somehand.output_alpha` | somehand qpos output smoothing alpha; `1.0` disables smoothing delay | `1.0` | +| `hands.enabled` | Enable optional hand worker | `false` | +| `hands.driver` | Hand driver plugin | `linkerhand_l6` | +| `hands.mode` | `gripper` or `vr_hand_pose` | `gripper` | +| `hands.sides` | Controlled sides | `[left, right]` | +| `hands.rate_hz` | Maximum gripper command rate in Hz | `30.0` | +| `hands.frame_timeout_s` | Controller or hand-pose staleness threshold | `0.3` | +| `hands.linkerhand_l6.left_can` / `right_can` | CAN channels for each hand | `can0` / `can1` | +| `hands.linkerhand_l6.speed` | L6 speed used by `gripper`; `vr_hand_pose` overrides this to maximum speed | see config | +| `hands.linkerhand_l6.open_pose` / `close_pose` | Six-value L6 open/closed poses | see config | +| `hands.somehand.config_path` | Official somehand 0.2.0 bi-hand L6 config used by `vr_hand_pose` | see config | +| `hands.somehand.rate_hz` | Low-latency `vr_hand_pose` command rate in Hz | `60.0` | +| `hands.somehand.max_iterations` | somehand solver iteration cap for `vr_hand_pose` | `12` | +| `hands.somehand.temporal_filter_alpha` | somehand input landmark smoothing alpha; `1.0` disables smoothing delay | `1.0` | +| `hands.somehand.output_alpha` | somehand qpos output smoothing alpha; `1.0` disables smoothing delay | `1.0` | ## Critical: `default_dof_pos` diff --git a/docs/docs/getting-started/installation.md b/docs/docs/getting-started/installation.md index 6129c681..7af73776 100644 --- a/docs/docs/getting-started/installation.md +++ b/docs/docs/getting-started/installation.md @@ -71,8 +71,7 @@ pip install -e '.[dexhand]' scripts/setup/download_somehand_l6_assets.sh ``` -This extra is only required when `dexterous_hand.mode=gripper` or -`dexterous_hand.mode=vr_hand_pose`. +This extra is only required when `hands.enabled=true`. ## Verify Installation diff --git a/docs/docs/reference/architecture.md b/docs/docs/reference/architecture.md index ab1fdae9..c3aa900e 100644 --- a/docs/docs/reference/architecture.md +++ b/docs/docs/reference/architecture.md @@ -16,7 +16,7 @@ InputProvider (BVH file / Pico4) -> Robot (MuJoCo sim or Unitree G1) ``` -Offline/online inference is assembled by `teleopit/runtime/` and `teleopit/pipeline.py`. The hardware state machine lives in `teleopit/sim2real/controller.py`. Training is provided by `train_mimic/`. +Offline/online inference is assembled by `teleopit/runtime/` and `teleopit/pipeline.py`. The hardware state machine runs through the process-isolated runtime in `teleopit/sim2real/mp/`. Training is provided by `train_mimic/`. ## Code Structure @@ -43,7 +43,7 @@ train_mimic/scripts/data | `teleopit/interfaces.py` | Stable protocols: InputProvider, Retargeter, Controller, Robot, ObservationBuilder, Recorder | | `teleopit/runtime/` | Config parsing, path normalization, component assembly, CLI validation | | `teleopit/pipeline.py` | Lightweight facade for offline sim | -| `teleopit/sim2real/controller.py` | Hardware state machine and control logic | +| `teleopit/sim2real/mp/` | Process-isolated sim2real state machine, IPC, and robot-control loop | | `teleopit/controllers/observation.py` | ObservationBuilder | | `teleopit/controllers/rl_policy.py` | Only accepts 166D dual-input ONNX | | `train_mimic/app.py` | Shared train/play/benchmark assembly | diff --git a/docs/docs/tutorials/pico-sim2real.md b/docs/docs/tutorials/pico-sim2real.md index 1eb8bb10..f9a1e776 100644 --- a/docs/docs/tutorials/pico-sim2real.md +++ b/docs/docs/tutorials/pico-sim2real.md @@ -146,13 +146,12 @@ Pico sim2real can drive LinkerHand L6 hands in two modes: - `gripper`: hold the matching side grip as a deadman switch; the matching trigger closes that hand. This mode uses the configured - `dexterous_hand.speed`, which defaults to 50. + `hands.linkerhand_l6.speed`, which defaults to 50. - `vr_hand_pose`: retarget Pico hand pose through somehand and command the continuous L6 hand target. If a hand pose disappears, that side keeps its last - commanded pose. This mode currently uses `hand_type=both` and always sets L6 - speed to the maximum. The default configuration uses a low-latency somehand - path at 60 Hz with reduced smoothing, so it should feel more responsive but - can be noisier than the standard somehand settings. + commanded pose. This mode uses Teleopit's Pico landmark adapter and the + public `somehand.api` from somehand 0.2.0. It always sets L6 speed to the + maximum. Hand control is active only in `MOCAP`. It sends the open pose in `STANDING`, `DAMPING`, paused mocap, and shutdown. @@ -185,18 +184,19 @@ python scripts/dev/test_linkerhand_l6.py \ Then enable L6 control in Pico sim2real: ```bash -dexterous_hand.mode=gripper -dexterous_hand.left_can=can0 -dexterous_hand.right_can=can1 +hands.enabled=true +hands.mode=gripper +hands.linkerhand_l6.left_can=can0 +hands.linkerhand_l6.right_can=can1 ``` For continuous VR hand-pose control, use: ```bash -dexterous_hand.mode=vr_hand_pose -dexterous_hand.hand_type=both -dexterous_hand.left_can=can0 -dexterous_hand.right_can=can1 +hands.enabled=true +hands.mode=vr_hand_pose +hands.linkerhand_l6.left_can=can0 +hands.linkerhand_l6.right_can=can1 ``` ## Optional RealSense Preview @@ -233,7 +233,8 @@ mocap_switch.check_frames=10 input.pause_button=right_axis_click # Enable LinkerHand L6 control -dexterous_hand.mode=gripper +hands.enabled=true +hands.mode=gripper # Enable headset video preview input.video.enabled=true @@ -248,5 +249,5 @@ input.video.enabled=true | Cannot enter debug mode | Unitree mode release failed | Stop other robot modes and press `Start` again | | Robot enters `STANDING` but not `MOCAP` | Mocap validation failed | Keep tracking active and stable; check `mocap_switch.check_frames` logs | | Pico pause does not return to `STANDING` | Expected behavior | Pico pause freezes mocap; press remote `X` for `STANDING` | -| LinkerHand does not move | Mode is `off`, not in `MOCAP`, gripper deadman released, SDK/assets not installed, or CAN channel wrong | Set `dexterous_hand.mode`, enter `MOCAP`, run `scripts/dev/test_linkerhand_l6.py`, and check `dexterous_hand.left_can` / `right_can` | +| LinkerHand does not move | `hands.enabled=false`, not in `MOCAP`, gripper deadman released, SDK/assets not installed, or CAN channel wrong | Enable `hands.enabled`, enter `MOCAP`, run `scripts/dev/test_linkerhand_l6.py`, and check `hands.linkerhand_l6.left_can` / `right_can` | | Video preview is unavailable | RealSense or video source failed | Check camera permissions, `input.video.source`, and logs | diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md index 942bbdcb..c1bf8350 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md @@ -139,28 +139,30 @@ MuJoCo 窗口显示重定向参考;`sim2sim`、`mocap`、`camera` 和 `all` ### 灵巧手(Pico sim2real) -`dexterous_hand.mode=gripper` 或 `dexterous_hand.mode=vr_hand_pose` 要求 -`input.provider=pico4`,并安装可选的 `dexhand` extra。控制只在 `MOCAP` -中生效;非活动模式会发送张开姿态。在 `vr_hand_pose` 中,手部 pose 消失时, -对应侧会保持上一条命令。`gripper` 使用配置的 `dexterous_hand.speed`; -`vr_hand_pose` 始终将 LinkerHand L6 速度设为最大值。默认的 `vr_hand_pose` -路径优先降低延时:它会按 `dexterous_hand.somehand.rate` 在后台线程运行,并关闭 -大部分 somehand 输入/输出平滑,因此手指运动可能更抖。 +`hands.mode=gripper` 或 `hands.mode=vr_hand_pose` 要求 `input.provider=pico4`, +并安装可选的 `dexhand` extra。控制只在 `MOCAP` 中生效;非活动模式会发送张开姿态。 +在 `vr_hand_pose` 中,Teleopit 将 Pico 手部 pose 适配成 somehand 0.2.0 的 +landmark 输入,只调用公开的 `somehand.api`;手部 pose 消失时,对应侧会保持上一条命令。 +`gripper` 使用配置的 `hands.linkerhand_l6.speed`;`vr_hand_pose` 始终将 +LinkerHand L6 速度设为最大值。默认的 `vr_hand_pose` 路径优先降低延时:它会按 +`hands.somehand.rate` 在后台线程运行,并关闭大部分 somehand 输入/输出平滑,因此手指运动可能更抖。 | 字段 | 说明 | 默认值 | |---|---|---| -| `dexterous_hand.mode` | `off`、`gripper` 或 `vr_hand_pose` | `off` | -| `dexterous_hand.hand_type` | 控制侧:`left`、`right` 或 `both`;`vr_hand_pose` 要求 `both` | `both` | -| `dexterous_hand.left_can` / `right_can` | 左右手 CAN 通道 | `can0` / `can1` | -| `dexterous_hand.rate` | gripper 最大命令频率(Hz) | `30.0` | -| `dexterous_hand.frame_timeout` | gripper 手柄超时或 VR 手部 pose 过期阈值 | `0.3` | -| `dexterous_hand.speed` | `gripper` 使用的 L6 速度;`vr_hand_pose` 会覆盖为最大速度 | 见配置 | -| `dexterous_hand.deadman_threshold` | 启用单侧控制所需的最小 grip 值 | `0.5` | -| `dexterous_hand.trigger_deadzone` | trigger 两端死区 | `0.05` | -| `dexterous_hand.open_pose` / `close_pose` | L6 的 6 维张开/闭合姿态 | 见配置 | -| `dexterous_hand.somehand.config_path` | `vr_hand_pose` 使用的 somehand 双手 L6 配置 | 见配置 | -| `dexterous_hand.somehand.rate` | 低延时 `vr_hand_pose` 命令频率(Hz) | `60.0` | -| `dexterous_hand.somehand.threaded` | 在机器人控制循环外运行 `vr_hand_pose` 手部重定向 | `true` | -| `dexterous_hand.somehand.max_iterations` | `vr_hand_pose` 的 somehand solver 迭代上限 | `12` | -| `dexterous_hand.somehand.temporal_filter_alpha` | somehand 输入 landmarks 平滑 alpha;`1.0` 表示关闭平滑延时 | `1.0` | -| `dexterous_hand.somehand.output_alpha` | somehand qpos 输出平滑 alpha;`1.0` 表示关闭平滑延时 | `1.0` | +| `hands.enabled` | 启用可选手部运行时 | `false` | +| `hands.mode` | `off`、`gripper` 或 `vr_hand_pose` | `off` | +| `hands.driver` | 手部设备驱动;当前支持 `linkerhand_l6` | `linkerhand_l6` | +| `hands.linkerhand_l6.hand_type` | 控制侧:`left`、`right` 或 `both`;`vr_hand_pose` 要求 `both` | `both` | +| `hands.linkerhand_l6.left_can` / `right_can` | 左右手 CAN 通道 | `can0` / `can1` | +| `hands.linkerhand_l6.rate` | gripper 最大命令频率(Hz) | `30.0` | +| `hands.linkerhand_l6.frame_timeout` | gripper 手柄超时或 VR 手部 pose 过期阈值 | `0.3` | +| `hands.linkerhand_l6.speed` | `gripper` 使用的 L6 速度;`vr_hand_pose` 会覆盖为最大速度 | 见配置 | +| `hands.linkerhand_l6.deadman_threshold` | 启用单侧控制所需的最小 grip 值 | `0.5` | +| `hands.linkerhand_l6.trigger_deadzone` | trigger 两端死区 | `0.05` | +| `hands.linkerhand_l6.open_pose` / `close_pose` | L6 的 6 维张开/闭合姿态 | 见配置 | +| `hands.somehand.config_path` | `vr_hand_pose` 使用的 somehand 双手 L6 配置 | 见配置 | +| `hands.somehand.rate` | 低延时 `vr_hand_pose` 命令频率(Hz) | `60.0` | +| `hands.somehand.threaded` | 在机器人控制循环外运行 `vr_hand_pose` 手部重定向 | `true` | +| `hands.somehand.max_iterations` | `vr_hand_pose` 的 somehand solver 迭代上限 | `12` | +| `hands.somehand.temporal_filter_alpha` | somehand 输入 landmarks 平滑 alpha;`1.0` 表示关闭平滑延时 | `1.0` | +| `hands.somehand.output_alpha` | somehand qpos 输出平滑 alpha;`1.0` 表示关闭平滑延时 | `1.0` | diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md index 8a038992..5317b6d0 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md @@ -70,8 +70,7 @@ pip install -e '.[dexhand]' scripts/setup/download_somehand_l6_assets.sh ``` -只有在 `dexterous_hand.mode=gripper` 或 -`dexterous_hand.mode=vr_hand_pose` 时才需要安装这个 extra。 +只有在 `hands.mode=gripper` 或 `hands.mode=vr_hand_pose` 时才需要安装这个 extra。 ## 验证安装 diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/architecture.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/architecture.md index cbeaa013..a17dccfd 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/architecture.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/architecture.md @@ -43,7 +43,7 @@ teleopit/ | 接口层 | `interfaces.py` | 定义所有核心抽象接口,模块间仅通过接口通信 | | 运行时 | `runtime/` | Hydra 配置加载、对象组装、依赖注入 | | Pipeline | `pipeline/` | 数据流编排,驱动每一帧的采样-推理-执行循环 | -| Sim2Real | `sim2real/` | 实机通信适配(DDS 桥接、状态同步) | +| Sim2Real | `sim2real/mp/` | 进程隔离的实机状态机、IPC 与机器人控制循环 | | 观测 | `observation/` | 从仿真/实机状态构建策略所需的观测向量 | | 策略 | `rl_policy/` | ONNX 模型加载与推理,action 后处理 | | 入口 | `app.py` | 命令行入口,调用 runtime 装配并启动 pipeline | diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md index 5169aeaa..ef6318ff 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md @@ -137,11 +137,12 @@ Pico 暂停/恢复是 mocap-session control event。 Pico sim2real 可以用两种模式控制 LinkerHand L6: - `gripper`:按住同侧 grip 作为 deadman,同侧 trigger 控制对应手闭合。 - 该模式使用配置的 `dexterous_hand.speed`,默认值为 50。 + 该模式使用配置的 `hands.linkerhand_l6.speed`,默认值为 50。 - `vr_hand_pose`:通过 somehand 重定向 Pico 手部 pose,并下发连续 L6 手部目标。 - 如果某侧手部 pose 消失,该侧会保持上一条手势命令。这个模式当前使用 - `hand_type=both`,并始终将 L6 速度设为最大值。默认配置使用 60 Hz 的低延时 - somehand 路径并减少平滑,所以响应会更快,但可能比标准 somehand 设置更抖。 + 如果某侧手部 pose 消失,该侧会保持上一条手势命令。这个模式使用 Teleopit 的 + Pico landmark 适配器和 somehand 0.2.0 公开的 `somehand.api`,并始终将 L6 + 速度设为最大值。默认配置使用 60 Hz 的低延时 somehand 路径并减少平滑,所以响应会更快, + 但可能比标准 somehand 设置更抖。 手控只在 `MOCAP` 中生效;在 `STANDING`、`DAMPING`、mocap 暂停和退出时都会发送张开姿态。 @@ -171,18 +172,19 @@ python scripts/dev/test_linkerhand_l6.py \ 然后在 Pico sim2real 中启用 L6 控制: ```bash -dexterous_hand.mode=gripper -dexterous_hand.left_can=can0 -dexterous_hand.right_can=can1 +hands.enabled=true +hands.mode=gripper +hands.linkerhand_l6.left_can=can0 +hands.linkerhand_l6.right_can=can1 ``` 连续 VR 手部 pose 控制使用: ```bash -dexterous_hand.mode=vr_hand_pose -dexterous_hand.hand_type=both -dexterous_hand.left_can=can0 -dexterous_hand.right_can=can1 +hands.enabled=true +hands.mode=vr_hand_pose +hands.linkerhand_l6.left_can=can0 +hands.linkerhand_l6.right_can=can1 ``` ## 可选 RealSense 预览 @@ -219,7 +221,8 @@ mocap_switch.check_frames=10 input.pause_button=right_axis_click # 开启 LinkerHand L6 控制 -dexterous_hand.mode=gripper +hands.enabled=true +hands.mode=gripper # 开启头显视频预览 input.video.enabled=true @@ -234,5 +237,5 @@ input.video.enabled=true | 无法进入 debug mode | Unitree mode 释放失败 | 停止其他机器人模式后再次按 `Start` | | 机器人进入 `STANDING` 但不进入 `MOCAP` | 动捕验证失败 | 保持追踪稳定,查看 `mocap_switch.check_frames` 日志 | | Pico 暂停没有返回 `STANDING` | 这是预期行为 | Pico 暂停只冻结 mocap;按遥控器 `X` 返回 `STANDING` | -| LinkerHand 不动 | 模式为 `off`、不在 `MOCAP`、gripper deadman 未按住、SDK/资产未安装,或 CAN 通道错误 | 设置 `dexterous_hand.mode`,进入 `MOCAP`,运行 `scripts/dev/test_linkerhand_l6.py`,并检查 `dexterous_hand.left_can` / `right_can` | +| LinkerHand 不动 | `hands.enabled=false`、模式为 `off`、不在 `MOCAP`、gripper deadman 未按住、SDK/资产未安装,或 CAN 通道错误 | 设置 `hands.enabled=true` 和 `hands.mode`,进入 `MOCAP`,运行 `scripts/dev/test_linkerhand_l6.py`,并检查 `hands.linkerhand_l6.left_can` / `right_can` | | 视频预览不可用 | RealSense 或视频源失败 | 检查相机权限、`input.video.source` 和日志 | diff --git a/scripts/dev/test_linkerhand_l6.py b/scripts/dev/test_linkerhand_l6.py index 95f92deb..e3f73826 100644 --- a/scripts/dev/test_linkerhand_l6.py +++ b/scripts/dev/test_linkerhand_l6.py @@ -23,12 +23,7 @@ from teleopit.inputs.pico4_provider import ( # noqa: E402 Pico4InputProvider, ) -from teleopit.sim2real.dexterous_hand import ( # noqa: E402 - LinkerHandConfig, - LinkerHandRuntime, - SomeHandPoseRuntime, - VR_HAND_POSE_SPEED, -) +from teleopit.sim2real.hands.linkerhand_l6 import VR_HAND_POSE_SPEED, build_linkerhand_l6 # noqa: E402 THUMB_YAW_DEFAULT = 10 @@ -74,7 +69,8 @@ def parse_args() -> argparse.Namespace: default="open_close", help=( "open_close sends fixed poses directly; gripper reads real Pico controller " - "grip/trigger input; vr_hand_pose reads real Pico hand pose input through somehand." + "grip/trigger input; vr_hand_pose reads real Pico hand pose input through Teleopit " + "and uses somehand only for hand retargeting." ), ) parser.add_argument("--hand-type", choices=["left", "right", "both"], default="both") @@ -138,34 +134,38 @@ def parse_args() -> argparse.Namespace: return args -def make_config(args: argparse.Namespace, *, mode: str) -> LinkerHandConfig: +def make_config(args: argparse.Namespace, *, mode: str) -> dict[str, object]: speed = VR_HAND_POSE_SPEED if mode == "vr_hand_pose" else args.speed - vr_hand_pose = mode == "vr_hand_pose" - return LinkerHandConfig( - mode=mode, - enabled=True, - hand_joint="L6", - hand_type=args.hand_type, - left_can=args.left_can, - right_can=args.right_can, - modbus=args.modbus, - rate=args.rate, - frame_timeout=args.frame_timeout, - trigger_deadzone=args.trigger_deadzone, - deadman_threshold=args.deadman_threshold, - thumb_yaw_center=args.thumb_yaw_center, - speed=tuple(speed), - open_pose=tuple(args.open_pose), - close_pose=tuple(args.close_pose), - print_input=args.print_input, - somehand_config_path=args.somehand_config_path, - somehand_sdk_root=args.somehand_sdk_root, - somehand_rate=args.rate if vr_hand_pose else None, - somehand_threaded=False, - somehand_max_iterations=12 if vr_hand_pose else None, - somehand_temporal_filter_alpha=1.0 if vr_hand_pose else None, - somehand_output_alpha=1.0 if vr_hand_pose else None, - ) + return { + "input": {"provider": "pico4"}, + "hands": { + "enabled": True, + "driver": "linkerhand_l6", + "mode": mode, + "sides": list(selected_hand_types(args.hand_type)), + "rate_hz": args.rate, + "frame_timeout_s": args.frame_timeout, + "linkerhand_l6": { + "left_can": args.left_can, + "right_can": args.right_can, + "modbus": args.modbus, + "trigger_deadzone": args.trigger_deadzone, + "deadman_threshold": args.deadman_threshold, + "thumb_yaw_center": args.thumb_yaw_center, + "speed": list(speed), + "open_pose": list(args.open_pose), + "close_pose": list(args.close_pose), + "print_input": args.print_input, + }, + "somehand": { + "config_path": args.somehand_config_path, + "rate_hz": args.rate, + "max_iterations": 12, + "temporal_filter_alpha": 1.0, + "output_alpha": 1.0, + }, + }, + } def send_all(hands: dict[str, object], pose: Sequence[int], *, label: str) -> None: @@ -175,19 +175,6 @@ def send_all(hands: dict[str, object], pose: Sequence[int], *, label: str) -> No hand.finger_move(pose=list(pose)) -def wait_runtime_idle(runtime: object, *, timeout_s: float = 2.0) -> None: - sender = getattr(runtime, "_sender", None) - wait_idle = getattr(sender, "wait_idle", None) - if callable(wait_idle) and not wait_idle(timeout_s=timeout_s): - raise RuntimeError("Timed out waiting for LinkerHand sender to become idle") - - -def assert_runtime_started(runtime: object) -> None: - sender = getattr(runtime, "_sender", None) - if not bool(getattr(sender, "started", False)): - raise RuntimeError("LinkerHand sender failed to start; check the log above for SDK/CAN errors") - - def make_pico_provider(args: argparse.Namespace) -> Pico4InputProvider: return Pico4InputProvider( timeout=args.duration_s, @@ -203,29 +190,32 @@ def make_pico_provider(args: argparse.Namespace) -> Pico4InputProvider: def run_live_until_done( - runtime: LinkerHandRuntime | SomeHandPoseRuntime, + runtime: object, *, provider: Pico4InputProvider, duration_s: float, mode_label: str, + rate_hz: float, ) -> None: deadline = time.monotonic() + duration_s last_seq: int | None = None print(f"Running {mode_label} for {duration_s:.1f}s; press Ctrl-C to stop early.", flush=True) while time.monotonic() < deadline: now_s = time.monotonic() - runtime.tick(active=True, now_s=now_s) - snapshot = ( - provider.get_controller_snapshot() - if isinstance(runtime, LinkerHandRuntime) - else provider.get_hand_snapshot() + controller_snapshot = provider.get_controller_snapshot() + hand_snapshot = provider.get_hand_snapshot() + runtime.tick( + controller_snapshot=controller_snapshot, + hand_snapshot=hand_snapshot, + active=True, + now_s=now_s, ) + snapshot = controller_snapshot if mode_label == "gripper" else hand_snapshot if snapshot is not None and snapshot.seq != last_seq: last_seq = snapshot.seq age_ms = max((now_s - snapshot.timestamp_s) * 1000.0, 0.0) print(f" pico seq={snapshot.seq} age={age_ms:.1f}ms", flush=True) - wait_runtime_idle(runtime) - time.sleep(max(1.0 / runtime.config.rate, 0.001)) + time.sleep(max(1.0 / rate_hz, 0.001)) def run_open_close(args: argparse.Namespace) -> None: @@ -283,50 +273,48 @@ def run_open_close(args: argparse.Namespace) -> None: def run_gripper(args: argparse.Namespace) -> None: config = make_config(args, mode="gripper") provider = make_pico_provider(args) - runtime = LinkerHandRuntime(config, provider) + device, mapper = build_linkerhand_l6(config) + from teleopit.sim2real.hands.worker import HandRuntime + runtime = HandRuntime(device, mapper) print( - "Testing dexterous_hand.mode=gripper with real Pico controller input. " + "Testing hands.mode=gripper with real Pico controller input. " "Hold grip above the deadman threshold, then use trigger to close/open.", flush=True, ) try: runtime.start() - wait_runtime_idle(runtime) - assert_runtime_started(runtime) - run_live_until_done(runtime, provider=provider, duration_s=args.duration_s, mode_label="gripper") + run_live_until_done(runtime, provider=provider, duration_s=args.duration_s, mode_label="gripper", rate_hz=args.rate) except KeyboardInterrupt: print("Interrupted; opening hands before exit", flush=True) finally: - runtime.tick(active=False) - wait_runtime_idle(runtime) + runtime.tick(controller_snapshot=None, hand_snapshot=None, active=False) runtime.close() provider.close() def run_vr_hand_pose(args: argparse.Namespace) -> None: if args.hand_type != "both": - raise SystemExit("dexterous_hand.mode=vr_hand_pose currently requires --hand-type both") + raise SystemExit("hands.mode=vr_hand_pose currently requires --hand-type both") config = make_config(args, mode="vr_hand_pose") provider = make_pico_provider(args) - runtime = SomeHandPoseRuntime(config, provider) + device, mapper = build_linkerhand_l6(config) + from teleopit.sim2real.hands.worker import HandRuntime + runtime = HandRuntime(device, mapper) print( - "Testing dexterous_hand.mode=vr_hand_pose with real Pico hand-pose input. " + "Testing hands.mode=vr_hand_pose with real Pico hand-pose input. " "Enable Pico hand tracking and move both hands; start with the robot clear of contacts.", flush=True, ) try: runtime.start() - wait_runtime_idle(runtime) - assert_runtime_started(runtime) - run_live_until_done(runtime, provider=provider, duration_s=args.duration_s, mode_label="vr_hand_pose") + run_live_until_done(runtime, provider=provider, duration_s=args.duration_s, mode_label="vr_hand_pose", rate_hz=args.rate) except KeyboardInterrupt: print("Interrupted; opening hands before exit", flush=True) finally: - runtime.tick(active=False) - wait_runtime_idle(runtime) + runtime.tick(controller_snapshot=None, hand_snapshot=None, active=False) runtime.close() provider.close() diff --git a/scripts/run/check_pico_signal.py b/scripts/run/check_pico_signal.py index b30cb671..44eac250 100644 --- a/scripts/run/check_pico_signal.py +++ b/scripts/run/check_pico_signal.py @@ -197,7 +197,7 @@ def main(cfg: DictConfig) -> None: input_cfg = cfg_get(cfg, "input", {}) or {} video_cfg = parse_pico_video_config(input_cfg) diag_cfg = cfg_get(cfg, "diagnostic", {}) or {} - poll_hz = float(cfg_get(diag_cfg, "poll_hz", cfg_get(cfg_get(cfg, "multiprocess", {}) or {}, "pico_io_hz", 120.0))) + poll_hz = float(cfg_get(diag_cfg, "poll_hz", cfg_get(cfg_get(cfg, "runtime", {}) or {}, "pico_input_hz", 120.0))) summary_interval_s = float(cfg_get(diag_cfg, "summary_interval_s", 1.0)) duration_s = float(cfg_get(diag_cfg, "duration_s", 0.0)) diff --git a/scripts/run/run_sim2real.py b/scripts/run/run_sim2real.py index 5aa312ed..a3553ddc 100644 --- a/scripts/run/run_sim2real.py +++ b/scripts/run/run_sim2real.py @@ -6,8 +6,7 @@ from omegaconf import DictConfig from teleopit.runtime.cli import validate_policy_path -from teleopit.sim2real.mp import MultiprocessSim2RealController, resolve_sim2real_runtime_mode -from teleopit.sim2real.controller import Sim2RealController +from teleopit.sim2real.mp import Sim2RealRuntime def _print_sim2real_controls(cfg: DictConfig) -> None: @@ -19,7 +18,7 @@ def _print_sim2real_controls(cfg: DictConfig) -> None: print(" Remote L1+R1: DAMPING / estop.") if provider == "pico4": print(" Mocap pause/resume: Pico/controller A.") - print(" Dexterous hand: dexterous_hand.mode=off|gripper|vr_hand_pose (default off).") + print(" Dexterous hand: hands.enabled=true hands.mode=gripper|vr_hand_pose.") else: print(" Offline playback: A pause/resume, B replay from start.") print(" State flow: IDLE -> STANDING -> MOCAP -> STANDING, Any -> DAMPING.") @@ -32,15 +31,10 @@ def main(cfg: DictConfig) -> None: def _run_sim2real(cfg: DictConfig) -> None: validate_policy_path(cfg, "run_sim2real.py") - runtime_mode = resolve_sim2real_runtime_mode(cfg) - controller = ( - MultiprocessSim2RealController(cfg) - if runtime_mode == "multiprocess" - else Sim2RealController(cfg) - ) + controller = Sim2RealRuntime(cfg) if cfg.input.get("provider") == "pico4": print("Waiting for Pico4 body tracking data...") - print(f"Sim2real runtime: {runtime_mode}") + print("Sim2real runtime: multiprocess") _print_sim2real_controls(cfg) try: controller.run() diff --git a/scripts/run/standalone_standing.py b/scripts/run/standalone_standing.py index 8cdaef3e..007c2085 100644 --- a/scripts/run/standalone_standing.py +++ b/scripts/run/standalone_standing.py @@ -2,7 +2,7 @@ """Standalone G1 standing script with RL policy -- no Teleopit/Pico dependency. Uses ONNX RL policy inference to maintain balanced standing, matching the -STANDING mode in Sim2RealController. Only depends on: +STANDING mode used by the sim2real robot-control runtime. Only depends on: - g1_bridge_sdk (C++ DDS bridge) - onnxruntime - mujoco @@ -355,7 +355,7 @@ def reset(self): # ===================================================================== class StandingController: - """RL-policy-based standing controller matching Sim2RealController.STANDING.""" + """RL-policy-based standing controller matching sim2real STANDING behavior.""" def __init__(self, network_interface: str, policy_path: str, no_policy: bool = False, @@ -675,7 +675,7 @@ def _check_joint_vel_safety(self, qvel: np.ndarray) -> bool: return True return False - # ---- Standing step (matches Sim2RealController._standing_step) ---- + # ---- Standing step (matches sim2real robot-control standing step) ---- def _standing_step(self) -> np.ndarray: """One step of RL policy standing inference. Returns target joint positions.""" diff --git a/scripts/setup/download_somehand_l6_assets.sh b/scripts/setup/download_somehand_l6_assets.sh index 61835102..5729e731 100755 --- a/scripts/setup/download_somehand_l6_assets.sh +++ b/scripts/setup/download_somehand_l6_assets.sh @@ -3,171 +3,12 @@ set -euo pipefail SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" PROJECT_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)" +SOMEHAND_DIR="${PROJECT_ROOT}/third_party/somehand" -SOURCE="modelscope" -REPO_ID="" -DEST="${PROJECT_ROOT}/third_party/somehand/assets/mjcf" -CACHE_DIR="${PROJECT_ROOT}/data/somehand_assets_cache" - -usage() { - cat <<'EOF' -Download somehand LinkerHand L6 bi-hand MJCF assets. - -Usage: - scripts/setup/download_somehand_l6_assets.sh - scripts/setup/download_somehand_l6_assets.sh --source huggingface - scripts/setup/download_somehand_l6_assets.sh --dest third_party/somehand/assets/mjcf - -Options: - --source modelscope|huggingface Download backend (default: modelscope) - --repo-id REPO Override asset repo id - --dest PATH Destination mjcf directory - --cache-dir PATH Download cache directory - -h, --help Show this help -EOF -} - -while [[ $# -gt 0 ]]; do - case "$1" in - --source) - SOURCE="$2" - shift 2 - ;; - --repo-id) - REPO_ID="$2" - shift 2 - ;; - --dest) - DEST="$2" - shift 2 - ;; - --cache-dir) - CACHE_DIR="$2" - shift 2 - ;; - -h|--help) - usage - exit 0 - ;; - *) - echo "Unknown argument: $1" >&2 - usage >&2 - exit 2 - ;; - esac -done - -if [[ "${SOURCE}" != "modelscope" && "${SOURCE}" != "huggingface" ]]; then - echo "--source must be modelscope or huggingface" >&2 - exit 2 +if [[ ! -f "${SOMEHAND_DIR}/scripts/setup/download_assets.py" ]]; then + echo "somehand submodule is not initialized. Run: git submodule update --init third_party/somehand" >&2 + exit 1 fi -python - "$PROJECT_ROOT" "$SOURCE" "$REPO_ID" "$DEST" "$CACHE_DIR" <<'PY' -from __future__ import annotations - -import shutil -import subprocess -import sys -import tarfile -from pathlib import Path - -project_root = Path(sys.argv[1]).resolve() -source = sys.argv[2] -repo_id = sys.argv[3] or ("BingqianWu/somehand-assets" if source == "modelscope" else "12e21/somehand-assets") -dest = Path(sys.argv[4]).expanduser() -if not dest.is_absolute(): - dest = (project_root / dest).resolve() -cache_dir = Path(sys.argv[5]).expanduser() -if not cache_dir.is_absolute(): - cache_dir = (project_root / cache_dir).resolve() - -hands = ("linkerhand_l6_left", "linkerhand_l6_right") -archive_pattern = "archives/mjcf_assets.tar.gz" -repo_cache = cache_dir / source / repo_id.split("/")[-1] - - -def ensure_package(name: str): - try: - return __import__(name) - except ImportError: - subprocess.check_call([sys.executable, "-m", "pip", "install", name]) - return __import__(name) - - -def remove_path(path: Path) -> None: - if path.is_dir() and not path.is_symlink(): - shutil.rmtree(path) - elif path.exists() or path.is_symlink(): - path.unlink() - - -def copy_dir(src: Path, dst: Path) -> None: - remove_path(dst) - dst.parent.mkdir(parents=True, exist_ok=True) - shutil.copytree(src, dst) - - -def download(patterns: list[str]) -> None: - repo_cache.mkdir(parents=True, exist_ok=True) - if source == "huggingface": - hub = ensure_package("huggingface_hub") - hub.snapshot_download( - repo_id=repo_id, - repo_type="model", - local_dir=str(repo_cache), - allow_patterns=patterns, - ) - return - modelscope = ensure_package("modelscope") - modelscope.snapshot_download( - repo_id, - repo_type="model", - local_dir=str(repo_cache), - allow_patterns=patterns, - allow_file_pattern=patterns, - ) - - -def safe_extract_l6_from_archive(archive_path: Path) -> None: - dest.mkdir(parents=True, exist_ok=True) - wanted_prefixes = {f"mjcf/{hand}/" for hand in hands} - with tarfile.open(archive_path, "r:*") as tar: - members = [] - for member in tar.getmembers(): - path = Path(member.name) - if path.is_absolute() or ".." in path.parts: - raise ValueError(f"Unsafe archive member path: {member.name}") - normalized = member.name.lstrip("./") - if any(normalized.startswith(prefix) for prefix in wanted_prefixes): - members.append(member) - if not members: - raise FileNotFoundError(f"No LinkerHand L6 assets found in {archive_path}") - tmp = dest.parent / ".somehand_l6_extracting" - remove_path(tmp) - tmp.mkdir(parents=True, exist_ok=True) - tar.extractall(tmp, members=members) - for hand in hands: - src = tmp / "mjcf" / hand - if not src.exists(): - raise FileNotFoundError(f"Archive missing mjcf/{hand}") - copy_dir(src, dest / hand) - print(f" archive:{hand} -> {dest / hand}") - remove_path(tmp) - - -print(f"Downloading somehand L6 assets from {source}:{repo_id}") -print(f"Destination: {dest}") - -download([archive_pattern]) -archive = repo_cache / archive_pattern -if not archive.exists(): - raise FileNotFoundError(f"Downloaded repo is missing {archive_pattern}") -safe_extract_l6_from_archive(archive) - -for hand in hands: - model = dest / hand / "model.xml" - if not model.exists(): - raise FileNotFoundError(f"Expected model file not found: {model}") - -print("Done.") -PY +cd "${SOMEHAND_DIR}" +python scripts/setup/download_assets.py --only mjcf "$@" diff --git a/teleopit/configs/pico4_sim2real.yaml b/teleopit/configs/pico4_sim2real.yaml index e2c49810..d1e2816a 100644 --- a/teleopit/configs/pico4_sim2real.yaml +++ b/teleopit/configs/pico4_sim2real.yaml @@ -5,7 +5,6 @@ defaults: - _self_ policy_hz: 50.0 -sim2real_runtime: multiprocess viewers: "none" # Optional: set viewers=retarget to show the retargeted reference input: video: @@ -19,12 +18,12 @@ reference_anchor_velocity_smoothing_alpha: 0.25 reference_steps: [0] reference_debug_log: false -multiprocess: +runtime: host: 127.0.0.1 base_port: 39700 start_method: spawn shutdown_timeout_s: 3.0 - pico_io_hz: 120.0 + pico_input_hz: 120.0 hand_worker_hz: 120.0 retarget_idle_sleep_s: 0.001 video_slots: 3 @@ -43,28 +42,28 @@ standing_return_kp_ramp_floor_ratio: 0.5 joint_vel_limit: 10.0 # Optional LinkerHand L6 control from Pico controller grip/trigger or VR hand pose. -dexterous_hand: - mode: "off" # off | gripper | vr_hand_pose - hand_joint: L6 - hand_type: both - left_can: can0 - right_can: can1 - modbus: "None" - rate: 30.0 - frame_timeout: 0.3 - trigger_deadzone: 0.05 - deadman_threshold: 0.5 - thumb_yaw_center: 10 - speed: [50, 50, 50, 50, 50, 50] # gripper mode; vr_hand_pose always uses max speed - open_pose: [250, 10, 250, 250, 250, 250] - close_pose: [79, 10, 0, 0, 0, 0] - print_input: false +hands: + enabled: false + driver: linkerhand_l6 + mode: gripper # gripper | vr_hand_pose + sides: [left, right] + rate_hz: 30.0 + frame_timeout_s: 0.3 + linkerhand_l6: + left_can: can0 + right_can: can1 + modbus: "None" + trigger_deadzone: 0.05 + deadman_threshold: 0.5 + thumb_yaw_center: 10 + speed: [50, 50, 50, 50, 50, 50] + open_pose: [250, 10, 250, 250, 250, 250] + close_pose: [79, 10, 0, 0, 0, 0] + print_input: false somehand: config_path: third_party/somehand/configs/retargeting/bihand/linkerhand_l6_bihand.yaml - sdk_root: third_party/linkerhand-python-sdk # Low-latency vr_hand_pose path. This favors response speed over smoothing. - rate: 60.0 - threaded: true + rate_hz: 60.0 max_iterations: 12 temporal_filter_alpha: 1.0 output_alpha: 1.0 diff --git a/teleopit/configs/sim2real.yaml b/teleopit/configs/sim2real.yaml index 2b4a3277..3a84fbd9 100644 --- a/teleopit/configs/sim2real.yaml +++ b/teleopit/configs/sim2real.yaml @@ -5,7 +5,6 @@ defaults: - _self_ policy_hz: 50.0 -sim2real_runtime: auto # auto | single_process | multiprocess; auto uses multiprocess for Pico4 viewers: "none" # Optional: set viewers=retarget to show the retargeted reference retarget_buffer_enabled: true retarget_buffer_window_s: 0.5 @@ -18,12 +17,12 @@ reference_debug_log: false playback: pause_on_end: true -multiprocess: +runtime: host: 127.0.0.1 base_port: 39700 start_method: spawn shutdown_timeout_s: 3.0 - pico_io_hz: 120.0 + pico_input_hz: 120.0 hand_worker_hz: 120.0 retarget_idle_sleep_s: 0.001 video_slots: 3 @@ -42,28 +41,28 @@ standing_return_kp_ramp_floor_ratio: 0.5 joint_vel_limit: 10.0 # Optional LinkerHand L6 control. Use only with input.provider=pico4. -dexterous_hand: - mode: "off" # off | gripper | vr_hand_pose - hand_joint: L6 - hand_type: both - left_can: can0 - right_can: can1 - modbus: "None" - rate: 30.0 - frame_timeout: 0.3 - trigger_deadzone: 0.05 - deadman_threshold: 0.5 - thumb_yaw_center: 10 - speed: [50, 50, 50, 50, 50, 50] # gripper mode; vr_hand_pose always uses max speed - open_pose: [250, 10, 250, 250, 250, 250] - close_pose: [79, 10, 0, 0, 0, 0] - print_input: false +hands: + enabled: false + driver: linkerhand_l6 + mode: gripper # gripper | vr_hand_pose + sides: [left, right] + rate_hz: 30.0 + frame_timeout_s: 0.3 + linkerhand_l6: + left_can: can0 + right_can: can1 + modbus: "None" + trigger_deadzone: 0.05 + deadman_threshold: 0.5 + thumb_yaw_center: 10 + speed: [50, 50, 50, 50, 50, 50] # gripper mode; vr_hand_pose always uses max speed + open_pose: [250, 10, 250, 250, 250, 250] + close_pose: [79, 10, 0, 0, 0, 0] + print_input: false somehand: config_path: third_party/somehand/configs/retargeting/bihand/linkerhand_l6_bihand.yaml - sdk_root: third_party/linkerhand-python-sdk # Low-latency vr_hand_pose path. This favors response speed over smoothing. - rate: 60.0 - threaded: true + rate_hz: 60.0 max_iterations: 12 temporal_filter_alpha: 1.0 output_alpha: 1.0 diff --git a/teleopit/runtime/reference_config.py b/teleopit/runtime/reference_config.py index a84f7f23..cf0e7ab5 100644 --- a/teleopit/runtime/reference_config.py +++ b/teleopit/runtime/reference_config.py @@ -1,7 +1,7 @@ """Shared reference-window / realtime-buffer configuration. Parsed once from the top-level config and consumed by both -``SimulationLoop`` and ``Sim2RealController``. +``SimulationLoop`` and the process-isolated sim2real runtime. """ from __future__ import annotations diff --git a/teleopit/sim/reference_utils.py b/teleopit/sim/reference_utils.py index 8f862eae..9b5807ad 100644 --- a/teleopit/sim/reference_utils.py +++ b/teleopit/sim/reference_utils.py @@ -1,6 +1,6 @@ """Shared reference-window utilities used by both offline simulation and sim2real. -These are pure functions extracted from ``SimLoop`` and ``Sim2RealController`` to avoid +These are pure functions extracted from ``SimLoop`` and sim2real runtime code to avoid code duplication. Each function takes explicit arguments instead of ``self``. """ diff --git a/teleopit/sim2real/__init__.py b/teleopit/sim2real/__init__.py index 7fe810de..9f8f6921 100644 --- a/teleopit/sim2real/__init__.py +++ b/teleopit/sim2real/__init__.py @@ -1,9 +1,9 @@ -from teleopit.sim2real.controller import Sim2RealController +from teleopit.sim2real.mp import Sim2RealRuntime from teleopit.sim2real.unitree_g1 import UnitreeG1Robot from teleopit.sim2real.remote import UnitreeRemote, Button __all__ = [ - "Sim2RealController", + "Sim2RealRuntime", "UnitreeG1Robot", "UnitreeRemote", "Button", diff --git a/teleopit/sim2real/controller.py b/teleopit/sim2real/controller.py deleted file mode 100644 index 5c853794..00000000 --- a/teleopit/sim2real/controller.py +++ /dev/null @@ -1,1105 +0,0 @@ -"""Sim2Real controller -- state machine + dual-mode control loop. - -Supports two operating modes for a physical Unitree G1: -- **Standing**: RL policy maintains balance with fixed default-pose reference -- **Mocap**: RL policy tracks retargeted motion commands - -State machine: - IDLE ──Start──▶ STANDING ──Y──▶ MOCAP ──X──▶ STANDING - Any ──L1+R1──▶ DAMPING ──Start──▶ STANDING -""" - -from __future__ import annotations - -import logging -import time -from enum import Enum -from pathlib import Path -from typing import Any - -import numpy as np -from numpy.typing import NDArray - -from teleopit.constants import FULL_QPOS_DIM, NUM_JOINTS, ROOT_DIM -from teleopit.controllers.observation import ( - VelCmdObservationBuilder, - align_motion_qpos_yaw, -) -from teleopit.controllers.rl_policy import RLPolicyController -from teleopit.inputs.bvh_provider import BVHInputProvider -from teleopit.inputs.pico4_provider import Pico4InputProvider -from teleopit.inputs.pico_video import PicoVideoRuntime, parse_pico_video_config -from teleopit.inputs.realtime_packet import ControlEvent, ControlEventType, RealtimeInputPacket -from teleopit.retargeting.core import RetargetingModule -from teleopit.runtime.common import cfg_get, parse_viewers -from teleopit.runtime.reference_config import parse_reference_config -from teleopit.runtime.factory import build_sim2real_mocap_components -from teleopit.runtime.mocap_session import MocapSessionManager, MocapSessionState -from teleopit.runtime.offline_playback import OfflinePlaybackController -from teleopit.sim.reference_motion import OfflineReferenceMotion -from teleopit.sim.reference_timeline import ReferenceTimeline, ReferenceWindow, ReferenceWindowBuilder -from teleopit.sim.reference_utils import ( - build_offline_reference_window, - build_static_reference_window, - obs_builder_requires_reference_window, -) -from teleopit.sim.realtime_utils import RealtimeReferenceManager -from teleopit.sim.viewer_subprocess import start_robot_viewer -from teleopit.sim2real.dexterous_hand import build_linkerhand_runtime -from teleopit.sim2real.reference_processor import Sim2RealReferenceProcessor -from teleopit.sim2real.remote import UnitreeRemote -from teleopit.sim2real.safety import Sim2RealSafetyManager -from teleopit.sim2real.unitree_g1 import UnitreeG1Robot - -logger = logging.getLogger(__name__) - -Float32Array = NDArray[np.float32] -Float64Array = NDArray[np.float64] - - -class RobotMode(Enum): - IDLE = "idle" # Script waiting, robot controlled by remote - STANDING = "standing" # Debug mode, RL policy holds default pose - MOCAP = "mocap" # Debug mode, RL policy tracks motion commands - DAMPING = "damping" # Emergency stop / recovery - - -class _LoopTimingReporter: - """Aggregate best-effort control-loop timing stats and emit periodic logs.""" - - def __init__(self, *, target_period_s: float, log_interval_s: float = 1.0) -> None: - self._target_period_s = float(target_period_s) - self._log_interval_s = float(log_interval_s) - self._window_start_s: float | None = None - self._loop_ms: list[float] = [] - self._work_ms: list[float] = [] - self._pico_age_ms: list[float] = [] - self._overrun_count = 0 - - def record( - self, - *, - loop_start_s: float, - work_elapsed_s: float, - cycle_elapsed_s: float, - pico_age_s: float | None, - ) -> None: - if self._window_start_s is None: - self._window_start_s = float(loop_start_s) - - self._loop_ms.append(float(cycle_elapsed_s) * 1000.0) - self._work_ms.append(float(work_elapsed_s) * 1000.0) - if pico_age_s is not None: - self._pico_age_ms.append(float(pico_age_s) * 1000.0) - if cycle_elapsed_s > self._target_period_s + 1e-9: - self._overrun_count += 1 - - if loop_start_s - self._window_start_s >= self._log_interval_s: - self._emit(loop_start_s) - - def _emit(self, end_s: float) -> None: - sample_count = len(self._loop_ms) - if sample_count <= 0: - self._reset(end_s) - return - - loop_summary = self._summarize(self._loop_ms) - work_summary = self._summarize(self._work_ms) - message = ( - "Timing stats | samples=%d window=%.1fs | " - "loop_ms p50=%.2f p95=%.2f p99=%.2f max=%.2f overrun=%d/%d | " - "work_ms p50=%.2f p95=%.2f p99=%.2f max=%.2f" - ) - args: list[object] = [ - sample_count, - end_s - float(self._window_start_s), - loop_summary[0], - loop_summary[1], - loop_summary[2], - loop_summary[3], - self._overrun_count, - sample_count, - work_summary[0], - work_summary[1], - work_summary[2], - work_summary[3], - ] - if self._pico_age_ms: - pico_summary = self._summarize(self._pico_age_ms) - message += " | pico_age_ms p50=%.2f p95=%.2f p99=%.2f max=%.2f" - args.extend([pico_summary[0], pico_summary[1], pico_summary[2], pico_summary[3]]) - logger.info(message, *args) - self._reset(end_s) - - def _reset(self, window_start_s: float) -> None: - self._window_start_s = float(window_start_s) - self._loop_ms.clear() - self._work_ms.clear() - self._pico_age_ms.clear() - self._overrun_count = 0 - - @staticmethod - def _summarize(samples: list[float]) -> tuple[float, float, float, float]: - values = np.asarray(samples, dtype=np.float64) - if values.size <= 0: - return 0.0, 0.0, 0.0, 0.0 - p50, p95, p99 = np.percentile(values, [50.0, 95.0, 99.0]) - return float(p50), float(p95), float(p99), float(np.max(values)) - - -def _parse_sim2real_viewers(cfg: Any) -> set[str]: - viewers = parse_viewers(cfg) - unsupported = viewers.difference({"retarget"}) - if unsupported: - raise ValueError( - f"Sim2real supports only the optional 'retarget' viewer; got unsupported viewers {sorted(unsupported)}. " - "Use viewers=retarget or viewers=none." - ) - return viewers - - -class _Sim2RealRetargetViewer: - def __init__(self, *, xml_path: str | None, enabled: bool) -> None: - self._entry: tuple[Any, Any, Any, Any] | None = None - if not enabled: - return - if not xml_path: - raise ValueError("Sim2real retarget viewer requires robot.xml_path to be set.") - self._entry = start_robot_viewer( - xml_path, - FULL_QPOS_DIM, - True, - "Retarget", - 900, - 50, - ) - - def write(self, qpos: Float64Array) -> None: - if self._entry is None: - return - _, arr, alive, _ = self._entry - if not alive.value: - return - qpos = np.asarray(qpos, dtype=np.float64).reshape(-1) - if qpos.shape[0] < FULL_QPOS_DIM: - return - with arr.get_lock(): - arr[:FULL_QPOS_DIM] = qpos[:FULL_QPOS_DIM].tolist() - - def shutdown(self) -> None: - if self._entry is None: - return - proc, _, _, shutdown = self._entry - shutdown.set() - proc.join(timeout=3) - if proc.is_alive(): - proc.terminate() - self._entry = None - - -class Sim2RealController: - """G1 real-robot controller -- standing/mocap dual mode with state machine. - - Standing mode: enter debug mode, RL policy maintains balance at default pose. - Mocap mode: RL policy tracks retargeted motion commands. - Both modes share the same RL policy inference pipeline. - """ - - def __init__(self, cfg: Any) -> None: - self.cfg = cfg - self.mode = RobotMode.IDLE - - self.policy_hz: float = float(cfg_get(cfg, "policy_hz", 50.0)) - self._project_root = Path(__file__).resolve().parent.parent.parent - - self._init_components(cfg) - self._init_reference_config(cfg) - self._safety = Sim2RealSafetyManager(cfg, self.robot, self.policy_hz, self.num_actions) - self._standing_return_ramp_duration = float(cfg_get(cfg, "standing_return_ramp_duration", 0.5)) - self._standing_return_kp_ramp_floor_ratio = float( - cfg_get(cfg, "standing_return_kp_ramp_floor_ratio", 0.5) - ) - - logger.info( - "Sim2RealController ready | mode=IDLE | policy_hz=%.0f", - self.policy_hz, - ) - - def _init_components(self, cfg: Any) -> None: - """Build robot hardware and mocap pipeline components.""" - real_cfg = cfg_get(cfg, "real_robot") - self.robot = UnitreeG1Robot(real_cfg) - self.remote = UnitreeRemote() - - robot_cfg = cfg_get(cfg, "robot") - mocap_components = build_sim2real_mocap_components( - cfg, - self._project_root, - controller_cls=RLPolicyController, - obs_builder_cls=VelCmdObservationBuilder, - bvh_input_cls=BVHInputProvider, - pico4_input_cls=Pico4InputProvider, - retargeter_cls=RetargetingModule, - ) - self.input_provider = mocap_components.input_provider - self.retargeter = mocap_components.retargeter - self.policy = mocap_components.controller - self.obs_builder = mocap_components.obs_builder - self._video_runtime = PicoVideoRuntime( - provider=self.input_provider, - config=parse_pico_video_config(cfg_get(cfg, "input", {})), - mode="sim2real", - ) - self._hand_runtime = build_linkerhand_runtime(cfg, self.input_provider) - self._offline_reference: OfflineReferenceMotion | None = None - self._offline_playback: OfflinePlaybackController | None = None - if hasattr(self.input_provider, "__len__") and hasattr(self.input_provider, "get_frame_by_index"): - playback_cfg = cfg_get(cfg, "playback", {}) - self._offline_reference = OfflineReferenceMotion(self.input_provider, self.retargeter) - self._offline_playback = OfflinePlaybackController( - duration_s=self._offline_reference.duration_s, - step_dt_s=1.0 / self.policy_hz, - pause_on_end=bool(cfg_get(playback_cfg, "pause_on_end", True)), - ) - if not bool(getattr(self.policy, "_multi_input", False)): - raise ValueError( - "Sim2real requires an ONNX policy with dual inputs ('obs' and 'obs_history')." - ) - - # Default standing pose (29-DOF) - self.default_angles = np.asarray( - cfg_get(robot_cfg, "default_angles"), dtype=np.float32 - ) - self.num_actions: int = int(cfg_get(robot_cfg, "num_actions", NUM_JOINTS)) - - # Standing mode reference qpos - self._standing_qpos = np.zeros(FULL_QPOS_DIM, dtype=np.float64) - self._standing_qpos[3] = 1.0 # identity quaternion w=1 - self._standing_qpos[ROOT_DIM:FULL_QPOS_DIM] = self.default_angles.astype(np.float64) - - # Policy state (shared by STANDING and MOCAP) - self._last_action: Float32Array = np.zeros(self.num_actions, dtype=np.float32) - self._last_retarget_qpos: Float64Array | None = None - self._mocap_reentry_armed: bool = False - self._mocap_session = MocapSessionManager() - self._last_commanded_motion_qpos: Float64Array | None = None - self._last_mocap_hold_reason: str | None = None - self._viewers = _parse_sim2real_viewers(cfg) - self._retarget_viewer = _Sim2RealRetargetViewer( - xml_path=str(cfg_get(robot_cfg, "xml_path", "")) if "retarget" in self._viewers else None, - enabled="retarget" in self._viewers, - ) - - def _init_reference_config(self, cfg: Any) -> None: - """Parse reference-window / realtime-buffer configuration.""" - provider_fps = float(getattr(self.input_provider, "fps", 30.0)) - self._ref_cfg = parse_reference_config(cfg, provider_fps=provider_fps) - rc = self._ref_cfg - - self._reference_window_builder = ReferenceWindowBuilder( - policy_dt_s=1.0 / self.policy_hz, - reference_steps=cfg_get(cfg, "reference_steps", [0]), - ) - if not rc.retarget_buffer_enabled and self._reference_window_builder.requires_timeline: - raise ValueError( - "Non-zero reference_steps require retarget_buffer_enabled=true in sim2real so " - "the realtime reference timeline can sample future/history horizons." - ) - if rc.retarget_buffer_enabled: - self._reference_window_builder.validate_runtime_support( - delay_s=rc.reference_delay_s, - window_s=rc.retarget_buffer_window_s, - config_label="Sim2Real reference timeline", - ) - self._reference_timeline: ReferenceTimeline | None = ( - ReferenceTimeline(window_s=rc.retarget_buffer_window_s) - if rc.retarget_buffer_enabled - else None - ) - self._reference_manager: RealtimeReferenceManager | None = ( - RealtimeReferenceManager( - reference_window_builder=self._reference_window_builder, - warmup_steps=rc.realtime_buffer_warmup_steps, - ) - if self._reference_timeline is not None - else None - ) - self._ref_proc = Sim2RealReferenceProcessor( - obs_builder=self.obs_builder, - policy=self.policy, - policy_hz=self.policy_hz, - num_actions=self.num_actions, - reference_velocity_smoothing_alpha=rc.reference_velocity_smoothing_alpha, - reference_anchor_velocity_smoothing_alpha=rc.reference_anchor_velocity_smoothing_alpha, - ) - self._last_live_packet_seq = -1 - - # Mocap switch safety - mocap_sw = cfg_get(cfg, "mocap_switch", {}) - self._check_frames: int = int(cfg_get(mocap_sw, "check_frames", 10)) - - # ------------------------------------------------------------------ - # Main control loop - # ------------------------------------------------------------------ - - def run(self) -> None: - """Main control loop at policy_hz.""" - logger.info( - "Control loop started | mode=IDLE | press Start to enter STANDING" - ) - dt = 1.0 / self.policy_hz - timing = _LoopTimingReporter(target_period_s=dt) - - try: - self._video_runtime.start() - self._hand_runtime.start() - while True: - t0 = time.monotonic() - self._video_runtime.tick() - - # 1. Read remote state - remote_bytes = self.robot.get_wireless_remote() - self.remote.update(remote_bytes) - - # 2. Emergency stop (highest priority) - if self.remote.LB.pressed and self.remote.RB.pressed: - if self.mode != RobotMode.DAMPING: - logger.warning("EMERGENCY STOP (L1+R1)") - self._enter_damping() - self._tick_dexterous_hand() - else: - # 3. Mode transitions - self._handle_transitions() - - # 5. Execute current mode - if self.mode == RobotMode.STANDING: - self._standing_step() - elif self.mode == RobotMode.MOCAP: - self._mocap_step() - - self._tick_dexterous_hand() - - work_elapsed_s = time.monotonic() - t0 - cycle_elapsed_s = self._sleep_until(t0, dt) - timing.record( - loop_start_s=t0, - work_elapsed_s=work_elapsed_s, - cycle_elapsed_s=cycle_elapsed_s, - pico_age_s=self._sample_pico_frame_age_s(), - ) - - except KeyboardInterrupt: - logger.info("KeyboardInterrupt -- shutting down") - - # ------------------------------------------------------------------ - # Mode execution - # ------------------------------------------------------------------ - - def _standing_step(self) -> None: - """Standing mode: feed fixed default-pose reference to RL policy.""" - robot_state = self.robot.get_state() - - qpos = self._standing_qpos.copy() - - # Standing → zero joint velocity reference - motion_joint_vel = np.zeros(self.num_actions, dtype=np.float32) - motion_qpos = np.asarray(qpos[:7 + self.num_actions], dtype=np.float32) - - reference_window = None - if obs_builder_requires_reference_window(self.obs_builder): - reference_window = build_static_reference_window(qpos, self._reference_window_builder, self.policy_hz) - - obs = self._ref_proc.build_observation( - robot_state=robot_state, - motion_qpos=motion_qpos, - motion_joint_vel=motion_joint_vel, - last_action=self._last_action, - anchor_lin_vel_w=np.zeros(3, dtype=np.float32), - anchor_ang_vel_w=np.zeros(3, dtype=np.float32), - reference_window=reference_window, - ) - obs = self._ref_proc.validate_observation(obs) - - action = self.policy.compute_action(obs) - target_dof_pos = self.policy.get_target_dof_pos(action) - target_dof_pos = self._safety.clip_to_joint_limits(target_dof_pos) - - self._safety.send_positions(target_dof_pos) - - self._last_action = np.asarray(action, dtype=np.float32).reshape(-1) - self._last_retarget_qpos = qpos.copy() - self._last_commanded_motion_qpos = qpos.copy() - self._write_retarget_viewer(qpos) - - def _mocap_step(self) -> None: - """Mocap mode: input provider -> retarget -> policy -> update LowCmd targets.""" - if self._offline_reference is not None: - self._offline_mocap_step() - return - - if not self.input_provider.is_available(): - self._hold_mocap_reference("input provider unavailable") - return - - try: - packet = self._fetch_realtime_input_packet() - except (TimeoutError, RuntimeError) as exc: - self._hold_mocap_reference( - "input provider error", - detail=f"{type(exc).__name__}: {exc}", - ) - return - - self._handle_mocap_control_events(packet.control_events) - if self._mocap_session.state == MocapSessionState.PAUSED: - self._paused_mocap_step() - return - - human_frame = packet.frame - frame_timestamp = float(packet.timestamp_s) - frame_seq = int(packet.seq) - - reference_window: ReferenceWindow | None = None - if self._reference_timeline is not None: - if int(frame_seq) != self._last_live_packet_seq: - retargeted = self.retargeter.retarget(human_frame) - self._reference_timeline.append( - self._ref_proc.retarget_to_qpos(retargeted), - float(frame_timestamp), - ) - if self._reference_manager is not None: - self._reference_manager.note_realtime_frame() - self._last_live_packet_seq = int(frame_seq) - if self._reference_manager is None: - raise RuntimeError("Realtime reference manager must be initialized when using reference_timeline") - if not self._reference_manager.warmup_done: - return - reference_window, reference_diag = self._reference_manager.sample( - self._reference_timeline, - time.monotonic() - self._ref_cfg.reference_delay_s, - ) - if self._ref_cfg.reference_debug_log and any(reference_window.fallback_mask()): - logger.warning( - "Reference timeline fallback | buffer_len=%d | steps=%s | modes=%s", - len(self._reference_timeline), - list(reference_window.reference_steps), - list(reference_window.modes()), - ) - reference_qpos = reference_window.current_sample().qpos - else: - retargeted = self.retargeter.retarget(human_frame) - reference_qpos = self._ref_proc.retarget_to_qpos(retargeted) - - robot_state = self.robot.get_state() - self._execute_mocap_pipeline(reference_qpos, robot_state, reference_window) - - def _offline_mocap_step(self) -> None: - if self._offline_reference is None or self._offline_playback is None: - raise RuntimeError("Offline playback step requires an offline reference motion") - - if self._mocap_session.state == MocapSessionState.PAUSED: - self._paused_mocap_step() - return - - sample_time_s = self._offline_playback.current_time_s - sampled = self._offline_reference.sample(sample_time_s) - if sampled is None: - self._hold_completed_offline_playback(self._resolve_mocap_hold_qpos()) - self._paused_mocap_step() - return - - reference_window: ReferenceWindow | None = None - if obs_builder_requires_reference_window(self.obs_builder): - reference_window = build_offline_reference_window( - self._offline_reference, - sample_time_s, - self._reference_window_builder, - self.policy_hz, - ) - - reference_qpos = np.asarray(sampled.qpos, dtype=np.float64).copy() - robot_state = self.robot.get_state() - self._execute_mocap_pipeline(reference_qpos, robot_state, reference_window) - - if self._offline_playback.advance(): - self._hold_completed_offline_playback(self._last_commanded_motion_qpos) - - def _execute_mocap_pipeline( - self, - reference_qpos: Float64Array, - robot_state: object, - reference_window: ReferenceWindow | None, - ) -> None: - """Shared mocap control pipeline: align → infer → send.""" - reference_qpos = self._ref_proc.align_reference_yaw(reference_qpos, robot_state=robot_state) - qpos = reference_qpos.copy() - - # Compute joint velocities via finite difference - if qpos.shape[0] < 7 + self.num_actions: - raise ValueError( - f"Retargeted qpos too short: {qpos.shape[0]} (need >= {7 + self.num_actions})" - ) - motion_joint_pos = np.asarray(qpos[7:7 + self.num_actions], dtype=np.float32) - if self._last_retarget_qpos is None: - raw_motion_joint_vel = np.zeros((self.num_actions,), dtype=np.float32) - else: - prev_joint_pos = np.asarray(self._last_retarget_qpos[7:7 + self.num_actions], dtype=np.float32) - raw_motion_joint_vel = (motion_joint_pos - prev_joint_pos) * np.float32(self.policy_hz) - motion_joint_vel = self._ref_proc.apply_joint_vel_smoothing(raw_motion_joint_vel) - - # Compute anchor velocities. - anchor_lin_vel_w = np.zeros(3, dtype=np.float32) - anchor_ang_vel_w = np.zeros(3, dtype=np.float32) - if not obs_builder_requires_reference_window(self.obs_builder): - raw_anchor_lin_vel_w, raw_anchor_ang_vel_w = self._ref_proc.compute_anchor_velocities(reference_qpos) - anchor_lin_vel_w, anchor_ang_vel_w = self._ref_proc.apply_anchor_vel_smoothing( - raw_anchor_lin_vel_w, raw_anchor_ang_vel_w, - ) - - # Build observation and run policy - motion_qpos = np.asarray(qpos[:7 + self.num_actions], dtype=np.float32) - obs = self._ref_proc.build_observation( - robot_state=robot_state, - motion_qpos=motion_qpos, - motion_joint_vel=motion_joint_vel, - last_action=self._last_action, - anchor_lin_vel_w=anchor_lin_vel_w, - anchor_ang_vel_w=anchor_ang_vel_w, - reference_window=reference_window, - ) - obs = self._ref_proc.validate_observation(obs) - - action = self.policy.compute_action(obs) - target_dof_pos = self.policy.get_target_dof_pos(action) - target_dof_pos = self._safety.clip_to_joint_limits(target_dof_pos) - self._safety.send_positions(target_dof_pos) - - # Update state - self._last_action = np.asarray(action, dtype=np.float32).reshape(-1) - self._last_retarget_qpos = qpos.copy() - self._ref_proc.last_reference_qpos = reference_qpos.copy() - self._last_commanded_motion_qpos = qpos.copy() - self._last_mocap_hold_reason = None - self._write_retarget_viewer(qpos) - - # ------------------------------------------------------------------ - # State machine transitions - # ------------------------------------------------------------------ - - def _handle_transitions(self) -> None: - """Handle remote-triggered mode transitions.""" - if self.mode == RobotMode.IDLE: - if self.remote.start.on_pressed: - logger.info("Start pressed (from IDLE)") - self._enter_standing() - - elif self.mode == RobotMode.STANDING: - reentry_request = self._mocap_reentry_armed and self.remote.Y.pressed - if self.remote.Y.on_pressed or reentry_request: - if self._can_switch_to_mocap(): - if reentry_request and not self.remote.Y.on_pressed: - logger.info("Y held after STANDING return -> re-entering MOCAP") - else: - logger.info("Y pressed -> entering MOCAP") - self._transition_to_mocap() - else: - logger.warning("Cannot switch to MOCAP -- input check failed") - - elif self.mode == RobotMode.MOCAP: - if self.remote.B.on_pressed and self._offline_playback is not None: - logger.info("B pressed -> replaying offline motion from start") - self._restart_offline_playback() - return - if self.remote.A.on_pressed: - if self._mocap_session.state == MocapSessionState.PAUSED: - if self._offline_playback is not None and self._offline_playback.finished: - logger.info("Playback already ended; press B to replay from the start") - else: - logger.info("A pressed -> resuming playback") - self._resume_paused_mocap() - else: - logger.info("A pressed -> pausing playback") - self._pause_active_mocap() - return - if self.remote.X.on_pressed: - logger.info("X pressed -> returning to STANDING") - self._enter_standing() - - elif self.mode == RobotMode.DAMPING: - if self.remote.start.on_pressed: - logger.info("Start pressed (from DAMPING)") - self._enter_standing() - - # ------------------------------------------------------------------ - # Enter STANDING (from IDLE, MOCAP, or DAMPING) - # ------------------------------------------------------------------ - - def _enter_standing(self) -> None: - """Enter standing mode: debug mode + RL policy holds default pose. - - Works from IDLE, MOCAP (already in debug mode), or DAMPING. - """ - prev_mode = self.mode - already_in_debug = self.mode in (RobotMode.STANDING, RobotMode.MOCAP) - - if not already_in_debug: - logger.info("Entering debug mode...") - ok = self.robot.enter_debug_mode() - if not ok: - logger.error("Failed to enter debug mode -- staying in %s", self.mode.value) - return - time.sleep(0.5) - - state = self.robot.get_state() - if prev_mode != RobotMode.MOCAP: - # Lock joints to current position during initial low-level takeover. - logger.info("Locking joints to current position...") - self.robot.lock_all_joints() - time.sleep(0.3) - - # Episode-reset semantics: reference = current robot state, full policy reset. - # This matches training where robot is teleported to reference position at - # episode start, so policy sees reference ≈ robot state with clean history. - init_qpos = self._build_robot_state_qpos(state) - self._last_retarget_qpos = init_qpos - self._ref_proc.last_reference_qpos = None - self._mocap_session.reset() - self._last_commanded_motion_qpos = None - self._set_default_standing_reference(state) - - # Always do a full policy reset (episode-reset semantics) to ensure - # the TemporalCNN history is clean and action-state causality holds. - self._reset_policy_state() - - # Kp ramp: gradually increase PD gains to avoid torque spike. - # Unlike position ramping, this does not alter policy targets. - if prev_mode == RobotMode.MOCAP: - self._safety.start_kp_ramp( - duration_s=self._standing_return_ramp_duration, - floor_ratio=self._standing_return_kp_ramp_floor_ratio, - ) - else: - self._safety.start_kp_ramp() - - self._mocap_reentry_armed = prev_mode == RobotMode.MOCAP - - self.mode = RobotMode.STANDING - self._deactivate_dexterous_hand() - logger.info("Mode -> STANDING (RL policy maintaining balance at default pose)") - - # ------------------------------------------------------------------ - # STANDING -> MOCAP - # ------------------------------------------------------------------ - - def _can_switch_to_mocap(self) -> bool: - """Verify input signal is stable and values are reasonable.""" - if not self.input_provider.is_available(): - logger.warning("Mocap check: input provider not available") - return False - - if self._offline_reference is not None: - frame_count = min(self._check_frames, self._offline_reference.num_frames) - valid_count = 0 - for frame_index in range(frame_count): - try: - frame = self.input_provider.get_frame_by_index(frame_index) - except (IndexError, RuntimeError, ValueError): - return False - if self._ref_proc.frame_is_valid(frame): - valid_count += 1 - else: - break - if valid_count >= frame_count: - return True - logger.warning("Mocap check: only %d/%d valid offline frames", valid_count, frame_count) - return False - - has_frame = getattr(self.input_provider, "has_frame", None) - if callable(has_frame): - try: - if not bool(has_frame()): - logger.warning("Mocap check: realtime input has no frame available yet") - return False - except Exception: - logger.warning("Mocap check: failed to query realtime input availability") - return False - - valid_count = 0 - for _ in range(self._check_frames + 5): - try: - frame = self.input_provider.get_frame() - except (TimeoutError, RuntimeError): - return False - - if self._ref_proc.frame_is_valid(frame): - valid_count += 1 - else: - valid_count = 0 - - if valid_count >= self._check_frames: - return True - - time.sleep(0.02) - - logger.warning("Mocap check: only %d/%d valid frames", valid_count, self._check_frames) - return False - - def _transition_to_mocap(self) -> None: - """Switch from STANDING -> MOCAP. - - Episode-reset + reference realignment. The policy/reference state is - reset like pause/resume so the first mocap frame is anchored to the - current robot pose and starts with zero inferred reference velocity. - """ - state = self.robot.get_state() - resume_qpos = self._build_resume_alignment_qpos(self._standing_qpos, state) - self._last_commanded_motion_qpos = resume_qpos.copy() - self._mocap_reentry_armed = False - - # Full episode reset: clean policy state, alignment, timeline. - self._reset_policy_state() - self._last_retarget_qpos = None - self._last_commanded_motion_qpos = resume_qpos.copy() - self._ref_proc.reset_alignment(target_qpos=resume_qpos) - if self._offline_playback is not None: - self._last_retarget_qpos = resume_qpos.copy() - self._offline_playback.replay() - - self.mode = RobotMode.MOCAP - logger.info("Mode -> MOCAP (tracking motion commands)") - - # ------------------------------------------------------------------ - # Emergency stop / damping - # ------------------------------------------------------------------ - - def _enter_damping(self) -> None: - """Enter damping mode from any state.""" - if self.mode in (RobotMode.STANDING, RobotMode.MOCAP): - logger.info("DAMPING: sending LowCmd damping...") - self.robot.set_damping() - time.sleep(0.5) - logger.info("DAMPING: exiting debug mode...") - self.robot.exit_debug_mode() - - self.mode = RobotMode.DAMPING - self._deactivate_dexterous_hand() - self._ref_proc.last_reference_qpos = None - if self._reference_timeline is not None: - self._reference_timeline.clear() - self._last_live_packet_seq = -1 - self._mocap_reentry_armed = False - self._mocap_session.reset() - self._last_commanded_motion_qpos = None - self._last_mocap_hold_reason = None - logger.info("Mode -> DAMPING (press Start to re-enter STANDING)") - - # ------------------------------------------------------------------ - # Helpers - # ------------------------------------------------------------------ - - def _reset_policy_state(self) -> None: - """Full episode-reset: clear all policy state so the TemporalCNN sees - a clean start identical to training episode reset.""" - self._last_action = np.zeros(self.num_actions, dtype=np.float32) - self._reset_mocap_reference_state() - self._ref_proc.reset_alignment() - self._mocap_session.reset() - self._last_commanded_motion_qpos = None - self._last_mocap_hold_reason = None - self.policy.reset() - self.obs_builder.reset() - - def _reset_policy_reference_state(self) -> None: - """Reset policy/reference state without resetting the retargeter.""" - self._last_action = np.zeros(self.num_actions, dtype=np.float32) - self._reset_mocap_reference_state() - self._ref_proc.reset_alignment() - self._mocap_session.reset() - self._last_commanded_motion_qpos = None - self._last_mocap_hold_reason = None - self.policy.reset() - self.obs_builder.reset() - - def _reset_mocap_reference_state(self) -> None: - """Reset mocap-specific reference state without disrupting policy observation continuity. - - Unlike ``_reset_policy_state``, this preserves ``_last_action``, the - policy history buffer, and the observation builder state so that the - TemporalCNN sees a continuous observation stream across mode switches. - """ - if self._reference_timeline is not None: - self._reference_timeline.clear() - if self._reference_manager is not None: - self._reference_manager.set_warmup_steps(self._ref_cfg.realtime_buffer_warmup_steps) - self._reference_manager.reset() - self._ref_proc.reset_smoothers() - self._last_live_packet_seq = -1 - self._last_mocap_hold_reason = None - - def _build_robot_state_qpos(self, state: object) -> Float64Array: - qpos = np.zeros(FULL_QPOS_DIM, dtype=np.float64) - base_pos = getattr(state, "base_pos", None) - if base_pos is not None: - qpos[0:3] = np.asarray(base_pos, dtype=np.float64).reshape(-1)[:3] - qpos[3:7] = np.asarray(getattr(state, "quat"), dtype=np.float64).reshape(-1)[:4] - qpos[ROOT_DIM:FULL_QPOS_DIM] = np.asarray(getattr(state, "qpos"), dtype=np.float64).reshape(-1)[ - : self.num_actions - ] - return qpos - - def _set_default_standing_reference(self, state: object) -> None: - self._standing_qpos[:] = 0.0 - base_pos = getattr(state, "base_pos", None) - if base_pos is not None: - self._standing_qpos[0:3] = np.asarray(base_pos, dtype=np.float64).reshape(-1)[:3] - self._standing_qpos[3:7] = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64) - align_motion_qpos_yaw(np.asarray(getattr(state, "quat"), dtype=np.float32), self._standing_qpos) - self._standing_qpos[ROOT_DIM:FULL_QPOS_DIM] = self.default_angles.astype(np.float64) - - def _build_resume_alignment_qpos(self, hold_qpos: Float64Array | None, state: object) -> Float64Array: - qpos = self._build_robot_state_qpos(state) - if hold_qpos is not None: - qpos[0:2] = np.asarray(hold_qpos, dtype=np.float64).reshape(-1)[0:2] - base_pos = getattr(state, "base_pos", None) - if base_pos is not None: - qpos[0:2] = np.asarray(base_pos, dtype=np.float64).reshape(-1)[0:2] - return qpos - - def _restart_offline_playback(self) -> None: - if self._offline_playback is None: - raise RuntimeError("Offline playback replay is only available for indexed BVH input") - - state = self.robot.get_state() - restart_qpos = np.zeros(FULL_QPOS_DIM, dtype=np.float64) - restart_qpos[3:7] = state.quat.astype(np.float64) - restart_qpos[ROOT_DIM:FULL_QPOS_DIM] = state.qpos.astype(np.float64) - - self._last_retarget_qpos = restart_qpos.copy() - self._last_commanded_motion_qpos = restart_qpos.copy() - self._offline_playback.replay() - self._reset_policy_state() - logger.info("Offline playback restarted from frame 0") - - def _hold_completed_offline_playback(self, hold_qpos: Float64Array) -> None: - if self._offline_playback is None or self._mocap_session.state == MocapSessionState.PAUSED: - return - self._offline_playback.finish() - self._mocap_session.pause(hold_qpos) - logger.info("Offline playback reached the end; press B to replay") - - def _fetch_realtime_input_packet(self) -> RealtimeInputPacket[object]: - get_realtime_input_packet = getattr(self.input_provider, "get_realtime_input_packet", None) - if callable(get_realtime_input_packet): - return get_realtime_input_packet() - - get_packet = getattr(self.input_provider, "get_frame_packet", None) - if callable(get_packet): - frame, frame_timestamp, frame_seq = get_packet() - return RealtimeInputPacket( - frame=frame, - timestamp_s=float(frame_timestamp), - seq=int(frame_seq), - control_events=(), - ) - - frame = self.input_provider.get_frame() - return RealtimeInputPacket( - frame=frame, - timestamp_s=time.monotonic(), - seq=self._last_live_packet_seq + 1, - control_events=(), - ) - - def _handle_mocap_control_events(self, control_events: tuple[ControlEvent, ...]) -> None: - for event in control_events: - if event.event_type != ControlEventType.TOGGLE_PAUSE: - continue - if self._mocap_session.state == MocapSessionState.PAUSED: - self._resume_paused_mocap() - else: - self._pause_active_mocap() - - def _pause_active_mocap(self) -> None: - # Episode-reset semantics: treat pause as a new episode starting at - # the hold pose. Full policy reset ensures TemporalCNN history is - # clean -- no stale frames from the previous motion that would create - # an OOD discontinuity (reference jumps from motion to static). - hold_qpos = self._resolve_mocap_hold_qpos() - self._last_retarget_qpos = hold_qpos.copy() - self._ref_proc.last_reference_qpos = hold_qpos.copy() - self._last_commanded_motion_qpos = hold_qpos.copy() - - # Reset policy/reference state (clears last_action, history, smoothers, etc.) - # without resetting the retargeter IK warm-start. Pause is a mocap-session - # control event, not a new retargeting source. - # Note: _reset_policy_state resets _mocap_session to ACTIVE, so we - # must call pause() *after* it to set the correct PAUSED state. - self._reset_policy_reference_state() - self._mocap_session.pause(hold_qpos) - if self._offline_playback is not None: - self._offline_playback.pause() - logger.info("Mocap session -> PAUSED (episode-reset)") - - def _resume_paused_mocap(self) -> None: - if self._offline_playback is not None and self._offline_playback.finished: - logger.info("Offline playback already ended; press B to replay from the start") - return - - hold_qpos = self._mocap_session.hold_qpos - if hold_qpos is None: - raise RuntimeError("Cannot resume mocap without a paused hold qpos") - state = self.robot.get_state() - resume_qpos = self._build_resume_alignment_qpos(hold_qpos, state) - - self._last_commanded_motion_qpos = resume_qpos.copy() - - # Policy/reference reset -- clean history, zero last_action, smoothers, - # timeline, alignment. Keep the retargeter IK warm-start so the first - # resumed frame is solved from the current retarget state rather than - # from the model default qpos. Also resets _mocap_session to ACTIVE. - self._reset_policy_reference_state() - self._last_retarget_qpos = None - self._last_commanded_motion_qpos = resume_qpos.copy() - - # Override warmup steps for the resume-specific buffer warmup. - self._ref_proc.reset_alignment(target_qpos=resume_qpos) - if self._offline_playback is not None: - self._last_retarget_qpos = resume_qpos.copy() - self._offline_playback.resume() - - logger.info("Mocap session -> ACTIVE (episode-reset + reference realignment)") - - def _resolve_mocap_hold_qpos(self) -> Float64Array: - if self._last_commanded_motion_qpos is not None: - return self._last_commanded_motion_qpos.copy() - if self._last_retarget_qpos is not None: - return np.asarray(self._last_retarget_qpos, dtype=np.float64).copy() - state = self.robot.get_state() - hold_qpos = np.zeros(FULL_QPOS_DIM, dtype=np.float64) - hold_qpos[3:7] = np.asarray(state.quat, dtype=np.float64) - hold_qpos[ROOT_DIM:FULL_QPOS_DIM] = np.asarray(state.qpos, dtype=np.float64) - return hold_qpos - - def _paused_mocap_step(self) -> None: - hold_qpos = self._mocap_session.hold_qpos - if hold_qpos is None: - raise RuntimeError("Paused mocap session is missing a hold_qpos") - self._run_static_mocap_step(hold_qpos) - - def _run_static_mocap_step(self, hold_qpos: Float64Array) -> None: - robot_state = self.robot.get_state() - qpos = np.asarray(hold_qpos, dtype=np.float64).copy() - motion_joint_vel = np.zeros(self.num_actions, dtype=np.float32) - motion_qpos = np.asarray(qpos[:7 + self.num_actions], dtype=np.float32) - reference_window = None - if obs_builder_requires_reference_window(self.obs_builder): - reference_window = build_static_reference_window(qpos, self._reference_window_builder, self.policy_hz) - - obs = self._ref_proc.build_observation( - robot_state=robot_state, - motion_qpos=motion_qpos, - motion_joint_vel=motion_joint_vel, - last_action=self._last_action, - anchor_lin_vel_w=np.zeros(3, dtype=np.float32), - anchor_ang_vel_w=np.zeros(3, dtype=np.float32), - reference_window=reference_window, - ) - obs = self._ref_proc.validate_observation(obs) - - action = self.policy.compute_action(obs) - target_dof_pos = self.policy.get_target_dof_pos(action) - target_dof_pos = self._safety.clip_to_joint_limits(target_dof_pos) - - self._safety.send_positions(target_dof_pos) - self._last_action = np.asarray(action, dtype=np.float32).reshape(-1) - self._last_retarget_qpos = qpos.copy() - self._ref_proc.last_reference_qpos = qpos.copy() - self._last_commanded_motion_qpos = qpos.copy() - self._write_retarget_viewer(qpos) - - def _hold_mocap_reference(self, reason: str, *, detail: str | None = None) -> None: - if self._last_mocap_hold_reason != reason: - suffix = f" ({detail})" if detail else "" - logger.warning("Mocap reference not fresh: %s%s -- holding command", reason, suffix) - self._last_mocap_hold_reason = reason - hold_qpos = self._resolve_mocap_hold_qpos() - self._run_static_mocap_step(hold_qpos) - - def _tick_dexterous_hand(self) -> None: - active = self.mode == RobotMode.MOCAP and self._mocap_session.state == MocapSessionState.ACTIVE - try: - self._hand_runtime.tick(active=active) - except Exception: - logger.exception("Dexterous hand runtime failed; body control continues") - - def _deactivate_dexterous_hand(self) -> None: - try: - self._hand_runtime.tick(active=False) - except Exception: - logger.exception("Failed to deactivate dexterous hand runtime") - - def _write_retarget_viewer(self, qpos: Float64Array) -> None: - try: - self._retarget_viewer.write(qpos) - except Exception: - logger.exception("Sim2real retarget viewer update failed; control continues") - - @staticmethod - def _sleep_until(t0: float, dt: float) -> float: - """Sleep to maintain control frequency.""" - elapsed = time.monotonic() - t0 - remaining = dt - elapsed - if remaining > 0: - time.sleep(remaining) - return time.monotonic() - t0 - - def _sample_pico_frame_age_s(self) -> float | None: - has_frame = getattr(self.input_provider, "has_frame", None) - get_frame_packet = getattr(self.input_provider, "get_frame_packet", None) - if not callable(has_frame) or not callable(get_frame_packet): - return None - try: - if not has_frame(): - return None - _, frame_timestamp_s, _ = get_frame_packet() - except Exception: - return None - return max(0.0, time.monotonic() - float(frame_timestamp_s)) - - # ------------------------------------------------------------------ - # Lifecycle - # ------------------------------------------------------------------ - - def shutdown(self) -> None: - """Clean shutdown.""" - logger.info("Shutting down Sim2RealController") - if self.mode in (RobotMode.STANDING, RobotMode.MOCAP): - try: - self.robot.set_damping() - time.sleep(0.5) - except Exception: - pass - try: - self.robot.exit_debug_mode() - except Exception: - pass - try: - self._video_runtime.stop() - except Exception: - pass - try: - self._hand_runtime.close() - except Exception: - pass - try: - self._retarget_viewer.shutdown() - except Exception: - pass - try: - self.input_provider.close() - except Exception: - pass - try: - self.robot.close() - except Exception: - pass diff --git a/teleopit/sim2real/dexterous_hand.py b/teleopit/sim2real/dexterous_hand.py deleted file mode 100644 index 04ae9812..00000000 --- a/teleopit/sim2real/dexterous_hand.py +++ /dev/null @@ -1,979 +0,0 @@ -"""Optional LinkerHand L6 control for Pico sim2real.""" - -from __future__ import annotations - -from dataclasses import dataclass -from functools import lru_cache -import importlib.util -import logging -from pathlib import Path -import threading -import time -from typing import Any, Protocol, Sequence - -import numpy as np - -from teleopit.inputs.pico4_provider import ( - PicoControllerSnapshot, - PicoControllerState, - PicoHandSnapshot, - PicoHandState, -) -from teleopit.runtime.common import cfg_get - -logger = logging.getLogger(__name__) - -PROJECT_ROOT = Path(__file__).resolve().parents[2] -THUMB_YAW_DEFAULT = 10 -OPEN_POSE = [250, THUMB_YAW_DEFAULT, 250, 250, 250, 250] -CLOSE_POSE = [79, THUMB_YAW_DEFAULT, 0, 0, 0, 0] -DEFAULT_SPEED = [50, 50, 50, 50, 50, 50] -VR_HAND_POSE_SPEED = [255, 255, 255, 255, 255, 255] -HAND_TYPES = ("left", "right") -HAND_MODES = ("off", "gripper", "vr_hand_pose") -DEFAULT_SOMEHAND_CONFIG_PATH = "third_party/somehand/configs/retargeting/bihand/linkerhand_l6_bihand.yaml" -DEFAULT_LINKERHAND_SDK_ROOT = "third_party/linkerhand-python-sdk" -L6_SDK_JOINT_ORDER = ( - "thumb_cmc_pitch", - "thumb_cmc_roll", - "index_mcp_pitch", - "middle_mcp_pitch", - "ring_mcp_pitch", - "pinky_mcp_pitch", -) - - -class ControllerSnapshotProvider(Protocol): - def get_controller_snapshot(self) -> PicoControllerSnapshot | None: - ... - - -class HandSnapshotProvider(Protocol): - def get_hand_snapshot(self) -> PicoHandSnapshot | None: - ... - - -@dataclass(frozen=True) -class LinkerHandConfig: - mode: str = "off" - enabled: bool = False - hand_joint: str = "L6" - hand_type: str = "both" - left_can: str = "can0" - right_can: str = "can1" - modbus: str = "None" - rate: float = 30.0 - frame_timeout: float = 0.3 - trigger_deadzone: float = 0.05 - deadman_threshold: float = 0.5 - thumb_yaw_center: int = THUMB_YAW_DEFAULT - speed: tuple[int, ...] = tuple(DEFAULT_SPEED) - open_pose: tuple[int, ...] = tuple(OPEN_POSE) - close_pose: tuple[int, ...] = tuple(CLOSE_POSE) - print_input: bool = False - somehand_config_path: str = DEFAULT_SOMEHAND_CONFIG_PATH - somehand_sdk_root: str = DEFAULT_LINKERHAND_SDK_ROOT - somehand_rate: float | None = None - somehand_threaded: bool = False - somehand_max_iterations: int | None = None - somehand_temporal_filter_alpha: float | None = None - somehand_output_alpha: float | None = None - - @property - def selected_hand_types(self) -> tuple[str, ...]: - if self.hand_type == "both": - return HAND_TYPES - return (self.hand_type,) - - @property - def vr_hand_pose_rate(self) -> float: - return self.somehand_rate if self.somehand_rate is not None else self.rate - - -def clamp_unit(value: float) -> float: - return max(0.0, min(1.0, float(value))) - - -def normalize_trigger(value: float, deadzone: float) -> float: - value = clamp_unit(value) - deadzone = clamp_unit(deadzone) - if deadzone >= 0.5: - raise ValueError(f"trigger_deadzone must be < 0.5, got {deadzone}") - if value <= deadzone: - return 0.0 - upper = 1.0 - deadzone - if value >= upper: - return 1.0 - return (value - deadzone) / (upper - deadzone) - - -def trigger_to_pose( - trigger: float, - *, - open_pose: Sequence[int], - close_pose: Sequence[int], - deadzone: float, - thumb_yaw_default: int, -) -> list[int]: - if len(open_pose) != 6 or len(close_pose) != 6: - raise ValueError("LinkerHand L6 open_pose and close_pose must each contain 6 values") - alpha = normalize_trigger(trigger, deadzone) - pose = [ - int(round(float(open_value) + alpha * (float(close_value) - float(open_value)))) - for open_value, close_value in zip(open_pose, close_pose) - ] - pose[1] = int(thumb_yaw_default) - return pose - - -class L6RetargetPoseMapper: - """Map somehand-retargeted L6 qpos into six-channel LinkerHand L6 SDK range.""" - - def __init__(self, hand_model: Any | None, *, hand_type: str, sdk_root: str): - self._hand_type = hand_type - self._indices = self._resolve_indices(hand_model, hand_type=hand_type) - self._mapping = _load_linkerhand_mapping_module(sdk_root) - self._arc_min, self._arc_max, self._arc_direction = _sdk_l6_range_params( - self._mapping, - hand_type=hand_type, - ) - - def qpos_to_pose(self, qpos: Any) -> list[int]: - values = np.asarray(qpos, dtype=np.float64).reshape(-1) - max_index = int(np.max(self._indices)) - if values.shape[0] <= max_index: - raise ValueError( - "somehand L6 retarget qpos is too short for resolved SDK joint mapping: " - f"got {values.shape[0]}, need index {max_index}" - ) - channel_values = values[self._indices] - sdk_range = [] - for index, value in enumerate(channel_values): - arc = self._mapping.is_within_range( - float(value), - float(self._arc_min[index]), - float(self._arc_max[index]), - ) - if int(self._arc_direction[index]) == -1: - sdk_range.append( - self._mapping.scale_value( - arc, - float(self._arc_min[index]), - float(self._arc_max[index]), - 255.0, - 0.0, - ) - ) - else: - sdk_range.append( - self._mapping.scale_value( - arc, - float(self._arc_min[index]), - float(self._arc_max[index]), - 0.0, - 255.0, - ) - ) - return [_uint8(round(float(value)), f"somehand.{self._hand_type}.pose") for value in sdk_range] - - @staticmethod - def _resolve_indices(hand_model: Any | None, *, hand_type: str) -> np.ndarray: - if hand_model is None: - raise ValueError("somehand L6 hand model is missing; cannot map retarget qpos to SDK joints") - get_index = getattr(hand_model, "get_joint_name_to_qpos_index", None) - if not callable(get_index): - raise ValueError("somehand L6 hand model does not expose get_joint_name_to_qpos_index()") - joint_index = get_index() - indices: list[int] = [] - for channel in L6_SDK_JOINT_ORDER: - resolved = _resolve_l6_joint_name(joint_index, channel, hand_type=hand_type) - if resolved is None: - available = ", ".join(sorted(str(name) for name in joint_index)) - raise ValueError( - f"Cannot resolve LinkerHand L6 SDK joint {channel!r} in somehand {hand_type} hand model. " - f"Available joints: {available}" - ) - indices.append(int(joint_index[resolved])) - return np.asarray(indices, dtype=np.int64) - - -def parse_linkerhand_config(cfg: Any) -> LinkerHandConfig: - hand_cfg = cfg_get(cfg, "dexterous_hand", {}) or {} - raw_mode = cfg_get(hand_cfg, "mode", None) - legacy_enabled = bool(cfg_get(hand_cfg, "enabled", False)) - if isinstance(raw_mode, bool): - mode = "gripper" if raw_mode else "off" - else: - mode = str(raw_mode if raw_mode is not None else ("gripper" if legacy_enabled else "off")).lower() - somehand_cfg = cfg_get(hand_cfg, "somehand", {}) or {} - thumb_yaw = _uint8(cfg_get(hand_cfg, "thumb_yaw_center", THUMB_YAW_DEFAULT), "thumb_yaw_center") - open_pose = _pose_values(cfg_get(hand_cfg, "open_pose", OPEN_POSE), "open_pose") - close_pose = _pose_values(cfg_get(hand_cfg, "close_pose", CLOSE_POSE), "close_pose") - open_pose[1] = thumb_yaw - close_pose[1] = thumb_yaw - - speed = VR_HAND_POSE_SPEED if mode == "vr_hand_pose" else _pose_values(cfg_get(hand_cfg, "speed", DEFAULT_SPEED), "speed") - - config = LinkerHandConfig( - mode=mode, - enabled=mode != "off", - hand_joint=str(cfg_get(hand_cfg, "hand_joint", "L6")).upper(), - hand_type=str(cfg_get(hand_cfg, "hand_type", "both")).lower(), - left_can=str(cfg_get(hand_cfg, "left_can", "can0")), - right_can=str(cfg_get(hand_cfg, "right_can", "can1")), - modbus=str(cfg_get(hand_cfg, "modbus", "None")), - rate=_positive_float(cfg_get(hand_cfg, "rate", 30.0), "rate"), - frame_timeout=_positive_float(cfg_get(hand_cfg, "frame_timeout", 0.3), "frame_timeout"), - trigger_deadzone=_trigger_deadzone(cfg_get(hand_cfg, "trigger_deadzone", 0.05)), - deadman_threshold=_deadman_threshold(cfg_get(hand_cfg, "deadman_threshold", 0.5)), - thumb_yaw_center=thumb_yaw, - speed=tuple(speed), - open_pose=tuple(open_pose), - close_pose=tuple(close_pose), - print_input=bool(cfg_get(hand_cfg, "print_input", False)), - somehand_config_path=str(cfg_get(somehand_cfg, "config_path", DEFAULT_SOMEHAND_CONFIG_PATH)), - somehand_sdk_root=str(cfg_get(somehand_cfg, "sdk_root", DEFAULT_LINKERHAND_SDK_ROOT)), - somehand_rate=_optional_positive_float(cfg_get(somehand_cfg, "rate", None), "somehand.rate"), - somehand_threaded=bool(cfg_get(somehand_cfg, "threaded", False)), - somehand_max_iterations=_optional_positive_int( - cfg_get(somehand_cfg, "max_iterations", None), - "somehand.max_iterations", - ), - somehand_temporal_filter_alpha=_optional_unit_interval( - cfg_get(somehand_cfg, "temporal_filter_alpha", None), - "somehand.temporal_filter_alpha", - ), - somehand_output_alpha=_optional_unit_interval( - cfg_get(somehand_cfg, "output_alpha", None), - "somehand.output_alpha", - ), - ) - if config.mode not in HAND_MODES: - raise ValueError(f"dexterous_hand.mode must be one of {', '.join(HAND_MODES)}, got {config.mode!r}") - if config.hand_joint != "L6": - raise ValueError(f"dexterous_hand.hand_joint must be 'L6', got {config.hand_joint!r}") - if config.hand_type not in ("left", "right", "both"): - raise ValueError("dexterous_hand.hand_type must be left, right, or both") - return config - - -class L6PoseSender: - """Thin adapter around LinkerHandApi with duplicate-command suppression.""" - - def __init__(self, config: LinkerHandConfig): - self._config = config - self._hand_types = config.selected_hand_types - self._can_channels = {"left": config.left_can, "right": config.right_can} - self._hands: dict[str, Any] = {} - self._last_pose: dict[str, list[int] | None] = { - hand_type: None for hand_type in self._hand_types - } - self._started = False - - @property - def started(self) -> bool: - return self._started - - def start(self) -> None: - if self._started: - return - try: - from LinkerHand.linker_hand_api import LinkerHandApi - except ImportError as exc: - raise ImportError( - "LinkerHand SDK is required when dexterous_hand.mode is gripper or vr_hand_pose. " - "Run: pip install -e third_party/linkerhand-python-sdk" - ) from exc - - try: - for hand_type in self._hand_types: - hand = LinkerHandApi( - hand_joint="L6", - hand_type=hand_type, - modbus=self._config.modbus, - can=self._can_channels[hand_type], - ) - hand.set_speed(speed=list(self._config.speed)) - self._hands[hand_type] = hand - self._started = True - except SystemExit as exc: - self._close_hands() - self._started = False - raise RuntimeError( - "LinkerHand SDK exited during startup. Check CAN interface configuration " - f"({', '.join(self._can_channels[hand_type] for hand_type in self._hand_types)}). " - "Run scripts/dev/test_linkerhand_l6.py to verify the hand connection." - ) from exc - except Exception: - self._close_hands() - self._started = False - raise - logger.info("LinkerHand L6 runtime started | hands=%s", ",".join(self._hand_types)) - - def send(self, hand_type: str, pose: Sequence[int], *, force: bool = False, reason: str = "") -> None: - if not self._started: - return - next_pose = [int(value) for value in pose] - if not force and self._last_pose.get(hand_type) == next_pose: - return - del reason - hand = self._hands.get(hand_type) - if hand is None: - raise RuntimeError("L6PoseSender has not been started") - hand.finger_move(pose=next_pose) - self._last_pose[hand_type] = next_pose - - def send_all(self, pose: Sequence[int], *, force: bool = False, reason: str = "") -> None: - for hand_type in self._hand_types: - self.send(hand_type, pose, force=force, reason=reason) - - def close(self) -> None: - if not self._started and not self._hands: - return - try: - if self._started: - self.send_all(self._config.open_pose, force=True, reason="exit") - time.sleep(0.2) - except Exception: - logger.exception("Failed to send LinkerHand open pose on exit") - self._close_hands() - self._started = False - - def _close_hands(self) -> None: - for hand in self._hands.values(): - inner_hand = getattr(hand, "hand", None) - close = getattr(inner_hand, "close", None) - if callable(close): - close() - self._hands.clear() - - -class AsyncL6PoseSender: - """Run blocking LinkerHand SDK calls outside the robot control loop.""" - - def __init__(self, config: LinkerHandConfig): - self._config = config - self._sync_sender = L6PoseSender(config) - self._condition = threading.Condition() - self._pending: dict[str, tuple[list[int], bool, str]] = {} - self._thread: threading.Thread | None = None - self._running = False - self._stopping = False - self._busy = False - self._failed = False - - @property - def started(self) -> bool: - return self._running and not self._failed - - @property - def _last_pose(self) -> dict[str, list[int] | None]: - return self._sync_sender._last_pose - - def start(self) -> None: - with self._condition: - if self._running: - return - self._running = True - self._stopping = False - self._failed = False - self._busy = True - self._thread = threading.Thread( - target=self._run, - name="linkerhand-l6-sender", - daemon=True, - ) - self._thread.start() - - def send(self, hand_type: str, pose: Sequence[int], *, force: bool = False, reason: str = "") -> None: - next_pose = [int(value) for value in pose] - if not force and self._sync_sender._last_pose.get(hand_type) == next_pose: - return - with self._condition: - if not self._running or self._failed or self._stopping: - return - self._pending[hand_type] = (next_pose, force, reason) - self._condition.notify_all() - - def send_all(self, pose: Sequence[int], *, force: bool = False, reason: str = "") -> None: - for hand_type in self._config.selected_hand_types: - self.send(hand_type, pose, force=force, reason=reason) - - def close(self) -> None: - thread: threading.Thread | None - with self._condition: - if not self._running: - return - if not self._failed: - for hand_type in self._config.selected_hand_types: - self._pending[hand_type] = (list(self._config.open_pose), True, "exit") - self._stopping = True - self._condition.notify_all() - thread = self._thread - if thread is not None: - thread.join(timeout=3.0) - if thread.is_alive(): - logger.warning("LinkerHand L6 worker did not stop within timeout") - - def wait_idle(self, timeout_s: float = 1.0) -> bool: - deadline = time.monotonic() + timeout_s - with self._condition: - while self._busy or self._pending: - remaining = deadline - time.monotonic() - if remaining <= 0.0: - return False - self._condition.wait(timeout=remaining) - return True - - def _run(self) -> None: - try: - self._sync_sender.start() - self._sync_sender.send_all(self._config.open_pose, force=True, reason="startup") - while True: - commands = self._take_commands() - if not commands: - break - try: - for hand_type, pose, force, reason in commands: - self._sync_sender.send(hand_type, pose, force=force, reason=reason) - finally: - with self._condition: - self._busy = False - self._condition.notify_all() - except Exception: - logger.exception("LinkerHand L6 worker failed; hand control is disabled") - with self._condition: - self._failed = True - self._pending.clear() - self._busy = False - self._condition.notify_all() - finally: - try: - self._sync_sender.close() - except Exception: - logger.exception("Failed to close LinkerHand L6 worker cleanly") - with self._condition: - self._running = False - self._busy = False - self._condition.notify_all() - - def _take_commands(self) -> list[tuple[str, list[int], bool, str]]: - with self._condition: - while not self._pending and not self._stopping: - self._busy = False - self._condition.notify_all() - self._condition.wait() - if not self._pending and self._stopping: - return [] - self._busy = True - commands = [ - (hand_type, pose, force, reason) - for hand_type, (pose, force, reason) in self._pending.items() - ] - self._pending.clear() - return commands - - -class LinkerHandRuntime: - """Drive LinkerHand L6 from Pico controller grip/trigger snapshots.""" - - def __init__(self, config: LinkerHandConfig, provider: ControllerSnapshotProvider): - self.config = config - self._provider = provider - self._sender = AsyncL6PoseSender(config) - self._interval_s = 1.0 / config.rate - self._next_tick_s = 0.0 - self._active = False - self._last_status: dict[str, str] = {hand_type: "" for hand_type in config.selected_hand_types} - - @property - def enabled(self) -> bool: - return self.config.enabled - - def start(self) -> None: - if not self.enabled: - return - self._sender.start() - self._sender.send_all(self.config.open_pose, force=True, reason="startup") - - def tick(self, *, active: bool, now_s: float | None = None) -> None: - if not self.enabled: - return - now = time.monotonic() if now_s is None else float(now_s) - if not active: - self._deactivate(reason="inactive") - return - if not self._active: - self._active = True - self._next_tick_s = 0.0 - if now < self._next_tick_s: - return - self._next_tick_s = now + self._interval_s - - snapshot = self._provider.get_controller_snapshot() - if snapshot is None or now - snapshot.timestamp_s > self.config.frame_timeout: - self._open_all(reason="timeout") - return - - for hand_type in self.config.selected_hand_types: - state = getattr(snapshot, hand_type) - self._tick_hand(hand_type, state, snapshot.seq) - - def close(self) -> None: - self._deactivate(reason="shutdown") - self._sender.close() - - def _tick_hand(self, hand_type: str, state: PicoControllerState, seq: int) -> None: - if not state.present: - self._set_status(hand_type, "missing", f"{hand_type} controller missing; opening hand") - self._sender.send(hand_type, self.config.open_pose, reason="missing-controller") - return - - grip = clamp_unit(state.grip) - trigger = clamp_unit(state.trigger) - if grip < self.config.deadman_threshold: - self._set_status(hand_type, "deadman", f"{hand_type} deadman released; opening hand") - self._sender.send(hand_type, self.config.open_pose, reason="deadman-released") - return - - self._set_status(hand_type, "enabled", f"{hand_type} controller active") - if self.config.print_input: - logger.info( - "LinkerHand input | seq=%d hand=%s grip=%.3f trigger=%.3f", - seq, - hand_type, - grip, - trigger, - ) - pose = trigger_to_pose( - trigger, - open_pose=self.config.open_pose, - close_pose=self.config.close_pose, - deadzone=self.config.trigger_deadzone, - thumb_yaw_default=self.config.thumb_yaw_center, - ) - self._sender.send(hand_type, pose, reason="controller") - - def _deactivate(self, *, reason: str) -> None: - if self._active: - self._open_all(reason=reason, force=True) - self._active = False - - def _open_all(self, *, reason: str, force: bool = False) -> None: - self._sender.send_all(self.config.open_pose, force=force, reason=reason) - - def _set_status(self, hand_type: str, status: str, message: str) -> None: - if self._last_status.get(hand_type) == status: - return - self._last_status[hand_type] = status - logger.info("LinkerHand L6: %s", message) - - -class SomeHandPoseRuntime: - """Drive LinkerHand L6 from Pico hand-pose snapshots through somehand.""" - - def __init__(self, config: LinkerHandConfig, provider: HandSnapshotProvider): - self.config = config - self._provider = provider - self._sender = AsyncL6PoseSender(config) - self._interval_s = 1.0 / config.vr_hand_pose_rate - self._next_tick_s = 0.0 - self._active = False - self._last_status: dict[str, str] = {hand_type: "" for hand_type in config.selected_hand_types} - self._engine: Any | None = None - self._hand_frame_cls: Any | None = None - self._bihand_frame_cls: Any | None = None - self._pico_hand_to_landmarks: Any | None = None - self._pose_mappers: dict[str, L6RetargetPoseMapper] = {} - - @property - def enabled(self) -> bool: - return self.config.enabled - - def start(self) -> None: - if not self.enabled: - return - self._load_somehand() - self._sender.start() - self._sender.send_all(self.config.open_pose, force=True, reason="startup") - - def tick(self, *, active: bool, now_s: float | None = None) -> None: - if not self.enabled: - return - now = time.monotonic() if now_s is None else float(now_s) - if not active: - self._deactivate(reason="inactive") - return - if not self._active: - self._active = True - self._next_tick_s = 0.0 - if now < self._next_tick_s: - return - self._next_tick_s = now + self._interval_s - - snapshot = self._provider.get_hand_snapshot() - if snapshot is None: - self._set_status("both", "missing", "Pico hand pose missing; holding last hand command") - return - if now - snapshot.timestamp_s > self.config.frame_timeout: - self._set_status("both", "timeout", "Pico hand pose timed out; holding last hand command") - return - - self._tick_snapshot(snapshot) - - def close(self) -> None: - self._deactivate(reason="shutdown") - self._sender.close() - - def _tick_snapshot(self, snapshot: PicoHandSnapshot) -> None: - left_frame = self._make_hand_frame("left", snapshot.left) if "left" in self.config.selected_hand_types else None - right_frame = self._make_hand_frame("right", snapshot.right) if "right" in self.config.selected_hand_types else None - if left_frame is None and right_frame is None: - return - - result = self._engine.process(self._bihand_frame_cls(left=left_frame, right=right_frame)) - for hand_type, detected, step in ( - ("left", result.left_detected, result.left), - ("right", result.right_detected, result.right), - ): - if hand_type not in self.config.selected_hand_types or not detected: - continue - pose = self._pose_mappers[hand_type].qpos_to_pose(step.qpos) - self._sender.send(hand_type, pose, reason="vr-hand-pose") - - def _make_hand_frame(self, hand_type: str, state: PicoHandState) -> Any | None: - if not state.present: - self._set_status(hand_type, "missing", f"{hand_type} hand pose missing; holding last hand command") - return None - if not state.active: - self._set_status(hand_type, "inactive", f"{hand_type} hand pose inactive; holding last hand command") - return None - self._set_status(hand_type, "enabled", f"{hand_type} hand pose active") - landmarks = self._pico_hand_to_landmarks(state.joints) - return self._hand_frame_cls(landmarks_3d=landmarks, landmarks_2d=None, hand_side=hand_type) - - def _deactivate(self, *, reason: str) -> None: - if self._active: - self._sender.send_all(self.config.open_pose, force=True, reason=reason) - self._active = False - - def _load_somehand(self) -> None: - try: - from somehand.api import BiHandFrame, BiHandRetargetingEngine, HandFrame - from somehand.pico_input import pico_hand_to_landmarks - except ImportError as exc: - raise ImportError( - "somehand is required when dexterous_hand.mode=vr_hand_pose. " - "Install it with: pip install -e '.[dexhand]'" - ) from exc - - config_path = _resolve_project_path(self.config.somehand_config_path) - if not config_path.exists(): - raise FileNotFoundError( - "somehand bi-hand config not found: " - f"{config_path}. Initialize the submodule and download assets with " - "scripts/setup/download_somehand_l6_assets.sh" - ) - self._engine = BiHandRetargetingEngine.from_config_path(str(config_path)) - self._apply_low_latency_overrides() - self._hand_frame_cls = HandFrame - self._bihand_frame_cls = BiHandFrame - self._pico_hand_to_landmarks = pico_hand_to_landmarks - - # somehand owns hand-pose retargeting; Teleopit owns the LinkerHand L6 command mapping. - self._pose_mappers = {} - for hand_type, engine in (("left", self._engine.left_engine), ("right", self._engine.right_engine)): - if hand_type not in self.config.selected_hand_types: - continue - self._pose_mappers[hand_type] = L6RetargetPoseMapper( - getattr(engine, "hand_model", None), - hand_type=hand_type, - sdk_root=self.config.somehand_sdk_root, - ) - logger.info("somehand LinkerHand L6 runtime started | hands=%s", ",".join(self.config.selected_hand_types)) - - def _apply_low_latency_overrides(self) -> None: - for hand_type, engine in (("left", self._engine.left_engine), ("right", self._engine.right_engine)): - retargeter = getattr(engine, "retargeter", None) - if retargeter is None: - continue - if self.config.somehand_max_iterations is not None: - setattr(retargeter, "_max_iterations", int(self.config.somehand_max_iterations)) - if self.config.somehand_output_alpha is not None: - setattr(retargeter, "_output_alpha", float(self.config.somehand_output_alpha)) - if self.config.somehand_temporal_filter_alpha is not None: - landmark_filter = getattr(retargeter, "landmark_filter", None) - if landmark_filter is not None: - setattr(landmark_filter, "alpha", float(self.config.somehand_temporal_filter_alpha)) - logger.info( - "somehand low-latency overrides | hand=%s rate=%.1fHz max_iter=%s temporal_alpha=%s output_alpha=%s", - hand_type, - self.config.vr_hand_pose_rate, - self.config.somehand_max_iterations, - self.config.somehand_temporal_filter_alpha, - self.config.somehand_output_alpha, - ) - - def _set_status(self, hand_type: str, status: str, message: str) -> None: - key = hand_type - if self._last_status.get(key) == status: - return - self._last_status[key] = status - logger.info("somehand LinkerHand L6: %s", message) - - -class ThreadedSomeHandPoseRuntime: - """Tick the somehand path independently from the robot control loop.""" - - def __init__(self, runtime: SomeHandPoseRuntime): - self._runtime = runtime - self._condition = threading.Condition() - self._runtime_lock = threading.Lock() - self._thread: threading.Thread | None = None - self._running = False - self._active = False - self._interval_s = 1.0 / runtime.config.vr_hand_pose_rate - - @property - def config(self) -> LinkerHandConfig: - return self._runtime.config - - @property - def enabled(self) -> bool: - return self._runtime.enabled - - def start(self) -> None: - if not self.enabled: - return - self._runtime.start() - with self._condition: - if self._running: - return - self._running = True - self._thread = threading.Thread( - target=self._run, - name="somehand-pose-runtime", - daemon=True, - ) - self._thread.start() - - def tick(self, *, active: bool, now_s: float | None = None) -> None: - del now_s - if not self.enabled: - return - should_deactivate = False - with self._condition: - if self._active != bool(active): - should_deactivate = self._active and not bool(active) - self._active = bool(active) - self._condition.notify_all() - if should_deactivate: - with self._runtime_lock: - self._runtime.tick(active=False) - - def close(self) -> None: - thread: threading.Thread | None - with self._condition: - self._running = False - self._active = False - self._condition.notify_all() - thread = self._thread - if thread is not None: - thread.join(timeout=2.0) - if thread.is_alive(): - logger.warning("somehand pose runtime worker did not stop within timeout") - with self._runtime_lock: - self._runtime.close() - - def _run(self) -> None: - next_tick_s = 0.0 - while True: - with self._condition: - while self._running and not self._active: - self._condition.wait() - next_tick_s = 0.0 - if not self._running: - return - active = self._active - now = time.monotonic() - if now < next_tick_s: - with self._condition: - self._condition.wait(timeout=next_tick_s - now) - continue - with self._runtime_lock: - self._runtime.tick(active=active, now_s=now) - next_tick_s = now + self._interval_s - - -class DisabledLinkerHandRuntime: - enabled = False - - def start(self) -> None: - pass - - def tick(self, *, active: bool, now_s: float | None = None) -> None: - del active, now_s - - def close(self) -> None: - pass - - -def build_linkerhand_runtime( - cfg: Any, - input_provider: Any, -) -> LinkerHandRuntime | SomeHandPoseRuntime | ThreadedSomeHandPoseRuntime | DisabledLinkerHandRuntime: - config = parse_linkerhand_config(cfg) - if not config.enabled: - return DisabledLinkerHandRuntime() - - input_cfg = cfg_get(cfg, "input", {}) or {} - provider_kind = str(cfg_get(input_cfg, "provider", "")).lower() - if provider_kind != "pico4": - raise ValueError("dexterous_hand.mode requires input.provider=pico4") - if config.mode == "gripper": - if not callable(getattr(input_provider, "get_controller_snapshot", None)): - raise ValueError("dexterous_hand.mode=gripper requires a Pico input provider with controller snapshots") - return LinkerHandRuntime(config, input_provider) - if config.mode == "vr_hand_pose": - if config.hand_type != "both": - raise ValueError("dexterous_hand.mode=vr_hand_pose currently requires dexterous_hand.hand_type=both") - if not callable(getattr(input_provider, "get_hand_snapshot", None)): - raise ValueError("dexterous_hand.mode=vr_hand_pose requires a Pico input provider with hand snapshots") - runtime = SomeHandPoseRuntime(config, input_provider) - if config.somehand_threaded: - return ThreadedSomeHandPoseRuntime(runtime) - return runtime - raise ValueError(f"Unsupported dexterous_hand.mode={config.mode!r}") - - -def _resolve_project_path(path_value: str) -> Path: - path = Path(path_value).expanduser() - if path.is_absolute(): - return path - return (PROJECT_ROOT / path).resolve() - - -@lru_cache(maxsize=4) -def _load_linkerhand_mapping_module(sdk_root: str) -> Any: - mapping_path = _resolve_project_path(sdk_root) / "LinkerHand" / "utils" / "mapping.py" - if not mapping_path.exists(): - raise FileNotFoundError(f"LinkerHand SDK mapping module not found: {mapping_path}") - spec = importlib.util.spec_from_file_location("teleopit_linkerhand_mapping", mapping_path) - if spec is None or spec.loader is None: - raise RuntimeError(f"Cannot load LinkerHand SDK mapping module from: {mapping_path}") - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - return module - - -def _sdk_l6_range_params(mapping: Any, *, hand_type: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - side = "l" if hand_type == "left" else "r" - arc_min = np.asarray(getattr(mapping, f"l6_{side}_min"), dtype=np.float64) - arc_max = np.asarray(getattr(mapping, f"l6_{side}_max"), dtype=np.float64) - direction = np.asarray(getattr(mapping, f"l6_{side}_derict"), dtype=np.int8) - expected_shape = (len(L6_SDK_JOINT_ORDER),) - if arc_min.shape != expected_shape or arc_max.shape != expected_shape or direction.shape != expected_shape: - raise ValueError( - "LinkerHand SDK L6 mapping has unexpected shape: " - f"min={arc_min.shape}, max={arc_max.shape}, direction={direction.shape}" - ) - return arc_min, arc_max, direction - - -def _resolve_l6_joint_name(joint_index: dict[str, int], semantic_name: str, *, hand_type: str) -> str | None: - aliases = _l6_joint_aliases(semantic_name) - side_prefixes = ( - "", - f"{hand_type}_", - f"{hand_type[0]}_", - f"{hand_type[0].upper()}_", - f"{'lh' if hand_type == 'left' else 'rh'}_", - ) - candidates = tuple(f"{prefix}{alias}" for alias in aliases for prefix in side_prefixes) - for candidate in candidates: - if candidate in joint_index: - return candidate - for alias in aliases: - suffix = f"_{alias}" - for name in joint_index: - if name == alias or name.endswith(suffix): - return name - return None - - -def _l6_joint_aliases(semantic_name: str) -> tuple[str, ...]: - if semantic_name == "thumb_cmc_pitch": - return ("thumb_cmc_pitch", "thumb_pitch") - if semantic_name == "thumb_cmc_roll": - return ("thumb_cmc_roll", "thumb_roll") - aliases = [semantic_name] - if semantic_name.endswith("_mcp_pitch"): - finger = semantic_name[: -len("_mcp_pitch")] - aliases.append(f"{finger}_pitch") - if finger == "pinky": - aliases.extend(("little_mcp_pitch", "little_pitch")) - elif finger == "little": - aliases.extend(("pinky_mcp_pitch", "pinky_pitch")) - return tuple(dict.fromkeys(aliases)) - - -def _positive_float(value: object, field_name: str) -> float: - parsed = float(value) - if parsed <= 0.0: - raise ValueError(f"dexterous_hand.{field_name} must be > 0, got {value!r}") - return parsed - - -def _optional_positive_float(value: object, field_name: str) -> float | None: - if value is None: - return None - return _positive_float(value, field_name) - - -def _optional_positive_int(value: object, field_name: str) -> int | None: - if value is None: - return None - parsed = int(value) - if parsed <= 0: - raise ValueError(f"dexterous_hand.{field_name} must be > 0, got {value!r}") - return parsed - - -def _optional_unit_interval(value: object, field_name: str) -> float | None: - if value is None: - return None - parsed = float(value) - if parsed <= 0.0 or parsed > 1.0: - raise ValueError(f"dexterous_hand.{field_name} must be in (0, 1], got {value!r}") - return parsed - - -def _uint8(value: object, field_name: str) -> int: - parsed = int(value) - if parsed < 0 or parsed > 255: - raise ValueError(f"dexterous_hand.{field_name} must be in range 0-255, got {value!r}") - return parsed - - -def _pose_values(value: object, field_name: str) -> list[int]: - try: - parsed = [_uint8(item, field_name) for item in value] # type: ignore[union-attr] - except TypeError as exc: - raise ValueError(f"dexterous_hand.{field_name} must be a sequence of 6 uint8 values") from exc - if len(parsed) != 6: - raise ValueError(f"dexterous_hand.{field_name} must contain 6 values, got {len(parsed)}") - return parsed - - -def _trigger_deadzone(value: object) -> float: - parsed = float(value) - if parsed < 0.0 or parsed >= 0.5: - raise ValueError(f"dexterous_hand.trigger_deadzone must be in [0, 0.5), got {value!r}") - return parsed - - -def _deadman_threshold(value: object) -> float: - parsed = float(value) - if parsed <= 0.0 or parsed >= 1.0: - raise ValueError(f"dexterous_hand.deadman_threshold must be in (0, 1), got {value!r}") - return parsed diff --git a/teleopit/sim2real/hands/__init__.py b/teleopit/sim2real/hands/__init__.py new file mode 100644 index 00000000..de2de569 --- /dev/null +++ b/teleopit/sim2real/hands/__init__.py @@ -0,0 +1,3 @@ +from teleopit.sim2real.hands.worker import build_hand_runtime + +__all__ = ["build_hand_runtime"] diff --git a/teleopit/sim2real/hands/base.py b/teleopit/sim2real/hands/base.py new file mode 100644 index 00000000..76257789 --- /dev/null +++ b/teleopit/sim2real/hands/base.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Protocol, Sequence + + +HAND_SIDES = ("left", "right") + + +@dataclass(frozen=True) +class HandPoseCommand: + side: str + pose: tuple[int, ...] + force: bool = False + reason: str = "" + + +class HandDevice(Protocol): + def connect(self) -> None: ... + + def send_pose(self, side: str, pose: Sequence[int], *, force: bool = False, reason: str = "") -> None: ... + + def open_all(self, *, force: bool = False, reason: str = "") -> None: ... + + def close(self) -> None: ... + + +class HandInputMapper(Protocol): + def start(self) -> None: ... + + def map(self, *, controller_snapshot: object | None, hand_snapshot: object | None, active: bool, now_s: float) -> tuple[HandPoseCommand, ...]: ... + + def close(self) -> None: ... diff --git a/teleopit/sim2real/hands/linkerhand_l6.py b/teleopit/sim2real/hands/linkerhand_l6.py new file mode 100644 index 00000000..52c11f86 --- /dev/null +++ b/teleopit/sim2real/hands/linkerhand_l6.py @@ -0,0 +1,439 @@ +from __future__ import annotations + +from dataclasses import dataclass +from importlib.metadata import PackageNotFoundError, version +from pathlib import Path +import importlib.util +import logging +from typing import Any, Sequence + +import numpy as np + +from teleopit.runtime.common import cfg_get +from teleopit.sim2real.hands.base import HAND_SIDES, HandDevice, HandInputMapper, HandPoseCommand +from teleopit.sim2real.hands.pico_landmarks import pico_hand_to_landmarks + +logger = logging.getLogger(__name__) + +PROJECT_ROOT = Path(__file__).resolve().parents[3] +DEFAULT_SOMEHAND_CONFIG = "third_party/somehand/configs/retargeting/bihand/linkerhand_l6_bihand.yaml" +DEFAULT_LINKERHAND_SDK_ROOT = "third_party/linkerhand-python-sdk" +THUMB_YAW_DEFAULT = 10 +OPEN_POSE = (250, THUMB_YAW_DEFAULT, 250, 250, 250, 250) +CLOSE_POSE = (79, THUMB_YAW_DEFAULT, 0, 0, 0, 0) +DEFAULT_SPEED = (50, 50, 50, 50, 50, 50) +VR_HAND_POSE_SPEED = (255, 255, 255, 255, 255, 255) +L6_SDK_JOINT_ORDER = ( + "thumb_cmc_pitch", + "thumb_cmc_roll", + "index_mcp_pitch", + "middle_mcp_pitch", + "ring_mcp_pitch", + "pinky_mcp_pitch", +) + + +@dataclass(frozen=True) +class LinkerHandL6Config: + mode: str + sides: tuple[str, ...] + left_can: str + right_can: str + modbus: str + rate_hz: float + frame_timeout_s: float + trigger_deadzone: float + deadman_threshold: float + thumb_yaw_center: int + speed: tuple[int, ...] + open_pose: tuple[int, ...] + close_pose: tuple[int, ...] + print_input: bool + somehand_config_path: str + somehand_rate_hz: float + somehand_max_iterations: int | None + somehand_temporal_filter_alpha: float | None + somehand_output_alpha: float | None + + +def parse_linkerhand_l6_config(cfg: Any) -> LinkerHandL6Config: + hands_cfg = cfg_get(cfg, "hands", {}) or {} + l6_cfg = cfg_get(hands_cfg, "linkerhand_l6", {}) or {} + somehand_cfg = cfg_get(hands_cfg, "somehand", {}) or {} + mode = str(cfg_get(hands_cfg, "mode", "gripper")).strip().lower() + if mode not in ("gripper", "vr_hand_pose"): + raise ValueError(f"hands.mode must be gripper or vr_hand_pose, got {mode!r}") + sides = tuple(str(side).strip().lower() for side in cfg_get(hands_cfg, "sides", HAND_SIDES)) + if not sides or any(side not in HAND_SIDES for side in sides): + raise ValueError("hands.sides must contain left, right, or both sides") + thumb_yaw = _uint8(cfg_get(l6_cfg, "thumb_yaw_center", THUMB_YAW_DEFAULT), "thumb_yaw_center") + open_pose = _pose_values(cfg_get(l6_cfg, "open_pose", OPEN_POSE), "open_pose") + close_pose = _pose_values(cfg_get(l6_cfg, "close_pose", CLOSE_POSE), "close_pose") + open_pose[1] = thumb_yaw + close_pose[1] = thumb_yaw + speed = VR_HAND_POSE_SPEED if mode == "vr_hand_pose" else tuple(_pose_values(cfg_get(l6_cfg, "speed", DEFAULT_SPEED), "speed")) + return LinkerHandL6Config( + mode=mode, + sides=sides, + left_can=str(cfg_get(l6_cfg, "left_can", "can0")), + right_can=str(cfg_get(l6_cfg, "right_can", "can1")), + modbus=str(cfg_get(l6_cfg, "modbus", "None")), + rate_hz=_positive_float(cfg_get(hands_cfg, "rate_hz", cfg_get(l6_cfg, "rate_hz", 30.0)), "rate_hz"), + frame_timeout_s=_positive_float(cfg_get(hands_cfg, "frame_timeout_s", 0.3), "frame_timeout_s"), + trigger_deadzone=_deadzone(cfg_get(l6_cfg, "trigger_deadzone", 0.05)), + deadman_threshold=_threshold(cfg_get(l6_cfg, "deadman_threshold", 0.5)), + thumb_yaw_center=thumb_yaw, + speed=tuple(speed), + open_pose=tuple(open_pose), + close_pose=tuple(close_pose), + print_input=bool(cfg_get(l6_cfg, "print_input", False)), + somehand_config_path=str(cfg_get(somehand_cfg, "config_path", DEFAULT_SOMEHAND_CONFIG)), + somehand_rate_hz=_positive_float(cfg_get(somehand_cfg, "rate_hz", cfg_get(somehand_cfg, "rate", 60.0)), "somehand.rate_hz"), + somehand_max_iterations=_optional_positive_int(cfg_get(somehand_cfg, "max_iterations", None), "somehand.max_iterations"), + somehand_temporal_filter_alpha=_optional_alpha(cfg_get(somehand_cfg, "temporal_filter_alpha", None), "somehand.temporal_filter_alpha"), + somehand_output_alpha=_optional_alpha(cfg_get(somehand_cfg, "output_alpha", None), "somehand.output_alpha"), + ) + + +class LinkerHandL6Device(HandDevice): + def __init__(self, config: LinkerHandL6Config): + self.config = config + self._hands: dict[str, Any] = {} + self._last_pose: dict[str, tuple[int, ...] | None] = {side: None for side in config.sides} + + def connect(self) -> None: + try: + from LinkerHand.linker_hand_api import LinkerHandApi + except ImportError as exc: + raise ImportError( + "LinkerHand SDK is required for hands.driver=linkerhand_l6. " + "Install it with: pip install -e third_party/linkerhand-python-sdk" + ) from exc + try: + for side in self.config.sides: + hand = LinkerHandApi( + hand_joint="L6", + hand_type=side, + modbus=self.config.modbus, + can=self.config.left_can if side == "left" else self.config.right_can, + ) + hand.set_speed(speed=list(self.config.speed)) + self._hands[side] = hand + except (Exception, SystemExit) as exc: + self.close() + if isinstance(exc, SystemExit): + raise RuntimeError("LinkerHand SDK exited during startup") from exc + raise + self.open_all(force=True, reason="startup") + + def send_pose(self, side: str, pose: Sequence[int], *, force: bool = False, reason: str = "") -> None: + del reason + next_pose = tuple(_uint8(value, f"{side}.pose") for value in pose) + if not force and self._last_pose.get(side) == next_pose: + return + hand = self._hands.get(side) + if hand is None: + return + hand.finger_move(pose=list(next_pose)) + self._last_pose[side] = next_pose + + def open_all(self, *, force: bool = False, reason: str = "") -> None: + for side in self.config.sides: + self.send_pose(side, self.config.open_pose, force=force, reason=reason) + + def close(self) -> None: + try: + self.open_all(force=True, reason="shutdown") + except Exception: + logger.exception("Failed to open LinkerHand L6 on shutdown") + for hand in self._hands.values(): + inner = getattr(hand, "hand", None) + close = getattr(inner, "close", None) + if callable(close): + close() + self._hands.clear() + + +class GripperMapper(HandInputMapper): + def __init__(self, config: LinkerHandL6Config): + self.config = config + self._active = False + self._next_tick_s = 0.0 + + def start(self) -> None: + pass + + def map(self, *, controller_snapshot: object | None, hand_snapshot: object | None, active: bool, now_s: float) -> tuple[HandPoseCommand, ...]: + del hand_snapshot + if not active: + if not self._active: + return () + self._active = False + return tuple(HandPoseCommand(side, self.config.open_pose, True, "inactive") for side in self.config.sides) + if now_s < self._next_tick_s: + return () + self._active = True + self._next_tick_s = now_s + 1.0 / self.config.rate_hz + if controller_snapshot is None or now_s - float(getattr(controller_snapshot, "timestamp_s", 0.0)) > self.config.frame_timeout_s: + return tuple(HandPoseCommand(side, self.config.open_pose, False, "timeout") for side in self.config.sides) + commands: list[HandPoseCommand] = [] + for side in self.config.sides: + state = getattr(controller_snapshot, side) + if not bool(getattr(state, "present", True)): + commands.append(HandPoseCommand(side, self.config.open_pose, False, "missing-controller")) + continue + grip = _clamp01(getattr(state, "grip", 0.0)) + trigger = _clamp01(getattr(state, "trigger", 0.0)) + if grip < self.config.deadman_threshold: + commands.append(HandPoseCommand(side, self.config.open_pose, False, "deadman")) + continue + pose = trigger_to_pose( + trigger, + open_pose=self.config.open_pose, + close_pose=self.config.close_pose, + deadzone=self.config.trigger_deadzone, + thumb_yaw_default=self.config.thumb_yaw_center, + ) + commands.append(HandPoseCommand(side, tuple(pose), False, "controller")) + return tuple(commands) + + def close(self) -> None: + pass + + +class SomehandL6Mapper(HandInputMapper): + def __init__(self, config: LinkerHandL6Config): + self.config = config + self._engine: Any | None = None + self._hand_frame_cls: Any | None = None + self._bihand_frame_cls: Any | None = None + self._mappers: dict[str, L6RetargetPoseMapper] = {} + self._next_tick_s = 0.0 + self._active = False + + def start(self) -> None: + _require_somehand_020() + from somehand.api import HandFrame, RetargetingEngine, load_bihand_config, load_retargeting_config + + config_path = _resolve_project_path(self.config.somehand_config_path) + if not config_path.exists(): + raise FileNotFoundError(f"somehand L6 config not found: {config_path}") + bihand_config = load_bihand_config(str(config_path)) + self._engine = {} + for side, path in (("left", bihand_config.left_config_path), ("right", bihand_config.right_config_path)): + retarget_cfg = load_retargeting_config(path) + self._apply_low_latency_overrides(retarget_cfg) + self._engine[side] = RetargetingEngine(retarget_cfg) + self._hand_frame_cls = HandFrame + for side, engine in self._engine.items(): + if side in self.config.sides: + self._mappers[side] = L6RetargetPoseMapper(getattr(engine, "hand_model", None), side=side) + + def map(self, *, controller_snapshot: object | None, hand_snapshot: object | None, active: bool, now_s: float) -> tuple[HandPoseCommand, ...]: + del controller_snapshot + if not active: + if not self._active: + return () + self._active = False + return tuple(HandPoseCommand(side, self.config.open_pose, True, "inactive") for side in self.config.sides) + if now_s < self._next_tick_s: + return () + self._active = True + self._next_tick_s = now_s + 1.0 / self.config.somehand_rate_hz + if hand_snapshot is None or now_s - float(getattr(hand_snapshot, "timestamp_s", 0.0)) > self.config.frame_timeout_s: + return () + left_frame = self._make_frame("left", getattr(hand_snapshot, "left", None)) + right_frame = self._make_frame("right", getattr(hand_snapshot, "right", None)) + if left_frame is None and right_frame is None: + return () + commands: list[HandPoseCommand] = [] + for side, frame in (("left", left_frame), ("right", right_frame)): + if side not in self.config.sides or frame is None: + continue + step = self._engine[side].process(frame) + commands.append(HandPoseCommand(side, tuple(self._mappers[side].qpos_to_pose(step.qpos)), False, "vr-hand-pose")) + return tuple(commands) + + def close(self) -> None: + pass + + def _make_frame(self, side: str, state: object | None) -> object | None: + if side not in self.config.sides or state is None: + return None + if not bool(getattr(state, "present", False)) or not bool(getattr(state, "active", False)): + return None + landmarks = pico_hand_to_landmarks(getattr(state, "joints")) + return self._hand_frame_cls(landmarks_3d=landmarks, landmarks_2d=None, hand_side=side) + + def _apply_low_latency_overrides(self, cfg: object) -> None: + if self.config.somehand_max_iterations is not None: + cfg.solver.max_iterations = int(self.config.somehand_max_iterations) + if self.config.somehand_output_alpha is not None: + cfg.solver.output_alpha = float(self.config.somehand_output_alpha) + if self.config.somehand_temporal_filter_alpha is not None: + cfg.preprocess.temporal_filter_alpha = float(self.config.somehand_temporal_filter_alpha) + + +class L6RetargetPoseMapper: + def __init__(self, hand_model: Any | None, *, side: str): + if hand_model is None: + raise ValueError("somehand L6 hand model is missing") + get_index = getattr(hand_model, "get_joint_name_to_qpos_index", None) + if not callable(get_index): + raise ValueError("somehand L6 hand model does not expose get_joint_name_to_qpos_index()") + joint_index = get_index() + self._indices = np.asarray([_resolve_l6_joint_index(joint_index, name, side=side) for name in L6_SDK_JOINT_ORDER], dtype=np.int64) + mapping = _load_linkerhand_mapping_module() + side_key = "l" if side == "left" else "r" + self._mapping = mapping + self._arc_min = np.asarray(getattr(mapping, f"l6_{side_key}_min"), dtype=np.float64) + self._arc_max = np.asarray(getattr(mapping, f"l6_{side_key}_max"), dtype=np.float64) + self._direction = np.asarray(getattr(mapping, f"l6_{side_key}_derict"), dtype=np.int8) + + def qpos_to_pose(self, qpos: object) -> list[int]: + values = np.asarray(qpos, dtype=np.float64).reshape(-1) + selected = values[self._indices] + pose = [] + for index, value in enumerate(selected): + arc = self._mapping.is_within_range(float(value), float(self._arc_min[index]), float(self._arc_max[index])) + if int(self._direction[index]) == -1: + scaled = self._mapping.scale_value(arc, float(self._arc_min[index]), float(self._arc_max[index]), 255.0, 0.0) + else: + scaled = self._mapping.scale_value(arc, float(self._arc_min[index]), float(self._arc_max[index]), 0.0, 255.0) + pose.append(_uint8(round(float(scaled)), "somehand.pose")) + return pose + + +def build_linkerhand_l6(cfg: Any) -> tuple[HandDevice, HandInputMapper]: + config = parse_linkerhand_l6_config(cfg) + device = LinkerHandL6Device(config) + mapper: HandInputMapper = SomehandL6Mapper(config) if config.mode == "vr_hand_pose" else GripperMapper(config) + return device, mapper + + +def trigger_to_pose(trigger: float, *, open_pose: Sequence[int], close_pose: Sequence[int], deadzone: float, thumb_yaw_default: int) -> list[int]: + alpha = _normalize_trigger(trigger, deadzone) + pose = [int(round(float(a) + alpha * (float(b) - float(a)))) for a, b in zip(open_pose, close_pose)] + pose[1] = int(thumb_yaw_default) + return pose + + +def _require_somehand_020() -> None: + try: + installed = version("somehand") + except PackageNotFoundError as exc: + raise ImportError("somehand==0.2.0 is required for hands.mode=vr_hand_pose") from exc + if installed != "0.2.0": + raise ImportError(f"somehand==0.2.0 is required for hands.mode=vr_hand_pose, found {installed}") + + +def _resolve_l6_joint_index(joint_index: dict[str, int], semantic_name: str, *, side: str) -> int: + for candidate in _l6_joint_candidates(semantic_name, side=side): + if candidate in joint_index: + return int(joint_index[candidate]) + suffixes = tuple(f"_{alias}" for alias in _l6_aliases(semantic_name)) + for name, index in joint_index.items(): + if name in _l6_aliases(semantic_name) or any(name.endswith(suffix) for suffix in suffixes): + return int(index) + raise ValueError(f"Cannot resolve LinkerHand L6 SDK joint {semantic_name!r} in somehand hand model") + + +def _l6_joint_candidates(semantic_name: str, *, side: str) -> tuple[str, ...]: + prefixes = ("", f"{side}_", f"{side[0]}_", f"{side[0].upper()}_", f"{'lh' if side == 'left' else 'rh'}_") + return tuple(f"{prefix}{alias}" for alias in _l6_aliases(semantic_name) for prefix in prefixes) + + +def _l6_aliases(semantic_name: str) -> tuple[str, ...]: + if semantic_name == "thumb_cmc_pitch": + return ("thumb_cmc_pitch", "thumb_pitch") + if semantic_name == "thumb_cmc_roll": + return ("thumb_cmc_roll", "thumb_roll") + aliases = [semantic_name] + if semantic_name.endswith("_mcp_pitch"): + finger = semantic_name[: -len("_mcp_pitch")] + aliases.append(f"{finger}_pitch") + if finger == "pinky": + aliases.extend(("little_mcp_pitch", "little_pitch")) + return tuple(dict.fromkeys(aliases)) + + +def _load_linkerhand_mapping_module() -> Any: + mapping_path = _resolve_project_path(DEFAULT_LINKERHAND_SDK_ROOT) / "LinkerHand" / "utils" / "mapping.py" + spec = importlib.util.spec_from_file_location("teleopit_linkerhand_mapping", mapping_path) + if spec is None or spec.loader is None: + raise RuntimeError(f"Cannot load LinkerHand mapping module from {mapping_path}") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def _resolve_project_path(path_value: str) -> Path: + path = Path(path_value).expanduser() + return path if path.is_absolute() else (PROJECT_ROOT / path).resolve() + + +def _clamp01(value: object) -> float: + return max(0.0, min(1.0, float(value))) + + +def _normalize_trigger(value: float, deadzone: float) -> float: + value = _clamp01(value) + if value <= deadzone: + return 0.0 + upper = 1.0 - deadzone + if value >= upper: + return 1.0 + return (value - deadzone) / (upper - deadzone) + + +def _uint8(value: object, field_name: str) -> int: + parsed = int(value) + if parsed < 0 or parsed > 255: + raise ValueError(f"hands.linkerhand_l6.{field_name} must be in 0-255, got {value!r}") + return parsed + + +def _pose_values(value: object, field_name: str) -> list[int]: + parsed = [_uint8(item, field_name) for item in value] # type: ignore[union-attr] + if len(parsed) != 6: + raise ValueError(f"hands.linkerhand_l6.{field_name} must contain 6 values") + return parsed + + +def _positive_float(value: object, field_name: str) -> float: + parsed = float(value) + if parsed <= 0: + raise ValueError(f"hands.{field_name} must be > 0") + return parsed + + +def _optional_positive_int(value: object, field_name: str) -> int | None: + if value is None: + return None + parsed = int(value) + if parsed <= 0: + raise ValueError(f"hands.{field_name} must be > 0") + return parsed + + +def _optional_alpha(value: object, field_name: str) -> float | None: + if value is None: + return None + parsed = float(value) + if parsed <= 0.0 or parsed > 1.0: + raise ValueError(f"hands.{field_name} must be in (0, 1]") + return parsed + + +def _deadzone(value: object) -> float: + parsed = float(value) + if parsed < 0.0 or parsed >= 0.5: + raise ValueError("hands.linkerhand_l6.trigger_deadzone must be in [0, 0.5)") + return parsed + + +def _threshold(value: object) -> float: + parsed = float(value) + if parsed <= 0.0 or parsed >= 1.0: + raise ValueError("hands.linkerhand_l6.deadman_threshold must be in (0, 1)") + return parsed diff --git a/teleopit/sim2real/hands/pico_landmarks.py b/teleopit/sim2real/hands/pico_landmarks.py new file mode 100644 index 00000000..c747046c --- /dev/null +++ b/teleopit/sim2real/hands/pico_landmarks.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +import numpy as np + +PICO_BRIDGE_TO_MEDIAPIPE = ( + 1, + 2, + 3, + 4, + 5, + 7, + 8, + 9, + 10, + 12, + 13, + 14, + 15, + 17, + 18, + 19, + 20, + 22, + 23, + 24, + 25, +) + +PICO_NATIVE_TO_RH = np.array( + [[1.0, 0.0, 0.0], [0.0, 0.0, -1.0], [0.0, 1.0, 0.0]], + dtype=np.float64, +) + + +def pico_hand_to_landmarks(hand_state: object) -> np.ndarray: + state = np.asarray(hand_state, dtype=np.float64) + if state.shape != (26, 7): + state = state.reshape(26, 7) + positions = state[:, :3] @ PICO_NATIVE_TO_RH.T + landmarks = np.empty((21, 3), dtype=np.float64) + for mp_index, pico_index in enumerate(PICO_BRIDGE_TO_MEDIAPIPE): + landmarks[mp_index] = positions[pico_index] + return landmarks diff --git a/teleopit/sim2real/hands/worker.py b/teleopit/sim2real/hands/worker.py new file mode 100644 index 00000000..c1a9cdb5 --- /dev/null +++ b/teleopit/sim2real/hands/worker.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +import logging +import time +from typing import Any + +from teleopit.runtime.common import cfg_get +from teleopit.sim2real.hands.base import HandDevice, HandInputMapper +from teleopit.sim2real.hands.linkerhand_l6 import build_linkerhand_l6 + +logger = logging.getLogger(__name__) + + +class HandRuntime: + def __init__(self, device: HandDevice, mapper: HandInputMapper): + self._device = device + self._mapper = mapper + self.enabled = True + self._failed = False + + def start(self) -> None: + try: + self._device.connect() + self._mapper.start() + except Exception: + try: + self._device.close() + finally: + raise + + def tick(self, *, controller_snapshot: object | None, hand_snapshot: object | None, active: bool, now_s: float | None = None) -> None: + if self._failed: + return + now = time.monotonic() if now_s is None else float(now_s) + try: + for command in self._mapper.map( + controller_snapshot=controller_snapshot, + hand_snapshot=hand_snapshot, + active=active, + now_s=now, + ): + self._device.send_pose(command.side, command.pose, force=command.force, reason=command.reason) + except Exception: + self._failed = True + logger.exception("Hand runtime failed; disabling hand control") + try: + self._device.open_all(force=True, reason="failure") + except Exception: + logger.exception("Failed to open hand after hand runtime failure") + + def close(self) -> None: + try: + self._mapper.close() + finally: + self._device.close() + + +class DisabledHandRuntime: + enabled = False + + def start(self) -> None: + pass + + def tick(self, *, controller_snapshot: object | None, hand_snapshot: object | None, active: bool, now_s: float | None = None) -> None: + del controller_snapshot, hand_snapshot, active, now_s + + def close(self) -> None: + pass + + +def build_hand_runtime(cfg: Any) -> HandRuntime | DisabledHandRuntime: + hands_cfg = cfg_get(cfg, "hands", {}) or {} + if not bool(cfg_get(hands_cfg, "enabled", False)): + return DisabledHandRuntime() + driver = str(cfg_get(hands_cfg, "driver", "linkerhand_l6")).strip().lower() + if driver != "linkerhand_l6": + raise ValueError(f"Unsupported hands.driver={driver!r}; only linkerhand_l6 is implemented") + device, mapper = build_linkerhand_l6(cfg) + return HandRuntime(device, mapper) diff --git a/teleopit/sim2real/mp/__init__.py b/teleopit/sim2real/mp/__init__.py index 59adde7c..f32f55cd 100644 --- a/teleopit/sim2real/mp/__init__.py +++ b/teleopit/sim2real/mp/__init__.py @@ -1,11 +1,9 @@ -"""Multiprocess sim2real runtime.""" +"""Process-isolated sim2real runtime.""" from teleopit.sim2real.mp.runtime import ( - MultiprocessSim2RealController, - resolve_sim2real_runtime_mode, + Sim2RealRuntime, ) __all__ = [ - "MultiprocessSim2RealController", - "resolve_sim2real_runtime_mode", + "Sim2RealRuntime", ] diff --git a/teleopit/sim2real/mp/ipc.py b/teleopit/sim2real/mp/ipc.py index d31fc328..109d41d8 100644 --- a/teleopit/sim2real/mp/ipc.py +++ b/teleopit/sim2real/mp/ipc.py @@ -32,6 +32,7 @@ class Sim2RealIpcEndpoints: video_pub: str health_pub: str command_pub: str + reference_command_pub: str def default_endpoints(*, host: str = "127.0.0.1", base_port: int = 39700) -> Sim2RealIpcEndpoints: @@ -47,6 +48,7 @@ def default_endpoints(*, host: str = "127.0.0.1", base_port: int = 39700) -> Sim video_pub=f"{prefix}{base_port + 6}", health_pub=f"{prefix}{base_port + 7}", command_pub=f"{prefix}{base_port + 8}", + reference_command_pub=f"{prefix}{base_port + 9}", ) diff --git a/teleopit/sim2real/mp/messages.py b/teleopit/sim2real/mp/messages.py index 137666e4..d16a1496 100644 --- a/teleopit/sim2real/mp/messages.py +++ b/teleopit/sim2real/mp/messages.py @@ -32,6 +32,8 @@ class ReferencePacket: frame_valid: bool = True reference_window: ReferenceWindow | None = None retarget_elapsed_s: float = 0.0 + playback_paused: bool = False + playback_finished: bool = False @dataclass(frozen=True) diff --git a/teleopit/sim2real/mp/runtime.py b/teleopit/sim2real/mp/runtime.py index 9052c1bf..b1279e98 100644 --- a/teleopit/sim2real/mp/runtime.py +++ b/teleopit/sim2real/mp/runtime.py @@ -5,6 +5,7 @@ import logging import multiprocessing as mp from multiprocessing.synchronize import Event as MpEvent +from enum import Enum from pathlib import Path import time from typing import Any, Callable @@ -15,25 +16,27 @@ from teleopit.constants import FULL_QPOS_DIM, NUM_JOINTS, ROOT_DIM from teleopit.controllers.observation import VelCmdObservationBuilder, align_motion_qpos_yaw from teleopit.controllers.rl_policy import RLPolicyController +from teleopit.inputs.bvh_provider import BVHInputProvider from teleopit.inputs.human_frame_validation import validate_human_frame from teleopit.inputs.pico4_provider import Pico4InputProvider from teleopit.inputs.pico_video import PicoVideoRuntime, bridge_video_source, parse_pico_video_config from teleopit.inputs.realtime_packet import ControlEvent, ControlEventType from teleopit.retargeting.core import RetargetingModule -from teleopit.runtime.common import cfg_get, require_section +from teleopit.runtime.offline_playback import OfflinePlaybackController +from teleopit.runtime.common import cfg_get, parse_viewers, require_section from teleopit.runtime.factory import _build_policy_components, build_simulation_cfg from teleopit.runtime.mocap_session import MocapSessionManager, MocapSessionState from teleopit.runtime.reference_config import parse_reference_config +from teleopit.sim.reference_motion import OfflineReferenceMotion from teleopit.sim.reference_timeline import ReferenceTimeline, ReferenceWindow, ReferenceWindowBuilder -from teleopit.sim.reference_utils import build_static_reference_window, obs_builder_requires_reference_window -from teleopit.sim.realtime_utils import RealtimeReferenceManager -from teleopit.sim2real.controller import ( - RobotMode, - _LoopTimingReporter, - _parse_sim2real_viewers, - _Sim2RealRetargetViewer, +from teleopit.sim.reference_utils import ( + build_offline_reference_window, + build_static_reference_window, + obs_builder_requires_reference_window, ) -from teleopit.sim2real.dexterous_hand import build_linkerhand_runtime +from teleopit.sim.realtime_utils import RealtimeReferenceManager +from teleopit.sim.viewer_subprocess import start_robot_viewer +from teleopit.sim2real.hands.worker import build_hand_runtime from teleopit.sim2real.mp.ipc import ( BODY_TOPIC, COMMAND_TOPIC, @@ -78,20 +81,118 @@ PROJECT_ROOT = Path(__file__).resolve().parents[3] -def resolve_sim2real_runtime_mode(cfg: Any) -> str: - """Resolve ``auto|single_process|multiprocess`` into a concrete runtime.""" - raw = str(cfg_get(cfg, "sim2real_runtime", "auto")).strip().lower() - if raw in ("single", "single_process", "legacy"): - return "single_process" - if raw in ("mp", "multi", "multiprocess"): - provider = str(cfg_get(cfg_get(cfg, "input", {}), "provider", "")).lower() - if provider != "pico4": - raise ValueError("sim2real_runtime=multiprocess currently requires input.provider=pico4") - return "multiprocess" - if raw != "auto": - raise ValueError("sim2real_runtime must be auto, single_process, or multiprocess") - provider = str(cfg_get(cfg_get(cfg, "input", {}), "provider", "")).lower() - return "multiprocess" if provider == "pico4" else "single_process" +class RobotMode(Enum): + IDLE = "idle" + STANDING = "standing" + MOCAP = "mocap" + DAMPING = "damping" + + +class _LoopTimingReporter: + def __init__(self, *, target_period_s: float, log_interval_s: float = 1.0) -> None: + self._target_period_s = float(target_period_s) + self._log_interval_s = float(log_interval_s) + self._window_start_s: float | None = None + self._loop_ms: list[float] = [] + self._work_ms: list[float] = [] + self._pico_age_ms: list[float] = [] + self._overrun_count = 0 + + def record(self, *, loop_start_s: float, work_elapsed_s: float, cycle_elapsed_s: float, pico_age_s: float | None) -> None: + if self._window_start_s is None: + self._window_start_s = float(loop_start_s) + self._loop_ms.append(float(cycle_elapsed_s) * 1000.0) + self._work_ms.append(float(work_elapsed_s) * 1000.0) + if pico_age_s is not None: + self._pico_age_ms.append(float(pico_age_s) * 1000.0) + if cycle_elapsed_s > self._target_period_s + 1e-9: + self._overrun_count += 1 + if loop_start_s - self._window_start_s >= self._log_interval_s: + self._emit(loop_start_s) + + def _emit(self, end_s: float) -> None: + sample_count = len(self._loop_ms) + if sample_count <= 0: + self._reset(end_s) + return + loop_summary = self._summarize(self._loop_ms) + work_summary = self._summarize(self._work_ms) + message = ( + "Timing stats | samples=%d window=%.1fs | " + "loop_ms p50=%.2f p95=%.2f p99=%.2f max=%.2f overrun=%d/%d | " + "work_ms p50=%.2f p95=%.2f p99=%.2f max=%.2f" + ) + args: list[object] = [ + sample_count, + end_s - float(self._window_start_s), + *loop_summary, + self._overrun_count, + sample_count, + *work_summary, + ] + if self._pico_age_ms: + message += " | reference_age_ms p50=%.2f p95=%.2f p99=%.2f max=%.2f" + args.extend(self._summarize(self._pico_age_ms)) + logger.info(message, *args) + self._reset(end_s) + + def _reset(self, window_start_s: float) -> None: + self._window_start_s = float(window_start_s) + self._loop_ms.clear() + self._work_ms.clear() + self._pico_age_ms.clear() + self._overrun_count = 0 + + @staticmethod + def _summarize(samples: list[float]) -> tuple[float, float, float, float]: + values = np.asarray(samples, dtype=np.float64) + if values.size <= 0: + return 0.0, 0.0, 0.0, 0.0 + p50, p95, p99 = np.percentile(values, [50.0, 95.0, 99.0]) + return float(p50), float(p95), float(p99), float(np.max(values)) + + +def _parse_sim2real_viewers(cfg: Any) -> set[str]: + viewers = parse_viewers(cfg) + unsupported = viewers.difference({"retarget"}) + if unsupported: + raise ValueError( + f"Sim2real supports only the optional 'retarget' viewer; got unsupported viewers {sorted(unsupported)}. " + "Use viewers=retarget or viewers=none." + ) + return viewers + + +class _Sim2RealRetargetViewer: + def __init__(self, *, xml_path: str | None, enabled: bool) -> None: + self._entry: tuple[Any, Any, Any, Any] | None = None + if not enabled: + return + if not xml_path: + raise ValueError("Sim2real retarget viewer requires robot.xml_path to be set.") + self._entry = start_robot_viewer(xml_path, FULL_QPOS_DIM, True, "Retarget", 900, 50) + + def write(self, qpos: Float64Array) -> None: + if self._entry is None: + return + _, arr, alive, _ = self._entry + if not alive.value: + return + qpos = np.asarray(qpos, dtype=np.float64).reshape(-1) + if qpos.shape[0] < FULL_QPOS_DIM: + return + with arr.get_lock(): + arr[:FULL_QPOS_DIM] = qpos[:FULL_QPOS_DIM].tolist() + + def shutdown(self) -> None: + if self._entry is None: + return + proc, _, _, shutdown = self._entry + shutdown.set() + proc.join(timeout=3) + if proc.is_alive(): + proc.terminate() + self._entry = None def _plain_cfg(cfg: Any) -> dict[str, Any]: @@ -103,7 +204,26 @@ def _plain_cfg(cfg: Any) -> dict[str, Any]: def _mp_cfg(cfg: Any) -> Any: - return cfg_get(cfg, "multiprocess", {}) or {} + return cfg_get(cfg, "runtime", {}) or {} + + +def _input_provider_kind(cfg: Any) -> str: + return str(cfg_get(cfg_get(cfg, "input", {}) or {}, "provider", "bvh")).strip().lower() + + +def _validate_new_runtime_config(cfg: Any) -> None: + legacy_keys = [key for key in ("sim2real_runtime", "multiprocess", "dexterous_hand") if cfg_get(cfg, key, None) is not None] + if legacy_keys: + raise ValueError( + "Legacy sim2real config keys are no longer supported: " + f"{', '.join(legacy_keys)}. Use input.provider, runtime, and hands instead." + ) + provider = _input_provider_kind(cfg) + if provider not in ("pico4", "bvh"): + raise ValueError(f"sim2real input.provider must be pico4 or bvh, got {provider!r}") + hands_cfg = cfg_get(cfg, "hands", {}) or {} + if bool(cfg_get(hands_cfg, "enabled", False)) and provider != "pico4": + raise ValueError("hands.enabled=true requires input.provider=pico4") def _worker_loop(name: str, fn: Callable[[], None]) -> None: @@ -121,19 +241,18 @@ def _human_frame_is_valid(frame: object) -> bool: return validate_human_frame(frame).valid -class MultiprocessSim2RealController: - """Supervisor facade for the multiprocess Pico sim2real runtime.""" +class Sim2RealRuntime: + """Supervisor facade for the process-isolated sim2real runtime.""" def __init__(self, cfg: Any) -> None: self.cfg = _plain_cfg(cfg) - if resolve_sim2real_runtime_mode(self.cfg) != "multiprocess": - raise ValueError("MultiprocessSim2RealController requires sim2real_runtime=multiprocess or auto+pico4") + _validate_new_runtime_config(self.cfg) mp_cfg = _mp_cfg(self.cfg) video_cfg = parse_pico_video_config(cfg_get(self.cfg, "input", {})) if video_cfg.enabled and video_cfg.source not in ("realsense", "test-pattern"): raise ValueError( - "Multiprocess sim2real only supports input.video.source=realsense or test-pattern" + "Sim2RealRuntime only supports input.video.source=realsense or test-pattern" ) self._ctx = mp.get_context(str(cfg_get(mp_cfg, "start_method", "spawn"))) self._stop_event = self._ctx.Event() @@ -145,24 +264,36 @@ def __init__(self, cfg: Any) -> None: ) def run(self) -> None: - logger.info("Starting multiprocess sim2real runtime") + logger.info("Starting sim2real runtime") try: self._start_processes() while not self._stop_event.is_set(): time.sleep(0.2) + critical_names = {"robot_control", "reference"} + if _input_provider_kind(self.cfg) == "pico4": + critical_names.add("pico_input") critical_dead = [ process.name for process in self._processes if not process.is_alive() and process.exitcode not in (None, 0) - and process.name in {"robot_control", "pico_io", "retarget_worker"} + and process.name in critical_names ] if critical_dead: logger.error("Critical sim2real worker exited: %s", ", ".join(critical_dead)) self._stop_event.set() break + noncritical_dead = [ + process.name + for process in self._processes + if not process.is_alive() + and process.exitcode not in (None, 0) + and process.name not in critical_names + ] + if noncritical_dead: + logger.warning("Non-critical sim2real worker exited: %s", ", ".join(noncritical_dead)) except KeyboardInterrupt: - logger.info("KeyboardInterrupt -- shutting down multiprocess sim2real") + logger.info("KeyboardInterrupt -- shutting down sim2real") self._stop_event.set() finally: self.shutdown() @@ -182,17 +313,21 @@ def _start_processes(self) -> None: if self._processes: return - specs: list[tuple[str, Callable[..., None]]] = [ - ("pico_io", _run_pico_io_worker), - ("retarget_worker", _run_retarget_worker), - ("robot_control", _run_robot_control_worker), - ] - hand_mode = str(cfg_get(cfg_get(self.cfg, "dexterous_hand", {}) or {}, "mode", "off")).lower() - if hand_mode != "off": + specs: list[tuple[str, Callable[..., None]]] = [] + if _input_provider_kind(self.cfg) == "pico4": + specs.append(("pico_input", _run_pico_io_worker)) + specs.extend( + [ + ("reference", _run_reference_worker), + ("robot_control", _run_robot_control_worker), + ] + ) + hands_cfg = cfg_get(self.cfg, "hands", {}) or {} + if bool(cfg_get(hands_cfg, "enabled", False)): specs.append(("hand_worker", _run_hand_worker)) video_cfg = parse_pico_video_config(cfg_get(self.cfg, "input", {})) if video_cfg.enabled: - logger.info("Pico video runs inside pico_io so frames are pushed directly to PicoBridge") + logger.info("Pico video runs inside pico_input so frames are pushed directly to PicoBridge") for name, target in specs: process = self._ctx.Process( @@ -241,7 +376,7 @@ def _main() -> None: mode="sim2real", ) - hz = float(cfg_get(_mp_cfg(cfg), "pico_io_hz", 120.0)) + hz = float(cfg_get(_mp_cfg(cfg), "pico_input_hz", 120.0)) sleep_s = 1.0 / max(hz, 1.0) last_body_seq = -1 last_hand_seq = -1 @@ -262,7 +397,7 @@ def _main() -> None: try: frame, timestamp_s, seq = provider.get_frame_packet() except Exception: - logger.exception("pico_io failed to read body frame") + logger.exception("pico_input failed to read body frame") else: if int(seq) != last_body_seq: body_pub.publish( @@ -309,7 +444,7 @@ def _main() -> None: health_pub.publish( HEALTH_TOPIC, HealthPacket( - worker="pico_io", + worker="pico_input", timestamp_s=now, metrics={ "body_seq": last_body_seq, @@ -329,10 +464,21 @@ def _main() -> None: publisher.close() provider.close() - _worker_loop("pico_io", _main) + _worker_loop("pico_input", _main) + + +def _run_reference_worker( + cfg: dict[str, Any], + endpoints: Sim2RealIpcEndpoints, + stop_event: MpEvent, +) -> None: + if _input_provider_kind(cfg) == "bvh": + _run_bvh_reference_worker(cfg, endpoints, stop_event) + return + _run_pico_reference_worker(cfg, endpoints, stop_event) -def _run_retarget_worker( +def _run_pico_reference_worker( cfg: dict[str, Any], endpoints: Sim2RealIpcEndpoints, stop_event: MpEvent, @@ -401,17 +547,17 @@ def _publish_invalid_reference(packet: BodyFramePacket, *, elapsed_s: float) -> try: while not stop_event.is_set(): - health_packet = health_sub.recv_latest() - if isinstance(health_packet, HealthPacket) and health_packet.worker == "pico_io": - metric_fps = health_packet.metrics.get("body_fps") - if isinstance(metric_fps, (int, float)) and float(metric_fps) > 0.0: - latest_body_fps = float(metric_fps) - command = command_sub.recv_latest() if isinstance(command, CommandPacket) and command.command == "shutdown": stop_event.set() break + health_packet = health_sub.recv_latest() + if isinstance(health_packet, HealthPacket) and health_packet.worker == "pico_input": + metric_fps = health_packet.metrics.get("body_fps") + if isinstance(metric_fps, (int, float)) and float(metric_fps) > 0.0: + latest_body_fps = float(metric_fps) + packet = body_sub.recv_latest() if packet is None: time.sleep(idle_sleep_s) @@ -425,7 +571,7 @@ def _publish_invalid_reference(packet: BodyFramePacket, *, elapsed_s: float) -> last_body_timestamp_s = None body_dt_s_ema = None _publish_invalid_reference(packet, elapsed_s=time.monotonic() - start_s) - logger.warning("retarget_worker dropped invalid body frame seq=%s", packet.seq) + logger.warning("reference worker dropped invalid body frame seq=%s", packet.seq) continue try: @@ -483,14 +629,144 @@ def _publish_invalid_reference(packet: BodyFramePacket, *, elapsed_s: float) -> ) last_body_seq = int(packet.seq) except Exception: - logger.exception("retarget_worker failed to retarget body seq=%s", getattr(packet, "seq", None)) + logger.exception("reference worker failed to retarget body seq=%s", getattr(packet, "seq", None)) finally: body_sub.close() health_sub.close() command_sub.close() ref_pub.close() - _worker_loop("retarget_worker", _main) + _worker_loop("reference", _main) + + +def _run_bvh_reference_worker( + cfg: dict[str, Any], + endpoints: Sim2RealIpcEndpoints, + stop_event: MpEvent, +) -> None: + def _main() -> None: + input_cfg = cfg_get(cfg, "input", {}) or {} + policy_hz = float(cfg_get(cfg, "policy_hz", 50.0)) + provider = BVHInputProvider( + str(cfg_get(input_cfg, "bvh_file", "")), + human_format=str(cfg_get(input_cfg, "bvh_format", cfg_get(input_cfg, "human_format", "lafan1"))), + ) + retargeter = RetargetingModule( + robot_name=str(cfg_get(input_cfg, "robot_name", "unitree_g1")), + human_format=str(cfg_get(input_cfg, "human_format", cfg_get(input_cfg, "bvh_format", "lafan1"))), + actual_human_height=float(cfg_get(input_cfg, "human_height", provider.human_height)), + ) + offline_reference = OfflineReferenceMotion(provider, retargeter) + playback_cfg = cfg_get(cfg, "playback", {}) or {} + playback = OfflinePlaybackController( + duration_s=offline_reference.duration_s, + step_dt_s=1.0 / policy_hz, + pause_on_end=bool(cfg_get(playback_cfg, "pause_on_end", True)), + ) + reference_window_builder = ReferenceWindowBuilder( + policy_dt_s=1.0 / policy_hz, + reference_steps=cfg_get(cfg, "reference_steps", [0]), + ) + ref_pub = ZmqPublisher(endpoints.reference_pub) + command_sub = LatestSubscriber(endpoints.command_pub, COMMAND_TOPIC) + reference_command_sub = LatestSubscriber(endpoints.reference_command_pub, COMMAND_TOPIC) + mode_sub = LatestSubscriber(endpoints.mode_pub, MODE_TOPIC) + health_pub = ZmqPublisher(endpoints.health_pub) + tick_s = 1.0 / policy_hz + seq = 0 + last_health_s = 0.0 + mocap_active = False + + def _publish(sample_time_s: float, *, frame_valid: bool = True) -> Float64Array | None: + nonlocal seq + start_s = time.monotonic() + sampled = offline_reference.sample(sample_time_s) + if sampled is None: + return None + reference_window = None + if reference_window_builder.requires_timeline: + reference_window = build_offline_reference_window( + offline_reference, + sample_time_s, + reference_window_builder, + policy_hz, + ) + qpos = np.asarray(sampled.qpos, dtype=np.float64).copy() + ref_pub.publish( + REFERENCE_TOPIC, + ReferencePacket( + qpos=qpos, + timestamp_s=time.monotonic(), + seq=seq, + source_timestamp_s=float(sample_time_s), + source_seq=int(sampled.frame_idx0), + frame_valid=frame_valid, + reference_window=reference_window, + retarget_elapsed_s=time.monotonic() - start_s, + playback_paused=playback.paused, + playback_finished=playback.finished, + ), + ) + seq += 1 + return qpos + + try: + while not stop_event.is_set(): + t0 = time.monotonic() + command = command_sub.recv_latest() + if isinstance(command, CommandPacket): + if command.command == "shutdown": + stop_event.set() + break + reference_command = reference_command_sub.recv_latest() + if isinstance(reference_command, CommandPacket): + command = reference_command + if command.command == "pause_mocap": + playback.pause() + elif command.command == "resume_mocap": + if not playback.finished: + playback.resume() + elif command.command == "replay_mocap": + playback.replay() + mode_packet = mode_sub.recv_latest() + if isinstance(mode_packet, ModeStatePacket): + mocap_active = bool(mode_packet.mocap_active) + + qpos = _publish(playback.current_time_s) + if qpos is None: + playback.finish() + _publish(playback.current_time_s) + elif mocap_active: + playback.advance() + + now = time.monotonic() + if now - last_health_s >= 1.0: + health_pub.publish( + HEALTH_TOPIC, + HealthPacket( + worker="reference", + timestamp_s=now, + metrics={ + "source": "bvh", + "seq": seq, + "playback_time_s": float(playback.current_time_s), + "paused": int(playback.paused), + "finished": int(playback.finished), + }, + ), + ) + last_health_s = now + elapsed = time.monotonic() - t0 + if elapsed < tick_s: + time.sleep(tick_s - elapsed) + finally: + command_sub.close() + reference_command_sub.close() + mode_sub.close() + ref_pub.close() + health_pub.close() + + _worker_loop("reference", _main) class _RobotControlWorker: @@ -503,6 +779,7 @@ def __init__( self.cfg = cfg self.endpoints = endpoints self.stop_event = stop_event + self.provider_kind = _input_provider_kind(cfg) self.mode = RobotMode.IDLE self.policy_hz = float(cfg_get(cfg, "policy_hz", 50.0)) self.dt = 1.0 / self.policy_hz @@ -556,6 +833,7 @@ def __init__( self._reference_sub = LatestSubscriber(endpoints.reference_pub, REFERENCE_TOPIC) self._events_sub = LatestSubscriber(endpoints.control_events_pub, CONTROL_EVENTS_TOPIC) self._command_sub = LatestSubscriber(endpoints.command_pub, COMMAND_TOPIC) + self._reference_command_pub = ZmqPublisher(endpoints.reference_command_pub) self._mode_pub = ZmqPublisher(endpoints.mode_pub) viewers = _parse_sim2real_viewers(cfg) @@ -613,6 +891,7 @@ def shutdown(self) -> None: self._reference_sub.close() self._events_sub.close() self._command_sub.close() + self._reference_command_pub.close() self._mode_pub.close() self.robot.close() @@ -657,12 +936,19 @@ def _handle_transitions(self) -> None: else: logger.warning("Cannot switch to MOCAP -- no fresh retarget reference") elif self.mode == RobotMode.MOCAP: + if self.provider_kind == "bvh" and self.remote.B.on_pressed: + logger.info("B pressed -> replaying BVH motion from start") + self._send_reference_command("replay_mocap") + self._resume_paused_mocap_if_needed() + return if self.remote.A.on_pressed: if self._mocap_session.state == MocapSessionState.PAUSED: logger.info("A pressed -> resuming playback") + self._send_reference_command("resume_mocap") self._resume_paused_mocap() else: logger.info("A pressed -> pausing playback") + self._send_reference_command("pause_mocap") self._pause_active_mocap() return if self.remote.X.on_pressed: @@ -808,6 +1094,8 @@ def _can_switch_to_mocap(self) -> bool: return False if not self._latest_reference.frame_valid: return False + if self.provider_kind == "bvh": + return True if age_s > self._max_reference_age_s: return False if self._consecutive_valid_references < self._check_frames: @@ -827,16 +1115,22 @@ def _transition_to_mocap(self) -> None: self._last_retarget_qpos = None self._last_commanded_motion_qpos = resume_qpos.copy() self._ref_proc.reset_alignment(target_qpos=resume_qpos) + if self.provider_kind == "bvh": + self._send_reference_command("replay_mocap") self.mode = RobotMode.MOCAP logger.info("Mode -> MOCAP (tracking multiprocess retarget reference)") + def _resume_paused_mocap_if_needed(self) -> None: + if self._mocap_session.state == MocapSessionState.PAUSED: + self._resume_paused_mocap() + def _enter_damping(self) -> None: if self.mode in (RobotMode.STANDING, RobotMode.MOCAP): logger.info("DAMPING: sending LowCmd damping...") self.robot.set_damping() time.sleep(0.5) logger.info("DAMPING: exiting debug mode...") - self.robot.exit_debug_mode() + self.robot.exit_debug_mode() self.mode = RobotMode.DAMPING self._ref_proc.last_reference_qpos = None self._mocap_reentry_armed = False @@ -927,6 +1221,12 @@ def _resume_paused_mocap(self) -> None: self._ref_proc.reset_alignment(target_qpos=resume_qpos) logger.info("Mocap session -> ACTIVE (multiprocess episode-reset + reference realignment)") + def _send_reference_command(self, command: str) -> None: + self._reference_command_pub.publish( + COMMAND_TOPIC, + CommandPacket(command=command, timestamp_s=time.monotonic()), + ) + def _resolve_mocap_hold_qpos(self) -> Float64Array: if self._last_commanded_motion_qpos is not None: return self._last_commanded_motion_qpos.copy() @@ -1010,6 +1310,13 @@ def _note_reference_packet(self, reference: ReferencePacket) -> None: return self._last_reference_seq = int(reference.seq) self._latest_reference = reference + if ( + self.provider_kind == "bvh" + and bool(getattr(reference, "playback_paused", False)) + and self.mode == RobotMode.MOCAP + and self._mocap_session.state == MocapSessionState.ACTIVE + ): + self._pause_active_mocap() if not reference.frame_valid: self._consecutive_valid_references = 0 return @@ -1055,7 +1362,7 @@ def _run_hand_worker( ) -> None: def _main() -> None: proxy = _HandSnapshotProxy() - runtime = build_linkerhand_runtime(cfg, proxy) + runtime = build_hand_runtime(cfg) hand_sub = LatestSubscriber(endpoints.hand_pub, HAND_TOPIC) controller_sub = LatestSubscriber(endpoints.controller_pub, CONTROLLER_TOPIC) mode_sub = LatestSubscriber(endpoints.mode_pub, MODE_TOPIC) @@ -1063,8 +1370,8 @@ def _main() -> None: active = False hz = float(cfg_get(_mp_cfg(cfg), "hand_worker_hz", 120.0)) sleep_s = 1.0 / max(hz, 1.0) - runtime.start() try: + runtime.start() while not stop_event.is_set(): command = command_sub.recv_latest() if isinstance(command, CommandPacket) and command.command == "shutdown": @@ -1080,7 +1387,11 @@ def _main() -> None: if isinstance(mode_packet, ModeStatePacket): active = bool(mode_packet.mocap_active) try: - runtime.tick(active=active) + runtime.tick( + controller_snapshot=proxy.controller_snapshot, + hand_snapshot=proxy.hand_snapshot, + active=active, + ) except Exception: logger.exception("Dexterous hand worker tick failed; hand control continues") time.sleep(sleep_s) diff --git a/tests/test_dexterous_hand.py b/tests/test_dexterous_hand.py index 728b08f7..0e65b644 100644 --- a/tests/test_dexterous_hand.py +++ b/tests/test_dexterous_hand.py @@ -6,18 +6,16 @@ import numpy as np import pytest -from teleopit.inputs.pico4_provider import PicoControllerSnapshot, PicoControllerState, PicoHandSnapshot, PicoHandState -from teleopit.sim2real.dexterous_hand import ( - L6PoseSender, - L6RetargetPoseMapper, - LinkerHandRuntime, - SomeHandPoseRuntime, - ThreadedSomeHandPoseRuntime, - VR_HAND_POSE_SPEED, - build_linkerhand_runtime, - parse_linkerhand_config, +from teleopit.inputs.pico4_provider import PicoControllerSnapshot, PicoControllerState +from teleopit.sim2real.hands.linkerhand_l6 import ( + GripperMapper, + LinkerHandL6Device, + SomehandL6Mapper, + parse_linkerhand_l6_config, trigger_to_pose, ) +from teleopit.sim2real.hands.pico_landmarks import pico_hand_to_landmarks +from teleopit.sim2real.hands.worker import HandRuntime class FakeInnerHand: @@ -37,561 +35,167 @@ def __init__(self, *, hand_joint: str, hand_type: str, modbus: str, can: str) -> self.modbus = modbus self.can = can self.hand = FakeInnerHand() - self.close_can_calls = 0 self.speed: list[int] | None = None self.poses: list[list[int]] = [] FakeLinkerHandApi.instances.append(self) def set_speed(self, speed: list[int]) -> None: - self.speed = speed + self.speed = list(speed) def finger_move(self, pose: list[int]) -> None: self.poses.append(list(pose)) - def close_can(self) -> None: - self.close_can_calls += 1 +def _cfg(mode: str = "gripper") -> dict[str, object]: + return { + "input": {"provider": "pico4"}, + "hands": { + "enabled": True, + "driver": "linkerhand_l6", + "mode": mode, + "sides": ["left", "right"], + "rate_hz": 30.0, + "frame_timeout_s": 0.3, + "linkerhand_l6": { + "left_can": "can0", + "right_can": "can1", + "modbus": "None", + "trigger_deadzone": 0.05, + "deadman_threshold": 0.5, + "open_pose": [250, 10, 250, 250, 250, 250], + "close_pose": [79, 10, 0, 0, 0, 0], + }, + "somehand": { + "rate_hz": 60.0, + "max_iterations": 12, + "temporal_filter_alpha": 1.0, + "output_alpha": 1.0, + }, + }, + } -@pytest.fixture(autouse=True) -def fake_linkerhand_sdk(monkeypatch): - FakeLinkerHandApi.instances = [] - fake_module = SimpleNamespace(LinkerHandApi=FakeLinkerHandApi) - monkeypatch.setitem(sys.modules, "LinkerHand.linker_hand_api", fake_module) - yield +def test_pico_hand_to_landmarks_uses_teleopit_adapter() -> None: + joints = np.zeros((26, 7), dtype=np.float64) + joints[:, 0] = np.arange(26) + joints[:, 1] = np.arange(26) + 100 + joints[:, 2] = np.arange(26) + 200 -class SnapshotProvider: - def __init__(self) -> None: - self.snapshot: PicoControllerSnapshot | None = None + landmarks = pico_hand_to_landmarks(joints) - def get_controller_snapshot(self) -> PicoControllerSnapshot | None: - return self.snapshot + assert landmarks.shape == (21, 3) + np.testing.assert_allclose(landmarks[0], [1.0, -201.0, 101.0]) + np.testing.assert_allclose(landmarks[-1], [25.0, -225.0, 125.0]) -class HandSnapshotProvider: - def __init__(self) -> None: - self.snapshot: PicoHandSnapshot | None = None - - def get_hand_snapshot(self) -> PicoHandSnapshot | None: - return self.snapshot - - -def _snapshot( - *, - left: PicoControllerState | None = None, - right: PicoControllerState | None = None, - timestamp_s: float = 10.0, - seq: int = 1, -) -> PicoControllerSnapshot: - missing = PicoControllerState(raw=False, grip=0.0, trigger=0.0, present=False) - return PicoControllerSnapshot( - left=left or missing, - right=right or missing, - timestamp_s=timestamp_s, - seq=seq, +def test_gripper_mapper_maps_trigger_and_deadman() -> None: + cfg = parse_linkerhand_l6_config(_cfg()) + mapper = GripperMapper(cfg) + snapshot = PicoControllerSnapshot( + left=PicoControllerState(raw=True, grip=1.0, trigger=1.0, present=True), + right=PicoControllerState(raw=True, grip=0.1, trigger=1.0, present=True), + timestamp_s=10.0, + seq=1, ) + commands = mapper.map(controller_snapshot=snapshot, hand_snapshot=None, active=True, now_s=10.0) -def _hand_snapshot( - *, - left: PicoHandState | None = None, - right: PicoHandState | None = None, - timestamp_s: float = 10.0, - seq: int = 1, -) -> PicoHandSnapshot: - missing = PicoHandState(active=False, joints=np.zeros((26, 7), dtype=np.float64), present=False) - return PicoHandSnapshot( - left=left or missing, - right=right or missing, - timestamp_s=timestamp_s, - seq=seq, - ) + assert commands[0].side == "left" + assert commands[0].pose == cfg.close_pose + assert commands[1].side == "right" + assert commands[1].pose == cfg.open_pose -def _hand_state(*, active: bool = True, value: float = 1.0, present: bool = True) -> PicoHandState: - joints = np.zeros((26, 7), dtype=np.float64) - joints[:, 0] = value - return PicoHandState(active=active, joints=joints, present=present) - - -def _runtime(provider: SnapshotProvider) -> LinkerHandRuntime: - cfg = parse_linkerhand_config( - { - "dexterous_hand": { - "enabled": True, - "hand_type": "both", - } - } +def test_hand_mappers_force_open_once_when_inactive() -> None: + cfg = parse_linkerhand_l6_config(_cfg()) + snapshot = PicoControllerSnapshot( + left=PicoControllerState(raw=True, grip=1.0, trigger=1.0, present=True), + right=PicoControllerState(raw=True, grip=1.0, trigger=1.0, present=True), + timestamp_s=10.0, + seq=1, ) - runtime = LinkerHandRuntime(cfg, provider) - runtime.start() - return runtime + gripper = GripperMapper(cfg) + assert gripper.map(controller_snapshot=None, hand_snapshot=None, active=False, now_s=9.0) == () + assert gripper.map(controller_snapshot=snapshot, hand_snapshot=None, active=True, now_s=10.0) + first_inactive = gripper.map(controller_snapshot=snapshot, hand_snapshot=None, active=False, now_s=10.1) + assert [command.force for command in first_inactive] == [True, True] + assert gripper.map(controller_snapshot=snapshot, hand_snapshot=None, active=False, now_s=10.2) == () -def _wait_runtime_idle(runtime: LinkerHandRuntime) -> None: - assert runtime._sender.wait_idle(timeout_s=1.0) + somehand = SomehandL6Mapper(cfg) + assert somehand.map(controller_snapshot=None, hand_snapshot=None, active=False, now_s=9.0) == () + somehand._active = True + first_inactive = somehand.map(controller_snapshot=None, hand_snapshot=None, active=False, now_s=10.0) + assert [command.force for command in first_inactive] == [True, True] + assert somehand.map(controller_snapshot=None, hand_snapshot=None, active=False, now_s=10.1) == () def test_trigger_to_pose_applies_deadzone_and_fixed_thumb_yaw() -> None: - pose = trigger_to_pose( + assert trigger_to_pose( 0.5, open_pose=[250, 10, 250, 250, 250, 250], close_pose=[79, 10, 0, 0, 0, 0], deadzone=0.05, thumb_yaw_default=10, - ) - - assert pose == [164, 10, 125, 125, 125, 125] - - -def test_parse_config_keeps_gripper_default_speed() -> None: - cfg = parse_linkerhand_config( - { - "dexterous_hand": { - "mode": "gripper", - "hand_type": "both", - } - } - ) - - assert cfg.speed == (50, 50, 50, 50, 50, 50) - - -def test_parse_config_sets_vr_hand_pose_speed_to_max() -> None: - cfg = parse_linkerhand_config( - { - "dexterous_hand": { - "mode": "vr_hand_pose", - "hand_type": "both", - "speed": [50, 50, 50, 50, 50, 50], - } - } - ) - - assert cfg.speed == (255, 255, 255, 255, 255, 255) - - -def test_parse_config_accepts_somehand_low_latency_overrides() -> None: - cfg = parse_linkerhand_config( - { - "dexterous_hand": { - "mode": "vr_hand_pose", - "hand_type": "both", - "somehand": { - "rate": 60.0, - "threaded": True, - "max_iterations": 12, - "temporal_filter_alpha": 1.0, - "output_alpha": 1.0, - }, - } - } - ) - - assert cfg.vr_hand_pose_rate == 60.0 - assert cfg.somehand_threaded is True - assert cfg.somehand_max_iterations == 12 - assert cfg.somehand_temporal_filter_alpha == 1.0 - assert cfg.somehand_output_alpha == 1.0 - - -def test_vr_hand_pose_speed_constant_is_max() -> None: - assert tuple(VR_HAND_POSE_SPEED) == (255, 255, 255, 255, 255, 255) - - -def test_runtime_opens_when_deadman_released() -> None: - provider = SnapshotProvider() - runtime = _runtime(provider) - provider.snapshot = _snapshot( - left=PicoControllerState(raw=True, grip=0.1, trigger=1.0), - right=PicoControllerState(raw=True, grip=0.1, trigger=1.0), - ) - - runtime.tick(active=True, now_s=10.0) - _wait_runtime_idle(runtime) - - assert runtime._sender._last_pose["left"] == list(runtime.config.open_pose) - assert runtime._sender._last_pose["right"] == list(runtime.config.open_pose) - - -def test_runtime_maps_present_controller_even_without_raw_flag() -> None: - provider = SnapshotProvider() - runtime = _runtime(provider) - provider.snapshot = _snapshot( - left=PicoControllerState(raw=False, grip=1.0, trigger=1.0, present=True), - right=PicoControllerState(raw=False, grip=1.0, trigger=0.0, present=True), - ) - - runtime.tick(active=True, now_s=10.0) - _wait_runtime_idle(runtime) - - assert runtime._sender._last_pose["left"] == list(runtime.config.close_pose) - assert runtime._sender._last_pose["right"] == list(runtime.config.open_pose) - - -def test_runtime_maps_trigger_when_deadman_active() -> None: - provider = SnapshotProvider() - runtime = _runtime(provider) - provider.snapshot = _snapshot( - left=PicoControllerState(raw=True, grip=1.0, trigger=1.0), - right=PicoControllerState(raw=True, grip=1.0, trigger=0.0), - ) - - runtime.tick(active=True, now_s=10.0) - _wait_runtime_idle(runtime) - - assert runtime._sender._last_pose["left"] == list(runtime.config.close_pose) - assert runtime._sender._last_pose["right"] == list(runtime.config.open_pose) - - -def test_runtime_opens_on_timeout_and_inactive_mode() -> None: - provider = SnapshotProvider() - runtime = _runtime(provider) - provider.snapshot = _snapshot( - left=PicoControllerState(raw=True, grip=1.0, trigger=1.0), - right=PicoControllerState(raw=True, grip=1.0, trigger=1.0), - timestamp_s=10.0, - ) + ) == [164, 10, 125, 125, 125, 125] - runtime.tick(active=True, now_s=10.0) - _wait_runtime_idle(runtime) - assert runtime._sender._last_pose["left"] == list(runtime.config.close_pose) - provider.snapshot = SimpleNamespace(timestamp_s=9.0, seq=2, left=None, right=None) - runtime.tick(active=True, now_s=20.0) - _wait_runtime_idle(runtime) - assert runtime._sender._last_pose["left"] == list(runtime.config.open_pose) - - provider.snapshot = _snapshot( - left=PicoControllerState(raw=True, grip=1.0, trigger=1.0), - right=PicoControllerState(raw=True, grip=1.0, trigger=1.0), - timestamp_s=20.1, - ) - runtime.tick(active=True, now_s=20.1) - _wait_runtime_idle(runtime) - assert runtime._sender._last_pose["left"] == list(runtime.config.close_pose) - - runtime.tick(active=False, now_s=20.2) - _wait_runtime_idle(runtime) - assert runtime._sender._last_pose["left"] == list(runtime.config.open_pose) - - -def test_pose_sender_close_leaves_can_interfaces_up() -> None: - cfg = parse_linkerhand_config( - { - "dexterous_hand": { - "enabled": True, - "hand_type": "both", - } - } +def test_linkerhand_l6_device_starts_sdk(monkeypatch) -> None: + FakeLinkerHandApi.instances = [] + monkeypatch.setitem( + sys.modules, + "LinkerHand.linker_hand_api", + SimpleNamespace(LinkerHandApi=FakeLinkerHandApi), ) - sender = L6PoseSender(cfg) - sender.start() + cfg = parse_linkerhand_l6_config(_cfg()) + device = LinkerHandL6Device(cfg) - sender.close() + device.connect() + device.send_pose("left", cfg.close_pose) + device.close() - assert [hand.close_can_calls for hand in FakeLinkerHandApi.instances] == [0, 0] + assert [hand.can for hand in FakeLinkerHandApi.instances] == ["can0", "can1"] + assert FakeLinkerHandApi.instances[0].speed == [50, 50, 50, 50, 50, 50] + assert FakeLinkerHandApi.instances[0].poses[-2] == list(cfg.close_pose) assert [hand.hand.close_calls for hand in FakeLinkerHandApi.instances] == [1, 1] -def test_pose_sender_cleans_up_partial_start_failure(monkeypatch) -> None: - created_hands = [] - - class FailingLinkerHandApi: - def __init__(self, *, hand_joint: str, hand_type: str, modbus: str, can: str) -> None: - del hand_joint, modbus, can - if hand_type == "right": - raise RuntimeError("right hand failed") - self.hand = FakeInnerHand() - self.close_can_calls = 0 - created_hands.append(self) - - def set_speed(self, speed: list[int]) -> None: - self.speed = speed +def test_hand_runtime_closes_device_when_mapper_start_fails() -> None: + calls: list[str] = [] - def close_can(self) -> None: - self.close_can_calls += 1 + class FakeDevice: + def connect(self) -> None: + calls.append("connect") - fake_module = SimpleNamespace(LinkerHandApi=FailingLinkerHandApi) - monkeypatch.setitem(sys.modules, "LinkerHand.linker_hand_api", fake_module) + def send_pose(self, *args, **kwargs) -> None: + raise AssertionError("send_pose should not be called") - cfg = parse_linkerhand_config( - { - "dexterous_hand": { - "enabled": True, - "hand_type": "both", - } - } - ) - sender = L6PoseSender(cfg) + def open_all(self, *args, **kwargs) -> None: + calls.append("open_all") - with pytest.raises(RuntimeError, match="right hand failed"): - sender.start() + def close(self) -> None: + calls.append("close") - assert sender.started is False - assert sender._hands == {} - assert len(created_hands) == 1 - assert created_hands[0].close_can_calls == 0 - assert created_hands[0].hand.close_calls == 1 + class FailingMapper: + def start(self) -> None: + calls.append("mapper_start") + raise RuntimeError("mapper failed") + def map(self, *args, **kwargs): + return () -def _install_fake_somehand(monkeypatch, *, left_qpos: list[float], right_qpos: list[float]) -> None: - class FakeHandFrame: - def __init__(self, *, landmarks_3d, landmarks_2d, hand_side): - self.landmarks_3d = landmarks_3d - self.landmarks_2d = landmarks_2d - self.hand_side = hand_side - - class FakeBiHandFrame: - def __init__(self, *, left=None, right=None): - self.left = left - self.right = right - - class FakeHandModel: - def get_joint_name_to_qpos_index(self): - return { - "thumb_cmc_pitch": 0, - "thumb_cmc_roll": 1, - "index_mcp_pitch": 2, - "middle_mcp_pitch": 3, - "ring_mcp_pitch": 4, - "pinky_mcp_pitch": 5, - } - - class FakeLandmarkFilter: - def __init__(self): - self.alpha = 0.65 - - class FakeRetargeter: - def __init__(self): - self._max_iterations = 60 - self._output_alpha = 0.92 - self.landmark_filter = FakeLandmarkFilter() - - class FakeEngine: - def __init__(self): - self.left_engine = SimpleNamespace( - config=SimpleNamespace(hand=SimpleNamespace(name="linkerhand_l6_left", mjcf_path="left.xml")), - hand_model=FakeHandModel(), - retargeter=FakeRetargeter(), - ) - self.right_engine = SimpleNamespace( - config=SimpleNamespace(hand=SimpleNamespace(name="linkerhand_l6_right", mjcf_path="right.xml")), - hand_model=FakeHandModel(), - retargeter=FakeRetargeter(), - ) - - @classmethod - def from_config_path(cls, _path: str): - return cls() - - def process(self, frame): - return SimpleNamespace( - left_detected=frame.left is not None, - right_detected=frame.right is not None, - left=SimpleNamespace(qpos=np.asarray(left_qpos, dtype=np.float64)), - right=SimpleNamespace(qpos=np.asarray(right_qpos, dtype=np.float64)), - ) - - fake_api = SimpleNamespace( - BiHandFrame=FakeBiHandFrame, - BiHandRetargetingEngine=FakeEngine, - HandFrame=FakeHandFrame, - ) - fake_pico = SimpleNamespace(pico_hand_to_landmarks=lambda joints: np.asarray(joints, dtype=np.float64)[:21, :3]) - monkeypatch.setitem(sys.modules, "somehand.api", fake_api) - monkeypatch.setitem(sys.modules, "somehand.pico_input", fake_pico) + def close(self) -> None: + calls.append("mapper_close") + runtime = HandRuntime(FakeDevice(), FailingMapper()) -def test_vr_hand_pose_runtime_holds_last_pose_when_hand_pose_disappears(monkeypatch, tmp_path) -> None: - _install_fake_somehand( - monkeypatch, - left_qpos=[0.837758, 0.0, 1.134464, 1.134464, 1.134464, 1.134464], - right_qpos=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ) - config_path = tmp_path / "linkerhand_l6_bihand.yaml" - config_path.write_text("left: {}\nright: {}\n", encoding="utf-8") - provider = HandSnapshotProvider() - cfg = parse_linkerhand_config( - { - "input": {"provider": "pico4"}, - "dexterous_hand": { - "mode": "vr_hand_pose", - "hand_type": "both", - "somehand": {"config_path": str(config_path), "sdk_root": "third_party/linkerhand-python-sdk"}, - }, - } - ) - runtime = SomeHandPoseRuntime(cfg, provider) - runtime.start() + with pytest.raises(RuntimeError, match="mapper failed"): + runtime.start() - provider.snapshot = _hand_snapshot( - left=_hand_state(active=True, value=1.0), - right=_hand_state(active=True, value=2.0), - timestamp_s=10.0, - ) - runtime.tick(active=True, now_s=10.0) - assert runtime._sender.wait_idle(timeout_s=1.0) + assert calls == ["connect", "mapper_start", "close"] - assert runtime._sender._last_pose["left"] == [0, 238, 0, 0, 0, 0] - assert runtime._sender._last_pose["right"] == [255, 238, 255, 255, 255, 255] - provider.snapshot = _hand_snapshot( - left=_hand_state(active=False, value=9.0), - right=_hand_state(active=False, value=9.0, present=False), - timestamp_s=10.1, - seq=2, - ) - runtime.tick(active=True, now_s=10.1) - assert runtime._sender.wait_idle(timeout_s=1.0) - - assert runtime._sender._last_pose["left"] == [0, 238, 0, 0, 0, 0] - assert runtime._sender._last_pose["right"] == [255, 238, 255, 255, 255, 255] - - runtime.tick(active=False, now_s=10.2) - assert runtime._sender.wait_idle(timeout_s=1.0) - assert runtime._sender._last_pose["left"] == list(runtime.config.open_pose) - assert runtime._sender._last_pose["right"] == list(runtime.config.open_pose) - runtime.close() - - -def test_vr_hand_pose_runtime_applies_low_latency_overrides(monkeypatch, tmp_path) -> None: - _install_fake_somehand( - monkeypatch, - left_qpos=[0.837758, -0.087266, 1.134464, 1.134464, 1.134464, 1.134464], - right_qpos=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ) - config_path = tmp_path / "linkerhand_l6_bihand.yaml" - config_path.write_text("left: {}\nright: {}\n", encoding="utf-8") - provider = HandSnapshotProvider() - cfg = parse_linkerhand_config( - { - "input": {"provider": "pico4"}, - "dexterous_hand": { - "mode": "vr_hand_pose", - "hand_type": "both", - "somehand": { - "config_path": str(config_path), - "sdk_root": "third_party/linkerhand-python-sdk", - "rate": 60.0, - "max_iterations": 12, - "temporal_filter_alpha": 1.0, - "output_alpha": 1.0, - }, - }, - } - ) - runtime = SomeHandPoseRuntime(cfg, provider) - runtime.start() - - assert runtime._interval_s == pytest.approx(1.0 / 60.0) - assert runtime._engine.left_engine.retargeter._max_iterations == 12 - assert runtime._engine.left_engine.retargeter._output_alpha == 1.0 - assert runtime._engine.left_engine.retargeter.landmark_filter.alpha == 1.0 - assert runtime._engine.right_engine.retargeter._max_iterations == 12 - runtime.close() - - -def test_build_linkerhand_runtime_returns_threaded_vr_hand_pose_runtime() -> None: - provider = HandSnapshotProvider() - runtime = build_linkerhand_runtime( - { - "input": {"provider": "pico4"}, - "dexterous_hand": { - "mode": "vr_hand_pose", - "hand_type": "both", - "somehand": {"threaded": True}, - }, - }, - provider, - ) - - assert isinstance(runtime, ThreadedSomeHandPoseRuntime) - - -def test_l6_retarget_pose_mapper_uses_sdk_order_and_model_joint_names() -> None: - class FakeHandModel: - def get_joint_name_to_qpos_index(self): - return { - "thumb_pitch": 2, - "thumb_roll": 0, - "index_pitch": 5, - "middle_pitch": 1, - "ring_pitch": 4, - "little_pitch": 3, - } - - qpos = np.zeros(6, dtype=np.float64) - qpos[2] = 0.837758 - qpos[0] = -0.087266 - qpos[5] = 1.134464 - qpos[1] = 0.0 - qpos[4] = 1.134464 - qpos[3] = 0.0 - - mapper = L6RetargetPoseMapper( - FakeHandModel(), - hand_type="right", - sdk_root="third_party/linkerhand-python-sdk", - ) - - assert mapper.qpos_to_pose(qpos) == [0, 255, 0, 255, 0, 255] - - -def test_l6_retarget_pose_mapper_supports_somehand_l6_prefixed_roll_joint_names() -> None: - class FakeHandModel: - def get_joint_name_to_qpos_index(self): - return { - "lh_thumb_cmc_pitch": 8, - "lh_thumb_cmc_roll": 9, - "lh_thumb_dip": 10, - "lh_index_mcp_pitch": 1, - "lh_index_dip": 0, - "lh_middle_mcp_pitch": 3, - "lh_middle_dip": 2, - "lh_ring_mcp_pitch": 5, - "lh_ring_dip": 4, - "lh_pinky_mcp_pitch": 7, - "lh_pinky_dip": 6, - } - - qpos = np.zeros(11, dtype=np.float64) - qpos[8] = 0.837758 - qpos[9] = -0.087266 - qpos[1] = 1.134464 - qpos[3] = 0.0 - qpos[5] = 1.134464 - qpos[7] = 0.0 - - mapper = L6RetargetPoseMapper( - FakeHandModel(), - hand_type="left", - sdk_root="third_party/linkerhand-python-sdk", - ) - - assert mapper.qpos_to_pose(qpos) == [0, 255, 0, 255, 0, 255] - - -def test_l6_retarget_pose_mapper_fails_when_model_joint_mapping_is_unknown() -> None: - class FakeHandModel: - def get_joint_name_to_qpos_index(self): - return { - "thumb_pitch": 0, - "thumb_roll": 1, - "index_pitch": 2, - "middle_pitch": 3, - "ring_pitch": 4, - } - - with pytest.raises(ValueError, match="pinky_mcp_pitch"): - L6RetargetPoseMapper( - FakeHandModel(), - hand_type="right", - sdk_root="third_party/linkerhand-python-sdk", - ) - - -def test_pose_sender_wraps_sdk_system_exit_and_cleans_up(monkeypatch) -> None: +def test_linkerhand_l6_device_wraps_sdk_system_exit_and_cleans_up(monkeypatch) -> None: created_hands = [] class ExitingLinkerHandApi: @@ -600,33 +204,24 @@ def __init__(self, *, hand_joint: str, hand_type: str, modbus: str, can: str) -> if hand_type == "right": raise SystemExit(1) self.hand = FakeInnerHand() - self.close_can_calls = 0 created_hands.append(self) def set_speed(self, speed: list[int]) -> None: - self.speed = speed - - def close_can(self) -> None: - self.close_can_calls += 1 + self.speed = list(speed) - fake_module = SimpleNamespace(LinkerHandApi=ExitingLinkerHandApi) - monkeypatch.setitem(sys.modules, "LinkerHand.linker_hand_api", fake_module) + def finger_move(self, pose: list[int]) -> None: + self.pose = list(pose) - cfg = parse_linkerhand_config( - { - "dexterous_hand": { - "enabled": True, - "hand_type": "both", - } - } + monkeypatch.setitem( + sys.modules, + "LinkerHand.linker_hand_api", + SimpleNamespace(LinkerHandApi=ExitingLinkerHandApi), ) - sender = L6PoseSender(cfg) + cfg = parse_linkerhand_l6_config(_cfg()) + device = LinkerHandL6Device(cfg) with pytest.raises(RuntimeError, match="LinkerHand SDK exited during startup"): - sender.start() + device.connect() - assert sender.started is False - assert sender._hands == {} assert len(created_hands) == 1 - assert created_hands[0].close_can_calls == 0 assert created_hands[0].hand.close_calls == 1 diff --git a/tests/test_sim2real_dim.py b/tests/test_sim2real_dim.py index 51b8fa15..dea8a8dd 100644 --- a/tests/test_sim2real_dim.py +++ b/tests/test_sim2real_dim.py @@ -1,134 +1,54 @@ -"""Integration tests for Sim2RealController startup validation.""" from __future__ import annotations -from pathlib import Path -from unittest.mock import MagicMock +from types import SimpleNamespace import pytest -from conftest import find_g1_xml_path, requires_mujoco +from teleopit.sim2real.mp.runtime import _RobotControlWorker -_XML_PATH = find_g1_xml_path() -_skip_no_xml = pytest.mark.skipif(_XML_PATH is None, reason="Robot XML not found") - -_DEFAULT_ANGLES_29 = [ - -0.312, 0.0, 0.0, 0.669, -0.363, 0.0, - -0.312, 0.0, 0.0, 0.669, -0.363, 0.0, - 0.0, 0.0, 0.0, - 0.2, 0.2, 0.0, 0.6, 0.0, 0.0, 0.0, - 0.2, -0.2, 0.0, 0.6, 0.0, 0.0, 0.0, -] - - -def _make_dummy_policy(expected_obs_dim: int, *, multi_input: bool = True) -> MagicMock: - policy = MagicMock() - policy._expected_obs_dim = expected_obs_dim - policy._multi_input = multi_input - return policy - - -def _make_sim2real_cfg(tmp_path: Path) -> dict: - policy_path = tmp_path / "policy.onnx" - policy_path.write_bytes(b"dummy") - bvh_path = tmp_path / "clip.bvh" - bvh_path.write_text("HIERARCHY\n", encoding="utf-8") - +def _cfg() -> dict[str, object]: return { "policy_hz": 50.0, - "real_robot": {"network_interface": "lo"}, - "gamepad": {}, - "mocap_switch": {}, - "input": { - "provider": "bvh", - "bvh_file": str(bvh_path), - "bvh_format": "hc_mocap", - "robot_name": "unitree_g1", - }, - "controller": { - "policy_path": str(policy_path), - "action_scale": [0.5] * 29, - "default_dof_pos": list(_DEFAULT_ANGLES_29), - }, + "input": {"provider": "bvh"}, + "runtime": {}, + "real_robot": {}, + "controller": {}, "robot": { "num_actions": 29, - "xml_path": _XML_PATH or "", - "default_angles": _DEFAULT_ANGLES_29, - "action_scale": [0.5] * 29, - "anchor_body_name": "torso_link", + "default_angles": [0.0] * 29, + "xml_path": "robot.xml", }, } -def _apply_sim2real_mocks(monkeypatch, policy_mock: MagicMock) -> None: - dummy_provider = MagicMock() - dummy_provider.human_format = "hc_mocap" - dummy_provider.fps = 30 - dummy_provider.__len__.return_value = 1 - dummy_provider.get_frame_by_index.return_value = {} - obs_builder = MagicMock(total_obs_size=166) - - def _build_components(*args, **kwargs): - if policy_mock._expected_obs_dim != 166: - raise ValueError( - f"Only 166D velcmd_history ONNX policies are supported here; " - f"obs_builder produces 166D but policy expects {policy_mock._expected_obs_dim}D." - ) - if not policy_mock._multi_input: - raise ValueError( - "Sim2real requires an ONNX policy with dual inputs ('obs' and 'obs_history')." - ) - return MagicMock( - input_provider=dummy_provider, - retargeter=MagicMock(), - controller=policy_mock, - obs_builder=obs_builder, - ) - +def test_robot_worker_requires_dual_input_policy(monkeypatch) -> None: + policy = SimpleNamespace(_multi_input=False) + obs_builder = SimpleNamespace(total_obs_size=166) monkeypatch.setattr( - "teleopit.sim2real.controller.UnitreeG1Robot", - lambda cfg: MagicMock(check_mode=MagicMock(return_value={"name": "mock"})), - ) - monkeypatch.setattr( - "teleopit.sim2real.controller.UnitreeRemote", - MagicMock, - ) - - monkeypatch.setattr( - "teleopit.sim2real.controller.build_sim2real_mocap_components", - _build_components, + "teleopit.sim2real.mp.runtime._build_policy_components", + lambda **_kwargs: (policy, obs_builder), ) + worker = object.__new__(_RobotControlWorker) + worker.cfg = _cfg() -@requires_mujoco -@_skip_no_xml -class TestSim2RealStartupDim: - def test_velcmd_history_startup_builds_166d_obs(self, monkeypatch, tmp_path: Path) -> None: - from teleopit.sim2real.controller import Sim2RealController + with pytest.raises(ValueError, match="dual inputs"): + worker._build_policy_and_obs() - policy_mock = _make_dummy_policy(166, multi_input=True) - _apply_sim2real_mocks(monkeypatch, policy_mock) - cfg = _make_sim2real_cfg(tmp_path) - ctrl = Sim2RealController(cfg) - assert ctrl.obs_builder.total_obs_size == 166 - - def test_non_166_policy_is_rejected(self, monkeypatch, tmp_path: Path) -> None: - from teleopit.sim2real.controller import Sim2RealController - - policy_mock = _make_dummy_policy(160, multi_input=True) - _apply_sim2real_mocks(monkeypatch, policy_mock) - cfg = _make_sim2real_cfg(tmp_path) - - with pytest.raises(ValueError, match="Only 166D"): - Sim2RealController(cfg) +def test_robot_worker_accepts_166d_dual_input_policy(monkeypatch) -> None: + policy = SimpleNamespace(_multi_input=True) + obs_builder = SimpleNamespace(total_obs_size=166) + monkeypatch.setattr( + "teleopit.sim2real.mp.runtime._build_policy_components", + lambda **_kwargs: (policy, obs_builder), + ) - def test_single_input_policy_is_rejected(self, monkeypatch, tmp_path: Path) -> None: - from teleopit.sim2real.controller import Sim2RealController + worker = object.__new__(_RobotControlWorker) + worker.cfg = _cfg() - policy_mock = _make_dummy_policy(166, multi_input=False) - _apply_sim2real_mocks(monkeypatch, policy_mock) - cfg = _make_sim2real_cfg(tmp_path) + built_policy, built_obs_builder = worker._build_policy_and_obs() - with pytest.raises(ValueError, match="dual inputs"): - Sim2RealController(cfg) + assert built_policy is policy + assert built_obs_builder is obs_builder diff --git a/tests/test_sim2real_multiprocess.py b/tests/test_sim2real_multiprocess.py index 428dc886..db653532 100644 --- a/tests/test_sim2real_multiprocess.py +++ b/tests/test_sim2real_multiprocess.py @@ -9,26 +9,31 @@ from teleopit.runtime.mocap_session import MocapSessionState from teleopit.sim2real.mp.ipc import HEALTH_TOPIC, LatestSubscriber, ZmqPublisher -from teleopit.sim2real.mp import resolve_sim2real_runtime_mode from teleopit.sim2real.mp.messages import ReferencePacket, SharedFrameDescriptor -from teleopit.sim2real.mp.runtime import MultiprocessSim2RealController, _RobotControlWorker, _human_frame_is_valid +from teleopit.sim2real.mp.runtime import RobotMode, Sim2RealRuntime, _RobotControlWorker, _human_frame_is_valid from teleopit.sim2real.mp.shm import SharedFrameRingReader, SharedFrameRingWriter -def test_resolve_runtime_auto_uses_multiprocess_for_pico4() -> None: +def test_sim2real_runtime_rejects_legacy_runtime_keys() -> None: cfg = {"sim2real_runtime": "auto", "input": {"provider": "pico4"}} - assert resolve_sim2real_runtime_mode(cfg) == "multiprocess" + with pytest.raises(ValueError, match="Legacy sim2real config keys"): + Sim2RealRuntime(cfg) -def test_resolve_runtime_auto_uses_single_process_for_bvh() -> None: - cfg = {"sim2real_runtime": "auto", "input": {"provider": "bvh"}} - assert resolve_sim2real_runtime_mode(cfg) == "single_process" +def test_sim2real_runtime_accepts_bvh_provider() -> None: + cfg = {"input": {"provider": "bvh"}, "runtime": {"shutdown_timeout_s": 0.01}} + runtime = Sim2RealRuntime(cfg) + runtime.shutdown() -def test_multiprocess_requires_pico4_provider() -> None: - cfg = {"sim2real_runtime": "multiprocess", "input": {"provider": "bvh"}} - with pytest.raises(ValueError, match="requires input.provider=pico4"): - resolve_sim2real_runtime_mode(cfg) +def test_sim2real_runtime_rejects_hands_without_pico_provider() -> None: + cfg = { + "input": {"provider": "bvh"}, + "runtime": {"shutdown_timeout_s": 0.01}, + "hands": {"enabled": True, "driver": "linkerhand_l6", "mode": "gripper"}, + } + with pytest.raises(ValueError, match="hands.enabled=true requires input.provider=pico4"): + Sim2RealRuntime(cfg) def test_shared_frame_ring_roundtrip() -> None: @@ -51,11 +56,11 @@ def test_shared_frame_ring_roundtrip() -> None: def test_multiprocess_rejects_unsupported_video_source() -> None: cfg = { - "sim2real_runtime": "multiprocess", "input": {"provider": "pico4", "video": {"enabled": True, "source": "mujoco"}}, + "runtime": {}, } with pytest.raises(ValueError, match="only supports input.video.source=realsense or test-pattern"): - MultiprocessSim2RealController(cfg) + Sim2RealRuntime(cfg) def test_zmq_endpoint_allows_one_publisher_and_subscribers() -> None: @@ -74,7 +79,7 @@ def test_zmq_endpoint_allows_one_publisher_and_subscribers() -> None: context.term() -def test_run_sim2real_single_process_shutdowns_on_exception(monkeypatch) -> None: +def test_run_sim2real_shutdowns_on_exception(monkeypatch) -> None: script_path = Path.cwd() / "scripts" / "run" / "run_sim2real.py" spec = importlib.util.spec_from_file_location("test_run_sim2real", script_path) assert spec is not None and spec.loader is not None @@ -83,7 +88,7 @@ def test_run_sim2real_single_process_shutdowns_on_exception(monkeypatch) -> None calls: list[str] = [] - class FailingController: + class FailingRuntime: def __init__(self, _cfg: object) -> None: calls.append("init") @@ -99,8 +104,7 @@ def shutdown(self) -> None: controller=SimpleNamespace(policy_path="policy.onnx"), ) monkeypatch.setattr(run_sim2real, "validate_policy_path", lambda *_args, **_kwargs: None) - monkeypatch.setattr(run_sim2real, "resolve_sim2real_runtime_mode", lambda _cfg: "single_process") - monkeypatch.setattr(run_sim2real, "Sim2RealController", FailingController) + monkeypatch.setattr(run_sim2real, "Sim2RealRuntime", FailingRuntime) with pytest.raises(RuntimeError, match="boom"): run_sim2real._run_sim2real(cfg) @@ -132,19 +136,18 @@ def join(self, timeout: float | None = None) -> None: def terminate(self) -> None: self.terminated = True - started_process = FakeProcess(name="pico_io") + started_process = FakeProcess(name="pico_input") - def fake_start_processes(self: MultiprocessSim2RealController) -> None: + def fake_start_processes(self: Sim2RealRuntime) -> None: started_process.started = True self._processes.append(started_process) raise RuntimeError("start failed") cfg = { - "sim2real_runtime": "multiprocess", "input": {"provider": "pico4"}, - "multiprocess": {"shutdown_timeout_s": 0.01}, + "runtime": {"shutdown_timeout_s": 0.01}, } - controller = MultiprocessSim2RealController(cfg) + controller = Sim2RealRuntime(cfg) monkeypatch.setattr(controller, "_start_processes", fake_start_processes.__get__(controller)) with pytest.raises(RuntimeError, match="start failed"): @@ -203,6 +206,7 @@ def test_robot_worker_requires_consecutive_valid_references(monkeypatch) -> None worker._consecutive_valid_references = 0 worker._check_frames = 2 worker._max_reference_age_s = 0.25 + worker.provider_kind = "pico4" worker._reference_age_s = lambda: 0.0 worker._mocap_session = SimpleNamespace(state=MocapSessionState.ACTIVE) worker._last_commanded_motion_qpos = np.zeros(36, dtype=np.float64) @@ -255,3 +259,51 @@ def test_robot_worker_requires_consecutive_valid_references(monkeypatch) -> None worker._note_reference_packet(fresh_packet) assert worker._latest_reference == fresh_packet assert worker._consecutive_valid_references == 1 + + +def test_robot_worker_replays_bvh_on_mocap_entry() -> None: + worker = object.__new__(_RobotControlWorker) + commands: list[str] = [] + worker.provider_kind = "bvh" + worker.robot = SimpleNamespace(get_state=lambda: SimpleNamespace()) + worker._standing_qpos = np.zeros(36, dtype=np.float64) + worker._mocap_reentry_armed = True + worker._reset_policy_state = lambda: None + worker._build_resume_alignment_qpos = lambda _hold, _state: np.ones(36, dtype=np.float64) + worker._ref_proc = SimpleNamespace(reset_alignment=lambda **_kwargs: None) + worker._send_reference_command = commands.append + + worker._transition_to_mocap() + + assert worker.mode == RobotMode.MOCAP + assert commands == ["replay_mocap"] + + +def test_robot_worker_pauses_when_bvh_reference_reports_paused() -> None: + worker = object.__new__(_RobotControlWorker) + worker.provider_kind = "bvh" + worker.mode = RobotMode.MOCAP + worker._mocap_session = SimpleNamespace(state=MocapSessionState.ACTIVE) + worker._last_reference_seq = -1 + worker._consecutive_valid_references = 0 + paused: list[str] = [] + + def pause_active_mocap() -> None: + paused.append("pause") + worker._mocap_session.state = MocapSessionState.PAUSED + + worker._pause_active_mocap = pause_active_mocap + packet = ReferencePacket( + qpos=np.zeros(36, dtype=np.float64), + timestamp_s=1.0, + seq=1, + source_timestamp_s=0.0, + source_seq=0, + playback_paused=True, + playback_finished=True, + ) + + worker._note_reference_packet(packet) + + assert paused == ["pause"] + assert worker._latest_reference is packet diff --git a/tests/test_sim2real_runtime.py b/tests/test_sim2real_runtime.py deleted file mode 100644 index 1320dc33..00000000 --- a/tests/test_sim2real_runtime.py +++ /dev/null @@ -1,848 +0,0 @@ -from __future__ import annotations - -from types import SimpleNamespace - -import numpy as np -import pytest - -from teleopit.inputs.realtime_packet import ControlEvent, ControlEventType, RealtimeInputPacket - - -class DummyRobot: - def __init__(self, _cfg: object) -> None: - self._state = SimpleNamespace( - quat=np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), - qpos=np.zeros(29, dtype=np.float32), - qvel=np.zeros(29, dtype=np.float32), - ang_vel=np.zeros(3, dtype=np.float32), - ) - self.sent_positions: list[np.ndarray] = [] - self.sent_gains: list[tuple[np.ndarray | None, np.ndarray | None]] = [] - self.lock_calls = 0 - - def enter_debug_mode(self) -> bool: - return True - - def lock_all_joints(self) -> None: - self.lock_calls += 1 - - def get_state(self) -> SimpleNamespace: - return self._state - - def send_positions(self, target_dof_pos: np.ndarray, kp: np.ndarray | None = None, kd: np.ndarray | None = None) -> None: - self.sent_positions.append(np.asarray(target_dof_pos, dtype=np.float32)) - self.sent_gains.append(( - None if kp is None else np.asarray(kp, dtype=np.float32), - None if kd is None else np.asarray(kd, dtype=np.float32), - )) - - def set_damping(self) -> None: - pass - - def exit_debug_mode(self) -> None: - pass - - -class DummyRemote: - def __init__(self) -> None: - self.LB = SimpleNamespace(pressed=False, on_pressed=False) - self.RB = SimpleNamespace(pressed=False, on_pressed=False) - self.start = SimpleNamespace(pressed=False, on_pressed=False) - self.A = SimpleNamespace(pressed=False, on_pressed=False) - self.B = SimpleNamespace(pressed=False, on_pressed=False) - self.Y = SimpleNamespace(pressed=False, on_pressed=False) - self.X = SimpleNamespace(pressed=False, on_pressed=False) - - -def _set_button(button: SimpleNamespace, *, pressed: bool, on_pressed: bool) -> None: - button.pressed = pressed - button.on_pressed = on_pressed - - -class DummyProvider: - def __init__(self) -> None: - self._xrt = SimpleNamespace(is_body_data_available=lambda: True) - self._frame = {"Pelvis": (np.zeros(3, dtype=np.float64), np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64))} - self.fps = 30 - self._frame_seq = 0 - self._frame_timestamp = 1.0 - self._control_events: tuple[ControlEvent, ...] = () - - def is_available(self) -> bool: - return True - - def get_frame(self) -> dict[str, tuple[np.ndarray, np.ndarray]]: - return self._frame - - def get_frame_packet(self) -> tuple[dict[str, tuple[np.ndarray, np.ndarray]], float, int]: - return self._frame, self._frame_timestamp, self._frame_seq - - def get_realtime_input_packet(self) -> RealtimeInputPacket[dict[str, tuple[np.ndarray, np.ndarray]]]: - control_events = tuple(self._control_events) - self._control_events = () - return RealtimeInputPacket( - frame=self._frame, - timestamp_s=self._frame_timestamp, - seq=self._frame_seq, - control_events=control_events, - ) - - -class DummyRetargeter: - def __init__(self, qpos: np.ndarray) -> None: - self._qpos = np.asarray(qpos, dtype=np.float64) - self.reset_calls = 0 - - def retarget(self, _frame: object) -> np.ndarray: - return self._qpos.copy() - - def reset(self) -> None: - self.reset_calls += 1 - - -class DummyPolicy: - def __init__(self, expected_obs_dim: int = 166) -> None: - self._expected_obs_dim = expected_obs_dim - self._multi_input = True - self.reset_calls = 0 - - def reset(self) -> None: - self.reset_calls += 1 - - def compute_action(self, _obs: np.ndarray) -> np.ndarray: - return np.zeros(29, dtype=np.float32) - - def get_target_dof_pos(self, _action: np.ndarray) -> np.ndarray: - return np.zeros(29, dtype=np.float32) - - -class DummyVelCmdObservationBuilder: - def __init__(self) -> None: - self.total_obs_size = 166 - self.reset_calls = 0 - self.build_calls: list[dict[str, np.ndarray]] = [] - - def reset(self) -> None: - self.reset_calls += 1 - - def build( - self, - _robot_state: object, - motion_qpos: np.ndarray, - motion_joint_vel: np.ndarray, - _last_action: np.ndarray, - motion_anchor_lin_vel_w: np.ndarray, - motion_anchor_ang_vel_w: np.ndarray, - ) -> np.ndarray: - self.build_calls.append( - { - "motion_qpos": np.asarray(motion_qpos, dtype=np.float32).copy(), - "motion_joint_vel": np.asarray(motion_joint_vel, dtype=np.float32).copy(), - "motion_anchor_lin_vel_w": np.asarray(motion_anchor_lin_vel_w, dtype=np.float32).copy(), - "motion_anchor_ang_vel_w": np.asarray(motion_anchor_ang_vel_w, dtype=np.float32).copy(), - } - ) - return np.zeros(self.total_obs_size, dtype=np.float32) - - -class DummyHandRuntime: - def __init__(self) -> None: - self.active_flags: list[bool] = [] - self.close_calls = 0 - - def start(self) -> None: - pass - - def tick(self, *, active: bool) -> None: - self.active_flags.append(active) - - def close(self) -> None: - self.close_calls += 1 - - -class FailingHandRuntime(DummyHandRuntime): - def tick(self, *, active: bool) -> None: - super().tick(active=active) - raise RuntimeError("hand send failed") - - -def _make_cfg() -> dict[str, object]: - return { - "policy_hz": 50.0, - "real_robot": { - "kp_real": [100.0] * 29, - "kd_real": [2.0] * 29, - }, - "standing_return_ramp_duration": 0.5, - "standing_return_kp_ramp_floor_ratio": 0.5, - "mocap_switch": {"check_frames": 1}, - "robot": { - "default_angles": [0.0] * 29, - "num_actions": 29, - "xml_path": "robot.xml", - }, - "controller": {}, - "input": {"provider": "pico4"}, - } - - -def _install_controller_mocks(monkeypatch, *, policy: DummyPolicy, obs_builder: DummyVelCmdObservationBuilder, qpos: np.ndarray) -> None: - import teleopit.sim2real.controller as controller_mod - - monkeypatch.setattr(controller_mod, "UnitreeG1Robot", DummyRobot) - monkeypatch.setattr(controller_mod, "UnitreeRemote", DummyRemote) - monkeypatch.setattr(controller_mod, "VelCmdObservationBuilder", DummyVelCmdObservationBuilder) - monkeypatch.setattr(controller_mod.time, "sleep", lambda _seconds: None) - monkeypatch.setattr( - controller_mod, - "build_sim2real_mocap_components", - lambda *args, **kwargs: SimpleNamespace( - input_provider=DummyProvider(), - retargeter=DummyRetargeter(qpos), - controller=policy, - obs_builder=obs_builder, - ), - ) - - -def test_mode_transitions_reset_stateful_policy(monkeypatch) -> None: - from teleopit.sim2real.controller import Sim2RealController - - policy = DummyPolicy() - obs_builder = DummyVelCmdObservationBuilder() - _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=np.zeros(36, dtype=np.float64)) - - ctrl = Sim2RealController(_make_cfg()) - ctrl._enter_standing() - ctrl._transition_to_mocap() - - # Both transitions now do soft episode-reset and preserve retargeter warm-start. - assert policy.reset_calls == 2 - assert obs_builder.reset_calls == 2 - assert ctrl.retargeter.reset_calls == 0 - - -def test_reset_policy_state_clears_reference_timeline(monkeypatch) -> None: - from teleopit.sim2real.controller import Sim2RealController - - policy = DummyPolicy() - obs_builder = DummyVelCmdObservationBuilder() - _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=np.zeros(36, dtype=np.float64)) - - ctrl = Sim2RealController(_make_cfg()) - assert ctrl._reference_timeline is not None - ctrl._reference_timeline.append(np.zeros(36, dtype=np.float64), 1.0) - ctrl._last_live_packet_seq = 7 - - ctrl._reset_policy_state() - - assert len(ctrl._reference_timeline) == 0 - assert ctrl._last_live_packet_seq == -1 - - -def test_sim2real_retarget_viewer_defaults_off(monkeypatch) -> None: - import teleopit.sim2real.controller as controller_mod - from teleopit.sim2real.controller import Sim2RealController - - policy = DummyPolicy() - obs_builder = DummyVelCmdObservationBuilder() - _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=np.zeros(36, dtype=np.float64)) - - starts: list[tuple[object, ...]] = [] - monkeypatch.setattr(controller_mod, "start_robot_viewer", lambda *args, **kwargs: starts.append(args)) - - Sim2RealController(_make_cfg()) - - assert starts == [] - - -def test_sim2real_retarget_viewer_writes_reference_qpos(monkeypatch) -> None: - import multiprocessing as mp - - import teleopit.sim2real.controller as controller_mod - from teleopit.sim2real.controller import Sim2RealController - - policy = DummyPolicy() - obs_builder = DummyVelCmdObservationBuilder() - target_qpos = np.zeros(36, dtype=np.float64) - target_qpos[0] = 0.25 - target_qpos[3] = 1.0 - target_qpos[7] = 0.5 - _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=target_qpos) - - arr = mp.Array("d", 36) - alive = mp.Value("i", 1) - shutdown = mp.Event() - proc = SimpleNamespace(join=lambda timeout=None: None, is_alive=lambda: False, terminate=lambda: None) - starts: list[tuple[object, ...]] = [] - - def fake_start_robot_viewer(*args: object, **_kwargs: object) -> tuple[object, object, object, object]: - starts.append(args) - return proc, arr, alive, shutdown - - monkeypatch.setattr(controller_mod, "start_robot_viewer", fake_start_robot_viewer) - cfg = _make_cfg() - cfg["retarget_buffer_enabled"] = False - cfg["viewers"] = "retarget" - ctrl = Sim2RealController(cfg) - monkeypatch.setattr( - ctrl._ref_proc, - "compute_anchor_velocities", - lambda _qpos: ( - np.zeros(3, dtype=np.float32), - np.zeros(3, dtype=np.float32), - ), - ) - - ctrl._mocap_step() - - assert starts - with arr.get_lock(): - written = np.asarray(arr[:], dtype=np.float64) - np.testing.assert_allclose(written[[0, 3, 7]], target_qpos[[0, 3, 7]], atol=1e-6) - - -def test_sim2real_retarget_viewer_rejects_sim_viewers(monkeypatch) -> None: - from teleopit.sim2real.controller import Sim2RealController - - policy = DummyPolicy() - obs_builder = DummyVelCmdObservationBuilder() - _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=np.zeros(36, dtype=np.float64)) - - cfg = _make_cfg() - cfg["viewers"] = ["retarget", "sim2sim"] - - with pytest.raises(ValueError, match="supports only the optional 'retarget' viewer"): - Sim2RealController(cfg) - - -def test_loop_timing_reporter_logs_percentiles_and_overruns(caplog) -> None: - import logging - - from teleopit.sim2real.controller import _LoopTimingReporter - - reporter = _LoopTimingReporter(target_period_s=0.02, log_interval_s=1.0) - with caplog.at_level(logging.INFO, logger="teleopit.sim2real.controller"): - reporter.record(loop_start_s=0.0, work_elapsed_s=0.005, cycle_elapsed_s=0.020, pico_age_s=0.010) - reporter.record(loop_start_s=0.5, work_elapsed_s=0.006, cycle_elapsed_s=0.021, pico_age_s=0.012) - reporter.record(loop_start_s=1.0, work_elapsed_s=0.007, cycle_elapsed_s=0.050, pico_age_s=0.030) - - text = caplog.text - assert "Timing stats" in text - assert "loop_ms p50=" in text - assert "p95=" in text - assert "p99=" in text - assert "max=" in text - assert "overrun=2/3" in text - assert "work_ms p50=" in text - assert "pico_age_ms p50=" in text - - -def test_sim2real_rejects_nonzero_reference_steps_without_buffer(monkeypatch) -> None: - from teleopit.sim2real.controller import Sim2RealController - - policy = DummyPolicy() - obs_builder = DummyVelCmdObservationBuilder() - _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=np.zeros(36, dtype=np.float64)) - - cfg = _make_cfg() - cfg["retarget_buffer_enabled"] = False - cfg["reference_steps"] = [0, 1] - - with pytest.raises(ValueError, match="retarget_buffer_enabled=true"): - Sim2RealController(cfg) - - -def test_sim2real_rejects_reference_horizon_with_insufficient_delay(monkeypatch) -> None: - from teleopit.sim2real.controller import Sim2RealController - - policy = DummyPolicy() - obs_builder = DummyVelCmdObservationBuilder() - _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=np.zeros(36, dtype=np.float64)) - - cfg = _make_cfg() - cfg["reference_steps"] = [0, 2] - cfg["retarget_buffer_delay_s"] = 0.01 - - with pytest.raises(ValueError, match="retarget_buffer_delay_s"): - Sim2RealController(cfg) - - -def test_state_machine_allows_mocap_reentry_after_returning_to_standing(monkeypatch) -> None: - from teleopit.sim2real.controller import RobotMode, Sim2RealController - - policy = DummyPolicy(expected_obs_dim=154) - obs_builder = DummyVelCmdObservationBuilder() - _install_controller_mocks( - monkeypatch, - policy=policy, - obs_builder=obs_builder, - qpos=np.zeros(36, dtype=np.float64), - ) - - ctrl = Sim2RealController(_make_cfg()) - - _set_button(ctrl.remote.start, pressed=True, on_pressed=True) - ctrl._handle_transitions() - assert ctrl.mode == RobotMode.STANDING - - _set_button(ctrl.remote.start, pressed=False, on_pressed=False) - _set_button(ctrl.remote.Y, pressed=True, on_pressed=True) - ctrl._handle_transitions() - assert ctrl.mode == RobotMode.MOCAP - - _set_button(ctrl.remote.Y, pressed=False, on_pressed=False) - _set_button(ctrl.remote.X, pressed=True, on_pressed=True) - ctrl._handle_transitions() - assert ctrl.mode == RobotMode.STANDING - - _set_button(ctrl.remote.X, pressed=False, on_pressed=False) - _set_button(ctrl.remote.Y, pressed=True, on_pressed=False) - ctrl._handle_transitions() - assert ctrl.mode == RobotMode.MOCAP - - -def test_return_to_standing_uses_default_pose_and_stronger_ramp_without_relock(monkeypatch) -> None: - from teleopit.sim2real.controller import RobotMode, Sim2RealController - - policy = DummyPolicy() - obs_builder = DummyVelCmdObservationBuilder() - _install_controller_mocks( - monkeypatch, - policy=policy, - obs_builder=obs_builder, - qpos=np.zeros(36, dtype=np.float64), - ) - - cfg = _make_cfg() - cfg["robot"]["default_angles"] = [0.2] * 29 - ctrl = Sim2RealController(cfg) - ctrl.mode = RobotMode.MOCAP - ctrl.robot._state.qpos = np.ones(29, dtype=np.float32) - ctrl._last_commanded_motion_qpos = np.ones(36, dtype=np.float64) - - ctrl._enter_standing() - ctrl._standing_step() - - assert ctrl.mode == RobotMode.STANDING - assert ctrl.robot.lock_calls == 0 - np.testing.assert_allclose(ctrl._standing_qpos[7:36], np.full(29, 0.2, dtype=np.float64)) - kp, kd = ctrl.robot.sent_gains[-1] - assert kp is not None - assert kd is not None - np.testing.assert_allclose(kp, np.full(29, 50.0, dtype=np.float32)) - - -def test_dexterous_hand_ticks_only_during_active_mocap(monkeypatch) -> None: - import teleopit.sim2real.controller as controller_mod - from teleopit.runtime.mocap_session import MocapSessionState - from teleopit.sim2real.controller import RobotMode, Sim2RealController - - policy = DummyPolicy() - obs_builder = DummyVelCmdObservationBuilder() - hand_runtime = DummyHandRuntime() - _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=np.zeros(36, dtype=np.float64)) - monkeypatch.setattr(controller_mod, "build_linkerhand_runtime", lambda _cfg, _provider: hand_runtime) - - ctrl = Sim2RealController(_make_cfg()) - - ctrl.mode = RobotMode.STANDING - ctrl._tick_dexterous_hand() - - ctrl.mode = RobotMode.MOCAP - ctrl._mocap_session.reset() - assert ctrl._mocap_session.state == MocapSessionState.ACTIVE - ctrl._tick_dexterous_hand() - - ctrl._mocap_session.pause(np.zeros(36, dtype=np.float64)) - ctrl._tick_dexterous_hand() - - assert hand_runtime.active_flags == [False, True, False] - - -def test_dexterous_hand_failure_does_not_enter_damping(monkeypatch) -> None: - import teleopit.sim2real.controller as controller_mod - from teleopit.sim2real.controller import RobotMode, Sim2RealController - - policy = DummyPolicy() - obs_builder = DummyVelCmdObservationBuilder() - hand_runtime = FailingHandRuntime() - _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=np.zeros(36, dtype=np.float64)) - monkeypatch.setattr(controller_mod, "build_linkerhand_runtime", lambda _cfg, _provider: hand_runtime) - - ctrl = Sim2RealController(_make_cfg()) - ctrl.mode = RobotMode.MOCAP - ctrl._mocap_session.reset() - - ctrl._tick_dexterous_hand() - - assert ctrl.mode == RobotMode.MOCAP - assert hand_runtime.active_flags == [True] - - -def test_realtime_input_timeout_holds_mocap_instead_of_damping(monkeypatch) -> None: - from teleopit.sim2real.controller import RobotMode, Sim2RealController - - policy = DummyPolicy() - obs_builder = DummyVelCmdObservationBuilder() - _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=np.zeros(36, dtype=np.float64)) - - ctrl = Sim2RealController(_make_cfg()) - ctrl.mode = RobotMode.MOCAP - ctrl._mocap_session.reset() - hold_qpos = np.zeros(36, dtype=np.float64) - hold_qpos[3] = 1.0 - hold_qpos[7] = 0.25 - ctrl._last_commanded_motion_qpos = hold_qpos.copy() - ctrl._fetch_realtime_input_packet = lambda: (_ for _ in ()).throw(TimeoutError("stalled")) - ctrl._enter_damping = lambda: pytest.fail("input timeouts must not enter damping") - - ctrl._mocap_step() - - assert ctrl.mode == RobotMode.MOCAP - np.testing.assert_allclose(obs_builder.build_calls[-1]["motion_qpos"], hold_qpos.astype(np.float32)) - assert len(ctrl.robot.sent_positions) == 1 - - -def test_can_switch_to_mocap_returns_false_without_blocking_when_realtime_has_no_frame(monkeypatch) -> None: - from teleopit.sim2real.controller import Sim2RealController - - policy = DummyPolicy() - obs_builder = DummyVelCmdObservationBuilder() - _install_controller_mocks( - monkeypatch, - policy=policy, - obs_builder=obs_builder, - qpos=np.zeros(36, dtype=np.float64), - ) - - ctrl = Sim2RealController(_make_cfg()) - get_frame_calls = 0 - - def blocking_get_frame() -> dict[str, tuple[np.ndarray, np.ndarray]]: - nonlocal get_frame_calls - get_frame_calls += 1 - raise AssertionError("get_frame should not be called before a realtime frame is available") - - ctrl.input_provider.has_frame = lambda: False - ctrl.input_provider.get_frame = blocking_get_frame - - assert ctrl._can_switch_to_mocap() is False - assert get_frame_calls == 0 - - -def test_mocap_step_episode_reset_on_transition(monkeypatch) -> None: - """After _transition_to_mocap, the first mocap step starts with zero joint velocity.""" - from teleopit.sim2real.controller import Sim2RealController - - policy = DummyPolicy() - obs_builder = DummyVelCmdObservationBuilder() - target_qpos = np.zeros(36, dtype=np.float64) - target_qpos[0] = 0.3 - target_qpos[7] = 1.0 - _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=target_qpos) - - cfg = _make_cfg() - cfg["retarget_buffer_enabled"] = False - ctrl = Sim2RealController(cfg) - ctrl._transition_to_mocap() - monkeypatch.setattr( - ctrl._ref_proc, - "compute_anchor_velocities", - lambda _qpos: ( - np.array([1.0, 2.0, 3.0], dtype=np.float32), - np.array([4.0, 5.0, 6.0], dtype=np.float32), - ), - ) - - ctrl._mocap_step() - - assert len(obs_builder.build_calls) == 1 - np.testing.assert_allclose(obs_builder.build_calls[0]["motion_joint_vel"], np.zeros(29, dtype=np.float32)) - np.testing.assert_allclose(obs_builder.build_calls[0]["motion_qpos"][7], 1.0, atol=1e-6) - - -def test_mocap_step_velcmd_applies_fixed_initial_yaw_alignment(monkeypatch) -> None: - from teleopit.sim2real.controller import Sim2RealController - - policy = DummyPolicy() - obs_builder = DummyVelCmdObservationBuilder() - target_qpos = np.zeros(36, dtype=np.float64) - target_qpos[3:7] = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64) - _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=target_qpos) - - ctrl = Sim2RealController(_make_cfg()) - ctrl.robot._state.quat = np.array([0.70710677, 0.0, 0.0, 0.70710677], dtype=np.float32) - monkeypatch.setattr( - ctrl._ref_proc, - "compute_anchor_velocities", - lambda _qpos: ( - np.zeros(3, dtype=np.float32), - np.zeros(3, dtype=np.float32), - ), - ) - - ctrl._mocap_step() - - np.testing.assert_allclose( - obs_builder.build_calls[0]["motion_qpos"][3:7], - np.array([0.70710677, 0.0, 0.0, 0.70710677], dtype=np.float32), - atol=1e-6, - ) - - -def test_mocap_step_velcmd_keeps_fixed_yaw_after_start(monkeypatch) -> None: - from teleopit.sim2real.controller import Sim2RealController - - policy = DummyPolicy() - obs_builder = DummyVelCmdObservationBuilder() - target_qpos = np.zeros(36, dtype=np.float64) - target_qpos[3:7] = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64) - _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=target_qpos) - - ctrl = Sim2RealController(_make_cfg()) - monkeypatch.setattr( - ctrl._ref_proc, - "compute_anchor_velocities", - lambda _qpos: ( - np.zeros(3, dtype=np.float32), - np.zeros(3, dtype=np.float32), - ), - ) - - ctrl.robot._state.quat = np.array([0.70710677, 0.0, 0.0, 0.70710677], dtype=np.float32) - ctrl._mocap_step() - ctrl.robot._state.quat = np.array([0.0, 0.0, 0.0, 1.0], dtype=np.float32) - ctrl._mocap_step() - - np.testing.assert_allclose( - obs_builder.build_calls[1]["motion_qpos"][3:7], - np.array([0.70710677, 0.0, 0.0, 0.70710677], dtype=np.float32), - atol=1e-6, - ) - - -def test_transition_to_mocap_uses_resume_style_alignment_and_zero_velocity(monkeypatch) -> None: - from teleopit.sim2real.controller import Sim2RealController - - policy = DummyPolicy() - obs_builder = DummyVelCmdObservationBuilder() - target_qpos = np.zeros(36, dtype=np.float64) - target_qpos[0] = 0.25 - target_qpos[7] = 0.75 - _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=target_qpos) - - cfg = _make_cfg() - cfg["retarget_buffer_enabled"] = False - ctrl = Sim2RealController(cfg) - ctrl.robot._state.base_pos = np.array([1.0, 2.0, 0.0], dtype=np.float32) - ctrl.robot._state.quat = np.array([0.9238795, 0.0, 0.0, 0.38268343], dtype=np.float32) - ctrl._transition_to_mocap() - - assert ctrl._last_retarget_qpos is None - - monkeypatch.setattr( - ctrl._ref_proc, - "compute_anchor_velocities", - lambda _qpos: ( - np.zeros(3, dtype=np.float32), - np.zeros(3, dtype=np.float32), - ), - ) - ctrl._mocap_step() - - np.testing.assert_allclose(obs_builder.build_calls[0]["motion_joint_vel"], np.zeros(29, dtype=np.float32)) - np.testing.assert_allclose(obs_builder.build_calls[0]["motion_qpos"][0:2], np.array([1.0, 2.0], dtype=np.float32), atol=1e-6) - - -def test_mocap_step_waits_for_realtime_warmup_before_running_policy(monkeypatch) -> None: - from teleopit.sim2real.controller import Sim2RealController - - policy = DummyPolicy() - obs_builder = DummyVelCmdObservationBuilder() - target_qpos = np.zeros(36, dtype=np.float64) - target_qpos[3] = 1.0 - _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=target_qpos) - - cfg = _make_cfg() - cfg['realtime_buffer_warmup_steps'] = 2 - ctrl = Sim2RealController(cfg) - monkeypatch.setattr( - ctrl._ref_proc, - 'compute_anchor_velocities', - lambda _qpos: ( - np.zeros(3, dtype=np.float32), - np.zeros(3, dtype=np.float32), - ), - ) - monkeypatch.setattr('teleopit.sim2real.controller.time.monotonic', lambda: 1.1) - - ctrl._mocap_step() - - assert len(obs_builder.build_calls) == 0 - assert ctrl.robot.sent_positions == [] - - ctrl.input_provider._frame_seq = 1 - ctrl.input_provider._frame_timestamp = 1.03 - ctrl._mocap_step() - - assert len(obs_builder.build_calls) == 1 - assert len(ctrl.robot.sent_positions) == 1 - - -def test_sim2real_allows_future_reference_steps(monkeypatch) -> None: - from teleopit.sim2real.controller import Sim2RealController - - policy = DummyPolicy() - obs_builder = DummyVelCmdObservationBuilder() - _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=np.zeros(36, dtype=np.float64)) - - cfg = _make_cfg() - cfg["reference_steps"] = [0, 1, 2, 3, 4] - cfg["retarget_buffer_delay_s"] = 0.08 - cfg["retarget_buffer_window_s"] = 0.5 - - Sim2RealController(cfg) - - -def test_mocap_step_uses_current_reference_qpos(monkeypatch) -> None: - from teleopit.sim2real.controller import Sim2RealController - - policy = DummyPolicy() - obs_builder = DummyVelCmdObservationBuilder() - target_qpos = np.zeros(36, dtype=np.float64) - target_qpos[3] = 1.0 - _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=target_qpos) - - cfg = _make_cfg() - cfg["retarget_buffer_enabled"] = False - ctrl = Sim2RealController(cfg) - monkeypatch.setattr( - ctrl._ref_proc, - "compute_anchor_velocities", - lambda _qpos: ( - np.zeros(3, dtype=np.float32), - np.zeros(3, dtype=np.float32), - ), - ) - - ctrl._mocap_step() - ctrl.retargeter._qpos[0] = 1.0 - ctrl._mocap_step() - - assert len(obs_builder.build_calls) == 2 - np.testing.assert_allclose(obs_builder.build_calls[0]["motion_qpos"][0], 0.0, atol=1e-6) - np.testing.assert_allclose(obs_builder.build_calls[1]["motion_qpos"][0], 1.0, atol=1e-6) - - -def test_mocap_pause_freezes_reference_and_zeroes_velocities(monkeypatch) -> None: - from teleopit.sim2real.controller import Sim2RealController - from teleopit.runtime.mocap_session import MocapSessionState - - policy = DummyPolicy() - obs_builder = DummyVelCmdObservationBuilder() - target_qpos = np.zeros(36, dtype=np.float64) - target_qpos[0] = 0.2 - target_qpos[3] = 1.0 - _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=target_qpos) - - cfg = _make_cfg() - cfg["retarget_buffer_enabled"] = False - ctrl = Sim2RealController(cfg) - monkeypatch.setattr( - ctrl._ref_proc, - "compute_anchor_velocities", - lambda _qpos: ( - np.zeros(3, dtype=np.float32), - np.zeros(3, dtype=np.float32), - ), - ) - - ctrl._mocap_step() - ctrl.input_provider._control_events = ( - ControlEvent( - event_type=ControlEventType.TOGGLE_PAUSE, - source="pico4:test", - timestamp_s=1.1, - ), - ) - ctrl.retargeter._qpos[0] = 1.0 - ctrl.input_provider._frame_seq = 1 - ctrl.input_provider._frame_timestamp = 1.1 - ctrl._mocap_step() - - assert ctrl._mocap_session.state == MocapSessionState.PAUSED - np.testing.assert_allclose(obs_builder.build_calls[-1]["motion_qpos"][0], 0.2, atol=1e-6) - np.testing.assert_allclose(obs_builder.build_calls[-1]["motion_joint_vel"], np.zeros(29, dtype=np.float32)) - np.testing.assert_allclose( - obs_builder.build_calls[-1]["motion_anchor_lin_vel_w"], - np.zeros(3, dtype=np.float32), - ) - np.testing.assert_allclose( - obs_builder.build_calls[-1]["motion_anchor_ang_vel_w"], - np.zeros(3, dtype=np.float32), - ) - - -def test_mocap_resume_uses_episode_reset_semantics(monkeypatch) -> None: - """Resume does an episode reset and reanchors live mocap root XY.""" - from teleopit.sim2real.controller import Sim2RealController - from teleopit.runtime.mocap_session import MocapSessionState - - policy = DummyPolicy() - obs_builder = DummyVelCmdObservationBuilder() - target_qpos = np.zeros(36, dtype=np.float64) - target_qpos[0] = 0.2 - target_qpos[3] = 1.0 - _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=target_qpos) - - cfg = _make_cfg() - cfg["retarget_buffer_enabled"] = False - ctrl = Sim2RealController(cfg) - monkeypatch.setattr( - ctrl._ref_proc, - "compute_anchor_velocities", - lambda _qpos: ( - np.zeros(3, dtype=np.float32), - np.zeros(3, dtype=np.float32), - ), - ) - - # Run one step, then pause - ctrl._mocap_step() - ctrl.input_provider._control_events = ( - ControlEvent( - event_type=ControlEventType.TOGGLE_PAUSE, - source="pico4:test", - timestamp_s=1.1, - ), - ) - ctrl.input_provider._frame_seq = 1 - ctrl.input_provider._frame_timestamp = 1.1 - ctrl._mocap_step() - assert ctrl._mocap_session.state == MocapSessionState.PAUSED - - # Resume: should stay on the ACTIVE/PAUSED state model. - ctrl.retargeter._qpos[0] = 1.0 - ctrl.retargeter._qpos[7] = 1.0 - ctrl.input_provider._control_events = ( - ControlEvent( - event_type=ControlEventType.TOGGLE_PAUSE, - source="pico4:test", - timestamp_s=1.2, - ), - ) - ctrl.input_provider._frame_seq = 2 - ctrl.input_provider._frame_timestamp = 1.2 - ctrl._mocap_step() - - # Episode-reset resume goes straight to ACTIVE - assert ctrl._mocap_session.state == MocapSessionState.ACTIVE - # Policy was reset (last_action zeroed, history cleared) - assert np.allclose(ctrl._last_action, 0.0) - assert ctrl.retargeter.reset_calls == 0 - # Retarget reference jumps to the live mocap pose (joint 0), while root XY - # is reanchored to the paused reference because real-robot XY is unobserved. - np.testing.assert_allclose(obs_builder.build_calls[-1]["motion_qpos"][0], 0.2, atol=1e-6) - np.testing.assert_allclose(obs_builder.build_calls[-1]["motion_qpos"][7], 1.0, atol=1e-6) - np.testing.assert_allclose(obs_builder.build_calls[-1]["motion_joint_vel"], np.zeros(29, dtype=np.float32)) diff --git a/tests/test_termination_config.py b/tests/test_termination_config.py index 3c9303b5..bf51159d 100644 --- a/tests/test_termination_config.py +++ b/tests/test_termination_config.py @@ -23,7 +23,7 @@ def test_general_tracking_termination_config_matches_baseline_policy() -> None: assert anchor_pos.func is mdp.bad_anchor_pos_z_only assert anchor_pos.params == { "command_name": "motion", - "threshold": 0.4, + "threshold": 0.25, } anchor_ori = terminations["anchor_ori"] @@ -34,7 +34,7 @@ def test_general_tracking_termination_config_matches_baseline_policy() -> None: assert ee_body_pos.func is mdp.bad_motion_body_pos_z_only assert ee_body_pos.params == { "command_name": "motion", - "threshold": 0.4, + "threshold": 0.25, "body_names": ( "left_ankle_roll_link", "right_ankle_roll_link", diff --git a/third_party/somehand b/third_party/somehand index a88bfe2f..0e9adba4 160000 --- a/third_party/somehand +++ b/third_party/somehand @@ -1 +1 @@ -Subproject commit a88bfe2f09eb3821b310774aa166f8506645a35c +Subproject commit 0e9adba4e193540279f8e5803a9339a49666499a From 642c99f36a3c29487f38ac4f946fe2a5e2762199 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Tue, 9 Jun 2026 16:54:59 +0800 Subject: [PATCH 065/122] Clarify sim2real loop timing stats --- teleopit/sim2real/mp/runtime.py | 36 ++++++++++++++++++++++------- tests/test_sim2real_multiprocess.py | 23 +++++++++++++++++- 2 files changed, 50 insertions(+), 9 deletions(-) diff --git a/teleopit/sim2real/mp/runtime.py b/teleopit/sim2real/mp/runtime.py index b1279e98..b4924a90 100644 --- a/teleopit/sim2real/mp/runtime.py +++ b/teleopit/sim2real/mp/runtime.py @@ -89,24 +89,36 @@ class RobotMode(Enum): class _LoopTimingReporter: - def __init__(self, *, target_period_s: float, log_interval_s: float = 1.0) -> None: + def __init__( + self, + *, + target_period_s: float, + log_interval_s: float = 1.0, + deadline_miss_tolerance_s: float = 0.001, + ) -> None: self._target_period_s = float(target_period_s) self._log_interval_s = float(log_interval_s) + self._deadline_miss_tolerance_s = float(deadline_miss_tolerance_s) self._window_start_s: float | None = None self._loop_ms: list[float] = [] + self._late_ms: list[float] = [] self._work_ms: list[float] = [] self._pico_age_ms: list[float] = [] - self._overrun_count = 0 + self._deadline_miss_count = 0 + self._work_overrun_count = 0 def record(self, *, loop_start_s: float, work_elapsed_s: float, cycle_elapsed_s: float, pico_age_s: float | None) -> None: if self._window_start_s is None: self._window_start_s = float(loop_start_s) self._loop_ms.append(float(cycle_elapsed_s) * 1000.0) + self._late_ms.append(max(0.0, float(cycle_elapsed_s) - self._target_period_s) * 1000.0) self._work_ms.append(float(work_elapsed_s) * 1000.0) if pico_age_s is not None: self._pico_age_ms.append(float(pico_age_s) * 1000.0) - if cycle_elapsed_s > self._target_period_s + 1e-9: - self._overrun_count += 1 + if cycle_elapsed_s > self._target_period_s + self._deadline_miss_tolerance_s: + self._deadline_miss_count += 1 + if work_elapsed_s > self._target_period_s + 1e-9: + self._work_overrun_count += 1 if loop_start_s - self._window_start_s >= self._log_interval_s: self._emit(loop_start_s) @@ -116,19 +128,25 @@ def _emit(self, end_s: float) -> None: self._reset(end_s) return loop_summary = self._summarize(self._loop_ms) + late_summary = self._summarize(self._late_ms) work_summary = self._summarize(self._work_ms) message = ( "Timing stats | samples=%d window=%.1fs | " - "loop_ms p50=%.2f p95=%.2f p99=%.2f max=%.2f overrun=%d/%d | " - "work_ms p50=%.2f p95=%.2f p99=%.2f max=%.2f" + "loop_ms p50=%.2f p95=%.2f p99=%.2f max=%.2f | " + "late_ms p50=%.2f p95=%.2f p99=%.2f max=%.2f deadline_miss(>%.2fms)=%d/%d | " + "work_ms p50=%.2f p95=%.2f p99=%.2f max=%.2f work_overrun=%d/%d" ) args: list[object] = [ sample_count, end_s - float(self._window_start_s), *loop_summary, - self._overrun_count, + *late_summary, + self._deadline_miss_tolerance_s * 1000.0, + self._deadline_miss_count, sample_count, *work_summary, + self._work_overrun_count, + sample_count, ] if self._pico_age_ms: message += " | reference_age_ms p50=%.2f p95=%.2f p99=%.2f max=%.2f" @@ -139,9 +157,11 @@ def _emit(self, end_s: float) -> None: def _reset(self, window_start_s: float) -> None: self._window_start_s = float(window_start_s) self._loop_ms.clear() + self._late_ms.clear() self._work_ms.clear() self._pico_age_ms.clear() - self._overrun_count = 0 + self._deadline_miss_count = 0 + self._work_overrun_count = 0 @staticmethod def _summarize(samples: list[float]) -> tuple[float, float, float, float]: diff --git a/tests/test_sim2real_multiprocess.py b/tests/test_sim2real_multiprocess.py index db653532..e8023d35 100644 --- a/tests/test_sim2real_multiprocess.py +++ b/tests/test_sim2real_multiprocess.py @@ -1,6 +1,7 @@ from __future__ import annotations import importlib.util +import logging from pathlib import Path from types import SimpleNamespace @@ -10,10 +11,30 @@ from teleopit.runtime.mocap_session import MocapSessionState from teleopit.sim2real.mp.ipc import HEALTH_TOPIC, LatestSubscriber, ZmqPublisher from teleopit.sim2real.mp.messages import ReferencePacket, SharedFrameDescriptor -from teleopit.sim2real.mp.runtime import RobotMode, Sim2RealRuntime, _RobotControlWorker, _human_frame_is_valid +from teleopit.sim2real.mp.runtime import ( + RobotMode, + Sim2RealRuntime, + _LoopTimingReporter, + _RobotControlWorker, + _human_frame_is_valid, +) from teleopit.sim2real.mp.shm import SharedFrameRingReader, SharedFrameRingWriter +def test_loop_timing_reporter_separates_late_sleep_from_work_overrun(caplog) -> None: + reporter = _LoopTimingReporter(target_period_s=0.02, log_interval_s=1.0, deadline_miss_tolerance_s=0.001) + + with caplog.at_level(logging.INFO, logger="teleopit.sim2real.mp.runtime"): + reporter.record(loop_start_s=0.0, work_elapsed_s=0.0004, cycle_elapsed_s=0.02006, pico_age_s=None) + reporter.record(loop_start_s=1.0, work_elapsed_s=0.021, cycle_elapsed_s=0.0212, pico_age_s=None) + + message = caplog.messages[-1] + assert "late_ms" in message + assert "deadline_miss(>1.00ms)=1/2" in message + assert "work_overrun=1/2" in message + assert " overrun=" not in message + + def test_sim2real_runtime_rejects_legacy_runtime_keys() -> None: cfg = {"sim2real_runtime": "auto", "input": {"provider": "pico4"}} with pytest.raises(ValueError, match="Legacy sim2real config keys"): From 93353ef86d8eb3a955c5f0e5cee6f138a62d838a Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Tue, 9 Jun 2026 17:22:39 +0800 Subject: [PATCH 066/122] Avoid eager sim2real runtime imports --- teleopit/sim2real/__init__.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/teleopit/sim2real/__init__.py b/teleopit/sim2real/__init__.py index 9f8f6921..2c22fb9c 100644 --- a/teleopit/sim2real/__init__.py +++ b/teleopit/sim2real/__init__.py @@ -1,10 +1,22 @@ -from teleopit.sim2real.mp import Sim2RealRuntime -from teleopit.sim2real.unitree_g1 import UnitreeG1Robot -from teleopit.sim2real.remote import UnitreeRemote, Button - __all__ = [ "Sim2RealRuntime", "UnitreeG1Robot", "UnitreeRemote", "Button", ] + + +def __getattr__(name: str): + if name == "Sim2RealRuntime": + from teleopit.sim2real.mp import Sim2RealRuntime + + return Sim2RealRuntime + if name == "UnitreeG1Robot": + from teleopit.sim2real.unitree_g1 import UnitreeG1Robot + + return UnitreeG1Robot + if name in ("UnitreeRemote", "Button"): + from teleopit.sim2real.remote import Button, UnitreeRemote + + return {"UnitreeRemote": UnitreeRemote, "Button": Button}[name] + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") From d252cacfb852c44806fca58a3d71396ec15577f7 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Tue, 9 Jun 2026 18:01:26 +0800 Subject: [PATCH 067/122] Avoid blocking RealSense Pico video updates --- teleopit/inputs/pico_video.py | 34 +++++++++++++++++++++- tests/test_pico_video.py | 55 +++++++++++++++++++++++++++++++++++ 2 files changed, 88 insertions(+), 1 deletion(-) diff --git a/teleopit/inputs/pico_video.py b/teleopit/inputs/pico_video.py index a1c3191b..5732113e 100644 --- a/teleopit/inputs/pico_video.py +++ b/teleopit/inputs/pico_video.py @@ -152,6 +152,9 @@ def stop(self) -> None: ... class _RealSenseVideoProducer(_VideoProducer): + _FRAME_WAIT_TIMEOUT_MS = 100 + _FRAME_TIMEOUT_LOG_INTERVAL_S = 2.0 + def __init__(self, provider: Any, config: PicoVideoConfig) -> None: self._provider = provider self._config = config @@ -199,9 +202,19 @@ def _run(self) -> None: ) pipeline.start(config) self._ready_event.set() + last_timeout_log_s = 0.0 try: while not self._stop_event.is_set(): - frames = pipeline.wait_for_frames() + frames = self._wait_for_frames(pipeline) + if frames is None: + now = time.monotonic() + if now - last_timeout_log_s >= self._FRAME_TIMEOUT_LOG_INTERVAL_S: + last_timeout_log_s = now + logger.warning( + "RealSense Pico video has no new color frame yet | pushed_frames=%d", + self._pushed_frames, + ) + continue color_frame = frames.get_color_frame() if not color_frame: continue @@ -214,6 +227,25 @@ def _run(self) -> None: self._ready_event.set() logger.exception("RealSense Pico video producer failed") + def _wait_for_frames(self, pipeline: Any) -> Any | None: + try_wait_for_frames = getattr(pipeline, "try_wait_for_frames", None) + if callable(try_wait_for_frames): + result = try_wait_for_frames(timeout_ms=self._FRAME_WAIT_TIMEOUT_MS) + if isinstance(result, tuple): + ok = bool(result[0]) if result else False + return result[1] if ok and len(result) > 1 else None + return result if result else None + + try: + return pipeline.wait_for_frames(timeout_ms=self._FRAME_WAIT_TIMEOUT_MS) + except TypeError: + return pipeline.wait_for_frames() + except RuntimeError as exc: + message = str(exc).lower() + if "timeout" in message or "timed out" in message or "frame didn't arrive" in message: + return None + raise + class _MujocoCameraVideoProducer(_VideoProducer): def __init__(self, provider: Any, config: PicoVideoConfig, robot: Any | None) -> None: diff --git a/tests/test_pico_video.py b/tests/test_pico_video.py index 075ec194..c1799833 100644 --- a/tests/test_pico_video.py +++ b/tests/test_pico_video.py @@ -87,6 +87,61 @@ def stop(self) -> None: assert sink.frames[-1].dtype == np.uint8 +def test_realsense_video_runtime_keeps_running_when_frames_timeout(monkeypatch: pytest.MonkeyPatch) -> None: + fake_rs = ModuleType("pyrealsense2") + fake_rs.stream = SimpleNamespace(color="color") + fake_rs.format = SimpleNamespace(rgb8="rgb8") + + class FakeConfig: + def enable_stream(self, *_args: object) -> None: + pass + + class FakeColorFrame: + def get_data(self) -> np.ndarray: + return np.full((2, 3, 3), 7, dtype=np.uint8) + + class FakeFrames: + def get_color_frame(self) -> FakeColorFrame: + return FakeColorFrame() + + class FakePipeline: + def __init__(self) -> None: + self.calls = 0 + + def start(self, _config: object) -> None: + pass + + def try_wait_for_frames(self, *, timeout_ms: int) -> tuple[bool, FakeFrames | None]: + assert timeout_ms > 0 + self.calls += 1 + time.sleep(0.002) + if self.calls == 1: + return True, FakeFrames() + return False, None + + def stop(self) -> None: + pass + + fake_rs.config = FakeConfig + fake_rs.pipeline = FakePipeline + monkeypatch.setitem(sys.modules, "pyrealsense2", fake_rs) + + sink = _FrameSink() + config = parse_pico_video_config( + {"video": {"enabled": True, "source": "realsense", "fail_on_error": True}} + ) + runtime = PicoVideoRuntime(provider=sink, config=config, mode="sim2real") + + runtime.start() + time.sleep(0.03) + runtime.tick() + pushed_frames = runtime.pushed_frames + runtime.stop() + + assert len(sink.frames) == 1 + assert pushed_frames == 1 + + def test_video_runtime_stops_producer_after_startup_error(monkeypatch: pytest.MonkeyPatch) -> None: stopped = False From 4a84253eff5d34b1569ae037b0719328b76186e4 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Tue, 9 Jun 2026 18:04:24 +0800 Subject: [PATCH 068/122] Revert "Avoid blocking RealSense Pico video updates" This reverts commit d252cacfb852c44806fca58a3d71396ec15577f7. --- teleopit/inputs/pico_video.py | 34 +--------------------- tests/test_pico_video.py | 55 ----------------------------------- 2 files changed, 1 insertion(+), 88 deletions(-) diff --git a/teleopit/inputs/pico_video.py b/teleopit/inputs/pico_video.py index 5732113e..a1c3191b 100644 --- a/teleopit/inputs/pico_video.py +++ b/teleopit/inputs/pico_video.py @@ -152,9 +152,6 @@ def stop(self) -> None: ... class _RealSenseVideoProducer(_VideoProducer): - _FRAME_WAIT_TIMEOUT_MS = 100 - _FRAME_TIMEOUT_LOG_INTERVAL_S = 2.0 - def __init__(self, provider: Any, config: PicoVideoConfig) -> None: self._provider = provider self._config = config @@ -202,19 +199,9 @@ def _run(self) -> None: ) pipeline.start(config) self._ready_event.set() - last_timeout_log_s = 0.0 try: while not self._stop_event.is_set(): - frames = self._wait_for_frames(pipeline) - if frames is None: - now = time.monotonic() - if now - last_timeout_log_s >= self._FRAME_TIMEOUT_LOG_INTERVAL_S: - last_timeout_log_s = now - logger.warning( - "RealSense Pico video has no new color frame yet | pushed_frames=%d", - self._pushed_frames, - ) - continue + frames = pipeline.wait_for_frames() color_frame = frames.get_color_frame() if not color_frame: continue @@ -227,25 +214,6 @@ def _run(self) -> None: self._ready_event.set() logger.exception("RealSense Pico video producer failed") - def _wait_for_frames(self, pipeline: Any) -> Any | None: - try_wait_for_frames = getattr(pipeline, "try_wait_for_frames", None) - if callable(try_wait_for_frames): - result = try_wait_for_frames(timeout_ms=self._FRAME_WAIT_TIMEOUT_MS) - if isinstance(result, tuple): - ok = bool(result[0]) if result else False - return result[1] if ok and len(result) > 1 else None - return result if result else None - - try: - return pipeline.wait_for_frames(timeout_ms=self._FRAME_WAIT_TIMEOUT_MS) - except TypeError: - return pipeline.wait_for_frames() - except RuntimeError as exc: - message = str(exc).lower() - if "timeout" in message or "timed out" in message or "frame didn't arrive" in message: - return None - raise - class _MujocoCameraVideoProducer(_VideoProducer): def __init__(self, provider: Any, config: PicoVideoConfig, robot: Any | None) -> None: diff --git a/tests/test_pico_video.py b/tests/test_pico_video.py index c1799833..075ec194 100644 --- a/tests/test_pico_video.py +++ b/tests/test_pico_video.py @@ -87,61 +87,6 @@ def stop(self) -> None: assert sink.frames[-1].dtype == np.uint8 -def test_realsense_video_runtime_keeps_running_when_frames_timeout(monkeypatch: pytest.MonkeyPatch) -> None: - fake_rs = ModuleType("pyrealsense2") - fake_rs.stream = SimpleNamespace(color="color") - fake_rs.format = SimpleNamespace(rgb8="rgb8") - - class FakeConfig: - def enable_stream(self, *_args: object) -> None: - pass - - class FakeColorFrame: - def get_data(self) -> np.ndarray: - return np.full((2, 3, 3), 7, dtype=np.uint8) - - class FakeFrames: - def get_color_frame(self) -> FakeColorFrame: - return FakeColorFrame() - - class FakePipeline: - def __init__(self) -> None: - self.calls = 0 - - def start(self, _config: object) -> None: - pass - - def try_wait_for_frames(self, *, timeout_ms: int) -> tuple[bool, FakeFrames | None]: - assert timeout_ms > 0 - self.calls += 1 - time.sleep(0.002) - if self.calls == 1: - return True, FakeFrames() - return False, None - - def stop(self) -> None: - pass - - fake_rs.config = FakeConfig - fake_rs.pipeline = FakePipeline - monkeypatch.setitem(sys.modules, "pyrealsense2", fake_rs) - - sink = _FrameSink() - config = parse_pico_video_config( - {"video": {"enabled": True, "source": "realsense", "fail_on_error": True}} - ) - runtime = PicoVideoRuntime(provider=sink, config=config, mode="sim2real") - - runtime.start() - time.sleep(0.03) - runtime.tick() - pushed_frames = runtime.pushed_frames - runtime.stop() - - assert len(sink.frames) == 1 - assert pushed_frames == 1 - - def test_video_runtime_stops_producer_after_startup_error(monkeypatch: pytest.MonkeyPatch) -> None: stopped = False From 5423d32d7c4831891250e0bec56c8c78bf0d51db Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Tue, 9 Jun 2026 18:12:33 +0800 Subject: [PATCH 069/122] Remove stale multiprocess Pico video worker --- teleopit/sim2real/mp/runtime.py | 62 ----------------------------- tests/test_sim2real_multiprocess.py | 31 +++++++++++++++ 2 files changed, 31 insertions(+), 62 deletions(-) diff --git a/teleopit/sim2real/mp/runtime.py b/teleopit/sim2real/mp/runtime.py index b4924a90..4482221f 100644 --- a/teleopit/sim2real/mp/runtime.py +++ b/teleopit/sim2real/mp/runtime.py @@ -46,7 +46,6 @@ HEALTH_TOPIC, MODE_TOPIC, REFERENCE_TOPIC, - VIDEO_TOPIC, LatestSubscriber, Sim2RealIpcEndpoints, ZmqPublisher, @@ -59,10 +58,8 @@ HealthPacket, ModeStatePacket, ReferencePacket, - SharedFrameDescriptor, SnapshotPacket, ) -from teleopit.sim2real.mp.shm import SharedFrameRingWriter from teleopit.sim2real.reference_processor import Sim2RealReferenceProcessor from teleopit.sim2real.remote import UnitreeRemote from teleopit.sim2real.safety import Sim2RealSafetyManager @@ -1425,62 +1422,3 @@ def _main() -> None: command_sub.close() _worker_loop("hand_worker", _main) - - -def _run_video_worker( - cfg: dict[str, Any], - endpoints: Sim2RealIpcEndpoints, - stop_event: MpEvent, -) -> None: - def _main() -> None: - input_cfg = cfg_get(cfg, "input", {}) or {} - video_cfg = parse_pico_video_config(input_cfg) - if not video_cfg.enabled: - return - if video_cfg.source not in ("realsense",): - logger.warning("Multiprocess video worker supports source=realsense; got %s", video_cfg.source) - return - - writer = SharedFrameRingWriter( - shape=(video_cfg.height, video_cfg.width, 3), - dtype=np.uint8, - slots=int(cfg_get(_mp_cfg(cfg), "video_slots", 3)), - ) - video_pub = ZmqPublisher(endpoints.video_pub) - command_sub = LatestSubscriber(endpoints.command_pub, COMMAND_TOPIC) - try: - import pyrealsense2 as rs - - pipeline = rs.pipeline() - rs_config = rs.config() - if video_cfg.device is not None: - rs_config.enable_device(video_cfg.device) - rs_config.enable_stream( - rs.stream.color, - video_cfg.width, - video_cfg.height, - rs.format.rgb8, - video_cfg.fps, - ) - pipeline.start(rs_config) - try: - while not stop_event.is_set(): - command = command_sub.recv_latest() - if isinstance(command, CommandPacket) and command.command == "shutdown": - stop_event.set() - break - frames = pipeline.wait_for_frames() - color_frame = frames.get_color_frame() - if not color_frame: - continue - rgb = np.ascontiguousarray(np.asanyarray(color_frame.get_data()), dtype=np.uint8) - descriptor = writer.write(rgb, timestamp_s=time.monotonic()) - video_pub.publish(VIDEO_TOPIC, descriptor) - finally: - pipeline.stop() - finally: - command_sub.close() - video_pub.close() - writer.close(unlink=True) - - _worker_loop("video_worker", _main) diff --git a/tests/test_sim2real_multiprocess.py b/tests/test_sim2real_multiprocess.py index e8023d35..ae765e1c 100644 --- a/tests/test_sim2real_multiprocess.py +++ b/tests/test_sim2real_multiprocess.py @@ -179,6 +179,37 @@ def fake_start_processes(self: Sim2RealRuntime) -> None: assert controller._processes == [] +def test_pico_video_does_not_spawn_separate_video_worker() -> None: + started_names: list[str] = [] + + class FakeProcess: + def __init__(self, *, name: str, target: object, args: tuple[object, ...]) -> None: + del target, args + self.name = name + self.exitcode = 0 + + def start(self) -> None: + started_names.append(self.name) + + class FakeContext: + def Event(self) -> object: + return SimpleNamespace(set=lambda: None, is_set=lambda: False) + + def Process(self, *, name: str, target: object, args: tuple[object, ...]) -> FakeProcess: + return FakeProcess(name=name, target=target, args=args) + + cfg = { + "input": {"provider": "pico4", "video": {"enabled": True, "source": "realsense"}}, + "runtime": {"shutdown_timeout_s": 0.01}, + } + runtime = Sim2RealRuntime(cfg) + runtime._ctx = FakeContext() # type: ignore[assignment] + + runtime._start_processes() + + assert started_names == ["pico_input", "reference", "robot_control"] + + def test_human_frame_validation_rejects_bad_inputs() -> None: valid_frame = { "Pelvis": (np.zeros(3, dtype=np.float64), np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64)), From c9c68fa3468edc129141ee7b6116a6647cc4bde1 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Wed, 10 Jun 2026 10:32:32 +0800 Subject: [PATCH 070/122] Refine dataset filtering and ground alignment --- teleopit/inputs/pico4_provider.py | 22 +- teleopit/runtime/assets.py | 1 - tests/test_dataset_v2.py | 164 ++++++++- tests/test_domain_randomization.py | 48 --- tests/test_pico4_provider.py | 37 ++- tests/test_task_registry.py | 17 - train_mimic/configs/datasets/lafan1.yaml | 6 +- train_mimic/data/dataset_builder.py | 310 ++++++++++++++---- train_mimic/data/preprocess.py | 45 ++- train_mimic/tasks/tracking/config/env.py | 40 +-- .../tasks/tracking/tracking_env_cfg.py | 52 --- 11 files changed, 502 insertions(+), 240 deletions(-) diff --git a/teleopit/inputs/pico4_provider.py b/teleopit/inputs/pico4_provider.py index 1df18f3e..37b952c2 100644 --- a/teleopit/inputs/pico4_provider.py +++ b/teleopit/inputs/pico4_provider.py @@ -121,7 +121,7 @@ def _has_non_degenerate_positions(positions: NDArray[np.float64]) -> bool: return extent > 1e-6 -def _compute_ground_lift_offset(positions: NDArray[np.float64]) -> float: +def _compute_ground_alignment_offset(positions: NDArray[np.float64]) -> float: pos = np.asarray(positions, dtype=np.float64).reshape(-1, 3) if pos.size == 0: return 0.0 @@ -129,7 +129,7 @@ def _compute_ground_lift_offset(positions: NDArray[np.float64]) -> float: if not np.any(finite_mask): return 0.0 min_z = float(np.min(pos[finite_mask, 2])) - return max(-min_z, 0.0) + return -min_z def _bridge_accepts_video_enabled(bridge_cls: type[Any]) -> bool: @@ -233,7 +233,7 @@ def __init__( self._last_source_seq: int | None = None self._controller_snapshot: PicoControllerSnapshot | None = None self._hand_snapshot: PicoHandSnapshot | None = None - self._ground_lift_offset: float | None = None + self._ground_alignment_offset: float | None = None self._bridge = bridge_cls( host=bridge_host, port=int(bridge_port), @@ -415,7 +415,7 @@ def _accept_pico_frame(self, frame: Any) -> bool: and timestamp - self._last_frame_timestamp > self._timestamp_gap_reset_s ): self._frame_cache.clear() - self._ground_lift_offset = None + self._ground_alignment_offset = None logger.warning( "Pico4InputProvider timestamp-gap reset | gap=%.4fs", timestamp - self._last_frame_timestamp, @@ -423,7 +423,7 @@ def _accept_pico_frame(self, frame: Any) -> bool: if self._last_frame_timestamp is not None and timestamp <= self._last_frame_timestamp + 1e-9: timestamp = self._last_frame_timestamp + 1e-6 - human_frame = self._apply_ground_lift(human_frame) + human_frame = self._apply_ground_alignment(human_frame) self._frame_cache.append(human_frame, timestamp, fps_timestamp=timestamp) self._last_raw_body_joints = body_joints.copy() self._last_frame_timestamp = timestamp @@ -534,17 +534,17 @@ def _convert_body_joints_to_frame(body_joints: NDArray[np.float64]) -> HumanFram result[name] = (np.asarray(pos, dtype=np.float64), np.asarray(quat, dtype=np.float64)) return result - def _apply_ground_lift(self, human_frame: HumanFrame) -> HumanFrame: - """Apply one fixed Z lift so the initial Pico skeleton sits on the floor.""" - if self._ground_lift_offset is None: + def _apply_ground_alignment(self, human_frame: HumanFrame) -> HumanFrame: + """Apply one fixed Z offset so the initial Pico skeleton sits on the floor.""" + if self._ground_alignment_offset is None: positions = np.asarray([value[0] for value in human_frame.values()], dtype=np.float64) if _has_non_degenerate_positions(positions): - self._ground_lift_offset = _compute_ground_lift_offset(positions) + self._ground_alignment_offset = _compute_ground_alignment_offset(positions) else: return human_frame - offset = float(self._ground_lift_offset) - if offset <= 0.0: + offset = float(self._ground_alignment_offset) + if abs(offset) <= 1e-12: return human_frame z_offset = np.array([0.0, 0.0, offset], dtype=np.float64) diff --git a/teleopit/runtime/assets.py b/teleopit/runtime/assets.py index f86ca79a..e9be0638 100644 --- a/teleopit/runtime/assets.py +++ b/teleopit/runtime/assets.py @@ -6,7 +6,6 @@ PROJECT_ROOT = Path(__file__).resolve().parents[2] GMR_ASSETS_ROOT = PROJECT_ROOT / "teleopit" / "retargeting" / "gmr" / "assets" UNITREE_G1_MJLAB_XML = GMR_ASSETS_ROOT / "unitree_g1" / "g1_mjlab.xml" -UNITREE_G1_MJLAB_PAYLOAD_XML = GMR_ASSETS_ROOT / "unitree_g1" / "g1_mjlab_payload.xml" def missing_gmr_assets_message(path: str | Path, *, label: str = "Required asset") -> str: diff --git a/tests/test_dataset_v2.py b/tests/test_dataset_v2.py index 2625ddfd..0aebf006 100644 --- a/tests/test_dataset_v2.py +++ b/tests/test_dataset_v2.py @@ -131,20 +131,26 @@ def test_load_dataset_spec_parses_preprocess(tmp_path: Path) -> None: hash_salt: "" preprocess: normalize_root_xy: true - ground_align: clip_min_foot + ground_align: none min_frames: 10 + max_all_off_ground_s: 0.5 + off_ground_height: 0.08 sources: - name: clips type: npz input: {tmp_path / 'npz_source'} + exclude_patterns: ["*obstacle*"] """, encoding="utf-8", ) spec = load_dataset_spec(spec_path) assert spec.preprocess.normalize_root_xy is True - assert spec.preprocess.ground_align == "clip_min_foot" + assert spec.preprocess.ground_align == "none" assert spec.preprocess.min_frames == 10 + assert spec.preprocess.max_all_off_ground_s == 0.5 + assert spec.preprocess.off_ground_height == 0.08 + assert spec.sources[0].exclude_patterns == ("*obstacle*",) def test_load_dataset_spec_parses_seed_filter_preset(tmp_path: Path) -> None: @@ -428,6 +434,71 @@ def test_collect_source_files_with_report_handles_single_file_source(tmp_path: P assert [item.rel_no_suffix.as_posix() for item in legacy_items] == ["clip_a"] +def test_collect_source_files_with_report_applies_exclude_patterns(tmp_path: Path) -> None: + input_root = tmp_path / "lafan1" + input_root.mkdir(parents=True, exist_ok=True) + (input_root / "walk1.bvh").write_text("placeholder", encoding="utf-8") + (input_root / "obstacle_run.bvh").write_text("placeholder", encoding="utf-8") + nested = input_root / "subject1" + nested.mkdir() + (nested / "obstacle_jump.bvh").write_text("placeholder", encoding="utf-8") + + source = DatasetSourceSpec( + name="lafan1", + type="bvh", + input=str(input_root), + bvh_format="lafan1", + exclude_patterns=("*obstacle*",), + ) + + items, _scan_root, report = dataset_builder._collect_source_files_with_report( + source, + quiet=True, + ) + + assert [item.rel_no_suffix.as_posix() for item in items] == ["walk1"] + assert report["scanned_files"] == 3 + assert report["path_rejected_files"] == 2 + assert report["kept_files"] == 1 + assert report["filtered_files"] == 2 + assert report["path_reject_reasons"] == {"*obstacle*": 2} + + +def test_collect_source_files_with_report_preserves_path_excludes_with_metadata(tmp_path: Path) -> None: + input_root = tmp_path / "seed_source" / "g1" / "csv" + input_root.mkdir(parents=True, exist_ok=True) + for name in ("walk_a.csv", "walk_b.csv", "obstacle_walk.csv"): + (input_root / name).write_text("placeholder", encoding="utf-8") + + metadata_csv = tmp_path / "metadata.csv" + with metadata_csv.open("w", encoding="utf-8", newline="") as handle: + writer = csv.DictWriter(handle, fieldnames=["move_g1_path", "is_mirror"]) + writer.writeheader() + for name in ("walk_a.csv", "walk_b.csv", "obstacle_walk.csv"): + writer.writerow({"move_g1_path": f"g1/csv/{name}", "is_mirror": "False"}) + + source = DatasetSourceSpec( + name="seed", + type="seed_csv", + input=str(input_root), + metadata_csv=str(metadata_csv), + filters={"is_mirror": [False]}, + exclude_patterns=("*obstacle*",), + ) + + items, _scan_root, report = dataset_builder._collect_source_files_with_report( + source, + quiet=True, + ) + + assert [item.rel_no_suffix.as_posix() for item in items] == ["walk_a", "walk_b"] + assert report["scanned_files"] == 3 + assert report["path_rejected_files"] == 1 + assert report["kept_files"] == 2 + assert report["filtered_files"] == 1 + assert report["path_reject_reasons"] == {"*obstacle*": 1} + + def test_build_dataset_from_spec_writes_shard_directories(tmp_path: Path) -> None: npz_input = tmp_path / "npz_source" _write_npz_from_pkl(npz_input / "clip_a.npz") @@ -460,6 +531,35 @@ def test_build_dataset_from_spec_writes_shard_directories(tmp_path: Path) -> Non assert "clip_lengths" in train_data.files +def test_collect_clip_rows_ignores_stale_excluded_cached_npz(tmp_path: Path) -> None: + npz_input = tmp_path / "npz_source" + for name in ("keep_a.npz", "keep_b.npz", "obstacle_old.npz"): + _write_npz_from_pkl(npz_input / name) + + spec = DatasetSpec( + name="demo_dataset", + target_fps=30, + val_percent=5, + hash_salt="", + sources=[ + DatasetSourceSpec( + name="npz_src", + type="npz", + input=str(npz_input), + exclude_patterns=("*obstacle*",), + ) + ], + ) + paths = dataset_builder.resolve_dataset_paths(spec, output_root=tmp_path / "datasets") + source_dir = paths.clips_root / "npz_src" + for name in ("keep_a.npz", "keep_b.npz", "obstacle_old.npz"): + _write_npz_from_pkl(source_dir / name) + + rows = dataset_builder.collect_clip_rows(spec, paths=paths) + + assert sorted(row.clip_id for row in rows) == ["npz_src:keep_a", "npz_src:keep_b"] + + def test_convert_source_to_npz_clips_applies_preprocess(tmp_path: Path) -> None: npz_input = tmp_path / "npz_source" _write_npz_from_pkl(npz_input / "clip_a.npz") @@ -503,6 +603,66 @@ def test_convert_source_to_npz_clips_applies_preprocess(tmp_path: Path) -> None: assert np.isclose(float(np.min(foot_z)), 0.0) +def test_convert_source_to_npz_clips_skips_all_off_ground_clips_before_ground_align(tmp_path: Path) -> None: + npz_input = tmp_path / "npz_source" + _write_npz_from_pkl(npz_input / "keep.npz") + _write_npz_from_pkl(npz_input / "float.npz") + + for name, floating in (("keep.npz", False), ("float.npz", True)): + path = npz_input / name + clip = dict(np.load(path, allow_pickle=True)) + body_names = [str(body_name) for body_name in clip["body_names"].tolist()] + left_idx = body_names.index("left_ankle_roll_link") + right_idx = body_names.index("right_ankle_roll_link") + body_pos_w = np.asarray(clip["body_pos_w"]).copy() + if floating: + body_pos_w[..., 2] = np.maximum(body_pos_w[..., 2], 0.3) + else: + body_pos_w[:, left_idx, 2] = 0.0 + body_pos_w[:, right_idx, 2] = 0.3 + clip["body_pos_w"] = body_pos_w + np.savez(path, **clip) + + source = DatasetSourceSpec(name="npz_src", type="npz", input=str(npz_input)) + output_dir = tmp_path / "dataset" / "clips" / "npz_src" + report = convert_source_to_npz_clips( + source, + output_dir, + jobs=1, + preprocess=dataset_builder.DatasetPreprocessSpec( + ground_align="clip_min_foot", + max_all_off_ground_s=0.05, + off_ground_height=0.08, + ), + ) + + assert report["clips"] == 1 + assert (output_dir / "keep.npz").is_file() + assert not (output_dir / "float.npz").exists() + assert (output_dir / "float.npz.filtered.json").is_file() + + def _unexpected_run_conversion_tasks(*_args, **_kwargs): + raise AssertionError("filtered clip should be skipped by marker on incremental rebuild") + + monkeypatch = pytest.MonkeyPatch() + monkeypatch.setattr(dataset_builder, "run_conversion_tasks", _unexpected_run_conversion_tasks) + try: + second_report = convert_source_to_npz_clips( + source, + output_dir, + jobs=1, + preprocess=dataset_builder.DatasetPreprocessSpec( + ground_align="clip_min_foot", + max_all_off_ground_s=0.05, + off_ground_height=0.08, + ), + ) + finally: + monkeypatch.undo() + + assert second_report["clips"] == 1 + + def test_merge_clip_dicts_rejects_inconsistent_body_names(tmp_path: Path) -> None: clip_a = tmp_path / "clip_a.npz" clip_b = tmp_path / "clip_b.npz" diff --git a/tests/test_domain_randomization.py b/tests/test_domain_randomization.py index 0c0941b9..da23645c 100644 --- a/tests/test_domain_randomization.py +++ b/tests/test_domain_randomization.py @@ -19,10 +19,6 @@ def test_general_tracking_domain_randomization_matches_gr00t_active_set() -> Non "add_joint_default_pos", "physics_material", "randomize_rigid_body_mass", - "randomize_dexhand_payload_mass", - "randomize_gimbal_payload_mass", - "randomize_dexhand_payload_pos", - "randomize_gimbal_payload_pos", } push_robot = events["push_robot"] @@ -69,46 +65,6 @@ def test_general_tracking_domain_randomization_matches_gr00t_active_set() -> Non assert mass.params["asset_cfg"].body_names == "torso_link" assert mass.params["alpha_range"] == (-0.1, 0.45) - dexhand_mass = events["randomize_dexhand_payload_mass"] - assert dexhand_mass.func is dr.pseudo_inertia - assert dexhand_mass.mode == "startup" - assert dexhand_mass.params["asset_cfg"].body_names == ( - "left_dexhand_payload", - "right_dexhand_payload", - ) - assert dexhand_mass.params["alpha_range"] == (-1, 0) - - gimbal_mass = events["randomize_gimbal_payload_mass"] - assert gimbal_mass.func is dr.pseudo_inertia - assert gimbal_mass.mode == "startup" - assert gimbal_mass.params["asset_cfg"].body_names == ("head_gimbal_payload",) - assert gimbal_mass.params["alpha_range"] == (-1, 0) - - dexhand_pos = events["randomize_dexhand_payload_pos"] - assert dexhand_pos.func is dr.body_pos - assert dexhand_pos.mode == "startup" - assert dexhand_pos.params["asset_cfg"].body_names == ( - "left_dexhand_payload", - "right_dexhand_payload", - ) - assert dexhand_pos.params["operation"] == "abs" - assert dexhand_pos.params["ranges"] == { - 0: (0.055, 0.095), - 1: (-0.02, 0.02), - 2: (-0.02, 0.02), - } - - gimbal_pos = events["randomize_gimbal_payload_pos"] - assert gimbal_pos.func is dr.body_pos - assert gimbal_pos.mode == "startup" - assert gimbal_pos.params["asset_cfg"].body_names == ("head_gimbal_payload",) - assert gimbal_pos.params["operation"] == "abs" - assert gimbal_pos.params["ranges"] == { - 0: (0.05, 0.09), - 1: (-0.02, 0.02), - 2: (0.43, 0.47), - } - def test_play_env_disables_training_only_domain_randomization() -> None: import mjlab.tasks # noqa: F401 @@ -122,8 +78,4 @@ def test_play_env_disables_training_only_domain_randomization() -> None: assert "add_joint_default_pos" not in play_cfg.events assert "physics_material" not in play_cfg.events assert "randomize_rigid_body_mass" not in play_cfg.events - assert "randomize_dexhand_payload_mass" not in play_cfg.events - assert "randomize_gimbal_payload_mass" not in play_cfg.events - assert "randomize_dexhand_payload_pos" not in play_cfg.events - assert "randomize_gimbal_payload_pos" not in play_cfg.events assert play_cfg.events == {} diff --git a/tests/test_pico4_provider.py b/tests/test_pico4_provider.py index da82c507..97a4ffde 100644 --- a/tests/test_pico4_provider.py +++ b/tests/test_pico4_provider.py @@ -61,7 +61,7 @@ def _make_provider() -> Pico4InputProvider: provider._last_raw_body_joints = None provider._last_frame_timestamp = None provider._last_source_seq = None - provider._ground_lift_offset = None + provider._ground_alignment_offset = None provider._controller_snapshot = None provider._hand_snapshot = None provider._closed = False @@ -169,7 +169,7 @@ def test_pico4_provider_converts_pico_native_body_pose_convention() -> None: np.testing.assert_allclose(frame["Pelvis"][0], [1.0, -3.0, 2.0], atol=1e-6) -def test_pico4_provider_applies_fixed_ground_lift_from_first_real_frame() -> None: +def test_pico4_provider_applies_fixed_ground_alignment_from_first_real_frame() -> None: provider = _make_provider() body_poses = np.zeros((len(BODY_JOINT_NAMES), 7), dtype=np.float64) pelvis_idx = BODY_JOINT_NAMES.index("Pelvis") @@ -184,7 +184,7 @@ def test_pico4_provider_applies_fixed_ground_lift_from_first_real_frame() -> Non first_frame, _, _ = provider._frame_cache.latest_packet() np.testing.assert_allclose(first_frame["Pelvis"][0][2], 0.8 + 0.2, atol=1e-6) np.testing.assert_allclose(first_frame["Left_Ankle"][0][2], 0.0, atol=1e-6) - assert provider._ground_lift_offset == pytest.approx(0.2) + assert provider._ground_alignment_offset == pytest.approx(0.2) body_poses[:, 1] += 0.3 assert provider._accept_pico_frame(_pico_frame(body_poses, seq=2, timestamp=1.1)) is True @@ -193,7 +193,32 @@ def test_pico4_provider_applies_fixed_ground_lift_from_first_real_frame() -> Non np.testing.assert_allclose(second_frame["Left_Ankle"][0][2], 0.3, atol=1e-6) -def test_pico4_provider_recomputes_ground_lift_after_timestamp_gap_reset() -> None: +def test_pico4_provider_aligns_floating_first_frame_down_to_ground() -> None: + provider = _make_provider() + body_poses = np.zeros((len(BODY_JOINT_NAMES), 7), dtype=np.float64) + pelvis_idx = BODY_JOINT_NAMES.index("Pelvis") + left_ankle_idx = BODY_JOINT_NAMES.index("Left_Ankle") + right_ankle_idx = BODY_JOINT_NAMES.index("Right_Ankle") + body_poses[:, 1] = 0.2 + body_poses[pelvis_idx, 0:3] = [0.0, 0.9, 0.0] + body_poses[left_ankle_idx, 0:3] = [0.1, 0.2, 0.0] + body_poses[right_ankle_idx, 0:3] = [-0.1, 0.4, 0.0] + body_poses[:, 6] = 1.0 + + assert provider._accept_pico_frame(_pico_frame(body_poses, seq=1, timestamp=1.0)) is True + first_frame, _, _ = provider._frame_cache.latest_packet() + np.testing.assert_allclose(first_frame["Left_Ankle"][0][2], 0.0, atol=1e-6) + np.testing.assert_allclose(first_frame["Pelvis"][0][2], 0.7, atol=1e-6) + assert provider._ground_alignment_offset == pytest.approx(-0.2) + + body_poses[:, 1] += 0.3 + assert provider._accept_pico_frame(_pico_frame(body_poses, seq=2, timestamp=1.1)) is True + second_frame, _, _ = provider._frame_cache.latest_packet() + np.testing.assert_allclose(second_frame["Left_Ankle"][0][2], 0.3, atol=1e-6) + np.testing.assert_allclose(second_frame["Pelvis"][0][2], 1.0, atol=1e-6) + + +def test_pico4_provider_recomputes_ground_alignment_after_timestamp_gap_reset() -> None: provider = _make_provider() body_poses = np.zeros((len(BODY_JOINT_NAMES), 7), dtype=np.float64) pelvis_idx = BODY_JOINT_NAMES.index("Pelvis") @@ -205,7 +230,7 @@ def test_pico4_provider_recomputes_ground_lift_after_timestamp_gap_reset() -> No body_poses[:, 6] = 1.0 assert provider._accept_pico_frame(_pico_frame(body_poses, seq=1, timestamp=1.0)) is True - assert provider._ground_lift_offset == pytest.approx(0.2) + assert provider._ground_alignment_offset == pytest.approx(0.2) body_poses[pelvis_idx, 1] = 0.7 body_poses[left_ankle_idx, 1] = -0.5 @@ -214,7 +239,7 @@ def test_pico4_provider_recomputes_ground_lift_after_timestamp_gap_reset() -> No latest_frame, _, _ = provider._frame_cache.latest_packet() np.testing.assert_allclose(latest_frame["Left_Ankle"][0][2], 0.0, atol=1e-6) np.testing.assert_allclose(latest_frame["Pelvis"][0][2], 1.2, atol=1e-6) - assert provider._ground_lift_offset == pytest.approx(0.5) + assert provider._ground_alignment_offset == pytest.approx(0.5) assert len(provider._frame_cache) == 1 diff --git a/tests/test_task_registry.py b/tests/test_task_registry.py index 7a437b97..0f6797bf 100644 --- a/tests/test_task_registry.py +++ b/tests/test_task_registry.py @@ -20,25 +20,8 @@ def test_general_tracking_task_is_registered() -> None: env_cfg = load_env_cfg(DEFAULT_TASK) actor_terms = env_cfg.observations["actor"].terms critic_terms = env_cfg.observations["critic"].terms - robot_model = env_cfg.scene.entities["robot"].spec_fn().compile() assert DEFAULT_TASK == GENERAL_TRACKING_TASK - for body_name in ( - "left_dexhand_payload", - "right_dexhand_payload", - "head_gimbal_payload", - ): - assert body_name in { - robot_model.body(i).name for i in range(robot_model.nbody) - } - for geom_name in ( - "left_dexhand_payload_collision", - "right_dexhand_payload_collision", - "head_gimbal_payload_collision", - ): - geom = robot_model.geom(geom_name) - assert int(geom.contype[0]) == 0 - assert int(geom.conaffinity[0]) == 0 for terms in (actor_terms, critic_terms): assert "projected_gravity" in terms assert "ref_base_lin_vel_b" in terms diff --git a/train_mimic/configs/datasets/lafan1.yaml b/train_mimic/configs/datasets/lafan1.yaml index bc58527c..2e34512a 100644 --- a/train_mimic/configs/datasets/lafan1.yaml +++ b/train_mimic/configs/datasets/lafan1.yaml @@ -4,9 +4,13 @@ val_percent: 5 hash_salt: "" preprocess: normalize_root_xy: true - ground_align: clip_min_foot + ground_align: none + max_all_off_ground_s: 2.0 + off_ground_height: 0.2 sources: - name: lafan1 type: bvh input: data/lafan1_bvh bvh_format: lafan1 + exclude_patterns: + - "*obstacle*" diff --git a/train_mimic/data/dataset_builder.py b/train_mimic/data/dataset_builder.py index 3385ccaa..eb7a504c 100644 --- a/train_mimic/data/dataset_builder.py +++ b/train_mimic/data/dataset_builder.py @@ -1,6 +1,8 @@ from __future__ import annotations import csv +import fnmatch +import json import os import shutil import multiprocessing @@ -88,6 +90,7 @@ class DatasetSourceSpec: metadata_csv: str | None = None filters: dict[str, list] | None = None seed_filter_preset: str | None = None + exclude_patterns: tuple[str, ...] = () @dataclass(frozen=True) @@ -150,6 +153,15 @@ class ConversionTask: preprocess: DatasetPreprocessSpec = field(default_factory=DatasetPreprocessSpec) +@dataclass(frozen=True) +class FilteredClipResult: + input_path: str + reason: str + + +_FILTERED_MARKER_SUFFIX = ".filtered.json" + + @dataclass(frozen=True) class SeedFilterRule: columns: tuple[str, ...] @@ -256,10 +268,31 @@ def _load_preprocess_spec(raw: object, spec_path: Path) -> DatasetPreprocessSpec else float(raw["max_all_off_ground_s"]) ), off_ground_height=float(raw.get("off_ground_height", 0.2)), + max_feet_off_ground_s=( + None + if raw.get("max_feet_off_ground_s") in (None, "", "null") + else float(raw["max_feet_off_ground_s"]) + ), + foot_off_ground_height=float(raw.get("foot_off_ground_height", 0.08)), ) return validate_preprocess_spec(spec) +def _load_exclude_patterns(raw: object, spec_path: Path, source_name: str) -> tuple[str, ...]: + if raw is None: + return () + if not isinstance(raw, list): + raise ValueError( + f"source {source_name!r} exclude_patterns must be a list in {spec_path}" + ) + patterns = tuple(str(item).strip() for item in raw if str(item).strip()) + if not patterns: + raise ValueError( + f"source {source_name!r} exclude_patterns must contain at least one non-empty pattern" + ) + return patterns + + def load_dataset_spec(path: str | Path) -> DatasetSpec: spec_path = Path(path).expanduser().resolve() if not spec_path.is_file(): @@ -339,6 +372,11 @@ def load_dataset_spec(path: str | Path) -> DatasetSpec: spec_path=spec_path, source_name=source_name, ) + exclude_patterns = _load_exclude_patterns( + raw.get("exclude_patterns"), + spec_path, + source_name, + ) sources.append( DatasetSourceSpec( @@ -352,6 +390,7 @@ def load_dataset_spec(path: str | Path) -> DatasetSpec: metadata_csv=metadata_csv, filters=filters, seed_filter_preset=seed_filter_preset, + exclude_patterns=exclude_patterns, ) ) @@ -422,20 +461,25 @@ def _filter_seed_csv_by_metadata( input_dir: Path, *, quiet: bool = False, + report: dict[str, Any] | None = None, ) -> tuple[list[SourceInputFile], dict[str, Any]]: """Filter seed_csv files using metadata_csv + filters from the source spec.""" - report: dict[str, Any] = { - "source": source.name, - "type": source.type, - "metadata_csv": source.metadata_csv, - "seed_filter_preset": source.seed_filter_preset, - "scanned_files": len(all_files), - "metadata_rows_matched": len(all_files), - "preset_rejected_rows": 0, - "kept_files": len(all_files), - "filtered_files": 0, - "preset_reject_reasons": {}, - } + if report is None: + report = { + "source": source.name, + "type": source.type, + "metadata_csv": source.metadata_csv, + "seed_filter_preset": source.seed_filter_preset, + "exclude_patterns": list(source.exclude_patterns), + "scanned_files": len(all_files), + "metadata_rows_matched": len(all_files), + "preset_rejected_rows": 0, + "path_rejected_files": 0, + "kept_files": len(all_files), + "filtered_files": 0, + "preset_reject_reasons": {}, + "path_reject_reasons": {}, + } if source.metadata_csv is None or (source.filters is None and source.seed_filter_preset is None): return all_files, report @@ -513,7 +557,7 @@ def _filter_seed_csv_by_metadata( filtered = [f for f in all_files if f.rel_no_suffix.as_posix() in allowed_rels] report["kept_files"] = len(filtered) - report["filtered_files"] = len(all_files) - len(filtered) + report["filtered_files"] = int(report.get("scanned_files", len(all_files))) - len(filtered) if not quiet: print( f"[FILTER] source={source.name}: {len(filtered)}/{len(all_files)} files " @@ -537,24 +581,75 @@ def _collect_source_files_with_report( _ensure_not_dataset_root_npz_input(source, input_path) suffix = _SOURCE_SUFFIXES[source.type] - if input_path.is_file(): - if input_path.suffix.lower() != suffix: - raise ValueError( - f"source {source.name} expected {suffix} input, got file {input_path.name}" - ) - items = [SourceInputFile(path=input_path, rel_no_suffix=Path(input_path.stem))] - report: dict[str, Any] = { + def _base_report(items_count: int) -> dict[str, Any]: + return { "source": source.name, "type": source.type, "metadata_csv": source.metadata_csv, "seed_filter_preset": source.seed_filter_preset, - "scanned_files": len(items), - "metadata_rows_matched": len(items), + "exclude_patterns": list(source.exclude_patterns), + "scanned_files": items_count, + "metadata_rows_matched": items_count, "preset_rejected_rows": 0, - "kept_files": len(items), + "path_rejected_files": 0, + "kept_files": items_count, "filtered_files": 0, "preset_reject_reasons": {}, + "path_reject_reasons": {}, } + + def _matches_exclude(item: SourceInputFile) -> str | None: + rel_no_suffix = item.rel_no_suffix.as_posix() + candidates = ( + rel_no_suffix, + f"{rel_no_suffix}{suffix}", + item.path.name, + item.path.stem, + ) + for pattern in source.exclude_patterns: + pat = pattern.lower() + for candidate in candidates: + if fnmatch.fnmatchcase(candidate.lower(), pat): + return pattern + return None + + def _apply_path_excludes( + items: list[SourceInputFile], + report: dict[str, Any], + ) -> list[SourceInputFile]: + if not source.exclude_patterns: + return items + reject_counts: Counter[str] = Counter() + kept: list[SourceInputFile] = [] + for item in items: + reason = _matches_exclude(item) + if reason is None: + kept.append(item) + else: + reject_counts[reason] += 1 + report["path_rejected_files"] = len(items) - len(kept) + report["path_reject_reasons"] = dict(sorted(reject_counts.items())) + report["kept_files"] = len(kept) + report["filtered_files"] = len(items) - len(kept) + if not quiet and report["path_rejected_files"] > 0: + print( + f"[FILTER] source={source.name}: path_excludes rejected=" + f"{report['path_rejected_files']} reasons={report['path_reject_reasons']}" + ) + if not kept: + raise ValueError( + f"no files remain after path exclude filtering for source {source.name}: {input_path}" + ) + return kept + + if input_path.is_file(): + if input_path.suffix.lower() != suffix: + raise ValueError( + f"source {source.name} expected {suffix} input, got file {input_path.name}" + ) + items = [SourceInputFile(path=input_path, rel_no_suffix=Path(input_path.stem))] + report = _base_report(len(items)) + items = _apply_path_excludes(items, report) return items, input_path.parent, report if not input_path.is_dir(): @@ -573,22 +668,18 @@ def _collect_source_files_with_report( for path in files ] - report: dict[str, Any] = { - "source": source.name, - "type": source.type, - "metadata_csv": source.metadata_csv, - "seed_filter_preset": source.seed_filter_preset, - "scanned_files": len(items), - "metadata_rows_matched": len(items), - "preset_rejected_rows": 0, - "kept_files": len(items), - "filtered_files": 0, - "preset_reject_reasons": {}, - } + report = _base_report(len(items)) + items = _apply_path_excludes(items, report) # Apply metadata filtering for seed_csv sources if source.type == "seed_csv" and source.metadata_csv is not None: - items, report = _filter_seed_csv_by_metadata(source, items, input_path, quiet=quiet) + items, report = _filter_seed_csv_by_metadata( + source, + items, + input_path, + quiet=quiet, + report=report, + ) if not items: raise ValueError( f"no files remain after metadata filtering for source {source.name}: {input_path}" @@ -641,12 +732,94 @@ def build_source_conversion_tasks( return tasks -def _source_has_cached_npz(out_dir: Path) -> bool: - return out_dir.is_dir() and any(out_dir.rglob("*.npz")) +def _filtered_marker_path(output_path: Path) -> Path: + return output_path.with_name(f"{output_path.name}{_FILTERED_MARKER_SUFFIX}") + + +def _conversion_task_signature(task: ConversionTask) -> dict[str, Any]: + input_path = Path(task.input_path) + stat = input_path.stat() + signature = { + "source_name": task.source_name, + "source_type": task.source_type, + "input_path": str(input_path), + "input_size": int(stat.st_size), + "input_mtime_ns": int(stat.st_mtime_ns), + "bvh_format": task.bvh_format, + "robot_name": task.robot_name, + "max_frames": int(task.max_frames), + "mocap_xml": task.mocap_xml, + "preprocess": task.preprocess.to_dict(), + } + return json.loads(json.dumps(signature, sort_keys=True)) + + +def _filtered_marker_matches(task: ConversionTask) -> bool: + marker_path = _filtered_marker_path(Path(task.output_path)) + if not marker_path.is_file(): + return False + try: + payload = json.loads(marker_path.read_text(encoding="utf-8")) + except (OSError, json.JSONDecodeError): + marker_path.unlink(missing_ok=True) + return False + if payload.get("signature") == _conversion_task_signature(task): + return True + marker_path.unlink(missing_ok=True) + return False + + +def _write_filtered_marker(task: ConversionTask, reason: str) -> None: + marker_path = _filtered_marker_path(Path(task.output_path)) + write_json( + marker_path, + { + "filtered": True, + "reason": reason, + "signature": _conversion_task_signature(task), + }, + ) + + +def _clear_filtered_marker(output_path: Path) -> None: + _filtered_marker_path(output_path).unlink(missing_ok=True) + + +def _prune_unexpected_source_outputs(out_dir: Path, tasks: list[ConversionTask]) -> None: + if not out_dir.is_dir(): + return + expected = {Path(task.output_path).resolve() for task in tasks} + for npz_path in sorted(out_dir.rglob("*.npz")): + if npz_path.resolve() not in expected: + npz_path.unlink() + _clear_filtered_marker(npz_path) + for marker_path in sorted(out_dir.rglob(f"*.npz{_FILTERED_MARKER_SUFFIX}")): + output_path = Path(str(marker_path)[: -len(_FILTERED_MARKER_SUFFIX)]) + if output_path.resolve() not in expected: + marker_path.unlink(missing_ok=True) + + +def _current_source_npz_files(source: DatasetSourceSpec, source_dir: Path) -> list[Path]: + items, _scan_root, _report = _collect_source_files_with_report(source, quiet=True) + allowed_rels = {item.rel_no_suffix.as_posix() for item in items} + return [ + npz_path + for npz_path in sorted(source_dir.rglob("*.npz")) + if npz_path.relative_to(source_dir).with_suffix("").as_posix() in allowed_rels + ] def _pending_tasks(tasks: list[ConversionTask]) -> list[ConversionTask]: - return [task for task in tasks if not Path(task.output_path).is_file()] + pending: list[ConversionTask] = [] + for task in tasks: + output_path = Path(task.output_path) + if output_path.is_file(): + _clear_filtered_marker(output_path) + continue + if _filtered_marker_matches(task): + continue + pending.append(task) + return pending def _build_source_filter_reports(spec: DatasetSpec) -> list[dict[str, Any]]: @@ -699,7 +872,7 @@ def _maybe_preprocess_npz_file( np.savez(npz_path, **processed) -def _convert_task(task: ConversionTask) -> str: +def _convert_task(task: ConversionTask) -> str | FilteredClipResult: input_path = Path(task.input_path) output_path = Path(task.output_path) output_path.parent.mkdir(parents=True, exist_ok=True) @@ -708,36 +881,42 @@ def _convert_task(task: ConversionTask) -> str: if task.source_type == "npz": payload = np.load(input_path, allow_pickle=True) clip_dict = {key: payload[key] for key in payload.files} + clip_label = f"{task.source_name}:{input_path.name}" processed = _maybe_preprocess_clip_dict( clip_dict, preprocess=task.preprocess, - clip_label=f"{task.source_name}:{input_path.name}", + clip_label=clip_label, ) np.savez(output_path, **processed) inspect_npz(output_path) + _clear_filtered_marker(output_path) return str(output_path) extractor = _get_fk_extractor() if task.source_type == "pkl": convert_pkl_to_npz(str(input_path), str(output_path), extractor=extractor) + clip_label = f"{task.source_name}:{input_path.name}" _maybe_preprocess_npz_file( output_path, preprocess=task.preprocess, - clip_label=f"{task.source_name}:{input_path.name}", + clip_label=clip_label, ) inspect_npz(output_path) + _clear_filtered_marker(output_path) return str(output_path) if task.source_type == "seed_csv": arrays = convert_seed_csv_to_arrays(str(input_path), extractor=extractor) + clip_label = f"{task.source_name}:{input_path.name}" arrays = _maybe_preprocess_clip_dict( arrays, preprocess=task.preprocess, - clip_label=f"{task.source_name}:{input_path.name}", + clip_label=clip_label, ) output_path.parent.mkdir(parents=True, exist_ok=True) np.savez(str(output_path), **arrays) inspect_npz(output_path) + _clear_filtered_marker(output_path) return str(output_path) if task.source_type == "bvh": @@ -755,15 +934,26 @@ def _convert_task(task: ConversionTask) -> str: model, ) convert_pkl_to_npz(str(tmp_pkl), str(output_path), extractor=extractor) + clip_label = f"{task.source_name}:{input_path.name}" _maybe_preprocess_npz_file( output_path, preprocess=task.preprocess, - clip_label=f"{task.source_name}:{input_path.name}", + clip_label=clip_label, ) inspect_npz(output_path) + _clear_filtered_marker(output_path) return str(output_path) raise ValueError(f"unsupported source type: {task.source_type}") + except ValueError as exc: + clip_label = f"{task.source_name}:{input_path.name}" + if str(exc).startswith(f"{clip_label}:"): + output_path.unlink(missing_ok=True) + _write_filtered_marker(task, str(exc)) + return FilteredClipResult(input_path=str(input_path), reason=str(exc)) + raise RuntimeError( + f"failed converting source={task.source_name} input={input_path}: {exc}" + ) from exc except Exception as exc: raise RuntimeError( f"failed converting source={task.source_name} input={input_path}: {exc}" @@ -773,7 +963,10 @@ def _convert_task(task: ConversionTask) -> str: def _run_conversion_tasks_serial(tasks: list[ConversionTask]) -> None: total = len(tasks) for idx, task in enumerate(tasks, start=1): - _convert_task(task) + result = _convert_task(task) + if isinstance(result, FilteredClipResult): + print(f"[FILTER] {result.reason}") + continue print(f"[CONVERT] {idx}/{total} source={task.source_name} -> {_display_path(Path(task.output_path))}") @@ -797,12 +990,15 @@ def run_conversion_tasks(tasks: list[ConversionTask], *, jobs: int = DEFAULT_JOB try: for future in as_completed(future_map): task = future_map[future] - future.result() + result = future.result() completed += 1 - print( - f"[CONVERT] {completed}/{total} " - f"source={task.source_name} -> {_display_path(Path(task.output_path))}" - ) + if isinstance(result, FilteredClipResult): + print(f"[FILTER] {result.reason}") + else: + print( + f"[CONVERT] {completed}/{total} " + f"source={task.source_name} -> {_display_path(Path(task.output_path))}" + ) except Exception: for future in future_map: future.cancel() @@ -823,14 +1019,15 @@ def convert_source_to_npz_clips( if force and output_dir.exists(): shutil.rmtree(output_dir) tasks = build_source_conversion_tasks(source, output_dir, preprocess=preprocess) + _prune_unexpected_source_outputs(output_dir, tasks) pending = _pending_tasks(tasks) - if tasks and not pending and _source_has_cached_npz(output_dir): + if tasks and not pending: print(f"[CACHE] reusing source={source.name} clips: {_display_path(output_dir)}") else: output_dir.mkdir(parents=True, exist_ok=True) - run_conversion_tasks(pending or tasks, jobs=jobs) + run_conversion_tasks(pending, jobs=jobs) - npz_files = sorted(output_dir.rglob("*.npz")) + npz_files = _current_source_npz_files(source, output_dir) if not npz_files: raise ValueError(f"no converted npz clips found for source {source.name}: {output_dir}") return { @@ -859,12 +1056,13 @@ def convert_sources_to_npz( for source in spec.sources: out_dir = source_out_dirs[source.name] tasks = build_source_conversion_tasks(source, out_dir, preprocess=spec.preprocess) + _prune_unexpected_source_outputs(out_dir, tasks) pending = _pending_tasks(tasks) - if tasks and not pending and _source_has_cached_npz(out_dir): + if tasks and not pending: print(f"[CACHE] reusing source={source.name} clips: {_display_path(out_dir)}") continue out_dir.mkdir(parents=True, exist_ok=True) - pending_tasks.extend(pending or tasks) + pending_tasks.extend(pending) run_conversion_tasks(pending_tasks, jobs=jobs) return source_out_dirs @@ -876,7 +1074,7 @@ def collect_clip_rows(spec: DatasetSpec, *, paths: DatasetPaths) -> list[Dataset source_dir = paths.clips_root / source.name if not source_dir.is_dir(): raise FileNotFoundError(f"expected converted npz dir for {source.name}: {source_dir}") - npz_files = sorted(source_dir.rglob("*.npz")) + npz_files = _current_source_npz_files(source, source_dir) if not npz_files: raise ValueError(f"no converted npz clips found for source {source.name}: {source_dir}") for npz_path in npz_files: diff --git a/train_mimic/data/preprocess.py b/train_mimic/data/preprocess.py index 4e565024..ed6e6b35 100644 --- a/train_mimic/data/preprocess.py +++ b/train_mimic/data/preprocess.py @@ -23,6 +23,8 @@ class DatasetPreprocessSpec: min_peak_body_height: float | None = None max_all_off_ground_s: float | None = None off_ground_height: float = 0.2 + max_feet_off_ground_s: float | None = None + foot_off_ground_height: float = 0.08 def to_dict(self) -> dict[str, Any]: return asdict(self) @@ -55,6 +57,16 @@ def validate_preprocess_spec(spec: DatasetPreprocessSpec) -> DatasetPreprocessSp raise ValueError( f"preprocess.off_ground_height must be >= 0, got {spec.off_ground_height}" ) + if spec.max_feet_off_ground_s is not None and spec.max_feet_off_ground_s <= 0.0: + raise ValueError( + "preprocess.max_feet_off_ground_s must be > 0, " + f"got {spec.max_feet_off_ground_s}" + ) + if spec.foot_off_ground_height < 0.0: + raise ValueError( + "preprocess.foot_off_ground_height must be >= 0, " + f"got {spec.foot_off_ground_height}" + ) return spec @@ -120,16 +132,6 @@ def preprocess_clip_dict( body_pos_w[..., 0] -= offset_xy[0] body_pos_w[..., 1] -= offset_xy[1] - foot_indices: tuple[int, int] | None = None - if spec.ground_align != "none": - foot_indices = ( - _body_index(body_names, spec.foot_body_names[0], label="foot"), - _body_index(body_names, spec.foot_body_names[1], label="foot"), - ) - foot_z = body_pos_w[:, foot_indices, 2] - if spec.ground_align == "clip_min_foot": - body_pos_w[..., 2] -= float(np.min(foot_z)) - if spec.max_root_lin_vel is not None: assert root_index is not None peak_root_lin_vel = float(np.max(np.abs(body_lin_vel_w[:, root_index, :]))) @@ -149,6 +151,29 @@ def preprocess_clip_dict( f"{spec.max_all_off_ground_s:.3f}s" ) + foot_indices: tuple[int, int] | None = None + if spec.ground_align != "none" or spec.max_feet_off_ground_s is not None: + foot_indices = ( + _body_index(body_names, spec.foot_body_names[0], label="foot"), + _body_index(body_names, spec.foot_body_names[1], label="foot"), + ) + + if spec.max_feet_off_ground_s is not None: + assert foot_indices is not None + foot_z = body_pos_w[:, foot_indices, 2] + both_feet_off = np.all(foot_z > spec.foot_off_ground_height, axis=1) + longest_run = _longest_true_run(both_feet_off) + if longest_run > int(round(spec.max_feet_off_ground_s * fps)): + raise ValueError( + f"{clip_label}: both feet off ground for {longest_run / fps:.3f}s, exceeds " + f"{spec.max_feet_off_ground_s:.3f}s" + ) + + if spec.ground_align == "clip_min_foot": + assert foot_indices is not None + foot_z = body_pos_w[:, foot_indices, 2] + body_pos_w[..., 2] -= float(np.min(foot_z)) + if spec.min_peak_body_height is not None: peak_height = float(np.max(body_pos_w[:, :, 2])) if peak_height < spec.min_peak_body_height: diff --git a/train_mimic/tasks/tracking/config/env.py b/train_mimic/tasks/tracking/config/env.py index 78a89e43..95932374 100644 --- a/train_mimic/tasks/tracking/config/env.py +++ b/train_mimic/tasks/tracking/config/env.py @@ -4,7 +4,6 @@ from copy import deepcopy -import mujoco from mjlab.asset_zoo.robots import G1_ACTION_SCALE, get_g1_robot_cfg from mjlab.envs import ManagerBasedRlEnvCfg from mjlab.envs.mdp.actions import JointPositionActionCfg @@ -14,7 +13,6 @@ from mjlab.sensor import ContactMatch, ContactSensorCfg from mjlab.utils.noise import UniformNoiseCfg as Unoise -from teleopit.runtime.assets import UNITREE_G1_MJLAB_PAYLOAD_XML from train_mimic.tasks.tracking import mdp from train_mimic.tasks.tracking.config.constants import DEFAULT_TRAIN_MOTION_FILE from train_mimic.tasks.tracking.mdp import MotionCommandCfg @@ -43,10 +41,6 @@ "add_joint_default_pos", "physics_material", "randomize_rigid_body_mass", - "randomize_dexhand_payload_mass", - "randomize_gimbal_payload_mass", - "randomize_dexhand_payload_pos", - "randomize_gimbal_payload_pos", ) @@ -118,16 +112,6 @@ def _add_history_obs_groups( } -def _payload_g1_spec() -> mujoco.MjSpec: - return mujoco.MjSpec.from_file(str(UNITREE_G1_MJLAB_PAYLOAD_XML)) - - -def _payload_g1_robot_cfg(): - robot_cfg = get_g1_robot_cfg() - robot_cfg.spec_fn = _payload_g1_spec - return robot_cfg - - def _configure_self_collision_reward(cfg: ManagerBasedRlEnvCfg) -> None: excluded_body_names = ( "left_wrist_yaw_link", @@ -171,31 +155,13 @@ def _configure_feet_acc_reward(cfg: ManagerBasedRlEnvCfg) -> None: ) -def _configure_payload_randomization(cfg: ManagerBasedRlEnvCfg) -> None: - cfg.events["randomize_rigid_body_mass"].params[ - "asset_cfg" - ].body_names = "torso_link" - cfg.events["randomize_dexhand_payload_mass"].params[ - "asset_cfg" - ].body_names = ("left_dexhand_payload", "right_dexhand_payload") - cfg.events["randomize_gimbal_payload_mass"].params[ - "asset_cfg" - ].body_names = ("head_gimbal_payload",) - cfg.events["randomize_dexhand_payload_pos"].params[ - "asset_cfg" - ].body_names = ("left_dexhand_payload", "right_dexhand_payload") - cfg.events["randomize_gimbal_payload_pos"].params[ - "asset_cfg" - ].body_names = ("head_gimbal_payload",) - - def make_general_tracking_env_cfg( *, play: bool = False, ) -> ManagerBasedRlEnvCfg: """Create the General-Tracking-G1 training env.""" cfg = make_tracking_env_cfg() - cfg.scene.entities = {"robot": _payload_g1_robot_cfg()} + cfg.scene.entities = {"robot": get_g1_robot_cfg()} joint_pos_action = cfg.actions["joint_pos"] assert isinstance(joint_pos_action, JointPositionActionCfg) @@ -213,7 +179,9 @@ def make_general_tracking_env_cfg( "asset_cfg" ].geom_names = r".*_collision$" cfg.events["base_com"].params["asset_cfg"].body_names = ("torso_link",) - _configure_payload_randomization(cfg) + cfg.events["randomize_rigid_body_mass"].params[ + "asset_cfg" + ].body_names = "torso_link" _configure_self_collision_reward(cfg) _configure_feet_acc_reward(cfg) cfg.terminations["ee_body_pos"].params["body_names"] = ( diff --git a/train_mimic/tasks/tracking/tracking_env_cfg.py b/train_mimic/tasks/tracking/tracking_env_cfg.py index 29b21dff..b6ed43c7 100644 --- a/train_mimic/tasks/tracking/tracking_env_cfg.py +++ b/train_mimic/tasks/tracking/tracking_env_cfg.py @@ -33,22 +33,6 @@ "yaw": (-0.78, 0.78), } -_DEXHAND_PAYLOAD_MASS_ALPHA_RANGE = (-1, 0) -_GIMBAL_PAYLOAD_MASS_ALPHA_RANGE = (-1, 0) -_DEXHAND_PAYLOAD_POS_RANGES_MM = { - 0: (55, 95), - 1: (-20, 20), - 2: (-20, 20), -} -_GIMBAL_PAYLOAD_POS_RANGES_MM = { - 0: (50, 90), - 1: (-20, 20), - 2: (430, 470), -} - - -def _mm_ranges_to_m(ranges_mm: dict[int, tuple[int, int]]) -> dict[int, tuple[float, float]]: - return {axis: (lower / 1000.0, upper / 1000.0) for axis, (lower, upper) in ranges_mm.items()} def make_tracking_env_cfg() -> ManagerBasedRlEnvCfg: """Create base tracking task configuration.""" @@ -221,42 +205,6 @@ def make_tracking_env_cfg() -> ManagerBasedRlEnvCfg: "alpha_range": (-0.1, 0.45), }, ), - "randomize_dexhand_payload_mass": EventTermCfg( - mode="startup", - func=dr.pseudo_inertia, - params={ - "asset_cfg": SceneEntityCfg("robot", body_names=()), # Set per-robot. - # Nominal is 0.5 kg per hand. Keep a tighter ~0.37-1.0x band. - "alpha_range": _DEXHAND_PAYLOAD_MASS_ALPHA_RANGE, - }, - ), - "randomize_gimbal_payload_mass": EventTermCfg( - mode="startup", - func=dr.pseudo_inertia, - params={ - "asset_cfg": SceneEntityCfg("robot", body_names=()), # Set per-robot. - # Nominal is 0.25 kg. Keep a tighter ~0.37-1.0x band. - "alpha_range": _GIMBAL_PAYLOAD_MASS_ALPHA_RANGE, - }, - ), - "randomize_dexhand_payload_pos": EventTermCfg( - mode="startup", - func=dr.body_pos, - params={ - "asset_cfg": SceneEntityCfg("robot", body_names=()), # Set per-robot. - "operation": "abs", - "ranges": _mm_ranges_to_m(_DEXHAND_PAYLOAD_POS_RANGES_MM), - }, - ), - "randomize_gimbal_payload_pos": EventTermCfg( - mode="startup", - func=dr.body_pos, - params={ - "asset_cfg": SceneEntityCfg("robot", body_names=()), # Set per-robot. - "operation": "abs", - "ranges": _mm_ranges_to_m(_GIMBAL_PAYLOAD_POS_RANGES_MM), - }, - ), } ## From f96878e09f21892097e2b9b33340d4b529d07229 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Wed, 10 Jun 2026 16:50:31 +0800 Subject: [PATCH 071/122] Add Pico motion recorder --- AGENTS.md | 11 +- README.md | 24 ++ docs/docs/getting-started/installation.md | 2 +- docs/docs/reference/dataset.md | 31 +- docs/docs/tutorials/training.md | 2 +- .../current/getting-started/installation.md | 2 +- .../current/reference/dataset.md | 28 +- .../current/tutorials/training.md | 2 +- pyproject.toml | 1 + scripts/run/record_pico_motion.py | 373 ++++++++++++++++++ teleopit/configs/pico4_record.yaml | 24 ++ teleopit/recording/__init__.py | 20 +- teleopit/recording/pico_motion.py | 233 +++++++++++ tests/test_dataset_v2.py | 10 +- tests/test_pico_motion_recording.py | 146 +++++++ tests/test_train_script.py | 87 +++- train_mimic/configs/datasets/seed.yaml | 2 +- train_mimic/configs/datasets/seed_clean.yaml | 2 +- train_mimic/configs/datasets/twist2.yaml | 2 +- train_mimic/data/preprocess.py | 6 +- train_mimic/scripts/train.py | 89 ++++- 21 files changed, 1062 insertions(+), 35 deletions(-) create mode 100644 scripts/run/record_pico_motion.py create mode 100644 teleopit/configs/pico4_record.yaml create mode 100644 teleopit/recording/pico_motion.py create mode 100644 tests/test_pico_motion_recording.py diff --git a/AGENTS.md b/AGENTS.md index 30e48481..c81a10fa 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -57,10 +57,11 @@ teleopit/ # Core inference package ├── sim2real/ │ ├── mp/ # Process-isolated sim2real runtime and IPC │ └── hands/ # Optional LinkerHand L6 driver/mapper plugins -└── recording/ # HDF5Recorder +└── recording/ # HDF5Recorder and Pico motion NPZ recording helpers scripts/ -├── run_sim.py # Offline sim2sim pipeline -├── run_sim2real.py # G1 sim2real control; supports offline BVH playback and Pico4 +├── run/run_sim.py # Offline sim2sim pipeline +├── run/run_sim2real.py # G1 sim2real control; supports offline BVH playback and Pico4 +├── run/record_pico_motion.py # Interactive Pico recording → G1 motion NPZ clips ├── render_sim.py # Render single BVH → 3 MuJoCo videos (mocap input, retarget, sim2sim) └── compute_ik_offsets.py # Compute IK quaternion offsets for new BVH formats train_mimic/ # Training package @@ -199,11 +200,15 @@ The single supported training task is `General-Tracking-G1` (experiment name: `g - Final dataset outputs are shard-only: `data/datasets//train/shard_*.npz` and `data/datasets//val/shard_*.npz` - Each shard stores clip-aware metadata (`clip_starts`, `clip_lengths`, `clip_fps`, `clip_weights`); `MotionLib` loads only shard directories - `MotionLib` samples only valid center frames for the configured `window_steps`; default is `window_steps=[0]` +- `scripts/run/record_pico_motion.py` records Pico live body tracking as retargeted G1 motion NPZ clips in `data/pico_motion/clips/`; it opens a live `Retarget` viewer, uses terminal keys `R/S/D/N/Q`, stores semantic labels in filenames, and intentionally does not write per-clip JSON +- Build Pico-recorded clips into shards with `python train_mimic/scripts/data/build_dataset.py --spec data/pico_motion/pico_recorded.yaml --force`; at least two clips are required for non-empty train/val splits Quick reference: ```bash python train_mimic/scripts/data/build_dataset.py --spec train_mimic/configs/datasets/twist2.yaml +python scripts/run/record_pico_motion.py +python train_mimic/scripts/data/build_dataset.py --spec data/pico_motion/pico_recorded.yaml --force python train_mimic/scripts/train.py --motion_file data/datasets/twist2/train python train_mimic/scripts/save_onnx.py --checkpoint logs/rsl_rl/g1_general_tracking//model_30000.pt --output policy.onnx --history_length 10 ``` diff --git a/README.md b/README.md index 8faf086f..0a1b2cf0 100644 --- a/README.md +++ b/README.md @@ -57,6 +57,29 @@ python scripts/run/run_sim.py \ For sim2real, viewers are disabled by default. Add `viewers=retarget` to show the retargeted reference in an optional MuJoCo window. +## Pico Motion Recording + +Record many Pico clips as training-ready G1 motion NPZ files: + +```bash +pip install -e '.[pico4]' +python scripts/run/record_pico_motion.py +``` + +The recorder starts the Pico receiver and live Retarget viewer before waiting +for clip names, so preview keeps running while the terminal is idle. Enter a +semantic clip name, then use `R` to start, `S` to save, `D` to discard, `N` for +a new name, and `Q` to quit. Saved clips are written to +`data/pico_motion/clips/` using the semantic label in the filename, with no +sidecar JSON. + +Merge recorded clips into the standard shard dataset: + +```bash +python train_mimic/scripts/data/build_dataset.py \ + --spec data/pico_motion/pico_recorded.yaml --force +``` + ## Documentation Full docs at **[BotRunner64.github.io/Teleopit](https://BotRunner64.github.io/Teleopit/)**, covering installation profiles, all tutorials, configuration reference, and architecture. @@ -70,6 +93,7 @@ Full docs at **[BotRunner64.github.io/Teleopit](https://BotRunner64.github.io/Te - Set LinkerHand L6 `vr_hand_pose` control to maximum speed while keeping `gripper` at the configured default speed. - Switched default `vr_hand_pose` to a low-latency somehand path with 60 Hz hand retargeting and reduced smoothing. - Realtime mode switches and pause/resume now preserve GMR IK warm-starts instead of cold-starting the retargeter on each transition. +- Added an interactive Pico motion recorder that saves retargeted G1 motion clips as training-ready NPZ files. ### v0.3.0 (2026-05-12) diff --git a/docs/docs/getting-started/installation.md b/docs/docs/getting-started/installation.md index 7af73776..807d4159 100644 --- a/docs/docs/getting-started/installation.md +++ b/docs/docs/getting-started/installation.md @@ -32,7 +32,7 @@ This is sufficient for offline BVH playback and MuJoCo simulation. pip install -e '.[train]' ``` -Adds `rsl-rl-lib`, `mjlab`, `wandb`, and training dependencies. +Adds `rsl-rl-lib`, `mjlab`, `wandb`, `swanlab`, and training dependencies. ### Sim2Real (Hardware Deployment) diff --git a/docs/docs/reference/dataset.md b/docs/docs/reference/dataset.md index 53bd41d2..f2257af5 100644 --- a/docs/docs/reference/dataset.md +++ b/docs/docs/reference/dataset.md @@ -20,6 +20,33 @@ For custom dataset construction, read on. --- +## Record Pico Clips + +Use the interactive Pico recorder to create training-ready NPZ clips from live +body tracking: + +```bash +pip install -e '.[pico4]' +python scripts/run/record_pico_motion.py +``` + +The recorder starts the Pico receiver and live `Retarget` viewer before waiting +for clip names, so preview keeps running while the terminal is idle. Enter a +semantic clip name, then use `R` to start, `S` to save, `D` to discard, `N` to +enter a new name, and `Q` to quit. Saved clips go to +`data/pico_motion/clips/` as `_.npz`; no per-clip +JSON is written, so clips can be renamed or deleted manually. + +Build all recorded clips into the standard shard dataset: + +```bash +python train_mimic/scripts/data/build_dataset.py \ + --spec data/pico_motion/pico_recorded.yaml --force +``` + +Record at least two clips before building so both train and validation splits +can be populated. + ## Custom Dataset Construction Data pipeline: `typed source YAML -> preprocess/filter -> shard-only training data` @@ -57,7 +84,7 @@ val_percent: 5 hash_salt: "" preprocess: normalize_root_xy: true - ground_align: clip_min_foot + ground_align: first_frame_foot sources: - name: OMOMO_g1_GMR type: pkl @@ -77,7 +104,7 @@ sources: | `val_percent` | Validation split percentage (hash-based on clip_id) | | `hash_salt` | Optional split salt | | `preprocess.normalize_root_xy` | Normalize root body first-frame xy to origin | -| `preprocess.ground_align` | `none` / `clip_min_foot` | +| `preprocess.ground_align` | `none` / `first_frame_foot` | | `preprocess.min_frames` | Minimum clip length | | `preprocess.max_root_lin_vel` | Root linear velocity filter threshold | | `preprocess.min_peak_body_height` | Minimum peak body height | diff --git a/docs/docs/tutorials/training.md b/docs/docs/tutorials/training.md index c49d52f6..5fe8b972 100644 --- a/docs/docs/tutorials/training.md +++ b/docs/docs/tutorials/training.md @@ -73,7 +73,7 @@ torchrun \ **Notes:** - `--num_envs` is per-GPU in multi-GPU mode - `--num_envs` is also per-process in multi-node mode, so total environments scale with `world_size` -- Default logger is TensorBoard; pass `--wandb_project ` to enable W&B +- Default logger is TensorBoard. Use `--logger wandb` or `--logger swanlab` to select W&B or SwanLab; the project name defaults to `experiment_name` - `--motion_file` accepts only shard directories (containing `shard_*.npz` files) - `--max_iterations` means additional iterations; resuming from `model_12000.pt` with `--max_iterations 18000` trains to `model_30000.pt` diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md index 5317b6d0..3bc1ac17 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md @@ -32,7 +32,7 @@ pip install -e . pip install -e '.[train]' ``` -额外安装 `rsl-rl-lib`、`mjlab`、`wandb` 等训练相关依赖。 +额外安装 `rsl-rl-lib`、`mjlab`、`wandb`、`swanlab` 等训练相关依赖。 ### Sim2Real(硬件部署) diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/dataset.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/dataset.md index 71802832..0ab0e366 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/dataset.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/dataset.md @@ -20,6 +20,30 @@ python train_mimic/scripts/train.py --motion_file data/datasets/seed/train --- +## 录制 Pico clips + +使用交互式 Pico 录制脚本,从实时 body tracking 生成训练可用的 NPZ clips: + +```bash +pip install -e '.[pico4]' +python scripts/run/record_pico_motion.py +``` + +录制器会先启动 Pico receiver 和实时 `Retarget` viewer,再等待输入 clip 名; +因此终端空闲时预览仍会持续运行。输入动作语义名后,用 `R` 开始录制、`S` +保存、`D` 丢弃、`N` 输入新名字、`Q` 退出。保存的 clip 会写入 +`data/pico_motion/clips/`,文件名格式为 `_.npz`;不会写 +每段 clip 的 JSON,因此可以手动改名或删除。 + +将所有已录制 clips 构建为标准 shard 数据集: + +```bash +python train_mimic/scripts/data/build_dataset.py \ + --spec data/pico_motion/pico_recorded.yaml --force +``` + +构建前至少录制两段 clip,确保 train 和 validation split 都能生成。 + ## 自定义构建 数据主线:`typed source YAML -> preprocess/filter -> shard-only 训练数据` @@ -57,7 +81,7 @@ val_percent: 5 hash_salt: "" preprocess: normalize_root_xy: true - ground_align: clip_min_foot + ground_align: first_frame_foot sources: - name: OMOMO_g1_GMR type: pkl @@ -77,7 +101,7 @@ sources: | `val_percent` | 基于 `clip_id` hash 的验证集比例 | | `hash_salt` | 可选 split salt | | `preprocess.normalize_root_xy` | 是否把根 body 首帧 xy 平移到原点 | -| `preprocess.ground_align` | `none` / `clip_min_foot` | +| `preprocess.ground_align` | `none` / `first_frame_foot` | | `preprocess.min_frames` | clip 最短长度约束 | | `preprocess.max_root_lin_vel` / `min_peak_body_height` / `max_all_off_ground_s` | 基础过滤阈值 | | `sources[].name` | source 名称;生成 clip 中间产物时也作为 `clips//` 子目录名 | diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md index dc9387ba..4ebc2521 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md @@ -73,7 +73,7 @@ torchrun \ **注意事项:** - 多卡模式下 `--num_envs` 为每张 GPU 的环境数量 - 多机模式下 `--num_envs` 也按每个进程计算,因此总环境数会随 `world_size` 线性增长 -- 默认日志工具为 TensorBoard;传入 `--wandb_project ` 可启用 W&B +- 默认日志工具为 TensorBoard。使用 `--logger wandb` 或 `--logger swanlab` 可选择 W&B 或 SwanLab;项目名默认使用 `experiment_name` - `--motion_file` 仅接受分片目录(包含 `shard_*.npz` 文件的目录) - `--max_iterations` 表示追加迭代次数;例如从 `model_12000.pt` 恢复训练并设置 `--max_iterations 18000`,最终将训练到 `model_30000.pt` diff --git a/pyproject.toml b/pyproject.toml index 924f10e7..ca01605b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ train = [ "rsl-rl-lib", "mjlab>=1.2.0", "wandb>=0.15.0", + "swanlab", "tqdm>=4.65.0", ] pico4 = [ diff --git a/scripts/run/record_pico_motion.py b/scripts/run/record_pico_motion.py new file mode 100644 index 00000000..380afbc4 --- /dev/null +++ b/scripts/run/record_pico_motion.py @@ -0,0 +1,373 @@ +"""Interactively record Pico-retargeted G1 motion clips as training NPZ files.""" + +from __future__ import annotations + +import logging +import threading +import time +from pathlib import Path + +import hydra +import numpy as np +from omegaconf import DictConfig + +from teleopit.constants import FULL_QPOS_DIM +from teleopit.inputs.pico4_provider import Pico4InputProvider +from teleopit.recording.pico_motion import ( + PicoDatasetSpec, + RecordingState, + ensure_pico_dataset_spec, + sanitize_clip_name, + unique_clip_path, + write_motion_clip_npz, +) +from teleopit.retargeting.core import RetargetingModule +from teleopit.runtime.assets import PROJECT_ROOT, UNITREE_G1_MJLAB_XML, missing_gmr_assets_message +from teleopit.runtime.common import cfg_get +from teleopit.runtime.terminal_keyboard import TerminalKeyboardReader +from teleopit.sim.viewer_subprocess import start_robot_viewer + + +class RetargetPreview: + """Small wrapper around the existing MuJoCo retarget viewer subprocess.""" + + def __init__(self, xml_path: str | Path = UNITREE_G1_MJLAB_XML, *, enabled: bool = True) -> None: + self.enabled = bool(enabled) + self._proc = None + self._arr = None + self._alive = None + self._shutdown = None + self._qpos_len = FULL_QPOS_DIM + if not self.enabled: + return + + path = Path(xml_path).expanduser().resolve() + if not path.is_file(): + raise FileNotFoundError(missing_gmr_assets_message(path, label="G1 MuJoCo XML for retarget viewer")) + + import mujoco + + model = mujoco.MjModel.from_xml_path(str(path)) + self._qpos_len = int(model.nq) + self._proc, self._arr, self._alive, self._shutdown = start_robot_viewer( + str(path), + self._qpos_len, + True, + "Retarget", + 900, + 50, + ) + initial = np.zeros((self._qpos_len,), dtype=np.float64) + if self._qpos_len > 3: + initial[3] = 1.0 + self.update(initial) + + def update(self, qpos: np.ndarray) -> None: + if not self.enabled or self._arr is None: + return + qpos_arr = np.asarray(qpos, dtype=np.float64).reshape(-1) + if qpos_arr.shape[0] > self._qpos_len: + raise ValueError(f"retarget qpos has {qpos_arr.shape[0]} values, viewer accepts {self._qpos_len}") + out = np.zeros((self._qpos_len,), dtype=np.float64) + out[: qpos_arr.shape[0]] = qpos_arr + with self._arr.get_lock(): + self._arr[: self._qpos_len] = out.tolist() + + def close(self) -> None: + if self._shutdown is not None: + self._shutdown.set() + if self._proc is not None: + self._proc.join(timeout=3.0) + if self._proc.is_alive(): + self._proc.terminate() + + +class RetargetRecordingWorker: + """Continuously retarget Pico frames without blocking terminal input.""" + + def __init__( + self, + *, + provider: Pico4InputProvider, + retargeter: RetargetingModule, + preview: RetargetPreview, + state: RecordingState, + target_fps: int, + ) -> None: + if target_fps <= 0: + raise ValueError(f"target_fps must be > 0, got {target_fps}") + self._provider = provider + self._retargeter = retargeter + self._preview = preview + self._state = state + self._target_fps = int(target_fps) + self._stop_event = threading.Event() + self._thread = threading.Thread(target=self._run, name="pico_retarget_recorder", daemon=True) + self._last_error: BaseException | None = None + + def start(self) -> None: + self._thread.start() + + def stop(self) -> None: + self._stop_event.set() + self._thread.join(timeout=3.0) + + def raise_if_failed(self) -> None: + if self._last_error is not None: + raise RuntimeError(f"retarget worker failed: {self._last_error}") from self._last_error + + def _run(self) -> None: + last_seq: int | None = None + dt = 1.0 / float(self._target_fps) + next_tick = time.monotonic() + try: + while not self._stop_event.is_set(): + now = time.monotonic() + if now < next_tick: + self._stop_event.wait(timeout=min(next_tick - now, 0.01)) + continue + next_tick = now + dt + qpos, last_seq = _retarget_latest_frame(self._provider, self._retargeter, last_seq=last_seq) + if qpos is None: + continue + self._preview.update(qpos) + self._state.append(qpos) + except BaseException as exc: + self._last_error = exc + + +def _project_path(raw: str | Path) -> Path: + path = Path(raw).expanduser() + return path if path.is_absolute() else PROJECT_ROOT / path + + +def _build_provider(cfg: DictConfig) -> Pico4InputProvider: + input_cfg = cfg_get(cfg, "input", {}) or {} + return Pico4InputProvider( + human_format=str(cfg_get(input_cfg, "human_format", "pico_bridge")), + timeout=float(cfg_get(input_cfg, "pico4_timeout", 60.0)), + buffer_size=int(cfg_get(input_cfg, "pico4_buffer_size", 60)), + timestamp_gap_reset_s=float(cfg_get(input_cfg, "pico4_timestamp_gap_reset_s", 0.15)), + pause_button=cfg_get(input_cfg, "pause_button", "A"), + pause_debounce_s=float(cfg_get(input_cfg, "pause_debounce_s", 0.25)), + bridge_host=str(cfg_get(input_cfg, "bridge_host", "0.0.0.0")), + bridge_port=int(cfg_get(input_cfg, "bridge_port", 63901)), + bridge_discovery=bool(cfg_get(input_cfg, "bridge_discovery", True)), + bridge_advertise_ip=cfg_get(input_cfg, "bridge_advertise_ip", None), + bridge_video=None, + bridge_video_enabled=False, + bridge_start_timeout=float(cfg_get(input_cfg, "bridge_start_timeout", 10.0)), + bridge_history_size=int(cfg_get(input_cfg, "bridge_history_size", 120)), + ) + + +def _build_retargeter(cfg: DictConfig, provider: Pico4InputProvider) -> RetargetingModule: + input_cfg = cfg_get(cfg, "input", {}) or {} + human_height = cfg_get(input_cfg, "human_height", 1.75) + return RetargetingModule( + robot_name=str(cfg_get(input_cfg, "robot_name", "unitree_g1")), + human_format=str(provider.human_format), + actual_human_height=float(human_height), + ) + + +def _prompt_clip_name() -> str | None: + while True: + raw = input("\nClip name (semantic label, or q to quit): ").strip() + if raw.lower() in {"q", "quit", "exit"}: + return None + try: + return sanitize_clip_name(raw) + except ValueError as exc: + print(f"Invalid clip name: {exc}") + + +def _retarget_latest_frame( + provider: Pico4InputProvider, + retargeter: RetargetingModule, + *, + last_seq: int | None, +) -> tuple[np.ndarray | None, int | None]: + if not provider.has_frame(): + return None, last_seq + frame, _timestamp_s, seq = provider.get_frame_packet() + seq = int(seq) + if last_seq is not None and seq == last_seq: + return None, last_seq + qpos = np.asarray(retargeter.retarget(frame), dtype=np.float64).reshape(-1) + if qpos.shape[0] != FULL_QPOS_DIM: + raise ValueError(f"retarget qpos must be {FULL_QPOS_DIM}D, got {qpos.shape[0]}") + if not np.isfinite(qpos).all(): + raise ValueError("retarget qpos contains NaN/Inf") + return qpos, seq + + +def _record_one_clip( + *, + cfg: DictConfig, + state: RecordingState, +) -> str: + record_cfg = cfg_get(cfg, "record", {}) or {} + target_fps = int(cfg_get(record_cfg, "target_fps", 30)) + min_frames = int(cfg_get(record_cfg, "min_frames", 30)) + max_duration_s = float(cfg_get(record_cfg, "max_duration_s", 0.0)) + output_dir = _project_path(str(cfg_get(record_cfg, "output_dir", "data/pico_motion/clips"))) + if target_fps <= 0: + raise ValueError(f"record.target_fps must be > 0, got {target_fps}") + if min_frames < 2: + raise ValueError(f"record.min_frames must be >= 2, got {min_frames}") + + keyboard = TerminalKeyboardReader() + if not keyboard.active: + keyboard.close() + raise RuntimeError("record_pico_motion.py requires an interactive TTY for keyboard controls") + + print("Controls: R=start S=save D=discard N=new name Q=quit") + clip_name, recording, frame_count, _elapsed = state.status() + if clip_name is None: + print("Preview is running. Press N to enter a clip name before recording.") + else: + print(f"Ready: {clip_name}") + + try: + while True: + for event in keyboard.poll(): + key = event.key.lower() + if key == "q": + _clip_name, was_recording, _frame_count, _elapsed = state.status() + if was_recording: + discarded_name, _ = state.discard() + print(f"\nDiscarded unsaved clip: {discarded_name}") + return "quit" + if key == "n": + _clip_name, was_recording, _frame_count, _elapsed = state.status() + if was_recording: + print("\nPress S to save or D to discard before changing clip name.") + continue + return "next" + if key == "r": + current_name, _recording, _frame_count, _elapsed = state.status() + if current_name is None: + print("\nPress N and enter a clip name before recording.") + continue + clip_name = state.start() + print(f"\nRecording: {clip_name}") + elif key == "d": + clip_name, frame_count = state.discard() + label = "" if clip_name is None else clip_name + print(f"\nDiscarded: {label} ({frame_count} frames)") + return "next" + elif key == "s": + clip_name, _was_recording, frames = state.snapshot() + if clip_name is None: + print("\nPress N and enter a clip name before saving.") + continue + if len(frames) < min_frames: + print(f"\nNeed at least {min_frames} frames before saving; current={len(frames)}") + continue + output_path = unique_clip_path(output_dir, clip_name) + write_motion_clip_npz(output_path, frames, fps=target_fps) + state.mark_saved() + duration_s = len(frames) / float(target_fps) + print(f"\nSaved: {output_path} ({len(frames)} frames, {duration_s:.2f}s)") + return "next" + + clip_name, recording, frame_count, elapsed = state.status() + if recording and clip_name is not None and max_duration_s > 0.0 and elapsed is not None: + if elapsed >= max_duration_s: + clip_name, _was_recording, frames = state.snapshot() + if clip_name is None: + raise RuntimeError("recording reached max_duration_s without a clip name") + if len(frames) < min_frames: + raise ValueError( + f"max_duration_s reached but only recorded {len(frames)} frames; " + f"min_frames={min_frames}" + ) + output_path = unique_clip_path(output_dir, clip_name) + write_motion_clip_npz(output_path, frames, fps=target_fps) + state.mark_saved() + print(f"\nSaved by max_duration_s: {output_path}") + return "next" + + time.sleep(0.02) + finally: + keyboard.close() + + +def _maybe_write_dataset_spec(cfg: DictConfig) -> None: + record_cfg = cfg_get(cfg, "record", {}) or {} + if not bool(cfg_get(record_cfg, "write_dataset_spec", True)): + return + output_dir = _project_path(str(cfg_get(record_cfg, "output_dir", "data/pico_motion/clips"))) + spec_path = _project_path(str(cfg_get(record_cfg, "dataset_spec_path", "data/pico_motion/pico_recorded.yaml"))) + dataset_name = str(cfg_get(record_cfg, "dataset_name", "pico_recorded")) + target_fps = int(cfg_get(record_cfg, "target_fps", 30)) + overwrite = bool(cfg_get(record_cfg, "overwrite_dataset_spec", False)) + path = ensure_pico_dataset_spec( + spec_path, + output_dir, + spec=PicoDatasetSpec(dataset_name=dataset_name, target_fps=target_fps), + overwrite=overwrite, + ) + print(f"Dataset spec: {path}") + + +def _configure_logging(cfg: DictConfig) -> None: + logging.basicConfig(level=logging.INFO, format="%(levelname)s:%(name)s:%(message)s") + record_cfg = cfg_get(cfg, "record", {}) or {} + if bool(cfg_get(record_cfg, "quiet_logs", True)): + logging.getLogger("pico_bridge").setLevel(logging.WARNING) + logging.getLogger("teleopit.inputs.pico4_provider").setLevel(logging.ERROR) + + +@hydra.main(version_base=None, config_path="../../teleopit/configs", config_name="pico4_record") +def main(cfg: DictConfig) -> None: + _configure_logging(cfg) + _maybe_write_dataset_spec(cfg) + + print("Starting Pico receiver; waiting for body tracking...") + provider = _build_provider(cfg) + preview = RetargetPreview(enabled=False) + try: + preview.close() + preview = RetargetPreview(enabled=bool(cfg_get(cfg_get(cfg, "record", {}) or {}, "viewer_enabled", True))) + retargeter = _build_retargeter(cfg, provider) + print("Pico recorder is ready. Retarget viewer will update after the first frame.") + record_cfg = cfg_get(cfg, "record", {}) or {} + target_fps = int(cfg_get(record_cfg, "target_fps", 30)) + state = RecordingState() + worker = RetargetRecordingWorker( + provider=provider, + retargeter=retargeter, + preview=preview, + state=state, + target_fps=target_fps, + ) + worker.start() + try: + clip_name = _prompt_clip_name() + if clip_name is not None: + state.set_clip_name(clip_name) + while True: + worker.raise_if_failed() + if clip_name is None: + break + command = _record_one_clip(cfg=cfg, state=state) + worker.raise_if_failed() + if command == "quit": + break + clip_name = _prompt_clip_name() + if clip_name is not None: + state.set_clip_name(clip_name) + finally: + worker.stop() + worker.raise_if_failed() + except KeyboardInterrupt: + print("\nInterrupted; unsaved in-progress clip was discarded.") + finally: + preview.close() + provider.close() + + +if __name__ == "__main__": + main() diff --git a/teleopit/configs/pico4_record.yaml b/teleopit/configs/pico4_record.yaml new file mode 100644 index 00000000..684a380e --- /dev/null +++ b/teleopit/configs/pico4_record.yaml @@ -0,0 +1,24 @@ +defaults: + - input: pico4 + - _self_ + +input: + # Recording does not need aggressive realtime timeline reset warnings. + pico4_timestamp_gap_reset_s: 0.5 + +record: + output_dir: data/pico_motion/clips + target_fps: 30 + min_frames: 30 + dataset_name: pico_recorded + dataset_spec_path: data/pico_motion/pico_recorded.yaml + write_dataset_spec: true + overwrite_dataset_spec: false + viewer_enabled: true + quiet_logs: true + # Optional safety cap in seconds. 0 means record until S/D/Q. + max_duration_s: 0.0 + +hydra: + run: + dir: . diff --git a/teleopit/recording/__init__.py b/teleopit/recording/__init__.py index 82c5b0c5..3b6edc87 100644 --- a/teleopit/recording/__init__.py +++ b/teleopit/recording/__init__.py @@ -1,3 +1,21 @@ from teleopit.recording.hdf5_recorder import HDF5Recorder +from teleopit.recording.pico_motion import ( + PicoDatasetSpec, + RecordingState, + ensure_pico_dataset_spec, + qpos_sequence_to_motion_clip, + sanitize_clip_name, + unique_clip_path, + write_motion_clip_npz, +) -__all__ = ["HDF5Recorder"] +__all__ = [ + "HDF5Recorder", + "PicoDatasetSpec", + "RecordingState", + "ensure_pico_dataset_spec", + "qpos_sequence_to_motion_clip", + "sanitize_clip_name", + "unique_clip_path", + "write_motion_clip_npz", +] diff --git a/teleopit/recording/pico_motion.py b/teleopit/recording/pico_motion.py new file mode 100644 index 00000000..90968a71 --- /dev/null +++ b/teleopit/recording/pico_motion.py @@ -0,0 +1,233 @@ +"""Helpers for recording Pico-retargeted G1 motion clips.""" + +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +import re +import threading +import time +from typing import Any, Iterable + +import numpy as np + +from teleopit.constants import FULL_QPOS_DIM, NUM_JOINTS, ROOT_DIM +from teleopit.runtime.assets import PROJECT_ROOT +from train_mimic.data.dataset_lib import inspect_clip_dict +from train_mimic.data.motion_fk import MotionFkExtractor, compute_body_velocities +from train_mimic.scripts.convert_pkl_to_npz import _MJLAB_G1_BODY_NAMES + + +@dataclass(frozen=True) +class PicoDatasetSpec: + """Dataset spec defaults for Pico-recorded NPZ clips.""" + + dataset_name: str = "pico_recorded" + target_fps: int = 30 + val_percent: int = 5 + source_name: str = "pico_clips" + + +class RecordingState: + """Thread-safe state shared by terminal UI and retarget worker.""" + + def __init__(self, clip_name: str | None = None) -> None: + self._lock = threading.Lock() + self.clip_name = clip_name + self.recording = False + self.qpos_buffer: list[np.ndarray] = [] + self.record_start_s: float | None = None + + def set_clip_name(self, clip_name: str) -> None: + with self._lock: + if self.recording: + raise RuntimeError("cannot change clip name while recording") + self.clip_name = clip_name + self.qpos_buffer.clear() + self.record_start_s = None + + def start(self) -> str: + with self._lock: + if self.clip_name is None: + raise RuntimeError("clip name must be set before recording") + if self.recording: + return self.clip_name + self.qpos_buffer.clear() + self.recording = True + self.record_start_s = time.monotonic() + return self.clip_name + + def discard(self) -> tuple[str | None, int]: + with self._lock: + clip_name = self.clip_name + frame_count = len(self.qpos_buffer) + self.recording = False + self.qpos_buffer.clear() + self.record_start_s = None + return clip_name, frame_count + + def snapshot(self) -> tuple[str | None, bool, list[np.ndarray]]: + with self._lock: + clip_name = self.clip_name + recording = self.recording + frames = [frame.copy() for frame in self.qpos_buffer] + return clip_name, recording, frames + + def mark_saved(self) -> None: + with self._lock: + self.recording = False + self.qpos_buffer.clear() + self.record_start_s = None + + def append(self, qpos: np.ndarray) -> None: + with self._lock: + if self.recording: + self.qpos_buffer.append(qpos.copy()) + + def status(self) -> tuple[str | None, bool, int, float | None]: + with self._lock: + elapsed = None if self.record_start_s is None else time.monotonic() - self.record_start_s + return self.clip_name, self.recording, len(self.qpos_buffer), elapsed + + +def sanitize_clip_name(raw_name: str) -> str: + """Return a filesystem-friendly semantic clip name.""" + name = raw_name.strip().lower() + name = re.sub(r"\s+", "_", name) + name = re.sub(r"[^a-z0-9_.-]+", "_", name) + name = re.sub(r"_+", "_", name).strip("._-") + if not name: + raise ValueError("clip name must contain at least one letter or digit") + return name + + +def timestamp_suffix(now: datetime | None = None) -> str: + current = now or datetime.now() + return current.strftime("%Y%m%d_%H%M%S") + + +def unique_clip_path(output_dir: str | Path, clip_name: str, *, now: datetime | None = None) -> Path: + """Build a non-overwriting NPZ path for one semantic clip.""" + out_dir = Path(output_dir).expanduser() + safe_name = sanitize_clip_name(clip_name) + stem = f"{safe_name}_{timestamp_suffix(now)}" + candidate = out_dir / f"{stem}.npz" + if not candidate.exists(): + return candidate + for index in range(1, 1000): + candidate = out_dir / f"{stem}_{index:03d}.npz" + if not candidate.exists(): + return candidate + raise RuntimeError(f"could not allocate a unique clip path in {out_dir}") + + +def _display_path(path: Path, *, project_root: Path = PROJECT_ROOT) -> str: + try: + return path.resolve().relative_to(project_root.resolve()).as_posix() + except ValueError: + return str(path) + + +def ensure_pico_dataset_spec( + spec_path: str | Path, + clips_dir: str | Path, + *, + spec: PicoDatasetSpec = PicoDatasetSpec(), + overwrite: bool = False, +) -> Path: + """Create the dataset YAML spec used to merge recorded Pico clips. + + Existing specs are preserved by default so hand-edited settings are not lost. + """ + path = Path(spec_path).expanduser() + if path.exists() and not overwrite: + return path + + clips_path = Path(clips_dir).expanduser() + path.parent.mkdir(parents=True, exist_ok=True) + clips_display = _display_path(clips_path) + content = ( + f"name: {spec.dataset_name}\n" + f"target_fps: {int(spec.target_fps)}\n" + f"val_percent: {int(spec.val_percent)}\n" + 'hash_salt: ""\n' + "preprocess:\n" + " normalize_root_xy: true\n" + " ground_align: first_frame_foot\n" + "sources:\n" + f" - name: {spec.source_name}\n" + " type: npz\n" + f" input: {clips_display}\n" + ) + path.write_text(content, encoding="utf-8") + return path + + +def qpos_sequence_to_motion_clip( + qpos_sequence: Iterable[np.ndarray] | np.ndarray, + *, + fps: int, + extractor: Any | None = None, + body_names: list[str] | None = None, +) -> dict[str, Any]: + """Convert retargeted G1 qpos frames into a standard training motion clip.""" + qpos = np.asarray(list(qpos_sequence) if not isinstance(qpos_sequence, np.ndarray) else qpos_sequence, dtype=np.float32) + if qpos.ndim != 2 or qpos.shape[1] != FULL_QPOS_DIM: + raise ValueError(f"qpos_sequence must have shape (T,{FULL_QPOS_DIM}), got {qpos.shape}") + if qpos.shape[0] < 2: + raise ValueError("qpos_sequence must contain at least 2 frames") + if int(fps) <= 0: + raise ValueError(f"fps must be > 0, got {fps}") + if not np.isfinite(qpos).all(): + raise ValueError("qpos_sequence contains NaN/Inf") + + names = list(_MJLAB_G1_BODY_NAMES if body_names is None else body_names) + root_pos = qpos[:, 0:3].astype(np.float32, copy=False) + root_quat_wxyz = qpos[:, 3:7].astype(np.float32, copy=False) + joint_pos = qpos[:, ROOT_DIM:ROOT_DIM + NUM_JOINTS].astype(np.float32, copy=False) + if joint_pos.shape[1] != NUM_JOINTS: + raise ValueError(f"joint_pos must have {NUM_JOINTS} columns, got {joint_pos.shape}") + + dt = 1.0 / float(fps) + joint_vel = np.gradient(joint_pos, dt, axis=0).astype(np.float32) + + fk_extractor = extractor or MotionFkExtractor() + body_pos_w, body_quat_w = fk_extractor.extract(root_pos, root_quat_wxyz, joint_pos, names) + body_lin_vel_w, body_ang_vel_w = compute_body_velocities(body_pos_w, body_quat_w, dt) + + clip = { + "fps": int(fps), + "joint_pos": joint_pos.astype(np.float32, copy=False), + "joint_vel": joint_vel.astype(np.float32, copy=False), + "body_pos_w": np.asarray(body_pos_w, dtype=np.float32), + "body_quat_w": np.asarray(body_quat_w, dtype=np.float32), + "body_lin_vel_w": np.asarray(body_lin_vel_w, dtype=np.float32), + "body_ang_vel_w": np.asarray(body_ang_vel_w, dtype=np.float32), + "body_names": np.asarray(names, dtype=str), + } + inspect_clip_dict(clip) + return clip + + +def write_motion_clip_npz( + output_path: str | Path, + qpos_sequence: Iterable[np.ndarray] | np.ndarray, + *, + fps: int, + extractor: Any | None = None, +) -> Path: + """Write a retargeted qpos sequence as an atomically replaced motion NPZ.""" + path = Path(output_path).expanduser() + path.parent.mkdir(parents=True, exist_ok=True) + clip = qpos_sequence_to_motion_clip(qpos_sequence, fps=fps, extractor=extractor) + tmp_path = path.with_name(f"{path.name}.tmp") + try: + with tmp_path.open("wb") as handle: + np.savez(handle, **clip) + path_tmp = tmp_path + path_tmp.replace(path) + finally: + if tmp_path.exists(): + tmp_path.unlink() + return path diff --git a/tests/test_dataset_v2.py b/tests/test_dataset_v2.py index 0aebf006..f5eebe82 100644 --- a/tests/test_dataset_v2.py +++ b/tests/test_dataset_v2.py @@ -590,7 +590,7 @@ def test_convert_source_to_npz_clips_applies_preprocess(tmp_path: Path) -> None: jobs=1, preprocess=DatasetPreprocessSpec( normalize_root_xy=True, - ground_align="clip_min_foot", + ground_align="first_frame_foot", ), ) clip = np.load(output_dir_2 / "clip_a.npz", allow_pickle=True) @@ -600,7 +600,7 @@ def test_convert_source_to_npz_clips_applies_preprocess(tmp_path: Path) -> None: right_idx = body_names.index("right_ankle_roll_link") assert np.allclose(clip["body_pos_w"][0, pelvis_idx, :2], 0.0) foot_z = clip["body_pos_w"][:, [left_idx, right_idx], 2] - assert np.isclose(float(np.min(foot_z)), 0.0) + assert np.isclose(float(np.min(foot_z[0])), 0.0) def test_convert_source_to_npz_clips_skips_all_off_ground_clips_before_ground_align(tmp_path: Path) -> None: @@ -630,7 +630,7 @@ def test_convert_source_to_npz_clips_skips_all_off_ground_clips_before_ground_al output_dir, jobs=1, preprocess=dataset_builder.DatasetPreprocessSpec( - ground_align="clip_min_foot", + ground_align="first_frame_foot", max_all_off_ground_s=0.05, off_ground_height=0.08, ), @@ -652,7 +652,7 @@ def _unexpected_run_conversion_tasks(*_args, **_kwargs): output_dir, jobs=1, preprocess=dataset_builder.DatasetPreprocessSpec( - ground_align="clip_min_foot", + ground_align="first_frame_foot", max_all_off_ground_s=0.05, off_ground_height=0.08, ), @@ -788,7 +788,7 @@ def _convert(path: str, **_kwargs): "train", preprocess=dataset_builder.DatasetPreprocessSpec( normalize_root_xy=True, - ground_align="clip_min_foot", + ground_align="first_frame_foot", min_frames=22, ), ) diff --git a/tests/test_pico_motion_recording.py b/tests/test_pico_motion_recording.py new file mode 100644 index 00000000..cfa540fd --- /dev/null +++ b/tests/test_pico_motion_recording.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +from datetime import datetime +from pathlib import Path + +import numpy as np +import pytest + +from teleopit.constants import FULL_QPOS_DIM, NUM_JOINTS +from teleopit.recording.pico_motion import ( + PicoDatasetSpec, + RecordingState, + ensure_pico_dataset_spec, + qpos_sequence_to_motion_clip, + sanitize_clip_name, + unique_clip_path, + write_motion_clip_npz, +) +from train_mimic.data.dataset_builder import load_dataset_spec +from train_mimic.data.dataset_lib import inspect_npz + + +class _FakeFkExtractor: + num_actions = NUM_JOINTS + + def extract( + self, + root_pos: np.ndarray, + root_quat_wxyz: np.ndarray, + joint_pos: np.ndarray, + body_names: list[str], + ) -> tuple[np.ndarray, np.ndarray]: + del root_quat_wxyz, joint_pos + offsets = np.zeros((len(body_names), 3), dtype=np.float32) + offsets[:, 2] = np.linspace(0.0, 0.2, len(body_names), dtype=np.float32) + body_pos_w = root_pos[:, None, :] + offsets[None, :, :] + body_quat_w = np.zeros((root_pos.shape[0], len(body_names), 4), dtype=np.float32) + body_quat_w[..., 0] = 1.0 + return body_pos_w.astype(np.float32), body_quat_w + + +def _qpos_sequence(num_frames: int = 4) -> np.ndarray: + qpos = np.zeros((num_frames, FULL_QPOS_DIM), dtype=np.float32) + qpos[:, 0] = np.linspace(0.0, 0.3, num_frames, dtype=np.float32) + qpos[:, 2] = 0.76 + qpos[:, 3] = 1.0 + qpos[:, 7] = np.linspace(0.0, 0.2, num_frames, dtype=np.float32) + return qpos + + +def test_sanitize_clip_name_keeps_semantic_label_filesystem_safe() -> None: + assert sanitize_clip_name(" Walk Forward Slow ") == "walk_forward_slow" + assert sanitize_clip_name("turn-left/fast") == "turn-left_fast" + with pytest.raises(ValueError, match="clip name"): + sanitize_clip_name("...") + + +def test_unique_clip_path_adds_timestamp_and_avoids_overwrite(tmp_path: Path) -> None: + now = datetime(2026, 6, 10, 14, 22, 33) + first = unique_clip_path(tmp_path, "walk forward", now=now) + assert first.name == "walk_forward_20260610_142233.npz" + first.write_bytes(b"placeholder") + + second = unique_clip_path(tmp_path, "walk forward", now=now) + assert second.name == "walk_forward_20260610_142233_001.npz" + + +def test_qpos_sequence_to_motion_clip_writes_standard_npz_fields() -> None: + clip = qpos_sequence_to_motion_clip(_qpos_sequence(), fps=30, extractor=_FakeFkExtractor()) + assert int(clip["fps"]) == 30 + assert clip["joint_pos"].shape == (4, NUM_JOINTS) + assert clip["joint_vel"].shape == (4, NUM_JOINTS) + assert clip["body_pos_w"].shape[0] == 4 + assert clip["body_quat_w"].shape[-1] == 4 + + +def test_qpos_sequence_to_motion_clip_rejects_invalid_input() -> None: + with pytest.raises(ValueError, match=r"shape"): + qpos_sequence_to_motion_clip(np.zeros((3, FULL_QPOS_DIM - 1)), fps=30, extractor=_FakeFkExtractor()) + + bad = _qpos_sequence() + bad[0, 0] = np.nan + with pytest.raises(ValueError, match="NaN/Inf"): + qpos_sequence_to_motion_clip(bad, fps=30, extractor=_FakeFkExtractor()) + + with pytest.raises(ValueError, match="at least 2"): + qpos_sequence_to_motion_clip(_qpos_sequence(1), fps=30, extractor=_FakeFkExtractor()) + + +def test_write_motion_clip_npz_and_inspect(tmp_path: Path) -> None: + out = tmp_path / "clip.npz" + write_motion_clip_npz(out, _qpos_sequence(), fps=30, extractor=_FakeFkExtractor()) + assert out.exists() + meta = inspect_npz(out) + assert meta.fps == 30 + assert meta.num_frames == 4 + + +def test_ensure_pico_dataset_spec_preserves_existing_file(tmp_path: Path) -> None: + clips_dir = tmp_path / "clips" + clips_dir.mkdir() + spec_path = tmp_path / "pico_recorded.yaml" + + ensure_pico_dataset_spec( + spec_path, + clips_dir, + spec=PicoDatasetSpec(dataset_name="pico_recorded", target_fps=30), + ) + spec = load_dataset_spec(spec_path) + assert spec.name == "pico_recorded" + assert spec.target_fps == 30 + assert spec.sources[0].type == "npz" + assert spec.sources[0].input == str(clips_dir) + + spec_path.write_text("name: hand_edited\n", encoding="utf-8") + ensure_pico_dataset_spec(spec_path, clips_dir) + assert spec_path.read_text(encoding="utf-8") == "name: hand_edited\n" + + +def test_recording_state_snapshot_does_not_clear_buffer() -> None: + state = RecordingState("walk") + state.start() + state.append(_qpos_sequence(2)[0]) + state.append(_qpos_sequence(2)[1]) + + clip_name, recording, frames = state.snapshot() + assert clip_name == "walk" + assert recording is True + assert len(frames) == 2 + assert state.status()[2] == 2 + + state.mark_saved() + assert state.status()[1] is False + assert state.status()[2] == 0 + + +def test_recording_state_discard_clears_buffer() -> None: + state = RecordingState("turn") + state.start() + state.append(_qpos_sequence(2)[0]) + + clip_name, frame_count = state.discard() + assert clip_name == "turn" + assert frame_count == 1 + assert state.status()[1] is False + assert state.status()[2] == 0 diff --git a/tests/test_train_script.py b/tests/test_train_script.py index fba33b7b..0dbce3bf 100644 --- a/tests/test_train_script.py +++ b/tests/test_train_script.py @@ -3,6 +3,8 @@ from __future__ import annotations import argparse +import sys +import types from pathlib import Path import pytest @@ -27,7 +29,7 @@ def _args(**overrides: object) -> argparse.Namespace: "num_envs": 1024, "max_iterations": 10, "seed": 42, - "wandb_project": None, + "logger": "tensorboard", "experiment_name": None, "motion_file": "data/datasets/twist2/train", "resume": None, @@ -43,6 +45,18 @@ def _args(**overrides: object) -> argparse.Namespace: class TestTrainLauncherHelpers: + def test_parse_args_defaults_to_tensorboard_logger(self) -> None: + args = train.parse_args([]) + assert args.logger == "tensorboard" + + def test_parse_args_accepts_logger_choice(self) -> None: + args = train.parse_args(["--logger", "swanlab"]) + assert args.logger == "swanlab" + + def test_parse_args_rejects_removed_wandb_project(self) -> None: + with pytest.raises(SystemExit): + train.parse_args(["--wandb_project", "teleopit"]) + def test_parse_args_with_gpu_ids(self) -> None: args = train.parse_args(["--gpu_ids", "0", "2", "3", "--master_port", "29600"]) assert args.gpu_ids == [0, 2, 3] @@ -114,6 +128,77 @@ def test_resolve_worker_seed_defaults_to_base_seed_without_rank(self) -> None: def test_resolve_worker_seed_ignores_rank_outside_distributed_mode(self) -> None: assert train._resolve_worker_seed(42, env={"WORLD_SIZE": "1", "RANK": "3"}) == 42 + def test_configure_tensorboard_logger(self) -> None: + agent_cfg = types.SimpleNamespace(logger="wandb", experiment_name="exp") + env_cfg = types.SimpleNamespace() + + active = train._configure_experiment_logger( + logger_name="tensorboard", + agent_cfg=agent_cfg, + env_cfg=env_cfg, + log_dir="/tmp/run", + ) + + assert active is False + assert agent_cfg.logger == "tensorboard" + + def test_configure_wandb_logger_uses_experiment_name_as_project(self) -> None: + agent_cfg = types.SimpleNamespace(logger="tensorboard", experiment_name="exp") + env_cfg = types.SimpleNamespace() + + active = train._configure_experiment_logger( + logger_name="wandb", + agent_cfg=agent_cfg, + env_cfg=env_cfg, + log_dir="/tmp/run", + ) + + assert active is False + assert agent_cfg.logger == "wandb" + assert agent_cfg.wandb_project == "exp" + + def test_configure_swanlab_logger_syncs_tensorboard(self, monkeypatch: pytest.MonkeyPatch) -> None: + calls: list[tuple[str, object]] = [] + fake_swanlab = types.SimpleNamespace( + init=lambda **kwargs: calls.append(("init", kwargs)), + sync_tensorboard_torch=lambda **kwargs: calls.append(("sync", kwargs)), + ) + monkeypatch.setitem(sys.modules, "swanlab", fake_swanlab) + monkeypatch.setenv("RANK", "0") + agent_cfg = types.SimpleNamespace(logger="wandb", experiment_name="exp", max_iterations=10) + env_cfg = types.SimpleNamespace( + commands={"motion": types.SimpleNamespace(motion_file="data/train", sampling_mode="uniform")}, + scene=types.SimpleNamespace(num_envs=64), + ) + + active = train._configure_experiment_logger( + logger_name="swanlab", + agent_cfg=agent_cfg, + env_cfg=env_cfg, + log_dir="/tmp/2026-01-01_00-00-00", + ) + + assert active is True + assert agent_cfg.logger == "tensorboard" + assert calls == [ + ( + "init", + { + "project": "exp", + "name": "2026-01-01_00-00-00", + "log_dir": "/tmp/2026-01-01_00-00-00", + "config": { + "experiment_name": "exp", + "motion_file": "data/train", + "num_envs": 64, + "max_iterations": 10, + "sampling_mode": "uniform", + }, + }, + ), + ("sync", {"types": ["scalar", "scalars", "image", "text"]}), + ] + def test_main_uses_launcher_branch(self, monkeypatch: pytest.MonkeyPatch) -> None: called: dict[str, object] = {} diff --git a/train_mimic/configs/datasets/seed.yaml b/train_mimic/configs/datasets/seed.yaml index 45d0c40e..c18c5f9a 100644 --- a/train_mimic/configs/datasets/seed.yaml +++ b/train_mimic/configs/datasets/seed.yaml @@ -4,7 +4,7 @@ val_percent: 5 hash_salt: "" preprocess: normalize_root_xy: true - ground_align: clip_min_foot + ground_align: first_frame_foot min_frames: 22 sources: - name: seed_full diff --git a/train_mimic/configs/datasets/seed_clean.yaml b/train_mimic/configs/datasets/seed_clean.yaml index 8a05a26c..80d3021a 100644 --- a/train_mimic/configs/datasets/seed_clean.yaml +++ b/train_mimic/configs/datasets/seed_clean.yaml @@ -4,7 +4,7 @@ val_percent: 5 hash_salt: "" preprocess: normalize_root_xy: true - ground_align: clip_min_foot + ground_align: first_frame_foot min_frames: 22 sources: - name: seed_full diff --git a/train_mimic/configs/datasets/twist2.yaml b/train_mimic/configs/datasets/twist2.yaml index 0c819271..e178b6f5 100644 --- a/train_mimic/configs/datasets/twist2.yaml +++ b/train_mimic/configs/datasets/twist2.yaml @@ -4,7 +4,7 @@ val_percent: 5 hash_salt: "" preprocess: normalize_root_xy: true - ground_align: clip_min_foot + ground_align: first_frame_foot sources: - name: OMOMO_g1_GMR type: pkl diff --git a/train_mimic/data/preprocess.py b/train_mimic/data/preprocess.py index ed6e6b35..f6558805 100644 --- a/train_mimic/data/preprocess.py +++ b/train_mimic/data/preprocess.py @@ -9,7 +9,7 @@ from train_mimic.data.dataset_lib import inspect_clip_dict -GROUND_ALIGN_MODES = {"none", "clip_min_foot"} +GROUND_ALIGN_MODES = {"none", "first_frame_foot"} @dataclass(frozen=True) @@ -169,10 +169,10 @@ def preprocess_clip_dict( f"{spec.max_feet_off_ground_s:.3f}s" ) - if spec.ground_align == "clip_min_foot": + if spec.ground_align == "first_frame_foot": assert foot_indices is not None foot_z = body_pos_w[:, foot_indices, 2] - body_pos_w[..., 2] -= float(np.min(foot_z)) + body_pos_w[..., 2] -= float(np.min(foot_z[0])) if spec.min_peak_body_height is not None: peak_height = float(np.max(body_pos_w[:, :, 2])) diff --git a/train_mimic/scripts/train.py b/train_mimic/scripts/train.py index 1e38c892..98e3353d 100644 --- a/train_mimic/scripts/train.py +++ b/train_mimic/scripts/train.py @@ -11,11 +11,17 @@ --num_envs 64 --max_iterations 100 \ --motion_file data/datasets/twist2/train - # With wandb logging + # With W&B logging python train_mimic/scripts/train.py \ --num_envs 4096 --max_iterations 30000 \ --motion_file data/datasets/twist2/train \ - --wandb_project teleopit + --logger wandb + + # With SwanLab logging + python train_mimic/scripts/train.py \ + --num_envs 4096 --max_iterations 30000 \ + --motion_file data/datasets/twist2/train \ + --logger swanlab # Resume for additional iterations python train_mimic/scripts/train.py \ @@ -59,8 +65,13 @@ def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace: ), ) parser.add_argument("--seed", type=int, default=42) - parser.add_argument("--wandb_project", type=str, default=None, - help="Enable wandb and set project name (default: tensorboard)") + parser.add_argument( + "--logger", + type=str, + default="tensorboard", + choices=["tensorboard", "wandb", "swanlab"], + help="Experiment logger backend (default: tensorboard)", + ) parser.add_argument("--experiment_name", type=str, default=None) parser.add_argument("--motion_file", type=str, default=None, help="Shard directory path containing shard_*.npz files") @@ -227,6 +238,58 @@ def _resolve_worker_seed(base_seed: int, env: dict[str, str] | None = None) -> i return base_seed + global_rank * 100003 +def _is_main_process(env: dict[str, str] | None = None) -> bool: + runtime_env = os.environ if env is None else env + return int(runtime_env.get("RANK", "0")) == 0 + + +def _configure_experiment_logger( + *, + logger_name: str, + agent_cfg: Any, + env_cfg: Any, + log_dir: str, +) -> bool: + """Configure the training logger and return whether SwanLab was started.""" + if logger_name == "tensorboard": + agent_cfg.logger = "tensorboard" + return False + + if logger_name == "wandb": + agent_cfg.logger = "wandb" + agent_cfg.wandb_project = agent_cfg.experiment_name + return False + + if logger_name != "swanlab": + raise ValueError(f"Unsupported logger '{logger_name}'") + + agent_cfg.logger = "tensorboard" + if not _is_main_process(): + return False + + try: + import swanlab + except ModuleNotFoundError: + raise ModuleNotFoundError( + "swanlab package is required for --logger swanlab. Install it with `pip install swanlab`." + ) from None + + swanlab.init( + project=agent_cfg.experiment_name, + name=os.path.basename(log_dir), + log_dir=log_dir, + config={ + "experiment_name": agent_cfg.experiment_name, + "motion_file": env_cfg.commands["motion"].motion_file, + "num_envs": env_cfg.scene.num_envs, + "max_iterations": agent_cfg.max_iterations, + "sampling_mode": env_cfg.commands["motion"].sampling_mode, + }, + ) + swanlab.sync_tensorboard_torch(types=["scalar", "scalars", "image", "text"]) + return True + + def _launch_multi_gpu(args: argparse.Namespace, argv: Sequence[str]) -> None: _validate_multi_gpu_args(args) command = _build_torchrun_command(args, argv) @@ -294,10 +357,6 @@ def _handle_shutdown(signum: int, _frame: Any) -> None: load_runner_cls=load_runner_cls, ) - # Default to tensorboard (mjlab defaults to wandb) - if args.wandb_project is None: - agent_cfg.logger = "tensorboard" - # CLI overrides env_cfg.seed = _resolve_worker_seed(args.seed) if args.num_envs is not None: @@ -311,9 +370,6 @@ def _handle_shutdown(signum: int, _frame: Any) -> None: agent_cfg.max_iterations = args.max_iterations if args.experiment_name is not None: agent_cfg.experiment_name = args.experiment_name - if args.wandb_project is not None: - agent_cfg.logger = "wandb" - agent_cfg.wandb_project = args.wandb_project device = _resolve_device(args, torch) @@ -322,6 +378,12 @@ def _handle_shutdown(signum: int, _frame: Any) -> None: os.makedirs(log_root, exist_ok=True) log_dir = os.path.join(log_root, datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) os.makedirs(log_dir, exist_ok=True) + swanlab_active = _configure_experiment_logger( + logger_name=args.logger, + agent_cfg=agent_cfg, + env_cfg=env_cfg, + log_dir=log_dir, + ) # render_mode only needed for video recording render_mode = "rgb_array" if args.video else None @@ -363,6 +425,11 @@ def _handle_shutdown(signum: int, _frame: Any) -> None: if env is not None: with contextlib.suppress(Exception): env.close() + if swanlab_active: + with contextlib.suppress(Exception): + import swanlab + + swanlab.finish() _destroy_process_group(torch) signal.signal(signal.SIGINT, old_sigint) signal.signal(signal.SIGTERM, old_sigterm) From ffbc04e7b0cbc99ddd997d6c871fbfc3def696ec Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Thu, 11 Jun 2026 12:01:54 +0800 Subject: [PATCH 072/122] Add adaptive motion sampling --- AGENTS.md | 2 +- README.md | 1 + tests/test_motion_sampling.py | 84 ++++++- tests/test_task_registry.py | 2 +- train_mimic/tasks/tracking/config/env.py | 2 +- train_mimic/tasks/tracking/mdp/commands.py | 245 ++++++++++++++++++++- train_mimic/tasks/tracking/rl/runner.py | 23 ++ 7 files changed, 351 insertions(+), 8 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index c81a10fa..16e7c6a7 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -190,7 +190,7 @@ The single supported training task is `General-Tracking-G1` (experiment name: `g - Uses TemporalCNN actor/critic with scaled dims (2048,1024,512,256,128) - 166D `velcmd_history` observation, dual-input ONNX export -- Training env uses `sampling_mode="uniform"` +- Training env uses `sampling_mode="adaptive"` - Playback/benchmark use `play=True`, which switches motion sampling to `start` - `window_steps=[0]` - `save_onnx.py` exports dual-input TemporalCNN ONNX diff --git a/README.md b/README.md index 0a1b2cf0..c6a97f1b 100644 --- a/README.md +++ b/README.md @@ -94,6 +94,7 @@ Full docs at **[BotRunner64.github.io/Teleopit](https://BotRunner64.github.io/Te - Switched default `vr_hand_pose` to a low-latency somehand path with 60 Hz hand retargeting and reduced smoothing. - Realtime mode switches and pause/resume now preserve GMR IK warm-starts instead of cold-starting the retargeter on each transition. - Added an interactive Pico motion recorder that saves retargeted G1 motion clips as training-ready NPZ files. +- General-Tracking-G1 training now defaults to clip-local adaptive motion sampling with checkpointed sampler state. ### v0.3.0 (2026-05-12) diff --git a/tests/test_motion_sampling.py b/tests/test_motion_sampling.py index ca68090b..ecaeb51d 100644 --- a/tests/test_motion_sampling.py +++ b/tests/test_motion_sampling.py @@ -1,12 +1,13 @@ from __future__ import annotations from pathlib import Path +from types import SimpleNamespace import numpy as np import torch from train_mimic.data.dataset_lib import merge_clip_dicts -from train_mimic.tasks.tracking.mdp.commands import MotionLib +from train_mimic.tasks.tracking.mdp.commands import MotionCommand, MotionLib def _clip_dict(num_frames: int = 6, fps: int = 1) -> dict[str, object]: @@ -40,9 +41,14 @@ def _clip_dict(num_frames: int = 6, fps: int = 1) -> dict[str, object]: } -def _write_shard_dir(path: Path, clip_dicts: list[dict[str, object]]) -> Path: +def _write_shard_dir( + path: Path, + clip_dicts: list[dict[str, object]], + *, + weights: list[float] | None = None, +) -> Path: path.mkdir(parents=True, exist_ok=True) - merge_clip_dicts(clip_dicts, path / "shard_000.npz") + merge_clip_dicts(clip_dicts, path / "shard_000.npz", weights=weights) return path @@ -143,3 +149,75 @@ def test_motion_lib_window_start_and_end_times_follow_valid_center_range(tmp_pat assert torch.allclose(motion.sample_start_times(motion_ids), torch.tensor([1.0])) assert torch.allclose(motion.clip_sample_end_s[motion_ids], torch.tensor([3.0])) + + +def test_motion_lib_adaptive_bins_are_clip_local(tmp_path: Path) -> None: + motion_path = _write_shard_dir( + tmp_path / "motion_adaptive_bins", + [_clip_dict(num_frames=6), _clip_dict(num_frames=8)], + weights=[1.0, 3.0], + ) + motion = MotionLib( + str(motion_path), + body_indexes=torch.tensor([0, 1], dtype=torch.long), + window_steps=(0,), + ) + + num_bins = motion.prepare_adaptive_sampling(bin_size_frames=2) + + assert num_bins == 7 + assert motion.adaptive_bin_clip_ids.tolist() == [0, 0, 0, 1, 1, 1, 1] + assert motion.adaptive_bin_start_frames.tolist() == [0, 2, 4, 0, 2, 4, 6] + assert motion.adaptive_bin_end_frames.tolist() == [2, 4, 5, 2, 4, 6, 7] + + clip0_mass = motion.adaptive_bin_base_probs[:3].sum() + clip1_mass = motion.adaptive_bin_base_probs[3:].sum() + assert torch.allclose(clip0_mass, torch.tensor(0.25), atol=1e-6) + assert torch.allclose(clip1_mass, torch.tensor(0.75), atol=1e-6) + + +def test_motion_lib_adaptive_sampling_never_crosses_clip_boundaries(tmp_path: Path) -> None: + motion_path = _write_shard_dir( + tmp_path / "motion_adaptive_sample", + [_clip_dict(num_frames=6), _clip_dict(num_frames=8)], + ) + motion = MotionLib( + str(motion_path), + body_indexes=torch.tensor([0, 1], dtype=torch.long), + window_steps=(0,), + ) + motion.prepare_adaptive_sampling(bin_size_frames=2) + + motion_ids, motion_times, bins = motion.sample_adaptive_times( + motion.adaptive_bin_base_probs, + 512, + ) + sampled_frames = motion_times * motion.clip_fps[motion_ids] + + assert torch.all(sampled_frames >= motion.clip_sample_starts[motion_ids]) + assert torch.all(sampled_frames < motion.clip_sample_ends[motion_ids]) + assert torch.equal(motion.adaptive_bin_clip_ids[bins], motion_ids) + assert torch.equal(motion.adaptive_bins_for(motion_ids, motion_times), bins) + + +def test_motion_command_adaptive_sampling_state_round_trips() -> None: + source = MotionCommand.__new__(MotionCommand) + source.cfg = SimpleNamespace(sampling_mode="adaptive", adaptive_bin_size_frames=2) + source.adaptive_bin_failed_count = torch.tensor([0.0, 2.0, 4.0]) + source._current_adaptive_bin_failed = torch.tensor([1.0, 0.0, 3.0]) + + target = MotionCommand.__new__(MotionCommand) + target.cfg = SimpleNamespace(sampling_mode="adaptive", adaptive_bin_size_frames=2) + target._env = SimpleNamespace(device="cpu") + target.adaptive_bin_failed_count = torch.zeros(3) + target._current_adaptive_bin_failed = torch.zeros(3) + + state = source.get_adaptive_sampling_state() + assert state is not None + target.load_adaptive_sampling_state(state) + + assert torch.equal(target.adaptive_bin_failed_count, source.adaptive_bin_failed_count) + assert torch.equal( + target._current_adaptive_bin_failed, + source._current_adaptive_bin_failed, + ) diff --git a/tests/test_task_registry.py b/tests/test_task_registry.py index 0f6797bf..3400709d 100644 --- a/tests/test_task_registry.py +++ b/tests/test_task_registry.py @@ -31,7 +31,7 @@ def test_general_tracking_task_is_registered() -> None: assert "base_lin_vel" not in actor_terms assert "actor_history" in env_cfg.observations assert "critic_history" in env_cfg.observations - assert env_cfg.commands["motion"].sampling_mode == "uniform" + assert env_cfg.commands["motion"].sampling_mode == "adaptive" assert env_cfg.commands["motion"].window_steps == (0,) reward = env_cfg.rewards["self_collisions"] assert reward.weight == -0.1 diff --git a/train_mimic/tasks/tracking/config/env.py b/train_mimic/tasks/tracking/config/env.py index 95932374..00dd1e6b 100644 --- a/train_mimic/tasks/tracking/config/env.py +++ b/train_mimic/tasks/tracking/config/env.py @@ -172,7 +172,7 @@ def make_general_tracking_env_cfg( motion_cmd.anchor_body_name = "torso_link" motion_cmd.body_names = _TRACKING_BODY_NAMES motion_cmd.motion_file = DEFAULT_TRAIN_MOTION_FILE - motion_cmd.sampling_mode = "uniform" + motion_cmd.sampling_mode = "adaptive" motion_cmd.window_steps = (0,) cfg.events["physics_material"].params[ diff --git a/train_mimic/tasks/tracking/mdp/commands.py b/train_mimic/tasks/tracking/mdp/commands.py index 753b1163..004c7636 100644 --- a/train_mimic/tasks/tracking/mdp/commands.py +++ b/train_mimic/tasks/tracking/mdp/commands.py @@ -299,6 +299,7 @@ def __init__( ) self.clip_sample_start_s = self.clip_sample_starts.float() * self.clip_dt self.clip_sample_end_s = self.clip_sample_ends.float() * self.clip_dt + self._adaptive_bin_size_frames: int | None = None # ------------------------------------------------------------------ # Sampling helpers @@ -332,6 +333,132 @@ def sample_start_times(self, motion_ids: torch.Tensor) -> torch.Tensor: """Return the earliest valid center time for each motion id.""" return self.clip_sample_start_s[motion_ids] + def prepare_adaptive_sampling(self, bin_size_frames: int) -> int: + """Build clip-local adaptive sampling bins. + + Bins are cut only from each clip's valid center-frame range, so sampled + times never cross into adjacent clips in the flat motion arrays. + """ + if bin_size_frames <= 0: + raise ValueError( + f"adaptive_bin_size_frames must be positive, got {bin_size_frames}" + ) + if self._adaptive_bin_size_frames == bin_size_frames: + return int(self.adaptive_bin_clip_ids.numel()) + + clip_sample_starts = self.clip_sample_starts.cpu().numpy() + clip_sample_ends = self.clip_sample_ends.cpu().numpy() + clip_weights = self.clip_weights.cpu().numpy() + + bin_clip_ids: list[int] = [] + bin_start_frames: list[int] = [] + bin_end_frames: list[int] = [] + bin_base_weights: list[float] = [] + clip_bin_offsets = np.full(self.num_clips, -1, dtype=np.int64) + clip_bin_counts = np.zeros(self.num_clips, dtype=np.int64) + + for clip_id in range(self.num_clips): + clip_weight = float(clip_weights[clip_id]) + sample_start = int(clip_sample_starts[clip_id]) + sample_end = int(clip_sample_ends[clip_id]) + valid_length = sample_end - sample_start + if clip_weight <= 0.0 or valid_length <= 0: + continue + + clip_bin_offsets[clip_id] = len(bin_clip_ids) + for start in range(sample_start, sample_end, bin_size_frames): + end = min(start + bin_size_frames, sample_end) + width = end - start + bin_clip_ids.append(clip_id) + bin_start_frames.append(start) + bin_end_frames.append(end) + bin_base_weights.append(clip_weight * float(width) / float(valid_length)) + clip_bin_counts[clip_id] = len(bin_clip_ids) - clip_bin_offsets[clip_id] + + if not bin_clip_ids: + raise ValueError( + "Adaptive sampling has no valid bins. Check clip_weights and " + f"window_steps={list(self.window_steps)}." + ) + + device = self._device + self.adaptive_bin_clip_ids = torch.tensor( + bin_clip_ids, dtype=torch.long, device=device + ) + self.adaptive_bin_start_frames = torch.tensor( + bin_start_frames, dtype=torch.float32, device=device + ) + self.adaptive_bin_end_frames = torch.tensor( + bin_end_frames, dtype=torch.float32, device=device + ) + self.adaptive_bin_base_probs = torch.tensor( + bin_base_weights, dtype=torch.float32, device=device + ) + total = self.adaptive_bin_base_probs.sum() + if total <= 0: + raise ValueError("Adaptive sampling base probabilities sum to zero.") + self.adaptive_bin_base_probs = self.adaptive_bin_base_probs / total + self.adaptive_clip_bin_offsets = torch.tensor( + clip_bin_offsets, dtype=torch.long, device=device + ) + self.adaptive_clip_bin_counts = torch.tensor( + clip_bin_counts, dtype=torch.long, device=device + ) + self._adaptive_bin_size_frames = bin_size_frames + return int(self.adaptive_bin_clip_ids.numel()) + + def sample_adaptive_times( + self, + bin_probabilities: torch.Tensor, + n: int, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Sample clip ids and clip-local times from adaptive bin probabilities.""" + if self._adaptive_bin_size_frames is None: + raise RuntimeError("prepare_adaptive_sampling() must be called first.") + if bin_probabilities.shape != self.adaptive_bin_base_probs.shape: + raise ValueError( + "adaptive bin probability shape mismatch: " + f"{tuple(bin_probabilities.shape)} vs " + f"{tuple(self.adaptive_bin_base_probs.shape)}" + ) + sampled_bins = torch.multinomial(bin_probabilities, n, replacement=True) + motion_ids = self.adaptive_bin_clip_ids[sampled_bins] + starts = self.adaptive_bin_start_frames[sampled_bins] + ends = self.adaptive_bin_end_frames[sampled_bins] + frame_f = starts + torch.rand_like(starts) * (ends - starts) + motion_times = frame_f / self.clip_fps[motion_ids] + return motion_ids, motion_times, sampled_bins + + def adaptive_bins_for( + self, motion_ids: torch.Tensor, motion_times: torch.Tensor + ) -> torch.Tensor: + """Return adaptive bin ids for clip-local motion states.""" + if self._adaptive_bin_size_frames is None: + raise RuntimeError("prepare_adaptive_sampling() must be called first.") + counts = self.adaptive_clip_bin_counts[motion_ids] + offsets = self.adaptive_clip_bin_offsets[motion_ids] + bins = torch.full_like(motion_ids, -1) + valid = (counts > 0) & (offsets >= 0) + if not torch.any(valid): + return bins + + valid_ids = motion_ids[valid] + local_frames = torch.floor( + motion_times[valid] * self.clip_fps[valid_ids] + ).long() + rel = local_frames - self.clip_sample_starts[valid_ids] + local_bins = torch.div( + torch.clamp(rel, min=0), + self._adaptive_bin_size_frames, + rounding_mode="floor", + ) + local_bins = torch.minimum( + torch.clamp(local_bins, min=0), + counts[valid] - 1, + ) + bins[valid] = offsets[valid] + local_bins + return bins + def _compute_interpolation_state( self, motion_ids: torch.Tensor, @@ -487,11 +614,23 @@ def __init__(self, cfg: MotionCommandCfg, env: ManagerBasedRlEnv): device=self.device, window_steps=self.cfg.window_steps, ) + if self.cfg.sampling_mode == "adaptive": + adaptive_bin_count = self.motion.prepare_adaptive_sampling( + self.cfg.adaptive_bin_size_frames + ) + else: + adaptive_bin_count = 0 # Per-env motion state: clip id + elapsed time (seconds) self.motion_ids = torch.zeros(self.num_envs, dtype=torch.long, device=self.device) self.motion_times = torch.zeros(self.num_envs, dtype=torch.float32, device=self.device) self._step_dt = env.step_dt + self.adaptive_bin_failed_count = torch.zeros( + adaptive_bin_count, dtype=torch.float32, device=self.device + ) + self._current_adaptive_bin_failed = torch.zeros( + adaptive_bin_count, dtype=torch.float32, device=self.device + ) # Cached interpolated frames — refreshed every step self._cached_frames: dict[str, torch.Tensor] = {} @@ -524,6 +663,9 @@ def __init__(self, cfg: MotionCommandCfg, env: ManagerBasedRlEnv): self.metrics["error_body_rot"] = torch.zeros(self.num_envs, device=self.device) self.metrics["error_joint_pos"] = torch.zeros(self.num_envs, device=self.device) self.metrics["error_joint_vel"] = torch.zeros(self.num_envs, device=self.device) + self.metrics["sampling_entropy"] = torch.zeros(self.num_envs, device=self.device) + self.metrics["sampling_top1_prob"] = torch.zeros(self.num_envs, device=self.device) + self.metrics["sampling_top1_bin"] = torch.zeros(self.num_envs, device=self.device) # Feet standing state (for feet_air_time_ref rewards) if self.cfg.feet_body_names: @@ -769,16 +911,68 @@ def _uniform_sampling(self, env_ids: torch.Tensor): self.motion_ids[env_ids] = self.motion.sample_motion_ids(len(env_ids)) self.motion_times[env_ids] = self.motion.sample_times(self.motion_ids[env_ids]) + def _adaptive_sampling(self, env_ids: torch.Tensor): + episode_failed = self._env.termination_manager.terminated[env_ids] + if torch.any(episode_failed): + failed_env_ids = env_ids[episode_failed] + failed_bins = self.motion.adaptive_bins_for( + self.motion_ids[failed_env_ids], + self.motion_times[failed_env_ids], + ) + failed_bins = failed_bins[failed_bins >= 0] + if failed_bins.numel() > 0: + self._current_adaptive_bin_failed += torch.bincount( + failed_bins, + minlength=self.adaptive_bin_failed_count.numel(), + ).to(dtype=torch.float32, device=self.device) + + sampling_probabilities = ( + self.adaptive_bin_failed_count + + self.cfg.adaptive_uniform_ratio * self.motion.adaptive_bin_base_probs + ) + probability_sum = sampling_probabilities.sum() + if probability_sum <= 0: + sampling_probabilities = self.motion.adaptive_bin_base_probs + else: + sampling_probabilities = sampling_probabilities / probability_sum + + motion_ids, motion_times, _sampled_bins = self.motion.sample_adaptive_times( + sampling_probabilities, + len(env_ids), + ) + self.motion_ids[env_ids] = motion_ids + self.motion_times[env_ids] = motion_times + + entropy = -( + sampling_probabilities * (sampling_probabilities + 1e-12).log() + ).sum() + if sampling_probabilities.numel() > 1: + entropy = entropy / torch.log( + torch.tensor( + float(sampling_probabilities.numel()), + dtype=torch.float32, + device=self.device, + ) + ) + pmax, imax = sampling_probabilities.max(dim=0) + self.metrics["sampling_entropy"][:] = entropy + self.metrics["sampling_top1_prob"][:] = pmax + self.metrics["sampling_top1_bin"][:] = ( + imax.float() / max(sampling_probabilities.numel(), 1) + ) + def _resample_command(self, env_ids: torch.Tensor): if self.cfg.sampling_mode == "start": self.motion_ids[env_ids] = self.motion.sample_motion_ids(len(env_ids)) self.motion_times[env_ids] = self.motion.sample_start_times(self.motion_ids[env_ids]) elif self.cfg.sampling_mode == "uniform": self._uniform_sampling(env_ids) + elif self.cfg.sampling_mode == "adaptive": + self._adaptive_sampling(env_ids) else: raise ValueError( f"Unsupported motion sampling_mode={self.cfg.sampling_mode!r}. " - "Supported modes are 'uniform' and 'start'." + "Supported modes are 'uniform', 'start', and 'adaptive'." ) if env_ids.numel() == 0: @@ -903,9 +1097,53 @@ def _update_command(self): delta_ori_w, self.body_pos_w - anchor_pos_w_repeat ) + if self.cfg.sampling_mode == "adaptive": + self.adaptive_bin_failed_count = ( + self.cfg.adaptive_alpha * self._current_adaptive_bin_failed + + (1.0 - self.cfg.adaptive_alpha) * self.adaptive_bin_failed_count + ) + self._current_adaptive_bin_failed.zero_() + self._refresh_body_local_cache() self._update_feet_standing() + def get_adaptive_sampling_state(self) -> dict[str, torch.Tensor | int] | None: + if self.cfg.sampling_mode != "adaptive": + return None + return { + "adaptive_bin_size_frames": int(self.cfg.adaptive_bin_size_frames), + "adaptive_bin_failed_count": self.adaptive_bin_failed_count.detach().cpu(), + "current_adaptive_bin_failed": self._current_adaptive_bin_failed.detach().cpu(), + } + + def load_adaptive_sampling_state(self, state: dict[str, torch.Tensor | int]) -> None: + if self.cfg.sampling_mode != "adaptive": + return + bin_size = int(state.get("adaptive_bin_size_frames", -1)) + if bin_size != self.cfg.adaptive_bin_size_frames: + raise ValueError( + "adaptive sampling checkpoint bin size mismatch: " + f"checkpoint={bin_size}, current={self.cfg.adaptive_bin_size_frames}" + ) + failed_count = state.get("adaptive_bin_failed_count") + current_failed = state.get("current_adaptive_bin_failed") + if not isinstance(failed_count, torch.Tensor) or not isinstance(current_failed, torch.Tensor): + raise ValueError("adaptive sampling checkpoint state is missing tensors") + if failed_count.shape != self.adaptive_bin_failed_count.shape: + raise ValueError( + "adaptive sampling checkpoint bin count mismatch: " + f"checkpoint={tuple(failed_count.shape)}, " + f"current={tuple(self.adaptive_bin_failed_count.shape)}" + ) + if current_failed.shape != self._current_adaptive_bin_failed.shape: + raise ValueError( + "adaptive sampling checkpoint current-bin count mismatch: " + f"checkpoint={tuple(current_failed.shape)}, " + f"current={tuple(self._current_adaptive_bin_failed.shape)}" + ) + self.adaptive_bin_failed_count.copy_(failed_count.to(self.device)) + self._current_adaptive_bin_failed.copy_(current_failed.to(self.device)) + # ------------------------------------------------------------------ # Visualization # ------------------------------------------------------------------ @@ -990,8 +1228,11 @@ class MotionCommandCfg(CommandTermCfg): pose_range: dict[str, tuple[float, float]] = field(default_factory=dict) velocity_range: dict[str, tuple[float, float]] = field(default_factory=dict) joint_position_range: tuple[float, float] = (-0.52, 0.52) - sampling_mode: Literal["uniform", "start"] = "uniform" + sampling_mode: Literal["uniform", "start", "adaptive"] = "uniform" window_steps: tuple[int, ...] = (0,) + adaptive_bin_size_frames: int = 10 + adaptive_uniform_ratio: float = 0.1 + adaptive_alpha: float = 0.001 feet_body_names: tuple[str, ...] = () feet_standing_z_threshold: float = 0.18 feet_standing_vxy_threshold: float = 0.2 diff --git a/train_mimic/tasks/tracking/rl/runner.py b/train_mimic/tasks/tracking/rl/runner.py index 058a1360..05ef981a 100644 --- a/train_mimic/tasks/tracking/rl/runner.py +++ b/train_mimic/tasks/tracking/rl/runner.py @@ -93,6 +93,29 @@ def __init__( super().__init__(env, train_cfg, log_dir, device) self.registry_name = registry_name + def _motion_command(self) -> MotionCommand: + return cast(MotionCommand, self.env.unwrapped.command_manager.get_term("motion")) + + def save(self, path: str, infos=None) -> None: + motion_state = self._motion_command().get_adaptive_sampling_state() + if motion_state is not None: + infos = {**(infos or {}), "motion_adaptive_sampling_state": motion_state} + super().save(path, infos) + + def load( + self, + path: str, + load_cfg: dict | None = None, + strict: bool = True, + map_location: str | None = None, + ) -> dict: + infos = super().load(path, load_cfg=load_cfg, strict=strict, map_location=map_location) + if infos and "motion_adaptive_sampling_state" in infos: + self._motion_command().load_adaptive_sampling_state( + infos["motion_adaptive_sampling_state"] + ) + return infos + def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = False) -> None: """Run the learning loop using 1-based iteration numbering.""" if init_at_random_ep_len: From d6cabd871e892c41699621ad33a80f125235b730 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Thu, 11 Jun 2026 12:12:49 +0800 Subject: [PATCH 073/122] Refine adaptive sampling metrics --- train_mimic/tasks/tracking/mdp/commands.py | 65 ++++++++++++---------- 1 file changed, 37 insertions(+), 28 deletions(-) diff --git a/train_mimic/tasks/tracking/mdp/commands.py b/train_mimic/tasks/tracking/mdp/commands.py index 004c7636..9edbfeef 100644 --- a/train_mimic/tasks/tracking/mdp/commands.py +++ b/train_mimic/tasks/tracking/mdp/commands.py @@ -665,7 +665,7 @@ def __init__(self, cfg: MotionCommandCfg, env: ManagerBasedRlEnv): self.metrics["error_joint_vel"] = torch.zeros(self.num_envs, device=self.device) self.metrics["sampling_entropy"] = torch.zeros(self.num_envs, device=self.device) self.metrics["sampling_top1_prob"] = torch.zeros(self.num_envs, device=self.device) - self.metrics["sampling_top1_bin"] = torch.zeros(self.num_envs, device=self.device) + self.metrics["sampling_failed_bin_mean"] = torch.zeros(self.num_envs, device=self.device) # Feet standing state (for feet_air_time_ref rewards) if self.cfg.feet_body_names: @@ -911,6 +911,37 @@ def _uniform_sampling(self, env_ids: torch.Tensor): self.motion_ids[env_ids] = self.motion.sample_motion_ids(len(env_ids)) self.motion_times[env_ids] = self.motion.sample_times(self.motion_ids[env_ids]) + def _update_adaptive_sampling_metrics( + self, + sampling_probabilities: torch.Tensor, + ) -> None: + entropy = -( + sampling_probabilities * (sampling_probabilities + 1e-12).log() + ).sum() + if sampling_probabilities.numel() > 1: + entropy = entropy / torch.log( + torch.tensor( + float(sampling_probabilities.numel()), + dtype=torch.float32, + device=self.device, + ) + ) + pmax = sampling_probabilities.max() + failed = self.adaptive_bin_failed_count + self.metrics["sampling_entropy"][:] = entropy + self.metrics["sampling_top1_prob"][:] = pmax + self.metrics["sampling_failed_bin_mean"][:] = failed.mean() + + def _adaptive_sampling_probabilities(self) -> torch.Tensor: + sampling_probabilities = ( + self.adaptive_bin_failed_count + + self.cfg.adaptive_uniform_ratio * self.motion.adaptive_bin_base_probs + ) + probability_sum = sampling_probabilities.sum() + if probability_sum <= 0: + return self.motion.adaptive_bin_base_probs + return sampling_probabilities / probability_sum + def _adaptive_sampling(self, env_ids: torch.Tensor): episode_failed = self._env.termination_manager.terminated[env_ids] if torch.any(episode_failed): @@ -926,16 +957,7 @@ def _adaptive_sampling(self, env_ids: torch.Tensor): minlength=self.adaptive_bin_failed_count.numel(), ).to(dtype=torch.float32, device=self.device) - sampling_probabilities = ( - self.adaptive_bin_failed_count - + self.cfg.adaptive_uniform_ratio * self.motion.adaptive_bin_base_probs - ) - probability_sum = sampling_probabilities.sum() - if probability_sum <= 0: - sampling_probabilities = self.motion.adaptive_bin_base_probs - else: - sampling_probabilities = sampling_probabilities / probability_sum - + sampling_probabilities = self._adaptive_sampling_probabilities() motion_ids, motion_times, _sampled_bins = self.motion.sample_adaptive_times( sampling_probabilities, len(env_ids), @@ -943,23 +965,7 @@ def _adaptive_sampling(self, env_ids: torch.Tensor): self.motion_ids[env_ids] = motion_ids self.motion_times[env_ids] = motion_times - entropy = -( - sampling_probabilities * (sampling_probabilities + 1e-12).log() - ).sum() - if sampling_probabilities.numel() > 1: - entropy = entropy / torch.log( - torch.tensor( - float(sampling_probabilities.numel()), - dtype=torch.float32, - device=self.device, - ) - ) - pmax, imax = sampling_probabilities.max(dim=0) - self.metrics["sampling_entropy"][:] = entropy - self.metrics["sampling_top1_prob"][:] = pmax - self.metrics["sampling_top1_bin"][:] = ( - imax.float() / max(sampling_probabilities.numel(), 1) - ) + self._update_adaptive_sampling_metrics(sampling_probabilities) def _resample_command(self, env_ids: torch.Tensor): if self.cfg.sampling_mode == "start": @@ -1103,6 +1109,9 @@ def _update_command(self): + (1.0 - self.cfg.adaptive_alpha) * self.adaptive_bin_failed_count ) self._current_adaptive_bin_failed.zero_() + self._update_adaptive_sampling_metrics( + self._adaptive_sampling_probabilities() + ) self._refresh_body_local_cache() self._update_feet_standing() From 4b78920d80d6c72d9f6ad7b2bb2f20f97a9233eb Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Thu, 11 Jun 2026 12:22:38 +0800 Subject: [PATCH 074/122] Allow adaptive sampling CLI override --- train_mimic/scripts/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_mimic/scripts/train.py b/train_mimic/scripts/train.py index 98e3353d..4b3fc611 100644 --- a/train_mimic/scripts/train.py +++ b/train_mimic/scripts/train.py @@ -85,7 +85,7 @@ def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace: ), ) parser.add_argument("--sampling_mode", type=str, default=None, - choices=["uniform", "start"], + choices=["uniform", "start", "adaptive"], help="Motion sampling mode (default: from task config)") parser.add_argument("--device", type=str, default=None) parser.add_argument( From f6b58553f16b4fd0f3c08907b0e377ee6675b3a6 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Thu, 11 Jun 2026 15:37:22 +0800 Subject: [PATCH 075/122] Remove unused tracking rewards --- train_mimic/tasks/tracking/mdp/commands.py | 1 - train_mimic/tasks/tracking/mdp/rewards.py | 376 --------------------- 2 files changed, 377 deletions(-) diff --git a/train_mimic/tasks/tracking/mdp/commands.py b/train_mimic/tasks/tracking/mdp/commands.py index 9edbfeef..c29a69fb 100644 --- a/train_mimic/tasks/tracking/mdp/commands.py +++ b/train_mimic/tasks/tracking/mdp/commands.py @@ -667,7 +667,6 @@ def __init__(self, cfg: MotionCommandCfg, env: ManagerBasedRlEnv): self.metrics["sampling_top1_prob"] = torch.zeros(self.num_envs, device=self.device) self.metrics["sampling_failed_bin_mean"] = torch.zeros(self.num_envs, device=self.device) - # Feet standing state (for feet_air_time_ref rewards) if self.cfg.feet_body_names: self._feet_body_indexes = [ self.cfg.body_names.index(n) for n in self.cfg.feet_body_names diff --git a/train_mimic/tasks/tracking/mdp/rewards.py b/train_mimic/tasks/tracking/mdp/rewards.py index ec032e05..9644a810 100644 --- a/train_mimic/tasks/tracking/mdp/rewards.py +++ b/train_mimic/tasks/tracking/mdp/rewards.py @@ -4,9 +4,6 @@ import torch -from mjlab.entity import Entity -from mjlab.managers.reward_manager import RewardTermCfg -from mjlab.managers.scene_entity_config import SceneEntityCfg from mjlab.utils.lab_api.math import ( quat_error_magnitude, ) @@ -17,9 +14,6 @@ from mjlab.envs import ManagerBasedRlEnv -_DEFAULT_ASSET_CFG = SceneEntityCfg("robot") - - def _get_body_indexes( command: MotionCommand, body_names: tuple[str, ...] | None ) -> list[int]: @@ -30,29 +24,6 @@ def _get_body_indexes( ] -def survival(env: ManagerBasedRlEnv) -> torch.Tensor: - """Match the sibling motion_tracking task's constant alive reward.""" - return torch.ones(env.num_envs, dtype=torch.float32, device=env.device) - - -def joint_pos_tracking_exp( - env: ManagerBasedRlEnv, command_name: str, std: float -) -> torch.Tensor: - """Joint position tracking reward using MotionCommand interface.""" - command = cast(MotionCommand, env.command_manager.get_term(command_name)) - error = torch.mean(torch.abs(command.joint_pos - command.robot_joint_pos), dim=-1) - return torch.exp(-error / std**2) - - -def joint_vel_tracking_exp( - env: ManagerBasedRlEnv, command_name: str, std: float -) -> torch.Tensor: - """Joint velocity tracking reward using MotionCommand interface.""" - command = cast(MotionCommand, env.command_manager.get_term(command_name)) - error = torch.mean(torch.abs(command.joint_vel - command.robot_joint_vel), dim=-1) - return torch.exp(-error / std**2) - - def motion_global_anchor_position_error_exp( env: ManagerBasedRlEnv, command_name: str, std: float ) -> torch.Tensor: @@ -179,350 +150,3 @@ def _self_collision_hits( found = torch.cat(found_values, dim=1) return (found > 0).any(dim=1, keepdim=True) - - -class joint_torque_limits: - """Penalize actuator-force limit violations with a configurable soft margin.""" - - def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRlEnv): - asset_cfg = cast(SceneEntityCfg, cfg.params.get("asset_cfg", _DEFAULT_ASSET_CFG)) - self.asset: Entity = env.scene[asset_cfg.name] - self.soft_factor = float(cfg.params.get("soft_factor", 0.9)) - if not (0.0 < self.soft_factor <= 1.0): - raise ValueError(f"soft_factor must be in (0, 1], got {self.soft_factor}") - - joint_ids = asset_cfg.joint_ids - if isinstance(joint_ids, slice): - joint_names = list(self.asset.joint_names) - else: - joint_names = [self.asset.joint_names[idx] for idx in joint_ids] - actuator_names = list(self.asset.actuator_names) - name_to_actuator = {name: idx for idx, name in enumerate(actuator_names)} - actuator_ids: list[int] = [] - for joint_name in joint_names: - if joint_name not in name_to_actuator: - raise RuntimeError(f"Actuator for joint '{joint_name}' not found.") - actuator_ids.append(name_to_actuator[joint_name]) - self._actuator_ids = torch.tensor(actuator_ids, device=env.device, dtype=torch.long) - - force_range = torch.as_tensor( - env.sim.model.actuator_forcerange, device=env.device, dtype=torch.float32 - ) - if force_range.ndim == 2: - force_range = force_range.unsqueeze(0).expand(env.num_envs, -1, -1) - elif force_range.ndim != 3: - raise RuntimeError( - f"Unexpected actuator_forcerange shape: {tuple(force_range.shape)}" - ) - ctrl_ids = torch.as_tensor( - self.asset.indexing.ctrl_ids, device=env.device, dtype=torch.long - ) - force_range = force_range.index_select(1, ctrl_ids) - torque_limit = torch.maximum(force_range[..., 0].abs(), force_range[..., 1].abs()) - self._soft_limits = torque_limit.index_select(1, self._actuator_ids) * self.soft_factor - - def __call__( - self, - env: ManagerBasedRlEnv, - asset_cfg: SceneEntityCfg = _DEFAULT_ASSET_CFG, - soft_factor: float = 0.9, - ) -> torch.Tensor: - del env, asset_cfg, soft_factor - applied_torque = self.asset.data.actuator_force.index_select(1, self._actuator_ids) - violation = (applied_torque.abs() / self._soft_limits - 1.0).clamp_min(0.0) - return -violation.sum(dim=1) - - -class feet_air_time_ref: - """Reward landing timing that matches the reference contact schedule.""" - - def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRlEnv): - self.sensor_name = str(cfg.params["sensor_name"]) - self.command_name = str(cfg.params["command_name"]) - self.threshold = float(cfg.params.get("thres", 0.5)) - self._reward_time = torch.zeros((env.num_envs, 0), dtype=torch.float32, device=env.device) - - def reset(self, env_ids: torch.Tensor | slice | None) -> None: - if env_ids is None: - env_ids = slice(None) - self._reward_time[env_ids] = 0.0 - - def __call__( - self, - env: ManagerBasedRlEnv, - sensor_name: str, - command_name: str, - thres: float = 0.5, - ) -> torch.Tensor: - del sensor_name, command_name, thres - sensor: ContactSensor = env.scene[self.sensor_name] - data = sensor.data - if data.found is None: - raise RuntimeError(f"Contact sensor '{self.sensor_name}' must expose 'found'") - current_contact = data.found > 0 - first_contact = sensor.compute_first_contact(dt=env.step_dt) - command = cast(object, env.command_manager.get_term(self.command_name)) - target_contact = torch.as_tensor( - cast(object, command).feet_standing, device=env.device, dtype=torch.bool - ) - if self._reward_time.shape != current_contact.shape: - self._reward_time = torch.zeros_like(current_contact, dtype=torch.float32) - - contact_match = current_contact == target_contact - self._reward_time = self._reward_time + torch.where( - contact_match, - torch.full_like(self._reward_time, env.step_dt), - torch.full_like(self._reward_time, -env.step_dt), - ) - reward = torch.sum( - (self._reward_time - self.threshold).clamp_max(0.0) * first_contact.float(), - dim=1, - ) - self._reward_time = self._reward_time * (~current_contact).float() - return reward - - -class feet_air_time_ref_dense: - """Dense contact/height shaping against the reference foot contact state.""" - - def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRlEnv): - asset_cfg = cast(SceneEntityCfg, cfg.params.get("asset_cfg", _DEFAULT_ASSET_CFG)) - self.asset: Entity = env.scene[asset_cfg.name] - body_names = cast(list[str], cfg.params["body_names"]) - body_ids, _ = self.asset.find_bodies(body_names, preserve_order=True) - self._body_ids = torch.tensor(body_ids, device=env.device, dtype=torch.long) - - site_names = cast(list[str] | None, cfg.params.get("site_names")) - self._site_ids: torch.Tensor | None = None - if site_names is not None: - site_ids, _ = self.asset.find_sites(site_names, preserve_order=True) - if len(site_ids) != len(body_ids): - raise ValueError( - "site_names must match body_names length for feet_air_time_ref_dense" - ) - self._site_ids = torch.tensor(site_ids, device=env.device, dtype=torch.long) - else: - body2_names = cast(list[str] | None, cfg.params.get("body2_names")) - if body2_names is None: - self._body2_ids = self._body_ids - else: - body2_ids, _ = self.asset.find_bodies(body2_names, preserve_order=True) - if len(body2_ids) != len(body_ids): - raise ValueError( - "body2_names must match body_names length for feet_air_time_ref_dense" - ) - self._body2_ids = torch.tensor(body2_ids, device=env.device, dtype=torch.long) - - self.sensor_name = str(cfg.params["sensor_name"]) - self.command_name = str(cfg.params["command_name"]) - self.air_h_low = float(cfg.params.get("air_h_low", 0.035)) - self.air_h_high = float(cfg.params.get("air_h_high", 0.155)) - self.contact_h_low = float(cfg.params.get("contact_h_low", 0.035)) - self.contact_h_high = float(cfg.params.get("contact_h_high", 0.125)) - self.air_h_span = max(self.air_h_high - self.air_h_low, 1.0e-6) - self.contact_h_span = max(self.contact_h_high - self.contact_h_low, 1.0e-6) - - def __call__( - self, - env: ManagerBasedRlEnv, - sensor_name: str, - command_name: str, - body_names: tuple[str, ...], - body2_names: tuple[str, ...] | None = None, - site_names: tuple[str, ...] | None = None, - air_h_low: float = 0.035, - air_h_high: float = 0.155, - contact_h_low: float = 0.035, - contact_h_high: float = 0.125, - asset_cfg: SceneEntityCfg = _DEFAULT_ASSET_CFG, - ) -> torch.Tensor: - del ( - sensor_name, - command_name, - body_names, - body2_names, - site_names, - air_h_low, - air_h_high, - contact_h_low, - contact_h_high, - asset_cfg, - ) - sensor: ContactSensor = env.scene[self.sensor_name] - data = sensor.data - if data.found is None: - raise RuntimeError(f"Contact sensor '{self.sensor_name}' must expose 'found'") - current_contact = data.found > 0 - command = cast(object, env.command_manager.get_term(self.command_name)) - target_contact = torch.as_tensor( - cast(object, command).feet_standing, device=env.device, dtype=torch.bool - ) - - mismatch = current_contact ^ target_contact - both_air = (~current_contact) & (~target_contact) - both_contact = current_contact & target_contact - - penalty = torch.zeros_like(target_contact, dtype=torch.float32) - penalty[mismatch] = -1.0 - - if self._site_ids is not None: - foot_probe_height = self.asset.data.site_pos_w.index_select(1, self._site_ids)[..., 2] - feet_height_air = foot_probe_height - feet_height_contact = foot_probe_height - else: - feet_height_air = torch.minimum( - self.asset.data.body_link_pos_w.index_select(1, self._body_ids)[..., 2], - self.asset.data.body_link_pos_w.index_select(1, self._body2_ids)[..., 2], - ) - feet_height_contact = torch.maximum( - self.asset.data.body_link_pos_w.index_select(1, self._body_ids)[..., 2], - self.asset.data.body_link_pos_w.index_select(1, self._body2_ids)[..., 2], - ) - air_ratio = ((feet_height_air - self.air_h_low) / self.air_h_span).clamp(0.0, 1.0) - penalty = torch.where(both_air, -(1.0 - air_ratio), penalty) - - contact_ratio = ( - (feet_height_contact - self.contact_h_low) / self.contact_h_span - ).clamp(0.0, 1.0) - penalty = torch.where(both_contact, -contact_ratio, penalty) - return penalty.mean(dim=1) - - -# --------------------------------------------------------------------------- -# TWIST2-style feet rewards (adapted to mjlab API) -# --------------------------------------------------------------------------- - - -class feet_air_time: - """Reward feet air time matching a target duration (TWIST2-style). - - Accumulates air time per foot; on first contact, rewards - ``(air_time - target).clamp(max=0)`` so that too-short air phases are - penalised. Only active when the reference root XY speed > 0.05 m/s. - """ - - def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRlEnv): - self.sensor_name = str(cfg.params["sensor_name"]) - self.command_name = str(cfg.params["command_name"]) - self.air_time_target = float(cfg.params.get("air_time_target", 0.5)) - self.feet_air_time = torch.zeros( - (env.num_envs, 0), dtype=torch.float32, device=env.device - ) - self._last_contacts = torch.zeros( - (env.num_envs, 0), dtype=torch.bool, device=env.device - ) - - def reset(self, env_ids: torch.Tensor | slice | None) -> None: - if env_ids is None: - env_ids = slice(None) - self.feet_air_time[env_ids] = 0.0 - self._last_contacts[env_ids] = False - - def __call__( - self, - env: ManagerBasedRlEnv, - sensor_name: str, - command_name: str, - air_time_target: float = 0.5, - ) -> torch.Tensor: - del sensor_name, command_name, air_time_target - sensor: ContactSensor = env.scene[self.sensor_name] - data = sensor.data - if data.force is None: - raise RuntimeError(f"Contact sensor '{self.sensor_name}' must expose 'force'") - contact = data.force[..., 2].abs() > 5.0 # [B, N] — TWIST2: fz > 5N - - # Lazy init on first call (shape unknown at __init__) - if self.feet_air_time.shape != contact.shape: - self.feet_air_time = torch.zeros_like(contact, dtype=torch.float32) - self._last_contacts = torch.zeros_like(contact, dtype=torch.bool) - - # contact_filt = contact OR last_contacts (same as TWIST2) - contact_filt = contact | self._last_contacts - self._last_contacts = contact - - first_contact = (self.feet_air_time > 0.0) & contact_filt - - self.feet_air_time += env.step_dt - air_time = (self.feet_air_time - self.air_time_target) * first_contact.float() - air_time = air_time.clamp(max=0.0) - self.feet_air_time *= ~contact_filt - - reward = air_time.sum(dim=1) - - # Gate by reference root XY speed > 0.05 - command = cast(MotionCommand, env.command_manager.get_term(self.command_name)) - ref_root_vxy = torch.norm(command.body_lin_vel_w[:, 0, :2], dim=-1) - reward *= (ref_root_vxy > 0.05).float() - - return reward - - -def feet_stumble( - env: ManagerBasedRlEnv, - sensor_name: str, -) -> torch.Tensor: - """Penalise stumbling: lateral contact force > 4x vertical (TWIST2-style).""" - sensor: ContactSensor = env.scene[sensor_name] - data = sensor.data - if data.force is None: - raise RuntimeError(f"Contact sensor '{sensor_name}' must expose 'force'") - force = data.force # [B, N, 3] - lateral = torch.norm(force[..., :2], dim=-1) # [B, N] - vertical = torch.abs(force[..., 2]) # [B, N] - stumble = torch.any(lateral > 4.0 * vertical, dim=1) - return stumble.float() - - -def feet_contact_forces( - env: ManagerBasedRlEnv, - sensor_name: str, - max_contact_force: float = 350.0, -) -> torch.Tensor: - """Penalise excessive vertical contact forces (TWIST2-style). - - Computes L2 norm of vertical forces across all feet, then penalises the - excess above *max_contact_force*. This matches TWIST2's implementation - where two feet at 300 N each (norm ≈ 424 N) would trigger a penalty. - """ - sensor: ContactSensor = env.scene[sensor_name] - data = sensor.data - if data.force is None: - raise RuntimeError(f"Contact sensor '{sensor_name}' must expose 'force'") - fz = data.force[..., 2] # [B, N] - fz_norm = torch.norm(fz, dim=-1) # [B] - excess = (fz_norm - max_contact_force).clamp(min=0.0) - return excess - - -class feet_slip: - """Penalise horizontal foot velocity while in contact (TWIST2-style).""" - - def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRlEnv): - asset_cfg = cast(SceneEntityCfg, cfg.params.get("asset_cfg", _DEFAULT_ASSET_CFG)) - self.asset: Entity = env.scene[asset_cfg.name] - body_names = cast(list[str], cfg.params["body_names"]) - body_ids, _ = self.asset.find_bodies(body_names, preserve_order=True) - self._body_ids = torch.tensor(body_ids, device=env.device, dtype=torch.long) - self.sensor_name = str(cfg.params["sensor_name"]) - - def __call__( - self, - env: ManagerBasedRlEnv, - sensor_name: str, - body_names: tuple[str, ...], - asset_cfg: SceneEntityCfg = _DEFAULT_ASSET_CFG, - ) -> torch.Tensor: - del sensor_name, body_names, asset_cfg - sensor: ContactSensor = env.scene[self.sensor_name] - data = sensor.data - if data.force is None: - raise RuntimeError(f"Contact sensor '{self.sensor_name}' must expose 'force'") - contact = (data.force[..., 2].abs() > 5.0).float() # [B, N] — TWIST2: fz > 5N - - feet_vel_xy = self.asset.data.body_link_lin_vel_w.index_select( - 1, self._body_ids - )[..., :2] # [B, N, 2] - speed_xy = torch.norm(feet_vel_xy, dim=-1) # [B, N] - slip = torch.sqrt(speed_xy) * contact - return slip.sum(dim=1) From 0a6b4d71721eb389537b0e520e297ce52e05aab5 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Thu, 11 Jun 2026 19:04:10 +0800 Subject: [PATCH 076/122] Add rewind motion sampling --- AGENTS.md | 3 +- README.md | 3 +- docs/docs/reference/architecture.md | 2 +- docs/docs/tutorials/training.md | 2 +- .../current/reference/architecture.md | 4 + .../current/tutorials/training.md | 2 +- tests/test_motion_sampling.py | 76 +++++++++++++++++++ tests/test_task_registry.py | 2 +- tests/test_train_script.py | 13 +++- train_mimic/scripts/train.py | 17 ++++- train_mimic/tasks/tracking/config/env.py | 2 +- train_mimic/tasks/tracking/mdp/commands.py | 60 ++++++++++++++- 12 files changed, 175 insertions(+), 11 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 16e7c6a7..81dc954b 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -190,7 +190,8 @@ The single supported training task is `General-Tracking-G1` (experiment name: `g - Uses TemporalCNN actor/critic with scaled dims (2048,1024,512,256,128) - 166D `velcmd_history` observation, dual-input ONNX export -- Training env uses `sampling_mode="adaptive"` +- Training env uses `sampling_mode="uniform"` +- Supported motion sampling modes are `adaptive`, `uniform`, `start`, and `rewind`; `rewind` restarts failed environments from the same clip after stepping back `rewind_min_steps..rewind_max_steps` with probability `rewind_prob`, otherwise it falls back to uniform sampling - Playback/benchmark use `play=True`, which switches motion sampling to `start` - `window_steps=[0]` - `save_onnx.py` exports dual-input TemporalCNN ONNX diff --git a/README.md b/README.md index c6a97f1b..7621ecbd 100644 --- a/README.md +++ b/README.md @@ -94,7 +94,8 @@ Full docs at **[BotRunner64.github.io/Teleopit](https://BotRunner64.github.io/Te - Switched default `vr_hand_pose` to a low-latency somehand path with 60 Hz hand retargeting and reduced smoothing. - Realtime mode switches and pause/resume now preserve GMR IK warm-starts instead of cold-starting the retargeter on each transition. - Added an interactive Pico motion recorder that saves retargeted G1 motion clips as training-ready NPZ files. -- General-Tracking-G1 training now defaults to clip-local adaptive motion sampling with checkpointed sampler state. +- General-Tracking-G1 training now defaults to uniform motion sampling; clip-local adaptive sampling remains available through `sampling_mode=adaptive`. +- Added optional `sampling_mode=rewind` for training, which restarts failed episodes from the same clip after rewinding a configurable number of policy steps. ### v0.3.0 (2026-05-12) diff --git a/docs/docs/reference/architecture.md b/docs/docs/reference/architecture.md index c3aa900e..b3edd11b 100644 --- a/docs/docs/reference/architecture.md +++ b/docs/docs/reference/architecture.md @@ -58,7 +58,7 @@ train_mimic/scripts/data | Inference observation | `velcmd_history` (166D) | | ONNX signature | Dual-input `obs` (166D) + `obs_history` | | Actor/Critic | TemporalCNN (2048, 1024, 512, 256, 128) | -| Training sampling | `uniform`; playback/benchmark use `start` | +| Training sampling | Default `uniform`; also supports `adaptive` and `rewind`; playback/benchmark use `start` | | Training `window_steps` | `[0]` | | Data format | Shard directories only (`shard_*.npz`) | diff --git a/docs/docs/tutorials/training.md b/docs/docs/tutorials/training.md index 5fe8b972..1e57142e 100644 --- a/docs/docs/tutorials/training.md +++ b/docs/docs/tutorials/training.md @@ -131,4 +131,4 @@ Key files: - `train_mimic/app.py` - Shared entry point for train/play/benchmark - `train_mimic/tasks/tracking/config/env.py` - General-Tracking-G1 env builder - `train_mimic/tasks/tracking/config/rl.py` - TemporalCNN PPO config -- `train_mimic/tasks/tracking/mdp/commands.py` - Supports `uniform` / `start` sampling modes +- `train_mimic/tasks/tracking/mdp/commands.py` - Supports `uniform`, `start`, `adaptive`, and `rewind` sampling modes. Training defaults to `uniform`; playback/benchmark use `start`. diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/architecture.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/architecture.md index a17dccfd..b73e931e 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/architecture.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/architecture.md @@ -60,6 +60,10 @@ teleopit/ | 实机通信 | DDS(Cyclone DDS) | | 支持的机器人 | 宇树 G1 | | 输入源 | BVH 文件、Pico 4 VR 头显 | +| 训练任务 | `General-Tracking-G1` | +| 推理观测 | `velcmd_history`(166D) | +| 训练采样 | 默认 `uniform`;也支持 `adaptive` 和 `rewind`;播放/评估使用 `start` | +| 训练 `window_steps` | `[0]` | ## 约束 diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md index 4ebc2521..4e385e4b 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md @@ -131,4 +131,4 @@ train_mimic/scripts - `train_mimic/app.py` - 训练/播放/评估的统一入口 - `train_mimic/tasks/tracking/config/env.py` - General-Tracking-G1 环境构建器 - `train_mimic/tasks/tracking/config/rl.py` - TemporalCNN PPO 配置 -- `train_mimic/tasks/tracking/mdp/commands.py` - 支持 `uniform` / `start` 两种采样模式 +- `train_mimic/tasks/tracking/mdp/commands.py` - 支持 `uniform`、`start`、`adaptive` 和 `rewind` 采样模式。训练默认使用 `uniform`;播放/评估使用 `start`。 diff --git a/tests/test_motion_sampling.py b/tests/test_motion_sampling.py index ecaeb51d..cf1b5132 100644 --- a/tests/test_motion_sampling.py +++ b/tests/test_motion_sampling.py @@ -4,6 +4,7 @@ from types import SimpleNamespace import numpy as np +import pytest import torch from train_mimic.data.dataset_lib import merge_clip_dicts @@ -221,3 +222,78 @@ def test_motion_command_adaptive_sampling_state_round_trips() -> None: target._current_adaptive_bin_failed, source._current_adaptive_bin_failed, ) + + +class _FakeMotion: + def __init__(self) -> None: + self.clip_sample_start_s = torch.tensor([0.0, 1.0, 2.0]) + + def sample_motion_ids(self, n: int) -> torch.Tensor: + return torch.full((n,), 2, dtype=torch.long) + + def sample_times(self, motion_ids: torch.Tensor) -> torch.Tensor: + return torch.full_like(motion_ids, 9.0, dtype=torch.float32) + + +def test_motion_command_rewind_sampling_uses_failed_env_previous_time() -> None: + command = MotionCommand.__new__(MotionCommand) + command.cfg = SimpleNamespace( + sampling_mode="rewind", + rewind_prob=1.0, + rewind_min_steps=2, + rewind_max_steps=2, + ) + command._env = SimpleNamespace( + device="cpu", + termination_manager=SimpleNamespace( + terminated=torch.tensor([True, False, True]) + ), + ) + command.motion = _FakeMotion() + command.motion_ids = torch.tensor([0, 1, 1], dtype=torch.long) + command.motion_times = torch.tensor([5.0, 6.0, 1.25], dtype=torch.float32) + command._step_dt = 0.5 + + command._rewind_sampling(torch.tensor([0, 1, 2], dtype=torch.long)) + + assert torch.equal(command.motion_ids, torch.tensor([0, 2, 1])) + assert torch.allclose(command.motion_times, torch.tensor([4.0, 9.0, 1.0])) + + +def test_motion_command_rewind_sampling_falls_back_to_uniform_when_disabled() -> None: + command = MotionCommand.__new__(MotionCommand) + command.cfg = SimpleNamespace( + sampling_mode="rewind", + rewind_prob=0.0, + rewind_min_steps=2, + rewind_max_steps=2, + ) + command._env = SimpleNamespace( + device="cpu", + termination_manager=SimpleNamespace(terminated=torch.tensor([True])), + ) + command.motion = _FakeMotion() + command.motion_ids = torch.tensor([0], dtype=torch.long) + command.motion_times = torch.tensor([5.0], dtype=torch.float32) + command._step_dt = 0.5 + + command._rewind_sampling(torch.tensor([0], dtype=torch.long)) + + assert torch.equal(command.motion_ids, torch.tensor([2])) + assert torch.allclose(command.motion_times, torch.tensor([9.0])) + + +def test_motion_command_rewind_sampling_rejects_invalid_step_range() -> None: + command = MotionCommand.__new__(MotionCommand) + command.cfg = SimpleNamespace( + sampling_mode="rewind", + rewind_prob=1.0, + rewind_min_steps=3, + rewind_max_steps=2, + ) + command.motion_ids = torch.tensor([0], dtype=torch.long) + command.motion_times = torch.tensor([5.0], dtype=torch.float32) + command.motion = _FakeMotion() + + with pytest.raises(ValueError, match="rewind_max_steps"): + command._rewind_sampling(torch.tensor([0], dtype=torch.long)) diff --git a/tests/test_task_registry.py b/tests/test_task_registry.py index 3400709d..0f6797bf 100644 --- a/tests/test_task_registry.py +++ b/tests/test_task_registry.py @@ -31,7 +31,7 @@ def test_general_tracking_task_is_registered() -> None: assert "base_lin_vel" not in actor_terms assert "actor_history" in env_cfg.observations assert "critic_history" in env_cfg.observations - assert env_cfg.commands["motion"].sampling_mode == "adaptive" + assert env_cfg.commands["motion"].sampling_mode == "uniform" assert env_cfg.commands["motion"].window_steps == (0,) reward = env_cfg.rewards["self_collisions"] assert reward.weight == -0.1 diff --git a/tests/test_train_script.py b/tests/test_train_script.py index 0dbce3bf..f92aa81c 100644 --- a/tests/test_train_script.py +++ b/tests/test_train_script.py @@ -167,7 +167,15 @@ def test_configure_swanlab_logger_syncs_tensorboard(self, monkeypatch: pytest.Mo monkeypatch.setenv("RANK", "0") agent_cfg = types.SimpleNamespace(logger="wandb", experiment_name="exp", max_iterations=10) env_cfg = types.SimpleNamespace( - commands={"motion": types.SimpleNamespace(motion_file="data/train", sampling_mode="uniform")}, + commands={ + "motion": types.SimpleNamespace( + motion_file="data/train", + sampling_mode="uniform", + rewind_prob=0.8, + rewind_min_steps=25, + rewind_max_steps=75, + ) + }, scene=types.SimpleNamespace(num_envs=64), ) @@ -193,6 +201,9 @@ def test_configure_swanlab_logger_syncs_tensorboard(self, monkeypatch: pytest.Mo "num_envs": 64, "max_iterations": 10, "sampling_mode": "uniform", + "rewind_prob": 0.8, + "rewind_min_steps": 25, + "rewind_max_steps": 75, }, }, ), diff --git a/train_mimic/scripts/train.py b/train_mimic/scripts/train.py index 4b3fc611..a576b9de 100644 --- a/train_mimic/scripts/train.py +++ b/train_mimic/scripts/train.py @@ -85,8 +85,14 @@ def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace: ), ) parser.add_argument("--sampling_mode", type=str, default=None, - choices=["uniform", "start", "adaptive"], + choices=["uniform", "start", "adaptive", "rewind"], help="Motion sampling mode (default: from task config)") + parser.add_argument("--rewind_prob", type=float, default=None, + help="Rewind sampling probability for failed episodes") + parser.add_argument("--rewind_min_steps", type=int, default=None, + help="Minimum policy steps to rewind for rewind sampling") + parser.add_argument("--rewind_max_steps", type=int, default=None, + help="Maximum policy steps to rewind for rewind sampling") parser.add_argument("--device", type=str, default=None) parser.add_argument( "--gpu_ids", @@ -284,6 +290,9 @@ def _configure_experiment_logger( "num_envs": env_cfg.scene.num_envs, "max_iterations": agent_cfg.max_iterations, "sampling_mode": env_cfg.commands["motion"].sampling_mode, + "rewind_prob": env_cfg.commands["motion"].rewind_prob, + "rewind_min_steps": env_cfg.commands["motion"].rewind_min_steps, + "rewind_max_steps": env_cfg.commands["motion"].rewind_max_steps, }, ) swanlab.sync_tensorboard_torch(types=["scalar", "scalars", "image", "text"]) @@ -366,6 +375,12 @@ def _handle_shutdown(signum: int, _frame: Any) -> None: validate_motion_file(env_cfg.commands["motion"].motion_file) if args.sampling_mode is not None: env_cfg.commands["motion"].sampling_mode = args.sampling_mode + if args.rewind_prob is not None: + env_cfg.commands["motion"].rewind_prob = args.rewind_prob + if args.rewind_min_steps is not None: + env_cfg.commands["motion"].rewind_min_steps = args.rewind_min_steps + if args.rewind_max_steps is not None: + env_cfg.commands["motion"].rewind_max_steps = args.rewind_max_steps if args.max_iterations is not None: agent_cfg.max_iterations = args.max_iterations if args.experiment_name is not None: diff --git a/train_mimic/tasks/tracking/config/env.py b/train_mimic/tasks/tracking/config/env.py index 00dd1e6b..95932374 100644 --- a/train_mimic/tasks/tracking/config/env.py +++ b/train_mimic/tasks/tracking/config/env.py @@ -172,7 +172,7 @@ def make_general_tracking_env_cfg( motion_cmd.anchor_body_name = "torso_link" motion_cmd.body_names = _TRACKING_BODY_NAMES motion_cmd.motion_file = DEFAULT_TRAIN_MOTION_FILE - motion_cmd.sampling_mode = "adaptive" + motion_cmd.sampling_mode = "uniform" motion_cmd.window_steps = (0,) cfg.events["physics_material"].params[ diff --git a/train_mimic/tasks/tracking/mdp/commands.py b/train_mimic/tasks/tracking/mdp/commands.py index c29a69fb..ce16af86 100644 --- a/train_mimic/tasks/tracking/mdp/commands.py +++ b/train_mimic/tasks/tracking/mdp/commands.py @@ -589,12 +589,28 @@ def get_frames( MotionLoader = MotionLib +def _validate_rewind_sampling_cfg(cfg: Any) -> None: + if cfg.rewind_min_steps < 0: + raise ValueError( + f"rewind_min_steps must be non-negative, got {cfg.rewind_min_steps}" + ) + if cfg.rewind_max_steps < cfg.rewind_min_steps: + raise ValueError( + "rewind_max_steps must be >= rewind_min_steps, got " + f"{cfg.rewind_max_steps} < {cfg.rewind_min_steps}" + ) + if not 0.0 <= cfg.rewind_prob <= 1.0: + raise ValueError(f"rewind_prob must be in [0, 1], got {cfg.rewind_prob}") + + class MotionCommand(CommandTerm): cfg: MotionCommandCfg _env: ManagerBasedRlEnv def __init__(self, cfg: MotionCommandCfg, env: ManagerBasedRlEnv): super().__init__(cfg, env) + if self.cfg.sampling_mode == "rewind": + _validate_rewind_sampling_cfg(self.cfg) self.robot: Entity = env.scene[cfg.entity_name] self.robot_anchor_body_index = self.robot.body_names.index( @@ -966,6 +982,41 @@ def _adaptive_sampling(self, env_ids: torch.Tensor): self._update_adaptive_sampling_metrics(sampling_probabilities) + def _rewind_sampling(self, env_ids: torch.Tensor) -> None: + _validate_rewind_sampling_cfg(self.cfg) + + previous_motion_ids = self.motion_ids[env_ids].clone() + previous_motion_times = self.motion_times[env_ids].clone() + self._uniform_sampling(env_ids) + if env_ids.numel() == 0 or self.cfg.rewind_prob <= 0.0: + return + + episode_failed = self._env.termination_manager.terminated[env_ids] + use_rewind = episode_failed & ( + torch.rand(env_ids.numel(), device=self.device) < self.cfg.rewind_prob + ) + if not torch.any(use_rewind): + return + + rewind_env_ids = env_ids[use_rewind] + rewind_steps = torch.randint( + self.cfg.rewind_min_steps, + self.cfg.rewind_max_steps + 1, + (rewind_env_ids.numel(),), + device=self.device, + dtype=torch.long, + ) + rewind_s = rewind_steps.to(dtype=self.motion_times.dtype) * float(self._step_dt) + motion_ids = previous_motion_ids[use_rewind] + rewind_times = previous_motion_times[use_rewind] - rewind_s + rewind_times = torch.maximum( + rewind_times, + self.motion.clip_sample_start_s[motion_ids], + ) + + self.motion_ids[rewind_env_ids] = motion_ids + self.motion_times[rewind_env_ids] = rewind_times + def _resample_command(self, env_ids: torch.Tensor): if self.cfg.sampling_mode == "start": self.motion_ids[env_ids] = self.motion.sample_motion_ids(len(env_ids)) @@ -974,10 +1025,12 @@ def _resample_command(self, env_ids: torch.Tensor): self._uniform_sampling(env_ids) elif self.cfg.sampling_mode == "adaptive": self._adaptive_sampling(env_ids) + elif self.cfg.sampling_mode == "rewind": + self._rewind_sampling(env_ids) else: raise ValueError( f"Unsupported motion sampling_mode={self.cfg.sampling_mode!r}. " - "Supported modes are 'uniform', 'start', and 'adaptive'." + "Supported modes are 'uniform', 'start', 'adaptive', and 'rewind'." ) if env_ids.numel() == 0: @@ -1236,11 +1289,14 @@ class MotionCommandCfg(CommandTermCfg): pose_range: dict[str, tuple[float, float]] = field(default_factory=dict) velocity_range: dict[str, tuple[float, float]] = field(default_factory=dict) joint_position_range: tuple[float, float] = (-0.52, 0.52) - sampling_mode: Literal["uniform", "start", "adaptive"] = "uniform" + sampling_mode: Literal["uniform", "start", "adaptive", "rewind"] = "uniform" window_steps: tuple[int, ...] = (0,) adaptive_bin_size_frames: int = 10 adaptive_uniform_ratio: float = 0.1 adaptive_alpha: float = 0.001 + rewind_prob: float = 0.8 + rewind_min_steps: int = 25 + rewind_max_steps: int = 75 feet_body_names: tuple[str, ...] = () feet_standing_z_threshold: float = 0.18 feet_standing_vxy_threshold: float = 0.2 From 77bc91d4f9636fa3008c40ece4f07e2050cae190 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Thu, 11 Jun 2026 21:42:15 +0800 Subject: [PATCH 077/122] Add reference base height to velcmd observation --- AGENTS.md | 9 +- docs/docs/configuration/faq.md | 3 +- docs/docs/intro.md | 4 +- docs/docs/reference/architecture.md | 10 +- docs/docs/tutorials/training.md | 2 +- .../current/configuration/faq.md | 3 +- .../current/intro.md | 4 +- .../current/reference/architecture.md | 96 +++++++++---------- .../current/tutorials/training.md | 2 +- scripts/run/standalone_standing.py | 21 ++-- teleopit/configs/robot/g1.yaml | 2 +- teleopit/controllers/observation.py | 10 +- teleopit/runtime/factory.py | 5 - tests/test_controller.py | 20 ++-- tests/test_e2e.py | 2 +- tests/test_observation.py | 8 +- tests/test_pipeline.py | 30 +++--- tests/test_save_onnx.py | 6 +- tests/test_sim2real_dim.py | 6 +- tests/test_task_registry.py | 1 + train_mimic/tasks/tracking/config/env.py | 8 ++ .../tasks/tracking/mdp/observations.py | 4 +- 22 files changed, 133 insertions(+), 123 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 81dc954b..6feae533 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -11,7 +11,7 @@ Config: Hydra/OmegaConf YAML files in `teleopit/configs/` ## Architecture ``` -InputProvider (BVH file / Pico4 VR) → Retargeter (GMR) → ObservationBuilder (166D) → Controller (dual-input TemporalCNN ONNX) → Robot (MuJoCo + PD / Unitree SDK) +InputProvider (BVH file / Pico4 VR) → Retargeter (GMR) → ObservationBuilder (167D) → Controller (dual-input TemporalCNN ONNX) → Robot (MuJoCo + PD / Unitree SDK) ``` Module-internal isolation: all modules run in-process and communicate via `InProcessBus` (zero-copy). Core interfaces are defined as `typing.Protocol` in `teleopit/interfaces.py`. @@ -19,7 +19,7 @@ Module-internal isolation: all modules run in-process and communicate via `InPro ## Supported Surface - Training task: `General-Tracking-G1` -- Inference observation: `velcmd_history` (166D, dual-input ONNX with `obs` + `obs_history`) +- Inference observation: `velcmd_history` (167D, dual-input ONNX with `obs` + `obs_history`) - TemporalCNN actor/critic with scaled dims (2048,1024,512,256,128) - Realtime inference uses a retargeted-reference timeline before observation build; `reference_steps=[0]` is the default production path @@ -165,7 +165,7 @@ target_dof_pos = clip(action, -10, 10) × action_scale + default_dof_pos - Realtime Pico sim2sim can start directly in `STANDING` with keyboard mode control enabled via top-level `keyboard.enabled` ### Inference Observation -Observation format: `velcmd_history` (166D, dual-input ONNX) +Observation format: `velcmd_history` (167D, dual-input ONNX) ``` command(58) @@ -178,6 +178,7 @@ command(58) + ref_base_lin_vel_b(3) + ref_base_ang_vel_b(3) + ref_projected_gravity_b(3) ++ ref_base_height(1) ``` Runtime constraints: @@ -189,7 +190,7 @@ Runtime constraints: The single supported training task is `General-Tracking-G1` (experiment name: `g1_general_tracking`). - Uses TemporalCNN actor/critic with scaled dims (2048,1024,512,256,128) -- 166D `velcmd_history` observation, dual-input ONNX export +- 167D `velcmd_history` observation, dual-input ONNX export - Training env uses `sampling_mode="uniform"` - Supported motion sampling modes are `adaptive`, `uniform`, `start`, and `rewind`; `rewind` restarts failed environments from the same clip after stepping back `rewind_min_steps..rewind_max_steps` with probability `rewind_prob`, otherwise it falls back to uniform sampling - Playback/benchmark use `play=True`, which switches motion sampling to `start` diff --git a/docs/docs/configuration/faq.md b/docs/docs/configuration/faq.md index 2509abcb..ccc28537 100644 --- a/docs/docs/configuration/faq.md +++ b/docs/docs/configuration/faq.md @@ -7,8 +7,7 @@ sidebar_position: 3 ## Why does it fail even though I set `policy_path`? 1. Verify the file exists -2. Confirm it's not an old 1402D / TWIST2 ONNX model -3. Confirm the input dimension is `166` with dual inputs (`obs` + `obs_history`) +2. Confirm the input dimension is `167` with dual inputs (`obs` + `obs_history`) ## Why must I specify `input.bvh_file` explicitly? diff --git a/docs/docs/intro.md b/docs/docs/intro.md index b26720b9..9e8cb8be 100644 --- a/docs/docs/intro.md +++ b/docs/docs/intro.md @@ -22,7 +22,7 @@ slug: / ```text InputProvider (BVH / Pico4 VR) -> Retargeter (GMR) - -> ObservationBuilder (166D) + -> ObservationBuilder (167D) -> Controller (dual-input TemporalCNN ONNX) -> Robot (MuJoCo sim or Unitree G1) ``` @@ -33,7 +33,7 @@ InputProvider (BVH / Pico4 VR) |------|-------| | Policy frequency | 50 Hz | | PD control frequency | 1000 Hz | -| Observation dimension | 166D | +| Observation dimension | 167D | | Action dimension | 29D (G1 joints) | | ONNX model | Dual-input TemporalCNN | | Retargeting | GMR (General Motion Retargeting) | diff --git a/docs/docs/reference/architecture.md b/docs/docs/reference/architecture.md index b3edd11b..1606ea6e 100644 --- a/docs/docs/reference/architecture.md +++ b/docs/docs/reference/architecture.md @@ -11,7 +11,7 @@ System internals and technical constraints for developers. ```text InputProvider (BVH file / Pico4) -> Retargeter (GMR) - -> ObservationBuilder (166D) + -> ObservationBuilder (167D) -> Controller (dual-input TemporalCNN ONNX) -> Robot (MuJoCo sim or Unitree G1) ``` @@ -45,7 +45,7 @@ train_mimic/scripts/data | `teleopit/pipeline.py` | Lightweight facade for offline sim | | `teleopit/sim2real/mp/` | Process-isolated sim2real state machine, IPC, and robot-control loop | | `teleopit/controllers/observation.py` | ObservationBuilder | -| `teleopit/controllers/rl_policy.py` | Only accepts 166D dual-input ONNX | +| `teleopit/controllers/rl_policy.py` | Accepts dual-input ONNX whose observation dimension matches the runtime builder | | `train_mimic/app.py` | Shared train/play/benchmark assembly | | `train_mimic/tasks/tracking/config/` | Single task registration (`General-Tracking-G1`) | | `train_mimic/data/dataset_builder.py` | Sole official dataset construction entry | @@ -55,8 +55,8 @@ train_mimic/scripts/data | Spec | Value | |------|-------| | Training task | `General-Tracking-G1` | -| Inference observation | `velcmd_history` (166D) | -| ONNX signature | Dual-input `obs` (166D) + `obs_history` | +| Inference observation | `velcmd_history` (167D) | +| ONNX signature | Dual-input `obs` (167D) + `obs_history` | | Actor/Critic | TemporalCNN (2048, 1024, 512, 256, 128) | | Training sampling | Default `uniform`; also supports `adaptive` and `rewind`; playback/benchmark use `start` | | Training `window_steps` | `[0]` | @@ -68,7 +68,7 @@ train_mimic/scripts/data - Offline BVH runs require explicit `input.bvh_file` - `viewers` is the sole viewer configuration entry - Observation/ONNX dimension mismatch causes immediate startup error -- sim2real also only supports 166D dual-input ONNX +- sim2real also requires a dual-input ONNX whose observation dimension matches the runtime builder ## Public Surface diff --git a/docs/docs/tutorials/training.md b/docs/docs/tutorials/training.md index 1e57142e..e3356ae2 100644 --- a/docs/docs/tutorials/training.md +++ b/docs/docs/tutorials/training.md @@ -86,7 +86,7 @@ python train_mimic/scripts/save_onnx.py \ --history_length 10 ``` -The exported model is a dual-input ONNX (`obs` + `obs_history`). The inference side only supports 166D dual-input ONNX. +The exported model is a dual-input ONNX (`obs` + `obs_history`). The inference side expects a 167D dual-input ONNX policy matching the current `velcmd_history` observation. ## Evaluation diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/faq.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/faq.md index a3a75e25..65856f4e 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/faq.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/faq.md @@ -7,8 +7,7 @@ sidebar_position: 3 ## 为什么设置了 `policy_path` 还是启动不了? 1. 确认文件存在 -2. 确认不是旧的 1402D / TWIST2 ONNX 模型 -3. 确认输入维度是 `166`,且为双输入 ONNX(`obs` + `obs_history`) +2. 确认输入维度是 `167`,且为双输入 ONNX(`obs` + `obs_history`) ## 为什么必须显式指定 `input.bvh_file`? diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/intro.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/intro.md index 8856a069..39969812 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/intro.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/intro.md @@ -20,7 +20,7 @@ slug: / ```text InputProvider (BVH / Pico4 VR) -> Retargeter (GMR) - -> ObservationBuilder (166D) + -> ObservationBuilder (167D) -> Controller (双输入 TemporalCNN ONNX) -> Robot (MuJoCo 仿真 或 Unitree G1) ``` @@ -31,7 +31,7 @@ InputProvider (BVH / Pico4 VR) |------|------| | 策略频率 | 50 Hz | | PD 控制频率 | 1000 Hz | -| 观测维度 | 166D | +| 观测维度 | 167D | | 动作维度 | 29D(G1 关节) | | ONNX 模型 | 双输入 TemporalCNN | | 运动重定向 | GMR(General Motion Retargeting) | diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/architecture.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/architecture.md index b73e931e..d82232c2 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/architecture.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/architecture.md @@ -4,80 +4,76 @@ sidebar_position: 1 # 架构 -本页介绍 Teleopit 的整体架构设计、核心模块边界及技术规格。 +面向开发者的系统内部结构和技术约束。 -## 运行主线 Pipeline +## Pipeline -Teleopit 的运行主线是一条线性数据流管线: - -``` -输入源 (Input) - → 重定向 (Retarget) - → 观测构建 (Observation) - → 策略推理 (Policy) - → PD 控制 (Controller) - → 仿真/实机执行 (Sim / Real) +```text +InputProvider(BVH 文件 / Pico4) + -> Retargeter(GMR) + -> ObservationBuilder(167D) + -> Controller(双输入 TemporalCNN ONNX) + -> Robot(MuJoCo 仿真或 Unitree G1) ``` -每个环节职责单一,通过明确定义的接口相互连接。 +离线/在线推理由 `teleopit/runtime/` 和 `teleopit/pipeline.py` 装配。硬件状态机通过 `teleopit/sim2real/mp/` 中的进程隔离运行时执行。训练由 `train_mimic/` 提供。 ## 代码结构 -``` -teleopit/ -├── app.py # 应用入口 -├── interfaces.py # 核心接口定义 -├── runtime/ # 运行时配置装配与启动逻辑 -├── pipeline/ # 数据流管线 -├── sim2real/ # 实机部署适配层 -├── observation/ # 观测构建 -├── rl_policy/ # 强化学习策略推理 -├── task/ # 任务配置 -└── dataset_builder/ # 数据集构建工具 +```text +configs / scripts + -> runtime + -> interfaces + pipeline state machines + -> adapters(inputs / retargeting / controller / robot / recording) + +train_mimic/scripts + -> train_mimic/app.py + -> single task registry / env builder / runner cfg + -> mjlab / rsl_rl + +train_mimic/scripts/data + -> train_mimic/data/dataset_builder.py + -> dataset_lib / motion_fk / convert_pkl_to_npz ``` ## 核心模块边界 -| 模块 | 文件/目录 | 职责 | -|---|---|---| -| 接口层 | `interfaces.py` | 定义所有核心抽象接口,模块间仅通过接口通信 | -| 运行时 | `runtime/` | Hydra 配置加载、对象组装、依赖注入 | -| Pipeline | `pipeline/` | 数据流编排,驱动每一帧的采样-推理-执行循环 | -| Sim2Real | `sim2real/mp/` | 进程隔离的实机状态机、IPC 与机器人控制循环 | -| 观测 | `observation/` | 从仿真/实机状态构建策略所需的观测向量 | -| 策略 | `rl_policy/` | ONNX 模型加载与推理,action 后处理 | -| 入口 | `app.py` | 命令行入口,调用 runtime 装配并启动 pipeline | -| 任务配置 | `task/` | Hydra 配置文件(YAML) | -| 数据集 | `dataset_builder/` | 动捕数据转换、NPZ 打包、数据集分片 | +| 模块 | 职责 | +|------|------| +| `teleopit/interfaces.py` | 稳定协议:InputProvider、Retargeter、Controller、Robot、ObservationBuilder、Recorder | +| `teleopit/runtime/` | 配置解析、路径规范化、组件装配、CLI 校验 | +| `teleopit/pipeline.py` | 离线仿真的轻量 facade | +| `teleopit/sim2real/mp/` | 进程隔离的 sim2real 状态机、IPC 和机器人控制循环 | +| `teleopit/controllers/observation.py` | ObservationBuilder | +| `teleopit/controllers/rl_policy.py` | 接受观测维度与运行时 builder 匹配的双输入 ONNX | +| `train_mimic/app.py` | 共享的训练/播放/benchmark 装配 | +| `train_mimic/tasks/tracking/config/` | 单一任务注册(`General-Tracking-G1`) | +| `train_mimic/data/dataset_builder.py` | 唯一官方数据集构建入口 | ## 技术规格 | 项目 | 规格 | |---|---| -| 仿真引擎 | MuJoCo | -| 策略推理 | ONNX Runtime | -| 配置系统 | Hydra | -| 实机通信 | DDS(Cyclone DDS) | -| 支持的机器人 | 宇树 G1 | -| 输入源 | BVH 文件、Pico 4 VR 头显 | | 训练任务 | `General-Tracking-G1` | -| 推理观测 | `velcmd_history`(166D) | +| 推理观测 | `velcmd_history`(167D) | +| ONNX 签名 | 双输入 `obs`(167D)+ `obs_history` | +| Actor/Critic | TemporalCNN(2048、1024、512、256、128) | | 训练采样 | 默认 `uniform`;也支持 `adaptive` 和 `rewind`;播放/评估使用 `start` | | 训练 `window_steps` | `[0]` | +| 数据格式 | 仅 shard 目录(`shard_*.npz`) | ## 约束 -- **单机器人**:当前架构假设同一时刻只控制一台机器人。 -- **固定观测格式**:观测构建器的输出维度在初始化时确定,运行时不可变。 -- **同步 pipeline**:pipeline 各阶段串行执行,策略推理频率即为 pipeline 的帧率。 -- **ONNX 模型**:策略必须导出为 ONNX 格式,不直接支持 PyTorch checkpoint。 +- 必须显式提供 `controller.policy_path`,且文件必须存在 +- 离线 BVH 运行必须显式提供 `input.bvh_file` +- `viewers` 是唯一的 viewer 配置入口 +- 观测/ONNX 维度不匹配会在启动时立即报错 +- sim2real 也要求双输入 ONNX,且观测维度必须与运行时 builder 匹配 ## 公共接口 -以下接口构成 Teleopit 的公共 API 面,外部扩展应仅依赖这些接口: +**稳定运行模式:** 离线 sim2sim、离线 sim2real playback、Pico4 sim2sim、G1 sim2real -- `interfaces.py` 中定义的抽象基类(InputProvider、ObsBuilder、Controller 等) -- `runtime/` 提供的工厂注册机制 -- Hydra 配置 schema +**稳定训练入口:** `train.py`、`play.py`、`benchmark.py`、`save_onnx.py` -内部实现细节(如 pipeline 的具体调度逻辑)不属于公共 API,可能在版本迭代中变更。 +**稳定数据入口:** `build_dataset.py`、`split_shards.py` diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md index 4e385e4b..7ea3c7e9 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md @@ -86,7 +86,7 @@ python train_mimic/scripts/save_onnx.py \ --history_length 10 ``` -导出的模型为双输入 ONNX(`obs` + `obs_history`)。推理端仅支持 166D 双输入 ONNX 格式。 +导出的模型为双输入 ONNX(`obs` + `obs_history`)。推理端需要与当前 `velcmd_history` 观测匹配的 167D 双输入 ONNX 策略。 ## 评估 diff --git a/scripts/run/standalone_standing.py b/scripts/run/standalone_standing.py index 007c2085..cbab7720 100644 --- a/scripts/run/standalone_standing.py +++ b/scripts/run/standalone_standing.py @@ -184,7 +184,7 @@ def quat_to_rot6d(q): # ===================================================================== class ObservationBuilder: - """166D VelCmd observation builder using MuJoCo FK.""" + """167D VelCmd observation builder using MuJoCo FK.""" def __init__(self, xml_path: str): self._mj_model = mujoco.MjModel.from_xml_path(xml_path) @@ -196,10 +196,10 @@ def __init__(self, xml_path: str): # Base obs: command(29+29) + anchor_ori_b(6) + ang_vel(3) + joint_pos_rel(29) + qvel(29) + last_act(29) # = 154 - # VelCmd extra: projected_gravity(3) + ref_lin_vel_b(3) + ref_ang_vel_b(3) + ref_proj_gravity(3) - # = 12 - # Total = 166 - self.total_obs_size = NUM_JOINTS * 2 + 6 + 3 + NUM_JOINTS * 3 + 12 + # VelCmd extra: projected_gravity(3) + ref_lin_vel_b(3) + ref_ang_vel_b(3) + # + ref_proj_gravity(3) + ref_base_height(1) = 13 + # Total = 167 + self.total_obs_size = NUM_JOINTS * 2 + 6 + 3 + NUM_JOINTS * 3 + 13 # Precompute motion torso offset for standing (DEFAULT_ANGLES with identity base) # torso_quat_world = quat_mul(base_quat, torso_offset) for constant joint angles @@ -219,9 +219,12 @@ def _run_fk(self, base_pos, base_quat, joint_pos): def _get_body_quat(self, body_id): return np.asarray(self._mj_data.xquat[body_id], dtype=np.float32).copy() + def _get_body_pos(self, body_id): + return np.asarray(self._mj_data.xpos[body_id], dtype=np.float32).copy() + def build(self, robot_qpos, robot_qvel, robot_quat, robot_ang_vel, motion_qpos, motion_joint_vel, last_action): - """Build 166D observation for VelCmd policy. + """Build 167D observation for VelCmd policy. Args: robot_qpos: (29,) current joint positions @@ -246,6 +249,8 @@ def build(self, robot_qpos, robot_qvel, robot_quat, robot_ang_vel, motion_base_quat = motion[3:7] motion_joint_pos = motion[7:7 + NUM_JOINTS] + self._run_fk(motion[0:3], motion_base_quat, motion_joint_pos) + motion_anchor_pos = self._get_body_pos(self._anchor_body_id) motion_anchor_quat = quat_mul(motion_base_quat, self._standing_torso_offset) # Base observation (154D) @@ -263,19 +268,21 @@ def build(self, robot_qpos, robot_qvel, robot_quat, robot_ang_vel, last_act, # 29 ], dtype=np.float32) - # VelCmd extra (12D) -- standing has zero reference velocities + # VelCmd extra (13D) -- standing has zero reference velocities projected_gravity = quat_rotate(quat_inv(robot_q), GRAVITY_UNIT_W) robot_inv = quat_inv(robot_anchor_quat) # Zero reference velocities for standing ref_lin_vel_b = np.zeros(3, dtype=np.float32) ref_ang_vel_b = np.zeros(3, dtype=np.float32) ref_proj_gravity = quat_rotate(quat_inv(motion_anchor_quat), GRAVITY_UNIT_W) + ref_base_height = motion_anchor_pos[2:3] velcmd_obs = np.concatenate([ projected_gravity, # 3 ref_lin_vel_b, # 3 ref_ang_vel_b, # 3 ref_proj_gravity, # 3 + ref_base_height, # 1 ], dtype=np.float32) obs = np.concatenate([base_obs, velcmd_obs], dtype=np.float32) diff --git a/teleopit/configs/robot/g1.yaml b/teleopit/configs/robot/g1.yaml index f1cfa837..f46d4487 100644 --- a/teleopit/configs/robot/g1.yaml +++ b/teleopit/configs/robot/g1.yaml @@ -3,7 +3,7 @@ # Do NOT copy from deploy.yaml (contains known waist_yaw scale error). num_actions: 29 -# Inference uses the 166D velcmd_history observation path (General-Tracking-G1). +# Inference uses the 167D velcmd_history observation path (General-Tracking-G1). kps: [40.2, 99.1, 40.2, 99.1, 28.5, 28.5, 40.2, 99.1, 40.2, 99.1, 28.5, 28.5, diff --git a/teleopit/controllers/observation.py b/teleopit/controllers/observation.py index b637b5b8..fd1fa987 100644 --- a/teleopit/controllers/observation.py +++ b/teleopit/controllers/observation.py @@ -107,7 +107,7 @@ def _quat_to_rot6d_np(q: FloatVec) -> FloatVec: @final class _VelCmdBaseObservationBuilder: - """Internal base block used by the public 166D VelCmd builder.""" + """Internal base block used by the public 167D VelCmd builder.""" def __init__(self, cfg: ConfigType) -> None: self.num_actions: int = _as_int_scalar(cfg_get(cfg, "num_actions"), "num_actions") @@ -215,12 +215,12 @@ def build( @final class VelCmdObservationBuilder: - """166D observation builder for the only supported VelCmdHistory policy.""" + """167D observation builder for the only supported VelCmdHistory policy.""" def __init__(self, cfg: ConfigType) -> None: self._base = _VelCmdBaseObservationBuilder(cfg) self.num_actions = self._base.num_actions - self.total_obs_size = self._base.total_obs_size + 12 + self.total_obs_size = self._base.total_obs_size + 13 def reset(self) -> None: self._base.reset() @@ -239,6 +239,7 @@ def build( robot_quat = np.asarray(robot_state.quat, dtype=np.float32).reshape(-1) motion = np.asarray(motion_qpos, dtype=np.float32).reshape(-1) self._base._run_fk(motion[0:3], motion[3:7], motion[7:7 + self.num_actions]) + motion_anchor_pos = self._base._get_body_pos(self._base._anchor_body_id) motion_anchor_quat = self._base._get_body_quat(self._base._anchor_body_id) qpos = np.asarray(robot_state.qpos, dtype=np.float32).reshape(-1)[: self.num_actions] @@ -255,12 +256,14 @@ def build( ref_base_lin_vel_b = _quat_rotate_np(robot_inv, ref_lin_vel_w) ref_base_ang_vel_b = _quat_rotate_np(robot_inv, ref_ang_vel_w) ref_projected_gravity_b = _quat_rotate_np(_quat_inv_np(motion_anchor_quat), _GRAVITY_UNIT_W) + ref_base_height = motion_anchor_pos[2:3] velcmd_obs = np.concatenate([ projected_gravity, ref_base_lin_vel_b, ref_base_ang_vel_b, ref_projected_gravity_b, + ref_base_height, ], dtype=np.float32) obs = np.concatenate([base_obs, velcmd_obs], dtype=np.float32) if obs.shape[0] != self.total_obs_size: @@ -282,4 +285,3 @@ def build_observation( "Use build(robot_state, motion_qpos, motion_joint_vel, last_action, " "motion_anchor_lin_vel_w, motion_anchor_ang_vel_w)." ) - diff --git a/teleopit/runtime/factory.py b/teleopit/runtime/factory.py index 3466b613..ba581c8f 100644 --- a/teleopit/runtime/factory.py +++ b/teleopit/runtime/factory.py @@ -152,11 +152,6 @@ def _build_policy_components( policy_dim = getattr(controller, "_expected_obs_dim", None) builder_dim = getattr(obs_builder, "total_obs_size", None) if policy_dim is not None and builder_dim is not None and policy_dim != builder_dim: - if builder_dim == 166: - raise ValueError( - f"Only 166D velcmd_history ONNX policies are supported here; " - f"obs_builder produces 166D but policy expects {policy_dim}D." - ) raise ValueError( f"Observation dimension mismatch at startup: obs_builder produces {builder_dim}D " f"but policy expects {policy_dim}D. Use a matching ONNX model." diff --git a/tests/test_controller.py b/tests/test_controller.py index 3a07d103..e86e41ca 100644 --- a/tests/test_controller.py +++ b/tests/test_controller.py @@ -41,11 +41,11 @@ def test_normalize_clip_range_inverted_raises(self): def test_extract_feature_dim_int(self): from teleopit.controllers.rl_policy import RLPolicyController - assert RLPolicyController._extract_feature_dim([1, 166]) == 166 + assert RLPolicyController._extract_feature_dim([1, 167]) == 167 def test_extract_feature_dim_string(self): from teleopit.controllers.rl_policy import RLPolicyController - assert RLPolicyController._extract_feature_dim(["batch", "166"]) == 166 + assert RLPolicyController._extract_feature_dim(["batch", "167"]) == 167 def test_extract_feature_dim_dynamic(self): from teleopit.controllers.rl_policy import RLPolicyController @@ -105,7 +105,7 @@ class TestRLPolicyControllerInference: def test_compute_action_shape(self): from teleopit.controllers.rl_policy import RLPolicyController - obs_dim = 166 + obs_dim = 167 action_dim = 29 # Build a controller with mocked internals @@ -138,7 +138,7 @@ def test_compute_action_wrong_dim_raises(self): from teleopit.controllers.rl_policy import RLPolicyController ctrl = RLPolicyController.__new__(RLPolicyController) - ctrl._expected_obs_dim = 166 + ctrl._expected_obs_dim = 167 ctrl.clip_range = (-10.0, 10.0) ctrl.action_scale = np.ones(29, dtype=np.float32) ctrl._session = MagicMock() @@ -154,8 +154,8 @@ def test_reset_is_noop(self): ctrl = RLPolicyController.__new__(RLPolicyController) from collections import deque ctrl._history_buf = deque(maxlen=3) - ctrl._last_obs_input = np.zeros(166, dtype=np.float32) - ctrl._last_obs_history_input = np.zeros((3, 166), dtype=np.float32) + ctrl._last_obs_input = np.zeros(167, dtype=np.float32) + ctrl._last_obs_history_input = np.zeros((3, 167), dtype=np.float32) assert ctrl.reset() is None assert len(ctrl._history_buf) == 0 assert ctrl._last_obs_input is None @@ -165,7 +165,7 @@ def test_multi_input_debug_inputs_capture_history(self): from teleopit.controllers.rl_policy import RLPolicyController ctrl = RLPolicyController.__new__(RLPolicyController) - ctrl._expected_obs_dim = 166 + ctrl._expected_obs_dim = 167 ctrl.clip_range = (-10.0, 10.0) ctrl.action_scale = np.ones(2, dtype=np.float32) ctrl.default_dof_pos = np.zeros(2, dtype=np.float32) @@ -173,7 +173,7 @@ def test_multi_input_debug_inputs_capture_history(self): ctrl._output_name = "action" ctrl._multi_input = True ctrl._history_length = 3 - ctrl._history_obs_dim = 166 + ctrl._history_obs_dim = 167 from collections import deque ctrl._history_buf = deque(maxlen=3) ctrl._last_obs_input = None @@ -182,9 +182,9 @@ def test_multi_input_debug_inputs_capture_history(self): mock_session.run.return_value = [np.zeros((1, 2), dtype=np.float32)] ctrl._session = mock_session - obs = np.zeros(166, dtype=np.float32) + obs = np.zeros(167, dtype=np.float32) ctrl.compute_action(obs) debug = ctrl.get_debug_inputs() assert debug["obs"] is not None assert debug["obs_history"] is not None - assert debug["obs_history"].shape == (3, 166) + assert debug["obs_history"].shape == (3, 167) diff --git a/tests/test_e2e.py b/tests/test_e2e.py index 879210c2..904515cd 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -46,7 +46,7 @@ def test_bvh_to_mujoco_pipeline_stands_and_records(project_root: Path, tmp_dir: policy_path, bvh_path, xml_path = _asset_paths(project_root) if not policy_path.exists() or not bvh_path.exists() or not xml_path.exists(): - pytest.skip("set TELEOPIT_TEST_POLICY_ONNX to a compatible 166D ONNX policy to run this e2e test") + pytest.skip("set TELEOPIT_TEST_POLICY_ONNX to a compatible 167D ONNX policy to run this e2e test") robot_cfg = OmegaConf.load(project_root / "teleopit" / "configs" / "robot" / "g1.yaml") controller_cfg = OmegaConf.load(project_root / "teleopit" / "configs" / "controller" / "rl_policy.yaml") diff --git a/tests/test_observation.py b/tests/test_observation.py index 0086125b..64dcbddb 100644 --- a/tests/test_observation.py +++ b/tests/test_observation.py @@ -70,7 +70,7 @@ def test_rotate_motion_qpos_by_yaw_keeps_first_frame_position_with_pivot() -> No @requires_mujoco @_skip_no_xml class TestVelCmdObservationBuilder: - def test_output_dimension_is_166(self) -> None: + def test_output_dimension_is_167(self) -> None: builder = VelCmdObservationBuilder(_velcmd_cfg()) obs = builder.build( _make_state(), @@ -80,15 +80,17 @@ def test_output_dimension_is_166(self) -> None: np.zeros(3, dtype=np.float32), np.zeros(3, dtype=np.float32), ) - assert builder.total_obs_size == 166 - assert obs.shape == (166,) + assert builder.total_obs_size == 167 + assert obs.shape == (167,) assert obs.dtype == np.float32 + assert obs[-1] > 0.0 def test_projected_gravity_uses_root_quat_not_anchor_quat(self) -> None: builder = VelCmdObservationBuilder(_velcmd_cfg()) builder._base.build = lambda *args, **kwargs: np.zeros(154, dtype=np.float32) # type: ignore[method-assign] builder._base._run_fk = lambda *args, **kwargs: None # type: ignore[method-assign] builder._base._anchor_body_id = 0 + builder._base._get_body_pos = lambda _body_id: np.array([0.0, 0.0, 1.0], dtype=np.float32) # type: ignore[method-assign] builder._base._get_body_quat = lambda _body_id: np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32) # type: ignore[method-assign] root_quat = np.array( diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 81419cc4..17333348 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -18,7 +18,7 @@ def __init__(self, cfg: object) -> None: captured["robot_cfg"] = cfg class DummyController: - _expected_obs_dim = 166 + _expected_obs_dim = 167 _multi_input = True def __init__(self, cfg: object) -> None: @@ -28,7 +28,7 @@ def reset(self) -> None: pass class DummyObsBuilder: - total_obs_size = 166 + total_obs_size = 167 def __init__(self, cfg: object) -> None: captured["obs_cfg"] = cfg @@ -119,7 +119,7 @@ def __init__(self, cfg: object) -> None: pass class DummyController: - _expected_obs_dim = 166 + _expected_obs_dim = 167 _multi_input = True def __init__(self, cfg: object) -> None: @@ -129,7 +129,7 @@ def reset(self) -> None: pass class DummyObsBuilder: - total_obs_size = 166 + total_obs_size = 167 def __init__(self, cfg: object) -> None: pass @@ -188,7 +188,7 @@ def __init__(self, cfg: object) -> None: pass class DummyController: - _expected_obs_dim = 166 + _expected_obs_dim = 167 _multi_input = True def __init__(self, cfg: object) -> None: @@ -198,7 +198,7 @@ def reset(self) -> None: pass class DummyObsBuilder: - total_obs_size = 166 + total_obs_size = 167 def __init__(self, cfg: object) -> None: pass @@ -268,7 +268,7 @@ def __init__(self, cfg: object) -> None: pass class DummyController: - _expected_obs_dim = 166 + _expected_obs_dim = 167 _multi_input = True def __init__(self, cfg: object) -> None: @@ -278,7 +278,7 @@ def reset(self) -> None: pass class DummyObsBuilder: - total_obs_size = 166 + total_obs_size = 167 def __init__(self, cfg: object) -> None: pass @@ -384,7 +384,7 @@ def _pipeline_cfg(tmp_path: Path) -> dict: @requires_mujoco @_skip_no_xml -def test_pipeline_166d_policy_required(monkeypatch, tmp_path: Path) -> None: +def test_pipeline_policy_dim_mismatch_required(monkeypatch, tmp_path: Path) -> None: class DummyController160: _expected_obs_dim = 160 _multi_input = True @@ -420,7 +420,7 @@ def __init__(self, *args: object, **kwargs: object) -> None: monkeypatch.setattr("teleopit.pipeline.SimulationLoop", DummyLoop) cfg = OmegaConf.create(_pipeline_cfg(tmp_path)) - with pytest.raises(ValueError, match="Only 166D"): + with pytest.raises(ValueError, match="Observation dimension mismatch"): TeleopPipeline(cfg) @@ -428,7 +428,7 @@ def __init__(self, *args: object, **kwargs: object) -> None: @_skip_no_xml def test_pipeline_requires_dual_input_policy(monkeypatch, tmp_path: Path) -> None: class DummyControllerSingle: - _expected_obs_dim = 166 + _expected_obs_dim = 167 _multi_input = False def __init__(self, cfg: object) -> None: @@ -469,8 +469,8 @@ def __init__(self, *args: object, **kwargs: object) -> None: @requires_mujoco @_skip_no_xml def test_pipeline_dim_match_passes(monkeypatch, tmp_path: Path) -> None: - class DummyController166: - _expected_obs_dim = 166 + class DummyController167: + _expected_obs_dim = 167 _multi_input = True def __init__(self, cfg: object) -> None: @@ -497,7 +497,7 @@ class DummyLoop: def __init__(self, *args: object, **kwargs: object) -> None: pass - monkeypatch.setattr("teleopit.pipeline.RLPolicyController", DummyController166) + monkeypatch.setattr("teleopit.pipeline.RLPolicyController", DummyController167) monkeypatch.setattr("teleopit.pipeline.MuJoCoRobot", DummyRobot) monkeypatch.setattr("teleopit.pipeline.BVHInputProvider", DummyInputProvider) monkeypatch.setattr("teleopit.pipeline.RetargetingModule", DummyRetargeter) @@ -505,4 +505,4 @@ def __init__(self, *args: object, **kwargs: object) -> None: cfg = OmegaConf.create(_pipeline_cfg(tmp_path)) pipeline = TeleopPipeline(cfg) - assert pipeline.obs_builder.total_obs_size == 166 + assert pipeline.obs_builder.total_obs_size == 167 diff --git a/tests/test_save_onnx.py b/tests/test_save_onnx.py index c351b2ac..d2ff0ad6 100644 --- a/tests/test_save_onnx.py +++ b/tests/test_save_onnx.py @@ -11,7 +11,7 @@ def _build_temporal_actor_model( *, - obs_dim: int = 166, + obs_dim: int = 167, history_length: int = 10, ref_window_dim: int | None = None, ref_window_length: int = 20, @@ -51,7 +51,7 @@ def _build_temporal_actor_model( ) -def _build_temporal_actor_state_dict(obs_dim: int = 166, history_length: int = 10) -> dict[str, torch.Tensor]: +def _build_temporal_actor_state_dict(obs_dim: int = 167, history_length: int = 10) -> dict[str, torch.Tensor]: return _build_temporal_actor_model(obs_dim=obs_dim, history_length=history_length).state_dict() @@ -113,7 +113,7 @@ def fake_export(model, args, output_path, **kwargs): # type: ignore[no-untyped- assert captured["output_path"] == str(tmp_path / "policy.onnx") assert captured["input_names"] == ["obs", "obs_history"] - assert captured["arg_shapes"] == [(1, 166), (1, 10, 166)] + assert captured["arg_shapes"] == [(1, 167), (1, 10, 167)] def test_export_temporal_cnn_multi_group_checkpoint(monkeypatch, tmp_path: Path) -> None: diff --git a/tests/test_sim2real_dim.py b/tests/test_sim2real_dim.py index dea8a8dd..31dc247a 100644 --- a/tests/test_sim2real_dim.py +++ b/tests/test_sim2real_dim.py @@ -24,7 +24,7 @@ def _cfg() -> dict[str, object]: def test_robot_worker_requires_dual_input_policy(monkeypatch) -> None: policy = SimpleNamespace(_multi_input=False) - obs_builder = SimpleNamespace(total_obs_size=166) + obs_builder = SimpleNamespace(total_obs_size=167) monkeypatch.setattr( "teleopit.sim2real.mp.runtime._build_policy_components", lambda **_kwargs: (policy, obs_builder), @@ -37,9 +37,9 @@ def test_robot_worker_requires_dual_input_policy(monkeypatch) -> None: worker._build_policy_and_obs() -def test_robot_worker_accepts_166d_dual_input_policy(monkeypatch) -> None: +def test_robot_worker_accepts_167d_dual_input_policy(monkeypatch) -> None: policy = SimpleNamespace(_multi_input=True) - obs_builder = SimpleNamespace(total_obs_size=166) + obs_builder = SimpleNamespace(total_obs_size=167) monkeypatch.setattr( "teleopit.sim2real.mp.runtime._build_policy_components", lambda **_kwargs: (policy, obs_builder), diff --git a/tests/test_task_registry.py b/tests/test_task_registry.py index 0f6797bf..ea0944c9 100644 --- a/tests/test_task_registry.py +++ b/tests/test_task_registry.py @@ -27,6 +27,7 @@ def test_general_tracking_task_is_registered() -> None: assert "ref_base_lin_vel_b" in terms assert "ref_base_ang_vel_b" in terms assert "ref_projected_gravity_b" in terms + assert "ref_base_height" in terms assert "motion_anchor_pos_b" not in actor_terms assert "base_lin_vel" not in actor_terms assert "actor_history" in env_cfg.observations diff --git a/train_mimic/tasks/tracking/config/env.py b/train_mimic/tasks/tracking/config/env.py index 95932374..880662c0 100644 --- a/train_mimic/tasks/tracking/config/env.py +++ b/train_mimic/tasks/tracking/config/env.py @@ -93,6 +93,10 @@ def _add_history_obs_groups( func=mdp.ref_projected_gravity_b, params={"command_name": "motion"}, ), + "ref_base_height": ObservationTermCfg( + func=mdp.ref_base_height, + params={"command_name": "motion"}, + ), } _VELCMD_CRITIC_TERMS: dict[str, ObservationTermCfg] = { @@ -109,6 +113,10 @@ def _add_history_obs_groups( func=mdp.ref_projected_gravity_b, params={"command_name": "motion"}, ), + "ref_base_height": ObservationTermCfg( + func=mdp.ref_base_height, + params={"command_name": "motion"}, + ), } diff --git a/train_mimic/tasks/tracking/mdp/observations.py b/train_mimic/tasks/tracking/mdp/observations.py index b3180c7f..7b9b66fe 100644 --- a/train_mimic/tasks/tracking/mdp/observations.py +++ b/train_mimic/tasks/tracking/mdp/observations.py @@ -102,14 +102,14 @@ def ref_projected_gravity_b(env: ManagerBasedRlEnv, command_name: str) -> torch. def ref_base_height(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: - """Reference anchor height (z-coordinate). (N, 1) — critic privileged.""" + """Reference anchor height (z-coordinate). (N, 1)""" command = cast(MotionCommand, env.command_manager.get_term(command_name)) return command.anchor_pos_w[:, 2:3] # --------------------------------------------------------------------------- # Yaw-only variants: use yaw_quat(robot_anchor_quat_w) to decouple -# roll/pitch from the coordinate transform, matching the TWIST2 approach. +# roll/pitch from the coordinate transform. # --------------------------------------------------------------------------- From 9d664a795e30663bb5b457acef02bb187e1bb099 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Fri, 12 Jun 2026 14:24:39 +0800 Subject: [PATCH 078/122] Add tracking velocity and survival rewards --- AGENTS.md | 1 + README.md | 1 + tests/test_task_registry.py | 22 ++++++++++ tests/test_tracking_rewards.py | 27 ++++++++++++- train_mimic/tasks/tracking/mdp/rewards.py | 40 +++++++++++++++++++ .../tasks/tracking/tracking_env_cfg.py | 21 ++++++++++ 6 files changed, 111 insertions(+), 1 deletion(-) diff --git a/AGENTS.md b/AGENTS.md index 6feae533..9c4af62d 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -192,6 +192,7 @@ The single supported training task is `General-Tracking-G1` (experiment name: `g - Uses TemporalCNN actor/critic with scaled dims (2048,1024,512,256,128) - 167D `velcmd_history` observation, dual-input ONNX export - Training env uses `sampling_mode="uniform"` +- Tracking rewards include root position/orientation/linear velocity/angular velocity, body pose/velocity, joint position/velocity, survival, action-rate, joint-limit, self-collision, and ankle acceleration terms - Supported motion sampling modes are `adaptive`, `uniform`, `start`, and `rewind`; `rewind` restarts failed environments from the same clip after stepping back `rewind_min_steps..rewind_max_steps` with probability `rewind_prob`, otherwise it falls back to uniform sampling - Playback/benchmark use `play=True`, which switches motion sampling to `start` - `window_steps=[0]` diff --git a/README.md b/README.md index 7621ecbd..9b10be17 100644 --- a/README.md +++ b/README.md @@ -96,6 +96,7 @@ Full docs at **[BotRunner64.github.io/Teleopit](https://BotRunner64.github.io/Te - Added an interactive Pico motion recorder that saves retargeted G1 motion clips as training-ready NPZ files. - General-Tracking-G1 training now defaults to uniform motion sampling; clip-local adaptive sampling remains available through `sampling_mode=adaptive`. - Added optional `sampling_mode=rewind` for training, which restarts failed episodes from the same clip after rewinding a configurable number of policy steps. +- Added root velocity, joint tracking, and survival rewards to the General-Tracking-G1 training objective. ### v0.3.0 (2026-05-12) diff --git a/tests/test_task_registry.py b/tests/test_task_registry.py index ea0944c9..1419ab47 100644 --- a/tests/test_task_registry.py +++ b/tests/test_task_registry.py @@ -34,6 +34,28 @@ def test_general_tracking_task_is_registered() -> None: assert "critic_history" in env_cfg.observations assert env_cfg.commands["motion"].sampling_mode == "uniform" assert env_cfg.commands["motion"].window_steps == (0,) + assert env_cfg.rewards["motion_global_root_lin_vel"].weight == 1.0 + assert env_cfg.rewards["motion_global_root_lin_vel"].params == { + "command_name": "motion", + "std": 1.0, + } + assert env_cfg.rewards["motion_global_root_ang_vel"].weight == 1.0 + assert env_cfg.rewards["motion_global_root_ang_vel"].params == { + "command_name": "motion", + "std": 3.0, + } + assert env_cfg.rewards["motion_joint_pos"].weight == 1.0 + assert env_cfg.rewards["motion_joint_pos"].params == { + "command_name": "motion", + "std": 0.5, + } + assert env_cfg.rewards["motion_joint_vel"].weight == 0.5 + assert env_cfg.rewards["motion_joint_vel"].params == { + "command_name": "motion", + "std": 3.0, + } + assert env_cfg.rewards["survival"].weight == 3.0 + assert env_cfg.rewards["survival"].params == {} reward = env_cfg.rewards["self_collisions"] assert reward.weight == -0.1 assert reward.params == { diff --git a/tests/test_tracking_rewards.py b/tests/test_tracking_rewards.py index 1550cf3d..377551ad 100644 --- a/tests/test_tracking_rewards.py +++ b/tests/test_tracking_rewards.py @@ -4,7 +4,11 @@ import torch -from train_mimic.tasks.tracking.mdp.rewards import self_collision_cost +from train_mimic.tasks.tracking.mdp.rewards import ( + motion_joint_position_error_exp, + self_collision_cost, + survival, +) def _env_with_force_history(force_history: torch.Tensor) -> SimpleNamespace: @@ -29,3 +33,24 @@ def test_self_collision_cost_counts_contact_slots_not_history_frames() -> None: ) torch.testing.assert_close(penalty, torch.tensor([3.0, 0.0])) + + +def test_motion_joint_position_error_exp_uses_mean_squared_error() -> None: + command = SimpleNamespace( + joint_pos=torch.tensor([[0.0, 1.0], [0.5, -0.5]], dtype=torch.float32), + robot_joint_pos=torch.tensor([[0.0, 3.0], [0.5, -0.5]], dtype=torch.float32), + ) + env = SimpleNamespace( + command_manager=SimpleNamespace(get_term=lambda _name: command), + ) + + reward = motion_joint_position_error_exp(env, command_name="motion", std=2.0) + + expected = torch.exp(torch.tensor([-0.5, 0.0])) + torch.testing.assert_close(reward, expected) + + +def test_survival_reward_returns_one_per_env() -> None: + reward = survival(SimpleNamespace(num_envs=3, device=torch.device("cpu"))) + + torch.testing.assert_close(reward, torch.ones(3)) diff --git a/train_mimic/tasks/tracking/mdp/rewards.py b/train_mimic/tasks/tracking/mdp/rewards.py index 9644a810..6a08dc44 100644 --- a/train_mimic/tasks/tracking/mdp/rewards.py +++ b/train_mimic/tasks/tracking/mdp/rewards.py @@ -42,6 +42,26 @@ def motion_global_anchor_orientation_error_exp( return torch.exp(-error / std**2) +def motion_global_anchor_linear_velocity_error_exp( + env: ManagerBasedRlEnv, command_name: str, std: float +) -> torch.Tensor: + command = cast(MotionCommand, env.command_manager.get_term(command_name)) + error = torch.sum( + torch.square(command.anchor_lin_vel_w - command.robot_anchor_lin_vel_w), dim=-1 + ) + return torch.exp(-error / std**2) + + +def motion_global_anchor_angular_velocity_error_exp( + env: ManagerBasedRlEnv, command_name: str, std: float +) -> torch.Tensor: + command = cast(MotionCommand, env.command_manager.get_term(command_name)) + error = torch.sum( + torch.square(command.anchor_ang_vel_w - command.robot_anchor_ang_vel_w), dim=-1 + ) + return torch.exp(-error / std**2) + + def motion_relative_body_position_error_exp( env: ManagerBasedRlEnv, command_name: str, @@ -114,6 +134,26 @@ def motion_global_body_angular_velocity_error_exp( return torch.exp(-error.mean(-1) / std**2) +def motion_joint_position_error_exp( + env: ManagerBasedRlEnv, command_name: str, std: float +) -> torch.Tensor: + command = cast(MotionCommand, env.command_manager.get_term(command_name)) + error = torch.square(command.joint_pos - command.robot_joint_pos) + return torch.exp(-error.mean(-1) / std**2) + + +def motion_joint_velocity_error_exp( + env: ManagerBasedRlEnv, command_name: str, std: float +) -> torch.Tensor: + command = cast(MotionCommand, env.command_manager.get_term(command_name)) + error = torch.square(command.joint_vel - command.robot_joint_vel) + return torch.exp(-error.mean(-1) / std**2) + + +def survival(env: ManagerBasedRlEnv) -> torch.Tensor: + return torch.ones(env.num_envs, device=env.device) + + def self_collision_cost( env: ManagerBasedRlEnv, sensor_name: str | tuple[str, ...], diff --git a/train_mimic/tasks/tracking/tracking_env_cfg.py b/train_mimic/tasks/tracking/tracking_env_cfg.py index b6ed43c7..1b5be1f6 100644 --- a/train_mimic/tasks/tracking/tracking_env_cfg.py +++ b/train_mimic/tasks/tracking/tracking_env_cfg.py @@ -222,6 +222,16 @@ def make_tracking_env_cfg() -> ManagerBasedRlEnvCfg: weight=0.5, params={"command_name": "motion", "std": 0.4}, ), + "motion_global_root_lin_vel": RewardTermCfg( + func=mdp.motion_global_anchor_linear_velocity_error_exp, + weight=1.0, + params={"command_name": "motion", "std": 1.0}, + ), + "motion_global_root_ang_vel": RewardTermCfg( + func=mdp.motion_global_anchor_angular_velocity_error_exp, + weight=1.0, + params={"command_name": "motion", "std": 3.0}, + ), "motion_body_pos": RewardTermCfg( func=mdp.motion_relative_body_position_error_exp, weight=1.0, @@ -242,6 +252,17 @@ def make_tracking_env_cfg() -> ManagerBasedRlEnvCfg: weight=1.0, params={"command_name": "motion", "std": 3.14}, ), + "motion_joint_pos": RewardTermCfg( + func=mdp.motion_joint_position_error_exp, + weight=1.0, + params={"command_name": "motion", "std": 0.5}, + ), + "motion_joint_vel": RewardTermCfg( + func=mdp.motion_joint_velocity_error_exp, + weight=0.5, + params={"command_name": "motion", "std": 3.0}, + ), + "survival": RewardTermCfg(func=mdp.survival, weight=3.0), "action_rate_l2": RewardTermCfg(func=mdp.action_rate_l2, weight=-1e-1), "joint_limit": RewardTermCfg( func=mdp.joint_pos_limits, From a244722ebebeaaaf3530e854b794b47a67a7c5ba Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Fri, 12 Jun 2026 15:59:20 +0800 Subject: [PATCH 079/122] Rename tracking observation terms --- AGENTS.md | 21 +++---- README.md | 1 + teleopit/controllers/observation.py | 60 ++++++++++--------- tests/test_task_registry.py | 40 ++++++++++--- train_mimic/tasks/tracking/config/env.py | 30 +++++----- .../tasks/tracking/mdp/observations.py | 35 +++++++---- .../tasks/tracking/tracking_env_cfg.py | 58 ++++++++++-------- 7 files changed, 144 insertions(+), 101 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 9c4af62d..f7828ebd 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -168,17 +168,18 @@ target_dof_pos = clip(action, -10, 10) × action_scale + default_dof_pos Observation format: `velcmd_history` (167D, dual-input ONNX) ``` -command(58) -+ motion_anchor_ori_b(6) -+ base_ang_vel(3) -+ joint_pos_rel(29) -+ joint_vel(29) -+ last_action(29) -+ projected_gravity(3) -+ ref_base_lin_vel_b(3) -+ ref_base_ang_vel_b(3) +ref_joint_pos(29) ++ ref_joint_vel(29) ++ ref_anchor_ori_b(6) ++ robot_base_ang_vel_b(3) ++ robot_joint_pos_rel(29) ++ robot_joint_vel(29) ++ prev_action(29) ++ robot_projected_gravity_b(3) ++ ref_anchor_lin_vel_b(3) ++ ref_anchor_ang_vel_b(3) + ref_projected_gravity_b(3) -+ ref_base_height(1) ++ ref_anchor_height(1) ``` Runtime constraints: diff --git a/README.md b/README.md index 9b10be17..b217d1b5 100644 --- a/README.md +++ b/README.md @@ -97,6 +97,7 @@ Full docs at **[BotRunner64.github.io/Teleopit](https://BotRunner64.github.io/Te - General-Tracking-G1 training now defaults to uniform motion sampling; clip-local adaptive sampling remains available through `sampling_mode=adaptive`. - Added optional `sampling_mode=rewind` for training, which restarts failed episodes from the same clip after rewinding a configurable number of policy steps. - Added root velocity, joint tracking, and survival rewards to the General-Tracking-G1 training objective. +- Renamed General-Tracking-G1 observation terms to explicit `ref_*`, `robot_*`, and `prev_action` keys. ### v0.3.0 (2026-05-12) diff --git a/teleopit/controllers/observation.py b/teleopit/controllers/observation.py index fd1fa987..606ceaea 100644 --- a/teleopit/controllers/observation.py +++ b/teleopit/controllers/observation.py @@ -164,13 +164,13 @@ def build( last_action: FloatVec, ) -> FloatVec: qpos = np.asarray(robot_state.qpos, dtype=np.float32).reshape(-1)[: self.num_actions] - qvel = np.asarray(robot_state.qvel, dtype=np.float32).reshape(-1)[: self.num_actions] + robot_joint_vel = np.asarray(robot_state.qvel, dtype=np.float32).reshape(-1)[: self.num_actions] robot_quat = np.asarray(robot_state.quat, dtype=np.float32).reshape(-1) - base_ang_vel = np.asarray(robot_state.ang_vel, dtype=np.float32).reshape(-1) + robot_base_ang_vel_b = np.asarray(robot_state.ang_vel, dtype=np.float32).reshape(-1) if robot_quat.shape[0] != 4: raise ValueError(f"robot_state.quat must be 4D (wxyz), got {robot_quat.shape[0]}") - if base_ang_vel.shape[0] != 3: - raise ValueError(f"robot_state.ang_vel must be 3D, got {base_ang_vel.shape[0]}") + if robot_base_ang_vel_b.shape[0] != 3: + raise ValueError(f"robot_state.ang_vel must be 3D, got {robot_base_ang_vel_b.shape[0]}") motion = np.asarray(motion_qpos, dtype=np.float32).reshape(-1) if motion.shape[0] < 7 + self.num_actions: @@ -182,9 +182,9 @@ def build( raise ValueError( f"motion_joint_vel must be {self.num_actions}D, got {motion_joint_vel_vec.shape[0]}" ) - last_act = np.asarray(last_action, dtype=np.float32).reshape(-1) - if last_act.shape[0] != self.num_actions: - raise ValueError(f"last_action length must be {self.num_actions}, got {last_act.shape[0]}") + prev_action = np.asarray(last_action, dtype=np.float32).reshape(-1) + if prev_action.shape[0] != self.num_actions: + raise ValueError(f"last_action length must be {self.num_actions}, got {prev_action.shape[0]}") self._run_fk(np.zeros(3, dtype=np.float32), robot_quat, qpos) robot_anchor_quat = self._get_body_quat(self._anchor_body_id) @@ -193,20 +193,22 @@ def build( motion_base_quat = motion[3:7] motion_joint_pos = motion[7:7 + self.num_actions] self._run_fk(motion_base_pos, motion_base_quat, motion_joint_pos) - motion_anchor_quat = self._get_body_quat(self._anchor_body_id) + ref_anchor_quat = self._get_body_quat(self._anchor_body_id) - command = np.concatenate((motion_joint_pos, motion_joint_vel_vec), dtype=np.float32) - rel_quat = _quat_mul_np(_quat_inv_np(robot_anchor_quat), motion_anchor_quat) - motion_anchor_ori_b = _quat_to_rot6d_np(rel_quat) - joint_pos_rel = qpos - self.default_dof_pos + ref_joint_pos = motion_joint_pos + ref_joint_vel = motion_joint_vel_vec + rel_quat = _quat_mul_np(_quat_inv_np(robot_anchor_quat), ref_anchor_quat) + ref_anchor_ori_b = _quat_to_rot6d_np(rel_quat) + robot_joint_pos_rel = qpos - self.default_dof_pos obs = np.concatenate([ - command, - motion_anchor_ori_b, - base_ang_vel, - joint_pos_rel, - qvel, - last_act, + ref_joint_pos, + ref_joint_vel, + ref_anchor_ori_b, + robot_base_ang_vel_b, + robot_joint_pos_rel, + robot_joint_vel, + prev_action, ], dtype=np.float32) if obs.shape[0] != self.total_obs_size: raise ValueError(f"Expected {self.total_obs_size}D base observation, got {obs.shape[0]}") @@ -239,8 +241,8 @@ def build( robot_quat = np.asarray(robot_state.quat, dtype=np.float32).reshape(-1) motion = np.asarray(motion_qpos, dtype=np.float32).reshape(-1) self._base._run_fk(motion[0:3], motion[3:7], motion[7:7 + self.num_actions]) - motion_anchor_pos = self._base._get_body_pos(self._base._anchor_body_id) - motion_anchor_quat = self._base._get_body_quat(self._base._anchor_body_id) + ref_anchor_pos = self._base._get_body_pos(self._base._anchor_body_id) + ref_anchor_quat = self._base._get_body_quat(self._base._anchor_body_id) qpos = np.asarray(robot_state.qpos, dtype=np.float32).reshape(-1)[: self.num_actions] robot_base_pos = np.zeros(3, dtype=np.float32) @@ -251,19 +253,19 @@ def build( ref_lin_vel_w = np.asarray(motion_anchor_lin_vel_w, dtype=np.float32).reshape(3) ref_ang_vel_w = np.asarray(motion_anchor_ang_vel_w, dtype=np.float32).reshape(3) - projected_gravity = _quat_rotate_np(_quat_inv_np(robot_quat), _GRAVITY_UNIT_W) + robot_projected_gravity_b = _quat_rotate_np(_quat_inv_np(robot_quat), _GRAVITY_UNIT_W) robot_inv = _quat_inv_np(robot_anchor_quat) - ref_base_lin_vel_b = _quat_rotate_np(robot_inv, ref_lin_vel_w) - ref_base_ang_vel_b = _quat_rotate_np(robot_inv, ref_ang_vel_w) - ref_projected_gravity_b = _quat_rotate_np(_quat_inv_np(motion_anchor_quat), _GRAVITY_UNIT_W) - ref_base_height = motion_anchor_pos[2:3] + ref_anchor_lin_vel_b = _quat_rotate_np(robot_inv, ref_lin_vel_w) + ref_anchor_ang_vel_b = _quat_rotate_np(robot_inv, ref_ang_vel_w) + ref_projected_gravity_b = _quat_rotate_np(_quat_inv_np(ref_anchor_quat), _GRAVITY_UNIT_W) + ref_anchor_height = ref_anchor_pos[2:3] velcmd_obs = np.concatenate([ - projected_gravity, - ref_base_lin_vel_b, - ref_base_ang_vel_b, + robot_projected_gravity_b, + ref_anchor_lin_vel_b, + ref_anchor_ang_vel_b, ref_projected_gravity_b, - ref_base_height, + ref_anchor_height, ], dtype=np.float32) obs = np.concatenate([base_obs, velcmd_obs], dtype=np.float32) if obs.shape[0] != self.total_obs_size: diff --git a/tests/test_task_registry.py b/tests/test_task_registry.py index 1419ab47..9d889926 100644 --- a/tests/test_task_registry.py +++ b/tests/test_task_registry.py @@ -22,14 +22,38 @@ def test_general_tracking_task_is_registered() -> None: critic_terms = env_cfg.observations["critic"].terms assert DEFAULT_TASK == GENERAL_TRACKING_TASK - for terms in (actor_terms, critic_terms): - assert "projected_gravity" in terms - assert "ref_base_lin_vel_b" in terms - assert "ref_base_ang_vel_b" in terms - assert "ref_projected_gravity_b" in terms - assert "ref_base_height" in terms - assert "motion_anchor_pos_b" not in actor_terms - assert "base_lin_vel" not in actor_terms + assert list(actor_terms) == [ + "ref_joint_pos", + "ref_joint_vel", + "ref_anchor_ori_b", + "robot_base_ang_vel_b", + "robot_joint_pos_rel", + "robot_joint_vel", + "prev_action", + "robot_projected_gravity_b", + "ref_anchor_lin_vel_b", + "ref_anchor_ang_vel_b", + "ref_projected_gravity_b", + "ref_anchor_height", + ] + assert list(critic_terms) == [ + "ref_joint_pos", + "ref_joint_vel", + "ref_anchor_pos_b", + "ref_anchor_ori_b", + "robot_tracking_body_pos_b", + "robot_tracking_body_ori_b", + "robot_base_lin_vel_b", + "robot_base_ang_vel_b", + "robot_joint_pos_rel", + "robot_joint_vel", + "prev_action", + "robot_projected_gravity_b", + "ref_anchor_lin_vel_b", + "ref_anchor_ang_vel_b", + "ref_projected_gravity_b", + "ref_anchor_height", + ] assert "actor_history" in env_cfg.observations assert "critic_history" in env_cfg.observations assert env_cfg.commands["motion"].sampling_mode == "uniform" diff --git a/train_mimic/tasks/tracking/config/env.py b/train_mimic/tasks/tracking/config/env.py index 880662c0..9f45a2ee 100644 --- a/train_mimic/tasks/tracking/config/env.py +++ b/train_mimic/tasks/tracking/config/env.py @@ -77,44 +77,44 @@ def _add_history_obs_groups( _VELCMD_ACTOR_TERMS: dict[str, ObservationTermCfg] = { - "projected_gravity": ObservationTermCfg( + "robot_projected_gravity_b": ObservationTermCfg( func=mdp.projected_gravity, noise=Unoise(n_min=-0.05, n_max=0.05), ), - "ref_base_lin_vel_b": ObservationTermCfg( - func=mdp.ref_base_lin_vel_b, + "ref_anchor_lin_vel_b": ObservationTermCfg( + func=mdp.ref_anchor_lin_vel_b, params={"command_name": "motion"}, ), - "ref_base_ang_vel_b": ObservationTermCfg( - func=mdp.ref_base_ang_vel_b, + "ref_anchor_ang_vel_b": ObservationTermCfg( + func=mdp.ref_anchor_ang_vel_b, params={"command_name": "motion"}, ), "ref_projected_gravity_b": ObservationTermCfg( func=mdp.ref_projected_gravity_b, params={"command_name": "motion"}, ), - "ref_base_height": ObservationTermCfg( - func=mdp.ref_base_height, + "ref_anchor_height": ObservationTermCfg( + func=mdp.ref_anchor_height, params={"command_name": "motion"}, ), } _VELCMD_CRITIC_TERMS: dict[str, ObservationTermCfg] = { - "projected_gravity": ObservationTermCfg(func=mdp.projected_gravity), - "ref_base_lin_vel_b": ObservationTermCfg( - func=mdp.ref_base_lin_vel_b, + "robot_projected_gravity_b": ObservationTermCfg(func=mdp.projected_gravity), + "ref_anchor_lin_vel_b": ObservationTermCfg( + func=mdp.ref_anchor_lin_vel_b, params={"command_name": "motion"}, ), - "ref_base_ang_vel_b": ObservationTermCfg( - func=mdp.ref_base_ang_vel_b, + "ref_anchor_ang_vel_b": ObservationTermCfg( + func=mdp.ref_anchor_ang_vel_b, params={"command_name": "motion"}, ), "ref_projected_gravity_b": ObservationTermCfg( func=mdp.ref_projected_gravity_b, params={"command_name": "motion"}, ), - "ref_base_height": ObservationTermCfg( - func=mdp.ref_base_height, + "ref_anchor_height": ObservationTermCfg( + func=mdp.ref_anchor_height, params={"command_name": "motion"}, ), } @@ -209,7 +209,7 @@ def make_general_tracking_env_cfg( actor_terms = { key: value for key, value in cfg.observations["actor"].terms.items() - if key not in {"motion_anchor_pos_b", "base_lin_vel"} + if key not in {"ref_anchor_pos_b", "robot_base_lin_vel_b"} } cfg.observations["actor"] = ObservationGroupCfg( terms=actor_terms, diff --git a/train_mimic/tasks/tracking/mdp/observations.py b/train_mimic/tasks/tracking/mdp/observations.py index 7b9b66fe..a7484efa 100644 --- a/train_mimic/tasks/tracking/mdp/observations.py +++ b/train_mimic/tasks/tracking/mdp/observations.py @@ -18,7 +18,17 @@ from mjlab.envs import ManagerBasedRlEnv -def motion_anchor_pos_b(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: +def ref_joint_pos(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: + command = cast(MotionCommand, env.command_manager.get_term(command_name)) + return command.joint_pos + + +def ref_joint_vel(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: + command = cast(MotionCommand, env.command_manager.get_term(command_name)) + return command.joint_vel + + +def ref_anchor_pos_b(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: command = cast(MotionCommand, env.command_manager.get_term(command_name)) pos, _ = subtract_frame_transforms( @@ -31,7 +41,7 @@ def motion_anchor_pos_b(env: ManagerBasedRlEnv, command_name: str) -> torch.Tens return pos.view(env.num_envs, -1) -def motion_anchor_ori_b(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: +def ref_anchor_ori_b(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: command = cast(MotionCommand, env.command_manager.get_term(command_name)) _, ori = subtract_frame_transforms( @@ -44,7 +54,7 @@ def motion_anchor_ori_b(env: ManagerBasedRlEnv, command_name: str) -> torch.Tens return mat[..., :2].reshape(mat.shape[0], -1) -def robot_body_pos_b(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: +def robot_tracking_body_pos_b(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: command = cast(MotionCommand, env.command_manager.get_term(command_name)) num_bodies = len(command.cfg.body_names) @@ -58,7 +68,7 @@ def robot_body_pos_b(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: return pos_b.view(env.num_envs, -1) -def robot_body_ori_b(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: +def robot_tracking_body_ori_b(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: command = cast(MotionCommand, env.command_manager.get_term(command_name)) num_bodies = len(command.cfg.body_names) @@ -73,18 +83,17 @@ def robot_body_ori_b(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: # --------------------------------------------------------------------------- -# Velocity-command observation terms: reference velocities and projected -# gravity for the VelCmd task variant. +# Velocity-command observation terms: reference velocities and projected gravity. # --------------------------------------------------------------------------- -def ref_base_lin_vel_b(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: +def ref_anchor_lin_vel_b(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: """Reference anchor linear velocity in the robot's body frame. (N, 3)""" command = cast(MotionCommand, env.command_manager.get_term(command_name)) return quat_apply(quat_inv(command.robot_anchor_quat_w), command.anchor_lin_vel_w) -def ref_base_ang_vel_b(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: +def ref_anchor_ang_vel_b(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: """Reference anchor angular velocity in the robot's body frame. (N, 3)""" command = cast(MotionCommand, env.command_manager.get_term(command_name)) return quat_apply(quat_inv(command.robot_anchor_quat_w), command.anchor_ang_vel_w) @@ -101,7 +110,7 @@ def ref_projected_gravity_b(env: ManagerBasedRlEnv, command_name: str) -> torch. return quat_apply(quat_inv(command.anchor_quat_w), asset.data.gravity_vec_w) -def ref_base_height(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: +def ref_anchor_height(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: """Reference anchor height (z-coordinate). (N, 1)""" command = cast(MotionCommand, env.command_manager.get_term(command_name)) return command.anchor_pos_w[:, 2:3] @@ -113,7 +122,7 @@ def ref_base_height(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: # --------------------------------------------------------------------------- -def motion_anchor_pos_b_yaw(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: +def ref_anchor_pos_b_yaw(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: command = cast(MotionCommand, env.command_manager.get_term(command_name)) pos, _ = subtract_frame_transforms( @@ -126,7 +135,7 @@ def motion_anchor_pos_b_yaw(env: ManagerBasedRlEnv, command_name: str) -> torch. return pos.view(env.num_envs, -1) -def motion_anchor_ori_b_yaw(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: +def ref_anchor_ori_b_yaw(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: command = cast(MotionCommand, env.command_manager.get_term(command_name)) _, ori = subtract_frame_transforms( @@ -139,7 +148,7 @@ def motion_anchor_ori_b_yaw(env: ManagerBasedRlEnv, command_name: str) -> torch. return mat[..., :2].reshape(mat.shape[0], -1) -def robot_body_pos_b_yaw(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: +def robot_tracking_body_pos_b_yaw(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: command = cast(MotionCommand, env.command_manager.get_term(command_name)) num_bodies = len(command.cfg.body_names) @@ -154,7 +163,7 @@ def robot_body_pos_b_yaw(env: ManagerBasedRlEnv, command_name: str) -> torch.Ten return pos_b.view(env.num_envs, -1) -def robot_body_ori_b_yaw(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: +def robot_tracking_body_ori_b_yaw(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: command = cast(MotionCommand, env.command_manager.get_term(command_name)) num_bodies = len(command.cfg.body_names) diff --git a/train_mimic/tasks/tracking/tracking_env_cfg.py b/train_mimic/tasks/tracking/tracking_env_cfg.py index 1b5be1f6..c0301f56 100644 --- a/train_mimic/tasks/tracking/tracking_env_cfg.py +++ b/train_mimic/tasks/tracking/tracking_env_cfg.py @@ -42,65 +42,71 @@ def make_tracking_env_cfg() -> ManagerBasedRlEnvCfg: ## actor_terms = { - "command": ObservationTermCfg( - func=mdp.generated_commands, params={"command_name": "motion"} + "ref_joint_pos": ObservationTermCfg( + func=mdp.ref_joint_pos, params={"command_name": "motion"} ), - "motion_anchor_pos_b": ObservationTermCfg( - func=mdp.motion_anchor_pos_b, + "ref_joint_vel": ObservationTermCfg( + func=mdp.ref_joint_vel, params={"command_name": "motion"} + ), + "ref_anchor_pos_b": ObservationTermCfg( + func=mdp.ref_anchor_pos_b, params={"command_name": "motion"}, noise=Unoise(n_min=-0.25, n_max=0.25), ), - "motion_anchor_ori_b": ObservationTermCfg( - func=mdp.motion_anchor_ori_b, + "ref_anchor_ori_b": ObservationTermCfg( + func=mdp.ref_anchor_ori_b, params={"command_name": "motion"}, noise=Unoise(n_min=-0.05, n_max=0.05), ), - "base_lin_vel": ObservationTermCfg( + "robot_base_lin_vel_b": ObservationTermCfg( func=mdp.builtin_sensor, params={"sensor_name": "robot/imu_lin_vel"}, noise=Unoise(n_min=-0.5, n_max=0.5), ), - "base_ang_vel": ObservationTermCfg( + "robot_base_ang_vel_b": ObservationTermCfg( func=mdp.builtin_sensor, params={"sensor_name": "robot/imu_ang_vel"}, noise=Unoise(n_min=-0.2, n_max=0.2), ), - "joint_pos": ObservationTermCfg( + "robot_joint_pos_rel": ObservationTermCfg( func=mdp.joint_pos_rel, noise=Unoise(n_min=-0.01, n_max=0.01), params={"biased": True}, ), - "joint_vel": ObservationTermCfg( + "robot_joint_vel": ObservationTermCfg( func=mdp.joint_vel_rel, noise=Unoise(n_min=-0.5, n_max=0.5) ), - "actions": ObservationTermCfg(func=mdp.last_action), + "prev_action": ObservationTermCfg(func=mdp.last_action), } critic_terms = { - "command": ObservationTermCfg( - func=mdp.generated_commands, params={"command_name": "motion"} + "ref_joint_pos": ObservationTermCfg( + func=mdp.ref_joint_pos, params={"command_name": "motion"} + ), + "ref_joint_vel": ObservationTermCfg( + func=mdp.ref_joint_vel, params={"command_name": "motion"} ), - "motion_anchor_pos_b": ObservationTermCfg( - func=mdp.motion_anchor_pos_b, params={"command_name": "motion"} + "ref_anchor_pos_b": ObservationTermCfg( + func=mdp.ref_anchor_pos_b, params={"command_name": "motion"} ), - "motion_anchor_ori_b": ObservationTermCfg( - func=mdp.motion_anchor_ori_b, params={"command_name": "motion"} + "ref_anchor_ori_b": ObservationTermCfg( + func=mdp.ref_anchor_ori_b, params={"command_name": "motion"} ), - "body_pos": ObservationTermCfg( - func=mdp.robot_body_pos_b, params={"command_name": "motion"} + "robot_tracking_body_pos_b": ObservationTermCfg( + func=mdp.robot_tracking_body_pos_b, params={"command_name": "motion"} ), - "body_ori": ObservationTermCfg( - func=mdp.robot_body_ori_b, params={"command_name": "motion"} + "robot_tracking_body_ori_b": ObservationTermCfg( + func=mdp.robot_tracking_body_ori_b, params={"command_name": "motion"} ), - "base_lin_vel": ObservationTermCfg( + "robot_base_lin_vel_b": ObservationTermCfg( func=mdp.builtin_sensor, params={"sensor_name": "robot/imu_lin_vel"} ), - "base_ang_vel": ObservationTermCfg( + "robot_base_ang_vel_b": ObservationTermCfg( func=mdp.builtin_sensor, params={"sensor_name": "robot/imu_ang_vel"} ), - "joint_pos": ObservationTermCfg(func=mdp.joint_pos_rel), - "joint_vel": ObservationTermCfg(func=mdp.joint_vel_rel), - "actions": ObservationTermCfg(func=mdp.last_action), + "robot_joint_pos_rel": ObservationTermCfg(func=mdp.joint_pos_rel), + "robot_joint_vel": ObservationTermCfg(func=mdp.joint_vel_rel), + "prev_action": ObservationTermCfg(func=mdp.last_action), } observations = { From e4d37251af5b51192c67e4daa3ccd090d318eeca Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Fri, 12 Jun 2026 17:31:05 +0800 Subject: [PATCH 080/122] Add Pico ARMS mode --- AGENTS.md | 8 +- README.md | 1 + docs/docs/configuration/config-reference.md | 5 +- docs/docs/tutorials/pico-sim2real.md | 10 +- docs/docs/tutorials/pico-sim2sim.md | 4 + .../current/configuration/config-reference.md | 5 +- .../current/tutorials/pico-sim2real.md | 7 +- .../current/tutorials/pico-sim2sim.md | 4 + scripts/run/run_sim.py | 6 +- scripts/run/run_sim2real.py | 4 +- teleopit/configs/input/pico4.yaml | 2 + teleopit/configs/pico4_sim.yaml | 2 + teleopit/configs/pico4_sim2real.yaml | 3 + teleopit/inputs/pico4_provider.py | 75 ++++++++-- teleopit/inputs/realtime_packet.py | 1 + teleopit/runtime/arm_mocap.py | 78 ++++++++++ teleopit/sim/loop.py | 14 ++ teleopit/sim/session.py | 48 ++++++- teleopit/sim2real/mp/runtime.py | 104 +++++++++++--- teleopit/sim2real/reference_processor.py | 5 +- tests/test_pico4_provider.py | 22 ++- tests/test_sim2real_multiprocess.py | 103 ++++++++++++++ tests/test_sim_loop.py | 134 ++++++++++++++++++ 23 files changed, 595 insertions(+), 50 deletions(-) create mode 100644 teleopit/runtime/arm_mocap.py diff --git a/AGENTS.md b/AGENTS.md index f7828ebd..2b4d48b5 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -138,10 +138,12 @@ target_dof_pos = clip(action, -10, 10) × action_scale + default_dof_pos - The provider applies an input-space transform to match the current retarget config - Do not hardcode that transform as a public coordinate-system contract; validate against actual retarget/sim2sim behavior when SDK or firmware changes - Pico4 realtime control uses the same retargeted-reference timeline path as the shared realtime input stack -- Pico sim2sim supports a keyboard-driven top-level mode state machine: `STANDING → MOCAP → STANDING` -- Default Pico sim2sim keyboard mappings are `Y` → `MOCAP`, `A` → pause/resume mocap, `X` → back to `STANDING`, `Q` → quit +- Pico sim2sim supports a keyboard-driven top-level mode state machine: `STANDING → MOCAP ↔ ARMS`, `X` returns to `STANDING` +- Default Pico sim2sim keyboard mappings are `Y` → `MOCAP`, `A` → pause/resume mocap, `B` → toggle `MOCAP`/`ARMS`, `X` → back to `STANDING`, `Q` → quit - Pico4 sim2real pause/resume is handled as a mocap-session control event (`toggle_pause`), not as a mode switch to `STANDING` - Default Pico pause button is `A`; resume resets policy/reference state and yaw/XY root-offset alignment while the process-isolated realtime reference worker continues its live input timeline +- Pico4 sim2sim/sim2real support `ARMS` mode toggled from `MOCAP` with Pico/controller `B`; retargeting continues, while the control loop sends the motion tracker a composed reference with stand-pose body/legs/waist and live retargeted arms +- `ARMS` entering/exiting/resume resets policy/reference alignment and uses Kp ramp; offline BVH sim2real does not use `ARMS`, and Unitree remote `B` remains BVH replay - Realtime mode switches and pause/resume use a retargeter-preserving soft reset: policy/reference state, smoothers, and reference alignment are reset, while the GMR IK warm-start is retained - Optional LinkerHand L6 control uses `hands.enabled=true` and `hands.mode=gripper|vr_hand_pose`; default is disabled - `gripper` mode reuses `Pico4InputProvider.get_controller_snapshot()` for Pico grip/trigger open-close control @@ -149,7 +151,7 @@ target_dof_pos = clip(action, -10, 10) × action_scale + default_dof_pos - Teleopit owns Pico 26-joint hand-state to 21-landmark conversion; do not import `somehand.pico_input` - `gripper` mode uses the configured `hands.linkerhand_l6.speed` (default `[50]*6`); `vr_hand_pose` always sets LinkerHand L6 speed to `[255]*6` - `vr_hand_pose` defaults to a low-latency somehand path: `hands.somehand.rate_hz=60`, `max_iterations=12`, `temporal_filter_alpha=1.0`, and `output_alpha=1.0`; this prioritizes response speed over smoothing -- LinkerHand L6 control is active only in sim2real `MOCAP`; `STANDING`, `DAMPING`, mocap pause, and shutdown must send the configured open pose +- LinkerHand L6 control is active in sim2real `MOCAP` and `ARMS`; `STANDING`, `DAMPING`, mocap pause, and shutdown must send the configured open pose - In `vr_hand_pose` mode, missing/inactive hand pose holds the last commanded pose for that side instead of opening the hand ### SimulationLoop Runtime Behavior diff --git a/README.md b/README.md index b217d1b5..8b88adb7 100644 --- a/README.md +++ b/README.md @@ -88,6 +88,7 @@ Full docs at **[BotRunner64.github.io/Teleopit](https://BotRunner64.github.io/Te ### Unreleased +- Added Pico sim2real `ARMS` mode: Pico/controller `B` toggles between whole-body `MOCAP` and stand-pose body/legs with live retargeted arms. - Bumped Pico input support to pico-bridge 0.2.1 and its corrected tracking pose semantics. - Added optional LinkerHand L6 sim2real modes under `hands.*`: `gripper` from Pico grip/trigger and `vr_hand_pose` from Pico hand pose through somehand 0.2.0 public API. - Set LinkerHand L6 `vr_hand_pose` control to maximum speed while keeping `gripper` at the configured default speed. diff --git a/docs/docs/configuration/config-reference.md b/docs/docs/configuration/config-reference.md index 2b922496..3f49c70c 100644 --- a/docs/docs/configuration/config-reference.md +++ b/docs/docs/configuration/config-reference.md @@ -62,6 +62,8 @@ Complete reference for all configurable fields. | `input.pico4_buffer_size` | Frame buffer size | `60` | | `input.pause_button` | Button for pause/resume | `A` | | `input.pause_debounce_s` | Debounce time for pause button | `0.25` | +| `input.arms_button` | Button for Pico `MOCAP` / `ARMS` toggle | `B` | +| `input.arms_debounce_s` | Debounce time for arms-mode button | `0.25` | | `input.bridge_host` | Teleopit host receiver bind host | `0.0.0.0` | | `input.bridge_port` | Teleopit host receiver TCP/UDP port | `63901` | | `input.bridge_discovery` | Enable pico-bridge discovery advertising | `true` | @@ -101,6 +103,7 @@ and `all` are simulation-only viewer modes. | `startup_ramp_duration` | Kp ramp duration after entering `STANDING`; gradually increases PD gains without changing policy targets | `2.0` | | `joint_vel_limit` | Joint velocity limit (rad/s); triggers emergency damping if exceeded | `10.0` | | `mocap_switch.check_frames` | Consecutive valid frames required before switching to MOCAP | `10` | +| `arm_mocap.controlled_joint_indices` | G1 joints driven by live retargeting in Pico `ARMS` mode | `[15..28]` | ### Real Robot @@ -121,7 +124,7 @@ Realtime Pico resume re-centers heading and ground-plane position before trackin ### Dexterous Hand (Pico sim2real) `hands.enabled=true` requires `input.provider=pico4` and the optional `dexhand` -extra. Control is active only in `MOCAP`; inactive modes send the open pose. In +extra. Control is active in `MOCAP` and `ARMS`; inactive modes send the open pose. In `vr_hand_pose`, missing hand pose holds the last command for that side. `gripper` uses the configured `hands.linkerhand_l6.speed`; `vr_hand_pose` always sets LinkerHand L6 speed to the maximum. Teleopit converts Pico hand diff --git a/docs/docs/tutorials/pico-sim2real.md b/docs/docs/tutorials/pico-sim2real.md index f9a1e776..c7f7cfd1 100644 --- a/docs/docs/tutorials/pico-sim2real.md +++ b/docs/docs/tutorials/pico-sim2real.md @@ -105,6 +105,7 @@ Keep the Unitree remote in hand. `L1+R1` is the emergency stop path into | Unitree remote `Start` | Enter `STANDING` | | Unitree remote `Y` | Enter `MOCAP` | | Pico/controller `A` | Pause / resume live mocap | +| Pico/controller `B` | Toggle `MOCAP` / `ARMS` | | Unitree remote `X` | Return to `STANDING` | | Unitree remote `L1+R1` | Emergency stop (`DAMPING`) | @@ -127,6 +128,11 @@ and ramps Kp without changing policy targets. When entering `MOCAP`, Teleopit resets policy/reference state and starts tracking the live mocap command through the realtime reference timeline. +`ARMS` keeps the same live retargeting timeline running, but sends the motion +tracker a composed reference: body, waist, and legs stay at the standing pose +while both arms follow the live retargeted result. Entering or leaving `ARMS` +resets policy/reference alignment and uses the same Kp ramp safety path. + ## Pause / Resume Pico pause/resume is a mocap-session control event. @@ -153,8 +159,8 @@ Pico sim2real can drive LinkerHand L6 hands in two modes: public `somehand.api` from somehand 0.2.0. It always sets L6 speed to the maximum. -Hand control is active only in `MOCAP`. It sends the open pose in `STANDING`, -`DAMPING`, paused mocap, and shutdown. +Hand control is active in `MOCAP` and `ARMS`. It sends the open pose in +`STANDING`, `DAMPING`, paused mocap, and shutdown. Install the dexhand extra first if it was not installed with the main Pico profile: diff --git a/docs/docs/tutorials/pico-sim2sim.md b/docs/docs/tutorials/pico-sim2sim.md index b72bfc36..b72dbcd1 100644 --- a/docs/docs/tutorials/pico-sim2sim.md +++ b/docs/docs/tutorials/pico-sim2sim.md @@ -71,6 +71,7 @@ enter `MOCAP`. |----------|--------| | `Y` | Enter `MOCAP` | | `A` | Pause / resume live mocap | +| `B` | Toggle `MOCAP` / `ARMS` | | `X` | Return to `STANDING` | | `Q` | Quit | @@ -92,6 +93,9 @@ The default Pico pause button is `A`. Supported overrides include `B`, `X`, `Y`, `left_axis_click`, `right_axis_click`, `left_menu_button`, and `right_menu_button`. +The default Pico arms-mode button is `B`. `ARMS` keeps body, waist, and legs at +the standing pose while both arms follow the live retargeted result. + ## Optional Headset Video Preview pico-bridge 0.2.1 can show a host-side camera stream in the headset. In diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md index c1bf8350..e42000af 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md @@ -79,6 +79,8 @@ target = clip(action, clip_range) * action_scale + default_dof_pos | `pico4_buffer_size` | int | `60` | 帧缓冲区大小 | | `pause_button` | str | `A` | 用于暂停/恢复的手柄按钮名称 | | `pause_debounce_s` | float | `0.25` | 暂停按钮防抖时间 | +| `arms_button` | str | `B` | Pico 中用于切换 `MOCAP` / `ARMS` 的按钮 | +| `arms_debounce_s` | float | `0.25` | 双臂模式按钮防抖时间 | | `bridge_host` | str | `0.0.0.0` | Teleopit host receiver 绑定地址 | | `bridge_port` | int | `63901` | Teleopit host receiver TCP/UDP 端口 | | `bridge_discovery` | bool | `true` | 是否启用 pico-bridge 发现广播 | @@ -120,6 +122,7 @@ MuJoCo 窗口显示重定向参考;`sim2sim`、`mocap`、`camera` 和 `all` | `startup_ramp_duration` | 进入 `STANDING` 后的 Kp ramp 时长;逐步提高 PD 增益,不改变 policy target | `2.0` | | `joint_vel_limit` | 关节速度限制(rad/s),超过时触发急停 | `10.0` | | `mocap_switch.check_frames` | 切换到 MOCAP 前所需的连续有效帧数 | `10` | +| `arm_mocap.controlled_joint_indices` | Pico `ARMS` 模式下由实时 retargeting 驱动的 G1 关节 | `[15..28]` | ### 真机 SDK @@ -140,7 +143,7 @@ MuJoCo 窗口显示重定向参考;`sim2sim`、`mocap`、`camera` 和 `all` ### 灵巧手(Pico sim2real) `hands.mode=gripper` 或 `hands.mode=vr_hand_pose` 要求 `input.provider=pico4`, -并安装可选的 `dexhand` extra。控制只在 `MOCAP` 中生效;非活动模式会发送张开姿态。 +并安装可选的 `dexhand` extra。控制在 `MOCAP` 和 `ARMS` 中生效;非活动模式会发送张开姿态。 在 `vr_hand_pose` 中,Teleopit 将 Pico 手部 pose 适配成 somehand 0.2.0 的 landmark 输入,只调用公开的 `somehand.api`;手部 pose 消失时,对应侧会保持上一条命令。 `gripper` 使用配置的 `hands.linkerhand_l6.speed`;`vr_hand_pose` 始终将 diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md index ef6318ff..37573f0c 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md @@ -100,6 +100,7 @@ python scripts/run/run_sim2real.py \ | Unitree remote `Start` | 进入 `STANDING` | | Unitree remote `Y` | 进入 `MOCAP` | | Pico/controller `A` | 暂停 / 恢复实时动捕 | +| Pico/controller `B` | 在 `MOCAP` / `ARMS` 之间切换 | | Unitree remote `X` | 返回 `STANDING` | | Unitree remote `L1+R1` | 急停(`DAMPING`) | @@ -120,6 +121,10 @@ Pico body frames -> retarget -> reference buffer -> observation -> policy -> G1 进入 `MOCAP` 时,Teleopit 会重置 policy/reference 状态,并通过实时参考时间线开始跟踪 实时 mocap 命令。 +`ARMS` 会保持同一条实时 retargeting 时间线继续运行,但发送给 motion tracker 的参考会被组合: +身体、腰部和腿部保持站立姿态,双臂跟随实时 retarget 结果。进入或离开 `ARMS` 时会重置 +policy/reference 对齐,并使用同一套 Kp ramp 安全路径。 + ## 暂停 / 恢复 Pico 暂停/恢复是 mocap-session control event。 @@ -144,7 +149,7 @@ Pico sim2real 可以用两种模式控制 LinkerHand L6: 速度设为最大值。默认配置使用 60 Hz 的低延时 somehand 路径并减少平滑,所以响应会更快, 但可能比标准 somehand 设置更抖。 -手控只在 `MOCAP` 中生效;在 `STANDING`、`DAMPING`、mocap 暂停和退出时都会发送张开姿态。 +手控在 `MOCAP` 和 `ARMS` 中生效;在 `STANDING`、`DAMPING`、mocap 暂停和退出时都会发送张开姿态。 如果主 Pico profile 没有包含手控支持,先安装 dexhand extra: diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2sim.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2sim.md index 3601a9f1..1ea75b66 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2sim.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2sim.md @@ -69,6 +69,7 @@ python scripts/run/run_sim.py \ |------|------| | `Y` | 进入 `MOCAP` | | `A` | 暂停 / 恢复实时动捕 | +| `B` | 在 `MOCAP` / `ARMS` 之间切换 | | `X` | 返回 `STANDING` | | `Q` | 退出 | @@ -85,6 +86,9 @@ Pico 暂停/恢复会冻结 mocap session;它不是切回 `STANDING`。 默认 Pico 暂停键是 `A`。支持的覆盖值包括 `B`、`X`、`Y`、`left_axis_click`、 `right_axis_click`、`left_menu_button` 和 `right_menu_button`。 +默认 Pico 双臂模式按钮是 `B`。`ARMS` 会让身体、腰部和腿部保持站立姿态,同时双臂跟随 +实时 retarget 结果。 + ## 可选头显视频预览 pico-bridge 0.2.1 可以在头显中显示 host 侧视频流。在仿真中,Teleopit 可以推送 diff --git a/scripts/run/run_sim.py b/scripts/run/run_sim.py index fb729844..a695945a 100644 --- a/scripts/run/run_sim.py +++ b/scripts/run/run_sim.py @@ -12,10 +12,10 @@ def _print_sim_controls(cfg: DictConfig) -> None: if provider == "pico4": print("Pico sim2sim controls:") if bool(cfg.get("keyboard", {}).get("enabled", False)): - print(" Keyboard: starts in STANDING; Y mocap, A pause/resume, X standing, Q quit.") + print(" Keyboard: starts in STANDING; Y mocap, A pause/resume, B arms, X standing, Q quit.") else: - print(" Pico controller: A pause/resume.") - print(" State flow: STANDING -> MOCAP -> STANDING.") + print(" Pico controller: A pause/resume, B arms.") + print(" State flow: STANDING -> MOCAP <-> ARMS, X -> STANDING.") return if bool(cfg.get("playback", {}).get("keyboard", {}).get("enabled", False)): print("Offline sim2sim controls:") diff --git a/scripts/run/run_sim2real.py b/scripts/run/run_sim2real.py index a3553ddc..080a7682 100644 --- a/scripts/run/run_sim2real.py +++ b/scripts/run/run_sim2real.py @@ -18,10 +18,12 @@ def _print_sim2real_controls(cfg: DictConfig) -> None: print(" Remote L1+R1: DAMPING / estop.") if provider == "pico4": print(" Mocap pause/resume: Pico/controller A.") + print(" Arm-only mode: Pico/controller B toggles MOCAP <-> ARMS.") print(" Dexterous hand: hands.enabled=true hands.mode=gripper|vr_hand_pose.") + print(" State flow: IDLE -> STANDING -> MOCAP <-> ARMS, X -> STANDING, Any -> DAMPING.") else: print(" Offline playback: A pause/resume, B replay from start.") - print(" State flow: IDLE -> STANDING -> MOCAP -> STANDING, Any -> DAMPING.") + print(" State flow: IDLE -> STANDING -> MOCAP -> STANDING, Any -> DAMPING.") @hydra.main(version_base=None, config_path="../../teleopit/configs", config_name="sim2real") diff --git a/teleopit/configs/input/pico4.yaml b/teleopit/configs/input/pico4.yaml index be0ef2cd..8aa4d7dc 100644 --- a/teleopit/configs/input/pico4.yaml +++ b/teleopit/configs/input/pico4.yaml @@ -7,6 +7,8 @@ pico4_buffer_size: 60 pico4_timestamp_gap_reset_s: 0.15 pause_button: A pause_debounce_s: 0.25 +arms_button: B +arms_debounce_s: 0.25 bridge_host: "0.0.0.0" bridge_port: 63901 bridge_discovery: true diff --git a/teleopit/configs/pico4_sim.yaml b/teleopit/configs/pico4_sim.yaml index c36be166..9abf8243 100644 --- a/teleopit/configs/pico4_sim.yaml +++ b/teleopit/configs/pico4_sim.yaml @@ -19,6 +19,8 @@ reference_velocity_smoothing_alpha: 0.35 reference_anchor_velocity_smoothing_alpha: 0.25 reference_steps: [0] reference_debug_log: false +arm_mocap: + controlled_joint_indices: [15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28] viewers: "all" realtime: true num_steps: 0 # 0 = infinite loop until Ctrl+C or device disconnect diff --git a/teleopit/configs/pico4_sim2real.yaml b/teleopit/configs/pico4_sim2real.yaml index d1e2816a..4ff0f072 100644 --- a/teleopit/configs/pico4_sim2real.yaml +++ b/teleopit/configs/pico4_sim2real.yaml @@ -41,6 +41,9 @@ standing_return_kp_ramp_floor_ratio: 0.5 # Joint velocity safety limit (rad/s) -- trigger emergency damping if exceeded joint_vel_limit: 10.0 +arm_mocap: + controlled_joint_indices: [15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28] + # Optional LinkerHand L6 control from Pico controller grip/trigger or VR hand pose. hands: enabled: false diff --git a/teleopit/inputs/pico4_provider.py b/teleopit/inputs/pico4_provider.py index 37b952c2..1780b03f 100644 --- a/teleopit/inputs/pico4_provider.py +++ b/teleopit/inputs/pico4_provider.py @@ -184,6 +184,8 @@ def __init__( timestamp_gap_reset_s: float = 0.15, pause_button: str | None = "A", pause_debounce_s: float = 0.25, + arms_button: str | None = "B", + arms_debounce_s: float | None = None, bridge_host: str = "0.0.0.0", bridge_port: int = 63901, bridge_discovery: bool = True, @@ -224,10 +226,15 @@ def __init__( self._timestamp_gap_reset_s = float(timestamp_gap_reset_s) self._pending_control_events: deque[ControlEvent] = deque() self._pause_button = None if pause_button in (None, "", "null") else str(pause_button) + self._arms_button = None if arms_button in (None, "", "null") else str(arms_button) self._pause_debounce_s = max(float(pause_debounce_s), 0.0) + self._arms_debounce_s = self._pause_debounce_s if arms_debounce_s is None else max(float(arms_debounce_s), 0.0) self._pause_button_path = self._resolve_button_path(self._pause_button) + self._arms_button_path = self._resolve_button_path(self._arms_button) self._last_pause_button_pressed = False + self._last_arms_button_pressed = False self._last_pause_toggle_timestamp: float | None = None + self._last_arms_toggle_timestamp: float | None = None self._last_raw_body_joints: NDArray[np.float64] | None = None self._last_frame_timestamp: float | None = None self._last_source_seq: int | None = None @@ -252,6 +259,11 @@ def __init__( "Pico4InputProvider pause button '%s' is unsupported by pico_bridge; pause events disabled", self._pause_button, ) + if self._arms_button is not None and self._arms_button_path is None: + logger.warning( + "Pico4InputProvider arms button '%s' is unsupported by pico_bridge; arms events disabled", + self._arms_button, + ) logger.info("Pico4InputProvider initialized (pico_bridge)") @property @@ -456,38 +468,73 @@ def _accept_hand_snapshot(self, frame: Any, *, timestamp: float) -> None: self._hand_snapshot = snapshot def _poll_control_events(self, frame: Any, *, timestamp: float) -> bool: - if self._pause_button_path is None: + emitted = False + emitted = self._poll_button_control_event( + frame, + timestamp=timestamp, + button_path=self._pause_button_path, + button_label=self._pause_button, + event_type=ControlEventType.TOGGLE_PAUSE, + last_pressed_attr="_last_pause_button_pressed", + last_toggle_attr="_last_pause_toggle_timestamp", + debounce_s=self._pause_debounce_s, + ) or emitted + emitted = self._poll_button_control_event( + frame, + timestamp=timestamp, + button_path=self._arms_button_path, + button_label=self._arms_button, + event_type=ControlEventType.TOGGLE_ARMS, + last_pressed_attr="_last_arms_button_pressed", + last_toggle_attr="_last_arms_toggle_timestamp", + debounce_s=self._arms_debounce_s, + ) or emitted + return emitted + + def _poll_button_control_event( + self, + frame: Any, + *, + timestamp: float, + button_path: tuple[str, str] | None, + button_label: str | None, + event_type: ControlEventType, + last_pressed_attr: str, + last_toggle_attr: str, + debounce_s: float, + ) -> bool: + if button_path is None: return False - side, button_name = self._pause_button_path + side, button_name = button_path controllers = getattr(frame, "controllers", None) controller = None if controllers is None else getattr(controllers, side, None) buttons = {} if controller is None else getattr(controller, "buttons", {}) or {} pressed = bool(buttons.get(button_name, False)) + last_pressed = bool(getattr(self, last_pressed_attr)) emitted = False - if pressed and not self._last_pause_button_pressed: - if ( - self._last_pause_toggle_timestamp is None - or timestamp - self._last_pause_toggle_timestamp >= self._pause_debounce_s - 1e-9 - ): + if pressed and not last_pressed: + last_toggle = getattr(self, last_toggle_attr) + if last_toggle is None or timestamp - float(last_toggle) >= debounce_s - 1e-9: with self._lock: self._pending_control_events.append( ControlEvent( - event_type=ControlEventType.TOGGLE_PAUSE, - source=f"pico4:{self._pause_button}", + event_type=event_type, + source=f"pico4:{button_label}", timestamp_s=float(timestamp), ) ) - self._last_pause_toggle_timestamp = float(timestamp) + logger.info("Pico control event: %s from %s", event_type.value, button_label) + setattr(self, last_toggle_attr, float(timestamp)) emitted = True - self._last_pause_button_pressed = pressed + setattr(self, last_pressed_attr, pressed) return emitted @staticmethod - def _resolve_button_path(pause_button: str | None) -> tuple[str, str] | None: - if pause_button is None: + def _resolve_button_path(button: str | None) -> tuple[str, str] | None: + if button is None: return None - return _PAUSE_BUTTON_MAP.get(pause_button) + return _PAUSE_BUTTON_MAP.get(button) @staticmethod def _read_controller_state(controller: Any) -> PicoControllerState: diff --git a/teleopit/inputs/realtime_packet.py b/teleopit/inputs/realtime_packet.py index e10cbd4b..5299fe03 100644 --- a/teleopit/inputs/realtime_packet.py +++ b/teleopit/inputs/realtime_packet.py @@ -16,6 +16,7 @@ class ControlEventType(str, Enum): TOGGLE_PAUSE = "toggle_pause" + TOGGLE_ARMS = "toggle_arms" @dataclass(frozen=True) diff --git a/teleopit/runtime/arm_mocap.py b/teleopit/runtime/arm_mocap.py new file mode 100644 index 00000000..3144749d --- /dev/null +++ b/teleopit/runtime/arm_mocap.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from typing import Any + +import numpy as np +from numpy.typing import NDArray + +from teleopit.constants import ROOT_DIM +from teleopit.runtime.common import cfg_get +from teleopit.sim.reference_timeline import ReferenceSample, ReferenceWindow + + +Float64Array = NDArray[np.float64] + + +def parse_arm_joint_indices(cfg: Any, *, num_actions: int) -> NDArray[np.int64]: + arm_mocap_cfg = cfg_get(cfg, "arm_mocap", {}) or {} + raw = cfg_get(arm_mocap_cfg, "controlled_joint_indices", None) + default = list(range(15, num_actions)) if num_actions > 15 else [num_actions - 1] + indices = np.asarray(default if raw is None else raw, dtype=np.int64).reshape(-1) + if indices.size == 0 or np.any(indices < 0) or np.any(indices >= num_actions): + raise ValueError( + f"arm_mocap.controlled_joint_indices must contain valid joint indices in [0, {num_actions}), " + f"got {indices.tolist()}" + ) + if np.unique(indices).shape[0] != indices.shape[0]: + raise ValueError("arm_mocap.controlled_joint_indices must not contain duplicates") + return indices + + +def compose_arm_reference( + *, + standing_qpos: Float64Array, + retarget_qpos: Float64Array, + arm_joint_indices: NDArray[np.int64], + num_actions: int, +) -> Float64Array: + retarget = np.asarray(retarget_qpos, dtype=np.float64).reshape(-1) + if retarget.shape[0] < ROOT_DIM + num_actions: + raise ValueError(f"Retargeted qpos too short: {retarget.shape[0]} (need >= {ROOT_DIM + num_actions})") + composed = np.asarray(standing_qpos, dtype=np.float64).reshape(-1).copy() + joint_indices = ROOT_DIM + np.asarray(arm_joint_indices, dtype=np.int64).reshape(-1) + composed[joint_indices] = retarget[joint_indices] + return composed + + +def compose_arm_reference_window( + reference_window: ReferenceWindow | None, + *, + standing_qpos: Float64Array, + arm_joint_indices: NDArray[np.int64], + num_actions: int, +) -> ReferenceWindow | None: + if reference_window is None: + return None + samples = tuple( + ReferenceSample( + qpos=compose_arm_reference( + standing_qpos=standing_qpos, + retarget_qpos=sample.qpos, + arm_joint_indices=arm_joint_indices, + num_actions=num_actions, + ), + timestamp_s=float(sample.timestamp_s), + mode=str(sample.mode), + used_fallback=bool(sample.used_fallback), + older_timestamp_s=sample.older_timestamp_s, + newer_timestamp_s=sample.newer_timestamp_s, + alpha=sample.alpha, + ) + for sample in reference_window.samples + ) + return ReferenceWindow( + base_time_s=float(reference_window.base_time_s), + policy_dt_s=float(reference_window.policy_dt_s), + reference_steps=tuple(reference_window.reference_steps), + samples=samples, + ) diff --git a/teleopit/sim/loop.py b/teleopit/sim/loop.py index 6de109b1..36918b7b 100644 --- a/teleopit/sim/loop.py +++ b/teleopit/sim/loop.py @@ -10,6 +10,7 @@ from teleopit.constants import FULL_QPOS_DIM, ROOT_DIM from teleopit.controllers.observation import align_motion_qpos_yaw from teleopit.runtime.reference_config import parse_reference_config +from teleopit.runtime.arm_mocap import compose_arm_reference, parse_arm_joint_indices from teleopit.inputs.realtime_packet import RealtimeInputPacket from teleopit.interfaces import Controller, InputProvider, MessageBus, ObservationBuilder, Recorder, Retargeter, Robot, RobotState from teleopit.sim.reference_timeline import ( @@ -34,6 +35,7 @@ class SimulationMode(Enum): IDLE = "idle" STANDING = "standing" MOCAP = "mocap" + ARMS = "arms" @final @@ -76,6 +78,7 @@ def __init__( self._last_action: Float32Array = np.zeros((self._num_actions,), dtype=np.float32) self._last_retarget_qpos: Float64Array | None = None self._standing_qpos: Float64Array | None = None + self._arm_joint_indices = parse_arm_joint_indices(cfg, num_actions=self._num_actions) self._realtime: bool = bool(self._try_get_cfg("realtime") or False) raw_debug_trace_path = self._try_get_cfg("debug_trace_path") self._debug_trace_path: str | None = None @@ -172,6 +175,17 @@ def _set_standing_reference(self, state: RobotState) -> Float64Array: self._standing_qpos = standing_qpos.copy() return standing_qpos + def _compose_arm_reference(self, retarget_qpos: Float64Array) -> Float64Array: + if self._standing_qpos is None: + self._set_standing_reference(self.robot.get_state()) + assert self._standing_qpos is not None + return compose_arm_reference( + standing_qpos=self._standing_qpos, + retarget_qpos=retarget_qpos, + arm_joint_indices=self._arm_joint_indices, + num_actions=self._num_actions, + ) + @staticmethod def _drain_realtime_control_events(input_provider: InputProvider) -> tuple[object, ...]: pop_control_events = getattr(input_provider, "pop_control_events", None) diff --git a/teleopit/sim/session.py b/teleopit/sim/session.py index df228c6a..53f7e3ab 100644 --- a/teleopit/sim/session.py +++ b/teleopit/sim/session.py @@ -33,6 +33,7 @@ ) from teleopit.sim.realtime_utils import RealtimeReferenceDiagnostics, RealtimeReferenceManager from teleopit.sim.runtime_components import MotionPreparation +from teleopit.runtime.arm_mocap import compose_arm_reference_window from teleopit.runtime.mocap_session import MocapSessionManager, MocapSessionState from teleopit.runtime.offline_playback import OfflinePlaybackController from teleopit.runtime.terminal_keyboard import TerminalKeyboardReader @@ -258,6 +259,30 @@ def enter_mocap_mode(self) -> None: self.last_commanded_motion_qpos = start_qpos.copy() self.simulation_mode = SimulationMode.MOCAP + def toggle_arms_mode(self) -> None: + from teleopit.sim.loop import SimulationMode + if not self.realtime_interpolated_input or self.simulation_mode not in (SimulationMode.MOCAP, SimulationMode.ARMS): + return + if self.mocap_session.state == MocapSessionState.PAUSED: + _logger.info("Ignoring arm-only mode toggle while mocap session is paused") + return + loop = self._loop + state = loop.robot.get_state() + resume_qpos = loop._build_resume_alignment_qpos(self.last_commanded_motion_qpos, state) + if self.simulation_mode == SimulationMode.MOCAP: + loop._set_standing_reference(state) + self.simulation_mode = SimulationMode.ARMS + else: + self.simulation_mode = SimulationMode.MOCAP + self._step_runner.reset() + loop.controller.reset() + loop.obs_builder.reset() + self.mocap_session.reset() + self.last_commanded_motion_qpos = None + self._step_runner.reset_reference_alignment(resume_qpos) + self.last_commanded_motion_qpos = resume_qpos.copy() + _logger.info("Simulation mode -> %s", self.simulation_mode.value.upper()) + def toggle_realtime_mocap_pause(self) -> None: loop = self._loop if self.mocap_session.state == MocapSessionState.PAUSED: @@ -299,6 +324,9 @@ def _handle_realtime_keyboard(self) -> bool: if key == "x": self.enter_standing_mode() continue + if key == "b": + self.toggle_arms_mode() + continue if key == "a": self.toggle_realtime_mocap_pause() return False @@ -413,9 +441,11 @@ def _fetch_realtime_input(self) -> tuple[bool, ReferenceWindow | None, RealtimeR frame_timestamp = float(packet.timestamp_s) frame_seq = int(packet.seq) for control_event in packet.control_events: - if control_event.event_type != ControlEventType.TOGGLE_PAUSE: + if control_event.event_type == ControlEventType.TOGGLE_ARMS: + self.toggle_arms_mode() continue - self.toggle_realtime_mocap_pause() + if control_event.event_type == ControlEventType.TOGGLE_PAUSE: + self.toggle_realtime_mocap_pause() new_bvh_frame = frame_seq != self.last_live_packet_seq if self.mocap_session.state == MocapSessionState.PAUSED: @@ -499,6 +529,18 @@ def _fetch_realtime_input(self) -> tuple[bool, ReferenceWindow | None, RealtimeR else: self.cached_retargeted = self.latest_live_retargeted + from teleopit.sim.loop import SimulationMode + if self.simulation_mode == SimulationMode.ARMS: + self.cached_retargeted = loop._compose_arm_reference(cast(Float64Array, self.cached_retargeted)) + if reference_window is not None: + assert loop._standing_qpos is not None + reference_window = compose_arm_reference_window( + reference_window, + standing_qpos=loop._standing_qpos, + arm_joint_indices=loop._arm_joint_indices, + num_actions=loop._num_actions, + ) + return new_bvh_frame, reference_window, realtime_reference_diag, False def _fetch_simple_bvh_input(self, frame_f: float) -> tuple[bool, bool]: @@ -547,7 +589,7 @@ def run(self) -> dict[str, float | int]: if self.playback_stop_requested: break - if self.realtime_keyboard_mode_enabled and self.simulation_mode != SimulationMode.MOCAP: + if self.realtime_keyboard_mode_enabled and self.simulation_mode == SimulationMode.STANDING: loop._drain_realtime_control_events(self._input_provider) # --- Compute time/frame --- diff --git a/teleopit/sim2real/mp/runtime.py b/teleopit/sim2real/mp/runtime.py index 4482221f..ca39fdf8 100644 --- a/teleopit/sim2real/mp/runtime.py +++ b/teleopit/sim2real/mp/runtime.py @@ -25,6 +25,11 @@ from teleopit.runtime.offline_playback import OfflinePlaybackController from teleopit.runtime.common import cfg_get, parse_viewers, require_section from teleopit.runtime.factory import _build_policy_components, build_simulation_cfg +from teleopit.runtime.arm_mocap import ( + compose_arm_reference, + compose_arm_reference_window, + parse_arm_joint_indices, +) from teleopit.runtime.mocap_session import MocapSessionManager, MocapSessionState from teleopit.runtime.reference_config import parse_reference_config from teleopit.sim.reference_motion import OfflineReferenceMotion @@ -82,6 +87,7 @@ class RobotMode(Enum): IDLE = "idle" STANDING = "standing" MOCAP = "mocap" + ARMS = "arms" DAMPING = "damping" @@ -371,6 +377,8 @@ def _main() -> None: timestamp_gap_reset_s=float(cfg_get(input_cfg, "pico4_timestamp_gap_reset_s", 0.15)), pause_button=cfg_get(input_cfg, "pause_button", "A"), pause_debounce_s=float(cfg_get(input_cfg, "pause_debounce_s", 0.25)), + arms_button=cfg_get(input_cfg, "arms_button", "B"), + arms_debounce_s=float(cfg_get(input_cfg, "arms_debounce_s", cfg_get(input_cfg, "pause_debounce_s", 0.25))), bridge_host=str(cfg_get(input_cfg, "bridge_host", "0.0.0.0")), bridge_port=int(cfg_get(input_cfg, "bridge_port", 63901)), bridge_discovery=bool(cfg_get(input_cfg, "bridge_discovery", True)), @@ -813,6 +821,7 @@ def __init__( self._standing_return_kp_ramp_floor_ratio = float( cfg_get(cfg, "standing_return_kp_ramp_floor_ratio", 0.5) ) + self._arm_joint_indices = parse_arm_joint_indices(cfg, num_actions=self.num_actions) self._ref_cfg = parse_reference_config(cfg, provider_fps=None) self._reference_window_builder = ReferenceWindowBuilder( @@ -878,7 +887,7 @@ def run(self) -> None: self._handle_transitions() if self.mode == RobotMode.STANDING: self._standing_step() - elif self.mode == RobotMode.MOCAP: + elif self.mode in (RobotMode.MOCAP, RobotMode.ARMS): self._mocap_step() self._publish_mode_state() @@ -894,7 +903,7 @@ def run(self) -> None: self.shutdown() def shutdown(self) -> None: - if self.mode in (RobotMode.STANDING, RobotMode.MOCAP): + if self.mode in (RobotMode.STANDING, RobotMode.MOCAP, RobotMode.ARMS): try: self.robot.set_damping() time.sleep(0.5) @@ -952,7 +961,7 @@ def _handle_transitions(self) -> None: self._transition_to_mocap() else: logger.warning("Cannot switch to MOCAP -- no fresh retarget reference") - elif self.mode == RobotMode.MOCAP: + elif self.mode in (RobotMode.MOCAP, RobotMode.ARMS): if self.provider_kind == "bvh" and self.remote.B.on_pressed: logger.info("B pressed -> replaying BVH motion from start") self._send_reference_command("replay_mocap") @@ -1031,7 +1040,13 @@ def _execute_mocap_pipeline( robot_state: object, reference_window: ReferenceWindow | None, ) -> None: + reference_window_aligned = False reference_qpos = self._ref_proc.align_reference_yaw(reference_qpos, robot_state=robot_state) + if self.mode == RobotMode.ARMS: + reference_qpos = self._compose_arm_reference(reference_qpos) + aligned_window = self._ref_proc.align_reference_window(reference_window, robot_state) + reference_window = self._compose_arm_reference_window(aligned_window) + reference_window_aligned = True qpos = reference_qpos.copy() if qpos.shape[0] < 7 + self.num_actions: raise ValueError(f"Retargeted qpos too short: {qpos.shape[0]} (need >= {7 + self.num_actions})") @@ -1058,6 +1073,7 @@ def _execute_mocap_pipeline( anchor_lin_vel_w=anchor_lin_vel_w, anchor_ang_vel_w=anchor_ang_vel_w, reference_window=reference_window, + reference_window_aligned=reference_window_aligned, ) obs = self._ref_proc.validate_observation(obs) action = self.policy.compute_action(obs) @@ -1070,9 +1086,25 @@ def _execute_mocap_pipeline( self._last_mocap_hold_reason = None self._write_retarget_viewer(qpos) + def _compose_arm_reference(self, retarget_qpos: Float64Array) -> Float64Array: + return compose_arm_reference( + standing_qpos=self._standing_qpos, + retarget_qpos=retarget_qpos, + arm_joint_indices=self._arm_joint_indices, + num_actions=self.num_actions, + ) + + def _compose_arm_reference_window(self, reference_window: ReferenceWindow | None) -> ReferenceWindow | None: + return compose_arm_reference_window( + reference_window, + standing_qpos=self._standing_qpos, + arm_joint_indices=self._arm_joint_indices, + num_actions=self.num_actions, + ) + def _enter_standing(self) -> None: prev_mode = self.mode - already_in_debug = self.mode in (RobotMode.STANDING, RobotMode.MOCAP) + already_in_debug = self.mode in (RobotMode.STANDING, RobotMode.MOCAP, RobotMode.ARMS) if not already_in_debug: logger.info("Entering debug mode...") ok = self.robot.enter_debug_mode() @@ -1082,7 +1114,7 @@ def _enter_standing(self) -> None: time.sleep(0.5) state = self.robot.get_state() - if prev_mode != RobotMode.MOCAP: + if prev_mode not in (RobotMode.MOCAP, RobotMode.ARMS): logger.info("Locking joints to current position...") self.robot.lock_all_joints() time.sleep(0.3) @@ -1094,14 +1126,14 @@ def _enter_standing(self) -> None: self._last_commanded_motion_qpos = None self._set_default_standing_reference(state) self._reset_policy_state() - if prev_mode == RobotMode.MOCAP: + if prev_mode in (RobotMode.MOCAP, RobotMode.ARMS): self._safety.start_kp_ramp( duration_s=self._standing_return_ramp_duration, floor_ratio=self._standing_return_kp_ramp_floor_ratio, ) else: self._safety.start_kp_ramp() - self._mocap_reentry_armed = prev_mode == RobotMode.MOCAP + self._mocap_reentry_armed = prev_mode in (RobotMode.MOCAP, RobotMode.ARMS) self.mode = RobotMode.STANDING logger.info("Mode -> STANDING (multiprocess robot control)") @@ -1126,7 +1158,9 @@ def _can_switch_to_mocap(self) -> bool: def _transition_to_mocap(self) -> None: state = self.robot.get_state() - resume_qpos = self._build_resume_alignment_qpos(self._standing_qpos, state) + last_commanded = getattr(self, "_last_commanded_motion_qpos", None) + hold_qpos = last_commanded if last_commanded is not None else self._standing_qpos + resume_qpos = self._build_resume_alignment_qpos(hold_qpos, state) self._mocap_reentry_armed = False self._reset_policy_state() self._last_retarget_qpos = None @@ -1137,12 +1171,35 @@ def _transition_to_mocap(self) -> None: self.mode = RobotMode.MOCAP logger.info("Mode -> MOCAP (tracking multiprocess retarget reference)") + def _toggle_arms_mode(self) -> None: + if self.provider_kind != "pico4" or self.mode not in (RobotMode.MOCAP, RobotMode.ARMS): + return + if self._mocap_session.state == MocapSessionState.PAUSED: + logger.info("Ignoring Pico B mode toggle while mocap session is paused") + return + + state = self.robot.get_state() + resume_qpos = self._build_resume_alignment_qpos(self._last_commanded_motion_qpos, state) + next_mode = RobotMode.ARMS if self.mode == RobotMode.MOCAP else RobotMode.MOCAP + if next_mode == RobotMode.ARMS: + self._set_default_standing_reference(state) + self._reset_policy_state() + self._last_retarget_qpos = None + self._last_commanded_motion_qpos = resume_qpos.copy() + self._ref_proc.reset_alignment(target_qpos=resume_qpos) + self._safety.start_kp_ramp( + duration_s=self._standing_return_ramp_duration, + floor_ratio=self._standing_return_kp_ramp_floor_ratio, + ) + self.mode = next_mode + logger.info("Mode -> %s (Pico B toggle)", next_mode.value.upper()) + def _resume_paused_mocap_if_needed(self) -> None: if self._mocap_session.state == MocapSessionState.PAUSED: self._resume_paused_mocap() def _enter_damping(self) -> None: - if self.mode in (RobotMode.STANDING, RobotMode.MOCAP): + if self.mode in (RobotMode.STANDING, RobotMode.MOCAP, RobotMode.ARMS): logger.info("DAMPING: sending LowCmd damping...") self.robot.set_damping() time.sleep(0.5) @@ -1207,14 +1264,16 @@ def _build_resume_alignment_qpos(self, hold_qpos: Float64Array | None, state: ob def _handle_mocap_control_events(self, control_events: tuple[ControlEvent, ...]) -> None: for event in control_events: - if event.event_type != ControlEventType.TOGGLE_PAUSE: + if event.event_type == ControlEventType.TOGGLE_ARMS: + self._toggle_arms_mode() continue - if self.mode != RobotMode.MOCAP: - continue - if self._mocap_session.state == MocapSessionState.PAUSED: - self._resume_paused_mocap() - else: - self._pause_active_mocap() + if event.event_type == ControlEventType.TOGGLE_PAUSE: + if self.mode not in (RobotMode.MOCAP, RobotMode.ARMS): + continue + if self._mocap_session.state == MocapSessionState.PAUSED: + self._resume_paused_mocap() + else: + self._pause_active_mocap() def _pause_active_mocap(self) -> None: hold_qpos = self._resolve_mocap_hold_qpos() @@ -1236,6 +1295,12 @@ def _resume_paused_mocap(self) -> None: self._last_retarget_qpos = None self._last_commanded_motion_qpos = resume_qpos.copy() self._ref_proc.reset_alignment(target_qpos=resume_qpos) + if self.mode == RobotMode.ARMS: + self._set_default_standing_reference(state) + self._safety.start_kp_ramp( + duration_s=self._standing_return_ramp_duration, + floor_ratio=self._standing_return_kp_ramp_floor_ratio, + ) logger.info("Mocap session -> ACTIVE (multiprocess episode-reset + reference realignment)") def _send_reference_command(self, command: str) -> None: @@ -1298,8 +1363,9 @@ def _hold_mocap_reference(self, reason: str, *, detail: str | None = None) -> No def _publish_mode_state(self) -> None: self._mode_seq += 1 - active = self.mode == RobotMode.MOCAP and self._mocap_session.state == MocapSessionState.ACTIVE - paused = self.mode == RobotMode.MOCAP and self._mocap_session.state == MocapSessionState.PAUSED + mocap_like = self.mode in (RobotMode.MOCAP, RobotMode.ARMS) + active = mocap_like and self._mocap_session.state == MocapSessionState.ACTIVE + paused = mocap_like and self._mocap_session.state == MocapSessionState.PAUSED self._mode_pub.publish( MODE_TOPIC, ModeStatePacket( @@ -1330,7 +1396,7 @@ def _note_reference_packet(self, reference: ReferencePacket) -> None: if ( self.provider_kind == "bvh" and bool(getattr(reference, "playback_paused", False)) - and self.mode == RobotMode.MOCAP + and self.mode in (RobotMode.MOCAP, RobotMode.ARMS) and self._mocap_session.state == MocapSessionState.ACTIVE ): self._pause_active_mocap() diff --git a/teleopit/sim2real/reference_processor.py b/teleopit/sim2real/reference_processor.py index 32567182..315ba795 100644 --- a/teleopit/sim2real/reference_processor.py +++ b/teleopit/sim2real/reference_processor.py @@ -165,8 +165,11 @@ def build_observation( anchor_lin_vel_w: Float32Array, anchor_ang_vel_w: Float32Array, reference_window: ReferenceWindow | None, + reference_window_aligned: bool = False, ) -> Float32Array: - aligned_reference_window = self.align_reference_window(reference_window, robot_state) + aligned_reference_window = ( + reference_window if reference_window_aligned else self.align_reference_window(reference_window, robot_state) + ) return ref_proc.dispatch_build_observation( self._obs_builder, robot_state, reference_window, aligned_reference_window, motion_qpos, motion_joint_vel, last_action, diff --git a/tests/test_pico4_provider.py b/tests/test_pico4_provider.py index 97a4ffde..e785aaae 100644 --- a/tests/test_pico4_provider.py +++ b/tests/test_pico4_provider.py @@ -27,6 +27,7 @@ def _pico_frame( timestamp: float, body_active: bool = True, right_primary: bool = False, + right_secondary: bool = False, ) -> SimpleNamespace: return SimpleNamespace( seq=seq, @@ -34,7 +35,7 @@ def _pico_frame( body=SimpleNamespace(active=body_active, joints=body_poses), controllers=SimpleNamespace( left=SimpleNamespace(buttons={}), - right=SimpleNamespace(buttons={"primaryButton": right_primary}), + right=SimpleNamespace(buttons={"primaryButton": right_primary, "secondaryButton": right_secondary}), ), ) @@ -54,10 +55,15 @@ def _make_provider() -> Pico4InputProvider: provider._timestamp_gap_reset_s = 0.15 provider._pending_control_events = deque() provider._pause_button = "A" + provider._arms_button = "B" provider._pause_debounce_s = 0.0 + provider._arms_debounce_s = 0.0 provider._pause_button_path = provider._resolve_button_path(provider._pause_button) + provider._arms_button_path = provider._resolve_button_path(provider._arms_button) provider._last_pause_button_pressed = False + provider._last_arms_button_pressed = False provider._last_pause_toggle_timestamp = None + provider._last_arms_toggle_timestamp = None provider._last_raw_body_joints = None provider._last_frame_timestamp = None provider._last_source_seq = None @@ -281,6 +287,20 @@ def test_pico4_provider_exposes_pause_control_events_once() -> None: assert packet.control_events == () +def test_pico4_provider_exposes_arms_control_events_once() -> None: + provider = _make_provider() + + assert provider._accept_pico_frame( + _pico_frame(_body_poses(1.0), seq=1, timestamp=1.0, right_secondary=True) + ) is True + + packet = provider.get_realtime_input_packet() + assert [event.event_type for event in packet.control_events] == [ControlEventType.TOGGLE_ARMS] + + packet = provider.get_realtime_input_packet() + assert packet.control_events == () + + def test_pico4_provider_marks_controller_present_without_raw_field() -> None: provider = _make_provider() frame = _pico_frame(_body_poses(1.0), seq=1, timestamp=1.0) diff --git a/tests/test_sim2real_multiprocess.py b/tests/test_sim2real_multiprocess.py index ae765e1c..f26ca68c 100644 --- a/tests/test_sim2real_multiprocess.py +++ b/tests/test_sim2real_multiprocess.py @@ -9,8 +9,11 @@ import pytest from teleopit.runtime.mocap_session import MocapSessionState +from teleopit.inputs.realtime_packet import ControlEvent, ControlEventType +from teleopit.runtime.arm_mocap import compose_arm_reference, compose_arm_reference_window from teleopit.sim2real.mp.ipc import HEALTH_TOPIC, LatestSubscriber, ZmqPublisher from teleopit.sim2real.mp.messages import ReferencePacket, SharedFrameDescriptor +from teleopit.sim.reference_timeline import ReferenceSample, ReferenceWindow from teleopit.sim2real.mp.runtime import ( RobotMode, Sim2RealRuntime, @@ -359,3 +362,103 @@ def pause_active_mocap() -> None: assert paused == ["pause"] assert worker._latest_reference is packet + + +def test_robot_worker_composes_arm_reference_from_standing_pose() -> None: + standing_qpos = np.arange(36, dtype=np.float64) + retarget = np.full(36, 100.0, dtype=np.float64) + retarget[7 + 15:7 + 29] = np.arange(14, dtype=np.float64) + 200.0 + + composed = compose_arm_reference( + standing_qpos=standing_qpos, + retarget_qpos=retarget, + arm_joint_indices=np.arange(15, 29, dtype=np.int64), + num_actions=29, + ) + + np.testing.assert_allclose(composed[: 7 + 15], standing_qpos[: 7 + 15]) + np.testing.assert_allclose(composed[7 + 15:7 + 29], retarget[7 + 15:7 + 29]) + + +def test_robot_worker_composes_arm_reference_window_samples() -> None: + standing_qpos = np.zeros(36, dtype=np.float64) + qpos0 = np.ones(36, dtype=np.float64) + qpos1 = np.full(36, 2.0, dtype=np.float64) + window = ReferenceWindow( + base_time_s=1.0, + policy_dt_s=0.02, + reference_steps=(0, 1), + samples=( + ReferenceSample(qpos=qpos0, timestamp_s=1.0, mode="a", used_fallback=False, older_timestamp_s=None, newer_timestamp_s=None, alpha=None), + ReferenceSample(qpos=qpos1, timestamp_s=1.02, mode="b", used_fallback=False, older_timestamp_s=None, newer_timestamp_s=None, alpha=None), + ), + ) + + composed = compose_arm_reference_window( + window, + standing_qpos=standing_qpos, + arm_joint_indices=np.arange(15, 29, dtype=np.int64), + num_actions=29, + ) + + assert composed is not None + np.testing.assert_allclose(composed.samples[0].qpos[7 + 15:7 + 29], 1.0) + np.testing.assert_allclose(composed.samples[1].qpos[7 + 15:7 + 29], 2.0) + np.testing.assert_allclose(composed.samples[0].qpos[:7 + 15], 0.0) + np.testing.assert_allclose(composed.samples[1].qpos[:7 + 15], 0.0) + + +def test_robot_worker_pico_arms_event_toggles_mocap_and_arms() -> None: + worker = object.__new__(_RobotControlWorker) + worker.provider_kind = "pico4" + worker.mode = RobotMode.MOCAP + worker._mocap_session = SimpleNamespace(state=MocapSessionState.ACTIVE) + worker.robot = SimpleNamespace(get_state=lambda: SimpleNamespace(base_pos=np.zeros(3), quat=np.array([1.0, 0.0, 0.0, 0.0]), qpos=np.zeros(29))) + worker._last_commanded_motion_qpos = np.zeros(36, dtype=np.float64) + worker._build_resume_alignment_qpos = lambda _hold, _state: np.ones(36, dtype=np.float64) + worker._set_default_standing_reference = lambda _state: None + worker._reset_policy_state = lambda: None + worker._last_retarget_qpos = np.zeros(36, dtype=np.float64) + resets: list[np.ndarray] = [] + ramps: list[str] = [] + worker._ref_proc = SimpleNamespace(reset_alignment=lambda *, target_qpos=None: resets.append(np.asarray(target_qpos).copy())) + worker._standing_return_ramp_duration = 0.5 + worker._standing_return_kp_ramp_floor_ratio = 0.5 + worker._safety = SimpleNamespace(start_kp_ramp=lambda **_kwargs: ramps.append("ramp")) + + event = ControlEvent(event_type=ControlEventType.TOGGLE_ARMS, source="test") + worker._handle_mocap_control_events((event,)) + assert worker.mode == RobotMode.ARMS + + worker._handle_mocap_control_events((event,)) + assert worker.mode == RobotMode.MOCAP + assert len(resets) == 2 + assert ramps == ["ramp", "ramp"] + + +def test_robot_worker_bvh_ignores_pico_arms_event() -> None: + worker = object.__new__(_RobotControlWorker) + worker.provider_kind = "bvh" + worker.mode = RobotMode.MOCAP + worker._mocap_session = SimpleNamespace(state=MocapSessionState.ACTIVE) + worker._handle_mocap_control_events(( + ControlEvent(event_type=ControlEventType.TOGGLE_ARMS, source="test"), + )) + + assert worker.mode == RobotMode.MOCAP + + +def test_robot_worker_mode_state_marks_arms_as_mocap_active() -> None: + worker = object.__new__(_RobotControlWorker) + worker.mode = RobotMode.ARMS + worker._mocap_session = SimpleNamespace(state=MocapSessionState.ACTIVE) + worker._mode_seq = 0 + published: list[object] = [] + worker._mode_pub = SimpleNamespace(publish=lambda _topic, packet: published.append(packet)) + + worker._publish_mode_state() + + packet = published[-1] + assert packet.mode == "arms" + assert packet.mocap_active is True + assert packet.mocap_paused is False diff --git a/tests/test_sim_loop.py b/tests/test_sim_loop.py index f54e59f5..590ad9d8 100644 --- a/tests/test_sim_loop.py +++ b/tests/test_sim_loop.py @@ -56,6 +56,7 @@ def reset(self) -> None: class _DummyObsBuilder: def __init__(self) -> None: self.mimic_obs_calls: list[np.ndarray] = [] + self.motion_qpos_calls: list[np.ndarray] = [] self._base = _DummyObsBuilderBase() def reset(self) -> None: @@ -72,6 +73,7 @@ def build( ) -> np.ndarray: del state, motion_joint_vel, last_action, motion_anchor_lin_vel_w, motion_anchor_ang_vel_w self.mimic_obs_calls.append(np.asarray(motion_qpos[:1], dtype=np.float32).copy()) + self.motion_qpos_calls.append(np.asarray(motion_qpos, dtype=np.float32).copy()) return np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32) @@ -666,6 +668,138 @@ def close(self) -> None: np.testing.assert_allclose(obs_builder.mimic_obs_calls[2], np.array([0.0], dtype=np.float32), atol=1e-6) +@requires_mujoco +def test_simulation_loop_pico_arms_mode_composes_standing_body_with_live_arm(monkeypatch) -> None: + from teleopit.sim.loop import SimulationLoop + + class _RealtimeInputProvider: + fps = 1 + + def __init__(self) -> None: + self._packets = [ + RealtimeInputPacket( + frame={ + "Pelvis": ( + np.array([0.3, 0.0, 0.0], dtype=np.float32), + np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), + ) + }, + timestamp_s=0.0, + seq=0, + control_events=(), + ), + RealtimeInputPacket( + frame={ + "Pelvis": ( + np.array([0.9, 0.0, 0.0], dtype=np.float32), + np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), + ) + }, + timestamp_s=1.0, + seq=1, + control_events=(), + ), + RealtimeInputPacket( + frame={ + "Pelvis": ( + np.array([0.9, 0.0, 0.0], dtype=np.float32), + np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), + ) + }, + timestamp_s=2.0, + seq=2, + control_events=(ControlEvent(event_type=ControlEventType.TOGGLE_ARMS, source="pico4:test"),), + ), + RealtimeInputPacket( + frame={ + "Pelvis": ( + np.array([1.2, 0.0, 0.0], dtype=np.float32), + np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), + ) + }, + timestamp_s=3.0, + seq=3, + control_events=(ControlEvent(event_type=ControlEventType.TOGGLE_ARMS, source="pico4:test"),), + ), + ] + self._idx = 0 + + def get_realtime_input_packet(self): + packet = self._packets[min(self._idx, len(self._packets) - 1)] + self._idx += 1 + return packet + + class _Retargeter: + def retarget(self, frame: object) -> np.ndarray: + pelvis = np.asarray(frame["Pelvis"][0], dtype=np.float64) + qpos = np.zeros(36, dtype=np.float64) + qpos[0] = pelvis[0] + qpos[3] = 1.0 + qpos[7] = pelvis[0] + qpos[8] = pelvis[0] + 10.0 + return qpos + + class _KeyboardReader: + def __init__(self) -> None: + self._polls = [ + (TerminalKeyEvent("y"),), + (), + (), + (), + ] + self._idx = 0 + + @property + def active(self) -> bool: + return True + + def poll(self) -> tuple[TerminalKeyEvent, ...]: + if self._idx >= len(self._polls): + return () + events = self._polls[self._idx] + self._idx += 1 + return events + + def close(self) -> None: + pass + + monkeypatch.setattr("teleopit.sim.session.TerminalKeyboardReader", _KeyboardReader) + + robot = _DummyRobot() + obs_builder = _DummyObsBuilder() + loop = SimulationLoop( + robot=robot, + controller=_DummyController(), + obs_builder=obs_builder, + bus=InProcessBus(), + cfg={ + "policy_hz": 50.0, + "pd_hz": 50.0, + "realtime": False, + "retarget_buffer_enabled": False, + "realtime_input_delay_s": 0.0, + "keyboard": {"enabled": True}, + "arm_mocap": {"controlled_joint_indices": [1]}, + }, + viewers=set(), + ) + + result = loop.run( + input_provider=_RealtimeInputProvider(), + retargeter=_Retargeter(), + num_steps=4, + ) + + assert result["steps"] == 4 + # Step 2 is ARMS: root/non-arm stays at standing while arm index 1 follows retarget. + np.testing.assert_allclose(obs_builder.motion_qpos_calls[2][0], 0.0, atol=1e-6) + np.testing.assert_allclose(obs_builder.motion_qpos_calls[2][7], 0.5, atol=1e-6) + np.testing.assert_allclose(obs_builder.motion_qpos_calls[2][8], 10.9, atol=1e-6) + # Step 3 toggles back to full-body MOCAP; root XY is reanchored, while non-arm joints follow retarget again. + np.testing.assert_allclose(obs_builder.motion_qpos_calls[3][0], 0.0, atol=1e-6) + np.testing.assert_allclose(obs_builder.motion_qpos_calls[3][7], 1.2, atol=1e-6) + + @requires_mujoco def test_simulation_loop_realtime_keyboard_mode_drains_stale_pause_events(monkeypatch) -> None: from teleopit.sim.loop import SimulationLoop From 0bb0df77b26462f489e036e2def78df47e25f213 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Fri, 12 Jun 2026 20:50:09 +0800 Subject: [PATCH 081/122] Fix sim2real standing reference root height --- teleopit/sim2real/mp/runtime.py | 23 +++++++++++++++++------ tests/test_sim2real_multiprocess.py | 16 ++++++++++++++++ 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/teleopit/sim2real/mp/runtime.py b/teleopit/sim2real/mp/runtime.py index ca39fdf8..668e467d 100644 --- a/teleopit/sim2real/mp/runtime.py +++ b/teleopit/sim2real/mp/runtime.py @@ -815,6 +815,12 @@ def __init__( robot_cfg = cfg_get(cfg, "robot") self.default_angles = np.asarray(cfg_get(robot_cfg, "default_angles"), dtype=np.float32) + default_root_qpos = np.asarray( + cfg_get(robot_cfg, "mujoco_default_qpos", [0.0, 0.0, 0.0]), dtype=np.float64 + ).reshape(-1) + self._default_root_pos = np.zeros(3, dtype=np.float64) + if default_root_qpos.shape[0] >= 3: + self._default_root_pos[:] = default_root_qpos[:3] self.num_actions = int(cfg_get(robot_cfg, "num_actions", NUM_JOINTS)) self._safety = Sim2RealSafetyManager(cfg, self.robot, self.policy_hz, self.num_actions) self._standing_return_ramp_duration = float(cfg_get(cfg, "standing_return_ramp_duration", 0.5)) @@ -1235,9 +1241,7 @@ def _reset_policy_reference_state(self) -> None: def _build_robot_state_qpos(self, state: object) -> Float64Array: qpos = np.zeros(FULL_QPOS_DIM, dtype=np.float64) - base_pos = getattr(state, "base_pos", None) - if base_pos is not None: - qpos[0:3] = np.asarray(base_pos, dtype=np.float64).reshape(-1)[:3] + qpos[0:3] = self._resolve_base_pos(state) qpos[3:7] = np.asarray(getattr(state, "quat"), dtype=np.float64).reshape(-1)[:4] qpos[ROOT_DIM:FULL_QPOS_DIM] = np.asarray(getattr(state, "qpos"), dtype=np.float64).reshape(-1)[ : self.num_actions @@ -1246,13 +1250,20 @@ def _build_robot_state_qpos(self, state: object) -> Float64Array: def _set_default_standing_reference(self, state: object) -> None: self._standing_qpos[:] = 0.0 - base_pos = getattr(state, "base_pos", None) - if base_pos is not None: - self._standing_qpos[0:3] = np.asarray(base_pos, dtype=np.float64).reshape(-1)[:3] + self._standing_qpos[0:3] = self._resolve_base_pos(state) self._standing_qpos[3:7] = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64) align_motion_qpos_yaw(np.asarray(getattr(state, "quat"), dtype=np.float32), self._standing_qpos) self._standing_qpos[ROOT_DIM:FULL_QPOS_DIM] = self.default_angles.astype(np.float64) + def _resolve_base_pos(self, state: object) -> Float64Array: + base_pos = getattr(state, "base_pos", None) + if base_pos is None: + return self._default_root_pos.copy() + resolved = self._default_root_pos.copy() + live = np.asarray(base_pos, dtype=np.float64).reshape(-1) + resolved[: min(3, live.shape[0])] = live[:3] + return resolved + def _build_resume_alignment_qpos(self, hold_qpos: Float64Array | None, state: object) -> Float64Array: qpos = self._build_robot_state_qpos(state) if hold_qpos is not None: diff --git a/tests/test_sim2real_multiprocess.py b/tests/test_sim2real_multiprocess.py index f26ca68c..2a4dd7c4 100644 --- a/tests/test_sim2real_multiprocess.py +++ b/tests/test_sim2real_multiprocess.py @@ -334,6 +334,22 @@ def test_robot_worker_replays_bvh_on_mocap_entry() -> None: assert commands == ["replay_mocap"] +def test_robot_worker_standing_reference_uses_default_root_height_without_base_pos() -> None: + worker = object.__new__(_RobotControlWorker) + worker.default_angles = np.zeros(29, dtype=np.float32) + worker.num_actions = 29 + worker._default_root_pos = np.array([0.0, 0.0, 0.76], dtype=np.float64) + worker._standing_qpos = np.zeros(36, dtype=np.float64) + state = SimpleNamespace( + quat=np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), + qpos=np.zeros(29, dtype=np.float32), + ) + + worker._set_default_standing_reference(state) + + np.testing.assert_allclose(worker._standing_qpos[0:3], np.array([0.0, 0.0, 0.76])) + + def test_robot_worker_pauses_when_bvh_reference_reports_paused() -> None: worker = object.__new__(_RobotControlWorker) worker.provider_kind = "bvh" From 4a7d46774425c862af2ad1d09e972dbf0e18ba1b Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Sun, 14 Jun 2026 17:40:25 +0800 Subject: [PATCH 082/122] Remove adaptive motion sampling --- AGENTS.md | 4 +- CHANGELOG.md | 1 - README.md | 4 +- docs/docs/reference/architecture.md | 2 +- docs/docs/tutorials/training.md | 2 +- .../current/reference/architecture.md | 2 +- .../current/tutorials/training.md | 2 +- tests/test_motion_sampling.py | 72 ----- tests/test_task_registry.py | 2 +- train_mimic/scripts/train.py | 2 +- train_mimic/tasks/tracking/config/env.py | 2 +- train_mimic/tasks/tracking/mdp/commands.py | 254 +----------------- train_mimic/tasks/tracking/rl/runner.py | 20 -- 13 files changed, 13 insertions(+), 356 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 2b4d48b5..595e1638 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -194,9 +194,9 @@ The single supported training task is `General-Tracking-G1` (experiment name: `g - Uses TemporalCNN actor/critic with scaled dims (2048,1024,512,256,128) - 167D `velcmd_history` observation, dual-input ONNX export -- Training env uses `sampling_mode="uniform"` +- Training env uses `sampling_mode="rewind"` - Tracking rewards include root position/orientation/linear velocity/angular velocity, body pose/velocity, joint position/velocity, survival, action-rate, joint-limit, self-collision, and ankle acceleration terms -- Supported motion sampling modes are `adaptive`, `uniform`, `start`, and `rewind`; `rewind` restarts failed environments from the same clip after stepping back `rewind_min_steps..rewind_max_steps` with probability `rewind_prob`, otherwise it falls back to uniform sampling +- Supported motion sampling modes are `uniform`, `start`, and `rewind`; `rewind` restarts failed environments from the same clip after stepping back `rewind_min_steps..rewind_max_steps` with probability `rewind_prob`, otherwise it falls back to uniform sampling - Playback/benchmark use `play=True`, which switches motion sampling to `start` - `window_steps=[0]` - `save_onnx.py` exports dual-input TemporalCNN ONNX diff --git a/CHANGELOG.md b/CHANGELOG.md index 124d5579..bbf3975d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,7 +20,6 @@ ## [0.1.1] - 2025-03-28 - 数据集改为 shard-only 输出。 -- 新增 adaptive_bin 采样。 - 引入外部资源管理并瘦身仓库。 ## [0.1.0] - 2025-03-25 diff --git a/README.md b/README.md index 8b88adb7..8745fd7e 100644 --- a/README.md +++ b/README.md @@ -95,7 +95,7 @@ Full docs at **[BotRunner64.github.io/Teleopit](https://BotRunner64.github.io/Te - Switched default `vr_hand_pose` to a low-latency somehand path with 60 Hz hand retargeting and reduced smoothing. - Realtime mode switches and pause/resume now preserve GMR IK warm-starts instead of cold-starting the retargeter on each transition. - Added an interactive Pico motion recorder that saves retargeted G1 motion clips as training-ready NPZ files. -- General-Tracking-G1 training now defaults to uniform motion sampling; clip-local adaptive sampling remains available through `sampling_mode=adaptive`. +- General-Tracking-G1 training defaults to `rewind` motion sampling and also supports `uniform`; playback/benchmark use `start`. - Added optional `sampling_mode=rewind` for training, which restarts failed episodes from the same clip after rewinding a configurable number of policy steps. - Added root velocity, joint tracking, and survival rewards to the General-Tracking-G1 training objective. - Renamed General-Tracking-G1 observation terms to explicit `ref_*`, `robot_*`, and `prev_action` keys. @@ -115,7 +115,7 @@ Full docs at **[BotRunner64.github.io/Teleopit](https://BotRunner64.github.io/Te ### v0.1.1 (2025-03-28) -- Dataset shard-only refactor and `adaptive_bin` sampling +- Dataset shard-only refactor - External asset management (ModelScope), repository slimming ### v0.1.0 (2025-03-25) diff --git a/docs/docs/reference/architecture.md b/docs/docs/reference/architecture.md index 1606ea6e..ce51f25a 100644 --- a/docs/docs/reference/architecture.md +++ b/docs/docs/reference/architecture.md @@ -58,7 +58,7 @@ train_mimic/scripts/data | Inference observation | `velcmd_history` (167D) | | ONNX signature | Dual-input `obs` (167D) + `obs_history` | | Actor/Critic | TemporalCNN (2048, 1024, 512, 256, 128) | -| Training sampling | Default `uniform`; also supports `adaptive` and `rewind`; playback/benchmark use `start` | +| Training sampling | Default `rewind`; also supports `uniform`; playback/benchmark use `start` | | Training `window_steps` | `[0]` | | Data format | Shard directories only (`shard_*.npz`) | diff --git a/docs/docs/tutorials/training.md b/docs/docs/tutorials/training.md index e3356ae2..50aacb7e 100644 --- a/docs/docs/tutorials/training.md +++ b/docs/docs/tutorials/training.md @@ -131,4 +131,4 @@ Key files: - `train_mimic/app.py` - Shared entry point for train/play/benchmark - `train_mimic/tasks/tracking/config/env.py` - General-Tracking-G1 env builder - `train_mimic/tasks/tracking/config/rl.py` - TemporalCNN PPO config -- `train_mimic/tasks/tracking/mdp/commands.py` - Supports `uniform`, `start`, `adaptive`, and `rewind` sampling modes. Training defaults to `uniform`; playback/benchmark use `start`. +- `train_mimic/tasks/tracking/mdp/commands.py` - Supports `uniform`, `start`, and `rewind` sampling modes. Training defaults to `rewind`; playback/benchmark use `start`. diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/architecture.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/architecture.md index d82232c2..e43857c3 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/architecture.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/architecture.md @@ -58,7 +58,7 @@ train_mimic/scripts/data | 推理观测 | `velcmd_history`(167D) | | ONNX 签名 | 双输入 `obs`(167D)+ `obs_history` | | Actor/Critic | TemporalCNN(2048、1024、512、256、128) | -| 训练采样 | 默认 `uniform`;也支持 `adaptive` 和 `rewind`;播放/评估使用 `start` | +| 训练采样 | 默认 `rewind`;也支持 `uniform`;播放/评估使用 `start` | | 训练 `window_steps` | `[0]` | | 数据格式 | 仅 shard 目录(`shard_*.npz`) | diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md index 7ea3c7e9..5e45cc1b 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md @@ -131,4 +131,4 @@ train_mimic/scripts - `train_mimic/app.py` - 训练/播放/评估的统一入口 - `train_mimic/tasks/tracking/config/env.py` - General-Tracking-G1 环境构建器 - `train_mimic/tasks/tracking/config/rl.py` - TemporalCNN PPO 配置 -- `train_mimic/tasks/tracking/mdp/commands.py` - 支持 `uniform`、`start`、`adaptive` 和 `rewind` 采样模式。训练默认使用 `uniform`;播放/评估使用 `start`。 +- `train_mimic/tasks/tracking/mdp/commands.py` - 支持 `uniform`、`start` 和 `rewind` 采样模式。训练默认使用 `rewind`;播放/评估使用 `start`。 diff --git a/tests/test_motion_sampling.py b/tests/test_motion_sampling.py index cf1b5132..be39fd42 100644 --- a/tests/test_motion_sampling.py +++ b/tests/test_motion_sampling.py @@ -152,78 +152,6 @@ def test_motion_lib_window_start_and_end_times_follow_valid_center_range(tmp_pat assert torch.allclose(motion.clip_sample_end_s[motion_ids], torch.tensor([3.0])) -def test_motion_lib_adaptive_bins_are_clip_local(tmp_path: Path) -> None: - motion_path = _write_shard_dir( - tmp_path / "motion_adaptive_bins", - [_clip_dict(num_frames=6), _clip_dict(num_frames=8)], - weights=[1.0, 3.0], - ) - motion = MotionLib( - str(motion_path), - body_indexes=torch.tensor([0, 1], dtype=torch.long), - window_steps=(0,), - ) - - num_bins = motion.prepare_adaptive_sampling(bin_size_frames=2) - - assert num_bins == 7 - assert motion.adaptive_bin_clip_ids.tolist() == [0, 0, 0, 1, 1, 1, 1] - assert motion.adaptive_bin_start_frames.tolist() == [0, 2, 4, 0, 2, 4, 6] - assert motion.adaptive_bin_end_frames.tolist() == [2, 4, 5, 2, 4, 6, 7] - - clip0_mass = motion.adaptive_bin_base_probs[:3].sum() - clip1_mass = motion.adaptive_bin_base_probs[3:].sum() - assert torch.allclose(clip0_mass, torch.tensor(0.25), atol=1e-6) - assert torch.allclose(clip1_mass, torch.tensor(0.75), atol=1e-6) - - -def test_motion_lib_adaptive_sampling_never_crosses_clip_boundaries(tmp_path: Path) -> None: - motion_path = _write_shard_dir( - tmp_path / "motion_adaptive_sample", - [_clip_dict(num_frames=6), _clip_dict(num_frames=8)], - ) - motion = MotionLib( - str(motion_path), - body_indexes=torch.tensor([0, 1], dtype=torch.long), - window_steps=(0,), - ) - motion.prepare_adaptive_sampling(bin_size_frames=2) - - motion_ids, motion_times, bins = motion.sample_adaptive_times( - motion.adaptive_bin_base_probs, - 512, - ) - sampled_frames = motion_times * motion.clip_fps[motion_ids] - - assert torch.all(sampled_frames >= motion.clip_sample_starts[motion_ids]) - assert torch.all(sampled_frames < motion.clip_sample_ends[motion_ids]) - assert torch.equal(motion.adaptive_bin_clip_ids[bins], motion_ids) - assert torch.equal(motion.adaptive_bins_for(motion_ids, motion_times), bins) - - -def test_motion_command_adaptive_sampling_state_round_trips() -> None: - source = MotionCommand.__new__(MotionCommand) - source.cfg = SimpleNamespace(sampling_mode="adaptive", adaptive_bin_size_frames=2) - source.adaptive_bin_failed_count = torch.tensor([0.0, 2.0, 4.0]) - source._current_adaptive_bin_failed = torch.tensor([1.0, 0.0, 3.0]) - - target = MotionCommand.__new__(MotionCommand) - target.cfg = SimpleNamespace(sampling_mode="adaptive", adaptive_bin_size_frames=2) - target._env = SimpleNamespace(device="cpu") - target.adaptive_bin_failed_count = torch.zeros(3) - target._current_adaptive_bin_failed = torch.zeros(3) - - state = source.get_adaptive_sampling_state() - assert state is not None - target.load_adaptive_sampling_state(state) - - assert torch.equal(target.adaptive_bin_failed_count, source.adaptive_bin_failed_count) - assert torch.equal( - target._current_adaptive_bin_failed, - source._current_adaptive_bin_failed, - ) - - class _FakeMotion: def __init__(self) -> None: self.clip_sample_start_s = torch.tensor([0.0, 1.0, 2.0]) diff --git a/tests/test_task_registry.py b/tests/test_task_registry.py index 9d889926..fdae6cc6 100644 --- a/tests/test_task_registry.py +++ b/tests/test_task_registry.py @@ -56,7 +56,7 @@ def test_general_tracking_task_is_registered() -> None: ] assert "actor_history" in env_cfg.observations assert "critic_history" in env_cfg.observations - assert env_cfg.commands["motion"].sampling_mode == "uniform" + assert env_cfg.commands["motion"].sampling_mode == "rewind" assert env_cfg.commands["motion"].window_steps == (0,) assert env_cfg.rewards["motion_global_root_lin_vel"].weight == 1.0 assert env_cfg.rewards["motion_global_root_lin_vel"].params == { diff --git a/train_mimic/scripts/train.py b/train_mimic/scripts/train.py index a576b9de..bd40db01 100644 --- a/train_mimic/scripts/train.py +++ b/train_mimic/scripts/train.py @@ -85,7 +85,7 @@ def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace: ), ) parser.add_argument("--sampling_mode", type=str, default=None, - choices=["uniform", "start", "adaptive", "rewind"], + choices=["uniform", "start", "rewind"], help="Motion sampling mode (default: from task config)") parser.add_argument("--rewind_prob", type=float, default=None, help="Rewind sampling probability for failed episodes") diff --git a/train_mimic/tasks/tracking/config/env.py b/train_mimic/tasks/tracking/config/env.py index 9f45a2ee..b4453816 100644 --- a/train_mimic/tasks/tracking/config/env.py +++ b/train_mimic/tasks/tracking/config/env.py @@ -180,7 +180,7 @@ def make_general_tracking_env_cfg( motion_cmd.anchor_body_name = "torso_link" motion_cmd.body_names = _TRACKING_BODY_NAMES motion_cmd.motion_file = DEFAULT_TRAIN_MOTION_FILE - motion_cmd.sampling_mode = "uniform" + motion_cmd.sampling_mode = "rewind" motion_cmd.window_steps = (0,) cfg.events["physics_material"].params[ diff --git a/train_mimic/tasks/tracking/mdp/commands.py b/train_mimic/tasks/tracking/mdp/commands.py index ce16af86..cc9cfbda 100644 --- a/train_mimic/tasks/tracking/mdp/commands.py +++ b/train_mimic/tasks/tracking/mdp/commands.py @@ -299,7 +299,6 @@ def __init__( ) self.clip_sample_start_s = self.clip_sample_starts.float() * self.clip_dt self.clip_sample_end_s = self.clip_sample_ends.float() * self.clip_dt - self._adaptive_bin_size_frames: int | None = None # ------------------------------------------------------------------ # Sampling helpers @@ -333,132 +332,6 @@ def sample_start_times(self, motion_ids: torch.Tensor) -> torch.Tensor: """Return the earliest valid center time for each motion id.""" return self.clip_sample_start_s[motion_ids] - def prepare_adaptive_sampling(self, bin_size_frames: int) -> int: - """Build clip-local adaptive sampling bins. - - Bins are cut only from each clip's valid center-frame range, so sampled - times never cross into adjacent clips in the flat motion arrays. - """ - if bin_size_frames <= 0: - raise ValueError( - f"adaptive_bin_size_frames must be positive, got {bin_size_frames}" - ) - if self._adaptive_bin_size_frames == bin_size_frames: - return int(self.adaptive_bin_clip_ids.numel()) - - clip_sample_starts = self.clip_sample_starts.cpu().numpy() - clip_sample_ends = self.clip_sample_ends.cpu().numpy() - clip_weights = self.clip_weights.cpu().numpy() - - bin_clip_ids: list[int] = [] - bin_start_frames: list[int] = [] - bin_end_frames: list[int] = [] - bin_base_weights: list[float] = [] - clip_bin_offsets = np.full(self.num_clips, -1, dtype=np.int64) - clip_bin_counts = np.zeros(self.num_clips, dtype=np.int64) - - for clip_id in range(self.num_clips): - clip_weight = float(clip_weights[clip_id]) - sample_start = int(clip_sample_starts[clip_id]) - sample_end = int(clip_sample_ends[clip_id]) - valid_length = sample_end - sample_start - if clip_weight <= 0.0 or valid_length <= 0: - continue - - clip_bin_offsets[clip_id] = len(bin_clip_ids) - for start in range(sample_start, sample_end, bin_size_frames): - end = min(start + bin_size_frames, sample_end) - width = end - start - bin_clip_ids.append(clip_id) - bin_start_frames.append(start) - bin_end_frames.append(end) - bin_base_weights.append(clip_weight * float(width) / float(valid_length)) - clip_bin_counts[clip_id] = len(bin_clip_ids) - clip_bin_offsets[clip_id] - - if not bin_clip_ids: - raise ValueError( - "Adaptive sampling has no valid bins. Check clip_weights and " - f"window_steps={list(self.window_steps)}." - ) - - device = self._device - self.adaptive_bin_clip_ids = torch.tensor( - bin_clip_ids, dtype=torch.long, device=device - ) - self.adaptive_bin_start_frames = torch.tensor( - bin_start_frames, dtype=torch.float32, device=device - ) - self.adaptive_bin_end_frames = torch.tensor( - bin_end_frames, dtype=torch.float32, device=device - ) - self.adaptive_bin_base_probs = torch.tensor( - bin_base_weights, dtype=torch.float32, device=device - ) - total = self.adaptive_bin_base_probs.sum() - if total <= 0: - raise ValueError("Adaptive sampling base probabilities sum to zero.") - self.adaptive_bin_base_probs = self.adaptive_bin_base_probs / total - self.adaptive_clip_bin_offsets = torch.tensor( - clip_bin_offsets, dtype=torch.long, device=device - ) - self.adaptive_clip_bin_counts = torch.tensor( - clip_bin_counts, dtype=torch.long, device=device - ) - self._adaptive_bin_size_frames = bin_size_frames - return int(self.adaptive_bin_clip_ids.numel()) - - def sample_adaptive_times( - self, - bin_probabilities: torch.Tensor, - n: int, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Sample clip ids and clip-local times from adaptive bin probabilities.""" - if self._adaptive_bin_size_frames is None: - raise RuntimeError("prepare_adaptive_sampling() must be called first.") - if bin_probabilities.shape != self.adaptive_bin_base_probs.shape: - raise ValueError( - "adaptive bin probability shape mismatch: " - f"{tuple(bin_probabilities.shape)} vs " - f"{tuple(self.adaptive_bin_base_probs.shape)}" - ) - sampled_bins = torch.multinomial(bin_probabilities, n, replacement=True) - motion_ids = self.adaptive_bin_clip_ids[sampled_bins] - starts = self.adaptive_bin_start_frames[sampled_bins] - ends = self.adaptive_bin_end_frames[sampled_bins] - frame_f = starts + torch.rand_like(starts) * (ends - starts) - motion_times = frame_f / self.clip_fps[motion_ids] - return motion_ids, motion_times, sampled_bins - - def adaptive_bins_for( - self, motion_ids: torch.Tensor, motion_times: torch.Tensor - ) -> torch.Tensor: - """Return adaptive bin ids for clip-local motion states.""" - if self._adaptive_bin_size_frames is None: - raise RuntimeError("prepare_adaptive_sampling() must be called first.") - counts = self.adaptive_clip_bin_counts[motion_ids] - offsets = self.adaptive_clip_bin_offsets[motion_ids] - bins = torch.full_like(motion_ids, -1) - valid = (counts > 0) & (offsets >= 0) - if not torch.any(valid): - return bins - - valid_ids = motion_ids[valid] - local_frames = torch.floor( - motion_times[valid] * self.clip_fps[valid_ids] - ).long() - rel = local_frames - self.clip_sample_starts[valid_ids] - local_bins = torch.div( - torch.clamp(rel, min=0), - self._adaptive_bin_size_frames, - rounding_mode="floor", - ) - local_bins = torch.minimum( - torch.clamp(local_bins, min=0), - counts[valid] - 1, - ) - bins[valid] = offsets[valid] + local_bins - return bins - def _compute_interpolation_state( self, motion_ids: torch.Tensor, @@ -630,23 +503,11 @@ def __init__(self, cfg: MotionCommandCfg, env: ManagerBasedRlEnv): device=self.device, window_steps=self.cfg.window_steps, ) - if self.cfg.sampling_mode == "adaptive": - adaptive_bin_count = self.motion.prepare_adaptive_sampling( - self.cfg.adaptive_bin_size_frames - ) - else: - adaptive_bin_count = 0 # Per-env motion state: clip id + elapsed time (seconds) self.motion_ids = torch.zeros(self.num_envs, dtype=torch.long, device=self.device) self.motion_times = torch.zeros(self.num_envs, dtype=torch.float32, device=self.device) self._step_dt = env.step_dt - self.adaptive_bin_failed_count = torch.zeros( - adaptive_bin_count, dtype=torch.float32, device=self.device - ) - self._current_adaptive_bin_failed = torch.zeros( - adaptive_bin_count, dtype=torch.float32, device=self.device - ) # Cached interpolated frames — refreshed every step self._cached_frames: dict[str, torch.Tensor] = {} @@ -679,9 +540,6 @@ def __init__(self, cfg: MotionCommandCfg, env: ManagerBasedRlEnv): self.metrics["error_body_rot"] = torch.zeros(self.num_envs, device=self.device) self.metrics["error_joint_pos"] = torch.zeros(self.num_envs, device=self.device) self.metrics["error_joint_vel"] = torch.zeros(self.num_envs, device=self.device) - self.metrics["sampling_entropy"] = torch.zeros(self.num_envs, device=self.device) - self.metrics["sampling_top1_prob"] = torch.zeros(self.num_envs, device=self.device) - self.metrics["sampling_failed_bin_mean"] = torch.zeros(self.num_envs, device=self.device) if self.cfg.feet_body_names: self._feet_body_indexes = [ @@ -926,62 +784,6 @@ def _uniform_sampling(self, env_ids: torch.Tensor): self.motion_ids[env_ids] = self.motion.sample_motion_ids(len(env_ids)) self.motion_times[env_ids] = self.motion.sample_times(self.motion_ids[env_ids]) - def _update_adaptive_sampling_metrics( - self, - sampling_probabilities: torch.Tensor, - ) -> None: - entropy = -( - sampling_probabilities * (sampling_probabilities + 1e-12).log() - ).sum() - if sampling_probabilities.numel() > 1: - entropy = entropy / torch.log( - torch.tensor( - float(sampling_probabilities.numel()), - dtype=torch.float32, - device=self.device, - ) - ) - pmax = sampling_probabilities.max() - failed = self.adaptive_bin_failed_count - self.metrics["sampling_entropy"][:] = entropy - self.metrics["sampling_top1_prob"][:] = pmax - self.metrics["sampling_failed_bin_mean"][:] = failed.mean() - - def _adaptive_sampling_probabilities(self) -> torch.Tensor: - sampling_probabilities = ( - self.adaptive_bin_failed_count - + self.cfg.adaptive_uniform_ratio * self.motion.adaptive_bin_base_probs - ) - probability_sum = sampling_probabilities.sum() - if probability_sum <= 0: - return self.motion.adaptive_bin_base_probs - return sampling_probabilities / probability_sum - - def _adaptive_sampling(self, env_ids: torch.Tensor): - episode_failed = self._env.termination_manager.terminated[env_ids] - if torch.any(episode_failed): - failed_env_ids = env_ids[episode_failed] - failed_bins = self.motion.adaptive_bins_for( - self.motion_ids[failed_env_ids], - self.motion_times[failed_env_ids], - ) - failed_bins = failed_bins[failed_bins >= 0] - if failed_bins.numel() > 0: - self._current_adaptive_bin_failed += torch.bincount( - failed_bins, - minlength=self.adaptive_bin_failed_count.numel(), - ).to(dtype=torch.float32, device=self.device) - - sampling_probabilities = self._adaptive_sampling_probabilities() - motion_ids, motion_times, _sampled_bins = self.motion.sample_adaptive_times( - sampling_probabilities, - len(env_ids), - ) - self.motion_ids[env_ids] = motion_ids - self.motion_times[env_ids] = motion_times - - self._update_adaptive_sampling_metrics(sampling_probabilities) - def _rewind_sampling(self, env_ids: torch.Tensor) -> None: _validate_rewind_sampling_cfg(self.cfg) @@ -1023,14 +825,12 @@ def _resample_command(self, env_ids: torch.Tensor): self.motion_times[env_ids] = self.motion.sample_start_times(self.motion_ids[env_ids]) elif self.cfg.sampling_mode == "uniform": self._uniform_sampling(env_ids) - elif self.cfg.sampling_mode == "adaptive": - self._adaptive_sampling(env_ids) elif self.cfg.sampling_mode == "rewind": self._rewind_sampling(env_ids) else: raise ValueError( f"Unsupported motion sampling_mode={self.cfg.sampling_mode!r}. " - "Supported modes are 'uniform', 'start', 'adaptive', and 'rewind'." + "Supported modes are 'uniform', 'start', and 'rewind'." ) if env_ids.numel() == 0: @@ -1155,56 +955,9 @@ def _update_command(self): delta_ori_w, self.body_pos_w - anchor_pos_w_repeat ) - if self.cfg.sampling_mode == "adaptive": - self.adaptive_bin_failed_count = ( - self.cfg.adaptive_alpha * self._current_adaptive_bin_failed - + (1.0 - self.cfg.adaptive_alpha) * self.adaptive_bin_failed_count - ) - self._current_adaptive_bin_failed.zero_() - self._update_adaptive_sampling_metrics( - self._adaptive_sampling_probabilities() - ) - self._refresh_body_local_cache() self._update_feet_standing() - def get_adaptive_sampling_state(self) -> dict[str, torch.Tensor | int] | None: - if self.cfg.sampling_mode != "adaptive": - return None - return { - "adaptive_bin_size_frames": int(self.cfg.adaptive_bin_size_frames), - "adaptive_bin_failed_count": self.adaptive_bin_failed_count.detach().cpu(), - "current_adaptive_bin_failed": self._current_adaptive_bin_failed.detach().cpu(), - } - - def load_adaptive_sampling_state(self, state: dict[str, torch.Tensor | int]) -> None: - if self.cfg.sampling_mode != "adaptive": - return - bin_size = int(state.get("adaptive_bin_size_frames", -1)) - if bin_size != self.cfg.adaptive_bin_size_frames: - raise ValueError( - "adaptive sampling checkpoint bin size mismatch: " - f"checkpoint={bin_size}, current={self.cfg.adaptive_bin_size_frames}" - ) - failed_count = state.get("adaptive_bin_failed_count") - current_failed = state.get("current_adaptive_bin_failed") - if not isinstance(failed_count, torch.Tensor) or not isinstance(current_failed, torch.Tensor): - raise ValueError("adaptive sampling checkpoint state is missing tensors") - if failed_count.shape != self.adaptive_bin_failed_count.shape: - raise ValueError( - "adaptive sampling checkpoint bin count mismatch: " - f"checkpoint={tuple(failed_count.shape)}, " - f"current={tuple(self.adaptive_bin_failed_count.shape)}" - ) - if current_failed.shape != self._current_adaptive_bin_failed.shape: - raise ValueError( - "adaptive sampling checkpoint current-bin count mismatch: " - f"checkpoint={tuple(current_failed.shape)}, " - f"current={tuple(self._current_adaptive_bin_failed.shape)}" - ) - self.adaptive_bin_failed_count.copy_(failed_count.to(self.device)) - self._current_adaptive_bin_failed.copy_(current_failed.to(self.device)) - # ------------------------------------------------------------------ # Visualization # ------------------------------------------------------------------ @@ -1289,11 +1042,8 @@ class MotionCommandCfg(CommandTermCfg): pose_range: dict[str, tuple[float, float]] = field(default_factory=dict) velocity_range: dict[str, tuple[float, float]] = field(default_factory=dict) joint_position_range: tuple[float, float] = (-0.52, 0.52) - sampling_mode: Literal["uniform", "start", "adaptive", "rewind"] = "uniform" + sampling_mode: Literal["uniform", "start", "rewind"] = "rewind" window_steps: tuple[int, ...] = (0,) - adaptive_bin_size_frames: int = 10 - adaptive_uniform_ratio: float = 0.1 - adaptive_alpha: float = 0.001 rewind_prob: float = 0.8 rewind_min_steps: int = 25 rewind_max_steps: int = 75 diff --git a/train_mimic/tasks/tracking/rl/runner.py b/train_mimic/tasks/tracking/rl/runner.py index 05ef981a..9e22c6b2 100644 --- a/train_mimic/tasks/tracking/rl/runner.py +++ b/train_mimic/tasks/tracking/rl/runner.py @@ -96,26 +96,6 @@ def __init__( def _motion_command(self) -> MotionCommand: return cast(MotionCommand, self.env.unwrapped.command_manager.get_term("motion")) - def save(self, path: str, infos=None) -> None: - motion_state = self._motion_command().get_adaptive_sampling_state() - if motion_state is not None: - infos = {**(infos or {}), "motion_adaptive_sampling_state": motion_state} - super().save(path, infos) - - def load( - self, - path: str, - load_cfg: dict | None = None, - strict: bool = True, - map_location: str | None = None, - ) -> dict: - infos = super().load(path, load_cfg=load_cfg, strict=strict, map_location=map_location) - if infos and "motion_adaptive_sampling_state" in infos: - self._motion_command().load_adaptive_sampling_state( - infos["motion_adaptive_sampling_state"] - ) - return infos - def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = False) -> None: """Run the learning loop using 1-based iteration numbering.""" if init_at_random_ep_len: From 95ec57eca490f207f16c3194ccfc563e3911d42c Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Sun, 14 Jun 2026 20:07:54 +0800 Subject: [PATCH 083/122] Switch training datasets to HDF5 shards --- AGENTS.md | 8 +- README.md | 3 +- docs/docs/getting-started/download-assets.md | 4 +- docs/docs/reference/architecture.md | 4 +- docs/docs/reference/dataset.md | 26 +- docs/docs/tutorials/training.md | 3 +- .../getting-started/download-assets.md | 4 +- .../current/reference/architecture.md | 4 +- .../current/reference/dataset.md | 24 +- .../current/tutorials/training.md | 3 +- scripts/review/build_dataset_from_review.py | 110 +++-- scripts/review/review_dataset.py | 60 +-- tests/test_dataset_v2.py | 82 ++-- tests/test_motion_sampling.py | 86 +++- tests/test_review_pipeline.py | 82 +++- tests/test_train_script.py | 12 +- train_mimic/app.py | 5 +- train_mimic/data/dataset_builder.py | 91 +++- train_mimic/data/dataset_lib.py | 285 +++++++++-- train_mimic/data/review_lib.py | 2 +- train_mimic/scripts/benchmark.py | 41 +- train_mimic/scripts/convert_pkl_to_npz.py | 10 +- train_mimic/scripts/data/split_shards.py | 217 --------- train_mimic/scripts/play.py | 2 +- train_mimic/scripts/train.py | 11 +- train_mimic/tasks/tracking/mdp/commands.py | 458 ++++++++++-------- train_mimic/tasks/tracking/rl/runner.py | 73 +-- 27 files changed, 946 insertions(+), 764 deletions(-) delete mode 100644 train_mimic/scripts/data/split_shards.py diff --git a/AGENTS.md b/AGENTS.md index 595e1638..986ebec5 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -72,7 +72,7 @@ train_mimic/ # Training package │ ├── env.py # General-Tracking-G1 env builder │ └── rl.py # TemporalCNN PPO cfg ├── tasks/tracking/rl/ -│ ├── runner.py # ONNX export wrapper for policy + motion labels +│ ├── runner.py # Training runner and policy ONNX export wrapper │ ├── conv1d_encoder.py # 1-D CNN encoder for temporal history groups │ └── temporal_cnn_model.py # TemporalCNN actor/critic model └── scripts/ @@ -203,9 +203,11 @@ The single supported training task is `General-Tracking-G1` (experiment name: `g ### Dataset Pipeline - Dataset build spec supports a `preprocess` section for root-xy normalization, ground alignment, and basic clip filtering -- Final dataset outputs are shard-only: `data/datasets//train/shard_*.npz` and `data/datasets//val/shard_*.npz` -- Each shard stores clip-aware metadata (`clip_starts`, `clip_lengths`, `clip_fps`, `clip_weights`); `MotionLib` loads only shard directories +- Final training dataset outputs are HDF5 split directories: `data/datasets//train/manifest.json` + `shard_*.h5` and the same under `val/` +- Each shard stores clip-aware window metadata (`clip_starts`, `clip_lengths`, `clip_fps`, `clip_weights`); long clips are split into overlapping bounded windows +- `MotionLib` loads only a configurable HDF5 subset cache into CPU/GPU memory, stages the next cache, and swaps at the PPO rollout barrier - `MotionLib` samples only valid center frames for the configured `window_steps`; default is `window_steps=[0]` +- Training supports `uniform` and `rewind` sampling on the active cache; in distributed training each rank sets a rank-offset `cache_seed` - `scripts/run/record_pico_motion.py` records Pico live body tracking as retargeted G1 motion NPZ clips in `data/pico_motion/clips/`; it opens a live `Retarget` viewer, uses terminal keys `R/S/D/N/Q`, stores semantic labels in filenames, and intentionally does not write per-clip JSON - Build Pico-recorded clips into shards with `python train_mimic/scripts/data/build_dataset.py --spec data/pico_motion/pico_recorded.yaml --force`; at least two clips are required for non-empty train/val splits diff --git a/README.md b/README.md index 8745fd7e..d3960059 100644 --- a/README.md +++ b/README.md @@ -73,7 +73,7 @@ a new name, and `Q` to quit. Saved clips are written to `data/pico_motion/clips/` using the semantic label in the filename, with no sidecar JSON. -Merge recorded clips into the standard shard dataset: +Merge recorded clips into the standard HDF5 shard dataset: ```bash python train_mimic/scripts/data/build_dataset.py \ @@ -95,6 +95,7 @@ Full docs at **[BotRunner64.github.io/Teleopit](https://BotRunner64.github.io/Te - Switched default `vr_hand_pose` to a low-latency somehand path with 60 Hz hand retargeting and reduced smoothing. - Realtime mode switches and pause/resume now preserve GMR IK warm-starts instead of cold-starting the retargeter on each transition. - Added an interactive Pico motion recorder that saves retargeted G1 motion clips as training-ready NPZ files. +- Switched training datasets to HDF5 split directories with subset caching and rollout-barrier cache swaps to reduce CPU/GPU memory use. - General-Tracking-G1 training defaults to `rewind` motion sampling and also supports `uniform`; playback/benchmark use `start`. - Added optional `sampling_mode=rewind` for training, which restarts failed episodes from the same clip after rewinding a configurable number of policy steps. - Added root velocity, joint tracking, and survival rewards to the General-Tracking-G1 training objective. diff --git a/docs/docs/getting-started/download-assets.md b/docs/docs/getting-started/download-assets.md index fc8e6182..2f561efd 100644 --- a/docs/docs/getting-started/download-assets.md +++ b/docs/docs/getting-started/download-assets.md @@ -29,8 +29,8 @@ python scripts/setup/download_assets.py --only gmr ckpt bvh |-------|------|---------| | `track.onnx` | 4 MB | ONNX inference model | | `track.pt` | 27 MB | PyTorch checkpoint (for resume training) | -| `data/datasets/seed/train/shard_*.npz` | ~25 GB | Training dataset | -| `data/datasets/seed/val/shard_*.npz` | ~1.4 GB | Validation dataset | +| `data/datasets/seed/train/manifest.json` + `shard_*.h5` | ~25 GB | Training dataset | +| `data/datasets/seed/val/manifest.json` + `shard_*.h5` | ~1.4 GB | Validation dataset | | `data/sample_bvh/*.bvh` | 5 MB | Sample motion files | | `teleopit/retargeting/gmr/assets/` | ~1.2 GB | GMR retargeting robot models | diff --git a/docs/docs/reference/architecture.md b/docs/docs/reference/architecture.md index ce51f25a..b6a2358f 100644 --- a/docs/docs/reference/architecture.md +++ b/docs/docs/reference/architecture.md @@ -60,7 +60,7 @@ train_mimic/scripts/data | Actor/Critic | TemporalCNN (2048, 1024, 512, 256, 128) | | Training sampling | Default `rewind`; also supports `uniform`; playback/benchmark use `start` | | Training `window_steps` | `[0]` | -| Data format | Shard directories only (`shard_*.npz`) | +| Data format | HDF5 shard directories (`manifest.json` + `shard_*.h5`) | ## Constraints @@ -76,4 +76,4 @@ train_mimic/scripts/data **Stable training entry points:** `train.py`, `play.py`, `benchmark.py`, `save_onnx.py` -**Stable data entry points:** `build_dataset.py`, `split_shards.py` +**Stable data entry points:** `build_dataset.py` diff --git a/docs/docs/reference/dataset.md b/docs/docs/reference/dataset.md index f2257af5..f7a058d8 100644 --- a/docs/docs/reference/dataset.md +++ b/docs/docs/reference/dataset.md @@ -10,7 +10,7 @@ sidebar_position: 3 python scripts/setup/download_assets.py --only data ``` -Then train directly with the shard directory: +Then train directly with the HDF5 shard directory: ```bash python train_mimic/scripts/train.py --motion_file data/datasets/seed/train @@ -37,7 +37,7 @@ enter a new name, and `Q` to quit. Saved clips go to `data/pico_motion/clips/` as `_.npz`; no per-clip JSON is written, so clips can be renamed or deleted manually. -Build all recorded clips into the standard shard dataset: +Build all recorded clips into the standard HDF5 shard dataset: ```bash python train_mimic/scripts/data/build_dataset.py \ @@ -49,7 +49,7 @@ can be populated. ## Custom Dataset Construction -Data pipeline: `typed source YAML -> preprocess/filter -> shard-only training data` +Data pipeline: `typed source YAML -> preprocess/filter -> HDF5 shard-only training data` ```bash python train_mimic/scripts/data/build_dataset.py \ @@ -63,15 +63,18 @@ data/datasets// ├── clips/ # Optional; only for per-clip intermediates │ └── /... ├── train/ -│ └── shard_*.npz +│ ├── manifest.json +│ └── shard_*.h5 ├── val/ -│ └── shard_*.npz +│ ├── manifest.json +│ └── shard_*.h5 ├── manifest_resolved.csv └── build_info.json ``` - If the spec contains `bvh` or `npz` sources, the builder retains/generates `clips/` - If the spec is all `pkl` or `seed_csv` sources, the builder takes a batch path producing split-level shards directly +- Training loads only a subset cache from the HDF5 split, stages the next cache, and swaps caches at the PPO rollout barrier. ## YAML Spec Format @@ -167,16 +170,3 @@ python train_mimic/scripts/data/check_motion_npz_fk.py \ ``` Recommended thresholds: `pos_max < 1e-3 m`, `quat_mean < 0.05 rad`, `quat_p95 < 0.10 rad`. - -## Re-shard - -Split large shards for distribution: - -```bash -python train_mimic/scripts/data/split_shards.py \ - --input data/datasets/seed/train \ - --output data/datasets/seed/train_small_shards \ - --max_size_gb 2 -``` - -Each shard is self-contained with full clip metadata. diff --git a/docs/docs/tutorials/training.md b/docs/docs/tutorials/training.md index 50aacb7e..d8e8706e 100644 --- a/docs/docs/tutorials/training.md +++ b/docs/docs/tutorials/training.md @@ -74,7 +74,8 @@ torchrun \ - `--num_envs` is per-GPU in multi-GPU mode - `--num_envs` is also per-process in multi-node mode, so total environments scale with `world_size` - Default logger is TensorBoard. Use `--logger wandb` or `--logger swanlab` to select W&B or SwanLab; the project name defaults to `experiment_name` -- `--motion_file` accepts only shard directories (containing `shard_*.npz` files) +- `--motion_file` accepts only HDF5 shard directories containing `manifest.json` and `shard_*.h5` files +- `--cache_num_clips` controls the active HDF5 subset size; `--cache_swap_interval_steps` controls how often the next subset is swapped in at a rollout barrier - `--max_iterations` means additional iterations; resuming from `model_12000.pt` with `--max_iterations 18000` trains to `model_30000.pt` ## Export ONNX diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/download-assets.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/download-assets.md index 419a2222..8f278db7 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/download-assets.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/download-assets.md @@ -29,8 +29,8 @@ python scripts/setup/download_assets.py --only gmr ckpt bvh |------|------|------| | `track.onnx` | 4 MB | ONNX 推理模型 | | `track.pt` | 27 MB | PyTorch 检查点(用于恢复训练) | -| `data/datasets/seed/train/shard_*.npz` | ~25 GB | 训练数据集 | -| `data/datasets/seed/val/shard_*.npz` | ~1.4 GB | 验证数据集 | +| `data/datasets/seed/train/manifest.json` + `shard_*.h5` | ~25 GB | 训练数据集 | +| `data/datasets/seed/val/manifest.json` + `shard_*.h5` | ~1.4 GB | 验证数据集 | | `data/sample_bvh/*.bvh` | 5 MB | 示例动捕文件 | | `teleopit/retargeting/gmr/assets/` | ~1.2 GB | GMR 重定向机器人模型 | diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/architecture.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/architecture.md index e43857c3..5e03c8f5 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/architecture.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/architecture.md @@ -60,7 +60,7 @@ train_mimic/scripts/data | Actor/Critic | TemporalCNN(2048、1024、512、256、128) | | 训练采样 | 默认 `rewind`;也支持 `uniform`;播放/评估使用 `start` | | 训练 `window_steps` | `[0]` | -| 数据格式 | 仅 shard 目录(`shard_*.npz`) | +| 数据格式 | HDF5 shard 目录(`manifest.json` + `shard_*.h5`) | ## 约束 @@ -76,4 +76,4 @@ train_mimic/scripts/data **稳定训练入口:** `train.py`、`play.py`、`benchmark.py`、`save_onnx.py` -**稳定数据入口:** `build_dataset.py`、`split_shards.py` +**稳定数据入口:** `build_dataset.py` diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/dataset.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/dataset.md index 0ab0e366..439a9e41 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/dataset.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/dataset.md @@ -10,7 +10,7 @@ sidebar_position: 3 python scripts/setup/download_assets.py --only data ``` -下载后直接传 shard 目录用于训练: +下载后直接传 HDF5 shard 目录用于训练: ```bash python train_mimic/scripts/train.py --motion_file data/datasets/seed/train @@ -35,7 +35,7 @@ python scripts/run/record_pico_motion.py `data/pico_motion/clips/`,文件名格式为 `_.npz`;不会写 每段 clip 的 JSON,因此可以手动改名或删除。 -将所有已录制 clips 构建为标准 shard 数据集: +将所有已录制 clips 构建为标准 HDF5 shard 数据集: ```bash python train_mimic/scripts/data/build_dataset.py \ @@ -46,7 +46,7 @@ python train_mimic/scripts/data/build_dataset.py \ ## 自定义构建 -数据主线:`typed source YAML -> preprocess/filter -> shard-only 训练数据` +数据主线:`typed source YAML -> preprocess/filter -> HDF5 shard-only 训练数据` ```bash python train_mimic/scripts/data/build_dataset.py \ @@ -60,15 +60,18 @@ data/datasets// ├── clips/ # 可选;仅在需要逐 clip 中间产物时存在 │ └── /... ├── train/ -│ └── shard_*.npz +│ ├── manifest.json +│ └── shard_*.h5 ├── val/ -│ └── shard_*.npz +│ ├── manifest.json +│ └── shard_*.h5 ├── manifest_resolved.csv └── build_info.json ``` - 若 spec 包含 `bvh` 或 `npz` source,builder 会保留/生成 `clips/` - 若 spec 全部是 `pkl` 或 `seed_csv` source,直接并行产出 split 级别的 shard,默认不写中间 clip 文件 +- 训练时只从 HDF5 split 加载一个 subset cache,同时预加载下一个 cache,并在 PPO rollout barrier 处切换。 ## YAML spec @@ -152,14 +155,3 @@ python train_mimic/scripts/data/check_motion_npz_fk.py \ ``` 推荐判据:`pos_max < 1e-3 m`、`quat_mean < 0.05 rad`、`quat_p95 < 0.10 rad`。 - -## 重新切分 shard - -```bash -python train_mimic/scripts/data/split_shards.py \ - --input data/datasets/seed/train \ - --output data/datasets/seed/train_small_shards \ - --max_size_gb 2 -``` - -每个 shard 是自包含的 merged NPZ(含完整 clip metadata),训练时直接传目录。 diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md index 5e45cc1b..0e241b7c 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md @@ -74,7 +74,8 @@ torchrun \ - 多卡模式下 `--num_envs` 为每张 GPU 的环境数量 - 多机模式下 `--num_envs` 也按每个进程计算,因此总环境数会随 `world_size` 线性增长 - 默认日志工具为 TensorBoard。使用 `--logger wandb` 或 `--logger swanlab` 可选择 W&B 或 SwanLab;项目名默认使用 `experiment_name` -- `--motion_file` 仅接受分片目录(包含 `shard_*.npz` 文件的目录) +- `--motion_file` 仅接受包含 `manifest.json` 和 `shard_*.h5` 文件的 HDF5 分片目录 +- `--cache_num_clips` 控制当前 HDF5 subset cache 大小;`--cache_swap_interval_steps` 控制在 rollout barrier 切换下一个 subset 的频率 - `--max_iterations` 表示追加迭代次数;例如从 `model_12000.pt` 恢复训练并设置 `--max_iterations 18000`,最终将训练到 `model_30000.pt` ## 导出 ONNX diff --git a/scripts/review/build_dataset_from_review.py b/scripts/review/build_dataset_from_review.py index 15f8d9a4..e5cc32c4 100644 --- a/scripts/review/build_dataset_from_review.py +++ b/scripts/review/build_dataset_from_review.py @@ -1,8 +1,8 @@ #!/usr/bin/env python3 -"""Rebuild train/val shard directories from a filtered manifest (review results). +"""Rebuild train/val HDF5 shard directories from a filtered manifest. Reads filtered_manifest.csv (output of export_reviewed_manifest.py), -verifies all NPZ files exist, and rebuilds cleaned train/val shard splits. +verifies all referenced motion files exist, and rebuilds cleaned train/val HDF5 splits. Usage: python scripts/data/build_dataset_from_review.py \ @@ -15,6 +15,7 @@ import argparse import csv import sys +from dataclasses import asdict from pathlib import Path import numpy as np @@ -23,12 +24,14 @@ sys.path.insert(0, str(PROJECT_ROOT)) from train_mimic.data.dataset_lib import ( - extract_clip_arrays, - merge_clip_dicts, - merge_npz_files, + merge_clip_dicts_payload, + read_motion_clip, utc_now_iso, + write_hdf5_manifest, + write_hdf5_motion_shard, write_json, ) +from train_mimic.data.dataset_builder import DatasetClipRow, write_manifest_resolved def main() -> None: @@ -87,7 +90,7 @@ def main() -> None: print("ERROR: filtered manifest has no data rows", file=sys.stderr) sys.exit(1) - # Verify all NPZ files exist + # Verify all referenced motion files exist missing = [] for row in rows: p = Path(row["resolved_npz_path"]) @@ -102,7 +105,7 @@ def main() -> None: missing.append(f" line {row['line_no']}: {row['clip_id']} -> {row['resolved_npz_path']}") if missing: - print(f"ERROR: {len(missing)} NPZ files not found:", file=sys.stderr) + print(f"ERROR: {len(missing)} motion files not found:", file=sys.stderr) for m in missing[:20]: print(m, file=sys.stderr) if len(missing) > 20: @@ -133,8 +136,7 @@ def main() -> None: output_dir.mkdir(parents=True, exist_ok=True) - # Check if any rows use clip_index (batch-built dataset) - has_indexed_clips = any(r["clip_index"] >= 0 for r in rows) + all_output_rows: list[DatasetClipRow] = [] def _merge_split(split_rows: list[dict], split_name: str) -> dict | None: if not split_rows: @@ -142,61 +144,67 @@ def _merge_split(split_rows: list[dict], split_name: str) -> dict | None: print(f"Merging {len(split_rows)} {split_name} clips...") split_dir = output_dir / split_name split_dir.mkdir(parents=True, exist_ok=True) - out = split_dir / "shard_000.npz" - - if has_indexed_clips: - # Extract individual clip slices from shard NPZ files - clip_dicts = [] - for r in split_rows: - npz_path = Path(r["resolved_npz_path"]) - ci = r["clip_index"] - if ci >= 0: - clip_dicts.append(extract_clip_arrays(npz_path, ci)) - else: - # Standalone clip — load entire file as one clip dict - d = np.load(npz_path, allow_pickle=True) - clip_dicts.append({ - "fps": int(d["fps"]), - "joint_pos": np.asarray(d["joint_pos"]), - "joint_vel": np.asarray(d["joint_vel"]), - "body_pos_w": np.asarray(d["body_pos_w"]), - "body_quat_w": np.asarray(d["body_quat_w"]), - "body_lin_vel_w": np.asarray(d["body_lin_vel_w"]), - "body_ang_vel_w": np.asarray(d["body_ang_vel_w"]), - "body_names": np.asarray(d["body_names"]), - }) - weights_list = [r["weight"] for r in split_rows] - stats = merge_clip_dicts( - clip_dicts, out, - target_fps=args.target_fps, weights=weights_list, - ) - else: - # Legacy per-clip NPZ files — use file-based merge - files = [Path(r["resolved_npz_path"]) for r in split_rows] - weights_list = [r["weight"] for r in split_rows] - stats = merge_npz_files( - files, out, - target_fps=args.target_fps, weights=weights_list, - ) - - stats["output"] = str(split_dir) - stats["shards"] = 1 + out = split_dir / "shard_000.h5" + + clip_dicts = [ + read_motion_clip(Path(r["resolved_npz_path"]), int(r["clip_index"])) + for r in split_rows + ] + weights_list = [r["weight"] for r in split_rows] + payload = merge_clip_dicts_payload( + clip_dicts, + target_fps=args.target_fps, + weights=weights_list, + ) + h5_info = write_hdf5_motion_shard(payload, out) + write_hdf5_manifest( + split_dir, + shard_infos=[h5_info], + fps=int(payload["fps"]), + body_names=np.asarray(payload["body_names"]), + ) + + source_lengths = np.asarray(payload["clip_lengths"], dtype=np.int64) + for clip_index, (r, num_frames) in enumerate(zip(split_rows, source_lengths)): + all_output_rows.append(DatasetClipRow( + clip_id=r["clip_id"], + source=r["source"], + file_rel=r["file_rel"], + num_frames=int(num_frames), + fps=int(payload["fps"]), + resolved_split=split_name, + resolved_npz_path=str(out), + weight=float(r["weight"]), + clip_index=clip_index, + )) + + total_frames = int(np.asarray(payload["joint_pos"]).shape[0]) + stats = { + "output": str(split_dir), + "shards": 1, + "clips": int(h5_info["clips"]), + "num_clips": int(h5_info["clips"]), + "source_clips": int(h5_info["source_clips"]), + "frames": total_frames, + "fps": int(payload["fps"]), + "duration_s": float(total_frames / max(int(payload["fps"]), 1)), + } print(f" {split_name}/: {stats['frames']} frames, {stats['duration_s'] / 60:.1f} min") return stats train_stats = _merge_split(train_rows, "train") val_stats = _merge_split(val_rows, "val") - # Copy manifest into output dir - import shutil - shutil.copy2(manifest_path, output_dir / "manifest_resolved.csv") + resolved_manifest = write_manifest_resolved(all_output_rows, output_dir) # Write build info report = { "built_at_utc": utc_now_iso(), "source_manifest": str(manifest_path), + "manifest_resolved": str(resolved_manifest), "output_dir": str(output_dir), "target_fps": args.target_fps, + "source_rows": [asdict(row) for row in all_output_rows], "clip_counts": { "total": len(rows), "train": len(train_rows), diff --git a/scripts/review/review_dataset.py b/scripts/review/review_dataset.py index f87568a9..4966db94 100644 --- a/scripts/review/review_dataset.py +++ b/scripts/review/review_dataset.py @@ -29,6 +29,7 @@ from mjlab.viewer.viser import ViserMujocoScene from teleopit.runtime.assets import UNITREE_G1_MJLAB_XML, missing_gmr_assets_message +from train_mimic.data.dataset_lib import read_motion_clip from train_mimic.data.review_lib import ( ReviewRow, ReviewStats, @@ -42,11 +43,11 @@ # --------------------------------------------------------------------------- -# ClipPlayer: loads NPZ clip and drives MuJoCo qpos per frame +# ClipPlayer: loads motion clips and drives MuJoCo qpos per frame # --------------------------------------------------------------------------- class ClipPlayer: - """Loads a single clip NPZ and sets MuJoCo qpos frame-by-frame.""" + """Loads a single motion clip and sets MuJoCo qpos frame-by-frame.""" def __init__(self, mj_model: mujoco.MjModel) -> None: self.model = mj_model @@ -56,47 +57,18 @@ def __init__(self, mj_model: mujoco.MjModel) -> None: self._pelvis_quat: np.ndarray | None = None # (T, 4) wxyz self._fps: int = 30 self._num_frames: int = 0 - # Cache for shard NPZ: avoid re-reading large shard files on every clip switch - self._cached_npz_path: str | None = None - self._cached_npz_data: dict[str, np.ndarray] | None = None - - def _get_npz_data(self, npz_path: Path) -> dict[str, np.ndarray]: - """Return NPZ data, using cache for shard files.""" - path_str = str(npz_path) - if self._cached_npz_path == path_str and self._cached_npz_data is not None: - return self._cached_npz_data - d = dict(np.load(path_str, allow_pickle=True)) - # Only cache shard NPZ files (those with clip_starts) - if "clip_starts" in d: - self._cached_npz_path = path_str - self._cached_npz_data = d - else: - self._cached_npz_path = None - self._cached_npz_data = None - return d - def load_clip(self, npz_path: Path, clip_index: int = -1) -> None: - """Load NPZ clip data. + def load_clip(self, motion_path: Path, clip_index: int = -1) -> None: + """Load one source clip from an HDF5 shard. Args: - npz_path: Path to NPZ file (standalone clip or shard file). - clip_index: If >= 0, extract this clip from a shard NPZ using - clip_starts/clip_lengths. If -1, load the entire file - as a single clip. + motion_path: Path to an HDF5 shard. + clip_index: Source-clip index for HDF5 rows. """ - d = self._get_npz_data(npz_path) - - if clip_index >= 0 and "clip_starts" in d and "clip_lengths" in d: - start = int(d["clip_starts"][clip_index]) - length = int(d["clip_lengths"][clip_index]) - s = slice(start, start + length) - self._joint_pos = np.asarray(d["joint_pos"][s]) - body_pos_w = np.asarray(d["body_pos_w"][s]) - body_quat_w = np.asarray(d["body_quat_w"][s]) - else: - self._joint_pos = np.asarray(d["joint_pos"]) - body_pos_w = np.asarray(d["body_pos_w"]) - body_quat_w = np.asarray(d["body_quat_w"]) + d = read_motion_clip(motion_path, clip_index) + self._joint_pos = np.asarray(d["joint_pos"]) + body_pos_w = np.asarray(d["body_pos_w"]) + body_quat_w = np.asarray(d["body_quat_w"]) self._pelvis_pos = body_pos_w[:, 0, :] # pelvis = body 0 self._pelvis_quat = body_quat_w[:, 0, :] @@ -451,13 +423,13 @@ def _load_current_clip(self) -> None: self._info_html.content = "No clips to review" return - # Resolve NPZ path (use resolved_npz_path which always points to .npz) - npz_path = Path(row.resolved_npz_path) - if not npz_path.is_absolute(): - npz_path = self._project_root / npz_path + # CSV keeps the historical column name, but current rows point at HDF5 shards. + motion_path = Path(row.resolved_npz_path) + if not motion_path.is_absolute(): + motion_path = self._project_root / motion_path try: - self._player.load_clip(npz_path, clip_index=row.clip_index) + self._player.load_clip(motion_path, clip_index=row.clip_index) except Exception as exc: self._info_html.content = f"Error loading clip:
{exc}" return diff --git a/tests/test_dataset_v2.py b/tests/test_dataset_v2.py index f5eebe82..207d2fe8 100644 --- a/tests/test_dataset_v2.py +++ b/tests/test_dataset_v2.py @@ -6,8 +6,10 @@ import numpy as np import pytest +import h5py from train_mimic.data import dataset_builder +from train_mimic.data.dataset_lib import write_hdf5_motion_shard from train_mimic.data.dataset_builder import ( DatasetClipRow, SourceInputFile, @@ -520,15 +522,17 @@ def test_build_dataset_from_spec_writes_shard_directories(tmp_path: Path) -> Non assert report["build_dir"] == str(dataset_dir) assert (dataset_dir / "clips" / "npz_src" / "clip_a.npz").is_file() assert (dataset_dir / "clips" / "npz_src" / "clip_b.npz").is_file() - assert (dataset_dir / "train" / "shard_000.npz").is_file() - assert (dataset_dir / "val" / "shard_000.npz").is_file() + assert (dataset_dir / "train" / "manifest.json").is_file() + assert (dataset_dir / "train" / "shard_000.h5").is_file() + assert (dataset_dir / "val" / "manifest.json").is_file() + assert (dataset_dir / "val" / "shard_000.h5").is_file() assert (dataset_dir / "manifest_resolved.csv").is_file() assert (dataset_dir / "build_info.json").is_file() assert report["clip_counts"]["total"] == 2 - train_data = np.load(dataset_dir / "train" / "shard_000.npz", allow_pickle=True) - assert "clip_starts" in train_data.files - assert "clip_lengths" in train_data.files + with h5py.File(dataset_dir / "train" / "shard_000.h5", "r") as train_data: + assert "clip_starts" in train_data + assert "clip_lengths" in train_data def test_collect_clip_rows_ignores_stale_excluded_cached_npz(tmp_path: Path) -> None: @@ -784,7 +788,7 @@ def _convert(path: str, **_kwargs): [str(short_path), str(valid_path)], [1.0, 1.0], 30, - str(tmp_path / "merged.npz"), + str(tmp_path / "merged.h5"), "train", preprocess=dataset_builder.DatasetPreprocessSpec( normalize_root_xy=True, @@ -793,10 +797,27 @@ def _convert(path: str, **_kwargs): ), ) - merged = np.load(tmp_path / "merged.npz", allow_pickle=True) assert stats["clips"] == 1 assert stats["kept_file_paths"] == [str(valid_path)] - assert merged["clip_lengths"].tolist() == [22] + with h5py.File(tmp_path / "merged.h5", "r") as merged: + assert merged["clip_lengths"][()].tolist() == [22] + + +def test_shard_stats_counts_real_frames_not_overlapped_windows(tmp_path: Path) -> None: + stats = dataset_builder._shard_stats( + output_dir=tmp_path, + shard_infos=[{ + "path": tmp_path / "shard_000.h5", + "clips": 3, + "frames": 1000, + "clip_lengths": [512, 512, 512], + "source_clip_lengths": [1000], + }], + fps=30, + ) + + assert stats["frames"] == 1000 + assert stats["duration_s"] == 1000 / 30 @@ -845,7 +866,7 @@ def _hash_split(clip_id: str, _val_percent: int, _salt: str = "") -> str: num_bodies = len(_MJLAB_G1_BODY_NAMES) - def _write_merged(path: Path, lengths: list[int]) -> None: + def _write_merged(path: Path, lengths: list[int]) -> dict: total = sum(lengths) joint_pos = np.zeros((total, 29), dtype=np.float32) joint_vel = np.zeros_like(joint_pos) @@ -858,29 +879,28 @@ def _write_merged(path: Path, lengths: list[int]) -> None: clip_starts = np.zeros(len(lengths), dtype=np.int64) if len(lengths) > 1: clip_starts[1:] = np.cumsum(clip_lengths[:-1]) - np.savez( - path, - fps=30, - joint_pos=joint_pos, - joint_vel=joint_vel, - body_pos_w=body_pos_w, - body_quat_w=body_quat_w, - body_lin_vel_w=body_lin_vel_w, - body_ang_vel_w=body_ang_vel_w, - body_names=np.asarray(_MJLAB_G1_BODY_NAMES, dtype=str), - clip_starts=clip_starts, - clip_lengths=clip_lengths, - clip_fps=np.full(len(lengths), 30, dtype=np.int64), - clip_weights=np.ones(len(lengths), dtype=np.float64), - ) + return write_hdf5_motion_shard({ + "fps": 30, + "joint_pos": joint_pos, + "joint_vel": joint_vel, + "body_pos_w": body_pos_w, + "body_quat_w": body_quat_w, + "body_lin_vel_w": body_lin_vel_w, + "body_ang_vel_w": body_ang_vel_w, + "body_names": np.asarray(_MJLAB_G1_BODY_NAMES, dtype=str), + "clip_starts": clip_starts, + "clip_lengths": clip_lengths, + "clip_fps": np.full(len(lengths), 30, dtype=np.int64), + "clip_weights": np.ones(len(lengths), dtype=np.float64), + }, path) def _batch_convert_split(clips, target_fps, output_dir, jobs, split_name, preprocess): _ = clips, target_fps, jobs, preprocess output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) - shard_path = output_dir / "shard_000.npz" + shard_path = output_dir / "shard_000.h5" if split_name == "train": - _write_merged(shard_path, [22]) + h5_info = _write_merged(shard_path, [22]) return ({ "output": str(output_dir), "shards": 1, @@ -891,10 +911,12 @@ def _batch_convert_split(clips, target_fps, output_dir, jobs, split_name, prepro "duration_s": 22.0 / 30.0, }, [{ "path": shard_path, - "clip_lengths": [22], + "clip_lengths": h5_info["clip_lengths"], + "source_clip_lengths": h5_info["source_clip_lengths"], + "frames": h5_info["frames"], "kept_file_paths": [str(keep_train)], }]) - _write_merged(shard_path, [24]) + h5_info = _write_merged(shard_path, [24]) return ({ "output": str(output_dir), "shards": 1, @@ -905,7 +927,9 @@ def _batch_convert_split(clips, target_fps, output_dir, jobs, split_name, prepro "duration_s": 24.0 / 30.0, }, [{ "path": shard_path, - "clip_lengths": [24], + "clip_lengths": h5_info["clip_lengths"], + "source_clip_lengths": h5_info["source_clip_lengths"], + "frames": h5_info["frames"], "kept_file_paths": [str(keep_val)], }]) diff --git a/tests/test_motion_sampling.py b/tests/test_motion_sampling.py index be39fd42..42811bed 100644 --- a/tests/test_motion_sampling.py +++ b/tests/test_motion_sampling.py @@ -2,12 +2,13 @@ from pathlib import Path from types import SimpleNamespace +import json import numpy as np import pytest import torch -from train_mimic.data.dataset_lib import merge_clip_dicts +from train_mimic.data.dataset_lib import write_hdf5_manifest, write_hdf5_motion_shard from train_mimic.tasks.tracking.mdp.commands import MotionCommand, MotionLib @@ -49,7 +50,28 @@ def _write_shard_dir( weights: list[float] | None = None, ) -> Path: path.mkdir(parents=True, exist_ok=True) - merge_clip_dicts(clip_dicts, path / "shard_000.npz", weights=weights) + array_keys = [ + "joint_pos", "joint_vel", "body_pos_w", "body_quat_w", + "body_lin_vel_w", "body_ang_vel_w", + ] + clip_lengths = np.asarray([np.asarray(cd["joint_pos"]).shape[0] for cd in clip_dicts], dtype=np.int64) + clip_starts = np.zeros(len(clip_lengths), dtype=np.int64) + if len(clip_lengths) > 1: + clip_starts[1:] = np.cumsum(clip_lengths[:-1]) + merged = {key: np.concatenate([np.asarray(cd[key]) for cd in clip_dicts], axis=0) for key in array_keys} + merged["fps"] = int(clip_dicts[0]["fps"]) + merged["body_names"] = np.asarray(clip_dicts[0]["body_names"]) + merged["clip_starts"] = clip_starts + merged["clip_lengths"] = clip_lengths + merged["clip_fps"] = np.full(len(clip_dicts), int(clip_dicts[0]["fps"]), dtype=np.int64) + merged["clip_weights"] = np.asarray(weights if weights is not None else [1.0] * len(clip_dicts), dtype=np.float64) + shard_info = write_hdf5_motion_shard(merged, path / "shard_000.h5") + write_hdf5_manifest( + path, + shard_infos=[shard_info], + fps=int(clip_dicts[0]["fps"]), + body_names=np.asarray(clip_dicts[0]["body_names"]), + ) return path @@ -152,6 +174,66 @@ def test_motion_lib_window_start_and_end_times_follow_valid_center_range(tmp_pat assert torch.allclose(motion.clip_sample_end_s[motion_ids], torch.tensor([3.0])) +def test_motion_lib_rejects_shard_body_name_mismatch(tmp_path: Path) -> None: + motion_path = tmp_path / "motion_mismatch" + clip = _clip_dict() + shard0 = _write_shard_dir(motion_path, [clip]) + + clip_bad = _clip_dict() + clip_bad["body_names"] = np.asarray( + ["pelvis", "right_ankle_roll_link", "left_ankle_roll_link"], + dtype=str, + ) + array_keys = [ + "joint_pos", "joint_vel", "body_pos_w", "body_quat_w", + "body_lin_vel_w", "body_ang_vel_w", + ] + merged = {key: np.asarray(clip_bad[key]) for key in array_keys} + merged["fps"] = int(clip_bad["fps"]) + merged["body_names"] = np.asarray(clip_bad["body_names"]) + merged["clip_starts"] = np.asarray([0], dtype=np.int64) + merged["clip_lengths"] = np.asarray([np.asarray(clip_bad["joint_pos"]).shape[0]], dtype=np.int64) + merged["clip_fps"] = np.asarray([int(clip_bad["fps"])], dtype=np.int64) + merged["clip_weights"] = np.asarray([1.0], dtype=np.float64) + bad_info = write_hdf5_motion_shard(merged, motion_path / "shard_001.h5") + + (motion_path / "manifest.json").write_text( + json.dumps({ + "format": "teleopit_motion_hdf5", + "version": 1, + "fps": 1, + "body_names": np.asarray(clip["body_names"]).tolist(), + "shards": [ + {"path": "shard_000.h5", "clips": 1, "frames": 6}, + {"path": "shard_001.h5", "clips": 1, "frames": int(bad_info["frames"])}, + ], + "clips": 2, + "frames": 12, + }), + encoding="utf-8", + ) + + with pytest.raises(ValueError, match="body_names mismatch"): + MotionLib( + str(shard0), + body_indexes=torch.tensor([0, 1], dtype=torch.long), + window_steps=(0,), + ) + + +def test_write_hdf5_manifest_accepts_relative_shard_paths(tmp_path: Path) -> None: + motion_path = _write_shard_dir(tmp_path / "motion_relative_manifest", [_clip_dict()]) + manifest_path = write_hdf5_manifest( + motion_path, + shard_infos=[{"path": "shard_000.h5", "clips": 1, "frames": 6}], + fps=1, + body_names=np.asarray(["pelvis", "left_ankle_roll_link", "right_ankle_roll_link"]), + ) + + payload = json.loads(manifest_path.read_text(encoding="utf-8")) + assert payload["shards"][0]["path"] == "shard_000.h5" + + class _FakeMotion: def __init__(self) -> None: self.clip_sample_start_s = torch.tensor([0.0, 1.0, 2.0]) diff --git a/tests/test_review_pipeline.py b/tests/test_review_pipeline.py index a2e4e5b8..a452970e 100644 --- a/tests/test_review_pipeline.py +++ b/tests/test_review_pipeline.py @@ -11,10 +11,16 @@ sys.path.insert(0, _PROJECT_ROOT) import numpy as np +import h5py from scripts.review import build_dataset_from_review from scripts.review import export_reviewed_manifest from scripts.review import init_review_manifest +from train_mimic.data.dataset_lib import ( + merge_clip_dicts_payload, + write_hdf5_manifest, + write_hdf5_motion_shard, +) from train_mimic.data.review_lib import ReviewRow, load_review_state, save_review_state @@ -88,6 +94,42 @@ def _write_npz(path: Path, *, num_frames: int, fps: int) -> None: ) +def _clip_dict(*, num_frames: int, fps: int) -> dict[str, object]: + joint_pos = np.linspace(0.0, 0.2, num_frames * 29, dtype=np.float32).reshape(num_frames, 29) + joint_vel = np.gradient(joint_pos, axis=0).astype(np.float32) + body_pos_w = np.zeros((num_frames, len(BODY_NAMES), 3), dtype=np.float32) + body_pos_w[:, 0, 2] = np.linspace(0.75, 0.8, num_frames, dtype=np.float32) + body_pos_w[:, 1, 2] = body_pos_w[:, 0, 2] + 0.3 + body_quat_w = np.zeros((num_frames, len(BODY_NAMES), 4), dtype=np.float32) + body_quat_w[..., 0] = 1.0 + body_lin_vel_w = np.zeros((num_frames, len(BODY_NAMES), 3), dtype=np.float32) + body_ang_vel_w = np.zeros((num_frames, len(BODY_NAMES), 3), dtype=np.float32) + return { + "fps": fps, + "joint_pos": joint_pos, + "joint_vel": joint_vel, + "body_pos_w": body_pos_w, + "body_quat_w": body_quat_w, + "body_lin_vel_w": body_lin_vel_w, + "body_ang_vel_w": body_ang_vel_w, + "body_names": BODY_NAMES, + } + + +def _write_h5_split(path: Path, clip: dict[str, object]) -> Path: + path.mkdir(parents=True, exist_ok=True) + payload = merge_clip_dicts_payload([clip]) + shard_path = path / "shard_000.h5" + h5_info = write_hdf5_motion_shard(payload, shard_path) + write_hdf5_manifest( + path, + shard_infos=[h5_info], + fps=int(payload["fps"]), + body_names=np.asarray(payload["body_names"]), + ) + return shard_path + + def test_init_review_manifest_preserves_weight(tmp_path: Path, monkeypatch) -> None: manifest_path = tmp_path / "manifest_resolved.csv" review_path = tmp_path / "review_state.csv" @@ -179,12 +221,8 @@ def test_build_dataset_from_review_resamples_mixed_fps_and_preserves_weights( tmp_path: Path, monkeypatch, ) -> None: - cache_dir = tmp_path / "cache" - cache_dir.mkdir() - train_npz = cache_dir / "clip_train.npz" - val_npz = cache_dir / "clip_val.npz" - _write_npz(train_npz, num_frames=4, fps=24) - _write_npz(val_npz, num_frames=5, fps=30) + train_h5 = _write_h5_split(tmp_path / "source_train", _clip_dict(num_frames=4, fps=24)) + val_h5 = _write_h5_split(tmp_path / "source_val", _clip_dict(num_frames=5, fps=30)) filtered_manifest = tmp_path / "filtered_manifest.csv" with filtered_manifest.open("w", encoding="utf-8", newline="") as f: @@ -206,26 +244,26 @@ def test_build_dataset_from_review_resamples_mixed_fps_and_preserves_weights( [ "src:clip_train", "src", - str(train_npz), + str(train_h5), 4, 24, "train", - str(train_npz), + str(train_h5), 2.5, - -1, + 0, ] ) writer.writerow( [ "src:clip_val", "src", - str(val_npz), + str(val_h5), 5, 30, "val", - str(val_npz), + str(val_h5), 0.75, - -1, + 0, ] ) @@ -246,12 +284,20 @@ def test_build_dataset_from_review_resamples_mixed_fps_and_preserves_weights( build_dataset_from_review.main() - train = np.load(output_dir / "train" / "shard_000.npz", allow_pickle=True) - val = np.load(output_dir / "val" / "shard_000.npz", allow_pickle=True) - assert int(train["fps"]) == 30 - assert int(val["fps"]) == 30 - assert train["clip_weights"].tolist() == [2.5] - assert val["clip_weights"].tolist() == [0.75] + assert (output_dir / "train" / "manifest.json").is_file() + assert (output_dir / "val" / "manifest.json").is_file() + with h5py.File(output_dir / "train" / "shard_000.h5", "r") as train: + assert int(train.attrs["fps"]) == 30 + assert train["clip_weights"][()].tolist() == [2.5] + assert train["source_clip_lengths"][()].tolist() == [5] + with h5py.File(output_dir / "val" / "shard_000.h5", "r") as val: + assert int(val.attrs["fps"]) == 30 + assert val["clip_weights"][()].tolist() == [0.75] + + with (output_dir / "manifest_resolved.csv").open("r", encoding="utf-8", newline="") as f: + rows = list(csv.DictReader(f)) + assert rows[0]["resolved_npz_path"].endswith("train/shard_000.h5") + assert rows[0]["clip_index"] == "0" def test_init_review_manifest_preserves_weight_metadata(tmp_path: Path, monkeypatch) -> None: diff --git a/tests/test_train_script.py b/tests/test_train_script.py index f92aa81c..8d1f8e96 100644 --- a/tests/test_train_script.py +++ b/tests/test_train_script.py @@ -33,6 +33,12 @@ def _args(**overrides: object) -> argparse.Namespace: "experiment_name": None, "motion_file": "data/datasets/twist2/train", "resume": None, + "sampling_mode": None, + "rewind_prob": None, + "rewind_min_steps": None, + "rewind_max_steps": None, + "cache_num_clips": None, + "cache_swap_interval_steps": None, "device": None, "gpu_ids": None, "master_port": 29500, @@ -236,7 +242,11 @@ def test_tracking_runner_configs_disable_model_upload() -> None: def test_validate_motion_file_accepts_shard_directories(tmp_path: Path) -> None: - (tmp_path / "shard_000.npz").write_bytes(b"placeholder") + (tmp_path / "manifest.json").write_text( + '{"format":"teleopit_motion_hdf5","version":1,"shards":[{"path":"shard_000.h5"}]}', + encoding="utf-8", + ) + (tmp_path / "shard_000.h5").write_bytes(b"placeholder") validate_motion_file(str(tmp_path)) diff --git a/train_mimic/app.py b/train_mimic/app.py index e261086e..5d619f53 100644 --- a/train_mimic/app.py +++ b/train_mimic/app.py @@ -17,11 +17,12 @@ def validate_motion_file(motion_file: str) -> None: p = Path(motion_file) - if p.is_dir() and any(p.glob("*.npz")): + manifest = p / "manifest.json" + if p.is_dir() and manifest.is_file() and any(p.glob("*.h5")): return raise FileNotFoundError( f"Motion shard directory not found: {motion_file}. Provide --motion_file " - f"pointing to a directory of shard NPZ files. " + f"pointing to an HDF5 split directory with manifest.json and shard_*.h5 files. " f"Example: {DEFAULT_TRAIN_MOTION_FILE}" ) diff --git a/train_mimic/data/dataset_builder.py b/train_mimic/data/dataset_builder.py index eb7a504c..2e20cdbb 100644 --- a/train_mimic/data/dataset_builder.py +++ b/train_mimic/data/dataset_builder.py @@ -19,12 +19,17 @@ import numpy as np from train_mimic.data.dataset_lib import ( + DEFAULT_HDF5_MAX_WINDOW_FRAMES, + DEFAULT_HDF5_WINDOW_OVERLAP_FRAMES, hash_split, inspect_clip_dict, inspect_npz, merge_npz_files, resample_along_time, + read_hdf5_body_names, utc_now_iso, + write_hdf5_manifest, + write_hdf5_motion_shard, write_json, ) from train_mimic.data.preprocess import ( @@ -113,7 +118,7 @@ class DatasetClipRow: resolved_split: str resolved_npz_path: str weight: float = 1.0 - clip_index: int = -1 # index into shard NPZ clip_starts/clip_lengths; -1 = standalone clip + clip_index: int = -1 # index into source clip metadata; -1 = standalone clip @dataclass(frozen=True) @@ -418,7 +423,7 @@ def split_output_dir(dataset_dir: Path, split: str) -> Path: def shard_output_path(split_dir: Path, shard_index: int) -> Path: - return split_dir / f"shard_{shard_index:03d}.npz" + return split_dir / f"shard_{shard_index:03d}.h5" def resolve_source_input_path(source: DatasetSourceSpec) -> Path: @@ -1231,7 +1236,7 @@ def _batch_convert_chunk( label: str, preprocess: DatasetPreprocessSpec, ) -> dict[str, Any]: - """Worker: convert a batch of PKL/seed_csv files and write one merged chunk NPZ. + """Worker: convert a batch of PKL/seed_csv files and write one HDF5 shard. Designed to run in a spawned subprocess via ProcessPoolExecutor. """ @@ -1309,6 +1314,7 @@ def _batch_convert_chunk( "fps": target_fps, "duration_s": 0.0, "clip_lengths": [], + "source_clip_lengths": [], "kept_file_paths": [], } @@ -1326,8 +1332,12 @@ def _batch_convert_chunk( merged["clip_fps"] = np.full(kept, target_fps, dtype=np.int64) merged["clip_weights"] = np.array(clip_weights, dtype=np.float64) - Path(output_path).parent.mkdir(parents=True, exist_ok=True) - np.savez(output_path, **merged) + shard_info = write_hdf5_motion_shard( + merged, + Path(output_path), + max_window_frames=DEFAULT_HDF5_MAX_WINDOW_FRAMES, + overlap_frames=DEFAULT_HDF5_WINDOW_OVERLAP_FRAMES, + ) total_frames = int(merged["joint_pos"].shape[0]) print( @@ -1335,14 +1345,17 @@ def _batch_convert_chunk( f"{filtered} filtered, {total_frames} frames -> {Path(output_path).name}", flush=True, ) + hdf5_windows = int(shard_info["clips"]) return { "output": output_path, - "clips": kept, - "num_clips": kept, + "clips": hdf5_windows, + "num_clips": hdf5_windows, + "source_clips": kept, "frames": total_frames, "fps": target_fps, "duration_s": total_frames / max(target_fps, 1), - "clip_lengths": clip_lengths, + "clip_lengths": list(shard_info["clip_lengths"]), + "source_clip_lengths": clip_lengths, "kept_file_paths": kept_file_paths, } @@ -1354,7 +1367,12 @@ def _shard_stats( fps: int, ) -> dict[str, Any]: total_clips = sum(len(info["clip_lengths"]) for info in shard_infos) - total_frames = sum(sum(int(length) for length in info["clip_lengths"]) for info in shard_infos) + total_frames = sum(int(info.get("frames", 0)) for info in shard_infos) + if total_frames <= 0: + total_frames = sum( + sum(int(length) for length in info.get("source_clip_lengths", info["clip_lengths"])) + for info in shard_infos + ) return { "output": str(output_dir), "shards": len(shard_infos), @@ -1392,9 +1410,14 @@ def _batch_convert_split( raise ValueError(f"no valid clips remain for split {split_name} after preprocessing") shard_infos = [{ "path": shard_path, + "clips": int(stats.get("clips", 0)), + "frames": int(stats.get("frames", 0)), "clip_lengths": list(stats.pop("clip_lengths", [])), + "source_clip_lengths": list(stats.pop("source_clip_lengths", [])), "kept_file_paths": list(stats.pop("kept_file_paths", [])), }] + if body_names := read_hdf5_body_names(shard_path): + write_hdf5_manifest(output_dir, shard_infos=shard_infos, fps=target_fps, body_names=body_names) return _shard_stats(output_dir=output_dir, shard_infos=shard_infos, fps=target_fps), shard_infos # Split into chunks, one per worker @@ -1405,7 +1428,7 @@ def _batch_convert_split( end = min(start + chunk_size, len(clips)) if start >= len(clips): break - chunk_out = str(output_dir / f".{split_name}_chunk_{i}.npz") + chunk_out = str(output_dir / f".{split_name}_chunk_{i}.h5") chunk_args.append(( file_paths[start:end], weights[start:end], @@ -1446,9 +1469,14 @@ def _batch_convert_split( raise ValueError(f"no valid clips remain for split {split_name} after preprocessing") shard_infos = [{ "path": shard_path, + "clips": int(stats.get("clips", 0)), + "frames": int(stats.get("frames", 0)), "clip_lengths": list(stats.pop("clip_lengths", [])), + "source_clip_lengths": list(stats.pop("source_clip_lengths", [])), "kept_file_paths": list(stats.pop("kept_file_paths", [])), }] + if body_names := read_hdf5_body_names(shard_path): + write_hdf5_manifest(output_dir, shard_infos=shard_infos, fps=target_fps, body_names=body_names) return _shard_stats(output_dir=output_dir, shard_infos=shard_infos, fps=target_fps), shard_infos shard_infos: list[dict[str, Any]] = [] @@ -1464,7 +1492,10 @@ def _batch_convert_split( tmp_path.replace(final_path) shard_infos.append({ "path": final_path, + "clips": int(chunk_stat.get("clips", 0)), + "frames": int(chunk_stat.get("frames", 0)), "clip_lengths": list(chunk_stat.get("clip_lengths", [])), + "source_clip_lengths": list(chunk_stat.get("source_clip_lengths", [])), "kept_file_paths": list(chunk_stat.get("kept_file_paths", [])), }) @@ -1472,6 +1503,8 @@ def _batch_convert_split( raise ValueError(f"no valid clips remain for split {split_name} after preprocessing") stats = _shard_stats(output_dir=output_dir, shard_infos=shard_infos, fps=target_fps) + first_body_names = read_hdf5_body_names(Path(shard_infos[0]["path"])) + write_hdf5_manifest(output_dir, shard_infos=shard_infos, fps=target_fps, body_names=first_body_names) print( f"[SHARDS] {split_name}: {stats['shards']} shards, " f"{stats['clips']} clips, {stats['frames']} frames ({stats['duration_s']:.1f}s)", @@ -1571,7 +1604,7 @@ def _build_rows_for_shards( for shard in shard_infos: shard_path = Path(shard["path"]) kept_paths = list(shard["kept_file_paths"]) - clip_lengths = list(shard["clip_lengths"]) + clip_lengths = list(shard.get("source_clip_lengths", shard["clip_lengths"])) if len(kept_paths) != len(clip_lengths): raise ValueError( f"kept path count mismatch for {shard_path}: {len(kept_paths)} vs {len(clip_lengths)}" @@ -1680,27 +1713,42 @@ def build_dataset_from_spec( val_dir = split_output_dir(paths.dataset_dir, "val") train_out = shard_output_path(train_dir, 0) val_out = shard_output_path(val_dir, 0) + train_tmp_npz = train_dir / ".merged_train_tmp.npz" + val_tmp_npz = val_dir / ".merged_val_tmp.npz" train_stats = merge_npz_files( train_files, - train_out, + train_tmp_npz, target_fps=spec.target_fps, weights=train_weights, ) val_stats = merge_npz_files( val_files, - val_out, + val_tmp_npz, target_fps=spec.target_fps, weights=val_weights, ) + train_npz = np.load(train_tmp_npz, allow_pickle=True) + train_payload = {key: train_npz[key] for key in train_npz.files} + train_h5_info = write_hdf5_motion_shard(train_payload, train_out) + val_npz = np.load(val_tmp_npz, allow_pickle=True) + val_payload = {key: val_npz[key] for key in val_npz.files} + val_h5_info = write_hdf5_motion_shard(val_payload, val_out) + train_tmp_npz.unlink(missing_ok=True) + val_tmp_npz.unlink(missing_ok=True) + train_stats["output"] = str(train_dir) train_stats["shards"] = 1 + train_stats["clips"] = int(train_h5_info["clips"]) + train_stats["num_clips"] = int(train_h5_info["clips"]) val_stats["output"] = str(val_dir) val_stats["shards"] = 1 + val_stats["clips"] = int(val_h5_info["clips"]) + val_stats["num_clips"] = int(val_h5_info["clips"]) - train_clip_lengths = np.load(train_out, allow_pickle=True)["clip_lengths"] - val_clip_lengths = np.load(val_out, allow_pickle=True)["clip_lengths"] + train_clip_lengths = np.asarray(train_payload["clip_lengths"]) + val_clip_lengths = np.asarray(val_payload["clip_lengths"]) if len(train_rows) != len(train_clip_lengths): raise ValueError( f"train row count mismatch: {len(train_rows)} vs {len(train_clip_lengths)}" @@ -1725,6 +1773,19 @@ def build_dataset_from_spec( clip_index=clip_index, ) ) + + write_hdf5_manifest( + train_dir, + shard_infos=[train_h5_info], + fps=spec.target_fps, + body_names=np.asarray(train_payload["body_names"]), + ) + write_hdf5_manifest( + val_dir, + shard_infos=[val_h5_info], + fps=spec.target_fps, + body_names=np.asarray(val_payload["body_names"]), + ) for clip_index, (row, num_frames) in enumerate(zip(val_rows, val_clip_lengths)): updated_rows.append( DatasetClipRow( diff --git a/train_mimic/data/dataset_lib.py b/train_mimic/data/dataset_lib.py index 71bfaa12..8c11aff8 100644 --- a/train_mimic/data/dataset_lib.py +++ b/train_mimic/data/dataset_lib.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -"""Shared utilities for the active NPZ-based dataset pipeline.""" +"""Shared utilities for motion dataset build and runtime loading.""" from __future__ import annotations @@ -10,6 +10,7 @@ from pathlib import Path from typing import Any, Mapping, Sequence +import h5py import numpy as np REQUIRED_NPZ_KEYS = [ @@ -24,6 +25,13 @@ ] NUM_ACTIONS = 29 +MOTION_ARRAY_KEYS = [ + "joint_pos", "joint_vel", "body_pos_w", "body_quat_w", + "body_lin_vel_w", "body_ang_vel_w", +] +HDF5_DATASET_VERSION = 1 +DEFAULT_HDF5_MAX_WINDOW_FRAMES = 512 +DEFAULT_HDF5_WINDOW_OVERLAP_FRAMES = 64 @dataclass(frozen=True) @@ -315,43 +323,40 @@ def merge_npz_files( } -def extract_clip_arrays(npz_path: Path, clip_index: int) -> dict[str, Any]: - """Extract a single clip's arrays from a merged NPZ by clip index. +def merge_clip_dicts( + clip_dicts: list[dict[str, Any]], + output_path: Path, + *, + target_fps: int | None = None, + weights: list[float] | None = None, +) -> dict[str, Any]: + """Merge a list of in-memory clip array dicts into a single NPZ.""" + merged = merge_clip_dicts_payload( + clip_dicts, + target_fps=target_fps, + weights=weights, + ) + output_path.parent.mkdir(parents=True, exist_ok=True) + np.savez(output_path, **merged) - Returns a dict with the same keys as a standalone clip NPZ: - fps, joint_pos, joint_vel, body_pos_w, body_quat_w, - body_lin_vel_w, body_ang_vel_w, body_names. - """ - d = np.load(npz_path, allow_pickle=True) - clip_starts = d["clip_starts"] - clip_lengths = d["clip_lengths"] - if clip_index < 0 or clip_index >= len(clip_starts): - raise IndexError( - f"clip_index {clip_index} out of range [0, {len(clip_starts)}) in {npz_path}" - ) - start = int(clip_starts[clip_index]) - length = int(clip_lengths[clip_index]) - s = slice(start, start + length) + total_frames = int(merged["joint_pos"].shape[0]) return { - "fps": int(d["fps"]), - "joint_pos": np.asarray(d["joint_pos"][s]), - "joint_vel": np.asarray(d["joint_vel"][s]), - "body_pos_w": np.asarray(d["body_pos_w"][s]), - "body_quat_w": np.asarray(d["body_quat_w"][s]), - "body_lin_vel_w": np.asarray(d["body_lin_vel_w"][s]), - "body_ang_vel_w": np.asarray(d["body_ang_vel_w"][s]), - "body_names": np.asarray(d["body_names"]), + "output": str(output_path), + "clips": len(clip_dicts), + "num_clips": len(clip_dicts), + "frames": total_frames, + "fps": int(merged["fps"]), + "duration_s": float(total_frames / max(int(merged["fps"]), 1)), } -def merge_clip_dicts( +def merge_clip_dicts_payload( clip_dicts: list[dict[str, Any]], - output_path: Path, *, target_fps: int | None = None, weights: list[float] | None = None, ) -> dict[str, Any]: - """Merge a list of in-memory clip array dicts into a single NPZ. + """Merge in-memory clip array dicts and return a flat motion payload. Each dict must have keys: fps, joint_pos, joint_vel, body_pos_w, body_quat_w, body_lin_vel_w, body_ang_vel_w, body_names. @@ -415,19 +420,227 @@ def merge_clip_dicts( merged["clip_lengths"] = clip_lengths merged["clip_fps"] = np.array(per_clip_fps, dtype=np.int64) merged["clip_weights"] = np.array(per_clip_weights, dtype=np.float64) + return merged - output_path.parent.mkdir(parents=True, exist_ok=True) - np.savez(output_path, **merged) - total_frames = int(merged["joint_pos"].shape[0]) +def _window_clip_ranges( + *, + clip_start: int, + clip_length: int, + max_window_frames: int, + overlap_frames: int, +) -> list[tuple[int, int]]: + if clip_length <= 0: + raise ValueError(f"clip_length must be > 0, got {clip_length}") + if max_window_frames <= 1: + raise ValueError(f"max_window_frames must be > 1, got {max_window_frames}") + if overlap_frames < 0 or overlap_frames >= max_window_frames: + raise ValueError( + "overlap_frames must be in [0, max_window_frames), got " + f"{overlap_frames} for max_window_frames={max_window_frames}" + ) + if clip_length <= max_window_frames: + return [(clip_start, clip_length)] + + stride = max_window_frames - overlap_frames + starts = list(range(0, max(clip_length - max_window_frames + 1, 1), stride)) + tail_start = clip_length - max_window_frames + if starts[-1] != tail_start: + starts.append(tail_start) + return [(clip_start + int(start), max_window_frames) for start in starts] + + +def write_hdf5_motion_shard( + merged: Mapping[str, Any], + output_path: Path, + *, + max_window_frames: int = DEFAULT_HDF5_MAX_WINDOW_FRAMES, + overlap_frames: int = DEFAULT_HDF5_WINDOW_OVERLAP_FRAMES, +) -> dict[str, Any]: + """Write a merged motion payload as one HDF5 shard with bounded windows. + + The frame arrays remain flat in the HDF5 file. ``clip_starts`` and + ``clip_lengths`` describe training windows, not necessarily original clips. + Long clips are split into overlapping windows to bound runtime cache size. + """ + missing = [key for key in [*MOTION_ARRAY_KEYS, "fps", "body_names", "clip_starts", "clip_lengths", "clip_fps", "clip_weights"] if key not in merged] + if missing: + raise ValueError(f"merged payload missing required keys: {missing}") + + fps = int(merged["fps"]) + if fps <= 0: + raise ValueError(f"fps must be > 0, got {fps}") + body_names = np.asarray(merged["body_names"]).astype(str) + original_starts = np.asarray(merged["clip_starts"], dtype=np.int64) + original_lengths = np.asarray(merged["clip_lengths"], dtype=np.int64) + original_fps = np.asarray(merged["clip_fps"], dtype=np.int64) + original_weights = np.asarray(merged["clip_weights"], dtype=np.float64) + + window_starts: list[int] = [] + window_lengths: list[int] = [] + window_fps: list[int] = [] + window_weights: list[float] = [] + source_clip_ids: list[int] = [] + source_start_frames: list[int] = [] + for source_idx, (clip_start, clip_length) in enumerate(zip(original_starts, original_lengths)): + ranges = _window_clip_ranges( + clip_start=int(clip_start), + clip_length=int(clip_length), + max_window_frames=max_window_frames, + overlap_frames=overlap_frames, + ) + per_window_weight = float(original_weights[source_idx]) / float(len(ranges)) + for start, length in ranges: + window_starts.append(int(start)) + window_lengths.append(int(length)) + window_fps.append(int(original_fps[source_idx])) + window_weights.append(per_window_weight) + source_clip_ids.append(int(source_idx)) + source_start_frames.append(int(start - int(clip_start))) + + output_path.parent.mkdir(parents=True, exist_ok=True) + str_dt = h5py.string_dtype(encoding="utf-8") + with h5py.File(output_path, "w") as h5: + h5.attrs["format"] = "teleopit_motion_hdf5" + h5.attrs["version"] = HDF5_DATASET_VERSION + h5.attrs["fps"] = fps + h5.attrs["max_window_frames"] = int(max_window_frames) + h5.attrs["overlap_frames"] = int(overlap_frames) + h5.create_dataset("body_names", data=body_names.astype(object), dtype=str_dt) + for key in MOTION_ARRAY_KEYS: + arr = np.asarray(merged[key], dtype=np.float32) + h5.create_dataset(key, data=arr, chunks=True) + h5.create_dataset("clip_starts", data=np.asarray(window_starts, dtype=np.int64)) + h5.create_dataset("clip_lengths", data=np.asarray(window_lengths, dtype=np.int64)) + h5.create_dataset("clip_fps", data=np.asarray(window_fps, dtype=np.int64)) + h5.create_dataset("clip_weights", data=np.asarray(window_weights, dtype=np.float64)) + h5.create_dataset("source_clip_ids", data=np.asarray(source_clip_ids, dtype=np.int64)) + h5.create_dataset("source_start_frames", data=np.asarray(source_start_frames, dtype=np.int64)) + h5.create_dataset("source_clip_starts", data=original_starts.astype(np.int64)) + h5.create_dataset("source_clip_lengths", data=original_lengths.astype(np.int64)) + h5.create_dataset("source_clip_fps", data=original_fps.astype(np.int64)) + h5.create_dataset("source_clip_weights", data=original_weights.astype(np.float64)) + + total_frames = int(np.asarray(merged["joint_pos"]).shape[0]) return { - "output": str(output_path), - "clips": len(clip_dicts), - "num_clips": len(clip_dicts), + "path": str(output_path), + "clips": len(window_lengths), + "num_clips": len(window_lengths), + "source_clips": len(original_lengths), + "frames": total_frames, + "fps": fps, + "duration_s": float(total_frames / max(fps, 1)), + "clip_lengths": [int(v) for v in window_lengths], + "source_clip_lengths": [int(v) for v in original_lengths], + } + + +def read_hdf5_body_names(path: Path) -> list[str]: + with h5py.File(path, "r") as h5: + return [ + str(name.decode("utf-8") if isinstance(name, bytes) else name) + for name in h5["body_names"][()] + ] + + +def read_motion_clip(path: Path, clip_index: int) -> dict[str, Any]: + """Read one source clip from a current HDF5 motion shard path. + + HDF5 shards use source-clip metadata, so ``clip_index`` indexes original + clips, not bounded training windows. + """ + if path.suffix == ".h5": + return read_hdf5_source_clip(path, clip_index) + raise ValueError( + f"review/rebuild input must be a current HDF5 shard (.h5), got: {path}" + ) + + +def read_hdf5_source_clip(path: Path, clip_index: int) -> dict[str, Any]: + if clip_index < 0: + raise ValueError(f"HDF5 shard rows require clip_index >= 0: {path}") + with h5py.File(path, "r") as h5: + required = [ + "source_clip_starts", + "source_clip_lengths", + "source_clip_fps", + "body_names", + *MOTION_ARRAY_KEYS, + ] + missing = [key for key in required if key not in h5] + if missing: + raise ValueError( + f"HDF5 shard {path} is missing source clip metadata {missing}. " + "Rebuild the dataset with the current HDF5 writer." + ) + starts = np.asarray(h5["source_clip_starts"], dtype=np.int64) + lengths = np.asarray(h5["source_clip_lengths"], dtype=np.int64) + fps = np.asarray(h5["source_clip_fps"], dtype=np.int64) + if clip_index >= len(starts): + raise IndexError( + f"clip_index {clip_index} out of range [0, {len(starts)}) in {path}" + ) + start = int(starts[clip_index]) + length = int(lengths[clip_index]) + sl = slice(start, start + length) + return { + "fps": int(fps[clip_index]), + "joint_pos": np.asarray(h5["joint_pos"][sl], dtype=np.float32), + "joint_vel": np.asarray(h5["joint_vel"][sl], dtype=np.float32), + "body_pos_w": np.asarray(h5["body_pos_w"][sl], dtype=np.float32), + "body_quat_w": np.asarray(h5["body_quat_w"][sl], dtype=np.float32), + "body_lin_vel_w": np.asarray(h5["body_lin_vel_w"][sl], dtype=np.float32), + "body_ang_vel_w": np.asarray(h5["body_ang_vel_w"][sl], dtype=np.float32), + "body_names": np.asarray(read_hdf5_body_names(path), dtype=str), + } + + +def write_hdf5_manifest( + split_dir: Path, + *, + shard_infos: Sequence[Mapping[str, Any]], + fps: int, + body_names: Sequence[str] | np.ndarray, +) -> Path: + shards = [] + total_windows = 0 + total_frames = 0 + expected_body_names = [str(name) for name in np.asarray(body_names).tolist()] + for info in shard_infos: + path = Path(str(info["path"])) + shard_path = path if path.is_absolute() else split_dir / path + if shard_path.is_file(): + actual_body_names = read_hdf5_body_names(shard_path) + if actual_body_names != expected_body_names: + raise ValueError( + f"HDF5 shard body_names mismatch for {shard_path}: " + "all shards in a split must use the same body order" + ) + if path.is_absolute(): + rel_path = path.name if path.parent == split_dir else str(path.relative_to(split_dir)) + else: + rel_path = str(path) + clips = int(info.get("clips", info.get("num_clips", 0))) + frames = int(info.get("frames", 0)) + total_windows += clips + total_frames += frames + shards.append({ + "path": rel_path, + "clips": clips, + "frames": frames, + }) + manifest = { + "format": "teleopit_motion_hdf5", + "version": HDF5_DATASET_VERSION, + "fps": int(fps), + "body_names": expected_body_names, + "shards": shards, + "clips": total_windows, "frames": total_frames, - "fps": int(merged["fps"]), - "duration_s": float(total_frames / max(int(merged["fps"]), 1)), } + path = split_dir / "manifest.json" + write_json(path, manifest) + return path def write_json(path: Path, payload: dict[str, Any]) -> None: diff --git a/train_mimic/data/review_lib.py b/train_mimic/data/review_lib.py index 72880737..e77a3700 100644 --- a/train_mimic/data/review_lib.py +++ b/train_mimic/data/review_lib.py @@ -43,7 +43,7 @@ class ReviewRow: fps: int duration_s: float weight: float = 1.0 - clip_index: int = -1 # index into merged NPZ clip_starts/clip_lengths; -1 = standalone clip + clip_index: int = -1 # source-clip index inside HDF5 shard; -1 = standalone source clip decision: str = "" difficulty: str = "" issue_tags: str = "" diff --git a/train_mimic/scripts/benchmark.py b/train_mimic/scripts/benchmark.py index d826dd8a..557df54b 100644 --- a/train_mimic/scripts/benchmark.py +++ b/train_mimic/scripts/benchmark.py @@ -27,6 +27,7 @@ import os from pathlib import Path +import h5py import numpy as np from tensordict import TensorDictBase @@ -142,7 +143,7 @@ def _stats(values: list[float]) -> dict[str, float]: def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Benchmark G1 tracking policy.") parser.add_argument("--checkpoint", type=str, required=True) - parser.add_argument("--motion_file", type=str, required=True, help="Path to motion shard directory") + parser.add_argument("--motion_file", type=str, required=True, help="Path to HDF5 motion shard directory") parser.add_argument("--num_envs", type=int, default=1) parser.add_argument("--num_eval_steps", type=int, default=2000, help="Number of rollout steps for evaluation (default: 2000)") @@ -173,24 +174,34 @@ def parse_args() -> argparse.Namespace: def _load_motion_dir_video_metadata(motion_dir: str) -> tuple[float, int]: shard_dir = Path(motion_dir) - shard_files = sorted(shard_dir.glob("*.npz")) - if not shard_files: - raise FileNotFoundError(f"no shard NPZ files found in {motion_dir}") + manifest_path = shard_dir / "manifest.json" + if not manifest_path.is_file(): + raise FileNotFoundError(f"no HDF5 manifest.json found in {motion_dir}") + with manifest_path.open("r", encoding="utf-8") as f: + manifest = json.load(f) + if manifest.get("format") != "teleopit_motion_hdf5": + raise ValueError(f"unsupported motion manifest format in {manifest_path}") clip_fps: float | None = None max_clip_frames = 0 - for shard_path in shard_files: - motion_data = np.load(shard_path, allow_pickle=True) - cur_fps = float(motion_data["fps"]) - if clip_fps is None: - clip_fps = cur_fps - elif clip_fps != cur_fps: - raise ValueError( - f"inconsistent fps across shards: {shard_path} has {cur_fps}, expected {clip_fps}" - ) - max_clip_frames = max(max_clip_frames, int(np.asarray(motion_data["clip_lengths"]).max())) + for shard in manifest.get("shards", []): + shard_path = shard_dir / str(shard["path"]) + with h5py.File(shard_path, "r") as h5: + fps_arr = np.asarray(h5["clip_fps"], dtype=np.float32) + if fps_arr.size == 0: + continue + cur_fps = float(fps_arr[0]) + if np.any(fps_arr != cur_fps): + raise ValueError(f"inconsistent fps within HDF5 shard: {shard_path}") + if clip_fps is None: + clip_fps = cur_fps + elif clip_fps != cur_fps: + raise ValueError( + f"inconsistent fps across shards: {shard_path} has {cur_fps}, expected {clip_fps}" + ) + max_clip_frames = max(max_clip_frames, int(np.asarray(h5["clip_lengths"]).max())) if clip_fps is None: - raise ValueError(f"failed reading shard metadata from {motion_dir}") + raise ValueError(f"failed reading HDF5 shard metadata from {motion_dir}") return clip_fps, max_clip_frames diff --git a/train_mimic/scripts/convert_pkl_to_npz.py b/train_mimic/scripts/convert_pkl_to_npz.py index a0afb1d7..8edd0aa2 100644 --- a/train_mimic/scripts/convert_pkl_to_npz.py +++ b/train_mimic/scripts/convert_pkl_to_npz.py @@ -1,8 +1,8 @@ #!/usr/bin/env python3 -"""Convert PKL motion files to NPZ format for mjlab MotionCommand. +"""Convert PKL motion files to per-clip NPZ format for dataset building. Reads retargeted PKL files (IsaacGym convention) and converts them to the NPZ -format expected by mjlab's MotionLoader. +clip format consumed by the HDF5 dataset builder. PKL fields: fps : int scalar @@ -23,7 +23,7 @@ body_names : list[str] 30 body names in mjlab G1 robot order IMPORTANT: NPZ body ordering must match mjlab G1 robot body ordering because -mjlab's MotionLoader uses robot body indices to index into body_pos_w. +the dataset builder preserves this order for HDF5 training shards. Usage: # Convert a single file @@ -60,8 +60,8 @@ # mjlab G1 robot body ordering (matches robot.body_names from G1_ROBOT_CFG). -# mjlab's MotionLoader uses robot body indices to index into NPZ body_pos_w, -# so NPZ body ordering MUST match this list exactly. +# The dataset builder preserves this order in HDF5 shards, so per-clip NPZ +# body ordering MUST match this list exactly. _MJLAB_G1_BODY_NAMES = [ "pelvis", "left_hip_pitch_link", "left_hip_roll_link", "left_hip_yaw_link", diff --git a/train_mimic/scripts/data/split_shards.py b/train_mimic/scripts/data/split_shards.py deleted file mode 100644 index 0b8ea602..00000000 --- a/train_mimic/scripts/data/split_shards.py +++ /dev/null @@ -1,217 +0,0 @@ -#!/usr/bin/env python3 -"""Repartition a shard directory into smaller shard NPZ files. - -Each shard is a self-contained merged NPZ (with clip_starts, clip_lengths, -clip_fps, clip_weights, etc.) that can be independently loaded. Splits are -made at clip boundaries — no clip is ever truncated across shards. - -Usage: - python train_mimic/scripts/data/split_shards.py \ - --input data/datasets/seed/train \ - --output data/datasets/seed/train_small_shards \ - --max_size_gb 2 -""" - -from __future__ import annotations - -import argparse -import sys -from pathlib import Path - -import numpy as np - -MOTION_ARRAY_KEYS = [ - "joint_pos", - "joint_vel", - "body_pos_w", - "body_quat_w", - "body_lin_vel_w", - "body_ang_vel_w", -] - - -def _estimate_frames_per_gb(data: dict) -> float: - """Estimate how many frames fit in 1 GB based on array dtypes/shapes.""" - one_frame_bytes = 0 - for k in MOTION_ARRAY_KEYS: - arr = data[k] - one_frame_bytes += int(np.prod(arr.shape[1:])) * arr.dtype.itemsize - return 1e9 / max(one_frame_bytes, 1) - - -def _load_shard_dir(input_dir: Path) -> dict: - shard_files = sorted(input_dir.glob("*.npz")) - if not shard_files: - raise FileNotFoundError(f"no shard NPZ files found in {input_dir}") - - arrays: dict[str, list[np.ndarray]] = {k: [] for k in MOTION_ARRAY_KEYS} - clip_lengths: list[np.ndarray] = [] - clip_fps: list[np.ndarray] = [] - clip_weights: list[np.ndarray] = [] - clip_sample_starts: list[np.ndarray] = [] - clip_sample_ends: list[np.ndarray] = [] - window_steps: np.ndarray | None = None - fps: int | None = None - body_names: np.ndarray | None = None - - for shard_path in shard_files: - data = np.load(shard_path, allow_pickle=True) - for key in MOTION_ARRAY_KEYS: - arrays[key].append(np.asarray(data[key])) - if fps is None: - fps = int(data["fps"]) - body_names = np.asarray(data["body_names"]) - else: - if int(data["fps"]) != fps: - raise ValueError( - f"inconsistent fps across shards: {shard_path} has {int(data['fps'])}, expected {fps}" - ) - if not np.array_equal(np.asarray(data["body_names"]), body_names): - raise ValueError(f"inconsistent body_names across shards: {shard_path}") - clip_lengths.append(np.asarray(data["clip_lengths"])) - clip_fps.append(np.asarray(data["clip_fps"])) - clip_weights.append(np.asarray(data["clip_weights"])) - if "clip_sample_starts" in data: - clip_sample_starts.append(np.asarray(data["clip_sample_starts"])) - if "clip_sample_ends" in data: - clip_sample_ends.append(np.asarray(data["clip_sample_ends"])) - if "window_steps" in data and window_steps is None: - window_steps = np.asarray(data["window_steps"]) - - merged = {key: np.concatenate(values, axis=0) for key, values in arrays.items()} - merged["fps"] = fps - merged["body_names"] = body_names - merged["clip_lengths"] = np.concatenate(clip_lengths) - merged["clip_fps"] = np.concatenate(clip_fps) - merged["clip_weights"] = np.concatenate(clip_weights) - merged["clip_starts"] = np.zeros(len(merged["clip_lengths"]), dtype=np.int64) - if len(merged["clip_lengths"]) > 1: - merged["clip_starts"][1:] = np.cumsum(merged["clip_lengths"][:-1]) - if clip_sample_starts and sum(len(v) for v in clip_sample_starts) == len(merged["clip_lengths"]): - merged["clip_sample_starts"] = np.concatenate(clip_sample_starts) - if clip_sample_ends and sum(len(v) for v in clip_sample_ends) == len(merged["clip_lengths"]): - merged["clip_sample_ends"] = np.concatenate(clip_sample_ends) - if window_steps is not None: - merged["window_steps"] = window_steps - return merged - - -def split_shards( - input_path: Path, - output_dir: Path, - max_size_gb: float = 2.0, -) -> list[Path]: - """Split a shard directory into smaller shards of approximately *max_size_gb* each.""" - print(f"Loading {input_path} ...") - data = _load_shard_dir(input_path) - - clip_starts = np.asarray(data["clip_starts"]) - clip_lengths = np.asarray(data["clip_lengths"]) - num_clips = len(clip_starts) - total_frames = int(data["joint_pos"].shape[0]) - - has_fps_array = "clip_fps" in data - has_weights = "clip_weights" in data - has_window_steps = "window_steps" in data - has_sample_starts = "clip_sample_starts" in data - has_sample_ends = "clip_sample_ends" in data - fps = data["fps"] - body_names = data["body_names"] - - frames_per_gb = _estimate_frames_per_gb(data) - max_frames_per_shard = int(frames_per_gb * max_size_gb) - - # --- plan shard boundaries (by clip index) --- - shard_ranges: list[tuple[int, int]] = [] # (clip_start_idx, clip_end_idx) - cur_start = 0 - cur_frames = 0 - for i in range(num_clips): - cl = int(clip_lengths[i]) - if cur_frames + cl > max_frames_per_shard and cur_frames > 0: - shard_ranges.append((cur_start, i)) - cur_start = i - cur_frames = 0 - cur_frames += cl - if cur_start < num_clips: - shard_ranges.append((cur_start, num_clips)) - - print( - f" {num_clips} clips, {total_frames} frames -> " - f"{len(shard_ranges)} shards (target ~{max_size_gb} GB each)" - ) - - # --- write shards --- - output_dir.mkdir(parents=True, exist_ok=True) - shard_paths: list[Path] = [] - n_digits = max(3, len(str(len(shard_ranges) - 1))) - - for shard_idx, (c_start, c_end) in enumerate(shard_ranges): - frame_start = int(clip_starts[c_start]) - if c_end < num_clips: - frame_end = int(clip_starts[c_end]) - else: - frame_end = total_frames - - s = slice(frame_start, frame_end) - shard: dict = {} - for k in MOTION_ARRAY_KEYS: - shard[k] = np.asarray(data[k][s]) - - # Rebuild clip metadata relative to this shard - shard_clip_lengths = clip_lengths[c_start:c_end].copy() - shard_clip_starts = np.zeros(c_end - c_start, dtype=np.int64) - if len(shard_clip_lengths) > 1: - shard_clip_starts[1:] = np.cumsum(shard_clip_lengths[:-1]) - - shard["fps"] = fps - shard["body_names"] = body_names - shard["clip_starts"] = shard_clip_starts - shard["clip_lengths"] = shard_clip_lengths - if has_fps_array: - shard["clip_fps"] = np.asarray(data["clip_fps"][c_start:c_end]) - if has_weights: - shard["clip_weights"] = np.asarray(data["clip_weights"][c_start:c_end]) - if has_window_steps: - shard["window_steps"] = np.asarray(data["window_steps"]) - if has_sample_starts: - shard["clip_sample_starts"] = np.asarray(data["clip_sample_starts"][c_start:c_end]) - if has_sample_ends: - shard["clip_sample_ends"] = np.asarray(data["clip_sample_ends"][c_start:c_end]) - - shard_name = f"shard_{shard_idx:0{n_digits}d}.npz" - shard_path = output_dir / shard_name - np.savez(shard_path, **shard) - - shard_frames = frame_end - frame_start - shard_size_mb = shard_path.stat().st_size / 1e6 - print( - f" {shard_name}: {c_end - c_start} clips, " - f"{shard_frames} frames, {shard_size_mb:.0f} MB" - ) - shard_paths.append(shard_path) - - print(f"Done. {len(shard_paths)} shards written to {output_dir}") - return shard_paths - - -def main() -> None: - parser = argparse.ArgumentParser(description="Repartition a shard directory into smaller shards") - parser.add_argument("--input", required=True, type=Path, help="Input shard directory") - parser.add_argument("--output", required=True, type=Path, help="Output shard directory") - parser.add_argument( - "--max_size_gb", - type=float, - default=2.0, - help="Target max size per shard in GB (default: 2.0)", - ) - args = parser.parse_args() - - if not args.input.is_dir(): - print(f"Error: shard directory not found: {args.input}", file=sys.stderr) - sys.exit(1) - - split_shards(args.input, args.output, args.max_size_gb) - - -if __name__ == "__main__": - main() diff --git a/train_mimic/scripts/play.py b/train_mimic/scripts/play.py index 33069710..f64e2359 100644 --- a/train_mimic/scripts/play.py +++ b/train_mimic/scripts/play.py @@ -44,7 +44,7 @@ def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Play trained G1 tracking policy.") parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint") - parser.add_argument("--motion_file", type=str, required=True, help="Path to motion shard directory") + parser.add_argument("--motion_file", type=str, required=True, help="Path to HDF5 motion shard directory") parser.add_argument("--num_envs", type=int, default=1) parser.add_argument( "--viewer", type=str, default="native", choices=["native", "viser"], diff --git a/train_mimic/scripts/train.py b/train_mimic/scripts/train.py index bd40db01..194f4759 100644 --- a/train_mimic/scripts/train.py +++ b/train_mimic/scripts/train.py @@ -74,7 +74,7 @@ def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace: ) parser.add_argument("--experiment_name", type=str, default=None) parser.add_argument("--motion_file", type=str, default=None, - help="Shard directory path containing shard_*.npz files") + help="HDF5 shard directory path containing manifest.json and shard_*.h5 files") parser.add_argument( "--resume", type=str, @@ -93,6 +93,10 @@ def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace: help="Minimum policy steps to rewind for rewind sampling") parser.add_argument("--rewind_max_steps", type=int, default=None, help="Maximum policy steps to rewind for rewind sampling") + parser.add_argument("--cache_num_clips", type=int, default=None, + help="Number of HDF5 motion windows to keep in the active subset cache") + parser.add_argument("--cache_swap_interval_steps", type=int, default=None, + help="Policy steps between HDF5 motion cache swaps; swaps occur at rollout barriers") parser.add_argument("--device", type=str, default=None) parser.add_argument( "--gpu_ids", @@ -368,6 +372,7 @@ def _handle_shutdown(signum: int, _frame: Any) -> None: # CLI overrides env_cfg.seed = _resolve_worker_seed(args.seed) + env_cfg.commands["motion"].cache_seed = env_cfg.seed if args.num_envs is not None: env_cfg.scene.num_envs = args.num_envs if args.motion_file is not None: @@ -381,6 +386,10 @@ def _handle_shutdown(signum: int, _frame: Any) -> None: env_cfg.commands["motion"].rewind_min_steps = args.rewind_min_steps if args.rewind_max_steps is not None: env_cfg.commands["motion"].rewind_max_steps = args.rewind_max_steps + if args.cache_num_clips is not None: + env_cfg.commands["motion"].cache_num_clips = args.cache_num_clips + if args.cache_swap_interval_steps is not None: + env_cfg.commands["motion"].cache_swap_interval_steps = args.cache_swap_interval_steps if args.max_iterations is not None: agent_cfg.max_iterations = args.max_iterations if args.experiment_name is not None: diff --git a/train_mimic/tasks/tracking/mdp/commands.py b/train_mimic/tasks/tracking/mdp/commands.py index cc9cfbda..454c7e4c 100644 --- a/train_mimic/tasks/tracking/mdp/commands.py +++ b/train_mimic/tasks/tracking/mdp/commands.py @@ -1,16 +1,23 @@ from __future__ import annotations import copy +import json import logging from dataclasses import dataclass, field from pathlib import Path from typing import TYPE_CHECKING, Any, Literal +import h5py import mujoco import numpy as np import torch -from train_mimic.data.dataset_lib import compute_clip_sample_ranges, parse_window_steps +from train_mimic.data.dataset_lib import ( + MOTION_ARRAY_KEYS, + compute_clip_sample_ranges, + parse_window_steps, + read_hdf5_body_names, +) from mjlab.managers import CommandTerm, CommandTermCfg from mjlab.utils.lab_api.math import ( @@ -57,113 +64,168 @@ def _batched_quat_slerp( return result / result.norm(dim=-1, keepdim=True) -def _load_shard_dir(shard_dir: Path) -> dict[str, Any]: - """Load and merge all shard NPZ files from a directory. +@dataclass(frozen=True) +class _Hdf5ClipRef: + shard_index: int + start: int + length: int + fps: int + weight: float - Each shard must be a self-contained merged NPZ with clip metadata. - This function concatenates motion arrays across shards and rebuilds the - clip-level metadata (``clip_starts`` offsets are adjusted). - """ - shard_files = sorted(shard_dir.glob("*.npz")) - if not shard_files: - raise FileNotFoundError(f"No .npz files found in {shard_dir}") - - _LOG.info("Loading %d shards from %s ...", len(shard_files), shard_dir) - - motion_keys = [ - "joint_pos", "joint_vel", "body_pos_w", "body_quat_w", - "body_lin_vel_w", "body_ang_vel_w", - ] - arrays: dict[str, list[np.ndarray]] = {k: [] for k in motion_keys} - all_clip_lengths: list[np.ndarray] = [] - all_clip_fps: list[np.ndarray] = [] - all_clip_weights: list[np.ndarray] = [] - all_clip_sample_starts: list[np.ndarray] = [] - all_clip_sample_ends: list[np.ndarray] = [] - window_steps: np.ndarray | None = None - fps = None - body_names: np.ndarray | None = None - - for sf in shard_files: - d = np.load(sf, allow_pickle=True) - for k in motion_keys: - arrays[k].append(np.asarray(d[k])) - - # --- validate metadata consistency across shards --- - cur_fps = d["fps"] - cur_body_names = np.asarray(d["body_names"]) - if fps is None: - fps = cur_fps - body_names = cur_body_names - else: - if int(cur_fps) != int(fps): - raise ValueError( - f"Inconsistent fps across shards: {sf} has {int(cur_fps)}, " - f"expected {int(fps)}" - ) - if not np.array_equal(cur_body_names, body_names): - raise ValueError( - f"Inconsistent body_names across shards: {sf} differs from first shard" - ) - if "clip_lengths" not in d or "clip_fps" not in d or "clip_weights" not in d: - raise ValueError( - f"Shard {sf} is missing required clip metadata. " - "Expected clip_lengths, clip_fps, and clip_weights." +@dataclass +class _MotionBatch: + tensors: dict[str, torch.Tensor] + lengths: torch.Tensor + fps: torch.Tensor + weights: torch.Tensor + sample_starts: torch.Tensor + sample_ends: torch.Tensor + global_ids: torch.Tensor + + +class _Hdf5MotionCache: + def __init__( + self, + motion_dir: Path, + *, + body_idx_np: np.ndarray, + device: str, + window_steps: tuple[int, ...], + cache_num_clips: int, + seed: int, + ) -> None: + if cache_num_clips <= 0: + raise ValueError(f"cache_num_clips must be positive, got {cache_num_clips}") + manifest_path = motion_dir / "manifest.json" + if not manifest_path.is_file(): + raise FileNotFoundError( + f"motion_file must be an HDF5 split directory containing manifest.json, got: {motion_dir}" + ) + with manifest_path.open("r", encoding="utf-8") as f: + manifest = json.load(f) + if manifest.get("format") != "teleopit_motion_hdf5": + raise ValueError(f"Unsupported motion manifest format in {manifest_path}") + + self.motion_dir = motion_dir + self.body_idx_np = body_idx_np + self.device = device + self.window_steps = window_steps + self.cache_num_clips = int(cache_num_clips) + self._rng = torch.Generator(device="cpu") + self._rng.manual_seed(int(seed)) + self._shard_paths = [motion_dir / shard["path"] for shard in manifest["shards"]] + self.body_names = np.asarray(manifest["body_names"], dtype=str) + + max_future = max((step for step in self.window_steps if step > 0), default=0) + max_history = -min((step for step in self.window_steps if step < 0), default=0) + min_clip_length = max_history + 1 + max_future + 1 # +1 for interpolation + + refs: list[_Hdf5ClipRef] = [] + skipped_short = 0 + for shard_index, shard_path in enumerate(self._shard_paths): + with h5py.File(shard_path, "r") as h5: + shard_body_names = read_hdf5_body_names(shard_path) + if shard_body_names != self.body_names.tolist(): + raise ValueError( + f"HDF5 shard body_names mismatch for {shard_path}: " + "all shards must match manifest body_names order" + ) + starts = np.asarray(h5["clip_starts"], dtype=np.int64) + lengths = np.asarray(h5["clip_lengths"], dtype=np.int64) + fps = np.asarray(h5["clip_fps"], dtype=np.int64) + weights = np.asarray(h5["clip_weights"], dtype=np.float64) + for start, length, cur_fps, weight in zip(starts, lengths, fps, weights): + if int(length) < min_clip_length: + skipped_short += 1 + continue + refs.append(_Hdf5ClipRef( + shard_index=shard_index, + start=int(start), + length=int(length), + fps=int(cur_fps), + weight=float(weight), + )) + if not refs: + raise ValueError(f"HDF5 motion dataset is empty: {motion_dir}") + if skipped_short > 0: + _LOG.warning( + "Ignoring %d HDF5 motion windows shorter than %d frames (window_steps=%s)", + skipped_short, + min_clip_length, + list(self.window_steps), ) - shard_clip_lengths = np.asarray(d["clip_lengths"]) - n_clips = len(shard_clip_lengths) - all_clip_lengths.append(shard_clip_lengths) - all_clip_fps.append(np.asarray(d["clip_fps"])) - all_clip_weights.append(np.asarray(d["clip_weights"])) - - if "clip_sample_starts" in d: - all_clip_sample_starts.append(np.asarray(d["clip_sample_starts"])) - if "clip_sample_ends" in d: - all_clip_sample_ends.append(np.asarray(d["clip_sample_ends"])) - if "window_steps" in d and window_steps is None: - window_steps = np.asarray(d["window_steps"]) - - merged: dict[str, Any] = {k: np.concatenate(v, axis=0) for k, v in arrays.items()} - merged["fps"] = fps - merged["body_names"] = body_names - - clip_lengths = np.concatenate(all_clip_lengths) - clip_starts = np.zeros(len(clip_lengths), dtype=np.int64) - if len(clip_lengths) > 1: - clip_starts[1:] = np.cumsum(clip_lengths[:-1]) - merged["clip_starts"] = clip_starts - merged["clip_lengths"] = clip_lengths - merged["clip_fps"] = np.concatenate(all_clip_fps) - merged["clip_weights"] = np.concatenate(all_clip_weights) - - # Propagate precomputed sample ranges if all shards had them - n_total_clips = len(clip_lengths) - if all_clip_sample_starts and sum(len(a) for a in all_clip_sample_starts) == n_total_clips: - merged["clip_sample_starts"] = np.concatenate(all_clip_sample_starts) - if all_clip_sample_ends and sum(len(a) for a in all_clip_sample_ends) == n_total_clips: - merged["clip_sample_ends"] = np.concatenate(all_clip_sample_ends) - if window_steps is not None: - merged["window_steps"] = window_steps - - _LOG.info( - "Loaded %d shards: %d clips, %d total frames", - len(shard_files), len(clip_lengths), merged["joint_pos"].shape[0], - ) - return merged + self.refs = refs + self.global_weights = torch.tensor( + [max(ref.weight, 0.0) for ref in refs], dtype=torch.float32 + ) + if float(self.global_weights.sum()) <= 0.0: + raise ValueError(f"All HDF5 motion weights are zero in {motion_dir}") + self.generation = 0 + self.current = self._load_random_batch() + self.next = self._load_random_batch() + + def _sample_global_ids(self) -> torch.Tensor: + probs = self.global_weights / self.global_weights.sum() + return torch.multinomial( + probs, + self.cache_num_clips, + replacement=True, + generator=self._rng, + ) + + def _load_random_batch(self) -> _MotionBatch: + return self._load_batch(self._sample_global_ids()) + + def _load_batch(self, global_ids: torch.Tensor) -> _MotionBatch: + ids_np = global_ids.cpu().numpy().astype(np.int64) + selected = [self.refs[int(idx)] for idx in ids_np] + max_len = max(ref.length for ref in selected) + arrays: dict[str, np.ndarray] = {} + for key in MOTION_ARRAY_KEYS: + sample_shape: tuple[int, ...] + if key in ("joint_pos", "joint_vel"): + sample_shape = (29,) + elif key == "body_quat_w": + sample_shape = (len(self.body_idx_np), 4) + else: + sample_shape = (len(self.body_idx_np), 3) + arrays[key] = np.zeros((len(selected), max_len, *sample_shape), dtype=np.float32) + + for out_i, ref in enumerate(selected): + shard_path = self._shard_paths[ref.shard_index] + sl = slice(ref.start, ref.start + ref.length) + with h5py.File(shard_path, "r") as h5: + arrays["joint_pos"][out_i, :ref.length] = np.asarray(h5["joint_pos"][sl], dtype=np.float32) + arrays["joint_vel"][out_i, :ref.length] = np.asarray(h5["joint_vel"][sl], dtype=np.float32) + for key in ("body_pos_w", "body_quat_w", "body_lin_vel_w", "body_ang_vel_w"): + arrays[key][out_i, :ref.length] = np.asarray(h5[key][sl], dtype=np.float32)[:, self.body_idx_np] + + lengths_np = np.asarray([ref.length for ref in selected], dtype=np.int64) + starts_np, ends_np = compute_clip_sample_ranges(lengths_np, window_steps=self.window_steps) + tensors = {key: torch.from_numpy(value).to(self.device) for key, value in arrays.items()} + return _MotionBatch( + tensors=tensors, + lengths=torch.tensor(lengths_np, dtype=torch.long, device=self.device), + fps=torch.tensor([ref.fps for ref in selected], dtype=torch.float32, device=self.device), + weights=torch.ones(len(selected), dtype=torch.float32, device=self.device), + sample_starts=torch.tensor(starts_np, dtype=torch.long, device=self.device), + sample_ends=torch.tensor(ends_np, dtype=torch.long, device=self.device), + global_ids=global_ids.to(self.device), + ) + + def advance(self) -> None: + self.current = self.next + self.next = self._load_random_batch() + self.generation += 1 class MotionLib: """Clip-aware motion library. - Loads a directory of shard NPZ files. Each shard contains flat motion arrays - plus per-clip metadata (``clip_starts``, ``clip_lengths``, ``clip_fps``, - ``clip_weights``). - - Motion data is stored as GPU tensors for fast gather+lerp interpolation. - All indexing, lerp, and slerp run entirely on device with zero CPU - round-trips. Numpy arrays are kept alongside for external consumers - (e.g. runner checkpoint buffers). + Loads a bounded subset of HDF5 motion windows into a GPU-resident cache. + Sampling and interpolation operate on cache-local clip ids; the next cache + is staged in memory and swapped at a rollout barrier by ``MotionCommand``. """ def __init__( @@ -173,6 +235,8 @@ def __init__( body_names: tuple[str, ...] | list[str] | None = None, device: str = "cpu", window_steps: tuple[int, ...] | list[int] | None = None, + cache_num_clips: int = 1024, + cache_seed: int = 0, ) -> None: self._device = device self.window_steps = parse_window_steps(window_steps) @@ -180,13 +244,22 @@ def __init__( motion_path = Path(motion_file) if not motion_path.is_dir(): raise FileNotFoundError( - f"motion_file must be a shard directory, got: {motion_file}" + f"motion_file must be an HDF5 shard directory, got: {motion_file}" + ) + manifest_path = motion_path / "manifest.json" + if not manifest_path.is_file(): + raise FileNotFoundError( + f"motion_file must contain manifest.json for HDF5 loading, got: {motion_file}" ) - data = _load_shard_dir(motion_path) + with manifest_path.open("r", encoding="utf-8") as f: + manifest = json.load(f) + if manifest.get("format") != "teleopit_motion_hdf5": + raise ValueError(f"Unsupported motion manifest format in {manifest_path}") + if body_names is None: body_idx_np = body_indexes.cpu().numpy() else: - dataset_body_names = [str(name) for name in np.asarray(data["body_names"])] + dataset_body_names = [str(name) for name in manifest["body_names"]] dataset_body_index_by_name = { name: index for index, name in enumerate(dataset_body_names) } @@ -205,100 +278,43 @@ def __init__( dtype=np.int64, ) - self._joint_pos = np.asarray(data["joint_pos"], dtype=np.float32) # (T, 29) - self._joint_vel = np.asarray(data["joint_vel"], dtype=np.float32) # (T, 29) - - # Body arrays: index by selected bodies immediately. Accessing an - # NpzFile key inflates that array from the zip; the intermediate full - # array is released once we slice and discard the reference. - self._body_pos_w = np.asarray( - data["body_pos_w"], dtype=np.float32 - )[:, body_idx_np] - self._body_quat_w = np.asarray( - data["body_quat_w"], dtype=np.float32 - )[:, body_idx_np] - self._body_lin_vel_w = np.asarray( - data["body_lin_vel_w"], dtype=np.float32 - )[:, body_idx_np] - self._body_ang_vel_w = np.asarray( - data["body_ang_vel_w"], dtype=np.float32 - )[:, body_idx_np] - - self.time_step_total = self._joint_pos.shape[0] - - # GPU tensors for fast gather+lerp interpolation (zero CPU round-trips). - self._joint_pos_t = torch.from_numpy(self._joint_pos).to(device) - self._joint_vel_t = torch.from_numpy(self._joint_vel).to(device) - self._body_pos_w_t = torch.from_numpy(self._body_pos_w).to(device) - self._body_quat_w_t = torch.from_numpy(self._body_quat_w).to(device) - self._body_lin_vel_w_t = torch.from_numpy(self._body_lin_vel_w).to(device) - self._body_ang_vel_w_t = torch.from_numpy(self._body_ang_vel_w).to(device) - - # --- clip-aware metadata (small — lives on GPU for sampling) --- - self.clip_starts = torch.tensor(data["clip_starts"], dtype=torch.long, device=device) - self.clip_lengths = torch.tensor(data["clip_lengths"], dtype=torch.long, device=device) - self.clip_weights = torch.tensor( - data["clip_weights"], dtype=torch.float32, device=device + self._cache = _Hdf5MotionCache( + motion_path, + body_idx_np=body_idx_np, + device=device, + window_steps=self.window_steps, + cache_num_clips=cache_num_clips, + seed=cache_seed, ) - fps_arr = np.asarray(data["clip_fps"]) - if fps_arr.ndim == 0: - self.clip_fps = torch.full( - (len(self.clip_starts),), float(fps_arr), dtype=torch.float32, device=device - ) - else: - self.clip_fps = torch.tensor(fps_arr, dtype=torch.float32, device=device) - - self.num_clips = len(self.clip_starts) + self._set_batch(self._cache.current) + + def _set_batch(self, batch: _MotionBatch) -> None: + self._batch = batch + self._joint_pos_t = batch.tensors["joint_pos"] + self._joint_vel_t = batch.tensors["joint_vel"] + self._body_pos_w_t = batch.tensors["body_pos_w"] + self._body_quat_w_t = batch.tensors["body_quat_w"] + self._body_lin_vel_w_t = batch.tensors["body_lin_vel_w"] + self._body_ang_vel_w_t = batch.tensors["body_ang_vel_w"] + + self.clip_lengths = batch.lengths + self.clip_weights = batch.weights + self.clip_fps = batch.fps + self.num_clips = int(batch.lengths.shape[0]) + self.time_step_total = int(batch.lengths.max().item()) self.clip_dt = 1.0 / self.clip_fps self.clip_duration_s = (self.clip_lengths.float() - 1.0) * self.clip_dt - - # Compute minimum clip length required by the window - max_future = max((s for s in self.window_steps if s > 0), default=0) - max_history = -min((s for s in self.window_steps if s < 0), default=0) - min_clip_length = max_history + 1 + max_future + 1 # +1 for interpolation - short_mask = self.clip_lengths < min_clip_length - n_short = int(short_mask.sum().item()) - if n_short > 0: - _LOG.warning( - "Disabling %d/%d clips shorter than %d frames (window_steps=%s)", - n_short, self.num_clips, min_clip_length, list(self.window_steps), - ) - self.clip_weights = self.clip_weights.clone() - self.clip_weights[short_mask] = 0.0 - - file_window_steps = parse_window_steps(data["window_steps"]) if "window_steps" in data else (0,) - if ( - "clip_sample_starts" in data - and "clip_sample_ends" in data - and file_window_steps == self.window_steps - ): - self.clip_sample_starts = torch.tensor( - data["clip_sample_starts"], dtype=torch.long, device=device - ) - self.clip_sample_ends = torch.tensor( - data["clip_sample_ends"], dtype=torch.long, device=device - ) - else: - # Only compute ranges for clips long enough; short ones get dummy [0,1) - lengths_np = self.clip_lengths.cpu().numpy() - sample_starts = np.zeros(self.num_clips, dtype=np.int64) - sample_ends = np.ones(self.num_clips, dtype=np.int64) - long_mask = ~short_mask.cpu().numpy() - if long_mask.any(): - s, e = compute_clip_sample_ranges( - lengths_np[long_mask], - window_steps=self.window_steps, - ) - sample_starts[long_mask] = s - sample_ends[long_mask] = e - self.clip_sample_starts = torch.tensor( - sample_starts, dtype=torch.long, device=device - ) - self.clip_sample_ends = torch.tensor( - sample_ends, dtype=torch.long, device=device - ) + self.clip_sample_starts = batch.sample_starts + self.clip_sample_ends = batch.sample_ends self.clip_sample_start_s = self.clip_sample_starts.float() * self.clip_dt self.clip_sample_end_s = self.clip_sample_ends.float() * self.clip_dt + # Kept for introspection/logging; frame interpolation is cache-local. + self.clip_starts = torch.zeros(self.num_clips, dtype=torch.long, device=self._device) + self.generation = self._cache.generation + + def advance_cache(self) -> None: + self._cache.advance() + self._set_batch(self._cache.current) # ------------------------------------------------------------------ # Sampling helpers @@ -339,7 +355,6 @@ def _compute_interpolation_state( steps: tuple[int, ...], ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: fps = self.clip_fps[motion_ids] - starts = self.clip_starts[motion_ids] lengths = self.clip_lengths[motion_ids] durations = self.clip_duration_s[motion_ids] @@ -358,10 +373,8 @@ def _compute_interpolation_state( frame_i0 = torch.clamp(frame_i0, min=zero, max=max_frame) frame_i1 = torch.clamp(frame_i1, min=zero, max=max_frame) - idx0 = (starts[:, None] + frame_i0).reshape(-1) - idx1 = (starts[:, None] + frame_i1).reshape(-1) window = len(steps) - return idx0, idx1, alpha.reshape(-1), window + return frame_i0.reshape(-1), frame_i1.reshape(-1), alpha.reshape(-1), window _ALL_KEYS = frozenset(( "joint_pos", "joint_vel", @@ -394,6 +407,7 @@ def get_window_frames( steps, ) batch = motion_ids.shape[0] + batch_idx = motion_ids[:, None].expand(batch, window).reshape(-1) want = self._ALL_KEYS if keys is None else keys result: dict[str, torch.Tensor] = {} @@ -403,7 +417,7 @@ def get_window_frames( for key, arr_t in (("joint_pos", self._joint_pos_t), ("joint_vel", self._joint_vel_t)): if key not in want: continue - v0, v1 = arr_t[idx0], arr_t[idx1] + v0, v1 = arr_t[batch_idx, idx0], arr_t[batch_idx, idx1] result[key] = (v0 + a1 * (v1 - v0)).reshape(batch, window, -1) # body arrays: (T, B, D) — GPU gather + lerp, optionally pre-slice bodies @@ -416,20 +430,21 @@ def get_window_frames( if key not in want: continue if body_indices is not None: - v0, v1 = arr_t[idx0][:, body_indices], arr_t[idx1][:, body_indices] + v0 = arr_t[batch_idx, idx0][:, body_indices] + v1 = arr_t[batch_idx, idx1][:, body_indices] else: - v0, v1 = arr_t[idx0], arr_t[idx1] + v0, v1 = arr_t[batch_idx, idx0], arr_t[batch_idx, idx1] interp = v0 + a2 * (v1 - v0) result[key] = interp.reshape(batch, window, *interp.shape[1:]) # body_quat_w: GPU slerp, optionally pre-slice bodies if "body_quat_w" in want: if body_indices is not None: - q0 = self._body_quat_w_t[idx0][:, body_indices] - q1 = self._body_quat_w_t[idx1][:, body_indices] + q0 = self._body_quat_w_t[batch_idx, idx0][:, body_indices] + q1 = self._body_quat_w_t[batch_idx, idx1][:, body_indices] else: - q0 = self._body_quat_w_t[idx0] - q1 = self._body_quat_w_t[idx1] + q0 = self._body_quat_w_t[batch_idx, idx0] + q1 = self._body_quat_w_t[batch_idx, idx1] nb = q0.shape[1] q0_flat = q0.reshape(-1, 4) q1_flat = q1.reshape(-1, 4) @@ -457,11 +472,6 @@ def get_frames( key: value[:, 0] for key, value in windowed.items() } - -# Backward compatibility alias -MotionLoader = MotionLib - - def _validate_rewind_sampling_cfg(cfg: Any) -> None: if cfg.rewind_min_steps < 0: raise ValueError( @@ -502,7 +512,11 @@ def __init__(self, cfg: MotionCommandCfg, env: ManagerBasedRlEnv): body_names=self.cfg.body_names, device=self.device, window_steps=self.cfg.window_steps, + cache_num_clips=self.cfg.cache_num_clips, + cache_seed=self.cfg.cache_seed, ) + self._motion_cache_step_counter = 0 + self._motion_cache_swap_pending = False # Per-env motion state: clip id + elapsed time (seconds) self.motion_ids = torch.zeros(self.num_envs, dtype=torch.long, device=self.device) @@ -833,6 +847,9 @@ def _resample_command(self, env_ids: torch.Tensor): "Supported modes are 'uniform', 'start', and 'rewind'." ) + self._reset_envs_to_current_reference(env_ids) + + def _reset_envs_to_current_reference(self, env_ids: torch.Tensor) -> None: if env_ids.numel() == 0: return @@ -920,6 +937,10 @@ def _refresh_body_local_cache(self) -> None: def _update_command(self): # Advance motion time by real elapsed time self.motion_times += self._step_dt + if self.cfg.cache_swap_interval_steps > 0: + self._motion_cache_step_counter += 1 + if self._motion_cache_step_counter >= self.cfg.cache_swap_interval_steps: + self._motion_cache_swap_pending = True # Handle clips that exceeded their duration end_times = self.motion.clip_sample_end_s[self.motion_ids] @@ -958,6 +979,24 @@ def _update_command(self): self._refresh_body_local_cache() self._update_feet_standing() + def apply_cache_swap_if_pending_barrier(self) -> bool: + """Swap the staged motion cache at a rollout barrier, then resample all envs.""" + if not self._motion_cache_swap_pending: + return False + self.motion.advance_cache() + self._motion_cache_step_counter = 0 + self._motion_cache_swap_pending = False + all_env_ids = torch.arange(self.num_envs, dtype=torch.long, device=self.device) + if self.cfg.sampling_mode == "start": + self.motion_ids[all_env_ids] = self.motion.sample_motion_ids(self.num_envs) + self.motion_times[all_env_ids] = self.motion.sample_start_times(self.motion_ids[all_env_ids]) + else: + # Rewind only makes sense inside one cache generation. After a cache + # swap, local ids refer to different clips, so fall back to uniform. + self._uniform_sampling(all_env_ids) + self._reset_envs_to_current_reference(all_env_ids) + return True + # ------------------------------------------------------------------ # Visualization # ------------------------------------------------------------------ @@ -1044,6 +1083,9 @@ class MotionCommandCfg(CommandTermCfg): joint_position_range: tuple[float, float] = (-0.52, 0.52) sampling_mode: Literal["uniform", "start", "rewind"] = "rewind" window_steps: tuple[int, ...] = (0,) + cache_num_clips: int = 1024 + cache_swap_interval_steps: int = 500 + cache_seed: int = 0 rewind_prob: float = 0.8 rewind_min_steps: int = 25 rewind_max_steps: int = 75 diff --git a/train_mimic/tasks/tracking/rl/runner.py b/train_mimic/tasks/tracking/rl/runner.py index 9e22c6b2..64acc939 100644 --- a/train_mimic/tasks/tracking/rl/runner.py +++ b/train_mimic/tasks/tracking/rl/runner.py @@ -6,7 +6,6 @@ import torch from rsl_rl.env.vec_env import VecEnv -from torch import nn from mjlab.rl import RslRlVecEnvWrapper from mjlab.rl.runner import MjlabOnPolicyRunner @@ -42,43 +41,6 @@ def _format_duration(seconds: float) -> str: return f"{hours:02d}:{minutes:02d}:{secs:02d}" -class _OnnxMotionModel(nn.Module): - """ONNX-exportable model that wraps the policy and bundles motion reference data.""" - - def __init__(self, actor, motion): - super().__init__() - self.policy = actor.as_onnx(verbose=False) - # torch.from_numpy shares memory with the underlying numpy array - # (zero-copy). The arrays are already writable (loaded from .npz). - self.register_buffer("joint_pos", torch.from_numpy(motion._joint_pos)) - self.register_buffer("joint_vel", torch.from_numpy(motion._joint_vel)) - self.register_buffer("body_pos_w", torch.from_numpy(motion._body_pos_w)) - self.register_buffer("body_quat_w", torch.from_numpy(motion._body_quat_w)) - self.register_buffer("body_lin_vel_w", torch.from_numpy(motion._body_lin_vel_w)) - self.register_buffer("body_ang_vel_w", torch.from_numpy(motion._body_ang_vel_w)) - self.time_step_total: int = self.joint_pos.shape[0] # type: ignore[index] - - def forward(self, *args): - # Last arg is always time_step; preceding args are policy inputs. - *policy_args, time_step = args - time_step_clamped = torch.clamp( - time_step.long().squeeze(-1), max=self.time_step_total - 1 - ) - if len(policy_args) == 1: - policy_out = self.policy(policy_args[0]) - else: - policy_out = self.policy(*policy_args) - return ( - policy_out, - self.joint_pos[time_step_clamped], # type: ignore[index] - self.joint_vel[time_step_clamped], # type: ignore[index] - self.body_pos_w[time_step_clamped], # type: ignore[index] - self.body_quat_w[time_step_clamped], # type: ignore[index] - self.body_lin_vel_w[time_step_clamped], # type: ignore[index] - self.body_ang_vel_w[time_step_clamped], # type: ignore[index] - ) - - class MotionTrackingOnPolicyRunner(MjlabOnPolicyRunner): env: RslRlVecEnvWrapper @@ -131,6 +93,9 @@ def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = Fals collect_time = stop - start start = stop self.alg.compute_returns(obs) + cmd = self._motion_command() + if cmd.apply_cache_swap_if_pending_barrier(): + obs = self.env.get_observations().to(self.device) loss_dict = self.alg.update() @@ -293,41 +258,9 @@ def export_policy_to_onnx( path: str, filename: str = "policy.onnx", verbose: bool = False, - *, - include_motion_labels: bool = False, ) -> None: os.makedirs(path, exist_ok=True) output_path = os.path.join(path, filename) - if include_motion_labels: - cmd = cast(MotionCommand, self.env.unwrapped.command_manager.get_term("motion")) - model = _OnnxMotionModel(self.alg.get_policy(), cmd.motion) - model.to("cpu") - model.eval() - dummy_inputs = model.policy.get_dummy_inputs() - time_step = torch.zeros(1, 1) - input_names = model.policy.input_names + ["time_step"] - torch.onnx.export( - model, - (*dummy_inputs, time_step), - output_path, - export_params=True, - opset_version=18, - verbose=verbose, - input_names=input_names, - output_names=[ - "actions", - "joint_pos", - "joint_vel", - "body_pos_w", - "body_quat_w", - "body_lin_vel_w", - "body_ang_vel_w", - ], - dynamic_axes={}, - dynamo=False, - ) - return - model = self.alg.get_policy().as_onnx(verbose=False) model.to("cpu") model.eval() From c40e039b35d6fc81ca7fd03920ba374423b2f115 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Mon, 15 Jun 2026 09:51:01 +0800 Subject: [PATCH 084/122] Simplify dataset pipeline --- AGENTS.md | 9 +- README.md | 2 +- docs/docs/getting-started/download-assets.md | 3 +- docs/docs/reference/architecture.md | 2 +- docs/docs/reference/assets.md | 5 +- docs/docs/reference/dataset.md | 35 +- .../reference/training-troubleshooting.md | 4 +- docs/docs/tutorials/training.md | 16 +- .../getting-started/download-assets.md | 3 +- .../current/reference/architecture.md | 2 +- .../current/reference/assets.md | 5 +- .../current/reference/dataset.md | 42 +- .../reference/training-troubleshooting.md | 4 +- .../current/tutorials/training.md | 16 +- scripts/review/__init__.py | 0 scripts/review/build_dataset_from_review.py | 225 ------ scripts/review/export_reviewed_manifest.py | 130 ---- scripts/review/init_review_manifest.py | 125 --- scripts/review/review_dataset.py | 720 ------------------ scripts/view/view_dataset.py | 451 +++++++++++ teleopit/recording/pico_motion.py | 9 +- teleopit/runtime/external_assets.py | 3 +- tests/test_dataset_v2.py | 204 ++--- tests/test_dataset_viewer.py | 59 ++ tests/test_motion_sampling.py | 82 +- tests/test_pico_motion_recording.py | 2 + tests/test_review_pipeline.py | 326 -------- tests/test_train_script.py | 23 +- train_mimic/app.py | 18 +- train_mimic/configs/datasets/lafan1.yaml | 2 - train_mimic/configs/datasets/seed.yaml | 3 - train_mimic/configs/datasets/seed_clean.yaml | 3 - train_mimic/configs/datasets/twist2.yaml | 2 - train_mimic/data/dataset_builder.py | 476 ++---------- train_mimic/data/dataset_lib.py | 252 +++--- train_mimic/data/motion_fk.py | 35 +- train_mimic/data/preprocess.py | 9 +- train_mimic/data/review_lib.py | 185 ----- train_mimic/scripts/benchmark.py | 17 +- train_mimic/scripts/convert_pkl_to_npz.py | 9 +- train_mimic/scripts/data/build_dataset.py | 5 +- train_mimic/scripts/data/ingest_motion.py | 2 - train_mimic/scripts/data/inspect_dataset.py | 35 + train_mimic/scripts/play.py | 8 +- train_mimic/scripts/train.py | 12 +- .../tasks/tracking/config/constants.py | 2 +- train_mimic/tasks/tracking/mdp/commands.py | 116 +-- 47 files changed, 1145 insertions(+), 2553 deletions(-) delete mode 100644 scripts/review/__init__.py delete mode 100644 scripts/review/build_dataset_from_review.py delete mode 100644 scripts/review/export_reviewed_manifest.py delete mode 100644 scripts/review/init_review_manifest.py delete mode 100644 scripts/review/review_dataset.py create mode 100644 scripts/view/view_dataset.py create mode 100644 tests/test_dataset_viewer.py delete mode 100644 tests/test_review_pipeline.py delete mode 100644 train_mimic/data/review_lib.py create mode 100644 train_mimic/scripts/data/inspect_dataset.py diff --git a/AGENTS.md b/AGENTS.md index 986ebec5..1400e12c 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -203,13 +203,14 @@ The single supported training task is `General-Tracking-G1` (experiment name: `g ### Dataset Pipeline - Dataset build spec supports a `preprocess` section for root-xy normalization, ground alignment, and basic clip filtering -- Final training dataset outputs are HDF5 split directories: `data/datasets//train/manifest.json` + `shard_*.h5` and the same under `val/` -- Each shard stores clip-aware window metadata (`clip_starts`, `clip_lengths`, `clip_fps`, `clip_weights`); long clips are split into overlapping bounded windows +- Final training dataset outputs are minimal HDF5 shards directly under `data/datasets//` (recursive shard discovery is supported; no train/val split and no manifest file) +- Each shard stores only `root_pos`, `root_quat_w`, `joint_pos`, `body_names`, and clip-aware window metadata (`clip_starts`, `clip_lengths`, `clip_fps`); long clips are split into overlapping bounded windows +- Training computes joint velocities and body FK/velocities online when loading the motion cache - `MotionLib` loads only a configurable HDF5 subset cache into CPU/GPU memory, stages the next cache, and swaps at the PPO rollout barrier - `MotionLib` samples only valid center frames for the configured `window_steps`; default is `window_steps=[0]` - Training supports `uniform` and `rewind` sampling on the active cache; in distributed training each rank sets a rank-offset `cache_seed` - `scripts/run/record_pico_motion.py` records Pico live body tracking as retargeted G1 motion NPZ clips in `data/pico_motion/clips/`; it opens a live `Retarget` viewer, uses terminal keys `R/S/D/N/Q`, stores semantic labels in filenames, and intentionally does not write per-clip JSON -- Build Pico-recorded clips into shards with `python train_mimic/scripts/data/build_dataset.py --spec data/pico_motion/pico_recorded.yaml --force`; at least two clips are required for non-empty train/val splits +- Build Pico-recorded clips into shards with `python train_mimic/scripts/data/build_dataset.py --spec data/pico_motion/pico_recorded.yaml --force` Quick reference: @@ -217,7 +218,7 @@ Quick reference: python train_mimic/scripts/data/build_dataset.py --spec train_mimic/configs/datasets/twist2.yaml python scripts/run/record_pico_motion.py python train_mimic/scripts/data/build_dataset.py --spec data/pico_motion/pico_recorded.yaml --force -python train_mimic/scripts/train.py --motion_file data/datasets/twist2/train +python train_mimic/scripts/train.py --motion_file data/datasets/twist2 python train_mimic/scripts/save_onnx.py --checkpoint logs/rsl_rl/g1_general_tracking//model_30000.pt --output policy.onnx --history_length 10 ``` diff --git a/README.md b/README.md index d3960059..fc121b3b 100644 --- a/README.md +++ b/README.md @@ -95,7 +95,7 @@ Full docs at **[BotRunner64.github.io/Teleopit](https://BotRunner64.github.io/Te - Switched default `vr_hand_pose` to a low-latency somehand path with 60 Hz hand retargeting and reduced smoothing. - Realtime mode switches and pause/resume now preserve GMR IK warm-starts instead of cold-starting the retargeter on each transition. - Added an interactive Pico motion recorder that saves retargeted G1 motion clips as training-ready NPZ files. -- Switched training datasets to HDF5 split directories with subset caching and rollout-barrier cache swaps to reduce CPU/GPU memory use. +- Switched training datasets to recursive minimal HDF5 shards with no train/val split or manifest; training derives FK/velocities while loading the motion cache. - General-Tracking-G1 training defaults to `rewind` motion sampling and also supports `uniform`; playback/benchmark use `start`. - Added optional `sampling_mode=rewind` for training, which restarts failed episodes from the same clip after rewinding a configurable number of policy steps. - Added root velocity, joint tracking, and survival rewards to the General-Tracking-G1 training objective. diff --git a/docs/docs/getting-started/download-assets.md b/docs/docs/getting-started/download-assets.md index 2f561efd..6db8fde2 100644 --- a/docs/docs/getting-started/download-assets.md +++ b/docs/docs/getting-started/download-assets.md @@ -29,8 +29,7 @@ python scripts/setup/download_assets.py --only gmr ckpt bvh |-------|------|---------| | `track.onnx` | 4 MB | ONNX inference model | | `track.pt` | 27 MB | PyTorch checkpoint (for resume training) | -| `data/datasets/seed/train/manifest.json` + `shard_*.h5` | ~25 GB | Training dataset | -| `data/datasets/seed/val/manifest.json` + `shard_*.h5` | ~1.4 GB | Validation dataset | +| `data/datasets/seed/shard_*.h5` | ~26 GB | Training dataset | | `data/sample_bvh/*.bvh` | 5 MB | Sample motion files | | `teleopit/retargeting/gmr/assets/` | ~1.2 GB | GMR retargeting robot models | diff --git a/docs/docs/reference/architecture.md b/docs/docs/reference/architecture.md index b6a2358f..b4d149ab 100644 --- a/docs/docs/reference/architecture.md +++ b/docs/docs/reference/architecture.md @@ -60,7 +60,7 @@ train_mimic/scripts/data | Actor/Critic | TemporalCNN (2048, 1024, 512, 256, 128) | | Training sampling | Default `rewind`; also supports `uniform`; playback/benchmark use `start` | | Training `window_steps` | `[0]` | -| Data format | HDF5 shard directories (`manifest.json` + `shard_*.h5`) | +| Data format | Minimal recursive HDF5 shards (`shard_*.h5`) | ## Constraints diff --git a/docs/docs/reference/assets.md b/docs/docs/reference/assets.md index c41f8bb1..5408881e 100644 --- a/docs/docs/reference/assets.md +++ b/docs/docs/reference/assets.md @@ -35,7 +35,7 @@ Robot models, datasets, checkpoints, and demo media are not tracked in Git. They | `ckpt` | Teleopit-models | `checkpoints/track.onnx`, `checkpoints/track.pt` | | `gmr` | Teleopit-models | `archives/gmr_assets.tar.gz` | | `bvh` | Teleopit-models | `archives/sample_bvh.tar.gz` | -| `data` | Teleopit-datasets | `data/train/`, `data/val/` | +| `data` | Teleopit-datasets | `data/` | ## Download @@ -63,8 +63,7 @@ Local paths after download: | `checkpoints/track.pt` | `track.pt` | | `archives/gmr_assets.tar.gz` | `teleopit/retargeting/gmr/assets/` (extracted) | | `archives/sample_bvh.tar.gz` | `data/sample_bvh/` (extracted) | -| `data/train/` | `data/datasets/seed/train/` | -| `data/val/` | `data/datasets/seed/val/` | +| `data/` | `data/datasets/seed/` | ## Upload to ModelScope diff --git a/docs/docs/reference/dataset.md b/docs/docs/reference/dataset.md index f7a058d8..389ba56c 100644 --- a/docs/docs/reference/dataset.md +++ b/docs/docs/reference/dataset.md @@ -10,10 +10,10 @@ sidebar_position: 3 python scripts/setup/download_assets.py --only data ``` -Then train directly with the HDF5 shard directory: +Then train directly with the dataset root: ```bash -python train_mimic/scripts/train.py --motion_file data/datasets/seed/train +python train_mimic/scripts/train.py --motion_file data/datasets/seed ``` For custom dataset construction, read on. @@ -44,12 +44,11 @@ python train_mimic/scripts/data/build_dataset.py \ --spec data/pico_motion/pico_recorded.yaml --force ``` -Record at least two clips before building so both train and validation splits -can be populated. +At least one valid clip is required after preprocessing. ## Custom Dataset Construction -Data pipeline: `typed source YAML -> preprocess/filter -> HDF5 shard-only training data` +Data pipeline: `typed source YAML -> preprocess/filter -> minimal HDF5 shards` ```bash python train_mimic/scripts/data/build_dataset.py \ @@ -62,19 +61,13 @@ python train_mimic/scripts/data/build_dataset.py \ data/datasets// ├── clips/ # Optional; only for per-clip intermediates │ └── /... -├── train/ -│ ├── manifest.json -│ └── shard_*.h5 -├── val/ -│ ├── manifest.json -│ └── shard_*.h5 -├── manifest_resolved.csv -└── build_info.json +└── shard_*.h5 ``` - If the spec contains `bvh` or `npz` sources, the builder retains/generates `clips/` -- If the spec is all `pkl` or `seed_csv` sources, the builder takes a batch path producing split-level shards directly -- Training loads only a subset cache from the HDF5 split, stages the next cache, and swaps caches at the PPO rollout barrier. +- If the spec is all `pkl` or `seed_csv` sources, the builder takes a batch path producing shards directly +- Training recursively discovers `*.h5` shards below the specified root, so datasets can be merged by placing multiple shard directories under one parent +- Training loads only a subset cache from the discovered shards, derives FK/velocities online, stages the next cache, and swaps caches at the PPO rollout barrier. ## YAML Spec Format @@ -83,8 +76,6 @@ Example (`train_mimic/configs/datasets/twist2.yaml`): ```yaml name: twist2 target_fps: 30 -val_percent: 5 -hash_salt: "" preprocess: normalize_root_xy: true ground_align: first_frame_foot @@ -104,8 +95,6 @@ sources: |-------|-------------| | `name` | Dataset name, maps to output directory | | `target_fps` | Target frame rate for resampling | -| `val_percent` | Validation split percentage (hash-based on clip_id) | -| `hash_salt` | Optional split salt | | `preprocess.normalize_root_xy` | Normalize root body first-frame xy to origin | | `preprocess.ground_align` | `none` / `first_frame_foot` | | `preprocess.min_frames` | Minimum clip length | @@ -115,20 +104,19 @@ sources: | `sources[].name` | Source name (used for clips subdirectory) | | `sources[].type` | `bvh` / `pkl` / `npz` / `seed_csv` | | `sources[].input` | Input file or directory | -| `sources[].weight` | Optional sampling weight (default `1.0`) | | `sources[].bvh_format` | Required for BVH: `lafan1` / `hc_mocap` / `nokov` | | `sources[].robot_name` | BVH only, default `unitree_g1` | | `sources[].max_frames` | BVH only, `0` = full length | ## Conversion Rules -All sources are converted to standard training shards. Each clip goes through preprocess/filter before writing to shards: +All sources are converted to standard training shards. Each clip goes through preprocessing/filtering before writing to shards: - `bvh -> retarget pkl -> npz clip` - `pkl -> npz clip` (or direct batch shard for pkl-only datasets) - `npz -> validate + copy/reuse` -Each shard contains: `clip_starts`, `clip_lengths`, `clip_fps`, `clip_weights`. +Each shard stores minimal motion data: `root_pos`, `root_quat_w`, `joint_pos`, `body_names`, `clip_starts`, `clip_lengths`, `clip_fps`. Joint velocities and body FK/velocities are computed when training loads a cache. ## Common Commands @@ -149,6 +137,9 @@ python train_mimic/scripts/data/build_dataset.py \ # Print build report python train_mimic/scripts/data/build_dataset.py \ --spec train_mimic/configs/datasets/twist2.yaml --json + +# Inspect a dataset root +python train_mimic/scripts/data/inspect_dataset.py data/datasets/twist2 ``` ## Batch Ingest to NPZ Clips diff --git a/docs/docs/reference/training-troubleshooting.md b/docs/docs/reference/training-troubleshooting.md index f2e56a77..e1aff875 100644 --- a/docs/docs/reference/training-troubleshooting.md +++ b/docs/docs/reference/training-troubleshooting.md @@ -45,7 +45,7 @@ If check fails, regenerate data and run a smoke test: ```bash python train_mimic/scripts/train.py \ --num_envs 64 --max_iterations 100 \ - --motion_file data/datasets//train + --motion_file data/datasets/ ``` Expected: `Mean episode length` significantly > 1, `error_anchor_pos` starts decreasing. @@ -126,7 +126,7 @@ Ensure `num_eval_steps >= video_length`: ```bash python train_mimic/scripts/benchmark.py \ --checkpoint logs/rsl_rl/g1_general_tracking//model_30000.pt \ - --motion_file data/datasets//val \ + --motion_file data/datasets/ \ --num_envs 1 --num_eval_steps 2000 \ --video --video_length 600 ``` diff --git a/docs/docs/tutorials/training.md b/docs/docs/tutorials/training.md index d8e8706e..a210ff4f 100644 --- a/docs/docs/tutorials/training.md +++ b/docs/docs/tutorials/training.md @@ -31,7 +31,7 @@ python -c "import train_mimic.tasks; print('training OK')" python train_mimic/scripts/train.py \ --num_envs 64 \ --max_iterations 100 \ - --motion_file data/datasets/seed/train + --motion_file data/datasets/seed ``` ### Full Training @@ -40,7 +40,7 @@ python train_mimic/scripts/train.py \ python train_mimic/scripts/train.py \ --num_envs 4096 \ --max_iterations 30000 \ - --motion_file data/datasets/seed/train + --motion_file data/datasets/seed ``` ### Multi-GPU @@ -50,7 +50,7 @@ python train_mimic/scripts/train.py \ --gpu_ids 0 1 2 3 \ --num_envs 1024 \ --max_iterations 30000 \ - --motion_file data/datasets/seed/train + --motion_file data/datasets/seed ``` ### Multi-Node Multi-GPU @@ -67,14 +67,14 @@ torchrun \ train_mimic/scripts/train.py \ --num_envs 1024 \ --max_iterations 1000 \ - --motion_file data/datasets/seed/train + --motion_file data/datasets/seed ``` **Notes:** - `--num_envs` is per-GPU in multi-GPU mode - `--num_envs` is also per-process in multi-node mode, so total environments scale with `world_size` - Default logger is TensorBoard. Use `--logger wandb` or `--logger swanlab` to select W&B or SwanLab; the project name defaults to `experiment_name` -- `--motion_file` accepts only HDF5 shard directories containing `manifest.json` and `shard_*.h5` files +- `--motion_file` accepts a dataset root directory or single `.h5` shard; shard discovery is recursive - `--cache_num_clips` controls the active HDF5 subset size; `--cache_swap_interval_steps` controls how often the next subset is swapped in at a rollout barrier - `--max_iterations` means additional iterations; resuming from `model_12000.pt` with `--max_iterations 18000` trains to `model_30000.pt` @@ -96,7 +96,7 @@ The exported model is a dual-input ONNX (`obs` + `obs_history`). The inference s ```bash python train_mimic/scripts/play.py \ --checkpoint logs/rsl_rl/g1_general_tracking//model_30000.pt \ - --motion_file data/datasets/seed/val + --motion_file data/datasets/seed ``` ### Benchmark @@ -104,7 +104,7 @@ python train_mimic/scripts/play.py \ ```bash python train_mimic/scripts/benchmark.py \ --checkpoint logs/rsl_rl/g1_general_tracking//model_30000.pt \ - --motion_file data/datasets/seed/val \ + --motion_file data/datasets/seed \ --num_envs 1 ``` @@ -113,7 +113,7 @@ python train_mimic/scripts/benchmark.py \ ```bash python train_mimic/scripts/benchmark.py \ --checkpoint logs/rsl_rl/g1_general_tracking//model_30000.pt \ - --motion_file data/datasets/seed/val \ + --motion_file data/datasets/seed \ --num_envs 1 \ --video \ --video_length 600 diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/download-assets.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/download-assets.md index 8f278db7..61ce714e 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/download-assets.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/download-assets.md @@ -29,8 +29,7 @@ python scripts/setup/download_assets.py --only gmr ckpt bvh |------|------|------| | `track.onnx` | 4 MB | ONNX 推理模型 | | `track.pt` | 27 MB | PyTorch 检查点(用于恢复训练) | -| `data/datasets/seed/train/manifest.json` + `shard_*.h5` | ~25 GB | 训练数据集 | -| `data/datasets/seed/val/manifest.json` + `shard_*.h5` | ~1.4 GB | 验证数据集 | +| `data/datasets/seed/shard_*.h5` | ~26 GB | 训练数据集 | | `data/sample_bvh/*.bvh` | 5 MB | 示例动捕文件 | | `teleopit/retargeting/gmr/assets/` | ~1.2 GB | GMR 重定向机器人模型 | diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/architecture.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/architecture.md index 5e03c8f5..4f48fdbe 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/architecture.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/architecture.md @@ -60,7 +60,7 @@ train_mimic/scripts/data | Actor/Critic | TemporalCNN(2048、1024、512、256、128) | | 训练采样 | 默认 `rewind`;也支持 `uniform`;播放/评估使用 `start` | | 训练 `window_steps` | `[0]` | -| 数据格式 | HDF5 shard 目录(`manifest.json` + `shard_*.h5`) | +| 数据格式 | 可递归发现的最小 HDF5 shard(`shard_*.h5`) | ## 约束 diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/assets.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/assets.md index 733afcf6..a0f7643c 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/assets.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/assets.md @@ -35,7 +35,7 @@ sidebar_position: 2 | `ckpt` | Teleopit-models | `checkpoints/track.onnx`、`checkpoints/track.pt` | | `gmr` | Teleopit-models | `archives/gmr_assets.tar.gz` | | `bvh` | Teleopit-models | `archives/sample_bvh.tar.gz` | -| `data` | Teleopit-datasets | `data/train/`、`data/val/` | +| `data` | Teleopit-datasets | `data/` | ## 下载 @@ -63,8 +63,7 @@ python scripts/setup/download_assets.py --source huggingface | `checkpoints/track.pt` | `track.pt` | | `archives/gmr_assets.tar.gz` | `teleopit/retargeting/gmr/assets/`(自动解压) | | `archives/sample_bvh.tar.gz` | `data/sample_bvh/`(自动解压) | -| `data/train/` | `data/datasets/seed/train/` | -| `data/val/` | `data/datasets/seed/val/` | +| `data/` | `data/datasets/seed/` | ## 上传到 ModelScope diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/dataset.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/dataset.md index 439a9e41..fa92e529 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/dataset.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/dataset.md @@ -10,10 +10,10 @@ sidebar_position: 3 python scripts/setup/download_assets.py --only data ``` -下载后直接传 HDF5 shard 目录用于训练: +下载后直接传数据集根目录用于训练: ```bash -python train_mimic/scripts/train.py --motion_file data/datasets/seed/train +python train_mimic/scripts/train.py --motion_file data/datasets/seed ``` 如需自定义构建,继续阅读下文。 @@ -42,11 +42,11 @@ python train_mimic/scripts/data/build_dataset.py \ --spec data/pico_motion/pico_recorded.yaml --force ``` -构建前至少录制两段 clip,确保 train 和 validation split 都能生成。 +预处理后至少需要保留一段有效 clip。 ## 自定义构建 -数据主线:`typed source YAML -> preprocess/filter -> HDF5 shard-only 训练数据` +数据主线:`typed source YAML -> preprocess/filter -> minimal HDF5 shards` ```bash python train_mimic/scripts/data/build_dataset.py \ @@ -59,19 +59,13 @@ python train_mimic/scripts/data/build_dataset.py \ data/datasets// ├── clips/ # 可选;仅在需要逐 clip 中间产物时存在 │ └── /... -├── train/ -│ ├── manifest.json -│ └── shard_*.h5 -├── val/ -│ ├── manifest.json -│ └── shard_*.h5 -├── manifest_resolved.csv -└── build_info.json +└── shard_*.h5 ``` - 若 spec 包含 `bvh` 或 `npz` source,builder 会保留/生成 `clips/` -- 若 spec 全部是 `pkl` 或 `seed_csv` source,直接并行产出 split 级别的 shard,默认不写中间 clip 文件 -- 训练时只从 HDF5 split 加载一个 subset cache,同时预加载下一个 cache,并在 PPO rollout barrier 处切换。 +- 若 spec 全部是 `pkl` 或 `seed_csv` source,builder 会直接并行产出 shard,默认不写中间 clip 文件 +- 训练会递归发现指定根目录下的 `*.h5` shard,因此可以把多个数据集目录放到同一个父目录下完成合并 +- 训练时只从发现的 shard 加载一个 subset cache,在线派生 FK/速度,同时预加载下一个 cache,并在 PPO rollout barrier 处切换。 ## YAML spec @@ -80,8 +74,6 @@ data/datasets// ```yaml name: twist2 target_fps: 30 -val_percent: 5 -hash_salt: "" preprocess: normalize_root_xy: true ground_align: first_frame_foot @@ -101,8 +93,6 @@ sources: |------|------| | `name` | 数据集名称,对应输出目录 `data/datasets//` | | `target_fps` | 写入 shard 前统一重采样到的目标帧率 | -| `val_percent` | 基于 `clip_id` hash 的验证集比例 | -| `hash_salt` | 可选 split salt | | `preprocess.normalize_root_xy` | 是否把根 body 首帧 xy 平移到原点 | | `preprocess.ground_align` | `none` / `first_frame_foot` | | `preprocess.min_frames` | clip 最短长度约束 | @@ -110,11 +100,20 @@ sources: | `sources[].name` | source 名称;生成 clip 中间产物时也作为 `clips//` 子目录名 | | `sources[].type` | `bvh` / `pkl` / `npz` / `seed_csv` | | `sources[].input` | 原始输入文件或目录 | -| `sources[].weight` | 可选源级别采样权重,默认 `1.0` | | `sources[].bvh_format` | 仅 `bvh` source 必填:`lafan1` / `hc_mocap` / `nokov` | | `sources[].robot_name` | 仅 `bvh` source,默认 `unitree_g1` | | `sources[].max_frames` | 仅 `bvh` source,`0` 表示全长 | +## 转换规则 + +所有 source 都会转换成标准训练 shard。每段 clip 会先经过预处理/过滤,再写入 shard: + +- `bvh -> retarget pkl -> npz clip` +- `pkl -> npz clip`(或在 pkl-only 数据集中直接 batch 写 shard) +- `npz -> validate + copy/reuse` + +每个 shard 只保存最小运动数据:`root_pos`、`root_quat_w`、`joint_pos`、`body_names`、`clip_starts`、`clip_lengths` 和 `clip_fps`。Joint velocity 和 body FK/velocity 会在训练加载 cache 时计算。 + ## 常用命令 ```bash @@ -134,11 +133,14 @@ python train_mimic/scripts/data/build_dataset.py \ # 打印 build report python train_mimic/scripts/data/build_dataset.py \ --spec train_mimic/configs/datasets/twist2.yaml --json + +# 查看数据集统计 +python train_mimic/scripts/data/inspect_dataset.py data/datasets/twist2 ``` ## 批量转换为 NPZ clips -只把某批原始数据转成标准 NPZ clip,不做 train/val merge: +只把某批原始数据转成标准 NPZ clip,不合并为 shard: ```bash python train_mimic/scripts/data/ingest_motion.py \ diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/training-troubleshooting.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/training-troubleshooting.md index 84d119c3..935affae 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/training-troubleshooting.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/training-troubleshooting.md @@ -45,7 +45,7 @@ python train_mimic/scripts/data/check_motion_npz_fk.py \ ```bash python train_mimic/scripts/train.py \ --num_envs 64 --max_iterations 100 \ - --motion_file data/datasets//train + --motion_file data/datasets/ ``` 预期:`Mean episode length` 明显大于 1,`error_anchor_pos` 开始下降。 @@ -126,7 +126,7 @@ self.sim.nconmax = 150_000 ```bash python train_mimic/scripts/benchmark.py \ --checkpoint logs/rsl_rl/g1_general_tracking//model_30000.pt \ - --motion_file data/datasets//val \ + --motion_file data/datasets/ \ --num_envs 1 --num_eval_steps 2000 \ --video --video_length 600 ``` diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md index 0e241b7c..3c43a5e9 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md @@ -31,7 +31,7 @@ python -c "import train_mimic.tasks; print('training OK')" python train_mimic/scripts/train.py \ --num_envs 64 \ --max_iterations 100 \ - --motion_file data/datasets/seed/train + --motion_file data/datasets/seed ``` ### 完整训练 @@ -40,7 +40,7 @@ python train_mimic/scripts/train.py \ python train_mimic/scripts/train.py \ --num_envs 4096 \ --max_iterations 30000 \ - --motion_file data/datasets/seed/train + --motion_file data/datasets/seed ``` ### 多卡训练 @@ -50,7 +50,7 @@ python train_mimic/scripts/train.py \ --gpu_ids 0 1 2 3 \ --num_envs 1024 \ --max_iterations 30000 \ - --motion_file data/datasets/seed/train + --motion_file data/datasets/seed ``` ### 多机多卡训练 @@ -67,14 +67,14 @@ torchrun \ train_mimic/scripts/train.py \ --num_envs 1024 \ --max_iterations 1000 \ - --motion_file data/datasets/seed/train + --motion_file data/datasets/seed ``` **注意事项:** - 多卡模式下 `--num_envs` 为每张 GPU 的环境数量 - 多机模式下 `--num_envs` 也按每个进程计算,因此总环境数会随 `world_size` 线性增长 - 默认日志工具为 TensorBoard。使用 `--logger wandb` 或 `--logger swanlab` 可选择 W&B 或 SwanLab;项目名默认使用 `experiment_name` -- `--motion_file` 仅接受包含 `manifest.json` 和 `shard_*.h5` 文件的 HDF5 分片目录 +- `--motion_file` 接受数据集根目录或单个 `.h5` shard;shard 会递归发现 - `--cache_num_clips` 控制当前 HDF5 subset cache 大小;`--cache_swap_interval_steps` 控制在 rollout barrier 切换下一个 subset 的频率 - `--max_iterations` 表示追加迭代次数;例如从 `model_12000.pt` 恢复训练并设置 `--max_iterations 18000`,最终将训练到 `model_30000.pt` @@ -96,7 +96,7 @@ python train_mimic/scripts/save_onnx.py \ ```bash python train_mimic/scripts/play.py \ --checkpoint logs/rsl_rl/g1_general_tracking//model_30000.pt \ - --motion_file data/datasets/seed/val + --motion_file data/datasets/seed ``` ### 定量评估 @@ -104,7 +104,7 @@ python train_mimic/scripts/play.py \ ```bash python train_mimic/scripts/benchmark.py \ --checkpoint logs/rsl_rl/g1_general_tracking//model_30000.pt \ - --motion_file data/datasets/seed/val \ + --motion_file data/datasets/seed \ --num_envs 1 ``` @@ -113,7 +113,7 @@ python train_mimic/scripts/benchmark.py \ ```bash python train_mimic/scripts/benchmark.py \ --checkpoint logs/rsl_rl/g1_general_tracking//model_30000.pt \ - --motion_file data/datasets/seed/val \ + --motion_file data/datasets/seed \ --num_envs 1 \ --video \ --video_length 600 diff --git a/scripts/review/__init__.py b/scripts/review/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/scripts/review/build_dataset_from_review.py b/scripts/review/build_dataset_from_review.py deleted file mode 100644 index e5cc32c4..00000000 --- a/scripts/review/build_dataset_from_review.py +++ /dev/null @@ -1,225 +0,0 @@ -#!/usr/bin/env python3 -"""Rebuild train/val HDF5 shard directories from a filtered manifest. - -Reads filtered_manifest.csv (output of export_reviewed_manifest.py), -verifies all referenced motion files exist, and rebuilds cleaned train/val HDF5 splits. - -Usage: - python scripts/data/build_dataset_from_review.py \ - --filtered_manifest data/datasets/review/twist2/filtered_manifest.csv \ - --output_dir data/datasets/builds/twist2_cleaned -""" - -from __future__ import annotations - -import argparse -import csv -import sys -from dataclasses import asdict -from pathlib import Path - -import numpy as np - -PROJECT_ROOT = Path(__file__).resolve().parents[2] -sys.path.insert(0, str(PROJECT_ROOT)) - -from train_mimic.data.dataset_lib import ( - merge_clip_dicts_payload, - read_motion_clip, - utc_now_iso, - write_hdf5_manifest, - write_hdf5_motion_shard, - write_json, -) -from train_mimic.data.dataset_builder import DatasetClipRow, write_manifest_resolved - - -def main() -> None: - parser = argparse.ArgumentParser( - description="Build cleaned dataset from filtered manifest" - ) - parser.add_argument( - "--filtered_manifest", type=str, required=True, - help="Path to filtered_manifest.csv", - ) - parser.add_argument( - "--output_dir", type=str, required=True, - help="Output directory, e.g. data/datasets/builds/twist2_cleaned", - ) - parser.add_argument( - "--target_fps", type=int, default=None, - help="Resample all clips to this FPS. " - "Required when source clips have mixed FPS values.", - ) - args = parser.parse_args() - - manifest_path = Path(args.filtered_manifest) - if not manifest_path.is_absolute(): - manifest_path = (PROJECT_ROOT / manifest_path).resolve() - - output_dir = Path(args.output_dir) - if not output_dir.is_absolute(): - output_dir = (PROJECT_ROOT / output_dir).resolve() - - if not manifest_path.is_file(): - print(f"ERROR: filtered manifest not found: {manifest_path}", file=sys.stderr) - sys.exit(1) - - # Read filtered manifest - # Columns: clip_id, source, file_rel, num_frames, fps, resolved_split, resolved_npz_path, - # weight, clip_index - rows = [] - with manifest_path.open("r", encoding="utf-8", newline="") as f: - reader = csv.DictReader(f) - has_clip_index = "clip_index" in (reader.fieldnames or []) - for idx, raw in enumerate(reader, start=2): - rows.append({ - "clip_id": raw["clip_id"].strip(), - "source": raw["source"].strip(), - "file_rel": raw["file_rel"].strip(), - "num_frames": int(raw["num_frames"]), - "fps": int(raw["fps"]), - "resolved_split": raw["resolved_split"].strip(), - "resolved_npz_path": raw["resolved_npz_path"].strip(), - "weight": float(raw["weight"]), - "clip_index": int(raw["clip_index"]) if has_clip_index else -1, - "line_no": idx, - }) - - if not rows: - print("ERROR: filtered manifest has no data rows", file=sys.stderr) - sys.exit(1) - - # Verify all referenced motion files exist - missing = [] - for row in rows: - p = Path(row["resolved_npz_path"]) - if not p.is_file(): - # Try resolving from file_rel - alt = Path(row["file_rel"]) - if not alt.is_absolute(): - alt = PROJECT_ROOT / alt - if alt.is_file(): - row["resolved_npz_path"] = str(alt.resolve()) - else: - missing.append(f" line {row['line_no']}: {row['clip_id']} -> {row['resolved_npz_path']}") - - if missing: - print(f"ERROR: {len(missing)} motion files not found:", file=sys.stderr) - for m in missing[:20]: - print(m, file=sys.stderr) - if len(missing) > 20: - print(f" ... and {len(missing) - 20} more", file=sys.stderr) - sys.exit(1) - - # Check for mixed FPS - fps_values = sorted(set(r["fps"] for r in rows)) - if len(fps_values) > 1: - print(f"[INFO] Mixed FPS detected: {fps_values}. Resampling all clips to {args.target_fps} FPS.") - if args.target_fps is None: - print( - "ERROR: clips have mixed FPS values but --target_fps is not set.\n" - "Please specify --target_fps (e.g. --target_fps 30).", - file=sys.stderr, - ) - sys.exit(1) - - # Split into train / val - train_rows = [r for r in rows if r["resolved_split"] == "train"] - val_rows = [r for r in rows if r["resolved_split"] == "val"] - - if not train_rows: - print("ERROR: no train clips in filtered manifest", file=sys.stderr) - sys.exit(1) - if not val_rows: - print("WARNING: no val clips in filtered manifest", file=sys.stderr) - - output_dir.mkdir(parents=True, exist_ok=True) - - all_output_rows: list[DatasetClipRow] = [] - - def _merge_split(split_rows: list[dict], split_name: str) -> dict | None: - if not split_rows: - return None - print(f"Merging {len(split_rows)} {split_name} clips...") - split_dir = output_dir / split_name - split_dir.mkdir(parents=True, exist_ok=True) - out = split_dir / "shard_000.h5" - - clip_dicts = [ - read_motion_clip(Path(r["resolved_npz_path"]), int(r["clip_index"])) - for r in split_rows - ] - weights_list = [r["weight"] for r in split_rows] - payload = merge_clip_dicts_payload( - clip_dicts, - target_fps=args.target_fps, - weights=weights_list, - ) - h5_info = write_hdf5_motion_shard(payload, out) - write_hdf5_manifest( - split_dir, - shard_infos=[h5_info], - fps=int(payload["fps"]), - body_names=np.asarray(payload["body_names"]), - ) - - source_lengths = np.asarray(payload["clip_lengths"], dtype=np.int64) - for clip_index, (r, num_frames) in enumerate(zip(split_rows, source_lengths)): - all_output_rows.append(DatasetClipRow( - clip_id=r["clip_id"], - source=r["source"], - file_rel=r["file_rel"], - num_frames=int(num_frames), - fps=int(payload["fps"]), - resolved_split=split_name, - resolved_npz_path=str(out), - weight=float(r["weight"]), - clip_index=clip_index, - )) - - total_frames = int(np.asarray(payload["joint_pos"]).shape[0]) - stats = { - "output": str(split_dir), - "shards": 1, - "clips": int(h5_info["clips"]), - "num_clips": int(h5_info["clips"]), - "source_clips": int(h5_info["source_clips"]), - "frames": total_frames, - "fps": int(payload["fps"]), - "duration_s": float(total_frames / max(int(payload["fps"]), 1)), - } - print(f" {split_name}/: {stats['frames']} frames, {stats['duration_s'] / 60:.1f} min") - return stats - - train_stats = _merge_split(train_rows, "train") - val_stats = _merge_split(val_rows, "val") - - resolved_manifest = write_manifest_resolved(all_output_rows, output_dir) - - # Write build info - report = { - "built_at_utc": utc_now_iso(), - "source_manifest": str(manifest_path), - "manifest_resolved": str(resolved_manifest), - "output_dir": str(output_dir), - "target_fps": args.target_fps, - "source_rows": [asdict(row) for row in all_output_rows], - "clip_counts": { - "total": len(rows), - "train": len(train_rows), - "val": len(val_rows), - }, - "splits": { - "train": train_stats, - "val": val_stats, - }, - } - write_json(output_dir / "build_info.json", report) - - print(f"\nCleaned dataset built at: {output_dir}") - print(f" Total clips: {len(rows)} (train={len(train_rows)}, val={len(val_rows)})") - - -if __name__ == "__main__": - main() diff --git a/scripts/review/export_reviewed_manifest.py b/scripts/review/export_reviewed_manifest.py deleted file mode 100644 index 1f4f3c32..00000000 --- a/scripts/review/export_reviewed_manifest.py +++ /dev/null @@ -1,130 +0,0 @@ -#!/usr/bin/env python3 -"""Export a filtered manifest from review results. - -Reads review_state.csv, keeps only clips with decision == 'keep', -and outputs filtered_manifest.csv + review_summary.json. - -Usage: - python scripts/data/export_reviewed_manifest.py \ - --review data/datasets/review/twist2/review_state.csv \ - --output data/datasets/review/twist2/filtered_manifest.csv -""" - -from __future__ import annotations - -import argparse -import csv -import sys -from pathlib import Path - -PROJECT_ROOT = Path(__file__).resolve().parents[2] -sys.path.insert(0, str(PROJECT_ROOT)) - -from train_mimic.data.dataset_lib import write_json -from train_mimic.data.review_lib import ( - compute_review_stats, - load_review_state, - utc_now_iso, -) - - -def main() -> None: - parser = argparse.ArgumentParser(description="Export filtered manifest from review") - parser.add_argument("--review", type=str, required=True, help="Path to review_state.csv") - parser.add_argument( - "--output", type=str, default=None, - help="Path for filtered_manifest.csv (default: same dir as review)", - ) - parser.add_argument( - "--summary", type=str, default=None, - help="Path for review_summary.json (default: same dir as review)", - ) - parser.add_argument( - "--require_complete", action="store_true", - help="Fail if any clips are unreviewed", - ) - args = parser.parse_args() - - review_path = Path(args.review) - if not review_path.is_absolute(): - review_path = (PROJECT_ROOT / review_path).resolve() - - rows = load_review_state(review_path) - stats = compute_review_stats(rows) - - # Check completeness - unreviewed = stats.total - stats.reviewed - if args.require_complete and unreviewed > 0: - print( - f"ERROR: {unreviewed} clips are unreviewed. " - "Complete the review or remove --require_complete.", - file=sys.stderr, - ) - sys.exit(1) - - # Filter to keep only - kept_rows = [r for r in rows if r.decision == "keep"] - - # Output paths - review_dir = review_path.parent - if args.output: - output_path = Path(args.output) - if not output_path.is_absolute(): - output_path = (PROJECT_ROOT / output_path).resolve() - else: - output_path = review_dir / "filtered_manifest.csv" - - if args.summary: - summary_path = Path(args.summary) - if not summary_path.is_absolute(): - summary_path = (PROJECT_ROOT / summary_path).resolve() - else: - summary_path = review_dir / "review_summary.json" - - # Write filtered_manifest.csv in manifest_resolved.csv format - output_path.parent.mkdir(parents=True, exist_ok=True) - with output_path.open("w", encoding="utf-8", newline="") as f: - writer = csv.writer(f) - writer.writerow([ - "clip_id", "source", "file_rel", "num_frames", "fps", - "resolved_split", "resolved_npz_path", "weight", "clip_index", - ]) - for r in sorted(kept_rows, key=lambda x: x.clip_id): - writer.writerow([ - r.clip_id, r.source, r.file_rel, r.num_frames, r.fps, - r.resolved_split, r.resolved_npz_path, r.weight, r.clip_index, - ]) - - # Write summary JSON - summary = { - "exported_at_utc": utc_now_iso(), - "review_file": str(review_path), - "output_file": str(output_path), - "total_clips": stats.total, - "reviewed_clips": stats.reviewed, - "unreviewed_clips": unreviewed, - "keep_count": stats.keep_count, - "drop_count": stats.drop_count, - "skip_count": stats.skip_count, - "progress_pct": stats.progress_pct, - "kept_duration_s": stats.kept_duration_s, - "kept_duration_min": stats.kept_duration_s / 60.0, - "kept_train_duration_s": stats.kept_train_duration_s, - "kept_val_duration_s": stats.kept_val_duration_s, - "kept_duration_by_source": { - src: {"duration_s": dur, "duration_min": dur / 60.0} - for src, dur in sorted(stats.kept_duration_by_source.items()) - }, - } - write_json(summary_path, summary) - - print(f"Exported filtered manifest: {output_path}") - print(f" Kept clips: {len(kept_rows)}") - print(f" Kept duration: {stats.kept_duration_s / 60:.1f} min") - print(f" Train: {stats.kept_train_duration_s / 60:.1f} min") - print(f" Val: {stats.kept_val_duration_s / 60:.1f} min") - print(f"Summary: {summary_path}") - - -if __name__ == "__main__": - main() diff --git a/scripts/review/init_review_manifest.py b/scripts/review/init_review_manifest.py deleted file mode 100644 index 8c81bfe8..00000000 --- a/scripts/review/init_review_manifest.py +++ /dev/null @@ -1,125 +0,0 @@ -#!/usr/bin/env python3 -"""Initialize a review_state.csv from an existing manifest_resolved.csv. - -Usage: - python scripts/data/init_review_manifest.py \ - --dataset twist2 \ - --manifest data/datasets/builds/twist2/manifest_resolved.csv -""" - -from __future__ import annotations - -import argparse -import csv -import sys -from pathlib import Path - -PROJECT_ROOT = Path(__file__).resolve().parents[2] -sys.path.insert(0, str(PROJECT_ROOT)) - -from train_mimic.data.review_lib import ReviewRow, save_review_state - - -def main() -> None: - parser = argparse.ArgumentParser(description="Initialize review state from manifest") - parser.add_argument("--dataset", type=str, required=True, help="Dataset name, e.g. twist2") - parser.add_argument( - "--manifest", type=str, required=True, help="Path to manifest_resolved.csv" - ) - parser.add_argument( - "--output", - type=str, - default=None, - help="Output path (default: data/datasets/review/{dataset}/review_state.csv)", - ) - parser.add_argument("--force", action="store_true", help="Overwrite existing review file") - args = parser.parse_args() - - manifest_path = Path(args.manifest) - if not manifest_path.is_absolute(): - manifest_path = (PROJECT_ROOT / manifest_path).resolve() - if not manifest_path.is_file(): - print(f"ERROR: manifest not found: {manifest_path}", file=sys.stderr) - sys.exit(1) - - if args.output: - output_path = Path(args.output) - if not output_path.is_absolute(): - output_path = (PROJECT_ROOT / output_path).resolve() - else: - output_path = PROJECT_ROOT / "data" / "datasets" / "review" / args.dataset / "review_state.csv" - - if output_path.is_file() and not args.force: - print( - f"ERROR: review file already exists: {output_path}\n" - "Use --force to overwrite.", - file=sys.stderr, - ) - sys.exit(1) - - # Read manifest_resolved.csv - # Columns: clip_id, source, file_rel, num_frames, fps, resolved_split, resolved_npz_path, - # weight, clip_index - rows: list[ReviewRow] = [] - with manifest_path.open("r", encoding="utf-8", newline="") as f: - reader = csv.DictReader(f) - if reader.fieldnames is None: - print(f"ERROR: manifest is empty: {manifest_path}", file=sys.stderr) - sys.exit(1) - - required = ["clip_id", "source", "file_rel", "num_frames", "fps", "resolved_split", "resolved_npz_path", "weight"] - missing = [c for c in required if c not in reader.fieldnames] - if missing: - print(f"ERROR: manifest missing columns: {missing}", file=sys.stderr) - sys.exit(1) - has_clip_index = "clip_index" in reader.fieldnames - - for idx, raw in enumerate(reader, start=2): - clip_id = raw["clip_id"].strip() - source = raw["source"].strip() - file_rel = raw["file_rel"].strip() - resolved_npz_path = raw["resolved_npz_path"].strip() - num_frames = int(raw["num_frames"]) - fps = int(raw["fps"]) - resolved_split = raw["resolved_split"].strip() - weight = float(raw["weight"]) - clip_index = int(raw["clip_index"]) if has_clip_index else -1 - duration_s = num_frames / fps if fps > 0 else 0.0 - - rows.append( - ReviewRow( - clip_id=clip_id, - source=source, - file_rel=file_rel, - resolved_npz_path=resolved_npz_path, - resolved_split=resolved_split, - num_frames=num_frames, - fps=fps, - duration_s=duration_s, - weight=weight, - clip_index=clip_index, - ) - ) - - if not rows: - print("ERROR: manifest has no data rows", file=sys.stderr) - sys.exit(1) - - save_review_state(rows, output_path) - - # Print summary - total_duration_s = sum(r.duration_s for r in rows) - sources = {} - for r in rows: - sources[r.source] = sources.get(r.source, 0) + 1 - - print(f"Initialized review state: {output_path}") - print(f" Total clips: {len(rows)}") - print(f" Total duration: {total_duration_s / 60:.1f} min ({total_duration_s / 3600:.1f} h)") - print(f" Sources:") - for src, count in sorted(sources.items()): - print(f" {src}: {count} clips") - - -if __name__ == "__main__": - main() diff --git a/scripts/review/review_dataset.py b/scripts/review/review_dataset.py deleted file mode 100644 index 4966db94..00000000 --- a/scripts/review/review_dataset.py +++ /dev/null @@ -1,720 +0,0 @@ -#!/usr/bin/env python3 -"""Web-based dataset clip review tool using viser. - -Plays reference motion clips one-by-one in a browser and provides -GUI controls for annotation (keep/drop/skip, difficulty, notes). - -Usage: - python scripts/review/review_dataset.py \ - --dataset lafan1 \ - --review data/datasets/review/lafan1/review_state.csv -""" - -from __future__ import annotations - -import argparse -import sys -import time -from pathlib import Path -from threading import Lock - - -import mujoco -import numpy as np -import viser - -PROJECT_ROOT = Path(__file__).resolve().parents[2] -sys.path.insert(0, str(PROJECT_ROOT)) - -from mjlab.viewer.viser import ViserMujocoScene - -from teleopit.runtime.assets import UNITREE_G1_MJLAB_XML, missing_gmr_assets_message -from train_mimic.data.dataset_lib import read_motion_clip -from train_mimic.data.review_lib import ( - ReviewRow, - ReviewStats, - compute_review_stats, - load_review_state, - save_review_state, - utc_now_iso, -) - -DEFAULT_XML = UNITREE_G1_MJLAB_XML - - -# --------------------------------------------------------------------------- -# ClipPlayer: loads motion clips and drives MuJoCo qpos per frame -# --------------------------------------------------------------------------- - -class ClipPlayer: - """Loads a single motion clip and sets MuJoCo qpos frame-by-frame.""" - - def __init__(self, mj_model: mujoco.MjModel) -> None: - self.model = mj_model - self.data = mujoco.MjData(mj_model) - self._joint_pos: np.ndarray | None = None # (T, 29) - self._pelvis_pos: np.ndarray | None = None # (T, 3) - self._pelvis_quat: np.ndarray | None = None # (T, 4) wxyz - self._fps: int = 30 - self._num_frames: int = 0 - - def load_clip(self, motion_path: Path, clip_index: int = -1) -> None: - """Load one source clip from an HDF5 shard. - - Args: - motion_path: Path to an HDF5 shard. - clip_index: Source-clip index for HDF5 rows. - """ - d = read_motion_clip(motion_path, clip_index) - self._joint_pos = np.asarray(d["joint_pos"]) - body_pos_w = np.asarray(d["body_pos_w"]) - body_quat_w = np.asarray(d["body_quat_w"]) - - self._pelvis_pos = body_pos_w[:, 0, :] # pelvis = body 0 - self._pelvis_quat = body_quat_w[:, 0, :] - self._fps = int(d["fps"]) - self._num_frames = self._joint_pos.shape[0] - - def set_frame(self, frame_idx: int) -> None: - """Set qpos from frame data and run mj_forward.""" - if self._joint_pos is None: - return - idx = max(0, min(frame_idx, self._num_frames - 1)) - self.data.qpos[:3] = self._pelvis_pos[idx] - self.data.qpos[3:7] = self._pelvis_quat[idx] - self.data.qpos[7:] = self._joint_pos[idx] - mujoco.mj_forward(self.model, self.data) - - @property - def fps(self) -> int: - return self._fps - - @property - def num_frames(self) -> int: - return self._num_frames - - @property - def duration_s(self) -> float: - return self._num_frames / self._fps if self._fps > 0 else 0.0 - - -# --------------------------------------------------------------------------- -# ReviewSession: manages review state, navigation, persistence -# --------------------------------------------------------------------------- - -class ReviewSession: - """Manages review state, navigation order, and persistence.""" - - def __init__( - self, - review_path: Path, - sort_mode: str = "unreviewed_first", - ) -> None: - self._review_path = review_path - self._rows: list[ReviewRow] = load_review_state(review_path) - self._order: list[int] = [] # indices into _rows - self._cursor: int = 0 - self._lock = Lock() - self._reorder(sort_mode) - - @property - def total(self) -> int: - return len(self._rows) - - @property - def cursor_display(self) -> int: - """1-based position in the current ordering.""" - return self._cursor + 1 - - def current_row(self) -> ReviewRow | None: - if not self._order: - return None - return self._rows[self._order[self._cursor]] - - def go_next(self) -> ReviewRow | None: - if self._cursor < len(self._order) - 1: - self._cursor += 1 - return self.current_row() - - def go_prev(self) -> ReviewRow | None: - if self._cursor > 0: - self._cursor -= 1 - return self.current_row() - - def go_next_unreviewed(self) -> ReviewRow | None: - """Jump to the next unreviewed clip after the current cursor.""" - for i in range(self._cursor + 1, len(self._order)): - if self._rows[self._order[i]].decision == "": - self._cursor = i - return self.current_row() - # Wrap around from beginning - for i in range(0, self._cursor): - if self._rows[self._order[i]].decision == "": - self._cursor = i - return self.current_row() - return self.current_row() - - def jump_to(self, position: int) -> ReviewRow | None: - """Jump to a 1-based position in the ordering.""" - idx = max(0, min(position - 1, len(self._order) - 1)) - self._cursor = idx - return self.current_row() - - def annotate( - self, - decision: str, - difficulty: str = "", - issue_tags: str = "", - note: str = "", - ) -> None: - """Set annotation on current row and save to disk.""" - with self._lock: - row = self.current_row() - if row is None: - return - row.decision = decision - row.difficulty = difficulty - row.issue_tags = issue_tags - row.note = note - row.reviewed_at = utc_now_iso() - self.save() - - def save(self) -> None: - save_review_state(self._rows, self._review_path) - - def stats(self) -> ReviewStats: - return compute_review_stats(self._rows) - - def _reorder(self, sort_mode: str) -> None: - indices = list(range(len(self._rows))) - - if sort_mode == "unreviewed_first": - indices.sort(key=lambda i: ( - 0 if self._rows[i].decision == "" else 1, - self._rows[i].source, - self._rows[i].clip_id, - )) - elif sort_mode == "source": - indices.sort(key=lambda i: (self._rows[i].source, self._rows[i].clip_id)) - elif sort_mode == "duration_desc": - indices.sort(key=lambda i: -self._rows[i].duration_s) - else: - indices.sort(key=lambda i: self._rows[i].clip_id) - - self._order = indices - self._cursor = 0 - - -# --------------------------------------------------------------------------- -# ReviewViewerApp: viser server + GUI + main loop -# --------------------------------------------------------------------------- - -class ReviewViewerApp: - """Main application: ties viser, ClipPlayer, and ReviewSession together.""" - - def __init__( - self, - review_path: Path, - xml_path: Path, - project_root: Path, - *, - port: int = 8012, - sort_mode: str = "unreviewed_first", - ) -> None: - self._project_root = project_root - self._model = mujoco.MjModel.from_xml_path(str(xml_path)) - self._player = ClipPlayer(self._model) - self._session = ReviewSession(review_path, sort_mode) - - self._server = viser.ViserServer(port=port, label="Clip Review") - self._scene = ViserMujocoScene.create( - server=self._server, mj_model=self._model, num_envs=1, - ) - - # Pre-warm viser's cached_property type hint resolution at a shallow - # call stack. Without this, the first update_from_mjdata triggered - # from a deep callback chain causes RecursionError in Python 3.10's - # get_type_hints / _eval_type. - mujoco.mj_forward(self._model, self._player.data) - self._scene.update_from_mjdata(self._player.data) - - # Playback state - self._playing: bool = False - self._speed: float = 1.0 - self._current_frame: int = 0 - - # Pending action queue: callbacks only append here, main loop processes. - # This avoids deep recursion inside viser's callback / message queue. - self._pending_actions: list[str] = [] - self._pending_frame_scrub: int | None = None # from slider drag - self._pending_jump: int | None = None # from jump input - self._hotkey_clearing: bool = False # guard against recursive on_update - - def setup_gui(self) -> None: - """Build all viser GUI elements. Callbacks only set flags.""" - gui = self._server.gui - - # --- Clip Info --- - with gui.add_folder("Clip Info", order=0): - self._info_html = gui.add_html("Loading...") - - # --- Playback --- - with gui.add_folder("Playback", order=1): - self._play_btn = gui.add_button("Play", color="green") - self._frame_slider = gui.add_slider( - "Frame", min=0, max=1, step=1, initial_value=0, - ) - self._speed_group = gui.add_button_group( - "Speed", options=["0.25x", "0.5x", "1x", "2x"], - ) - self._restart_btn = gui.add_button("Restart") - - @self._play_btn.on_click - def _(_) -> None: - self._pending_actions.append("toggle_play") - - @self._frame_slider.on_update - def _(_) -> None: - self._pending_frame_scrub = int(self._frame_slider.value) - - @self._speed_group.on_click - def _(event) -> None: - speed_map = {"0.25x": 0.25, "0.5x": 0.5, "1x": 1.0, "2x": 2.0} - self._speed = speed_map.get(event.target.value, 1.0) - - @self._restart_btn.on_click - def _(_) -> None: - self._pending_actions.append("restart") - - # --- Annotation --- - with gui.add_folder("Annotation", order=2): - self._decision_dropdown = gui.add_dropdown( - "Decision", - options=["", "Keep", "Drop", "Skip"], - initial_value="", - hint="Must select Keep/Drop/Skip before saving", - ) - self._difficulty_dropdown = gui.add_dropdown( - "Difficulty", - options=["", "easy", "medium", "hard", "bad_data"], - initial_value="", - ) - self._tags_input = gui.add_text("Issue Tags", initial_value="") - self._note_input = gui.add_text("Note", initial_value="") - self._save_next_btn = gui.add_button("Save & Next", color="blue") - self._save_btn = gui.add_button("Save") - - @self._save_next_btn.on_click - def _(_) -> None: - self._pending_actions.append("save_next") - - @self._save_btn.on_click - def _(_) -> None: - self._pending_actions.append("save") - - # --- Navigation --- - with gui.add_folder("Navigation", order=3): - self._prev_btn = gui.add_button("Prev") - self._next_btn = gui.add_button("Next") - self._next_unreviewed_btn = gui.add_button("Next Unreviewed", color="orange") - self._jump_input = gui.add_number( - "Jump to #", initial_value=1, min=1, max=self._session.total, step=1, - ) - - @self._prev_btn.on_click - def _(_) -> None: - self._pending_actions.append("prev") - - @self._next_btn.on_click - def _(_) -> None: - self._pending_actions.append("next") - - @self._next_unreviewed_btn.on_click - def _(_) -> None: - self._pending_actions.append("next_unreviewed") - - @self._jump_input.on_update - def _(_) -> None: - self._pending_jump = int(self._jump_input.value) - - # --- Stats --- - with gui.add_folder("Stats", order=4, expand_by_default=False): - self._stats_html = gui.add_html("Loading...") - - # --- Keyboard Shortcuts --- - with gui.add_folder("Shortcuts", order=5): - self._hotkey_input = gui.add_text( - "Hotkey (click here, type key)", - initial_value="", - ) - gui.add_markdown( - "**K**=Keep+Next **D**=Drop+Next **S**=Skip+Next\n\n" - "**N**=Next **P**=Prev **U**=Next Unreviewed\n\n" - "**Space**=Play/Pause **R**=Restart **F**=Speed Up\n\n" - "**1**=Easy **2**=Medium **3**=Hard" - ) - - @self._hotkey_input.on_update - def _(_) -> None: - if self._hotkey_clearing: - return - raw = self._hotkey_input.value - self._hotkey_clearing = True - self._hotkey_input.value = "" - self._hotkey_clearing = False - if not raw: - return - ch = raw[-1].lower() - key_map = { - "k": "hotkey_keep", - "d": "hotkey_drop", - "s": "hotkey_skip", - "n": "next", - "p": "prev", - "u": "next_unreviewed", - " ": "toggle_play", - "r": "restart", - "f": "speed_up", - "1": "set_easy", - "2": "set_medium", - "3": "set_hard", - } - action = key_map.get(ch) - if action: - self._pending_actions.append(action) - - # Visualization options - self._scene.create_visualization_gui(show_debug_viz_control=False) - - # ------------------------------------------------------------------ - # Actions executed from the main loop (shallow call stack) - # ------------------------------------------------------------------ - - def _do_save(self) -> bool: - """Save the current annotation. Returns False if no decision selected.""" - decision = self._decision_dropdown.value.lower() if self._decision_dropdown.value else "" - if decision not in ("keep", "drop", "skip"): - print("[REVIEW] WARNING: no decision selected, save skipped") - return False - self._session.annotate( - decision=decision, - difficulty=self._difficulty_dropdown.value, - issue_tags=self._tags_input.value, - note=self._note_input.value, - ) - self._update_stats_display() - - # Terminal summary - s = self._session.stats() - row = self._session.current_row() - clip_id = row.clip_id if row else "?" - print( - f"[REVIEW] {clip_id} -> {decision} | " - f"{s.reviewed}/{s.total} ({s.progress_pct:.1f}%) | " - f"keep={s.keep_count} drop={s.drop_count} skip={s.skip_count} | " - f"kept_dur={s.kept_duration_s / 60:.1f}min" - ) - return True - - def _load_current_clip(self) -> None: - """Load the clip for the current session cursor.""" - row = self._session.current_row() - if row is None: - self._info_html.content = "No clips to review" - return - - # CSV keeps the historical column name, but current rows point at HDF5 shards. - motion_path = Path(row.resolved_npz_path) - if not motion_path.is_absolute(): - motion_path = self._project_root / motion_path - - try: - self._player.load_clip(motion_path, clip_index=row.clip_index) - except Exception as exc: - self._info_html.content = f"Error loading clip:
{exc}" - return - - self._current_frame = 0 - self._playing = False - - # Update scene first (before touching GUI widgets) - self._player.set_frame(0) - self._scene.update_from_mjdata(self._player.data) - - # Now update GUI widgets — disable slider callback processing - # by setting _pending_frame_scrub to None after we're done. - self._play_btn.label = "Play" - self._play_btn.color = "green" - - max_frame = max(0, self._player.num_frames - 1) - self._frame_slider.max = max_frame - self._frame_slider.value = 0 - - self._decision_dropdown.value = row.decision.capitalize() if row.decision else "" - self._difficulty_dropdown.value = row.difficulty - self._tags_input.value = row.issue_tags - self._note_input.value = row.note - self._jump_input.value = self._session.cursor_display - - # Drain any spurious pending events triggered by the GUI updates above - self._pending_frame_scrub = None - self._pending_jump = None - self._pending_actions.clear() - - self._update_info_display() - self._update_stats_display() - - def _update_frame(self) -> None: - """Set frame on player and update the 3D scene.""" - self._player.set_frame(self._current_frame) - self._scene.update_from_mjdata(self._player.data) - - def _update_info_display(self) -> None: - """Update the Clip Info HTML panel.""" - row = self._session.current_row() - if row is None: - return - - status_color = { - "keep": "green", "drop": "red", "skip": "orange", "": "gray", - } - status_label = row.decision if row.decision else "unreviewed" - color = status_color.get(row.decision, "gray") - - kept_min = self._session.stats().kept_duration_s / 60.0 - - self._info_html.content = f""" -
- #{self._session.cursor_display}/{self._session.total} - [{status_label}] - Kept: {kept_min:.1f} min
- clip_id: {row.clip_id}
- source: {row.source}
- split: {row.resolved_split}
- frames: {row.num_frames} | fps: {row.fps} - | duration: {row.duration_s:.2f}s -
- """ - - def _update_stats_display(self) -> None: - """Update the Stats HTML panel.""" - s = self._session.stats() - kept_min = s.kept_duration_s / 60.0 - kept_h = s.kept_duration_s / 3600.0 - by_source_lines = "".join( - f"{src}: {dur / 60:.1f} min
" - for src, dur in sorted(s.kept_duration_by_source.items()) - ) - self._stats_html.content = f""" -
-

- Kept: {kept_min:.1f} min ({kept_h:.2f} h) -

- Progress: {s.reviewed} / {s.total} ({s.progress_pct:.1f}%)
- Keep: {s.keep_count} | - Drop: {s.drop_count} | - Skip: {s.skip_count}
-

Duration Breakdown

- Train: {s.kept_train_duration_s / 60:.1f} min
- Val: {s.kept_val_duration_s / 60:.1f} min
-

By Source

- {by_source_lines if by_source_lines else "none yet"} -
- """ - - # ------------------------------------------------------------------ - # Main loop — processes pending actions from callbacks - # ------------------------------------------------------------------ - - def _process_pending(self) -> None: - """Process all pending actions from GUI callbacks.""" - # Handle frame scrub from slider (take latest value only) - scrub = self._pending_frame_scrub - if scrub is not None and not self._playing: - self._pending_frame_scrub = None - self._current_frame = scrub - self._update_frame() - - # Handle jump input - jump = self._pending_jump - if jump is not None: - self._pending_jump = None - self._session.jump_to(jump) - self._load_current_clip() - return # load_current_clip clears remaining actions - - # Handle button actions (process one per tick to keep things responsive) - while self._pending_actions: - action = self._pending_actions.pop(0) - - if action == "toggle_play": - self._playing = not self._playing - self._play_btn.label = "Pause" if self._playing else "Play" - self._play_btn.color = "red" if self._playing else "green" - - elif action == "restart": - self._current_frame = 0 - self._update_frame() - self._frame_slider.value = 0 - self._pending_frame_scrub = None # drain spurious slider event - - elif action == "save": - self._do_save() - - elif action == "save_next": - if self._do_save(): - self._session.go_next_unreviewed() - self._load_current_clip() - return # load clears actions - - elif action == "prev": - self._session.go_prev() - self._load_current_clip() - return - - elif action == "next": - self._session.go_next() - self._load_current_clip() - return - - elif action == "next_unreviewed": - self._session.go_next_unreviewed() - self._load_current_clip() - return - - elif action == "hotkey_keep": - self._decision_dropdown.value = "Keep" - self._pending_actions.append("save_next") - - elif action == "hotkey_drop": - self._decision_dropdown.value = "Drop" - self._pending_actions.append("save_next") - - elif action == "hotkey_skip": - self._decision_dropdown.value = "Skip" - self._pending_actions.append("save_next") - - elif action == "speed_up": - _speed_levels = [1.0, 2.0, 4.0, 8.0] - try: - idx = _speed_levels.index(self._speed) - self._speed = _speed_levels[(idx + 1) % len(_speed_levels)] - except ValueError: - self._speed = 1.0 - # Update button group to reflect current speed - label = f"{self._speed:g}x" - if label in ("0.25x", "0.5x", "1x", "2x"): - self._speed_group.value = label - - elif action == "set_easy": - self._difficulty_dropdown.value = "easy" - - elif action == "set_medium": - self._difficulty_dropdown.value = "medium" - - elif action == "set_hard": - self._difficulty_dropdown.value = "hard" - - def run(self) -> None: - """Main loop: process callbacks and handle playback timing.""" - self.setup_gui() - self._load_current_clip() - - print(f"\nReview viewer ready at http://localhost:{self._server.get_port()}") - print("Press Ctrl+C to exit.\n") - - last_frame_time = time.time() - try: - while True: - now = time.time() - - # Check if scene visualization settings changed - if self._scene.needs_update: - self._scene.refresh_visualization() - - # Process GUI actions from callbacks - self._process_pending() - - # Advance playback - if self._playing and self._player.num_frames > 0: - dt = now - last_frame_time - frames_to_advance = dt * self._player.fps * self._speed - if frames_to_advance >= 1.0: - self._current_frame += int(frames_to_advance) - if self._current_frame >= self._player.num_frames: - self._current_frame = 0 - self._playing = False - self._play_btn.label = "Play" - self._play_btn.color = "green" - self._update_frame() - self._frame_slider.value = self._current_frame - self._pending_frame_scrub = None # drain spurious event - last_frame_time = now - else: - last_frame_time = now - - time.sleep(1.0 / 60.0) - - except KeyboardInterrupt: - print("\nShutting down...") - self._server.stop() - - -# --------------------------------------------------------------------------- -# CLI entry point -# --------------------------------------------------------------------------- - -def main() -> None: - parser = argparse.ArgumentParser(description="Web-based dataset clip review tool") - parser.add_argument("--dataset", type=str, required=True, help="Dataset name") - parser.add_argument( - "--review", type=str, default=None, - help="Path to review_state.csv (default: data/datasets/review/{dataset}/review_state.csv)", - ) - parser.add_argument( - "--xml", - type=str, - default=None, - help="Robot XML path (default: teleopit/retargeting/gmr/assets/unitree_g1/g1_mjlab.xml)", - ) - parser.add_argument("--port", type=int, default=8012, help="Viser server port") - parser.add_argument( - "--sort", type=str, default="unreviewed_first", - choices=["unreviewed_first", "source", "duration_desc"], - help="Initial clip sort order", - ) - args = parser.parse_args() - - if args.review: - review_path = Path(args.review) - if not review_path.is_absolute(): - review_path = (PROJECT_ROOT / review_path).resolve() - else: - review_path = PROJECT_ROOT / "data" / "datasets" / "review" / args.dataset / "review_state.csv" - - if not review_path.is_file(): - print(f"ERROR: review state not found: {review_path}", file=sys.stderr) - print("Run init_review_manifest.py first.", file=sys.stderr) - sys.exit(1) - - xml_path = Path(args.xml) if args.xml else DEFAULT_XML - if not xml_path.is_file(): - print( - "ERROR: " - + missing_gmr_assets_message(xml_path, label="Robot XML"), - file=sys.stderr, - ) - sys.exit(1) - - app = ReviewViewerApp( - review_path=review_path, - xml_path=xml_path, - project_root=PROJECT_ROOT, - port=args.port, - sort_mode=args.sort, - ) - app.run() - - -if __name__ == "__main__": - main() diff --git a/scripts/view/view_dataset.py b/scripts/view/view_dataset.py new file mode 100644 index 00000000..44fee0fb --- /dev/null +++ b/scripts/view/view_dataset.py @@ -0,0 +1,451 @@ +#!/usr/bin/env python3 +"""Read-only web viewer for Teleopit HDF5 motion datasets.""" + +from __future__ import annotations + +import argparse +import sys +import time +from dataclasses import dataclass +from pathlib import Path + +import h5py +import mujoco +import numpy as np +import viser + +PROJECT_ROOT = Path(__file__).resolve().parents[2] +sys.path.insert(0, str(PROJECT_ROOT)) + +from mjlab.viewer.viser import ViserMujocoScene + +from teleopit.runtime.assets import UNITREE_G1_MJLAB_XML, missing_gmr_assets_message +from train_mimic.data.dataset_lib import find_motion_shards, read_motion_clip + +DEFAULT_XML = UNITREE_G1_MJLAB_XML + + +@dataclass(frozen=True) +class DatasetClip: + clip_id: str + shard_path: Path + clip_index: int + num_frames: int + fps: int + + @property + def duration_s(self) -> float: + return self.num_frames / self.fps if self.fps > 0 else 0.0 + + +def discover_dataset_clips(dataset_path: Path) -> list[DatasetClip]: + clips: list[DatasetClip] = [] + for shard_path in find_motion_shards(dataset_path): + with h5py.File(shard_path, "r") as h5: + required = ["source_clip_lengths", "source_clip_fps"] + missing = [key for key in required if key not in h5] + if missing: + raise ValueError( + f"HDF5 shard {shard_path} missing source clip metadata {missing}. " + "Rebuild the dataset with the current HDF5 writer." + ) + lengths = np.asarray(h5["source_clip_lengths"], dtype=np.int64) + fps = np.asarray(h5["source_clip_fps"], dtype=np.int64) + + if dataset_path.is_file(): + shard_rel = Path(shard_path.name) + else: + try: + shard_rel = shard_path.relative_to(dataset_path) + except ValueError: + shard_rel = Path(shard_path.name) + for clip_index, (num_frames, clip_fps) in enumerate(zip(lengths, fps)): + clips.append( + DatasetClip( + clip_id=f"{shard_rel.as_posix()}#{clip_index}", + shard_path=shard_path, + clip_index=int(clip_index), + num_frames=int(num_frames), + fps=int(clip_fps), + ) + ) + if not clips: + raise ValueError(f"dataset has no source clips: {dataset_path}") + return clips + + +class ClipPlayer: + """Loads one source clip and drives MuJoCo qpos frame by frame.""" + + def __init__(self, mj_model: mujoco.MjModel) -> None: + self.model = mj_model + self.data = mujoco.MjData(mj_model) + self._joint_pos: np.ndarray | None = None + self._pelvis_pos: np.ndarray | None = None + self._pelvis_quat: np.ndarray | None = None + self._fps: int = 30 + self._num_frames: int = 0 + + def load_clip(self, clip: DatasetClip) -> None: + d = read_motion_clip(clip.shard_path, clip.clip_index) + self._joint_pos = np.asarray(d["joint_pos"]) + body_pos_w = np.asarray(d["body_pos_w"]) + body_quat_w = np.asarray(d["body_quat_w"]) + self._pelvis_pos = body_pos_w[:, 0, :] + self._pelvis_quat = body_quat_w[:, 0, :] + self._fps = int(d["fps"]) + self._num_frames = int(self._joint_pos.shape[0]) + + def set_frame(self, frame_idx: int) -> None: + if self._joint_pos is None or self._pelvis_pos is None or self._pelvis_quat is None: + return + idx = max(0, min(frame_idx, self._num_frames - 1)) + self.data.qpos[:3] = self._pelvis_pos[idx] + self.data.qpos[3:7] = self._pelvis_quat[idx] + self.data.qpos[7:] = self._joint_pos[idx] + mujoco.mj_forward(self.model, self.data) + + @property + def fps(self) -> int: + return self._fps + + @property + def num_frames(self) -> int: + return self._num_frames + + +class DatasetSession: + def __init__(self, clips: list[DatasetClip], sort_mode: str) -> None: + self._clips = clips + self._order = list(range(len(clips))) + if sort_mode == "duration_desc": + self._order.sort(key=lambda i: -clips[i].duration_s) + elif sort_mode == "shard": + self._order.sort(key=lambda i: clips[i].clip_id) + self._cursor = 0 + + @property + def total(self) -> int: + return len(self._clips) + + @property + def cursor_display(self) -> int: + return self._cursor + 1 + + @property + def total_duration_s(self) -> float: + return sum(clip.duration_s for clip in self._clips) + + def current_clip(self) -> DatasetClip: + return self._clips[self._order[self._cursor]] + + def go_next(self) -> None: + if self._cursor < len(self._order) - 1: + self._cursor += 1 + + def go_prev(self) -> None: + if self._cursor > 0: + self._cursor -= 1 + + def jump_to(self, position: int) -> None: + self._cursor = max(0, min(position - 1, len(self._order) - 1)) + + +class DatasetViewerApp: + def __init__( + self, + *, + dataset_path: Path, + xml_path: Path, + port: int, + sort_mode: str, + ) -> None: + self._dataset_path = dataset_path + self._model = mujoco.MjModel.from_xml_path(str(xml_path)) + self._player = ClipPlayer(self._model) + self._session = DatasetSession(discover_dataset_clips(dataset_path), sort_mode) + + self._server = viser.ViserServer(port=port, label="Dataset Viewer") + self._scene = ViserMujocoScene.create( + server=self._server, + mj_model=self._model, + num_envs=1, + ) + mujoco.mj_forward(self._model, self._player.data) + self._scene.update_from_mjdata(self._player.data) + + self._playing = False + self._speed = 1.0 + self._current_frame = 0 + self._pending_actions: list[str] = [] + self._pending_frame_scrub: int | None = None + self._pending_jump: int | None = None + self._hotkey_clearing = False + + def setup_gui(self) -> None: + gui = self._server.gui + + with gui.add_folder("Clip", order=0): + self._info_html = gui.add_html("Loading...") + + with gui.add_folder("Playback", order=1): + self._play_btn = gui.add_button("Play", color="green") + self._frame_slider = gui.add_slider("Frame", min=0, max=1, step=1, initial_value=0) + self._speed_group = gui.add_button_group( + "Speed", + options=["0.25x", "0.5x", "1x", "2x"], + ) + self._restart_btn = gui.add_button("Restart") + + @self._play_btn.on_click + def _(_) -> None: + self._pending_actions.append("toggle_play") + + @self._frame_slider.on_update + def _(_) -> None: + self._pending_frame_scrub = int(self._frame_slider.value) + + @self._speed_group.on_click + def _(event) -> None: + self._speed = {"0.25x": 0.25, "0.5x": 0.5, "1x": 1.0, "2x": 2.0}.get( + event.target.value, + 1.0, + ) + + @self._restart_btn.on_click + def _(_) -> None: + self._pending_actions.append("restart") + + with gui.add_folder("Navigation", order=2): + self._prev_btn = gui.add_button("Prev") + self._next_btn = gui.add_button("Next") + self._jump_input = gui.add_number( + "Jump to #", + initial_value=1, + min=1, + max=self._session.total, + step=1, + ) + + @self._prev_btn.on_click + def _(_) -> None: + self._pending_actions.append("prev") + + @self._next_btn.on_click + def _(_) -> None: + self._pending_actions.append("next") + + @self._jump_input.on_update + def _(_) -> None: + self._pending_jump = int(self._jump_input.value) + + with gui.add_folder("Dataset", order=3, expand_by_default=False): + self._stats_html = gui.add_html("Loading...") + + with gui.add_folder("Shortcuts", order=4): + self._hotkey_input = gui.add_text("Hotkey (click here, type key)", initial_value="") + gui.add_markdown("**N**=Next **P**=Prev **Space**=Play/Pause **R**=Restart **F**=Speed Up") + + @self._hotkey_input.on_update + def _(_) -> None: + if self._hotkey_clearing: + return + raw = self._hotkey_input.value + self._hotkey_clearing = True + self._hotkey_input.value = "" + self._hotkey_clearing = False + if not raw: + return + action = { + "n": "next", + "p": "prev", + " ": "toggle_play", + "r": "restart", + "f": "speed_up", + }.get(raw[-1].lower()) + if action: + self._pending_actions.append(action) + + self._scene.create_visualization_gui(show_debug_viz_control=False) + + def _load_current_clip(self) -> None: + clip = self._session.current_clip() + try: + self._player.load_clip(clip) + except Exception as exc: + self._info_html.content = f"Error loading clip:
{exc}" + return + + self._current_frame = 0 + self._playing = False + self._player.set_frame(0) + self._scene.update_from_mjdata(self._player.data) + + self._play_btn.label = "Play" + self._play_btn.color = "green" + self._frame_slider.max = max(0, self._player.num_frames - 1) + self._frame_slider.value = 0 + self._jump_input.value = self._session.cursor_display + + self._pending_frame_scrub = None + self._pending_jump = None + self._pending_actions.clear() + self._update_info_display() + self._update_stats_display() + + def _update_frame(self) -> None: + self._player.set_frame(self._current_frame) + self._scene.update_from_mjdata(self._player.data) + + def _update_info_display(self) -> None: + clip = self._session.current_clip() + self._info_html.content = f""" +
+ #{self._session.cursor_display}/{self._session.total}
+ clip: {clip.clip_id}
+ shard: {clip.shard_path}
+ source clip index: {clip.clip_index}
+ frames: {clip.num_frames} | fps: {clip.fps} + | duration: {clip.duration_s:.2f}s +
+ """ + + def _update_stats_display(self) -> None: + total_min = self._session.total_duration_s / 60.0 + self._stats_html.content = f""" +
+ root: {self._dataset_path}
+ source clips: {self._session.total}
+ duration: {total_min:.1f} min +
+ """ + + def _process_pending(self) -> None: + scrub = self._pending_frame_scrub + if scrub is not None and not self._playing: + self._pending_frame_scrub = None + self._current_frame = scrub + self._update_frame() + + jump = self._pending_jump + if jump is not None: + self._pending_jump = None + self._session.jump_to(jump) + self._load_current_clip() + return + + while self._pending_actions: + action = self._pending_actions.pop(0) + if action == "toggle_play": + self._playing = not self._playing + self._play_btn.label = "Pause" if self._playing else "Play" + self._play_btn.color = "red" if self._playing else "green" + elif action == "restart": + self._current_frame = 0 + self._update_frame() + self._frame_slider.value = 0 + self._pending_frame_scrub = None + elif action == "prev": + self._session.go_prev() + self._load_current_clip() + return + elif action == "next": + self._session.go_next() + self._load_current_clip() + return + elif action == "speed_up": + levels = [1.0, 2.0, 4.0, 8.0] + self._speed = levels[(levels.index(self._speed) + 1) % len(levels)] if self._speed in levels else 1.0 + label = f"{self._speed:g}x" + if label in ("0.25x", "0.5x", "1x", "2x"): + self._speed_group.value = label + + def run(self) -> None: + self.setup_gui() + self._load_current_clip() + + print(f"\nDataset viewer ready at http://localhost:{self._server.get_port()}") + print("Press Ctrl+C to exit.\n") + + last_frame_time = time.time() + try: + while True: + now = time.time() + if self._scene.needs_update: + self._scene.refresh_visualization() + + self._process_pending() + + if self._playing and self._player.num_frames > 0: + dt = now - last_frame_time + frames_to_advance = dt * self._player.fps * self._speed + if frames_to_advance >= 1.0: + self._current_frame += int(frames_to_advance) + if self._current_frame >= self._player.num_frames: + self._current_frame = 0 + self._playing = False + self._play_btn.label = "Play" + self._play_btn.color = "green" + self._update_frame() + self._frame_slider.value = self._current_frame + self._pending_frame_scrub = None + last_frame_time = now + else: + last_frame_time = now + + time.sleep(1.0 / 60.0) + except KeyboardInterrupt: + print("\nShutting down...") + self._server.stop() + + +def main() -> None: + parser = argparse.ArgumentParser(description="Read-only web viewer for Teleopit HDF5 datasets") + parser.add_argument( + "--dataset", + type=str, + required=True, + help="Dataset root directory or a single Teleopit .h5 shard", + ) + parser.add_argument( + "--xml", + type=str, + default=None, + help="Robot XML path", + ) + parser.add_argument("--port", type=int, default=8012, help="Viser server port") + parser.add_argument( + "--sort", + type=str, + default="shard", + choices=["shard", "duration_desc"], + help="Initial clip sort order", + ) + args = parser.parse_args() + + dataset_path = Path(args.dataset) + if not dataset_path.is_absolute(): + dataset_path = (PROJECT_ROOT / dataset_path).resolve() + if not dataset_path.exists(): + print(f"ERROR: dataset path not found: {dataset_path}", file=sys.stderr) + sys.exit(1) + + xml_path = Path(args.xml) if args.xml else DEFAULT_XML + if not xml_path.is_file(): + print( + "ERROR: " + missing_gmr_assets_message(xml_path, label="Robot XML"), + file=sys.stderr, + ) + sys.exit(1) + + app = DatasetViewerApp( + dataset_path=dataset_path, + xml_path=xml_path, + port=args.port, + sort_mode=args.sort, + ) + app.run() + + +if __name__ == "__main__": + main() diff --git a/teleopit/recording/pico_motion.py b/teleopit/recording/pico_motion.py index 90968a71..77d9509e 100644 --- a/teleopit/recording/pico_motion.py +++ b/teleopit/recording/pico_motion.py @@ -15,7 +15,7 @@ from teleopit.constants import FULL_QPOS_DIM, NUM_JOINTS, ROOT_DIM from teleopit.runtime.assets import PROJECT_ROOT from train_mimic.data.dataset_lib import inspect_clip_dict -from train_mimic.data.motion_fk import MotionFkExtractor, compute_body_velocities +from train_mimic.data.motion_fk import MotionFkExtractor, compute_body_velocities, finite_diff_velocity from train_mimic.scripts.convert_pkl_to_npz import _MJLAB_G1_BODY_NAMES @@ -25,7 +25,6 @@ class PicoDatasetSpec: dataset_name: str = "pico_recorded" target_fps: int = 30 - val_percent: int = 5 source_name: str = "pico_clips" @@ -150,8 +149,6 @@ def ensure_pico_dataset_spec( content = ( f"name: {spec.dataset_name}\n" f"target_fps: {int(spec.target_fps)}\n" - f"val_percent: {int(spec.val_percent)}\n" - 'hash_salt: ""\n' "preprocess:\n" " normalize_root_xy: true\n" " ground_align: first_frame_foot\n" @@ -190,7 +187,7 @@ def qpos_sequence_to_motion_clip( raise ValueError(f"joint_pos must have {NUM_JOINTS} columns, got {joint_pos.shape}") dt = 1.0 / float(fps) - joint_vel = np.gradient(joint_pos, dt, axis=0).astype(np.float32) + joint_vel = finite_diff_velocity(joint_pos, dt) fk_extractor = extractor or MotionFkExtractor() body_pos_w, body_quat_w = fk_extractor.extract(root_pos, root_quat_wxyz, joint_pos, names) @@ -198,6 +195,8 @@ def qpos_sequence_to_motion_clip( clip = { "fps": int(fps), + "root_pos": root_pos.astype(np.float32, copy=False), + "root_quat_w": root_quat_wxyz.astype(np.float32, copy=False), "joint_pos": joint_pos.astype(np.float32, copy=False), "joint_vel": joint_vel.astype(np.float32, copy=False), "body_pos_w": np.asarray(body_pos_w, dtype=np.float32), diff --git a/teleopit/runtime/external_assets.py b/teleopit/runtime/external_assets.py index 0513f137..3bd8ed90 100644 --- a/teleopit/runtime/external_assets.py +++ b/teleopit/runtime/external_assets.py @@ -40,7 +40,6 @@ class AssetEntry: ), ], "data": [ - AssetEntry("data/train", "data/datasets/seed/train", repo="dataset"), - AssetEntry("data/val", "data/datasets/seed/val", repo="dataset"), + AssetEntry("data", "data/datasets/seed", repo="dataset"), ], } diff --git a/tests/test_dataset_v2.py b/tests/test_dataset_v2.py index 207d2fe8..b7905098 100644 --- a/tests/test_dataset_v2.py +++ b/tests/test_dataset_v2.py @@ -15,7 +15,6 @@ SourceInputFile, DatasetSourceSpec, DatasetSpec, - assign_splits, build_dataset_from_spec, convert_source_to_npz_clips, load_dataset_spec, @@ -99,13 +98,10 @@ def test_load_dataset_spec_parses_typed_sources(tmp_path: Path) -> None: spec_path.write_text( f"""name: demo target_fps: 30 -val_percent: 5 -hash_salt: "" sources: - name: clips type: npz input: {tmp_path / 'npz_source'} - weight: 2.5 - name: lafan1 type: bvh input: {tmp_path / 'lafan1'} @@ -119,7 +115,6 @@ def test_load_dataset_spec_parses_typed_sources(tmp_path: Path) -> None: assert spec.name == "demo" assert spec.target_fps == 30 assert spec.sources[0].type == "npz" - assert spec.sources[0].weight == 2.5 assert spec.sources[1].type == "bvh" assert spec.sources[1].bvh_format == "lafan1" @@ -129,8 +124,6 @@ def test_load_dataset_spec_parses_preprocess(tmp_path: Path) -> None: spec_path.write_text( f"""name: demo target_fps: 30 -val_percent: 5 -hash_salt: "" preprocess: normalize_root_xy: true ground_align: none @@ -162,8 +155,6 @@ def test_load_dataset_spec_parses_seed_filter_preset(tmp_path: Path) -> None: spec_path.write_text( f"""name: seed_demo target_fps: 30 -val_percent: 5 -hash_salt: "" sources: - name: seed type: seed_csv @@ -187,8 +178,6 @@ def test_load_dataset_spec_rejects_seed_filter_preset_on_non_seed_source(tmp_pat spec_path.write_text( f"""name: demo target_fps: 30 -val_percent: 5 -hash_salt: "" sources: - name: clips type: npz @@ -210,8 +199,6 @@ def test_load_dataset_spec_rejects_unknown_seed_filter_preset(tmp_path: Path) -> spec_path.write_text( f"""name: demo target_fps: 30 -val_percent: 5 -hash_salt: "" sources: - name: seed type: seed_csv @@ -231,8 +218,6 @@ def test_load_dataset_spec_rejects_bvh_without_format(tmp_path: Path) -> None: spec_path.write_text( """name: demo target_fps: 30 -val_percent: 5 -hash_salt: "" sources: - name: broken type: bvh @@ -245,22 +230,6 @@ def test_load_dataset_spec_rejects_bvh_without_format(tmp_path: Path) -> None: load_dataset_spec(spec_path) -def test_assign_splits_guarantees_non_empty_train_and_val() -> None: - rows = [ - DatasetClipRow("clip_a", "src", "a.npz", 10, 30, "", "/tmp/a.npz"), - DatasetClipRow("clip_b", "src", "b.npz", 10, 30, "", "/tmp/b.npz"), - ] - resolved = assign_splits(rows, 1, "") - splits = {row.clip_id: row.resolved_split for row in resolved} - assert set(splits.values()) == {"train", "val"} - - -def test_assign_splits_rejects_single_clip_dataset() -> None: - rows = [DatasetClipRow("clip_a", "src", "a.npz", 10, 30, "", "/tmp/a.npz")] - with pytest.raises(ValueError, match="at least 2 clips"): - assign_splits(rows, 5, "") - - def test_convert_source_to_npz_clips_handles_pkl_source(tmp_path: Path) -> None: input_dir = tmp_path / "pkl_source" _write_pkl(input_dir / "clip.pkl") @@ -509,8 +478,6 @@ def test_build_dataset_from_spec_writes_shard_directories(tmp_path: Path) -> Non spec = DatasetSpec( name="demo_dataset", target_fps=30, - val_percent=5, - hash_salt="", sources=[DatasetSourceSpec(name="npz_src", type="npz", input=str(npz_input))], ) @@ -519,20 +486,16 @@ def test_build_dataset_from_spec_writes_shard_directories(tmp_path: Path) -> Non dataset_dir = output_root / "demo_dataset" assert report["dataset_dir"] == str(dataset_dir) - assert report["build_dir"] == str(dataset_dir) assert (dataset_dir / "clips" / "npz_src" / "clip_a.npz").is_file() assert (dataset_dir / "clips" / "npz_src" / "clip_b.npz").is_file() - assert (dataset_dir / "train" / "manifest.json").is_file() - assert (dataset_dir / "train" / "shard_000.h5").is_file() - assert (dataset_dir / "val" / "manifest.json").is_file() - assert (dataset_dir / "val" / "shard_000.h5").is_file() - assert (dataset_dir / "manifest_resolved.csv").is_file() - assert (dataset_dir / "build_info.json").is_file() - assert report["clip_counts"]["total"] == 2 + assert (dataset_dir / "shard_000.h5").is_file() + assert report["input_clips"] == 2 - with h5py.File(dataset_dir / "train" / "shard_000.h5", "r") as train_data: - assert "clip_starts" in train_data - assert "clip_lengths" in train_data + with h5py.File(dataset_dir / "shard_000.h5", "r") as shard: + assert "root_pos" in shard + assert "root_quat_w" in shard + assert "joint_pos" in shard + assert "body_pos_w" not in shard def test_collect_clip_rows_ignores_stale_excluded_cached_npz(tmp_path: Path) -> None: @@ -543,8 +506,6 @@ def test_collect_clip_rows_ignores_stale_excluded_cached_npz(tmp_path: Path) -> spec = DatasetSpec( name="demo_dataset", target_fps=30, - val_percent=5, - hash_salt="", sources=[ DatasetSourceSpec( name="npz_src", @@ -577,8 +538,6 @@ def test_convert_source_to_npz_clips_applies_preprocess(tmp_path: Path) -> None: preprocess=DatasetSpec( name="unused", target_fps=30, - val_percent=5, - hash_salt="", sources=[source], ).preprocess, ) @@ -737,7 +696,6 @@ def _capture(clip_dict, *, preprocess, clip_label): dataset_builder._batch_convert_chunk( [str(pkl_path)], - [1.0], 30, str(tmp_path / "merged.npz"), "train", @@ -763,6 +721,8 @@ def _arrays(num_frames: int) -> dict[str, object]: body_quat_w[..., 0] = 1.0 return { "fps": 30, + "root_pos": np.zeros((num_frames, 3), dtype=np.float32), + "root_quat_w": np.tile(np.asarray([[1.0, 0.0, 0.0, 0.0]], dtype=np.float32), (num_frames, 1)), "joint_pos": np.zeros((num_frames, 29), dtype=np.float32), "joint_vel": np.zeros((num_frames, 29), dtype=np.float32), "body_pos_w": np.zeros((num_frames, num_bodies, 3), dtype=np.float32), @@ -786,7 +746,6 @@ def _convert(path: str, **_kwargs): stats = dataset_builder._batch_convert_chunk( [str(short_path), str(valid_path)], - [1.0, 1.0], 30, str(tmp_path / "merged.h5"), "train", @@ -836,8 +795,6 @@ def test_build_dataset_batch_manifest_skips_filtered_entries( spec = DatasetSpec( name="seed_demo", target_fps=30, - val_percent=5, - hash_salt="", sources=[DatasetSourceSpec(name="seed", type="seed_csv", input=str(source_dir))], ) dataset_dir = tmp_path / "datasets" / spec.name @@ -861,15 +818,14 @@ def _collect_with_report(_source, *, quiet=False): "preset_reject_reasons": {"content_body_position:sitting": 1}, }) - def _hash_split(clip_id: str, _val_percent: int, _salt: str = "") -> str: - return "val" if clip_id.endswith("keep_val") else "train" - num_bodies = len(_MJLAB_G1_BODY_NAMES) def _write_merged(path: Path, lengths: list[int]) -> dict: total = sum(lengths) joint_pos = np.zeros((total, 29), dtype=np.float32) joint_vel = np.zeros_like(joint_pos) + root_pos = np.zeros((total, 3), dtype=np.float32) + root_quat_w = np.tile(np.asarray([[1.0, 0.0, 0.0, 0.0]], dtype=np.float32), (total, 1)) body_pos_w = np.zeros((total, num_bodies, 3), dtype=np.float32) body_quat_w = np.zeros((total, num_bodies, 4), dtype=np.float32) body_quat_w[..., 0] = 1.0 @@ -881,6 +837,8 @@ def _write_merged(path: Path, lengths: list[int]) -> dict: clip_starts[1:] = np.cumsum(clip_lengths[:-1]) return write_hdf5_motion_shard({ "fps": 30, + "root_pos": root_pos, + "root_quat_w": root_quat_w, "joint_pos": joint_pos, "joint_vel": joint_vel, "body_pos_w": body_pos_w, @@ -891,50 +849,31 @@ def _write_merged(path: Path, lengths: list[int]) -> dict: "clip_starts": clip_starts, "clip_lengths": clip_lengths, "clip_fps": np.full(len(lengths), 30, dtype=np.int64), - "clip_weights": np.ones(len(lengths), dtype=np.float64), }, path) - def _batch_convert_split(clips, target_fps, output_dir, jobs, split_name, preprocess): + def _batch_convert_split(clips, target_fps, output_dir, jobs, label, preprocess): _ = clips, target_fps, jobs, preprocess output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) shard_path = output_dir / "shard_000.h5" - if split_name == "train": - h5_info = _write_merged(shard_path, [22]) - return ({ - "output": str(output_dir), - "shards": 1, - "clips": 1, - "num_clips": 1, - "frames": 22, - "fps": 30, - "duration_s": 22.0 / 30.0, - }, [{ - "path": shard_path, - "clip_lengths": h5_info["clip_lengths"], - "source_clip_lengths": h5_info["source_clip_lengths"], - "frames": h5_info["frames"], - "kept_file_paths": [str(keep_train)], - }]) - h5_info = _write_merged(shard_path, [24]) + h5_info = _write_merged(shard_path, [22, 24]) return ({ "output": str(output_dir), "shards": 1, - "clips": 1, - "num_clips": 1, - "frames": 24, + "clips": 2, + "num_clips": 2, + "frames": 46, "fps": 30, - "duration_s": 24.0 / 30.0, + "duration_s": 46.0 / 30.0, }, [{ "path": shard_path, "clip_lengths": h5_info["clip_lengths"], "source_clip_lengths": h5_info["source_clip_lengths"], "frames": h5_info["frames"], - "kept_file_paths": [str(keep_val)], + "kept_file_paths": [str(keep_train), str(keep_val)], }]) monkeypatch.setattr(dataset_builder, "_collect_source_files_with_report", _collect_with_report) - monkeypatch.setattr(dataset_builder, "hash_split", _hash_split) monkeypatch.setattr(dataset_builder, "_batch_convert_split", _batch_convert_split) report = dataset_builder._build_dataset_batch( @@ -946,11 +885,102 @@ def _batch_convert_split(clips, target_fps, output_dir, jobs, split_name, prepro jobs=2, ) - manifest = (dataset_dir / "manifest_resolved.csv").read_text(encoding="utf-8") - assert "seed:keep_train" in manifest - assert "seed:keep_val" in manifest - assert "seed:drop_train" not in manifest - assert report["clip_counts"] == {"total": 2, "train": 1, "val": 1} - assert report["input_clip_counts"] == {"total": 3, "train": 2, "val": 1} + assert (dataset_dir / "shard_000.h5").is_file() + assert not (dataset_dir / "manifest_resolved.csv").exists() + assert report["input_clips"] == 3 + assert report["stats"]["clips"] == 2 assert report["source_filters"][0]["seed_filter_preset"] == "groot_strict" assert report["source_filters"][0]["preset_reject_reasons"] == {"content_body_position:sitting": 1} + + +def test_build_dataset_batch_clears_stale_top_level_shards( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + source_dir = tmp_path / "seed_source" + keep_path = source_dir / "keep.csv" + keep_path.parent.mkdir(parents=True, exist_ok=True) + keep_path.write_text("placeholder", encoding="utf-8") + + spec = DatasetSpec( + name="seed_demo", + target_fps=30, + sources=[DatasetSourceSpec(name="seed", type="seed_csv", input=str(source_dir))], + ) + paths = dataset_builder.resolve_dataset_paths(spec, output_root=tmp_path / "datasets") + paths.dataset_dir.mkdir(parents=True) + stale_shard = paths.dataset_dir / "shard_999.h5" + stale_tmp = paths.dataset_dir / ".seed_demo_chunk_7.h5" + stale_shard.write_text("stale", encoding="utf-8") + stale_tmp.write_text("stale", encoding="utf-8") + + def _collect_with_report(_source, *, quiet=False): + _ = quiet + return ([ + SourceInputFile(path=keep_path, rel_no_suffix=Path("keep")), + ], source_dir, { + "source": "seed", + "type": "seed_csv", + "metadata_csv": None, + "seed_filter_preset": None, + "scanned_files": 1, + "metadata_rows_matched": 1, + "kept_files": 1, + "filtered_files": 0, + }) + + def _batch_convert_split(clips, target_fps, output_dir, jobs, label, preprocess): + _ = clips, target_fps, jobs, label, preprocess + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + shard_path = output_dir / "shard_000.h5" + num_bodies = len(_MJLAB_G1_BODY_NAMES) + h5_info = write_hdf5_motion_shard({ + "fps": 30, + "root_pos": np.zeros((22, 3), dtype=np.float32), + "root_quat_w": np.tile(np.asarray([[1.0, 0.0, 0.0, 0.0]], dtype=np.float32), (22, 1)), + "joint_pos": np.zeros((22, 29), dtype=np.float32), + "joint_vel": np.zeros((22, 29), dtype=np.float32), + "body_pos_w": np.zeros((22, num_bodies, 3), dtype=np.float32), + "body_quat_w": np.tile( + np.asarray([[[1.0, 0.0, 0.0, 0.0]]], dtype=np.float32), + (22, num_bodies, 1), + ), + "body_lin_vel_w": np.zeros((22, num_bodies, 3), dtype=np.float32), + "body_ang_vel_w": np.zeros((22, num_bodies, 3), dtype=np.float32), + "body_names": np.asarray(_MJLAB_G1_BODY_NAMES, dtype=str), + "clip_starts": np.asarray([0], dtype=np.int64), + "clip_lengths": np.asarray([22], dtype=np.int64), + "clip_fps": np.asarray([30], dtype=np.int64), + }, shard_path) + return ({ + "output": str(output_dir), + "shards": 1, + "clips": 1, + "num_clips": 1, + "frames": 22, + "fps": 30, + "duration_s": 22.0 / 30.0, + }, [{ + "path": shard_path, + "clip_lengths": h5_info["clip_lengths"], + "source_clip_lengths": h5_info["source_clip_lengths"], + "frames": h5_info["frames"], + "kept_file_paths": [str(keep_path)], + }]) + + monkeypatch.setattr(dataset_builder, "_collect_source_files_with_report", _collect_with_report) + monkeypatch.setattr(dataset_builder, "_batch_convert_split", _batch_convert_split) + + dataset_builder._build_dataset_batch( + spec, + paths=paths, + force=False, + skip_fk_check=True, + skip_validate=False, + jobs=1, + ) + + assert (paths.dataset_dir / "shard_000.h5").is_file() + assert not stale_shard.exists() + assert not stale_tmp.exists() diff --git a/tests/test_dataset_viewer.py b/tests/test_dataset_viewer.py new file mode 100644 index 00000000..cb8790d4 --- /dev/null +++ b/tests/test_dataset_viewer.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +import sys +from pathlib import Path + +import numpy as np + +_PROJECT_ROOT = str(Path(__file__).resolve().parents[1]) +if _PROJECT_ROOT not in sys.path: + sys.path.insert(0, _PROJECT_ROOT) + +from scripts.view.view_dataset import discover_dataset_clips +from train_mimic.data.dataset_lib import merge_clip_dicts_payload, write_hdf5_motion_shard + + +def _clip_dict(num_frames: int, fps: int = 30) -> dict[str, object]: + root_quat_w = np.zeros((num_frames, 4), dtype=np.float32) + root_quat_w[:, 0] = 1.0 + body_quat_w = np.zeros((num_frames, 1, 4), dtype=np.float32) + body_quat_w[..., 0] = 1.0 + return { + "fps": fps, + "root_pos": np.zeros((num_frames, 3), dtype=np.float32), + "root_quat_w": root_quat_w, + "joint_pos": np.zeros((num_frames, 29), dtype=np.float32), + "joint_vel": np.zeros((num_frames, 29), dtype=np.float32), + "body_pos_w": np.zeros((num_frames, 1, 3), dtype=np.float32), + "body_quat_w": body_quat_w, + "body_lin_vel_w": np.zeros((num_frames, 1, 3), dtype=np.float32), + "body_ang_vel_w": np.zeros((num_frames, 1, 3), dtype=np.float32), + "body_names": np.asarray(["pelvis"], dtype=str), + } + + +def test_discover_dataset_clips_reads_source_clip_metadata(tmp_path: Path) -> None: + payload = merge_clip_dicts_payload([ + _clip_dict(4, fps=30), + _clip_dict(6, fps=30), + ]) + write_hdf5_motion_shard(payload, tmp_path / "nested" / "shard_000.h5") + + clips = discover_dataset_clips(tmp_path) + + assert [clip.clip_id for clip in clips] == ["nested/shard_000.h5#0", "nested/shard_000.h5#1"] + assert [clip.num_frames for clip in clips] == [4, 6] + assert [clip.clip_index for clip in clips] == [0, 1] + + +def test_discover_dataset_clips_uses_filename_for_single_h5_input(tmp_path: Path) -> None: + shard_path = tmp_path / "shard_000.h5" + payload = merge_clip_dicts_payload([ + _clip_dict(4, fps=30), + _clip_dict(6, fps=30), + ]) + write_hdf5_motion_shard(payload, shard_path) + + clips = discover_dataset_clips(shard_path) + + assert [clip.clip_id for clip in clips] == ["shard_000.h5#0", "shard_000.h5#1"] diff --git a/tests/test_motion_sampling.py b/tests/test_motion_sampling.py index 42811bed..23e1ab77 100644 --- a/tests/test_motion_sampling.py +++ b/tests/test_motion_sampling.py @@ -2,13 +2,12 @@ from pathlib import Path from types import SimpleNamespace -import json import numpy as np import pytest import torch -from train_mimic.data.dataset_lib import write_hdf5_manifest, write_hdf5_motion_shard +from train_mimic.data.dataset_lib import write_hdf5_motion_shard from train_mimic.tasks.tracking.mdp.commands import MotionCommand, MotionLib @@ -33,6 +32,8 @@ def _clip_dict(num_frames: int = 6, fps: int = 1) -> dict[str, object]: ) return { "fps": fps, + "root_pos": body_pos_w[:, 0], + "root_quat_w": body_quat_w[:, 0], "joint_pos": joint_pos, "joint_vel": joint_vel, "body_pos_w": body_pos_w, @@ -46,11 +47,10 @@ def _clip_dict(num_frames: int = 6, fps: int = 1) -> dict[str, object]: def _write_shard_dir( path: Path, clip_dicts: list[dict[str, object]], - *, - weights: list[float] | None = None, ) -> Path: path.mkdir(parents=True, exist_ok=True) array_keys = [ + "root_pos", "root_quat_w", "joint_pos", "joint_vel", "body_pos_w", "body_quat_w", "body_lin_vel_w", "body_ang_vel_w", ] @@ -64,14 +64,8 @@ def _write_shard_dir( merged["clip_starts"] = clip_starts merged["clip_lengths"] = clip_lengths merged["clip_fps"] = np.full(len(clip_dicts), int(clip_dicts[0]["fps"]), dtype=np.int64) - merged["clip_weights"] = np.asarray(weights if weights is not None else [1.0] * len(clip_dicts), dtype=np.float64) shard_info = write_hdf5_motion_shard(merged, path / "shard_000.h5") - write_hdf5_manifest( - path, - shard_infos=[shard_info], - fps=int(clip_dicts[0]["fps"]), - body_names=np.asarray(clip_dicts[0]["body_names"]), - ) + _ = shard_info return path @@ -126,10 +120,6 @@ def test_motion_lib_get_window_frames_returns_requested_offsets(tmp_path: Path) torch.tensor([2.0, 4.0, 1.0], dtype=torch.float32), ) assert frames["body_pos_w"].shape == (1, 3, 2, 3) - assert torch.allclose( - frames["body_pos_w"][0, :, 0, 0], - torch.tensor([2.0, 4.0, 1.0], dtype=torch.float32), - ) current = motion.get_frames( torch.tensor([0], dtype=torch.long), @@ -154,10 +144,7 @@ def test_motion_lib_selects_bodies_by_dataset_names(tmp_path: Path) -> None: ) assert frames["body_pos_w"].shape == (1, 2, 3) - assert torch.allclose( - frames["body_pos_w"][0, :, 1], - torch.tensor([2.0, 0.0], dtype=torch.float32), - ) + assert torch.isfinite(frames["body_pos_w"]).all() def test_motion_lib_window_start_and_end_times_follow_valid_center_range(tmp_path: Path) -> None: @@ -174,6 +161,28 @@ def test_motion_lib_window_start_and_end_times_follow_valid_center_range(tmp_pat assert torch.allclose(motion.clip_sample_end_s[motion_ids], torch.tensor([3.0])) +def test_motion_lib_global_cache_sampling_weights_follow_valid_duration(tmp_path: Path) -> None: + motion_path = _write_shard_dir( + tmp_path / "motion_weighted", + [ + _clip_dict(num_frames=3, fps=10), + _clip_dict(num_frames=6, fps=10), + _clip_dict(num_frames=11, fps=10), + ], + ) + + motion = MotionLib( + str(motion_path), + body_indexes=torch.tensor([0, 1], dtype=torch.long), + window_steps=(0,), + ) + + assert torch.allclose( + motion._cache.global_sample_weights, + torch.tensor([0.2, 0.5, 1.0], dtype=torch.float32), + ) + + def test_motion_lib_rejects_shard_body_name_mismatch(tmp_path: Path) -> None: motion_path = tmp_path / "motion_mismatch" clip = _clip_dict() @@ -185,6 +194,7 @@ def test_motion_lib_rejects_shard_body_name_mismatch(tmp_path: Path) -> None: dtype=str, ) array_keys = [ + "root_pos", "root_quat_w", "joint_pos", "joint_vel", "body_pos_w", "body_quat_w", "body_lin_vel_w", "body_ang_vel_w", ] @@ -194,26 +204,9 @@ def test_motion_lib_rejects_shard_body_name_mismatch(tmp_path: Path) -> None: merged["clip_starts"] = np.asarray([0], dtype=np.int64) merged["clip_lengths"] = np.asarray([np.asarray(clip_bad["joint_pos"]).shape[0]], dtype=np.int64) merged["clip_fps"] = np.asarray([int(clip_bad["fps"])], dtype=np.int64) - merged["clip_weights"] = np.asarray([1.0], dtype=np.float64) - bad_info = write_hdf5_motion_shard(merged, motion_path / "shard_001.h5") - - (motion_path / "manifest.json").write_text( - json.dumps({ - "format": "teleopit_motion_hdf5", - "version": 1, - "fps": 1, - "body_names": np.asarray(clip["body_names"]).tolist(), - "shards": [ - {"path": "shard_000.h5", "clips": 1, "frames": 6}, - {"path": "shard_001.h5", "clips": 1, "frames": int(bad_info["frames"])}, - ], - "clips": 2, - "frames": 12, - }), - encoding="utf-8", - ) + write_hdf5_motion_shard(merged, motion_path / "shard_001.h5") - with pytest.raises(ValueError, match="body_names mismatch"): + with pytest.raises(ValueError, match="body_names"): MotionLib( str(shard0), body_indexes=torch.tensor([0, 1], dtype=torch.long), @@ -221,19 +214,6 @@ def test_motion_lib_rejects_shard_body_name_mismatch(tmp_path: Path) -> None: ) -def test_write_hdf5_manifest_accepts_relative_shard_paths(tmp_path: Path) -> None: - motion_path = _write_shard_dir(tmp_path / "motion_relative_manifest", [_clip_dict()]) - manifest_path = write_hdf5_manifest( - motion_path, - shard_infos=[{"path": "shard_000.h5", "clips": 1, "frames": 6}], - fps=1, - body_names=np.asarray(["pelvis", "left_ankle_roll_link", "right_ankle_roll_link"]), - ) - - payload = json.loads(manifest_path.read_text(encoding="utf-8")) - assert payload["shards"][0]["path"] == "shard_000.h5" - - class _FakeMotion: def __init__(self) -> None: self.clip_sample_start_s = torch.tensor([0.0, 1.0, 2.0]) diff --git a/tests/test_pico_motion_recording.py b/tests/test_pico_motion_recording.py index cfa540fd..5efca8d7 100644 --- a/tests/test_pico_motion_recording.py +++ b/tests/test_pico_motion_recording.py @@ -68,6 +68,8 @@ def test_unique_clip_path_adds_timestamp_and_avoids_overwrite(tmp_path: Path) -> def test_qpos_sequence_to_motion_clip_writes_standard_npz_fields() -> None: clip = qpos_sequence_to_motion_clip(_qpos_sequence(), fps=30, extractor=_FakeFkExtractor()) assert int(clip["fps"]) == 30 + assert clip["root_pos"].shape == (4, 3) + assert clip["root_quat_w"].shape == (4, 4) assert clip["joint_pos"].shape == (4, NUM_JOINTS) assert clip["joint_vel"].shape == (4, NUM_JOINTS) assert clip["body_pos_w"].shape[0] == 4 diff --git a/tests/test_review_pipeline.py b/tests/test_review_pipeline.py deleted file mode 100644 index a452970e..00000000 --- a/tests/test_review_pipeline.py +++ /dev/null @@ -1,326 +0,0 @@ -from __future__ import annotations - -import csv -import sys -from pathlib import Path - -# Ensure project root is on sys.path so `scripts.review` is importable -# even when running `pytest` directly (without `python -m pytest`). -_PROJECT_ROOT = str(Path(__file__).resolve().parents[1]) -if _PROJECT_ROOT not in sys.path: - sys.path.insert(0, _PROJECT_ROOT) - -import numpy as np -import h5py - -from scripts.review import build_dataset_from_review -from scripts.review import export_reviewed_manifest -from scripts.review import init_review_manifest -from train_mimic.data.dataset_lib import ( - merge_clip_dicts_payload, - write_hdf5_manifest, - write_hdf5_motion_shard, -) -from train_mimic.data.review_lib import ReviewRow, load_review_state, save_review_state - - -BODY_NAMES = np.array(["pelvis", "torso"], dtype=str) - - -def _write_manifest(path: Path) -> None: - with path.open("w", encoding="utf-8", newline="") as f: - writer = csv.writer(f) - writer.writerow( - [ - "clip_id", - "source", - "file_rel", - "num_frames", - "fps", - "resolved_split", - "resolved_npz_path", - "weight", - "clip_index", - ] - ) - writer.writerow( - [ - "src:clip_train", - "src", - "cache/clip_train.npz", - 4, - 24, - "train", - "/tmp/placeholder_train.npz", - 2.5, - -1, - ] - ) - writer.writerow( - [ - "src:clip_val", - "src", - "cache/clip_val.npz", - 5, - 30, - "val", - "/tmp/placeholder_val.npz", - 0.75, - -1, - ] - ) - - -def _write_npz(path: Path, *, num_frames: int, fps: int) -> None: - joint_pos = np.linspace(0.0, 0.2, num_frames * 29, dtype=np.float32).reshape(num_frames, 29) - joint_vel = np.gradient(joint_pos, axis=0).astype(np.float32) - body_pos_w = np.zeros((num_frames, len(BODY_NAMES), 3), dtype=np.float32) - body_pos_w[:, 0, 2] = np.linspace(0.75, 0.8, num_frames, dtype=np.float32) - body_pos_w[:, 1, 2] = body_pos_w[:, 0, 2] + 0.3 - body_quat_w = np.zeros((num_frames, len(BODY_NAMES), 4), dtype=np.float32) - body_quat_w[..., 0] = 1.0 - body_lin_vel_w = np.zeros((num_frames, len(BODY_NAMES), 3), dtype=np.float32) - body_ang_vel_w = np.zeros((num_frames, len(BODY_NAMES), 3), dtype=np.float32) - np.savez( - path, - fps=fps, - joint_pos=joint_pos, - joint_vel=joint_vel, - body_pos_w=body_pos_w, - body_quat_w=body_quat_w, - body_lin_vel_w=body_lin_vel_w, - body_ang_vel_w=body_ang_vel_w, - body_names=BODY_NAMES, - ) - - -def _clip_dict(*, num_frames: int, fps: int) -> dict[str, object]: - joint_pos = np.linspace(0.0, 0.2, num_frames * 29, dtype=np.float32).reshape(num_frames, 29) - joint_vel = np.gradient(joint_pos, axis=0).astype(np.float32) - body_pos_w = np.zeros((num_frames, len(BODY_NAMES), 3), dtype=np.float32) - body_pos_w[:, 0, 2] = np.linspace(0.75, 0.8, num_frames, dtype=np.float32) - body_pos_w[:, 1, 2] = body_pos_w[:, 0, 2] + 0.3 - body_quat_w = np.zeros((num_frames, len(BODY_NAMES), 4), dtype=np.float32) - body_quat_w[..., 0] = 1.0 - body_lin_vel_w = np.zeros((num_frames, len(BODY_NAMES), 3), dtype=np.float32) - body_ang_vel_w = np.zeros((num_frames, len(BODY_NAMES), 3), dtype=np.float32) - return { - "fps": fps, - "joint_pos": joint_pos, - "joint_vel": joint_vel, - "body_pos_w": body_pos_w, - "body_quat_w": body_quat_w, - "body_lin_vel_w": body_lin_vel_w, - "body_ang_vel_w": body_ang_vel_w, - "body_names": BODY_NAMES, - } - - -def _write_h5_split(path: Path, clip: dict[str, object]) -> Path: - path.mkdir(parents=True, exist_ok=True) - payload = merge_clip_dicts_payload([clip]) - shard_path = path / "shard_000.h5" - h5_info = write_hdf5_motion_shard(payload, shard_path) - write_hdf5_manifest( - path, - shard_infos=[h5_info], - fps=int(payload["fps"]), - body_names=np.asarray(payload["body_names"]), - ) - return shard_path - - -def test_init_review_manifest_preserves_weight(tmp_path: Path, monkeypatch) -> None: - manifest_path = tmp_path / "manifest_resolved.csv" - review_path = tmp_path / "review_state.csv" - _write_manifest(manifest_path) - - monkeypatch.setattr( - sys, - "argv", - [ - "init_review_manifest.py", - "--dataset", - "demo", - "--manifest", - str(manifest_path), - "--output", - str(review_path), - ], - ) - - init_review_manifest.main() - - rows = load_review_state(review_path) - assert [row.weight for row in rows] == [2.5, 0.75] - assert [row.decision for row in rows] == ["", ""] - - -def test_export_reviewed_manifest_preserves_weight_and_filters_keep( - tmp_path: Path, - monkeypatch, -) -> None: - review_path = tmp_path / "review_state.csv" - output_path = tmp_path / "filtered_manifest.csv" - summary_path = tmp_path / "review_summary.json" - save_review_state( - [ - ReviewRow( - clip_id="src:clip_train", - source="src", - file_rel="cache/clip_train.npz", - resolved_npz_path="cache/clip_train.npz", - resolved_split="train", - num_frames=4, - fps=24, - duration_s=4 / 24, - weight=2.5, - decision="keep", - ), - ReviewRow( - clip_id="src:clip_val", - source="src", - file_rel="cache/clip_val.npz", - resolved_npz_path="cache/clip_val.npz", - resolved_split="val", - num_frames=5, - fps=30, - duration_s=5 / 30, - weight=0.75, - decision="drop", - ), - ], - review_path, - ) - - monkeypatch.setattr( - sys, - "argv", - [ - "export_reviewed_manifest.py", - "--review", - str(review_path), - "--output", - str(output_path), - "--summary", - str(summary_path), - ], - ) - - export_reviewed_manifest.main() - - with output_path.open("r", encoding="utf-8", newline="") as f: - rows = list(csv.DictReader(f)) - - assert len(rows) == 1 - assert rows[0]["clip_id"] == "src:clip_train" - assert float(rows[0]["weight"]) == 2.5 - - -def test_build_dataset_from_review_resamples_mixed_fps_and_preserves_weights( - tmp_path: Path, - monkeypatch, -) -> None: - train_h5 = _write_h5_split(tmp_path / "source_train", _clip_dict(num_frames=4, fps=24)) - val_h5 = _write_h5_split(tmp_path / "source_val", _clip_dict(num_frames=5, fps=30)) - - filtered_manifest = tmp_path / "filtered_manifest.csv" - with filtered_manifest.open("w", encoding="utf-8", newline="") as f: - writer = csv.writer(f) - writer.writerow( - [ - "clip_id", - "source", - "file_rel", - "num_frames", - "fps", - "resolved_split", - "resolved_npz_path", - "weight", - "clip_index", - ] - ) - writer.writerow( - [ - "src:clip_train", - "src", - str(train_h5), - 4, - 24, - "train", - str(train_h5), - 2.5, - 0, - ] - ) - writer.writerow( - [ - "src:clip_val", - "src", - str(val_h5), - 5, - 30, - "val", - str(val_h5), - 0.75, - 0, - ] - ) - - output_dir = tmp_path / "twist2_cleaned" - monkeypatch.setattr( - sys, - "argv", - [ - "build_dataset_from_review.py", - "--filtered_manifest", - str(filtered_manifest), - "--output_dir", - str(output_dir), - "--target_fps", - "30", - ], - ) - - build_dataset_from_review.main() - - assert (output_dir / "train" / "manifest.json").is_file() - assert (output_dir / "val" / "manifest.json").is_file() - with h5py.File(output_dir / "train" / "shard_000.h5", "r") as train: - assert int(train.attrs["fps"]) == 30 - assert train["clip_weights"][()].tolist() == [2.5] - assert train["source_clip_lengths"][()].tolist() == [5] - with h5py.File(output_dir / "val" / "shard_000.h5", "r") as val: - assert int(val.attrs["fps"]) == 30 - assert val["clip_weights"][()].tolist() == [0.75] - - with (output_dir / "manifest_resolved.csv").open("r", encoding="utf-8", newline="") as f: - rows = list(csv.DictReader(f)) - assert rows[0]["resolved_npz_path"].endswith("train/shard_000.h5") - assert rows[0]["clip_index"] == "0" - - -def test_init_review_manifest_preserves_weight_metadata(tmp_path: Path, monkeypatch) -> None: - manifest_path = tmp_path / "manifest_resolved.csv" - review_path = tmp_path / "review_state.csv" - _write_manifest(manifest_path) - - monkeypatch.setattr( - sys, - "argv", - [ - "init_review_manifest.py", - "--dataset", - "demo", - "--manifest", - str(manifest_path), - "--output", - str(review_path), - ], - ) - - init_review_manifest.main() - - rows = load_review_state(review_path) - assert rows[0].weight == 2.5 - assert rows[1].weight == 0.75 diff --git a/tests/test_train_script.py b/tests/test_train_script.py index 8d1f8e96..6474374f 100644 --- a/tests/test_train_script.py +++ b/tests/test_train_script.py @@ -7,9 +7,11 @@ import types from pathlib import Path +import numpy as np import pytest from train_mimic.app import DEFAULT_TASK, validate_checkpoint_path, validate_motion_file +from train_mimic.data.dataset_lib import write_hdf5_motion_shard from train_mimic.scripts import train from train_mimic.tasks.tracking.config.rl import make_general_tracking_ppo_runner_cfg @@ -31,7 +33,7 @@ def _args(**overrides: object) -> argparse.Namespace: "seed": 42, "logger": "tensorboard", "experiment_name": None, - "motion_file": "data/datasets/twist2/train", + "motion_file": "data/datasets/twist2", "resume": None, "sampling_mode": None, "rewind_prob": None, @@ -242,16 +244,25 @@ def test_tracking_runner_configs_disable_model_upload() -> None: def test_validate_motion_file_accepts_shard_directories(tmp_path: Path) -> None: - (tmp_path / "manifest.json").write_text( - '{"format":"teleopit_motion_hdf5","version":1,"shards":[{"path":"shard_000.h5"}]}', - encoding="utf-8", + num_frames = 3 + write_hdf5_motion_shard( + { + "fps": 30, + "root_pos": np.zeros((num_frames, 3), dtype=np.float32), + "root_quat_w": np.tile(np.asarray([[1.0, 0.0, 0.0, 0.0]], dtype=np.float32), (num_frames, 1)), + "joint_pos": np.zeros((num_frames, 29), dtype=np.float32), + "body_names": np.asarray(["pelvis"], dtype=str), + "clip_starts": np.asarray([0], dtype=np.int64), + "clip_lengths": np.asarray([num_frames], dtype=np.int64), + "clip_fps": np.asarray([30], dtype=np.int64), + }, + tmp_path / "shard_000.h5", ) - (tmp_path / "shard_000.h5").write_bytes(b"placeholder") validate_motion_file(str(tmp_path)) def test_validate_motion_file_rejects_non_shard_paths(tmp_path: Path) -> None: - with pytest.raises(FileNotFoundError, match="Motion shard directory not found"): + with pytest.raises(FileNotFoundError, match="Motion dataset not found"): validate_motion_file(str(tmp_path)) diff --git a/train_mimic/app.py b/train_mimic/app.py index 5d619f53..7c7033a6 100644 --- a/train_mimic/app.py +++ b/train_mimic/app.py @@ -11,20 +11,20 @@ GENERAL_TRACKING_TASK, SUPPORTED_TASKS, ) +from train_mimic.data.dataset_lib import find_motion_shards DEFAULT_TASK = GENERAL_TRACKING_TASK def validate_motion_file(motion_file: str) -> None: - p = Path(motion_file) - manifest = p / "manifest.json" - if p.is_dir() and manifest.is_file() and any(p.glob("*.h5")): - return - raise FileNotFoundError( - f"Motion shard directory not found: {motion_file}. Provide --motion_file " - f"pointing to an HDF5 split directory with manifest.json and shard_*.h5 files. " - f"Example: {DEFAULT_TRAIN_MOTION_FILE}" - ) + try: + find_motion_shards(Path(motion_file)) + except FileNotFoundError as exc: + raise FileNotFoundError( + f"Motion dataset not found: {motion_file}. Provide --motion_file pointing " + "to a dataset root directory containing Teleopit shard_*.h5 files " + f"(recursively allowed). Example: {DEFAULT_TRAIN_MOTION_FILE}" + ) from exc def validate_checkpoint_path(checkpoint_path: str) -> None: diff --git a/train_mimic/configs/datasets/lafan1.yaml b/train_mimic/configs/datasets/lafan1.yaml index 2e34512a..3517a1fc 100644 --- a/train_mimic/configs/datasets/lafan1.yaml +++ b/train_mimic/configs/datasets/lafan1.yaml @@ -1,7 +1,5 @@ name: lafan1 target_fps: 30 -val_percent: 5 -hash_salt: "" preprocess: normalize_root_xy: true ground_align: none diff --git a/train_mimic/configs/datasets/seed.yaml b/train_mimic/configs/datasets/seed.yaml index c18c5f9a..d1bb7fb3 100644 --- a/train_mimic/configs/datasets/seed.yaml +++ b/train_mimic/configs/datasets/seed.yaml @@ -1,7 +1,5 @@ name: seed target_fps: 30 -val_percent: 5 -hash_salt: "" preprocess: normalize_root_xy: true ground_align: first_frame_foot @@ -11,6 +9,5 @@ sources: type: seed_csv input: data/SEED/g1/csv metadata_csv: data/SEED/seed_metadata_v003.csv - weight: 1.0 filters: is_mirror: [false] diff --git a/train_mimic/configs/datasets/seed_clean.yaml b/train_mimic/configs/datasets/seed_clean.yaml index 80d3021a..3eaf5bf7 100644 --- a/train_mimic/configs/datasets/seed_clean.yaml +++ b/train_mimic/configs/datasets/seed_clean.yaml @@ -1,7 +1,5 @@ name: seed_clean target_fps: 30 -val_percent: 5 -hash_salt: "" preprocess: normalize_root_xy: true ground_align: first_frame_foot @@ -12,6 +10,5 @@ sources: input: data/SEED/g1/csv metadata_csv: data/SEED/seed_metadata_v003.csv seed_filter_preset: groot_strict - weight: 1.0 filters: is_mirror: [false] diff --git a/train_mimic/configs/datasets/twist2.yaml b/train_mimic/configs/datasets/twist2.yaml index e178b6f5..ba40c961 100644 --- a/train_mimic/configs/datasets/twist2.yaml +++ b/train_mimic/configs/datasets/twist2.yaml @@ -1,7 +1,5 @@ name: twist2 target_fps: 30 -val_percent: 5 -hash_salt: "" preprocess: normalize_root_xy: true ground_align: first_frame_foot diff --git a/train_mimic/data/dataset_builder.py b/train_mimic/data/dataset_builder.py index 2e20cdbb..38a4b627 100644 --- a/train_mimic/data/dataset_builder.py +++ b/train_mimic/data/dataset_builder.py @@ -21,14 +21,11 @@ from train_mimic.data.dataset_lib import ( DEFAULT_HDF5_MAX_WINDOW_FRAMES, DEFAULT_HDF5_WINDOW_OVERLAP_FRAMES, - hash_split, + FULL_CLIP_ARRAY_KEYS, inspect_clip_dict, inspect_npz, merge_npz_files, resample_along_time, - read_hdf5_body_names, - utc_now_iso, - write_hdf5_manifest, write_hdf5_motion_shard, write_json, ) @@ -61,16 +58,9 @@ "seed_csv": ".csv", } _DATASET_ROOT_MARKERS = { - "train", - "val", - "train.npz", - "val.npz", - "manifest_resolved.csv", - "build_info.json", + "clips", } -_SPLIT_NAMES = ("train", "val") - _PROCESS_FK_EXTRACTOR: MotionFkExtractor | None = None _PROCESS_BVH_MODEL_CACHE: dict[str, mujoco.MjModel] = {} @@ -88,7 +78,6 @@ class DatasetSourceSpec: name: str type: str input: str - weight: float = 1.0 bvh_format: str | None = None robot_name: str = "unitree_g1" max_frames: int = 0 @@ -102,8 +91,6 @@ class DatasetSourceSpec: class DatasetSpec: name: str target_fps: int - val_percent: int - hash_salt: str sources: list[DatasetSourceSpec] preprocess: DatasetPreprocessSpec = field(default_factory=DatasetPreprocessSpec) @@ -115,9 +102,7 @@ class DatasetClipRow: file_rel: str num_frames: int fps: int - resolved_split: str resolved_npz_path: str - weight: float = 1.0 clip_index: int = -1 # index into source clip metadata; -1 = standalone clip @@ -314,11 +299,6 @@ def load_dataset_spec(path: str | Path) -> DatasetSpec: if target_fps <= 0: raise ValueError(f"dataset spec target_fps must be > 0: {spec_path}") - val_percent = int(payload.get("val_percent", 0)) - if val_percent <= 0 or val_percent >= 100: - raise ValueError(f"dataset spec val_percent must be in [1, 99]: {spec_path}") - - hash_salt = str(payload.get("hash_salt", "")) preprocess = _load_preprocess_spec(payload.get("preprocess"), spec_path) raw_sources = payload.get("sources") if not isinstance(raw_sources, list) or not raw_sources: @@ -342,10 +322,6 @@ def load_dataset_spec(path: str | Path) -> DatasetSpec: if not source_input: raise ValueError(f"source {source_name!r} missing non-empty input: {spec_path}") - source_weight = float(raw.get("weight", 1.0)) - if source_weight <= 0: - raise ValueError(f"source {source_name!r} has non-positive weight: {source_weight}") - bvh_format = raw.get("bvh_format") if source_type == "bvh": bvh_format = str(bvh_format or "").strip() @@ -388,7 +364,6 @@ def load_dataset_spec(path: str | Path) -> DatasetSpec: name=source_name, type=source_type, input=source_input, - weight=source_weight, bvh_format=bvh_format, robot_name=robot_name, max_frames=max_frames, @@ -402,8 +377,6 @@ def load_dataset_spec(path: str | Path) -> DatasetSpec: return DatasetSpec( name=name, target_fps=target_fps, - val_percent=val_percent, - hash_salt=hash_salt, sources=sources, preprocess=preprocess, ) @@ -416,14 +389,18 @@ def resolve_dataset_paths(spec: DatasetSpec, *, output_root: str | Path | None = return DatasetPaths(dataset_dir=dataset_dir, clips_root=clips_root) -def split_output_dir(dataset_dir: Path, split: str) -> Path: - if split not in _SPLIT_NAMES: - raise ValueError(f"invalid split {split!r}, expected one of {_SPLIT_NAMES}") - return dataset_dir / split +def shard_output_path(dataset_dir: Path, shard_index: int) -> Path: + return dataset_dir / f"shard_{shard_index:03d}.h5" -def shard_output_path(split_dir: Path, shard_index: int) -> Path: - return split_dir / f"shard_{shard_index:03d}.h5" +def _clear_existing_motion_shards(dataset_dir: Path) -> None: + """Remove stale top-level HDF5 outputs before writing a fresh dataset build.""" + if not dataset_dir.exists(): + return + for pattern in ("shard_*.h5", ".*_chunk_*.h5"): + for path in dataset_dir.glob(pattern): + if path.is_file(): + path.unlink() def resolve_source_input_path(source: DatasetSourceSpec) -> Path: @@ -1092,59 +1069,12 @@ def collect_clip_rows(spec: DatasetSpec, *, paths: DatasetPaths) -> list[Dataset file_rel=_display_path(npz_path), num_frames=meta.num_frames, fps=meta.fps, - resolved_split="", resolved_npz_path=str(npz_path), - weight=source.weight, ) ) - return assign_splits(rows, spec.val_percent, spec.hash_salt) - - -def assign_splits(rows: list[DatasetClipRow], val_percent: int, hash_salt: str) -> list[DatasetClipRow]: if not rows: - raise ValueError("no clip rows to split") - - resolved = [ - DatasetClipRow( - clip_id=row.clip_id, - source=row.source, - file_rel=row.file_rel, - num_frames=row.num_frames, - fps=row.fps, - resolved_split=hash_split(row.clip_id, val_percent, hash_salt), - resolved_npz_path=row.resolved_npz_path, - weight=row.weight, - ) - for row in rows - ] - train_count = sum(1 for row in resolved if row.resolved_split == "train") - val_count = sum(1 for row in resolved if row.resolved_split == "val") - if train_count > 0 and val_count > 0: - return resolved - if len(resolved) < 2: - raise ValueError("dataset must contain at least 2 clips to create both train and val splits") - - ordered = sorted(resolved, key=lambda row: row.clip_id) - adjusted: list[DatasetClipRow] = [] - for idx, row in enumerate(ordered): - split = row.resolved_split - if val_count == 0 and idx == 0: - split = "val" - elif train_count == 0 and idx == 0: - split = "train" - adjusted.append( - DatasetClipRow( - clip_id=row.clip_id, - source=row.source, - file_rel=row.file_rel, - num_frames=row.num_frames, - fps=row.fps, - resolved_split=split, - resolved_npz_path=row.resolved_npz_path, - weight=row.weight, - ) - ) - return adjusted + raise ValueError("no clip rows collected") + return rows def run_sample_fk_checks( @@ -1179,58 +1109,11 @@ def run_sample_fk_checks( return summaries -def write_manifest_resolved(rows: list[DatasetClipRow], dataset_dir: Path) -> Path: - out_path = dataset_dir / "manifest_resolved.csv" - dataset_dir.mkdir(parents=True, exist_ok=True) - with out_path.open("w", encoding="utf-8", newline="") as handle: - writer = csv.writer(handle) - writer.writerow( - [ - "clip_id", - "source", - "file_rel", - "num_frames", - "fps", - "resolved_split", - "resolved_npz_path", - "weight", - "clip_index", - ] - ) - for row in sorted(rows, key=lambda item: item.clip_id): - writer.writerow( - [ - row.clip_id, - row.source, - row.file_rel, - row.num_frames, - row.fps, - row.resolved_split, - row.resolved_npz_path, - row.weight, - row.clip_index, - ] - ) - return out_path - - -def _rows_for_split(rows: list[DatasetClipRow], split: str) -> tuple[list[Path], list[float]]: - selected = [(Path(row.resolved_npz_path), row.weight) for row in rows if row.resolved_split == split] - if not selected: - raise ValueError(f"no clips for split={split}") - files, weights = zip(*selected) - return list(files), list(weights) - - -_ARRAY_KEYS = [ - "joint_pos", "joint_vel", "body_pos_w", "body_quat_w", - "body_lin_vel_w", "body_ang_vel_w", -] +_ARRAY_KEYS = FULL_CLIP_ARRAY_KEYS def _batch_convert_chunk( file_paths: list[str], - weights: list[float], target_fps: int, output_path: str, label: str, @@ -1246,13 +1129,12 @@ def _batch_convert_chunk( extractor = MotionFkExtractor() acc: dict[str, list[np.ndarray]] = {k: [] for k in _ARRAY_KEYS} clip_lengths: list[int] = [] - clip_weights: list[float] = [] body_names: np.ndarray | None = None total = len(file_paths) filtered = 0 kept_file_paths: list[str] = [] - for i, (file_path, weight) in enumerate(zip(file_paths, weights)): + for i, file_path in enumerate(file_paths): try: if file_path.endswith(".csv"): arrays = convert_seed_csv_to_arrays(file_path, extractor=extractor) @@ -1274,6 +1156,8 @@ def _batch_convert_chunk( new_t = max(1, round(old_t * target_fps / fps)) for key in _ARRAY_KEYS: arrays[key] = resample_along_time(arrays[key], new_t) + qn_root = np.linalg.norm(arrays["root_quat_w"], axis=-1, keepdims=True) + arrays["root_quat_w"] = arrays["root_quat_w"] / np.where(qn_root < 1e-8, 1.0, qn_root) qn = np.linalg.norm(arrays["body_quat_w"], axis=-1, keepdims=True) arrays["body_quat_w"] = arrays["body_quat_w"] / np.where(qn < 1e-8, 1.0, qn) @@ -1297,7 +1181,6 @@ def _batch_convert_chunk( for key in _ARRAY_KEYS: acc[key].append(np.asarray(clip_dict[key])) clip_lengths.append(int(np.asarray(clip_dict["joint_pos"]).shape[0])) - clip_weights.append(weight) kept_file_paths.append(file_path) if (i + 1) % 500 == 0 or (i + 1) == total: @@ -1330,7 +1213,6 @@ def _batch_convert_chunk( merged["clip_starts"] = clip_starts merged["clip_lengths"] = clip_lengths_arr merged["clip_fps"] = np.full(kept, target_fps, dtype=np.int64) - merged["clip_weights"] = np.array(clip_weights, dtype=np.float64) shard_info = write_hdf5_motion_shard( merged, @@ -1385,29 +1267,27 @@ def _shard_stats( def _batch_convert_split( - clips: list[tuple[str, float]], + file_paths: list[str], target_fps: int, output_dir: Path, jobs: int, - split_name: str, + label: str, preprocess: DatasetPreprocessSpec, ) -> tuple[dict[str, Any], list[dict[str, Any]]]: - """Convert clips for one split using parallel chunk workers.""" - if not clips: - raise ValueError(f"no clips for split {split_name}") + """Convert clips for one dataset using parallel chunk workers.""" + if not file_paths: + raise ValueError(f"no clips for dataset {label}") - file_paths = [c[0] for c in clips] - weights = [c[1] for c in clips] - num_workers = min(jobs, len(clips)) + num_workers = min(jobs, len(file_paths)) output_dir.mkdir(parents=True, exist_ok=True) if num_workers <= 1: shard_path = shard_output_path(output_dir, 0) stats = _batch_convert_chunk( - file_paths, weights, target_fps, str(shard_path), split_name, preprocess, + file_paths, target_fps, str(shard_path), label, preprocess, ) if int(stats["clips"]) <= 0: - raise ValueError(f"no valid clips remain for split {split_name} after preprocessing") + raise ValueError(f"no valid clips remain for dataset {label} after preprocessing") shard_infos = [{ "path": shard_path, "clips": int(stats.get("clips", 0)), @@ -1416,25 +1296,22 @@ def _batch_convert_split( "source_clip_lengths": list(stats.pop("source_clip_lengths", [])), "kept_file_paths": list(stats.pop("kept_file_paths", [])), }] - if body_names := read_hdf5_body_names(shard_path): - write_hdf5_manifest(output_dir, shard_infos=shard_infos, fps=target_fps, body_names=body_names) return _shard_stats(output_dir=output_dir, shard_infos=shard_infos, fps=target_fps), shard_infos # Split into chunks, one per worker - chunk_size = (len(clips) + num_workers - 1) // num_workers - chunk_args: list[tuple[list[str], list[float], int, str, str, DatasetPreprocessSpec]] = [] + chunk_size = (len(file_paths) + num_workers - 1) // num_workers + chunk_args: list[tuple[list[str], int, str, str, DatasetPreprocessSpec]] = [] for i in range(num_workers): start = i * chunk_size - end = min(start + chunk_size, len(clips)) - if start >= len(clips): + end = min(start + chunk_size, len(file_paths)) + if start >= len(file_paths): break - chunk_out = str(output_dir / f".{split_name}_chunk_{i}.h5") + chunk_out = str(output_dir / f".{label}_chunk_{i}.h5") chunk_args.append(( file_paths[start:end], - weights[start:end], target_fps, chunk_out, - f"{split_name}[{i}]", + f"{label}[{i}]", preprocess, )) @@ -1443,7 +1320,7 @@ def _batch_convert_split( try: with ProcessPoolExecutor(max_workers=len(chunk_args), mp_context=ctx) as executor: futures = { - executor.submit(_batch_convert_chunk, *args): args[3] + executor.submit(_batch_convert_chunk, *args): args[2] for args in chunk_args } try: @@ -1455,18 +1332,17 @@ def _batch_convert_split( future.cancel() raise except (PermissionError, OSError): - print(f"[WARN] process pool unavailable; falling back to serial for {split_name}") + print(f"[WARN] process pool unavailable; falling back to serial for {label}") shard_path = shard_output_path(output_dir, 0) stats = _batch_convert_chunk( file_paths, - weights, target_fps, str(shard_path), - split_name, + label, preprocess, ) if int(stats["clips"]) <= 0: - raise ValueError(f"no valid clips remain for split {split_name} after preprocessing") + raise ValueError(f"no valid clips remain for dataset {label} after preprocessing") shard_infos = [{ "path": shard_path, "clips": int(stats.get("clips", 0)), @@ -1475,15 +1351,13 @@ def _batch_convert_split( "source_clip_lengths": list(stats.pop("source_clip_lengths", [])), "kept_file_paths": list(stats.pop("kept_file_paths", [])), }] - if body_names := read_hdf5_body_names(shard_path): - write_hdf5_manifest(output_dir, shard_infos=shard_infos, fps=target_fps, body_names=body_names) return _shard_stats(output_dir=output_dir, shard_infos=shard_infos, fps=target_fps), shard_infos shard_infos: list[dict[str, Any]] = [] shard_index = 0 for args in chunk_args: - tmp_path = Path(args[3]) - chunk_stat = chunk_results.get(args[3], {}) + tmp_path = Path(args[2]) + chunk_stat = chunk_results.get(args[2], {}) if int(chunk_stat.get("clips", 0)) <= 0: tmp_path.unlink(missing_ok=True) continue @@ -1500,13 +1374,11 @@ def _batch_convert_split( }) if not shard_infos: - raise ValueError(f"no valid clips remain for split {split_name} after preprocessing") + raise ValueError(f"no valid clips remain for dataset {label} after preprocessing") stats = _shard_stats(output_dir=output_dir, shard_infos=shard_infos, fps=target_fps) - first_body_names = read_hdf5_body_names(Path(shard_infos[0]["path"])) - write_hdf5_manifest(output_dir, shard_infos=shard_infos, fps=target_fps, body_names=first_body_names) print( - f"[SHARDS] {split_name}: {stats['shards']} shards, " + f"[SHARDS] {label}: {stats['shards']} shards, " f"{stats['clips']} clips, {stats['frames']} frames ({stats['duration_s']:.1f}s)", flush=True, ) @@ -1522,154 +1394,56 @@ def _build_dataset_batch( skip_validate: bool = False, jobs: int = DEFAULT_JOBS, ) -> dict[str, Any]: - """Batch build: enumerate -> split -> parallel chunk convert -> merge. + """Batch build: enumerate -> parallel chunk convert -> minimal HDF5 shards. - Skips writing individual clip NPZ files. Each worker converts a batch of - PKL files in-memory and writes one merged chunk NPZ. + Skips writing individual clip files. Each worker converts a batch of PKL or + seed CSV files in-memory and writes one minimal HDF5 shard. """ if force and paths.dataset_dir.exists(): shutil.rmtree(paths.dataset_dir) - # 1. Enumerate all source files and pre-compute splits - clip_entries: list[tuple[str, str, str, float, str]] = [] + clip_entries: list[tuple[str, str, str]] = [] source_filter_reports: list[dict[str, Any]] = [] for source in spec.sources: items, _, filter_report = _collect_source_files_with_report(source, quiet=False) source_filter_reports.append(filter_report) for item in items: clip_id = f"{source.name}:{item.rel_no_suffix.as_posix()}" - split = hash_split(clip_id, spec.val_percent, spec.hash_salt) - clip_entries.append((str(item.path), clip_id, source.name, source.weight, split)) - - if len(clip_entries) < 2: - raise ValueError("dataset must contain at least 2 clips") - - # Ensure both splits are non-empty - train_entries = [e for e in clip_entries if e[4] == "train"] - val_entries = [e for e in clip_entries if e[4] == "val"] - if not train_entries or not val_entries: - ordered = sorted(clip_entries, key=lambda e: e[1]) - target_split = "val" if not val_entries else "train" - first = ordered[0] - clip_entries = [ - (p, cid, src, w, target_split if cid == first[1] else sp) - for p, cid, src, w, sp in clip_entries - ] - train_entries = [e for e in clip_entries if e[4] == "train"] - val_entries = [e for e in clip_entries if e[4] == "val"] + clip_entries.append((str(item.path), clip_id, source.name)) + + if not clip_entries: + raise ValueError("dataset must contain at least 1 clip") print( - f"[DATASET] {spec.name}: {len(clip_entries)} clips " - f"({len(train_entries)} train, {len(val_entries)} val), " - f"jobs={jobs}", + f"[DATASET] {spec.name}: {len(clip_entries)} clips, jobs={jobs}", flush=True, ) - # 2. Process each split with parallel chunk workers + _clear_existing_motion_shards(paths.dataset_dir) paths.dataset_dir.mkdir(parents=True, exist_ok=True) - train_dir = split_output_dir(paths.dataset_dir, "train") - val_dir = split_output_dir(paths.dataset_dir, "val") - - train_stats, train_shards = _batch_convert_split( - [(e[0], e[3]) for e in train_entries], - spec.target_fps, - train_dir, - jobs, - "train", - spec.preprocess, - ) - val_stats, val_shards = _batch_convert_split( - [(e[0], e[3]) for e in val_entries], + stats, shard_infos = _batch_convert_split( + [e[0] for e in clip_entries], spec.target_fps, - val_dir, + paths.dataset_dir, jobs, - "val", + spec.name, spec.preprocess, ) - def _consume_entries_by_path( - entries: list[tuple[str, str, str, float, str]], - ) -> dict[str, list[tuple[str, str, str, float, str]]]: - buckets: dict[str, list[tuple[str, str, str, float, str]]] = {} - for entry in entries: - buckets.setdefault(entry[0], []).append(entry) - return buckets - - def _build_rows_for_shards( - entries: list[tuple[str, str, str, float, str]], - shard_infos: list[dict[str, Any]], - ) -> list[DatasetClipRow]: - entry_buckets = _consume_entries_by_path(entries) - built_rows: list[DatasetClipRow] = [] - for shard in shard_infos: - shard_path = Path(shard["path"]) - kept_paths = list(shard["kept_file_paths"]) - clip_lengths = list(shard.get("source_clip_lengths", shard["clip_lengths"])) - if len(kept_paths) != len(clip_lengths): - raise ValueError( - f"kept path count mismatch for {shard_path}: {len(kept_paths)} vs {len(clip_lengths)}" - ) - for clip_index, (kept_path, num_frames) in enumerate(zip(kept_paths, clip_lengths)): - candidates = entry_buckets.get(kept_path) - if not candidates: - raise ValueError(f"kept path not found in split entries: {kept_path}") - path, clip_id, source, weight, split = candidates.pop(0) - built_rows.append(DatasetClipRow( - clip_id=clip_id, - source=source, - file_rel=_display_path(Path(path)), - num_frames=int(num_frames), - fps=spec.target_fps, - resolved_split=split, - resolved_npz_path=str(shard_path), - weight=weight, - clip_index=clip_index, - )) - return built_rows - - # 3. Write manifest with correct num_frames and clip_index for kept clips only - rows: list[DatasetClipRow] = [] - train_rows = _build_rows_for_shards(train_entries, train_shards) - val_rows = _build_rows_for_shards(val_entries, val_shards) - rows.extend(train_rows) - rows.extend(val_rows) - manifest_path = write_manifest_resolved(rows, paths.dataset_dir) - - # 5. Build report - report: dict[str, Any] = { + return { "dataset": spec.name, - "built_at_utc": utc_now_iso(), "target_fps": spec.target_fps, - "val_percent": spec.val_percent, - "hash_salt": spec.hash_salt, "dataset_dir": str(paths.dataset_dir), - "build_dir": str(paths.dataset_dir), - "clips_dir": "", - "manifest_resolved": str(manifest_path), "skip_validate": bool(skip_validate), "skip_fk_check": bool(skip_fk_check), "jobs": int(jobs), "preprocess": spec.preprocess.to_dict(), - "sources": [asdict(source) for source in spec.sources], "source_filters": source_filter_reports, - "splits": { - "train": train_stats, - "val": val_stats, - }, - "clip_counts": { - "total": len(rows), - "train": len(train_rows), - "val": len(val_rows), - }, - "input_clip_counts": { - "total": len(clip_entries), - "train": len(train_entries), - "val": len(val_entries), - }, + "stats": stats, + "shards": [str(info["path"]) for info in shard_infos], + "input_clips": len(clip_entries), "fk_checks": [], } - write_json(paths.dataset_dir / "build_info.json", report) - return report def build_dataset_from_spec( @@ -1695,8 +1469,8 @@ def build_dataset_from_spec( jobs=jobs, ) - # Legacy per-file mode for bvh/npz sources - source_filter_reports = _build_source_filter_reports(spec) + # Per-file mode for BVH/NPZ sources. Converted clips are temporary build + # inputs; final training data is the minimal shard(s) in dataset_dir. convert_sources_to_npz(spec, paths=paths, force=force, jobs=jobs) rows = collect_clip_rows(spec, paths=paths) @@ -1704,130 +1478,36 @@ def build_dataset_from_spec( if not skip_fk_check: fk_checks = run_sample_fk_checks(rows) - train_rows = [row for row in rows if row.resolved_split == "train"] - val_rows = [row for row in rows if row.resolved_split == "val"] - train_files, train_weights = _rows_for_split(rows, "train") - val_files, val_weights = _rows_for_split(rows, "val") + _clear_existing_motion_shards(paths.dataset_dir) paths.dataset_dir.mkdir(parents=True, exist_ok=True) - train_dir = split_output_dir(paths.dataset_dir, "train") - val_dir = split_output_dir(paths.dataset_dir, "val") - train_out = shard_output_path(train_dir, 0) - val_out = shard_output_path(val_dir, 0) - train_tmp_npz = train_dir / ".merged_train_tmp.npz" - val_tmp_npz = val_dir / ".merged_val_tmp.npz" - - train_stats = merge_npz_files( - train_files, - train_tmp_npz, + shard_path = shard_output_path(paths.dataset_dir, 0) + tmp_npz = paths.dataset_dir / ".merged_tmp.npz" + stats = merge_npz_files( + [Path(row.resolved_npz_path) for row in rows], + tmp_npz, target_fps=spec.target_fps, - weights=train_weights, - ) - val_stats = merge_npz_files( - val_files, - val_tmp_npz, - target_fps=spec.target_fps, - weights=val_weights, ) + payload_npz = np.load(tmp_npz, allow_pickle=True) + payload = {key: payload_npz[key] for key in payload_npz.files} + shard_info = write_hdf5_motion_shard(payload, shard_path) + tmp_npz.unlink(missing_ok=True) - train_npz = np.load(train_tmp_npz, allow_pickle=True) - train_payload = {key: train_npz[key] for key in train_npz.files} - train_h5_info = write_hdf5_motion_shard(train_payload, train_out) - val_npz = np.load(val_tmp_npz, allow_pickle=True) - val_payload = {key: val_npz[key] for key in val_npz.files} - val_h5_info = write_hdf5_motion_shard(val_payload, val_out) - train_tmp_npz.unlink(missing_ok=True) - val_tmp_npz.unlink(missing_ok=True) - - train_stats["output"] = str(train_dir) - train_stats["shards"] = 1 - train_stats["clips"] = int(train_h5_info["clips"]) - train_stats["num_clips"] = int(train_h5_info["clips"]) - val_stats["output"] = str(val_dir) - val_stats["shards"] = 1 - val_stats["clips"] = int(val_h5_info["clips"]) - val_stats["num_clips"] = int(val_h5_info["clips"]) - - train_clip_lengths = np.asarray(train_payload["clip_lengths"]) - val_clip_lengths = np.asarray(val_payload["clip_lengths"]) - if len(train_rows) != len(train_clip_lengths): - raise ValueError( - f"train row count mismatch: {len(train_rows)} vs {len(train_clip_lengths)}" - ) - if len(val_rows) != len(val_clip_lengths): - raise ValueError( - f"val row count mismatch: {len(val_rows)} vs {len(val_clip_lengths)}" - ) + stats["output"] = str(paths.dataset_dir) + stats["shards"] = 1 + stats["clips"] = int(shard_info["clips"]) + stats["num_clips"] = int(shard_info["clips"]) - updated_rows: list[DatasetClipRow] = [] - for clip_index, (row, num_frames) in enumerate(zip(train_rows, train_clip_lengths)): - updated_rows.append( - DatasetClipRow( - clip_id=row.clip_id, - source=row.source, - file_rel=row.file_rel, - num_frames=int(num_frames), - fps=spec.target_fps, - resolved_split=row.resolved_split, - resolved_npz_path=str(train_out), - weight=row.weight, - clip_index=clip_index, - ) - ) - - write_hdf5_manifest( - train_dir, - shard_infos=[train_h5_info], - fps=spec.target_fps, - body_names=np.asarray(train_payload["body_names"]), - ) - write_hdf5_manifest( - val_dir, - shard_infos=[val_h5_info], - fps=spec.target_fps, - body_names=np.asarray(val_payload["body_names"]), - ) - for clip_index, (row, num_frames) in enumerate(zip(val_rows, val_clip_lengths)): - updated_rows.append( - DatasetClipRow( - clip_id=row.clip_id, - source=row.source, - file_rel=row.file_rel, - num_frames=int(num_frames), - fps=spec.target_fps, - resolved_split=row.resolved_split, - resolved_npz_path=str(val_out), - weight=row.weight, - clip_index=clip_index, - ) - ) - - manifest_path = write_manifest_resolved(updated_rows, paths.dataset_dir) - report: dict[str, Any] = { + return { "dataset": spec.name, - "built_at_utc": utc_now_iso(), "target_fps": spec.target_fps, - "val_percent": spec.val_percent, - "hash_salt": spec.hash_salt, "dataset_dir": str(paths.dataset_dir), - "build_dir": str(paths.dataset_dir), "clips_dir": str(paths.clips_root), - "manifest_resolved": str(manifest_path), "skip_validate": bool(skip_validate), "skip_fk_check": bool(skip_fk_check), "jobs": int(jobs), "preprocess": spec.preprocess.to_dict(), - "sources": [asdict(source) for source in spec.sources], - "source_filters": source_filter_reports, - "splits": { - "train": train_stats, - "val": val_stats, - }, - "clip_counts": { - "total": len(updated_rows), - "train": len(train_rows), - "val": len(val_rows), - }, + "stats": stats, + "shards": [str(shard_path)], + "input_clips": len(rows), "fk_checks": [item.to_dict() for item in fk_checks], } - write_json(paths.dataset_dir / "build_info.json", report) - return report diff --git a/train_mimic/data/dataset_lib.py b/train_mimic/data/dataset_lib.py index 8c11aff8..a61ef30e 100644 --- a/train_mimic/data/dataset_lib.py +++ b/train_mimic/data/dataset_lib.py @@ -3,7 +3,6 @@ from __future__ import annotations -import hashlib import json from dataclasses import dataclass from datetime import datetime, timezone @@ -15,6 +14,8 @@ REQUIRED_NPZ_KEYS = [ "fps", + "root_pos", + "root_quat_w", "joint_pos", "joint_vel", "body_pos_w", @@ -29,6 +30,17 @@ "joint_pos", "joint_vel", "body_pos_w", "body_quat_w", "body_lin_vel_w", "body_ang_vel_w", ] +MINIMAL_MOTION_ARRAY_KEYS = ["root_pos", "root_quat_w", "joint_pos"] +FULL_CLIP_ARRAY_KEYS = [ + "root_pos", + "root_quat_w", + "joint_pos", + "joint_vel", + "body_pos_w", + "body_quat_w", + "body_lin_vel_w", + "body_ang_vel_w", +] HDF5_DATASET_VERSION = 1 DEFAULT_HDF5_MAX_WINDOW_FRAMES = 512 DEFAULT_HDF5_WINDOW_OVERLAP_FRAMES = 64 @@ -126,6 +138,8 @@ def inspect_clip_dict(payload: Mapping[str, Any]) -> NpzMeta: if fps <= 0: raise ValueError(f"invalid fps={fps}") + root_pos = np.asarray(payload["root_pos"]) + root_quat_w = np.asarray(payload["root_quat_w"]) joint_pos = np.asarray(payload["joint_pos"]) joint_vel = np.asarray(payload["joint_vel"]) body_pos_w = np.asarray(payload["body_pos_w"]) @@ -134,6 +148,10 @@ def inspect_clip_dict(payload: Mapping[str, Any]) -> NpzMeta: body_ang_vel_w = np.asarray(payload["body_ang_vel_w"]) body_names = np.asarray(payload["body_names"]) + if root_pos.ndim != 2 or root_pos.shape[1] != 3: + raise ValueError(f"root_pos must be (T,3), got {root_pos.shape}") + if root_quat_w.ndim != 2 or root_quat_w.shape[1] != 4: + raise ValueError(f"root_quat_w must be (T,4), got {root_quat_w.shape}") if joint_pos.ndim != 2 or joint_pos.shape[1] != NUM_ACTIONS: raise ValueError(f"joint_pos must be (T,{NUM_ACTIONS}), got {joint_pos.shape}") if joint_vel.ndim != 2 or joint_vel.shape != joint_pos.shape: @@ -153,6 +171,8 @@ def inspect_clip_dict(payload: Mapping[str, Any]) -> NpzMeta: raise ValueError("empty time/body dimension") for name, arr in [ + ("root_pos", root_pos), + ("root_quat_w", root_quat_w), ("joint_pos", joint_pos), ("joint_vel", joint_vel), ("body_pos_w", body_pos_w), @@ -170,6 +190,9 @@ def inspect_clip_dict(payload: Mapping[str, Any]) -> NpzMeta: if body_names.ndim != 1 or body_names.shape[0] != nb: raise ValueError(f"body_names must be (nb,), got {body_names.shape}") + root_quat_norm = np.linalg.norm(root_quat_w, axis=-1) + if np.min(root_quat_norm) < 1e-6: + raise ValueError("root_quat_w contains near-zero norms") quat_norm = np.linalg.norm(body_quat_w, axis=-1) if np.min(quat_norm) < 1e-6: raise ValueError("body_quat_w contains near-zero norms") @@ -188,12 +211,6 @@ def inspect_npz(path: Path) -> NpzMeta: return inspect_clip_dict({key: data[key] for key in data.files}) -def hash_split(clip_id: str, val_percent: int, salt: str = "") -> str: - payload = f"{salt}:{clip_id}".encode("utf-8") - bucket = int(hashlib.md5(payload).hexdigest(), 16) % 100 - return "val" if bucket < val_percent else "train" - - def resample_along_time(arr: np.ndarray, new_t: int) -> np.ndarray: """Resample an array along time axis 0 using linear interpolation.""" old_t = int(arr.shape[0]) @@ -219,29 +236,16 @@ def merge_npz_files( output_path: Path, *, target_fps: int | None = None, - weights: list[float] | None = None, ) -> dict[str, Any]: if not npz_files: raise ValueError("no npz files to merge") - if weights is not None and len(weights) != len(npz_files): - raise ValueError( - f"weights length ({len(weights)}) != npz_files length ({len(npz_files)})" - ) - arrays: dict[str, list[np.ndarray]] = { - "joint_pos": [], - "joint_vel": [], - "body_pos_w": [], - "body_quat_w": [], - "body_lin_vel_w": [], - "body_ang_vel_w": [], - } + arrays: dict[str, list[np.ndarray]] = {key: [] for key in FULL_CLIP_ARRAY_KEYS} fps: int | None = None body_names: np.ndarray[Any, Any] | None = None per_clip_fps: list[int] = [] - per_clip_weights: list[float] = [] - for i, p in enumerate(npz_files): + for p in npz_files: d = np.load(p, allow_pickle=True) cur_fps = int(d["fps"]) if cur_fps <= 0: @@ -250,6 +254,8 @@ def merge_npz_files( if target_fps is not None and target_fps <= 0: raise ValueError(f"target_fps must be > 0, got {target_fps}") + root_pos = np.asarray(d["root_pos"]) + root_quat_w = np.asarray(d["root_quat_w"]) joint_pos = np.asarray(d["joint_pos"]) joint_vel = np.asarray(d["joint_vel"]) body_pos_w = np.asarray(d["body_pos_w"]) @@ -262,6 +268,8 @@ def merge_npz_files( new_t = int(round(old_t * float(target_fps) / float(cur_fps))) new_t = max(new_t, 1) + root_pos = resample_along_time(root_pos, new_t) + root_quat_w = resample_along_time(root_quat_w, new_t) joint_pos = resample_along_time(joint_pos, new_t) joint_vel = resample_along_time(joint_vel, new_t) body_pos_w = resample_along_time(body_pos_w, new_t) @@ -269,6 +277,9 @@ def merge_npz_files( body_lin_vel_w = resample_along_time(body_lin_vel_w, new_t) body_ang_vel_w = resample_along_time(body_ang_vel_w, new_t) + root_quat_norm = np.linalg.norm(root_quat_w, axis=-1, keepdims=True) + root_quat_norm = np.where(root_quat_norm < 1e-8, 1.0, root_quat_norm) + root_quat_w = root_quat_w / root_quat_norm quat_norm = np.linalg.norm(body_quat_w, axis=-1, keepdims=True) quat_norm = np.where(quat_norm < 1e-8, 1.0, quat_norm) body_quat_w = body_quat_w / quat_norm @@ -284,8 +295,9 @@ def merge_npz_files( raise ValueError(f"inconsistent body_names in merge: {p}") per_clip_fps.append(cur_fps) - per_clip_weights.append(weights[i] if weights is not None else 1.0) + arrays["root_pos"].append(root_pos) + arrays["root_quat_w"].append(root_quat_w) arrays["joint_pos"].append(joint_pos) arrays["joint_vel"].append(joint_vel) arrays["body_pos_w"].append(body_pos_w) @@ -307,7 +319,6 @@ def merge_npz_files( merged["clip_starts"] = clip_starts merged["clip_lengths"] = clip_lengths merged["clip_fps"] = np.array(per_clip_fps, dtype=np.int64) - merged["clip_weights"] = np.array(per_clip_weights, dtype=np.float64) output_path.parent.mkdir(parents=True, exist_ok=True) np.savez(output_path, **merged) @@ -328,13 +339,11 @@ def merge_clip_dicts( output_path: Path, *, target_fps: int | None = None, - weights: list[float] | None = None, ) -> dict[str, Any]: """Merge a list of in-memory clip array dicts into a single NPZ.""" merged = merge_clip_dicts_payload( clip_dicts, target_fps=target_fps, - weights=weights, ) output_path.parent.mkdir(parents=True, exist_ok=True) np.savez(output_path, **merged) @@ -354,7 +363,6 @@ def merge_clip_dicts_payload( clip_dicts: list[dict[str, Any]], *, target_fps: int | None = None, - weights: list[float] | None = None, ) -> dict[str, Any]: """Merge in-memory clip array dicts and return a flat motion payload. @@ -363,22 +371,14 @@ def merge_clip_dicts_payload( """ if not clip_dicts: raise ValueError("no clip dicts to merge") - if weights is not None and len(weights) != len(clip_dicts): - raise ValueError( - f"weights length ({len(weights)}) != clip_dicts length ({len(clip_dicts)})" - ) if target_fps is not None and target_fps <= 0: raise ValueError(f"target_fps must be > 0, got {target_fps}") - array_keys = [ - "joint_pos", "joint_vel", "body_pos_w", "body_quat_w", - "body_lin_vel_w", "body_ang_vel_w", - ] + array_keys = FULL_CLIP_ARRAY_KEYS arrays: dict[str, list[np.ndarray]] = {k: [] for k in array_keys} fps: int | None = None body_names: np.ndarray | None = None per_clip_fps: list[int] = [] - per_clip_weights: list[float] = [] for i, cd in enumerate(clip_dicts): cur_fps = int(cd["fps"]) @@ -391,6 +391,8 @@ def merge_clip_dicts_payload( new_t = max(1, round(old_t * target_fps / cur_fps)) for k in array_keys: clip_arrays[k] = resample_along_time(clip_arrays[k], new_t) + qn_root = np.linalg.norm(clip_arrays["root_quat_w"], axis=-1, keepdims=True) + clip_arrays["root_quat_w"] = clip_arrays["root_quat_w"] / np.where(qn_root < 1e-8, 1.0, qn_root) qn = np.linalg.norm(clip_arrays["body_quat_w"], axis=-1, keepdims=True) clip_arrays["body_quat_w"] = clip_arrays["body_quat_w"] / np.where(qn < 1e-8, 1.0, qn) cur_fps = target_fps @@ -404,7 +406,6 @@ def merge_clip_dicts_payload( raise ValueError(f"inconsistent body_names: clip {i}") per_clip_fps.append(cur_fps) - per_clip_weights.append(weights[i] if weights is not None else 1.0) for k in array_keys: arrays[k].append(clip_arrays[k]) @@ -419,7 +420,6 @@ def merge_clip_dicts_payload( merged["clip_starts"] = clip_starts merged["clip_lengths"] = clip_lengths merged["clip_fps"] = np.array(per_clip_fps, dtype=np.int64) - merged["clip_weights"] = np.array(per_clip_weights, dtype=np.float64) return merged @@ -457,13 +457,15 @@ def write_hdf5_motion_shard( max_window_frames: int = DEFAULT_HDF5_MAX_WINDOW_FRAMES, overlap_frames: int = DEFAULT_HDF5_WINDOW_OVERLAP_FRAMES, ) -> dict[str, Any]: - """Write a merged motion payload as one HDF5 shard with bounded windows. + """Write a merged motion payload as one minimal HDF5 shard. - The frame arrays remain flat in the HDF5 file. ``clip_starts`` and + The frame arrays remain flat in the HDF5 file. ``clip_starts`` and ``clip_lengths`` describe training windows, not necessarily original clips. Long clips are split into overlapping windows to bound runtime cache size. + Runtime loading derives joint velocities and body FK/velocities from the + stored root pose and joint positions. """ - missing = [key for key in [*MOTION_ARRAY_KEYS, "fps", "body_names", "clip_starts", "clip_lengths", "clip_fps", "clip_weights"] if key not in merged] + missing = [key for key in [*MINIMAL_MOTION_ARRAY_KEYS, "fps", "body_names", "clip_starts", "clip_lengths", "clip_fps"] if key not in merged] if missing: raise ValueError(f"merged payload missing required keys: {missing}") @@ -474,12 +476,10 @@ def write_hdf5_motion_shard( original_starts = np.asarray(merged["clip_starts"], dtype=np.int64) original_lengths = np.asarray(merged["clip_lengths"], dtype=np.int64) original_fps = np.asarray(merged["clip_fps"], dtype=np.int64) - original_weights = np.asarray(merged["clip_weights"], dtype=np.float64) window_starts: list[int] = [] window_lengths: list[int] = [] window_fps: list[int] = [] - window_weights: list[float] = [] source_clip_ids: list[int] = [] source_start_frames: list[int] = [] for source_idx, (clip_start, clip_length) in enumerate(zip(original_starts, original_lengths)): @@ -489,12 +489,10 @@ def write_hdf5_motion_shard( max_window_frames=max_window_frames, overlap_frames=overlap_frames, ) - per_window_weight = float(original_weights[source_idx]) / float(len(ranges)) for start, length in ranges: window_starts.append(int(start)) window_lengths.append(int(length)) window_fps.append(int(original_fps[source_idx])) - window_weights.append(per_window_weight) source_clip_ids.append(int(source_idx)) source_start_frames.append(int(start - int(clip_start))) @@ -507,19 +505,17 @@ def write_hdf5_motion_shard( h5.attrs["max_window_frames"] = int(max_window_frames) h5.attrs["overlap_frames"] = int(overlap_frames) h5.create_dataset("body_names", data=body_names.astype(object), dtype=str_dt) - for key in MOTION_ARRAY_KEYS: + for key in MINIMAL_MOTION_ARRAY_KEYS: arr = np.asarray(merged[key], dtype=np.float32) h5.create_dataset(key, data=arr, chunks=True) h5.create_dataset("clip_starts", data=np.asarray(window_starts, dtype=np.int64)) h5.create_dataset("clip_lengths", data=np.asarray(window_lengths, dtype=np.int64)) h5.create_dataset("clip_fps", data=np.asarray(window_fps, dtype=np.int64)) - h5.create_dataset("clip_weights", data=np.asarray(window_weights, dtype=np.float64)) h5.create_dataset("source_clip_ids", data=np.asarray(source_clip_ids, dtype=np.int64)) h5.create_dataset("source_start_frames", data=np.asarray(source_start_frames, dtype=np.int64)) h5.create_dataset("source_clip_starts", data=original_starts.astype(np.int64)) h5.create_dataset("source_clip_lengths", data=original_lengths.astype(np.int64)) h5.create_dataset("source_clip_fps", data=original_fps.astype(np.int64)) - h5.create_dataset("source_clip_weights", data=original_weights.astype(np.float64)) total_frames = int(np.asarray(merged["joint_pos"]).shape[0]) return { @@ -552,7 +548,7 @@ def read_motion_clip(path: Path, clip_index: int) -> dict[str, Any]: if path.suffix == ".h5": return read_hdf5_source_clip(path, clip_index) raise ValueError( - f"review/rebuild input must be a current HDF5 shard (.h5), got: {path}" + f"source clip input must be a current HDF5 shard (.h5), got: {path}" ) @@ -565,7 +561,7 @@ def read_hdf5_source_clip(path: Path, clip_index: int) -> dict[str, Any]: "source_clip_lengths", "source_clip_fps", "body_names", - *MOTION_ARRAY_KEYS, + *MINIMAL_MOTION_ARRAY_KEYS, ] missing = [key for key in required if key not in h5] if missing: @@ -583,64 +579,120 @@ def read_hdf5_source_clip(path: Path, clip_index: int) -> dict[str, Any]: start = int(starts[clip_index]) length = int(lengths[clip_index]) sl = slice(start, start + length) - return { - "fps": int(fps[clip_index]), - "joint_pos": np.asarray(h5["joint_pos"][sl], dtype=np.float32), - "joint_vel": np.asarray(h5["joint_vel"][sl], dtype=np.float32), - "body_pos_w": np.asarray(h5["body_pos_w"][sl], dtype=np.float32), - "body_quat_w": np.asarray(h5["body_quat_w"][sl], dtype=np.float32), - "body_lin_vel_w": np.asarray(h5["body_lin_vel_w"][sl], dtype=np.float32), - "body_ang_vel_w": np.asarray(h5["body_ang_vel_w"][sl], dtype=np.float32), - "body_names": np.asarray(read_hdf5_body_names(path), dtype=str), - } - - -def write_hdf5_manifest( - split_dir: Path, - *, - shard_infos: Sequence[Mapping[str, Any]], - fps: int, - body_names: Sequence[str] | np.ndarray, -) -> Path: - shards = [] + root_pos = np.asarray(h5["root_pos"][sl], dtype=np.float32) + root_quat_w = np.asarray(h5["root_quat_w"][sl], dtype=np.float32) + joint_pos = np.asarray(h5["joint_pos"][sl], dtype=np.float32) + body_names = np.asarray(read_hdf5_body_names(path), dtype=str) + + from train_mimic.data.motion_fk import MotionFkExtractor, compute_body_velocities, finite_diff_velocity + + dt = 1.0 / float(fps[clip_index]) + joint_vel = finite_diff_velocity(joint_pos, dt) + fk = MotionFkExtractor() + body_pos_w, body_quat_w = fk.extract(root_pos, root_quat_w, joint_pos, body_names) + body_lin_vel_w, body_ang_vel_w = compute_body_velocities(body_pos_w, body_quat_w, dt) + return { + "fps": int(fps[clip_index]), + "root_pos": root_pos, + "root_quat_w": root_quat_w, + "joint_pos": joint_pos, + "joint_vel": joint_vel, + "body_pos_w": body_pos_w, + "body_quat_w": body_quat_w, + "body_lin_vel_w": body_lin_vel_w, + "body_ang_vel_w": body_ang_vel_w, + "body_names": body_names, + } + + +def find_motion_shards(dataset_dir: str | Path) -> list[Path]: + """Recursively find Teleopit HDF5 motion shards under a root directory.""" + root = Path(dataset_dir).expanduser().resolve() + if root.is_file(): + candidates = [root] + elif root.is_dir(): + candidates = sorted(root.rglob("*.h5")) + else: + raise FileNotFoundError(f"motion dataset path not found: {dataset_dir}") + + shards: list[Path] = [] + for path in candidates: + try: + with h5py.File(path, "r") as h5: + if h5.attrs.get("format") == "teleopit_motion_hdf5": + shards.append(path) + except OSError: + continue + if not shards: + raise FileNotFoundError(f"no Teleopit HDF5 motion shards found under {dataset_dir}") + return shards + + +def compute_dataset_stats(dataset_dir: str | Path) -> dict[str, Any]: + shards = find_motion_shards(dataset_dir) total_windows = 0 total_frames = 0 - expected_body_names = [str(name) for name in np.asarray(body_names).tolist()] - for info in shard_infos: - path = Path(str(info["path"])) - shard_path = path if path.is_absolute() else split_dir / path - if shard_path.is_file(): - actual_body_names = read_hdf5_body_names(shard_path) - if actual_body_names != expected_body_names: + total_source_clips = 0 + fps_values: set[int] = set() + body_names_ref: list[str] | None = None + shard_rows: list[dict[str, Any]] = [] + + for shard_path in shards: + with h5py.File(shard_path, "r") as h5: + missing = [ + key for key in [ + *MINIMAL_MOTION_ARRAY_KEYS, + "body_names", + "clip_starts", + "clip_lengths", + "clip_fps", + ] + if key not in h5 + ] + if missing: + raise ValueError(f"HDF5 shard {shard_path} missing required datasets: {missing}") + body_names = read_hdf5_body_names(shard_path) + if body_names_ref is None: + body_names_ref = body_names + elif body_names != body_names_ref: raise ValueError( - f"HDF5 shard body_names mismatch for {shard_path}: " - "all shards in a split must use the same body order" + f"inconsistent body_names in {shard_path}; all shards under one training root must match" ) - if path.is_absolute(): - rel_path = path.name if path.parent == split_dir else str(path.relative_to(split_dir)) - else: - rel_path = str(path) - clips = int(info.get("clips", info.get("num_clips", 0))) - frames = int(info.get("frames", 0)) - total_windows += clips - total_frames += frames - shards.append({ - "path": rel_path, - "clips": clips, - "frames": frames, - }) - manifest = { + lengths = np.asarray(h5["clip_lengths"], dtype=np.int64) + fps_arr = np.asarray(h5["clip_fps"], dtype=np.int64) + source_ids = ( + np.asarray(h5["source_clip_ids"], dtype=np.int64) + if "source_clip_ids" in h5 + else np.arange(lengths.shape[0], dtype=np.int64) + ) + fps_values.update(int(v) for v in np.unique(fps_arr)) + windows = int(lengths.shape[0]) + frames = int(np.asarray(h5["joint_pos"]).shape[0]) + source_clips = int(len(np.unique(source_ids))) + total_windows += windows + total_frames += frames + total_source_clips += source_clips + shard_rows.append({ + "path": str(shard_path), + "windows": windows, + "source_clips": source_clips, + "frames": frames, + "min_window_frames": int(lengths.min()) if windows else 0, + "max_window_frames": int(lengths.max()) if windows else 0, + }) + + return { "format": "teleopit_motion_hdf5", "version": HDF5_DATASET_VERSION, - "fps": int(fps), - "body_names": expected_body_names, - "shards": shards, - "clips": total_windows, + "root": str(Path(dataset_dir).expanduser().resolve()), + "shards": len(shards), + "windows": total_windows, + "source_clips": total_source_clips, "frames": total_frames, + "fps": sorted(fps_values), + "body_names": body_names_ref or [], + "shard_details": shard_rows, } - path = split_dir / "manifest.json" - write_json(path, manifest) - return path def write_json(path: Path, payload: dict[str, Any]) -> None: diff --git a/train_mimic/data/motion_fk.py b/train_mimic/data/motion_fk.py index 3b03c671..fd7f2d4d 100644 --- a/train_mimic/data/motion_fk.py +++ b/train_mimic/data/motion_fk.py @@ -33,6 +33,21 @@ def normalize_quaternion(q: np.ndarray) -> np.ndarray: return arr / norms +def finite_diff_velocity(x: np.ndarray, dt: float) -> np.ndarray: + """Finite-difference velocity with central interior and one-sided edges.""" + arr = np.asarray(x, dtype=np.float32) + vel = np.zeros_like(arr, dtype=np.float32) + if dt <= 0.0: + raise ValueError(f"dt must be > 0, got {dt}") + if arr.shape[0] < 2: + return vel + inv_dt = 1.0 / float(dt) + vel[1:-1] = (arr[2:] - arr[:-2]) * (0.5 * inv_dt) + vel[0] = (arr[1] - arr[0]) * inv_dt + vel[-1] = (arr[-1] - arr[-2]) * inv_dt + return vel + + def quat_multiply(q1: np.ndarray, q2: np.ndarray) -> np.ndarray: """Multiply two wxyz quaternions: q1 * q2.""" q1 = np.asarray(q1, dtype=np.float32) @@ -75,9 +90,19 @@ def quat_rotate_inverse(q: np.ndarray, v: np.ndarray) -> np.ndarray: def quat_to_angular_velocity(q: np.ndarray, dt: float) -> np.ndarray: """Compute angular velocity from quaternion sequence in wxyz convention.""" - q = normalize_quaternion(q) - q_dot = np.gradient(q, dt, axis=0) - product = quat_multiply(q_dot, quat_conjugate(q)) + quat = normalize_quaternion(q) + if quat.shape[0] < 2: + return np.zeros(quat.shape[:-1] + (3,), dtype=np.float32) + + flat = quat.reshape(quat.shape[0], -1, 4) + dots = np.sum(flat[1:] * flat[:-1], axis=-1) + signs = np.where(dots < 0.0, -1.0, 1.0).astype(np.float32) + signs = np.concatenate([np.ones_like(signs[:1]), signs], axis=0) + flat = flat * np.cumprod(signs, axis=0)[..., None] + quat = flat.reshape(quat.shape) + + q_dot = finite_diff_velocity(quat, dt) + product = quat_multiply(q_dot, quat_conjugate(quat)) return (2.0 * product[..., 1:4]).astype(np.float32) @@ -192,9 +217,7 @@ def compute_body_velocities( """Compute linear and angular velocity sequences from FK outputs.""" body_pos_w = np.asarray(body_pos_w, dtype=np.float32) body_quat_w = normalize_quaternion(body_quat_w) - if dt <= 0.0: - raise ValueError(f"dt must be > 0, got {dt}") - body_lin_vel_w = np.gradient(body_pos_w, dt, axis=0).astype(np.float32) + body_lin_vel_w = finite_diff_velocity(body_pos_w, dt) body_ang_vel_w = quat_to_angular_velocity(body_quat_w, dt).astype(np.float32) return body_lin_vel_w, body_ang_vel_w diff --git a/train_mimic/data/preprocess.py b/train_mimic/data/preprocess.py index f6558805..89cad31f 100644 --- a/train_mimic/data/preprocess.py +++ b/train_mimic/data/preprocess.py @@ -101,6 +101,8 @@ def preprocess_clip_dict( result = { "fps": int(payload["fps"]), + "root_pos": np.asarray(payload["root_pos"]).copy(), + "root_quat_w": np.asarray(payload["root_quat_w"]).copy(), "joint_pos": np.asarray(payload["joint_pos"]).copy(), "joint_vel": np.asarray(payload["joint_vel"]).copy(), "body_pos_w": np.asarray(payload["body_pos_w"]).copy(), @@ -112,6 +114,7 @@ def preprocess_clip_dict( inspect_clip_dict(result) fps = int(result["fps"]) + root_pos = result["root_pos"] body_pos_w = result["body_pos_w"] body_lin_vel_w = result["body_lin_vel_w"] body_names = np.asarray(result["body_names"]) @@ -129,6 +132,8 @@ def preprocess_clip_dict( if spec.normalize_root_xy: assert root_index is not None offset_xy = body_pos_w[0, root_index, :2].copy() + root_pos[:, 0] -= offset_xy[0] + root_pos[:, 1] -= offset_xy[1] body_pos_w[..., 0] -= offset_xy[0] body_pos_w[..., 1] -= offset_xy[1] @@ -172,7 +177,9 @@ def preprocess_clip_dict( if spec.ground_align == "first_frame_foot": assert foot_indices is not None foot_z = body_pos_w[:, foot_indices, 2] - body_pos_w[..., 2] -= float(np.min(foot_z[0])) + dz = float(np.min(foot_z[0])) + root_pos[:, 2] -= dz + body_pos_w[..., 2] -= dz if spec.min_peak_body_height is not None: peak_height = float(np.max(body_pos_w[:, :, 2])) diff --git a/train_mimic/data/review_lib.py b/train_mimic/data/review_lib.py deleted file mode 100644 index e77a3700..00000000 --- a/train_mimic/data/review_lib.py +++ /dev/null @@ -1,185 +0,0 @@ -"""Shared utilities for dataset review: data model, I/O, and statistics.""" - -from __future__ import annotations - -import csv -import os -from dataclasses import dataclass, fields -from datetime import datetime, timezone -from pathlib import Path -from typing import Any - - -REVIEW_COLUMNS = [ - "clip_id", - "source", - "file_rel", - "resolved_npz_path", - "resolved_split", - "num_frames", - "fps", - "duration_s", - "weight", - "clip_index", - "decision", - "difficulty", - "issue_tags", - "note", - "reviewed_at", -] - -VALID_DECISIONS = {"keep", "drop", "skip", ""} -VALID_DIFFICULTIES = {"easy", "medium", "hard", "bad_data", ""} - - -@dataclass -class ReviewRow: - clip_id: str - source: str - file_rel: str - resolved_npz_path: str - resolved_split: str - num_frames: int - fps: int - duration_s: float - weight: float = 1.0 - clip_index: int = -1 # source-clip index inside HDF5 shard; -1 = standalone source clip - decision: str = "" - difficulty: str = "" - issue_tags: str = "" - note: str = "" - reviewed_at: str = "" - - -@dataclass(frozen=True) -class ReviewStats: - total: int - reviewed: int - keep_count: int - drop_count: int - skip_count: int - progress_pct: float - kept_duration_s: float - kept_train_duration_s: float - kept_val_duration_s: float - kept_duration_by_source: dict[str, float] - - -def load_review_state(path: Path) -> list[ReviewRow]: - """Load review_state.csv and return a list of ReviewRow.""" - if not path.is_file(): - raise FileNotFoundError(f"review state not found: {path}") - - with path.open("r", encoding="utf-8", newline="") as f: - reader = csv.DictReader(f) - if reader.fieldnames is None: - raise ValueError(f"review state is empty: {path}") - optional = {"clip_index"} - required = [c for c in REVIEW_COLUMNS if c not in optional] - missing = [c for c in required if c not in reader.fieldnames] - if missing: - raise ValueError(f"review state missing columns: {missing}") - has_clip_index = "clip_index" in reader.fieldnames - - rows: list[ReviewRow] = [] - for idx, raw in enumerate(reader, start=2): - try: - row = ReviewRow( - clip_id=raw["clip_id"].strip(), - source=raw["source"].strip(), - file_rel=raw["file_rel"].strip(), - resolved_npz_path=raw["resolved_npz_path"].strip(), - resolved_split=raw["resolved_split"].strip(), - num_frames=int(raw["num_frames"]), - fps=int(raw["fps"]), - duration_s=float(raw["duration_s"]), - weight=float(raw["weight"]), - clip_index=int(raw["clip_index"]) if has_clip_index else -1, - decision=raw["decision"].strip(), - difficulty=raw["difficulty"].strip(), - issue_tags=raw["issue_tags"].strip(), - note=raw["note"].strip(), - reviewed_at=raw["reviewed_at"].strip(), - ) - except Exception as exc: - raise ValueError(f"line {idx}: {exc}") from exc - - if row.decision not in VALID_DECISIONS: - raise ValueError(f"line {idx}: invalid decision '{row.decision}'") - if row.difficulty not in VALID_DIFFICULTIES: - raise ValueError(f"line {idx}: invalid difficulty '{row.difficulty}'") - rows.append(row) - return rows - - -def save_review_state(rows: list[ReviewRow], path: Path) -> None: - """Atomically write review_state.csv (write to .tmp, then os.replace).""" - path.parent.mkdir(parents=True, exist_ok=True) - tmp_path = path.with_suffix(".csv.tmp") - with tmp_path.open("w", encoding="utf-8", newline="") as f: - writer = csv.writer(f) - writer.writerow(REVIEW_COLUMNS) - for row in rows: - writer.writerow([ - row.clip_id, - row.source, - row.file_rel, - row.resolved_npz_path, - row.resolved_split, - row.num_frames, - row.fps, - f"{row.duration_s:.4f}", - row.weight, - row.clip_index, - row.decision, - row.difficulty, - row.issue_tags, - row.note, - row.reviewed_at, - ]) - os.replace(tmp_path, path) - - -def compute_review_stats(rows: list[ReviewRow]) -> ReviewStats: - """Compute aggregate statistics from review rows.""" - total = len(rows) - reviewed = sum(1 for r in rows if r.decision != "") - keep_count = sum(1 for r in rows if r.decision == "keep") - drop_count = sum(1 for r in rows if r.decision == "drop") - skip_count = sum(1 for r in rows if r.decision == "skip") - - kept_duration_s = 0.0 - kept_train_duration_s = 0.0 - kept_val_duration_s = 0.0 - kept_duration_by_source: dict[str, float] = {} - - for r in rows: - if r.decision != "keep": - continue - kept_duration_s += r.duration_s - if r.resolved_split == "train": - kept_train_duration_s += r.duration_s - elif r.resolved_split == "val": - kept_val_duration_s += r.duration_s - kept_duration_by_source[r.source] = ( - kept_duration_by_source.get(r.source, 0.0) + r.duration_s - ) - - progress_pct = (reviewed / total * 100.0) if total > 0 else 0.0 - - return ReviewStats( - total=total, - reviewed=reviewed, - keep_count=keep_count, - drop_count=drop_count, - skip_count=skip_count, - progress_pct=progress_pct, - kept_duration_s=kept_duration_s, - kept_train_duration_s=kept_train_duration_s, - kept_val_duration_s=kept_val_duration_s, - kept_duration_by_source=kept_duration_by_source, - ) - - -def utc_now_iso() -> str: - return datetime.now(timezone.utc).isoformat() diff --git a/train_mimic/scripts/benchmark.py b/train_mimic/scripts/benchmark.py index 557df54b..005b243f 100644 --- a/train_mimic/scripts/benchmark.py +++ b/train_mimic/scripts/benchmark.py @@ -10,7 +10,7 @@ # Benchmark only (no video) python train_mimic/scripts/benchmark.py \ --checkpoint logs/rsl_rl/g1_tracking/.../model_30000.pt \ - --motion_file data/datasets/twist2/val \ + --motion_file data/datasets/twist2 \ --num_envs 1 # Single video (one continuous clip) @@ -40,6 +40,7 @@ validate_checkpoint_path, validate_motion_file, ) +from train_mimic.data.dataset_lib import find_motion_shards from teleopit.debug.rollout_trace import RolloutTraceWriter @@ -143,7 +144,7 @@ def _stats(values: list[float]) -> dict[str, float]: def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Benchmark G1 tracking policy.") parser.add_argument("--checkpoint", type=str, required=True) - parser.add_argument("--motion_file", type=str, required=True, help="Path to HDF5 motion shard directory") + parser.add_argument("--motion_file", type=str, required=True, help="Path to dataset root containing Teleopit shard_*.h5 files") parser.add_argument("--num_envs", type=int, default=1) parser.add_argument("--num_eval_steps", type=int, default=2000, help="Number of rollout steps for evaluation (default: 2000)") @@ -173,19 +174,9 @@ def parse_args() -> argparse.Namespace: def _load_motion_dir_video_metadata(motion_dir: str) -> tuple[float, int]: - shard_dir = Path(motion_dir) - manifest_path = shard_dir / "manifest.json" - if not manifest_path.is_file(): - raise FileNotFoundError(f"no HDF5 manifest.json found in {motion_dir}") - with manifest_path.open("r", encoding="utf-8") as f: - manifest = json.load(f) - if manifest.get("format") != "teleopit_motion_hdf5": - raise ValueError(f"unsupported motion manifest format in {manifest_path}") - clip_fps: float | None = None max_clip_frames = 0 - for shard in manifest.get("shards", []): - shard_path = shard_dir / str(shard["path"]) + for shard_path in find_motion_shards(motion_dir): with h5py.File(shard_path, "r") as h5: fps_arr = np.asarray(h5["clip_fps"], dtype=np.float32) if fps_arr.size == 0: diff --git a/train_mimic/scripts/convert_pkl_to_npz.py b/train_mimic/scripts/convert_pkl_to_npz.py index 8edd0aa2..382d3744 100644 --- a/train_mimic/scripts/convert_pkl_to_npz.py +++ b/train_mimic/scripts/convert_pkl_to_npz.py @@ -47,6 +47,7 @@ from train_mimic.data.motion_fk import ( MotionFkExtractor, compute_body_velocities, + finite_diff_velocity, normalize_quaternion, quat_xyzw_to_wxyz, ) @@ -131,7 +132,7 @@ def convert_pkl_to_arrays( ) # Joint velocity via finite difference - joint_vel = np.gradient(dof_pos, dt, axis=0).astype(np.float32) + joint_vel = finite_diff_velocity(dof_pos, dt) # Convert root quaternion: xyzw -> wxyz root_rot_wxyz = normalize_quaternion(quat_xyzw_to_wxyz(root_rot_xyzw)) @@ -148,6 +149,8 @@ def convert_pkl_to_arrays( return { "fps": fps, + "root_pos": root_pos, + "root_quat_w": root_rot_wxyz, "joint_pos": dof_pos, "joint_vel": joint_vel, "body_pos_w": body_pos_w, @@ -217,7 +220,7 @@ def convert_seed_csv_to_arrays( root_rot_xyzw = np.asarray(pkl_dict["root_rot"], dtype=np.float32) dof_pos = np.asarray(pkl_dict["dof_pos"], dtype=np.float32) - joint_vel = np.gradient(dof_pos, dt, axis=0).astype(np.float32) + joint_vel = finite_diff_velocity(dof_pos, dt) root_rot_wxyz = normalize_quaternion(quat_xyzw_to_wxyz(root_rot_xyzw)) fk_extractor = extractor or MotionFkExtractor() @@ -229,6 +232,8 @@ def convert_seed_csv_to_arrays( return { "fps": fps, + "root_pos": root_pos, + "root_quat_w": root_rot_wxyz, "joint_pos": dof_pos, "joint_vel": joint_vel, "body_pos_w": body_pos_w, diff --git a/train_mimic/scripts/data/build_dataset.py b/train_mimic/scripts/data/build_dataset.py index 74d1739a..0f9c4b2f 100644 --- a/train_mimic/scripts/data/build_dataset.py +++ b/train_mimic/scripts/data/build_dataset.py @@ -41,9 +41,8 @@ def main() -> int: output_root=args.output_root, ) print(f"[DONE] dataset={report['dataset']}") - print(f"[DONE] train={report['splits']['train']['output']}") - print(f"[DONE] val={report['splits']['val']['output']}") - print(f"[DONE] build_info={report['dataset_dir']}/build_info.json") + print(f"[DONE] output={report['dataset_dir']}") + print(f"[DONE] shards={len(report.get('shards', []))}") if args.json: print(json.dumps(report, ensure_ascii=False, indent=2)) return 0 diff --git a/train_mimic/scripts/data/ingest_motion.py b/train_mimic/scripts/data/ingest_motion.py index b8953d52..59c6c534 100644 --- a/train_mimic/scripts/data/ingest_motion.py +++ b/train_mimic/scripts/data/ingest_motion.py @@ -26,7 +26,6 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--input", required=True, help="Input file or directory") parser.add_argument("--output", required=True, help="Output clips directory") parser.add_argument("--source", default="source", help="Logical source name used in logs") - parser.add_argument("--weight", type=float, default=1.0, help="Optional source weight metadata") parser.add_argument("--bvh_format", choices=["lafan1", "hc_mocap", "nokov"], default=None) parser.add_argument("--robot_name", default="unitree_g1", help="Robot name for BVH retargeting") parser.add_argument("--max_frames", type=int, default=0, help="Max frames per BVH clip (0 = all)") @@ -42,7 +41,6 @@ def main() -> int: name=args.source, type=args.type, input=args.input, - weight=float(args.weight), bvh_format=args.bvh_format, robot_name=args.robot_name, max_frames=int(args.max_frames), diff --git a/train_mimic/scripts/data/inspect_dataset.py b/train_mimic/scripts/data/inspect_dataset.py new file mode 100644 index 00000000..001c3ed4 --- /dev/null +++ b/train_mimic/scripts/data/inspect_dataset.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import json + +from train_mimic.data.dataset_lib import compute_dataset_stats + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Inspect a Teleopit motion dataset root.") + parser.add_argument("dataset", type=str, help="Dataset root directory or a single .h5 shard") + parser.add_argument("--json", action="store_true", help="Print full JSON statistics") + return parser.parse_args() + + +def main() -> int: + args = parse_args() + stats = compute_dataset_stats(args.dataset) + if args.json: + print(json.dumps(stats, ensure_ascii=False, indent=2)) + return 0 + + print(f"root: {stats['root']}") + print(f"shards: {stats['shards']}") + print(f"windows: {stats['windows']}") + print(f"source_clips: {stats['source_clips']}") + print(f"frames: {stats['frames']}") + print(f"fps: {stats['fps']}") + print(f"bodies: {len(stats['body_names'])}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/train_mimic/scripts/play.py b/train_mimic/scripts/play.py index f64e2359..26546f22 100644 --- a/train_mimic/scripts/play.py +++ b/train_mimic/scripts/play.py @@ -9,18 +9,18 @@ # Native window python train_mimic/scripts/play.py \ --checkpoint logs/rsl_rl/g1_tracking/2026-.../model_30000.pt \ - --motion_file data/datasets/twist2/val + --motion_file data/datasets/twist2 # Browser viewer (no display required) python train_mimic/scripts/play.py \ --checkpoint logs/rsl_rl/g1_tracking/2026-.../model_30000.pt \ - --motion_file data/datasets/twist2/val \ + --motion_file data/datasets/twist2 \ --viewer viser # Record video instead of interactive viewer python train_mimic/scripts/play.py \ --checkpoint logs/rsl_rl/g1_tracking/2026-.../model_30000.pt \ - --motion_file data/datasets/twist2/val \ + --motion_file data/datasets/twist2 \ --video """ @@ -44,7 +44,7 @@ def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Play trained G1 tracking policy.") parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint") - parser.add_argument("--motion_file", type=str, required=True, help="Path to HDF5 motion shard directory") + parser.add_argument("--motion_file", type=str, required=True, help="Path to dataset root containing Teleopit shard_*.h5 files") parser.add_argument("--num_envs", type=int, default=1) parser.add_argument( "--viewer", type=str, default="native", choices=["native", "viser"], diff --git a/train_mimic/scripts/train.py b/train_mimic/scripts/train.py index 194f4759..3e4f6a30 100644 --- a/train_mimic/scripts/train.py +++ b/train_mimic/scripts/train.py @@ -4,30 +4,30 @@ Usage: python train_mimic/scripts/train.py \ --num_envs 4096 --max_iterations 18000 \ - --motion_file data/datasets/twist2/train + --motion_file data/datasets/twist2 # Quick verification python train_mimic/scripts/train.py \ --num_envs 64 --max_iterations 100 \ - --motion_file data/datasets/twist2/train + --motion_file data/datasets/twist2 # With W&B logging python train_mimic/scripts/train.py \ --num_envs 4096 --max_iterations 30000 \ - --motion_file data/datasets/twist2/train \ + --motion_file data/datasets/twist2 \ --logger wandb # With SwanLab logging python train_mimic/scripts/train.py \ --num_envs 4096 --max_iterations 30000 \ - --motion_file data/datasets/twist2/train \ + --motion_file data/datasets/twist2 \ --logger swanlab # Resume for additional iterations python train_mimic/scripts/train.py \ --resume logs/rsl_rl/g1_general_tracking//model_12000.pt \ --max_iterations 18000 \ - --motion_file data/datasets/twist2/train + --motion_file data/datasets/twist2 """ from __future__ import annotations @@ -74,7 +74,7 @@ def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace: ) parser.add_argument("--experiment_name", type=str, default=None) parser.add_argument("--motion_file", type=str, default=None, - help="HDF5 shard directory path containing manifest.json and shard_*.h5 files") + help="Dataset root containing Teleopit shard_*.h5 files, searched recursively") parser.add_argument( "--resume", type=str, diff --git a/train_mimic/tasks/tracking/config/constants.py b/train_mimic/tasks/tracking/config/constants.py index a96cf8b1..cdba0dc9 100644 --- a/train_mimic/tasks/tracking/config/constants.py +++ b/train_mimic/tasks/tracking/config/constants.py @@ -1,6 +1,6 @@ """Public constants for supported tracking tasks.""" -DEFAULT_TRAIN_MOTION_FILE = "data/datasets/twist2/train" +DEFAULT_TRAIN_MOTION_FILE = "data/datasets/twist2" GENERAL_TRACKING_TASK = "General-Tracking-G1" GENERAL_TRACKING_EXPERIMENT_NAME = "g1_general_tracking" diff --git a/train_mimic/tasks/tracking/mdp/commands.py b/train_mimic/tasks/tracking/mdp/commands.py index 454c7e4c..b76dd004 100644 --- a/train_mimic/tasks/tracking/mdp/commands.py +++ b/train_mimic/tasks/tracking/mdp/commands.py @@ -1,7 +1,6 @@ from __future__ import annotations import copy -import json import logging from dataclasses import dataclass, field from pathlib import Path @@ -15,9 +14,11 @@ from train_mimic.data.dataset_lib import ( MOTION_ARRAY_KEYS, compute_clip_sample_ranges, + compute_dataset_stats, + find_motion_shards, parse_window_steps, - read_hdf5_body_names, ) +from train_mimic.data.motion_fk import MotionFkExtractor, compute_body_velocities, finite_diff_velocity from mjlab.managers import CommandTerm, CommandTermCfg from mjlab.utils.lab_api.math import ( @@ -70,7 +71,6 @@ class _Hdf5ClipRef: start: int length: int fps: int - weight: float @dataclass @@ -78,7 +78,6 @@ class _MotionBatch: tensors: dict[str, torch.Tensor] lengths: torch.Tensor fps: torch.Tensor - weights: torch.Tensor sample_starts: torch.Tensor sample_ends: torch.Tensor global_ids: torch.Tensor @@ -97,15 +96,6 @@ def __init__( ) -> None: if cache_num_clips <= 0: raise ValueError(f"cache_num_clips must be positive, got {cache_num_clips}") - manifest_path = motion_dir / "manifest.json" - if not manifest_path.is_file(): - raise FileNotFoundError( - f"motion_file must be an HDF5 split directory containing manifest.json, got: {motion_dir}" - ) - with manifest_path.open("r", encoding="utf-8") as f: - manifest = json.load(f) - if manifest.get("format") != "teleopit_motion_hdf5": - raise ValueError(f"Unsupported motion manifest format in {manifest_path}") self.motion_dir = motion_dir self.body_idx_np = body_idx_np @@ -114,8 +104,19 @@ def __init__( self.cache_num_clips = int(cache_num_clips) self._rng = torch.Generator(device="cpu") self._rng.manual_seed(int(seed)) - self._shard_paths = [motion_dir / shard["path"] for shard in manifest["shards"]] - self.body_names = np.asarray(manifest["body_names"], dtype=str) + self._shard_paths = find_motion_shards(motion_dir) + stats = compute_dataset_stats(motion_dir) + self.body_names = np.asarray(stats["body_names"], dtype=str) + self._fk_extractor = MotionFkExtractor() + _LOG.info( + "Motion dataset: root=%s shards=%d windows=%d source_clips=%d frames=%d fps=%s", + motion_dir, + stats["shards"], + stats["windows"], + stats["source_clips"], + stats["frames"], + stats["fps"], + ) max_future = max((step for step in self.window_steps if step > 0), default=0) max_history = -min((step for step in self.window_steps if step < 0), default=0) @@ -125,17 +126,10 @@ def __init__( skipped_short = 0 for shard_index, shard_path in enumerate(self._shard_paths): with h5py.File(shard_path, "r") as h5: - shard_body_names = read_hdf5_body_names(shard_path) - if shard_body_names != self.body_names.tolist(): - raise ValueError( - f"HDF5 shard body_names mismatch for {shard_path}: " - "all shards must match manifest body_names order" - ) starts = np.asarray(h5["clip_starts"], dtype=np.int64) lengths = np.asarray(h5["clip_lengths"], dtype=np.int64) fps = np.asarray(h5["clip_fps"], dtype=np.int64) - weights = np.asarray(h5["clip_weights"], dtype=np.float64) - for start, length, cur_fps, weight in zip(starts, lengths, fps, weights): + for start, length, cur_fps in zip(starts, lengths, fps): if int(length) < min_clip_length: skipped_short += 1 continue @@ -144,7 +138,6 @@ def __init__( start=int(start), length=int(length), fps=int(cur_fps), - weight=float(weight), )) if not refs: raise ValueError(f"HDF5 motion dataset is empty: {motion_dir}") @@ -156,19 +149,29 @@ def __init__( list(self.window_steps), ) self.refs = refs - self.global_weights = torch.tensor( - [max(ref.weight, 0.0) for ref in refs], dtype=torch.float32 + ref_lengths_np = np.asarray([ref.length for ref in refs], dtype=np.int64) + ref_starts_np, ref_ends_np = compute_clip_sample_ranges( + ref_lengths_np, + window_steps=self.window_steps, ) - if float(self.global_weights.sum()) <= 0.0: - raise ValueError(f"All HDF5 motion weights are zero in {motion_dir}") + ref_fps_np = np.asarray([ref.fps for ref in refs], dtype=np.float32) + ref_valid_seconds = (ref_ends_np - ref_starts_np).astype(np.float32) / ref_fps_np + if np.any(ref_valid_seconds <= 0.0): + raise ValueError( + "HDF5 motion dataset contains windows with no valid sample duration " + f"after applying window_steps={list(self.window_steps)}" + ) + self.global_sample_weights = torch.as_tensor(ref_valid_seconds, dtype=torch.float32) + total_weight = float(self.global_sample_weights.sum().item()) + if total_weight <= 0.0: + raise ValueError(f"HDF5 motion dataset has no positive sample duration: {motion_dir}") self.generation = 0 self.current = self._load_random_batch() self.next = self._load_random_batch() def _sample_global_ids(self) -> torch.Tensor: - probs = self.global_weights / self.global_weights.sum() return torch.multinomial( - probs, + self.global_sample_weights, self.cache_num_clips, replacement=True, generator=self._rng, @@ -196,10 +199,26 @@ def _load_batch(self, global_ids: torch.Tensor) -> _MotionBatch: shard_path = self._shard_paths[ref.shard_index] sl = slice(ref.start, ref.start + ref.length) with h5py.File(shard_path, "r") as h5: - arrays["joint_pos"][out_i, :ref.length] = np.asarray(h5["joint_pos"][sl], dtype=np.float32) - arrays["joint_vel"][out_i, :ref.length] = np.asarray(h5["joint_vel"][sl], dtype=np.float32) - for key in ("body_pos_w", "body_quat_w", "body_lin_vel_w", "body_ang_vel_w"): - arrays[key][out_i, :ref.length] = np.asarray(h5[key][sl], dtype=np.float32)[:, self.body_idx_np] + root_pos = np.asarray(h5["root_pos"][sl], dtype=np.float32) + root_quat_w = np.asarray(h5["root_quat_w"][sl], dtype=np.float32) + joint_pos = np.asarray(h5["joint_pos"][sl], dtype=np.float32) + + dt = 1.0 / float(ref.fps) + body_pos_w, body_quat_w = self._fk_extractor.extract( + root_pos, + root_quat_w, + joint_pos, + self.body_names, + ) + body_lin_vel_w, body_ang_vel_w = compute_body_velocities(body_pos_w, body_quat_w, dt) + joint_vel = finite_diff_velocity(joint_pos, dt) + + arrays["joint_pos"][out_i, :ref.length] = joint_pos + arrays["joint_vel"][out_i, :ref.length] = joint_vel + arrays["body_pos_w"][out_i, :ref.length] = body_pos_w[:, self.body_idx_np] + arrays["body_quat_w"][out_i, :ref.length] = body_quat_w[:, self.body_idx_np] + arrays["body_lin_vel_w"][out_i, :ref.length] = body_lin_vel_w[:, self.body_idx_np] + arrays["body_ang_vel_w"][out_i, :ref.length] = body_ang_vel_w[:, self.body_idx_np] lengths_np = np.asarray([ref.length for ref in selected], dtype=np.int64) starts_np, ends_np = compute_clip_sample_ranges(lengths_np, window_steps=self.window_steps) @@ -208,7 +227,6 @@ def _load_batch(self, global_ids: torch.Tensor) -> _MotionBatch: tensors=tensors, lengths=torch.tensor(lengths_np, dtype=torch.long, device=self.device), fps=torch.tensor([ref.fps for ref in selected], dtype=torch.float32, device=self.device), - weights=torch.ones(len(selected), dtype=torch.float32, device=self.device), sample_starts=torch.tensor(starts_np, dtype=torch.long, device=self.device), sample_ends=torch.tensor(ends_np, dtype=torch.long, device=self.device), global_ids=global_ids.to(self.device), @@ -242,24 +260,16 @@ def __init__( self.window_steps = parse_window_steps(window_steps) motion_path = Path(motion_file) - if not motion_path.is_dir(): - raise FileNotFoundError( - f"motion_file must be an HDF5 shard directory, got: {motion_file}" - ) - manifest_path = motion_path / "manifest.json" - if not manifest_path.is_file(): + if not motion_path.exists(): raise FileNotFoundError( - f"motion_file must contain manifest.json for HDF5 loading, got: {motion_file}" + f"motion_file must be a dataset root directory or .h5 shard, got: {motion_file}" ) - with manifest_path.open("r", encoding="utf-8") as f: - manifest = json.load(f) - if manifest.get("format") != "teleopit_motion_hdf5": - raise ValueError(f"Unsupported motion manifest format in {manifest_path}") + stats = compute_dataset_stats(motion_path) if body_names is None: body_idx_np = body_indexes.cpu().numpy() else: - dataset_body_names = [str(name) for name in manifest["body_names"]] + dataset_body_names = [str(name) for name in stats["body_names"]] dataset_body_index_by_name = { name: index for index, name in enumerate(dataset_body_names) } @@ -298,7 +308,6 @@ def _set_batch(self, batch: _MotionBatch) -> None: self._body_ang_vel_w_t = batch.tensors["body_ang_vel_w"] self.clip_lengths = batch.lengths - self.clip_weights = batch.weights self.clip_fps = batch.fps self.num_clips = int(batch.lengths.shape[0]) self.time_step_total = int(batch.lengths.max().item()) @@ -321,15 +330,8 @@ def advance_cache(self) -> None: # ------------------------------------------------------------------ def sample_motion_ids(self, n: int) -> torch.Tensor: - """Sample *n* clip indices weighted by ``clip_weights``.""" - total = self.clip_weights.sum() - if total <= 0: - raise ValueError( - "All clip weights are zero — cannot sample. " - "Check that the shard dataset was built with positive weights." - ) - probs = self.clip_weights / total - return torch.multinomial(probs, n, replacement=True) + """Sample *n* cache-local clip indices uniformly.""" + return torch.randint(0, self.num_clips, (n,), device=self._device) def sample_times(self, motion_ids: torch.Tensor) -> torch.Tensor: """Uniform random time over valid center frames for each motion id.""" From fcb1b07b57fc009a1de560178137dcf81f9e032c Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Mon, 15 Jun 2026 11:14:07 +0800 Subject: [PATCH 085/122] Update dataset stats and configs --- tests/test_dataset_v2.py | 39 +++++++++++++++++++- train_mimic/configs/datasets/lafan1.yaml | 1 + train_mimic/configs/datasets/seed.yaml | 7 +++- train_mimic/configs/datasets/seed_clean.yaml | 14 ------- train_mimic/configs/datasets/twist2.yaml | 5 ++- train_mimic/data/dataset_lib.py | 12 ++++++ train_mimic/scripts/data/inspect_dataset.py | 10 +++++ 7 files changed, 70 insertions(+), 18 deletions(-) delete mode 100644 train_mimic/configs/datasets/seed_clean.yaml diff --git a/tests/test_dataset_v2.py b/tests/test_dataset_v2.py index b7905098..cca91872 100644 --- a/tests/test_dataset_v2.py +++ b/tests/test_dataset_v2.py @@ -9,7 +9,7 @@ import h5py from train_mimic.data import dataset_builder -from train_mimic.data.dataset_lib import write_hdf5_motion_shard +from train_mimic.data.dataset_lib import compute_dataset_stats, write_hdf5_motion_shard from train_mimic.data.dataset_builder import ( DatasetClipRow, SourceInputFile, @@ -779,6 +779,43 @@ def test_shard_stats_counts_real_frames_not_overlapped_windows(tmp_path: Path) - assert stats["duration_s"] == 1000 / 30 +def test_compute_dataset_stats_reports_total_duration_from_source_clips(tmp_path: Path) -> None: + num_bodies = len(_MJLAB_G1_BODY_NAMES) + total_frames = 18 + body_quat_w = np.zeros((total_frames, num_bodies, 4), dtype=np.float32) + body_quat_w[..., 0] = 1.0 + write_hdf5_motion_shard( + { + "fps": 30, + "root_pos": np.zeros((total_frames, 3), dtype=np.float32), + "root_quat_w": np.tile( + np.asarray([[1.0, 0.0, 0.0, 0.0]], dtype=np.float32), + (total_frames, 1), + ), + "joint_pos": np.zeros((total_frames, 29), dtype=np.float32), + "joint_vel": np.zeros((total_frames, 29), dtype=np.float32), + "body_pos_w": np.zeros((total_frames, num_bodies, 3), dtype=np.float32), + "body_quat_w": body_quat_w, + "body_lin_vel_w": np.zeros((total_frames, num_bodies, 3), dtype=np.float32), + "body_ang_vel_w": np.zeros((total_frames, num_bodies, 3), dtype=np.float32), + "body_names": np.asarray(_MJLAB_G1_BODY_NAMES, dtype=str), + "clip_starts": np.asarray([0, 12], dtype=np.int64), + "clip_lengths": np.asarray([12, 6], dtype=np.int64), + "clip_fps": np.asarray([30, 60], dtype=np.int64), + }, + tmp_path / "shard_000.h5", + max_window_frames=8, + overlap_frames=2, + ) + + stats = compute_dataset_stats(tmp_path) + + assert stats["frames"] == total_frames + assert stats["duration_s"] == pytest.approx(12 / 30 + 6 / 60) + assert stats["duration_h"] == pytest.approx((12 / 30 + 6 / 60) / 3600) + assert stats["shard_details"][0]["duration_s"] == pytest.approx(12 / 30 + 6 / 60) + + def test_build_dataset_batch_manifest_skips_filtered_entries( tmp_path: Path, diff --git a/train_mimic/configs/datasets/lafan1.yaml b/train_mimic/configs/datasets/lafan1.yaml index 3517a1fc..d179f307 100644 --- a/train_mimic/configs/datasets/lafan1.yaml +++ b/train_mimic/configs/datasets/lafan1.yaml @@ -1,6 +1,7 @@ name: lafan1 target_fps: 30 preprocess: + min_frames: 22 normalize_root_xy: true ground_align: none max_all_off_ground_s: 2.0 diff --git a/train_mimic/configs/datasets/seed.yaml b/train_mimic/configs/datasets/seed.yaml index d1bb7fb3..680df526 100644 --- a/train_mimic/configs/datasets/seed.yaml +++ b/train_mimic/configs/datasets/seed.yaml @@ -1,13 +1,16 @@ name: seed target_fps: 30 preprocess: - normalize_root_xy: true - ground_align: first_frame_foot min_frames: 22 + normalize_root_xy: true + ground_align: none + max_all_off_ground_s: 2.0 + off_ground_height: 0.2 sources: - name: seed_full type: seed_csv input: data/SEED/g1/csv metadata_csv: data/SEED/seed_metadata_v003.csv + seed_filter_preset: groot_strict filters: is_mirror: [false] diff --git a/train_mimic/configs/datasets/seed_clean.yaml b/train_mimic/configs/datasets/seed_clean.yaml deleted file mode 100644 index 3eaf5bf7..00000000 --- a/train_mimic/configs/datasets/seed_clean.yaml +++ /dev/null @@ -1,14 +0,0 @@ -name: seed_clean -target_fps: 30 -preprocess: - normalize_root_xy: true - ground_align: first_frame_foot - min_frames: 22 -sources: - - name: seed_full - type: seed_csv - input: data/SEED/g1/csv - metadata_csv: data/SEED/seed_metadata_v003.csv - seed_filter_preset: groot_strict - filters: - is_mirror: [false] diff --git a/train_mimic/configs/datasets/twist2.yaml b/train_mimic/configs/datasets/twist2.yaml index ba40c961..ee7f8541 100644 --- a/train_mimic/configs/datasets/twist2.yaml +++ b/train_mimic/configs/datasets/twist2.yaml @@ -1,8 +1,11 @@ name: twist2 target_fps: 30 preprocess: + min_frames: 22 normalize_root_xy: true - ground_align: first_frame_foot + ground_align: none + max_all_off_ground_s: 2.0 + off_ground_height: 0.2 sources: - name: OMOMO_g1_GMR type: pkl diff --git a/train_mimic/data/dataset_lib.py b/train_mimic/data/dataset_lib.py index a61ef30e..ae967be4 100644 --- a/train_mimic/data/dataset_lib.py +++ b/train_mimic/data/dataset_lib.py @@ -632,6 +632,7 @@ def compute_dataset_stats(dataset_dir: str | Path) -> dict[str, Any]: shards = find_motion_shards(dataset_dir) total_windows = 0 total_frames = 0 + total_duration_s = 0.0 total_source_clips = 0 fps_values: set[int] = set() body_names_ref: list[str] | None = None @@ -669,14 +670,23 @@ def compute_dataset_stats(dataset_dir: str | Path) -> dict[str, Any]: windows = int(lengths.shape[0]) frames = int(np.asarray(h5["joint_pos"]).shape[0]) source_clips = int(len(np.unique(source_ids))) + if "source_clip_lengths" in h5 and "source_clip_fps" in h5: + source_lengths = np.asarray(h5["source_clip_lengths"], dtype=np.float64) + source_fps = np.asarray(h5["source_clip_fps"], dtype=np.float64) + shard_duration_s = float(np.sum(source_lengths / np.maximum(source_fps, 1.0))) + else: + shard_fps = int(h5.attrs.get("fps", fps_arr[0] if fps_arr.shape[0] else 1)) + shard_duration_s = float(frames / max(shard_fps, 1)) total_windows += windows total_frames += frames + total_duration_s += shard_duration_s total_source_clips += source_clips shard_rows.append({ "path": str(shard_path), "windows": windows, "source_clips": source_clips, "frames": frames, + "duration_s": shard_duration_s, "min_window_frames": int(lengths.min()) if windows else 0, "max_window_frames": int(lengths.max()) if windows else 0, }) @@ -689,6 +699,8 @@ def compute_dataset_stats(dataset_dir: str | Path) -> dict[str, Any]: "windows": total_windows, "source_clips": total_source_clips, "frames": total_frames, + "duration_s": total_duration_s, + "duration_h": total_duration_s / 3600.0, "fps": sorted(fps_values), "body_names": body_names_ref or [], "shard_details": shard_rows, diff --git a/train_mimic/scripts/data/inspect_dataset.py b/train_mimic/scripts/data/inspect_dataset.py index 001c3ed4..84caf97c 100644 --- a/train_mimic/scripts/data/inspect_dataset.py +++ b/train_mimic/scripts/data/inspect_dataset.py @@ -7,6 +7,15 @@ from train_mimic.data.dataset_lib import compute_dataset_stats +def _format_duration(seconds: float) -> str: + total_seconds = int(round(seconds)) + hours, rem = divmod(total_seconds, 3600) + minutes, secs = divmod(rem, 60) + if hours: + return f"{hours:d}:{minutes:02d}:{secs:02d} ({seconds / 3600.0:.2f} h)" + return f"{minutes:d}:{secs:02d} ({seconds:.1f} s)" + + def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Inspect a Teleopit motion dataset root.") parser.add_argument("dataset", type=str, help="Dataset root directory or a single .h5 shard") @@ -26,6 +35,7 @@ def main() -> int: print(f"windows: {stats['windows']}") print(f"source_clips: {stats['source_clips']}") print(f"frames: {stats['frames']}") + print(f"duration: {_format_duration(float(stats['duration_s']))}") print(f"fps: {stats['fps']}") print(f"bodies: {len(stats['body_names'])}") return 0 From b66da1ab1b0bd68089306470f69b12ffe656a26e Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Mon, 15 Jun 2026 12:09:40 +0800 Subject: [PATCH 086/122] Rebuild temporary dataset clips --- docs/docs/reference/dataset.md | 10 ++-- .../reference/training-troubleshooting.md | 2 +- .../current/reference/dataset.md | 10 ++-- .../reference/training-troubleshooting.md | 2 +- tests/test_dataset_v2.py | 30 ++++++++++- train_mimic/configs/datasets/lafan1.yaml | 4 +- train_mimic/configs/datasets/twist2.yaml | 4 +- train_mimic/data/dataset_builder.py | 53 ++++++++++++++++++- 8 files changed, 94 insertions(+), 21 deletions(-) diff --git a/docs/docs/reference/dataset.md b/docs/docs/reference/dataset.md index 389ba56c..555b105d 100644 --- a/docs/docs/reference/dataset.md +++ b/docs/docs/reference/dataset.md @@ -59,12 +59,10 @@ python train_mimic/scripts/data/build_dataset.py \ ```text data/datasets// -├── clips/ # Optional; only for per-clip intermediates -│ └── /... └── shard_*.h5 ``` -- If the spec contains `bvh` or `npz` sources, the builder retains/generates `clips/` +- If the spec contains `bvh` or `npz` sources, the full dataset builder uses a temporary `clips/` directory during conversion and deletes it after shards are written. Rebuilds do not reuse converted clips. - If the spec is all `pkl` or `seed_csv` sources, the builder takes a batch path producing shards directly - Training recursively discovers `*.h5` shards below the specified root, so datasets can be merged by placing multiple shard directories under one parent - Training loads only a subset cache from the discovered shards, derives FK/velocities online, stages the next cache, and swaps caches at the PPO rollout barrier. @@ -101,7 +99,7 @@ sources: | `preprocess.max_root_lin_vel` | Root linear velocity filter threshold | | `preprocess.min_peak_body_height` | Minimum peak body height | | `preprocess.max_all_off_ground_s` | Max duration all feet off ground | -| `sources[].name` | Source name (used for clips subdirectory) | +| `sources[].name` | Source name | | `sources[].type` | `bvh` / `pkl` / `npz` / `seed_csv` | | `sources[].input` | Input file or directory | | `sources[].bvh_format` | Required for BVH: `lafan1` / `hc_mocap` / `nokov` | @@ -149,7 +147,7 @@ Convert raw data to standard NPZ clips without merging: ```bash python train_mimic/scripts/data/ingest_motion.py \ --type bvh --input data/lafan1_bvh \ - --output data/datasets/lafan1/clips/lafan1 \ + --output data/lafan1_clips/lafan1 \ --source lafan1 --bvh_format lafan1 --jobs 8 ``` @@ -157,7 +155,7 @@ python train_mimic/scripts/data/ingest_motion.py \ ```bash python train_mimic/scripts/data/check_motion_npz_fk.py \ - --npz data/datasets//clips//.npz + --npz data/lafan1_clips/lafan1/.npz ``` Recommended thresholds: `pos_max < 1e-3 m`, `quat_mean < 0.05 rad`, `quat_p95 < 0.10 rad`. diff --git a/docs/docs/reference/training-troubleshooting.md b/docs/docs/reference/training-troubleshooting.md index e1aff875..b7c82249 100644 --- a/docs/docs/reference/training-troubleshooting.md +++ b/docs/docs/reference/training-troubleshooting.md @@ -35,7 +35,7 @@ The current `convert_pkl_to_npz.py` fixes these issues. ```bash python train_mimic/scripts/data/check_motion_npz_fk.py \ - --npz data/datasets//clips//.npz + --npz data/lafan1_clips/lafan1/.npz ``` Expected thresholds: `pos_max < 1e-3 m`, `quat_mean < 0.05 rad`, `quat_p95 < 0.10 rad`. diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/dataset.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/dataset.md index fa92e529..5c65fef1 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/dataset.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/dataset.md @@ -57,12 +57,10 @@ python train_mimic/scripts/data/build_dataset.py \ ```text data/datasets// -├── clips/ # 可选;仅在需要逐 clip 中间产物时存在 -│ └── /... └── shard_*.h5 ``` -- 若 spec 包含 `bvh` 或 `npz` source,builder 会保留/生成 `clips/` +- 若 spec 包含 `bvh` 或 `npz` source,完整 dataset builder 会在转换期间使用临时 `clips/` 目录,并在 shard 写入完成后删除。重新 build 不会复用已转换 clips。 - 若 spec 全部是 `pkl` 或 `seed_csv` source,builder 会直接并行产出 shard,默认不写中间 clip 文件 - 训练会递归发现指定根目录下的 `*.h5` shard,因此可以把多个数据集目录放到同一个父目录下完成合并 - 训练时只从发现的 shard 加载一个 subset cache,在线派生 FK/速度,同时预加载下一个 cache,并在 PPO rollout barrier 处切换。 @@ -97,7 +95,7 @@ sources: | `preprocess.ground_align` | `none` / `first_frame_foot` | | `preprocess.min_frames` | clip 最短长度约束 | | `preprocess.max_root_lin_vel` / `min_peak_body_height` / `max_all_off_ground_s` | 基础过滤阈值 | -| `sources[].name` | source 名称;生成 clip 中间产物时也作为 `clips//` 子目录名 | +| `sources[].name` | source 名称 | | `sources[].type` | `bvh` / `pkl` / `npz` / `seed_csv` | | `sources[].input` | 原始输入文件或目录 | | `sources[].bvh_format` | 仅 `bvh` source 必填:`lafan1` / `hc_mocap` / `nokov` | @@ -145,7 +143,7 @@ python train_mimic/scripts/data/inspect_dataset.py data/datasets/twist2 ```bash python train_mimic/scripts/data/ingest_motion.py \ --type bvh --input data/lafan1_bvh \ - --output data/datasets/lafan1/clips/lafan1 \ + --output data/lafan1_clips/lafan1 \ --source lafan1 --bvh_format lafan1 --jobs 8 ``` @@ -153,7 +151,7 @@ python train_mimic/scripts/data/ingest_motion.py \ ```bash python train_mimic/scripts/data/check_motion_npz_fk.py \ - --npz data/datasets//clips//.npz + --npz data/lafan1_clips/lafan1/.npz ``` 推荐判据:`pos_max < 1e-3 m`、`quat_mean < 0.05 rad`、`quat_p95 < 0.10 rad`。 diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/training-troubleshooting.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/training-troubleshooting.md index 935affae..ea9e6f95 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/training-troubleshooting.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/training-troubleshooting.md @@ -35,7 +35,7 @@ sidebar_position: 5 ```bash python train_mimic/scripts/data/check_motion_npz_fk.py \ - --npz data/datasets//clips//.npz + --npz data/lafan1_clips/lafan1/.npz ``` 推荐判据:`pos_max < 1e-3 m`、`quat_mean < 0.05 rad`、`quat_p95 < 0.10 rad`。 diff --git a/tests/test_dataset_v2.py b/tests/test_dataset_v2.py index cca91872..3722ff0e 100644 --- a/tests/test_dataset_v2.py +++ b/tests/test_dataset_v2.py @@ -482,12 +482,17 @@ def test_build_dataset_from_spec_writes_shard_directories(tmp_path: Path) -> Non ) output_root = tmp_path / "datasets" + stale_cache = output_root / "demo_dataset" / "clips" / "npz_src" / "clip_a.npz" + _write_npz_from_pkl(stale_cache) + stale_payload = dict(np.load(stale_cache, allow_pickle=True)) + stale_payload["root_pos"] = np.asarray(stale_payload["root_pos"], dtype=np.float32) + 100.0 + np.savez(stale_cache, **stale_payload) + report = build_dataset_from_spec(spec, jobs=2, skip_fk_check=True, output_root=output_root) dataset_dir = output_root / "demo_dataset" assert report["dataset_dir"] == str(dataset_dir) - assert (dataset_dir / "clips" / "npz_src" / "clip_a.npz").is_file() - assert (dataset_dir / "clips" / "npz_src" / "clip_b.npz").is_file() + assert not (dataset_dir / "clips").exists() assert (dataset_dir / "shard_000.h5").is_file() assert report["input_clips"] == 2 @@ -496,6 +501,27 @@ def test_build_dataset_from_spec_writes_shard_directories(tmp_path: Path) -> Non assert "root_quat_w" in shard assert "joint_pos" in shard assert "body_pos_w" not in shard + assert float(shard["root_pos"][0, 2]) < 10.0 + + +def test_build_dataset_from_spec_rejects_clips_root_source_without_deleting_input( + tmp_path: Path, +) -> None: + output_root = tmp_path / "datasets" + source_dir = output_root / "demo_dataset" / "clips" / "npz_src" + clip_path = source_dir / "clip_a.npz" + _write_npz_from_pkl(clip_path) + + spec = DatasetSpec( + name="demo_dataset", + target_fps=30, + sources=[DatasetSourceSpec(name="npz_src", type="npz", input=str(source_dir))], + ) + + with pytest.raises(ValueError, match="temporary clips directory"): + build_dataset_from_spec(spec, jobs=1, skip_fk_check=True, output_root=output_root) + + assert clip_path.is_file() def test_collect_clip_rows_ignores_stale_excluded_cached_npz(tmp_path: Path) -> None: diff --git a/train_mimic/configs/datasets/lafan1.yaml b/train_mimic/configs/datasets/lafan1.yaml index d179f307..19df2d0c 100644 --- a/train_mimic/configs/datasets/lafan1.yaml +++ b/train_mimic/configs/datasets/lafan1.yaml @@ -4,8 +4,8 @@ preprocess: min_frames: 22 normalize_root_xy: true ground_align: none - max_all_off_ground_s: 2.0 - off_ground_height: 0.2 + max_all_off_ground_s: 0.8 + off_ground_height: 0.12 sources: - name: lafan1 type: bvh diff --git a/train_mimic/configs/datasets/twist2.yaml b/train_mimic/configs/datasets/twist2.yaml index ee7f8541..19043b8d 100644 --- a/train_mimic/configs/datasets/twist2.yaml +++ b/train_mimic/configs/datasets/twist2.yaml @@ -4,8 +4,8 @@ preprocess: min_frames: 22 normalize_root_xy: true ground_align: none - max_all_off_ground_s: 2.0 - off_ground_height: 0.2 + max_all_off_ground_s: 0.8 + off_ground_height: 0.12 sources: - name: OMOMO_g1_GMR type: pkl diff --git a/train_mimic/data/dataset_builder.py b/train_mimic/data/dataset_builder.py index 38a4b627..20369814 100644 --- a/train_mimic/data/dataset_builder.py +++ b/train_mimic/data/dataset_builder.py @@ -403,6 +403,52 @@ def _clear_existing_motion_shards(dataset_dir: Path) -> None: path.unlink() +def _clear_intermediate_clips(clips_root: Path) -> None: + if not clips_root.exists() and not clips_root.is_symlink(): + return + if clips_root.is_dir() and not clips_root.is_symlink(): + shutil.rmtree(clips_root) + else: + clips_root.unlink() + + +def _path_contains(path: Path, candidate: Path) -> bool: + try: + candidate.relative_to(path) + except ValueError: + return False + return True + + +def _source_input_candidate_path(source: DatasetSourceSpec) -> Path: + candidate = Path(source.input).expanduser() + if not candidate.is_absolute(): + candidate = PROJECT_ROOT / candidate + return candidate.resolve(strict=False) + + +def _ensure_source_inputs_do_not_overlap_intermediate_clips( + spec: DatasetSpec, + clips_root: Path, +) -> None: + clips_path = clips_root.resolve(strict=False) + conflicts: list[tuple[str, Path]] = [] + for source in spec.sources: + input_path = _source_input_candidate_path(source) + if _path_contains(clips_path, input_path) or _path_contains(input_path, clips_path): + conflicts.append((source.name, input_path)) + + if not conflicts: + return + + details = ", ".join(f"{name}={path}" for name, path in conflicts) + raise ValueError( + f"source input overlaps the temporary clips directory {clips_path}: {details}. " + "The dataset builder deletes that directory during full builds, so move source clips " + "outside data/datasets//clips or choose a different output root." + ) + + def resolve_source_input_path(source: DatasetSourceSpec) -> Path: candidate = Path(source.input).expanduser() input_path = candidate.resolve() if candidate.is_absolute() else (PROJECT_ROOT / candidate).resolve() @@ -1029,6 +1075,7 @@ def convert_sources_to_npz( ) -> dict[str, Path]: if jobs <= 0: raise ValueError(f"jobs must be > 0, got {jobs}") + _ensure_source_inputs_do_not_overlap_intermediate_clips(spec, paths.clips_root) if force and paths.dataset_dir.exists(): shutil.rmtree(paths.dataset_dir) paths.clips_root.mkdir(parents=True, exist_ok=True) @@ -1470,7 +1517,10 @@ def build_dataset_from_spec( ) # Per-file mode for BVH/NPZ sources. Converted clips are temporary build - # inputs; final training data is the minimal shard(s) in dataset_dir. + # inputs and are rebuilt every time; final training data is the minimal + # shard(s) in dataset_dir. + _ensure_source_inputs_do_not_overlap_intermediate_clips(spec, paths.clips_root) + _clear_intermediate_clips(paths.clips_root) convert_sources_to_npz(spec, paths=paths, force=force, jobs=jobs) rows = collect_clip_rows(spec, paths=paths) @@ -1491,6 +1541,7 @@ def build_dataset_from_spec( payload = {key: payload_npz[key] for key in payload_npz.files} shard_info = write_hdf5_motion_shard(payload, shard_path) tmp_npz.unlink(missing_ok=True) + _clear_intermediate_clips(paths.clips_root) stats["output"] = str(paths.dataset_dir) stats["shards"] = 1 From bfc8c8e5a8eba2183277a4fc8a63db90282ed532 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Mon, 15 Jun 2026 14:56:16 +0800 Subject: [PATCH 087/122] Add async HDF5 motion cache loading --- AGENTS.md | 4 +- docs/docs/tutorials/training.md | 1 + .../current/tutorials/training.md | 1 + train_mimic/scripts/train.py | 12 + train_mimic/tasks/tracking/mdp/commands.py | 356 ++++++++++++++---- 5 files changed, 309 insertions(+), 65 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 1400e12c..a24912e5 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -205,8 +205,8 @@ The single supported training task is `General-Tracking-G1` (experiment name: `g - Dataset build spec supports a `preprocess` section for root-xy normalization, ground alignment, and basic clip filtering - Final training dataset outputs are minimal HDF5 shards directly under `data/datasets//` (recursive shard discovery is supported; no train/val split and no manifest file) - Each shard stores only `root_pos`, `root_quat_w`, `joint_pos`, `body_names`, and clip-aware window metadata (`clip_starts`, `clip_lengths`, `clip_fps`); long clips are split into overlapping bounded windows -- Training computes joint velocities and body FK/velocities online when loading the motion cache -- `MotionLib` loads only a configurable HDF5 subset cache into CPU/GPU memory, stages the next cache, and swaps at the PPO rollout barrier +- Training computes joint velocities and body FK/velocities online in PyTorch DataLoader workers when loading the motion cache +- `MotionLib` loads only a configurable HDF5 subset cache into CPU/GPU memory, asynchronously stages the next cache, and swaps at the PPO rollout barrier - `MotionLib` samples only valid center frames for the configured `window_steps`; default is `window_steps=[0]` - Training supports `uniform` and `rewind` sampling on the active cache; in distributed training each rank sets a rank-offset `cache_seed` - `scripts/run/record_pico_motion.py` records Pico live body tracking as retargeted G1 motion NPZ clips in `data/pico_motion/clips/`; it opens a live `Retarget` viewer, uses terminal keys `R/S/D/N/Q`, stores semantic labels in filenames, and intentionally does not write per-clip JSON diff --git a/docs/docs/tutorials/training.md b/docs/docs/tutorials/training.md index a210ff4f..6b150ac8 100644 --- a/docs/docs/tutorials/training.md +++ b/docs/docs/tutorials/training.md @@ -76,6 +76,7 @@ torchrun \ - Default logger is TensorBoard. Use `--logger wandb` or `--logger swanlab` to select W&B or SwanLab; the project name defaults to `experiment_name` - `--motion_file` accepts a dataset root directory or single `.h5` shard; shard discovery is recursive - `--cache_num_clips` controls the active HDF5 subset size; `--cache_swap_interval_steps` controls how often the next subset is swapped in at a rollout barrier +- `--cache_dataloader_num_workers`, `--cache_dataloader_prefetch_factor`, and `--cache_dataloader_pin_memory` tune asynchronous HDF5 cache loading without increasing dataset size - `--max_iterations` means additional iterations; resuming from `model_12000.pt` with `--max_iterations 18000` trains to `model_30000.pt` ## Export ONNX diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md index 3c43a5e9..eea5989a 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md @@ -76,6 +76,7 @@ torchrun \ - 默认日志工具为 TensorBoard。使用 `--logger wandb` 或 `--logger swanlab` 可选择 W&B 或 SwanLab;项目名默认使用 `experiment_name` - `--motion_file` 接受数据集根目录或单个 `.h5` shard;shard 会递归发现 - `--cache_num_clips` 控制当前 HDF5 subset cache 大小;`--cache_swap_interval_steps` 控制在 rollout barrier 切换下一个 subset 的频率 +- `--cache_dataloader_num_workers`、`--cache_dataloader_prefetch_factor` 和 `--cache_dataloader_pin_memory` 用于调节异步 HDF5 cache 加载,不会增加数据集大小 - `--max_iterations` 表示追加迭代次数;例如从 `model_12000.pt` 恢复训练并设置 `--max_iterations 18000`,最终将训练到 `model_30000.pt` ## 导出 ONNX diff --git a/train_mimic/scripts/train.py b/train_mimic/scripts/train.py index 3e4f6a30..24cd9c48 100644 --- a/train_mimic/scripts/train.py +++ b/train_mimic/scripts/train.py @@ -97,6 +97,12 @@ def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace: help="Number of HDF5 motion windows to keep in the active subset cache") parser.add_argument("--cache_swap_interval_steps", type=int, default=None, help="Policy steps between HDF5 motion cache swaps; swaps occur at rollout barriers") + parser.add_argument("--cache_dataloader_num_workers", type=int, default=None, + help="Number of PyTorch DataLoader workers for asynchronous HDF5 motion cache loading") + parser.add_argument("--cache_dataloader_prefetch_factor", type=int, default=None, + help="PyTorch DataLoader prefetch factor for asynchronous HDF5 motion cache loading") + parser.add_argument("--cache_dataloader_pin_memory", action=argparse.BooleanOptionalAction, default=None, + help="Pin CPU motion cache batches before asynchronous CUDA staging") parser.add_argument("--device", type=str, default=None) parser.add_argument( "--gpu_ids", @@ -390,6 +396,12 @@ def _handle_shutdown(signum: int, _frame: Any) -> None: env_cfg.commands["motion"].cache_num_clips = args.cache_num_clips if args.cache_swap_interval_steps is not None: env_cfg.commands["motion"].cache_swap_interval_steps = args.cache_swap_interval_steps + if args.cache_dataloader_num_workers is not None: + env_cfg.commands["motion"].cache_dataloader_num_workers = args.cache_dataloader_num_workers + if args.cache_dataloader_prefetch_factor is not None: + env_cfg.commands["motion"].cache_dataloader_prefetch_factor = args.cache_dataloader_prefetch_factor + if args.cache_dataloader_pin_memory is not None: + env_cfg.commands["motion"].cache_dataloader_pin_memory = args.cache_dataloader_pin_memory if args.max_iterations is not None: agent_cfg.max_iterations = args.max_iterations if args.experiment_name is not None: diff --git a/train_mimic/tasks/tracking/mdp/commands.py b/train_mimic/tasks/tracking/mdp/commands.py index b76dd004..85c5260c 100644 --- a/train_mimic/tasks/tracking/mdp/commands.py +++ b/train_mimic/tasks/tracking/mdp/commands.py @@ -2,14 +2,16 @@ import copy import logging +import time from dataclasses import dataclass, field from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Iterator, Literal import h5py import mujoco import numpy as np import torch +from torch.utils.data import DataLoader, Dataset, Sampler from train_mimic.data.dataset_lib import ( MOTION_ARRAY_KEYS, @@ -82,6 +84,157 @@ class _MotionBatch: sample_ends: torch.Tensor global_ids: torch.Tensor + def pin_memory(self) -> "_MotionBatch": + return _MotionBatch( + tensors={key: value.pin_memory() for key, value in self.tensors.items()}, + lengths=self.lengths.pin_memory(), + fps=self.fps.pin_memory(), + sample_starts=self.sample_starts.pin_memory(), + sample_ends=self.sample_ends.pin_memory(), + global_ids=self.global_ids.pin_memory(), + ) + + +@dataclass +class _MotionClipSample: + tensors: dict[str, torch.Tensor] + length: int + fps: int + sample_start: int + sample_end: int + global_id: int + + +class _WeightedInfiniteClipBatchSampler(Sampler[list[int]]): + def __init__( + self, + *, + sample_weights: torch.Tensor, + batch_size: int, + seed: int, + ) -> None: + if batch_size <= 0: + raise ValueError(f"batch_size must be positive, got {batch_size}") + self.sample_weights = sample_weights.cpu().to(dtype=torch.float32) + self.batch_size = int(batch_size) + self._rng = torch.Generator(device="cpu") + self._rng.manual_seed(int(seed)) + + def __iter__(self) -> Iterator[list[int]]: + while True: + ids = torch.multinomial( + self.sample_weights, + self.batch_size, + replacement=True, + generator=self._rng, + ) + yield [int(value) for value in ids.tolist()] + + +class _Hdf5MotionDataset(Dataset[_MotionClipSample]): + def __init__( + self, + *, + refs: list[_Hdf5ClipRef], + shard_paths: list[Path], + body_idx_np: np.ndarray, + body_names: np.ndarray, + window_steps: tuple[int, ...], + ) -> None: + self.refs = refs + self._shard_paths = shard_paths + self.body_idx_np = body_idx_np + self.body_names = body_names + self.window_steps = window_steps + self._fk_extractor = MotionFkExtractor() + self._h5_handles: dict[int, h5py.File] = {} + + def __len__(self) -> int: + return len(self.refs) + + def __getitem__(self, index: int) -> _MotionClipSample: + ref = self.refs[int(index)] + sl = slice(ref.start, ref.start + ref.length) + h5 = self._h5_handle(ref.shard_index) + root_pos = np.asarray(h5["root_pos"][sl], dtype=np.float32) + root_quat_w = np.asarray(h5["root_quat_w"][sl], dtype=np.float32) + joint_pos = np.asarray(h5["joint_pos"][sl], dtype=np.float32) + + dt = 1.0 / float(ref.fps) + body_pos_w, body_quat_w = self._fk_extractor.extract( + root_pos, + root_quat_w, + joint_pos, + self.body_names, + ) + body_lin_vel_w, body_ang_vel_w = compute_body_velocities(body_pos_w, body_quat_w, dt) + joint_vel = finite_diff_velocity(joint_pos, dt) + sample_starts, sample_ends = compute_clip_sample_ranges( + np.asarray([ref.length], dtype=np.int64), + window_steps=self.window_steps, + ) + tensors = { + "joint_pos": torch.from_numpy(joint_pos), + "joint_vel": torch.from_numpy(joint_vel), + "body_pos_w": torch.from_numpy(body_pos_w[:, self.body_idx_np]), + "body_quat_w": torch.from_numpy(body_quat_w[:, self.body_idx_np]), + "body_lin_vel_w": torch.from_numpy(body_lin_vel_w[:, self.body_idx_np]), + "body_ang_vel_w": torch.from_numpy(body_ang_vel_w[:, self.body_idx_np]), + } + return _MotionClipSample( + tensors=tensors, + length=ref.length, + fps=ref.fps, + sample_start=int(sample_starts[0]), + sample_end=int(sample_ends[0]), + global_id=int(index), + ) + + def _h5_handle(self, shard_index: int) -> h5py.File: + handle = self._h5_handles.get(shard_index) + if handle is not None and handle.id: + return handle + handle = h5py.File(self._shard_paths[shard_index], "r") + self._h5_handles[shard_index] = handle + return handle + + def close(self) -> None: + for handle in self._h5_handles.values(): + if handle.id: + handle.close() + self._h5_handles.clear() + + +def _collate_motion_clips(samples: list[_MotionClipSample]) -> _MotionBatch: + if not samples: + raise ValueError("Motion cache DataLoader produced an empty batch") + max_len = max(sample.length for sample in samples) + arrays: dict[str, torch.Tensor] = {} + for key in MOTION_ARRAY_KEYS: + first = samples[0].tensors[key] + arrays[key] = torch.zeros( + (len(samples), max_len, *first.shape[1:]), + dtype=torch.float32, + ) + + for out_i, sample in enumerate(samples): + for key, value in sample.tensors.items(): + arrays[key][out_i, :sample.length] = value + + return _MotionBatch( + tensors=arrays, + lengths=torch.tensor([sample.length for sample in samples], dtype=torch.long), + fps=torch.tensor([sample.fps for sample in samples], dtype=torch.float32), + sample_starts=torch.tensor([sample.sample_start for sample in samples], dtype=torch.long), + sample_ends=torch.tensor([sample.sample_end for sample in samples], dtype=torch.long), + global_ids=torch.tensor([sample.global_id for sample in samples], dtype=torch.long), + ) + + +def _motion_worker_init(worker_id: int) -> None: + del worker_id + torch.set_num_threads(1) + class _Hdf5MotionCache: def __init__( @@ -93,6 +246,9 @@ def __init__( window_steps: tuple[int, ...], cache_num_clips: int, seed: int, + dataloader_num_workers: int, + dataloader_prefetch_factor: int, + dataloader_pin_memory: bool, ) -> None: if cache_num_clips <= 0: raise ValueError(f"cache_num_clips must be positive, got {cache_num_clips}") @@ -102,8 +258,12 @@ def __init__( self.device = device self.window_steps = window_steps self.cache_num_clips = int(cache_num_clips) - self._rng = torch.Generator(device="cpu") - self._rng.manual_seed(int(seed)) + self.dataloader_num_workers = max(0, int(dataloader_num_workers)) + self.dataloader_prefetch_factor = max(1, int(dataloader_prefetch_factor)) + self.dataloader_pin_memory = bool(dataloader_pin_memory) + self._device = torch.device(device) + self._copy_stream: torch.cuda.Stream | None = None + self._next_ready_event: torch.cuda.Event | None = None self._shard_paths = find_motion_shards(motion_dir) stats = compute_dataset_stats(motion_dir) self.body_names = np.asarray(stats["body_names"], dtype=str) @@ -166,77 +326,132 @@ def __init__( if total_weight <= 0.0: raise ValueError(f"HDF5 motion dataset has no positive sample duration: {motion_dir}") self.generation = 0 - self.current = self._load_random_batch() - self.next = self._load_random_batch() + self._dataset = _Hdf5MotionDataset( + refs=self.refs, + shard_paths=self._shard_paths, + body_idx_np=self.body_idx_np, + body_names=self.body_names, + window_steps=self.window_steps, + ) + self._sampler = _WeightedInfiniteClipBatchSampler( + sample_weights=self.global_sample_weights, + batch_size=self.cache_num_clips, + seed=seed, + ) + loader_kwargs: dict[str, object] = {} + if self.dataloader_num_workers > 0: + loader_kwargs["prefetch_factor"] = self.dataloader_prefetch_factor + loader_kwargs["persistent_workers"] = True + loader_kwargs["worker_init_fn"] = _motion_worker_init + self._loader = DataLoader( + self._dataset, + batch_sampler=self._sampler, + num_workers=self.dataloader_num_workers, + pin_memory=self.dataloader_pin_memory and self._device.type == "cuda", + collate_fn=_collate_motion_clips, + **loader_kwargs, + ) + self._iterator = iter(self._loader) + self.current = self._stage_batch(self._load_next_cpu_batch(), wait=True) + self._next_batch = self._stage_batch(self._load_next_cpu_batch(), wait=False) + + def _load_next_cpu_batch(self, *, log_wait: bool = False) -> _MotionBatch: + start = time.perf_counter() + batch = next(self._iterator) + elapsed = time.perf_counter() - start + if log_wait and elapsed > 1e-3: + _LOG.info( + "Waited %.3fs for asynchronous HDF5 motion cache DataLoader", + elapsed, + ) + return batch + + def _stage_batch(self, batch: _MotionBatch, *, wait: bool) -> _MotionBatch: + if self._device.type != "cuda": + tensors = {key: value.to(self._device) for key, value in batch.tensors.items()} + return _MotionBatch( + tensors=tensors, + lengths=batch.lengths.to(self._device), + fps=batch.fps.to(self._device), + sample_starts=batch.sample_starts.to(self._device), + sample_ends=batch.sample_ends.to(self._device), + global_ids=batch.global_ids.to(self._device), + ) + + if self._copy_stream is None: + self._copy_stream = torch.cuda.Stream(device=self._device) + with torch.cuda.stream(self._copy_stream): + tensors = { + key: value.to(self._device, non_blocking=True) + for key, value in batch.tensors.items() + } + staged = _MotionBatch( + tensors=tensors, + lengths=batch.lengths.to(self._device, non_blocking=True), + fps=batch.fps.to(self._device, non_blocking=True), + sample_starts=batch.sample_starts.to(self._device, non_blocking=True), + sample_ends=batch.sample_ends.to(self._device, non_blocking=True), + global_ids=batch.global_ids.to(self._device, non_blocking=True), + ) + event = torch.cuda.Event() + event.record(self._copy_stream) + if wait: + torch.cuda.current_stream(self._device).wait_event(event) + else: + self._next_ready_event = event + return staged + + def _wait_next_ready(self) -> None: + if self._next_ready_event is None: + return + if self._device.type == "cuda": + torch.cuda.current_stream(self._device).wait_event(self._next_ready_event) + self._next_ready_event = None + + def _materialize_legacy_batch(self, global_ids: torch.Tensor) -> _MotionBatch: + samples = [self._dataset[int(idx)] for idx in global_ids.tolist()] + batch = _collate_motion_clips(samples) + return self._stage_batch(batch, wait=True) def _sample_global_ids(self) -> torch.Tensor: - return torch.multinomial( - self.global_sample_weights, - self.cache_num_clips, - replacement=True, - generator=self._rng, - ) + ids = next(iter(self._sampler)) + return torch.tensor(ids, dtype=torch.long) def _load_random_batch(self) -> _MotionBatch: - return self._load_batch(self._sample_global_ids()) + return self._materialize_legacy_batch(self._sample_global_ids()) def _load_batch(self, global_ids: torch.Tensor) -> _MotionBatch: - ids_np = global_ids.cpu().numpy().astype(np.int64) - selected = [self.refs[int(idx)] for idx in ids_np] - max_len = max(ref.length for ref in selected) - arrays: dict[str, np.ndarray] = {} - for key in MOTION_ARRAY_KEYS: - sample_shape: tuple[int, ...] - if key in ("joint_pos", "joint_vel"): - sample_shape = (29,) - elif key == "body_quat_w": - sample_shape = (len(self.body_idx_np), 4) - else: - sample_shape = (len(self.body_idx_np), 3) - arrays[key] = np.zeros((len(selected), max_len, *sample_shape), dtype=np.float32) + return self._materialize_legacy_batch(global_ids.cpu().to(dtype=torch.long)) - for out_i, ref in enumerate(selected): - shard_path = self._shard_paths[ref.shard_index] - sl = slice(ref.start, ref.start + ref.length) - with h5py.File(shard_path, "r") as h5: - root_pos = np.asarray(h5["root_pos"][sl], dtype=np.float32) - root_quat_w = np.asarray(h5["root_quat_w"][sl], dtype=np.float32) - joint_pos = np.asarray(h5["joint_pos"][sl], dtype=np.float32) - - dt = 1.0 / float(ref.fps) - body_pos_w, body_quat_w = self._fk_extractor.extract( - root_pos, - root_quat_w, - joint_pos, - self.body_names, + def advance(self) -> None: + start = time.perf_counter() + self._wait_next_ready() + elapsed = time.perf_counter() - start + if elapsed > 1e-3: + _LOG.info( + "Waited %.3fs for asynchronous HDF5 motion cache staging", + elapsed, ) - body_lin_vel_w, body_ang_vel_w = compute_body_velocities(body_pos_w, body_quat_w, dt) - joint_vel = finite_diff_velocity(joint_pos, dt) - - arrays["joint_pos"][out_i, :ref.length] = joint_pos - arrays["joint_vel"][out_i, :ref.length] = joint_vel - arrays["body_pos_w"][out_i, :ref.length] = body_pos_w[:, self.body_idx_np] - arrays["body_quat_w"][out_i, :ref.length] = body_quat_w[:, self.body_idx_np] - arrays["body_lin_vel_w"][out_i, :ref.length] = body_lin_vel_w[:, self.body_idx_np] - arrays["body_ang_vel_w"][out_i, :ref.length] = body_ang_vel_w[:, self.body_idx_np] - - lengths_np = np.asarray([ref.length for ref in selected], dtype=np.int64) - starts_np, ends_np = compute_clip_sample_ranges(lengths_np, window_steps=self.window_steps) - tensors = {key: torch.from_numpy(value).to(self.device) for key, value in arrays.items()} - return _MotionBatch( - tensors=tensors, - lengths=torch.tensor(lengths_np, dtype=torch.long, device=self.device), - fps=torch.tensor([ref.fps for ref in selected], dtype=torch.float32, device=self.device), - sample_starts=torch.tensor(starts_np, dtype=torch.long, device=self.device), - sample_ends=torch.tensor(ends_np, dtype=torch.long, device=self.device), - global_ids=global_ids.to(self.device), + self.current = self._next_batch + self._next_batch = self._stage_batch( + self._load_next_cpu_batch(log_wait=True), + wait=False, ) - - def advance(self) -> None: - self.current = self.next - self.next = self._load_random_batch() self.generation += 1 + def close(self) -> None: + self._dataset.close() + iterator = getattr(self, "_iterator", None) + shutdown = getattr(iterator, "_shutdown_workers", None) + if callable(shutdown): + shutdown() + + def __del__(self) -> None: + try: + self.close() + except Exception: + pass + class MotionLib: """Clip-aware motion library. @@ -255,6 +470,9 @@ def __init__( window_steps: tuple[int, ...] | list[int] | None = None, cache_num_clips: int = 1024, cache_seed: int = 0, + dataloader_num_workers: int = 2, + dataloader_prefetch_factor: int = 1, + dataloader_pin_memory: bool = True, ) -> None: self._device = device self.window_steps = parse_window_steps(window_steps) @@ -295,6 +513,9 @@ def __init__( window_steps=self.window_steps, cache_num_clips=cache_num_clips, seed=cache_seed, + dataloader_num_workers=dataloader_num_workers, + dataloader_prefetch_factor=dataloader_prefetch_factor, + dataloader_pin_memory=dataloader_pin_memory, ) self._set_batch(self._cache.current) @@ -325,6 +546,9 @@ def advance_cache(self) -> None: self._cache.advance() self._set_batch(self._cache.current) + def close(self) -> None: + self._cache.close() + # ------------------------------------------------------------------ # Sampling helpers # ------------------------------------------------------------------ @@ -516,6 +740,9 @@ def __init__(self, cfg: MotionCommandCfg, env: ManagerBasedRlEnv): window_steps=self.cfg.window_steps, cache_num_clips=self.cfg.cache_num_clips, cache_seed=self.cfg.cache_seed, + dataloader_num_workers=self.cfg.cache_dataloader_num_workers, + dataloader_prefetch_factor=self.cfg.cache_dataloader_prefetch_factor, + dataloader_pin_memory=self.cfg.cache_dataloader_pin_memory, ) self._motion_cache_step_counter = 0 self._motion_cache_swap_pending = False @@ -1087,6 +1314,9 @@ class MotionCommandCfg(CommandTermCfg): window_steps: tuple[int, ...] = (0,) cache_num_clips: int = 1024 cache_swap_interval_steps: int = 500 + cache_dataloader_num_workers: int = 2 + cache_dataloader_prefetch_factor: int = 1 + cache_dataloader_pin_memory: bool = True cache_seed: int = 0 rewind_prob: float = 0.8 rewind_min_steps: int = 25 From 761a2505576fdcba9a44add7f003dafeaaae64f8 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Mon, 15 Jun 2026 15:22:50 +0800 Subject: [PATCH 088/122] Add Pico motion dataset config --- train_mimic/configs/datasets/pico_motion.yaml | 12 ++++++++++++ train_mimic/configs/datasets/seed.yaml | 4 ++-- 2 files changed, 14 insertions(+), 2 deletions(-) create mode 100644 train_mimic/configs/datasets/pico_motion.yaml diff --git a/train_mimic/configs/datasets/pico_motion.yaml b/train_mimic/configs/datasets/pico_motion.yaml new file mode 100644 index 00000000..11004041 --- /dev/null +++ b/train_mimic/configs/datasets/pico_motion.yaml @@ -0,0 +1,12 @@ +name: pico_record +target_fps: 30 +preprocess: + normalize_root_xy: true + ground_align: first_frame_foot + min_frames: 22 + max_all_off_ground_s: 0.8 + off_ground_height: 0.12 +sources: + - name: pico_clips + type: npz + input: data/pico_motion/clips diff --git a/train_mimic/configs/datasets/seed.yaml b/train_mimic/configs/datasets/seed.yaml index 680df526..e5b5a146 100644 --- a/train_mimic/configs/datasets/seed.yaml +++ b/train_mimic/configs/datasets/seed.yaml @@ -4,8 +4,8 @@ preprocess: min_frames: 22 normalize_root_xy: true ground_align: none - max_all_off_ground_s: 2.0 - off_ground_height: 0.2 + max_all_off_ground_s: 0.8 + off_ground_height: 0.12 sources: - name: seed_full type: seed_csv From a222c523b2171075f7496d6f47c552d82d293ea0 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Mon, 15 Jun 2026 15:46:37 +0800 Subject: [PATCH 089/122] Remove legacy G1 sim2sim XML fallback --- tests/conftest.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 0d7ba0c8..c3d81bcf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -52,8 +52,6 @@ def find_g1_xml_path() -> str | None: root = Path(__file__).parent.parent candidates = [ root / "teleopit" / "retargeting" / "gmr" / "assets" / "unitree_g1" / "g1_mjlab.xml", - root / "GMR" / "assets" / "unitree_g1" / "g1_sim2sim_29dof.xml", - root / "teleopit" / "retargeting" / "gmr" / "assets" / "unitree_g1" / "g1_sim2sim_29dof.xml", ] for path in candidates: if path.exists(): From e7c1795625e2f7735c4fe6a5c125fb34c4fd9771 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Mon, 15 Jun 2026 17:00:07 +0800 Subject: [PATCH 090/122] Use canonical G1 robot assets --- .gitignore | 5 ++++- AGENTS.md | 11 +++++----- README.md | 6 ++++- docs/docs/getting-started/download-assets.md | 6 +++-- docs/docs/getting-started/quick-start.md | 2 +- docs/docs/reference/assets.md | 13 ++++++----- .../reference/training-troubleshooting.md | 2 +- docs/docs/tutorials/pico-sim2sim.md | 2 +- .../getting-started/download-assets.md | 6 +++-- .../current/getting-started/quick-start.md | 2 +- .../current/reference/assets.md | 15 ++++++++----- .../reference/training-troubleshooting.md | 2 +- .../current/tutorials/pico-sim2sim.md | 2 +- scripts/dev/compute_ik_offsets.py | 12 +++------- scripts/render/render_sim.py | 13 ++--------- scripts/run/record_pico_motion.py | 4 ++-- scripts/run/standalone_standing.py | 2 +- scripts/setup/download_assets.py | 2 +- scripts/setup/prepare_modelscope_assets.py | 2 +- scripts/setup/upload_hf_assets.py | 4 ++-- scripts/view/view_dataset.py | 4 ++-- teleopit/configs/robot/g1.yaml | 2 +- teleopit/retargeting/gmr/params.py | 4 +++- teleopit/robots/mujoco_robot.py | 14 ++++++++---- teleopit/runtime/assets.py | 6 +++-- teleopit/runtime/external_assets.py | 8 +++++++ tests/conftest.py | 2 +- tests/test_download_assets.py | 11 ++++++++++ tests/test_motion_sampling.py | 21 ++++++++++++++++++ train_mimic/data/motion_fk.py | 4 ++-- .../scripts/data/check_motion_npz_fk.py | 2 +- train_mimic/tasks/tracking/config/env.py | 22 ++++++++++++++++++- 32 files changed, 143 insertions(+), 70 deletions(-) diff --git a/.gitignore b/.gitignore index 00126fc3..33034add 100644 --- a/.gitignore +++ b/.gitignore @@ -91,7 +91,10 @@ Thumbs.db # GMR retargeting assets (downloaded from ModelScope) teleopit/retargeting/gmr/assets/ -# Legacy training-side robot assets are no longer tracked; FK tooling reuses GMR unitree_g1. +# Canonical robot XML/meshes (downloaded from ModelScope) +assets/robots/ + +# Legacy training-side robot assets are no longer tracked. train_mimic/assets/ # ModelScope download cache diff --git a/AGENTS.md b/AGENTS.md index a24912e5..04f849e4 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -87,7 +87,7 @@ train_mimic/ # Training package ### Sim2Sim Pipeline - Policy runs at 50Hz, PD control at 1000Hz (`decimation=20`, `sim_dt=0.001`) - Action flow: `compute_action()` returns raw action → `get_target_dof_pos()` applies clip `[-10, 10]`, scale, and `default_dof_pos` -- Must use `g1_mjlab.xml` for sim2sim; `g1_mocap_29dof.xml` clamps torques to `±1 Nm` and is only for kinematic retarget visualization +- Must use `assets/robots/unitree_g1/g1_29dof.xml` for training, sim2sim, dataset FK, and retargeting; it is the canonical G1 XML entry point ### Multi-Viewer Support `SimulationLoop` supports multiple simultaneous viewer windows controlled by the `viewers` config: @@ -223,15 +223,16 @@ python train_mimic/scripts/save_onnx.py --checkpoint logs/rsl_rl/g1_general_trac ``` ### GMR Retargeting -- Self-contained in `teleopit/retargeting/gmr/`; assets need `scripts/setup/download_assets.py --only gmr` +- Self-contained in `teleopit/retargeting/gmr/`; assets need `scripts/setup/download_assets.py --only robots gmr` - Supports `lafan1` BVH (22 joints, 30fps, centimeters) - Supports `hc_mocap` BVH (50 joints, 60fps downsampled to 30fps, meters) - `lafan1-resolved` still needs an adapter layer and remains unsupported ### External Assets - Do not commit robot meshes, datasets, checkpoints, or demo media to Git; use `scripts/setup/download_assets.py` +- `assets/robots/unitree_g1/g1_29dof.xml` and its meshes are the canonical G1 robot model assets; they are downloaded from the `robots` asset group and are not tracked in Git - `teleopit/retargeting/gmr/assets/` is gitignored; downloaded at runtime -- `train_mimic/assets/` is no longer tracked; FK tooling reuses `teleopit/retargeting/gmr/assets/unitree_g1/g1_mjlab.xml` +- `train_mimic/assets/` is no longer tracked; FK tooling reuses `assets/robots/unitree_g1/g1_29dof.xml` - `third_party/linkerhand-python-sdk` and `third_party/somehand` support optional LinkerHand L6 sim2real control - Run `python scripts/check_large_tracked_files.py` before pushing @@ -248,7 +249,7 @@ Asset group → repo mapping is defined in `teleopit/runtime/external_assets.py` ```bash # 1. Prepare upload directory -python scripts/setup/prepare_modelscope_assets.py --only ckpt gmr bvh --clean +python scripts/setup/prepare_modelscope_assets.py --only ckpt robots gmr bvh --clean python scripts/setup/prepare_modelscope_assets.py --only data # 2. Upload to each repo @@ -299,4 +300,4 @@ pytest tests/ -v ## Known Issues 1. `lafan1-resolved` retargeting is still broken because it uses a different BVH skeleton layout. -2. `g1_mocap_29dof.xml` still has `ctrlrange="-1 1"`; never use it for sim2sim. +2. Legacy downloaded GMR XMLs under `teleopit/retargeting/gmr/assets/unitree_g1/` are not the project entry point; use `assets/robots/unitree_g1/g1_29dof.xml`. diff --git a/README.md b/README.md index fc121b3b..ed796e38 100644 --- a/README.md +++ b/README.md @@ -32,9 +32,13 @@ pip install -e . ```bash pip install modelscope -python scripts/setup/download_assets.py --only gmr ckpt bvh +python scripts/setup/download_assets.py --only robots gmr ckpt bvh ``` +The canonical Unitree G1 robot model is downloaded to +`assets/robots/unitree_g1/g1_29dof.xml`. Training, sim2sim, retargeting, and FK +validation all use this same XML. + **3. Run** ```bash diff --git a/docs/docs/getting-started/download-assets.md b/docs/docs/getting-started/download-assets.md index 6db8fde2..7f4de486 100644 --- a/docs/docs/getting-started/download-assets.md +++ b/docs/docs/getting-started/download-assets.md @@ -20,7 +20,7 @@ python scripts/setup/download_assets.py Download only what you need for inference: ```bash -python scripts/setup/download_assets.py --only gmr ckpt bvh +python scripts/setup/download_assets.py --only robots gmr ckpt bvh ``` ## Asset Inventory @@ -31,13 +31,15 @@ python scripts/setup/download_assets.py --only gmr ckpt bvh | `track.pt` | 27 MB | PyTorch checkpoint (for resume training) | | `data/datasets/seed/shard_*.h5` | ~26 GB | Training dataset | | `data/sample_bvh/*.bvh` | 5 MB | Sample motion files | -| `teleopit/retargeting/gmr/assets/` | ~1.2 GB | GMR retargeting robot models | +| `assets/robots/unitree_g1/` | ~52 MB | Canonical G1 XML and meshes used by training, sim2sim, retargeting, and FK validation | +| `teleopit/retargeting/gmr/assets/` | ~1.2 GB | GMR retargeting assets, IK configs, and non-canonical robot descriptions | ## Asset Groups | Group | ModelScope Repo | Contents | |-------|----------------|----------| | `ckpt` | `BingqianWu/Teleopit-models` | `track.onnx`, `track.pt` | +| `robots` | `BingqianWu/Teleopit-models` | Canonical robot XML/meshes | | `gmr` | `BingqianWu/Teleopit-models` | GMR retargeting assets | | `bvh` | `BingqianWu/Teleopit-models` | Sample BVH motion files | | `data` | `BingqianWu/Teleopit-datasets` | Training/validation shards | diff --git a/docs/docs/getting-started/quick-start.md b/docs/docs/getting-started/quick-start.md index dc59e17e..41203821 100644 --- a/docs/docs/getting-started/quick-start.md +++ b/docs/docs/getting-started/quick-start.md @@ -9,7 +9,7 @@ This guide walks you through running your first sim2sim playback in under 5 minu ## Prerequisites 1. [Install Teleopit](installation) (inference profile) -2. [Download assets](download-assets) (`--only gmr ckpt bvh`) +2. [Download assets](download-assets) (`--only robots gmr ckpt bvh`) ## Run Offline Sim2Sim diff --git a/docs/docs/reference/assets.md b/docs/docs/reference/assets.md index 5408881e..92322c10 100644 --- a/docs/docs/reference/assets.md +++ b/docs/docs/reference/assets.md @@ -4,11 +4,12 @@ sidebar_position: 2 # Asset Management -Robot models, datasets, checkpoints, and demo media are not tracked in Git. They are distributed via ModelScope and HuggingFace. +Datasets, checkpoints, robot models, and demo media are not tracked in Git. They are distributed via ModelScope and HuggingFace. The canonical Unitree G1 model is downloaded to `assets/robots/unitree_g1/g1_29dof.xml`. ## What's Not in Git -- `teleopit/retargeting/gmr/assets/` - Robot meshes, URDF/MJCF +- `assets/robots/` - Canonical robot XML/meshes +- `teleopit/retargeting/gmr/assets/` - GMR retargeting assets, IK configs, and non-canonical robot descriptions - `data/`, checkpoints, caches - Demo media (`assets/demo.gif`, `assets/demo.mp4`) @@ -33,6 +34,7 @@ Robot models, datasets, checkpoints, and demo media are not tracked in Git. They | Group | Repository | Remote Path | |-------|-----------|-------------| | `ckpt` | Teleopit-models | `checkpoints/track.onnx`, `checkpoints/track.pt` | +| `robots` | Teleopit-models | `archives/robot_assets.tar.gz` | | `gmr` | Teleopit-models | `archives/gmr_assets.tar.gz` | | `bvh` | Teleopit-models | `archives/sample_bvh.tar.gz` | | `data` | Teleopit-datasets | `data/` | @@ -46,7 +48,7 @@ Use the project download script (defaults to ModelScope): python scripts/setup/download_assets.py # Only inference essentials -python scripts/setup/download_assets.py --only gmr ckpt bvh +python scripts/setup/download_assets.py --only robots gmr ckpt bvh # Only training data python scripts/setup/download_assets.py --only data @@ -61,6 +63,7 @@ Local paths after download: |--------|-------| | `checkpoints/track.onnx` | `track.onnx` | | `checkpoints/track.pt` | `track.pt` | +| `archives/robot_assets.tar.gz` | `assets/robots/` (extracted) | | `archives/gmr_assets.tar.gz` | `teleopit/retargeting/gmr/assets/` (extracted) | | `archives/sample_bvh.tar.gz` | `data/sample_bvh/` (extracted) | | `data/` | `data/datasets/seed/` | @@ -70,7 +73,7 @@ Local paths after download: ### Step 1: Prepare Upload Directory ```bash -python scripts/setup/prepare_modelscope_assets.py --only ckpt gmr bvh --clean +python scripts/setup/prepare_modelscope_assets.py --only ckpt robots gmr bvh --clean python scripts/setup/prepare_modelscope_assets.py --only data ``` @@ -111,7 +114,7 @@ Tags should match Git tags for traceability. ```bash # Prepare and upload model assets (--clean ensures no leftover files) -python scripts/setup/upload_hf_assets.py --only ckpt gmr bvh --clean +python scripts/setup/upload_hf_assets.py --only ckpt robots gmr bvh --clean # Prepare and upload dataset python scripts/setup/upload_hf_assets.py --only data --clean diff --git a/docs/docs/reference/training-troubleshooting.md b/docs/docs/reference/training-troubleshooting.md index b7c82249..61475ad9 100644 --- a/docs/docs/reference/training-troubleshooting.md +++ b/docs/docs/reference/training-troubleshooting.md @@ -168,6 +168,6 @@ print(cfg.init_state.joint_pos) # Must match g1.yaml default_angles ### Solution -Update `teleopit/configs/robot/g1.yaml` and `g1_mjlab.xml` to match training environment values (default angles, armature, condim). +Update `teleopit/configs/robot/g1.yaml` and `assets/robots/unitree_g1/g1_29dof.xml` to match training environment values (default angles, armature, condim). This fix also affects the sim2real path since `default_angles` is shared by `rl_policy.py` and `observation.py`. diff --git a/docs/docs/tutorials/pico-sim2sim.md b/docs/docs/tutorials/pico-sim2sim.md index b72dbcd1..ff354dd9 100644 --- a/docs/docs/tutorials/pico-sim2sim.md +++ b/docs/docs/tutorials/pico-sim2sim.md @@ -53,7 +53,7 @@ Teleopit targets pico-bridge 0.2.1 and its `pico_native` tracking semantics. ```bash pip install modelscope -python scripts/setup/download_assets.py --only gmr ckpt bvh +python scripts/setup/download_assets.py --only robots gmr ckpt bvh ``` ## 4. Run Pico Sim2Sim diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/download-assets.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/download-assets.md index 61ce714e..014a0e6b 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/download-assets.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/download-assets.md @@ -20,7 +20,7 @@ python scripts/setup/download_assets.py 只下载推理所需的资源: ```bash -python scripts/setup/download_assets.py --only gmr ckpt bvh +python scripts/setup/download_assets.py --only robots gmr ckpt bvh ``` ## 资源清单 @@ -31,13 +31,15 @@ python scripts/setup/download_assets.py --only gmr ckpt bvh | `track.pt` | 27 MB | PyTorch 检查点(用于恢复训练) | | `data/datasets/seed/shard_*.h5` | ~26 GB | 训练数据集 | | `data/sample_bvh/*.bvh` | 5 MB | 示例动捕文件 | -| `teleopit/retargeting/gmr/assets/` | ~1.2 GB | GMR 重定向机器人模型 | +| `assets/robots/unitree_g1/` | ~52 MB | 训练、sim2sim、重定向和 FK 校验共用的 G1 canonical XML 与 mesh | +| `teleopit/retargeting/gmr/assets/` | ~1.2 GB | GMR 重定向资源、IK 配置和非 canonical 机器人描述 | ## 资源分组 | 分组 | ModelScope 仓库 | 包含内容 | |------|-----------------|----------| | `ckpt` | `BingqianWu/Teleopit-models` | `track.onnx`、`track.pt` | +| `robots` | `BingqianWu/Teleopit-models` | Canonical 机器人 XML/mesh | | `gmr` | `BingqianWu/Teleopit-models` | GMR 重定向资源 | | `bvh` | `BingqianWu/Teleopit-models` | 示例 BVH 动捕文件 | | `data` | `BingqianWu/Teleopit-datasets` | 训练 / 验证数据分片 | diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/quick-start.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/quick-start.md index 6c1b6f3a..5b05d1c4 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/quick-start.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/quick-start.md @@ -9,7 +9,7 @@ sidebar_position: 3 ## 前置条件 1. [安装 Teleopit](installation)(推理配置) -2. [下载资源](download-assets)(`--only gmr ckpt bvh`) +2. [下载资源](download-assets)(`--only robots gmr ckpt bvh`) ## 运行离线 Sim2Sim diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/assets.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/assets.md index a0f7643c..2451e7b7 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/assets.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/assets.md @@ -4,11 +4,12 @@ sidebar_position: 2 # 资源管理 -机器人模型、数据集、checkpoint 和演示媒体不进 Git 历史,统一走外部下载。 +数据集、checkpoint、机器人模型和演示媒体不进 Git 历史,统一走外部下载。Unitree G1 的 canonical 模型下载到 `assets/robots/unitree_g1/g1_29dof.xml`。 ## 不入库的内容 -- `teleopit/retargeting/gmr/assets/` — 机器人 mesh、URDF/MJCF 等 +- `assets/robots/` — canonical 机器人 XML/mesh +- `teleopit/retargeting/gmr/assets/` — GMR 重定向资源、IK 配置和非 canonical 机器人描述 - `data/`、checkpoint、缓存等生成产物 - 演示媒体(`assets/demo.gif`、`assets/demo.mp4`) @@ -33,6 +34,7 @@ sidebar_position: 2 | 组 | 仓库 | 远端路径 | |----|------|---------| | `ckpt` | Teleopit-models | `checkpoints/track.onnx`、`checkpoints/track.pt` | +| `robots` | Teleopit-models | `archives/robot_assets.tar.gz` | | `gmr` | Teleopit-models | `archives/gmr_assets.tar.gz` | | `bvh` | Teleopit-models | `archives/sample_bvh.tar.gz` | | `data` | Teleopit-datasets | `data/` | @@ -46,7 +48,7 @@ sidebar_position: 2 python scripts/setup/download_assets.py # 只下载推理必需的资源 -python scripts/setup/download_assets.py --only gmr ckpt bvh +python scripts/setup/download_assets.py --only robots gmr ckpt bvh # 只下载训练数据 python scripts/setup/download_assets.py --only data @@ -61,6 +63,7 @@ python scripts/setup/download_assets.py --source huggingface |---------|---------| | `checkpoints/track.onnx` | `track.onnx` | | `checkpoints/track.pt` | `track.pt` | +| `archives/robot_assets.tar.gz` | `assets/robots/`(自动解压) | | `archives/gmr_assets.tar.gz` | `teleopit/retargeting/gmr/assets/`(自动解压) | | `archives/sample_bvh.tar.gz` | `data/sample_bvh/`(自动解压) | | `data/` | `data/datasets/seed/` | @@ -70,7 +73,7 @@ python scripts/setup/download_assets.py --source huggingface ### 第一步:准备上传目录 ```bash -python scripts/setup/prepare_modelscope_assets.py --only ckpt gmr bvh --clean +python scripts/setup/prepare_modelscope_assets.py --only ckpt robots gmr bvh --clean python scripts/setup/prepare_modelscope_assets.py --only data ``` @@ -114,7 +117,7 @@ tag 与代码仓库的 Git tag 保持一致,方便追溯每个版本对应的 python scripts/setup/upload_hf_assets.py --dry-run --clean # 只准备指定组 -python scripts/setup/upload_hf_assets.py --only ckpt gmr bvh --dry-run +python scripts/setup/upload_hf_assets.py --only ckpt robots gmr bvh --dry-run python scripts/setup/upload_hf_assets.py --only data --dry-run ``` @@ -123,7 +126,7 @@ python scripts/setup/upload_hf_assets.py --only data --dry-run ### 第二步:执行上传 ```bash -python scripts/setup/upload_hf_assets.py --only ckpt gmr bvh --clean +python scripts/setup/upload_hf_assets.py --only ckpt robots gmr bvh --clean python scripts/setup/upload_hf_assets.py --only data --clean ``` diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/training-troubleshooting.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/training-troubleshooting.md index ea9e6f95..fb9f6b0d 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/training-troubleshooting.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/training-troubleshooting.md @@ -168,6 +168,6 @@ print(cfg.init_state.joint_pos) # 必须与 g1.yaml default_angles 一致 ### 解决方案 -更新 `teleopit/configs/robot/g1.yaml` 和 `g1_mjlab.xml`,使其与训练环境的值一致(default angles、armature、condim)。 +更新 `teleopit/configs/robot/g1.yaml` 和 `assets/robots/unitree_g1/g1_29dof.xml`,使其与训练环境的值一致(default angles、armature、condim)。 此修复同时影响 sim2real 路径,因为 `default_angles` 被 `rl_policy.py` 和 `observation.py` 共用。 diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2sim.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2sim.md index 1ea75b66..02eb6143 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2sim.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2sim.md @@ -52,7 +52,7 @@ Teleopit 面向 pico-bridge 0.2.1 及其 `pico_native` tracking 语义。 ```bash pip install modelscope -python scripts/setup/download_assets.py --only gmr ckpt bvh +python scripts/setup/download_assets.py --only robots gmr ckpt bvh ``` ## 4. 运行 Pico Sim2Sim diff --git a/scripts/dev/compute_ik_offsets.py b/scripts/dev/compute_ik_offsets.py index 283817af..74366c17 100644 --- a/scripts/dev/compute_ik_offsets.py +++ b/scripts/dev/compute_ik_offsets.py @@ -33,6 +33,8 @@ import numpy as np from scipy.spatial.transform import Rotation as R +from teleopit.runtime.assets import UNITREE_G1_XML + PROJECT_ROOT = Path(__file__).resolve().parents[2] @@ -184,15 +186,7 @@ def main(): if not config_path.is_absolute(): config_path = PROJECT_ROOT / config_path - xml_path = str( - PROJECT_ROOT - / "teleopit" - / "retargeting" - / "gmr" - / "assets" - / "unitree_g1" - / "g1_mocap_29dof.xml" - ) + xml_path = str(UNITREE_G1_XML) with open(config_path) as f: config = json.load(f) diff --git a/scripts/render/render_sim.py b/scripts/render/render_sim.py index 94d71262..fe3d0201 100644 --- a/scripts/render/render_sim.py +++ b/scripts/render/render_sim.py @@ -31,6 +31,7 @@ import mujoco # noqa: E402 from teleopit.debug.rollout_trace import RolloutTraceWriter # noqa: E402 +from teleopit.runtime.assets import UNITREE_G1_XML # noqa: E402 from teleopit.sim.mocap_mujoco import ( # noqa: E402 MocapSkeletonSceneDrawer, fit_mocap_camera, @@ -191,17 +192,7 @@ def render_retarget( project_root = _find_project_root() cfgs = _load_configs(str(bvh_path), project_root, bvh_format, policy_path=None) - # Retarget rendering uses the GMR mocap XML (no actuator limits needed) - mocap_xml = ( - project_root - / "teleopit" - / "retargeting" - / "gmr" - / "assets" - / "unitree_g1" - / "g1_mocap_29dof.xml" - ) - cfgs["robot"].xml_path = str(mocap_xml) + cfgs["robot"].xml_path = str(UNITREE_G1_XML) from teleopit.inputs import BVHInputProvider from teleopit.retargeting.core import RetargetingModule diff --git a/scripts/run/record_pico_motion.py b/scripts/run/record_pico_motion.py index 380afbc4..3fb1bff1 100644 --- a/scripts/run/record_pico_motion.py +++ b/scripts/run/record_pico_motion.py @@ -22,7 +22,7 @@ write_motion_clip_npz, ) from teleopit.retargeting.core import RetargetingModule -from teleopit.runtime.assets import PROJECT_ROOT, UNITREE_G1_MJLAB_XML, missing_gmr_assets_message +from teleopit.runtime.assets import PROJECT_ROOT, UNITREE_G1_XML, missing_gmr_assets_message from teleopit.runtime.common import cfg_get from teleopit.runtime.terminal_keyboard import TerminalKeyboardReader from teleopit.sim.viewer_subprocess import start_robot_viewer @@ -31,7 +31,7 @@ class RetargetPreview: """Small wrapper around the existing MuJoCo retarget viewer subprocess.""" - def __init__(self, xml_path: str | Path = UNITREE_G1_MJLAB_XML, *, enabled: bool = True) -> None: + def __init__(self, xml_path: str | Path = UNITREE_G1_XML, *, enabled: bool = True) -> None: self.enabled = bool(enabled) self._proc = None self._arr = None diff --git a/scripts/run/standalone_standing.py b/scripts/run/standalone_standing.py index cbab7720..82474188 100644 --- a/scripts/run/standalone_standing.py +++ b/scripts/run/standalone_standing.py @@ -111,7 +111,7 @@ ], dtype=np.float32) # MuJoCo XML path for FK -MJCF_PATH = _REPO_ROOT / "teleopit" / "retargeting" / "gmr" / "assets" / "unitree_g1" / "g1_mjlab.xml" +MJCF_PATH = _REPO_ROOT / "assets" / "robots" / "unitree_g1" / "g1_29dof.xml" JOINT_MAP = list(range(NUM_JOINTS)) GRAVITY_UNIT_W = np.array([0.0, 0.0, -1.0], dtype=np.float32) diff --git a/scripts/setup/download_assets.py b/scripts/setup/download_assets.py index fcaa303f..77d3dae3 100755 --- a/scripts/setup/download_assets.py +++ b/scripts/setup/download_assets.py @@ -4,7 +4,7 @@ Usage: python scripts/setup/download_assets.py # download everything (ModelScope) python scripts/setup/download_assets.py --source huggingface # download from HuggingFace - python scripts/setup/download_assets.py --only gmr # only GMR retargeting assets + python scripts/setup/download_assets.py --only robots gmr # only robot XML/meshes + GMR retargeting assets python scripts/setup/download_assets.py --only ckpt # only checkpoints python scripts/setup/download_assets.py --only data # only training data python scripts/setup/download_assets.py --only bvh # only sample BVH diff --git a/scripts/setup/prepare_modelscope_assets.py b/scripts/setup/prepare_modelscope_assets.py index 528fea72..ccb77ae3 100644 --- a/scripts/setup/prepare_modelscope_assets.py +++ b/scripts/setup/prepare_modelscope_assets.py @@ -2,7 +2,7 @@ """Prepare a compact upload directory for the ModelScope asset repos. Asset groups by destination repo: - model (BingqianWu/Teleopit-models): ckpt, gmr, bvh + model (BingqianWu/Teleopit-models): ckpt, robots, gmr, bvh dataset (BingqianWu/Teleopit-datasets): data """ diff --git a/scripts/setup/upload_hf_assets.py b/scripts/setup/upload_hf_assets.py index 1fe4d815..46bfe156 100644 --- a/scripts/setup/upload_hf_assets.py +++ b/scripts/setup/upload_hf_assets.py @@ -2,12 +2,12 @@ """Prepare and upload Teleopit assets to HuggingFace repos. Asset groups by destination repo: - model (12e21/Teleopit-models): ckpt, gmr, bvh + model (12e21/Teleopit-models): ckpt, robots, gmr, bvh dataset (12e21/Teleopit-datasets): data Usage: python scripts/setup/upload_hf_assets.py # upload everything - python scripts/setup/upload_hf_assets.py --only gmr # only GMR assets + python scripts/setup/upload_hf_assets.py --only robots gmr # only robot + GMR assets python scripts/setup/upload_hf_assets.py --dry-run # prepare files without uploading """ diff --git a/scripts/view/view_dataset.py b/scripts/view/view_dataset.py index 44fee0fb..582f50e7 100644 --- a/scripts/view/view_dataset.py +++ b/scripts/view/view_dataset.py @@ -19,10 +19,10 @@ from mjlab.viewer.viser import ViserMujocoScene -from teleopit.runtime.assets import UNITREE_G1_MJLAB_XML, missing_gmr_assets_message +from teleopit.runtime.assets import UNITREE_G1_XML, missing_gmr_assets_message from train_mimic.data.dataset_lib import find_motion_shards, read_motion_clip -DEFAULT_XML = UNITREE_G1_MJLAB_XML +DEFAULT_XML = UNITREE_G1_XML @dataclass(frozen=True) diff --git a/teleopit/configs/robot/g1.yaml b/teleopit/configs/robot/g1.yaml index f46d4487..014e84f1 100644 --- a/teleopit/configs/robot/g1.yaml +++ b/teleopit/configs/robot/g1.yaml @@ -46,7 +46,7 @@ ls_iterations: 20 # Use built-in PD actuators matching mjlab training builtin_pd: true -xml_path: "teleopit/retargeting/gmr/assets/unitree_g1/g1_mjlab.xml" +xml_path: "assets/robots/unitree_g1/g1_29dof.xml" base_body: "pelvis" ang_vel_scale: 0.25 diff --git a/teleopit/retargeting/gmr/params.py b/teleopit/retargeting/gmr/params.py index 7c87c30d..cca6c564 100644 --- a/teleopit/retargeting/gmr/params.py +++ b/teleopit/retargeting/gmr/params.py @@ -1,5 +1,7 @@ from pathlib import Path +from teleopit.runtime.assets import UNITREE_G1_XML + BASE_DIR = Path(__file__).parent @@ -11,7 +13,7 @@ def _resolve_path(relative_path): ASSET_ROOT = _resolve_path("assets") ROBOT_XML_DICT = { - "unitree_g1": _resolve_path("assets/unitree_g1/g1_mocap_29dof.xml"), + "unitree_g1": UNITREE_G1_XML, "unitree_g1_with_hands": _resolve_path("assets/unitree_g1/g1_mocap_29dof_with_hands.xml"), "unitree_h1": _resolve_path("assets/unitree_h1/h1.xml"), "unitree_h1_2": _resolve_path("assets/unitree_h1_2/h1_2_handless.xml"), diff --git a/teleopit/robots/mujoco_robot.py b/teleopit/robots/mujoco_robot.py index b3283921..bdaaf5db 100644 --- a/teleopit/robots/mujoco_robot.py +++ b/teleopit/robots/mujoco_robot.py @@ -8,7 +8,11 @@ from omegaconf import DictConfig from teleopit.interfaces import RobotState -from teleopit.runtime.assets import GMR_ASSETS_ROOT, missing_gmr_assets_message +from teleopit.runtime.assets import ( + GMR_ASSETS_ROOT, + ROBOT_ASSETS_ROOT, + missing_gmr_assets_message, +) def _quat_conjugate(quat_wxyz: np.ndarray) -> np.ndarray: @@ -62,9 +66,11 @@ def __init__(self, cfg: DictConfig) -> None: xml_path = Path.cwd() / xml_path xml_path = xml_path.resolve() if not xml_path.exists(): - try: - xml_path.relative_to(GMR_ASSETS_ROOT) - except ValueError: + asset_roots = (ROBOT_ASSETS_ROOT, GMR_ASSETS_ROOT) + is_external_asset = any( + xml_path.is_relative_to(root) for root in asset_roots + ) + if not is_external_asset: raise FileNotFoundError(f"MuJoCo XML not found: {xml_path}") from None raise FileNotFoundError( missing_gmr_assets_message(xml_path, label="MuJoCo XML") diff --git a/teleopit/runtime/assets.py b/teleopit/runtime/assets.py index e9be0638..23fc44ff 100644 --- a/teleopit/runtime/assets.py +++ b/teleopit/runtime/assets.py @@ -4,8 +4,10 @@ PROJECT_ROOT = Path(__file__).resolve().parents[2] +ROBOT_ASSETS_ROOT = PROJECT_ROOT / "assets" / "robots" GMR_ASSETS_ROOT = PROJECT_ROOT / "teleopit" / "retargeting" / "gmr" / "assets" -UNITREE_G1_MJLAB_XML = GMR_ASSETS_ROOT / "unitree_g1" / "g1_mjlab.xml" +UNITREE_G1_XML = ROBOT_ASSETS_ROOT / "unitree_g1" / "g1_29dof.xml" +UNITREE_G1_MJLAB_XML = UNITREE_G1_XML def missing_gmr_assets_message(path: str | Path, *, label: str = "Required asset") -> str: @@ -17,5 +19,5 @@ def missing_gmr_assets_message(path: str | Path, *, label: str = "Required asset return ( f"{label} not found: {resolved}\n" "Download the external robot assets with:\n" - " python scripts/setup/download_assets.py --only gmr" + " python scripts/setup/download_assets.py --only robots gmr" ) diff --git a/teleopit/runtime/external_assets.py b/teleopit/runtime/external_assets.py index 3bd8ed90..6fdfb7c9 100644 --- a/teleopit/runtime/external_assets.py +++ b/teleopit/runtime/external_assets.py @@ -31,6 +31,14 @@ class AssetEntry: mode="extract", ), ], + "robots": [ + AssetEntry( + "archives/robot_assets.tar.gz", + "assets/robots", + repo="model", + mode="extract", + ), + ], "bvh": [ AssetEntry( "archives/sample_bvh.tar.gz", diff --git a/tests/conftest.py b/tests/conftest.py index c3d81bcf..6f7b5110 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -51,7 +51,7 @@ def find_g1_xml_path() -> str | None: """Return the preferred test XML path for G1 MuJoCo-based tests.""" root = Path(__file__).parent.parent candidates = [ - root / "teleopit" / "retargeting" / "gmr" / "assets" / "unitree_g1" / "g1_mjlab.xml", + root / "assets" / "robots" / "unitree_g1" / "g1_29dof.xml", ] for path in candidates: if path.exists(): diff --git a/tests/test_download_assets.py b/tests/test_download_assets.py index 0e1f14d4..bdd383c1 100644 --- a/tests/test_download_assets.py +++ b/tests/test_download_assets.py @@ -40,3 +40,14 @@ def test_resolve_entry_source_uses_only_current_remote_layout(tmp_path: Path) -> ) assert _resolve_entry_source(tmp_path, entry) == archive + + +def test_robot_asset_group_uses_archive_layout() -> None: + from teleopit.runtime.external_assets import ASSET_GROUPS + + entries = ASSET_GROUPS["robots"] + + assert len(entries) == 1 + assert entries[0].remote_path == "archives/robot_assets.tar.gz" + assert entries[0].local_path == "assets/robots" + assert entries[0].mode == "extract" diff --git a/tests/test_motion_sampling.py b/tests/test_motion_sampling.py index 23e1ab77..63e4d5c9 100644 --- a/tests/test_motion_sampling.py +++ b/tests/test_motion_sampling.py @@ -183,6 +183,27 @@ def test_motion_lib_global_cache_sampling_weights_follow_valid_duration(tmp_path ) +def test_motion_cache_sampler_draws_global_ids_randomly_by_valid_duration(tmp_path: Path) -> None: + motion_path = _write_shard_dir( + tmp_path / "motion_weighted_global", + [ + _clip_dict(num_frames=3, fps=10), + _clip_dict(num_frames=11, fps=10), + ], + ) + + motion = MotionLib( + str(motion_path), + body_indexes=torch.tensor([0, 1], dtype=torch.long), + window_steps=(0,), + ) + + ids = torch.cat([motion._cache._sample_global_ids() for _ in range(256)]) + counts = torch.bincount(ids.cpu(), minlength=2).float() + + assert counts[1] > counts[0] * 3.0 + + def test_motion_lib_rejects_shard_body_name_mismatch(tmp_path: Path) -> None: motion_path = tmp_path / "motion_mismatch" clip = _clip_dict() diff --git a/train_mimic/data/motion_fk.py b/train_mimic/data/motion_fk.py index fd7f2d4d..730089c4 100644 --- a/train_mimic/data/motion_fk.py +++ b/train_mimic/data/motion_fk.py @@ -9,10 +9,10 @@ import mujoco import numpy as np -from teleopit.runtime.assets import UNITREE_G1_MJLAB_XML, missing_gmr_assets_message +from teleopit.runtime.assets import UNITREE_G1_XML, missing_gmr_assets_message -DEFAULT_G1_XML_PATH = UNITREE_G1_MJLAB_XML +DEFAULT_G1_XML_PATH = UNITREE_G1_XML def quat_xyzw_to_wxyz(q: np.ndarray) -> np.ndarray: diff --git a/train_mimic/scripts/data/check_motion_npz_fk.py b/train_mimic/scripts/data/check_motion_npz_fk.py index 6e31bc0a..621ac865 100644 --- a/train_mimic/scripts/data/check_motion_npz_fk.py +++ b/train_mimic/scripts/data/check_motion_npz_fk.py @@ -17,7 +17,7 @@ def parse_args() -> argparse.Namespace: "--xml", type=str, default=None, - help="Optional MuJoCo XML path (default: teleopit/retargeting/gmr/assets/unitree_g1/g1_mjlab.xml)", + help="Optional MuJoCo XML path (default: assets/robots/unitree_g1/g1_29dof.xml)", ) parser.add_argument( "--sample_count", diff --git a/train_mimic/tasks/tracking/config/env.py b/train_mimic/tasks/tracking/config/env.py index b4453816..540db60f 100644 --- a/train_mimic/tasks/tracking/config/env.py +++ b/train_mimic/tasks/tracking/config/env.py @@ -4,6 +4,8 @@ from copy import deepcopy +import mujoco + from mjlab.asset_zoo.robots import G1_ACTION_SCALE, get_g1_robot_cfg from mjlab.envs import ManagerBasedRlEnvCfg from mjlab.envs.mdp.actions import JointPositionActionCfg @@ -17,6 +19,7 @@ from train_mimic.tasks.tracking.config.constants import DEFAULT_TRAIN_MOTION_FILE from train_mimic.tasks.tracking.mdp import MotionCommandCfg from train_mimic.tasks.tracking.tracking_env_cfg import make_tracking_env_cfg +from teleopit.runtime.assets import UNITREE_G1_XML, missing_gmr_assets_message _TRACKING_BODY_NAMES = ( "pelvis", @@ -44,6 +47,23 @@ ) +def _get_g1_training_spec() -> mujoco.MjSpec: + if not UNITREE_G1_XML.is_file(): + raise FileNotFoundError( + missing_gmr_assets_message(UNITREE_G1_XML, label="G1 training MuJoCo XML") + ) + spec = mujoco.MjSpec.from_file(str(UNITREE_G1_XML)) + for actuator in list(spec.actuators): + spec.delete(actuator) + return spec + + +def _get_g1_training_robot_cfg(): + robot_cfg = get_g1_robot_cfg() + robot_cfg.spec_fn = _get_g1_training_spec + return robot_cfg + + def _apply_play_mode_overrides(cfg: ManagerBasedRlEnvCfg) -> None: motion_cmd = cfg.commands["motion"] assert isinstance(motion_cmd, MotionCommandCfg) @@ -169,7 +189,7 @@ def make_general_tracking_env_cfg( """Create the General-Tracking-G1 training env.""" cfg = make_tracking_env_cfg() - cfg.scene.entities = {"robot": get_g1_robot_cfg()} + cfg.scene.entities = {"robot": _get_g1_training_robot_cfg()} joint_pos_action = cfg.actions["joint_pos"] assert isinstance(joint_pos_action, JointPositionActionCfg) From 196bcc84c806d58f4cb963ff27d151eaded61f72 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Mon, 15 Jun 2026 20:20:55 +0800 Subject: [PATCH 091/122] Use flat motion cache for tracking --- teleopit/retargeting/gmr/params.py | 14 +++++- teleopit/runtime/assets.py | 2 + tests/test_motion_sampling.py | 30 +++++++++++++ train_mimic/tasks/tracking/mdp/commands.py | 52 ++++++++++++---------- 4 files changed, 72 insertions(+), 26 deletions(-) diff --git a/teleopit/retargeting/gmr/params.py b/teleopit/retargeting/gmr/params.py index cca6c564..1110eef0 100644 --- a/teleopit/retargeting/gmr/params.py +++ b/teleopit/retargeting/gmr/params.py @@ -1,6 +1,10 @@ from pathlib import Path -from teleopit.runtime.assets import UNITREE_G1_XML +from teleopit.runtime.assets import ( + UNITREE_G1_AVP_O6_XML, + UNITREE_G1_DEX3_XML, + UNITREE_G1_XML, +) BASE_DIR = Path(__file__).parent @@ -14,7 +18,8 @@ def _resolve_path(relative_path): ROBOT_XML_DICT = { "unitree_g1": UNITREE_G1_XML, - "unitree_g1_with_hands": _resolve_path("assets/unitree_g1/g1_mocap_29dof_with_hands.xml"), + "unitree_g1_with_hands": UNITREE_G1_DEX3_XML, + "unitree_g1_avp_o6": UNITREE_G1_AVP_O6_XML, "unitree_h1": _resolve_path("assets/unitree_h1/h1.xml"), "unitree_h1_2": _resolve_path("assets/unitree_h1_2/h1_2_handless.xml"), "booster_t1": _resolve_path("assets/booster_t1/T1_serial.xml"), @@ -38,6 +43,7 @@ def _resolve_path(relative_path): "smplx": { "unitree_g1": _resolve_path("ik_configs/smplx_to_g1.json"), "unitree_g1_with_hands": _resolve_path("ik_configs/smplx_to_g1.json"), + "unitree_g1_avp_o6": _resolve_path("ik_configs/smplx_to_g1.json"), "unitree_h1": _resolve_path("ik_configs/smplx_to_h1.json"), "unitree_h1_2": _resolve_path("ik_configs/smplx_to_h1_2.json"), "booster_t1": _resolve_path("ik_configs/smplx_to_t1.json"), @@ -57,6 +63,7 @@ def _resolve_path(relative_path): "bvh_lafan1": { "unitree_g1": _resolve_path("ik_configs/bvh_lafan1_to_g1.json"), "unitree_g1_with_hands": _resolve_path("ik_configs/bvh_lafan1_to_g1.json"), + "unitree_g1_avp_o6": _resolve_path("ik_configs/bvh_lafan1_to_g1.json"), "booster_t1_29dof": _resolve_path("ik_configs/bvh_lafan1_to_t1_29dof.json"), "fourier_n1": _resolve_path("ik_configs/bvh_lafan1_to_n1.json"), "stanford_toddy": _resolve_path("ik_configs/bvh_lafan1_to_toddy.json"), @@ -76,6 +83,7 @@ def _resolve_path(relative_path): "fbx": { "unitree_g1": _resolve_path("ik_configs/fbx_to_g1.json"), "unitree_g1_with_hands": _resolve_path("ik_configs/fbx_to_g1.json"), + "unitree_g1_avp_o6": _resolve_path("ik_configs/fbx_to_g1.json"), }, "fbx_offline": { "unitree_g1": _resolve_path("ik_configs/fbx_offline_to_g1.json"), @@ -89,6 +97,7 @@ def _resolve_path(relative_path): ROBOT_BASE_DICT = { "unitree_g1": "pelvis", "unitree_g1_with_hands": "pelvis", + "unitree_g1_avp_o6": "pelvis", "unitree_h1": "pelvis", "unitree_h1_2": "pelvis", "booster_t1": "Waist", @@ -110,6 +119,7 @@ def _resolve_path(relative_path): VIEWER_CAM_DISTANCE_DICT = { "unitree_g1": 2.0, "unitree_g1_with_hands": 2.0, + "unitree_g1_avp_o6": 2.0, "unitree_h1": 3.0, "unitree_h1_2": 3.0, "booster_t1": 2.0, diff --git a/teleopit/runtime/assets.py b/teleopit/runtime/assets.py index 23fc44ff..00d2e8e7 100644 --- a/teleopit/runtime/assets.py +++ b/teleopit/runtime/assets.py @@ -7,6 +7,8 @@ ROBOT_ASSETS_ROOT = PROJECT_ROOT / "assets" / "robots" GMR_ASSETS_ROOT = PROJECT_ROOT / "teleopit" / "retargeting" / "gmr" / "assets" UNITREE_G1_XML = ROBOT_ASSETS_ROOT / "unitree_g1" / "g1_29dof.xml" +UNITREE_G1_DEX3_XML = ROBOT_ASSETS_ROOT / "unitree_g1" / "g1_29dof_dex3.xml" +UNITREE_G1_AVP_O6_XML = ROBOT_ASSETS_ROOT / "unitree_g1" / "g1_29dof_avp_o6.xml" UNITREE_G1_MJLAB_XML = UNITREE_G1_XML diff --git a/tests/test_motion_sampling.py b/tests/test_motion_sampling.py index 63e4d5c9..5a2074f1 100644 --- a/tests/test_motion_sampling.py +++ b/tests/test_motion_sampling.py @@ -129,6 +129,36 @@ def test_motion_lib_get_window_frames_returns_requested_offsets(tmp_path: Path) assert torch.allclose(current["joint_pos"][0, :1], torch.tensor([2.0], dtype=torch.float32)) +def test_motion_lib_flat_cache_offsets_select_requested_clip(tmp_path: Path) -> None: + clip0 = _clip_dict(num_frames=3) + clip1 = _clip_dict(num_frames=7) + clip1["root_pos"] = np.asarray(clip1["root_pos"]).copy() + clip1["root_pos"][:, 0] += 100.0 + clip1["joint_pos"] = np.asarray(clip1["joint_pos"]).copy() + clip1["joint_pos"][:, 0] += 100.0 + clip1["body_pos_w"] = np.asarray(clip1["body_pos_w"]).copy() + clip1["body_pos_w"][:, :, 0] += 100.0 + + motion_path = _write_shard_dir(tmp_path / "motion_flat_offsets", [clip0, clip1]) + motion = MotionLib( + str(motion_path), + body_indexes=torch.tensor([0, 1], dtype=torch.long), + window_steps=(0,), + cache_num_clips=2, + cache_seed=0, + dataloader_num_workers=0, + ) + motion._set_batch(motion._cache._load_batch(torch.tensor([0, 1], dtype=torch.long))) + + frames = motion.get_frames( + torch.tensor([1], dtype=torch.long), + torch.tensor([2.0], dtype=torch.float32), + ) + + assert torch.allclose(frames["joint_pos"][0, :1], torch.tensor([102.0])) + assert torch.allclose(frames["body_pos_w"][0, 0], torch.tensor([102.0, 0.0, 0.0])) + + def test_motion_lib_selects_bodies_by_dataset_names(tmp_path: Path) -> None: motion_path = _write_shard_dir(tmp_path / "motion_named_bodies", [_clip_dict()]) diff --git a/train_mimic/tasks/tracking/mdp/commands.py b/train_mimic/tasks/tracking/mdp/commands.py index 85c5260c..5ec1c017 100644 --- a/train_mimic/tasks/tracking/mdp/commands.py +++ b/train_mimic/tasks/tracking/mdp/commands.py @@ -78,6 +78,7 @@ class _Hdf5ClipRef: @dataclass class _MotionBatch: tensors: dict[str, torch.Tensor] + frame_offsets: torch.Tensor lengths: torch.Tensor fps: torch.Tensor sample_starts: torch.Tensor @@ -87,6 +88,7 @@ class _MotionBatch: def pin_memory(self) -> "_MotionBatch": return _MotionBatch( tensors={key: value.pin_memory() for key, value in self.tensors.items()}, + frame_offsets=self.frame_offsets.pin_memory(), lengths=self.lengths.pin_memory(), fps=self.fps.pin_memory(), sample_starts=self.sample_starts.pin_memory(), @@ -208,22 +210,19 @@ def close(self) -> None: def _collate_motion_clips(samples: list[_MotionClipSample]) -> _MotionBatch: if not samples: raise ValueError("Motion cache DataLoader produced an empty batch") - max_len = max(sample.length for sample in samples) arrays: dict[str, torch.Tensor] = {} for key in MOTION_ARRAY_KEYS: - first = samples[0].tensors[key] - arrays[key] = torch.zeros( - (len(samples), max_len, *first.shape[1:]), - dtype=torch.float32, - ) + arrays[key] = torch.cat([sample.tensors[key] for sample in samples], dim=0) - for out_i, sample in enumerate(samples): - for key, value in sample.tensors.items(): - arrays[key][out_i, :sample.length] = value + lengths = torch.tensor([sample.length for sample in samples], dtype=torch.long) + frame_offsets = torch.zeros(len(samples), dtype=torch.long) + if len(samples) > 1: + frame_offsets[1:] = torch.cumsum(lengths[:-1], dim=0) return _MotionBatch( tensors=arrays, - lengths=torch.tensor([sample.length for sample in samples], dtype=torch.long), + frame_offsets=frame_offsets, + lengths=lengths, fps=torch.tensor([sample.fps for sample in samples], dtype=torch.float32), sample_starts=torch.tensor([sample.sample_start for sample in samples], dtype=torch.long), sample_ends=torch.tensor([sample.sample_end for sample in samples], dtype=torch.long), @@ -371,6 +370,7 @@ def _stage_batch(self, batch: _MotionBatch, *, wait: bool) -> _MotionBatch: tensors = {key: value.to(self._device) for key, value in batch.tensors.items()} return _MotionBatch( tensors=tensors, + frame_offsets=batch.frame_offsets.to(self._device), lengths=batch.lengths.to(self._device), fps=batch.fps.to(self._device), sample_starts=batch.sample_starts.to(self._device), @@ -387,6 +387,7 @@ def _stage_batch(self, batch: _MotionBatch, *, wait: bool) -> _MotionBatch: } staged = _MotionBatch( tensors=tensors, + frame_offsets=batch.frame_offsets.to(self._device, non_blocking=True), lengths=batch.lengths.to(self._device, non_blocking=True), fps=batch.fps.to(self._device, non_blocking=True), sample_starts=batch.sample_starts.to(self._device, non_blocking=True), @@ -468,7 +469,7 @@ def __init__( body_names: tuple[str, ...] | list[str] | None = None, device: str = "cpu", window_steps: tuple[int, ...] | list[int] | None = None, - cache_num_clips: int = 1024, + cache_num_clips: int = 8192, cache_seed: int = 0, dataloader_num_workers: int = 2, dataloader_prefetch_factor: int = 1, @@ -527,6 +528,7 @@ def _set_batch(self, batch: _MotionBatch) -> None: self._body_quat_w_t = batch.tensors["body_quat_w"] self._body_lin_vel_w_t = batch.tensors["body_lin_vel_w"] self._body_ang_vel_w_t = batch.tensors["body_ang_vel_w"] + self.clip_frame_offsets = batch.frame_offsets self.clip_lengths = batch.lengths self.clip_fps = batch.fps @@ -538,8 +540,8 @@ def _set_batch(self, batch: _MotionBatch) -> None: self.clip_sample_ends = batch.sample_ends self.clip_sample_start_s = self.clip_sample_starts.float() * self.clip_dt self.clip_sample_end_s = self.clip_sample_ends.float() * self.clip_dt - # Kept for introspection/logging; frame interpolation is cache-local. - self.clip_starts = torch.zeros(self.num_clips, dtype=torch.long, device=self._device) + # Kept for introspection/logging; these are cache-local flat frame offsets. + self.clip_starts = self.clip_frame_offsets self.generation = self._cache.generation def advance_cache(self) -> None: @@ -633,7 +635,9 @@ def get_window_frames( steps, ) batch = motion_ids.shape[0] - batch_idx = motion_ids[:, None].expand(batch, window).reshape(-1) + frame_offsets = self.clip_frame_offsets[motion_ids] + flat_idx0 = (frame_offsets[:, None] + idx0.reshape(batch, window)).reshape(-1) + flat_idx1 = (frame_offsets[:, None] + idx1.reshape(batch, window)).reshape(-1) want = self._ALL_KEYS if keys is None else keys result: dict[str, torch.Tensor] = {} @@ -643,7 +647,7 @@ def get_window_frames( for key, arr_t in (("joint_pos", self._joint_pos_t), ("joint_vel", self._joint_vel_t)): if key not in want: continue - v0, v1 = arr_t[batch_idx, idx0], arr_t[batch_idx, idx1] + v0, v1 = arr_t[flat_idx0], arr_t[flat_idx1] result[key] = (v0 + a1 * (v1 - v0)).reshape(batch, window, -1) # body arrays: (T, B, D) — GPU gather + lerp, optionally pre-slice bodies @@ -656,21 +660,21 @@ def get_window_frames( if key not in want: continue if body_indices is not None: - v0 = arr_t[batch_idx, idx0][:, body_indices] - v1 = arr_t[batch_idx, idx1][:, body_indices] + v0 = arr_t[flat_idx0][:, body_indices] + v1 = arr_t[flat_idx1][:, body_indices] else: - v0, v1 = arr_t[batch_idx, idx0], arr_t[batch_idx, idx1] + v0, v1 = arr_t[flat_idx0], arr_t[flat_idx1] interp = v0 + a2 * (v1 - v0) result[key] = interp.reshape(batch, window, *interp.shape[1:]) # body_quat_w: GPU slerp, optionally pre-slice bodies if "body_quat_w" in want: if body_indices is not None: - q0 = self._body_quat_w_t[batch_idx, idx0][:, body_indices] - q1 = self._body_quat_w_t[batch_idx, idx1][:, body_indices] + q0 = self._body_quat_w_t[flat_idx0][:, body_indices] + q1 = self._body_quat_w_t[flat_idx1][:, body_indices] else: - q0 = self._body_quat_w_t[batch_idx, idx0] - q1 = self._body_quat_w_t[batch_idx, idx1] + q0 = self._body_quat_w_t[flat_idx0] + q1 = self._body_quat_w_t[flat_idx1] nb = q0.shape[1] q0_flat = q0.reshape(-1, 4) q1_flat = q1.reshape(-1, 4) @@ -1312,8 +1316,8 @@ class MotionCommandCfg(CommandTermCfg): joint_position_range: tuple[float, float] = (-0.52, 0.52) sampling_mode: Literal["uniform", "start", "rewind"] = "rewind" window_steps: tuple[int, ...] = (0,) - cache_num_clips: int = 1024 - cache_swap_interval_steps: int = 500 + cache_num_clips: int = 8192 + cache_swap_interval_steps: int = 2000 cache_dataloader_num_workers: int = 2 cache_dataloader_prefetch_factor: int = 1 cache_dataloader_pin_memory: bool = True From 809d480d048611e790996459ca52659bc2e68bd4 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Mon, 15 Jun 2026 20:58:27 +0800 Subject: [PATCH 092/122] Add robot XML override for training --- tests/test_train_script.py | 18 +++++++++++++++ train_mimic/scripts/train.py | 19 ++++++++++++++++ train_mimic/tasks/tracking/config/env.py | 29 ++++++++++++++++++------ 3 files changed, 59 insertions(+), 7 deletions(-) diff --git a/tests/test_train_script.py b/tests/test_train_script.py index 6474374f..ef82d381 100644 --- a/tests/test_train_script.py +++ b/tests/test_train_script.py @@ -3,6 +3,7 @@ from __future__ import annotations import argparse +from functools import partial import sys import types from pathlib import Path @@ -34,6 +35,7 @@ def _args(**overrides: object) -> argparse.Namespace: "logger": "tensorboard", "experiment_name": None, "motion_file": "data/datasets/twist2", + "robot_xml": None, "resume": None, "sampling_mode": None, "rewind_prob": None, @@ -70,6 +72,10 @@ def test_parse_args_with_gpu_ids(self) -> None: assert args.gpu_ids == [0, 2, 3] assert args.master_port == 29600 + def test_parse_args_accepts_robot_xml(self) -> None: + args = train.parse_args(["--robot_xml", "assets/robots/unitree_g1/g1_29dof_dex3.xml"]) + assert args.robot_xml == "assets/robots/unitree_g1/g1_29dof_dex3.xml" + def test_should_launch_multi_gpu(self) -> None: args = _args(gpu_ids=[0, 1, 2, 3]) assert train._should_launch_multi_gpu(args, env={"WORLD_SIZE": "1"}) is True @@ -185,6 +191,7 @@ def test_configure_swanlab_logger_syncs_tensorboard(self, monkeypatch: pytest.Mo ) }, scene=types.SimpleNamespace(num_envs=64), + robot_xml="/tmp/g1.xml", ) active = train._configure_experiment_logger( @@ -206,6 +213,7 @@ def test_configure_swanlab_logger_syncs_tensorboard(self, monkeypatch: pytest.Mo "config": { "experiment_name": "exp", "motion_file": "data/train", + "robot_xml": "/tmp/g1.xml", "num_envs": 64, "max_iterations": 10, "sampling_mode": "uniform", @@ -243,6 +251,16 @@ def test_tracking_runner_configs_disable_model_upload() -> None: assert make_general_tracking_ppo_runner_cfg().upload_model is False +def test_make_g1_training_robot_cfg_uses_requested_xml() -> None: + from train_mimic.tasks.tracking.config.env import make_g1_training_robot_cfg + + xml_path = Path("assets/robots/unitree_g1/g1_29dof_dex3.xml").resolve() + robot_cfg = make_g1_training_robot_cfg(xml_path) + + assert isinstance(robot_cfg.spec_fn, partial) + assert robot_cfg.spec_fn.args == (xml_path,) + + def test_validate_motion_file_accepts_shard_directories(tmp_path: Path) -> None: num_frames = 3 write_hdf5_motion_shard( diff --git a/train_mimic/scripts/train.py b/train_mimic/scripts/train.py index 24cd9c48..3a918cee 100644 --- a/train_mimic/scripts/train.py +++ b/train_mimic/scripts/train.py @@ -50,6 +50,10 @@ validate_motion_file, ) from train_mimic.tasks.tracking.config.constants import DEFAULT_TRAIN_MOTION_FILE +from train_mimic.tasks.tracking.config.env import ( + make_g1_training_robot_cfg, + resolve_g1_training_xml, +) def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace: @@ -75,6 +79,15 @@ def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace: parser.add_argument("--experiment_name", type=str, default=None) parser.add_argument("--motion_file", type=str, default=None, help="Dataset root containing Teleopit shard_*.h5 files, searched recursively") + parser.add_argument( + "--robot_xml", + type=str, + default=None, + help=( + "MuJoCo XML used for the G1 training robot " + "(default: assets/robots/unitree_g1/g1_29dof.xml)" + ), + ) parser.add_argument( "--resume", type=str, @@ -297,6 +310,7 @@ def _configure_experiment_logger( config={ "experiment_name": agent_cfg.experiment_name, "motion_file": env_cfg.commands["motion"].motion_file, + "robot_xml": getattr(env_cfg, "robot_xml", None), "num_envs": env_cfg.scene.num_envs, "max_iterations": agent_cfg.max_iterations, "sampling_mode": env_cfg.commands["motion"].sampling_mode, @@ -379,6 +393,11 @@ def _handle_shutdown(signum: int, _frame: Any) -> None: # CLI overrides env_cfg.seed = _resolve_worker_seed(args.seed) env_cfg.commands["motion"].cache_seed = env_cfg.seed + robot_xml = resolve_g1_training_xml(args.robot_xml) + if not robot_xml.is_file(): + raise FileNotFoundError(f"G1 training MuJoCo XML not found: {robot_xml}") + env_cfg.scene.entities["robot"] = make_g1_training_robot_cfg(robot_xml) + env_cfg.robot_xml = str(robot_xml) if args.num_envs is not None: env_cfg.scene.num_envs = args.num_envs if args.motion_file is not None: diff --git a/train_mimic/tasks/tracking/config/env.py b/train_mimic/tasks/tracking/config/env.py index 540db60f..35535524 100644 --- a/train_mimic/tasks/tracking/config/env.py +++ b/train_mimic/tasks/tracking/config/env.py @@ -3,6 +3,8 @@ from __future__ import annotations from copy import deepcopy +from functools import partial +from pathlib import Path import mujoco @@ -47,20 +49,33 @@ ) -def _get_g1_training_spec() -> mujoco.MjSpec: - if not UNITREE_G1_XML.is_file(): +def resolve_g1_training_xml(robot_xml: str | Path | None = None) -> Path: + """Resolve the MuJoCo XML used for G1 policy training.""" + if robot_xml is None or str(robot_xml).strip() == "": + return UNITREE_G1_XML.resolve() + + path = Path(robot_xml).expanduser() + if not path.is_absolute(): + path = (Path.cwd() / path).resolve() + return path + + +def _get_g1_training_spec(robot_xml: str | Path | None = None) -> mujoco.MjSpec: + xml_path = resolve_g1_training_xml(robot_xml) + if not xml_path.is_file(): raise FileNotFoundError( - missing_gmr_assets_message(UNITREE_G1_XML, label="G1 training MuJoCo XML") + missing_gmr_assets_message(xml_path, label="G1 training MuJoCo XML") ) - spec = mujoco.MjSpec.from_file(str(UNITREE_G1_XML)) + spec = mujoco.MjSpec.from_file(str(xml_path)) for actuator in list(spec.actuators): spec.delete(actuator) return spec -def _get_g1_training_robot_cfg(): +def make_g1_training_robot_cfg(robot_xml: str | Path | None = None): robot_cfg = get_g1_robot_cfg() - robot_cfg.spec_fn = _get_g1_training_spec + xml_path = resolve_g1_training_xml(robot_xml) + robot_cfg.spec_fn = partial(_get_g1_training_spec, xml_path) return robot_cfg @@ -189,7 +204,7 @@ def make_general_tracking_env_cfg( """Create the General-Tracking-G1 training env.""" cfg = make_tracking_env_cfg() - cfg.scene.entities = {"robot": _get_g1_training_robot_cfg()} + cfg.scene.entities = {"robot": make_g1_training_robot_cfg()} joint_pos_action = cfg.actions["joint_pos"] assert isinstance(joint_pos_action, JointPositionActionCfg) From 3b9d698d6cb789d6d5feca4386f419897e2bd58d Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Mon, 15 Jun 2026 21:37:26 +0800 Subject: [PATCH 093/122] Add additional wrist position reward --- train_mimic/tasks/tracking/config/env.py | 14 +++++++++ train_mimic/tasks/tracking/mdp/rewards.py | 36 +++++++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/train_mimic/tasks/tracking/config/env.py b/train_mimic/tasks/tracking/config/env.py index 35535524..36a551c6 100644 --- a/train_mimic/tasks/tracking/config/env.py +++ b/train_mimic/tasks/tracking/config/env.py @@ -198,6 +198,19 @@ def _configure_feet_acc_reward(cfg: ManagerBasedRlEnvCfg) -> None: ) +def _configure_additional_wrist_pos_reward(cfg: ManagerBasedRlEnvCfg) -> None: + cfg.rewards["additional_wrist_pos"] = RewardTermCfg( + func=mdp.motion_relative_body_point_position_error_exp, + weight=1.0, + params={ + "command_name": "motion", + "std": 0.12, + "body_names": ("left_wrist_yaw_link", "right_wrist_yaw_link"), + "body_offsets": ((0.18, -0.025, 0.0), (0.18, 0.025, 0.0)), + }, + ) + + def make_general_tracking_env_cfg( *, play: bool = False, ) -> ManagerBasedRlEnvCfg: @@ -227,6 +240,7 @@ def make_general_tracking_env_cfg( ].body_names = "torso_link" _configure_self_collision_reward(cfg) _configure_feet_acc_reward(cfg) + _configure_additional_wrist_pos_reward(cfg) cfg.terminations["ee_body_pos"].params["body_names"] = ( "left_ankle_roll_link", "right_ankle_roll_link", diff --git a/train_mimic/tasks/tracking/mdp/rewards.py b/train_mimic/tasks/tracking/mdp/rewards.py index 6a08dc44..ff1bfe07 100644 --- a/train_mimic/tasks/tracking/mdp/rewards.py +++ b/train_mimic/tasks/tracking/mdp/rewards.py @@ -5,6 +5,7 @@ import torch from mjlab.utils.lab_api.math import ( + quat_apply, quat_error_magnitude, ) @@ -80,6 +81,41 @@ def motion_relative_body_position_error_exp( return torch.exp(-error.mean(-1) / std**2) +def motion_relative_body_point_position_error_exp( + env: ManagerBasedRlEnv, + command_name: str, + std: float, + body_names: tuple[str, ...], + body_offsets: tuple[tuple[float, float, float], ...], +) -> torch.Tensor: + command = cast(MotionCommand, env.command_manager.get_term(command_name)) + body_index_by_name = {name: i for i, name in enumerate(command.cfg.body_names)} + missing_body_names = [name for name in body_names if name not in body_index_by_name] + if missing_body_names: + raise ValueError( + "body_names must exist in the motion command tracking body list: " + f"missing {missing_body_names}" + ) + if len(body_names) != len(body_offsets): + raise ValueError( + "body_offsets must contain one offset for each selected body: " + f"got {len(body_offsets)} offsets for {len(body_names)} bodies" + ) + + body_indexes = [body_index_by_name[name] for name in body_names] + offsets = torch.tensor( + body_offsets, dtype=command.body_pos_relative_w.dtype, device=env.device + ).expand(env.num_envs, -1, -1) + ref_points = command.body_pos_relative_w[:, body_indexes] + quat_apply( + command.body_quat_relative_w[:, body_indexes], offsets + ) + robot_points = command.robot_body_pos_w[:, body_indexes] + quat_apply( + command.robot_body_quat_w[:, body_indexes], offsets + ) + error = torch.sum(torch.square(ref_points - robot_points), dim=-1) + return torch.exp(-error.mean(-1) / std**2) + + def motion_relative_body_orientation_error_exp( env: ManagerBasedRlEnv, command_name: str, From dbb09796b7cbe20cb9af01cf20d93aa0d88854a8 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Mon, 15 Jun 2026 21:40:00 +0800 Subject: [PATCH 094/122] Expand tracking mass randomization bodies --- tests/test_domain_randomization.py | 6 +++++- train_mimic/tasks/tracking/config/env.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/test_domain_randomization.py b/tests/test_domain_randomization.py index da23645c..b98ce407 100644 --- a/tests/test_domain_randomization.py +++ b/tests/test_domain_randomization.py @@ -62,7 +62,11 @@ def test_general_tracking_domain_randomization_matches_gr00t_active_set() -> Non mass = events["randomize_rigid_body_mass"] assert mass.func is dr.pseudo_inertia assert mass.mode == "startup" - assert mass.params["asset_cfg"].body_names == "torso_link" + assert mass.params["asset_cfg"].body_names == ( + "torso_link", + "left_wrist_yaw_link", + "right_wrist_yaw_link", + ) assert mass.params["alpha_range"] == (-0.1, 0.45) diff --git a/train_mimic/tasks/tracking/config/env.py b/train_mimic/tasks/tracking/config/env.py index 36a551c6..3a3c47fb 100644 --- a/train_mimic/tasks/tracking/config/env.py +++ b/train_mimic/tasks/tracking/config/env.py @@ -237,7 +237,11 @@ def make_general_tracking_env_cfg( cfg.events["base_com"].params["asset_cfg"].body_names = ("torso_link",) cfg.events["randomize_rigid_body_mass"].params[ "asset_cfg" - ].body_names = "torso_link" + ].body_names = ( + "torso_link", + "left_wrist_yaw_link", + "right_wrist_yaw_link", + ) _configure_self_collision_reward(cfg) _configure_feet_acc_reward(cfg) _configure_additional_wrist_pos_reward(cfg) From bd823e9b21644aa4ea7167e7af212230c283d172 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Tue, 16 Jun 2026 21:05:33 +0800 Subject: [PATCH 095/122] Precompute motion dataset for training --- AGENTS.md | 11 +- README.md | 2 +- docs/docs/getting-started/download-assets.md | 18 +- docs/docs/reference/dataset.md | 27 +- .../reference/training-troubleshooting.md | 4 +- docs/docs/tutorials/training.md | 28 +- .../getting-started/download-assets.md | 18 +- .../current/reference/dataset.md | 27 +- .../reference/training-troubleshooting.md | 4 +- .../current/tutorials/training.md | 28 +- tests/test_motion_sampling.py | 148 ++++++++- tests/test_train_script.py | 45 ++- train_mimic/app.py | 10 +- train_mimic/data/dataset_lib.py | 303 ++++++++++++++++-- train_mimic/scripts/benchmark.py | 8 +- .../scripts/data/precompute_dataset.py | 106 ++++++ train_mimic/scripts/play.py | 8 +- train_mimic/scripts/train.py | 12 +- .../tasks/tracking/config/constants.py | 2 +- train_mimic/tasks/tracking/mdp/commands.py | 52 ++- 20 files changed, 710 insertions(+), 151 deletions(-) create mode 100644 train_mimic/scripts/data/precompute_dataset.py diff --git a/AGENTS.md b/AGENTS.md index 04f849e4..81d436c4 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -203,10 +203,11 @@ The single supported training task is `General-Tracking-G1` (experiment name: `g ### Dataset Pipeline - Dataset build spec supports a `preprocess` section for root-xy normalization, ground alignment, and basic clip filtering -- Final training dataset outputs are minimal HDF5 shards directly under `data/datasets//` (recursive shard discovery is supported; no train/val split and no manifest file) +- Final distributed dataset build outputs are minimal HDF5 shards directly under `data/datasets//` (recursive shard discovery is supported; no train/val split and no manifest file) +- `train_mimic/scripts/data/precompute_dataset.py` converts a minimal dataset into a separate precomputed training dataset directory; `build_dataset.py` must not run precompute - Each shard stores only `root_pos`, `root_quat_w`, `joint_pos`, `body_names`, and clip-aware window metadata (`clip_starts`, `clip_lengths`, `clip_fps`); long clips are split into overlapping bounded windows -- Training computes joint velocities and body FK/velocities online in PyTorch DataLoader workers when loading the motion cache -- `MotionLib` loads only a configurable HDF5 subset cache into CPU/GPU memory, asynchronously stages the next cache, and swaps at the PPO rollout barrier +- Training `motion_file` must point to a precomputed training dataset, not the minimal distributed dataset; training reads joint velocities and body FK/velocities from those precomputed shards and must not run MuJoCo FK while loading motion clips into the fixed-size cache +- `MotionLib` loads only a configurable precomputed HDF5 subset cache into CPU/GPU memory, asynchronously stages the next cache, and swaps caches at the PPO rollout barrier - `MotionLib` samples only valid center frames for the configured `window_steps`; default is `window_steps=[0]` - Training supports `uniform` and `rewind` sampling on the active cache; in distributed training each rank sets a rank-offset `cache_seed` - `scripts/run/record_pico_motion.py` records Pico live body tracking as retargeted G1 motion NPZ clips in `data/pico_motion/clips/`; it opens a live `Retarget` viewer, uses terminal keys `R/S/D/N/Q`, stores semantic labels in filenames, and intentionally does not write per-clip JSON @@ -218,7 +219,9 @@ Quick reference: python train_mimic/scripts/data/build_dataset.py --spec train_mimic/configs/datasets/twist2.yaml python scripts/run/record_pico_motion.py python train_mimic/scripts/data/build_dataset.py --spec data/pico_motion/pico_recorded.yaml --force -python train_mimic/scripts/train.py --motion_file data/datasets/twist2 +python train_mimic/scripts/data/precompute_dataset.py data/datasets/seed --outdir data/datasets/seed_precomputed --jobs 8 +python train_mimic/scripts/train.py --motion_file data/datasets/seed_precomputed +python train_mimic/scripts/data/precompute_dataset.py data/datasets/twist2 --outdir data/datasets/twist2_precomputed --jobs 8 --force python train_mimic/scripts/save_onnx.py --checkpoint logs/rsl_rl/g1_general_tracking//model_30000.pt --output policy.onnx --history_length 10 ``` diff --git a/README.md b/README.md index ed796e38..9ce0bdfa 100644 --- a/README.md +++ b/README.md @@ -99,7 +99,7 @@ Full docs at **[BotRunner64.github.io/Teleopit](https://BotRunner64.github.io/Te - Switched default `vr_hand_pose` to a low-latency somehand path with 60 Hz hand retargeting and reduced smoothing. - Realtime mode switches and pause/resume now preserve GMR IK warm-starts instead of cold-starting the retargeter on each transition. - Added an interactive Pico motion recorder that saves retargeted G1 motion clips as training-ready NPZ files. -- Switched training datasets to recursive minimal HDF5 shards with no train/val split or manifest; training derives FK/velocities while loading the motion cache. +- Switched dataset build outputs to recursive minimal HDF5 shards with no train/val split or manifest; `precompute_dataset.py` turns them into separate precomputed training datasets before training. - General-Tracking-G1 training defaults to `rewind` motion sampling and also supports `uniform`; playback/benchmark use `start`. - Added optional `sampling_mode=rewind` for training, which restarts failed episodes from the same clip after rewinding a configurable number of policy steps. - Added root velocity, joint tracking, and survival rewards to the General-Tracking-G1 training objective. diff --git a/docs/docs/getting-started/download-assets.md b/docs/docs/getting-started/download-assets.md index 7f4de486..994d9486 100644 --- a/docs/docs/getting-started/download-assets.md +++ b/docs/docs/getting-started/download-assets.md @@ -25,14 +25,16 @@ python scripts/setup/download_assets.py --only robots gmr ckpt bvh ## Asset Inventory -| Asset | Size | Purpose | -|-------|------|---------| -| `track.onnx` | 4 MB | ONNX inference model | -| `track.pt` | 27 MB | PyTorch checkpoint (for resume training) | -| `data/datasets/seed/shard_*.h5` | ~26 GB | Training dataset | -| `data/sample_bvh/*.bvh` | 5 MB | Sample motion files | -| `assets/robots/unitree_g1/` | ~52 MB | Canonical G1 XML and meshes used by training, sim2sim, retargeting, and FK validation | -| `teleopit/retargeting/gmr/assets/` | ~1.2 GB | GMR retargeting assets, IK configs, and non-canonical robot descriptions | +Downloaded file sizes change as checkpoints, datasets, and asset bundles are updated. Use the repository paths below as the stable contract. + +| Local Path | Purpose | +|------------|---------| +| `track.onnx` | ONNX inference model | +| `track.pt` | PyTorch checkpoint for resume training | +| `data/datasets/seed/shard_*.h5` | Minimal motion dataset; run precompute before training | +| `data/sample_bvh/*.bvh` | Sample motion files | +| `assets/robots/unitree_g1/` | Canonical G1 XML and meshes used by training, sim2sim, retargeting, and FK validation | +| `teleopit/retargeting/gmr/assets/` | GMR retargeting assets, IK configs, and non-canonical robot descriptions | ## Asset Groups diff --git a/docs/docs/reference/dataset.md b/docs/docs/reference/dataset.md index 555b105d..0f3752cb 100644 --- a/docs/docs/reference/dataset.md +++ b/docs/docs/reference/dataset.md @@ -7,13 +7,15 @@ sidebar_position: 3 ## Download Pre-Built Dataset (Recommended) ```bash -python scripts/setup/download_assets.py --only data +python scripts/setup/download_assets.py --only robots data ``` -Then train directly with the dataset root: +Then precompute the training shard and train with the precomputed dataset root: ```bash -python train_mimic/scripts/train.py --motion_file data/datasets/seed +python train_mimic/scripts/data/precompute_dataset.py \ + data/datasets/seed --outdir data/datasets/seed_precomputed --jobs 8 +python train_mimic/scripts/train.py --motion_file data/datasets/seed_precomputed ``` For custom dataset construction, read on. @@ -48,7 +50,7 @@ At least one valid clip is required after preprocessing. ## Custom Dataset Construction -Data pipeline: `typed source YAML -> preprocess/filter -> minimal HDF5 shards` +Data pipeline: `typed source YAML -> preprocess/filter -> minimal HDF5 shards -> precomputed training dataset` ```bash python train_mimic/scripts/data/build_dataset.py \ @@ -60,12 +62,17 @@ python train_mimic/scripts/data/build_dataset.py \ ```text data/datasets// └── shard_*.h5 + +data/datasets/_precomputed/ +└── shard_*.h5 ``` - If the spec contains `bvh` or `npz` sources, the full dataset builder uses a temporary `clips/` directory during conversion and deletes it after shards are written. Rebuilds do not reuse converted clips. - If the spec is all `pkl` or `seed_csv` sources, the builder takes a batch path producing shards directly -- Training recursively discovers `*.h5` shards below the specified root, so datasets can be merged by placing multiple shard directories under one parent -- Training loads only a subset cache from the discovered shards, derives FK/velocities online, stages the next cache, and swaps caches at the PPO rollout barrier. +- `build_dataset.py` only writes the minimal distributable dataset. It does not run FK precompute. +- `precompute_dataset.py` writes a separate training dataset containing the minimal motion plus precomputed joint velocities and body FK/velocities. +- Training accepts only the precomputed dataset directory. It recursively discovers precomputed `*.h5` shards below the specified root, so precomputed datasets can be merged by placing multiple shard directories under one parent. +- Training loads only a subset cache from the discovered precomputed shards, stages the next cache asynchronously, and swaps caches at the PPO rollout barrier. Joint velocities and body FK/velocities are not computed during training. ## YAML Spec Format @@ -108,13 +115,13 @@ sources: ## Conversion Rules -All sources are converted to standard training shards. Each clip goes through preprocessing/filtering before writing to shards: +All sources are converted to standard minimal shards. Each clip goes through preprocessing/filtering before writing to shards: - `bvh -> retarget pkl -> npz clip` - `pkl -> npz clip` (or direct batch shard for pkl-only datasets) - `npz -> validate + copy/reuse` -Each shard stores minimal motion data: `root_pos`, `root_quat_w`, `joint_pos`, `body_names`, `clip_starts`, `clip_lengths`, `clip_fps`. Joint velocities and body FK/velocities are computed when training loads a cache. +Each minimal shard stores `root_pos`, `root_quat_w`, `joint_pos`, `body_names`, `clip_starts`, `clip_lengths`, and `clip_fps`. The precomputed training shards store `joint_pos`, `joint_vel`, `body_pos_w`, `body_quat_w`, `body_lin_vel_w`, `body_ang_vel_w`, and the same metadata. Training fails fast if `--motion_file` points at a minimal dataset instead of a precomputed training dataset. ## Common Commands @@ -136,6 +143,10 @@ python train_mimic/scripts/data/build_dataset.py \ python train_mimic/scripts/data/build_dataset.py \ --spec train_mimic/configs/datasets/twist2.yaml --json +# Generate a precomputed training dataset from an existing minimal dataset +python train_mimic/scripts/data/precompute_dataset.py \ + data/datasets/twist2 --outdir data/datasets/twist2_precomputed --jobs 8 --force + # Inspect a dataset root python train_mimic/scripts/data/inspect_dataset.py data/datasets/twist2 ``` diff --git a/docs/docs/reference/training-troubleshooting.md b/docs/docs/reference/training-troubleshooting.md index 61475ad9..b683c6c4 100644 --- a/docs/docs/reference/training-troubleshooting.md +++ b/docs/docs/reference/training-troubleshooting.md @@ -45,7 +45,7 @@ If check fails, regenerate data and run a smoke test: ```bash python train_mimic/scripts/train.py \ --num_envs 64 --max_iterations 100 \ - --motion_file data/datasets/ + --motion_file data/datasets/_precomputed ``` Expected: `Mean episode length` significantly > 1, `error_anchor_pos` starts decreasing. @@ -126,7 +126,7 @@ Ensure `num_eval_steps >= video_length`: ```bash python train_mimic/scripts/benchmark.py \ --checkpoint logs/rsl_rl/g1_general_tracking//model_30000.pt \ - --motion_file data/datasets/ \ + --motion_file data/datasets/_precomputed \ --num_envs 1 --num_eval_steps 2000 \ --video --video_length 600 ``` diff --git a/docs/docs/tutorials/training.md b/docs/docs/tutorials/training.md index 6b150ac8..00d83e9a 100644 --- a/docs/docs/tutorials/training.md +++ b/docs/docs/tutorials/training.md @@ -23,6 +23,14 @@ Verify: python -c "import train_mimic.tasks; print('training OK')" ``` +Download the minimal seed dataset and generate the precomputed training shard: + +```bash +python scripts/setup/download_assets.py --only robots data +python train_mimic/scripts/data/precompute_dataset.py \ + data/datasets/seed --outdir data/datasets/seed_precomputed --jobs 8 +``` + ## Training ### Smoke Test @@ -31,7 +39,7 @@ python -c "import train_mimic.tasks; print('training OK')" python train_mimic/scripts/train.py \ --num_envs 64 \ --max_iterations 100 \ - --motion_file data/datasets/seed + --motion_file data/datasets/seed_precomputed ``` ### Full Training @@ -40,7 +48,7 @@ python train_mimic/scripts/train.py \ python train_mimic/scripts/train.py \ --num_envs 4096 \ --max_iterations 30000 \ - --motion_file data/datasets/seed + --motion_file data/datasets/seed_precomputed ``` ### Multi-GPU @@ -50,7 +58,7 @@ python train_mimic/scripts/train.py \ --gpu_ids 0 1 2 3 \ --num_envs 1024 \ --max_iterations 30000 \ - --motion_file data/datasets/seed + --motion_file data/datasets/seed_precomputed ``` ### Multi-Node Multi-GPU @@ -67,16 +75,16 @@ torchrun \ train_mimic/scripts/train.py \ --num_envs 1024 \ --max_iterations 1000 \ - --motion_file data/datasets/seed + --motion_file data/datasets/seed_precomputed ``` **Notes:** - `--num_envs` is per-GPU in multi-GPU mode - `--num_envs` is also per-process in multi-node mode, so total environments scale with `world_size` - Default logger is TensorBoard. Use `--logger wandb` or `--logger swanlab` to select W&B or SwanLab; the project name defaults to `experiment_name` -- `--motion_file` accepts a dataset root directory or single `.h5` shard; shard discovery is recursive -- `--cache_num_clips` controls the active HDF5 subset size; `--cache_swap_interval_steps` controls how often the next subset is swapped in at a rollout barrier -- `--cache_dataloader_num_workers`, `--cache_dataloader_prefetch_factor`, and `--cache_dataloader_pin_memory` tune asynchronous HDF5 cache loading without increasing dataset size +- `--motion_file` accepts a precomputed training dataset root directory or a single precomputed `.h5` shard; shard discovery is recursive +- If you only have the minimal distributed shards, first run `python train_mimic/scripts/data/precompute_dataset.py --outdir ` and pass the precomputed output to training. +- `--cache_num_clips` controls how many precomputed HDF5 motion windows are loaded into the active subset cache; the next cache is staged asynchronously and swapped at rollout barriers. - `--max_iterations` means additional iterations; resuming from `model_12000.pt` with `--max_iterations 18000` trains to `model_30000.pt` ## Export ONNX @@ -97,7 +105,7 @@ The exported model is a dual-input ONNX (`obs` + `obs_history`). The inference s ```bash python train_mimic/scripts/play.py \ --checkpoint logs/rsl_rl/g1_general_tracking//model_30000.pt \ - --motion_file data/datasets/seed + --motion_file data/datasets/seed_precomputed ``` ### Benchmark @@ -105,7 +113,7 @@ python train_mimic/scripts/play.py \ ```bash python train_mimic/scripts/benchmark.py \ --checkpoint logs/rsl_rl/g1_general_tracking//model_30000.pt \ - --motion_file data/datasets/seed \ + --motion_file data/datasets/seed_precomputed \ --num_envs 1 ``` @@ -114,7 +122,7 @@ python train_mimic/scripts/benchmark.py \ ```bash python train_mimic/scripts/benchmark.py \ --checkpoint logs/rsl_rl/g1_general_tracking//model_30000.pt \ - --motion_file data/datasets/seed \ + --motion_file data/datasets/seed_precomputed \ --num_envs 1 \ --video \ --video_length 600 diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/download-assets.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/download-assets.md index 014a0e6b..e8f17538 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/download-assets.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/download-assets.md @@ -25,14 +25,16 @@ python scripts/setup/download_assets.py --only robots gmr ckpt bvh ## 资源清单 -| 资源 | 大小 | 用途 | -|------|------|------| -| `track.onnx` | 4 MB | ONNX 推理模型 | -| `track.pt` | 27 MB | PyTorch 检查点(用于恢复训练) | -| `data/datasets/seed/shard_*.h5` | ~26 GB | 训练数据集 | -| `data/sample_bvh/*.bvh` | 5 MB | 示例动捕文件 | -| `assets/robots/unitree_g1/` | ~52 MB | 训练、sim2sim、重定向和 FK 校验共用的 G1 canonical XML 与 mesh | -| `teleopit/retargeting/gmr/assets/` | ~1.2 GB | GMR 重定向资源、IK 配置和非 canonical 机器人描述 | +checkpoint、数据集和资源包更新后,下载文件大小会变化。下表中的仓库路径才是稳定约定。 + +| 本地路径 | 用途 | +|----------|------| +| `track.onnx` | ONNX 推理模型 | +| `track.pt` | 用于恢复训练的 PyTorch checkpoint | +| `data/datasets/seed/shard_*.h5` | 最小运动数据集;训练前需先预计算 | +| `data/sample_bvh/*.bvh` | 示例动捕文件 | +| `assets/robots/unitree_g1/` | 训练、sim2sim、重定向和 FK 校验共用的 G1 canonical XML 与 mesh | +| `teleopit/retargeting/gmr/assets/` | GMR 重定向资源、IK 配置和非 canonical 机器人描述 | ## 资源分组 diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/dataset.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/dataset.md index 5c65fef1..f45ab730 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/dataset.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/dataset.md @@ -7,13 +7,15 @@ sidebar_position: 3 ## 下载预构建数据集(推荐) ```bash -python scripts/setup/download_assets.py --only data +python scripts/setup/download_assets.py --only robots data ``` -下载后直接传数据集根目录用于训练: +下载后先生成预计算训练 shard,再把预计算数据集根目录用于训练: ```bash -python train_mimic/scripts/train.py --motion_file data/datasets/seed +python train_mimic/scripts/data/precompute_dataset.py \ + data/datasets/seed --outdir data/datasets/seed_precomputed --jobs 8 +python train_mimic/scripts/train.py --motion_file data/datasets/seed_precomputed ``` 如需自定义构建,继续阅读下文。 @@ -46,7 +48,7 @@ python train_mimic/scripts/data/build_dataset.py \ ## 自定义构建 -数据主线:`typed source YAML -> preprocess/filter -> minimal HDF5 shards` +数据主线:`typed source YAML -> preprocess/filter -> minimal HDF5 shards -> precomputed training dataset` ```bash python train_mimic/scripts/data/build_dataset.py \ @@ -58,12 +60,17 @@ python train_mimic/scripts/data/build_dataset.py \ ```text data/datasets// └── shard_*.h5 + +data/datasets/_precomputed/ +└── shard_*.h5 ``` - 若 spec 包含 `bvh` 或 `npz` source,完整 dataset builder 会在转换期间使用临时 `clips/` 目录,并在 shard 写入完成后删除。重新 build 不会复用已转换 clips。 - 若 spec 全部是 `pkl` 或 `seed_csv` source,builder 会直接并行产出 shard,默认不写中间 clip 文件 -- 训练会递归发现指定根目录下的 `*.h5` shard,因此可以把多个数据集目录放到同一个父目录下完成合并 -- 训练时只从发现的 shard 加载一个 subset cache,在线派生 FK/速度,同时预加载下一个 cache,并在 PPO rollout barrier 处切换。 +- `build_dataset.py` 只写最小分发数据集,不执行 FK 预计算。 +- `precompute_dataset.py` 会写出独立的训练数据集,里面包含最小运动数据以及预计算的 joint velocity 和 body FK/velocity。 +- 训练只接受预计算后的数据集目录。它会递归发现指定根目录下的预计算 `*.h5` shard,因此可以把多个预计算数据集目录放到同一个父目录下完成合并。 +- 训练只会从发现的预计算 shard 中加载 subset cache,异步 staging 下一个 cache,并在 PPO rollout barrier 切换 cache。joint velocity 和 body FK/velocity 不会在训练时计算。 ## YAML spec @@ -104,13 +111,13 @@ sources: ## 转换规则 -所有 source 都会转换成标准训练 shard。每段 clip 会先经过预处理/过滤,再写入 shard: +所有 source 都会转换成标准最小 shard。每段 clip 会先经过预处理/过滤,再写入 shard: - `bvh -> retarget pkl -> npz clip` - `pkl -> npz clip`(或在 pkl-only 数据集中直接 batch 写 shard) - `npz -> validate + copy/reuse` -每个 shard 只保存最小运动数据:`root_pos`、`root_quat_w`、`joint_pos`、`body_names`、`clip_starts`、`clip_lengths` 和 `clip_fps`。Joint velocity 和 body FK/velocity 会在训练加载 cache 时计算。 +每个最小 shard 保存 `root_pos`、`root_quat_w`、`joint_pos`、`body_names`、`clip_starts`、`clip_lengths` 和 `clip_fps`。预计算训练 shard 保存 `joint_pos`、`joint_vel`、`body_pos_w`、`body_quat_w`、`body_lin_vel_w`、`body_ang_vel_w` 以及相同的元数据。如果 `--motion_file` 指向最小数据集而不是预计算训练数据集,训练会立即报错。 ## 常用命令 @@ -132,6 +139,10 @@ python train_mimic/scripts/data/build_dataset.py \ python train_mimic/scripts/data/build_dataset.py \ --spec train_mimic/configs/datasets/twist2.yaml --json +# 从已有最小数据集生成预计算训练数据集 +python train_mimic/scripts/data/precompute_dataset.py \ + data/datasets/twist2 --outdir data/datasets/twist2_precomputed --jobs 8 --force + # 查看数据集统计 python train_mimic/scripts/data/inspect_dataset.py data/datasets/twist2 ``` diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/training-troubleshooting.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/training-troubleshooting.md index fb9f6b0d..14a8500a 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/training-troubleshooting.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/training-troubleshooting.md @@ -45,7 +45,7 @@ python train_mimic/scripts/data/check_motion_npz_fk.py \ ```bash python train_mimic/scripts/train.py \ --num_envs 64 --max_iterations 100 \ - --motion_file data/datasets/ + --motion_file data/datasets/_precomputed ``` 预期:`Mean episode length` 明显大于 1,`error_anchor_pos` 开始下降。 @@ -126,7 +126,7 @@ self.sim.nconmax = 150_000 ```bash python train_mimic/scripts/benchmark.py \ --checkpoint logs/rsl_rl/g1_general_tracking//model_30000.pt \ - --motion_file data/datasets/ \ + --motion_file data/datasets/_precomputed \ --num_envs 1 --num_eval_steps 2000 \ --video --video_length 600 ``` diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md index eea5989a..ce7c82af 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md @@ -23,6 +23,14 @@ pip install -e '.[train]' python -c "import train_mimic.tasks; print('training OK')" ``` +下载最小 seed 数据集,并生成预计算训练 shard: + +```bash +python scripts/setup/download_assets.py --only robots data +python train_mimic/scripts/data/precompute_dataset.py \ + data/datasets/seed --outdir data/datasets/seed_precomputed --jobs 8 +``` + ## 训练 ### 冒烟测试 @@ -31,7 +39,7 @@ python -c "import train_mimic.tasks; print('training OK')" python train_mimic/scripts/train.py \ --num_envs 64 \ --max_iterations 100 \ - --motion_file data/datasets/seed + --motion_file data/datasets/seed_precomputed ``` ### 完整训练 @@ -40,7 +48,7 @@ python train_mimic/scripts/train.py \ python train_mimic/scripts/train.py \ --num_envs 4096 \ --max_iterations 30000 \ - --motion_file data/datasets/seed + --motion_file data/datasets/seed_precomputed ``` ### 多卡训练 @@ -50,7 +58,7 @@ python train_mimic/scripts/train.py \ --gpu_ids 0 1 2 3 \ --num_envs 1024 \ --max_iterations 30000 \ - --motion_file data/datasets/seed + --motion_file data/datasets/seed_precomputed ``` ### 多机多卡训练 @@ -67,16 +75,16 @@ torchrun \ train_mimic/scripts/train.py \ --num_envs 1024 \ --max_iterations 1000 \ - --motion_file data/datasets/seed + --motion_file data/datasets/seed_precomputed ``` **注意事项:** - 多卡模式下 `--num_envs` 为每张 GPU 的环境数量 - 多机模式下 `--num_envs` 也按每个进程计算,因此总环境数会随 `world_size` 线性增长 - 默认日志工具为 TensorBoard。使用 `--logger wandb` 或 `--logger swanlab` 可选择 W&B 或 SwanLab;项目名默认使用 `experiment_name` -- `--motion_file` 接受数据集根目录或单个 `.h5` shard;shard 会递归发现 -- `--cache_num_clips` 控制当前 HDF5 subset cache 大小;`--cache_swap_interval_steps` 控制在 rollout barrier 切换下一个 subset 的频率 -- `--cache_dataloader_num_workers`、`--cache_dataloader_prefetch_factor` 和 `--cache_dataloader_pin_memory` 用于调节异步 HDF5 cache 加载,不会增加数据集大小 +- `--motion_file` 接受预计算训练数据集根目录或单个预计算 `.h5` shard;shard 会递归发现 +- 如果只有最小分发 shard,先运行 `python train_mimic/scripts/data/precompute_dataset.py --outdir `,再把预计算输出传给训练。 +- `--cache_num_clips` 控制加载到 active subset cache 的预计算 HDF5 motion window 数量;下一个 cache 会异步 staging,并在 rollout barrier 切换。 - `--max_iterations` 表示追加迭代次数;例如从 `model_12000.pt` 恢复训练并设置 `--max_iterations 18000`,最终将训练到 `model_30000.pt` ## 导出 ONNX @@ -97,7 +105,7 @@ python train_mimic/scripts/save_onnx.py \ ```bash python train_mimic/scripts/play.py \ --checkpoint logs/rsl_rl/g1_general_tracking//model_30000.pt \ - --motion_file data/datasets/seed + --motion_file data/datasets/seed_precomputed ``` ### 定量评估 @@ -105,7 +113,7 @@ python train_mimic/scripts/play.py \ ```bash python train_mimic/scripts/benchmark.py \ --checkpoint logs/rsl_rl/g1_general_tracking//model_30000.pt \ - --motion_file data/datasets/seed \ + --motion_file data/datasets/seed_precomputed \ --num_envs 1 ``` @@ -114,7 +122,7 @@ python train_mimic/scripts/benchmark.py \ ```bash python train_mimic/scripts/benchmark.py \ --checkpoint logs/rsl_rl/g1_general_tracking//model_30000.pt \ - --motion_file data/datasets/seed \ + --motion_file data/datasets/seed_precomputed \ --num_envs 1 \ --video \ --video_length 600 diff --git a/tests/test_motion_sampling.py b/tests/test_motion_sampling.py index 5a2074f1..a834e634 100644 --- a/tests/test_motion_sampling.py +++ b/tests/test_motion_sampling.py @@ -6,8 +6,14 @@ import numpy as np import pytest import torch - -from train_mimic.data.dataset_lib import write_hdf5_motion_shard +import h5py + +from train_mimic.data.dataset_lib import ( + PRECOMPUTED_MOTION_VERSION, + compute_dataset_stats, + write_precomputed_motion_shard, + write_hdf5_motion_shard, +) from train_mimic.tasks.tracking.mdp.commands import MotionCommand, MotionLib @@ -64,11 +70,37 @@ def _write_shard_dir( merged["clip_starts"] = clip_starts merged["clip_lengths"] = clip_lengths merged["clip_fps"] = np.full(len(clip_dicts), int(clip_dicts[0]["fps"]), dtype=np.int64) - shard_info = write_hdf5_motion_shard(merged, path / "shard_000.h5") - _ = shard_info + shard_path = path / "shard_000.h5" + _write_precomputed_from_merged(shard_path, merged) return path +def _write_precomputed_from_merged(shard_path: Path, merged: dict[str, object]) -> None: + shard_path.parent.mkdir(parents=True, exist_ok=True) + str_dt = h5py.string_dtype(encoding="utf-8") + clip_starts = np.asarray(merged["clip_starts"], dtype=np.int64) + clip_lengths = np.asarray(merged["clip_lengths"], dtype=np.int64) + clip_fps = np.asarray(merged["clip_fps"], dtype=np.int64) + with h5py.File(shard_path, "w") as h5: + h5.attrs["format"] = "teleopit_precomputed_motion_hdf5" + h5.attrs["version"] = PRECOMPUTED_MOTION_VERSION + h5.create_dataset( + "body_names", + data=np.asarray(merged["body_names"]).astype(str).astype(object), + dtype=str_dt, + ) + for key in ("joint_pos", "joint_vel", "body_pos_w", "body_quat_w", "body_lin_vel_w", "body_ang_vel_w"): + h5.create_dataset(key, data=np.asarray(merged[key], dtype=np.float32), chunks=True) + h5.create_dataset("clip_starts", data=clip_starts) + h5.create_dataset("clip_lengths", data=clip_lengths) + h5.create_dataset("clip_fps", data=clip_fps) + h5.create_dataset("source_clip_ids", data=np.arange(len(clip_lengths), dtype=np.int64)) + h5.create_dataset("source_start_frames", data=np.zeros(len(clip_lengths), dtype=np.int64)) + h5.create_dataset("source_clip_starts", data=clip_starts) + h5.create_dataset("source_clip_lengths", data=clip_lengths) + h5.create_dataset("source_clip_fps", data=clip_fps) + + def test_motion_lib_sample_times_respect_window_steps(tmp_path: Path) -> None: motion_path = _write_shard_dir(tmp_path / "motion", [_clip_dict()]) @@ -177,6 +209,111 @@ def test_motion_lib_selects_bodies_by_dataset_names(tmp_path: Path) -> None: assert torch.isfinite(frames["body_pos_w"]).all() +def test_motion_lib_rejects_minimal_motion_shards(tmp_path: Path) -> None: + path = tmp_path / "motion_minimal" + path.mkdir() + clip = _clip_dict() + merged = { + "fps": int(clip["fps"]), + "root_pos": np.asarray(clip["root_pos"]), + "root_quat_w": np.asarray(clip["root_quat_w"]), + "joint_pos": np.asarray(clip["joint_pos"]), + "body_names": np.asarray(clip["body_names"]), + "clip_starts": np.asarray([0], dtype=np.int64), + "clip_lengths": np.asarray([np.asarray(clip["joint_pos"]).shape[0]], dtype=np.int64), + "clip_fps": np.asarray([int(clip["fps"])], dtype=np.int64), + } + write_hdf5_motion_shard(merged, path / "shard_000.h5") + + with pytest.raises(FileNotFoundError, match="precomputed Teleopit"): + MotionLib( + str(path), + body_indexes=torch.tensor([0, 1], dtype=torch.long), + window_steps=(0,), + ) + + +def test_precomputed_stats_preserve_source_clip_ids_for_windowed_shards(tmp_path: Path) -> None: + clip = _clip_dict(num_frames=12, fps=30) + shard_path = tmp_path / "motion" / "shard_000.h5" + shard_path.parent.mkdir(parents=True) + str_dt = h5py.string_dtype(encoding="utf-8") + with h5py.File(shard_path, "w") as h5: + h5.attrs["format"] = "teleopit_precomputed_motion_hdf5" + h5.attrs["version"] = PRECOMPUTED_MOTION_VERSION + h5.create_dataset("body_names", data=np.asarray(clip["body_names"]).astype(object), dtype=str_dt) + for key in ("joint_pos", "joint_vel", "body_pos_w", "body_quat_w", "body_lin_vel_w", "body_ang_vel_w"): + h5.create_dataset(key, data=np.asarray(clip[key], dtype=np.float32), chunks=True) + h5.create_dataset("clip_starts", data=np.asarray([0, 4, 8], dtype=np.int64)) + h5.create_dataset("clip_lengths", data=np.asarray([6, 6, 4], dtype=np.int64)) + h5.create_dataset("clip_fps", data=np.asarray([30, 30, 30], dtype=np.int64)) + h5.create_dataset("source_clip_ids", data=np.asarray([0, 0, 0], dtype=np.int64)) + h5.create_dataset("source_start_frames", data=np.asarray([0, 4, 8], dtype=np.int64)) + h5.create_dataset("source_clip_starts", data=np.asarray([0], dtype=np.int64)) + h5.create_dataset("source_clip_lengths", data=np.asarray([12], dtype=np.int64)) + h5.create_dataset("source_clip_fps", data=np.asarray([30], dtype=np.int64)) + + stats = compute_dataset_stats(shard_path.parent, precomputed=True) + + assert stats["windows"] == 3 + assert stats["source_clips"] == 1 + assert stats["duration_s"] == pytest.approx(12 / 30) + + +def test_write_precomputed_motion_shard_copies_window_source_metadata( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + class _FakeMotionFkExtractor: + def __init__(self, model_path: object = None) -> None: + del model_path + + def extract( + self, + root_pos: np.ndarray, + root_quat_w: np.ndarray, + joint_pos: np.ndarray, + body_names: np.ndarray, + ) -> tuple[np.ndarray, np.ndarray]: + del joint_pos + bodies = int(len(body_names)) + body_pos_w = np.repeat(np.asarray(root_pos, dtype=np.float32)[:, None, :], bodies, axis=1) + body_quat_w = np.repeat(np.asarray(root_quat_w, dtype=np.float32)[:, None, :], bodies, axis=1) + return body_pos_w, body_quat_w + + monkeypatch.setattr("train_mimic.data.motion_fk.MotionFkExtractor", _FakeMotionFkExtractor) + clip = _clip_dict(num_frames=12, fps=30) + merged = { + "fps": int(clip["fps"]), + "root_pos": np.asarray(clip["root_pos"], dtype=np.float32), + "root_quat_w": np.asarray(clip["root_quat_w"], dtype=np.float32), + "joint_pos": np.asarray(clip["joint_pos"], dtype=np.float32), + "body_names": np.asarray(clip["body_names"]).astype(str), + "clip_starts": np.asarray([0], dtype=np.int64), + "clip_lengths": np.asarray([12], dtype=np.int64), + "clip_fps": np.asarray([30], dtype=np.int64), + } + minimal_path = tmp_path / "minimal" / "shard_000.h5" + precomputed_path = tmp_path / "precomputed" / "shard_000.h5" + write_hdf5_motion_shard( + merged, + minimal_path, + max_window_frames=6, + overlap_frames=2, + ) + + write_precomputed_motion_shard(minimal_path, precomputed_path) + + with h5py.File(precomputed_path, "r") as h5: + assert h5.attrs["version"] == PRECOMPUTED_MOTION_VERSION + assert h5["clip_starts"][()].tolist() == [0, 4, 6] + assert h5["clip_lengths"][()].tolist() == [6, 6, 6] + assert h5["source_clip_ids"][()].tolist() == [0, 0, 0] + assert h5["source_start_frames"][()].tolist() == [0, 4, 6] + assert h5["source_clip_starts"][()].tolist() == [0] + assert h5["source_clip_lengths"][()].tolist() == [12] + + def test_motion_lib_window_start_and_end_times_follow_valid_center_range(tmp_path: Path) -> None: motion_path = _write_shard_dir(tmp_path / "motion_windowed", [_clip_dict()]) @@ -255,7 +392,8 @@ def test_motion_lib_rejects_shard_body_name_mismatch(tmp_path: Path) -> None: merged["clip_starts"] = np.asarray([0], dtype=np.int64) merged["clip_lengths"] = np.asarray([np.asarray(clip_bad["joint_pos"]).shape[0]], dtype=np.int64) merged["clip_fps"] = np.asarray([int(clip_bad["fps"])], dtype=np.int64) - write_hdf5_motion_shard(merged, motion_path / "shard_001.h5") + bad_shard = motion_path / "shard_001.h5" + _write_precomputed_from_merged(bad_shard, merged) with pytest.raises(ValueError, match="body_names"): MotionLib( diff --git a/tests/test_train_script.py b/tests/test_train_script.py index ef82d381..d0361f41 100644 --- a/tests/test_train_script.py +++ b/tests/test_train_script.py @@ -8,11 +8,12 @@ import types from pathlib import Path +import h5py import numpy as np import pytest from train_mimic.app import DEFAULT_TASK, validate_checkpoint_path, validate_motion_file -from train_mimic.data.dataset_lib import write_hdf5_motion_shard +from train_mimic.data.dataset_lib import PRECOMPUTED_MOTION_VERSION from train_mimic.scripts import train from train_mimic.tasks.tracking.config.rl import make_general_tracking_ppo_runner_cfg @@ -34,7 +35,7 @@ def _args(**overrides: object) -> argparse.Namespace: "seed": 42, "logger": "tensorboard", "experiment_name": None, - "motion_file": "data/datasets/twist2", + "motion_file": "data/datasets/seed_precomputed", "robot_xml": None, "resume": None, "sampling_mode": None, @@ -263,19 +264,33 @@ def test_make_g1_training_robot_cfg_uses_requested_xml() -> None: def test_validate_motion_file_accepts_shard_directories(tmp_path: Path) -> None: num_frames = 3 - write_hdf5_motion_shard( - { - "fps": 30, - "root_pos": np.zeros((num_frames, 3), dtype=np.float32), - "root_quat_w": np.tile(np.asarray([[1.0, 0.0, 0.0, 0.0]], dtype=np.float32), (num_frames, 1)), - "joint_pos": np.zeros((num_frames, 29), dtype=np.float32), - "body_names": np.asarray(["pelvis"], dtype=str), - "clip_starts": np.asarray([0], dtype=np.int64), - "clip_lengths": np.asarray([num_frames], dtype=np.int64), - "clip_fps": np.asarray([30], dtype=np.int64), - }, - tmp_path / "shard_000.h5", - ) + shard_path = tmp_path / "shard_000.h5" + str_dt = h5py.string_dtype(encoding="utf-8") + with h5py.File(shard_path, "w") as h5: + h5.attrs["format"] = "teleopit_precomputed_motion_hdf5" + h5.attrs["version"] = PRECOMPUTED_MOTION_VERSION + h5.create_dataset("body_names", data=np.asarray(["pelvis"], dtype=object), dtype=str_dt) + h5.create_dataset("joint_pos", data=np.zeros((num_frames, 29), dtype=np.float32), chunks=True) + h5.create_dataset("joint_vel", data=np.zeros((num_frames, 29), dtype=np.float32), chunks=True) + h5.create_dataset("body_pos_w", data=np.zeros((num_frames, 1, 3), dtype=np.float32), chunks=True) + h5.create_dataset( + "body_quat_w", + data=np.tile( + np.asarray([[[1.0, 0.0, 0.0, 0.0]]], dtype=np.float32), + (num_frames, 1, 1), + ), + chunks=True, + ) + h5.create_dataset("body_lin_vel_w", data=np.zeros((num_frames, 1, 3), dtype=np.float32), chunks=True) + h5.create_dataset("body_ang_vel_w", data=np.zeros((num_frames, 1, 3), dtype=np.float32), chunks=True) + h5.create_dataset("clip_starts", data=np.asarray([0], dtype=np.int64)) + h5.create_dataset("clip_lengths", data=np.asarray([num_frames], dtype=np.int64)) + h5.create_dataset("clip_fps", data=np.asarray([30], dtype=np.int64)) + h5.create_dataset("source_clip_ids", data=np.asarray([0], dtype=np.int64)) + h5.create_dataset("source_start_frames", data=np.asarray([0], dtype=np.int64)) + h5.create_dataset("source_clip_starts", data=np.asarray([0], dtype=np.int64)) + h5.create_dataset("source_clip_lengths", data=np.asarray([num_frames], dtype=np.int64)) + h5.create_dataset("source_clip_fps", data=np.asarray([30], dtype=np.int64)) validate_motion_file(str(tmp_path)) diff --git a/train_mimic/app.py b/train_mimic/app.py index 7c7033a6..9d08250f 100644 --- a/train_mimic/app.py +++ b/train_mimic/app.py @@ -11,20 +11,22 @@ GENERAL_TRACKING_TASK, SUPPORTED_TASKS, ) -from train_mimic.data.dataset_lib import find_motion_shards +from train_mimic.data.dataset_lib import find_precomputed_motion_shards, validate_precomputed_motion_dataset DEFAULT_TASK = GENERAL_TRACKING_TASK def validate_motion_file(motion_file: str) -> None: try: - find_motion_shards(Path(motion_file)) + find_precomputed_motion_shards(Path(motion_file)) except FileNotFoundError as exc: raise FileNotFoundError( f"Motion dataset not found: {motion_file}. Provide --motion_file pointing " - "to a dataset root directory containing Teleopit shard_*.h5 files " - f"(recursively allowed). Example: {DEFAULT_TRAIN_MOTION_FILE}" + "to a precomputed training dataset root produced by " + "train_mimic/scripts/data/precompute_dataset.py. Example: " + f"{DEFAULT_TRAIN_MOTION_FILE}" ) from exc + validate_precomputed_motion_dataset(Path(motion_file)) def validate_checkpoint_path(checkpoint_path: str) -> None: diff --git a/train_mimic/data/dataset_lib.py b/train_mimic/data/dataset_lib.py index ae967be4..45696d60 100644 --- a/train_mimic/data/dataset_lib.py +++ b/train_mimic/data/dataset_lib.py @@ -30,6 +30,13 @@ "joint_pos", "joint_vel", "body_pos_w", "body_quat_w", "body_lin_vel_w", "body_ang_vel_w", ] +PRECOMPUTED_MOTION_ARRAY_KEYS = [ + "joint_vel", + "body_pos_w", + "body_quat_w", + "body_lin_vel_w", + "body_ang_vel_w", +] MINIMAL_MOTION_ARRAY_KEYS = ["root_pos", "root_quat_w", "joint_pos"] FULL_CLIP_ARRAY_KEYS = [ "root_pos", @@ -42,6 +49,9 @@ "body_ang_vel_w", ] HDF5_DATASET_VERSION = 1 +PRECOMPUTED_MOTION_VERSION = 2 +MINIMAL_HDF5_FORMAT = "teleopit_motion_hdf5" +PRECOMPUTED_HDF5_FORMAT = "teleopit_precomputed_motion_hdf5" DEFAULT_HDF5_MAX_WINDOW_FRAMES = 512 DEFAULT_HDF5_WINDOW_OVERLAP_FRAMES = 64 @@ -499,7 +509,7 @@ def write_hdf5_motion_shard( output_path.parent.mkdir(parents=True, exist_ok=True) str_dt = h5py.string_dtype(encoding="utf-8") with h5py.File(output_path, "w") as h5: - h5.attrs["format"] = "teleopit_motion_hdf5" + h5.attrs["format"] = MINIMAL_HDF5_FORMAT h5.attrs["version"] = HDF5_DATASET_VERSION h5.attrs["fps"] = fps h5.attrs["max_window_frames"] = int(max_window_frames) @@ -539,6 +549,232 @@ def read_hdf5_body_names(path: Path) -> list[str]: ] +def _read_hdf5_string_list(h5: h5py.File, key: str) -> list[str]: + return [ + str(name.decode("utf-8") if isinstance(name, bytes) else name) + for name in h5[key][()] + ] + + +def validate_precomputed_motion_shard( + shard_path: str | Path, +) -> None: + """Validate that a shard is a standalone precomputed training shard.""" + shard = Path(shard_path).expanduser().resolve() + if not shard.is_file(): + raise FileNotFoundError(f"precomputed motion shard not found: {shard}") + + with h5py.File(shard, "r") as h5: + if h5.attrs.get("format") != PRECOMPUTED_HDF5_FORMAT: + raise ValueError( + f"motion dataset must be precomputed before training: {shard}. " + "Run: python train_mimic/scripts/data/precompute_dataset.py " + " --outdir , then pass " + "--motion_file ." + ) + if int(h5.attrs.get("version", 0)) != PRECOMPUTED_MOTION_VERSION: + raise ValueError( + f"unsupported precomputed motion version in {shard}: " + f"{h5.attrs.get('version')}. Regenerate with precompute_dataset.py." + ) + missing = [ + key for key in [ + "joint_pos", + *PRECOMPUTED_MOTION_ARRAY_KEYS, + "body_names", + "clip_starts", + "clip_lengths", + "clip_fps", + "source_clip_ids", + "source_start_frames", + "source_clip_starts", + "source_clip_lengths", + "source_clip_fps", + ] + if key not in h5 + ] + if missing: + raise ValueError( + f"precomputed motion shard {shard} missing required datasets: {missing}. " + "Regenerate with precompute_dataset.py." + ) + + body_names = _read_hdf5_string_list(h5, "body_names") + frames = int(h5["joint_pos"].shape[0]) + num_bodies = len(body_names) + expected_shapes = { + "joint_pos": (frames, NUM_ACTIONS), + "joint_vel": (frames, NUM_ACTIONS), + "body_pos_w": (frames, num_bodies, 3), + "body_quat_w": (frames, num_bodies, 4), + "body_lin_vel_w": (frames, num_bodies, 3), + "body_ang_vel_w": (frames, num_bodies, 3), + } + for key, shape in expected_shapes.items(): + if h5[key].shape != shape: + raise ValueError( + f"precomputed motion shard {shard} dataset {key} shape mismatch: " + f"expected {shape}, got {h5[key].shape}. Regenerate with precompute_dataset.py." + ) + windows = int(h5["clip_starts"].shape[0]) + for key in ["clip_lengths", "clip_fps", "source_clip_ids", "source_start_frames"]: + if h5[key].shape != (windows,): + raise ValueError( + f"precomputed motion shard {shard} dataset {key} shape mismatch: " + f"expected {(windows,)}, got {h5[key].shape}. Regenerate with precompute_dataset.py." + ) + source_clips = int(h5["source_clip_starts"].shape[0]) + for key in ["source_clip_lengths", "source_clip_fps"]: + if h5[key].shape != (source_clips,): + raise ValueError( + f"precomputed motion shard {shard} dataset {key} shape mismatch: " + f"expected {(source_clips,)}, got {h5[key].shape}. Regenerate with precompute_dataset.py." + ) + + +def validate_precomputed_motion_dataset(dataset_dir_or_shard: str | Path) -> None: + for shard_path in find_precomputed_motion_shards(dataset_dir_or_shard): + validate_precomputed_motion_shard(shard_path) + + +def write_precomputed_motion_shard( + shard_path: str | Path, + output_path: str | Path, + *, + force: bool = False, + model_path: str | Path | None = None, +) -> dict[str, Any]: + """Write one standalone precomputed training shard from a minimal source shard.""" + from train_mimic.data.motion_fk import ( + MotionFkExtractor, + compute_body_velocities, + finite_diff_velocity, + ) + + source_path = Path(shard_path).expanduser().resolve() + output = Path(output_path).expanduser().resolve() + if output.is_file() and not force: + validate_precomputed_motion_shard(output) + return { + "shard": str(source_path), + "output": str(output), + "status": "existing", + } + + extractor = MotionFkExtractor(model_path) + with h5py.File(source_path, "r") as source: + if source.attrs.get("format") != MINIMAL_HDF5_FORMAT: + raise ValueError( + f"precompute input must be a minimal Teleopit motion shard, got {source_path}" + ) + required = [ + "root_pos", + "root_quat_w", + "joint_pos", + "body_names", + "clip_starts", + "clip_lengths", + "clip_fps", + "source_clip_ids", + "source_start_frames", + "source_clip_starts", + "source_clip_lengths", + "source_clip_fps", + ] + missing = [key for key in required if key not in source] + if missing: + raise ValueError( + f"HDF5 shard {source_path} missing required datasets for precompute: {missing}. " + "Rebuild the dataset with the current HDF5 writer." + ) + + root_pos = np.asarray(source["root_pos"], dtype=np.float32) + root_quat_w = np.asarray(source["root_quat_w"], dtype=np.float32) + joint_pos = np.asarray(source["joint_pos"], dtype=np.float32) + body_names = np.asarray(_read_hdf5_string_list(source, "body_names"), dtype=str) + source_starts = np.asarray(source["source_clip_starts"], dtype=np.int64) + source_lengths = np.asarray(source["source_clip_lengths"], dtype=np.int64) + source_fps = np.asarray(source["source_clip_fps"], dtype=np.int64) + + frames = int(joint_pos.shape[0]) + num_bodies = int(body_names.shape[0]) + joint_vel = np.empty_like(joint_pos, dtype=np.float32) + body_pos_w = np.empty((frames, num_bodies, 3), dtype=np.float32) + body_quat_w = np.empty((frames, num_bodies, 4), dtype=np.float32) + body_lin_vel_w = np.empty((frames, num_bodies, 3), dtype=np.float32) + body_ang_vel_w = np.empty((frames, num_bodies, 3), dtype=np.float32) + + for start, length, fps in zip(source_starts, source_lengths, source_fps): + start_i = int(start) + length_i = int(length) + fps_i = int(fps) + if length_i <= 0: + raise ValueError(f"invalid source clip length {length_i} in {source_path}") + if fps_i <= 0: + raise ValueError(f"invalid source clip fps {fps_i} in {source_path}") + sl = slice(start_i, start_i + length_i) + dt = 1.0 / float(fps_i) + cur_body_pos_w, cur_body_quat_w = extractor.extract( + root_pos[sl], + root_quat_w[sl], + joint_pos[sl], + body_names, + ) + cur_body_lin_vel_w, cur_body_ang_vel_w = compute_body_velocities( + cur_body_pos_w, + cur_body_quat_w, + dt, + ) + joint_vel[sl] = finite_diff_velocity(joint_pos[sl], dt) + body_pos_w[sl] = cur_body_pos_w + body_quat_w[sl] = cur_body_quat_w + body_lin_vel_w[sl] = cur_body_lin_vel_w + body_ang_vel_w[sl] = cur_body_ang_vel_w + + clip_starts = np.asarray(source["clip_starts"], dtype=np.int64) + clip_lengths = np.asarray(source["clip_lengths"], dtype=np.int64) + clip_fps = np.asarray(source["clip_fps"], dtype=np.int64) + source_clip_ids = np.asarray(source["source_clip_ids"], dtype=np.int64) + source_start_frames = np.asarray(source["source_start_frames"], dtype=np.int64) + + output.parent.mkdir(parents=True, exist_ok=True) + tmp_path = output.with_suffix(f"{output.suffix}.tmp") + if tmp_path.exists(): + tmp_path.unlink() + str_dt = h5py.string_dtype(encoding="utf-8") + with h5py.File(tmp_path, "w") as h5: + h5.attrs["format"] = PRECOMPUTED_HDF5_FORMAT + h5.attrs["version"] = PRECOMPUTED_MOTION_VERSION + h5.attrs["source_shard"] = str(source_path) + h5.attrs["created_at"] = utc_now_iso() + h5.create_dataset("body_names", data=body_names.astype(object), dtype=str_dt) + h5.create_dataset("joint_pos", data=joint_pos, chunks=True) + h5.create_dataset("joint_vel", data=joint_vel, chunks=True) + h5.create_dataset("body_pos_w", data=body_pos_w, chunks=True) + h5.create_dataset("body_quat_w", data=body_quat_w, chunks=True) + h5.create_dataset("body_lin_vel_w", data=body_lin_vel_w, chunks=True) + h5.create_dataset("body_ang_vel_w", data=body_ang_vel_w, chunks=True) + h5.create_dataset("clip_starts", data=clip_starts) + h5.create_dataset("clip_lengths", data=clip_lengths) + h5.create_dataset("clip_fps", data=clip_fps) + h5.create_dataset("source_clip_ids", data=source_clip_ids) + h5.create_dataset("source_start_frames", data=source_start_frames) + h5.create_dataset("source_clip_starts", data=source_starts) + h5.create_dataset("source_clip_lengths", data=source_lengths) + h5.create_dataset("source_clip_fps", data=source_fps) + tmp_path.replace(output) + validate_precomputed_motion_shard(output) + + return { + "shard": str(source_path), + "output": str(output), + "status": "written", + "frames": frames, + "bodies": num_bodies, + "arrays": ["joint_pos", *PRECOMPUTED_MOTION_ARRAY_KEYS], + } + + def read_motion_clip(path: Path, clip_index: int) -> dict[str, Any]: """Read one source clip from a current HDF5 motion shard path. @@ -605,8 +841,12 @@ def read_hdf5_source_clip(path: Path, clip_index: int) -> dict[str, Any]: } -def find_motion_shards(dataset_dir: str | Path) -> list[Path]: - """Recursively find Teleopit HDF5 motion shards under a root directory.""" +def _find_hdf5_shards_by_format( + dataset_dir: str | Path, + *, + expected_format: str, + label: str, +) -> list[Path]: root = Path(dataset_dir).expanduser().resolve() if root.is_file(): candidates = [root] @@ -619,17 +859,40 @@ def find_motion_shards(dataset_dir: str | Path) -> list[Path]: for path in candidates: try: with h5py.File(path, "r") as h5: - if h5.attrs.get("format") == "teleopit_motion_hdf5": + if h5.attrs.get("format") == expected_format: shards.append(path) except OSError: continue if not shards: - raise FileNotFoundError(f"no Teleopit HDF5 motion shards found under {dataset_dir}") + raise FileNotFoundError(f"no {label} HDF5 motion shards found under {dataset_dir}") return shards -def compute_dataset_stats(dataset_dir: str | Path) -> dict[str, Any]: - shards = find_motion_shards(dataset_dir) +def find_motion_shards(dataset_dir: str | Path) -> list[Path]: + """Recursively find minimal Teleopit HDF5 motion shards under a root directory.""" + return _find_hdf5_shards_by_format( + dataset_dir, + expected_format=MINIMAL_HDF5_FORMAT, + label="minimal Teleopit", + ) + + +def find_precomputed_motion_shards(dataset_dir: str | Path) -> list[Path]: + """Recursively find precomputed Teleopit HDF5 training shards.""" + return _find_hdf5_shards_by_format( + dataset_dir, + expected_format=PRECOMPUTED_HDF5_FORMAT, + label="precomputed Teleopit", + ) + + +def compute_dataset_stats( + dataset_dir: str | Path, + *, + precomputed: bool = False, +) -> dict[str, Any]: + shards = find_precomputed_motion_shards(dataset_dir) if precomputed else find_motion_shards(dataset_dir) + array_keys = MOTION_ARRAY_KEYS if precomputed else MINIMAL_MOTION_ARRAY_KEYS total_windows = 0 total_frames = 0 total_duration_s = 0.0 @@ -642,11 +905,15 @@ def compute_dataset_stats(dataset_dir: str | Path) -> dict[str, Any]: with h5py.File(shard_path, "r") as h5: missing = [ key for key in [ - *MINIMAL_MOTION_ARRAY_KEYS, + *array_keys, "body_names", "clip_starts", "clip_lengths", "clip_fps", + "source_clip_ids", + "source_clip_starts", + "source_clip_lengths", + "source_clip_fps", ] if key not in h5 ] @@ -661,22 +928,14 @@ def compute_dataset_stats(dataset_dir: str | Path) -> dict[str, Any]: ) lengths = np.asarray(h5["clip_lengths"], dtype=np.int64) fps_arr = np.asarray(h5["clip_fps"], dtype=np.int64) - source_ids = ( - np.asarray(h5["source_clip_ids"], dtype=np.int64) - if "source_clip_ids" in h5 - else np.arange(lengths.shape[0], dtype=np.int64) - ) + source_ids = np.asarray(h5["source_clip_ids"], dtype=np.int64) fps_values.update(int(v) for v in np.unique(fps_arr)) windows = int(lengths.shape[0]) frames = int(np.asarray(h5["joint_pos"]).shape[0]) source_clips = int(len(np.unique(source_ids))) - if "source_clip_lengths" in h5 and "source_clip_fps" in h5: - source_lengths = np.asarray(h5["source_clip_lengths"], dtype=np.float64) - source_fps = np.asarray(h5["source_clip_fps"], dtype=np.float64) - shard_duration_s = float(np.sum(source_lengths / np.maximum(source_fps, 1.0))) - else: - shard_fps = int(h5.attrs.get("fps", fps_arr[0] if fps_arr.shape[0] else 1)) - shard_duration_s = float(frames / max(shard_fps, 1)) + source_lengths = np.asarray(h5["source_clip_lengths"], dtype=np.float64) + source_fps = np.asarray(h5["source_clip_fps"], dtype=np.float64) + shard_duration_s = float(np.sum(source_lengths / np.maximum(source_fps, 1.0))) total_windows += windows total_frames += frames total_duration_s += shard_duration_s @@ -692,8 +951,8 @@ def compute_dataset_stats(dataset_dir: str | Path) -> dict[str, Any]: }) return { - "format": "teleopit_motion_hdf5", - "version": HDF5_DATASET_VERSION, + "format": PRECOMPUTED_HDF5_FORMAT if precomputed else MINIMAL_HDF5_FORMAT, + "version": PRECOMPUTED_MOTION_VERSION if precomputed else HDF5_DATASET_VERSION, "root": str(Path(dataset_dir).expanduser().resolve()), "shards": len(shards), "windows": total_windows, diff --git a/train_mimic/scripts/benchmark.py b/train_mimic/scripts/benchmark.py index 005b243f..50c59f1a 100644 --- a/train_mimic/scripts/benchmark.py +++ b/train_mimic/scripts/benchmark.py @@ -10,7 +10,7 @@ # Benchmark only (no video) python train_mimic/scripts/benchmark.py \ --checkpoint logs/rsl_rl/g1_tracking/.../model_30000.pt \ - --motion_file data/datasets/twist2 \ + --motion_file data/datasets/seed_precomputed \ --num_envs 1 # Single video (one continuous clip) @@ -40,7 +40,7 @@ validate_checkpoint_path, validate_motion_file, ) -from train_mimic.data.dataset_lib import find_motion_shards +from train_mimic.data.dataset_lib import find_precomputed_motion_shards from teleopit.debug.rollout_trace import RolloutTraceWriter @@ -144,7 +144,7 @@ def _stats(values: list[float]) -> dict[str, float]: def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Benchmark G1 tracking policy.") parser.add_argument("--checkpoint", type=str, required=True) - parser.add_argument("--motion_file", type=str, required=True, help="Path to dataset root containing Teleopit shard_*.h5 files") + parser.add_argument("--motion_file", type=str, required=True, help="Path to precomputed training dataset root containing Teleopit shard_*.h5 files") parser.add_argument("--num_envs", type=int, default=1) parser.add_argument("--num_eval_steps", type=int, default=2000, help="Number of rollout steps for evaluation (default: 2000)") @@ -176,7 +176,7 @@ def parse_args() -> argparse.Namespace: def _load_motion_dir_video_metadata(motion_dir: str) -> tuple[float, int]: clip_fps: float | None = None max_clip_frames = 0 - for shard_path in find_motion_shards(motion_dir): + for shard_path in find_precomputed_motion_shards(motion_dir): with h5py.File(shard_path, "r") as h5: fps_arr = np.asarray(h5["clip_fps"], dtype=np.float32) if fps_arr.size == 0: diff --git a/train_mimic/scripts/data/precompute_dataset.py b/train_mimic/scripts/data/precompute_dataset.py new file mode 100644 index 00000000..323cf009 --- /dev/null +++ b/train_mimic/scripts/data/precompute_dataset.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python3 +"""Precompute derived FK and velocity arrays for minimal Teleopit HDF5 shards.""" + +from __future__ import annotations + +import argparse +import json +from concurrent.futures import ProcessPoolExecutor, as_completed +from pathlib import Path +from typing import Any + +from train_mimic.data.dataset_lib import ( + find_motion_shards, + write_precomputed_motion_shard, +) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Precompute FK and velocity arrays for a minimal Teleopit HDF5 dataset." + ) + parser.add_argument("dataset", type=str, help="Dataset root directory or a single .h5 shard") + parser.add_argument( + "--outdir", + type=str, + default=None, + help="Output training dataset directory. Defaults to _precomputed.", + ) + parser.add_argument("--force", action="store_true", help="Overwrite existing precomputed shards") + parser.add_argument("--jobs", type=int, default=1, help="Number of shard-level worker processes") + parser.add_argument( + "--model_path", + type=str, + default=None, + help="Optional MuJoCo XML override for FK precompute", + ) + parser.add_argument("--json", action="store_true", help="Print final report JSON") + return parser.parse_args() + + +def main() -> int: + args = parse_args() + if args.jobs <= 0: + raise ValueError(f"--jobs must be positive, got {args.jobs}") + + shard_paths = find_motion_shards(args.dataset) + source_root = Path(args.dataset).expanduser().resolve() + if source_root.is_file(): + default_outdir = source_root.parent.with_name(f"{source_root.parent.name}_precomputed") + else: + default_outdir = source_root.with_name(f"{source_root.name}_precomputed") + outdir = Path(args.outdir).expanduser().resolve() if args.outdir is not None else default_outdir + outdir.mkdir(parents=True, exist_ok=True) + + def output_path_for(source_path: Path) -> Path: + rel = Path(source_path.name) if source_root.is_file() else source_path.relative_to(source_root) + return outdir / rel + + results: list[dict[str, Any]] = [] + + if args.jobs == 1: + for shard_path in shard_paths: + result = write_precomputed_motion_shard( + shard_path, + output_path_for(shard_path), + force=args.force, + model_path=args.model_path, + ) + results.append(result) + print(f"[{result['status'].upper()}] {result['output']}") + else: + with ProcessPoolExecutor(max_workers=args.jobs) as executor: + futures = { + executor.submit( + write_precomputed_motion_shard, + shard_path, + output_path_for(shard_path), + force=args.force, + model_path=args.model_path, + ): shard_path + for shard_path in shard_paths + } + for future in as_completed(futures): + result = future.result() + results.append(result) + print(f"[{result['status'].upper()}] {result['output']}") + + report = { + "dataset": str(Path(args.dataset).expanduser().resolve()), + "outdir": str(outdir), + "shards": len(shard_paths), + "written": sum(1 for item in results if item["status"] == "written"), + "existing": sum(1 for item in results if item["status"] == "existing"), + "results": sorted(results, key=lambda item: item["shard"]), + } + print( + f"[DONE] shards={report['shards']} written={report['written']} " + f"existing={report['existing']}" + ) + if args.json: + print(json.dumps(report, ensure_ascii=True, indent=2)) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/train_mimic/scripts/play.py b/train_mimic/scripts/play.py index 26546f22..8ad5a9d7 100644 --- a/train_mimic/scripts/play.py +++ b/train_mimic/scripts/play.py @@ -9,18 +9,18 @@ # Native window python train_mimic/scripts/play.py \ --checkpoint logs/rsl_rl/g1_tracking/2026-.../model_30000.pt \ - --motion_file data/datasets/twist2 + --motion_file data/datasets/seed_precomputed # Browser viewer (no display required) python train_mimic/scripts/play.py \ --checkpoint logs/rsl_rl/g1_tracking/2026-.../model_30000.pt \ - --motion_file data/datasets/twist2 \ + --motion_file data/datasets/seed_precomputed \ --viewer viser # Record video instead of interactive viewer python train_mimic/scripts/play.py \ --checkpoint logs/rsl_rl/g1_tracking/2026-.../model_30000.pt \ - --motion_file data/datasets/twist2 \ + --motion_file data/datasets/seed_precomputed \ --video """ @@ -44,7 +44,7 @@ def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Play trained G1 tracking policy.") parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint") - parser.add_argument("--motion_file", type=str, required=True, help="Path to dataset root containing Teleopit shard_*.h5 files") + parser.add_argument("--motion_file", type=str, required=True, help="Path to precomputed training dataset root containing Teleopit shard_*.h5 files") parser.add_argument("--num_envs", type=int, default=1) parser.add_argument( "--viewer", type=str, default="native", choices=["native", "viser"], diff --git a/train_mimic/scripts/train.py b/train_mimic/scripts/train.py index 3a918cee..dac2c0ac 100644 --- a/train_mimic/scripts/train.py +++ b/train_mimic/scripts/train.py @@ -4,30 +4,30 @@ Usage: python train_mimic/scripts/train.py \ --num_envs 4096 --max_iterations 18000 \ - --motion_file data/datasets/twist2 + --motion_file data/datasets/seed_precomputed # Quick verification python train_mimic/scripts/train.py \ --num_envs 64 --max_iterations 100 \ - --motion_file data/datasets/twist2 + --motion_file data/datasets/seed_precomputed # With W&B logging python train_mimic/scripts/train.py \ --num_envs 4096 --max_iterations 30000 \ - --motion_file data/datasets/twist2 \ + --motion_file data/datasets/seed_precomputed \ --logger wandb # With SwanLab logging python train_mimic/scripts/train.py \ --num_envs 4096 --max_iterations 30000 \ - --motion_file data/datasets/twist2 \ + --motion_file data/datasets/seed_precomputed \ --logger swanlab # Resume for additional iterations python train_mimic/scripts/train.py \ --resume logs/rsl_rl/g1_general_tracking//model_12000.pt \ --max_iterations 18000 \ - --motion_file data/datasets/twist2 + --motion_file data/datasets/seed_precomputed """ from __future__ import annotations @@ -78,7 +78,7 @@ def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace: ) parser.add_argument("--experiment_name", type=str, default=None) parser.add_argument("--motion_file", type=str, default=None, - help="Dataset root containing Teleopit shard_*.h5 files, searched recursively") + help="Precomputed training dataset root containing Teleopit shard_*.h5 files, searched recursively") parser.add_argument( "--robot_xml", type=str, diff --git a/train_mimic/tasks/tracking/config/constants.py b/train_mimic/tasks/tracking/config/constants.py index cdba0dc9..17560fa5 100644 --- a/train_mimic/tasks/tracking/config/constants.py +++ b/train_mimic/tasks/tracking/config/constants.py @@ -1,6 +1,6 @@ """Public constants for supported tracking tasks.""" -DEFAULT_TRAIN_MOTION_FILE = "data/datasets/twist2" +DEFAULT_TRAIN_MOTION_FILE = "data/datasets/seed_precomputed" GENERAL_TRACKING_TASK = "General-Tracking-G1" GENERAL_TRACKING_EXPERIMENT_NAME = "g1_general_tracking" diff --git a/train_mimic/tasks/tracking/mdp/commands.py b/train_mimic/tasks/tracking/mdp/commands.py index 5ec1c017..f695db84 100644 --- a/train_mimic/tasks/tracking/mdp/commands.py +++ b/train_mimic/tasks/tracking/mdp/commands.py @@ -17,10 +17,10 @@ MOTION_ARRAY_KEYS, compute_clip_sample_ranges, compute_dataset_stats, - find_motion_shards, + find_precomputed_motion_shards, parse_window_steps, + validate_precomputed_motion_shard, ) -from train_mimic.data.motion_fk import MotionFkExtractor, compute_body_velocities, finite_diff_velocity from mjlab.managers import CommandTerm, CommandTermCfg from mjlab.utils.lab_api.math import ( @@ -140,15 +140,12 @@ def __init__( refs: list[_Hdf5ClipRef], shard_paths: list[Path], body_idx_np: np.ndarray, - body_names: np.ndarray, window_steps: tuple[int, ...], ) -> None: self.refs = refs self._shard_paths = shard_paths self.body_idx_np = body_idx_np - self.body_names = body_names self.window_steps = window_steps - self._fk_extractor = MotionFkExtractor() self._h5_handles: dict[int, h5py.File] = {} def __len__(self) -> int: @@ -158,30 +155,27 @@ def __getitem__(self, index: int) -> _MotionClipSample: ref = self.refs[int(index)] sl = slice(ref.start, ref.start + ref.length) h5 = self._h5_handle(ref.shard_index) - root_pos = np.asarray(h5["root_pos"][sl], dtype=np.float32) - root_quat_w = np.asarray(h5["root_quat_w"][sl], dtype=np.float32) joint_pos = np.asarray(h5["joint_pos"][sl], dtype=np.float32) - dt = 1.0 / float(ref.fps) - body_pos_w, body_quat_w = self._fk_extractor.extract( - root_pos, - root_quat_w, - joint_pos, - self.body_names, - ) - body_lin_vel_w, body_ang_vel_w = compute_body_velocities(body_pos_w, body_quat_w, dt) - joint_vel = finite_diff_velocity(joint_pos, dt) sample_starts, sample_ends = compute_clip_sample_ranges( np.asarray([ref.length], dtype=np.int64), window_steps=self.window_steps, ) tensors = { "joint_pos": torch.from_numpy(joint_pos), - "joint_vel": torch.from_numpy(joint_vel), - "body_pos_w": torch.from_numpy(body_pos_w[:, self.body_idx_np]), - "body_quat_w": torch.from_numpy(body_quat_w[:, self.body_idx_np]), - "body_lin_vel_w": torch.from_numpy(body_lin_vel_w[:, self.body_idx_np]), - "body_ang_vel_w": torch.from_numpy(body_ang_vel_w[:, self.body_idx_np]), + "joint_vel": torch.from_numpy(np.asarray(h5["joint_vel"][sl], dtype=np.float32)), + "body_pos_w": torch.from_numpy( + np.asarray(h5["body_pos_w"][sl], dtype=np.float32)[:, self.body_idx_np] + ), + "body_quat_w": torch.from_numpy( + np.asarray(h5["body_quat_w"][sl], dtype=np.float32)[:, self.body_idx_np] + ), + "body_lin_vel_w": torch.from_numpy( + np.asarray(h5["body_lin_vel_w"][sl], dtype=np.float32)[:, self.body_idx_np] + ), + "body_ang_vel_w": torch.from_numpy( + np.asarray(h5["body_ang_vel_w"][sl], dtype=np.float32)[:, self.body_idx_np] + ), } return _MotionClipSample( tensors=tensors, @@ -263,10 +257,11 @@ def __init__( self._device = torch.device(device) self._copy_stream: torch.cuda.Stream | None = None self._next_ready_event: torch.cuda.Event | None = None - self._shard_paths = find_motion_shards(motion_dir) - stats = compute_dataset_stats(motion_dir) + self._shard_paths = find_precomputed_motion_shards(motion_dir) + stats = compute_dataset_stats(motion_dir, precomputed=True) self.body_names = np.asarray(stats["body_names"], dtype=str) - self._fk_extractor = MotionFkExtractor() + for shard_path in self._shard_paths: + validate_precomputed_motion_shard(shard_path) _LOG.info( "Motion dataset: root=%s shards=%d windows=%d source_clips=%d frames=%d fps=%s", motion_dir, @@ -329,7 +324,6 @@ def __init__( refs=self.refs, shard_paths=self._shard_paths, body_idx_np=self.body_idx_np, - body_names=self.body_names, window_steps=self.window_steps, ) self._sampler = _WeightedInfiniteClipBatchSampler( @@ -409,7 +403,7 @@ def _wait_next_ready(self) -> None: torch.cuda.current_stream(self._device).wait_event(self._next_ready_event) self._next_ready_event = None - def _materialize_legacy_batch(self, global_ids: torch.Tensor) -> _MotionBatch: + def _materialize_batch_by_global_ids(self, global_ids: torch.Tensor) -> _MotionBatch: samples = [self._dataset[int(idx)] for idx in global_ids.tolist()] batch = _collate_motion_clips(samples) return self._stage_batch(batch, wait=True) @@ -419,10 +413,10 @@ def _sample_global_ids(self) -> torch.Tensor: return torch.tensor(ids, dtype=torch.long) def _load_random_batch(self) -> _MotionBatch: - return self._materialize_legacy_batch(self._sample_global_ids()) + return self._materialize_batch_by_global_ids(self._sample_global_ids()) def _load_batch(self, global_ids: torch.Tensor) -> _MotionBatch: - return self._materialize_legacy_batch(global_ids.cpu().to(dtype=torch.long)) + return self._materialize_batch_by_global_ids(global_ids.cpu().to(dtype=torch.long)) def advance(self) -> None: start = time.perf_counter() @@ -483,7 +477,7 @@ def __init__( raise FileNotFoundError( f"motion_file must be a dataset root directory or .h5 shard, got: {motion_file}" ) - stats = compute_dataset_stats(motion_path) + stats = compute_dataset_stats(motion_path, precomputed=True) if body_names is None: body_idx_np = body_indexes.cpu().numpy() From 10338131a877cf04949f0a945ef39dac969a0cf7 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Tue, 16 Jun 2026 22:11:45 +0800 Subject: [PATCH 096/122] Load precomputed motions fully at startup --- AGENTS.md | 6 +- docs/docs/reference/dataset.md | 2 +- docs/docs/tutorials/training.md | 2 +- .../current/reference/dataset.md | 2 +- .../current/tutorials/training.md | 2 +- tests/test_motion_sampling.py | 16 +- tests/test_train_script.py | 2 - train_mimic/scripts/train.py | 21 - train_mimic/tasks/tracking/mdp/commands.py | 534 ++++-------------- train_mimic/tasks/tracking/rl/runner.py | 8 - 10 files changed, 129 insertions(+), 466 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 81d436c4..da370a6b 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -206,10 +206,10 @@ The single supported training task is `General-Tracking-G1` (experiment name: `g - Final distributed dataset build outputs are minimal HDF5 shards directly under `data/datasets//` (recursive shard discovery is supported; no train/val split and no manifest file) - `train_mimic/scripts/data/precompute_dataset.py` converts a minimal dataset into a separate precomputed training dataset directory; `build_dataset.py` must not run precompute - Each shard stores only `root_pos`, `root_quat_w`, `joint_pos`, `body_names`, and clip-aware window metadata (`clip_starts`, `clip_lengths`, `clip_fps`); long clips are split into overlapping bounded windows -- Training `motion_file` must point to a precomputed training dataset, not the minimal distributed dataset; training reads joint velocities and body FK/velocities from those precomputed shards and must not run MuJoCo FK while loading motion clips into the fixed-size cache -- `MotionLib` loads only a configurable precomputed HDF5 subset cache into CPU/GPU memory, asynchronously stages the next cache, and swaps caches at the PPO rollout barrier +- Training `motion_file` must point to a precomputed training dataset, not the minimal distributed dataset; training reads joint velocities and body FK/velocities from those precomputed shards and must not run MuJoCo FK while loading motion clips +- `MotionLib` loads all discovered precomputed HDF5 motion windows into CPU/GPU memory at startup - `MotionLib` samples only valid center frames for the configured `window_steps`; default is `window_steps=[0]` -- Training supports `uniform` and `rewind` sampling on the active cache; in distributed training each rank sets a rank-offset `cache_seed` +- Training supports `uniform` and `rewind` sampling over the fully loaded precomputed dataset - `scripts/run/record_pico_motion.py` records Pico live body tracking as retargeted G1 motion NPZ clips in `data/pico_motion/clips/`; it opens a live `Retarget` viewer, uses terminal keys `R/S/D/N/Q`, stores semantic labels in filenames, and intentionally does not write per-clip JSON - Build Pico-recorded clips into shards with `python train_mimic/scripts/data/build_dataset.py --spec data/pico_motion/pico_recorded.yaml --force` diff --git a/docs/docs/reference/dataset.md b/docs/docs/reference/dataset.md index 0f3752cb..b5580847 100644 --- a/docs/docs/reference/dataset.md +++ b/docs/docs/reference/dataset.md @@ -72,7 +72,7 @@ data/datasets/_precomputed/ - `build_dataset.py` only writes the minimal distributable dataset. It does not run FK precompute. - `precompute_dataset.py` writes a separate training dataset containing the minimal motion plus precomputed joint velocities and body FK/velocities. - Training accepts only the precomputed dataset directory. It recursively discovers precomputed `*.h5` shards below the specified root, so precomputed datasets can be merged by placing multiple shard directories under one parent. -- Training loads only a subset cache from the discovered precomputed shards, stages the next cache asynchronously, and swaps caches at the PPO rollout barrier. Joint velocities and body FK/velocities are not computed during training. +- Training loads all discovered precomputed motion windows into memory at startup. Joint velocities and body FK/velocities are not computed during training. ## YAML Spec Format diff --git a/docs/docs/tutorials/training.md b/docs/docs/tutorials/training.md index 00d83e9a..cdd0efb7 100644 --- a/docs/docs/tutorials/training.md +++ b/docs/docs/tutorials/training.md @@ -84,7 +84,7 @@ torchrun \ - Default logger is TensorBoard. Use `--logger wandb` or `--logger swanlab` to select W&B or SwanLab; the project name defaults to `experiment_name` - `--motion_file` accepts a precomputed training dataset root directory or a single precomputed `.h5` shard; shard discovery is recursive - If you only have the minimal distributed shards, first run `python train_mimic/scripts/data/precompute_dataset.py --outdir ` and pass the precomputed output to training. -- `--cache_num_clips` controls how many precomputed HDF5 motion windows are loaded into the active subset cache; the next cache is staged asynchronously and swapped at rollout barriers. +- Training loads all discovered precomputed motion windows into memory at startup. - `--max_iterations` means additional iterations; resuming from `model_12000.pt` with `--max_iterations 18000` trains to `model_30000.pt` ## Export ONNX diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/dataset.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/dataset.md index f45ab730..eb497301 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/dataset.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/dataset.md @@ -70,7 +70,7 @@ data/datasets/_precomputed/ - `build_dataset.py` 只写最小分发数据集,不执行 FK 预计算。 - `precompute_dataset.py` 会写出独立的训练数据集,里面包含最小运动数据以及预计算的 joint velocity 和 body FK/velocity。 - 训练只接受预计算后的数据集目录。它会递归发现指定根目录下的预计算 `*.h5` shard,因此可以把多个预计算数据集目录放到同一个父目录下完成合并。 -- 训练只会从发现的预计算 shard 中加载 subset cache,异步 staging 下一个 cache,并在 PPO rollout barrier 切换 cache。joint velocity 和 body FK/velocity 不会在训练时计算。 +- 训练会在启动时把所有发现的预计算 motion window 全量加载到内存中。joint velocity 和 body FK/velocity 不会在训练时计算。 ## YAML spec diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md index ce7c82af..84a9b366 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md @@ -84,7 +84,7 @@ torchrun \ - 默认日志工具为 TensorBoard。使用 `--logger wandb` 或 `--logger swanlab` 可选择 W&B 或 SwanLab;项目名默认使用 `experiment_name` - `--motion_file` 接受预计算训练数据集根目录或单个预计算 `.h5` shard;shard 会递归发现 - 如果只有最小分发 shard,先运行 `python train_mimic/scripts/data/precompute_dataset.py --outdir `,再把预计算输出传给训练。 -- `--cache_num_clips` 控制加载到 active subset cache 的预计算 HDF5 motion window 数量;下一个 cache 会异步 staging,并在 rollout barrier 切换。 +- 训练会在启动时把所有发现的预计算 motion window 全量加载到内存中。 - `--max_iterations` 表示追加迭代次数;例如从 `model_12000.pt` 恢复训练并设置 `--max_iterations 18000`,最终将训练到 `model_30000.pt` ## 导出 ONNX diff --git a/tests/test_motion_sampling.py b/tests/test_motion_sampling.py index a834e634..ac98f1f9 100644 --- a/tests/test_motion_sampling.py +++ b/tests/test_motion_sampling.py @@ -161,7 +161,7 @@ def test_motion_lib_get_window_frames_returns_requested_offsets(tmp_path: Path) assert torch.allclose(current["joint_pos"][0, :1], torch.tensor([2.0], dtype=torch.float32)) -def test_motion_lib_flat_cache_offsets_select_requested_clip(tmp_path: Path) -> None: +def test_motion_lib_frame_offsets_select_requested_clip(tmp_path: Path) -> None: clip0 = _clip_dict(num_frames=3) clip1 = _clip_dict(num_frames=7) clip1["root_pos"] = np.asarray(clip1["root_pos"]).copy() @@ -176,11 +176,7 @@ def test_motion_lib_flat_cache_offsets_select_requested_clip(tmp_path: Path) -> str(motion_path), body_indexes=torch.tensor([0, 1], dtype=torch.long), window_steps=(0,), - cache_num_clips=2, - cache_seed=0, - dataloader_num_workers=0, ) - motion._set_batch(motion._cache._load_batch(torch.tensor([0, 1], dtype=torch.long))) frames = motion.get_frames( torch.tensor([1], dtype=torch.long), @@ -328,7 +324,7 @@ def test_motion_lib_window_start_and_end_times_follow_valid_center_range(tmp_pat assert torch.allclose(motion.clip_sample_end_s[motion_ids], torch.tensor([3.0])) -def test_motion_lib_global_cache_sampling_weights_follow_valid_duration(tmp_path: Path) -> None: +def test_motion_lib_sampling_weights_follow_valid_duration(tmp_path: Path) -> None: motion_path = _write_shard_dir( tmp_path / "motion_weighted", [ @@ -345,12 +341,12 @@ def test_motion_lib_global_cache_sampling_weights_follow_valid_duration(tmp_path ) assert torch.allclose( - motion._cache.global_sample_weights, - torch.tensor([0.2, 0.5, 1.0], dtype=torch.float32), + motion.sample_weights, + torch.tensor([0.2 / 1.7, 0.5 / 1.7, 1.0 / 1.7], dtype=torch.float32), ) -def test_motion_cache_sampler_draws_global_ids_randomly_by_valid_duration(tmp_path: Path) -> None: +def test_motion_lib_samples_ids_randomly_by_valid_duration(tmp_path: Path) -> None: motion_path = _write_shard_dir( tmp_path / "motion_weighted_global", [ @@ -365,7 +361,7 @@ def test_motion_cache_sampler_draws_global_ids_randomly_by_valid_duration(tmp_pa window_steps=(0,), ) - ids = torch.cat([motion._cache._sample_global_ids() for _ in range(256)]) + ids = motion.sample_motion_ids(2048) counts = torch.bincount(ids.cpu(), minlength=2).float() assert counts[1] > counts[0] * 3.0 diff --git a/tests/test_train_script.py b/tests/test_train_script.py index d0361f41..84657cb6 100644 --- a/tests/test_train_script.py +++ b/tests/test_train_script.py @@ -42,8 +42,6 @@ def _args(**overrides: object) -> argparse.Namespace: "rewind_prob": None, "rewind_min_steps": None, "rewind_max_steps": None, - "cache_num_clips": None, - "cache_swap_interval_steps": None, "device": None, "gpu_ids": None, "master_port": 29500, diff --git a/train_mimic/scripts/train.py b/train_mimic/scripts/train.py index dac2c0ac..4f8bf7b3 100644 --- a/train_mimic/scripts/train.py +++ b/train_mimic/scripts/train.py @@ -106,16 +106,6 @@ def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace: help="Minimum policy steps to rewind for rewind sampling") parser.add_argument("--rewind_max_steps", type=int, default=None, help="Maximum policy steps to rewind for rewind sampling") - parser.add_argument("--cache_num_clips", type=int, default=None, - help="Number of HDF5 motion windows to keep in the active subset cache") - parser.add_argument("--cache_swap_interval_steps", type=int, default=None, - help="Policy steps between HDF5 motion cache swaps; swaps occur at rollout barriers") - parser.add_argument("--cache_dataloader_num_workers", type=int, default=None, - help="Number of PyTorch DataLoader workers for asynchronous HDF5 motion cache loading") - parser.add_argument("--cache_dataloader_prefetch_factor", type=int, default=None, - help="PyTorch DataLoader prefetch factor for asynchronous HDF5 motion cache loading") - parser.add_argument("--cache_dataloader_pin_memory", action=argparse.BooleanOptionalAction, default=None, - help="Pin CPU motion cache batches before asynchronous CUDA staging") parser.add_argument("--device", type=str, default=None) parser.add_argument( "--gpu_ids", @@ -392,7 +382,6 @@ def _handle_shutdown(signum: int, _frame: Any) -> None: # CLI overrides env_cfg.seed = _resolve_worker_seed(args.seed) - env_cfg.commands["motion"].cache_seed = env_cfg.seed robot_xml = resolve_g1_training_xml(args.robot_xml) if not robot_xml.is_file(): raise FileNotFoundError(f"G1 training MuJoCo XML not found: {robot_xml}") @@ -411,16 +400,6 @@ def _handle_shutdown(signum: int, _frame: Any) -> None: env_cfg.commands["motion"].rewind_min_steps = args.rewind_min_steps if args.rewind_max_steps is not None: env_cfg.commands["motion"].rewind_max_steps = args.rewind_max_steps - if args.cache_num_clips is not None: - env_cfg.commands["motion"].cache_num_clips = args.cache_num_clips - if args.cache_swap_interval_steps is not None: - env_cfg.commands["motion"].cache_swap_interval_steps = args.cache_swap_interval_steps - if args.cache_dataloader_num_workers is not None: - env_cfg.commands["motion"].cache_dataloader_num_workers = args.cache_dataloader_num_workers - if args.cache_dataloader_prefetch_factor is not None: - env_cfg.commands["motion"].cache_dataloader_prefetch_factor = args.cache_dataloader_prefetch_factor - if args.cache_dataloader_pin_memory is not None: - env_cfg.commands["motion"].cache_dataloader_pin_memory = args.cache_dataloader_pin_memory if args.max_iterations is not None: agent_cfg.max_iterations = args.max_iterations if args.experiment_name is not None: diff --git a/train_mimic/tasks/tracking/mdp/commands.py b/train_mimic/tasks/tracking/mdp/commands.py index f695db84..28519f96 100644 --- a/train_mimic/tasks/tracking/mdp/commands.py +++ b/train_mimic/tasks/tracking/mdp/commands.py @@ -2,16 +2,14 @@ import copy import logging -import time from dataclasses import dataclass, field from pathlib import Path -from typing import TYPE_CHECKING, Any, Iterator, Literal +from typing import TYPE_CHECKING, Any, Literal import h5py import mujoco import numpy as np import torch -from torch.utils.data import DataLoader, Dataset, Sampler from train_mimic.data.dataset_lib import ( MOTION_ARRAY_KEYS, @@ -69,7 +67,6 @@ def _batched_quat_slerp( @dataclass(frozen=True) class _Hdf5ClipRef: - shard_index: int start: int length: int fps: int @@ -83,377 +80,120 @@ class _MotionBatch: fps: torch.Tensor sample_starts: torch.Tensor sample_ends: torch.Tensor - global_ids: torch.Tensor - - def pin_memory(self) -> "_MotionBatch": - return _MotionBatch( - tensors={key: value.pin_memory() for key, value in self.tensors.items()}, - frame_offsets=self.frame_offsets.pin_memory(), - lengths=self.lengths.pin_memory(), - fps=self.fps.pin_memory(), - sample_starts=self.sample_starts.pin_memory(), - sample_ends=self.sample_ends.pin_memory(), - global_ids=self.global_ids.pin_memory(), - ) - - -@dataclass -class _MotionClipSample: - tensors: dict[str, torch.Tensor] - length: int - fps: int - sample_start: int - sample_end: int - global_id: int - -class _WeightedInfiniteClipBatchSampler(Sampler[list[int]]): - def __init__( - self, - *, - sample_weights: torch.Tensor, - batch_size: int, - seed: int, - ) -> None: - if batch_size <= 0: - raise ValueError(f"batch_size must be positive, got {batch_size}") - self.sample_weights = sample_weights.cpu().to(dtype=torch.float32) - self.batch_size = int(batch_size) - self._rng = torch.Generator(device="cpu") - self._rng.manual_seed(int(seed)) - - def __iter__(self) -> Iterator[list[int]]: - while True: - ids = torch.multinomial( - self.sample_weights, - self.batch_size, - replacement=True, - generator=self._rng, - ) - yield [int(value) for value in ids.tolist()] +def _load_all_precomputed_motion( + motion_dir: Path, + *, + body_idx_np: np.ndarray, + device: str, + window_steps: tuple[int, ...], +) -> _MotionBatch: + shard_paths = find_precomputed_motion_shards(motion_dir) + stats = compute_dataset_stats(motion_dir, precomputed=True) + for shard_path in shard_paths: + validate_precomputed_motion_shard(shard_path) + _LOG.info( + "Loading full motion dataset: root=%s shards=%d windows=%d source_clips=%d frames=%d fps=%s", + motion_dir, + stats["shards"], + stats["windows"], + stats["source_clips"], + stats["frames"], + stats["fps"], + ) -class _Hdf5MotionDataset(Dataset[_MotionClipSample]): - def __init__( - self, - *, - refs: list[_Hdf5ClipRef], - shard_paths: list[Path], - body_idx_np: np.ndarray, - window_steps: tuple[int, ...], - ) -> None: - self.refs = refs - self._shard_paths = shard_paths - self.body_idx_np = body_idx_np - self.window_steps = window_steps - self._h5_handles: dict[int, h5py.File] = {} - - def __len__(self) -> int: - return len(self.refs) - - def __getitem__(self, index: int) -> _MotionClipSample: - ref = self.refs[int(index)] - sl = slice(ref.start, ref.start + ref.length) - h5 = self._h5_handle(ref.shard_index) - joint_pos = np.asarray(h5["joint_pos"][sl], dtype=np.float32) - - sample_starts, sample_ends = compute_clip_sample_ranges( - np.asarray([ref.length], dtype=np.int64), - window_steps=self.window_steps, - ) - tensors = { - "joint_pos": torch.from_numpy(joint_pos), - "joint_vel": torch.from_numpy(np.asarray(h5["joint_vel"][sl], dtype=np.float32)), - "body_pos_w": torch.from_numpy( - np.asarray(h5["body_pos_w"][sl], dtype=np.float32)[:, self.body_idx_np] - ), - "body_quat_w": torch.from_numpy( - np.asarray(h5["body_quat_w"][sl], dtype=np.float32)[:, self.body_idx_np] - ), - "body_lin_vel_w": torch.from_numpy( - np.asarray(h5["body_lin_vel_w"][sl], dtype=np.float32)[:, self.body_idx_np] - ), - "body_ang_vel_w": torch.from_numpy( - np.asarray(h5["body_ang_vel_w"][sl], dtype=np.float32)[:, self.body_idx_np] - ), - } - return _MotionClipSample( - tensors=tensors, - length=ref.length, - fps=ref.fps, - sample_start=int(sample_starts[0]), - sample_end=int(sample_ends[0]), - global_id=int(index), + max_future = max((step for step in window_steps if step > 0), default=0) + max_history = -min((step for step in window_steps if step < 0), default=0) + min_clip_length = max_history + 1 + max_future + 1 # +1 for interpolation + + arrays: dict[str, list[torch.Tensor]] = {key: [] for key in MOTION_ARRAY_KEYS} + lengths_out: list[int] = [] + fps_out: list[int] = [] + sample_starts_out: list[int] = [] + sample_ends_out: list[int] = [] + skipped_short = 0 + for shard_path in shard_paths: + with h5py.File(shard_path, "r") as h5: + starts = np.asarray(h5["clip_starts"], dtype=np.int64) + lengths = np.asarray(h5["clip_lengths"], dtype=np.int64) + fps = np.asarray(h5["clip_fps"], dtype=np.int64) + for start, length, cur_fps in zip(starts, lengths, fps): + if int(length) < min_clip_length: + skipped_short += 1 + continue + ref = _Hdf5ClipRef( + start=int(start), + length=int(length), + fps=int(cur_fps), + ) + sl = slice(ref.start, ref.start + ref.length) + sample_starts, sample_ends = compute_clip_sample_ranges( + np.asarray([ref.length], dtype=np.int64), + window_steps=window_steps, + ) + arrays["joint_pos"].append( + torch.from_numpy(np.asarray(h5["joint_pos"][sl], dtype=np.float32)) + ) + arrays["joint_vel"].append( + torch.from_numpy(np.asarray(h5["joint_vel"][sl], dtype=np.float32)) + ) + arrays["body_pos_w"].append( + torch.from_numpy( + np.asarray(h5["body_pos_w"][sl], dtype=np.float32)[:, body_idx_np] + ) + ) + arrays["body_quat_w"].append( + torch.from_numpy( + np.asarray(h5["body_quat_w"][sl], dtype=np.float32)[:, body_idx_np] + ) + ) + arrays["body_lin_vel_w"].append( + torch.from_numpy( + np.asarray(h5["body_lin_vel_w"][sl], dtype=np.float32)[:, body_idx_np] + ) + ) + arrays["body_ang_vel_w"].append( + torch.from_numpy( + np.asarray(h5["body_ang_vel_w"][sl], dtype=np.float32)[:, body_idx_np] + ) + ) + lengths_out.append(ref.length) + fps_out.append(ref.fps) + sample_starts_out.append(int(sample_starts[0])) + sample_ends_out.append(int(sample_ends[0])) + if not lengths_out: + raise ValueError(f"HDF5 motion dataset is empty: {motion_dir}") + if skipped_short > 0: + _LOG.warning( + "Ignoring %d HDF5 motion windows shorter than %d frames (window_steps=%s)", + skipped_short, + min_clip_length, + list(window_steps), ) - def _h5_handle(self, shard_index: int) -> h5py.File: - handle = self._h5_handles.get(shard_index) - if handle is not None and handle.id: - return handle - handle = h5py.File(self._shard_paths[shard_index], "r") - self._h5_handles[shard_index] = handle - return handle - - def close(self) -> None: - for handle in self._h5_handles.values(): - if handle.id: - handle.close() - self._h5_handles.clear() - - -def _collate_motion_clips(samples: list[_MotionClipSample]) -> _MotionBatch: - if not samples: - raise ValueError("Motion cache DataLoader produced an empty batch") - arrays: dict[str, torch.Tensor] = {} - for key in MOTION_ARRAY_KEYS: - arrays[key] = torch.cat([sample.tensors[key] for sample in samples], dim=0) - - lengths = torch.tensor([sample.length for sample in samples], dtype=torch.long) - frame_offsets = torch.zeros(len(samples), dtype=torch.long) - if len(samples) > 1: + device_obj = torch.device(device) + lengths = torch.tensor(lengths_out, dtype=torch.long) + frame_offsets = torch.zeros(len(lengths_out), dtype=torch.long) + if len(lengths_out) > 1: frame_offsets[1:] = torch.cumsum(lengths[:-1], dim=0) - return _MotionBatch( - tensors=arrays, - frame_offsets=frame_offsets, - lengths=lengths, - fps=torch.tensor([sample.fps for sample in samples], dtype=torch.float32), - sample_starts=torch.tensor([sample.sample_start for sample in samples], dtype=torch.long), - sample_ends=torch.tensor([sample.sample_end for sample in samples], dtype=torch.long), - global_ids=torch.tensor([sample.global_id for sample in samples], dtype=torch.long), + tensors={ + key: torch.cat(values, dim=0).to(device_obj) + for key, values in arrays.items() + }, + frame_offsets=frame_offsets.to(device_obj), + lengths=lengths.to(device_obj), + fps=torch.tensor(fps_out, dtype=torch.float32, device=device_obj), + sample_starts=torch.tensor(sample_starts_out, dtype=torch.long, device=device_obj), + sample_ends=torch.tensor(sample_ends_out, dtype=torch.long, device=device_obj), ) -def _motion_worker_init(worker_id: int) -> None: - del worker_id - torch.set_num_threads(1) - - -class _Hdf5MotionCache: - def __init__( - self, - motion_dir: Path, - *, - body_idx_np: np.ndarray, - device: str, - window_steps: tuple[int, ...], - cache_num_clips: int, - seed: int, - dataloader_num_workers: int, - dataloader_prefetch_factor: int, - dataloader_pin_memory: bool, - ) -> None: - if cache_num_clips <= 0: - raise ValueError(f"cache_num_clips must be positive, got {cache_num_clips}") - - self.motion_dir = motion_dir - self.body_idx_np = body_idx_np - self.device = device - self.window_steps = window_steps - self.cache_num_clips = int(cache_num_clips) - self.dataloader_num_workers = max(0, int(dataloader_num_workers)) - self.dataloader_prefetch_factor = max(1, int(dataloader_prefetch_factor)) - self.dataloader_pin_memory = bool(dataloader_pin_memory) - self._device = torch.device(device) - self._copy_stream: torch.cuda.Stream | None = None - self._next_ready_event: torch.cuda.Event | None = None - self._shard_paths = find_precomputed_motion_shards(motion_dir) - stats = compute_dataset_stats(motion_dir, precomputed=True) - self.body_names = np.asarray(stats["body_names"], dtype=str) - for shard_path in self._shard_paths: - validate_precomputed_motion_shard(shard_path) - _LOG.info( - "Motion dataset: root=%s shards=%d windows=%d source_clips=%d frames=%d fps=%s", - motion_dir, - stats["shards"], - stats["windows"], - stats["source_clips"], - stats["frames"], - stats["fps"], - ) - - max_future = max((step for step in self.window_steps if step > 0), default=0) - max_history = -min((step for step in self.window_steps if step < 0), default=0) - min_clip_length = max_history + 1 + max_future + 1 # +1 for interpolation - - refs: list[_Hdf5ClipRef] = [] - skipped_short = 0 - for shard_index, shard_path in enumerate(self._shard_paths): - with h5py.File(shard_path, "r") as h5: - starts = np.asarray(h5["clip_starts"], dtype=np.int64) - lengths = np.asarray(h5["clip_lengths"], dtype=np.int64) - fps = np.asarray(h5["clip_fps"], dtype=np.int64) - for start, length, cur_fps in zip(starts, lengths, fps): - if int(length) < min_clip_length: - skipped_short += 1 - continue - refs.append(_Hdf5ClipRef( - shard_index=shard_index, - start=int(start), - length=int(length), - fps=int(cur_fps), - )) - if not refs: - raise ValueError(f"HDF5 motion dataset is empty: {motion_dir}") - if skipped_short > 0: - _LOG.warning( - "Ignoring %d HDF5 motion windows shorter than %d frames (window_steps=%s)", - skipped_short, - min_clip_length, - list(self.window_steps), - ) - self.refs = refs - ref_lengths_np = np.asarray([ref.length for ref in refs], dtype=np.int64) - ref_starts_np, ref_ends_np = compute_clip_sample_ranges( - ref_lengths_np, - window_steps=self.window_steps, - ) - ref_fps_np = np.asarray([ref.fps for ref in refs], dtype=np.float32) - ref_valid_seconds = (ref_ends_np - ref_starts_np).astype(np.float32) / ref_fps_np - if np.any(ref_valid_seconds <= 0.0): - raise ValueError( - "HDF5 motion dataset contains windows with no valid sample duration " - f"after applying window_steps={list(self.window_steps)}" - ) - self.global_sample_weights = torch.as_tensor(ref_valid_seconds, dtype=torch.float32) - total_weight = float(self.global_sample_weights.sum().item()) - if total_weight <= 0.0: - raise ValueError(f"HDF5 motion dataset has no positive sample duration: {motion_dir}") - self.generation = 0 - self._dataset = _Hdf5MotionDataset( - refs=self.refs, - shard_paths=self._shard_paths, - body_idx_np=self.body_idx_np, - window_steps=self.window_steps, - ) - self._sampler = _WeightedInfiniteClipBatchSampler( - sample_weights=self.global_sample_weights, - batch_size=self.cache_num_clips, - seed=seed, - ) - loader_kwargs: dict[str, object] = {} - if self.dataloader_num_workers > 0: - loader_kwargs["prefetch_factor"] = self.dataloader_prefetch_factor - loader_kwargs["persistent_workers"] = True - loader_kwargs["worker_init_fn"] = _motion_worker_init - self._loader = DataLoader( - self._dataset, - batch_sampler=self._sampler, - num_workers=self.dataloader_num_workers, - pin_memory=self.dataloader_pin_memory and self._device.type == "cuda", - collate_fn=_collate_motion_clips, - **loader_kwargs, - ) - self._iterator = iter(self._loader) - self.current = self._stage_batch(self._load_next_cpu_batch(), wait=True) - self._next_batch = self._stage_batch(self._load_next_cpu_batch(), wait=False) - - def _load_next_cpu_batch(self, *, log_wait: bool = False) -> _MotionBatch: - start = time.perf_counter() - batch = next(self._iterator) - elapsed = time.perf_counter() - start - if log_wait and elapsed > 1e-3: - _LOG.info( - "Waited %.3fs for asynchronous HDF5 motion cache DataLoader", - elapsed, - ) - return batch - - def _stage_batch(self, batch: _MotionBatch, *, wait: bool) -> _MotionBatch: - if self._device.type != "cuda": - tensors = {key: value.to(self._device) for key, value in batch.tensors.items()} - return _MotionBatch( - tensors=tensors, - frame_offsets=batch.frame_offsets.to(self._device), - lengths=batch.lengths.to(self._device), - fps=batch.fps.to(self._device), - sample_starts=batch.sample_starts.to(self._device), - sample_ends=batch.sample_ends.to(self._device), - global_ids=batch.global_ids.to(self._device), - ) - - if self._copy_stream is None: - self._copy_stream = torch.cuda.Stream(device=self._device) - with torch.cuda.stream(self._copy_stream): - tensors = { - key: value.to(self._device, non_blocking=True) - for key, value in batch.tensors.items() - } - staged = _MotionBatch( - tensors=tensors, - frame_offsets=batch.frame_offsets.to(self._device, non_blocking=True), - lengths=batch.lengths.to(self._device, non_blocking=True), - fps=batch.fps.to(self._device, non_blocking=True), - sample_starts=batch.sample_starts.to(self._device, non_blocking=True), - sample_ends=batch.sample_ends.to(self._device, non_blocking=True), - global_ids=batch.global_ids.to(self._device, non_blocking=True), - ) - event = torch.cuda.Event() - event.record(self._copy_stream) - if wait: - torch.cuda.current_stream(self._device).wait_event(event) - else: - self._next_ready_event = event - return staged - - def _wait_next_ready(self) -> None: - if self._next_ready_event is None: - return - if self._device.type == "cuda": - torch.cuda.current_stream(self._device).wait_event(self._next_ready_event) - self._next_ready_event = None - - def _materialize_batch_by_global_ids(self, global_ids: torch.Tensor) -> _MotionBatch: - samples = [self._dataset[int(idx)] for idx in global_ids.tolist()] - batch = _collate_motion_clips(samples) - return self._stage_batch(batch, wait=True) - - def _sample_global_ids(self) -> torch.Tensor: - ids = next(iter(self._sampler)) - return torch.tensor(ids, dtype=torch.long) - - def _load_random_batch(self) -> _MotionBatch: - return self._materialize_batch_by_global_ids(self._sample_global_ids()) - - def _load_batch(self, global_ids: torch.Tensor) -> _MotionBatch: - return self._materialize_batch_by_global_ids(global_ids.cpu().to(dtype=torch.long)) - - def advance(self) -> None: - start = time.perf_counter() - self._wait_next_ready() - elapsed = time.perf_counter() - start - if elapsed > 1e-3: - _LOG.info( - "Waited %.3fs for asynchronous HDF5 motion cache staging", - elapsed, - ) - self.current = self._next_batch - self._next_batch = self._stage_batch( - self._load_next_cpu_batch(log_wait=True), - wait=False, - ) - self.generation += 1 - - def close(self) -> None: - self._dataset.close() - iterator = getattr(self, "_iterator", None) - shutdown = getattr(iterator, "_shutdown_workers", None) - if callable(shutdown): - shutdown() - - def __del__(self) -> None: - try: - self.close() - except Exception: - pass - - class MotionLib: """Clip-aware motion library. - Loads a bounded subset of HDF5 motion windows into a GPU-resident cache. - Sampling and interpolation operate on cache-local clip ids; the next cache - is staged in memory and swapped at a rollout barrier by ``MotionCommand``. + Loads all precomputed HDF5 motion windows into memory at startup. """ def __init__( @@ -463,11 +203,6 @@ def __init__( body_names: tuple[str, ...] | list[str] | None = None, device: str = "cpu", window_steps: tuple[int, ...] | list[int] | None = None, - cache_num_clips: int = 8192, - cache_seed: int = 0, - dataloader_num_workers: int = 2, - dataloader_prefetch_factor: int = 1, - dataloader_pin_memory: bool = True, ) -> None: self._device = device self.window_steps = parse_window_steps(window_steps) @@ -501,18 +236,13 @@ def __init__( dtype=np.int64, ) - self._cache = _Hdf5MotionCache( + batch = _load_all_precomputed_motion( motion_path, body_idx_np=body_idx_np, device=device, window_steps=self.window_steps, - cache_num_clips=cache_num_clips, - seed=cache_seed, - dataloader_num_workers=dataloader_num_workers, - dataloader_prefetch_factor=dataloader_prefetch_factor, - dataloader_pin_memory=dataloader_pin_memory, ) - self._set_batch(self._cache.current) + self._set_batch(batch) def _set_batch(self, batch: _MotionBatch) -> None: self._batch = batch @@ -534,24 +264,27 @@ def _set_batch(self, batch: _MotionBatch) -> None: self.clip_sample_ends = batch.sample_ends self.clip_sample_start_s = self.clip_sample_starts.float() * self.clip_dt self.clip_sample_end_s = self.clip_sample_ends.float() * self.clip_dt - # Kept for introspection/logging; these are cache-local flat frame offsets. self.clip_starts = self.clip_frame_offsets - self.generation = self._cache.generation - - def advance_cache(self) -> None: - self._cache.advance() - self._set_batch(self._cache.current) + ref_valid_seconds = ( + (self.clip_sample_ends - self.clip_sample_starts).float() / self.clip_fps + ) + if torch.any(ref_valid_seconds <= 0.0): + raise ValueError( + "HDF5 motion dataset contains windows with no valid sample duration " + f"after applying window_steps={list(self.window_steps)}" + ) + self.sample_weights = ref_valid_seconds / torch.sum(ref_valid_seconds) def close(self) -> None: - self._cache.close() + return None # ------------------------------------------------------------------ # Sampling helpers # ------------------------------------------------------------------ def sample_motion_ids(self, n: int) -> torch.Tensor: - """Sample *n* cache-local clip indices uniformly.""" - return torch.randint(0, self.num_clips, (n,), device=self._device) + """Sample *n* clip indices by valid-duration weights.""" + return torch.multinomial(self.sample_weights, int(n), replacement=True) def sample_times(self, motion_ids: torch.Tensor) -> torch.Tensor: """Uniform random time over valid center frames for each motion id.""" @@ -736,14 +469,7 @@ def __init__(self, cfg: MotionCommandCfg, env: ManagerBasedRlEnv): body_names=self.cfg.body_names, device=self.device, window_steps=self.cfg.window_steps, - cache_num_clips=self.cfg.cache_num_clips, - cache_seed=self.cfg.cache_seed, - dataloader_num_workers=self.cfg.cache_dataloader_num_workers, - dataloader_prefetch_factor=self.cfg.cache_dataloader_prefetch_factor, - dataloader_pin_memory=self.cfg.cache_dataloader_pin_memory, ) - self._motion_cache_step_counter = 0 - self._motion_cache_swap_pending = False # Per-env motion state: clip id + elapsed time (seconds) self.motion_ids = torch.zeros(self.num_envs, dtype=torch.long, device=self.device) @@ -1164,10 +890,6 @@ def _refresh_body_local_cache(self) -> None: def _update_command(self): # Advance motion time by real elapsed time self.motion_times += self._step_dt - if self.cfg.cache_swap_interval_steps > 0: - self._motion_cache_step_counter += 1 - if self._motion_cache_step_counter >= self.cfg.cache_swap_interval_steps: - self._motion_cache_swap_pending = True # Handle clips that exceeded their duration end_times = self.motion.clip_sample_end_s[self.motion_ids] @@ -1206,24 +928,6 @@ def _update_command(self): self._refresh_body_local_cache() self._update_feet_standing() - def apply_cache_swap_if_pending_barrier(self) -> bool: - """Swap the staged motion cache at a rollout barrier, then resample all envs.""" - if not self._motion_cache_swap_pending: - return False - self.motion.advance_cache() - self._motion_cache_step_counter = 0 - self._motion_cache_swap_pending = False - all_env_ids = torch.arange(self.num_envs, dtype=torch.long, device=self.device) - if self.cfg.sampling_mode == "start": - self.motion_ids[all_env_ids] = self.motion.sample_motion_ids(self.num_envs) - self.motion_times[all_env_ids] = self.motion.sample_start_times(self.motion_ids[all_env_ids]) - else: - # Rewind only makes sense inside one cache generation. After a cache - # swap, local ids refer to different clips, so fall back to uniform. - self._uniform_sampling(all_env_ids) - self._reset_envs_to_current_reference(all_env_ids) - return True - # ------------------------------------------------------------------ # Visualization # ------------------------------------------------------------------ @@ -1310,12 +1014,6 @@ class MotionCommandCfg(CommandTermCfg): joint_position_range: tuple[float, float] = (-0.52, 0.52) sampling_mode: Literal["uniform", "start", "rewind"] = "rewind" window_steps: tuple[int, ...] = (0,) - cache_num_clips: int = 8192 - cache_swap_interval_steps: int = 2000 - cache_dataloader_num_workers: int = 2 - cache_dataloader_prefetch_factor: int = 1 - cache_dataloader_pin_memory: bool = True - cache_seed: int = 0 rewind_prob: float = 0.8 rewind_min_steps: int = 25 rewind_max_steps: int = 75 diff --git a/train_mimic/tasks/tracking/rl/runner.py b/train_mimic/tasks/tracking/rl/runner.py index 64acc939..4614ce50 100644 --- a/train_mimic/tasks/tracking/rl/runner.py +++ b/train_mimic/tasks/tracking/rl/runner.py @@ -2,7 +2,6 @@ import pathlib import statistics import time -from typing import cast import torch from rsl_rl.env.vec_env import VecEnv @@ -10,7 +9,6 @@ from mjlab.rl import RslRlVecEnvWrapper from mjlab.rl.runner import MjlabOnPolicyRunner from rsl_rl.utils import check_nan -from train_mimic.tasks.tracking.mdp import MotionCommand def _one_based_iteration_range(start_iteration: int, total_iterations: int) -> range: @@ -55,9 +53,6 @@ def __init__( super().__init__(env, train_cfg, log_dir, device) self.registry_name = registry_name - def _motion_command(self) -> MotionCommand: - return cast(MotionCommand, self.env.unwrapped.command_manager.get_term("motion")) - def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = False) -> None: """Run the learning loop using 1-based iteration numbering.""" if init_at_random_ep_len: @@ -93,9 +88,6 @@ def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = Fals collect_time = stop - start start = stop self.alg.compute_returns(obs) - cmd = self._motion_command() - if cmd.apply_cache_swap_if_pending_barrier(): - obs = self.env.get_observations().to(self.device) loss_dict = self.alg.update() From 81a0ace03449842d476a922931c26fc34203adc0 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Tue, 16 Jun 2026 22:45:52 +0800 Subject: [PATCH 097/122] Fix precomputed motion window offsets --- tests/test_motion_sampling.py | 42 +++++++ train_mimic/data/dataset_lib.py | 2 +- train_mimic/tasks/tracking/mdp/commands.py | 130 ++++++++++++--------- 3 files changed, 120 insertions(+), 54 deletions(-) diff --git a/tests/test_motion_sampling.py b/tests/test_motion_sampling.py index ac98f1f9..28d01ea5 100644 --- a/tests/test_motion_sampling.py +++ b/tests/test_motion_sampling.py @@ -187,6 +187,48 @@ def test_motion_lib_frame_offsets_select_requested_clip(tmp_path: Path) -> None: assert torch.allclose(frames["body_pos_w"][0, 0], torch.tensor([102.0, 0.0, 0.0])) +def test_motion_lib_uses_original_window_starts_for_overlapped_shards(tmp_path: Path) -> None: + motion_path = tmp_path / "motion_overlapped_offsets" + motion_path.mkdir() + clip = _clip_dict(num_frames=12) + shard_path = motion_path / "shard_000.h5" + merged = { + "fps": int(clip["fps"]), + "root_pos": np.asarray(clip["root_pos"]), + "root_quat_w": np.asarray(clip["root_quat_w"]), + "joint_pos": np.asarray(clip["joint_pos"]), + "joint_vel": np.asarray(clip["joint_vel"]), + "body_pos_w": np.asarray(clip["body_pos_w"]), + "body_quat_w": np.asarray(clip["body_quat_w"]), + "body_lin_vel_w": np.asarray(clip["body_lin_vel_w"]), + "body_ang_vel_w": np.asarray(clip["body_ang_vel_w"]), + "body_names": np.asarray(clip["body_names"]), + "clip_starts": np.asarray([0, 4, 8], dtype=np.int64), + "clip_lengths": np.asarray([6, 6, 4], dtype=np.int64), + "clip_fps": np.asarray([1, 1, 1], dtype=np.int64), + } + _write_precomputed_from_merged(shard_path, merged) + + motion = MotionLib( + str(motion_path), + body_indexes=torch.tensor([0, 1], dtype=torch.long), + window_steps=(0,), + ) + + assert motion.clip_frame_offsets.cpu().tolist() == [0, 4, 8] + assert motion._joint_pos_t.shape[0] == 12 + + frames = motion.get_frames( + torch.tensor([0, 1, 2], dtype=torch.long), + torch.tensor([2.0, 2.0, 2.0], dtype=torch.float32), + ) + + assert torch.allclose( + frames["joint_pos"][:, 0], + torch.tensor([2.0, 6.0, 10.0], dtype=torch.float32), + ) + + def test_motion_lib_selects_bodies_by_dataset_names(tmp_path: Path) -> None: motion_path = _write_shard_dir(tmp_path / "motion_named_bodies", [_clip_dict()]) diff --git a/train_mimic/data/dataset_lib.py b/train_mimic/data/dataset_lib.py index 45696d60..76c9411f 100644 --- a/train_mimic/data/dataset_lib.py +++ b/train_mimic/data/dataset_lib.py @@ -931,7 +931,7 @@ def compute_dataset_stats( source_ids = np.asarray(h5["source_clip_ids"], dtype=np.int64) fps_values.update(int(v) for v in np.unique(fps_arr)) windows = int(lengths.shape[0]) - frames = int(np.asarray(h5["joint_pos"]).shape[0]) + frames = int(h5["joint_pos"].shape[0]) source_clips = int(len(np.unique(source_ids))) source_lengths = np.asarray(h5["source_clip_lengths"], dtype=np.float64) source_fps = np.asarray(h5["source_clip_fps"], dtype=np.float64) diff --git a/train_mimic/tasks/tracking/mdp/commands.py b/train_mimic/tasks/tracking/mdp/commands.py index 28519f96..add39de0 100644 --- a/train_mimic/tasks/tracking/mdp/commands.py +++ b/train_mimic/tasks/tracking/mdp/commands.py @@ -65,13 +65,6 @@ def _batched_quat_slerp( return result / result.norm(dim=-1, keepdim=True) -@dataclass(frozen=True) -class _Hdf5ClipRef: - start: int - length: int - fps: int - - @dataclass class _MotionBatch: tensors: dict[str, torch.Tensor] @@ -82,6 +75,35 @@ class _MotionBatch: sample_ends: torch.Tensor +def _read_selected_body_array( + h5: h5py.File, + key: str, + body_idx_np: np.ndarray, +) -> torch.Tensor: + """Read only the requested body axis while preserving caller order.""" + dataset = h5[key] + idx = np.asarray(body_idx_np, dtype=np.int64) + if idx.ndim != 1: + raise ValueError(f"body indexes must be 1-D, got {idx.shape}") + if idx.size == 0: + frames = int(dataset.shape[0]) + width = int(dataset.shape[2]) + return torch.empty((frames, 0, width), dtype=torch.float32) + + if np.any(idx < 0) or np.any(idx >= dataset.shape[1]): + raise IndexError( + f"body indexes out of range for {key}: indexes={idx.tolist()}, " + f"num_bodies={dataset.shape[1]}" + ) + + parts = [ + np.asarray(dataset[:, int(body_idx) : int(body_idx) + 1, :], dtype=np.float32) + for body_idx in idx + ] + arr = np.concatenate(parts, axis=1) + return torch.from_numpy(arr) + + def _load_all_precomputed_motion( motion_dir: Path, *, @@ -112,56 +134,60 @@ def _load_all_precomputed_motion( fps_out: list[int] = [] sample_starts_out: list[int] = [] sample_ends_out: list[int] = [] + frame_offsets_out: list[int] = [] skipped_short = 0 + loaded_frames = 0 for shard_path in shard_paths: with h5py.File(shard_path, "r") as h5: starts = np.asarray(h5["clip_starts"], dtype=np.int64) lengths = np.asarray(h5["clip_lengths"], dtype=np.int64) fps = np.asarray(h5["clip_fps"], dtype=np.int64) - for start, length, cur_fps in zip(starts, lengths, fps): - if int(length) < min_clip_length: - skipped_short += 1 - continue - ref = _Hdf5ClipRef( - start=int(start), - length=int(length), - fps=int(cur_fps), - ) - sl = slice(ref.start, ref.start + ref.length) - sample_starts, sample_ends = compute_clip_sample_ranges( - np.asarray([ref.length], dtype=np.int64), - window_steps=window_steps, - ) - arrays["joint_pos"].append( - torch.from_numpy(np.asarray(h5["joint_pos"][sl], dtype=np.float32)) - ) - arrays["joint_vel"].append( - torch.from_numpy(np.asarray(h5["joint_vel"][sl], dtype=np.float32)) - ) - arrays["body_pos_w"].append( - torch.from_numpy( - np.asarray(h5["body_pos_w"][sl], dtype=np.float32)[:, body_idx_np] - ) - ) - arrays["body_quat_w"].append( - torch.from_numpy( - np.asarray(h5["body_quat_w"][sl], dtype=np.float32)[:, body_idx_np] - ) - ) - arrays["body_lin_vel_w"].append( - torch.from_numpy( - np.asarray(h5["body_lin_vel_w"][sl], dtype=np.float32)[:, body_idx_np] - ) - ) - arrays["body_ang_vel_w"].append( - torch.from_numpy( - np.asarray(h5["body_ang_vel_w"][sl], dtype=np.float32)[:, body_idx_np] - ) + frames = int(h5["joint_pos"].shape[0]) + if np.any(starts < 0) or np.any(starts + lengths > frames): + raise ValueError( + f"HDF5 shard {shard_path} has clip windows outside joint_pos " + f"frame range: frames={frames}" ) - lengths_out.append(ref.length) - fps_out.append(ref.fps) - sample_starts_out.append(int(sample_starts[0])) - sample_ends_out.append(int(sample_ends[0])) + + valid_mask = lengths >= min_clip_length + skipped_short += int(np.count_nonzero(~valid_mask)) + if not np.any(valid_mask): + continue + + shard_frame_base = loaded_frames + arrays["joint_pos"].append( + torch.from_numpy(np.asarray(h5["joint_pos"], dtype=np.float32)) + ) + arrays["joint_vel"].append( + torch.from_numpy(np.asarray(h5["joint_vel"], dtype=np.float32)) + ) + arrays["body_pos_w"].append( + _read_selected_body_array(h5, "body_pos_w", body_idx_np) + ) + arrays["body_quat_w"].append( + _read_selected_body_array(h5, "body_quat_w", body_idx_np) + ) + arrays["body_lin_vel_w"].append( + _read_selected_body_array(h5, "body_lin_vel_w", body_idx_np) + ) + arrays["body_ang_vel_w"].append( + _read_selected_body_array(h5, "body_ang_vel_w", body_idx_np) + ) + loaded_frames += frames + + valid_starts = starts[valid_mask] + valid_lengths = lengths[valid_mask] + valid_fps = fps[valid_mask] + sample_starts, sample_ends = compute_clip_sample_ranges( + valid_lengths, + window_steps=window_steps, + ) + valid_frame_offsets = (shard_frame_base + valid_starts).astype(np.int64) + frame_offsets_out.extend(int(offset) for offset in valid_frame_offsets) + lengths_out.extend(int(length) for length in valid_lengths) + fps_out.extend(int(cur_fps) for cur_fps in valid_fps) + sample_starts_out.extend(int(start) for start in sample_starts) + sample_ends_out.extend(int(end) for end in sample_ends) if not lengths_out: raise ValueError(f"HDF5 motion dataset is empty: {motion_dir}") if skipped_short > 0: @@ -174,9 +200,7 @@ def _load_all_precomputed_motion( device_obj = torch.device(device) lengths = torch.tensor(lengths_out, dtype=torch.long) - frame_offsets = torch.zeros(len(lengths_out), dtype=torch.long) - if len(lengths_out) > 1: - frame_offsets[1:] = torch.cumsum(lengths[:-1], dim=0) + frame_offsets = torch.tensor(frame_offsets_out, dtype=torch.long) return _MotionBatch( tensors={ key: torch.cat(values, dim=0).to(device_obj) From 2ad91e1b8bebb07a2f06f39126bb08767b56a328 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Tue, 16 Jun 2026 22:56:51 +0800 Subject: [PATCH 098/122] Remove keyframes from training G1 spec --- tests/test_train_script.py | 3 +++ train_mimic/tasks/tracking/config/env.py | 2 ++ 2 files changed, 5 insertions(+) diff --git a/tests/test_train_script.py b/tests/test_train_script.py index 84657cb6..53681c6f 100644 --- a/tests/test_train_script.py +++ b/tests/test_train_script.py @@ -258,6 +258,9 @@ def test_make_g1_training_robot_cfg_uses_requested_xml() -> None: assert isinstance(robot_cfg.spec_fn, partial) assert robot_cfg.spec_fn.args == (xml_path,) + spec = robot_cfg.spec_fn() + assert len(spec.actuators) == 0 + assert len(spec.keys) == 0 def test_validate_motion_file_accepts_shard_directories(tmp_path: Path) -> None: diff --git a/train_mimic/tasks/tracking/config/env.py b/train_mimic/tasks/tracking/config/env.py index 3a3c47fb..21dbf3cc 100644 --- a/train_mimic/tasks/tracking/config/env.py +++ b/train_mimic/tasks/tracking/config/env.py @@ -69,6 +69,8 @@ def _get_g1_training_spec(robot_xml: str | Path | None = None) -> mujoco.MjSpec: spec = mujoco.MjSpec.from_file(str(xml_path)) for actuator in list(spec.actuators): spec.delete(actuator) + for key in list(spec.keys): + spec.delete(key) return spec From bad15865ef9ab036bf8a091a6e62bf6ee8216953 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Wed, 17 Jun 2026 16:29:59 +0800 Subject: [PATCH 099/122] Add LinkerHand O6 gripper support --- AGENTS.md | 13 +- README.md | 1 + docs/docs/configuration/config-reference.md | 16 +- docs/docs/getting-started/installation.md | 4 +- docs/docs/tutorials/pico-sim2real.md | 52 +++-- .../current/configuration/config-reference.md | 28 +-- .../current/getting-started/installation.md | 6 +- .../current/tutorials/pico-sim2real.md | 42 +++- scripts/dev/test_linkerhand_l6.py | 199 ++++++++---------- scripts/run/run_sim2real.py | 2 +- teleopit/configs/pico4_sim2real.yaml | 14 +- teleopit/configs/sim2real.yaml | 14 +- teleopit/sim2real/hands/linkerhand_l6.py | 22 +- teleopit/sim2real/hands/linkerhand_o6.py | 163 ++++++++++++++ teleopit/sim2real/hands/worker.py | 10 +- tests/test_dexterous_hand.py | 88 ++++++++ 16 files changed, 496 insertions(+), 178 deletions(-) create mode 100644 teleopit/sim2real/hands/linkerhand_o6.py diff --git a/AGENTS.md b/AGENTS.md index da370a6b..6cd487e3 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -56,7 +56,7 @@ teleopit/ # Core inference package │ └── loop.py # SimulationLoop — PD control at 1000Hz, policy at 50Hz ├── sim2real/ │ ├── mp/ # Process-isolated sim2real runtime and IPC -│ └── hands/ # Optional LinkerHand L6 driver/mapper plugins +│ └── hands/ # Optional LinkerHand driver/mapper plugins └── recording/ # HDF5Recorder and Pico motion NPZ recording helpers scripts/ ├── run/run_sim.py # Offline sim2sim pipeline @@ -145,13 +145,14 @@ target_dof_pos = clip(action, -10, 10) × action_scale + default_dof_pos - Pico4 sim2sim/sim2real support `ARMS` mode toggled from `MOCAP` with Pico/controller `B`; retargeting continues, while the control loop sends the motion tracker a composed reference with stand-pose body/legs/waist and live retargeted arms - `ARMS` entering/exiting/resume resets policy/reference alignment and uses Kp ramp; offline BVH sim2real does not use `ARMS`, and Unitree remote `B` remains BVH replay - Realtime mode switches and pause/resume use a retargeter-preserving soft reset: policy/reference state, smoothers, and reference alignment are reset, while the GMR IK warm-start is retained -- Optional LinkerHand L6 control uses `hands.enabled=true` and `hands.mode=gripper|vr_hand_pose`; default is disabled -- `gripper` mode reuses `Pico4InputProvider.get_controller_snapshot()` for Pico grip/trigger open-close control +- Optional LinkerHand control uses `hands.enabled=true`, `hands.driver=linkerhand_l6|linkerhand_o6`, and `hands.mode=gripper|vr_hand_pose`; default is disabled +- `gripper` mode reuses `Pico4InputProvider.get_controller_snapshot()` for Pico grip/trigger open-close control and supports LinkerHand L6 and O6 - `vr_hand_pose` mode reuses `Pico4InputProvider.get_hand_snapshot()` and somehand 0.2.0 public `somehand.api` for continuous Pico hand-pose retargeting; do not start a second `PicoBridge` for hand control - Teleopit owns Pico 26-joint hand-state to 21-landmark conversion; do not import `somehand.pico_input` -- `gripper` mode uses the configured `hands.linkerhand_l6.speed` (default `[50]*6`); `vr_hand_pose` always sets LinkerHand L6 speed to `[255]*6` +- LinkerHand O6 supports only `hands.mode=gripper`; its default `close_pose` is `[86, 73, 118, 111, 110, 111]` +- L6 `gripper` mode uses the configured `hands.linkerhand_l6.speed` (default `[50]*6`); O6 `gripper` mode uses `hands.linkerhand_o6.speed` (default `[255]*6`); `vr_hand_pose` always sets LinkerHand L6 speed to `[255]*6` - `vr_hand_pose` defaults to a low-latency somehand path: `hands.somehand.rate_hz=60`, `max_iterations=12`, `temporal_filter_alpha=1.0`, and `output_alpha=1.0`; this prioritizes response speed over smoothing -- LinkerHand L6 control is active in sim2real `MOCAP` and `ARMS`; `STANDING`, `DAMPING`, mocap pause, and shutdown must send the configured open pose +- LinkerHand control is active in sim2real `MOCAP` and `ARMS`; `STANDING`, `DAMPING`, mocap pause, and shutdown must send the configured open pose - In `vr_hand_pose` mode, missing/inactive hand pose holds the last commanded pose for that side instead of opening the hand ### SimulationLoop Runtime Behavior @@ -236,7 +237,7 @@ python train_mimic/scripts/save_onnx.py --checkpoint logs/rsl_rl/g1_general_trac - `assets/robots/unitree_g1/g1_29dof.xml` and its meshes are the canonical G1 robot model assets; they are downloaded from the `robots` asset group and are not tracked in Git - `teleopit/retargeting/gmr/assets/` is gitignored; downloaded at runtime - `train_mimic/assets/` is no longer tracked; FK tooling reuses `assets/robots/unitree_g1/g1_29dof.xml` -- `third_party/linkerhand-python-sdk` and `third_party/somehand` support optional LinkerHand L6 sim2real control +- `third_party/linkerhand-python-sdk` and `third_party/somehand` support optional LinkerHand sim2real control - Run `python scripts/check_large_tracked_files.py` before pushing Assets are split across two ModelScope repos by type: diff --git a/README.md b/README.md index 9ce0bdfa..56d6559c 100644 --- a/README.md +++ b/README.md @@ -95,6 +95,7 @@ Full docs at **[BotRunner64.github.io/Teleopit](https://BotRunner64.github.io/Te - Added Pico sim2real `ARMS` mode: Pico/controller `B` toggles between whole-body `MOCAP` and stand-pose body/legs with live retargeted arms. - Bumped Pico input support to pico-bridge 0.2.1 and its corrected tracking pose semantics. - Added optional LinkerHand L6 sim2real modes under `hands.*`: `gripper` from Pico grip/trigger and `vr_hand_pose` from Pico hand pose through somehand 0.2.0 public API. +- Added LinkerHand O6 support for Pico `gripper` mode with an O6-specific grasp pose. - Set LinkerHand L6 `vr_hand_pose` control to maximum speed while keeping `gripper` at the configured default speed. - Switched default `vr_hand_pose` to a low-latency somehand path with 60 Hz hand retargeting and reduced smoothing. - Realtime mode switches and pause/resume now preserve GMR IK warm-starts instead of cold-starting the retargeter on each transition. diff --git a/docs/docs/configuration/config-reference.md b/docs/docs/configuration/config-reference.md index 3f49c70c..ecb66b58 100644 --- a/docs/docs/configuration/config-reference.md +++ b/docs/docs/configuration/config-reference.md @@ -124,16 +124,17 @@ Realtime Pico resume re-centers heading and ground-plane position before trackin ### Dexterous Hand (Pico sim2real) `hands.enabled=true` requires `input.provider=pico4` and the optional `dexhand` -extra. Control is active in `MOCAP` and `ARMS`; inactive modes send the open pose. In -`vr_hand_pose`, missing hand pose holds the last command for that side. -`gripper` uses the configured `hands.linkerhand_l6.speed`; `vr_hand_pose` -always sets LinkerHand L6 speed to the maximum. Teleopit converts Pico hand -state to 21 landmarks and embeds somehand 0.2.0 through `somehand.api` only. +extra. Control is active in `MOCAP` and `ARMS`; inactive modes send the open pose. +`gripper` supports `linkerhand_l6` and `linkerhand_o6` by interpolating Pico +trigger input between the configured open and close poses. `vr_hand_pose` is +L6-only: missing hand pose holds the last command for that side, L6 speed is +set to the maximum, and Teleopit converts Pico hand state to 21 landmarks before +calling somehand 0.2.0 through `somehand.api` only. | Field | Description | Default | |-------|-------------|---------| | `hands.enabled` | Enable optional hand worker | `false` | -| `hands.driver` | Hand driver plugin | `linkerhand_l6` | +| `hands.driver` | Hand driver plugin: `linkerhand_l6` or `linkerhand_o6` | `linkerhand_l6` | | `hands.mode` | `gripper` or `vr_hand_pose` | `gripper` | | `hands.sides` | Controlled sides | `[left, right]` | | `hands.rate_hz` | Maximum gripper command rate in Hz | `30.0` | @@ -141,6 +142,9 @@ state to 21 landmarks and embeds somehand 0.2.0 through `somehand.api` only. | `hands.linkerhand_l6.left_can` / `right_can` | CAN channels for each hand | `can0` / `can1` | | `hands.linkerhand_l6.speed` | L6 speed used by `gripper`; `vr_hand_pose` overrides this to maximum speed | see config | | `hands.linkerhand_l6.open_pose` / `close_pose` | Six-value L6 open/closed poses | see config | +| `hands.linkerhand_o6.left_can` / `right_can` | CAN channels for each O6 hand | `can0` / `can1` | +| `hands.linkerhand_o6.speed` | O6 speed used by `gripper` | see config | +| `hands.linkerhand_o6.open_pose` / `close_pose` | Six-value O6 open/closed poses | see config | | `hands.somehand.config_path` | Official somehand 0.2.0 bi-hand L6 config used by `vr_hand_pose` | see config | | `hands.somehand.rate_hz` | Low-latency `vr_hand_pose` command rate in Hz | `60.0` | | `hands.somehand.max_iterations` | somehand solver iteration cap for `vr_hand_pose` | `12` | diff --git a/docs/docs/getting-started/installation.md b/docs/docs/getting-started/installation.md index 807d4159..14cf05f7 100644 --- a/docs/docs/getting-started/installation.md +++ b/docs/docs/getting-started/installation.md @@ -61,9 +61,9 @@ The receiver can run on a workstation PC or the robot onboard computer. See [Pico Sim2Sim](../tutorials/pico-sim2sim) and [Pico Sim2Real](../tutorials/pico-sim2real) for the full setup guides. -Optional LinkerHand L6 control for Pico sim2real is installed through the +Optional LinkerHand control for Pico sim2real is installed through the `dexhand` extra. It includes the LinkerHand SDK submodule and the remote -somehand package used by VR hand-pose mode: +somehand package used by the L6 VR hand-pose mode: ```bash git submodule update --init --recursive diff --git a/docs/docs/tutorials/pico-sim2real.md b/docs/docs/tutorials/pico-sim2real.md index c7f7cfd1..ab0f3e2a 100644 --- a/docs/docs/tutorials/pico-sim2real.md +++ b/docs/docs/tutorials/pico-sim2real.md @@ -146,18 +146,19 @@ Resume while standing still and close to the paused pose. This reduces sudden reference changes when live tracking resumes. ::: -## Optional LinkerHand L6 Control +## Optional LinkerHand Control -Pico sim2real can drive LinkerHand L6 hands in two modes: +Pico sim2real can drive LinkerHand hands from Pico input: - `gripper`: hold the matching side grip as a deadman switch; the matching - trigger closes that hand. This mode uses the configured - `hands.linkerhand_l6.speed`, which defaults to 50. -- `vr_hand_pose`: retarget Pico hand pose through somehand and command the - continuous L6 hand target. If a hand pose disappears, that side keeps its last - commanded pose. This mode uses Teleopit's Pico landmark adapter and the - public `somehand.api` from somehand 0.2.0. It always sets L6 speed to the - maximum. + trigger closes that hand. This mode supports `hands.driver=linkerhand_l6` and + `hands.driver=linkerhand_o6`; speed and open/close poses come from the matching + driver config. +- `vr_hand_pose`: L6-only mode that retargets Pico hand pose through somehand and + commands the continuous L6 hand target. If a hand pose disappears, that side + keeps its last commanded pose. This mode uses Teleopit's Pico landmark adapter + and the public `somehand.api` from somehand 0.2.0. It always sets L6 speed to + the maximum. Hand control is active in `MOCAP` and `ARMS`. It sends the open pose in `STANDING`, `DAMPING`, paused mocap, and shutdown. @@ -178,7 +179,7 @@ sudo /usr/sbin/ip link set can1 up type can bitrate 1000000 ``` Before enabling full sim2real, verify the hand connection with a standalone -open/close test: +open/close test. The test runs until Ctrl-C: ```bash python scripts/dev/test_linkerhand_l6.py \ @@ -187,19 +188,43 @@ python scripts/dev/test_linkerhand_l6.py \ --right-can can1 ``` -Then enable L6 control in Pico sim2real: +For an O6 standalone open/close test, add the O6 driver: + +```bash +python scripts/dev/test_linkerhand_l6.py \ + --driver linkerhand_o6 \ + --hand-type both \ + --left-can can0 \ + --right-can can1 +``` + +To test O6 with live Pico gripper input, add `--mode gripper`. + +Then enable L6 gripper control in Pico sim2real: ```bash hands.enabled=true +hands.driver=linkerhand_l6 hands.mode=gripper hands.linkerhand_l6.left_can=can0 hands.linkerhand_l6.right_can=can1 ``` +For O6 gripper control, use: + +```bash +hands.enabled=true +hands.driver=linkerhand_o6 +hands.mode=gripper +hands.linkerhand_o6.left_can=can0 +hands.linkerhand_o6.right_can=can1 +``` + For continuous VR hand-pose control, use: ```bash hands.enabled=true +hands.driver=linkerhand_l6 hands.mode=vr_hand_pose hands.linkerhand_l6.left_can=can0 hands.linkerhand_l6.right_can=can1 @@ -238,8 +263,9 @@ mocap_switch.check_frames=10 # Change Pico pause button input.pause_button=right_axis_click -# Enable LinkerHand L6 control +# Enable LinkerHand gripper control hands.enabled=true +hands.driver=linkerhand_l6 hands.mode=gripper # Enable headset video preview @@ -255,5 +281,5 @@ input.video.enabled=true | Cannot enter debug mode | Unitree mode release failed | Stop other robot modes and press `Start` again | | Robot enters `STANDING` but not `MOCAP` | Mocap validation failed | Keep tracking active and stable; check `mocap_switch.check_frames` logs | | Pico pause does not return to `STANDING` | Expected behavior | Pico pause freezes mocap; press remote `X` for `STANDING` | -| LinkerHand does not move | `hands.enabled=false`, not in `MOCAP`, gripper deadman released, SDK/assets not installed, or CAN channel wrong | Enable `hands.enabled`, enter `MOCAP`, run `scripts/dev/test_linkerhand_l6.py`, and check `hands.linkerhand_l6.left_can` / `right_can` | +| LinkerHand does not move | `hands.enabled=false`, not in `MOCAP`, gripper deadman released, SDK/assets not installed, or CAN channel wrong | Enable `hands.enabled`, enter `MOCAP`, run `scripts/dev/test_linkerhand_l6.py`, and check the selected driver's `left_can` / `right_can` | | Video preview is unavailable | RealSense or video source failed | Check camera permissions, `input.video.source`, and logs | diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md index e42000af..3e7e4dc4 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md @@ -142,30 +142,30 @@ MuJoCo 窗口显示重定向参考;`sim2sim`、`mocap`、`camera` 和 `all` ### 灵巧手(Pico sim2real) -`hands.mode=gripper` 或 `hands.mode=vr_hand_pose` 要求 `input.provider=pico4`, -并安装可选的 `dexhand` extra。控制在 `MOCAP` 和 `ARMS` 中生效;非活动模式会发送张开姿态。 -在 `vr_hand_pose` 中,Teleopit 将 Pico 手部 pose 适配成 somehand 0.2.0 的 -landmark 输入,只调用公开的 `somehand.api`;手部 pose 消失时,对应侧会保持上一条命令。 -`gripper` 使用配置的 `hands.linkerhand_l6.speed`;`vr_hand_pose` 始终将 -LinkerHand L6 速度设为最大值。默认的 `vr_hand_pose` 路径优先降低延时:它会按 -`hands.somehand.rate` 在后台线程运行,并关闭大部分 somehand 输入/输出平滑,因此手指运动可能更抖。 +`hands.enabled=true` 要求 `input.provider=pico4`,并安装可选的 `dexhand` +extra。控制在 `MOCAP` 和 `ARMS` 中生效;非活动模式会发送张开姿态。 +`gripper` 支持 `linkerhand_l6` 和 `linkerhand_o6`,会用 Pico trigger 在配置的张开和闭合姿态之间插值。 +`vr_hand_pose` 只支持 L6:手部 pose 消失时,对应侧会保持上一条命令;L6 速度会设为最大值; +Teleopit 会先将 Pico 手部状态转成 21 个 landmarks,再只通过 somehand 0.2.0 公开的 `somehand.api` 调用。 | 字段 | 说明 | 默认值 | |---|---|---| | `hands.enabled` | 启用可选手部运行时 | `false` | -| `hands.mode` | `off`、`gripper` 或 `vr_hand_pose` | `off` | -| `hands.driver` | 手部设备驱动;当前支持 `linkerhand_l6` | `linkerhand_l6` | -| `hands.linkerhand_l6.hand_type` | 控制侧:`left`、`right` 或 `both`;`vr_hand_pose` 要求 `both` | `both` | +| `hands.mode` | `gripper` 或 `vr_hand_pose` | `gripper` | +| `hands.driver` | 手部设备驱动:`linkerhand_l6` 或 `linkerhand_o6` | `linkerhand_l6` | +| `hands.sides` | 控制侧 | `[left, right]` | +| `hands.rate_hz` | gripper 最大命令频率(Hz) | `30.0` | +| `hands.frame_timeout_s` | 手柄或手部 pose 过期阈值 | `0.3` | | `hands.linkerhand_l6.left_can` / `right_can` | 左右手 CAN 通道 | `can0` / `can1` | -| `hands.linkerhand_l6.rate` | gripper 最大命令频率(Hz) | `30.0` | -| `hands.linkerhand_l6.frame_timeout` | gripper 手柄超时或 VR 手部 pose 过期阈值 | `0.3` | | `hands.linkerhand_l6.speed` | `gripper` 使用的 L6 速度;`vr_hand_pose` 会覆盖为最大速度 | 见配置 | | `hands.linkerhand_l6.deadman_threshold` | 启用单侧控制所需的最小 grip 值 | `0.5` | | `hands.linkerhand_l6.trigger_deadzone` | trigger 两端死区 | `0.05` | | `hands.linkerhand_l6.open_pose` / `close_pose` | L6 的 6 维张开/闭合姿态 | 见配置 | +| `hands.linkerhand_o6.left_can` / `right_can` | 左右 O6 手 CAN 通道 | `can0` / `can1` | +| `hands.linkerhand_o6.speed` | `gripper` 使用的 O6 速度 | 见配置 | +| `hands.linkerhand_o6.open_pose` / `close_pose` | O6 的 6 维张开/闭合姿态 | 见配置 | | `hands.somehand.config_path` | `vr_hand_pose` 使用的 somehand 双手 L6 配置 | 见配置 | -| `hands.somehand.rate` | 低延时 `vr_hand_pose` 命令频率(Hz) | `60.0` | -| `hands.somehand.threaded` | 在机器人控制循环外运行 `vr_hand_pose` 手部重定向 | `true` | +| `hands.somehand.rate_hz` | 低延时 `vr_hand_pose` 命令频率(Hz) | `60.0` | | `hands.somehand.max_iterations` | `vr_hand_pose` 的 somehand solver 迭代上限 | `12` | | `hands.somehand.temporal_filter_alpha` | somehand 输入 landmarks 平滑 alpha;`1.0` 表示关闭平滑延时 | `1.0` | | `hands.somehand.output_alpha` | somehand qpos 输出平滑 alpha;`1.0` 表示关闭平滑延时 | `1.0` | diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md index 3bc1ac17..6ebaba87 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md @@ -61,8 +61,8 @@ receiver 可以运行在工作站 PC,也可以运行在机器人 onboard 计 完整设置流程详见 [Pico Sim2Sim](../tutorials/pico-sim2sim) 和 [Pico Sim2Real](../tutorials/pico-sim2real)。 -Pico sim2real 可选的 LinkerHand L6 控制通过 `dexhand` extra 安装。它包含 -LinkerHand SDK submodule,以及 `vr_hand_pose` 模式使用的远程 somehand 包: +Pico sim2real 可选的 LinkerHand 控制通过 `dexhand` extra 安装。它包含 +LinkerHand SDK submodule,以及 L6 `vr_hand_pose` 模式使用的远程 somehand 包: ```bash git submodule update --init --recursive @@ -70,7 +70,7 @@ pip install -e '.[dexhand]' scripts/setup/download_somehand_l6_assets.sh ``` -只有在 `hands.mode=gripper` 或 `hands.mode=vr_hand_pose` 时才需要安装这个 extra。 +只有在 `hands.enabled=true` 时才需要安装这个 extra。 ## 验证安装 diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md index 37573f0c..1425fcc0 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md @@ -137,13 +137,14 @@ Pico 暂停/恢复是 mocap-session control event。 恢复时请保持静止,并尽量接近暂停时的姿态。这样可以减少实时追踪恢复时的参考突变。 ::: -## 可选 LinkerHand L6 控制 +## 可选 LinkerHand 控制 -Pico sim2real 可以用两种模式控制 LinkerHand L6: +Pico sim2real 可以用 Pico 输入控制 LinkerHand: - `gripper`:按住同侧 grip 作为 deadman,同侧 trigger 控制对应手闭合。 - 该模式使用配置的 `hands.linkerhand_l6.speed`,默认值为 50。 -- `vr_hand_pose`:通过 somehand 重定向 Pico 手部 pose,并下发连续 L6 手部目标。 + 该模式支持 `hands.driver=linkerhand_l6` 和 `hands.driver=linkerhand_o6`; + 速度和张开/闭合姿态来自对应 driver 配置。 +- `vr_hand_pose`:只支持 L6,通过 somehand 重定向 Pico 手部 pose,并下发连续 L6 手部目标。 如果某侧手部 pose 消失,该侧会保持上一条手势命令。这个模式使用 Teleopit 的 Pico landmark 适配器和 somehand 0.2.0 公开的 `somehand.api`,并始终将 L6 速度设为最大值。默认配置使用 60 Hz 的低延时 somehand 路径并减少平滑,所以响应会更快, @@ -165,7 +166,7 @@ sudo /usr/sbin/ip link set can0 up type can bitrate 1000000 sudo /usr/sbin/ip link set can1 up type can bitrate 1000000 ``` -启用完整 sim2real 前,先用独立开合测试验证灵巧手连接: +启用完整 sim2real 前,先用独立开合测试验证灵巧手连接。测试默认一直运行到 Ctrl-C: ```bash python scripts/dev/test_linkerhand_l6.py \ @@ -174,19 +175,43 @@ python scripts/dev/test_linkerhand_l6.py \ --right-can can1 ``` -然后在 Pico sim2real 中启用 L6 控制: +O6 独立开合测试需要加上 O6 driver: + +```bash +python scripts/dev/test_linkerhand_l6.py \ + --driver linkerhand_o6 \ + --hand-type both \ + --left-can can0 \ + --right-can can1 +``` + +如果要用实时 Pico gripper 输入测试 O6,再加 `--mode gripper`。 + +然后在 Pico sim2real 中启用 L6 gripper 控制: ```bash hands.enabled=true +hands.driver=linkerhand_l6 hands.mode=gripper hands.linkerhand_l6.left_can=can0 hands.linkerhand_l6.right_can=can1 ``` +O6 gripper 控制使用: + +```bash +hands.enabled=true +hands.driver=linkerhand_o6 +hands.mode=gripper +hands.linkerhand_o6.left_can=can0 +hands.linkerhand_o6.right_can=can1 +``` + 连续 VR 手部 pose 控制使用: ```bash hands.enabled=true +hands.driver=linkerhand_l6 hands.mode=vr_hand_pose hands.linkerhand_l6.left_can=can0 hands.linkerhand_l6.right_can=can1 @@ -225,8 +250,9 @@ mocap_switch.check_frames=10 # 更换 Pico 暂停键 input.pause_button=right_axis_click -# 开启 LinkerHand L6 控制 +# 开启 LinkerHand gripper 控制 hands.enabled=true +hands.driver=linkerhand_l6 hands.mode=gripper # 开启头显视频预览 @@ -242,5 +268,5 @@ input.video.enabled=true | 无法进入 debug mode | Unitree mode 释放失败 | 停止其他机器人模式后再次按 `Start` | | 机器人进入 `STANDING` 但不进入 `MOCAP` | 动捕验证失败 | 保持追踪稳定,查看 `mocap_switch.check_frames` 日志 | | Pico 暂停没有返回 `STANDING` | 这是预期行为 | Pico 暂停只冻结 mocap;按遥控器 `X` 返回 `STANDING` | -| LinkerHand 不动 | `hands.enabled=false`、模式为 `off`、不在 `MOCAP`、gripper deadman 未按住、SDK/资产未安装,或 CAN 通道错误 | 设置 `hands.enabled=true` 和 `hands.mode`,进入 `MOCAP`,运行 `scripts/dev/test_linkerhand_l6.py`,并检查 `hands.linkerhand_l6.left_can` / `right_can` | +| LinkerHand 不动 | `hands.enabled=false`、不在 `MOCAP`、gripper deadman 未按住、SDK/资产未安装,或 CAN 通道错误 | 设置 `hands.enabled=true` 和 `hands.mode`,进入 `MOCAP`,运行 `scripts/dev/test_linkerhand_l6.py`,并检查所选 driver 的 `left_can` / `right_can` | | 视频预览不可用 | RealSense 或视频源失败 | 检查相机权限、`input.video.source` 和日志 | diff --git a/scripts/dev/test_linkerhand_l6.py b/scripts/dev/test_linkerhand_l6.py index e3f73826..3b40d58c 100644 --- a/scripts/dev/test_linkerhand_l6.py +++ b/scripts/dev/test_linkerhand_l6.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -"""Exercise LinkerHand L6 dexterous-hand control modes.""" +"""Exercise LinkerHand dexterous-hand control modes.""" from __future__ import annotations @@ -24,35 +24,24 @@ Pico4InputProvider, ) from teleopit.sim2real.hands.linkerhand_l6 import VR_HAND_POSE_SPEED, build_linkerhand_l6 # noqa: E402 +from teleopit.sim2real.hands.linkerhand_o6 import build_linkerhand_o6 # noqa: E402 THUMB_YAW_DEFAULT = 10 OPEN_POSE = [250, THUMB_YAW_DEFAULT, 250, 250, 250, 250] CLOSE_POSE = [79, THUMB_YAW_DEFAULT, 0, 0, 0, 0] DEFAULT_SPEED = [50, 50, 50, 50, 50, 50] +O6_OPEN_POSE = [250, 250, 250, 250, 250, 250] +O6_CLOSE_POSE = [86, 73, 118, 111, 110, 111] +O6_DEFAULT_SPEED = [255, 255, 255, 255, 255, 255] DEFAULT_SOMEHAND_CONFIG_PATH = "third_party/somehand/configs/retargeting/bihand/linkerhand_l6_bihand.yaml" -DEFAULT_LINKERHAND_SDK_ROOT = "third_party/linkerhand-python-sdk" - - -def uint8(value: str) -> int: - parsed = int(value) - if parsed < 0 or parsed > 255: - raise argparse.ArgumentTypeError("value must be in range 0-255") - return parsed - - -def positive_float(value: str) -> float: - parsed = float(value) - if parsed <= 0.0: - raise argparse.ArgumentTypeError("value must be greater than 0") - return parsed - - -def positive_int(value: str) -> int: - parsed = int(value) - if parsed <= 0: - raise argparse.ArgumentTypeError("value must be greater than 0") - return parsed +OPEN_CLOSE_HOLD_S = 1.0 +GRIPPER_RATE_HZ = 30.0 +VR_HAND_POSE_RATE_HZ = 60.0 +PICO_START_TIMEOUT_S = 60.0 +FRAME_TIMEOUT_S = 0.3 +TRIGGER_DEADZONE = 0.05 +DEADMAN_THRESHOLD = 0.5 def selected_hand_types(hand_type: str) -> tuple[str, ...]: @@ -62,7 +51,13 @@ def selected_hand_types(hand_type: str) -> tuple[str, ...]: def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="Test LinkerHand L6 dexterous-hand control modes") + parser = argparse.ArgumentParser(description="Test LinkerHand dexterous-hand control modes") + parser.add_argument( + "--driver", + choices=["linkerhand_l6", "linkerhand_o6"], + default="linkerhand_l6", + help="Hand driver to test. O6 currently supports open_close and gripper only.", + ) parser.add_argument( "--mode", choices=["open_close", "gripper", "vr_hand_pose"], @@ -81,85 +76,45 @@ def parse_args() -> argparse.Namespace: default="None", help='RS485 serial port such as /dev/ttyUSB0; "None" uses CAN', ) - parser.add_argument("--cycles", type=positive_int, default=3) - parser.add_argument("--hold-s", type=positive_float, default=1.0) - parser.add_argument( - "--duration-s", - type=positive_float, - default=30.0, - help="Live Pico test duration for gripper/vr_hand_pose modes.", - ) - parser.add_argument("--rate", type=positive_float, default=30.0) - parser.add_argument("--frame-timeout", type=positive_float, default=0.3) - parser.add_argument("--trigger-deadzone", type=float, default=0.05) - parser.add_argument("--deadman-threshold", type=float, default=0.5) - parser.add_argument("--thumb-yaw-center", type=uint8, default=THUMB_YAW_DEFAULT) - parser.add_argument("--print-input", action="store_true") - parser.add_argument( - "--speed", - type=uint8, - nargs=6, - default=DEFAULT_SPEED, - help="L6 speed for open_close and gripper modes. vr_hand_pose always uses max speed.", - metavar=("THUMB_PITCH", "THUMB_YAW", "INDEX", "MIDDLE", "RING", "LITTLE"), - ) - parser.add_argument( - "--open-pose", - type=uint8, - nargs=6, - default=OPEN_POSE, - metavar=("THUMB_PITCH", "THUMB_YAW", "INDEX", "MIDDLE", "RING", "LITTLE"), - ) - parser.add_argument( - "--close-pose", - type=uint8, - nargs=6, - default=CLOSE_POSE, - metavar=("THUMB_PITCH", "THUMB_YAW", "INDEX", "MIDDLE", "RING", "LITTLE"), - ) - parser.add_argument("--somehand-config-path", default=DEFAULT_SOMEHAND_CONFIG_PATH) - parser.add_argument("--somehand-sdk-root", default=DEFAULT_LINKERHAND_SDK_ROOT) - parser.add_argument("--bridge-host", default="0.0.0.0") - parser.add_argument("--bridge-port", type=positive_int, default=63901) - parser.add_argument("--bridge-advertise-ip", default=None) - parser.add_argument("--bridge-start-timeout", type=positive_float, default=10.0) - parser.add_argument("--no-bridge-discovery", action="store_true") args = parser.parse_args() - args.open_pose[1] = args.thumb_yaw_center - args.close_pose[1] = args.thumb_yaw_center - if args.trigger_deadzone < 0.0 or args.trigger_deadzone >= 0.5: - raise SystemExit("--trigger-deadzone must be in [0, 0.5)") - if args.deadman_threshold <= 0.0 or args.deadman_threshold >= 1.0: - raise SystemExit("--deadman-threshold must be in (0, 1)") + if args.driver == "linkerhand_o6" and args.mode == "vr_hand_pose": + raise SystemExit("hands.driver=linkerhand_o6 supports only --mode open_close or gripper") + args.speed = list(O6_DEFAULT_SPEED if args.driver == "linkerhand_o6" else DEFAULT_SPEED) + args.open_pose = list(O6_OPEN_POSE if args.driver == "linkerhand_o6" else OPEN_POSE) + args.close_pose = list(O6_CLOSE_POSE if args.driver == "linkerhand_o6" else CLOSE_POSE) return args def make_config(args: argparse.Namespace, *, mode: str) -> dict[str, object]: speed = VR_HAND_POSE_SPEED if mode == "vr_hand_pose" else args.speed + rate_hz = VR_HAND_POSE_RATE_HZ if mode == "vr_hand_pose" else GRIPPER_RATE_HZ + driver_section = "linkerhand_o6" if args.driver == "linkerhand_o6" else "linkerhand_l6" + driver_cfg = { + "left_can": args.left_can, + "right_can": args.right_can, + "modbus": args.modbus, + "trigger_deadzone": TRIGGER_DEADZONE, + "deadman_threshold": DEADMAN_THRESHOLD, + "speed": list(speed), + "open_pose": list(args.open_pose), + "close_pose": list(args.close_pose), + "print_input": False, + } + if args.driver == "linkerhand_l6": + driver_cfg["thumb_yaw_center"] = THUMB_YAW_DEFAULT return { "input": {"provider": "pico4"}, "hands": { "enabled": True, - "driver": "linkerhand_l6", + "driver": args.driver, "mode": mode, "sides": list(selected_hand_types(args.hand_type)), - "rate_hz": args.rate, - "frame_timeout_s": args.frame_timeout, - "linkerhand_l6": { - "left_can": args.left_can, - "right_can": args.right_can, - "modbus": args.modbus, - "trigger_deadzone": args.trigger_deadzone, - "deadman_threshold": args.deadman_threshold, - "thumb_yaw_center": args.thumb_yaw_center, - "speed": list(speed), - "open_pose": list(args.open_pose), - "close_pose": list(args.close_pose), - "print_input": args.print_input, - }, + "rate_hz": rate_hz, + "frame_timeout_s": FRAME_TIMEOUT_S, + driver_section: driver_cfg, "somehand": { - "config_path": args.somehand_config_path, - "rate_hz": args.rate, + "config_path": DEFAULT_SOMEHAND_CONFIG_PATH, + "rate_hz": VR_HAND_POSE_RATE_HZ, "max_iterations": 12, "temporal_filter_alpha": 1.0, "output_alpha": 1.0, @@ -168,6 +123,12 @@ def make_config(args: argparse.Namespace, *, mode: str) -> dict[str, object]: } +def build_driver_runtime(config: dict[str, object], *, driver: str): + if driver == "linkerhand_o6": + return build_linkerhand_o6(config) + return build_linkerhand_l6(config) + + def send_all(hands: dict[str, object], pose: Sequence[int], *, label: str) -> None: print(f"{label}: {list(pose)}", flush=True) for hand_type, hand in hands.items(): @@ -175,15 +136,15 @@ def send_all(hands: dict[str, object], pose: Sequence[int], *, label: str) -> No hand.finger_move(pose=list(pose)) -def make_pico_provider(args: argparse.Namespace) -> Pico4InputProvider: +def make_pico_provider() -> Pico4InputProvider: return Pico4InputProvider( - timeout=args.duration_s, + timeout=PICO_START_TIMEOUT_S, pause_button=None, - bridge_host=args.bridge_host, - bridge_port=args.bridge_port, - bridge_discovery=not args.no_bridge_discovery, - bridge_advertise_ip=args.bridge_advertise_ip, - bridge_start_timeout=args.bridge_start_timeout, + bridge_host="0.0.0.0", + bridge_port=63901, + bridge_discovery=True, + bridge_advertise_ip=None, + bridge_start_timeout=10.0, bridge_video=None, bridge_video_enabled=False, ) @@ -193,14 +154,12 @@ def run_live_until_done( runtime: object, *, provider: Pico4InputProvider, - duration_s: float, mode_label: str, rate_hz: float, ) -> None: - deadline = time.monotonic() + duration_s last_seq: int | None = None - print(f"Running {mode_label} for {duration_s:.1f}s; press Ctrl-C to stop early.", flush=True) - while time.monotonic() < deadline: + print(f"Running {mode_label}; press Ctrl-C to stop.", flush=True) + while True: now_s = time.monotonic() controller_snapshot = provider.get_controller_snapshot() hand_snapshot = provider.get_hand_snapshot() @@ -233,7 +192,7 @@ def run_open_close(args: argparse.Namespace) -> None: hands: dict[str, object] = {} print( - "Testing LinkerHand L6 | " + f"Testing {args.driver} | " f"hands={','.join(hand_types)} | " f"can={','.join(f'{hand}:{can_channels[hand]}' for hand in hand_types)} | " f"modbus={args.modbus}", @@ -242,7 +201,7 @@ def run_open_close(args: argparse.Namespace) -> None: try: for hand_type in hand_types: hand = LinkerHandApi( - hand_joint="L6", + hand_joint="O6" if args.driver == "linkerhand_o6" else "L6", hand_type=hand_type, modbus=args.modbus, can=can_channels[hand_type], @@ -251,13 +210,15 @@ def run_open_close(args: argparse.Namespace) -> None: hands[hand_type] = hand send_all(hands, args.open_pose, label="startup open") - time.sleep(args.hold_s) - for cycle in range(args.cycles): - print(f"cycle {cycle + 1}/{args.cycles}", flush=True) + time.sleep(OPEN_CLOSE_HOLD_S) + cycle = 0 + while True: + cycle += 1 + print(f"cycle {cycle}", flush=True) send_all(hands, args.close_pose, label="close") - time.sleep(args.hold_s) + time.sleep(OPEN_CLOSE_HOLD_S) send_all(hands, args.open_pose, label="open") - time.sleep(args.hold_s) + time.sleep(OPEN_CLOSE_HOLD_S) except KeyboardInterrupt: print("Interrupted; opening hands before exit", flush=True) finally: @@ -272,8 +233,8 @@ def run_open_close(args: argparse.Namespace) -> None: def run_gripper(args: argparse.Namespace) -> None: config = make_config(args, mode="gripper") - provider = make_pico_provider(args) - device, mapper = build_linkerhand_l6(config) + provider = make_pico_provider() + device, mapper = build_driver_runtime(config, driver=args.driver) from teleopit.sim2real.hands.worker import HandRuntime runtime = HandRuntime(device, mapper) @@ -284,7 +245,12 @@ def run_gripper(args: argparse.Namespace) -> None: ) try: runtime.start() - run_live_until_done(runtime, provider=provider, duration_s=args.duration_s, mode_label="gripper", rate_hz=args.rate) + run_live_until_done( + runtime, + provider=provider, + mode_label="gripper", + rate_hz=GRIPPER_RATE_HZ, + ) except KeyboardInterrupt: print("Interrupted; opening hands before exit", flush=True) finally: @@ -298,8 +264,8 @@ def run_vr_hand_pose(args: argparse.Namespace) -> None: raise SystemExit("hands.mode=vr_hand_pose currently requires --hand-type both") config = make_config(args, mode="vr_hand_pose") - provider = make_pico_provider(args) - device, mapper = build_linkerhand_l6(config) + provider = make_pico_provider() + device, mapper = build_driver_runtime(config, driver=args.driver) from teleopit.sim2real.hands.worker import HandRuntime runtime = HandRuntime(device, mapper) @@ -310,7 +276,12 @@ def run_vr_hand_pose(args: argparse.Namespace) -> None: ) try: runtime.start() - run_live_until_done(runtime, provider=provider, duration_s=args.duration_s, mode_label="vr_hand_pose", rate_hz=args.rate) + run_live_until_done( + runtime, + provider=provider, + mode_label="vr_hand_pose", + rate_hz=VR_HAND_POSE_RATE_HZ, + ) except KeyboardInterrupt: print("Interrupted; opening hands before exit", flush=True) finally: diff --git a/scripts/run/run_sim2real.py b/scripts/run/run_sim2real.py index 080a7682..f36d6673 100644 --- a/scripts/run/run_sim2real.py +++ b/scripts/run/run_sim2real.py @@ -19,7 +19,7 @@ def _print_sim2real_controls(cfg: DictConfig) -> None: if provider == "pico4": print(" Mocap pause/resume: Pico/controller A.") print(" Arm-only mode: Pico/controller B toggles MOCAP <-> ARMS.") - print(" Dexterous hand: hands.enabled=true hands.mode=gripper|vr_hand_pose.") + print(" Dexterous hand: hands.enabled=true hands.driver=linkerhand_l6|linkerhand_o6 hands.mode=gripper|vr_hand_pose.") print(" State flow: IDLE -> STANDING -> MOCAP <-> ARMS, X -> STANDING, Any -> DAMPING.") else: print(" Offline playback: A pause/resume, B replay from start.") diff --git a/teleopit/configs/pico4_sim2real.yaml b/teleopit/configs/pico4_sim2real.yaml index 4ff0f072..2cb15cb5 100644 --- a/teleopit/configs/pico4_sim2real.yaml +++ b/teleopit/configs/pico4_sim2real.yaml @@ -44,10 +44,10 @@ joint_vel_limit: 10.0 arm_mocap: controlled_joint_indices: [15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28] -# Optional LinkerHand L6 control from Pico controller grip/trigger or VR hand pose. +# Optional LinkerHand control from Pico controller grip/trigger or VR hand pose. hands: enabled: false - driver: linkerhand_l6 + driver: linkerhand_l6 # linkerhand_l6 | linkerhand_o6 mode: gripper # gripper | vr_hand_pose sides: [left, right] rate_hz: 30.0 @@ -63,6 +63,16 @@ hands: open_pose: [250, 10, 250, 250, 250, 250] close_pose: [79, 10, 0, 0, 0, 0] print_input: false + linkerhand_o6: + left_can: can0 + right_can: can1 + modbus: "None" + trigger_deadzone: 0.05 + deadman_threshold: 0.5 + speed: [255, 255, 255, 255, 255, 255] + open_pose: [250, 250, 250, 250, 250, 250] + close_pose: [86, 73, 118, 111, 110, 111] + print_input: false somehand: config_path: third_party/somehand/configs/retargeting/bihand/linkerhand_l6_bihand.yaml # Low-latency vr_hand_pose path. This favors response speed over smoothing. diff --git a/teleopit/configs/sim2real.yaml b/teleopit/configs/sim2real.yaml index 3a84fbd9..644eb446 100644 --- a/teleopit/configs/sim2real.yaml +++ b/teleopit/configs/sim2real.yaml @@ -40,10 +40,10 @@ standing_return_kp_ramp_floor_ratio: 0.5 # Joint velocity safety limit (rad/s) -- trigger emergency damping if exceeded joint_vel_limit: 10.0 -# Optional LinkerHand L6 control. Use only with input.provider=pico4. +# Optional LinkerHand control. Use only with input.provider=pico4. hands: enabled: false - driver: linkerhand_l6 + driver: linkerhand_l6 # linkerhand_l6 | linkerhand_o6 mode: gripper # gripper | vr_hand_pose sides: [left, right] rate_hz: 30.0 @@ -59,6 +59,16 @@ hands: open_pose: [250, 10, 250, 250, 250, 250] close_pose: [79, 10, 0, 0, 0, 0] print_input: false + linkerhand_o6: + left_can: can0 + right_can: can1 + modbus: "None" + trigger_deadzone: 0.05 + deadman_threshold: 0.5 + speed: [255, 255, 255, 255, 255, 255] + open_pose: [250, 250, 250, 250, 250, 250] + close_pose: [86, 73, 118, 111, 110, 111] + print_input: false somehand: config_path: third_party/somehand/configs/retargeting/bihand/linkerhand_l6_bihand.yaml # Low-latency vr_hand_pose path. This favors response speed over smoothing. diff --git a/teleopit/sim2real/hands/linkerhand_l6.py b/teleopit/sim2real/hands/linkerhand_l6.py index 52c11f86..d7ed00be 100644 --- a/teleopit/sim2real/hands/linkerhand_l6.py +++ b/teleopit/sim2real/hands/linkerhand_l6.py @@ -48,6 +48,7 @@ class LinkerHandL6Config: speed: tuple[int, ...] open_pose: tuple[int, ...] close_pose: tuple[int, ...] + fixed_thumb_yaw: int | None print_input: bool somehand_config_path: str somehand_rate_hz: float @@ -86,6 +87,7 @@ def parse_linkerhand_l6_config(cfg: Any) -> LinkerHandL6Config: speed=tuple(speed), open_pose=tuple(open_pose), close_pose=tuple(close_pose), + fixed_thumb_yaw=thumb_yaw, print_input=bool(cfg_get(l6_cfg, "print_input", False)), somehand_config_path=str(cfg_get(somehand_cfg, "config_path", DEFAULT_SOMEHAND_CONFIG)), somehand_rate_hz=_positive_float(cfg_get(somehand_cfg, "rate_hz", cfg_get(somehand_cfg, "rate", 60.0)), "somehand.rate_hz"), @@ -155,8 +157,9 @@ def close(self) -> None: class GripperMapper(HandInputMapper): - def __init__(self, config: LinkerHandL6Config): + def __init__(self, config: Any): self.config = config + self._fixed_thumb_yaw = getattr(config, "fixed_thumb_yaw", getattr(config, "thumb_yaw_center", None)) self._active = False self._next_tick_s = 0.0 @@ -192,7 +195,7 @@ def map(self, *, controller_snapshot: object | None, hand_snapshot: object | Non open_pose=self.config.open_pose, close_pose=self.config.close_pose, deadzone=self.config.trigger_deadzone, - thumb_yaw_default=self.config.thumb_yaw_center, + fixed_thumb_yaw=self._fixed_thumb_yaw, ) commands.append(HandPoseCommand(side, tuple(pose), False, "controller")) return tuple(commands) @@ -311,10 +314,21 @@ def build_linkerhand_l6(cfg: Any) -> tuple[HandDevice, HandInputMapper]: return device, mapper -def trigger_to_pose(trigger: float, *, open_pose: Sequence[int], close_pose: Sequence[int], deadzone: float, thumb_yaw_default: int) -> list[int]: +def trigger_to_pose( + trigger: float, + *, + open_pose: Sequence[int], + close_pose: Sequence[int], + deadzone: float, + fixed_thumb_yaw: int | None = None, + thumb_yaw_default: int | None = None, +) -> list[int]: + if fixed_thumb_yaw is None: + fixed_thumb_yaw = thumb_yaw_default alpha = _normalize_trigger(trigger, deadzone) pose = [int(round(float(a) + alpha * (float(b) - float(a)))) for a, b in zip(open_pose, close_pose)] - pose[1] = int(thumb_yaw_default) + if fixed_thumb_yaw is not None: + pose[1] = int(fixed_thumb_yaw) return pose diff --git a/teleopit/sim2real/hands/linkerhand_o6.py b/teleopit/sim2real/hands/linkerhand_o6.py new file mode 100644 index 00000000..6d4de349 --- /dev/null +++ b/teleopit/sim2real/hands/linkerhand_o6.py @@ -0,0 +1,163 @@ +from __future__ import annotations + +from dataclasses import dataclass +import logging +from typing import Any, Sequence + +from teleopit.runtime.common import cfg_get +from teleopit.sim2real.hands.base import HAND_SIDES, HandDevice, HandInputMapper +from teleopit.sim2real.hands.linkerhand_l6 import GripperMapper + +logger = logging.getLogger(__name__) + +OPEN_POSE = (250, 250, 250, 250, 250, 250) +CLOSE_POSE = (86, 73, 118, 111, 110, 111) +DEFAULT_SPEED = (255, 255, 255, 255, 255, 255) + + +@dataclass(frozen=True) +class LinkerHandO6Config: + mode: str + sides: tuple[str, ...] + left_can: str + right_can: str + modbus: str + rate_hz: float + frame_timeout_s: float + trigger_deadzone: float + deadman_threshold: float + speed: tuple[int, ...] + open_pose: tuple[int, ...] + close_pose: tuple[int, ...] + fixed_thumb_yaw: int | None + print_input: bool + + +def parse_linkerhand_o6_config(cfg: Any) -> LinkerHandO6Config: + hands_cfg = cfg_get(cfg, "hands", {}) or {} + o6_cfg = cfg_get(hands_cfg, "linkerhand_o6", {}) or {} + mode = str(cfg_get(hands_cfg, "mode", "gripper")).strip().lower() + if mode != "gripper": + raise ValueError(f"hands.driver=linkerhand_o6 supports only hands.mode=gripper, got {mode!r}") + sides = tuple(str(side).strip().lower() for side in cfg_get(hands_cfg, "sides", HAND_SIDES)) + if not sides or any(side not in HAND_SIDES for side in sides): + raise ValueError("hands.sides must contain left, right, or both sides") + return LinkerHandO6Config( + mode=mode, + sides=sides, + left_can=str(cfg_get(o6_cfg, "left_can", "can0")), + right_can=str(cfg_get(o6_cfg, "right_can", "can1")), + modbus=str(cfg_get(o6_cfg, "modbus", "None")), + rate_hz=_positive_float(cfg_get(hands_cfg, "rate_hz", cfg_get(o6_cfg, "rate_hz", 30.0)), "rate_hz"), + frame_timeout_s=_positive_float(cfg_get(hands_cfg, "frame_timeout_s", 0.3), "frame_timeout_s"), + trigger_deadzone=_deadzone(cfg_get(o6_cfg, "trigger_deadzone", 0.05)), + deadman_threshold=_threshold(cfg_get(o6_cfg, "deadman_threshold", 0.5)), + speed=tuple(_pose_values(cfg_get(o6_cfg, "speed", DEFAULT_SPEED), "speed")), + open_pose=tuple(_pose_values(cfg_get(o6_cfg, "open_pose", OPEN_POSE), "open_pose")), + close_pose=tuple(_pose_values(cfg_get(o6_cfg, "close_pose", CLOSE_POSE), "close_pose")), + fixed_thumb_yaw=None, + print_input=bool(cfg_get(o6_cfg, "print_input", False)), + ) + + +class LinkerHandO6Device(HandDevice): + def __init__(self, config: LinkerHandO6Config): + self.config = config + self._hands: dict[str, Any] = {} + self._last_pose: dict[str, tuple[int, ...] | None] = {side: None for side in config.sides} + + def connect(self) -> None: + try: + from LinkerHand.linker_hand_api import LinkerHandApi + except ImportError as exc: + raise ImportError( + "LinkerHand SDK is required for hands.driver=linkerhand_o6. " + "Install it with: pip install -e third_party/linkerhand-python-sdk" + ) from exc + try: + for side in self.config.sides: + hand = LinkerHandApi( + hand_joint="O6", + hand_type=side, + modbus=self.config.modbus, + can=self.config.left_can if side == "left" else self.config.right_can, + ) + hand.set_speed(speed=list(self.config.speed)) + self._hands[side] = hand + except (Exception, SystemExit) as exc: + self.close() + if isinstance(exc, SystemExit): + raise RuntimeError("LinkerHand SDK exited during startup") from exc + raise + self.open_all(force=True, reason="startup") + + def send_pose(self, side: str, pose: Sequence[int], *, force: bool = False, reason: str = "") -> None: + del reason + next_pose = tuple(_uint8(value, f"{side}.pose") for value in pose) + if not force and self._last_pose.get(side) == next_pose: + return + hand = self._hands.get(side) + if hand is None: + return + hand.finger_move(pose=list(next_pose)) + self._last_pose[side] = next_pose + + def open_all(self, *, force: bool = False, reason: str = "") -> None: + for side in self.config.sides: + self.send_pose(side, self.config.open_pose, force=force, reason=reason) + + def close(self) -> None: + try: + self.open_all(force=True, reason="shutdown") + except Exception: + logger.exception("Failed to open LinkerHand O6 on shutdown") + for hand in self._hands.values(): + inner = getattr(hand, "hand", None) + close_backend = getattr(inner, "close_can_interface", None) + if callable(close_backend): + close_backend() + continue + close = getattr(inner, "close", None) + if callable(close): + close() + self._hands.clear() + + +def build_linkerhand_o6(cfg: Any) -> tuple[HandDevice, HandInputMapper]: + config = parse_linkerhand_o6_config(cfg) + return LinkerHandO6Device(config), GripperMapper(config) + + +def _uint8(value: object, field_name: str) -> int: + parsed = int(value) + if parsed < 0 or parsed > 255: + raise ValueError(f"hands.linkerhand_o6.{field_name} must be in 0-255, got {value!r}") + return parsed + + +def _pose_values(value: object, field_name: str) -> list[int]: + parsed = [_uint8(item, field_name) for item in value] # type: ignore[union-attr] + if len(parsed) != 6: + raise ValueError(f"hands.linkerhand_o6.{field_name} must contain 6 values") + return parsed + + +def _positive_float(value: object, field_name: str) -> float: + parsed = float(value) + if parsed <= 0: + raise ValueError(f"hands.{field_name} must be > 0") + return parsed + + +def _deadzone(value: object) -> float: + parsed = float(value) + if parsed < 0.0 or parsed >= 0.5: + raise ValueError("hands.linkerhand_o6.trigger_deadzone must be in [0, 0.5)") + return parsed + + +def _threshold(value: object) -> float: + parsed = float(value) + if parsed <= 0.0 or parsed >= 1.0: + raise ValueError("hands.linkerhand_o6.deadman_threshold must be in (0, 1)") + return parsed diff --git a/teleopit/sim2real/hands/worker.py b/teleopit/sim2real/hands/worker.py index c1a9cdb5..b969d851 100644 --- a/teleopit/sim2real/hands/worker.py +++ b/teleopit/sim2real/hands/worker.py @@ -7,6 +7,7 @@ from teleopit.runtime.common import cfg_get from teleopit.sim2real.hands.base import HandDevice, HandInputMapper from teleopit.sim2real.hands.linkerhand_l6 import build_linkerhand_l6 +from teleopit.sim2real.hands.linkerhand_o6 import build_linkerhand_o6 logger = logging.getLogger(__name__) @@ -73,7 +74,10 @@ def build_hand_runtime(cfg: Any) -> HandRuntime | DisabledHandRuntime: if not bool(cfg_get(hands_cfg, "enabled", False)): return DisabledHandRuntime() driver = str(cfg_get(hands_cfg, "driver", "linkerhand_l6")).strip().lower() - if driver != "linkerhand_l6": - raise ValueError(f"Unsupported hands.driver={driver!r}; only linkerhand_l6 is implemented") - device, mapper = build_linkerhand_l6(cfg) + if driver == "linkerhand_l6": + device, mapper = build_linkerhand_l6(cfg) + elif driver == "linkerhand_o6": + device, mapper = build_linkerhand_o6(cfg) + else: + raise ValueError(f"Unsupported hands.driver={driver!r}; supported drivers: linkerhand_l6, linkerhand_o6") return HandRuntime(device, mapper) diff --git a/tests/test_dexterous_hand.py b/tests/test_dexterous_hand.py index 0e65b644..272fa6a4 100644 --- a/tests/test_dexterous_hand.py +++ b/tests/test_dexterous_hand.py @@ -14,6 +14,11 @@ parse_linkerhand_l6_config, trigger_to_pose, ) +from teleopit.sim2real.hands.linkerhand_o6 import ( + CLOSE_POSE as O6_CLOSE_POSE, + LinkerHandO6Device, + parse_linkerhand_o6_config, +) from teleopit.sim2real.hands.pico_landmarks import pico_hand_to_landmarks from teleopit.sim2real.hands.worker import HandRuntime @@ -21,10 +26,14 @@ class FakeInnerHand: def __init__(self) -> None: self.close_calls = 0 + self.close_can_interface_calls = 0 def close(self) -> None: self.close_calls += 1 + def close_can_interface(self) -> None: + self.close_can_interface_calls += 1 + class FakeLinkerHandApi: instances: list["FakeLinkerHandApi"] = [] @@ -37,6 +46,7 @@ def __init__(self, *, hand_joint: str, hand_type: str, modbus: str, can: str) -> self.hand = FakeInnerHand() self.speed: list[int] | None = None self.poses: list[list[int]] = [] + self.close_can_calls = 0 FakeLinkerHandApi.instances.append(self) def set_speed(self, speed: list[int]) -> None: @@ -45,6 +55,9 @@ def set_speed(self, speed: list[int]) -> None: def finger_move(self, pose: list[int]) -> None: self.poses.append(list(pose)) + def close_can(self) -> None: + self.close_can_calls += 1 + def _cfg(mode: str = "gripper") -> dict[str, object]: return { @@ -75,6 +88,27 @@ def _cfg(mode: str = "gripper") -> dict[str, object]: } +def _o6_cfg(mode: str = "gripper") -> dict[str, object]: + return { + "input": {"provider": "pico4"}, + "hands": { + "enabled": True, + "driver": "linkerhand_o6", + "mode": mode, + "sides": ["left", "right"], + "rate_hz": 30.0, + "frame_timeout_s": 0.3, + "linkerhand_o6": { + "left_can": "can0", + "right_can": "can1", + "modbus": "None", + "trigger_deadzone": 0.05, + "deadman_threshold": 0.5, + }, + }, + } + + def test_pico_hand_to_landmarks_uses_teleopit_adapter() -> None: joints = np.zeros((26, 7), dtype=np.float64) joints[:, 0] = np.arange(26) @@ -140,6 +174,15 @@ def test_trigger_to_pose_applies_deadzone_and_fixed_thumb_yaw() -> None: ) == [164, 10, 125, 125, 125, 125] +def test_trigger_to_pose_can_interpolate_thumb_yaw_for_o6() -> None: + assert trigger_to_pose( + 1.0, + open_pose=[250, 250, 250, 250, 250, 250], + close_pose=[86, 73, 118, 111, 110, 111], + deadzone=0.05, + ) == list(O6_CLOSE_POSE) + + def test_linkerhand_l6_device_starts_sdk(monkeypatch) -> None: FakeLinkerHandApi.instances = [] monkeypatch.setitem( @@ -160,6 +203,51 @@ def test_linkerhand_l6_device_starts_sdk(monkeypatch) -> None: assert [hand.hand.close_calls for hand in FakeLinkerHandApi.instances] == [1, 1] +def test_linkerhand_o6_gripper_defaults_to_reference_grasp_pose() -> None: + cfg = parse_linkerhand_o6_config(_o6_cfg()) + mapper = GripperMapper(cfg) + snapshot = PicoControllerSnapshot( + left=PicoControllerState(raw=True, grip=1.0, trigger=1.0, present=True), + right=PicoControllerState(raw=True, grip=0.1, trigger=1.0, present=True), + timestamp_s=10.0, + seq=1, + ) + + commands = mapper.map(controller_snapshot=snapshot, hand_snapshot=None, active=True, now_s=10.0) + + assert cfg.close_pose == O6_CLOSE_POSE + assert commands[0].pose == O6_CLOSE_POSE + assert commands[1].pose == cfg.open_pose + + +def test_linkerhand_o6_device_starts_sdk(monkeypatch) -> None: + FakeLinkerHandApi.instances = [] + monkeypatch.setitem( + sys.modules, + "LinkerHand.linker_hand_api", + SimpleNamespace(LinkerHandApi=FakeLinkerHandApi), + ) + cfg = parse_linkerhand_o6_config(_o6_cfg()) + device = LinkerHandO6Device(cfg) + + device.connect() + device.send_pose("left", cfg.close_pose) + device.close() + + assert [hand.hand_joint for hand in FakeLinkerHandApi.instances] == ["O6", "O6"] + assert [hand.can for hand in FakeLinkerHandApi.instances] == ["can0", "can1"] + assert FakeLinkerHandApi.instances[0].speed == [255, 255, 255, 255, 255, 255] + assert FakeLinkerHandApi.instances[0].poses[-2] == list(O6_CLOSE_POSE) + assert [hand.close_can_calls for hand in FakeLinkerHandApi.instances] == [0, 0] + assert [hand.hand.close_can_interface_calls for hand in FakeLinkerHandApi.instances] == [1, 1] + assert [hand.hand.close_calls for hand in FakeLinkerHandApi.instances] == [0, 0] + + +def test_linkerhand_o6_rejects_vr_hand_pose() -> None: + with pytest.raises(ValueError, match="supports only hands.mode=gripper"): + parse_linkerhand_o6_config(_o6_cfg(mode="vr_hand_pose")) + + def test_hand_runtime_closes_device_when_mapper_start_fails() -> None: calls: list[str] = [] From e74262aef3d04db5aad83c6acc587e32b17b6b01 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Thu, 18 Jun 2026 14:49:02 +0800 Subject: [PATCH 100/122] Fix Pico G1 IK body targets --- .../gmr/ik_configs/pico_bridge_to_g1.json | 10 +++--- teleopit/retargeting/gmr/motion_retarget.py | 11 +++++++ tests/test_retargeting.py | 31 ++++++++++++++++++- 3 files changed, 46 insertions(+), 6 deletions(-) diff --git a/teleopit/retargeting/gmr/ik_configs/pico_bridge_to_g1.json b/teleopit/retargeting/gmr/ik_configs/pico_bridge_to_g1.json index 673bbc1e..d114a7f3 100644 --- a/teleopit/retargeting/gmr/ik_configs/pico_bridge_to_g1.json +++ b/teleopit/retargeting/gmr/ik_configs/pico_bridge_to_g1.json @@ -25,10 +25,10 @@ "pelvis": ["Pelvis", 0, 10, [0.0, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], "left_hip_yaw_link": ["Left_Hip", 0, 10, [0.0, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], "left_knee_link": ["Left_Knee", 0, 10, [0.0, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], - "left_toe_link": ["Left_Foot", 100, 10, [0.05, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], + "left_ankle_roll_link": ["Left_Foot", 100, 10, [0.05, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], "right_hip_yaw_link": ["Right_Hip", 0, 10, [0.0, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], "right_knee_link": ["Right_Knee", 0, 10, [0.0, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], - "right_toe_link": ["Right_Foot", 100, 10, [0.05, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], + "right_ankle_roll_link": ["Right_Foot", 100, 10, [0.05, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], "torso_link": ["Spine3", 0, 10, [0.0, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], "left_shoulder_yaw_link": ["Left_Shoulder", 0, 10, [0.0, 0.0, 0.0], [0.7071067811865475, 0.0, 0.7071067811865475, 0.0]], "left_elbow_link": ["Left_Elbow", 0, 10, [0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]], @@ -41,10 +41,10 @@ "pelvis": ["Pelvis", 10, 5, [0.0, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], "left_hip_yaw_link": ["Left_Hip", 10, 5, [0.0, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], "left_knee_link": ["Left_Knee", 10, 5, [0.0, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], - "left_toe_link": ["Left_Foot", 100, 10, [0.05, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], + "left_ankle_roll_link": ["Left_Foot", 100, 10, [0.05, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], "right_hip_yaw_link": ["Right_Hip", 10, 5, [0.0, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], "right_knee_link": ["Right_Knee", 10, 5, [0.0, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], - "right_toe_link": ["Right_Foot", 100, 10, [0.05, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], + "right_ankle_roll_link": ["Right_Foot", 100, 10, [0.05, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], "torso_link": ["Spine3", 0, 10, [0.0, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], "left_shoulder_yaw_link": ["Left_Shoulder", 0, 10, [0.0, 0.0, 0.0], [0.7071067811865475, 0.0, 0.7071067811865475, 0.0]], "left_elbow_link": ["Left_Elbow", 0, 10, [0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]], @@ -53,4 +53,4 @@ "right_elbow_link": ["Right_Elbow", 0, 10, [0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]], "right_wrist_yaw_link": ["Right_Wrist", 0, 10, [0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]] } -} \ No newline at end of file +} diff --git a/teleopit/retargeting/gmr/motion_retarget.py b/teleopit/retargeting/gmr/motion_retarget.py index 28508e13..8a2f39da 100644 --- a/teleopit/retargeting/gmr/motion_retarget.py +++ b/teleopit/retargeting/gmr/motion_retarget.py @@ -107,6 +107,15 @@ def __init__( self._warmup_max_iter = 200 self._warmup_dt = 0.1 # large integration step for fast convergence during warmup + def _validate_body_frame(self, frame_name, table_name): + if mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_BODY, frame_name) >= 0: + return + available = ", ".join(self.robot_body_names.keys()) + raise ValueError( + f"IK config {table_name} references body '{frame_name}', but it does not exist " + f"in robot model '{self.xml_file}'. Update the IK config to use one of: {available}" + ) + def reset_configuration(self): """Reset the IK configuration to the model's default qpos. @@ -130,6 +139,7 @@ def setup_retarget_configuration(self): for frame_name, entry in self.ik_match_table1.items(): body_name, pos_weight, rot_weight, pos_offset, rot_offset = entry if pos_weight != 0 or rot_weight != 0: + self._validate_body_frame(frame_name, "ik_match_table1") task = mink.FrameTask( frame_name=frame_name, frame_type="body", @@ -148,6 +158,7 @@ def setup_retarget_configuration(self): for frame_name, entry in self.ik_match_table2.items(): body_name, pos_weight, rot_weight, pos_offset, rot_offset = entry if pos_weight != 0 or rot_weight != 0: + self._validate_body_frame(frame_name, "ik_match_table2") task = mink.FrameTask( frame_name=frame_name, frame_type="body", diff --git a/tests/test_retargeting.py b/tests/test_retargeting.py index 0041213f..8baf3e1f 100644 --- a/tests/test_retargeting.py +++ b/tests/test_retargeting.py @@ -5,12 +5,14 @@ the RetargetingModule.retarget output contract via a mock GMR, and the extract_mimic_obs helper. """ +import json +from pathlib import Path from unittest.mock import MagicMock import numpy as np import pytest -from conftest import requires_mink, requires_mujoco +from conftest import find_g1_xml_path, requires_mink, requires_mujoco class TestRetargetingModuleImport: @@ -123,3 +125,30 @@ def test_init_with_invalid_robot_raises(self): robot_name="nonexistent_robot_xyz", human_format="nonexistent_format", ) + + def test_pico_bridge_g1_ik_frames_exist_in_canonical_xml(self): + import mujoco + + from teleopit.retargeting.gmr.params import IK_CONFIG_DICT, ROBOT_XML_DICT + + robot_name = "unitree_g1" + xml_path = find_g1_xml_path() + if xml_path is None: + pytest.skip("G1 robot XML asset not available; run scripts/setup/download_assets.py --only robots gmr") + + assert Path(xml_path).resolve() == Path(ROBOT_XML_DICT[robot_name]).resolve() + model = mujoco.MjModel.from_xml_path(xml_path) + with open(IK_CONFIG_DICT["pico_bridge"][robot_name], encoding="utf-8") as f: + ik_config = json.load(f) + + missing = [] + for table_name in ("ik_match_table1", "ik_match_table2"): + for frame_name, entry in ik_config[table_name].items(): + _, pos_weight, rot_weight, _, _ = entry + if pos_weight == 0 and rot_weight == 0: + continue + frame_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_BODY, frame_name) + if frame_id < 0: + missing.append(f"{table_name}:{frame_name}") + + assert missing == [] From 4394146cd6f21177b882d1186f0d5b3b9bc1b655 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Thu, 18 Jun 2026 15:22:29 +0800 Subject: [PATCH 101/122] Support site frames in GMR IK config --- .../gmr/ik_configs/pico_bridge_to_g1.json | 8 +-- teleopit/retargeting/gmr/motion_retarget.py | 59 +++++++++++++++---- tests/test_retargeting.py | 24 +++++++- 3 files changed, 74 insertions(+), 17 deletions(-) diff --git a/teleopit/retargeting/gmr/ik_configs/pico_bridge_to_g1.json b/teleopit/retargeting/gmr/ik_configs/pico_bridge_to_g1.json index d114a7f3..75568756 100644 --- a/teleopit/retargeting/gmr/ik_configs/pico_bridge_to_g1.json +++ b/teleopit/retargeting/gmr/ik_configs/pico_bridge_to_g1.json @@ -25,10 +25,10 @@ "pelvis": ["Pelvis", 0, 10, [0.0, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], "left_hip_yaw_link": ["Left_Hip", 0, 10, [0.0, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], "left_knee_link": ["Left_Knee", 0, 10, [0.0, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], - "left_ankle_roll_link": ["Left_Foot", 100, 10, [0.05, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], + "left_foot": ["Left_Foot", 100, 10, [-0.01, 0.0, -0.015], [-0.5, 0.5, -0.5, -0.5], "site"], "right_hip_yaw_link": ["Right_Hip", 0, 10, [0.0, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], "right_knee_link": ["Right_Knee", 0, 10, [0.0, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], - "right_ankle_roll_link": ["Right_Foot", 100, 10, [0.05, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], + "right_foot": ["Right_Foot", 100, 10, [-0.01, 0.0, -0.015], [-0.5, 0.5, -0.5, -0.5], "site"], "torso_link": ["Spine3", 0, 10, [0.0, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], "left_shoulder_yaw_link": ["Left_Shoulder", 0, 10, [0.0, 0.0, 0.0], [0.7071067811865475, 0.0, 0.7071067811865475, 0.0]], "left_elbow_link": ["Left_Elbow", 0, 10, [0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]], @@ -41,10 +41,10 @@ "pelvis": ["Pelvis", 10, 5, [0.0, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], "left_hip_yaw_link": ["Left_Hip", 10, 5, [0.0, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], "left_knee_link": ["Left_Knee", 10, 5, [0.0, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], - "left_ankle_roll_link": ["Left_Foot", 100, 10, [0.05, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], + "left_foot": ["Left_Foot", 100, 10, [-0.01, 0.0, -0.015], [-0.5, 0.5, -0.5, -0.5], "site"], "right_hip_yaw_link": ["Right_Hip", 10, 5, [0.0, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], "right_knee_link": ["Right_Knee", 10, 5, [0.0, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], - "right_ankle_roll_link": ["Right_Foot", 100, 10, [0.05, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], + "right_foot": ["Right_Foot", 100, 10, [-0.01, 0.0, -0.015], [-0.5, 0.5, -0.5, -0.5], "site"], "torso_link": ["Spine3", 0, 10, [0.0, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], "left_shoulder_yaw_link": ["Left_Shoulder", 0, 10, [0.0, 0.0, 0.0], [0.7071067811865475, 0.0, 0.7071067811865475, 0.0]], "left_elbow_link": ["Left_Elbow", 0, 10, [0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]], diff --git a/teleopit/retargeting/gmr/motion_retarget.py b/teleopit/retargeting/gmr/motion_retarget.py index 8a2f39da..8b0d6637 100644 --- a/teleopit/retargeting/gmr/motion_retarget.py +++ b/teleopit/retargeting/gmr/motion_retarget.py @@ -7,6 +7,13 @@ from .params import ROBOT_XML_DICT, IK_CONFIG_DICT from rich import print +_FRAME_TYPE_TO_MJ_OBJ = { + "body": mj.mjtObj.mjOBJ_BODY, + "geom": mj.mjtObj.mjOBJ_GEOM, + "site": mj.mjtObj.mjOBJ_SITE, +} + + class GeneralMotionRetargeting: """General Motion Retargeting (GMR). """ @@ -107,12 +114,44 @@ def __init__( self._warmup_max_iter = 200 self._warmup_dt = 0.1 # large integration step for fast convergence during warmup - def _validate_body_frame(self, frame_name, table_name): - if mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_BODY, frame_name) >= 0: + def _parse_ik_entry(self, entry): + if len(entry) == 5: + body_name, pos_weight, rot_weight, pos_offset, rot_offset = entry + frame_type = "body" + elif len(entry) == 6: + body_name, pos_weight, rot_weight, pos_offset, rot_offset, frame_type = entry + else: + raise ValueError( + "IK config entries must be [human_body, pos_weight, rot_weight, " + "pos_offset, rot_offset] or the same list plus frame_type" + ) + frame_type = str(frame_type) + if frame_type not in _FRAME_TYPE_TO_MJ_OBJ: + supported = ", ".join(sorted(_FRAME_TYPE_TO_MJ_OBJ)) + raise ValueError(f"Unsupported IK frame_type '{frame_type}'. Supported values: {supported}") + return body_name, pos_weight, rot_weight, pos_offset, rot_offset, frame_type + + def _available_frame_names(self, frame_type): + obj_type = _FRAME_TYPE_TO_MJ_OBJ[frame_type] + count = { + "body": self.model.nbody, + "geom": self.model.ngeom, + "site": self.model.nsite, + }[frame_type] + names = [] + for idx in range(count): + name = mj.mj_id2name(self.model, obj_type, idx) + if name: + names.append(name) + return names + + def _validate_frame(self, frame_name, frame_type, table_name): + obj_type = _FRAME_TYPE_TO_MJ_OBJ[frame_type] + if mj.mj_name2id(self.model, obj_type, frame_name) >= 0: return - available = ", ".join(self.robot_body_names.keys()) + available = ", ".join(self._available_frame_names(frame_type)) raise ValueError( - f"IK config {table_name} references body '{frame_name}', but it does not exist " + f"IK config {table_name} references {frame_type} '{frame_name}', but it does not exist " f"in robot model '{self.xml_file}'. Update the IK config to use one of: {available}" ) @@ -137,12 +176,12 @@ def setup_retarget_configuration(self): self.tasks2 = [] for frame_name, entry in self.ik_match_table1.items(): - body_name, pos_weight, rot_weight, pos_offset, rot_offset = entry + body_name, pos_weight, rot_weight, pos_offset, rot_offset, frame_type = self._parse_ik_entry(entry) if pos_weight != 0 or rot_weight != 0: - self._validate_body_frame(frame_name, "ik_match_table1") + self._validate_frame(frame_name, frame_type, "ik_match_table1") task = mink.FrameTask( frame_name=frame_name, - frame_type="body", + frame_type=frame_type, position_cost=pos_weight, orientation_cost=rot_weight, lm_damping=1, @@ -156,12 +195,12 @@ def setup_retarget_configuration(self): self.task_errors1[task] = [] for frame_name, entry in self.ik_match_table2.items(): - body_name, pos_weight, rot_weight, pos_offset, rot_offset = entry + body_name, pos_weight, rot_weight, pos_offset, rot_offset, frame_type = self._parse_ik_entry(entry) if pos_weight != 0 or rot_weight != 0: - self._validate_body_frame(frame_name, "ik_match_table2") + self._validate_frame(frame_name, frame_type, "ik_match_table2") task = mink.FrameTask( frame_name=frame_name, - frame_type="body", + frame_type=frame_type, position_cost=pos_weight, orientation_cost=rot_weight, lm_damping=1, diff --git a/tests/test_retargeting.py b/tests/test_retargeting.py index 8baf3e1f..3d0cd8a1 100644 --- a/tests/test_retargeting.py +++ b/tests/test_retargeting.py @@ -141,14 +141,32 @@ def test_pico_bridge_g1_ik_frames_exist_in_canonical_xml(self): with open(IK_CONFIG_DICT["pico_bridge"][robot_name], encoding="utf-8") as f: ik_config = json.load(f) + object_types = { + "body": mujoco.mjtObj.mjOBJ_BODY, + "geom": mujoco.mjtObj.mjOBJ_GEOM, + "site": mujoco.mjtObj.mjOBJ_SITE, + } missing = [] for table_name in ("ik_match_table1", "ik_match_table2"): for frame_name, entry in ik_config[table_name].items(): - _, pos_weight, rot_weight, _, _ = entry + _, pos_weight, rot_weight, _, _, *rest = entry if pos_weight == 0 and rot_weight == 0: continue - frame_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_BODY, frame_name) + frame_type = rest[0] if rest else "body" + frame_id = mujoco.mj_name2id(model, object_types[frame_type], frame_name) if frame_id < 0: - missing.append(f"{table_name}:{frame_name}") + missing.append(f"{table_name}:{frame_type}:{frame_name}") assert missing == [] + + def test_pico_bridge_g1_foot_ik_uses_canonical_foot_sites(self): + from teleopit.retargeting.gmr.params import IK_CONFIG_DICT + + with open(IK_CONFIG_DICT["pico_bridge"]["unitree_g1"], encoding="utf-8") as f: + ik_config = json.load(f) + + for table_name in ("ik_match_table1", "ik_match_table2"): + assert ik_config[table_name]["left_foot"][-1] == "site" + assert ik_config[table_name]["left_foot"][0] == "Left_Foot" + assert ik_config[table_name]["right_foot"][-1] == "site" + assert ik_config[table_name]["right_foot"][0] == "Right_Foot" From 19ebce74d46b4f905e5122858979833e36ac1298 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Thu, 18 Jun 2026 15:52:53 +0800 Subject: [PATCH 102/122] Tune tracking action rate penalty --- train_mimic/tasks/tracking/tracking_env_cfg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_mimic/tasks/tracking/tracking_env_cfg.py b/train_mimic/tasks/tracking/tracking_env_cfg.py index c0301f56..85801faf 100644 --- a/train_mimic/tasks/tracking/tracking_env_cfg.py +++ b/train_mimic/tasks/tracking/tracking_env_cfg.py @@ -269,7 +269,7 @@ def make_tracking_env_cfg() -> ManagerBasedRlEnvCfg: params={"command_name": "motion", "std": 3.0}, ), "survival": RewardTermCfg(func=mdp.survival, weight=3.0), - "action_rate_l2": RewardTermCfg(func=mdp.action_rate_l2, weight=-1e-1), + "action_rate_l2": RewardTermCfg(func=mdp.action_rate_l2, weight=-0.5), "joint_limit": RewardTermCfg( func=mdp.joint_pos_limits, weight=-10.0, From 78d130e646724513a203d9342cc067125581fda9 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Fri, 19 Jun 2026 21:00:57 +0800 Subject: [PATCH 103/122] Reuse sim2real standing path in standalone script --- docs/docs/tutorials/standalone-standing.md | 15 +- .../current/tutorials/standalone-standing.md | 12 +- scripts/run/standalone_standing.py | 1219 ++++------------- teleopit/sim2real/unitree_g1.py | 3 +- 4 files changed, 252 insertions(+), 997 deletions(-) diff --git a/docs/docs/tutorials/standalone-standing.md b/docs/docs/tutorials/standalone-standing.md index 214ee6e2..772b0e91 100644 --- a/docs/docs/tutorials/standalone-standing.md +++ b/docs/docs/tutorials/standalone-standing.md @@ -5,8 +5,9 @@ sidebar_position: 3 # Standalone Standing Test Run this before full sim2real control when bringing up a new robot, network -setup, or policy. It verifies the G1 bridge and RL standing path without Pico, -BVH playback, retargeting, or the full Teleopit mocap pipeline. +setup, or policy. It verifies the G1 bridge and the same RL standing path used +by sim2real, without Pico, BVH playback, retargeting, or the full mocap +pipeline. ```text G1 LowState -> standing observation -> RL policy -> G1 LowCmd targets @@ -58,9 +59,11 @@ python scripts/run/standalone_standing.py \ --network-interface eth0 ``` -Standalone standing uses the same Kp ramp semantics as sim2real: after locking -the current joints, policy targets are sent immediately while Kp ramps from 10% -to the configured gains over 2 seconds. To tune this startup behavior: +Standalone standing reuses the sim2real standing components: `UnitreeG1Robot`, +`Sim2RealSafetyManager`, `RLPolicyController`, `VelCmdObservationBuilder`, and +`Sim2RealReferenceProcessor`. After locking the current joints, policy targets +are sent while Kp ramps from 10% to the configured gains over 2 seconds. To tune +this startup behavior: ```bash python scripts/run/standalone_standing.py \ @@ -76,6 +79,8 @@ python scripts/run/standalone_standing.py \ - LowState is received from the robot. - The dual-input ONNX policy can run the standing observation path. - Low-level position targets can be published through the C++ bridge. +- Observation construction, action scaling, default standing pose, Kp ramp, and + joint-limit clipping match the sim2real standing runtime. ## Next Steps diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/standalone-standing.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/standalone-standing.md index 5201df02..54bb06e6 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/standalone-standing.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/standalone-standing.md @@ -5,8 +5,8 @@ sidebar_position: 3 # 独立站立测试 在接入完整 sim2real 控制之前,如果正在调试新机器人、网络配置或 policy,先运行此测试。 -它不使用 Pico、BVH 回放、retargeting,也不走完整 Teleopit mocap pipeline,只验证 G1 bridge -和 RL standing 路径。 +它不使用 Pico、BVH 回放、retargeting,也不走完整 mocap pipeline,只验证 G1 bridge +和 sim2real 使用的同一条 RL standing 路径。 ```text G1 LowState -> standing observation -> RL policy -> G1 LowCmd targets @@ -57,8 +57,10 @@ python scripts/run/standalone_standing.py \ --network-interface eth0 ``` -standalone standing 使用与 sim2real 相同的 Kp ramp 语义:锁住当前关节后立即发送 -policy target,同时在 2 秒内把 Kp 从 10% 逐步升到配置的增益。可以这样调整启动行为: +standalone standing 复用 sim2real standing 组件:`UnitreeG1Robot`、 +`Sim2RealSafetyManager`、`RLPolicyController`、`VelCmdObservationBuilder` 和 +`Sim2RealReferenceProcessor`。锁住当前关节后发送 policy target,同时在 2 秒内把 Kp +从 10% 逐步升到配置的增益。可以这样调整启动行为: ```bash python scripts/run/standalone_standing.py \ @@ -74,6 +76,8 @@ python scripts/run/standalone_standing.py \ - 能从机器人收到 LowState。 - dual-input ONNX policy 能运行 standing observation 路径。 - 能通过 C++ bridge 发布 low-level position targets。 +- observation 构建、action scale、默认站姿、Kp ramp 和 joint-limit clipping 与 sim2real + standing runtime 保持一致。 ## 下一步 diff --git a/scripts/run/standalone_standing.py b/scripts/run/standalone_standing.py index 82474188..49e7989a 100644 --- a/scripts/run/standalone_standing.py +++ b/scripts/run/standalone_standing.py @@ -1,1042 +1,287 @@ #!/usr/bin/env python3 -"""Standalone G1 standing script with RL policy -- no Teleopit/Pico dependency. - -Uses ONNX RL policy inference to maintain balanced standing, matching the -STANDING mode used by the sim2real robot-control runtime. Only depends on: - - g1_bridge_sdk (C++ DDS bridge) - - onnxruntime - - mujoco - - numpy - -Usage: - python scripts/run/standalone_standing.py \ - --policy track.onnx \ - --network-interface enp130s0 - -Flow: - 1. Init DDS, subscribe to rt/lowstate - 2. Load ONNX policy + MuJoCo model for observation building - 3. Enter debug mode (release MotionSwitcher modes) - 4. Lock joints, then run RL policy standing loop - 5. Hold standing until Ctrl-C or L1+R1 - 6. On exit: set damping, restore ai mode -""" +"""Standalone G1 standing controller using the sim2real standing path.""" from __future__ import annotations import argparse import logging -import math import signal -import struct -import sys -import threading import time -from collections import deque from pathlib import Path +from typing import Any -_REPO_ROOT = Path(__file__).resolve().parents[2] - -import mujoco import numpy as np -import onnxruntime as ort +from omegaconf import OmegaConf + +from teleopit.constants import FULL_QPOS_DIM, NUM_JOINTS, ROOT_DIM +from teleopit.controllers.observation import align_motion_qpos_yaw +from teleopit.controllers.rl_policy import RLPolicyController +from teleopit.runtime.common import cfg_get +from teleopit.runtime.factory import _build_policy_components, build_simulation_cfg +from teleopit.sim.reference_timeline import ReferenceWindowBuilder +from teleopit.sim.reference_utils import build_static_reference_window, obs_builder_requires_reference_window +from teleopit.sim2real.reference_processor import Sim2RealReferenceProcessor +from teleopit.sim2real.remote import UnitreeRemote +from teleopit.sim2real.safety import Sim2RealSafetyManager +from teleopit.sim2real.unitree_g1 import UnitreeG1Robot + +PROJECT_ROOT = Path(__file__).resolve().parents[2] +DEFAULT_POLICY_HZ = 50.0 logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") logger = logging.getLogger(__name__) -# ---- G1 29-DOF Constants ---- -NUM_JOINTS = 29 -NUM_MOTORS = 35 -MODE_PR = 0 -MODE_MACHINE = 5 -PUBLISH_HZ = 200 -POLICY_HZ = 50.0 -POS_STOP_F = 2146000000.0 -VEL_STOP_F = 16000.0 -KD_DAMPING = 8.0 -JOINT_VEL_LIMIT = 10.0 -DEFAULT_KP_RAMP_DURATION = 2.0 -DEFAULT_KP_RAMP_FLOOR_RATIO = 0.1 - -# Default standing pose (from g1_constants.py HOME_KEYFRAME) -DEFAULT_ANGLES = np.array([ - -0.312, 0, 0, 0.669, -0.363, 0, # Left leg (0-5) - -0.312, 0, 0, 0.669, -0.363, 0, # Right leg (6-11) - 0, 0, 0, # Waist (12-14) - 0.2, 0.2, 0, 0.6, 0, 0, 0, # Left arm (15-21) - 0.2, -0.2, 0, 0.6, 0, 0, 0, # Right arm (22-28) -], dtype=np.float32) - -# Action scale (from g1.yaml) -ACTION_SCALE = np.array([ - 0.5475, 0.3507, 0.5475, 0.3507, 0.4386, 0.4386, - 0.5475, 0.3507, 0.5475, 0.3507, 0.4386, 0.4386, - 0.5475, 0.4386, 0.4386, - 0.4386, 0.4386, 0.4386, 0.4386, 0.4386, 0.0745, 0.0745, - 0.4386, 0.4386, 0.4386, 0.4386, 0.4386, 0.0745, 0.0745, -], dtype=np.float32) - -# PD gains (from mjlab g1_constants.py) -KP = np.array([ - 40.2, 99.1, 40.2, 99.1, 28.5, 28.5, - 40.2, 99.1, 40.2, 99.1, 28.5, 28.5, - 40.2, 28.5, 28.5, - 14.3, 14.3, 14.3, 14.3, 14.3, 16.8, 16.8, - 14.3, 14.3, 14.3, 14.3, 14.3, 16.8, 16.8, -], dtype=np.float32) - -KD = np.array([ - 2.6, 6.3, 2.6, 6.3, 1.8, 1.8, - 2.6, 6.3, 2.6, 6.3, 1.8, 1.8, - 2.6, 1.8, 1.8, - 0.9, 0.9, 0.9, 0.9, 0.9, 1.1, 1.1, - 0.9, 0.9, 0.9, 0.9, 0.9, 1.1, 1.1, -], dtype=np.float32) - -# Joint position limits (from pico4_sim2real.yaml) -JOINT_POS_LOWER = np.array([ - -2.5307, -0.5236, -2.7576, -0.087267, -0.87267, -0.2618, - -2.5307, -2.9671, -2.7576, -0.087267, -0.87267, -0.2618, - -2.618, -0.52, -0.52, - -3.0892, -1.5882, -2.618, -1.0472, -1.972222054, -1.61443, -1.61443, - -3.0892, -2.2515, -2.618, -1.0472, -1.972222054, -1.61443, -1.61443, -], dtype=np.float32) - -JOINT_POS_UPPER = np.array([ - 2.8798, 2.9671, 2.7576, 2.8798, 0.5236, 0.2618, - 2.8798, 0.5236, 2.7576, 2.8798, 0.5236, 0.2618, - 2.618, 0.52, 0.52, - 2.6704, 2.2515, 2.618, 2.0944, 1.972222054, 1.61443, 1.61443, - 2.6704, 1.5882, 2.618, 2.0944, 1.972222054, 1.61443, 1.61443, -], dtype=np.float32) - -# MuJoCo XML path for FK -MJCF_PATH = _REPO_ROOT / "assets" / "robots" / "unitree_g1" / "g1_29dof.xml" - -JOINT_MAP = list(range(NUM_JOINTS)) -GRAVITY_UNIT_W = np.array([0.0, 0.0, -1.0], dtype=np.float32) - -# Wireless remote button masks -_KEYS_OFFSET = 2 -_R1 = 0x0001 -_L1 = 0x0002 - - -# ===================================================================== -# Quaternion helpers (from teleopit.controllers.observation) -# ===================================================================== - -def quat_inv(q): - inv = q.copy() - inv[..., 1:] = -inv[..., 1:] - return inv - - -def quat_mul(q1, q2): - w1, x1, y1, z1 = q1[..., 0], q1[..., 1], q1[..., 2], q1[..., 3] - w2, x2, y2, z2 = q2[..., 0], q2[..., 1], q2[..., 2], q2[..., 3] - return np.stack([ - w1*w2 - x1*x2 - y1*y2 - z1*z2, - w1*x2 + x1*w2 + y1*z2 - z1*y2, - w1*y2 - x1*z2 + y1*w2 + z1*x2, - w1*z2 + x1*y2 - y1*x2 + z1*w2, - ], axis=-1).astype(np.float32) - - -def quat_rotate(q, v): - v_quat = np.zeros((*v.shape[:-1], 4), dtype=np.float32) - v_quat[..., 1:4] = v - result = quat_mul(quat_mul(q, v_quat), quat_inv(q)) - return result[..., 1:4] - - -def yaw_quat(q): - w, x, y, z = float(q[0]), float(q[1]), float(q[2]), float(q[3]) - yaw = math.atan2(2.0 * (w * z + x * y), 1.0 - 2.0 * (y * y + z * z)) - out = np.array([math.cos(yaw / 2.0), 0.0, 0.0, math.sin(yaw / 2.0)], dtype=np.float32) - out /= max(float(np.linalg.norm(out)), 1e-8) - return out - - -def align_motion_qpos_yaw(robot_quat_wxyz, motion_qpos): - """Align motion_qpos[3:7] quaternion yaw to match robot yaw.""" - motion_quat = np.asarray(motion_qpos[3:7], dtype=np.float32) - robot_quat = np.asarray(robot_quat_wxyz, dtype=np.float32) - delta = quat_mul(robot_quat, quat_inv(motion_quat)) - delta_yaw = yaw_quat(delta) - motion_qpos[3:7] = quat_mul(delta_yaw, motion_quat).astype(motion_qpos.dtype) - return motion_qpos - - -def quat_to_rot6d(q): - w, x, y, z = q[0], q[1], q[2], q[3] - r00 = 1 - 2 * (y*y + z*z) - r01 = 2 * (x*y - w*z) - r10 = 2 * (x*y + w*z) - r11 = 1 - 2 * (x*x + z*z) - r20 = 2 * (x*z - w*y) - r21 = 2 * (y*z + w*x) - return np.array([r00, r01, r10, r11, r20, r21], dtype=np.float32) - - -# ===================================================================== -# Observation builder (from teleopit.controllers.observation) -# ===================================================================== - -class ObservationBuilder: - """167D VelCmd observation builder using MuJoCo FK.""" - - def __init__(self, xml_path: str): - self._mj_model = mujoco.MjModel.from_xml_path(xml_path) - self._mj_data = mujoco.MjData(self._mj_model) - - # Find anchor body (torso_link) - body_names = {self._mj_model.body(i).name: i for i in range(self._mj_model.nbody)} - self._anchor_body_id = body_names["torso_link"] - - # Base obs: command(29+29) + anchor_ori_b(6) + ang_vel(3) + joint_pos_rel(29) + qvel(29) + last_act(29) - # = 154 - # VelCmd extra: projected_gravity(3) + ref_lin_vel_b(3) + ref_ang_vel_b(3) - # + ref_proj_gravity(3) + ref_base_height(1) = 13 - # Total = 167 - self.total_obs_size = NUM_JOINTS * 2 + 6 + 3 + NUM_JOINTS * 3 + 13 - - # Precompute motion torso offset for standing (DEFAULT_ANGLES with identity base) - # torso_quat_world = quat_mul(base_quat, torso_offset) for constant joint angles - self._run_fk(np.zeros(3), np.array([1, 0, 0, 0], dtype=np.float32), DEFAULT_ANGLES) - self._standing_torso_offset = self._get_body_quat(self._anchor_body_id).copy() - - def _run_fk(self, base_pos, base_quat, joint_pos): - self._mj_data.qpos[:] = 0.0 - self._mj_data.qpos[0:3] = np.asarray(base_pos, dtype=np.float64).reshape(3) - quat = np.asarray(base_quat, dtype=np.float64).reshape(4) - quat = quat / max(np.linalg.norm(quat), 1e-8) - self._mj_data.qpos[3:7] = quat - n = min(len(joint_pos), self._mj_model.nq - 7) - self._mj_data.qpos[7:7 + n] = np.asarray(joint_pos, dtype=np.float64)[:n] - mujoco.mj_kinematics(self._mj_model, self._mj_data) - - def _get_body_quat(self, body_id): - return np.asarray(self._mj_data.xquat[body_id], dtype=np.float32).copy() - - def _get_body_pos(self, body_id): - return np.asarray(self._mj_data.xpos[body_id], dtype=np.float32).copy() - - def build(self, robot_qpos, robot_qvel, robot_quat, robot_ang_vel, - motion_qpos, motion_joint_vel, last_action): - """Build 167D observation for VelCmd policy. - - Args: - robot_qpos: (29,) current joint positions - robot_qvel: (29,) current joint velocities - robot_quat: (4,) base orientation quaternion (w,x,y,z) - robot_ang_vel: (3,) base angular velocity - motion_qpos: (36,) reference motion [pos(3) + quat(4) + joints(29)] - motion_joint_vel: (29,) reference joint velocity - last_action: (29,) previous policy action - """ - qpos = np.asarray(robot_qpos, dtype=np.float32)[:NUM_JOINTS] - qvel = np.asarray(robot_qvel, dtype=np.float32)[:NUM_JOINTS] - robot_q = np.asarray(robot_quat, dtype=np.float32) - ang_vel = np.asarray(robot_ang_vel, dtype=np.float32) - motion = np.asarray(motion_qpos, dtype=np.float32) - m_joint_vel = np.asarray(motion_joint_vel, dtype=np.float32)[:NUM_JOINTS] - last_act = np.asarray(last_action, dtype=np.float32) - - # Anchor quaternions: skip MuJoCo FK, use base quat * precomputed offset - # For standing, waist joints ≈ 0 so torso ≈ base orientation - robot_anchor_quat = quat_mul(robot_q, self._standing_torso_offset) - - motion_base_quat = motion[3:7] - motion_joint_pos = motion[7:7 + NUM_JOINTS] - self._run_fk(motion[0:3], motion_base_quat, motion_joint_pos) - motion_anchor_pos = self._get_body_pos(self._anchor_body_id) - motion_anchor_quat = quat_mul(motion_base_quat, self._standing_torso_offset) - - # Base observation (154D) - command = np.concatenate((motion_joint_pos, m_joint_vel), dtype=np.float32) - rel_quat = quat_mul(quat_inv(robot_anchor_quat), motion_anchor_quat) - motion_anchor_ori_b = quat_to_rot6d(rel_quat) - joint_pos_rel = qpos - DEFAULT_ANGLES - - base_obs = np.concatenate([ - command, # 29 + 29 = 58 - motion_anchor_ori_b, # 6 - ang_vel, # 3 - joint_pos_rel, # 29 - qvel, # 29 - last_act, # 29 - ], dtype=np.float32) - - # VelCmd extra (13D) -- standing has zero reference velocities - projected_gravity = quat_rotate(quat_inv(robot_q), GRAVITY_UNIT_W) - robot_inv = quat_inv(robot_anchor_quat) - # Zero reference velocities for standing - ref_lin_vel_b = np.zeros(3, dtype=np.float32) - ref_ang_vel_b = np.zeros(3, dtype=np.float32) - ref_proj_gravity = quat_rotate(quat_inv(motion_anchor_quat), GRAVITY_UNIT_W) - ref_base_height = motion_anchor_pos[2:3] - - velcmd_obs = np.concatenate([ - projected_gravity, # 3 - ref_lin_vel_b, # 3 - ref_ang_vel_b, # 3 - ref_proj_gravity, # 3 - ref_base_height, # 1 - ], dtype=np.float32) - - obs = np.concatenate([base_obs, velcmd_obs], dtype=np.float32) - return obs - - -# ===================================================================== -# ONNX Policy wrapper (from teleopit.controllers.rl_policy) -# ===================================================================== - -class PolicyInference: - """Minimal ONNX policy inference wrapper.""" - - def __init__(self, policy_path: str): - providers = ["CPUExecutionProvider"] - available = set(ort.get_available_providers()) - if "CUDAExecutionProvider" in available: - providers.insert(0, "CUDAExecutionProvider") - self._session = ort.InferenceSession(policy_path, providers=providers) - onnx_inputs = self._session.get_inputs() - self._input_name = onnx_inputs[0].name - self._output_name = self._session.get_outputs()[0].name +class StandaloneStandingController: + """Small wrapper around the production sim2real STANDING implementation.""" - # Check for dual-input (obs + obs_history) model - self._multi_input = False - self._history_length = 0 - self._history_buf: deque[np.ndarray] = deque() - if len(onnx_inputs) == 2 and onnx_inputs[1].name == "obs_history": - self._multi_input = True - self._history_length = int(onnx_inputs[1].shape[1]) - self._history_buf = deque(maxlen=self._history_length) + def __init__(self, cfg: Any, *, dry_run: bool = False, no_policy: bool = False) -> None: + self.cfg = cfg + self.dry_run = dry_run + self.no_policy = no_policy + self.shutdown_requested = False - # Extract expected obs dim - self._expected_obs_dim = None - if len(onnx_inputs[0].shape) >= 2 and isinstance(onnx_inputs[0].shape[-1], int): - self._expected_obs_dim = int(onnx_inputs[0].shape[-1]) + self.policy_hz = float(cfg_get(cfg, "policy_hz", DEFAULT_POLICY_HZ)) + self.dt = 1.0 / self.policy_hz + self.robot_cfg = cfg_get(cfg, "robot") + self.real_robot_cfg = cfg_get(cfg, "real_robot") + self.default_angles = np.asarray(cfg_get(self.robot_cfg, "default_angles"), dtype=np.float32) + self.num_actions = int(cfg_get(self.robot_cfg, "num_actions", NUM_JOINTS)) - logger.info( - "Policy loaded: input=%s, obs_dim=%s, multi_input=%s, history_len=%d", - self._input_name, self._expected_obs_dim, self._multi_input, self._history_length, - ) - - def compute_action(self, observation: np.ndarray) -> np.ndarray: - obs = np.asarray(observation, dtype=np.float32) - if obs.ndim == 1: - obs = obs[np.newaxis, :] - obs_flat = obs.reshape(-1) - - if self._multi_input: - if len(self._history_buf) == 0: - for _ in range(self._history_length): - self._history_buf.append(obs_flat.copy()) - else: - self._history_buf.append(obs_flat.copy()) - obs_history = np.stack(list(self._history_buf), axis=0)[np.newaxis].astype(np.float32) - feed = {self._input_name: obs, "obs_history": obs_history} - else: - feed = {self._input_name: obs} - - raw_action = np.asarray( - self._session.run([self._output_name], feed)[0], dtype=np.float32 + default_root_qpos = np.asarray( + cfg_get(self.robot_cfg, "mujoco_default_qpos", [0.0, 0.0, 0.0]), dtype=np.float64 ).reshape(-1) - return raw_action - - def get_target_dof_pos(self, raw_action: np.ndarray) -> np.ndarray: - clipped = np.clip(raw_action, -10.0, 10.0) - scaled = clipped * ACTION_SCALE - return scaled + DEFAULT_ANGLES - - def reset(self): - self._history_buf.clear() - - -# ===================================================================== -# Main controller -# ===================================================================== - -class StandingController: - """RL-policy-based standing controller matching sim2real STANDING behavior.""" - - def __init__(self, network_interface: str, policy_path: str, - no_policy: bool = False, - publish_hz: int = 250, - obs_delay: float = 0.0, - command_delay: float = 0.0, - kp_ramp_duration: float = DEFAULT_KP_RAMP_DURATION, - kp_ramp_floor_ratio: float = DEFAULT_KP_RAMP_FLOOR_RATIO) -> None: - self._network_interface = network_interface - self._shutdown = False - if obs_delay < 0.0: - raise ValueError("obs_delay must be >= 0") - if command_delay < 0.0: - raise ValueError("command_delay must be >= 0") - if kp_ramp_duration < 0.0: - raise ValueError("kp_ramp_duration must be >= 0") - if not 0.0 <= kp_ramp_floor_ratio <= 1.0: - raise ValueError("kp_ramp_floor_ratio must be in [0, 1]") - - # ---- Load policy and observation builder ---- - self._policy = PolicyInference(policy_path) - xml_path = str(MJCF_PATH) - if not MJCF_PATH.exists(): - raise FileNotFoundError(f"MuJoCo XML not found: {xml_path}") - self._obs_builder = ObservationBuilder(xml_path) - - # Verify observation dimension matches policy expectation - if (self._policy._expected_obs_dim is not None and - self._policy._expected_obs_dim != self._obs_builder.total_obs_size): - raise ValueError( - f"Obs dimension mismatch: builder={self._obs_builder.total_obs_size}, " - f"policy={self._policy._expected_obs_dim}" - ) - - self._no_policy = no_policy - self._dry_run = False - self._state_delay = 0.0 - self._obs_delay = float(obs_delay) - self._command_delay = float(command_delay) - self._state_history: deque[tuple[float, tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]] = deque(maxlen=512) - self._state_history_lock = threading.Lock() - self._state_sampler_thread: threading.Thread | None = None - self._state_sampler_running = False - self._pending_targets: deque[tuple[float, np.ndarray]] = deque() - self._pending_targets_cv = threading.Condition() - self._command_sender_thread: threading.Thread | None = None - self._command_sender_running = False - self._last_obs_age_s = 0.0 - self._last_command_queue_len = 0 - - # ---- Policy state ---- - self._step_count = 0 - self._last_action = np.zeros(NUM_JOINTS, dtype=np.float32) - # Standing reference qpos: [pos(3), quat(4), joints(29)] = 36D - self._standing_qpos = np.zeros(36, dtype=np.float64) - self._standing_qpos[3] = 1.0 # identity quaternion w=1 - self._standing_qpos[7:36] = DEFAULT_ANGLES.astype(np.float64) - - # ---- Pipeline state ---- - self._inference_thread: threading.Thread | None = None - self._inference_running = False - - self._publish_hz = publish_hz - self._kp_ramp_duration_steps = max(1, int(kp_ramp_duration * POLICY_HZ)) - self._kp_ramp_floor_ratio = float(kp_ramp_floor_ratio) - self._kp_ramp_step = 0 - self._kp_ramp_active = False - - self._init_cpp_backend() - - # ================================================================== - # Backend init - # ================================================================== - - def _init_cpp_backend(self) -> None: - import g1_bridge_sdk - logger.info("Using C++ bridge backend (%dHz publish)", self._publish_hz) - self._bridge = g1_bridge_sdk.G1Bridge(self._network_interface, self._publish_hz) - - logger.info("Waiting for LowState on %s ...", self._network_interface) - if not self._bridge.wait_for_state(5.0): - raise RuntimeError("No LowState received within 5s -- check network and robot power") - logger.info("LowState received, robot connected") - - # ================================================================== - # Robot state reading - # ================================================================== - - def _get_robot_state(self): - return self._read_robot_state() - - def _read_robot_state(self) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - qpos, qvel, quat, ang_vel = self._bridge.get_state() - return ( - np.asarray(qpos, dtype=np.float32).copy(), - np.asarray(qvel, dtype=np.float32).copy(), - np.asarray(quat, dtype=np.float32).copy(), - np.asarray(ang_vel, dtype=np.float32).copy(), + self._default_root_pos = np.zeros(3, dtype=np.float64) + if default_root_qpos.shape[0] >= 3: + self._default_root_pos[:] = default_root_qpos[:3] + + self.policy, self.obs_builder = self._build_policy_and_obs() + self.robot = UnitreeG1Robot(self.real_robot_cfg) + self.remote = UnitreeRemote() + self.safety = Sim2RealSafetyManager(cfg, self.robot, self.policy_hz, self.num_actions) + self._entered_debug = False + + sim_cfg = build_simulation_cfg(cfg) + self.ref_proc = Sim2RealReferenceProcessor( + obs_builder=self.obs_builder, + policy=self.policy, + policy_hz=self.policy_hz, + num_actions=self.num_actions, + reference_velocity_smoothing_alpha=float(sim_cfg["reference_velocity_smoothing_alpha"]), + reference_anchor_velocity_smoothing_alpha=float(sim_cfg["reference_anchor_velocity_smoothing_alpha"]), ) - - def _record_robot_state(self) -> tuple[float, tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]: - now = time.monotonic() - state = self._read_robot_state() - with self._state_history_lock: - self._state_history.append((now, state)) - return now, state - - def _get_observation_state(self) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - now, current = self._record_robot_state() - if self._obs_delay <= 0.0: - self._last_obs_age_s = 0.0 - return current - - target_time = now - self._obs_delay - with self._state_history_lock: - history = list(self._state_history) - selected_time, selected_state = history[0] - for sample_time, sample_state in reversed(history): - if sample_time <= target_time: - selected_time, selected_state = sample_time, sample_state - break - self._last_obs_age_s = max(0.0, now - selected_time) - return selected_state - - def _start_state_sampler(self) -> None: - if self._obs_delay <= 0.0 or self._state_sampler_thread is not None: - return - self._state_sampler_running = True - self._state_sampler_thread = threading.Thread(target=self._state_sampler_loop, daemon=True) - self._state_sampler_thread.start() - - def _stop_state_sampler(self) -> None: - self._state_sampler_running = False - if self._state_sampler_thread is not None: - self._state_sampler_thread.join(timeout=1.0) - self._state_sampler_thread = None - - def _state_sampler_loop(self) -> None: - sample_dt = 1.0 / max(float(self._publish_hz), POLICY_HZ) - while self._state_sampler_running and not self._shutdown: - self._record_robot_state() - time.sleep(sample_dt) - - # ================================================================== - # Publish thread - # ================================================================== - - def _start_publish(self) -> None: - self._bridge.start_publish() - - def _stop_publish(self) -> None: - self._bridge.stop_publish() - - # ================================================================== - # Motion switcher - # ================================================================== - - def enter_debug_mode(self) -> bool: - try: - for _ in range(10): - code, name = self._bridge.check_mode() - logger.info("check_mode -> code=%s, name=%s", code, name) - if code != 0: - logger.error("check_mode RPC failed with code=%s", code) - return False - if not name: - logger.info("Debug mode ready (no active mode)") - return True - logger.info("Releasing mode: %s", name) - release_code = self._bridge.release_mode() - logger.info("release_mode('%s') -> code=%s", name, release_code) - if release_code != 0: - logger.error("Failed to release mode '%s' (code=%s)", name, release_code) - return False - time.sleep(1) - logger.warning("Could not release modes after 10 attempts") - return False - except Exception as exc: - logger.error("enter_debug_mode failed: %s", exc) - return False - - def exit_debug_mode(self) -> bool: - self._stop_publish() - try: - code = self._bridge.select_mode("ai") - logger.info("select_mode('ai') -> code=%s", code) - return code == 0 - except Exception as exc: - logger.error("exit_debug_mode failed: %s", exc) - return False - - # ================================================================== - # Motor commands - # ================================================================== - - def _set_damping(self) -> None: - self._bridge.set_damping() - - def _lock_joints(self) -> None: - qpos, _, _, _ = self._get_robot_state() - self._bridge.set_target(qpos, KP, KD) - self._bridge.lock_joints() - - def _start_kp_ramp(self) -> None: - self._kp_ramp_step = 0 - self._kp_ramp_active = True - logger.info( - "Kp ramp armed: %d steps (%.1fs), floor_ratio=%.2f", - self._kp_ramp_duration_steps, - self._kp_ramp_duration_steps / POLICY_HZ, - self._kp_ramp_floor_ratio, + self.reference_window_builder = ReferenceWindowBuilder( + policy_dt_s=self.dt, + reference_steps=cfg_get(cfg, "reference_steps", [0]), ) - def _compute_kp_ramp_gains(self) -> tuple[np.ndarray, np.ndarray] | None: - if not self._kp_ramp_active: - return None - - factor = min(1.0, self._kp_ramp_step / self._kp_ramp_duration_steps) - kp = KP * (self._kp_ramp_floor_ratio + (1.0 - self._kp_ramp_floor_ratio) * factor) - - self._kp_ramp_step += 1 - if self._kp_ramp_step >= self._kp_ramp_duration_steps: - self._kp_ramp_active = False - logger.info("Kp ramp complete (%d steps)", self._kp_ramp_duration_steps) - - return np.asarray(kp, dtype=np.float32), KD.copy() - - def _write_target_now(self, target: np.ndarray) -> None: - gains = self._compute_kp_ramp_gains() - if gains is None: - self._bridge.set_target(target, KP, KD) - return - kp, kd = gains - self._bridge.set_target(target, kp, kd) - - def _pop_due_targets(self, now: float) -> list[np.ndarray]: - due: list[np.ndarray] = [] - with self._pending_targets_cv: - while self._pending_targets and self._pending_targets[0][0] <= now: - _, target = self._pending_targets.popleft() - due.append(target) - self._last_command_queue_len = len(self._pending_targets) - return due - - def _flush_pending_targets(self, now: float | None = None) -> None: - if now is None: - now = time.monotonic() - due = self._pop_due_targets(now) - for target in due: - self._write_target_now(target) - if len(due) > 1: - logger.warning("Flushed %d delayed targets in one control tick", len(due)) - - def _send_target(self, target: np.ndarray) -> None: - if self._command_delay <= 0.0: - self._write_target_now(target) - self._last_command_queue_len = len(self._pending_targets) - return - with self._pending_targets_cv: - self._pending_targets.append((time.monotonic() + self._command_delay, np.asarray(target, dtype=np.float32).copy())) - self._last_command_queue_len = len(self._pending_targets) - self._pending_targets_cv.notify() - - def _start_command_sender(self) -> None: - if self._command_delay <= 0.0 or self._command_sender_thread is not None: - return - self._command_sender_running = True - self._command_sender_thread = threading.Thread(target=self._command_sender_loop, daemon=True) - self._command_sender_thread.start() - - def _stop_command_sender(self) -> None: - self._command_sender_running = False - with self._pending_targets_cv: - self._pending_targets_cv.notify_all() - if self._command_sender_thread is not None: - self._command_sender_thread.join(timeout=1.0) - self._command_sender_thread = None - with self._pending_targets_cv: - self._pending_targets.clear() - self._last_command_queue_len = 0 - - def _command_sender_loop(self) -> None: - while self._command_sender_running and not self._shutdown: - now = time.monotonic() - due = self._pop_due_targets(now) - if due: - self._write_target_now(due[-1]) - if len(due) > 1: - logger.warning("Dropped %d stale delayed targets", len(due) - 1) - continue + self._standing_qpos = np.zeros(FULL_QPOS_DIM, dtype=np.float64) + self._standing_qpos[3] = 1.0 + self._standing_qpos[ROOT_DIM:FULL_QPOS_DIM] = self.default_angles.astype(np.float64) + self._last_action = np.zeros(self.num_actions, dtype=np.float32) + self._step_count = 0 - with self._pending_targets_cv: - if not self._pending_targets: - self._pending_targets_cv.wait(timeout=0.02) - continue - wait_s = max(0.0, self._pending_targets[0][0] - time.monotonic()) - self._pending_targets_cv.wait(timeout=min(wait_s, 0.02)) + def _build_policy_and_obs(self) -> tuple[Any, Any]: + controller_cfg = cfg_get(self.cfg, "controller") + policy, obs_builder = _build_policy_components( + robot_cfg=self.robot_cfg, + controller_cfg=controller_cfg, + sim_cfg=build_simulation_cfg(self.cfg), + project_root=PROJECT_ROOT, + controller_cls=RLPolicyController, + ) + if not bool(getattr(policy, "_multi_input", False)): + raise ValueError("Standalone standing requires a dual-input ONNX policy ('obs', 'obs_history').") + return policy, obs_builder - # ================================================================== - # Safety checks - # ================================================================== + def run(self) -> None: + signal.signal(signal.SIGINT, self._signal_handler) + signal.signal(signal.SIGTERM, self._signal_handler) - def _check_emergency_stop(self) -> bool: try: - remote_bytes = self._bridge.get_wireless_remote() - if len(remote_bytes) < 4: - return False - keys = struct.unpack_from(" bool: - max_vel = np.max(np.abs(qvel)) - if max_vel > JOINT_VEL_LIMIT: - logger.error("SAFETY: joint vel %.2f rad/s > limit %.2f -- damping!", max_vel, JOINT_VEL_LIMIT) - return True - return False + if self.dry_run: + self._run_dry() + return + self._enter_standing() + self._run_control_loop() + finally: + self._cleanup() - # ---- Standing step (matches sim2real robot-control standing step) ---- + def _enter_standing(self) -> None: + logger.info("Entering debug mode...") + if not self.robot.enter_debug_mode(): + raise RuntimeError("Failed to enter debug mode") + self._entered_debug = True + time.sleep(0.5) - def _standing_step(self) -> np.ndarray: - """One step of RL policy standing inference. Returns target joint positions.""" - _t0 = time.monotonic() - qpos, qvel, quat, ang_vel = self._get_observation_state() + logger.info("Locking joints to current position...") + self.robot.lock_all_joints() + time.sleep(0.3) - # Build standing reference aligned to robot's current yaw - ref_qpos = self._standing_qpos.copy() - align_motion_qpos_yaw(quat, ref_qpos) + self._warmup_policy() + state = self.robot.get_state() + self._set_default_standing_reference(state) + self._reset_policy_state() + self.safety.start_kp_ramp() + logger.info("Mode -> STANDING (standalone)") - # Zero joint velocity reference for standing - motion_joint_vel = np.zeros(NUM_JOINTS, dtype=np.float32) - motion_qpos = np.asarray(ref_qpos[:7 + NUM_JOINTS], dtype=np.float32) + def _run_control_loop(self) -> None: + while not self.shutdown_requested: + t0 = time.monotonic() + self.remote.update(self.robot.get_wireless_remote()) + if self.remote.LB.pressed and self.remote.RB.pressed: + logger.warning("L1+R1 pressed -- damping") + self.shutdown_requested = True + break + if self.safety.check_joint_velocity_safety(): + self.robot.set_damping() + self.shutdown_requested = True + break - _t1 = time.monotonic() - # Build observation - obs = self._obs_builder.build( - robot_qpos=qpos, - robot_qvel=qvel, - robot_quat=quat, - robot_ang_vel=ang_vel, - motion_qpos=ref_qpos, + self._standing_step() + self._sleep_until(t0) + + def _standing_step(self) -> None: + robot_state = self.robot.get_state() + qpos = self._standing_qpos.copy() + motion_joint_vel = np.zeros(self.num_actions, dtype=np.float32) + motion_qpos = np.asarray(qpos[: ROOT_DIM + self.num_actions], dtype=np.float32) + reference_window = None + if obs_builder_requires_reference_window(self.obs_builder): + reference_window = build_static_reference_window(qpos, self.reference_window_builder, self.policy_hz) + obs = self.ref_proc.build_observation( + robot_state=robot_state, + motion_qpos=motion_qpos, motion_joint_vel=motion_joint_vel, last_action=self._last_action, + anchor_lin_vel_w=np.zeros(3, dtype=np.float32), + anchor_ang_vel_w=np.zeros(3, dtype=np.float32), + reference_window=reference_window, ) - - _t2 = time.monotonic() - # Policy inference - action = self._policy.compute_action(obs) - - target_dof_pos = self._policy.get_target_dof_pos(action) - _t3 = time.monotonic() - - # Diagnostic - self._step_count += 1 - step_ms = (_t3 - _t0) * 1000 - if self._step_count % 25 == 1 or step_ms > (1000.0 / POLICY_HZ): - tag = "OVERRUN" if step_ms > (1000.0 / POLICY_HZ) else "DIAG" - logger.info( - "%s step=%d | state=%.2fms obs=%.2fms infer=%.2fms total=%.1fms | " - "obs_age=%.1fms cmd_q=%d | qvel_norm=%.4f | action_norm=%.4f | " - "target[:6]=%s | qpos[:6]=%s", - tag, self._step_count, - (_t1 - _t0) * 1000, (_t2 - _t1) * 1000, (_t3 - _t2) * 1000, - step_ms, - self._last_obs_age_s * 1000, - self._last_command_queue_len, - float(np.linalg.norm(qvel)), - float(np.linalg.norm(action)), - np.array2string(target_dof_pos[:6], precision=4, separator=','), - np.array2string(qpos[:6], precision=4, separator=','), - ) - - # Joint limits - target_dof_pos = np.clip(target_dof_pos, JOINT_POS_LOWER, JOINT_POS_UPPER) - + obs = self.ref_proc.validate_observation(obs) + action = np.zeros(self.num_actions, dtype=np.float32) if self.no_policy else self.policy.compute_action(obs) + target_dof_pos = self.safety.clip_to_joint_limits(self.policy.get_target_dof_pos(action)) + if self.dry_run: + self._log_step(robot_state.qvel, action, target_dof_pos, dry=True) + else: + self.safety.send_positions(target_dof_pos) + self._log_step(robot_state.qvel, action, target_dof_pos, dry=False) self._last_action = np.asarray(action, dtype=np.float32).reshape(-1) - return target_dof_pos - - # ---- Inference thread ---- - - def _start_inference(self) -> None: - if self._inference_thread is not None: - return - self._inference_running = True - self._inference_thread = threading.Thread(target=self._inference_loop, daemon=True) - self._inference_thread.start() - - def _stop_inference(self) -> None: - self._inference_running = False - if self._inference_thread is not None: - self._inference_thread.join(timeout=2.0) - self._inference_thread = None - - def _inference_loop(self) -> None: - """~50Hz soft-realtime inference loop. - - Runs policy inference and writes target positions to _target_buf. - The 250Hz publish thread reads _target_buf independently — even if - inference takes 30ms, the robot keeps receiving 250Hz commands with - the last good target. - """ - dt = 1.0 / POLICY_HZ - loop_count = 0 - overrun_count = 0 - max_elapsed = 0.0 - elapsed_sum = 0.0 - - while self._inference_running and not self._shutdown: - t0 = time.monotonic() - if self._command_delay <= 0.0: - self._flush_pending_targets(t0) - - # Emergency stop check - if self._check_emergency_stop(): - logger.warning("L1+R1 pressed -- emergency damping!") - self._set_damping() - self._shutdown = True - break - - # Joint velocity safety check (warning only, no shutdown) - _, qvel, _, _ = self._get_robot_state() - self._check_joint_vel_safety(qvel) - - # Artificial state delay (for debugging) - if self._state_delay > 0: - time.sleep(self._state_delay) - - # Policy step - if self._no_policy: - target = np.clip(DEFAULT_ANGLES.copy(), JOINT_POS_LOWER, JOINT_POS_UPPER) - else: - target = self._standing_step() - - # Write target to publish thread. Kp ramps after standing entry; - # policy targets stay unchanged, matching sim2real STANDING. - self._send_target(target) - - # Timing diagnostics (informational only — not a control failure) - elapsed = time.monotonic() - t0 - loop_count += 1 - elapsed_sum += elapsed - max_elapsed = max(max_elapsed, elapsed) - if elapsed > dt: - overrun_count += 1 - - if loop_count % 50 == 0: - avg_ms = (elapsed_sum / loop_count) * 1000 - logger.info( - "Inference stats: avg=%.1fms, max=%.1fms, overruns=%d/%d (target=%.1fms)", - avg_ms, max_elapsed * 1000, overrun_count, loop_count, dt * 1000, - ) - - remain = dt - (time.monotonic() - t0) - if remain > 0: - time.sleep(remain) - - if self._command_delay <= 0.0: - self._flush_pending_targets() - - # ---- Main loop ---- def _run_dry(self) -> None: - """Dry-run: read state + build obs + infer, no motor commands. Safe for timing tests.""" logger.info("=== DRY-RUN MODE: no motor commands will be sent ===") - self._last_action = np.zeros(NUM_JOINTS, dtype=np.float32) - self._policy.reset() - - dt = 1.0 / POLICY_HZ - loop_count = 0 - overrun_count = 0 - elapsed_sum = 0.0 - max_elapsed = 0.0 - state_sum = 0.0 - obs_sum = 0.0 - infer_sum = 0.0 - - self._start_state_sampler() - while not self._shutdown: + self._warmup_policy() + state = self.robot.get_state() + self._set_default_standing_reference(state) + self._reset_policy_state() + while not self.shutdown_requested: t0 = time.monotonic() - - # 1. Read state - qpos, qvel, quat, ang_vel = self._get_observation_state() - t1 = time.monotonic() - - # 2. Build reference - ref_qpos = self._standing_qpos.copy() - align_motion_qpos_yaw(quat, ref_qpos) - motion_joint_vel = np.zeros(NUM_JOINTS, dtype=np.float32) - - # 3. Build observation - obs = self._obs_builder.build( - robot_qpos=qpos, robot_qvel=qvel, - robot_quat=quat, robot_ang_vel=ang_vel, - motion_qpos=ref_qpos, motion_joint_vel=motion_joint_vel, - last_action=self._last_action, - ) - t2 = time.monotonic() - - # 4. Policy inference - raw_action = self._policy.compute_action(obs) - target = self._policy.get_target_dof_pos(raw_action) - t3 = time.monotonic() - - self._last_action = raw_action.copy() - loop_count += 1 - e = t3 - t0 - elapsed_sum += e - max_elapsed = max(max_elapsed, e) - state_sum += (t1 - t0) - obs_sum += (t2 - t1) - infer_sum += (t3 - t2) - if e > dt: - overrun_count += 1 - - if loop_count % 50 == 0: - n = loop_count - logger.info( - "DRY step=%d | state=%.2fms obs=%.2fms infer=%.2fms total=%.2fms | " - "max=%.2fms overruns=%d/%d | obs_age=%.1fms | target[:6]=%s", - n, - (state_sum / n) * 1000, (obs_sum / n) * 1000, - (infer_sum / n) * 1000, (elapsed_sum / n) * 1000, - max_elapsed * 1000, overrun_count, n, - self._last_obs_age_s * 1000, - np.array2string(target[:6], precision=4, separator=','), - ) - - remain = dt - (time.monotonic() - t0) - if remain > 0: - time.sleep(remain) - - self._stop_state_sampler() + self._standing_step() + self._sleep_until(t0) + + def _set_default_standing_reference(self, state: object) -> None: + self._standing_qpos[:] = 0.0 + self._standing_qpos[0:3] = self._default_root_pos + self._standing_qpos[3:7] = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64) + align_motion_qpos_yaw(np.asarray(getattr(state, "quat"), dtype=np.float32), self._standing_qpos) + self._standing_qpos[ROOT_DIM:FULL_QPOS_DIM] = self.default_angles.astype(np.float64) + + def _reset_policy_state(self) -> None: + self._last_action = np.zeros(self.num_actions, dtype=np.float32) + self.ref_proc.reset_smoothers() + self.ref_proc.reset_alignment() + self.policy.reset() + self.obs_builder.reset() + + def _warmup_policy(self) -> None: + if self.no_policy: + return + logger.info("Warming up ONNX runtime...") + dummy_obs = np.zeros(int(self.obs_builder.total_obs_size), dtype=np.float32) + for _ in range(3): + self.policy.compute_action(dummy_obs) + self.policy.reset() + logger.info("ONNX warmup complete") + + def _log_step(self, qvel: np.ndarray, action: np.ndarray, target: np.ndarray, *, dry: bool) -> None: + self._step_count += 1 + if self._step_count % 50 != 1: + return + tag = "DRY" if dry else "STANDING" logger.info( - "DRY-RUN finished: %d steps, avg total=%.2fms " - "(state=%.2f obs=%.2f infer=%.2f) max=%.2fms overruns=%d", - loop_count, - (elapsed_sum / max(loop_count, 1)) * 1000, - (state_sum / max(loop_count, 1)) * 1000, - (obs_sum / max(loop_count, 1)) * 1000, - (infer_sum / max(loop_count, 1)) * 1000, - max_elapsed * 1000, overrun_count, + "%s step=%d | qvel_norm=%.4f | action_norm=%.4f | target[:6]=%s", + tag, + self._step_count, + float(np.linalg.norm(qvel)), + float(np.linalg.norm(action)), + np.array2string(target[:6], precision=4, separator=","), ) - def run(self) -> None: - signal.signal(signal.SIGINT, self._signal_handler) - signal.signal(signal.SIGTERM, self._signal_handler) + def _sleep_until(self, t0: float) -> None: + remaining = self.dt - (time.monotonic() - t0) + if remaining > 0.0: + time.sleep(remaining) - if self._dry_run: - self._run_dry() - return + def _signal_handler(self, _signum: int, _frame: object) -> None: + logger.info("Shutdown signal received") + self.shutdown_requested = True + def _cleanup(self) -> None: + if self.dry_run: + self.robot.close() + return + if not self._entered_debug: + self.robot.close() + return try: - # 1. Enter debug mode - logger.info("Entering debug mode ...") - if not self.enter_debug_mode(): - logger.error("Failed to enter debug mode, aborting") - return + logger.info("Shutting down: setting damping...") + self.robot.set_damping() time.sleep(0.5) - - # 2. Start publish thread (C++: 500Hz, Python: 250Hz) - self._start_publish() - - # 3. Lock joints to current position - logger.info("Locking joints to current position ...") - self._lock_joints() - time.sleep(0.3) - - # 4. ONNX warmup — eliminate first-inference spike - logger.info("Warming up ONNX runtime ...") - dummy_obs = np.zeros(self._obs_builder.total_obs_size, dtype=np.float32) - for _ in range(3): - self._policy.compute_action(dummy_obs) - self._policy.reset() - logger.info("ONNX warmup complete") - - # 5. Initialize policy state - self._last_action = np.zeros(NUM_JOINTS, dtype=np.float32) - self._start_kp_ramp() - self._start_state_sampler() - self._start_command_sender() - - logger.info("Starting RL policy standing (pipelined)") - - # 6. Start inference thread (~50Hz, soft deadline) - self._start_inference() - - # 7. Main thread waits for shutdown signal - while not self._shutdown: - time.sleep(0.1) - - # 8. Stop inference thread - self._stop_inference() - - except Exception as exc: - logger.error("Error in main loop: %s", exc) finally: - self._cleanup() - - def _signal_handler(self, signum, frame) -> None: - logger.info("Shutdown signal received") - self._shutdown = True - - def _cleanup(self) -> None: - self._inference_running = False - self._stop_state_sampler() - self._stop_command_sender() - logger.info("Shutting down: setting damping ...") - self._set_damping() - time.sleep(0.5) - logger.info("Stopping publish and restoring ai mode ...") - self.exit_debug_mode() - logger.info("Done.") - - -def main(): - parser = argparse.ArgumentParser(description="G1 standalone standing with RL policy") - parser.add_argument( - "--policy", type=str, required=True, - help="Path to ONNX policy file (e.g. track.onnx)", - ) - parser.add_argument( - "--network-interface", type=str, default="eth0", - help="Network interface for DDS (e.g. eth0, enp130s0)", - ) - parser.add_argument( - "--no-policy", action="store_true", - help="Skip RL policy, just send fixed DEFAULT_ANGLES (diagnostic mode)", - ) - parser.add_argument( - "--state-delay", type=float, default=0.0, - help=( - "Legacy loop delay before the policy step. This consumes timing budget but does not make " - "the observation stale; prefer --obs-delay or --command-delay for latency tests." - ), - ) - parser.add_argument( - "--obs-delay", type=float, default=0.0, - help="Use LowState sampled this many seconds in the past when building the policy observation.", - ) - parser.add_argument( - "--command-delay", type=float, default=0.0, - help="Delay writing each computed target to the C++ publish thread by this many seconds.", - ) - parser.add_argument( - "--dry-run", action="store_true", - help="Read state + build obs + infer only, no motor commands (safe timing test)", - ) - parser.add_argument( - "--publish-hz", type=int, default=200, - help="C++ publish frequency in Hz (default: 200, matching training pd_hz)", - ) - parser.add_argument( - "--kp-ramp-duration", type=float, default=DEFAULT_KP_RAMP_DURATION, - help="Seconds to ramp Kp after entering standing (default: 2.0, matches sim2real)", - ) - parser.add_argument( - "--kp-ramp-floor-ratio", type=float, default=DEFAULT_KP_RAMP_FLOOR_RATIO, - help="Initial Kp ratio for the standing ramp (default: 0.1, matches sim2real)", + logger.info("Stopping publish and restoring ai mode...") + self.robot.exit_debug_mode() + self.robot.close() + logger.info("Done.") + + +def _build_cfg(args: argparse.Namespace) -> Any: + cfg = OmegaConf.create( + { + "policy_hz": DEFAULT_POLICY_HZ, + "startup_ramp_duration": args.kp_ramp_duration, + "kp_ramp_floor_ratio": args.kp_ramp_floor_ratio, + "joint_vel_limit": args.joint_vel_limit, + "reference_steps": [0], + "reference_velocity_smoothing_alpha": 1.0, + "reference_anchor_velocity_smoothing_alpha": 1.0, + "robot": OmegaConf.load(PROJECT_ROOT / "teleopit" / "configs" / "robot" / "g1.yaml"), + "controller": OmegaConf.load(PROJECT_ROOT / "teleopit" / "configs" / "controller" / "rl_policy.yaml"), + "real_robot": OmegaConf.load(PROJECT_ROOT / "teleopit" / "configs" / "sim2real.yaml").real_robot, + } ) + cfg.controller.policy_path = str(args.policy) + cfg.real_robot.network_interface = str(args.network_interface) + cfg.real_robot.publish_hz = int(args.publish_hz) + return cfg + + +def main() -> None: + parser = argparse.ArgumentParser(description="G1 standalone standing using sim2real standing implementation") + parser.add_argument("--policy", type=str, required=True, help="Path to a dual-input ONNX policy") + parser.add_argument("--network-interface", type=str, default="eth0", help="DDS network interface") + parser.add_argument("--dry-run", action="store_true", help="Read state and run policy without motor commands") + parser.add_argument("--no-policy", action="store_true", help="Send zero-action standing target") + parser.add_argument("--publish-hz", type=int, default=200, help="C++ bridge publish frequency") + parser.add_argument("--kp-ramp-duration", type=float, default=2.0, help="Startup Kp ramp duration in seconds") + parser.add_argument("--kp-ramp-floor-ratio", type=float, default=0.1, help="Initial Kp ratio during startup") + parser.add_argument("--joint-vel-limit", type=float, default=10.0, help="Damp if any joint exceeds this velocity") args = parser.parse_args() - controller = StandingController( - network_interface=args.network_interface, - policy_path=args.policy, - no_policy=args.no_policy, - publish_hz=args.publish_hz, - obs_delay=args.obs_delay, - command_delay=args.command_delay, - kp_ramp_duration=args.kp_ramp_duration, - kp_ramp_floor_ratio=args.kp_ramp_floor_ratio, + controller = StandaloneStandingController( + _build_cfg(args), + dry_run=bool(args.dry_run), + no_policy=bool(args.no_policy), ) - controller._state_delay = args.state_delay - controller._dry_run = args.dry_run controller.run() diff --git a/teleopit/sim2real/unitree_g1.py b/teleopit/sim2real/unitree_g1.py index 704b612c..d36c95f6 100644 --- a/teleopit/sim2real/unitree_g1.py +++ b/teleopit/sim2real/unitree_g1.py @@ -30,6 +30,7 @@ class UnitreeG1Robot: def __init__(self, cfg: Any) -> None: self._network_interface: str = str(cfg_get(cfg, "network_interface", "eth0")) + self._publish_hz: int = int(cfg_get(cfg, "publish_hz", 200)) self._kp = np.asarray(cfg_get(cfg, "kp_real", [100] * NUM_JOINTS), dtype=np.float32) self._kd = np.asarray(cfg_get(cfg, "kd_real", [2] * NUM_JOINTS), dtype=np.float32) @@ -40,7 +41,7 @@ def __init__(self, cfg: Any) -> None: import g1_bridge_sdk - self._bridge = g1_bridge_sdk.G1Bridge(self._network_interface) + self._bridge = g1_bridge_sdk.G1Bridge(self._network_interface, self._publish_hz) self._publishing: bool = False if not self._bridge.wait_for_state(3.0): From 574d092a51ccf56310d6e6e478801dc609d2113c Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Fri, 19 Jun 2026 21:40:22 +0800 Subject: [PATCH 104/122] Add standalone standing timing telemetry --- scripts/run/standalone_standing.py | 191 +++++++++++++++++++++++++++-- 1 file changed, 184 insertions(+), 7 deletions(-) diff --git a/scripts/run/standalone_standing.py b/scripts/run/standalone_standing.py index 49e7989a..c16f535f 100644 --- a/scripts/run/standalone_standing.py +++ b/scripts/run/standalone_standing.py @@ -4,6 +4,7 @@ from __future__ import annotations import argparse +from dataclasses import dataclass import logging import signal import time @@ -32,6 +33,144 @@ logger = logging.getLogger(__name__) +@dataclass(frozen=True) +class _StepTiming: + state_s: float + obs_s: float + infer_s: float + target_s: float + send_s: float + action_delta: float | None + target_delta: float | None + qvel_norm: float + ang_vel_norm: float + + +class _StandaloneTimingReporter: + def __init__( + self, + *, + target_period_s: float, + log_interval_s: float = 1.0, + deadline_miss_tolerance_s: float = 0.001, + ) -> None: + self._target_period_s = float(target_period_s) + self._log_interval_s = float(log_interval_s) + self._deadline_miss_tolerance_s = float(deadline_miss_tolerance_s) + self._window_start_s: float | None = None + self._loop_ms: list[float] = [] + self._late_ms: list[float] = [] + self._work_ms: list[float] = [] + self._state_ms: list[float] = [] + self._obs_ms: list[float] = [] + self._infer_ms: list[float] = [] + self._target_ms: list[float] = [] + self._send_ms: list[float] = [] + self._action_delta: list[float] = [] + self._target_delta: list[float] = [] + self._qvel_norm: list[float] = [] + self._ang_vel_norm: list[float] = [] + self._deadline_miss_count = 0 + self._work_overrun_count = 0 + + def record( + self, + *, + loop_start_s: float, + work_elapsed_s: float, + cycle_elapsed_s: float, + step: _StepTiming, + ) -> None: + if self._window_start_s is None: + self._window_start_s = float(loop_start_s) + self._loop_ms.append(float(cycle_elapsed_s) * 1000.0) + self._late_ms.append(max(0.0, float(cycle_elapsed_s) - self._target_period_s) * 1000.0) + self._work_ms.append(float(work_elapsed_s) * 1000.0) + self._state_ms.append(float(step.state_s) * 1000.0) + self._obs_ms.append(float(step.obs_s) * 1000.0) + self._infer_ms.append(float(step.infer_s) * 1000.0) + self._target_ms.append(float(step.target_s) * 1000.0) + self._send_ms.append(float(step.send_s) * 1000.0) + if step.action_delta is not None: + self._action_delta.append(float(step.action_delta)) + if step.target_delta is not None: + self._target_delta.append(float(step.target_delta)) + self._qvel_norm.append(float(step.qvel_norm)) + self._ang_vel_norm.append(float(step.ang_vel_norm)) + if cycle_elapsed_s > self._target_period_s + self._deadline_miss_tolerance_s: + self._deadline_miss_count += 1 + if work_elapsed_s > self._target_period_s + 1e-9: + self._work_overrun_count += 1 + if loop_start_s - self._window_start_s >= self._log_interval_s: + self._emit(loop_start_s) + + def _emit(self, end_s: float) -> None: + sample_count = len(self._loop_ms) + if sample_count <= 0: + self._reset(end_s) + return + logger.info( + "Standalone timing | samples=%d window=%.1fs | " + "loop_ms p50=%.2f p95=%.2f p99=%.2f max=%.2f | " + "late_ms p50=%.2f p95=%.2f p99=%.2f max=%.2f deadline_miss(>%.2fms)=%d/%d | " + "work_ms p50=%.2f p95=%.2f p99=%.2f max=%.2f work_overrun=%d/%d | " + "state_ms p50=%.2f p95=%.2f p99=%.2f max=%.2f | " + "obs_ms p50=%.2f p95=%.2f p99=%.2f max=%.2f | " + "infer_ms p50=%.2f p95=%.2f p99=%.2f max=%.2f | " + "target_ms p50=%.2f p95=%.2f p99=%.2f max=%.2f | " + "send_ms p50=%.2f p95=%.2f p99=%.2f max=%.2f | " + "action_delta p50=%.4f p95=%.4f p99=%.4f max=%.4f | " + "target_delta p50=%.4f p95=%.4f p99=%.4f max=%.4f | " + "qvel_norm p50=%.4f p95=%.4f p99=%.4f max=%.4f | " + "ang_vel_norm p50=%.4f p95=%.4f p99=%.4f max=%.4f", + sample_count, + end_s - float(self._window_start_s), + *self._summarize(self._loop_ms), + *self._summarize(self._late_ms), + self._deadline_miss_tolerance_s * 1000.0, + self._deadline_miss_count, + sample_count, + *self._summarize(self._work_ms), + self._work_overrun_count, + sample_count, + *self._summarize(self._state_ms), + *self._summarize(self._obs_ms), + *self._summarize(self._infer_ms), + *self._summarize(self._target_ms), + *self._summarize(self._send_ms), + *self._summarize(self._action_delta), + *self._summarize(self._target_delta), + *self._summarize(self._qvel_norm), + *self._summarize(self._ang_vel_norm), + ) + self._reset(end_s) + + def _reset(self, window_start_s: float) -> None: + self._window_start_s = float(window_start_s) + self._loop_ms.clear() + self._late_ms.clear() + self._work_ms.clear() + self._state_ms.clear() + self._obs_ms.clear() + self._infer_ms.clear() + self._target_ms.clear() + self._send_ms.clear() + self._action_delta.clear() + self._target_delta.clear() + self._qvel_norm.clear() + self._ang_vel_norm.clear() + self._deadline_miss_count = 0 + self._work_overrun_count = 0 + + @staticmethod + def _summarize(samples: list[float]) -> tuple[float, float, float, float]: + if not samples: + return 0.0, 0.0, 0.0, 0.0 + values = np.asarray(samples, dtype=np.float64) + p50, p95, p99 = np.percentile(values, [50.0, 95.0, 99.0]) + return float(p50), float(p95), float(p99), float(np.max(values)) + + class StandaloneStandingController: """Small wrapper around the production sim2real STANDING implementation.""" @@ -79,7 +218,9 @@ def __init__(self, cfg: Any, *, dry_run: bool = False, no_policy: bool = False) self._standing_qpos[3] = 1.0 self._standing_qpos[ROOT_DIM:FULL_QPOS_DIM] = self.default_angles.astype(np.float64) self._last_action = np.zeros(self.num_actions, dtype=np.float32) + self._last_target: np.ndarray | None = None self._step_count = 0 + self._timing = _StandaloneTimingReporter(target_period_s=self.dt) def _build_policy_and_obs(self) -> tuple[Any, Any]: controller_cfg = cfg_get(self.cfg, "controller") @@ -138,11 +279,20 @@ def _run_control_loop(self) -> None: self.shutdown_requested = True break - self._standing_step() - self._sleep_until(t0) - - def _standing_step(self) -> None: + step_timing = self._standing_step() + work_elapsed_s = time.monotonic() - t0 + cycle_elapsed_s = self._sleep_until(t0) + self._timing.record( + loop_start_s=t0, + work_elapsed_s=work_elapsed_s, + cycle_elapsed_s=cycle_elapsed_s, + step=step_timing, + ) + + def _standing_step(self) -> _StepTiming: + t0 = time.monotonic() robot_state = self.robot.get_state() + t_state = time.monotonic() qpos = self._standing_qpos.copy() motion_joint_vel = np.zeros(self.num_actions, dtype=np.float32) motion_qpos = np.asarray(qpos[: ROOT_DIM + self.num_actions], dtype=np.float32) @@ -159,14 +309,32 @@ def _standing_step(self) -> None: reference_window=reference_window, ) obs = self.ref_proc.validate_observation(obs) + t_obs = time.monotonic() action = np.zeros(self.num_actions, dtype=np.float32) if self.no_policy else self.policy.compute_action(obs) + t_infer = time.monotonic() target_dof_pos = self.safety.clip_to_joint_limits(self.policy.get_target_dof_pos(action)) + t_target = time.monotonic() + action_delta = float(np.linalg.norm(action - self._last_action)) + target_delta = None if self._last_target is None else float(np.linalg.norm(target_dof_pos - self._last_target)) if self.dry_run: self._log_step(robot_state.qvel, action, target_dof_pos, dry=True) else: self.safety.send_positions(target_dof_pos) self._log_step(robot_state.qvel, action, target_dof_pos, dry=False) + t_send = time.monotonic() self._last_action = np.asarray(action, dtype=np.float32).reshape(-1) + self._last_target = np.asarray(target_dof_pos, dtype=np.float32).reshape(-1).copy() + return _StepTiming( + state_s=t_state - t0, + obs_s=t_obs - t_state, + infer_s=t_infer - t_obs, + target_s=t_target - t_infer, + send_s=t_send - t_target, + action_delta=action_delta, + target_delta=target_delta, + qvel_norm=float(np.linalg.norm(robot_state.qvel)), + ang_vel_norm=float(np.linalg.norm(robot_state.ang_vel)), + ) def _run_dry(self) -> None: logger.info("=== DRY-RUN MODE: no motor commands will be sent ===") @@ -176,8 +344,15 @@ def _run_dry(self) -> None: self._reset_policy_state() while not self.shutdown_requested: t0 = time.monotonic() - self._standing_step() - self._sleep_until(t0) + step_timing = self._standing_step() + work_elapsed_s = time.monotonic() - t0 + cycle_elapsed_s = self._sleep_until(t0) + self._timing.record( + loop_start_s=t0, + work_elapsed_s=work_elapsed_s, + cycle_elapsed_s=cycle_elapsed_s, + step=step_timing, + ) def _set_default_standing_reference(self, state: object) -> None: self._standing_qpos[:] = 0.0 @@ -192,6 +367,7 @@ def _reset_policy_state(self) -> None: self.ref_proc.reset_alignment() self.policy.reset() self.obs_builder.reset() + self._last_target = None def _warmup_policy(self) -> None: if self.no_policy: @@ -217,10 +393,11 @@ def _log_step(self, qvel: np.ndarray, action: np.ndarray, target: np.ndarray, *, np.array2string(target[:6], precision=4, separator=","), ) - def _sleep_until(self, t0: float) -> None: + def _sleep_until(self, t0: float) -> float: remaining = self.dt - (time.monotonic() - t0) if remaining > 0.0: time.sleep(remaining) + return time.monotonic() - t0 def _signal_handler(self, _signum: int, _frame: object) -> None: logger.info("Shutdown signal received") From e5f8771094478e2bc2781b4a71a489ca7eb8e1b5 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Fri, 19 Jun 2026 22:47:33 +0800 Subject: [PATCH 105/122] Add standalone standing delay diagnostics --- scripts/run/standalone_standing.py | 44 +++++++++++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/scripts/run/standalone_standing.py b/scripts/run/standalone_standing.py index c16f535f..1670746e 100644 --- a/scripts/run/standalone_standing.py +++ b/scripts/run/standalone_standing.py @@ -174,10 +174,20 @@ def _summarize(samples: list[float]) -> tuple[float, float, float, float]: class StandaloneStandingController: """Small wrapper around the production sim2real STANDING implementation.""" - def __init__(self, cfg: Any, *, dry_run: bool = False, no_policy: bool = False) -> None: + def __init__( + self, + cfg: Any, + *, + dry_run: bool = False, + no_policy: bool = False, + obs_delay_s: float = 0.0, + command_delay_s: float = 0.0, + ) -> None: self.cfg = cfg self.dry_run = dry_run self.no_policy = no_policy + self.obs_delay_s = self._validate_delay_s(obs_delay_s, name="obs_delay_s") + self.command_delay_s = self._validate_delay_s(command_delay_s, name="command_delay_s") self.shutdown_requested = False self.policy_hz = float(cfg_get(cfg, "policy_hz", DEFAULT_POLICY_HZ)) @@ -222,6 +232,20 @@ def __init__(self, cfg: Any, *, dry_run: bool = False, no_policy: bool = False) self._step_count = 0 self._timing = _StandaloneTimingReporter(target_period_s=self.dt) + if self.obs_delay_s > 0.0 or self.command_delay_s > 0.0: + logger.info( + "Diagnostic delay injection enabled | obs_delay=%.3fms command_delay=%.3fms", + self.obs_delay_s * 1000.0, + self.command_delay_s * 1000.0, + ) + + @staticmethod + def _validate_delay_s(value: float, *, name: str) -> float: + delay_s = float(value) + if not np.isfinite(delay_s) or delay_s < 0.0: + raise ValueError(f"{name} must be finite and >= 0, got {value!r}") + return delay_s + def _build_policy_and_obs(self) -> tuple[Any, Any]: controller_cfg = cfg_get(self.cfg, "controller") policy, obs_builder = _build_policy_components( @@ -293,6 +317,8 @@ def _standing_step(self) -> _StepTiming: t0 = time.monotonic() robot_state = self.robot.get_state() t_state = time.monotonic() + if self.obs_delay_s > 0.0: + time.sleep(self.obs_delay_s) qpos = self._standing_qpos.copy() motion_joint_vel = np.zeros(self.num_actions, dtype=np.float32) motion_qpos = np.asarray(qpos[: ROOT_DIM + self.num_actions], dtype=np.float32) @@ -319,6 +345,8 @@ def _standing_step(self) -> _StepTiming: if self.dry_run: self._log_step(robot_state.qvel, action, target_dof_pos, dry=True) else: + if self.command_delay_s > 0.0: + time.sleep(self.command_delay_s) self.safety.send_positions(target_dof_pos) self._log_step(robot_state.qvel, action, target_dof_pos, dry=False) t_send = time.monotonic() @@ -452,12 +480,26 @@ def main() -> None: parser.add_argument("--kp-ramp-duration", type=float, default=2.0, help="Startup Kp ramp duration in seconds") parser.add_argument("--kp-ramp-floor-ratio", type=float, default=0.1, help="Initial Kp ratio during startup") parser.add_argument("--joint-vel-limit", type=float, default=10.0, help="Damp if any joint exceeds this velocity") + parser.add_argument( + "--obs-delay-ms", + type=float, + default=0.0, + help="Diagnostic delay after LowState read, before observation build/inference", + ) + parser.add_argument( + "--command-delay-ms", + type=float, + default=0.0, + help="Diagnostic delay after target computation, before C++ bridge set_target", + ) args = parser.parse_args() controller = StandaloneStandingController( _build_cfg(args), dry_run=bool(args.dry_run), no_policy=bool(args.no_policy), + obs_delay_s=float(args.obs_delay_ms) / 1000.0, + command_delay_s=float(args.command_delay_ms) / 1000.0, ) controller.run() From 64dab5642adbb1e1338e3b6eaa0e2a70e81f8e4c Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Fri, 19 Jun 2026 23:31:02 +0800 Subject: [PATCH 106/122] Tune training action latency and rewards --- tests/test_domain_randomization.py | 42 +++++++++++++++++++ train_mimic/scripts/train.py | 5 ++- train_mimic/tasks/tracking/config/env.py | 40 +++++++++++++++++- .../tasks/tracking/tracking_env_cfg.py | 2 +- 4 files changed, 85 insertions(+), 4 deletions(-) diff --git a/tests/test_domain_randomization.py b/tests/test_domain_randomization.py index b98ce407..4b17d894 100644 --- a/tests/test_domain_randomization.py +++ b/tests/test_domain_randomization.py @@ -83,3 +83,45 @@ def test_play_env_disables_training_only_domain_randomization() -> None: assert "physics_material" not in play_cfg.events assert "randomize_rigid_body_mass" not in play_cfg.events assert play_cfg.events == {} + + +def test_general_tracking_enables_action_latency_randomization_for_training_only() -> None: + import mjlab.tasks # noqa: F401 + import train_mimic.tasks # noqa: F401 + from mjlab.actuator.delayed_actuator import DelayedActuatorCfg + from mjlab.tasks.registry import load_env_cfg + + env_cfg = load_env_cfg(DEFAULT_TASK) + for group_name in ("actor", "actor_history", "critic", "critic_history"): + for term_name, term_cfg in env_cfg.observations[group_name].terms.items(): + assert term_cfg.delay_min_lag == 0, (group_name, term_name) + assert term_cfg.delay_max_lag == 0, (group_name, term_name) + + actuators = env_cfg.scene.entities["robot"].articulation.actuators + assert actuators + assert all(isinstance(actuator, DelayedActuatorCfg) for actuator in actuators) + assert all(actuator.delay_target == "position" for actuator in actuators) + assert all(actuator.delay_min_lag == 0 for actuator in actuators) + assert all(actuator.delay_max_lag == 1 for actuator in actuators) + assert all(actuator.delay_hold_prob == 0.8 for actuator in actuators) + assert all(actuator.delay_update_period == 4 for actuator in actuators) + + play_cfg = load_env_cfg(DEFAULT_TASK, play=True) + for group_name in ("actor", "actor_history", "critic", "critic_history"): + for term_name, term_cfg in play_cfg.observations[group_name].terms.items(): + assert term_cfg.delay_max_lag == 0, (group_name, term_name) + assert not any( + isinstance(actuator, DelayedActuatorCfg) + for actuator in play_cfg.scene.entities["robot"].articulation.actuators + ) + + +def test_g1_training_robot_cfg_can_enable_action_latency_randomization() -> None: + from mjlab.actuator.delayed_actuator import DelayedActuatorCfg + from train_mimic.tasks.tracking.config.env import make_g1_training_robot_cfg + + robot_cfg = make_g1_training_robot_cfg(action_latency_randomization=True) + + actuators = robot_cfg.articulation.actuators + assert actuators + assert all(isinstance(actuator, DelayedActuatorCfg) for actuator in actuators) diff --git a/train_mimic/scripts/train.py b/train_mimic/scripts/train.py index 4f8bf7b3..353aa0de 100644 --- a/train_mimic/scripts/train.py +++ b/train_mimic/scripts/train.py @@ -385,7 +385,10 @@ def _handle_shutdown(signum: int, _frame: Any) -> None: robot_xml = resolve_g1_training_xml(args.robot_xml) if not robot_xml.is_file(): raise FileNotFoundError(f"G1 training MuJoCo XML not found: {robot_xml}") - env_cfg.scene.entities["robot"] = make_g1_training_robot_cfg(robot_xml) + env_cfg.scene.entities["robot"] = make_g1_training_robot_cfg( + robot_xml, + action_latency_randomization=True, + ) env_cfg.robot_xml = str(robot_xml) if args.num_envs is not None: env_cfg.scene.num_envs = args.num_envs diff --git a/train_mimic/tasks/tracking/config/env.py b/train_mimic/tasks/tracking/config/env.py index 21dbf3cc..56a2dab7 100644 --- a/train_mimic/tasks/tracking/config/env.py +++ b/train_mimic/tasks/tracking/config/env.py @@ -8,6 +8,7 @@ import mujoco +from mjlab.actuator.delayed_actuator import DelayedActuatorCfg from mjlab.asset_zoo.robots import G1_ACTION_SCALE, get_g1_robot_cfg from mjlab.envs import ManagerBasedRlEnvCfg from mjlab.envs.mdp.actions import JointPositionActionCfg @@ -48,6 +49,11 @@ "randomize_rigid_body_mass", ) +_ACTION_LATENCY_RANDOMIZATION_MIN_LAG_PHYSICS_STEPS = 0 +_ACTION_LATENCY_RANDOMIZATION_MAX_LAG_PHYSICS_STEPS = 1 +_ACTION_LATENCY_RANDOMIZATION_HOLD_PROB = 0.8 +_ACTION_LATENCY_RANDOMIZATION_UPDATE_PERIOD_PHYSICS_STEPS = 4 + def resolve_g1_training_xml(robot_xml: str | Path | None = None) -> Path: """Resolve the MuJoCo XML used for G1 policy training.""" @@ -74,13 +80,39 @@ def _get_g1_training_spec(robot_xml: str | Path | None = None) -> mujoco.MjSpec: return spec -def make_g1_training_robot_cfg(robot_xml: str | Path | None = None): +def make_g1_training_robot_cfg( + robot_xml: str | Path | None = None, + *, + action_latency_randomization: bool = False, +): robot_cfg = get_g1_robot_cfg() + robot_cfg.articulation = deepcopy(robot_cfg.articulation) + if action_latency_randomization: + _enable_robot_action_latency_randomization(robot_cfg) xml_path = resolve_g1_training_xml(robot_xml) robot_cfg.spec_fn = partial(_get_g1_training_spec, xml_path) return robot_cfg +def _enable_robot_action_latency_randomization(robot_cfg) -> None: + articulation = robot_cfg.articulation + if articulation is None: + raise ValueError( + "G1 robot cfg must define articulation actuators before enabling action latency randomization" + ) + articulation.actuators = tuple( + DelayedActuatorCfg( + base_cfg=actuator, + delay_target="position", + delay_min_lag=_ACTION_LATENCY_RANDOMIZATION_MIN_LAG_PHYSICS_STEPS, + delay_max_lag=_ACTION_LATENCY_RANDOMIZATION_MAX_LAG_PHYSICS_STEPS, + delay_hold_prob=_ACTION_LATENCY_RANDOMIZATION_HOLD_PROB, + delay_update_period=_ACTION_LATENCY_RANDOMIZATION_UPDATE_PERIOD_PHYSICS_STEPS, + ) + for actuator in articulation.actuators + ) + + def _apply_play_mode_overrides(cfg: ManagerBasedRlEnvCfg) -> None: motion_cmd = cfg.commands["motion"] assert isinstance(motion_cmd, MotionCommandCfg) @@ -219,7 +251,11 @@ def make_general_tracking_env_cfg( """Create the General-Tracking-G1 training env.""" cfg = make_tracking_env_cfg() - cfg.scene.entities = {"robot": make_g1_training_robot_cfg()} + cfg.scene.entities = { + "robot": make_g1_training_robot_cfg( + action_latency_randomization=not play, + ) + } joint_pos_action = cfg.actions["joint_pos"] assert isinstance(joint_pos_action, JointPositionActionCfg) diff --git a/train_mimic/tasks/tracking/tracking_env_cfg.py b/train_mimic/tasks/tracking/tracking_env_cfg.py index 85801faf..885e247f 100644 --- a/train_mimic/tasks/tracking/tracking_env_cfg.py +++ b/train_mimic/tasks/tracking/tracking_env_cfg.py @@ -269,7 +269,7 @@ def make_tracking_env_cfg() -> ManagerBasedRlEnvCfg: params={"command_name": "motion", "std": 3.0}, ), "survival": RewardTermCfg(func=mdp.survival, weight=3.0), - "action_rate_l2": RewardTermCfg(func=mdp.action_rate_l2, weight=-0.5), + "action_rate_l2": RewardTermCfg(func=mdp.action_rate_l2, weight=-0.1), "joint_limit": RewardTermCfg( func=mdp.joint_pos_limits, weight=-10.0, From 2c6b89c5d0ed82b8fd8c6ddb349f09d6c7ea5997 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Sat, 20 Jun 2026 00:00:17 +0800 Subject: [PATCH 107/122] Update training stack for mjlab 1.4 --- pyproject.toml | 8 ++++---- tests/test_domain_randomization.py | 10 +++------- train_mimic/tasks/tracking/config/env.py | 19 ++++++++++--------- .../tasks/tracking/tracking_env_cfg.py | 2 +- 4 files changed, 18 insertions(+), 21 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ca01605b..508afa4e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,11 +41,11 @@ sim2real = [ # bash scripts/setup/setup_g1_bridge.sh ] train = [ - "torch>=2.0.0", + "torch>=2.7.0", "numpy>=1.20.0", - "rsl-rl-lib", - "mjlab>=1.2.0", - "wandb>=0.15.0", + "rsl-rl-lib==5.2.0", + "mjlab==1.4.0", + "wandb>=0.22.3", "swanlab", "tqdm>=4.65.0", ] diff --git a/tests/test_domain_randomization.py b/tests/test_domain_randomization.py index 4b17d894..263e42bf 100644 --- a/tests/test_domain_randomization.py +++ b/tests/test_domain_randomization.py @@ -88,7 +88,6 @@ def test_play_env_disables_training_only_domain_randomization() -> None: def test_general_tracking_enables_action_latency_randomization_for_training_only() -> None: import mjlab.tasks # noqa: F401 import train_mimic.tasks # noqa: F401 - from mjlab.actuator.delayed_actuator import DelayedActuatorCfg from mjlab.tasks.registry import load_env_cfg env_cfg = load_env_cfg(DEFAULT_TASK) @@ -99,8 +98,6 @@ def test_general_tracking_enables_action_latency_randomization_for_training_only actuators = env_cfg.scene.entities["robot"].articulation.actuators assert actuators - assert all(isinstance(actuator, DelayedActuatorCfg) for actuator in actuators) - assert all(actuator.delay_target == "position" for actuator in actuators) assert all(actuator.delay_min_lag == 0 for actuator in actuators) assert all(actuator.delay_max_lag == 1 for actuator in actuators) assert all(actuator.delay_hold_prob == 0.8 for actuator in actuators) @@ -110,18 +107,17 @@ def test_general_tracking_enables_action_latency_randomization_for_training_only for group_name in ("actor", "actor_history", "critic", "critic_history"): for term_name, term_cfg in play_cfg.observations[group_name].terms.items(): assert term_cfg.delay_max_lag == 0, (group_name, term_name) - assert not any( - isinstance(actuator, DelayedActuatorCfg) + assert all( + actuator.delay_max_lag == 0 for actuator in play_cfg.scene.entities["robot"].articulation.actuators ) def test_g1_training_robot_cfg_can_enable_action_latency_randomization() -> None: - from mjlab.actuator.delayed_actuator import DelayedActuatorCfg from train_mimic.tasks.tracking.config.env import make_g1_training_robot_cfg robot_cfg = make_g1_training_robot_cfg(action_latency_randomization=True) actuators = robot_cfg.articulation.actuators assert actuators - assert all(isinstance(actuator, DelayedActuatorCfg) for actuator in actuators) + assert all(actuator.delay_max_lag == 1 for actuator in actuators) diff --git a/train_mimic/tasks/tracking/config/env.py b/train_mimic/tasks/tracking/config/env.py index 56a2dab7..403d2c48 100644 --- a/train_mimic/tasks/tracking/config/env.py +++ b/train_mimic/tasks/tracking/config/env.py @@ -8,7 +8,6 @@ import mujoco -from mjlab.actuator.delayed_actuator import DelayedActuatorCfg from mjlab.asset_zoo.robots import G1_ACTION_SCALE, get_g1_robot_cfg from mjlab.envs import ManagerBasedRlEnvCfg from mjlab.envs.mdp.actions import JointPositionActionCfg @@ -101,18 +100,20 @@ def _enable_robot_action_latency_randomization(robot_cfg) -> None: "G1 robot cfg must define articulation actuators before enabling action latency randomization" ) articulation.actuators = tuple( - DelayedActuatorCfg( - base_cfg=actuator, - delay_target="position", - delay_min_lag=_ACTION_LATENCY_RANDOMIZATION_MIN_LAG_PHYSICS_STEPS, - delay_max_lag=_ACTION_LATENCY_RANDOMIZATION_MAX_LAG_PHYSICS_STEPS, - delay_hold_prob=_ACTION_LATENCY_RANDOMIZATION_HOLD_PROB, - delay_update_period=_ACTION_LATENCY_RANDOMIZATION_UPDATE_PERIOD_PHYSICS_STEPS, - ) + _with_action_latency_randomization(actuator) for actuator in articulation.actuators ) +def _with_action_latency_randomization(actuator): + actuator = deepcopy(actuator) + actuator.delay_min_lag = _ACTION_LATENCY_RANDOMIZATION_MIN_LAG_PHYSICS_STEPS + actuator.delay_max_lag = _ACTION_LATENCY_RANDOMIZATION_MAX_LAG_PHYSICS_STEPS + actuator.delay_hold_prob = _ACTION_LATENCY_RANDOMIZATION_HOLD_PROB + actuator.delay_update_period = _ACTION_LATENCY_RANDOMIZATION_UPDATE_PERIOD_PHYSICS_STEPS + return actuator + + def _apply_play_mode_overrides(cfg: ManagerBasedRlEnvCfg) -> None: motion_cmd = cfg.commands["motion"] assert isinstance(motion_cmd, MotionCommandCfg) diff --git a/train_mimic/tasks/tracking/tracking_env_cfg.py b/train_mimic/tasks/tracking/tracking_env_cfg.py index 885e247f..53aea94c 100644 --- a/train_mimic/tasks/tracking/tracking_env_cfg.py +++ b/train_mimic/tasks/tracking/tracking_env_cfg.py @@ -1,6 +1,6 @@ """Base motion tracking task configuration. -Copied from mjlab 1.2.0 ``mjlab.tasks.tracking.tracking_env_cfg`` for local +Copied from mjlab 1.4.0 ``mjlab.tasks.tracking.tracking_env_cfg`` for local customisation. All observation / reward / termination / event terms still reference ``mjlab.tasks.tracking.mdp`` — only the *wiring* lives here. """ From b83bccdc2ffd45b085db7e3d8e046fa0511387dd Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Sun, 21 Jun 2026 17:33:01 +0800 Subject: [PATCH 108/122] Remove training action latency randomization --- tests/test_domain_randomization.py | 38 ----------------- train_mimic/scripts/train.py | 5 +-- train_mimic/tasks/tracking/config/env.py | 41 +------------------ .../tasks/tracking/tracking_env_cfg.py | 2 +- 4 files changed, 4 insertions(+), 82 deletions(-) diff --git a/tests/test_domain_randomization.py b/tests/test_domain_randomization.py index 263e42bf..b98ce407 100644 --- a/tests/test_domain_randomization.py +++ b/tests/test_domain_randomization.py @@ -83,41 +83,3 @@ def test_play_env_disables_training_only_domain_randomization() -> None: assert "physics_material" not in play_cfg.events assert "randomize_rigid_body_mass" not in play_cfg.events assert play_cfg.events == {} - - -def test_general_tracking_enables_action_latency_randomization_for_training_only() -> None: - import mjlab.tasks # noqa: F401 - import train_mimic.tasks # noqa: F401 - from mjlab.tasks.registry import load_env_cfg - - env_cfg = load_env_cfg(DEFAULT_TASK) - for group_name in ("actor", "actor_history", "critic", "critic_history"): - for term_name, term_cfg in env_cfg.observations[group_name].terms.items(): - assert term_cfg.delay_min_lag == 0, (group_name, term_name) - assert term_cfg.delay_max_lag == 0, (group_name, term_name) - - actuators = env_cfg.scene.entities["robot"].articulation.actuators - assert actuators - assert all(actuator.delay_min_lag == 0 for actuator in actuators) - assert all(actuator.delay_max_lag == 1 for actuator in actuators) - assert all(actuator.delay_hold_prob == 0.8 for actuator in actuators) - assert all(actuator.delay_update_period == 4 for actuator in actuators) - - play_cfg = load_env_cfg(DEFAULT_TASK, play=True) - for group_name in ("actor", "actor_history", "critic", "critic_history"): - for term_name, term_cfg in play_cfg.observations[group_name].terms.items(): - assert term_cfg.delay_max_lag == 0, (group_name, term_name) - assert all( - actuator.delay_max_lag == 0 - for actuator in play_cfg.scene.entities["robot"].articulation.actuators - ) - - -def test_g1_training_robot_cfg_can_enable_action_latency_randomization() -> None: - from train_mimic.tasks.tracking.config.env import make_g1_training_robot_cfg - - robot_cfg = make_g1_training_robot_cfg(action_latency_randomization=True) - - actuators = robot_cfg.articulation.actuators - assert actuators - assert all(actuator.delay_max_lag == 1 for actuator in actuators) diff --git a/train_mimic/scripts/train.py b/train_mimic/scripts/train.py index 353aa0de..4f8bf7b3 100644 --- a/train_mimic/scripts/train.py +++ b/train_mimic/scripts/train.py @@ -385,10 +385,7 @@ def _handle_shutdown(signum: int, _frame: Any) -> None: robot_xml = resolve_g1_training_xml(args.robot_xml) if not robot_xml.is_file(): raise FileNotFoundError(f"G1 training MuJoCo XML not found: {robot_xml}") - env_cfg.scene.entities["robot"] = make_g1_training_robot_cfg( - robot_xml, - action_latency_randomization=True, - ) + env_cfg.scene.entities["robot"] = make_g1_training_robot_cfg(robot_xml) env_cfg.robot_xml = str(robot_xml) if args.num_envs is not None: env_cfg.scene.num_envs = args.num_envs diff --git a/train_mimic/tasks/tracking/config/env.py b/train_mimic/tasks/tracking/config/env.py index 403d2c48..48bc3ecb 100644 --- a/train_mimic/tasks/tracking/config/env.py +++ b/train_mimic/tasks/tracking/config/env.py @@ -48,12 +48,6 @@ "randomize_rigid_body_mass", ) -_ACTION_LATENCY_RANDOMIZATION_MIN_LAG_PHYSICS_STEPS = 0 -_ACTION_LATENCY_RANDOMIZATION_MAX_LAG_PHYSICS_STEPS = 1 -_ACTION_LATENCY_RANDOMIZATION_HOLD_PROB = 0.8 -_ACTION_LATENCY_RANDOMIZATION_UPDATE_PERIOD_PHYSICS_STEPS = 4 - - def resolve_g1_training_xml(robot_xml: str | Path | None = None) -> Path: """Resolve the MuJoCo XML used for G1 policy training.""" if robot_xml is None or str(robot_xml).strip() == "": @@ -79,41 +73,14 @@ def _get_g1_training_spec(robot_xml: str | Path | None = None) -> mujoco.MjSpec: return spec -def make_g1_training_robot_cfg( - robot_xml: str | Path | None = None, - *, - action_latency_randomization: bool = False, -): +def make_g1_training_robot_cfg(robot_xml: str | Path | None = None): robot_cfg = get_g1_robot_cfg() robot_cfg.articulation = deepcopy(robot_cfg.articulation) - if action_latency_randomization: - _enable_robot_action_latency_randomization(robot_cfg) xml_path = resolve_g1_training_xml(robot_xml) robot_cfg.spec_fn = partial(_get_g1_training_spec, xml_path) return robot_cfg -def _enable_robot_action_latency_randomization(robot_cfg) -> None: - articulation = robot_cfg.articulation - if articulation is None: - raise ValueError( - "G1 robot cfg must define articulation actuators before enabling action latency randomization" - ) - articulation.actuators = tuple( - _with_action_latency_randomization(actuator) - for actuator in articulation.actuators - ) - - -def _with_action_latency_randomization(actuator): - actuator = deepcopy(actuator) - actuator.delay_min_lag = _ACTION_LATENCY_RANDOMIZATION_MIN_LAG_PHYSICS_STEPS - actuator.delay_max_lag = _ACTION_LATENCY_RANDOMIZATION_MAX_LAG_PHYSICS_STEPS - actuator.delay_hold_prob = _ACTION_LATENCY_RANDOMIZATION_HOLD_PROB - actuator.delay_update_period = _ACTION_LATENCY_RANDOMIZATION_UPDATE_PERIOD_PHYSICS_STEPS - return actuator - - def _apply_play_mode_overrides(cfg: ManagerBasedRlEnvCfg) -> None: motion_cmd = cfg.commands["motion"] assert isinstance(motion_cmd, MotionCommandCfg) @@ -252,11 +219,7 @@ def make_general_tracking_env_cfg( """Create the General-Tracking-G1 training env.""" cfg = make_tracking_env_cfg() - cfg.scene.entities = { - "robot": make_g1_training_robot_cfg( - action_latency_randomization=not play, - ) - } + cfg.scene.entities = {"robot": make_g1_training_robot_cfg()} joint_pos_action = cfg.actions["joint_pos"] assert isinstance(joint_pos_action, JointPositionActionCfg) diff --git a/train_mimic/tasks/tracking/tracking_env_cfg.py b/train_mimic/tasks/tracking/tracking_env_cfg.py index 53aea94c..eb1faaae 100644 --- a/train_mimic/tasks/tracking/tracking_env_cfg.py +++ b/train_mimic/tasks/tracking/tracking_env_cfg.py @@ -269,7 +269,7 @@ def make_tracking_env_cfg() -> ManagerBasedRlEnvCfg: params={"command_name": "motion", "std": 3.0}, ), "survival": RewardTermCfg(func=mdp.survival, weight=3.0), - "action_rate_l2": RewardTermCfg(func=mdp.action_rate_l2, weight=-0.1), + "action_rate_l2": RewardTermCfg(func=mdp.action_rate_l2, weight=-0.5), "joint_limit": RewardTermCfg( func=mdp.joint_pos_limits, weight=-10.0, From 48b35a66e49ca8b7af93aa2bbd83cbf1254afe35 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Mon, 22 Jun 2026 15:06:01 +0800 Subject: [PATCH 109/122] Remove sim HDF5 recording support --- AGENTS.md | 2 +- docs/docs/contributing.md | 2 +- docs/docs/getting-started/quick-start.md | 2 +- docs/docs/reference/architecture.md | 2 +- docs/docs/tutorials/offline-sim2sim.md | 14 --- .../current/contributing.md | 2 +- .../current/getting-started/quick-start.md | 2 +- .../current/reference/architecture.md | 2 +- .../current/tutorials/offline-sim2sim.md | 14 --- scripts/render/render_sim.py | 1 - scripts/run/run_sim.py | 3 +- teleopit/interfaces.py | 9 -- teleopit/pipeline.py | 25 +--- teleopit/recording/__init__.py | 2 - teleopit/recording/hdf5_recorder.py | 114 ------------------ teleopit/sim/loop.py | 24 +--- teleopit/sim/runtime_components.py | 27 +---- teleopit/sim/session.py | 5 +- tests/test_e2e.py | 27 +---- tests/test_interfaces.py | 9 -- tests/test_recording.py | 84 ------------- tests/test_sim_loop.py | 16 +-- 22 files changed, 20 insertions(+), 368 deletions(-) delete mode 100644 teleopit/recording/hdf5_recorder.py delete mode 100644 tests/test_recording.py diff --git a/AGENTS.md b/AGENTS.md index 6cd487e3..d41a5c05 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -57,7 +57,7 @@ teleopit/ # Core inference package ├── sim2real/ │ ├── mp/ # Process-isolated sim2real runtime and IPC │ └── hands/ # Optional LinkerHand driver/mapper plugins -└── recording/ # HDF5Recorder and Pico motion NPZ recording helpers +└── recording/ # Pico motion NPZ recording helpers scripts/ ├── run/run_sim.py # Offline sim2sim pipeline ├── run/run_sim2real.py # G1 sim2real control; supports offline BVH playback and Pico4 diff --git a/docs/docs/contributing.md b/docs/docs/contributing.md index b38bda60..465fc06d 100644 --- a/docs/docs/contributing.md +++ b/docs/docs/contributing.md @@ -34,7 +34,7 @@ teleopit/ # Core inference & deployment package ├── retargeting/ # GMR motion retargeting ├── sim/ # SimulationLoop, reference motion utilities ├── sim2real/ # Hardware state machines -├── recording/ # HDF5Recorder +├── recording/ # Pico motion recording helpers ├── runtime/ # Config parsing, factories, external assets ├── bus/ # InProcessBus for inter-component communication └── configs/ # Hydra YAML configurations diff --git a/docs/docs/getting-started/quick-start.md b/docs/docs/getting-started/quick-start.md index 41203821..bede79a1 100644 --- a/docs/docs/getting-started/quick-start.md +++ b/docs/docs/getting-started/quick-start.md @@ -55,7 +55,7 @@ python scripts/run/run_sim.py controller.policy_path=track.onnx 'viewers=[retarg ## What's Next -- [Offline Sim2Sim Tutorial](../tutorials/offline-sim2sim) - Full guide with recording and rendering +- [Offline Sim2Sim Tutorial](../tutorials/offline-sim2sim) - Full guide with rendering - [Pico Sim2Sim](../tutorials/pico-sim2sim) - Verify Pico tracking in MuJoCo - [Standalone Standing](../tutorials/standalone-standing) - Check G1 bridge, network, and policy standing - [Pico Sim2Real](../tutorials/pico-sim2real) - Deploy Pico teleoperation to Unitree G1 diff --git a/docs/docs/reference/architecture.md b/docs/docs/reference/architecture.md index b4d149ab..8d5bfddc 100644 --- a/docs/docs/reference/architecture.md +++ b/docs/docs/reference/architecture.md @@ -40,7 +40,7 @@ train_mimic/scripts/data | Module | Role | |--------|------| -| `teleopit/interfaces.py` | Stable protocols: InputProvider, Retargeter, Controller, Robot, ObservationBuilder, Recorder | +| `teleopit/interfaces.py` | Stable protocols: InputProvider, Retargeter, Controller, Robot, ObservationBuilder | | `teleopit/runtime/` | Config parsing, path normalization, component assembly, CLI validation | | `teleopit/pipeline.py` | Lightweight facade for offline sim | | `teleopit/sim2real/mp/` | Process-isolated sim2real state machine, IPC, and robot-control loop | diff --git a/docs/docs/tutorials/offline-sim2sim.md b/docs/docs/tutorials/offline-sim2sim.md index 172e1868..57455374 100644 --- a/docs/docs/tutorials/offline-sim2sim.md +++ b/docs/docs/tutorials/offline-sim2sim.md @@ -68,20 +68,6 @@ viewers=none # Headless When all active viewer windows are closed, the simulation ends automatically. ::: -## Recording - -Record simulation data to HDF5: - -```bash -python scripts/run/run_sim.py \ - controller.policy_path=track.onnx \ - input.bvh_file=data/sample_bvh/aiming1_subject1.bvh \ - +record=true \ - recording.output_path=outputs/session.h5 -``` - -Recorded fields: `joint_pos`, `joint_vel`, `mimic_obs`, `action`, `target_dof_pos`, `torque`, `timestamp`. - ## Offline Rendering Render simulation to video (headless): diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/contributing.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/contributing.md index ab2e7b2c..2056aa8b 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/contributing.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/contributing.md @@ -34,7 +34,7 @@ teleopit/ # 核心推理与部署包 ├── retargeting/ # GMR 动作重定向 ├── sim/ # SimulationLoop、参考运动工具 ├── sim2real/ # 真机状态机 -├── recording/ # HDF5Recorder +├── recording/ # Pico motion 录制辅助工具 ├── runtime/ # 配置解析、工厂、外部资源管理 ├── bus/ # InProcessBus 进程内通信 └── configs/ # Hydra YAML 配置 diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/quick-start.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/quick-start.md index 5b05d1c4..1fee2d36 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/quick-start.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/quick-start.md @@ -55,7 +55,7 @@ python scripts/run/run_sim.py controller.policy_path=track.onnx 'viewers=[retarg ## 下一步 -- [离线 Sim2Sim 教程](../tutorials/offline-sim2sim) - 包含录制和渲染的完整指南 +- [离线 Sim2Sim 教程](../tutorials/offline-sim2sim) - 包含渲染的完整指南 - [Pico Sim2Sim](../tutorials/pico-sim2sim) - 在 MuJoCo 中验证 Pico 追踪 - [独立站立测试](../tutorials/standalone-standing) - 检查 G1 bridge、网络和 policy 站立 - [Pico Sim2Real](../tutorials/pico-sim2real) - 将 Pico 遥操作部署到 Unitree G1 diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/architecture.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/architecture.md index 4f48fdbe..3f868eb1 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/architecture.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/architecture.md @@ -40,7 +40,7 @@ train_mimic/scripts/data | 模块 | 职责 | |------|------| -| `teleopit/interfaces.py` | 稳定协议:InputProvider、Retargeter、Controller、Robot、ObservationBuilder、Recorder | +| `teleopit/interfaces.py` | 稳定协议:InputProvider、Retargeter、Controller、Robot、ObservationBuilder | | `teleopit/runtime/` | 配置解析、路径规范化、组件装配、CLI 校验 | | `teleopit/pipeline.py` | 离线仿真的轻量 facade | | `teleopit/sim2real/mp/` | 进程隔离的 sim2real 状态机、IPC 和机器人控制循环 | diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/offline-sim2sim.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/offline-sim2sim.md index 66d317f7..b5731659 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/offline-sim2sim.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/offline-sim2sim.md @@ -68,20 +68,6 @@ viewers=none # 无头模式(不显示窗口) 当所有 Viewer 窗口被关闭后,仿真会自动结束。 ::: -## 录制 - -将仿真数据录制为 HDF5 文件: - -```bash -python scripts/run/run_sim.py \ - controller.policy_path=track.onnx \ - input.bvh_file=data/sample_bvh/aiming1_subject1.bvh \ - +record=true \ - recording.output_path=outputs/session.h5 -``` - -录制包含以下字段:`joint_pos`、`joint_vel`、`mimic_obs`、`action`、`target_dof_pos`、`torque`、`timestamp`。 - ## 离线渲染 在无头模式下将仿真渲染为视频: diff --git a/scripts/render/render_sim.py b/scripts/render/render_sim.py index fe3d0201..9ec33146 100644 --- a/scripts/render/render_sim.py +++ b/scripts/render/render_sim.py @@ -297,7 +297,6 @@ def render_sim2sim( "policy_hz": POLICY_HZ, "pd_hz": PD_HZ, "debug_trace_path": debug_trace_path, - "recording": {"output_path": "/tmp/teleopit_render.h5"}, } ) pipeline = TeleopPipeline(cfg) diff --git a/scripts/run/run_sim.py b/scripts/run/run_sim.py index a695945a..491b9b19 100644 --- a/scripts/run/run_sim.py +++ b/scripts/run/run_sim.py @@ -27,11 +27,10 @@ def main(cfg: DictConfig) -> None: validate_policy_path(cfg, "run_sim.py") pipeline = TeleopPipeline(cfg) num_steps = int(cfg.get("num_steps", 0)) - record = bool(cfg.get("record", False)) if cfg.input.get("provider") == "pico4": print("Waiting for Pico4 body tracking data...") _print_sim_controls(cfg) - result = pipeline.run(num_steps=num_steps, record=record) + result = pipeline.run(num_steps=num_steps) print(result) diff --git a/teleopit/interfaces.py b/teleopit/interfaces.py index c115906f..68c79ffc 100644 --- a/teleopit/interfaces.py +++ b/teleopit/interfaces.py @@ -122,15 +122,6 @@ def subscribe(self, topic: str) -> Any: ... -@runtime_checkable -class Recorder(Protocol): - """Records teleoperation data.""" - - def add_frame(self, data: Dict[str, Any]) -> None: - """Record a single frame of data.""" - ... - - @runtime_checkable class ObservationBuilder(Protocol): """Builds observations for controller from robot state.""" diff --git a/teleopit/pipeline.py b/teleopit/pipeline.py index 5ff4b5f7..234247b6 100644 --- a/teleopit/pipeline.py +++ b/teleopit/pipeline.py @@ -10,7 +10,6 @@ from teleopit.controllers.rl_policy import RLPolicyController from teleopit.inputs import BVHInputProvider, Pico4InputProvider from teleopit.inputs.pico_video import PicoVideoRuntime, parse_pico_video_config -from teleopit.recording.hdf5_recorder import HDF5Recorder from teleopit.retargeting.core import RetargetingModule from teleopit.robots.mujoco_robot import MuJoCoRobot from teleopit.runtime.common import cfg_get @@ -56,29 +55,9 @@ def __init__(self, cfg: DictConfig | dict[str, Any]) -> None: video_runtime=self.video_runtime, ) - def run(self, num_steps: int, record: bool = False) -> dict[str, float | int | str]: + def run(self, num_steps: int) -> dict[str, float | int | str]: if num_steps < 0: raise ValueError("num_steps must be non-negative (0 = infinite)") self.controller.reset() - - if not record: - return dict(self.loop.run(cast(Any, self.input_provider), cast(Any, self.retargeter), num_steps=num_steps)) - - rec_cfg = cast(Any, cfg_get(self.cfg, "recording", {})) - output_path = Path(str(cfg_get(rec_cfg, "output_path", "teleop_session.h5"))).expanduser() - if not output_path.is_absolute(): - output_path = (Path.cwd() / output_path).resolve() - output_path.parent.mkdir(parents=True, exist_ok=True) - - with HDF5Recorder(output_path) as recorder: - result = self.loop.run( - cast(Any, self.input_provider), - cast(Any, self.retargeter), - num_steps=num_steps, - recorder=cast(Any, recorder), - ) - - result_with_path: dict[str, float | int | str] = dict(result) - result_with_path["record_path"] = str(output_path) - return result_with_path + return dict(self.loop.run(cast(Any, self.input_provider), cast(Any, self.retargeter), num_steps=num_steps)) diff --git a/teleopit/recording/__init__.py b/teleopit/recording/__init__.py index 3b6edc87..b8b60782 100644 --- a/teleopit/recording/__init__.py +++ b/teleopit/recording/__init__.py @@ -1,4 +1,3 @@ -from teleopit.recording.hdf5_recorder import HDF5Recorder from teleopit.recording.pico_motion import ( PicoDatasetSpec, RecordingState, @@ -10,7 +9,6 @@ ) __all__ = [ - "HDF5Recorder", "PicoDatasetSpec", "RecordingState", "ensure_pico_dataset_spec", diff --git a/teleopit/recording/hdf5_recorder.py b/teleopit/recording/hdf5_recorder.py deleted file mode 100644 index b9138ec1..00000000 --- a/teleopit/recording/hdf5_recorder.py +++ /dev/null @@ -1,114 +0,0 @@ -"""HDF5-based data recorder for teleoperation sessions.""" -from __future__ import annotations - -import time -from pathlib import Path -from typing import Any, Dict - -import h5py -import numpy as np - - -class HDF5Recorder: - """Records teleoperation data to HDF5 format with chunked storage and compression. - - Supports numerical data fields: joint_pos, joint_vel, mimic_obs, action, timestamp. - First call to add_frame creates datasets based on data keys/shapes (resizable). - Subsequent calls resize and append data. - """ - - def __init__(self, path: str | Path, chunk_size: int = 100): - """Initialize HDF5 recorder. - - Args: - path: Path to HDF5 file to create - chunk_size: Chunk size for HDF5 datasets (affects compression/performance) - """ - self.path = Path(path) - self.chunk_size = chunk_size - self.file: h5py.File | None = None - self.datasets: Dict[str, h5py.Dataset] = {} - self.frame_count = 0 - self.start_time = time.time() - self._initialized = False - - def __enter__(self) -> HDF5Recorder: - """Context manager entry.""" - self.file = h5py.File(self.path, 'w') - return self - - def __exit__(self, exc_type, exc_val, exc_tb) -> None: - """Context manager exit.""" - self.close() - - def add_frame(self, data: Dict[str, np.ndarray]) -> None: - """Add a frame of data to the recording. - - First call creates datasets based on data keys/shapes. - Subsequent calls resize and append data. - - Args: - data: Dictionary mapping field names to numpy arrays - Supported fields: joint_pos, joint_vel, mimic_obs, action, timestamp - """ - if self.file is None: - raise RuntimeError("Recorder not opened. Use context manager or call __enter__") - - if not self._initialized: - self._create_datasets(data) - self._initialized = True - - # Resize and write to each dataset - for key, value in data.items(): - if key not in self.datasets: - raise ValueError(f"Unexpected key '{key}' not in initial frame") - - dataset = self.datasets[key] - # Resize to accommodate new frame - dataset.resize(self.frame_count + 1, axis=0) - # Write data - dataset[self.frame_count] = value - - self.frame_count += 1 - - def _create_datasets(self, data: Dict[str, np.ndarray]) -> None: - """Create resizable HDF5 datasets based on first frame data. - - Args: - data: First frame of data defining dataset structure - """ - for key, value in data.items(): - arr = np.asarray(value) - shape = arr.shape - dtype = arr.dtype - - # Create resizable dataset with chunking and compression - chunk_shape = (self.chunk_size,) + shape - max_shape = (None,) + shape # Unlimited first dimension - - self.datasets[key] = self.file.create_dataset( - key, - shape=(0,) + shape, - maxshape=max_shape, - chunks=chunk_shape, - dtype=dtype, - compression='gzip', - compression_opts=4 - ) - - def close(self) -> None: - """Close the recorder and write metadata.""" - if self.file is None: - return - - # Write metadata - end_time = time.time() - recording_duration = end_time - self.start_time - - self.file.attrs['total_frames'] = self.frame_count - self.file.attrs['recording_time'] = recording_duration - self.file.attrs['start_time'] = self.start_time - self.file.attrs['end_time'] = end_time - - self.file.close() - self.file = None diff --git a/teleopit/sim/loop.py b/teleopit/sim/loop.py index 36918b7b..8e058135 100644 --- a/teleopit/sim/loop.py +++ b/teleopit/sim/loop.py @@ -12,13 +12,13 @@ from teleopit.runtime.reference_config import parse_reference_config from teleopit.runtime.arm_mocap import compose_arm_reference, parse_arm_joint_indices from teleopit.inputs.realtime_packet import RealtimeInputPacket -from teleopit.interfaces import Controller, InputProvider, MessageBus, ObservationBuilder, Recorder, Retargeter, Robot, RobotState +from teleopit.interfaces import Controller, InputProvider, MessageBus, ObservationBuilder, Retargeter, Robot, RobotState from teleopit.sim.reference_timeline import ( ReferenceWindow, ReferenceWindowBuilder, ) from teleopit.sim.realtime_utils import RealtimeReferenceDiagnostics -from teleopit.sim.runtime_components import MotionPreparation, PolicyStepRunner, RunRecorder, RuntimePublisher, ViewerManager +from teleopit.sim.runtime_components import MotionPreparation, PolicyStepRunner, RuntimePublisher, ViewerManager from teleopit.sim.viewer_subprocess import mocap_viewer_proc, start_camera_viewer, start_robot_viewer from teleopit.runtime.mocap_session import MocapSessionManager from teleopit.runtime.offline_playback import OfflinePlaybackController @@ -111,7 +111,7 @@ def _init_reference_config(self) -> None: ) def _init_components(self, viewers: set[str] | None) -> None: - """Build PolicyStepRunner, publisher, recorder helper, and viewer manager.""" + """Build PolicyStepRunner, publisher, and viewer manager.""" self._viewers: set[str] = set(viewers or set()) self._step_runner = PolicyStepRunner( robot=self.robot, @@ -128,7 +128,6 @@ def _init_components(self, viewers: set[str] | None) -> None: reference_anchor_velocity_smoothing_alpha=self._ref_cfg.reference_anchor_velocity_smoothing_alpha, ) self._publisher = RuntimePublisher(self.bus) - self._recorder_helper = RunRecorder() self._viewer_manager = ViewerManager( robot=self.robot, viewers=self._viewers, @@ -142,11 +141,10 @@ def run( input_provider: InputProvider, retargeter: Retargeter, num_steps: int, - recorder: Recorder | None = None, ) -> dict[str, float | int]: from teleopit.sim.session import SimLoopSession - session = SimLoopSession(self, input_provider, retargeter, num_steps, recorder) + session = SimLoopSession(self, input_provider, retargeter, num_steps) return session.run() def run_headless( @@ -154,9 +152,8 @@ def run_headless( input_provider: InputProvider, retargeter: Retargeter, num_steps: int, - recorder: Recorder | None = None, ) -> dict[str, float | int]: - return self.run(input_provider=input_provider, retargeter=retargeter, num_steps=num_steps, recorder=recorder) + return self.run(input_provider=input_provider, retargeter=retargeter, num_steps=num_steps) def _compute_target_dof_pos(self, action: Float32Array) -> Float32Array: return self._step_runner.compute_target_dof_pos(action) @@ -328,17 +325,6 @@ def _build_observation( def _publish(self, mimic_obs: Float32Array, action: Float32Array, robot_state: object) -> None: self._publisher.publish(mimic_obs, action, robot_state) - def _record( - self, - recorder: Recorder | None, - state: object, - mimic_obs: Float32Array, - action: Float32Array, - target_dof_pos: Float32Array, - torque: Float32Array, - ) -> None: - self._recorder_helper.record(recorder, state, mimic_obs, action, target_dof_pos, torque) - def _retarget_to_qpos(self, retargeted: object) -> Float64Array: return self._step_runner._retarget_to_qpos(retargeted) diff --git a/teleopit/sim/runtime_components.py b/teleopit/sim/runtime_components.py index 3c44db23..ae450813 100644 --- a/teleopit/sim/runtime_components.py +++ b/teleopit/sim/runtime_components.py @@ -15,7 +15,7 @@ from teleopit.bus.topics import TOPIC_ACTION, TOPIC_MIMIC_OBS, TOPIC_ROBOT_STATE from teleopit.controllers.observation import VelCmdObservationBuilder from teleopit.controllers import reference_processing as ref_proc -from teleopit.interfaces import MessageBus, ObservationBuilder, Recorder, Robot, RobotState +from teleopit.interfaces import MessageBus, ObservationBuilder, Robot, RobotState from teleopit.retargeting.core import extract_mimic_obs from teleopit.sim.reference_timeline import ReferenceWindow from teleopit.sim.reference_utils import obs_builder_requires_reference_window @@ -54,31 +54,6 @@ def publish(self, mimic_obs: Float32Array, action: Float32Array, robot_state: ob self._bus.publish(TOPIC_ROBOT_STATE, robot_state) -class RunRecorder: - def record( - self, - recorder: Recorder | None, - state: RobotState, - mimic_obs: Float32Array, - action: Float32Array, - target_dof_pos: Float32Array, - torque: Float32Array, - ) -> None: - if recorder is None: - return - - payload: dict[str, object] = { - "joint_pos": np.asarray(state.qpos, dtype=np.float32), - "joint_vel": np.asarray(state.qvel, dtype=np.float32), - "mimic_obs": mimic_obs.astype(np.float32, copy=False), - "action": action.astype(np.float32, copy=False), - "target_dof_pos": target_dof_pos.astype(np.float32, copy=False), - "torque": torque.astype(np.float32, copy=False), - "timestamp": np.asarray(float(state.timestamp), dtype=np.float64), - } - recorder.add_frame(payload) - - class PolicyStepRunner: def __init__( self, diff --git a/teleopit/sim/session.py b/teleopit/sim/session.py index 53f7e3ab..b5b5280e 100644 --- a/teleopit/sim/session.py +++ b/teleopit/sim/session.py @@ -19,7 +19,7 @@ from numpy.typing import NDArray from teleopit.debug.rollout_trace import RolloutTraceWriter -from teleopit.interfaces import InputProvider, Recorder, Retargeter, RobotState +from teleopit.interfaces import InputProvider, Retargeter, RobotState from teleopit.sim.reference_motion import ( OfflineReferenceMotion, interpolate_human_frames, @@ -64,12 +64,10 @@ def __init__( input_provider: InputProvider, retargeter: Retargeter, num_steps: int, - recorder: Recorder | None, ) -> None: self._loop = loop self._input_provider = input_provider self._retargeter = retargeter - self._recorder = recorder # Convenience aliases for heavily-used loop attributes self._step_runner = loop._step_runner @@ -659,7 +657,6 @@ def run(self) -> dict[str, float | int]: target_dof_pos = self._step_runner.compute_target_dof_pos(action) torque, final_state = self._step_runner.apply_control(target_dof_pos) loop._publisher.publish(preparation.mimic_obs, action, final_state) - loop._recorder_helper.record(self._recorder, final_state, preparation.mimic_obs, action, target_dof_pos, torque) self._viewer_manager.write_sim2sim(loop.robot) self._viewer_manager.write_camera(loop.robot) if loop._video_runtime is not None: diff --git a/tests/test_e2e.py b/tests/test_e2e.py index 904515cd..8fbfa306 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -2,7 +2,6 @@ import os from pathlib import Path -from typing import cast import numpy as np import pytest @@ -21,7 +20,6 @@ def _has_module(name: str) -> bool: requires_mujoco = pytest.mark.skipif(not _has_module("mujoco"), reason="mujoco not installed") requires_onnxruntime = pytest.mark.skipif(not _has_module("onnxruntime"), reason="onnxruntime not installed") -requires_h5py = pytest.mark.skipif(not _has_module("h5py"), reason="h5py not installed") requires_mink = pytest.mark.skipif(not _has_module("mink"), reason="mink not installed") @@ -37,9 +35,7 @@ def _asset_paths(project_root: Path) -> tuple[Path, Path, Path]: @requires_mujoco @requires_onnxruntime @requires_mink -@requires_h5py -def test_bvh_to_mujoco_pipeline_stands_and_records(project_root: Path, tmp_dir: Path) -> None: - import h5py +def test_bvh_to_mujoco_pipeline_stands(project_root: Path) -> None: from omegaconf import OmegaConf from teleopit.pipeline import TeleopPipeline @@ -64,7 +60,6 @@ def test_bvh_to_mujoco_pipeline_stands_and_records(project_root: Path, tmp_dir: input_cfg.human_format = "bvh_xsens" input_cfg.robot_name = "unitree_g1" - recording_path = tmp_dir / "e2e.h5" cfg = OmegaConf.create( { "robot": robot_cfg, @@ -72,7 +67,6 @@ def test_bvh_to_mujoco_pipeline_stands_and_records(project_root: Path, tmp_dir: "input": input_cfg, "policy_hz": 50, "pd_hz": 50, - "recording": {"output_path": str(recording_path)}, } ) @@ -86,7 +80,7 @@ def test_bvh_to_mujoco_pipeline_stands_and_records(project_root: Path, tmp_dir: pipeline.bus.subscribe(TOPIC_MIMIC_OBS, lambda _: mimic_count.append(1)) pipeline.bus.subscribe(TOPIC_ROBOT_STATE, lambda _: state_count.append(1)) - result = pipeline.run(num_steps=100, record=True) + result = pipeline.run(num_steps=100) assert float(result["root_height"]) > 0.3 assert int(result["steps"]) == 100 @@ -104,20 +98,3 @@ def test_bvh_to_mujoco_pipeline_stands_and_records(project_root: Path, tmp_dir: assert isinstance(latest_mimic, np.ndarray) assert latest_mimic.shape == (35,) assert latest_state is not None - - record_path = Path(str(result["record_path"])) - assert record_path.exists() - - with h5py.File(record_path, "r") as f: - total_frames = int(np.asarray(f.attrs["total_frames"]).item()) - assert total_frames == 100 - - joint_pos = cast(h5py.Dataset, f["joint_pos"]) - joint_vel = cast(h5py.Dataset, f["joint_vel"]) - mimic_obs = cast(h5py.Dataset, f["mimic_obs"]) - action = cast(h5py.Dataset, f["action"]) - - assert joint_pos.shape[0] == 100 - assert joint_vel.shape[0] == 100 - assert mimic_obs.shape == (100, 35) - assert action.shape == (100, 29) diff --git a/tests/test_interfaces.py b/tests/test_interfaces.py index df57aba9..0847800a 100644 --- a/tests/test_interfaces.py +++ b/tests/test_interfaces.py @@ -6,7 +6,6 @@ InputProvider, MessageBus, ObservationBuilder, - Recorder, RealtimeInputProvider, Retargeter, Robot, @@ -127,11 +126,6 @@ def subscribe(self, topic): return None -class _FakeRecorder: - def add_frame(self, data): - pass - - class _FakeObservationBuilder: def build_observation(self, state, history, action_mimic): return np.zeros(10) @@ -169,9 +163,6 @@ def test_robot_isinstance(self): def test_message_bus_isinstance(self): assert isinstance(_FakeMessageBus(), MessageBus) - def test_recorder_isinstance(self): - assert isinstance(_FakeRecorder(), Recorder) - def test_observation_builder_isinstance(self): assert isinstance(_FakeObservationBuilder(), ObservationBuilder) diff --git a/tests/test_recording.py b/tests/test_recording.py deleted file mode 100644 index 02a65bdb..00000000 --- a/tests/test_recording.py +++ /dev/null @@ -1,84 +0,0 @@ -"""Tests for teleopit.recording.hdf5_recorder — HDF5 write/read, context manager, empty file.""" -import numpy as np -import pytest - -from conftest import requires_h5py - - -@requires_h5py -class TestHDF5RecorderContextManager: - """Context manager and basic write/read.""" - - def test_context_manager_creates_file(self, tmp_dir): - from teleopit.recording.hdf5_recorder import HDF5Recorder - - path = tmp_dir / "test.h5" - with HDF5Recorder(path) as rec: - assert rec.file is not None - # After exit, file should be closed - assert rec.file is None - assert path.exists() - - def test_write_and_read_frames(self, tmp_dir): - import h5py - from teleopit.recording.hdf5_recorder import HDF5Recorder - - path = tmp_dir / "test.h5" - n_frames = 5 - joint_dim = 29 - - with HDF5Recorder(path) as rec: - for i in range(n_frames): - rec.add_frame({ - "joint_pos": np.full(joint_dim, float(i), dtype=np.float32), - "timestamp": np.array(float(i)), - }) - assert rec.frame_count == n_frames - - # Read back - with h5py.File(path, "r") as f: - assert f["joint_pos"].shape == (n_frames, joint_dim) - assert f["timestamp"].shape == (n_frames,) - assert f.attrs["total_frames"] == n_frames - np.testing.assert_allclose(f["joint_pos"][0], 0.0) - np.testing.assert_allclose(f["joint_pos"][4], 4.0) - - def test_empty_file_has_metadata(self, tmp_dir): - import h5py - from teleopit.recording.hdf5_recorder import HDF5Recorder - - path = tmp_dir / "empty.h5" - with HDF5Recorder(path) as rec: - pass # no frames added - - with h5py.File(path, "r") as f: - assert f.attrs["total_frames"] == 0 - assert "recording_time" in f.attrs - - -@requires_h5py -class TestHDF5RecorderErrors: - """Error handling.""" - - def test_add_frame_without_context_raises(self, tmp_dir): - from teleopit.recording.hdf5_recorder import HDF5Recorder - - rec = HDF5Recorder(tmp_dir / "test.h5") - with pytest.raises(RuntimeError, match="not opened"): - rec.add_frame({"x": np.array([1.0])}) - - def test_unexpected_key_raises(self, tmp_dir): - from teleopit.recording.hdf5_recorder import HDF5Recorder - - path = tmp_dir / "test.h5" - with HDF5Recorder(path) as rec: - rec.add_frame({"a": np.array([1.0])}) - with pytest.raises(ValueError, match="Unexpected key"): - rec.add_frame({"b": np.array([2.0])}) - - def test_close_idempotent(self, tmp_dir): - from teleopit.recording.hdf5_recorder import HDF5Recorder - - rec = HDF5Recorder(tmp_dir / "test.h5") - rec.close() # should not raise even if never opened - rec.close() # double close is fine diff --git a/tests/test_sim_loop.py b/tests/test_sim_loop.py index 590ad9d8..11cdebd8 100644 --- a/tests/test_sim_loop.py +++ b/tests/test_sim_loop.py @@ -127,14 +127,6 @@ def reset(self) -> None: self.reset_calls += 1 -class _DummyRecorder: - def __init__(self) -> None: - self.frames: list[dict[str, object]] = [] - - def add_frame(self, data: dict[str, object]) -> None: - self.frames.append(data) - - def _quat_mul(a: np.ndarray, b: np.ndarray) -> np.ndarray: aw, ax, ay, az = a bw, bx, by, bz = b @@ -226,7 +218,7 @@ def test_standing_reference_is_fixed_after_initialization() -> None: @requires_mujoco -def test_simulation_loop_runs_and_records_without_viewers() -> None: +def test_simulation_loop_runs_without_viewers() -> None: from teleopit.sim.loop import SimulationLoop bus = InProcessBus() @@ -240,16 +232,13 @@ def test_simulation_loop_runs_and_records_without_viewers() -> None: viewers=set(), ) - recorder = _DummyRecorder() result = loop.run( input_provider=_DummyInputProvider(), retargeter=_DummyRetargeter(), num_steps=2, - recorder=recorder, ) assert result["steps"] == 2 - assert len(recorder.frames) == 2 latest_action = bus.get_latest(TOPIC_ACTION) latest_mimic = bus.get_latest(TOPIC_MIMIC_OBS) @@ -261,9 +250,6 @@ def test_simulation_loop_runs_and_records_without_viewers() -> None: assert latest_mimic.shape == (35,) assert latest_state is not None - target = np.asarray(recorder.frames[-1]["target_dof_pos"], dtype=np.float32) - np.testing.assert_allclose(target, np.array([0.6, -0.6], dtype=np.float32)) - @requires_mujoco def test_simulation_loop_interpolates_realtime_input_with_one_frame_delay(monkeypatch) -> None: From 333d32002e5f39416b58cc35bdb26a4289451238 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Mon, 22 Jun 2026 17:31:06 +0800 Subject: [PATCH 110/122] Add Pico sim2real recording support --- AGENTS.md | 3 + README.md | 18 + docs/docs/configuration/config-reference.md | 42 ++ docs/docs/getting-started/installation.md | 9 + docs/docs/tutorials/pico-sim2real.md | 24 ++ .../current/configuration/config-reference.md | 40 ++ .../current/getting-started/installation.md | 9 + .../current/tutorials/pico-sim2real.md | 23 ++ pyproject.toml | 7 + teleopit/configs/pico4_sim2real.yaml | 21 + teleopit/configs/sim2real.yaml | 21 + teleopit/configs/sim2real_record.yaml | 16 + teleopit/inputs/pico_video.py | 43 +- teleopit/recording/lerobot_v3.py | 286 +++++++++++++ teleopit/sim2real/mp/ipc.py | 9 +- teleopit/sim2real/mp/messages.py | 12 + teleopit/sim2real/mp/runtime.py | 378 +++++++++++++++++- tests/test_pico_video.py | 55 ++- tests/test_sim2real_multiprocess.py | 288 ++++++++++++- 19 files changed, 1288 insertions(+), 16 deletions(-) create mode 100644 teleopit/configs/sim2real_record.yaml create mode 100644 teleopit/recording/lerobot_v3.py diff --git a/AGENTS.md b/AGENTS.md index d41a5c05..ceaa5264 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -146,6 +146,9 @@ target_dof_pos = clip(action, -10, 10) × action_scale + default_dof_pos - `ARMS` entering/exiting/resume resets policy/reference alignment and uses Kp ramp; offline BVH sim2real does not use `ARMS`, and Unitree remote `B` remains BVH replay - Realtime mode switches and pause/resume use a retargeter-preserving soft reset: policy/reference state, smoothers, and reference alignment are reset, while the GMR IK warm-start is retained - Optional LinkerHand control uses `hands.enabled=true`, `hands.driver=linkerhand_l6|linkerhand_o6`, and `hands.mode=gripper|vr_hand_pose`; default is disabled +- Optional Pico sim2real LeRobot v3 recording uses `--config-name sim2real_record` or `recording.enabled=true`; it requires `input.provider=pico4`, `input.video.enabled=true`, `input.video.source=realsense`, an interactive terminal, and the `recording` extra +- Recording is manual only: terminal `R` starts an episode, `S` saves, `D` discards the active episode, and `Q` shuts down; `STANDING`, `MOCAP`, `ARMS`, and paused mocap are recordable +- Recording captures `observation.images.d435i_rgb` RealSense RGB video at 30Hz plus `observation.state(68)`, `observation.mode(1)`, and `action(36)`; RealSense capture lives in `pico_input` through the normal `input.video` path - `gripper` mode reuses `Pico4InputProvider.get_controller_snapshot()` for Pico grip/trigger open-close control and supports LinkerHand L6 and O6 - `vr_hand_pose` mode reuses `Pico4InputProvider.get_hand_snapshot()` and somehand 0.2.0 public `somehand.api` for continuous Pico hand-pose retargeting; do not start a second `PicoBridge` for hand control - Teleopit owns Pico 26-joint hand-state to 21-landmark conversion; do not import `somehand.pico_input` diff --git a/README.md b/README.md index 56d6559c..bd43f945 100644 --- a/README.md +++ b/README.md @@ -84,6 +84,23 @@ python train_mimic/scripts/data/build_dataset.py \ --spec data/pico_motion/pico_recorded.yaml --force ``` +## Sim2Real LeRobot Recording + +Pico sim2real can also record manual LeRobot v3 episodes from the real G1: + +```bash +pip install -e '.[recording]' +python scripts/run/run_sim2real.py --config-name sim2real_record \ + controller.policy_path=track.onnx \ + recording.task="walk forward" +``` + +Recording uses the terminal controls `R` start, `S` save, `D` discard, and `Q` +shutdown. `STANDING`, `MOCAP`, `ARMS`, and paused mocap can be recorded. The +dataset schema is `observation.images.d435i_rgb` video at 30 Hz, +`observation.state(68)`, `observation.mode(1)`, and `action(36)` as the aligned +reference qpos sent to the policy path. + ## Documentation Full docs at **[BotRunner64.github.io/Teleopit](https://BotRunner64.github.io/Teleopit/)**, covering installation profiles, all tutorials, configuration reference, and architecture. @@ -95,6 +112,7 @@ Full docs at **[BotRunner64.github.io/Teleopit](https://BotRunner64.github.io/Te - Added Pico sim2real `ARMS` mode: Pico/controller `B` toggles between whole-body `MOCAP` and stand-pose body/legs with live retargeted arms. - Bumped Pico input support to pico-bridge 0.2.1 and its corrected tracking pose semantics. - Added optional LinkerHand L6 sim2real modes under `hands.*`: `gripper` from Pico grip/trigger and `vr_hand_pose` from Pico hand pose through somehand 0.2.0 public API. +- Added manual Pico sim2real LeRobot v3 recording with RealSense D435i RGB video, 68D robot state, mode labels, and 36D reference-qpos action labels. - Added LinkerHand O6 support for Pico `gripper` mode with an O6-specific grasp pose. - Set LinkerHand L6 `vr_hand_pose` control to maximum speed while keeping `gripper` at the configured default speed. - Switched default `vr_hand_pose` to a low-latency somehand path with 60 Hz hand retargeting and reduced smoothing. diff --git a/docs/docs/configuration/config-reference.md b/docs/docs/configuration/config-reference.md index ecb66b58..3e4690e8 100644 --- a/docs/docs/configuration/config-reference.md +++ b/docs/docs/configuration/config-reference.md @@ -151,6 +151,48 @@ calling somehand 0.2.0 through `somehand.api` only. | `hands.somehand.temporal_filter_alpha` | somehand input landmark smoothing alpha; `1.0` disables smoothing delay | `1.0` | | `hands.somehand.output_alpha` | somehand qpos output smoothing alpha; `1.0` disables smoothing delay | `1.0` | +### LeRobot Recording (Pico sim2real) + +`recording.enabled=true` is supported only with `input.provider=pico4`, +`input.video.enabled=true`, `input.video.source=realsense`, and an interactive +terminal. The recorder is manual: `R` starts an episode, `S` saves the active +episode, `D` discards the active episode, and `Q` shuts down. `STANDING`, +`MOCAP`, `ARMS`, and paused mocap can be recorded. + +`sim2real_record.yaml` enables both recording and the required RealSense +`input.video` path. Recording does not open a second camera; it consumes the +same frames produced by `pico_input`. + +| Field | Description | Default | +|-------|-------------|---------| +| `recording.enabled` | Enable manual LeRobot v3 recording | `false` | +| `recording.output_dir` | Dataset root directory | `data/lerobot` | +| `recording.repo_id` / `dataset_name` | LeRobot dataset identity | `null` | +| `recording.task` | Task string stored with frames | `demo` | +| `recording.fps` | Recording/video clock rate | `30` | +| `recording.min_episode_seconds` | Discard saved episodes shorter than this duration | `1.0` | +| `recording.record_modes` | Modes that allow recording start and frame writes | `[standing, mocap, arms, pause]` | +| `recording.camera.key` | LeRobot video feature key | `observation.images.d435i_rgb` | +| `recording.camera.width` / `height` / `fps` | RealSense RGB capture settings | `640` / `480` / `30` | +| `recording.camera.device` | Optional RealSense serial | `null` | + +Camera failure behavior is controlled by `input.video.fail_on_error`. + +LeRobot features: + +```text +observation.images.d435i_rgb video [480,640,3] uint8 +observation.state float32[68] +observation.mode float32[1] +action float32[36] +``` + +`observation.state` is ordered as `joint_pos(29)`, `joint_vel(29)`, +`base_quat_wxyz(4)`, `base_ang_vel(3)`, and `projected_gravity(3)`. +`observation.mode` is a numeric categorical: `standing=0`, `mocap=1`, +`arms=2`, and `pause=3`. `action` is the current reference qpos: +`root_pos(3) + root_quat_wxyz(4) + joint_pos(29)`. + ## Critical: `default_dof_pos` The RL policy outputs action **offsets** relative to the default standing pose, not absolute joint angles: diff --git a/docs/docs/getting-started/installation.md b/docs/docs/getting-started/installation.md index 14cf05f7..c8e17094 100644 --- a/docs/docs/getting-started/installation.md +++ b/docs/docs/getting-started/installation.md @@ -73,6 +73,15 @@ scripts/setup/download_somehand_l6_assets.sh This extra is only required when `hands.enabled=true`. +### Sim2Real Recording + +```bash +pip install -e '.[recording]' +``` + +Adds the Pico sim2real stack plus LeRobot, RealSense, and video encoding +dependencies used by `sim2real_record.yaml`. + ## Verify Installation ```bash diff --git a/docs/docs/tutorials/pico-sim2real.md b/docs/docs/tutorials/pico-sim2real.md index ab0f3e2a..3fd1334d 100644 --- a/docs/docs/tutorials/pico-sim2real.md +++ b/docs/docs/tutorials/pico-sim2real.md @@ -95,6 +95,30 @@ python scripts/run/run_sim2real.py \ real_robot.network_interface=eth0 ``` +## Optional LeRobot Recording + +Install the recording extra on the machine that owns Pico input and RealSense: + +```bash +pip install -e '.[recording]' +``` + +Run the recording config: + +```bash +python scripts/run/run_sim2real.py \ + --config-name sim2real_record \ + controller.policy_path=track.onnx \ + real_robot.network_interface=enp130s0 \ + recording.task="walk forward" +``` + +Terminal controls are `R` start episode, `S` save, `D` discard, and `Q` +shutdown. `STANDING`, `MOCAP`, `ARMS`, and paused mocap can be recorded; +saved episodes cannot be discarded afterward. The v1 schema records +`observation.images.d435i_rgb`, `observation.state(68)`, +`observation.mode(1)`, and `action(36)` at 30 Hz. + ## Operator Flow Keep the Unitree remote in hand. `L1+R1` is the emergency stop path into diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md index 3e7e4dc4..6b92edbf 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md @@ -169,3 +169,43 @@ Teleopit 会先将 Pico 手部状态转成 21 个 landmarks,再只通过 someh | `hands.somehand.max_iterations` | `vr_hand_pose` 的 somehand solver 迭代上限 | `12` | | `hands.somehand.temporal_filter_alpha` | somehand 输入 landmarks 平滑 alpha;`1.0` 表示关闭平滑延时 | `1.0` | | `hands.somehand.output_alpha` | somehand qpos 输出平滑 alpha;`1.0` 表示关闭平滑延时 | `1.0` | + +### LeRobot 录制(Pico sim2real) + +`recording.enabled=true` 只支持 `input.provider=pico4`、 +`input.video.enabled=true`、`input.video.source=realsense`,并且需要交互式终端。 +录制是手动控制:`R` 开始 episode,`S` 保存当前 episode,`D` 丢弃当前 episode, +`Q` 关闭。可以录制 `STANDING`、`MOCAP`、`ARMS` 和暂停状态的 mocap。 + +`sim2real_record.yaml` 会同时启用录制和必需的 RealSense `input.video` +路径。录制不会打开第二路相机,而是消费 `pico_input` 已经产生的同一批帧。 + +| 字段 | 说明 | 默认值 | +|---|---|---| +| `recording.enabled` | 启用手动 LeRobot v3 录制 | `false` | +| `recording.output_dir` | 数据集根目录 | `data/lerobot` | +| `recording.repo_id` / `dataset_name` | LeRobot 数据集标识 | `null` | +| `recording.task` | 写入 frame 的任务字符串 | `demo` | +| `recording.fps` | 录制/视频主时钟频率 | `30` | +| `recording.min_episode_seconds` | 保存时短于该时长的 episode 会被丢弃 | `1.0` | +| `recording.record_modes` | 允许开始录制和写帧的模式 | `[standing, mocap, arms, pause]` | +| `recording.camera.key` | LeRobot 视频 feature key | `observation.images.d435i_rgb` | +| `recording.camera.width` / `height` / `fps` | RealSense RGB 采集设置 | `640` / `480` / `30` | +| `recording.camera.device` | 可选 RealSense 序列号 | `null` | + +相机失败时的行为由 `input.video.fail_on_error` 控制。 + +LeRobot features: + +```text +observation.images.d435i_rgb video [480,640,3] uint8 +observation.state float32[68] +observation.mode float32[1] +action float32[36] +``` + +`observation.state` 的顺序是 `joint_pos(29)`、`joint_vel(29)`、 +`base_quat_wxyz(4)`、`base_ang_vel(3)` 和 `projected_gravity(3)`。 +`observation.mode` 是数值类别:`standing=0`、`mocap=1`、 +`arms=2`、`pause=3`。`action` 是当前 reference qpos: +`root_pos(3) + root_quat_wxyz(4) + joint_pos(29)`。 diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md index 6ebaba87..919dc1b2 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md @@ -72,6 +72,15 @@ scripts/setup/download_somehand_l6_assets.sh 只有在 `hands.enabled=true` 时才需要安装这个 extra。 +### Sim2Real 录制 + +```bash +pip install -e '.[recording]' +``` + +该配置包含 Pico sim2real 栈,以及 `sim2real_record.yaml` 使用的 LeRobot、 +RealSense 和视频编码依赖。 + ## 验证安装 ```bash diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md index 1425fcc0..8bf778ad 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md @@ -91,6 +91,29 @@ python scripts/run/run_sim2real.py \ real_robot.network_interface=eth0 ``` +## 可选 LeRobot 录制 + +在负责 Pico 输入和 RealSense 的机器上安装 recording extra: + +```bash +pip install -e '.[recording]' +``` + +运行录制配置: + +```bash +python scripts/run/run_sim2real.py \ + --config-name sim2real_record \ + controller.policy_path=track.onnx \ + real_robot.network_interface=enp130s0 \ + recording.task="walk forward" +``` + +终端控制为:`R` 开始 episode,`S` 保存,`D` 丢弃,`Q` 关闭。可以录制 +`STANDING`、`MOCAP`、`ARMS` 和暂停状态的 mocap;已经保存的 episode 不支持再丢弃。 +v1 schema 以 30 Hz 记录 `observation.images.d435i_rgb`、`observation.state(68)`、 +`observation.mode(1)` 和 `action(36)`。 + ## 操作流程 始终把 Unitree 遥控器拿在手里。`L1+R1` 是进入 `DAMPING` 的急停路径。 diff --git a/pyproject.toml b/pyproject.toml index 508afa4e..0f95aa7c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,13 @@ pico4 = [ "pico-bridge[camera] @ https://github.com/BotRunner64/pico-bridge/releases/download/v0.2.1/pico_bridge-0.2.1-py3-none-any.whl", "teleopit[sim2real]", ] +recording = [ + "teleopit[pico4]", + "lerobot", + "pyrealsense2", + "opencv-python", + "imageio[ffmpeg]", +] dexhand = [ "linkerhand-python-sdk @ file:third_party/linkerhand-python-sdk", "somehand @ file:third_party/somehand", diff --git a/teleopit/configs/pico4_sim2real.yaml b/teleopit/configs/pico4_sim2real.yaml index 2cb15cb5..c8c8d181 100644 --- a/teleopit/configs/pico4_sim2real.yaml +++ b/teleopit/configs/pico4_sim2real.yaml @@ -18,6 +18,27 @@ reference_anchor_velocity_smoothing_alpha: 0.25 reference_steps: [0] reference_debug_log: false +recording: + enabled: false + format: lerobot_v3 + output_dir: data/lerobot + dataset_name: null + repo_id: null + task: demo + fps: 30 + control: terminal + min_episode_seconds: 1.0 + discard_on_shutdown: true + record_modes: [standing, mocap, arms, pause] + camera: + enabled: true + key: observation.images.d435i_rgb + source: realsense + width: 640 + height: 480 + fps: 30 + device: null + runtime: host: 127.0.0.1 base_port: 39700 diff --git a/teleopit/configs/sim2real.yaml b/teleopit/configs/sim2real.yaml index 644eb446..cc53009f 100644 --- a/teleopit/configs/sim2real.yaml +++ b/teleopit/configs/sim2real.yaml @@ -17,6 +17,27 @@ reference_debug_log: false playback: pause_on_end: true +recording: + enabled: false + format: lerobot_v3 + output_dir: data/lerobot + dataset_name: null + repo_id: null + task: demo + fps: 30 + control: terminal + min_episode_seconds: 1.0 + discard_on_shutdown: true + record_modes: [standing, mocap, arms, pause] + camera: + enabled: true + key: observation.images.d435i_rgb + source: realsense + width: 640 + height: 480 + fps: 30 + device: null + runtime: host: 127.0.0.1 base_port: 39700 diff --git a/teleopit/configs/sim2real_record.yaml b/teleopit/configs/sim2real_record.yaml new file mode 100644 index 00000000..2fda8bf4 --- /dev/null +++ b/teleopit/configs/sim2real_record.yaml @@ -0,0 +1,16 @@ +defaults: + - pico4_sim2real + - _self_ + +recording: + enabled: true + +input: + video: + enabled: true + source: realsense + width: 640 + height: 480 + fps: 30 + device: null + fail_on_error: true diff --git a/teleopit/inputs/pico_video.py b/teleopit/inputs/pico_video.py index a1c3191b..fd1ed25e 100644 --- a/teleopit/inputs/pico_video.py +++ b/teleopit/inputs/pico_video.py @@ -6,7 +6,7 @@ import logging import threading import time -from typing import Any +from typing import Any, Callable import numpy as np @@ -70,11 +70,13 @@ def __init__( config: PicoVideoConfig, mode: str, robot: Any | None = None, + frame_callback: Callable[[np.ndarray, float], None] | None = None, ) -> None: self._provider = provider self._config = config self._mode = mode self._robot = robot + self._frame_callback = frame_callback self._producer: _VideoProducer | None = None self._stopped = False @@ -96,16 +98,16 @@ def start(self) -> None: if self._config.source == "test-pattern": logger.info("Pico video enabled via pico-bridge test-pattern source") return - if not callable(getattr(self._provider, "push_video_frame", None)): + if self._frame_callback is None and not callable(getattr(self._provider, "push_video_frame", None)): self._handle_error(RuntimeError("Pico input provider does not support push_video_frame")) return producer: _VideoProducer | None = None try: if self._config.source == "realsense": - producer = _RealSenseVideoProducer(self._provider, self._config) + producer = _RealSenseVideoProducer(self._provider, self._config, self._frame_callback) elif self._config.source == "mujoco": - producer = _MujocoCameraVideoProducer(self._provider, self._config, self._robot) + producer = _MujocoCameraVideoProducer(self._provider, self._config, self._robot, self._frame_callback) else: raise ValueError(f"Unsupported Pico video source: {self._config.source!r}") self._producer = producer @@ -152,9 +154,15 @@ def stop(self) -> None: ... class _RealSenseVideoProducer(_VideoProducer): - def __init__(self, provider: Any, config: PicoVideoConfig) -> None: + def __init__( + self, + provider: Any, + config: PicoVideoConfig, + frame_callback: Callable[[np.ndarray, float], None] | None = None, + ) -> None: self._provider = provider self._config = config + self._frame_callback = frame_callback self._stop_event = threading.Event() self._ready_event = threading.Event() self._thread = threading.Thread(target=self._run, name="pico_realsense_video", daemon=True) @@ -206,7 +214,13 @@ def _run(self) -> None: if not color_frame: continue rgb = np.ascontiguousarray(np.asanyarray(color_frame.get_data()), dtype=np.uint8) - self._pushed_frames = int(self._provider.push_video_frame(rgb)) + timestamp_s = time.monotonic() + if self._frame_callback is not None: + self._frame_callback(rgb, timestamp_s) + if callable(getattr(self._provider, "push_video_frame", None)): + self._pushed_frames = int(self._provider.push_video_frame(rgb)) + else: + self._pushed_frames += 1 finally: pipeline.stop() except BaseException as exc: @@ -216,10 +230,17 @@ def _run(self) -> None: class _MujocoCameraVideoProducer(_VideoProducer): - def __init__(self, provider: Any, config: PicoVideoConfig, robot: Any | None) -> None: + def __init__( + self, + provider: Any, + config: PicoVideoConfig, + robot: Any | None, + frame_callback: Callable[[np.ndarray, float], None] | None = None, + ) -> None: self._provider = provider self._config = config self._robot = robot + self._frame_callback = frame_callback self._renderer: Any | None = None self._next_frame_time = 0.0 self._camera_name = "d435i_rgb" @@ -253,7 +274,13 @@ def tick(self) -> None: raise RuntimeError("MuJoCo Pico video requires robot.data") self._renderer.update_scene(data, camera=self._camera_name) frame = np.ascontiguousarray(self._renderer.render(), dtype=np.uint8) - self._pushed_frames = int(self._provider.push_video_frame(frame)) + timestamp_s = time.monotonic() + if self._frame_callback is not None: + self._frame_callback(frame, timestamp_s) + if callable(getattr(self._provider, "push_video_frame", None)): + self._pushed_frames = int(self._provider.push_video_frame(frame)) + else: + self._pushed_frames += 1 self._next_frame_time = now + 1.0 / float(self._config.fps) def stop(self) -> None: diff --git a/teleopit/recording/lerobot_v3.py b/teleopit/recording/lerobot_v3.py new file mode 100644 index 00000000..3c2cf00d --- /dev/null +++ b/teleopit/recording/lerobot_v3.py @@ -0,0 +1,286 @@ +"""LeRobot v3 adapter and schema helpers for Teleopit sim2real recording.""" + +from __future__ import annotations + +from dataclasses import dataclass +import json +from pathlib import Path +from typing import Any + +import numpy as np + +from teleopit.constants import FULL_QPOS_DIM, NUM_JOINTS +from teleopit.controllers.observation import _quat_rotate_np +from teleopit.math_utils import quat_inv_np +from teleopit.runtime.common import cfg_get + + +IMAGE_KEY = "observation.images.d435i_rgb" +STATE_KEY = "observation.state" +MODE_KEY = "observation.mode" +ACTION_KEY = "action" +STATE_DIM = 68 +MODE_DIM = 1 +ACTION_DIM = FULL_QPOS_DIM +DEFAULT_IMAGE_SHAPE = (480, 640, 3) +MODE_CODES = { + "standing": 0, + "mocap": 1, + "arms": 2, + "pause": 3, +} + + +@dataclass(frozen=True) +class RecordingSchema: + image_key: str + image_shape: tuple[int, int, int] + state_key: str = STATE_KEY + state_dim: int = STATE_DIM + mode_key: str = MODE_KEY + mode_dim: int = MODE_DIM + action_key: str = ACTION_KEY + action_dim: int = ACTION_DIM + + +def build_recording_schema(camera_cfg: Any) -> RecordingSchema: + key = str(cfg_get(camera_cfg, "key", IMAGE_KEY)) + width = int(cfg_get(camera_cfg, "width", DEFAULT_IMAGE_SHAPE[1])) + height = int(cfg_get(camera_cfg, "height", DEFAULT_IMAGE_SHAPE[0])) + if width <= 0 or height <= 0: + raise ValueError("recording.camera.width and recording.camera.height must be positive") + return RecordingSchema(image_key=key, image_shape=(height, width, 3)) + + +def lerobot_features(schema: RecordingSchema) -> dict[str, dict[str, object]]: + return { + schema.image_key: { + "dtype": "video", + "shape": schema.image_shape, + "names": ["height", "width", "channel"], + }, + schema.state_key: { + "dtype": "float32", + "shape": (schema.state_dim,), + "names": ["state"], + }, + schema.mode_key: { + "dtype": "float32", + "shape": (schema.mode_dim,), + "names": ["mode"], + }, + schema.action_key: { + "dtype": "float32", + "shape": (schema.action_dim,), + "names": ["action"], + }, + } + + +def modality_sidecar(schema: RecordingSchema) -> dict[str, object]: + return { + "version": 1, + "features": { + schema.image_key: { + "type": "video", + "shape": list(schema.image_shape), + "dtype": "uint8", + }, + schema.state_key: { + "type": "low_dim", + "shape": [schema.state_dim], + "dtype": "float32", + "slices": { + "joint_pos": [0, 29], + "joint_vel": [29, 58], + "base_quat_wxyz": [58, 62], + "base_ang_vel": [62, 65], + "projected_gravity": [65, 68], + }, + }, + schema.mode_key: { + "type": "categorical", + "shape": [schema.mode_dim], + "dtype": "float32", + "codes": MODE_CODES, + }, + schema.action_key: { + "type": "low_dim", + "shape": [schema.action_dim], + "dtype": "float32", + "slices": { + "root_pos": [0, 3], + "root_quat_wxyz": [3, 7], + "joint_pos": [7, 36], + }, + }, + }, + } + + +def build_observation_state(robot_state: object) -> np.ndarray: + joint_pos = np.asarray(getattr(robot_state, "qpos"), dtype=np.float32).reshape(-1)[:NUM_JOINTS] + joint_vel = np.asarray(getattr(robot_state, "qvel"), dtype=np.float32).reshape(-1)[:NUM_JOINTS] + base_quat = np.asarray(getattr(robot_state, "quat"), dtype=np.float32).reshape(-1)[:4] + base_ang_vel = np.asarray(getattr(robot_state, "ang_vel"), dtype=np.float32).reshape(-1)[:3] + if joint_pos.shape[0] != NUM_JOINTS: + raise ValueError(f"robot_state.qpos must contain {NUM_JOINTS} joints, got {joint_pos.shape[0]}") + if joint_vel.shape[0] != NUM_JOINTS: + raise ValueError(f"robot_state.qvel must contain {NUM_JOINTS} joints, got {joint_vel.shape[0]}") + if base_quat.shape[0] != 4: + raise ValueError(f"robot_state.quat must be 4D (wxyz), got {base_quat.shape[0]}") + if base_ang_vel.shape[0] != 3: + raise ValueError(f"robot_state.ang_vel must be 3D, got {base_ang_vel.shape[0]}") + gravity_w = np.array([0.0, 0.0, -1.0], dtype=np.float32) + projected_gravity = _quat_rotate_np(quat_inv_np(base_quat), gravity_w) + state = np.concatenate( + [joint_pos, joint_vel, base_quat, base_ang_vel, projected_gravity], + dtype=np.float32, + ) + if state.shape[0] != STATE_DIM: + raise ValueError(f"recording observation.state must be {STATE_DIM}D, got {state.shape[0]}") + return state + + +def normalize_action_reference_qpos(reference_qpos: object) -> np.ndarray: + action = np.asarray(reference_qpos, dtype=np.float32).reshape(-1)[:ACTION_DIM] + if action.shape[0] != ACTION_DIM: + raise ValueError(f"recording action reference qpos must be {ACTION_DIM}D, got {action.shape[0]}") + return action + + +def build_mode_observation(mode: str) -> np.ndarray: + normalized = str(mode).strip().lower() + if normalized not in MODE_CODES: + raise ValueError(f"Unsupported recording mode {mode!r}; expected one of {sorted(MODE_CODES)}") + return np.array([MODE_CODES[normalized]], dtype=np.float32) + + +class TeleopitLeRobotV3Recorder: + """Small adapter around LeRobot v3 dataset writing.""" + + def __init__( + self, + *, + dataset: Any, + output_dir: Path, + schema: RecordingSchema, + ) -> None: + self._dataset = dataset + self._output_dir = output_dir + self._schema = schema + self._active = False + self._frames_in_episode = 0 + + @classmethod + def create( + cls, + *, + output_dir: str | Path, + dataset_name: str | None, + repo_id: str | None, + task: str, + fps: int, + schema: RecordingSchema, + ) -> "TeleopitLeRobotV3Recorder": + try: + from lerobot.common.datasets.lerobot_dataset import LeRobotDataset + except Exception as exc: # pragma: no cover - exercised in environments without optional extra. + raise RuntimeError( + "recording.enabled=true requires the optional LeRobot dependency. " + "Install Teleopit with the recording extra, for example: pip install -e '.[recording]'." + ) from exc + + root = Path(output_dir) + root.mkdir(parents=True, exist_ok=True) + dataset_repo_id = repo_id or dataset_name or "teleopit/sim2real" + features = lerobot_features(schema) + + try: + dataset = LeRobotDataset.create( + repo_id=dataset_repo_id, + fps=int(fps), + root=root, + features=features, + use_videos=True, + ) + except TypeError: + dataset = LeRobotDataset.create( + repo_id=dataset_repo_id, + fps=int(fps), + root=root, + features=features, + ) + recorder = cls(dataset=dataset, output_dir=root, schema=schema) + recorder._write_modality_sidecar() + return recorder + + def start_episode(self) -> None: + if self._active: + raise RuntimeError("Cannot start a new recording episode while one is active") + self._active = True + self._frames_in_episode = 0 + + def add_frame( + self, + *, + image: np.ndarray, + state: np.ndarray, + mode: np.ndarray, + action: np.ndarray, + task: str, + ) -> None: + if not self._active: + raise RuntimeError("Cannot add a recording frame without an active episode") + image_arr = np.asarray(image, dtype=np.uint8) + if tuple(image_arr.shape) != self._schema.image_shape: + raise ValueError(f"{self._schema.image_key} frame shape {image_arr.shape} != {self._schema.image_shape}") + state_arr = np.asarray(state, dtype=np.float32).reshape(-1) + mode_arr = np.asarray(mode, dtype=np.float32).reshape(-1) + action_arr = np.asarray(action, dtype=np.float32).reshape(-1) + if state_arr.shape[0] != self._schema.state_dim: + raise ValueError(f"{self._schema.state_key} must be {self._schema.state_dim}D") + if mode_arr.shape[0] != self._schema.mode_dim: + raise ValueError(f"{self._schema.mode_key} must be {self._schema.mode_dim}D") + if action_arr.shape[0] != self._schema.action_dim: + raise ValueError(f"{self._schema.action_key} must be {self._schema.action_dim}D") + self._dataset.add_frame( + { + self._schema.image_key: image_arr, + self._schema.state_key: state_arr, + self._schema.mode_key: mode_arr, + self._schema.action_key: action_arr, + "task": str(task), + } + ) + self._frames_in_episode += 1 + + def save_episode(self) -> None: + if not self._active: + return + self._dataset.save_episode() + self._active = False + self._frames_in_episode = 0 + + def discard_episode(self) -> None: + if not self._active: + return + clear = getattr(self._dataset, "clear_episode_buffer", None) + if callable(clear): + clear() + else: + buffer_attr = getattr(self._dataset, "episode_buffer", None) + if isinstance(buffer_attr, dict): + buffer_attr.clear() + self._active = False + self._frames_in_episode = 0 + + def finalize(self) -> None: + consolidate = getattr(self._dataset, "consolidate", None) + if callable(consolidate): + consolidate() + + def _write_modality_sidecar(self) -> None: + path = self._output_dir / "meta" / "modality.json" + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(modality_sidecar(self._schema), indent=2) + "\n", encoding="utf-8") diff --git a/teleopit/sim2real/mp/ipc.py b/teleopit/sim2real/mp/ipc.py index 109d41d8..c12cd878 100644 --- a/teleopit/sim2real/mp/ipc.py +++ b/teleopit/sim2real/mp/ipc.py @@ -17,6 +17,7 @@ REFERENCE_TOPIC = "reference" MODE_TOPIC = "mode" VIDEO_TOPIC = "video" +RECORD_TOPIC = "record" HEALTH_TOPIC = "health" COMMAND_TOPIC = "command" @@ -30,6 +31,7 @@ class Sim2RealIpcEndpoints: reference_pub: str mode_pub: str video_pub: str + record_pub: str health_pub: str command_pub: str reference_command_pub: str @@ -46,9 +48,10 @@ def default_endpoints(*, host: str = "127.0.0.1", base_port: int = 39700) -> Sim reference_pub=f"{prefix}{base_port + 4}", mode_pub=f"{prefix}{base_port + 5}", video_pub=f"{prefix}{base_port + 6}", - health_pub=f"{prefix}{base_port + 7}", - command_pub=f"{prefix}{base_port + 8}", - reference_command_pub=f"{prefix}{base_port + 9}", + record_pub=f"{prefix}{base_port + 7}", + health_pub=f"{prefix}{base_port + 8}", + command_pub=f"{prefix}{base_port + 9}", + reference_command_pub=f"{prefix}{base_port + 10}", ) diff --git a/teleopit/sim2real/mp/messages.py b/teleopit/sim2real/mp/messages.py index d16a1496..ec22eda5 100644 --- a/teleopit/sim2real/mp/messages.py +++ b/teleopit/sim2real/mp/messages.py @@ -59,6 +59,18 @@ class ModeStatePacket: seq: int +@dataclass(frozen=True) +class RecordStepPacket: + timestamp_s: float + mode: str + mocap_active: bool + recordable: bool + observation_state: Float64Array + observation_mode: Float64Array + action_reference_qpos: Float64Array + seq: int + + @dataclass(frozen=True) class HealthPacket: worker: str diff --git a/teleopit/sim2real/mp/runtime.py b/teleopit/sim2real/mp/runtime.py index 668e467d..a8af3f70 100644 --- a/teleopit/sim2real/mp/runtime.py +++ b/teleopit/sim2real/mp/runtime.py @@ -6,7 +6,9 @@ import multiprocessing as mp from multiprocessing.synchronize import Event as MpEvent from enum import Enum +import importlib.util from pathlib import Path +import sys import time from typing import Any, Callable @@ -32,6 +34,12 @@ ) from teleopit.runtime.mocap_session import MocapSessionManager, MocapSessionState from teleopit.runtime.reference_config import parse_reference_config +from teleopit.runtime.terminal_keyboard import TerminalKeyboardReader +from teleopit.recording.lerobot_v3 import ( + build_mode_observation, + build_observation_state, + normalize_action_reference_qpos, +) from teleopit.sim.reference_motion import OfflineReferenceMotion from teleopit.sim.reference_timeline import ReferenceTimeline, ReferenceWindow, ReferenceWindowBuilder from teleopit.sim.reference_utils import ( @@ -50,7 +58,9 @@ HAND_TOPIC, HEALTH_TOPIC, MODE_TOPIC, + RECORD_TOPIC, REFERENCE_TOPIC, + VIDEO_TOPIC, LatestSubscriber, Sim2RealIpcEndpoints, ZmqPublisher, @@ -63,8 +73,11 @@ HealthPacket, ModeStatePacket, ReferencePacket, + RecordStepPacket, SnapshotPacket, + SharedFrameDescriptor, ) +from teleopit.sim2real.mp.shm import SharedFrameRingReader, SharedFrameRingWriter from teleopit.sim2real.reference_processor import Sim2RealReferenceProcessor from teleopit.sim2real.remote import UnitreeRemote from teleopit.sim2real.safety import Sim2RealSafetyManager @@ -234,6 +247,18 @@ def _input_provider_kind(cfg: Any) -> str: return str(cfg_get(cfg_get(cfg, "input", {}) or {}, "provider", "bvh")).strip().lower() +def _recording_cfg(cfg: Any) -> Any: + return cfg_get(cfg, "recording", {}) or {} + + +def _recording_enabled(cfg: Any) -> bool: + return bool(cfg_get(_recording_cfg(cfg), "enabled", False)) + + +def _recording_camera_cfg(cfg: Any) -> Any: + return cfg_get(_recording_cfg(cfg), "camera", {}) or {} + + def _validate_new_runtime_config(cfg: Any) -> None: legacy_keys = [key for key in ("sim2real_runtime", "multiprocess", "dexterous_hand") if cfg_get(cfg, key, None) is not None] if legacy_keys: @@ -247,6 +272,51 @@ def _validate_new_runtime_config(cfg: Any) -> None: hands_cfg = cfg_get(cfg, "hands", {}) or {} if bool(cfg_get(hands_cfg, "enabled", False)) and provider != "pico4": raise ValueError("hands.enabled=true requires input.provider=pico4") + if _recording_enabled(cfg): + if provider != "pico4": + raise ValueError("recording.enabled=true requires input.provider=pico4") + rec_cfg = _recording_cfg(cfg) + if str(cfg_get(rec_cfg, "format", "lerobot_v3")) != "lerobot_v3": + raise ValueError("Only recording.format=lerobot_v3 is supported") + if str(cfg_get(rec_cfg, "control", "terminal")) != "terminal": + raise ValueError("Only recording.control=terminal is supported") + camera_cfg = _recording_camera_cfg(cfg) + if not bool(cfg_get(camera_cfg, "enabled", True)): + raise ValueError("recording.camera.enabled=false is not supported for LeRobot recording") + if str(cfg_get(camera_cfg, "source", "realsense")).lower() != "realsense": + raise ValueError("recording.camera.source must be realsense") + if int(cfg_get(rec_cfg, "fps", 30)) != int(cfg_get(camera_cfg, "fps", 30)): + raise ValueError("recording.fps must match recording.camera.fps") + input_video = parse_pico_video_config(cfg_get(cfg, "input", {}) or {}) + if not input_video.enabled: + raise ValueError("recording.enabled=true requires input.video.enabled=true") + if input_video.source != "realsense": + raise ValueError("recording.enabled=true requires input.video.source=realsense") + if int(input_video.width) != int(cfg_get(camera_cfg, "width", 640)): + raise ValueError("recording.camera.width must match input.video.width") + if int(input_video.height) != int(cfg_get(camera_cfg, "height", 480)): + raise ValueError("recording.camera.height must match input.video.height") + if int(input_video.fps) != int(cfg_get(camera_cfg, "fps", 30)): + raise ValueError("recording.camera.fps must match input.video.fps") + input_device = input_video.device + camera_device = cfg_get(camera_cfg, "device", None) + camera_device = None if camera_device in (None, "", "null") else str(camera_device) + if input_device != camera_device: + raise ValueError("recording.camera.device must match input.video.device") + + +def _require_recording_dependencies() -> None: + if importlib.util.find_spec("lerobot") is None: + raise RuntimeError( + "recording.enabled=true requires the recording dependencies and LeRobot v3 adapter. " + "Install Teleopit with: pip install -e '.[recording]'." + ) + try: + from teleopit.recording.lerobot_v3 import TeleopitLeRobotV3Recorder + + TeleopitLeRobotV3Recorder.create + except Exception as exc: + raise RuntimeError("LeRobot v3 recording adapter is unavailable") from exc def _worker_loop(name: str, fn: Callable[[], None]) -> None: @@ -285,12 +355,23 @@ def __init__(self, cfg: Any) -> None: host=str(cfg_get(mp_cfg, "host", "127.0.0.1")), base_port=int(cfg_get(mp_cfg, "base_port", 39700)), ) + self._command_pub: ZmqPublisher | None = None + self._keyboard: TerminalKeyboardReader | None = None + if _recording_enabled(self.cfg): + _require_recording_dependencies() + if not sys.stdin.isatty(): + raise RuntimeError("recording.enabled=true requires an interactive TTY for terminal controls") def run(self) -> None: logger.info("Starting sim2real runtime") try: self._start_processes() + if _recording_enabled(self.cfg): + self._command_pub = ZmqPublisher(self._endpoints.command_pub) + self._keyboard = TerminalKeyboardReader() + logger.info("Recording controls: R=start, S=save, D=discard, Q=shutdown") while not self._stop_event.is_set(): + self._poll_terminal_recording_controls() time.sleep(0.2) critical_names = {"robot_control", "reference"} if _input_provider_kind(self.cfg) == "pico4": @@ -323,6 +404,8 @@ def run(self) -> None: def shutdown(self) -> None: self._stop_event.set() + if self._command_pub is not None: + self._command_pub.publish(COMMAND_TOPIC, CommandPacket(command="shutdown", timestamp_s=time.monotonic())) for process in self._processes: process.join(timeout=self._shutdown_timeout_s) for process in self._processes: @@ -331,6 +414,12 @@ def shutdown(self) -> None: process.terminate() process.join(timeout=1.0) self._processes.clear() + if self._keyboard is not None: + self._keyboard.close() + self._keyboard = None + if self._command_pub is not None: + self._command_pub.close() + self._command_pub = None def _start_processes(self) -> None: if self._processes: @@ -348,6 +437,8 @@ def _start_processes(self) -> None: hands_cfg = cfg_get(self.cfg, "hands", {}) or {} if bool(cfg_get(hands_cfg, "enabled", False)): specs.append(("hand_worker", _run_hand_worker)) + if _recording_enabled(self.cfg): + specs.append(("recording_worker", _run_recording_worker)) video_cfg = parse_pico_video_config(cfg_get(self.cfg, "input", {})) if video_cfg.enabled: logger.info("Pico video runs inside pico_input so frames are pushed directly to PicoBridge") @@ -361,6 +452,33 @@ def _start_processes(self) -> None: process.start() self._processes.append(process) + def _poll_terminal_recording_controls(self) -> None: + if self._keyboard is None or self._command_pub is None: + return + events = self._keyboard.poll() + if not events: + return + for event in events: + command = map_recording_key_to_command(event.key) + if command is None: + continue + self._command_pub.publish(COMMAND_TOPIC, CommandPacket(command=command, timestamp_s=time.monotonic())) + if command == "shutdown": + self._stop_event.set() + + +def map_recording_key_to_command(key: str) -> str | None: + normalized = str(key).strip().lower() + if normalized == "r": + return "record_start" + if normalized == "s": + return "record_save" + if normalized == "d": + return "record_discard" + if normalized == "q": + return "shutdown" + return None + def _run_pico_io_worker( cfg: dict[str, Any], @@ -394,11 +512,28 @@ def _main() -> None: controller_pub = ZmqPublisher(endpoints.controller_pub) events_pub = ZmqPublisher(endpoints.control_events_pub) health_pub = ZmqPublisher(endpoints.health_pub) + video_pub = ZmqPublisher(endpoints.video_pub) if _recording_enabled(cfg) else None command_sub = LatestSubscriber(endpoints.command_pub, COMMAND_TOPIC) + frame_writer: SharedFrameRingWriter | None = None + + def _publish_recording_frame(frame: NDArray[np.generic], timestamp_s: float) -> None: + nonlocal frame_writer + if video_pub is None: + return + if frame_writer is None: + frame_writer = SharedFrameRingWriter( + shape=tuple(np.asarray(frame).shape), + dtype=np.uint8, + slots=int(cfg_get(_mp_cfg(cfg), "video_slots", 3)), + ) + descriptor = frame_writer.write(np.asarray(frame, dtype=np.uint8), timestamp_s=float(timestamp_s)) + video_pub.publish(VIDEO_TOPIC, descriptor) + video_runtime = PicoVideoRuntime( provider=provider, config=video_cfg, mode="sim2real", + frame_callback=_publish_recording_frame if _recording_enabled(cfg) else None, ) hz = float(cfg_get(_mp_cfg(cfg), "pico_input_hz", 120.0)) @@ -484,9 +619,12 @@ def _main() -> None: time.sleep(sleep_s) finally: video_runtime.stop() + if frame_writer is not None: + frame_writer.close(unlink=True) command_sub.close() - for publisher in (body_pub, hand_pub, controller_pub, events_pub, health_pub): - publisher.close() + for publisher in (body_pub, hand_pub, controller_pub, events_pub, health_pub, video_pub): + if publisher is not None: + publisher.close() provider.close() _worker_loop("pico_input", _main) @@ -867,6 +1005,7 @@ def __init__( self._command_sub = LatestSubscriber(endpoints.command_pub, COMMAND_TOPIC) self._reference_command_pub = ZmqPublisher(endpoints.reference_command_pub) self._mode_pub = ZmqPublisher(endpoints.mode_pub) + self._record_pub = ZmqPublisher(endpoints.record_pub) if _recording_enabled(cfg) else None viewers = _parse_sim2real_viewers(cfg) self._retarget_viewer = _Sim2RealRetargetViewer( @@ -925,6 +1064,8 @@ def shutdown(self) -> None: self._command_sub.close() self._reference_command_pub.close() self._mode_pub.close() + if self._record_pub is not None: + self._record_pub.close() self.robot.close() def _build_policy_and_obs(self) -> tuple[Any, Any]: @@ -1015,6 +1156,7 @@ def _standing_step(self) -> None: self._last_action = np.asarray(action, dtype=np.float32).reshape(-1) self._last_retarget_qpos = qpos.copy() self._last_commanded_motion_qpos = qpos.copy() + self._publish_record_step(robot_state=robot_state, reference_qpos=qpos) self._write_retarget_viewer(qpos) def _mocap_step(self) -> None: @@ -1090,6 +1232,7 @@ def _execute_mocap_pipeline( self._ref_proc.last_reference_qpos = reference_qpos.copy() self._last_commanded_motion_qpos = qpos.copy() self._last_mocap_hold_reason = None + self._publish_record_step(robot_state=robot_state, reference_qpos=qpos) self._write_retarget_viewer(qpos) def _compose_arm_reference(self, retarget_qpos: Float64Array) -> Float64Array: @@ -1212,6 +1355,7 @@ def _enter_damping(self) -> None: logger.info("DAMPING: exiting debug mode...") self.robot.exit_debug_mode() self.mode = RobotMode.DAMPING + self._publish_damping_record_step() self._ref_proc.last_reference_qpos = None self._mocap_reentry_armed = False self._mocap_session.reset() @@ -1362,6 +1506,7 @@ def _run_static_mocap_step(self, hold_qpos: Float64Array) -> None: self._last_retarget_qpos = qpos.copy() self._ref_proc.last_reference_qpos = qpos.copy() self._last_commanded_motion_qpos = qpos.copy() + self._publish_record_step(robot_state=robot_state, reference_qpos=qpos) self._write_retarget_viewer(qpos) def _hold_mocap_reference(self, reason: str, *, detail: str | None = None) -> None: @@ -1388,6 +1533,60 @@ def _publish_mode_state(self) -> None: ), ) + def _publish_record_step(self, *, robot_state: object, reference_qpos: Float64Array) -> None: + if self._record_pub is None: + return + record_mode = self._recording_mode_label() + mocap_like = self.mode in (RobotMode.MOCAP, RobotMode.ARMS) + active = mocap_like and self._mocap_session.state == MocapSessionState.ACTIVE + recordable = self.mode != RobotMode.DAMPING + try: + self._record_pub.publish( + RECORD_TOPIC, + RecordStepPacket( + timestamp_s=time.monotonic(), + mode=record_mode, + mocap_active=active, + recordable=recordable, + observation_state=build_observation_state(robot_state).astype(np.float32, copy=True), + observation_mode=build_mode_observation(record_mode).astype(np.float32, copy=True), + action_reference_qpos=normalize_action_reference_qpos(reference_qpos).astype(np.float32, copy=True), + seq=self._mode_seq, + ), + ) + except Exception: + logger.exception("Failed to publish sim2real recording step") + + def _publish_damping_record_step(self) -> None: + if self._record_pub is None: + return + try: + robot_state = self.robot.get_state() + reference_qpos = self._build_robot_state_qpos(robot_state) + self._record_pub.publish( + RECORD_TOPIC, + RecordStepPacket( + timestamp_s=time.monotonic(), + mode=RobotMode.DAMPING.value, + mocap_active=False, + recordable=False, + observation_state=build_observation_state(robot_state).astype(np.float32, copy=True), + observation_mode=np.array([-1.0], dtype=np.float32), + action_reference_qpos=normalize_action_reference_qpos(reference_qpos).astype(np.float32, copy=True), + seq=self._mode_seq, + ), + ) + except Exception: + logger.exception("Failed to publish sim2real damping recording state") + + def _recording_mode_label(self) -> str: + if ( + self.mode in (RobotMode.MOCAP, RobotMode.ARMS) + and self._mocap_session.state == MocapSessionState.PAUSED + ): + return "pause" + return self.mode.value + def _write_retarget_viewer(self, qpos: Float64Array) -> None: try: self._retarget_viewer.write(qpos) @@ -1437,6 +1636,181 @@ def _main() -> None: _worker_loop("robot_control", _main) +class _RecordingWorker: + def __init__( + self, + cfg: dict[str, Any], + endpoints: Sim2RealIpcEndpoints, + stop_event: MpEvent, + *, + recorder_factory: Callable[..., Any] | None = None, + frame_reader: SharedFrameRingReader | None = None, + ) -> None: + self.cfg = cfg + self.endpoints = endpoints + self.stop_event = stop_event + self.rec_cfg = _recording_cfg(cfg) + self.camera_cfg = _recording_camera_cfg(cfg) + self.record_modes = { + str(mode).lower() + for mode in cfg_get(self.rec_cfg, "record_modes", ["standing", "mocap", "arms", "pause"]) + } + self.min_episode_seconds = float(cfg_get(self.rec_cfg, "min_episode_seconds", 1.0)) + self.discard_on_shutdown = bool(cfg_get(self.rec_cfg, "discard_on_shutdown", True)) + self.task = str(cfg_get(self.rec_cfg, "task", "demo")) + self.fps = int(cfg_get(self.rec_cfg, "fps", 30)) + self._record_sub = LatestSubscriber(endpoints.record_pub, RECORD_TOPIC) + self._video_sub = LatestSubscriber(endpoints.video_pub, VIDEO_TOPIC) + self._command_sub = LatestSubscriber(endpoints.command_pub, COMMAND_TOPIC) + self._frame_reader = frame_reader or SharedFrameRingReader() + self._latest_record: RecordStepPacket | None = None + self._latest_video_seq = -1 + self._active = False + self._episode_started_s = 0.0 + self._episode_frames = 0 + + from teleopit.recording.lerobot_v3 import ( + TeleopitLeRobotV3Recorder, + build_recording_schema, + ) + + self._schema = build_recording_schema(self.camera_cfg) + factory = recorder_factory or TeleopitLeRobotV3Recorder.create + self._recorder = factory( + output_dir=cfg_get(self.rec_cfg, "output_dir", "data/lerobot"), + dataset_name=cfg_get(self.rec_cfg, "dataset_name", None), + repo_id=cfg_get(self.rec_cfg, "repo_id", None), + task=self.task, + fps=self.fps, + schema=self._schema, + ) + + def run(self) -> None: + logger.info("Recording worker started | fps=%d | modes=%s", self.fps, sorted(self.record_modes)) + idle_sleep_s = 1.0 / max(float(self.fps) * 4.0, 1.0) + try: + while not self.stop_event.is_set(): + command = self._command_sub.recv_latest() + if isinstance(command, CommandPacket): + if self._handle_command(command): + break + + record = self._record_sub.recv_latest() + if isinstance(record, RecordStepPacket): + self._latest_record = record + + video = self._video_sub.recv_latest() + if isinstance(video, SharedFrameDescriptor): + self._handle_video(video) + + time.sleep(idle_sleep_s) + finally: + if self._active: + if self.discard_on_shutdown: + self._discard_episode("shutdown") + else: + self._save_episode() + try: + self._recorder.finalize() + finally: + self._record_sub.close() + self._video_sub.close() + self._command_sub.close() + self._frame_reader.close() + + def _handle_command(self, command: CommandPacket) -> bool: + name = command.command + if name == "shutdown": + self.stop_event.set() + return True + if name == "record_start": + self._start_episode() + elif name == "record_save": + self._save_episode() + elif name == "record_discard": + self._discard_episode("manual discard") + return False + + def _start_episode(self) -> None: + if self._active: + logger.warning("Recording episode already active; ignoring R") + return + record = self._latest_record + if record is None: + logger.warning("Cannot start recording: no robot record packet yet") + return + mode = str(record.mode).lower() + if mode not in self.record_modes or not bool(record.recordable): + logger.warning( + "Cannot start recording: mode=%s recordable=%s", + record.mode, + record.recordable, + ) + return + self._recorder.start_episode() + self._active = True + self._episode_started_s = time.monotonic() + self._episode_frames = 0 + logger.info("Recording episode started") + + def _save_episode(self) -> None: + if not self._active: + logger.info("No active recording episode to save") + return + duration_s = time.monotonic() - self._episode_started_s + if duration_s < self.min_episode_seconds: + self._discard_episode(f"short episode ({duration_s:.2f}s < {self.min_episode_seconds:.2f}s)") + return + self._recorder.save_episode() + logger.info("Recording episode saved | frames=%d duration=%.2fs", self._episode_frames, duration_s) + self._active = False + self._episode_frames = 0 + + def _discard_episode(self, reason: str) -> None: + if not self._active: + logger.info("No active recording episode to discard") + return + self._recorder.discard_episode() + logger.info("Recording episode discarded | reason=%s | frames=%d", reason, self._episode_frames) + self._active = False + self._episode_frames = 0 + + def _handle_video(self, descriptor: SharedFrameDescriptor) -> None: + if int(descriptor.seq) == self._latest_video_seq: + return + self._latest_video_seq = int(descriptor.seq) + if not self._active: + return + record = self._latest_record + if record is None: + return + mode = str(record.mode).lower() + if mode not in self.record_modes or not bool(record.recordable): + logger.warning("Recording stopped because mode is no longer recordable: %s", record.mode) + self._discard_episode("mode not recordable") + return + image = self._frame_reader.read(descriptor, copy=True) + self._recorder.add_frame( + image=np.asarray(image, dtype=np.uint8), + state=np.asarray(record.observation_state, dtype=np.float32), + mode=np.asarray(record.observation_mode, dtype=np.float32), + action=np.asarray(record.action_reference_qpos, dtype=np.float32), + task=self.task, + ) + self._episode_frames += 1 + +def _run_recording_worker( + cfg: dict[str, Any], + endpoints: Sim2RealIpcEndpoints, + stop_event: MpEvent, +) -> None: + def _main() -> None: + worker = _RecordingWorker(cfg, endpoints, stop_event) + worker.run() + + _worker_loop("recording_worker", _main) + + class _HandSnapshotProxy: def __init__(self) -> None: self.hand_snapshot: Any | None = None diff --git a/tests/test_pico_video.py b/tests/test_pico_video.py index 075ec194..6e0d3cb7 100644 --- a/tests/test_pico_video.py +++ b/tests/test_pico_video.py @@ -87,11 +87,62 @@ def stop(self) -> None: assert sink.frames[-1].dtype == np.uint8 +def test_realsense_video_runtime_invokes_frame_callback(monkeypatch: pytest.MonkeyPatch) -> None: + fake_rs = ModuleType("pyrealsense2") + fake_rs.stream = SimpleNamespace(color="color") + fake_rs.format = SimpleNamespace(rgb8="rgb8") + + class FakeConfig: + def enable_stream(self, *_args: object) -> None: + pass + + class FakeColorFrame: + def get_data(self) -> np.ndarray: + return np.full((2, 2, 3), 7, dtype=np.uint8) + + class FakeFrames: + def get_color_frame(self) -> FakeColorFrame: + return FakeColorFrame() + + class FakePipeline: + def start(self, _config: object) -> None: + pass + + def wait_for_frames(self) -> FakeFrames: + time.sleep(0.005) + return FakeFrames() + + def stop(self) -> None: + pass + + fake_rs.config = FakeConfig + fake_rs.pipeline = FakePipeline + monkeypatch.setitem(sys.modules, "pyrealsense2", fake_rs) + + sink = _FrameSink() + callback_frames: list[np.ndarray] = [] + config = parse_pico_video_config({"video": {"enabled": True, "source": "realsense", "width": 2, "height": 2}}) + runtime = PicoVideoRuntime( + provider=sink, + config=config, + mode="sim2real", + frame_callback=lambda frame, _timestamp_s: callback_frames.append(frame.copy()), + ) + + runtime.start() + time.sleep(0.03) + runtime.stop() + + assert sink.frames + assert callback_frames + np.testing.assert_array_equal(callback_frames[-1], sink.frames[-1]) + + def test_video_runtime_stops_producer_after_startup_error(monkeypatch: pytest.MonkeyPatch) -> None: stopped = False class FailingProducer: - def __init__(self, _provider: object, _config: object) -> None: + def __init__(self, _provider: object, _config: object, _frame_callback: object = None) -> None: pass def start(self) -> None: @@ -120,7 +171,7 @@ def test_video_runtime_stops_producer_before_reraising_tick_error(monkeypatch: p stopped = False class FailingProducer: - def __init__(self, _provider: object, _config: object) -> None: + def __init__(self, _provider: object, _config: object, _frame_callback: object = None) -> None: pass def start(self) -> None: diff --git a/tests/test_sim2real_multiprocess.py b/tests/test_sim2real_multiprocess.py index 2a4dd7c4..50abd5ac 100644 --- a/tests/test_sim2real_multiprocess.py +++ b/tests/test_sim2real_multiprocess.py @@ -11,13 +11,26 @@ from teleopit.runtime.mocap_session import MocapSessionState from teleopit.inputs.realtime_packet import ControlEvent, ControlEventType from teleopit.runtime.arm_mocap import compose_arm_reference, compose_arm_reference_window +from teleopit.recording.lerobot_v3 import ( + ACTION_KEY, + IMAGE_KEY, + MODE_KEY, + STATE_KEY, + build_mode_observation, + build_observation_state, + build_recording_schema, + lerobot_features, + modality_sidecar, +) from teleopit.sim2real.mp.ipc import HEALTH_TOPIC, LatestSubscriber, ZmqPublisher -from teleopit.sim2real.mp.messages import ReferencePacket, SharedFrameDescriptor +from teleopit.sim2real.mp.messages import RecordStepPacket, ReferencePacket, SharedFrameDescriptor from teleopit.sim.reference_timeline import ReferenceSample, ReferenceWindow from teleopit.sim2real.mp.runtime import ( + map_recording_key_to_command, RobotMode, Sim2RealRuntime, _LoopTimingReporter, + _RecordingWorker, _RobotControlWorker, _human_frame_is_valid, ) @@ -60,6 +73,26 @@ def test_sim2real_runtime_rejects_hands_without_pico_provider() -> None: Sim2RealRuntime(cfg) +def test_sim2real_runtime_rejects_recording_without_pico_provider() -> None: + cfg = { + "input": {"provider": "bvh"}, + "runtime": {"shutdown_timeout_s": 0.01}, + "recording": {"enabled": True}, + } + with pytest.raises(ValueError, match="recording.enabled=true requires input.provider=pico4"): + Sim2RealRuntime(cfg) + + +def test_sim2real_runtime_rejects_recording_without_input_video() -> None: + cfg = { + "input": {"provider": "pico4", "video": {"enabled": False, "source": "realsense"}}, + "runtime": {"shutdown_timeout_s": 0.01}, + "recording": {"enabled": True}, + } + with pytest.raises(ValueError, match="recording.enabled=true requires input.video.enabled=true"): + Sim2RealRuntime(cfg) + + def test_shared_frame_ring_roundtrip() -> None: writer = SharedFrameRingWriter(shape=(2, 3, 1), dtype=np.uint8, slots=2) reader = SharedFrameRingReader() @@ -213,6 +246,84 @@ def Process(self, *, name: str, target: object, args: tuple[object, ...]) -> Fak assert started_names == ["pico_input", "reference", "robot_control"] +def test_recording_enabled_adds_recording_worker(monkeypatch) -> None: + started_names: list[str] = [] + + class FakeProcess: + def __init__(self, *, name: str, target: object, args: tuple[object, ...]) -> None: + del target, args + self.name = name + self.exitcode = 0 + + def start(self) -> None: + started_names.append(self.name) + + class FakeContext: + def Event(self) -> object: + return SimpleNamespace(set=lambda: None, is_set=lambda: False) + + def Process(self, *, name: str, target: object, args: tuple[object, ...]) -> FakeProcess: + return FakeProcess(name=name, target=target, args=args) + + cfg = { + "input": { + "provider": "pico4", + "video": {"enabled": True, "source": "realsense", "width": 640, "height": 480, "fps": 30}, + }, + "runtime": {"shutdown_timeout_s": 0.01}, + "recording": {"enabled": True}, + } + monkeypatch.setattr("teleopit.sim2real.mp.runtime._require_recording_dependencies", lambda: None) + monkeypatch.setattr("sys.stdin.isatty", lambda: True) + runtime = Sim2RealRuntime(cfg) + runtime._ctx = FakeContext() # type: ignore[assignment] + + runtime._start_processes() + + assert started_names == ["pico_input", "reference", "robot_control", "recording_worker"] + + +def test_recording_key_mapping() -> None: + assert map_recording_key_to_command("R") == "record_start" + assert map_recording_key_to_command("s") == "record_save" + assert map_recording_key_to_command("D") == "record_discard" + assert map_recording_key_to_command("q") == "shutdown" + assert map_recording_key_to_command("x") is None + + +def test_lerobot_recording_schema_and_modality_sidecar() -> None: + schema = build_recording_schema({"width": 640, "height": 480, "key": IMAGE_KEY}) + features = lerobot_features(schema) + sidecar = modality_sidecar(schema) + + assert features[IMAGE_KEY]["shape"] == (480, 640, 3) + assert features[STATE_KEY]["shape"] == (68,) + assert features[MODE_KEY]["shape"] == (1,) + assert features[ACTION_KEY]["shape"] == (36,) + assert sidecar["features"][STATE_KEY]["slices"]["joint_pos"] == [0, 29] + assert sidecar["features"][STATE_KEY]["slices"]["projected_gravity"] == [65, 68] + assert sidecar["features"][MODE_KEY]["codes"]["pause"] == 3 + assert sidecar["features"][ACTION_KEY]["slices"]["joint_pos"] == [7, 36] + + +def test_record_observation_state_concat_order() -> None: + state = SimpleNamespace( + qpos=np.arange(29, dtype=np.float32), + qvel=np.arange(29, dtype=np.float32) + 100.0, + quat=np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), + ang_vel=np.array([1.0, 2.0, 3.0], dtype=np.float32), + ) + + out = build_observation_state(state) + + assert out.shape == (68,) + np.testing.assert_allclose(out[0:29], state.qpos) + np.testing.assert_allclose(out[29:58], state.qvel) + np.testing.assert_allclose(out[58:62], state.quat) + np.testing.assert_allclose(out[62:65], state.ang_vel) + np.testing.assert_allclose(out[65:68], np.array([0.0, 0.0, -1.0], dtype=np.float32)) + + def test_human_frame_validation_rejects_bad_inputs() -> None: valid_frame = { "Pelvis": (np.zeros(3, dtype=np.float64), np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64)), @@ -478,3 +589,178 @@ def test_robot_worker_mode_state_marks_arms_as_mocap_active() -> None: assert packet.mode == "arms" assert packet.mocap_active is True assert packet.mocap_paused is False + + +def test_robot_worker_publish_record_step() -> None: + worker = object.__new__(_RobotControlWorker) + worker.mode = RobotMode.ARMS + worker._mocap_session = SimpleNamespace(state=MocapSessionState.ACTIVE) + worker._mode_seq = 7 + published: list[tuple[str, object]] = [] + worker._record_pub = SimpleNamespace(publish=lambda topic, packet: published.append((topic, packet))) + robot_state = SimpleNamespace( + qpos=np.arange(29, dtype=np.float32), + qvel=np.arange(29, dtype=np.float32) + 10.0, + quat=np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), + ang_vel=np.array([0.1, 0.2, 0.3], dtype=np.float32), + ) + reference_qpos = np.arange(36, dtype=np.float64) + + worker._publish_record_step(robot_state=robot_state, reference_qpos=reference_qpos) + + assert len(published) == 1 + packet = published[0][1] + assert isinstance(packet, RecordStepPacket) + assert packet.mode == "arms" + assert packet.mocap_active is True + assert packet.recordable is True + assert packet.observation_state.shape == (68,) + assert packet.observation_mode.shape == (1,) + assert packet.action_reference_qpos.shape == (36,) + np.testing.assert_allclose(packet.observation_mode, build_mode_observation("arms")) + np.testing.assert_allclose(packet.action_reference_qpos, reference_qpos.astype(np.float32)) + + +def test_robot_worker_enter_damping_publishes_non_recordable_packet() -> None: + worker = object.__new__(_RobotControlWorker) + worker.mode = RobotMode.MOCAP + worker._mode_seq = 9 + worker._mocap_reentry_armed = True + worker._last_commanded_motion_qpos = np.ones(36, dtype=np.float64) + worker._last_mocap_hold_reason = "stale" + worker._default_root_pos = np.zeros(3, dtype=np.float64) + worker.num_actions = 29 + worker._mocap_session = SimpleNamespace(reset=lambda: None) + worker._ref_proc = SimpleNamespace(last_reference_qpos=np.zeros(36, dtype=np.float64)) + robot_state = SimpleNamespace( + qpos=np.arange(29, dtype=np.float32), + qvel=np.arange(29, dtype=np.float32) + 10.0, + quat=np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), + ang_vel=np.array([0.1, 0.2, 0.3], dtype=np.float32), + base_pos=np.array([0.0, 0.0, 0.8], dtype=np.float32), + ) + worker.robot = SimpleNamespace( + set_damping=lambda: None, + exit_debug_mode=lambda: None, + get_state=lambda: robot_state, + ) + published: list[tuple[str, object]] = [] + worker._record_pub = SimpleNamespace(publish=lambda topic, packet: published.append((topic, packet))) + + worker._enter_damping() + + assert worker.mode == RobotMode.DAMPING + assert published + packet = published[-1][1] + assert isinstance(packet, RecordStepPacket) + assert packet.mode == "damping" + assert packet.recordable is False + assert packet.mocap_active is False + np.testing.assert_allclose(packet.observation_mode, np.array([-1.0], dtype=np.float32)) + + +def test_recording_worker_start_save_discard_with_fake_adapter() -> None: + from teleopit.sim2real.mp.ipc import default_endpoints + + calls: list[str] = [] + frames: list[dict[str, np.ndarray]] = [] + + class FakeRecorder: + def start_episode(self) -> None: + calls.append("start") + + def add_frame( + self, + *, + image: np.ndarray, + state: np.ndarray, + mode: np.ndarray, + action: np.ndarray, + task: str, + ) -> None: + calls.append(f"frame:{task}") + frames.append({"image": image.copy(), "state": state.copy(), "mode": mode.copy(), "action": action.copy()}) + + def save_episode(self) -> None: + calls.append("save") + + def discard_episode(self) -> None: + calls.append("discard") + + def finalize(self) -> None: + calls.append("finalize") + + def fake_factory(**_kwargs: object) -> FakeRecorder: + return FakeRecorder() + + stop_event = SimpleNamespace(is_set=lambda: False, set=lambda: None) + endpoints = default_endpoints(base_port=39850) + worker = _RecordingWorker( + { + "recording": { + "enabled": True, + "task": "walk", + "fps": 30, + "min_episode_seconds": 0.0, + "camera": {"width": 2, "height": 2, "key": IMAGE_KEY}, + } + }, + endpoints, + stop_event, # type: ignore[arg-type] + recorder_factory=fake_factory, + ) + writer = SharedFrameRingWriter(shape=(2, 2, 3), dtype=np.uint8, slots=2) + try: + worker._latest_record = RecordStepPacket( + timestamp_s=1.0, + mode="damping", + mocap_active=False, + recordable=False, + observation_state=np.ones(68, dtype=np.float32), + observation_mode=build_mode_observation("standing"), + action_reference_qpos=np.ones(36, dtype=np.float32), + seq=1, + ) + worker._start_episode() + assert calls == [] + + worker._latest_record = RecordStepPacket( + timestamp_s=2.0, + mode="standing", + mocap_active=False, + recordable=True, + observation_state=np.arange(68, dtype=np.float32), + observation_mode=build_mode_observation("standing"), + action_reference_qpos=np.arange(36, dtype=np.float32), + seq=2, + ) + worker._start_episode() + desc = writer.write(np.full((2, 2, 3), 5, dtype=np.uint8), timestamp_s=2.1) + worker._handle_video(desc) + worker._save_episode() + + assert calls == ["start", "frame:walk", "save"] + assert frames[0]["image"].shape == (2, 2, 3) + np.testing.assert_allclose(frames[0]["state"], np.arange(68, dtype=np.float32)) + np.testing.assert_allclose(frames[0]["mode"], build_mode_observation("standing")) + np.testing.assert_allclose(frames[0]["action"], np.arange(36, dtype=np.float32)) + + worker._latest_record = RecordStepPacket( + timestamp_s=3.0, + mode="pause", + mocap_active=False, + recordable=True, + observation_state=np.zeros(68, dtype=np.float32), + observation_mode=build_mode_observation("pause"), + action_reference_qpos=np.zeros(36, dtype=np.float32), + seq=3, + ) + worker._start_episode() + worker._discard_episode("test") + assert calls[-2:] == ["start", "discard"] + finally: + writer.close(unlink=True) + worker._record_sub.close() + worker._video_sub.close() + worker._command_sub.close() + worker._frame_reader.close() From 97a354afcec71da56cec72497f9223b92ed4ab4e Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Mon, 22 Jun 2026 19:09:18 +0800 Subject: [PATCH 111/122] Update dexhand install docs --- docs/docs/configuration/config-reference.md | 5 +++-- docs/docs/getting-started/installation.md | 10 +++++----- docs/docs/tutorials/pico-sim2real.md | 8 +++++--- .../current/configuration/config-reference.md | 5 +++-- .../current/getting-started/installation.md | 9 +++++---- .../current/tutorials/pico-sim2real.md | 6 ++++-- pyproject.toml | 5 +---- 7 files changed, 26 insertions(+), 22 deletions(-) diff --git a/docs/docs/configuration/config-reference.md b/docs/docs/configuration/config-reference.md index 3e4690e8..6853f438 100644 --- a/docs/docs/configuration/config-reference.md +++ b/docs/docs/configuration/config-reference.md @@ -123,8 +123,9 @@ Realtime Pico resume re-centers heading and ground-plane position before trackin ### Dexterous Hand (Pico sim2real) -`hands.enabled=true` requires `input.provider=pico4` and the optional `dexhand` -extra. Control is active in `MOCAP` and `ARMS`; inactive modes send the open pose. +`hands.enabled=true` requires `input.provider=pico4` plus local editable +installs of `third_party/linkerhand-python-sdk` and `third_party/somehand`. +Control is active in `MOCAP` and `ARMS`; inactive modes send the open pose. `gripper` supports `linkerhand_l6` and `linkerhand_o6` by interpolating Pico trigger input between the configured open and close poses. `vr_hand_pose` is L6-only: missing hand pose holds the last command for that side, L6 speed is diff --git a/docs/docs/getting-started/installation.md b/docs/docs/getting-started/installation.md index c8e17094..615153c5 100644 --- a/docs/docs/getting-started/installation.md +++ b/docs/docs/getting-started/installation.md @@ -61,17 +61,17 @@ The receiver can run on a workstation PC or the robot onboard computer. See [Pico Sim2Sim](../tutorials/pico-sim2sim) and [Pico Sim2Real](../tutorials/pico-sim2real) for the full setup guides. -Optional LinkerHand control for Pico sim2real is installed through the -`dexhand` extra. It includes the LinkerHand SDK submodule and the remote -somehand package used by the L6 VR hand-pose mode: +Optional LinkerHand control for Pico sim2real uses local third-party packages. +Install those packages directly after initializing the submodules: ```bash git submodule update --init --recursive -pip install -e '.[dexhand]' +pip install -e third_party/linkerhand-python-sdk +pip install -e third_party/somehand scripts/setup/download_somehand_l6_assets.sh ``` -This extra is only required when `hands.enabled=true`. +These packages are only required when `hands.enabled=true`. ### Sim2Real Recording diff --git a/docs/docs/tutorials/pico-sim2real.md b/docs/docs/tutorials/pico-sim2real.md index 3fd1334d..957cdd95 100644 --- a/docs/docs/tutorials/pico-sim2real.md +++ b/docs/docs/tutorials/pico-sim2real.md @@ -187,11 +187,13 @@ Pico sim2real can drive LinkerHand hands from Pico input: Hand control is active in `MOCAP` and `ARMS`. It sends the open pose in `STANDING`, `DAMPING`, paused mocap, and shutdown. -Install the dexhand extra first if it was not installed with the main Pico -profile: +Install the local hand-control packages first if they were not installed with +the main Pico profile: ```bash -pip install -e '.[dexhand]' +git submodule update --init --recursive +pip install -e third_party/linkerhand-python-sdk +pip install -e third_party/somehand scripts/setup/download_somehand_l6_assets.sh ``` diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md index 6b92edbf..d241e69d 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md @@ -142,8 +142,9 @@ MuJoCo 窗口显示重定向参考;`sim2sim`、`mocap`、`camera` 和 `all` ### 灵巧手(Pico sim2real) -`hands.enabled=true` 要求 `input.provider=pico4`,并安装可选的 `dexhand` -extra。控制在 `MOCAP` 和 `ARMS` 中生效;非活动模式会发送张开姿态。 +`hands.enabled=true` 要求 `input.provider=pico4`,并以本地 editable 方式安装 +`third_party/linkerhand-python-sdk` 和 `third_party/somehand`。控制在 `MOCAP` +和 `ARMS` 中生效;非活动模式会发送张开姿态。 `gripper` 支持 `linkerhand_l6` 和 `linkerhand_o6`,会用 Pico trigger 在配置的张开和闭合姿态之间插值。 `vr_hand_pose` 只支持 L6:手部 pose 消失时,对应侧会保持上一条命令;L6 速度会设为最大值; Teleopit 会先将 Pico 手部状态转成 21 个 landmarks,再只通过 somehand 0.2.0 公开的 `somehand.api` 调用。 diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md index 919dc1b2..5fa7f38d 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md @@ -61,16 +61,17 @@ receiver 可以运行在工作站 PC,也可以运行在机器人 onboard 计 完整设置流程详见 [Pico Sim2Sim](../tutorials/pico-sim2sim) 和 [Pico Sim2Real](../tutorials/pico-sim2real)。 -Pico sim2real 可选的 LinkerHand 控制通过 `dexhand` extra 安装。它包含 -LinkerHand SDK submodule,以及 L6 `vr_hand_pose` 模式使用的远程 somehand 包: +Pico sim2real 可选的 LinkerHand 控制使用本地 third-party 包。初始化 +submodule 后,直接安装这些包: ```bash git submodule update --init --recursive -pip install -e '.[dexhand]' +pip install -e third_party/linkerhand-python-sdk +pip install -e third_party/somehand scripts/setup/download_somehand_l6_assets.sh ``` -只有在 `hands.enabled=true` 时才需要安装这个 extra。 +只有在 `hands.enabled=true` 时才需要安装这些包。 ### Sim2Real 录制 diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md index 8bf778ad..ef67634d 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md @@ -175,10 +175,12 @@ Pico sim2real 可以用 Pico 输入控制 LinkerHand: 手控在 `MOCAP` 和 `ARMS` 中生效;在 `STANDING`、`DAMPING`、mocap 暂停和退出时都会发送张开姿态。 -如果主 Pico profile 没有包含手控支持,先安装 dexhand extra: +如果主 Pico profile 没有包含手控支持,先安装本地手控包: ```bash -pip install -e '.[dexhand]' +git submodule update --init --recursive +pip install -e third_party/linkerhand-python-sdk +pip install -e third_party/somehand scripts/setup/download_somehand_l6_assets.sh ``` diff --git a/pyproject.toml b/pyproject.toml index 0f95aa7c..fbe6288d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,10 +60,7 @@ recording = [ "opencv-python", "imageio[ffmpeg]", ] -dexhand = [ - "linkerhand-python-sdk @ file:third_party/linkerhand-python-sdk", - "somehand @ file:third_party/somehand", -] +dexhand = [] [tool.setuptools.packages.find] where = ["."] From ab728f89662a9931bc3952b5d56c350878de1ee2 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Mon, 22 Jun 2026 19:11:44 +0800 Subject: [PATCH 112/122] Discard empty recording episodes --- teleopit/sim2real/mp/runtime.py | 3 +++ tests/test_sim2real_multiprocess.py | 6 +++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/teleopit/sim2real/mp/runtime.py b/teleopit/sim2real/mp/runtime.py index a8af3f70..001eaeef 100644 --- a/teleopit/sim2real/mp/runtime.py +++ b/teleopit/sim2real/mp/runtime.py @@ -1758,6 +1758,9 @@ def _save_episode(self) -> None: logger.info("No active recording episode to save") return duration_s = time.monotonic() - self._episode_started_s + if self._episode_frames <= 0: + self._discard_episode("empty episode") + return if duration_s < self.min_episode_seconds: self._discard_episode(f"short episode ({duration_s:.2f}s < {self.min_episode_seconds:.2f}s)") return diff --git a/tests/test_sim2real_multiprocess.py b/tests/test_sim2real_multiprocess.py index 50abd5ac..2be0b10c 100644 --- a/tests/test_sim2real_multiprocess.py +++ b/tests/test_sim2real_multiprocess.py @@ -734,12 +734,16 @@ def fake_factory(**_kwargs: object) -> FakeRecorder: action_reference_qpos=np.arange(36, dtype=np.float32), seq=2, ) + worker._start_episode() + worker._save_episode() + assert calls == ["start", "discard"] + worker._start_episode() desc = writer.write(np.full((2, 2, 3), 5, dtype=np.uint8), timestamp_s=2.1) worker._handle_video(desc) worker._save_episode() - assert calls == ["start", "frame:walk", "save"] + assert calls == ["start", "discard", "start", "frame:walk", "save"] assert frames[0]["image"].shape == (2, 2, 3) np.testing.assert_allclose(frames[0]["state"], np.arange(68, dtype=np.float32)) np.testing.assert_allclose(frames[0]["mode"], build_mode_observation("standing")) From ea7cb7d72b4d72f45b75f5a4b8f7b35b09134793 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Mon, 22 Jun 2026 20:28:31 +0800 Subject: [PATCH 113/122] Add runtime console feedback --- scripts/run/run_sim.py | 48 ++-- scripts/run/run_sim2real.py | 55 ++-- scripts/run/standalone_standing.py | 23 +- teleopit/configs/default.yaml | 5 + teleopit/configs/sim2real.yaml | 5 + teleopit/pipeline.py | 4 +- teleopit/retargeting/gmr/motion_retarget.py | 9 +- teleopit/runtime/console.py | 271 ++++++++++++++++++++ teleopit/sim/loop.py | 4 + teleopit/sim/session.py | 42 ++- teleopit/sim2real/mp/runtime.py | 112 +++++--- tests/test_console.py | 84 ++++++ 12 files changed, 574 insertions(+), 88 deletions(-) create mode 100644 teleopit/runtime/console.py create mode 100644 tests/test_console.py diff --git a/scripts/run/run_sim.py b/scripts/run/run_sim.py index 491b9b19..216179c3 100644 --- a/scripts/run/run_sim.py +++ b/scripts/run/run_sim.py @@ -4,34 +4,48 @@ from omegaconf import DictConfig from teleopit.pipeline import TeleopPipeline +from teleopit.runtime.common import cfg_get +from teleopit.runtime.console import ( + PlainConsole, + configure_runtime_logging, + sim_keyboard_controls, +) from teleopit.runtime.cli import validate_policy_path -def _print_sim_controls(cfg: DictConfig) -> None: - provider = str(cfg.input.get("provider", "bvh")).lower() +def _sim_status(cfg: DictConfig) -> tuple[tuple[str, str], ...]: + input_cfg = cfg_get(cfg, "input", {}) or {} + provider = str(cfg_get(input_cfg, "provider", "bvh")).lower() + viewers = str(cfg_get(cfg, "viewers", "none")) if provider == "pico4": - print("Pico sim2sim controls:") - if bool(cfg.get("keyboard", {}).get("enabled", False)): - print(" Keyboard: starts in STANDING; Y mocap, A pause/resume, B arms, X standing, Q quit.") - else: - print(" Pico controller: A pause/resume, B arms.") - print(" State flow: STANDING -> MOCAP <-> ARMS, X -> STANDING.") - return - if bool(cfg.get("playback", {}).get("keyboard", {}).get("enabled", False)): - print("Offline sim2sim controls:") - print(" Keyboard: Space/P pause/resume, R replay, Q stop.") + keyboard_cfg = cfg_get(cfg, "keyboard", {}) or {} + state = "STANDING" if bool(cfg_get(keyboard_cfg, "enabled", False)) else "MOCAP" + return ( + ("State", state), + ("Input", "Pico4 live"), + ("Viewers", viewers), + ) + return ( + ("State", "MOCAP"), + ("Input", "BVH"), + ("Viewers", viewers), + ) @hydra.main(version_base=None, config_path="../../teleopit/configs", config_name="default") def main(cfg: DictConfig) -> None: + configure_runtime_logging(cfg, force=True) validate_policy_path(cfg, "run_sim.py") - pipeline = TeleopPipeline(cfg) + console = PlainConsole(title="Teleopit sim2sim") + pipeline = TeleopPipeline(cfg, console=console) num_steps = int(cfg.get("num_steps", 0)) - if cfg.input.get("provider") == "pico4": - print("Waiting for Pico4 body tracking data...") - _print_sim_controls(cfg) + events = [] + input_cfg = cfg_get(cfg, "input", {}) or {} + if cfg_get(input_cfg, "provider", None) == "pico4": + events.append("waiting for Pico4 body tracking data") + console.start(status=_sim_status(cfg), controls=sim_keyboard_controls(cfg), events=events) result = pipeline.run(num_steps=num_steps) - print(result) + console.event(str(result)) if __name__ == "__main__": diff --git a/scripts/run/run_sim2real.py b/scripts/run/run_sim2real.py index f36d6673..4e260a35 100644 --- a/scripts/run/run_sim2real.py +++ b/scripts/run/run_sim2real.py @@ -2,28 +2,33 @@ from __future__ import annotations +import inspect + import hydra from omegaconf import DictConfig +from teleopit.runtime.common import cfg_get +from teleopit.runtime.console import ( + PlainConsole, + configure_runtime_logging, + sim2real_operator_controls, +) from teleopit.runtime.cli import validate_policy_path from teleopit.sim2real.mp import Sim2RealRuntime -def _print_sim2real_controls(cfg: DictConfig) -> None: - provider = str(cfg.input.get("provider", "bvh")).lower() - print("Sim2real controls:") - print(" Remote Start: enter STANDING.") - print(" Remote Y: enter MOCAP.") - print(" Remote X: return to STANDING.") - print(" Remote L1+R1: DAMPING / estop.") - if provider == "pico4": - print(" Mocap pause/resume: Pico/controller A.") - print(" Arm-only mode: Pico/controller B toggles MOCAP <-> ARMS.") - print(" Dexterous hand: hands.enabled=true hands.driver=linkerhand_l6|linkerhand_o6 hands.mode=gripper|vr_hand_pose.") - print(" State flow: IDLE -> STANDING -> MOCAP <-> ARMS, X -> STANDING, Any -> DAMPING.") - else: - print(" Offline playback: A pause/resume, B replay from start.") - print(" State flow: IDLE -> STANDING -> MOCAP -> STANDING, Any -> DAMPING.") +def _sim2real_status(cfg: DictConfig) -> tuple[tuple[str, str], ...]: + input_cfg = cfg_get(cfg, "input", {}) or {} + provider = str(cfg_get(input_cfg, "provider", "bvh")).lower() + input_label = "Pico4 live" if provider == "pico4" else "BVH" + recording_cfg = cfg_get(cfg, "recording", {}) or {} + recording = "enabled" if bool(cfg_get(recording_cfg, "enabled", False)) else "off" + return ( + ("State", "IDLE"), + ("Runtime", "multiprocess"), + ("Input", input_label), + ("Recording", recording), + ) @hydra.main(version_base=None, config_path="../../teleopit/configs", config_name="sim2real") @@ -32,12 +37,22 @@ def main(cfg: DictConfig) -> None: def _run_sim2real(cfg: DictConfig) -> None: + configure_runtime_logging(cfg, force=True) validate_policy_path(cfg, "run_sim2real.py") - controller = Sim2RealRuntime(cfg) - if cfg.input.get("provider") == "pico4": - print("Waiting for Pico4 body tracking data...") - print("Sim2real runtime: multiprocess") - _print_sim2real_controls(cfg) + console = PlainConsole(title="Teleopit sim2real") + runtime_params = inspect.signature(Sim2RealRuntime).parameters + controller = Sim2RealRuntime(cfg, console=console) if "console" in runtime_params else Sim2RealRuntime(cfg) + events = [] + input_cfg = cfg_get(cfg, "input", {}) or {} + if cfg_get(input_cfg, "provider", None) == "pico4": + events.append("waiting for Pico4 body tracking data") + console.start( + status=_sim2real_status(cfg), + controls=sim2real_operator_controls(cfg), + events=events, + control_section="Controls", + show_help_key=False, + ) try: controller.run() finally: diff --git a/scripts/run/standalone_standing.py b/scripts/run/standalone_standing.py index 1670746e..1553393a 100644 --- a/scripts/run/standalone_standing.py +++ b/scripts/run/standalone_standing.py @@ -53,10 +53,12 @@ def __init__( target_period_s: float, log_interval_s: float = 1.0, deadline_miss_tolerance_s: float = 0.001, + enabled: bool = True, ) -> None: self._target_period_s = float(target_period_s) self._log_interval_s = float(log_interval_s) self._deadline_miss_tolerance_s = float(deadline_miss_tolerance_s) + self._enabled = bool(enabled) self._window_start_s: float | None = None self._loop_ms: list[float] = [] self._late_ms: list[float] = [] @@ -109,6 +111,9 @@ def _emit(self, end_s: float) -> None: if sample_count <= 0: self._reset(end_s) return + if not self._enabled: + self._reset(end_s) + return logger.info( "Standalone timing | samples=%d window=%.1fs | " "loop_ms p50=%.2f p95=%.2f p99=%.2f max=%.2f | " @@ -230,7 +235,12 @@ def __init__( self._last_action = np.zeros(self.num_actions, dtype=np.float32) self._last_target: np.ndarray | None = None self._step_count = 0 - self._timing = _StandaloneTimingReporter(target_period_s=self.dt) + console_cfg = cfg_get(cfg, "console", {}) or {} + self._timing = _StandaloneTimingReporter( + target_period_s=self.dt, + log_interval_s=float(cfg_get(console_cfg, "timing_log_interval_s", 10.0)), + enabled=bool(cfg_get(console_cfg, "show_timing", False)), + ) if self.obs_delay_s > 0.0 or self.command_delay_s > 0.0: logger.info( @@ -462,6 +472,10 @@ def _build_cfg(args: argparse.Namespace) -> Any: "robot": OmegaConf.load(PROJECT_ROOT / "teleopit" / "configs" / "robot" / "g1.yaml"), "controller": OmegaConf.load(PROJECT_ROOT / "teleopit" / "configs" / "controller" / "rl_policy.yaml"), "real_robot": OmegaConf.load(PROJECT_ROOT / "teleopit" / "configs" / "sim2real.yaml").real_robot, + "console": { + "show_timing": bool(args.show_timing), + "timing_log_interval_s": float(args.timing_log_interval_s), + }, } ) cfg.controller.policy_path = str(args.policy) @@ -480,6 +494,13 @@ def main() -> None: parser.add_argument("--kp-ramp-duration", type=float, default=2.0, help="Startup Kp ramp duration in seconds") parser.add_argument("--kp-ramp-floor-ratio", type=float, default=0.1, help="Initial Kp ratio during startup") parser.add_argument("--joint-vel-limit", type=float, default=10.0, help="Damp if any joint exceeds this velocity") + parser.add_argument("--show-timing", action="store_true", help="Print periodic timing diagnostics") + parser.add_argument( + "--timing-log-interval-s", + type=float, + default=10.0, + help="Timing diagnostic print interval when --show-timing is set", + ) parser.add_argument( "--obs-delay-ms", type=float, diff --git a/teleopit/configs/default.yaml b/teleopit/configs/default.yaml index 17036bb1..41eae1e4 100644 --- a/teleopit/configs/default.yaml +++ b/teleopit/configs/default.yaml @@ -14,6 +14,11 @@ playback: keyboard: enabled: false +console: + log_level: warning + show_timing: false + timing_log_interval_s: 10.0 + hydra: run: dir: . diff --git a/teleopit/configs/sim2real.yaml b/teleopit/configs/sim2real.yaml index cc53009f..262338af 100644 --- a/teleopit/configs/sim2real.yaml +++ b/teleopit/configs/sim2real.yaml @@ -17,6 +17,11 @@ reference_debug_log: false playback: pause_on_end: true +console: + log_level: warning + show_timing: false + timing_log_interval_s: 10.0 + recording: enabled: false format: lerobot_v3 diff --git a/teleopit/pipeline.py b/teleopit/pipeline.py index 234247b6..181a1912 100644 --- a/teleopit/pipeline.py +++ b/teleopit/pipeline.py @@ -13,12 +13,13 @@ from teleopit.retargeting.core import RetargetingModule from teleopit.robots.mujoco_robot import MuJoCoRobot from teleopit.runtime.common import cfg_get +from teleopit.runtime.console import PlainConsole from teleopit.runtime.factory import build_inference_components from teleopit.sim.loop import SimulationLoop class TeleopPipeline: - def __init__(self, cfg: DictConfig | dict[str, Any]) -> None: + def __init__(self, cfg: DictConfig | dict[str, Any], *, console: PlainConsole | None = None) -> None: self.cfg = cfg self._project_root = Path(__file__).resolve().parent.parent components = build_inference_components( @@ -53,6 +54,7 @@ def __init__(self, cfg: DictConfig | dict[str, Any]) -> None: components.sim_cfg, viewers=components.viewers, video_runtime=self.video_runtime, + console=console, ) def run(self, num_steps: int) -> dict[str, float | int | str]: diff --git a/teleopit/retargeting/gmr/motion_retarget.py b/teleopit/retargeting/gmr/motion_retarget.py index 8b0d6637..c97400f9 100644 --- a/teleopit/retargeting/gmr/motion_retarget.py +++ b/teleopit/retargeting/gmr/motion_retarget.py @@ -35,8 +35,9 @@ def __init__( self.model = mj.MjModel.from_xml_path(self.xml_file) # Print DoF names in order - print("[GMR] Robot Degrees of Freedom (DoF) names and their order:") self.robot_dof_names = {} + if verbose: + print("[GMR] Robot Degrees of Freedom (DoF) names and their order:") for i in range(self.model.nv): # 'nv' is the number of DoFs dof_name = mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_JOINT, self.model.dof_jntid[i]) self.robot_dof_names[dof_name] = i @@ -44,16 +45,18 @@ def __init__( print(f"DoF {i}: {dof_name}") - print("[GMR] Robot Body names and their IDs:") self.robot_body_names = {} + if verbose: + print("[GMR] Robot Body names and their IDs:") for i in range(self.model.nbody): # 'nbody' is the number of bodies body_name = mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_BODY, i) self.robot_body_names[body_name] = i if verbose: print(f"Body ID {i}: {body_name}") - print("[GMR] Robot Motor (Actuator) names and their IDs:") self.robot_motor_names = {} + if verbose: + print("[GMR] Robot Motor (Actuator) names and their IDs:") for i in range(self.model.nu): # 'nu' is the number of actuators (motors) motor_name = mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_ACTUATOR, i) self.robot_motor_names[motor_name] = i diff --git a/teleopit/runtime/console.py b/teleopit/runtime/console.py new file mode 100644 index 00000000..2b6d7b12 --- /dev/null +++ b/teleopit/runtime/console.py @@ -0,0 +1,271 @@ +from __future__ import annotations + +import sys +import logging +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Iterable + +from teleopit.runtime.common import cfg_get + + +RESET = "\033[0m" +BOLD = "\033[1m" +DIM = "\033[2m" +CYAN = "\033[36m" +GREEN = "\033[32m" +YELLOW = "\033[33m" +RED = "\033[31m" +MAGENTA = "\033[35m" +OPERATOR_LOGGER_NAME = "teleopit.operator" + + +@dataclass(frozen=True) +class KeyboardControl: + keys: str + action: str + + +def console_show_timing(cfg: Any) -> bool: + console_cfg = cfg_get(cfg, "console", {}) or {} + return bool(cfg_get(console_cfg, "show_timing", False)) + + +def console_timing_interval_s(cfg: Any, default: float = 10.0) -> float: + console_cfg = cfg_get(cfg, "console", {}) or {} + return float(cfg_get(console_cfg, "timing_log_interval_s", default)) + + +def console_log_level(cfg: Any, default: str = "warning") -> str: + console_cfg = cfg_get(cfg, "console", {}) or {} + return str(cfg_get(console_cfg, "log_level", default)).strip().lower() + + +def configure_runtime_logging(cfg: Any, *, force: bool = False) -> None: + """Keep operator output quiet by default while preserving warnings/errors.""" + + level_name = console_log_level(cfg) + level = getattr(logging, level_name.upper(), logging.WARNING) + logging.basicConfig(level=level, format="%(levelname)s:%(name)s:%(message)s", force=force) + + # Operator events are intentional console feedback; ordinary subsystem INFO stays hidden. + logging.getLogger(OPERATOR_LOGGER_NAME).setLevel(logging.INFO) + + if level > logging.INFO: + noisy_names = ( + "pico_bridge", + "onnxruntime", + "teleopit.inputs.pico4_provider", + "teleopit.sim2real.unitree_g1", + "teleopit.sim2real.safety", + ) + for name in noisy_names: + logging.getLogger(name).setLevel(logging.WARNING) + + +class PlainConsole: + """Small runtime console for keyboard controls and operator feedback.""" + + def __init__(self, *, title: str, enabled: bool = True, color: bool | None = None) -> None: + self._title = str(title) + self._enabled = bool(enabled) + self._color = sys.stdout.isatty() if color is None else bool(color) + self._started = False + + @property + def enabled(self) -> bool: + return self._enabled + + def start( + self, + *, + status: Iterable[tuple[str, str]] = (), + controls: Iterable[KeyboardControl] = (), + events: Iterable[str] = (), + control_section: str = "Keyboard", + show_help_key: bool = True, + ) -> None: + if not self._enabled: + return + self._started = True + print( + self.render( + status=status, + controls=controls, + events=events, + control_section=control_section, + show_help_key=show_help_key, + ), + flush=True, + ) + + def event(self, message: str) -> None: + if not self._enabled: + return + prefix = datetime.now().strftime("%H:%M:%S") + print(f"{self._dim(prefix)} {self._highlight_text(message)}", flush=True) + + def key_feedback(self, key: str, action: str, *, result: str | None = None) -> None: + details = f" -> {result}" if result else "" + self.event(f"{self.format_key(key)} {action}{details}") + + def help(self, controls: Iterable[KeyboardControl]) -> None: + if not self._enabled: + return + control_items = list(controls) + if not control_items: + self.event("no keyboard controls active") + return + lines = [self._section("Keyboard help")] + lines.extend(f" {self.format_key(item.keys)} {item.action}" for item in control_items) + print("\n".join(lines), flush=True) + + def render( + self, + *, + status: Iterable[tuple[str, str]] = (), + controls: Iterable[KeyboardControl] = (), + events: Iterable[str] = (), + control_section: str = "Keyboard", + show_help_key: bool = True, + ) -> str: + lines = [self._title_text(self._title)] + status_items = [(str(key), str(value)) for key, value in status if str(value)] + if status_items: + width = max(len(key) for key, _value in status_items) + lines.append("") + lines.extend( + f"{self._label(key.ljust(width))} {self._status_value(value)}" + for key, value in status_items + ) + + control_items = list(controls) + if control_items: + lines.append("") + lines.append(self._section(control_section)) + lines.append(" ".join(f"{self.format_key(item.keys)} {item.action}" for item in control_items)) + if show_help_key: + lines.append(f"{self.format_key('H')} help") + + event_items = [str(event) for event in events if str(event)] + if event_items: + lines.append("") + lines.append(self._section("Events")) + lines.extend(self._highlight_text(event) for event in event_items) + return "\n".join(lines) + + def _style(self, text: str, code: str) -> str: + if not self._color: + return text + return f"{code}{text}{RESET}" + + def _title_text(self, text: str) -> str: + return self._style(text, BOLD) + + def _section(self, text: str) -> str: + return self._style(text, CYAN + BOLD) + + def _label(self, text: str) -> str: + return self._style(text, DIM) + + def _dim(self, text: str) -> str: + return self._style(text, DIM) + + def format_key(self, text: str) -> str: + return self._style(f"[{text}]", YELLOW + BOLD) + + def _status_value(self, text: str) -> str: + normalized = text.strip().lower() + if normalized in {"ok", "enabled", "running", "mocap", "standing", "active"}: + return self._style(text, GREEN + BOLD) + if normalized in {"idle", "waiting", "paused", "off", "none"} or "waiting" in normalized: + return self._style(text, YELLOW + BOLD) + if normalized in {"damping", "error", "failed"} or "stale" in normalized: + return self._style(text, RED + BOLD) + if normalized in {"arms", "pico4 live"}: + return self._style(text, MAGENTA + BOLD) + return self._style(text, BOLD) + + def _highlight_text(self, text: str) -> str: + if not self._color: + return text + highlighted = text + replacements = { + "waiting": YELLOW + BOLD, + "paused": YELLOW + BOLD, + "resumed": GREEN + BOLD, + "replay": GREEN + BOLD, + "stopping": RED + BOLD, + "shutdown": RED + BOLD, + "MOCAP": GREEN + BOLD, + "STANDING": GREEN + BOLD, + "ARMS": MAGENTA + BOLD, + } + for word, code in replacements.items(): + highlighted = highlighted.replace(word, f"{code}{word}{RESET}") + return highlighted + + +def sim_keyboard_controls(cfg: Any) -> tuple[KeyboardControl, ...]: + input_cfg = cfg_get(cfg, "input", {}) or {} + provider = str(cfg_get(input_cfg, "provider", "bvh")).lower() + if provider == "pico4": + keyboard_cfg = cfg_get(cfg, "keyboard", {}) or {} + if not bool(cfg_get(keyboard_cfg, "enabled", False)): + return () + return ( + KeyboardControl("Y", "mocap"), + KeyboardControl("A", "pause/resume"), + KeyboardControl("B", "arms"), + KeyboardControl("X", "standing"), + KeyboardControl("Q", "quit"), + ) + + playback_cfg = cfg_get(cfg, "playback", {}) or {} + keyboard_cfg = cfg_get(playback_cfg, "keyboard", {}) or {} + if not bool(cfg_get(keyboard_cfg, "enabled", False)): + return () + return ( + KeyboardControl("Space/P", "pause/resume"), + KeyboardControl("R", "replay"), + KeyboardControl("Q", "stop"), + ) + + +def sim2real_keyboard_controls(cfg: Any) -> tuple[KeyboardControl, ...]: + recording_cfg = cfg_get(cfg, "recording", {}) or {} + if not bool(cfg_get(recording_cfg, "enabled", False)): + return () + return ( + KeyboardControl("R", "start"), + KeyboardControl("S", "save"), + KeyboardControl("D", "discard"), + KeyboardControl("Q", "shutdown"), + ) + + +def sim2real_operator_controls(cfg: Any) -> tuple[KeyboardControl, ...]: + input_cfg = cfg_get(cfg, "input", {}) or {} + provider = str(cfg_get(input_cfg, "provider", "bvh")).lower() + controls = [ + KeyboardControl("Remote Start", "standing"), + KeyboardControl("Remote Y", "mocap"), + KeyboardControl("Remote X", "standing"), + KeyboardControl("Remote L1+R1", "damping / estop"), + ] + if provider == "pico4": + controls.extend( + [ + KeyboardControl("Pico/Controller A", "pause/resume"), + KeyboardControl("Pico/Controller B", "arms"), + ] + ) + else: + controls.extend( + [ + KeyboardControl("Remote A", "pause/resume playback"), + KeyboardControl("Remote B", "replay from start"), + ] + ) + controls.extend(sim2real_keyboard_controls(cfg)) + return tuple(controls) diff --git a/teleopit/sim/loop.py b/teleopit/sim/loop.py index 8e058135..5f1a211d 100644 --- a/teleopit/sim/loop.py +++ b/teleopit/sim/loop.py @@ -11,6 +11,7 @@ from teleopit.controllers.observation import align_motion_qpos_yaw from teleopit.runtime.reference_config import parse_reference_config from teleopit.runtime.arm_mocap import compose_arm_reference, parse_arm_joint_indices +from teleopit.runtime.console import PlainConsole, sim_keyboard_controls from teleopit.inputs.realtime_packet import RealtimeInputPacket from teleopit.interfaces import Controller, InputProvider, MessageBus, ObservationBuilder, Retargeter, Robot, RobotState from teleopit.sim.reference_timeline import ( @@ -49,6 +50,7 @@ def __init__( cfg: object, viewers: set[str] | None = None, video_runtime: object | None = None, + console: PlainConsole | None = None, ) -> None: self.robot: Robot = robot self.controller: Controller = controller @@ -56,6 +58,7 @@ def __init__( self.bus: MessageBus = bus self.cfg: object = cfg self._video_runtime = video_runtime + self._console = console or PlainConsole(title="Teleopit sim2sim") self.policy_hz: float = self._to_float(self._get_cfg("policy_hz", "sim.policy_hz", "control.policy_hz", "policy_frequency")) self.pd_hz: float = self._to_float(self._get_cfg("pd_hz", "sim.pd_hz", "control.pd_hz", "pd_frequency")) @@ -94,6 +97,7 @@ def _init_reference_config(self) -> None: self._playback_pause_on_end = bool(self._try_get_cfg("playback.pause_on_end", False)) self._playback_keyboard_enabled = bool(self._try_get_cfg("playback.keyboard.enabled", False)) self._realtime_keyboard_enabled = bool(self._try_get_cfg("keyboard.enabled", False)) + self._console_controls = sim_keyboard_controls(self.cfg) # Shared reference config (parsed once, used by both sim and sim2real) self._ref_cfg = parse_reference_config(self.cfg) diff --git a/teleopit/sim/session.py b/teleopit/sim/session.py index b5b5280e..c404200c 100644 --- a/teleopit/sim/session.py +++ b/teleopit/sim/session.py @@ -244,26 +244,27 @@ def enter_standing_mode(self) -> None: self._loop._set_standing_reference(self._loop.robot.get_state()) self.simulation_mode = SimulationMode.STANDING - def enter_mocap_mode(self) -> None: + def enter_mocap_mode(self) -> bool: from teleopit.sim.loop import SimulationMode loop = self._loop if not loop._realtime_input_has_frame(self._input_provider): _logger.warning("Cannot switch to MOCAP yet: realtime input has no frame available") - return + return False state = loop.robot.get_state() start_qpos = loop._resolve_hold_qpos(None, None, None, state) self.reset_policy_reference_state() self._step_runner.last_retarget_qpos = start_qpos.copy() self.last_commanded_motion_qpos = start_qpos.copy() self.simulation_mode = SimulationMode.MOCAP + return True - def toggle_arms_mode(self) -> None: + def toggle_arms_mode(self) -> bool: from teleopit.sim.loop import SimulationMode if not self.realtime_interpolated_input or self.simulation_mode not in (SimulationMode.MOCAP, SimulationMode.ARMS): - return + return False if self.mocap_session.state == MocapSessionState.PAUSED: _logger.info("Ignoring arm-only mode toggle while mocap session is paused") - return + return False loop = self._loop state = loop.robot.get_state() resume_qpos = loop._build_resume_alignment_qpos(self.last_commanded_motion_qpos, state) @@ -280,8 +281,9 @@ def toggle_arms_mode(self) -> None: self._step_runner.reset_reference_alignment(resume_qpos) self.last_commanded_motion_qpos = resume_qpos.copy() _logger.info("Simulation mode -> %s", self.simulation_mode.value.upper()) + return True - def toggle_realtime_mocap_pause(self) -> None: + def toggle_realtime_mocap_pause(self) -> str: loop = self._loop if self.mocap_session.state == MocapSessionState.PAUSED: hold_qpos = self.mocap_session.hold_qpos @@ -291,7 +293,7 @@ def toggle_realtime_mocap_pause(self) -> None: self.reset_policy_reference_state() self._step_runner.reset_reference_alignment(resume_qpos) self.last_commanded_motion_qpos = resume_qpos.copy() - return + return "resumed" hold_qpos = loop._resolve_hold_qpos( self.last_commanded_motion_qpos, self._step_runner.last_retarget_qpos, @@ -301,6 +303,7 @@ def toggle_realtime_mocap_pause(self) -> None: self.reset_policy_reference_state() self.mocap_session.pause(hold_qpos) self.last_commanded_motion_qpos = hold_qpos.copy() + return "paused" # ------------------------------------------------------------------ # Keyboard handling @@ -312,21 +315,32 @@ def _handle_realtime_keyboard(self) -> bool: assert self.keyboard_reader is not None for key_event in self.keyboard_reader.poll(): key = key_event.key.lower() + if key == "h": + self._loop._console.help(self._loop._console_controls) + continue if key == "q": self.playback_stop_requested = True + self._loop._console.key_feedback("Q", "quit", result="stopping") return True if self.simulation_mode == SimulationMode.STANDING: if key == "y": - self.enter_mocap_mode() + if self.enter_mocap_mode(): + self._loop._console.key_feedback("Y", "mocap", result="MOCAP") + else: + self._loop._console.key_feedback("Y", "mocap", result="waiting for input") continue if key == "x": self.enter_standing_mode() + self._loop._console.key_feedback("X", "standing", result="STANDING") continue if key == "b": - self.toggle_arms_mode() + if self.toggle_arms_mode(): + self._loop._console.key_feedback("B", "arms", result=self.simulation_mode.value.upper()) + else: + self._loop._console.key_feedback("B", "arms", result="ignored") continue if key == "a": - self.toggle_realtime_mocap_pause() + self._loop._console.key_feedback("A", "pause/resume", result=self.toggle_realtime_mocap_pause()) return False def _handle_offline_keyboard(self) -> bool: @@ -336,8 +350,12 @@ def _handle_offline_keyboard(self) -> bool: loop = self._loop for key_event in self.keyboard_reader.poll(): key = key_event.key.lower() + if key == "h": + self._loop._console.help(self._loop._console_controls) + continue if key == "q": self.playback_stop_requested = True + self._loop._console.key_feedback("Q", "stop", result="stopping") return True if key == "r": loop._restart_offline_playback( @@ -347,6 +365,7 @@ def _handle_offline_keyboard(self) -> bool: self.cached_human_frame = None self.cached_retargeted = None self.last_commanded_motion_qpos = None + self._loop._console.key_feedback("R", "replay", result="frame 0") continue if key not in (" ", "p"): continue @@ -355,6 +374,7 @@ def _handle_offline_keyboard(self) -> bool: logging.getLogger(__name__).info( "Offline playback already ended; press r to replay from frame 0." ) + self._loop._console.key_feedback("Space/P", "pause/resume", result="ended; press R") else: loop._resume_offline_playback( offline_playback=self.offline_playback, @@ -362,6 +382,7 @@ def _handle_offline_keyboard(self) -> bool: state=loop.robot.get_state(), ) self.last_commanded_motion_qpos = None + self._loop._console.key_feedback("Space/P", "pause/resume", result="resumed") else: hold_qpos = loop._resolve_hold_qpos( self.last_commanded_motion_qpos, @@ -374,6 +395,7 @@ def _handle_offline_keyboard(self) -> bool: mocap_session=self.mocap_session, hold_qpos=hold_qpos, ) + self._loop._console.key_feedback("Space/P", "pause/resume", result="paused") return False # ------------------------------------------------------------------ diff --git a/teleopit/sim2real/mp/runtime.py b/teleopit/sim2real/mp/runtime.py index 001eaeef..cde6ffdc 100644 --- a/teleopit/sim2real/mp/runtime.py +++ b/teleopit/sim2real/mp/runtime.py @@ -26,6 +26,14 @@ from teleopit.retargeting.core import RetargetingModule from teleopit.runtime.offline_playback import OfflinePlaybackController from teleopit.runtime.common import cfg_get, parse_viewers, require_section +from teleopit.runtime.console import ( + OPERATOR_LOGGER_NAME, + PlainConsole, + configure_runtime_logging, + console_show_timing, + console_timing_interval_s, + sim2real_keyboard_controls, +) from teleopit.runtime.factory import _build_policy_components, build_simulation_cfg from teleopit.runtime.arm_mocap import ( compose_arm_reference, @@ -90,6 +98,7 @@ logger = logging.getLogger(__name__) +operator_logger = logging.getLogger(OPERATOR_LOGGER_NAME) Float32Array = NDArray[np.float32] Float64Array = NDArray[np.float64] @@ -111,10 +120,12 @@ def __init__( target_period_s: float, log_interval_s: float = 1.0, deadline_miss_tolerance_s: float = 0.001, + enabled: bool = True, ) -> None: self._target_period_s = float(target_period_s) self._log_interval_s = float(log_interval_s) self._deadline_miss_tolerance_s = float(deadline_miss_tolerance_s) + self._enabled = bool(enabled) self._window_start_s: float | None = None self._loop_ms: list[float] = [] self._late_ms: list[float] = [] @@ -143,6 +154,9 @@ def _emit(self, end_s: float) -> None: if sample_count <= 0: self._reset(end_s) return + if not self._enabled: + self._reset(end_s) + return loop_summary = self._summarize(self._loop_ms) late_summary = self._summarize(self._late_ms) work_summary = self._summarize(self._work_ms) @@ -167,7 +181,7 @@ def _emit(self, end_s: float) -> None: if self._pico_age_ms: message += " | reference_age_ms p50=%.2f p95=%.2f p99=%.2f max=%.2f" args.extend(self._summarize(self._pico_age_ms)) - logger.info(message, *args) + operator_logger.info(message, *args) self._reset(end_s) def _reset(self, window_start_s: float) -> None: @@ -319,8 +333,8 @@ def _require_recording_dependencies() -> None: raise RuntimeError("LeRobot v3 recording adapter is unavailable") from exc -def _worker_loop(name: str, fn: Callable[[], None]) -> None: - logging.basicConfig(level=logging.INFO) +def _worker_loop(name: str, cfg: dict[str, Any], fn: Callable[[], None]) -> None: + configure_runtime_logging(cfg, force=True) try: fn() except KeyboardInterrupt: @@ -337,7 +351,7 @@ def _human_frame_is_valid(frame: object) -> bool: class Sim2RealRuntime: """Supervisor facade for the process-isolated sim2real runtime.""" - def __init__(self, cfg: Any) -> None: + def __init__(self, cfg: Any, *, console: PlainConsole | None = None) -> None: self.cfg = _plain_cfg(cfg) _validate_new_runtime_config(self.cfg) @@ -357,19 +371,23 @@ def __init__(self, cfg: Any) -> None: ) self._command_pub: ZmqPublisher | None = None self._keyboard: TerminalKeyboardReader | None = None + self._console = console or PlainConsole(title="Teleopit sim2real", enabled=False) + self._console_controls = sim2real_keyboard_controls(self.cfg) if _recording_enabled(self.cfg): _require_recording_dependencies() if not sys.stdin.isatty(): raise RuntimeError("recording.enabled=true requires an interactive TTY for terminal controls") + if console is None: + self._console = PlainConsole(title="Teleopit sim2real") def run(self) -> None: - logger.info("Starting sim2real runtime") + operator_logger.info("runtime starting") try: self._start_processes() if _recording_enabled(self.cfg): self._command_pub = ZmqPublisher(self._endpoints.command_pub) self._keyboard = TerminalKeyboardReader() - logger.info("Recording controls: R=start, S=save, D=discard, Q=shutdown") + operator_logger.info("keyboard recording controls active: R start, S save, D discard, Q shutdown, H help") while not self._stop_event.is_set(): self._poll_terminal_recording_controls() time.sleep(0.2) @@ -384,7 +402,7 @@ def run(self) -> None: and process.name in critical_names ] if critical_dead: - logger.error("Critical sim2real worker exited: %s", ", ".join(critical_dead)) + operator_logger.error("critical worker exited: %s", ", ".join(critical_dead)) self._stop_event.set() break noncritical_dead = [ @@ -395,9 +413,9 @@ def run(self) -> None: and process.name not in critical_names ] if noncritical_dead: - logger.warning("Non-critical sim2real worker exited: %s", ", ".join(noncritical_dead)) + operator_logger.warning("non-critical worker exited: %s", ", ".join(noncritical_dead)) except KeyboardInterrupt: - logger.info("KeyboardInterrupt -- shutting down sim2real") + operator_logger.info("keyboard interrupt -> shutting down") self._stop_event.set() finally: self.shutdown() @@ -410,7 +428,7 @@ def shutdown(self) -> None: process.join(timeout=self._shutdown_timeout_s) for process in self._processes: if process.is_alive(): - logger.warning("Terminating sim2real worker %s", process.name) + operator_logger.warning("terminating worker %s", process.name) process.terminate() process.join(timeout=1.0) self._processes.clear() @@ -459,10 +477,15 @@ def _poll_terminal_recording_controls(self) -> None: if not events: return for event in events: + normalized = str(event.key).strip().lower() + if normalized == "h": + self._console.help(self._console_controls) + continue command = map_recording_key_to_command(event.key) if command is None: continue self._command_pub.publish(COMMAND_TOPIC, CommandPacket(command=command, timestamp_s=time.monotonic())) + self._console.key_feedback(str(event.key).upper(), _recording_command_label(command)) if command == "shutdown": self._stop_event.set() @@ -480,6 +503,18 @@ def map_recording_key_to_command(key: str) -> str | None: return None +def _recording_command_label(command: str) -> str: + if command == "record_start": + return "start recording" + if command == "record_save": + return "save recording" + if command == "record_discard": + return "discard recording" + if command == "shutdown": + return "shutdown" + return command + + def _run_pico_io_worker( cfg: dict[str, Any], endpoints: Sim2RealIpcEndpoints, @@ -627,7 +662,7 @@ def _publish_recording_frame(frame: NDArray[np.generic], timestamp_s: float) -> publisher.close() provider.close() - _worker_loop("pico_input", _main) + _worker_loop("pico_input", cfg, _main) def _run_reference_worker( @@ -799,7 +834,7 @@ def _publish_invalid_reference(packet: BodyFramePacket, *, elapsed_s: float) -> command_sub.close() ref_pub.close() - _worker_loop("reference", _main) + _worker_loop("reference", cfg, _main) def _run_bvh_reference_worker( @@ -929,7 +964,7 @@ def _publish(sample_time_s: float, *, frame_valid: bool = True) -> Float64Array ref_pub.close() health_pub.close() - _worker_loop("reference", _main) + _worker_loop("reference", cfg, _main) class _RobotControlWorker: @@ -1015,8 +1050,12 @@ def __init__( self._mode_seq = 0 def run(self) -> None: - logger.info("Robot control worker started | mode=IDLE | policy_hz=%.0f", self.policy_hz) - timing = _LoopTimingReporter(target_period_s=self.dt) + operator_logger.info("robot control ready | mode=IDLE | policy_hz=%.0f", self.policy_hz) + timing = _LoopTimingReporter( + target_period_s=self.dt, + log_interval_s=console_timing_interval_s(self.cfg), + enabled=console_show_timing(self.cfg), + ) try: while not self.stop_event.is_set(): t0 = time.monotonic() @@ -1027,6 +1066,7 @@ def run(self) -> None: if self.remote.LB.pressed and self.remote.RB.pressed: if self.mode != RobotMode.DAMPING: logger.warning("EMERGENCY STOP (L1+R1)") + operator_logger.warning("DAMPING requested by emergency stop") self._enter_damping() else: self._handle_transitions() @@ -1098,38 +1138,38 @@ def _drain_ipc(self) -> None: def _handle_transitions(self) -> None: if self.mode == RobotMode.IDLE: if self.remote.start.on_pressed: - logger.info("Start pressed (from IDLE)") + operator_logger.info("Start -> STANDING") self._enter_standing() elif self.mode == RobotMode.STANDING: reentry_request = self._mocap_reentry_armed and self.remote.Y.pressed if self.remote.Y.on_pressed or reentry_request: if self._can_switch_to_mocap(): - logger.info("Y pressed -> entering MOCAP") + operator_logger.info("Y -> MOCAP") self._transition_to_mocap() else: - logger.warning("Cannot switch to MOCAP -- no fresh retarget reference") + operator_logger.warning("Y -> waiting for fresh retarget reference") elif self.mode in (RobotMode.MOCAP, RobotMode.ARMS): if self.provider_kind == "bvh" and self.remote.B.on_pressed: - logger.info("B pressed -> replaying BVH motion from start") + operator_logger.info("B -> replay BVH from frame 0") self._send_reference_command("replay_mocap") self._resume_paused_mocap_if_needed() return if self.remote.A.on_pressed: if self._mocap_session.state == MocapSessionState.PAUSED: - logger.info("A pressed -> resuming playback") + operator_logger.info("A -> resume playback") self._send_reference_command("resume_mocap") self._resume_paused_mocap() else: - logger.info("A pressed -> pausing playback") + operator_logger.info("A -> pause playback") self._send_reference_command("pause_mocap") self._pause_active_mocap() return if self.remote.X.on_pressed: - logger.info("X pressed -> returning to STANDING") + operator_logger.info("X -> STANDING") self._enter_standing() elif self.mode == RobotMode.DAMPING: if self.remote.start.on_pressed: - logger.info("Start pressed (from DAMPING)") + operator_logger.info("Start -> STANDING") self._enter_standing() def _standing_step(self) -> None: @@ -1284,7 +1324,7 @@ def _enter_standing(self) -> None: self._safety.start_kp_ramp() self._mocap_reentry_armed = prev_mode in (RobotMode.MOCAP, RobotMode.ARMS) self.mode = RobotMode.STANDING - logger.info("Mode -> STANDING (multiprocess robot control)") + operator_logger.info("mode -> STANDING") def _can_switch_to_mocap(self) -> bool: age_s = self._reference_age_s() @@ -1318,7 +1358,7 @@ def _transition_to_mocap(self) -> None: if self.provider_kind == "bvh": self._send_reference_command("replay_mocap") self.mode = RobotMode.MOCAP - logger.info("Mode -> MOCAP (tracking multiprocess retarget reference)") + operator_logger.info("mode -> MOCAP") def _toggle_arms_mode(self) -> None: if self.provider_kind != "pico4" or self.mode not in (RobotMode.MOCAP, RobotMode.ARMS): @@ -1341,7 +1381,7 @@ def _toggle_arms_mode(self) -> None: floor_ratio=self._standing_return_kp_ramp_floor_ratio, ) self.mode = next_mode - logger.info("Mode -> %s (Pico B toggle)", next_mode.value.upper()) + operator_logger.info("mode -> %s", next_mode.value.upper()) def _resume_paused_mocap_if_needed(self) -> None: if self._mocap_session.state == MocapSessionState.PAUSED: @@ -1361,7 +1401,7 @@ def _enter_damping(self) -> None: self._mocap_session.reset() self._last_commanded_motion_qpos = None self._last_mocap_hold_reason = None - logger.info("Mode -> DAMPING (press Start to re-enter STANDING)") + operator_logger.warning("mode -> DAMPING") def _reset_policy_state(self) -> None: self._last_action = np.zeros(self.num_actions, dtype=np.float32) @@ -1633,7 +1673,7 @@ def _main() -> None: worker = _RobotControlWorker(cfg, endpoints, stop_event) worker.run() - _worker_loop("robot_control", _main) + _worker_loop("robot_control", cfg, _main) class _RecordingWorker: @@ -1686,7 +1726,7 @@ def __init__( ) def run(self) -> None: - logger.info("Recording worker started | fps=%d | modes=%s", self.fps, sorted(self.record_modes)) + operator_logger.info("recording worker ready | fps=%d | modes=%s", self.fps, sorted(self.record_modes)) idle_sleep_s = 1.0 / max(float(self.fps) * 4.0, 1.0) try: while not self.stop_event.is_set(): @@ -1751,11 +1791,11 @@ def _start_episode(self) -> None: self._active = True self._episode_started_s = time.monotonic() self._episode_frames = 0 - logger.info("Recording episode started") + operator_logger.info("recording episode started") def _save_episode(self) -> None: if not self._active: - logger.info("No active recording episode to save") + operator_logger.info("no active recording episode to save") return duration_s = time.monotonic() - self._episode_started_s if self._episode_frames <= 0: @@ -1765,16 +1805,16 @@ def _save_episode(self) -> None: self._discard_episode(f"short episode ({duration_s:.2f}s < {self.min_episode_seconds:.2f}s)") return self._recorder.save_episode() - logger.info("Recording episode saved | frames=%d duration=%.2fs", self._episode_frames, duration_s) + operator_logger.info("recording episode saved | frames=%d duration=%.2fs", self._episode_frames, duration_s) self._active = False self._episode_frames = 0 def _discard_episode(self, reason: str) -> None: if not self._active: - logger.info("No active recording episode to discard") + operator_logger.info("no active recording episode to discard") return self._recorder.discard_episode() - logger.info("Recording episode discarded | reason=%s | frames=%d", reason, self._episode_frames) + operator_logger.info("recording episode discarded | reason=%s | frames=%d", reason, self._episode_frames) self._active = False self._episode_frames = 0 @@ -1811,7 +1851,7 @@ def _main() -> None: worker = _RecordingWorker(cfg, endpoints, stop_event) worker.run() - _worker_loop("recording_worker", _main) + _worker_loop("recording_worker", cfg, _main) class _HandSnapshotProxy: @@ -1875,4 +1915,4 @@ def _main() -> None: mode_sub.close() command_sub.close() - _worker_loop("hand_worker", _main) + _worker_loop("hand_worker", cfg, _main) diff --git a/tests/test_console.py b/tests/test_console.py new file mode 100644 index 00000000..e4e163b8 --- /dev/null +++ b/tests/test_console.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import logging + +from teleopit.runtime.console import ( + OPERATOR_LOGGER_NAME, + PlainConsole, + configure_runtime_logging, + sim2real_operator_controls, + sim_keyboard_controls, +) + + +def test_sim_console_shows_only_enabled_keyboard_controls() -> None: + cfg = { + "input": {"provider": "pico4"}, + "keyboard": {"enabled": True}, + } + + labels = [control.keys for control in sim_keyboard_controls(cfg)] + + assert labels == ["Y", "A", "B", "X", "Q"] + + +def test_sim_console_hides_non_keyboard_controls() -> None: + cfg = { + "input": {"provider": "pico4"}, + "keyboard": {"enabled": False}, + } + + assert sim_keyboard_controls(cfg) == () + + +def test_sim2real_console_shows_remote_and_pico_controls() -> None: + cfg = {"input": {"provider": "pico4"}, "recording": {"enabled": False}} + + rendered = PlainConsole(title="Teleopit sim2real", color=False).render( + controls=sim2real_operator_controls(cfg), + control_section="Controls", + show_help_key=False, + ) + + assert "Controls" in rendered + assert "[Remote Start] standing" in rendered + assert "[Remote L1+R1] damping / estop" in rendered + assert "[Pico/Controller A] pause/resume" in rendered + assert "[Pico/Controller B] arms" in rendered + assert "[H] help" not in rendered + + +def test_sim2real_console_includes_terminal_recording_controls_when_enabled() -> None: + cfg = {"recording": {"enabled": True}} + + rendered = PlainConsole(title="Teleopit sim2real", color=False).render( + controls=sim2real_operator_controls(cfg), + control_section="Controls", + show_help_key=False, + ) + + assert "Controls" in rendered + assert "[Remote Start] standing" in rendered + assert "[R] start" in rendered + assert "[Q] shutdown" in rendered + assert "[H] help" not in rendered + + +def test_console_can_highlight_important_words_with_ansi_color() -> None: + rendered = PlainConsole(title="Teleopit sim2sim", color=True).render( + status=(("State", "MOCAP"),), + controls=sim_keyboard_controls({"input": {"provider": "bvh"}, "playback": {"keyboard": {"enabled": True}}}), + events=("waiting for input",), + ) + + assert "\033[" in rendered + assert "[Space/P]" in rendered + assert "MOCAP" in rendered + + +def test_runtime_logging_keeps_operator_info_but_hides_noisy_info() -> None: + configure_runtime_logging({"console": {"log_level": "warning"}}, force=True) + + assert logging.getLogger().getEffectiveLevel() == logging.WARNING + assert logging.getLogger(OPERATOR_LOGGER_NAME).getEffectiveLevel() == logging.INFO + assert logging.getLogger("pico_bridge").getEffectiveLevel() == logging.WARNING From 6da60d7b1fcc47ef129bd0c34f9eee7dbcb41db1 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Mon, 22 Jun 2026 21:22:23 +0800 Subject: [PATCH 114/122] Record LinkerHand actions in sim2real datasets --- AGENTS.md | 2 +- README.md | 7 +- docs/docs/configuration/config-reference.md | 3 + docs/docs/tutorials/pico-sim2real.md | 2 +- .../current/configuration/config-reference.md | 3 + .../current/tutorials/pico-sim2real.md | 2 +- teleopit/recording/lerobot_v3.py | 37 ++++++ teleopit/sim2real/hands/worker.py | 75 ++++++++++--- teleopit/sim2real/mp/ipc.py | 21 ++-- teleopit/sim2real/mp/messages.py | 11 ++ teleopit/sim2real/mp/runtime.py | 105 +++++++++++++++++- tests/test_dexterous_hand.py | 53 +++++++++ tests/test_sim2real_multiprocess.py | 57 +++++++++- 13 files changed, 342 insertions(+), 36 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index ceaa5264..6af3f639 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -148,7 +148,7 @@ target_dof_pos = clip(action, -10, 10) × action_scale + default_dof_pos - Optional LinkerHand control uses `hands.enabled=true`, `hands.driver=linkerhand_l6|linkerhand_o6`, and `hands.mode=gripper|vr_hand_pose`; default is disabled - Optional Pico sim2real LeRobot v3 recording uses `--config-name sim2real_record` or `recording.enabled=true`; it requires `input.provider=pico4`, `input.video.enabled=true`, `input.video.source=realsense`, an interactive terminal, and the `recording` extra - Recording is manual only: terminal `R` starts an episode, `S` saves, `D` discards the active episode, and `Q` shuts down; `STANDING`, `MOCAP`, `ARMS`, and paused mocap are recordable -- Recording captures `observation.images.d435i_rgb` RealSense RGB video at 30Hz plus `observation.state(68)`, `observation.mode(1)`, and `action(36)`; RealSense capture lives in `pico_input` through the normal `input.video` path +- Recording captures `observation.images.d435i_rgb` RealSense RGB video at 30Hz plus `observation.state(68)`, `observation.mode(1)`, `action(36)`, and `action.hand(12)`; RealSense capture lives in `pico_input` through the normal `input.video` path - `gripper` mode reuses `Pico4InputProvider.get_controller_snapshot()` for Pico grip/trigger open-close control and supports LinkerHand L6 and O6 - `vr_hand_pose` mode reuses `Pico4InputProvider.get_hand_snapshot()` and somehand 0.2.0 public `somehand.api` for continuous Pico hand-pose retargeting; do not start a second `PicoBridge` for hand control - Teleopit owns Pico 26-joint hand-state to 21-landmark conversion; do not import `somehand.pico_input` diff --git a/README.md b/README.md index bd43f945..adb47fdf 100644 --- a/README.md +++ b/README.md @@ -98,8 +98,9 @@ python scripts/run/run_sim2real.py --config-name sim2real_record \ Recording uses the terminal controls `R` start, `S` save, `D` discard, and `Q` shutdown. `STANDING`, `MOCAP`, `ARMS`, and paused mocap can be recorded. The dataset schema is `observation.images.d435i_rgb` video at 30 Hz, -`observation.state(68)`, `observation.mode(1)`, and `action(36)` as the aligned -reference qpos sent to the policy path. +`observation.state(68)`, `observation.mode(1)`, `action(36)` as the aligned +reference qpos sent to the policy path, and `action.hand(12)` as the latest +LinkerHand left/right 6D pose commands. ## Documentation @@ -112,7 +113,7 @@ Full docs at **[BotRunner64.github.io/Teleopit](https://BotRunner64.github.io/Te - Added Pico sim2real `ARMS` mode: Pico/controller `B` toggles between whole-body `MOCAP` and stand-pose body/legs with live retargeted arms. - Bumped Pico input support to pico-bridge 0.2.1 and its corrected tracking pose semantics. - Added optional LinkerHand L6 sim2real modes under `hands.*`: `gripper` from Pico grip/trigger and `vr_hand_pose` from Pico hand pose through somehand 0.2.0 public API. -- Added manual Pico sim2real LeRobot v3 recording with RealSense D435i RGB video, 68D robot state, mode labels, and 36D reference-qpos action labels. +- Added manual Pico sim2real LeRobot v3 recording with RealSense D435i RGB video, 68D robot state, mode labels, 36D reference-qpos action labels, and 12D LinkerHand pose action labels. - Added LinkerHand O6 support for Pico `gripper` mode with an O6-specific grasp pose. - Set LinkerHand L6 `vr_hand_pose` control to maximum speed while keeping `gripper` at the configured default speed. - Switched default `vr_hand_pose` to a low-latency somehand path with 60 Hz hand retargeting and reduced smoothing. diff --git a/docs/docs/configuration/config-reference.md b/docs/docs/configuration/config-reference.md index 6853f438..362b3df8 100644 --- a/docs/docs/configuration/config-reference.md +++ b/docs/docs/configuration/config-reference.md @@ -186,6 +186,7 @@ observation.images.d435i_rgb video [480,640,3] uint8 observation.state float32[68] observation.mode float32[1] action float32[36] +action.hand float32[12] ``` `observation.state` is ordered as `joint_pos(29)`, `joint_vel(29)`, @@ -193,6 +194,8 @@ action float32[36] `observation.mode` is a numeric categorical: `standing=0`, `mocap=1`, `arms=2`, and `pause=3`. `action` is the current reference qpos: `root_pos(3) + root_quat_wxyz(4) + joint_pos(29)`. +`action.hand` is the latest LinkerHand command from the hand worker: +`left_pose(6) + right_pose(6)`, using the SDK's 0-255 pose values. ## Critical: `default_dof_pos` diff --git a/docs/docs/tutorials/pico-sim2real.md b/docs/docs/tutorials/pico-sim2real.md index 957cdd95..d172e4a3 100644 --- a/docs/docs/tutorials/pico-sim2real.md +++ b/docs/docs/tutorials/pico-sim2real.md @@ -117,7 +117,7 @@ Terminal controls are `R` start episode, `S` save, `D` discard, and `Q` shutdown. `STANDING`, `MOCAP`, `ARMS`, and paused mocap can be recorded; saved episodes cannot be discarded afterward. The v1 schema records `observation.images.d435i_rgb`, `observation.state(68)`, -`observation.mode(1)`, and `action(36)` at 30 Hz. +`observation.mode(1)`, `action(36)`, and `action.hand(12)` at 30 Hz. ## Operator Flow diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md index d241e69d..7d573842 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md @@ -203,6 +203,7 @@ observation.images.d435i_rgb video [480,640,3] uint8 observation.state float32[68] observation.mode float32[1] action float32[36] +action.hand float32[12] ``` `observation.state` 的顺序是 `joint_pos(29)`、`joint_vel(29)`、 @@ -210,3 +211,5 @@ action float32[36] `observation.mode` 是数值类别:`standing=0`、`mocap=1`、 `arms=2`、`pause=3`。`action` 是当前 reference qpos: `root_pos(3) + root_quat_wxyz(4) + joint_pos(29)`。 +`action.hand` 是手部 worker 最新的 LinkerHand 命令: +`left_pose(6) + right_pose(6)`,使用 SDK 的 0-255 pose 数值。 diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md index ef67634d..3127eeeb 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md @@ -112,7 +112,7 @@ python scripts/run/run_sim2real.py \ 终端控制为:`R` 开始 episode,`S` 保存,`D` 丢弃,`Q` 关闭。可以录制 `STANDING`、`MOCAP`、`ARMS` 和暂停状态的 mocap;已经保存的 episode 不支持再丢弃。 v1 schema 以 30 Hz 记录 `observation.images.d435i_rgb`、`observation.state(68)`、 -`observation.mode(1)` 和 `action(36)`。 +`observation.mode(1)`、`action(36)` 和 `action.hand(12)`。 ## 操作流程 diff --git a/teleopit/recording/lerobot_v3.py b/teleopit/recording/lerobot_v3.py index 3c2cf00d..80391f6d 100644 --- a/teleopit/recording/lerobot_v3.py +++ b/teleopit/recording/lerobot_v3.py @@ -19,9 +19,11 @@ STATE_KEY = "observation.state" MODE_KEY = "observation.mode" ACTION_KEY = "action" +HAND_ACTION_KEY = "action.hand" STATE_DIM = 68 MODE_DIM = 1 ACTION_DIM = FULL_QPOS_DIM +HAND_ACTION_DIM = 12 DEFAULT_IMAGE_SHAPE = (480, 640, 3) MODE_CODES = { "standing": 0, @@ -41,6 +43,8 @@ class RecordingSchema: mode_dim: int = MODE_DIM action_key: str = ACTION_KEY action_dim: int = ACTION_DIM + hand_action_key: str = HAND_ACTION_KEY + hand_action_dim: int = HAND_ACTION_DIM def build_recording_schema(camera_cfg: Any) -> RecordingSchema: @@ -74,6 +78,11 @@ def lerobot_features(schema: RecordingSchema) -> dict[str, dict[str, object]]: "shape": (schema.action_dim,), "names": ["action"], }, + schema.hand_action_key: { + "dtype": "float32", + "shape": (schema.hand_action_dim,), + "names": ["hand_action"], + }, } @@ -114,6 +123,16 @@ def modality_sidecar(schema: RecordingSchema) -> dict[str, object]: "joint_pos": [7, 36], }, }, + schema.hand_action_key: { + "type": "low_dim", + "shape": [schema.hand_action_dim], + "dtype": "float32", + "units": "linkerhand_uint8_pose", + "slices": { + "left_pose": [0, 6], + "right_pose": [6, 12], + }, + }, }, } @@ -149,6 +168,19 @@ def normalize_action_reference_qpos(reference_qpos: object) -> np.ndarray: return action +def normalize_hand_action(left_pose: object, right_pose: object) -> np.ndarray: + left = np.asarray(left_pose, dtype=np.float32).reshape(-1) + right = np.asarray(right_pose, dtype=np.float32).reshape(-1) + if left.shape[0] != 6: + raise ValueError(f"recording left hand pose must be 6D, got {left.shape[0]}") + if right.shape[0] != 6: + raise ValueError(f"recording right hand pose must be 6D, got {right.shape[0]}") + action = np.concatenate([left, right], dtype=np.float32) + if action.shape[0] != HAND_ACTION_DIM: + raise ValueError(f"recording action.hand must be {HAND_ACTION_DIM}D, got {action.shape[0]}") + return action + + def build_mode_observation(mode: str) -> np.ndarray: normalized = str(mode).strip().lower() if normalized not in MODE_CODES: @@ -228,6 +260,7 @@ def add_frame( state: np.ndarray, mode: np.ndarray, action: np.ndarray, + hand_action: np.ndarray, task: str, ) -> None: if not self._active: @@ -238,18 +271,22 @@ def add_frame( state_arr = np.asarray(state, dtype=np.float32).reshape(-1) mode_arr = np.asarray(mode, dtype=np.float32).reshape(-1) action_arr = np.asarray(action, dtype=np.float32).reshape(-1) + hand_action_arr = np.asarray(hand_action, dtype=np.float32).reshape(-1) if state_arr.shape[0] != self._schema.state_dim: raise ValueError(f"{self._schema.state_key} must be {self._schema.state_dim}D") if mode_arr.shape[0] != self._schema.mode_dim: raise ValueError(f"{self._schema.mode_key} must be {self._schema.mode_dim}D") if action_arr.shape[0] != self._schema.action_dim: raise ValueError(f"{self._schema.action_key} must be {self._schema.action_dim}D") + if hand_action_arr.shape[0] != self._schema.hand_action_dim: + raise ValueError(f"{self._schema.hand_action_key} must be {self._schema.hand_action_dim}D") self._dataset.add_frame( { self._schema.image_key: image_arr, self._schema.state_key: state_arr, self._schema.mode_key: mode_arr, self._schema.action_key: action_arr, + self._schema.hand_action_key: hand_action_arr, "task": str(task), } ) diff --git a/teleopit/sim2real/hands/worker.py b/teleopit/sim2real/hands/worker.py index b969d851..e9fe323a 100644 --- a/teleopit/sim2real/hands/worker.py +++ b/teleopit/sim2real/hands/worker.py @@ -2,10 +2,10 @@ import logging import time -from typing import Any +from typing import Any, Sequence from teleopit.runtime.common import cfg_get -from teleopit.sim2real.hands.base import HandDevice, HandInputMapper +from teleopit.sim2real.hands.base import HandDevice, HandInputMapper, HandPoseCommand from teleopit.sim2real.hands.linkerhand_l6 import build_linkerhand_l6 from teleopit.sim2real.hands.linkerhand_o6 import build_linkerhand_o6 @@ -13,34 +13,53 @@ class HandRuntime: - def __init__(self, device: HandDevice, mapper: HandInputMapper): + def __init__( + self, + device: HandDevice, + mapper: HandInputMapper, + *, + open_commands: Sequence[HandPoseCommand] = (), + ): self._device = device self._mapper = mapper self.enabled = True self._failed = False + self._open_commands = tuple(open_commands) - def start(self) -> None: + def start(self) -> tuple[HandPoseCommand, ...]: try: self._device.connect() self._mapper.start() + return self._open_pose_commands("startup") except Exception: try: self._device.close() finally: raise - def tick(self, *, controller_snapshot: object | None, hand_snapshot: object | None, active: bool, now_s: float | None = None) -> None: + def tick( + self, + *, + controller_snapshot: object | None, + hand_snapshot: object | None, + active: bool, + now_s: float | None = None, + ) -> tuple[HandPoseCommand, ...]: if self._failed: - return + return () now = time.monotonic() if now_s is None else float(now_s) try: - for command in self._mapper.map( + commands = self._mapper.map( controller_snapshot=controller_snapshot, hand_snapshot=hand_snapshot, active=active, now_s=now, - ): + ) + sent: list[HandPoseCommand] = [] + for command in commands: self._device.send_pose(command.side, command.pose, force=command.force, reason=command.reason) + sent.append(command) + return tuple(sent) except Exception: self._failed = True logger.exception("Hand runtime failed; disabling hand control") @@ -48,25 +67,42 @@ def tick(self, *, controller_snapshot: object | None, hand_snapshot: object | No self._device.open_all(force=True, reason="failure") except Exception: logger.exception("Failed to open hand after hand runtime failure") + return () + return self._open_pose_commands("failure") - def close(self) -> None: + def close(self) -> tuple[HandPoseCommand, ...]: try: self._mapper.close() finally: self._device.close() + return self._open_pose_commands("shutdown") + + def _open_pose_commands(self, reason: str) -> tuple[HandPoseCommand, ...]: + return tuple( + HandPoseCommand(command.side, command.pose, True, reason) + for command in self._open_commands + ) class DisabledHandRuntime: enabled = False - def start(self) -> None: - pass + def start(self) -> tuple[HandPoseCommand, ...]: + return () - def tick(self, *, controller_snapshot: object | None, hand_snapshot: object | None, active: bool, now_s: float | None = None) -> None: + def tick( + self, + *, + controller_snapshot: object | None, + hand_snapshot: object | None, + active: bool, + now_s: float | None = None, + ) -> tuple[HandPoseCommand, ...]: del controller_snapshot, hand_snapshot, active, now_s + return () - def close(self) -> None: - pass + def close(self) -> tuple[HandPoseCommand, ...]: + return () def build_hand_runtime(cfg: Any) -> HandRuntime | DisabledHandRuntime: @@ -80,4 +116,13 @@ def build_hand_runtime(cfg: Any) -> HandRuntime | DisabledHandRuntime: device, mapper = build_linkerhand_o6(cfg) else: raise ValueError(f"Unsupported hands.driver={driver!r}; supported drivers: linkerhand_l6, linkerhand_o6") - return HandRuntime(device, mapper) + return HandRuntime(device, mapper, open_commands=_open_commands_from_device(device)) + + +def _open_commands_from_device(device: HandDevice) -> tuple[HandPoseCommand, ...]: + config = getattr(device, "config", None) + sides = tuple(str(side).strip().lower() for side in getattr(config, "sides", ())) + open_pose = tuple(int(value) for value in getattr(config, "open_pose", ())) + if not sides or len(open_pose) != 6: + return () + return tuple(HandPoseCommand(side, open_pose, True, "open") for side in sides) diff --git a/teleopit/sim2real/mp/ipc.py b/teleopit/sim2real/mp/ipc.py index c12cd878..e0c3b1c9 100644 --- a/teleopit/sim2real/mp/ipc.py +++ b/teleopit/sim2real/mp/ipc.py @@ -12,6 +12,7 @@ BODY_TOPIC = "body" HAND_TOPIC = "hand" +HAND_COMMAND_TOPIC = "hand_command" CONTROLLER_TOPIC = "controller" CONTROL_EVENTS_TOPIC = "control_events" REFERENCE_TOPIC = "reference" @@ -26,6 +27,7 @@ class Sim2RealIpcEndpoints: body_pub: str hand_pub: str + hand_command_pub: str controller_pub: str control_events_pub: str reference_pub: str @@ -43,15 +45,16 @@ def default_endpoints(*, host: str = "127.0.0.1", base_port: int = 39700) -> Sim return Sim2RealIpcEndpoints( body_pub=f"{prefix}{base_port}", hand_pub=f"{prefix}{base_port + 1}", - controller_pub=f"{prefix}{base_port + 2}", - control_events_pub=f"{prefix}{base_port + 3}", - reference_pub=f"{prefix}{base_port + 4}", - mode_pub=f"{prefix}{base_port + 5}", - video_pub=f"{prefix}{base_port + 6}", - record_pub=f"{prefix}{base_port + 7}", - health_pub=f"{prefix}{base_port + 8}", - command_pub=f"{prefix}{base_port + 9}", - reference_command_pub=f"{prefix}{base_port + 10}", + hand_command_pub=f"{prefix}{base_port + 2}", + controller_pub=f"{prefix}{base_port + 3}", + control_events_pub=f"{prefix}{base_port + 4}", + reference_pub=f"{prefix}{base_port + 5}", + mode_pub=f"{prefix}{base_port + 6}", + video_pub=f"{prefix}{base_port + 7}", + record_pub=f"{prefix}{base_port + 8}", + health_pub=f"{prefix}{base_port + 9}", + command_pub=f"{prefix}{base_port + 10}", + reference_command_pub=f"{prefix}{base_port + 11}", ) diff --git a/teleopit/sim2real/mp/messages.py b/teleopit/sim2real/mp/messages.py index ec22eda5..a56bb1e7 100644 --- a/teleopit/sim2real/mp/messages.py +++ b/teleopit/sim2real/mp/messages.py @@ -71,6 +71,17 @@ class RecordStepPacket: seq: int +@dataclass(frozen=True) +class HandCommandPacket: + timestamp_s: float + driver: str + mode: str + active: bool + left_pose: Float64Array + right_pose: Float64Array + seq: int + + @dataclass(frozen=True) class HealthPacket: worker: str diff --git a/teleopit/sim2real/mp/runtime.py b/teleopit/sim2real/mp/runtime.py index cde6ffdc..472785c0 100644 --- a/teleopit/sim2real/mp/runtime.py +++ b/teleopit/sim2real/mp/runtime.py @@ -46,6 +46,7 @@ from teleopit.recording.lerobot_v3 import ( build_mode_observation, build_observation_state, + normalize_hand_action, normalize_action_reference_qpos, ) from teleopit.sim.reference_motion import OfflineReferenceMotion @@ -58,11 +59,15 @@ from teleopit.sim.realtime_utils import RealtimeReferenceManager from teleopit.sim.viewer_subprocess import start_robot_viewer from teleopit.sim2real.hands.worker import build_hand_runtime +from teleopit.sim2real.hands.base import HandPoseCommand +from teleopit.sim2real.hands.linkerhand_l6 import parse_linkerhand_l6_config +from teleopit.sim2real.hands.linkerhand_o6 import parse_linkerhand_o6_config from teleopit.sim2real.mp.ipc import ( BODY_TOPIC, COMMAND_TOPIC, CONTROL_EVENTS_TOPIC, CONTROLLER_TOPIC, + HAND_COMMAND_TOPIC, HAND_TOPIC, HEALTH_TOPIC, MODE_TOPIC, @@ -78,6 +83,7 @@ BodyFramePacket, CommandPacket, ControlEventsPacket, + HandCommandPacket, HealthPacket, ModeStatePacket, ReferencePacket, @@ -273,6 +279,26 @@ def _recording_camera_cfg(cfg: Any) -> Any: return cfg_get(_recording_cfg(cfg), "camera", {}) or {} +def _configured_open_hand_pose(cfg: Any) -> tuple[np.ndarray, np.ndarray]: + hands_cfg = cfg_get(cfg, "hands", {}) or {} + driver = str(cfg_get(hands_cfg, "driver", "linkerhand_l6")).strip().lower() + if bool(cfg_get(hands_cfg, "enabled", False)): + if driver == "linkerhand_o6": + hand_cfg = parse_linkerhand_o6_config(cfg) + else: + hand_cfg = parse_linkerhand_l6_config(cfg) + pose = np.asarray(hand_cfg.open_pose, dtype=np.float32).reshape(-1) + elif driver == "linkerhand_o6": + pose = np.array([250, 250, 250, 250, 250, 250], dtype=np.float32) + else: + driver_cfg = cfg_get(hands_cfg, "linkerhand_l6", {}) or {} + thumb_yaw = int(cfg_get(driver_cfg, "thumb_yaw_center", 10)) + pose = np.array([250, thumb_yaw, 250, 250, 250, 250], dtype=np.float32) + if pose.shape[0] != 6: + raise ValueError(f"hands.{driver}.open_pose must contain 6 values") + return pose.copy(), pose.copy() + + def _validate_new_runtime_config(cfg: Any) -> None: legacy_keys = [key for key in ("sim2real_runtime", "multiprocess", "dexterous_hand") if cfg_get(cfg, key, None) is not None] if legacy_keys: @@ -1701,9 +1727,20 @@ def __init__( self.fps = int(cfg_get(self.rec_cfg, "fps", 30)) self._record_sub = LatestSubscriber(endpoints.record_pub, RECORD_TOPIC) self._video_sub = LatestSubscriber(endpoints.video_pub, VIDEO_TOPIC) + self._hand_command_sub = LatestSubscriber(endpoints.hand_command_pub, HAND_COMMAND_TOPIC) self._command_sub = LatestSubscriber(endpoints.command_pub, COMMAND_TOPIC) self._frame_reader = frame_reader or SharedFrameRingReader() self._latest_record: RecordStepPacket | None = None + left_open, right_open = _configured_open_hand_pose(cfg) + self._latest_hand_command = HandCommandPacket( + timestamp_s=0.0, + driver=str(cfg_get(cfg_get(cfg, "hands", {}) or {}, "driver", "linkerhand_l6")).strip().lower(), + mode=str(cfg_get(cfg_get(cfg, "hands", {}) or {}, "mode", "gripper")).strip().lower(), + active=False, + left_pose=left_open.astype(np.float32, copy=True), + right_pose=right_open.astype(np.float32, copy=True), + seq=0, + ) self._latest_video_seq = -1 self._active = False self._episode_started_s = 0.0 @@ -1739,6 +1776,10 @@ def run(self) -> None: if isinstance(record, RecordStepPacket): self._latest_record = record + hand_command = self._hand_command_sub.recv_latest() + if isinstance(hand_command, HandCommandPacket): + self._latest_hand_command = hand_command + video = self._video_sub.recv_latest() if isinstance(video, SharedFrameDescriptor): self._handle_video(video) @@ -1755,6 +1796,7 @@ def run(self) -> None: finally: self._record_sub.close() self._video_sub.close() + self._hand_command_sub.close() self._command_sub.close() self._frame_reader.close() @@ -1838,6 +1880,10 @@ def _handle_video(self, descriptor: SharedFrameDescriptor) -> None: state=np.asarray(record.observation_state, dtype=np.float32), mode=np.asarray(record.observation_mode, dtype=np.float32), action=np.asarray(record.action_reference_qpos, dtype=np.float32), + hand_action=normalize_hand_action( + self._latest_hand_command.left_pose, + self._latest_hand_command.right_pose, + ), task=self.task, ) self._episode_frames += 1 @@ -1878,11 +1924,55 @@ def _main() -> None: controller_sub = LatestSubscriber(endpoints.controller_pub, CONTROLLER_TOPIC) mode_sub = LatestSubscriber(endpoints.mode_pub, MODE_TOPIC) command_sub = LatestSubscriber(endpoints.command_pub, COMMAND_TOPIC) + hand_command_pub = ZmqPublisher(endpoints.hand_command_pub) active = False hz = float(cfg_get(_mp_cfg(cfg), "hand_worker_hz", 120.0)) sleep_s = 1.0 / max(hz, 1.0) + hands_cfg = cfg_get(cfg, "hands", {}) or {} + driver = str(cfg_get(hands_cfg, "driver", "linkerhand_l6")).strip().lower() + hand_mode = str(cfg_get(hands_cfg, "mode", "gripper")).strip().lower() + left_pose, right_pose = _configured_open_hand_pose(cfg) + command_seq = 0 + + def _apply_hand_commands(commands: tuple[HandPoseCommand, ...]) -> bool: + nonlocal left_pose, right_pose + changed = False + for hand_command in commands: + pose = np.asarray(hand_command.pose, dtype=np.float32).reshape(-1) + if pose.shape[0] != 6: + logger.warning("Ignoring %s hand command with invalid pose shape %s", hand_command.side, pose.shape) + continue + if hand_command.side == "left": + left_pose = pose.copy() + changed = True + elif hand_command.side == "right": + right_pose = pose.copy() + changed = True + else: + logger.warning("Ignoring hand command with unsupported side %r", hand_command.side) + return changed + + def _publish_hand_command(*, timestamp_s: float, active_state: bool) -> None: + nonlocal command_seq + command_seq += 1 + hand_command_pub.publish( + HAND_COMMAND_TOPIC, + HandCommandPacket( + timestamp_s=float(timestamp_s), + driver=driver, + mode=hand_mode, + active=bool(active_state), + left_pose=np.asarray(left_pose, dtype=np.float32).copy(), + right_pose=np.asarray(right_pose, dtype=np.float32).copy(), + seq=command_seq, + ), + ) + try: - runtime.start() + startup_commands = runtime.start() + startup_s = time.monotonic() + _apply_hand_commands(startup_commands) + _publish_hand_command(timestamp_s=startup_s, active_state=False) while not stop_event.is_set(): command = command_sub.recv_latest() if isinstance(command, CommandPacket) and command.command == "shutdown": @@ -1898,21 +1988,30 @@ def _main() -> None: if isinstance(mode_packet, ModeStatePacket): active = bool(mode_packet.mocap_active) try: - runtime.tick( + now_s = time.monotonic() + commands = runtime.tick( controller_snapshot=proxy.controller_snapshot, hand_snapshot=proxy.hand_snapshot, active=active, + now_s=now_s, ) + if commands: + if _apply_hand_commands(commands): + _publish_hand_command(timestamp_s=now_s, active_state=active) except Exception: logger.exception("Dexterous hand worker tick failed; hand control continues") time.sleep(sleep_s) finally: try: - runtime.close() + shutdown_commands = runtime.close() + shutdown_s = time.monotonic() + if _apply_hand_commands(shutdown_commands): + _publish_hand_command(timestamp_s=shutdown_s, active_state=False) finally: hand_sub.close() controller_sub.close() mode_sub.close() command_sub.close() + hand_command_pub.close() _worker_loop("hand_worker", cfg, _main) diff --git a/tests/test_dexterous_hand.py b/tests/test_dexterous_hand.py index 272fa6a4..53aad28a 100644 --- a/tests/test_dexterous_hand.py +++ b/tests/test_dexterous_hand.py @@ -14,6 +14,7 @@ parse_linkerhand_l6_config, trigger_to_pose, ) +from teleopit.sim2real.hands.base import HandPoseCommand from teleopit.sim2real.hands.linkerhand_o6 import ( CLOSE_POSE as O6_CLOSE_POSE, LinkerHandO6Device, @@ -283,6 +284,58 @@ def close(self) -> None: assert calls == ["connect", "mapper_start", "close"] +def test_hand_runtime_reports_actual_open_commands() -> None: + calls: list[tuple[str, object, object]] = [] + open_commands = ( + HandPoseCommand("left", (250, 10, 250, 250, 250, 250), True, "open"), + HandPoseCommand("right", (250, 10, 250, 250, 250, 250), True, "open"), + ) + + class FakeDevice: + def connect(self) -> None: + calls.append(("connect", None, None)) + + def send_pose(self, side, pose, *, force=False, reason="") -> None: + calls.append((side, tuple(pose), reason)) + + def open_all(self, *, force=False, reason="") -> None: + calls.append(("open_all", force, reason)) + + def close(self) -> None: + calls.append(("close", None, None)) + + class Mapper: + def __init__(self) -> None: + self.fail = False + + def start(self) -> None: + calls.append(("mapper_start", None, None)) + + def map(self, *args, **kwargs): + if self.fail: + raise RuntimeError("tick failed") + return (HandPoseCommand("left", (1, 2, 3, 4, 5, 6), False, "mapped"),) + + def close(self) -> None: + calls.append(("mapper_close", None, None)) + + mapper = Mapper() + runtime = HandRuntime(FakeDevice(), mapper, open_commands=open_commands) + + startup = runtime.start() + ticked = runtime.tick(controller_snapshot=None, hand_snapshot=None, active=True, now_s=1.0) + mapper.fail = True + failure = runtime.tick(controller_snapshot=None, hand_snapshot=None, active=True, now_s=2.0) + shutdown = runtime.close() + + assert [command.reason for command in startup] == ["startup", "startup"] + assert ticked[0].pose == (1, 2, 3, 4, 5, 6) + assert [command.reason for command in failure] == ["failure", "failure"] + assert [command.reason for command in shutdown] == ["shutdown", "shutdown"] + assert ("open_all", True, "failure") in calls + assert ("close", None, None) in calls + + def test_linkerhand_l6_device_wraps_sdk_system_exit_and_cleans_up(monkeypatch) -> None: created_hands = [] diff --git a/tests/test_sim2real_multiprocess.py b/tests/test_sim2real_multiprocess.py index 2be0b10c..1688892e 100644 --- a/tests/test_sim2real_multiprocess.py +++ b/tests/test_sim2real_multiprocess.py @@ -13,6 +13,7 @@ from teleopit.runtime.arm_mocap import compose_arm_reference, compose_arm_reference_window from teleopit.recording.lerobot_v3 import ( ACTION_KEY, + HAND_ACTION_KEY, IMAGE_KEY, MODE_KEY, STATE_KEY, @@ -23,7 +24,7 @@ modality_sidecar, ) from teleopit.sim2real.mp.ipc import HEALTH_TOPIC, LatestSubscriber, ZmqPublisher -from teleopit.sim2real.mp.messages import RecordStepPacket, ReferencePacket, SharedFrameDescriptor +from teleopit.sim2real.mp.messages import HandCommandPacket, RecordStepPacket, ReferencePacket, SharedFrameDescriptor from teleopit.sim.reference_timeline import ReferenceSample, ReferenceWindow from teleopit.sim2real.mp.runtime import ( map_recording_key_to_command, @@ -32,6 +33,7 @@ _LoopTimingReporter, _RecordingWorker, _RobotControlWorker, + _configured_open_hand_pose, _human_frame_is_valid, ) from teleopit.sim2real.mp.shm import SharedFrameRingReader, SharedFrameRingWriter @@ -40,7 +42,7 @@ def test_loop_timing_reporter_separates_late_sleep_from_work_overrun(caplog) -> None: reporter = _LoopTimingReporter(target_period_s=0.02, log_interval_s=1.0, deadline_miss_tolerance_s=0.001) - with caplog.at_level(logging.INFO, logger="teleopit.sim2real.mp.runtime"): + with caplog.at_level(logging.INFO, logger="teleopit.operator"): reporter.record(loop_start_s=0.0, work_elapsed_s=0.0004, cycle_elapsed_s=0.02006, pico_age_s=None) reporter.record(loop_start_s=1.0, work_elapsed_s=0.021, cycle_elapsed_s=0.0212, pico_age_s=None) @@ -300,10 +302,39 @@ def test_lerobot_recording_schema_and_modality_sidecar() -> None: assert features[STATE_KEY]["shape"] == (68,) assert features[MODE_KEY]["shape"] == (1,) assert features[ACTION_KEY]["shape"] == (36,) + assert features[HAND_ACTION_KEY]["shape"] == (12,) assert sidecar["features"][STATE_KEY]["slices"]["joint_pos"] == [0, 29] assert sidecar["features"][STATE_KEY]["slices"]["projected_gravity"] == [65, 68] assert sidecar["features"][MODE_KEY]["codes"]["pause"] == 3 assert sidecar["features"][ACTION_KEY]["slices"]["joint_pos"] == [7, 36] + assert sidecar["features"][HAND_ACTION_KEY]["slices"]["left_pose"] == [0, 6] + assert sidecar["features"][HAND_ACTION_KEY]["slices"]["right_pose"] == [6, 12] + + +def test_configured_open_hand_pose_matches_linkerhand_l6_parser() -> None: + left, right = _configured_open_hand_pose( + { + "hands": { + "enabled": True, + "driver": "linkerhand_l6", + "mode": "gripper", + "linkerhand_l6": { + "thumb_yaw_center": 42, + "open_pose": [250, 99, 250, 250, 250, 250], + }, + }, + } + ) + + np.testing.assert_allclose(left, np.array([250, 42, 250, 250, 250, 250], dtype=np.float32)) + np.testing.assert_allclose(right, left) + + +def test_configured_open_hand_pose_defaults_without_enabled_hands() -> None: + left, right = _configured_open_hand_pose({}) + + np.testing.assert_allclose(left, np.array([250, 10, 250, 250, 250, 250], dtype=np.float32)) + np.testing.assert_allclose(right, left) def test_record_observation_state_concat_order() -> None: @@ -676,10 +707,19 @@ def add_frame( state: np.ndarray, mode: np.ndarray, action: np.ndarray, + hand_action: np.ndarray, task: str, ) -> None: calls.append(f"frame:{task}") - frames.append({"image": image.copy(), "state": state.copy(), "mode": mode.copy(), "action": action.copy()}) + frames.append( + { + "image": image.copy(), + "state": state.copy(), + "mode": mode.copy(), + "action": action.copy(), + "hand_action": hand_action.copy(), + } + ) def save_episode(self) -> None: calls.append("save") @@ -739,6 +779,15 @@ def fake_factory(**_kwargs: object) -> FakeRecorder: assert calls == ["start", "discard"] worker._start_episode() + worker._latest_hand_command = HandCommandPacket( + timestamp_s=2.05, + driver="linkerhand_l6", + mode="gripper", + active=True, + left_pose=np.arange(6, dtype=np.float32), + right_pose=np.arange(6, 12, dtype=np.float32), + seq=1, + ) desc = writer.write(np.full((2, 2, 3), 5, dtype=np.uint8), timestamp_s=2.1) worker._handle_video(desc) worker._save_episode() @@ -748,6 +797,7 @@ def fake_factory(**_kwargs: object) -> FakeRecorder: np.testing.assert_allclose(frames[0]["state"], np.arange(68, dtype=np.float32)) np.testing.assert_allclose(frames[0]["mode"], build_mode_observation("standing")) np.testing.assert_allclose(frames[0]["action"], np.arange(36, dtype=np.float32)) + np.testing.assert_allclose(frames[0]["hand_action"], np.arange(12, dtype=np.float32)) worker._latest_record = RecordStepPacket( timestamp_s=3.0, @@ -766,5 +816,6 @@ def fake_factory(**_kwargs: object) -> FakeRecorder: writer.close(unlink=True) worker._record_sub.close() worker._video_sub.close() + worker._hand_command_sub.close() worker._command_sub.close() worker._frame_reader.close() From 7614da14a21e854060276b4c7754aabb0bd5bc82 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Mon, 22 Jun 2026 22:05:10 +0800 Subject: [PATCH 115/122] Fix LeRobot recording API compatibility --- teleopit/recording/lerobot_v3.py | 43 ++++++++++++++++++++++++++------ 1 file changed, 35 insertions(+), 8 deletions(-) diff --git a/teleopit/recording/lerobot_v3.py b/teleopit/recording/lerobot_v3.py index 80391f6d..d8d2a87d 100644 --- a/teleopit/recording/lerobot_v3.py +++ b/teleopit/recording/lerobot_v3.py @@ -3,6 +3,7 @@ from __future__ import annotations from dataclasses import dataclass +import importlib import json from pathlib import Path from typing import Any @@ -33,6 +34,34 @@ } +def _import_lerobot_dataset() -> Any: + try: + module = importlib.import_module("lerobot.datasets.lerobot_dataset") + except ModuleNotFoundError as new_exc: + if new_exc.name not in {"lerobot", "lerobot.datasets", "lerobot.datasets.lerobot_dataset"}: + raise RuntimeError("Failed to import LeRobotDataset from the current LeRobot API") from new_exc + try: + module = importlib.import_module("lerobot.common.datasets.lerobot_dataset") + except ModuleNotFoundError as old_exc: + if old_exc.name not in { + "lerobot", + "lerobot.common", + "lerobot.common.datasets", + "lerobot.common.datasets.lerobot_dataset", + }: + raise RuntimeError("Failed to import LeRobotDataset from the legacy LeRobot API") from old_exc + raise RuntimeError( + "recording.enabled=true requires a LeRobot version that provides " + "LeRobotDataset. Install Teleopit with the recording extra, for example: " + "pip install -e '.[recording]'." + ) from old_exc + except Exception as old_exc: + raise RuntimeError("Failed to import LeRobotDataset from the legacy LeRobot API") from old_exc + except Exception as new_exc: + raise RuntimeError("Failed to import LeRobotDataset from the current LeRobot API") from new_exc + return module.LeRobotDataset + + @dataclass(frozen=True) class RecordingSchema: image_key: str @@ -215,16 +244,10 @@ def create( fps: int, schema: RecordingSchema, ) -> "TeleopitLeRobotV3Recorder": - try: - from lerobot.common.datasets.lerobot_dataset import LeRobotDataset - except Exception as exc: # pragma: no cover - exercised in environments without optional extra. - raise RuntimeError( - "recording.enabled=true requires the optional LeRobot dependency. " - "Install Teleopit with the recording extra, for example: pip install -e '.[recording]'." - ) from exc + LeRobotDataset = _import_lerobot_dataset() root = Path(output_dir) - root.mkdir(parents=True, exist_ok=True) + root.parent.mkdir(parents=True, exist_ok=True) dataset_repo_id = repo_id or dataset_name or "teleopit/sim2real" features = lerobot_features(schema) @@ -313,6 +336,10 @@ def discard_episode(self) -> None: self._frames_in_episode = 0 def finalize(self) -> None: + finalize = getattr(self._dataset, "finalize", None) + if callable(finalize): + finalize() + return consolidate = getattr(self._dataset, "consolidate", None) if callable(consolidate): consolidate() From fe7926e3cce2486a988ebe2dab489fe07362b005 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Mon, 22 Jun 2026 22:29:28 +0800 Subject: [PATCH 116/122] Replace LeRobot recording with HDF5 --- AGENTS.md | 2 +- README.md | 11 +- docs/docs/configuration/config-reference.md | 17 +- docs/docs/getting-started/installation.md | 4 +- docs/docs/tutorials/pico-sim2real.md | 5 +- .../current/configuration/config-reference.md | 17 +- .../current/getting-started/installation.md | 4 +- .../current/tutorials/pico-sim2real.md | 3 +- pyproject.toml | 2 - teleopit/configs/pico4_sim2real.yaml | 6 +- teleopit/configs/sim2real.yaml | 6 +- teleopit/recording/{lerobot_v3.py => hdf5.py} | 274 +++++++++--------- teleopit/sim2real/mp/runtime.py | 30 +- tests/test_sim2real_multiprocess.py | 64 +++- 14 files changed, 237 insertions(+), 208 deletions(-) rename teleopit/recording/{lerobot_v3.py => hdf5.py} (55%) diff --git a/AGENTS.md b/AGENTS.md index 6af3f639..2c730fc2 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -146,7 +146,7 @@ target_dof_pos = clip(action, -10, 10) × action_scale + default_dof_pos - `ARMS` entering/exiting/resume resets policy/reference alignment and uses Kp ramp; offline BVH sim2real does not use `ARMS`, and Unitree remote `B` remains BVH replay - Realtime mode switches and pause/resume use a retargeter-preserving soft reset: policy/reference state, smoothers, and reference alignment are reset, while the GMR IK warm-start is retained - Optional LinkerHand control uses `hands.enabled=true`, `hands.driver=linkerhand_l6|linkerhand_o6`, and `hands.mode=gripper|vr_hand_pose`; default is disabled -- Optional Pico sim2real LeRobot v3 recording uses `--config-name sim2real_record` or `recording.enabled=true`; it requires `input.provider=pico4`, `input.video.enabled=true`, `input.video.source=realsense`, an interactive terminal, and the `recording` extra +- Optional Pico sim2real HDF5 recording uses `--config-name sim2real_record` or `recording.enabled=true`; it requires `input.provider=pico4`, `input.video.enabled=true`, `input.video.source=realsense`, an interactive terminal, and the `recording` extra - Recording is manual only: terminal `R` starts an episode, `S` saves, `D` discards the active episode, and `Q` shuts down; `STANDING`, `MOCAP`, `ARMS`, and paused mocap are recordable - Recording captures `observation.images.d435i_rgb` RealSense RGB video at 30Hz plus `observation.state(68)`, `observation.mode(1)`, `action(36)`, and `action.hand(12)`; RealSense capture lives in `pico_input` through the normal `input.video` path - `gripper` mode reuses `Pico4InputProvider.get_controller_snapshot()` for Pico grip/trigger open-close control and supports LinkerHand L6 and O6 diff --git a/README.md b/README.md index adb47fdf..34930d87 100644 --- a/README.md +++ b/README.md @@ -84,9 +84,9 @@ python train_mimic/scripts/data/build_dataset.py \ --spec data/pico_motion/pico_recorded.yaml --force ``` -## Sim2Real LeRobot Recording +## Sim2Real HDF5 Recording -Pico sim2real can also record manual LeRobot v3 episodes from the real G1: +Pico sim2real can also record manual HDF5 episodes from the real G1: ```bash pip install -e '.[recording]' @@ -96,8 +96,9 @@ python scripts/run/run_sim2real.py --config-name sim2real_record \ ``` Recording uses the terminal controls `R` start, `S` save, `D` discard, and `Q` -shutdown. `STANDING`, `MOCAP`, `ARMS`, and paused mocap can be recorded. The -dataset schema is `observation.images.d435i_rgb` video at 30 Hz, +shutdown. `STANDING`, `MOCAP`, `ARMS`, and paused mocap can be recorded. Saved +episodes are written as `.h5` files under `data/recordings/sim2real_hdf5/episodes/`. +The dataset schema is `observation.images.d435i_rgb` RGB frames at 30 Hz, `observation.state(68)`, `observation.mode(1)`, `action(36)` as the aligned reference qpos sent to the policy path, and `action.hand(12)` as the latest LinkerHand left/right 6D pose commands. @@ -113,7 +114,7 @@ Full docs at **[BotRunner64.github.io/Teleopit](https://BotRunner64.github.io/Te - Added Pico sim2real `ARMS` mode: Pico/controller `B` toggles between whole-body `MOCAP` and stand-pose body/legs with live retargeted arms. - Bumped Pico input support to pico-bridge 0.2.1 and its corrected tracking pose semantics. - Added optional LinkerHand L6 sim2real modes under `hands.*`: `gripper` from Pico grip/trigger and `vr_hand_pose` from Pico hand pose through somehand 0.2.0 public API. -- Added manual Pico sim2real LeRobot v3 recording with RealSense D435i RGB video, 68D robot state, mode labels, 36D reference-qpos action labels, and 12D LinkerHand pose action labels. +- Added manual Pico sim2real HDF5 recording with RealSense D435i RGB video, 68D robot state, mode labels, 36D reference-qpos action labels, and 12D LinkerHand pose action labels. - Added LinkerHand O6 support for Pico `gripper` mode with an O6-specific grasp pose. - Set LinkerHand L6 `vr_hand_pose` control to maximum speed while keeping `gripper` at the configured default speed. - Switched default `vr_hand_pose` to a low-latency somehand path with 60 Hz hand retargeting and reduced smoothing. diff --git a/docs/docs/configuration/config-reference.md b/docs/docs/configuration/config-reference.md index 362b3df8..788b90a3 100644 --- a/docs/docs/configuration/config-reference.md +++ b/docs/docs/configuration/config-reference.md @@ -152,7 +152,7 @@ calling somehand 0.2.0 through `somehand.api` only. | `hands.somehand.temporal_filter_alpha` | somehand input landmark smoothing alpha; `1.0` disables smoothing delay | `1.0` | | `hands.somehand.output_alpha` | somehand qpos output smoothing alpha; `1.0` disables smoothing delay | `1.0` | -### LeRobot Recording (Pico sim2real) +### HDF5 Recording (Pico sim2real) `recording.enabled=true` is supported only with `input.provider=pico4`, `input.video.enabled=true`, `input.video.source=realsense`, and an interactive @@ -166,29 +166,32 @@ same frames produced by `pico_input`. | Field | Description | Default | |-------|-------------|---------| -| `recording.enabled` | Enable manual LeRobot v3 recording | `false` | -| `recording.output_dir` | Dataset root directory | `data/lerobot` | -| `recording.repo_id` / `dataset_name` | LeRobot dataset identity | `null` | +| `recording.enabled` | Enable manual HDF5 recording | `false` | +| `recording.output_dir` | Dataset root directory | `data/recordings/sim2real_hdf5` | | `recording.task` | Task string stored with frames | `demo` | | `recording.fps` | Recording/video clock rate | `30` | | `recording.min_episode_seconds` | Discard saved episodes shorter than this duration | `1.0` | | `recording.record_modes` | Modes that allow recording start and frame writes | `[standing, mocap, arms, pause]` | -| `recording.camera.key` | LeRobot video feature key | `observation.images.d435i_rgb` | +| `recording.camera.key` | RGB image dataset key | `observation.images.d435i_rgb` | | `recording.camera.width` / `height` / `fps` | RealSense RGB capture settings | `640` / `480` / `30` | | `recording.camera.device` | Optional RealSense serial | `null` | Camera failure behavior is controlled by `input.video.fail_on_error`. -LeRobot features: +HDF5 datasets: ```text -observation.images.d435i_rgb video [480,640,3] uint8 +observation.images.d435i_rgb uint8[N,480,640,3] observation.state float32[68] observation.mode float32[1] action float32[36] action.hand float32[12] ``` +Each saved episode is a standalone `.h5` file under +`recording.output_dir/episodes/`. The root attributes include the Teleopit HDF5 +recording format, schema version, task, fps, and frame count. + `observation.state` is ordered as `joint_pos(29)`, `joint_vel(29)`, `base_quat_wxyz(4)`, `base_ang_vel(3)`, and `projected_gravity(3)`. `observation.mode` is a numeric categorical: `standing=0`, `mocap=1`, diff --git a/docs/docs/getting-started/installation.md b/docs/docs/getting-started/installation.md index 615153c5..582e4f81 100644 --- a/docs/docs/getting-started/installation.md +++ b/docs/docs/getting-started/installation.md @@ -79,8 +79,8 @@ These packages are only required when `hands.enabled=true`. pip install -e '.[recording]' ``` -Adds the Pico sim2real stack plus LeRobot, RealSense, and video encoding -dependencies used by `sim2real_record.yaml`. +Adds the Pico sim2real stack plus RealSense and video dependencies used by +`sim2real_record.yaml`. ## Verify Installation diff --git a/docs/docs/tutorials/pico-sim2real.md b/docs/docs/tutorials/pico-sim2real.md index d172e4a3..25367df6 100644 --- a/docs/docs/tutorials/pico-sim2real.md +++ b/docs/docs/tutorials/pico-sim2real.md @@ -95,7 +95,7 @@ python scripts/run/run_sim2real.py \ real_robot.network_interface=eth0 ``` -## Optional LeRobot Recording +## Optional HDF5 Recording Install the recording extra on the machine that owns Pico input and RealSense: @@ -115,7 +115,8 @@ python scripts/run/run_sim2real.py \ Terminal controls are `R` start episode, `S` save, `D` discard, and `Q` shutdown. `STANDING`, `MOCAP`, `ARMS`, and paused mocap can be recorded; -saved episodes cannot be discarded afterward. The v1 schema records +saved episodes cannot be discarded afterward. Episodes are saved as `.h5` files +under `data/recordings/sim2real_hdf5/episodes/`. The v1 schema records `observation.images.d435i_rgb`, `observation.state(68)`, `observation.mode(1)`, `action(36)`, and `action.hand(12)` at 30 Hz. diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md index 7d573842..8575e9b3 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md @@ -171,7 +171,7 @@ Teleopit 会先将 Pico 手部状态转成 21 个 landmarks,再只通过 someh | `hands.somehand.temporal_filter_alpha` | somehand 输入 landmarks 平滑 alpha;`1.0` 表示关闭平滑延时 | `1.0` | | `hands.somehand.output_alpha` | somehand qpos 输出平滑 alpha;`1.0` 表示关闭平滑延时 | `1.0` | -### LeRobot 录制(Pico sim2real) +### HDF5 录制(Pico sim2real) `recording.enabled=true` 只支持 `input.provider=pico4`、 `input.video.enabled=true`、`input.video.source=realsense`,并且需要交互式终端。 @@ -183,29 +183,32 @@ Teleopit 会先将 Pico 手部状态转成 21 个 landmarks,再只通过 someh | 字段 | 说明 | 默认值 | |---|---|---| -| `recording.enabled` | 启用手动 LeRobot v3 录制 | `false` | -| `recording.output_dir` | 数据集根目录 | `data/lerobot` | -| `recording.repo_id` / `dataset_name` | LeRobot 数据集标识 | `null` | +| `recording.enabled` | 启用手动 HDF5 录制 | `false` | +| `recording.output_dir` | 数据集根目录 | `data/recordings/sim2real_hdf5` | | `recording.task` | 写入 frame 的任务字符串 | `demo` | | `recording.fps` | 录制/视频主时钟频率 | `30` | | `recording.min_episode_seconds` | 保存时短于该时长的 episode 会被丢弃 | `1.0` | | `recording.record_modes` | 允许开始录制和写帧的模式 | `[standing, mocap, arms, pause]` | -| `recording.camera.key` | LeRobot 视频 feature key | `observation.images.d435i_rgb` | +| `recording.camera.key` | RGB 图像数据集 key | `observation.images.d435i_rgb` | | `recording.camera.width` / `height` / `fps` | RealSense RGB 采集设置 | `640` / `480` / `30` | | `recording.camera.device` | 可选 RealSense 序列号 | `null` | 相机失败时的行为由 `input.video.fail_on_error` 控制。 -LeRobot features: +HDF5 datasets: ```text -observation.images.d435i_rgb video [480,640,3] uint8 +observation.images.d435i_rgb uint8[N,480,640,3] observation.state float32[68] observation.mode float32[1] action float32[36] action.hand float32[12] ``` +每个保存的 episode 都是 `recording.output_dir/episodes/` 下的独立 `.h5` +文件。根属性包含 Teleopit HDF5 recording format、schema version、task、fps +和 frame count。 + `observation.state` 的顺序是 `joint_pos(29)`、`joint_vel(29)`、 `base_quat_wxyz(4)`、`base_ang_vel(3)` 和 `projected_gravity(3)`。 `observation.mode` 是数值类别:`standing=0`、`mocap=1`、 diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md index 5fa7f38d..42d208fa 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md @@ -79,8 +79,8 @@ scripts/setup/download_somehand_l6_assets.sh pip install -e '.[recording]' ``` -该配置包含 Pico sim2real 栈,以及 `sim2real_record.yaml` 使用的 LeRobot、 -RealSense 和视频编码依赖。 +该配置包含 Pico sim2real 栈,以及 `sim2real_record.yaml` 使用的 RealSense +和视频依赖。 ## 验证安装 diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md index 3127eeeb..4b828afe 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md @@ -91,7 +91,7 @@ python scripts/run/run_sim2real.py \ real_robot.network_interface=eth0 ``` -## 可选 LeRobot 录制 +## 可选 HDF5 录制 在负责 Pico 输入和 RealSense 的机器上安装 recording extra: @@ -111,6 +111,7 @@ python scripts/run/run_sim2real.py \ 终端控制为:`R` 开始 episode,`S` 保存,`D` 丢弃,`Q` 关闭。可以录制 `STANDING`、`MOCAP`、`ARMS` 和暂停状态的 mocap;已经保存的 episode 不支持再丢弃。 +episode 会保存为 `data/recordings/sim2real_hdf5/episodes/` 下的 `.h5` 文件。 v1 schema 以 30 Hz 记录 `observation.images.d435i_rgb`、`observation.state(68)`、 `observation.mode(1)`、`action(36)` 和 `action.hand(12)`。 diff --git a/pyproject.toml b/pyproject.toml index fbe6288d..3c1f8dfa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,8 +55,6 @@ pico4 = [ ] recording = [ "teleopit[pico4]", - "lerobot", - "pyrealsense2", "opencv-python", "imageio[ffmpeg]", ] diff --git a/teleopit/configs/pico4_sim2real.yaml b/teleopit/configs/pico4_sim2real.yaml index c8c8d181..61d50a51 100644 --- a/teleopit/configs/pico4_sim2real.yaml +++ b/teleopit/configs/pico4_sim2real.yaml @@ -20,10 +20,8 @@ reference_debug_log: false recording: enabled: false - format: lerobot_v3 - output_dir: data/lerobot - dataset_name: null - repo_id: null + format: hdf5 + output_dir: data/recordings/sim2real_hdf5 task: demo fps: 30 control: terminal diff --git a/teleopit/configs/sim2real.yaml b/teleopit/configs/sim2real.yaml index 262338af..45373fa4 100644 --- a/teleopit/configs/sim2real.yaml +++ b/teleopit/configs/sim2real.yaml @@ -24,10 +24,8 @@ console: recording: enabled: false - format: lerobot_v3 - output_dir: data/lerobot - dataset_name: null - repo_id: null + format: hdf5 + output_dir: data/recordings/sim2real_hdf5 task: demo fps: 30 control: terminal diff --git a/teleopit/recording/lerobot_v3.py b/teleopit/recording/hdf5.py similarity index 55% rename from teleopit/recording/lerobot_v3.py rename to teleopit/recording/hdf5.py index d8d2a87d..29a124f1 100644 --- a/teleopit/recording/lerobot_v3.py +++ b/teleopit/recording/hdf5.py @@ -1,13 +1,14 @@ -"""LeRobot v3 adapter and schema helpers for Teleopit sim2real recording.""" +"""HDF5 recorder and schema helpers for Teleopit sim2real recording.""" from __future__ import annotations from dataclasses import dataclass -import importlib import json from pathlib import Path +import time from typing import Any +import h5py import numpy as np from teleopit.constants import FULL_QPOS_DIM, NUM_JOINTS @@ -26,6 +27,8 @@ ACTION_DIM = FULL_QPOS_DIM HAND_ACTION_DIM = 12 DEFAULT_IMAGE_SHAPE = (480, 640, 3) +HDF5_RECORDING_FORMAT = "teleopit_sim2real_recording_hdf5" +HDF5_RECORDING_VERSION = 1 MODE_CODES = { "standing": 0, "mocap": 1, @@ -34,34 +37,6 @@ } -def _import_lerobot_dataset() -> Any: - try: - module = importlib.import_module("lerobot.datasets.lerobot_dataset") - except ModuleNotFoundError as new_exc: - if new_exc.name not in {"lerobot", "lerobot.datasets", "lerobot.datasets.lerobot_dataset"}: - raise RuntimeError("Failed to import LeRobotDataset from the current LeRobot API") from new_exc - try: - module = importlib.import_module("lerobot.common.datasets.lerobot_dataset") - except ModuleNotFoundError as old_exc: - if old_exc.name not in { - "lerobot", - "lerobot.common", - "lerobot.common.datasets", - "lerobot.common.datasets.lerobot_dataset", - }: - raise RuntimeError("Failed to import LeRobotDataset from the legacy LeRobot API") from old_exc - raise RuntimeError( - "recording.enabled=true requires a LeRobot version that provides " - "LeRobotDataset. Install Teleopit with the recording extra, for example: " - "pip install -e '.[recording]'." - ) from old_exc - except Exception as old_exc: - raise RuntimeError("Failed to import LeRobotDataset from the legacy LeRobot API") from old_exc - except Exception as new_exc: - raise RuntimeError("Failed to import LeRobotDataset from the current LeRobot API") from new_exc - return module.LeRobotDataset - - @dataclass(frozen=True) class RecordingSchema: image_key: str @@ -85,42 +60,13 @@ def build_recording_schema(camera_cfg: Any) -> RecordingSchema: return RecordingSchema(image_key=key, image_shape=(height, width, 3)) -def lerobot_features(schema: RecordingSchema) -> dict[str, dict[str, object]]: - return { - schema.image_key: { - "dtype": "video", - "shape": schema.image_shape, - "names": ["height", "width", "channel"], - }, - schema.state_key: { - "dtype": "float32", - "shape": (schema.state_dim,), - "names": ["state"], - }, - schema.mode_key: { - "dtype": "float32", - "shape": (schema.mode_dim,), - "names": ["mode"], - }, - schema.action_key: { - "dtype": "float32", - "shape": (schema.action_dim,), - "names": ["action"], - }, - schema.hand_action_key: { - "dtype": "float32", - "shape": (schema.hand_action_dim,), - "names": ["hand_action"], - }, - } - - -def modality_sidecar(schema: RecordingSchema) -> dict[str, object]: +def hdf5_schema(schema: RecordingSchema) -> dict[str, object]: return { - "version": 1, + "format": HDF5_RECORDING_FORMAT, + "version": HDF5_RECORDING_VERSION, "features": { schema.image_key: { - "type": "video", + "type": "image", "shape": list(schema.image_shape), "dtype": "uint8", }, @@ -217,62 +163,59 @@ def build_mode_observation(mode: str) -> np.ndarray: return np.array([MODE_CODES[normalized]], dtype=np.float32) -class TeleopitLeRobotV3Recorder: - """Small adapter around LeRobot v3 dataset writing.""" +class TeleopitHDF5Recorder: + """Writes one HDF5 file per saved sim2real recording episode.""" def __init__( self, *, - dataset: Any, output_dir: Path, + task: str, + fps: int, schema: RecordingSchema, ) -> None: - self._dataset = dataset self._output_dir = output_dir + self._task = str(task) + self._fps = int(fps) self._schema = schema self._active = False self._frames_in_episode = 0 + self._episode_index = 0 + self._h5: h5py.File | None = None + self._tmp_path: Path | None = None + self._episode_path: Path | None = None + self._datasets: dict[str, h5py.Dataset] = {} @classmethod def create( cls, *, output_dir: str | Path, - dataset_name: str | None, - repo_id: str | None, task: str, fps: int, schema: RecordingSchema, - ) -> "TeleopitLeRobotV3Recorder": - LeRobotDataset = _import_lerobot_dataset() - + ) -> "TeleopitHDF5Recorder": root = Path(output_dir) - root.parent.mkdir(parents=True, exist_ok=True) - dataset_repo_id = repo_id or dataset_name or "teleopit/sim2real" - features = lerobot_features(schema) - - try: - dataset = LeRobotDataset.create( - repo_id=dataset_repo_id, - fps=int(fps), - root=root, - features=features, - use_videos=True, - ) - except TypeError: - dataset = LeRobotDataset.create( - repo_id=dataset_repo_id, - fps=int(fps), - root=root, - features=features, - ) - recorder = cls(dataset=dataset, output_dir=root, schema=schema) - recorder._write_modality_sidecar() + root.mkdir(parents=True, exist_ok=True) + recorder = cls(output_dir=root, task=task, fps=fps, schema=schema) + recorder._write_schema_sidecar() return recorder def start_episode(self) -> None: if self._active: raise RuntimeError("Cannot start a new recording episode while one is active") + self._episode_index += 1 + timestamp = time.strftime("%Y%m%d_%H%M%S") + stem = f"episode_{timestamp}_{time.time_ns()}_{self._episode_index:06d}" + tmp_dir = self._output_dir / ".tmp" + episodes_dir = self._output_dir / "episodes" + tmp_dir.mkdir(parents=True, exist_ok=True) + episodes_dir.mkdir(parents=True, exist_ok=True) + self._tmp_path = tmp_dir / f"{stem}.h5" + self._episode_path = episodes_dir / f"{stem}.h5" + self._h5 = h5py.File(self._tmp_path, "w") + self._write_episode_header(self._h5) + self._datasets = self._create_datasets(self._h5) self._active = True self._frames_in_episode = 0 @@ -286,65 +229,116 @@ def add_frame( hand_action: np.ndarray, task: str, ) -> None: - if not self._active: + if not self._active or self._h5 is None: raise RuntimeError("Cannot add a recording frame without an active episode") image_arr = np.asarray(image, dtype=np.uint8) if tuple(image_arr.shape) != self._schema.image_shape: raise ValueError(f"{self._schema.image_key} frame shape {image_arr.shape} != {self._schema.image_shape}") - state_arr = np.asarray(state, dtype=np.float32).reshape(-1) - mode_arr = np.asarray(mode, dtype=np.float32).reshape(-1) - action_arr = np.asarray(action, dtype=np.float32).reshape(-1) - hand_action_arr = np.asarray(hand_action, dtype=np.float32).reshape(-1) - if state_arr.shape[0] != self._schema.state_dim: - raise ValueError(f"{self._schema.state_key} must be {self._schema.state_dim}D") - if mode_arr.shape[0] != self._schema.mode_dim: - raise ValueError(f"{self._schema.mode_key} must be {self._schema.mode_dim}D") - if action_arr.shape[0] != self._schema.action_dim: - raise ValueError(f"{self._schema.action_key} must be {self._schema.action_dim}D") - if hand_action_arr.shape[0] != self._schema.hand_action_dim: - raise ValueError(f"{self._schema.hand_action_key} must be {self._schema.hand_action_dim}D") - self._dataset.add_frame( - { - self._schema.image_key: image_arr, - self._schema.state_key: state_arr, - self._schema.mode_key: mode_arr, - self._schema.action_key: action_arr, - self._schema.hand_action_key: hand_action_arr, - "task": str(task), - } - ) + state_arr = self._validate_vector(state, self._schema.state_key, self._schema.state_dim) + mode_arr = self._validate_vector(mode, self._schema.mode_key, self._schema.mode_dim) + action_arr = self._validate_vector(action, self._schema.action_key, self._schema.action_dim) + hand_action_arr = self._validate_vector(hand_action, self._schema.hand_action_key, self._schema.hand_action_dim) + + row = self._frames_in_episode + for dataset in self._datasets.values(): + dataset.resize((row + 1, *dataset.shape[1:])) + self._datasets[self._schema.image_key][row] = image_arr + self._datasets[self._schema.state_key][row] = state_arr + self._datasets[self._schema.mode_key][row] = mode_arr + self._datasets[self._schema.action_key][row] = action_arr + self._datasets[self._schema.hand_action_key][row] = hand_action_arr self._frames_in_episode += 1 + self._h5.attrs["frames"] = self._frames_in_episode + self._h5.attrs["task"] = str(task) def save_episode(self) -> None: if not self._active: return - self._dataset.save_episode() - self._active = False - self._frames_in_episode = 0 + tmp_path = self._require_tmp_path() + episode_path = self._require_episode_path() + self._close_active_file() + tmp_path.replace(episode_path) + self._reset_episode() def discard_episode(self) -> None: if not self._active: return - clear = getattr(self._dataset, "clear_episode_buffer", None) - if callable(clear): - clear() - else: - buffer_attr = getattr(self._dataset, "episode_buffer", None) - if isinstance(buffer_attr, dict): - buffer_attr.clear() - self._active = False - self._frames_in_episode = 0 + tmp_path = self._tmp_path + self._close_active_file() + if tmp_path is not None and tmp_path.exists(): + tmp_path.unlink() + self._reset_episode() def finalize(self) -> None: - finalize = getattr(self._dataset, "finalize", None) - if callable(finalize): - finalize() - return - consolidate = getattr(self._dataset, "consolidate", None) - if callable(consolidate): - consolidate() - - def _write_modality_sidecar(self) -> None: - path = self._output_dir / "meta" / "modality.json" - path.parent.mkdir(parents=True, exist_ok=True) - path.write_text(json.dumps(modality_sidecar(self._schema), indent=2) + "\n", encoding="utf-8") + if self._active: + self.discard_episode() + + def _write_schema_sidecar(self) -> None: + path = self._output_dir / "schema.json" + path.write_text(json.dumps(hdf5_schema(self._schema), indent=2) + "\n", encoding="utf-8") + + def _write_episode_header(self, h5: h5py.File) -> None: + h5.attrs["format"] = HDF5_RECORDING_FORMAT + h5.attrs["version"] = HDF5_RECORDING_VERSION + h5.attrs["task"] = self._task + h5.attrs["fps"] = self._fps + h5.attrs["frames"] = 0 + h5.attrs["schema_json"] = json.dumps(hdf5_schema(self._schema), sort_keys=True) + + def _create_datasets(self, h5: h5py.File) -> dict[str, h5py.Dataset]: + image_shape = self._schema.image_shape + return { + self._schema.image_key: h5.create_dataset( + self._schema.image_key, + shape=(0, *image_shape), + maxshape=(None, *image_shape), + chunks=(1, *image_shape), + dtype=np.uint8, + compression="lzf", + ), + self._schema.state_key: self._create_vector_dataset(h5, self._schema.state_key, self._schema.state_dim), + self._schema.mode_key: self._create_vector_dataset(h5, self._schema.mode_key, self._schema.mode_dim), + self._schema.action_key: self._create_vector_dataset(h5, self._schema.action_key, self._schema.action_dim), + self._schema.hand_action_key: self._create_vector_dataset( + h5, self._schema.hand_action_key, self._schema.hand_action_dim + ), + } + + @staticmethod + def _create_vector_dataset(h5: h5py.File, key: str, dim: int) -> h5py.Dataset: + return h5.create_dataset( + key, + shape=(0, dim), + maxshape=(None, dim), + chunks=(1024, dim), + dtype=np.float32, + ) + + @staticmethod + def _validate_vector(value: object, key: str, dim: int) -> np.ndarray: + arr = np.asarray(value, dtype=np.float32).reshape(-1) + if arr.shape[0] != dim: + raise ValueError(f"{key} must be {dim}D") + return arr + + def _close_active_file(self) -> None: + if self._h5 is not None: + self._h5.close() + self._h5 = None + self._datasets = {} + + def _reset_episode(self) -> None: + self._active = False + self._frames_in_episode = 0 + self._tmp_path = None + self._episode_path = None + + def _require_tmp_path(self) -> Path: + if self._tmp_path is None: + raise RuntimeError("recording episode has no temporary path") + return self._tmp_path + + def _require_episode_path(self) -> Path: + if self._episode_path is None: + raise RuntimeError("recording episode has no output path") + return self._episode_path diff --git a/teleopit/sim2real/mp/runtime.py b/teleopit/sim2real/mp/runtime.py index 472785c0..34f54c0f 100644 --- a/teleopit/sim2real/mp/runtime.py +++ b/teleopit/sim2real/mp/runtime.py @@ -6,7 +6,6 @@ import multiprocessing as mp from multiprocessing.synchronize import Event as MpEvent from enum import Enum -import importlib.util from pathlib import Path import sys import time @@ -43,7 +42,7 @@ from teleopit.runtime.mocap_session import MocapSessionManager, MocapSessionState from teleopit.runtime.reference_config import parse_reference_config from teleopit.runtime.terminal_keyboard import TerminalKeyboardReader -from teleopit.recording.lerobot_v3 import ( +from teleopit.recording.hdf5 import ( build_mode_observation, build_observation_state, normalize_hand_action, @@ -316,13 +315,13 @@ def _validate_new_runtime_config(cfg: Any) -> None: if provider != "pico4": raise ValueError("recording.enabled=true requires input.provider=pico4") rec_cfg = _recording_cfg(cfg) - if str(cfg_get(rec_cfg, "format", "lerobot_v3")) != "lerobot_v3": - raise ValueError("Only recording.format=lerobot_v3 is supported") + if str(cfg_get(rec_cfg, "format", "hdf5")) != "hdf5": + raise ValueError("Only recording.format=hdf5 is supported") if str(cfg_get(rec_cfg, "control", "terminal")) != "terminal": raise ValueError("Only recording.control=terminal is supported") camera_cfg = _recording_camera_cfg(cfg) if not bool(cfg_get(camera_cfg, "enabled", True)): - raise ValueError("recording.camera.enabled=false is not supported for LeRobot recording") + raise ValueError("recording.camera.enabled=false is not supported for HDF5 recording") if str(cfg_get(camera_cfg, "source", "realsense")).lower() != "realsense": raise ValueError("recording.camera.source must be realsense") if int(cfg_get(rec_cfg, "fps", 30)) != int(cfg_get(camera_cfg, "fps", 30)): @@ -346,17 +345,12 @@ def _validate_new_runtime_config(cfg: Any) -> None: def _require_recording_dependencies() -> None: - if importlib.util.find_spec("lerobot") is None: - raise RuntimeError( - "recording.enabled=true requires the recording dependencies and LeRobot v3 adapter. " - "Install Teleopit with: pip install -e '.[recording]'." - ) try: - from teleopit.recording.lerobot_v3 import TeleopitLeRobotV3Recorder + from teleopit.recording.hdf5 import TeleopitHDF5Recorder - TeleopitLeRobotV3Recorder.create + TeleopitHDF5Recorder.create except Exception as exc: - raise RuntimeError("LeRobot v3 recording adapter is unavailable") from exc + raise RuntimeError("HDF5 recording adapter is unavailable") from exc def _worker_loop(name: str, cfg: dict[str, Any], fn: Callable[[], None]) -> None: @@ -1746,17 +1740,15 @@ def __init__( self._episode_started_s = 0.0 self._episode_frames = 0 - from teleopit.recording.lerobot_v3 import ( - TeleopitLeRobotV3Recorder, + from teleopit.recording.hdf5 import ( + TeleopitHDF5Recorder, build_recording_schema, ) self._schema = build_recording_schema(self.camera_cfg) - factory = recorder_factory or TeleopitLeRobotV3Recorder.create + factory = recorder_factory or TeleopitHDF5Recorder.create self._recorder = factory( - output_dir=cfg_get(self.rec_cfg, "output_dir", "data/lerobot"), - dataset_name=cfg_get(self.rec_cfg, "dataset_name", None), - repo_id=cfg_get(self.rec_cfg, "repo_id", None), + output_dir=cfg_get(self.rec_cfg, "output_dir", "data/recordings/sim2real_hdf5"), task=self.task, fps=self.fps, schema=self._schema, diff --git a/tests/test_sim2real_multiprocess.py b/tests/test_sim2real_multiprocess.py index 1688892e..9fa23259 100644 --- a/tests/test_sim2real_multiprocess.py +++ b/tests/test_sim2real_multiprocess.py @@ -5,23 +5,24 @@ from pathlib import Path from types import SimpleNamespace +import h5py import numpy as np import pytest from teleopit.runtime.mocap_session import MocapSessionState from teleopit.inputs.realtime_packet import ControlEvent, ControlEventType from teleopit.runtime.arm_mocap import compose_arm_reference, compose_arm_reference_window -from teleopit.recording.lerobot_v3 import ( +from teleopit.recording.hdf5 import ( ACTION_KEY, HAND_ACTION_KEY, + HDF5_RECORDING_FORMAT, IMAGE_KEY, MODE_KEY, STATE_KEY, build_mode_observation, build_observation_state, build_recording_schema, - lerobot_features, - modality_sidecar, + hdf5_schema, ) from teleopit.sim2real.mp.ipc import HEALTH_TOPIC, LatestSubscriber, ZmqPublisher from teleopit.sim2real.mp.messages import HandCommandPacket, RecordStepPacket, ReferencePacket, SharedFrameDescriptor @@ -293,16 +294,17 @@ def test_recording_key_mapping() -> None: assert map_recording_key_to_command("x") is None -def test_lerobot_recording_schema_and_modality_sidecar() -> None: +def test_hdf5_recording_schema() -> None: schema = build_recording_schema({"width": 640, "height": 480, "key": IMAGE_KEY}) - features = lerobot_features(schema) - sidecar = modality_sidecar(schema) - - assert features[IMAGE_KEY]["shape"] == (480, 640, 3) - assert features[STATE_KEY]["shape"] == (68,) - assert features[MODE_KEY]["shape"] == (1,) - assert features[ACTION_KEY]["shape"] == (36,) - assert features[HAND_ACTION_KEY]["shape"] == (12,) + sidecar = hdf5_schema(schema) + features = sidecar["features"] + + assert sidecar["format"] == HDF5_RECORDING_FORMAT + assert features[IMAGE_KEY]["shape"] == [480, 640, 3] + assert features[STATE_KEY]["shape"] == [68] + assert features[MODE_KEY]["shape"] == [1] + assert features[ACTION_KEY]["shape"] == [36] + assert features[HAND_ACTION_KEY]["shape"] == [12] assert sidecar["features"][STATE_KEY]["slices"]["joint_pos"] == [0, 29] assert sidecar["features"][STATE_KEY]["slices"]["projected_gravity"] == [65, 68] assert sidecar["features"][MODE_KEY]["codes"]["pause"] == 3 @@ -311,6 +313,44 @@ def test_lerobot_recording_schema_and_modality_sidecar() -> None: assert sidecar["features"][HAND_ACTION_KEY]["slices"]["right_pose"] == [6, 12] +def test_hdf5_recorder_writes_episode_file(tmp_path: Path) -> None: + from teleopit.recording.hdf5 import TeleopitHDF5Recorder + + schema = build_recording_schema({"width": 2, "height": 2, "key": IMAGE_KEY}) + recorder = TeleopitHDF5Recorder.create(output_dir=tmp_path, task="walk", fps=30, schema=schema) + + recorder.start_episode() + recorder.add_frame( + image=np.full((2, 2, 3), 7, dtype=np.uint8), + state=np.arange(68, dtype=np.float32), + mode=build_mode_observation("mocap"), + action=np.arange(36, dtype=np.float32), + hand_action=np.arange(12, dtype=np.float32), + task="walk", + ) + recorder.save_episode() + recorder.finalize() + + episodes = sorted((tmp_path / "episodes").glob("*.h5")) + assert len(episodes) == 1 + assert (tmp_path / "schema.json").exists() + assert not list((tmp_path / ".tmp").glob("*.h5")) + + with h5py.File(episodes[0], "r") as h5: + assert h5.attrs["format"] == HDF5_RECORDING_FORMAT + assert h5.attrs["version"] == 1 + assert h5.attrs["task"] == "walk" + assert h5.attrs["fps"] == 30 + assert h5.attrs["frames"] == 1 + assert h5[IMAGE_KEY].shape == (1, 2, 2, 3) + assert h5[STATE_KEY].shape == (1, 68) + assert h5[MODE_KEY].shape == (1, 1) + assert h5[ACTION_KEY].shape == (1, 36) + assert h5[HAND_ACTION_KEY].shape == (1, 12) + np.testing.assert_array_equal(h5[IMAGE_KEY][0], np.full((2, 2, 3), 7, dtype=np.uint8)) + np.testing.assert_allclose(h5[HAND_ACTION_KEY][0], np.arange(12, dtype=np.float32)) + + def test_configured_open_hand_pose_matches_linkerhand_l6_parser() -> None: left, right = _configured_open_hand_pose( { From 5ab2e473b8f12845fd490e2cbd77ab9cf5182f7d Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Tue, 23 Jun 2026 14:02:07 +0800 Subject: [PATCH 117/122] Store sim2real recording video as MP4 sidecars --- AGENTS.md | 1 + README.md | 4 +- docs/docs/configuration/config-reference.md | 16 +- docs/docs/tutorials/pico-sim2real.md | 5 +- .../current/configuration/config-reference.md | 15 +- .../current/tutorials/pico-sim2real.md | 8 +- teleopit/configs/pico4_sim2real.yaml | 4 + teleopit/configs/sim2real.yaml | 4 + teleopit/recording/hdf5.py | 206 ++++++++++++++++-- teleopit/sim2real/mp/runtime.py | 4 + tests/test_sim2real_multiprocess.py | 105 +++++++-- 11 files changed, 318 insertions(+), 54 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 2c730fc2..ee21cf0b 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -149,6 +149,7 @@ target_dof_pos = clip(action, -10, 10) × action_scale + default_dof_pos - Optional Pico sim2real HDF5 recording uses `--config-name sim2real_record` or `recording.enabled=true`; it requires `input.provider=pico4`, `input.video.enabled=true`, `input.video.source=realsense`, an interactive terminal, and the `recording` extra - Recording is manual only: terminal `R` starts an episode, `S` saves, `D` discards the active episode, and `Q` shuts down; `STANDING`, `MOCAP`, `ARMS`, and paused mocap are recordable - Recording captures `observation.images.d435i_rgb` RealSense RGB video at 30Hz plus `observation.state(68)`, `observation.mode(1)`, `action(36)`, and `action.hand(12)`; RealSense capture lives in `pico_input` through the normal `input.video` path +- HDF5 recording writes compressed MP4 sidecar videos under `recording.output_dir/videos//` while HDF5 episodes store `frame_index`, `timestamp`, low-dimensional data, and video sync attributes; raw RGB image datasets are not supported - `gripper` mode reuses `Pico4InputProvider.get_controller_snapshot()` for Pico grip/trigger open-close control and supports LinkerHand L6 and O6 - `vr_hand_pose` mode reuses `Pico4InputProvider.get_hand_snapshot()` and somehand 0.2.0 public `somehand.api` for continuous Pico hand-pose retargeting; do not start a second `PicoBridge` for hand control - Teleopit owns Pico 26-joint hand-state to 21-landmark conversion; do not import `somehand.pico_input` diff --git a/README.md b/README.md index 34930d87..882312eb 100644 --- a/README.md +++ b/README.md @@ -98,7 +98,9 @@ python scripts/run/run_sim2real.py --config-name sim2real_record \ Recording uses the terminal controls `R` start, `S` save, `D` discard, and `Q` shutdown. `STANDING`, `MOCAP`, `ARMS`, and paused mocap can be recorded. Saved episodes are written as `.h5` files under `data/recordings/sim2real_hdf5/episodes/`. -The dataset schema is `observation.images.d435i_rgb` RGB frames at 30 Hz, +`sim2real_record.yaml` stores camera frames as compressed MP4 sidecar files under +`data/recordings/sim2real_hdf5/videos/` and keeps `frame_index` / `timestamp` +sync metadata in the HDF5 episode. The low-dimensional HDF5 schema records `observation.state(68)`, `observation.mode(1)`, `action(36)` as the aligned reference qpos sent to the policy path, and `action.hand(12)` as the latest LinkerHand left/right 6D pose commands. diff --git a/docs/docs/configuration/config-reference.md b/docs/docs/configuration/config-reference.md index 788b90a3..02d4db38 100644 --- a/docs/docs/configuration/config-reference.md +++ b/docs/docs/configuration/config-reference.md @@ -175,22 +175,30 @@ same frames produced by `pico_input`. | `recording.camera.key` | RGB image dataset key | `observation.images.d435i_rgb` | | `recording.camera.width` / `height` / `fps` | RealSense RGB capture settings | `640` / `480` / `30` | | `recording.camera.device` | Optional RealSense serial | `null` | +| `recording.video.codec` / `quality` / `pixelformat` | MP4 sidecar encoder settings | `libx264` / `8` / `yuv420p` | Camera failure behavior is controlled by `input.video.fail_on_error`. +Each saved episode has one `.h5` file under `recording.output_dir/episodes/` +and one compressed MP4 sidecar under +`recording.output_dir/videos//`. The HDF5 episode stores +`frame_index` and `timestamp` arrays, plus `video_path`, `video_fps`, and +`video_frames` root attributes for synchronization. Raw RGB image datasets are +not written. + HDF5 datasets: ```text -observation.images.d435i_rgb uint8[N,480,640,3] +frame_index int64[N] +timestamp float64[N] observation.state float32[68] observation.mode float32[1] action float32[36] action.hand float32[12] ``` -Each saved episode is a standalone `.h5` file under -`recording.output_dir/episodes/`. The root attributes include the Teleopit HDF5 -recording format, schema version, task, fps, and frame count. +The root attributes include the Teleopit HDF5 recording format, schema version, +task, fps, frame count, and video sync metadata. `observation.state` is ordered as `joint_pos(29)`, `joint_vel(29)`, `base_quat_wxyz(4)`, `base_ang_vel(3)`, and `projected_gravity(3)`. diff --git a/docs/docs/tutorials/pico-sim2real.md b/docs/docs/tutorials/pico-sim2real.md index 25367df6..df0529a2 100644 --- a/docs/docs/tutorials/pico-sim2real.md +++ b/docs/docs/tutorials/pico-sim2real.md @@ -116,8 +116,9 @@ python scripts/run/run_sim2real.py \ Terminal controls are `R` start episode, `S` save, `D` discard, and `Q` shutdown. `STANDING`, `MOCAP`, `ARMS`, and paused mocap can be recorded; saved episodes cannot be discarded afterward. Episodes are saved as `.h5` files -under `data/recordings/sim2real_hdf5/episodes/`. The v1 schema records -`observation.images.d435i_rgb`, `observation.state(68)`, +under `data/recordings/sim2real_hdf5/episodes/`, with compressed MP4 sidecar +videos under `data/recordings/sim2real_hdf5/videos/`. The HDF5 episode stores +`frame_index` and `timestamp` sync arrays plus `observation.state(68)`, `observation.mode(1)`, `action(36)`, and `action.hand(12)` at 30 Hz. ## Operator Flow diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md index 8575e9b3..bd35c27f 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md @@ -192,22 +192,29 @@ Teleopit 会先将 Pico 手部状态转成 21 个 landmarks,再只通过 someh | `recording.camera.key` | RGB 图像数据集 key | `observation.images.d435i_rgb` | | `recording.camera.width` / `height` / `fps` | RealSense RGB 采集设置 | `640` / `480` / `30` | | `recording.camera.device` | 可选 RealSense 序列号 | `null` | +| `recording.video.codec` / `quality` / `pixelformat` | MP4 sidecar 编码设置 | `libx264` / `8` / `yuv420p` | 相机失败时的行为由 `input.video.fail_on_error` 控制。 +每个保存的 episode 会在 `recording.output_dir/episodes/` 下写入一个 `.h5` +文件,并在 `recording.output_dir/videos//` 下写入一个压缩 MP4 +sidecar。HDF5 episode 保存 `frame_index` 和 `timestamp` 数组,并在根属性中 +记录 `video_path`、`video_fps` 和 `video_frames` 用于同步。录制不会写入原始 +RGB 图像 dataset。 + HDF5 datasets: ```text -observation.images.d435i_rgb uint8[N,480,640,3] +frame_index int64[N] +timestamp float64[N] observation.state float32[68] observation.mode float32[1] action float32[36] action.hand float32[12] ``` -每个保存的 episode 都是 `recording.output_dir/episodes/` 下的独立 `.h5` -文件。根属性包含 Teleopit HDF5 recording format、schema version、task、fps -和 frame count。 +根属性包含 Teleopit HDF5 recording format、schema version、task、fps、 +frame count 和视频同步元数据。 `observation.state` 的顺序是 `joint_pos(29)`、`joint_vel(29)`、 `base_quat_wxyz(4)`、`base_ang_vel(3)` 和 `projected_gravity(3)`。 diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md index 4b828afe..c883d5a1 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md @@ -111,9 +111,11 @@ python scripts/run/run_sim2real.py \ 终端控制为:`R` 开始 episode,`S` 保存,`D` 丢弃,`Q` 关闭。可以录制 `STANDING`、`MOCAP`、`ARMS` 和暂停状态的 mocap;已经保存的 episode 不支持再丢弃。 -episode 会保存为 `data/recordings/sim2real_hdf5/episodes/` 下的 `.h5` 文件。 -v1 schema 以 30 Hz 记录 `observation.images.d435i_rgb`、`observation.state(68)`、 -`observation.mode(1)`、`action(36)` 和 `action.hand(12)`。 +episode 会保存为 `data/recordings/sim2real_hdf5/episodes/` 下的 `.h5` 文件, +压缩 MP4 sidecar 视频保存在 `data/recordings/sim2real_hdf5/videos/` 下。 +HDF5 episode 以 30 Hz 保存 `frame_index` 和 `timestamp` 同步数组,以及 +`observation.state(68)`、`observation.mode(1)`、`action(36)` 和 +`action.hand(12)`。 ## 操作流程 diff --git a/teleopit/configs/pico4_sim2real.yaml b/teleopit/configs/pico4_sim2real.yaml index 61d50a51..9b335e58 100644 --- a/teleopit/configs/pico4_sim2real.yaml +++ b/teleopit/configs/pico4_sim2real.yaml @@ -36,6 +36,10 @@ recording: height: 480 fps: 30 device: null + video: + codec: libx264 + quality: 8 + pixelformat: yuv420p runtime: host: 127.0.0.1 diff --git a/teleopit/configs/sim2real.yaml b/teleopit/configs/sim2real.yaml index 45373fa4..f2dbe24f 100644 --- a/teleopit/configs/sim2real.yaml +++ b/teleopit/configs/sim2real.yaml @@ -40,6 +40,10 @@ recording: height: 480 fps: 30 device: null + video: + codec: libx264 + quality: 8 + pixelformat: yuv420p runtime: host: 127.0.0.1 diff --git a/teleopit/recording/hdf5.py b/teleopit/recording/hdf5.py index 29a124f1..36849179 100644 --- a/teleopit/recording/hdf5.py +++ b/teleopit/recording/hdf5.py @@ -4,7 +4,9 @@ from dataclasses import dataclass import json +import logging from pathlib import Path +import re import time from typing import Any @@ -17,11 +19,15 @@ from teleopit.runtime.common import cfg_get +logger = logging.getLogger(__name__) + IMAGE_KEY = "observation.images.d435i_rgb" STATE_KEY = "observation.state" MODE_KEY = "observation.mode" ACTION_KEY = "action" HAND_ACTION_KEY = "action.hand" +FRAME_INDEX_KEY = "frame_index" +TIMESTAMP_KEY = "timestamp" STATE_DIM = 68 MODE_DIM = 1 ACTION_DIM = FULL_QPOS_DIM @@ -51,6 +57,13 @@ class RecordingSchema: hand_action_dim: int = HAND_ACTION_DIM +@dataclass(frozen=True) +class MP4VideoConfig: + codec: str = "libx264" + quality: int = 8 + pixelformat: str = "yuv420p" + + def build_recording_schema(camera_cfg: Any) -> RecordingSchema: key = str(cfg_get(camera_cfg, "key", IMAGE_KEY)) width = int(cfg_get(camera_cfg, "width", DEFAULT_IMAGE_SHAPE[1])) @@ -60,15 +73,48 @@ def build_recording_schema(camera_cfg: Any) -> RecordingSchema: return RecordingSchema(image_key=key, image_shape=(height, width, 3)) -def hdf5_schema(schema: RecordingSchema) -> dict[str, object]: +def build_mp4_video_config(video_cfg: Any) -> MP4VideoConfig: + quality = int(cfg_get(video_cfg, "quality", 8)) + if quality < 0 or quality > 10: + raise ValueError("recording.video.quality must be in [0, 10]") + return MP4VideoConfig( + codec=str(cfg_get(video_cfg, "codec", "libx264")), + quality=quality, + pixelformat=str(cfg_get(video_cfg, "pixelformat", "yuv420p")), + ) + + +def hdf5_schema( + schema: RecordingSchema, + *, + video_config: MP4VideoConfig | None = None, +) -> dict[str, object]: + video_cfg = video_config or MP4VideoConfig() return { "format": HDF5_RECORDING_FORMAT, "version": HDF5_RECORDING_VERSION, "features": { schema.image_key: { - "type": "image", + "type": "video", + "format": "mp4", + "codec": video_cfg.codec, "shape": list(schema.image_shape), "dtype": "uint8", + "sync": { + "frame_index": FRAME_INDEX_KEY, + "timestamp": TIMESTAMP_KEY, + }, + }, + FRAME_INDEX_KEY: { + "type": "index", + "shape": [], + "dtype": "int64", + }, + TIMESTAMP_KEY: { + "type": "timestamp", + "shape": [], + "dtype": "float64", + "units": "seconds", }, schema.state_key: { "type": "low_dim", @@ -173,17 +219,23 @@ def __init__( task: str, fps: int, schema: RecordingSchema, + video_config: MP4VideoConfig | None = None, ) -> None: self._output_dir = output_dir self._task = str(task) self._fps = int(fps) self._schema = schema + self._video_config = video_config or MP4VideoConfig() self._active = False self._frames_in_episode = 0 self._episode_index = 0 self._h5: h5py.File | None = None self._tmp_path: Path | None = None self._episode_path: Path | None = None + self._tmp_video_path: Path | None = None + self._episode_video_path: Path | None = None + self._video_rel_path: str | None = None + self._video_writer: Any | None = None self._datasets: dict[str, h5py.Dataset] = {} @classmethod @@ -194,10 +246,17 @@ def create( task: str, fps: int, schema: RecordingSchema, + video_config: MP4VideoConfig | None = None, ) -> "TeleopitHDF5Recorder": root = Path(output_dir) root.mkdir(parents=True, exist_ok=True) - recorder = cls(output_dir=root, task=task, fps=fps, schema=schema) + recorder = cls( + output_dir=root, + task=task, + fps=fps, + schema=schema, + video_config=video_config, + ) recorder._write_schema_sidecar() return recorder @@ -214,10 +273,22 @@ def start_episode(self) -> None: self._tmp_path = tmp_dir / f"{stem}.h5" self._episode_path = episodes_dir / f"{stem}.h5" self._h5 = h5py.File(self._tmp_path, "w") - self._write_episode_header(self._h5) - self._datasets = self._create_datasets(self._h5) - self._active = True - self._frames_in_episode = 0 + video_dir = self._output_dir / "videos" / _safe_path_component(self._schema.image_key) + tmp_video_dir = tmp_dir / "videos" / _safe_path_component(self._schema.image_key) + video_dir.mkdir(parents=True, exist_ok=True) + tmp_video_dir.mkdir(parents=True, exist_ok=True) + self._tmp_video_path = tmp_video_dir / f"{stem}.mp4" + self._episode_video_path = video_dir / f"{stem}.mp4" + self._video_rel_path = self._episode_video_path.relative_to(self._output_dir).as_posix() + try: + self._video_writer = self._create_video_writer(self._tmp_video_path) + self._write_episode_header(self._h5) + self._datasets = self._create_datasets(self._h5) + self._active = True + self._frames_in_episode = 0 + except Exception: + self._cleanup_partial_episode() + raise def add_frame( self, @@ -242,7 +313,11 @@ def add_frame( row = self._frames_in_episode for dataset in self._datasets.values(): dataset.resize((row + 1, *dataset.shape[1:])) - self._datasets[self._schema.image_key][row] = image_arr + if self._video_writer is None: + raise RuntimeError("MP4 recording writer is not open") + self._video_writer.append_data(image_arr) + self._datasets[FRAME_INDEX_KEY][row] = row + self._datasets[TIMESTAMP_KEY][row] = float(row) / float(self._fps) self._datasets[self._schema.state_key][row] = state_arr self._datasets[self._schema.mode_key][row] = mode_arr self._datasets[self._schema.action_key][row] = action_arr @@ -256,7 +331,12 @@ def save_episode(self) -> None: return tmp_path = self._require_tmp_path() episode_path = self._require_episode_path() - self._close_active_file() + tmp_video_path = self._tmp_video_path + episode_video_path = self._episode_video_path + self._close_active_outputs() + if tmp_video_path is None or episode_video_path is None: + raise RuntimeError("recording episode has no video output path") + tmp_video_path.replace(episode_video_path) tmp_path.replace(episode_path) self._reset_episode() @@ -264,9 +344,12 @@ def discard_episode(self) -> None: if not self._active: return tmp_path = self._tmp_path - self._close_active_file() + tmp_video_path = self._tmp_video_path + self._close_active_outputs() if tmp_path is not None and tmp_path.exists(): tmp_path.unlink() + if tmp_video_path is not None and tmp_video_path.exists(): + tmp_video_path.unlink() self._reset_episode() def finalize(self) -> None: @@ -275,7 +358,7 @@ def finalize(self) -> None: def _write_schema_sidecar(self) -> None: path = self._output_dir / "schema.json" - path.write_text(json.dumps(hdf5_schema(self._schema), indent=2) + "\n", encoding="utf-8") + path.write_text(json.dumps(self._schema_dict(), indent=2) + "\n", encoding="utf-8") def _write_episode_header(self, h5: h5py.File) -> None: h5.attrs["format"] = HDF5_RECORDING_FORMAT @@ -283,18 +366,32 @@ def _write_episode_header(self, h5: h5py.File) -> None: h5.attrs["task"] = self._task h5.attrs["fps"] = self._fps h5.attrs["frames"] = 0 - h5.attrs["schema_json"] = json.dumps(hdf5_schema(self._schema), sort_keys=True) + h5.attrs["schema_json"] = json.dumps(self._schema_dict(), sort_keys=True) + h5.attrs["video_key"] = self._schema.image_key + h5.attrs["video_path"] = self._video_rel_path or "" + h5.attrs["video_format"] = "mp4" + h5.attrs["video_codec"] = self._video_config.codec + h5.attrs["video_pixelformat"] = self._video_config.pixelformat + h5.attrs["video_fps"] = self._fps + h5.attrs["video_frames"] = 0 + h5.attrs["video_from_timestamp_s"] = 0.0 + h5.attrs["video_to_timestamp_s"] = 0.0 def _create_datasets(self, h5: h5py.File) -> dict[str, h5py.Dataset]: - image_shape = self._schema.image_shape return { - self._schema.image_key: h5.create_dataset( - self._schema.image_key, - shape=(0, *image_shape), - maxshape=(None, *image_shape), - chunks=(1, *image_shape), - dtype=np.uint8, - compression="lzf", + FRAME_INDEX_KEY: h5.create_dataset( + FRAME_INDEX_KEY, + shape=(0,), + maxshape=(None,), + chunks=(1024,), + dtype=np.int64, + ), + TIMESTAMP_KEY: h5.create_dataset( + TIMESTAMP_KEY, + shape=(0,), + maxshape=(None,), + chunks=(1024,), + dtype=np.float64, ), self._schema.state_key: self._create_vector_dataset(h5, self._schema.state_key, self._schema.state_dim), self._schema.mode_key: self._create_vector_dataset(h5, self._schema.mode_key, self._schema.mode_dim), @@ -321,8 +418,38 @@ def _validate_vector(value: object, key: str, dim: int) -> np.ndarray: raise ValueError(f"{key} must be {dim}D") return arr - def _close_active_file(self) -> None: + def _schema_dict(self) -> dict[str, object]: + return hdf5_schema( + self._schema, + video_config=self._video_config, + ) + + def _create_video_writer(self, path: Path) -> Any: + try: + import imageio.v2 as imageio + except Exception as exc: + raise RuntimeError("MP4 recording requires imageio[ffmpeg]") from exc + return imageio.get_writer( + str(path), + fps=self._fps, + codec=self._video_config.codec, + quality=self._video_config.quality, + macro_block_size=1, + pixelformat=self._video_config.pixelformat, + ) + + def _close_active_outputs(self) -> None: + if self._video_writer is not None: + self._video_writer.close() + self._video_writer = None if self._h5 is not None: + self._h5.attrs["video_frames"] = self._frames_in_episode + self._h5.attrs["video_to_timestamp_s"] = ( + float(max(self._frames_in_episode - 1, 0)) / float(self._fps) + if self._frames_in_episode > 0 + else 0.0 + ) + self._h5.attrs["video_path"] = self._video_rel_path or "" self._h5.close() self._h5 = None self._datasets = {} @@ -332,6 +459,38 @@ def _reset_episode(self) -> None: self._frames_in_episode = 0 self._tmp_path = None self._episode_path = None + self._tmp_video_path = None + self._episode_video_path = None + self._video_rel_path = None + self._video_writer = None + + def _cleanup_partial_episode(self) -> None: + if self._video_writer is not None: + try: + self._video_writer.close() + except Exception: + logger.exception("Failed to close partial MP4 recording writer") + self._video_writer = None + if self._h5 is not None: + try: + self._h5.close() + except Exception: + logger.exception("Failed to close partial HDF5 recording file") + self._h5 = None + tmp_path = self._tmp_path + tmp_video_path = self._tmp_video_path + if tmp_path is not None and tmp_path.exists(): + try: + tmp_path.unlink() + except Exception: + logger.exception("Failed to remove partial HDF5 recording file: %s", tmp_path) + if tmp_video_path is not None and tmp_video_path.exists(): + try: + tmp_video_path.unlink() + except Exception: + logger.exception("Failed to remove partial MP4 recording file: %s", tmp_video_path) + self._datasets = {} + self._reset_episode() def _require_tmp_path(self) -> Path: if self._tmp_path is None: @@ -342,3 +501,8 @@ def _require_episode_path(self) -> Path: if self._episode_path is None: raise RuntimeError("recording episode has no output path") return self._episode_path + + +def _safe_path_component(value: str) -> str: + safe = re.sub(r"[^A-Za-z0-9._-]+", "_", value).strip("._") + return safe or "camera" diff --git a/teleopit/sim2real/mp/runtime.py b/teleopit/sim2real/mp/runtime.py index 34f54c0f..83805e64 100644 --- a/teleopit/sim2real/mp/runtime.py +++ b/teleopit/sim2real/mp/runtime.py @@ -44,6 +44,7 @@ from teleopit.runtime.terminal_keyboard import TerminalKeyboardReader from teleopit.recording.hdf5 import ( build_mode_observation, + build_mp4_video_config, build_observation_state, normalize_hand_action, normalize_action_reference_qpos, @@ -319,6 +320,7 @@ def _validate_new_runtime_config(cfg: Any) -> None: raise ValueError("Only recording.format=hdf5 is supported") if str(cfg_get(rec_cfg, "control", "terminal")) != "terminal": raise ValueError("Only recording.control=terminal is supported") + build_mp4_video_config(cfg_get(rec_cfg, "video", {}) or {}) camera_cfg = _recording_camera_cfg(cfg) if not bool(cfg_get(camera_cfg, "enabled", True)): raise ValueError("recording.camera.enabled=false is not supported for HDF5 recording") @@ -1746,12 +1748,14 @@ def __init__( ) self._schema = build_recording_schema(self.camera_cfg) + self._video_config = build_mp4_video_config(cfg_get(self.rec_cfg, "video", {}) or {}) factory = recorder_factory or TeleopitHDF5Recorder.create self._recorder = factory( output_dir=cfg_get(self.rec_cfg, "output_dir", "data/recordings/sim2real_hdf5"), task=self.task, fps=self.fps, schema=self._schema, + video_config=self._video_config, ) def run(self) -> None: diff --git a/tests/test_sim2real_multiprocess.py b/tests/test_sim2real_multiprocess.py index 9fa23259..398540a1 100644 --- a/tests/test_sim2real_multiprocess.py +++ b/tests/test_sim2real_multiprocess.py @@ -14,11 +14,13 @@ from teleopit.runtime.arm_mocap import compose_arm_reference, compose_arm_reference_window from teleopit.recording.hdf5 import ( ACTION_KEY, + FRAME_INDEX_KEY, HAND_ACTION_KEY, HDF5_RECORDING_FORMAT, IMAGE_KEY, MODE_KEY, STATE_KEY, + TIMESTAMP_KEY, build_mode_observation, build_observation_state, build_recording_schema, @@ -300,7 +302,11 @@ def test_hdf5_recording_schema() -> None: features = sidecar["features"] assert sidecar["format"] == HDF5_RECORDING_FORMAT + assert features[IMAGE_KEY]["type"] == "video" + assert features[IMAGE_KEY]["format"] == "mp4" assert features[IMAGE_KEY]["shape"] == [480, 640, 3] + assert features[FRAME_INDEX_KEY]["dtype"] == "int64" + assert features[TIMESTAMP_KEY]["dtype"] == "float64" assert features[STATE_KEY]["shape"] == [68] assert features[MODE_KEY]["shape"] == [1] assert features[ACTION_KEY]["shape"] == [36] @@ -313,26 +319,36 @@ def test_hdf5_recording_schema() -> None: assert sidecar["features"][HAND_ACTION_KEY]["slices"]["right_pose"] == [6, 12] -def test_hdf5_recorder_writes_episode_file(tmp_path: Path) -> None: - from teleopit.recording.hdf5 import TeleopitHDF5Recorder +def test_hdf5_recorder_mp4_sidecar_writes_sync_metadata(tmp_path: Path) -> None: + from teleopit.recording.hdf5 import MP4VideoConfig, TeleopitHDF5Recorder schema = build_recording_schema({"width": 2, "height": 2, "key": IMAGE_KEY}) - recorder = TeleopitHDF5Recorder.create(output_dir=tmp_path, task="walk", fps=30, schema=schema) - - recorder.start_episode() - recorder.add_frame( - image=np.full((2, 2, 3), 7, dtype=np.uint8), - state=np.arange(68, dtype=np.float32), - mode=build_mode_observation("mocap"), - action=np.arange(36, dtype=np.float32), - hand_action=np.arange(12, dtype=np.float32), + recorder = TeleopitHDF5Recorder.create( + output_dir=tmp_path, task="walk", + fps=30, + schema=schema, + video_config=MP4VideoConfig(quality=5), ) + + recorder.start_episode() + for idx in range(2): + recorder.add_frame( + image=np.full((2, 2, 3), idx * 64, dtype=np.uint8), + state=np.arange(68, dtype=np.float32), + mode=build_mode_observation("mocap"), + action=np.arange(36, dtype=np.float32), + hand_action=np.arange(12, dtype=np.float32), + task="walk", + ) recorder.save_episode() recorder.finalize() episodes = sorted((tmp_path / "episodes").glob("*.h5")) + videos = sorted((tmp_path / "videos" / "observation.images.d435i_rgb").glob("*.mp4")) assert len(episodes) == 1 + assert len(videos) == 1 + assert videos[0].stat().st_size > 0 assert (tmp_path / "schema.json").exists() assert not list((tmp_path / ".tmp").glob("*.h5")) @@ -341,14 +357,65 @@ def test_hdf5_recorder_writes_episode_file(tmp_path: Path) -> None: assert h5.attrs["version"] == 1 assert h5.attrs["task"] == "walk" assert h5.attrs["fps"] == 30 - assert h5.attrs["frames"] == 1 - assert h5[IMAGE_KEY].shape == (1, 2, 2, 3) - assert h5[STATE_KEY].shape == (1, 68) - assert h5[MODE_KEY].shape == (1, 1) - assert h5[ACTION_KEY].shape == (1, 36) - assert h5[HAND_ACTION_KEY].shape == (1, 12) - np.testing.assert_array_equal(h5[IMAGE_KEY][0], np.full((2, 2, 3), 7, dtype=np.uint8)) - np.testing.assert_allclose(h5[HAND_ACTION_KEY][0], np.arange(12, dtype=np.float32)) + assert h5.attrs["frames"] == 2 + assert h5.attrs["video_path"] == videos[0].relative_to(tmp_path).as_posix() + assert h5.attrs["video_key"] == IMAGE_KEY + assert h5.attrs["video_frames"] == 2 + assert h5.attrs["video_fps"] == 30 + assert IMAGE_KEY not in h5 + assert h5[FRAME_INDEX_KEY].shape == (2,) + assert h5[TIMESTAMP_KEY].shape == (2,) + np.testing.assert_array_equal(h5[FRAME_INDEX_KEY][...], np.array([0, 1], dtype=np.int64)) + np.testing.assert_allclose(h5[TIMESTAMP_KEY][...], np.array([0.0, 1.0 / 30.0], dtype=np.float64)) + assert h5[STATE_KEY].shape == (2, 68) + assert h5[MODE_KEY].shape == (2, 1) + assert h5[ACTION_KEY].shape == (2, 36) + assert h5[HAND_ACTION_KEY].shape == (2, 12) + + +def test_hdf5_recorder_cleans_partial_episode_when_video_writer_fails(tmp_path: Path) -> None: + from teleopit.recording.hdf5 import TeleopitHDF5Recorder + + class FailingVideoRecorder(TeleopitHDF5Recorder): + def _create_video_writer(self, path: Path) -> object: + path.write_bytes(b"partial") + raise RuntimeError("writer failed") + + schema = build_recording_schema({"width": 2, "height": 2, "key": IMAGE_KEY}) + recorder = FailingVideoRecorder.create(output_dir=tmp_path, task="walk", fps=30, schema=schema) + + with pytest.raises(RuntimeError, match="writer failed"): + recorder.start_episode() + + recorder.finalize() + assert not list((tmp_path / ".tmp").glob("*.h5")) + assert not list((tmp_path / ".tmp" / "videos" / "observation.images.d435i_rgb").glob("*.mp4")) + assert not list((tmp_path / "episodes").glob("*.h5")) + assert not list((tmp_path / "videos" / "observation.images.d435i_rgb").glob("*.mp4")) + + +def test_hdf5_recorder_keeps_startup_error_when_partial_cleanup_fails(tmp_path: Path) -> None: + from teleopit.recording.hdf5 import TeleopitHDF5Recorder + + class BrokenWriter: + def close(self) -> None: + raise RuntimeError("cleanup failed") + + class FailingDatasetRecorder(TeleopitHDF5Recorder): + def _create_video_writer(self, path: Path) -> object: + return BrokenWriter() + + def _create_datasets(self, h5: h5py.File) -> dict[str, h5py.Dataset]: + raise RuntimeError("startup failed") + + schema = build_recording_schema({"width": 2, "height": 2, "key": IMAGE_KEY}) + recorder = FailingDatasetRecorder.create(output_dir=tmp_path, task="walk", fps=30, schema=schema) + + with pytest.raises(RuntimeError, match="startup failed"): + recorder.start_episode() + + recorder.finalize() + assert not list((tmp_path / ".tmp").glob("*.h5")) def test_configured_open_hand_pose_matches_linkerhand_l6_parser() -> None: From e52cf70333e71e7519cb2af9b09b43f0bb3b132f Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Tue, 23 Jun 2026 17:12:48 +0800 Subject: [PATCH 118/122] docs: document pyrealsense2 installation --- README.md | 3 +++ docs/docs/getting-started/installation.md | 11 +++++++++-- .../current/getting-started/installation.md | 10 ++++++++-- 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 882312eb..4ad82ca1 100644 --- a/README.md +++ b/README.md @@ -90,6 +90,9 @@ Pico sim2real can also record manual HDF5 episodes from the real G1: ```bash pip install -e '.[recording]' +# If you use RealSense video, install pyrealsense2 manually for your platform. +# On Arm machines, prefer conda-forge: +# conda install -c conda-forge pyrealsense2 python scripts/run/run_sim2real.py --config-name sim2real_record \ controller.policy_path=track.onnx \ recording.task="walk forward" diff --git a/docs/docs/getting-started/installation.md b/docs/docs/getting-started/installation.md index 582e4f81..54ab7d79 100644 --- a/docs/docs/getting-started/installation.md +++ b/docs/docs/getting-started/installation.md @@ -79,8 +79,15 @@ These packages are only required when `hands.enabled=true`. pip install -e '.[recording]' ``` -Adds the Pico sim2real stack plus RealSense and video dependencies used by -`sim2real_record.yaml`. +Adds the Pico sim2real stack plus the video dependencies used by +`sim2real_record.yaml`. RealSense Python bindings are platform-specific: install +`pyrealsense2` manually in the active environment when using +`input.video.source=realsense`. On Arm machines, use conda-forge rather than the +pip package: + +```bash +conda install -c conda-forge pyrealsense2 +``` ## Verify Installation diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md index 42d208fa..86ca2e09 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md @@ -79,8 +79,14 @@ scripts/setup/download_somehand_l6_assets.sh pip install -e '.[recording]' ``` -该配置包含 Pico sim2real 栈,以及 `sim2real_record.yaml` 使用的 RealSense -和视频依赖。 +该配置包含 Pico sim2real 栈,以及 `sim2real_record.yaml` 使用的视频依赖。 +RealSense Python 绑定与平台相关;使用 `input.video.source=realsense` 时, +需要在当前环境中手动安装 `pyrealsense2`。在 Arm 机器上,请使用 +conda-forge,而不是 pip 包: + +```bash +conda install -c conda-forge pyrealsense2 +``` ## 验证安装 From 071ba8ecddc75d86e369478b4d513d1e044ea65a Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Wed, 24 Jun 2026 17:57:28 +0800 Subject: [PATCH 119/122] Add Xsens BVH input loading --- teleopit/inputs/bvh_provider.py | 42 +++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/teleopit/inputs/bvh_provider.py b/teleopit/inputs/bvh_provider.py index d88c5a43..338ce5a9 100644 --- a/teleopit/inputs/bvh_provider.py +++ b/teleopit/inputs/bvh_provider.py @@ -6,6 +6,7 @@ import numpy as np from teleopit.retargeting.gmr.utils.lafan_vendor.extract import read_bvh from teleopit.retargeting.gmr.utils.lafan_vendor import utils +from teleopit.retargeting.gmr.utils.xsens_vendor.BVHParser import BVHParser, Anim from scipy.spatial.transform import Rotation as R @@ -117,6 +118,9 @@ def process_single_bvh_frame( def _load_bvh_file(bvh_file: str, format: str = "lafan1"): + if format == "xsens": + return _load_xsens_bvh_file(bvh_file) + data = read_bvh(bvh_file) bone_names = list(data.bones) bone_parents = np.array(data.parents, dtype=np.int32) @@ -182,6 +186,44 @@ def _load_bvh_file(bvh_file: str, format: str = "lafan1"): return frames, human_height, fps, bone_names, bone_parents +def _load_xsens_bvh_file(bvh_file: str): + parser = BVHParser(axis_order="zxy", scale=0.01) + with open(bvh_file, "r") as f: + bvh_text = f.read() + + rotations, positions = parser.parse(bvh_text, reset_to_zero=True) + quats, processed_positions, offsets, parents = parser._MOTION_data_post_processing( + rotations, + positions, + reset_to_zero=True, + ) + anim = Anim(quats, processed_positions, offsets, parents, parser.names) + global_quats, global_pos = utils.quat_fk(anim.quats, anim.pos, anim.parents) + + frames = [] + for frame in range(anim.pos.shape[0]): + result = {} + for i, bone in enumerate(anim.bones): + result[bone] = (global_pos[frame, i], global_quats[frame, i]) + + result["LeftFootMod"] = (np.array(result["LeftAnkle"][0], copy=True), result["LeftAnkle"][1]) + result["RightFootMod"] = (np.array(result["RightAnkle"][0], copy=True), result["RightAnkle"][1]) + frames.append(result) + + if not frames: + raise ValueError(f"No frames parsed from xsens BVH input: {bvh_file}") + + frame_time = float(parser.frame_time) + fps = int(round(1.0 / frame_time)) if frame_time > 0.0 else 60 + last_frame = frames[-1] + human_height = float( + last_frame["Head_end_site"][0][2] + - min(last_frame["LeftToe_end_site"][0][2], last_frame["RightToe_end_site"][0][2]) + ) + + return frames, human_height, fps, list(anim.bones), np.array(anim.parents, dtype=np.int32) + + class BVHInputProvider: def __init__(self, bvh_path: str, human_format: str = "lafan1"): From 873928c6d5bac2b887ebb749666d4aa8881dc8e4 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Thu, 25 Jun 2026 17:11:54 +0800 Subject: [PATCH 120/122] Keep sim2real hand control active across modes --- AGENTS.md | 2 +- README.md | 1 + docs/docs/configuration/config-reference.md | 2 +- docs/docs/tutorials/pico-sim2real.md | 6 +-- .../current/configuration/config-reference.md | 3 +- .../current/tutorials/pico-sim2real.md | 4 +- teleopit/sim2real/mp/runtime.py | 7 ++- tests/test_sim2real_multiprocess.py | 50 ++++++++++++++++++- 8 files changed, 64 insertions(+), 11 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index ee21cf0b..8471e293 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -156,7 +156,7 @@ target_dof_pos = clip(action, -10, 10) × action_scale + default_dof_pos - LinkerHand O6 supports only `hands.mode=gripper`; its default `close_pose` is `[86, 73, 118, 111, 110, 111]` - L6 `gripper` mode uses the configured `hands.linkerhand_l6.speed` (default `[50]*6`); O6 `gripper` mode uses `hands.linkerhand_o6.speed` (default `[255]*6`); `vr_hand_pose` always sets LinkerHand L6 speed to `[255]*6` - `vr_hand_pose` defaults to a low-latency somehand path: `hands.somehand.rate_hz=60`, `max_iterations=12`, `temporal_filter_alpha=1.0`, and `output_alpha=1.0`; this prioritizes response speed over smoothing -- LinkerHand control is active in sim2real `MOCAP` and `ARMS`; `STANDING`, `DAMPING`, mocap pause, and shutdown must send the configured open pose +- LinkerHand control is active in all sim2real modes when `hands.enabled=true`; shutdown and hand-runtime failure must send the configured open pose - In `vr_hand_pose` mode, missing/inactive hand pose holds the last commanded pose for that side instead of opening the hand ### SimulationLoop Runtime Behavior diff --git a/README.md b/README.md index 4ad82ca1..062a83e9 100644 --- a/README.md +++ b/README.md @@ -123,6 +123,7 @@ Full docs at **[BotRunner64.github.io/Teleopit](https://BotRunner64.github.io/Te - Added LinkerHand O6 support for Pico `gripper` mode with an O6-specific grasp pose. - Set LinkerHand L6 `vr_hand_pose` control to maximum speed while keeping `gripper` at the configured default speed. - Switched default `vr_hand_pose` to a low-latency somehand path with 60 Hz hand retargeting and reduced smoothing. +- LinkerHand sim2real control remains active across all sim2real modes after the runtime mode state is initialized. - Realtime mode switches and pause/resume now preserve GMR IK warm-starts instead of cold-starting the retargeter on each transition. - Added an interactive Pico motion recorder that saves retargeted G1 motion clips as training-ready NPZ files. - Switched dataset build outputs to recursive minimal HDF5 shards with no train/val split or manifest; `precompute_dataset.py` turns them into separate precomputed training datasets before training. diff --git a/docs/docs/configuration/config-reference.md b/docs/docs/configuration/config-reference.md index 02d4db38..f6bf965f 100644 --- a/docs/docs/configuration/config-reference.md +++ b/docs/docs/configuration/config-reference.md @@ -125,7 +125,7 @@ Realtime Pico resume re-centers heading and ground-plane position before trackin `hands.enabled=true` requires `input.provider=pico4` plus local editable installs of `third_party/linkerhand-python-sdk` and `third_party/somehand`. -Control is active in `MOCAP` and `ARMS`; inactive modes send the open pose. +When enabled, hand control remains active in all sim2real modes. `gripper` supports `linkerhand_l6` and `linkerhand_o6` by interpolating Pico trigger input between the configured open and close poses. `vr_hand_pose` is L6-only: missing hand pose holds the last command for that side, L6 speed is diff --git a/docs/docs/tutorials/pico-sim2real.md b/docs/docs/tutorials/pico-sim2real.md index df0529a2..4b89d3f6 100644 --- a/docs/docs/tutorials/pico-sim2real.md +++ b/docs/docs/tutorials/pico-sim2real.md @@ -186,8 +186,8 @@ Pico sim2real can drive LinkerHand hands from Pico input: and the public `somehand.api` from somehand 0.2.0. It always sets L6 speed to the maximum. -Hand control is active in `MOCAP` and `ARMS`. It sends the open pose in -`STANDING`, `DAMPING`, paused mocap, and shutdown. +When `hands.enabled=true`, hand control remains active in all sim2real modes. +Shutdown and hand-runtime failure send the configured open pose. Install the local hand-control packages first if they were not installed with the main Pico profile: @@ -309,5 +309,5 @@ input.video.enabled=true | Cannot enter debug mode | Unitree mode release failed | Stop other robot modes and press `Start` again | | Robot enters `STANDING` but not `MOCAP` | Mocap validation failed | Keep tracking active and stable; check `mocap_switch.check_frames` logs | | Pico pause does not return to `STANDING` | Expected behavior | Pico pause freezes mocap; press remote `X` for `STANDING` | -| LinkerHand does not move | `hands.enabled=false`, not in `MOCAP`, gripper deadman released, SDK/assets not installed, or CAN channel wrong | Enable `hands.enabled`, enter `MOCAP`, run `scripts/dev/test_linkerhand_l6.py`, and check the selected driver's `left_can` / `right_can` | +| LinkerHand does not move | `hands.enabled=false`, gripper deadman released, SDK/assets not installed, or CAN channel wrong | Enable `hands.enabled`, set `hands.mode`, run `scripts/dev/test_linkerhand_l6.py`, and check the selected driver's `left_can` / `right_can` | | Video preview is unavailable | RealSense or video source failed | Check camera permissions, `input.video.source`, and logs | diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md index bd35c27f..0366815c 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md @@ -143,8 +143,7 @@ MuJoCo 窗口显示重定向参考;`sim2sim`、`mocap`、`camera` 和 `all` ### 灵巧手(Pico sim2real) `hands.enabled=true` 要求 `input.provider=pico4`,并以本地 editable 方式安装 -`third_party/linkerhand-python-sdk` 和 `third_party/somehand`。控制在 `MOCAP` -和 `ARMS` 中生效;非活动模式会发送张开姿态。 +`third_party/linkerhand-python-sdk` 和 `third_party/somehand`。启用后,手控会在所有 sim2real 模式中保持生效。 `gripper` 支持 `linkerhand_l6` 和 `linkerhand_o6`,会用 Pico trigger 在配置的张开和闭合姿态之间插值。 `vr_hand_pose` 只支持 L6:手部 pose 消失时,对应侧会保持上一条命令;L6 速度会设为最大值; Teleopit 会先将 Pico 手部状态转成 21 个 landmarks,再只通过 somehand 0.2.0 公开的 `somehand.api` 调用。 diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md index c883d5a1..23fd2ce4 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md @@ -176,7 +176,7 @@ Pico sim2real 可以用 Pico 输入控制 LinkerHand: 速度设为最大值。默认配置使用 60 Hz 的低延时 somehand 路径并减少平滑,所以响应会更快, 但可能比标准 somehand 设置更抖。 -手控在 `MOCAP` 和 `ARMS` 中生效;在 `STANDING`、`DAMPING`、mocap 暂停和退出时都会发送张开姿态。 +`hands.enabled=true` 时,手控会在所有 sim2real 模式中保持生效。退出和手控运行时失败会发送配置的张开姿态。 如果主 Pico profile 没有包含手控支持,先安装本地手控包: @@ -296,5 +296,5 @@ input.video.enabled=true | 无法进入 debug mode | Unitree mode 释放失败 | 停止其他机器人模式后再次按 `Start` | | 机器人进入 `STANDING` 但不进入 `MOCAP` | 动捕验证失败 | 保持追踪稳定,查看 `mocap_switch.check_frames` 日志 | | Pico 暂停没有返回 `STANDING` | 这是预期行为 | Pico 暂停只冻结 mocap;按遥控器 `X` 返回 `STANDING` | -| LinkerHand 不动 | `hands.enabled=false`、不在 `MOCAP`、gripper deadman 未按住、SDK/资产未安装,或 CAN 通道错误 | 设置 `hands.enabled=true` 和 `hands.mode`,进入 `MOCAP`,运行 `scripts/dev/test_linkerhand_l6.py`,并检查所选 driver 的 `left_can` / `right_can` | +| LinkerHand 不动 | `hands.enabled=false`、gripper deadman 未按住、SDK/资产未安装,或 CAN 通道错误 | 设置 `hands.enabled=true` 和 `hands.mode`,运行 `scripts/dev/test_linkerhand_l6.py`,并检查所选 driver 的 `left_can` / `right_can` | | 视频预览不可用 | RealSense 或视频源失败 | 检查相机权限、`input.video.source` 和日志 | diff --git a/teleopit/sim2real/mp/runtime.py b/teleopit/sim2real/mp/runtime.py index 83805e64..7cd4e99a 100644 --- a/teleopit/sim2real/mp/runtime.py +++ b/teleopit/sim2real/mp/runtime.py @@ -1908,6 +1908,11 @@ def get_controller_snapshot(self) -> Any | None: return self.controller_snapshot +def _hand_worker_active_for_mode(mode_packet: ModeStatePacket) -> bool: + del mode_packet + return True + + def _run_hand_worker( cfg: dict[str, Any], endpoints: Sim2RealIpcEndpoints, @@ -1982,7 +1987,7 @@ def _publish_hand_command(*, timestamp_s: float, active_state: bool) -> None: proxy.controller_snapshot = controller_packet.snapshot mode_packet = mode_sub.recv_latest() if isinstance(mode_packet, ModeStatePacket): - active = bool(mode_packet.mocap_active) + active = _hand_worker_active_for_mode(mode_packet) try: now_s = time.monotonic() commands = runtime.tick( diff --git a/tests/test_sim2real_multiprocess.py b/tests/test_sim2real_multiprocess.py index 398540a1..b59b8e1b 100644 --- a/tests/test_sim2real_multiprocess.py +++ b/tests/test_sim2real_multiprocess.py @@ -27,7 +27,7 @@ hdf5_schema, ) from teleopit.sim2real.mp.ipc import HEALTH_TOPIC, LatestSubscriber, ZmqPublisher -from teleopit.sim2real.mp.messages import HandCommandPacket, RecordStepPacket, ReferencePacket, SharedFrameDescriptor +from teleopit.sim2real.mp.messages import HandCommandPacket, ModeStatePacket, RecordStepPacket, ReferencePacket, SharedFrameDescriptor from teleopit.sim.reference_timeline import ReferenceSample, ReferenceWindow from teleopit.sim2real.mp.runtime import ( map_recording_key_to_command, @@ -37,6 +37,7 @@ _RecordingWorker, _RobotControlWorker, _configured_open_hand_pose, + _hand_worker_active_for_mode, _human_frame_is_valid, ) from teleopit.sim2real.mp.shm import SharedFrameRingReader, SharedFrameRingWriter @@ -729,6 +730,53 @@ def test_robot_worker_mode_state_marks_arms_as_mocap_active() -> None: assert packet.mocap_paused is False +@pytest.mark.parametrize( + ("mode", "mocap_active", "mocap_paused"), + [ + ("standing", False, False), + ("mocap", True, False), + ("arms", True, False), + ("mocap", False, True), + ("arms", False, True), + ("damping", False, False), + ], +) +def test_hand_worker_stays_active_in_all_modes(mode: str, mocap_active: bool, mocap_paused: bool) -> None: + packet = ModeStatePacket( + mode=mode, + mocap_active=mocap_active, + mocap_paused=mocap_paused, + timestamp_s=1.0, + seq=1, + ) + + assert _hand_worker_active_for_mode(packet) is True + + +def test_hand_worker_active_state_only_updates_from_mode_packets() -> None: + active = False + mode_packet = None + if isinstance(mode_packet, ModeStatePacket): + active = _hand_worker_active_for_mode(mode_packet) + assert active is False + + mode_packet = ModeStatePacket( + mode="standing", + mocap_active=False, + mocap_paused=False, + timestamp_s=1.0, + seq=1, + ) + if isinstance(mode_packet, ModeStatePacket): + active = _hand_worker_active_for_mode(mode_packet) + assert active is True + + mode_packet = None + if isinstance(mode_packet, ModeStatePacket): + active = _hand_worker_active_for_mode(mode_packet) + assert active is True + + def test_robot_worker_publish_record_step() -> None: worker = object.__new__(_RobotControlWorker) worker.mode = RobotMode.ARMS From db771d6a45c5eee4d58852700703a033b96450f3 Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Thu, 25 Jun 2026 20:42:20 +0800 Subject: [PATCH 121/122] docs: align release notes and metadata --- AGENTS.md | 12 +++++----- CHANGELOG.md | 9 +++++--- README.md | 23 +++++-------------- docs/docs/getting-started/installation.md | 2 +- docs/docs/intro.md | 2 +- docs/docs/reference/architecture.md | 2 +- .../current/getting-started/installation.md | 2 +- .../current/intro.md | 2 +- .../current/reference/architecture.md | 2 +- pyproject.toml | 2 +- 10 files changed, 25 insertions(+), 33 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 8471e293..ffbeb26c 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -53,7 +53,7 @@ teleopit/ # Core inference package ├── robots/ │ └── mujoco_robot.py # MuJoCoRobot — MuJoCo sim wrapper ├── sim/ -│ └── loop.py # SimulationLoop — PD control at 1000Hz, policy at 50Hz +│ └── loop.py # SimulationLoop — PD control at 200Hz, policy at 50Hz ├── sim2real/ │ ├── mp/ # Process-isolated sim2real runtime and IPC │ └── hands/ # Optional LinkerHand driver/mapper plugins @@ -62,8 +62,8 @@ scripts/ ├── run/run_sim.py # Offline sim2sim pipeline ├── run/run_sim2real.py # G1 sim2real control; supports offline BVH playback and Pico4 ├── run/record_pico_motion.py # Interactive Pico recording → G1 motion NPZ clips -├── render_sim.py # Render single BVH → 3 MuJoCo videos (mocap input, retarget, sim2sim) -└── compute_ik_offsets.py # Compute IK quaternion offsets for new BVH formats +├── render/render_sim.py # Render single BVH → 3 MuJoCo videos (mocap input, retarget, sim2sim) +└── dev/compute_ik_offsets.py # Compute IK quaternion offsets for new BVH formats train_mimic/ # Training package ├── app.py # Shared app helpers for train/play/benchmark ├── tasks/tracking/config/ @@ -85,7 +85,7 @@ train_mimic/ # Training package ## Key Technical Details ### Sim2Sim Pipeline -- Policy runs at 50Hz, PD control at 1000Hz (`decimation=20`, `sim_dt=0.001`) +- Policy runs at 50Hz, PD control at 200Hz (`decimation=4`, `sim_dt=0.005`) - Action flow: `compute_action()` returns raw action → `get_target_dof_pos()` applies clip `[-10, 10]`, scale, and `default_dof_pos` - Must use `assets/robots/unitree_g1/g1_29dof.xml` for training, sim2sim, dataset FK, and retargeting; it is the canonical G1 XML entry point @@ -242,7 +242,7 @@ python train_mimic/scripts/save_onnx.py --checkpoint logs/rsl_rl/g1_general_trac - `teleopit/retargeting/gmr/assets/` is gitignored; downloaded at runtime - `train_mimic/assets/` is no longer tracked; FK tooling reuses `assets/robots/unitree_g1/g1_29dof.xml` - `third_party/linkerhand-python-sdk` and `third_party/somehand` support optional LinkerHand sim2real control -- Run `python scripts/check_large_tracked_files.py` before pushing +- Run `python scripts/dev/check_large_tracked_files.py` before pushing Assets are split across two ModelScope repos by type: @@ -284,7 +284,7 @@ R_offset = R_human_tpose^{-1} * R_robot_tpose Critical note: align robot root orientation to the BVH human forward direction before computing `R_robot_tpose`. For `hc_mocap`, G1 default faces `+X` while the BVH human faces `-Y` (`Z-up`), so the robot root must receive a `-90°` Z rotation first. -`scripts/compute_ik_offsets.py` can print or write calibrated offsets. +`scripts/dev/compute_ik_offsets.py` can print or write calibrated offsets. ## Development diff --git a/CHANGELOG.md b/CHANGELOG.md index bbf3975d..fc4daeec 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,10 @@ ## [Unreleased] -- 支持 pico-bridge 0.2.1,并适配其修正后的 tracking pose 语义。 +- 改进 Pico 实时控制:支持 pico-bridge 0.2.1、`ARMS` 模式,以及保留 retargeter warm-start 的模式切换/暂停恢复。 +- 新增可选 LinkerHand L6/O6 sim2real 控制,支持 Pico gripper 输入和低延迟 L6 `vr_hand_pose`。 +- 新增 Pico sim2real 手动 HDF5 录制,以及用于训练数据采集的交互式 Pico motion recorder。 +- 优化训练数据流程:minimal HDF5 shards、显式 precompute、rewind 采样和更新后的 tracking rewards。 ## [0.3.0] - 2026-05-12 @@ -17,12 +20,12 @@ - 新增独立 Standing 控制器、离线播放键盘控制与 Pico sim2sim 模式控制。 - 优化实时 mocap 缓冲与 catch-up,并将发布模型升级至 30k checkpoint。 -## [0.1.1] - 2025-03-28 +## [0.1.1] - 2026-03-28 - 数据集改为 shard-only 输出。 - 引入外部资源管理并瘦身仓库。 -## [0.1.0] - 2025-03-25 +## [0.1.0] - 2026-03-25 - 首个公开版本。 - 支持 General-Tracking-G1 全身追踪训练与 ONNX sim2sim 推理。 diff --git a/README.md b/README.md index 062a83e9..79eea955 100644 --- a/README.md +++ b/README.md @@ -116,21 +116,10 @@ Full docs at **[BotRunner64.github.io/Teleopit](https://BotRunner64.github.io/Te ### Unreleased -- Added Pico sim2real `ARMS` mode: Pico/controller `B` toggles between whole-body `MOCAP` and stand-pose body/legs with live retargeted arms. -- Bumped Pico input support to pico-bridge 0.2.1 and its corrected tracking pose semantics. -- Added optional LinkerHand L6 sim2real modes under `hands.*`: `gripper` from Pico grip/trigger and `vr_hand_pose` from Pico hand pose through somehand 0.2.0 public API. -- Added manual Pico sim2real HDF5 recording with RealSense D435i RGB video, 68D robot state, mode labels, 36D reference-qpos action labels, and 12D LinkerHand pose action labels. -- Added LinkerHand O6 support for Pico `gripper` mode with an O6-specific grasp pose. -- Set LinkerHand L6 `vr_hand_pose` control to maximum speed while keeping `gripper` at the configured default speed. -- Switched default `vr_hand_pose` to a low-latency somehand path with 60 Hz hand retargeting and reduced smoothing. -- LinkerHand sim2real control remains active across all sim2real modes after the runtime mode state is initialized. -- Realtime mode switches and pause/resume now preserve GMR IK warm-starts instead of cold-starting the retargeter on each transition. -- Added an interactive Pico motion recorder that saves retargeted G1 motion clips as training-ready NPZ files. -- Switched dataset build outputs to recursive minimal HDF5 shards with no train/val split or manifest; `precompute_dataset.py` turns them into separate precomputed training datasets before training. -- General-Tracking-G1 training defaults to `rewind` motion sampling and also supports `uniform`; playback/benchmark use `start`. -- Added optional `sampling_mode=rewind` for training, which restarts failed episodes from the same clip after rewinding a configurable number of policy steps. -- Added root velocity, joint tracking, and survival rewards to the General-Tracking-G1 training objective. -- Renamed General-Tracking-G1 observation terms to explicit `ref_*`, `robot_*`, and `prev_action` keys. +- Improved Pico realtime control with pico-bridge 0.2.1, `ARMS` mode, and retargeter-preserving mode/pause resets. +- Added optional LinkerHand L6/O6 sim2real control, including Pico gripper input and low-latency L6 `vr_hand_pose`. +- Added manual Pico sim2real HDF5 recording and an interactive Pico motion recorder for training NPZ clips. +- Refined the training data path with minimal HDF5 shards, explicit precompute, rewind sampling, and updated tracking rewards. ### v0.3.0 (2026-05-12) @@ -145,12 +134,12 @@ Full docs at **[BotRunner64.github.io/Teleopit](https://BotRunner64.github.io/Te - Added offline playback keyboard controls, Pico sim2sim mode control, and a standalone standing controller. - Improved realtime mocap buffering/catch-up and upgraded the released model to the 30k checkpoint. -### v0.1.1 (2025-03-28) +### v0.1.1 (2026-03-28) - Dataset shard-only refactor - External asset management (ModelScope), repository slimming -### v0.1.0 (2025-03-25) +### v0.1.0 (2026-03-25) - Initial public release: General-Tracking-G1 training, ONNX sim2sim inference, Pico 4 VR teleoperation, Unitree G1 hardware deployment diff --git a/docs/docs/getting-started/installation.md b/docs/docs/getting-started/installation.md index 54ab7d79..ed9daf6c 100644 --- a/docs/docs/getting-started/installation.md +++ b/docs/docs/getting-started/installation.md @@ -40,7 +40,7 @@ Adds `rsl-rl-lib`, `mjlab`, `wandb`, `swanlab`, and training dependencies. pip install -e '.[sim2real]' ``` -Adds `opencv-python` and `g1_bridge_sdk`. You also need to initialize submodules and build the C++ bridge: +Adds `opencv-python`. You also need to initialize submodules and build/install the C++ `g1_bridge_sdk` bridge: ```bash git submodule update --init --recursive diff --git a/docs/docs/intro.md b/docs/docs/intro.md index 9e8cb8be..8555ed00 100644 --- a/docs/docs/intro.md +++ b/docs/docs/intro.md @@ -32,7 +32,7 @@ InputProvider (BVH / Pico4 VR) | Spec | Value | |------|-------| | Policy frequency | 50 Hz | -| PD control frequency | 1000 Hz | +| PD control frequency | 200 Hz | | Observation dimension | 167D | | Action dimension | 29D (G1 joints) | | ONNX model | Dual-input TemporalCNN | diff --git a/docs/docs/reference/architecture.md b/docs/docs/reference/architecture.md index 8d5bfddc..f979d687 100644 --- a/docs/docs/reference/architecture.md +++ b/docs/docs/reference/architecture.md @@ -76,4 +76,4 @@ train_mimic/scripts/data **Stable training entry points:** `train.py`, `play.py`, `benchmark.py`, `save_onnx.py` -**Stable data entry points:** `build_dataset.py` +**Stable data entry points:** `build_dataset.py`, `precompute_dataset.py` diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md index 86ca2e09..4f3028c0 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md @@ -40,7 +40,7 @@ pip install -e '.[train]' pip install -e '.[sim2real]' ``` -额外安装 `opencv-python` 和 `g1_bridge_sdk`。此外还需要初始化子模块并编译 C++ 桥接库: +额外安装 `opencv-python`。此外还需要初始化子模块并编译/安装 C++ `g1_bridge_sdk` 桥接库: ```bash git submodule update --init --recursive diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/intro.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/intro.md index 39969812..30fa4e3c 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/intro.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/intro.md @@ -30,7 +30,7 @@ InputProvider (BVH / Pico4 VR) | 项目 | 参数 | |------|------| | 策略频率 | 50 Hz | -| PD 控制频率 | 1000 Hz | +| PD 控制频率 | 200 Hz | | 观测维度 | 167D | | 动作维度 | 29D(G1 关节) | | ONNX 模型 | 双输入 TemporalCNN | diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/architecture.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/architecture.md index 3f868eb1..5baef6ff 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/architecture.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/architecture.md @@ -76,4 +76,4 @@ train_mimic/scripts/data **稳定训练入口:** `train.py`、`play.py`、`benchmark.py`、`save_onnx.py` -**稳定数据入口:** `build_dataset.py` +**稳定数据入口:** `build_dataset.py`、`precompute_dataset.py` diff --git a/pyproject.toml b/pyproject.toml index 3c1f8dfa..48b178d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ authors = [ ] readme = "README.md" requires-python = ">=3.10" -license = {text = "MIT"} +license = {text = "Apache-2.0"} dependencies = [ "mujoco", "mink", From ec406328bbdf472df9e02fedf85af895cf466a6a Mon Sep 17 00:00:00 2001 From: Wu Bingqian Date: Thu, 25 Jun 2026 21:27:11 +0800 Subject: [PATCH 122/122] Release v0.4.0 --- CHANGELOG.md | 2 +- README.md | 2 +- pyproject.toml | 2 +- teleopit/__init__.py | 2 +- train_mimic/__init__.py | 4 ++-- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fc4daeec..e72a1b9e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # Changelog -## [Unreleased] +## [0.4.0] - 2026-06-25 - 改进 Pico 实时控制:支持 pico-bridge 0.2.1、`ARMS` 模式,以及保留 retargeter warm-start 的模式切换/暂停恢复。 - 新增可选 LinkerHand L6/O6 sim2real 控制,支持 Pico gripper 输入和低延迟 L6 `vr_hand_pose`。 diff --git a/README.md b/README.md index 79eea955..94d4f117 100644 --- a/README.md +++ b/README.md @@ -114,7 +114,7 @@ Full docs at **[BotRunner64.github.io/Teleopit](https://BotRunner64.github.io/Te ## Changelog -### Unreleased +### v0.4.0 (2026-06-25) - Improved Pico realtime control with pico-bridge 0.2.1, `ARMS` mode, and retargeter-preserving mode/pause resets. - Added optional LinkerHand L6/O6 sim2real control, including Pico gripper input and low-latency L6 `vr_hand_pose`. diff --git a/pyproject.toml b/pyproject.toml index 48b178d5..4bb6ce39 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "teleopit" -version = "0.3.0" +version = "0.4.0" description = "Teleoperation framework for humanoid robots with motion retargeting" authors = [ {name = "Teleopit Team"} diff --git a/teleopit/__init__.py b/teleopit/__init__.py index 3dc1f76b..6a9beea8 100644 --- a/teleopit/__init__.py +++ b/teleopit/__init__.py @@ -1 +1 @@ -__version__ = "0.1.0" +__version__ = "0.4.0" diff --git a/train_mimic/__init__.py b/train_mimic/__init__.py index fc8aa849..4796fe89 100644 --- a/train_mimic/__init__.py +++ b/train_mimic/__init__.py @@ -1,6 +1,6 @@ -__version__ = "0.1.0" +__version__ = "0.4.0" import os # TRAIN_MIMIC_ROOT_DIR points to train_mimic/ directory -TRAIN_MIMIC_ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) \ No newline at end of file +TRAIN_MIMIC_ROOT_DIR = os.path.dirname(os.path.abspath(__file__))