diff --git a/.gitignore b/.gitignore index 00126fc3..33034add 100644 --- a/.gitignore +++ b/.gitignore @@ -91,7 +91,10 @@ Thumbs.db # GMR retargeting assets (downloaded from ModelScope) teleopit/retargeting/gmr/assets/ -# Legacy training-side robot assets are no longer tracked; FK tooling reuses GMR unitree_g1. +# Canonical robot XML/meshes (downloaded from ModelScope) +assets/robots/ + +# Legacy training-side robot assets are no longer tracked. train_mimic/assets/ # ModelScope download cache diff --git a/.gitmodules b/.gitmodules index a5085287..1b97b5ef 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,9 @@ [submodule "third_party/unitree_sdk2_python"] path = third_party/unitree_sdk2_python url = https://github.com/unitreerobotics/unitree_sdk2_python.git +[submodule "third_party/linkerhand-python-sdk"] + path = third_party/linkerhand-python-sdk + url = https://github.com/BotRunner64/linkerhand-python-sdk.git +[submodule "third_party/somehand"] + path = third_party/somehand + url = https://github.com/BotRunner64/somehand.git diff --git a/AGENTS.md b/AGENTS.md index 8d62dc4f..ffbeb26c 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -11,7 +11,7 @@ Config: Hydra/OmegaConf YAML files in `teleopit/configs/` ## Architecture ``` -InputProvider (BVH file / Pico4 VR) → Retargeter (GMR) → ObservationBuilder (166D) → Controller (dual-input TemporalCNN ONNX) → Robot (MuJoCo + PD / Unitree SDK) +InputProvider (BVH file / Pico4 VR) → Retargeter (GMR) → ObservationBuilder (167D) → Controller (dual-input TemporalCNN ONNX) → Robot (MuJoCo + PD / Unitree SDK) ``` Module-internal isolation: all modules run in-process and communicate via `InProcessBus` (zero-copy). Core interfaces are defined as `typing.Protocol` in `teleopit/interfaces.py`. @@ -19,8 +19,8 @@ Module-internal isolation: all modules run in-process and communicate via `InPro ## Supported Surface - Training task: `General-Tracking-G1` -- Inference observation: `velcmd_history` (166D, dual-input ONNX with `obs` + `obs_history`) -- TemporalCNN actor/critic with larger dims (1024,512,256,256,128) +- Inference observation: `velcmd_history` (167D, dual-input ONNX with `obs` + `obs_history`) +- TemporalCNN actor/critic with scaled dims (2048,1024,512,256,128) - Realtime inference uses a retargeted-reference timeline before observation build; `reference_steps=[0]` is the default production path ## Directory Structure @@ -53,13 +53,17 @@ teleopit/ # Core inference package ├── robots/ │ └── mujoco_robot.py # MuJoCoRobot — MuJoCo sim wrapper ├── sim/ -│ └── loop.py # SimulationLoop — PD control at 1000Hz, policy at 50Hz -└── recording/ # HDF5Recorder +│ └── loop.py # SimulationLoop — PD control at 200Hz, policy at 50Hz +├── sim2real/ +│ ├── mp/ # Process-isolated sim2real runtime and IPC +│ └── hands/ # Optional LinkerHand driver/mapper plugins +└── recording/ # Pico motion NPZ recording helpers scripts/ -├── run_sim.py # Offline sim2sim pipeline -├── run_sim2real.py # G1 sim2real control; supports offline BVH playback and Pico4 -├── render_sim.py # Render single BVH → 3 MuJoCo videos (mocap input, retarget, sim2sim) -└── compute_ik_offsets.py # Compute IK quaternion offsets for new BVH formats +├── run/run_sim.py # Offline sim2sim pipeline +├── run/run_sim2real.py # G1 sim2real control; supports offline BVH playback and Pico4 +├── run/record_pico_motion.py # Interactive Pico recording → G1 motion NPZ clips +├── render/render_sim.py # Render single BVH → 3 MuJoCo videos (mocap input, retarget, sim2sim) +└── dev/compute_ik_offsets.py # Compute IK quaternion offsets for new BVH formats train_mimic/ # Training package ├── app.py # Shared app helpers for train/play/benchmark ├── tasks/tracking/config/ @@ -68,7 +72,7 @@ train_mimic/ # Training package │ ├── env.py # General-Tracking-G1 env builder │ └── rl.py # TemporalCNN PPO cfg ├── tasks/tracking/rl/ -│ ├── runner.py # ONNX export wrapper for policy + motion labels +│ ├── runner.py # Training runner and policy ONNX export wrapper │ ├── conv1d_encoder.py # 1-D CNN encoder for temporal history groups │ └── temporal_cnn_model.py # TemporalCNN actor/critic model └── scripts/ @@ -81,9 +85,9 @@ train_mimic/ # Training package ## Key Technical Details ### Sim2Sim Pipeline -- Policy runs at 50Hz, PD control at 1000Hz (`decimation=20`, `sim_dt=0.001`) +- Policy runs at 50Hz, PD control at 200Hz (`decimation=4`, `sim_dt=0.005`) - Action flow: `compute_action()` returns raw action → `get_target_dof_pos()` applies clip `[-10, 10]`, scale, and `default_dof_pos` -- Must use `g1_mjlab.xml` for sim2sim; `g1_mocap_29dof.xml` clamps torques to `±1 Nm` and is only for kinematic retarget visualization +- Must use `assets/robots/unitree_g1/g1_29dof.xml` for training, sim2sim, dataset FK, and retargeting; it is the canonical G1 XML entry point ### Multi-Viewer Support `SimulationLoop` supports multiple simultaneous viewer windows controlled by the `viewers` config: @@ -105,6 +109,7 @@ python scripts/run/run_sim.py controller.policy_path=policy.onnx viewers=none - `viewers=all` opens `mocap`, `retarget`, and `sim2sim`; add `camera` explicitly when needed - All viewers run in separate subprocesses because GLFW/GLX only supports one window per process - Simulation exits when all active viewer windows are closed +- sim2real defaults to `viewers=none`; it supports only optional `viewers=retarget` - `viewers` is the only supported viewer key; legacy `viewer` alias is removed ### default_dof_pos Propagation @@ -119,7 +124,7 @@ target_dof_pos = clip(action, -10, 10) × action_scale + default_dof_pos ### Offline Playback - Offline sim2sim and default sim2real both read `input.bvh_file` directly; no UDP relay path remains - Offline sim2sim playback can be keyboard-controlled: `Space/P` pause/resume, `R` replay from frame 0, `Q` stop -- Pause/resume now includes a short hold window; users should stay still during resume and pause again if visible distortion appears +- Offline pause holds the commanded pose; resume resets policy/reference state and reanchors yaw/XY without qpos interpolation or retargeter IK reset - sim2sim keyboard playback is optional via `playback.keyboard.enabled=true` - sim2real reuses the Unitree remote: `Start` → `STANDING`, `Y` → playback, `X` → back to `STANDING`, `L1+R1` → `DAMPING` - `playback.pause_on_end=true` keeps the final pose and waits for manual replay @@ -127,42 +132,61 @@ target_dof_pos = clip(action, -10, 10) × action_scale + default_dof_pos ### Pico4 Realtime Input - `Pico4InputProvider` reads realtime body tracking from the in-process `pico_bridge.PicoBridge` - The pico-bridge receiver runs on the Teleopit host, which can be a workstation PC or robot onboard computer; do not maintain a separate onboard Pico input mode -- pico-bridge 0.2.0 is the supported runtime; camera preview uses `PicoBridge(video="frames").push_video_frame(rgb_uint8)` +- pico-bridge 0.2.1 is the supported runtime; camera preview uses `PicoBridge(video="frames").push_video_frame(rgb_uint8)` - Pico video preview is optional and disabled by default; sim2sim uses the MuJoCo `d435i_rgb` camera and sim2real uses RealSense when `input.video.enabled=true` - Bone naming follows `pico_bridge_to_g1.json` - The provider applies an input-space transform to match the current retarget config - Do not hardcode that transform as a public coordinate-system contract; validate against actual retarget/sim2sim behavior when SDK or firmware changes - Pico4 realtime control uses the same retargeted-reference timeline path as the shared realtime input stack -- Pico sim2sim supports a keyboard-driven top-level mode state machine: `STANDING → MOCAP → STANDING` -- Default Pico sim2sim keyboard mappings are `Y` → `MOCAP`, `A` → pause/resume mocap, `X` → back to `STANDING`, `Q` → quit +- Pico sim2sim supports a keyboard-driven top-level mode state machine: `STANDING → MOCAP ↔ ARMS`, `X` returns to `STANDING` +- Default Pico sim2sim keyboard mappings are `Y` → `MOCAP`, `A` → pause/resume mocap, `B` → toggle `MOCAP`/`ARMS`, `X` → back to `STANDING`, `Q` → quit - Pico4 sim2real pause/resume is handled as a mocap-session control event (`toggle_pause`), not as a mode switch to `STANDING` -- Default Pico pause button is `A`; restore tracking by rebuilding the realtime buffer and yaw/XY root-offset alignment, then accept the current live mocap retarget reference directly +- Default Pico pause button is `A`; resume resets policy/reference state and yaw/XY root-offset alignment while the process-isolated realtime reference worker continues its live input timeline +- Pico4 sim2sim/sim2real support `ARMS` mode toggled from `MOCAP` with Pico/controller `B`; retargeting continues, while the control loop sends the motion tracker a composed reference with stand-pose body/legs/waist and live retargeted arms +- `ARMS` entering/exiting/resume resets policy/reference alignment and uses Kp ramp; offline BVH sim2real does not use `ARMS`, and Unitree remote `B` remains BVH replay +- Realtime mode switches and pause/resume use a retargeter-preserving soft reset: policy/reference state, smoothers, and reference alignment are reset, while the GMR IK warm-start is retained +- Optional LinkerHand control uses `hands.enabled=true`, `hands.driver=linkerhand_l6|linkerhand_o6`, and `hands.mode=gripper|vr_hand_pose`; default is disabled +- Optional Pico sim2real HDF5 recording uses `--config-name sim2real_record` or `recording.enabled=true`; it requires `input.provider=pico4`, `input.video.enabled=true`, `input.video.source=realsense`, an interactive terminal, and the `recording` extra +- Recording is manual only: terminal `R` starts an episode, `S` saves, `D` discards the active episode, and `Q` shuts down; `STANDING`, `MOCAP`, `ARMS`, and paused mocap are recordable +- Recording captures `observation.images.d435i_rgb` RealSense RGB video at 30Hz plus `observation.state(68)`, `observation.mode(1)`, `action(36)`, and `action.hand(12)`; RealSense capture lives in `pico_input` through the normal `input.video` path +- HDF5 recording writes compressed MP4 sidecar videos under `recording.output_dir/videos//` while HDF5 episodes store `frame_index`, `timestamp`, low-dimensional data, and video sync attributes; raw RGB image datasets are not supported +- `gripper` mode reuses `Pico4InputProvider.get_controller_snapshot()` for Pico grip/trigger open-close control and supports LinkerHand L6 and O6 +- `vr_hand_pose` mode reuses `Pico4InputProvider.get_hand_snapshot()` and somehand 0.2.0 public `somehand.api` for continuous Pico hand-pose retargeting; do not start a second `PicoBridge` for hand control +- Teleopit owns Pico 26-joint hand-state to 21-landmark conversion; do not import `somehand.pico_input` +- LinkerHand O6 supports only `hands.mode=gripper`; its default `close_pose` is `[86, 73, 118, 111, 110, 111]` +- L6 `gripper` mode uses the configured `hands.linkerhand_l6.speed` (default `[50]*6`); O6 `gripper` mode uses `hands.linkerhand_o6.speed` (default `[255]*6`); `vr_hand_pose` always sets LinkerHand L6 speed to `[255]*6` +- `vr_hand_pose` defaults to a low-latency somehand path: `hands.somehand.rate_hz=60`, `max_iterations=12`, `temporal_filter_alpha=1.0`, and `output_alpha=1.0`; this prioritizes response speed over smoothing +- LinkerHand control is active in all sim2real modes when `hands.enabled=true`; shutdown and hand-runtime failure must send the configured open pose +- In `vr_hand_pose` mode, missing/inactive hand pose holds the last commanded pose for that side instead of opening the hand ### SimulationLoop Runtime Behavior - `realtime=true` enforces wall-clock pacing even without a viewer - `num_steps=0` means infinite loop (`max_steps = 2**63`) - `KeyboardInterrupt` is handled for clean shutdown - BVH frame alignment is time-based: `bvh_idx = int(policy_time × input_fps)` -- Realtime reference buffering is controlled by `retarget_buffer_enabled`, `retarget_buffer_window_s`, `retarget_buffer_delay_s`, `reference_steps`, `realtime_buffer_warmup_steps`, and the low/high watermark knobs +- Realtime reference buffering is controlled by `retarget_buffer_enabled`, `retarget_buffer_window_s`, `retarget_buffer_delay_s`, `reference_steps`, and `realtime_buffer_warmup_steps` - Realtime inferred `motion_joint_vel`, anchor linear velocity, and anchor angular velocity can be EMA-smoothed via `reference_velocity_smoothing_alpha` and `reference_anchor_velocity_smoothing_alpha` -- Sim2real Pico pause/resume uses mocap-session states `ACTIVE ↔ PAUSED`; resume clears policy/reference state, warms the realtime buffer, rebuilds yaw/XY root alignment, and does not interpolate retarget qpos from the paused pose -- Realtime sim2sim with Pico control events uses the same mocap-session pause/resume semantics and rebuilds the realtime reference path on resume +- Sim2real Pico pause/resume uses mocap-session states `ACTIVE ↔ PAUSED`; resume clears policy/reference state, rebuilds yaw/XY root alignment, and does not interpolate retarget qpos from the paused pose +- Realtime sim2sim with Pico control events uses the same mocap-session pause/resume semantics and rebuilds the realtime reference path on resume, including the configured warmup +- Realtime sim2sim/sim2real `STANDING ↔ MOCAP` transitions use the same retargeter-preserving soft reset, rather than cold-starting the retargeter from its default qpos - Realtime Pico sim2sim can start directly in `STANDING` with keyboard mode control enabled via top-level `keyboard.enabled` ### Inference Observation -Observation format: `velcmd_history` (166D, dual-input ONNX) +Observation format: `velcmd_history` (167D, dual-input ONNX) ``` -command(58) -+ motion_anchor_ori_b(6) -+ base_ang_vel(3) -+ joint_pos_rel(29) -+ joint_vel(29) -+ last_action(29) -+ projected_gravity(3) -+ ref_base_lin_vel_b(3) -+ ref_base_ang_vel_b(3) +ref_joint_pos(29) ++ ref_joint_vel(29) ++ ref_anchor_ori_b(6) ++ robot_base_ang_vel_b(3) ++ robot_joint_pos_rel(29) ++ robot_joint_vel(29) ++ prev_action(29) ++ robot_projected_gravity_b(3) ++ ref_anchor_lin_vel_b(3) ++ ref_anchor_ang_vel_b(3) + ref_projected_gravity_b(3) ++ ref_anchor_height(1) ``` Runtime constraints: @@ -173,38 +197,52 @@ Runtime constraints: ### Training Task The single supported training task is `General-Tracking-G1` (experiment name: `g1_general_tracking`). -- Uses TemporalCNN actor/critic with larger dims (1024,512,256,256,128) -- 166D `velcmd_history` observation, dual-input ONNX export -- Training env uses `sampling_mode="uniform"` +- Uses TemporalCNN actor/critic with scaled dims (2048,1024,512,256,128) +- 167D `velcmd_history` observation, dual-input ONNX export +- Training env uses `sampling_mode="rewind"` +- Tracking rewards include root position/orientation/linear velocity/angular velocity, body pose/velocity, joint position/velocity, survival, action-rate, joint-limit, self-collision, and ankle acceleration terms +- Supported motion sampling modes are `uniform`, `start`, and `rewind`; `rewind` restarts failed environments from the same clip after stepping back `rewind_min_steps..rewind_max_steps` with probability `rewind_prob`, otherwise it falls back to uniform sampling - Playback/benchmark use `play=True`, which switches motion sampling to `start` - `window_steps=[0]` - `save_onnx.py` exports dual-input TemporalCNN ONNX ### Dataset Pipeline - Dataset build spec supports a `preprocess` section for root-xy normalization, ground alignment, and basic clip filtering -- Final dataset outputs are shard-only: `data/datasets//train/shard_*.npz` and `data/datasets//val/shard_*.npz` -- Each shard stores clip-aware metadata (`clip_starts`, `clip_lengths`, `clip_fps`, `clip_weights`); `MotionLib` loads only shard directories +- Final distributed dataset build outputs are minimal HDF5 shards directly under `data/datasets//` (recursive shard discovery is supported; no train/val split and no manifest file) +- `train_mimic/scripts/data/precompute_dataset.py` converts a minimal dataset into a separate precomputed training dataset directory; `build_dataset.py` must not run precompute +- Each shard stores only `root_pos`, `root_quat_w`, `joint_pos`, `body_names`, and clip-aware window metadata (`clip_starts`, `clip_lengths`, `clip_fps`); long clips are split into overlapping bounded windows +- Training `motion_file` must point to a precomputed training dataset, not the minimal distributed dataset; training reads joint velocities and body FK/velocities from those precomputed shards and must not run MuJoCo FK while loading motion clips +- `MotionLib` loads all discovered precomputed HDF5 motion windows into CPU/GPU memory at startup - `MotionLib` samples only valid center frames for the configured `window_steps`; default is `window_steps=[0]` +- Training supports `uniform` and `rewind` sampling over the fully loaded precomputed dataset +- `scripts/run/record_pico_motion.py` records Pico live body tracking as retargeted G1 motion NPZ clips in `data/pico_motion/clips/`; it opens a live `Retarget` viewer, uses terminal keys `R/S/D/N/Q`, stores semantic labels in filenames, and intentionally does not write per-clip JSON +- Build Pico-recorded clips into shards with `python train_mimic/scripts/data/build_dataset.py --spec data/pico_motion/pico_recorded.yaml --force` Quick reference: ```bash -python train_mimic/scripts/data/build_dataset.py --spec train_mimic/configs/datasets/twist2_full.yaml -python train_mimic/scripts/train.py --motion_file data/datasets/twist2_full/train +python train_mimic/scripts/data/build_dataset.py --spec train_mimic/configs/datasets/twist2.yaml +python scripts/run/record_pico_motion.py +python train_mimic/scripts/data/build_dataset.py --spec data/pico_motion/pico_recorded.yaml --force +python train_mimic/scripts/data/precompute_dataset.py data/datasets/seed --outdir data/datasets/seed_precomputed --jobs 8 +python train_mimic/scripts/train.py --motion_file data/datasets/seed_precomputed +python train_mimic/scripts/data/precompute_dataset.py data/datasets/twist2 --outdir data/datasets/twist2_precomputed --jobs 8 --force python train_mimic/scripts/save_onnx.py --checkpoint logs/rsl_rl/g1_general_tracking//model_30000.pt --output policy.onnx --history_length 10 ``` ### GMR Retargeting -- Self-contained in `teleopit/retargeting/gmr/`; assets need `scripts/setup/download_assets.py --only gmr` +- Self-contained in `teleopit/retargeting/gmr/`; assets need `scripts/setup/download_assets.py --only robots gmr` - Supports `lafan1` BVH (22 joints, 30fps, centimeters) - Supports `hc_mocap` BVH (50 joints, 60fps downsampled to 30fps, meters) - `lafan1-resolved` still needs an adapter layer and remains unsupported ### External Assets - Do not commit robot meshes, datasets, checkpoints, or demo media to Git; use `scripts/setup/download_assets.py` +- `assets/robots/unitree_g1/g1_29dof.xml` and its meshes are the canonical G1 robot model assets; they are downloaded from the `robots` asset group and are not tracked in Git - `teleopit/retargeting/gmr/assets/` is gitignored; downloaded at runtime -- `train_mimic/assets/` is no longer tracked; FK tooling reuses `teleopit/retargeting/gmr/assets/unitree_g1/g1_mjlab.xml` -- Run `python scripts/check_large_tracked_files.py` before pushing +- `train_mimic/assets/` is no longer tracked; FK tooling reuses `assets/robots/unitree_g1/g1_29dof.xml` +- `third_party/linkerhand-python-sdk` and `third_party/somehand` support optional LinkerHand sim2real control +- Run `python scripts/dev/check_large_tracked_files.py` before pushing Assets are split across two ModelScope repos by type: @@ -219,7 +257,7 @@ Asset group → repo mapping is defined in `teleopit/runtime/external_assets.py` ```bash # 1. Prepare upload directory -python scripts/setup/prepare_modelscope_assets.py --only ckpt gmr bvh --clean +python scripts/setup/prepare_modelscope_assets.py --only ckpt robots gmr bvh --clean python scripts/setup/prepare_modelscope_assets.py --only data # 2. Upload to each repo @@ -246,7 +284,7 @@ R_offset = R_human_tpose^{-1} * R_robot_tpose Critical note: align robot root orientation to the BVH human forward direction before computing `R_robot_tpose`. For `hc_mocap`, G1 default faces `+X` while the BVH human faces `-Y` (`Z-up`), so the robot root must receive a `-90°` Z rotation first. -`scripts/compute_ik_offsets.py` can print or write calibrated offsets. +`scripts/dev/compute_ik_offsets.py` can print or write calibrated offsets. ## Development @@ -270,4 +308,4 @@ pytest tests/ -v ## Known Issues 1. `lafan1-resolved` retargeting is still broken because it uses a different BVH skeleton layout. -2. `g1_mocap_29dof.xml` still has `ctrlrange="-1 1"`; never use it for sim2sim. +2. Legacy downloaded GMR XMLs under `teleopit/retargeting/gmr/assets/unitree_g1/` are not the project entry point; use `assets/robots/unitree_g1/g1_29dof.xml`. diff --git a/CHANGELOG.md b/CHANGELOG.md index 56573946..e72a1b9e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## [0.4.0] - 2026-06-25 + +- 改进 Pico 实时控制:支持 pico-bridge 0.2.1、`ARMS` 模式,以及保留 retargeter warm-start 的模式切换/暂停恢复。 +- 新增可选 LinkerHand L6/O6 sim2real 控制,支持 Pico gripper 输入和低延迟 L6 `vr_hand_pose`。 +- 新增 Pico sim2real 手动 HDF5 录制,以及用于训练数据采集的交互式 Pico motion recorder。 +- 优化训练数据流程:minimal HDF5 shards、显式 precompute、rewind 采样和更新后的 tracking rewards。 + ## [0.3.0] - 2026-05-12 - 重构实时输入栈,Pico 4 统一使用 pico-bridge 0.2.0 in-process receiver,并移除旧 ZMQ/onboard Pico 路径。 @@ -13,13 +20,12 @@ - 新增独立 Standing 控制器、离线播放键盘控制与 Pico sim2sim 模式控制。 - 优化实时 mocap 缓冲与 catch-up,并将发布模型升级至 30k checkpoint。 -## [0.1.1] - 2025-03-28 +## [0.1.1] - 2026-03-28 - 数据集改为 shard-only 输出。 -- 新增 adaptive_bin 采样。 - 引入外部资源管理并瘦身仓库。 -## [0.1.0] - 2025-03-25 +## [0.1.0] - 2026-03-25 - 首个公开版本。 - 支持 General-Tracking-G1 全身追踪训练与 ONNX sim2sim 推理。 diff --git a/README.md b/README.md index baba966a..94d4f117 100644 --- a/README.md +++ b/README.md @@ -32,9 +32,13 @@ pip install -e . ```bash pip install modelscope -python scripts/setup/download_assets.py --only gmr ckpt bvh +python scripts/setup/download_assets.py --only robots gmr ckpt bvh ``` +The canonical Unitree G1 robot model is downloaded to +`assets/robots/unitree_g1/g1_29dof.xml`. Training, sim2sim, retargeting, and FK +validation all use this same XML. + **3. Run** ```bash @@ -54,12 +58,69 @@ python scripts/run/run_sim.py \ 'viewers=[sim2sim,camera]' ``` +For sim2real, viewers are disabled by default. Add `viewers=retarget` to show +the retargeted reference in an optional MuJoCo window. + +## Pico Motion Recording + +Record many Pico clips as training-ready G1 motion NPZ files: + +```bash +pip install -e '.[pico4]' +python scripts/run/record_pico_motion.py +``` + +The recorder starts the Pico receiver and live Retarget viewer before waiting +for clip names, so preview keeps running while the terminal is idle. Enter a +semantic clip name, then use `R` to start, `S` to save, `D` to discard, `N` for +a new name, and `Q` to quit. Saved clips are written to +`data/pico_motion/clips/` using the semantic label in the filename, with no +sidecar JSON. + +Merge recorded clips into the standard HDF5 shard dataset: + +```bash +python train_mimic/scripts/data/build_dataset.py \ + --spec data/pico_motion/pico_recorded.yaml --force +``` + +## Sim2Real HDF5 Recording + +Pico sim2real can also record manual HDF5 episodes from the real G1: + +```bash +pip install -e '.[recording]' +# If you use RealSense video, install pyrealsense2 manually for your platform. +# On Arm machines, prefer conda-forge: +# conda install -c conda-forge pyrealsense2 +python scripts/run/run_sim2real.py --config-name sim2real_record \ + controller.policy_path=track.onnx \ + recording.task="walk forward" +``` + +Recording uses the terminal controls `R` start, `S` save, `D` discard, and `Q` +shutdown. `STANDING`, `MOCAP`, `ARMS`, and paused mocap can be recorded. Saved +episodes are written as `.h5` files under `data/recordings/sim2real_hdf5/episodes/`. +`sim2real_record.yaml` stores camera frames as compressed MP4 sidecar files under +`data/recordings/sim2real_hdf5/videos/` and keeps `frame_index` / `timestamp` +sync metadata in the HDF5 episode. The low-dimensional HDF5 schema records +`observation.state(68)`, `observation.mode(1)`, `action(36)` as the aligned +reference qpos sent to the policy path, and `action.hand(12)` as the latest +LinkerHand left/right 6D pose commands. + ## Documentation Full docs at **[BotRunner64.github.io/Teleopit](https://BotRunner64.github.io/Teleopit/)**, covering installation profiles, all tutorials, configuration reference, and architecture. ## Changelog +### v0.4.0 (2026-06-25) + +- Improved Pico realtime control with pico-bridge 0.2.1, `ARMS` mode, and retargeter-preserving mode/pause resets. +- Added optional LinkerHand L6/O6 sim2real control, including Pico gripper input and low-latency L6 `vr_hand_pose`. +- Added manual Pico sim2real HDF5 recording and an interactive Pico motion recorder for training NPZ clips. +- Refined the training data path with minimal HDF5 shards, explicit precompute, rewind sampling, and updated tracking rewards. + ### v0.3.0 (2026-05-12) - Consolidated realtime input around pico-bridge 0.2.0 and removed the old ZMQ/onboard Pico path. @@ -73,12 +134,12 @@ Full docs at **[BotRunner64.github.io/Teleopit](https://BotRunner64.github.io/Te - Added offline playback keyboard controls, Pico sim2sim mode control, and a standalone standing controller. - Improved realtime mocap buffering/catch-up and upgraded the released model to the 30k checkpoint. -### v0.1.1 (2025-03-28) +### v0.1.1 (2026-03-28) -- Dataset shard-only refactor and `adaptive_bin` sampling +- Dataset shard-only refactor - External asset management (ModelScope), repository slimming -### v0.1.0 (2025-03-25) +### v0.1.0 (2026-03-25) - Initial public release: General-Tracking-G1 training, ONNX sim2sim inference, Pico 4 VR teleoperation, Unitree G1 hardware deployment diff --git a/docs/docs/configuration/config-reference.md b/docs/docs/configuration/config-reference.md index 96b86f21..f6bf965f 100644 --- a/docs/docs/configuration/config-reference.md +++ b/docs/docs/configuration/config-reference.md @@ -15,7 +15,6 @@ Complete reference for all configurable fields. | `viewers` | Viewer set: `mocap`, `retarget`, `sim2sim`, `camera`, `all`, `none`. `all` opens `mocap`, `retarget`, and `sim2sim`; add `camera` explicitly. | `sim2sim` | | `realtime` | Rate-limit to wall clock | `false` | | `num_steps` | Number of steps; `0` = infinite | `0` | -| `transition_duration` | Smooth transition time (seconds) from current pose to retarget command | - | | `keyboard.enabled` | Enable realtime keyboard mode control for sim2sim | `false` | | `playback.pause_on_end` | Pause at last frame when offline motion ends | `false` | | `playback.keyboard.enabled` | Enable keyboard control for offline playback | `false` | @@ -63,13 +62,15 @@ Complete reference for all configurable fields. | `input.pico4_buffer_size` | Frame buffer size | `60` | | `input.pause_button` | Button for pause/resume | `A` | | `input.pause_debounce_s` | Debounce time for pause button | `0.25` | +| `input.arms_button` | Button for Pico `MOCAP` / `ARMS` toggle | `B` | +| `input.arms_debounce_s` | Debounce time for arms-mode button | `0.25` | | `input.bridge_host` | Teleopit host receiver bind host | `0.0.0.0` | | `input.bridge_port` | Teleopit host receiver TCP/UDP port | `63901` | | `input.bridge_discovery` | Enable pico-bridge discovery advertising | `true` | | `input.bridge_advertise_ip` | Optional advertised host IP override | `null` | | `input.bridge_start_timeout` | Timeout while starting the bridge | `10.0` | | `input.bridge_history_size` | Pico frame history retained by the bridge | `120` | -| `input.video.enabled` | Stream host camera preview back to Pico through pico-bridge 0.2.0 | `false` | +| `input.video.enabled` | Stream host camera preview back to Pico through pico-bridge 0.2.1 | `false` | | `input.video.source` | Video source: `mujoco`, `realsense`, or `test-pattern` | `null` | | `input.video.width` / `height` / `fps` | Video capture/render settings | `1280` / `720` / `30` | | `input.video.device` | Optional RealSense serial | `null` | @@ -84,8 +85,6 @@ Complete reference for all configurable fields. | `retarget_buffer_delay_s` | Buffer delay | | `reference_steps` | Reference window steps | | `realtime_buffer_warmup_steps` | Warmup before playback | -| `realtime_buffer_low_watermark_steps` | Low watermark | -| `realtime_buffer_high_watermark_steps` | High watermark | | `reference_velocity_smoothing_alpha` | Velocity smoothing | | `reference_anchor_velocity_smoothing_alpha` | Anchor velocity smoothing | @@ -93,14 +92,18 @@ Complete reference for all configurable fields. Fields used by sim2real configs (`sim2real.yaml`, `pico4_sim2real.yaml`). +Sim2real defaults to `viewers=none`. Set `viewers=retarget` to open an optional +MuJoCo window showing the retargeted reference; `sim2sim`, `mocap`, `camera`, +and `all` are simulation-only viewer modes. + ### Safety | Field | Description | Default | |-------|-------------|---------| -| `startup_ramp_duration` | Seconds to smoothly blend from locked to policy positions | `2.0` | +| `startup_ramp_duration` | Kp ramp duration after entering `STANDING`; gradually increases PD gains without changing policy targets | `2.0` | | `joint_vel_limit` | Joint velocity limit (rad/s); triggers emergency damping if exceeded | `10.0` | | `mocap_switch.check_frames` | Consecutive valid frames required before switching to MOCAP | `10` | -| `mocap_switch.max_position_value` | Position sanity threshold in meters | `5.0` | +| `arm_mocap.controlled_joint_indices` | G1 joints driven by live retargeting in Pico `ARMS` mode | `[15..28]` | ### Real Robot @@ -116,22 +119,94 @@ Fields used by sim2real configs (`sim2real.yaml`, `pico4_sim2real.yaml`). ### Pause/Resume (Pico sim2real) -| Field | Description | Default | -|-------|-------------|---------| -| `pause_resume_transition_duration` | Resume blend duration for offline playback; realtime Pico resume uses live tracking re-centering | `1.0` | -| `pause_resume_warmup_steps` | Realtime mocap frames to collect before tracking resumes | `2` | - Realtime Pico resume re-centers heading and ground-plane position before tracking continues. Operators should keep still and stay as close as practical to the paused pose to reduce sudden reference changes. -### Realtime Catch-up (Pico sim2real) +### Dexterous Hand (Pico sim2real) + +`hands.enabled=true` requires `input.provider=pico4` plus local editable +installs of `third_party/linkerhand-python-sdk` and `third_party/somehand`. +When enabled, hand control remains active in all sim2real modes. +`gripper` supports `linkerhand_l6` and `linkerhand_o6` by interpolating Pico +trigger input between the configured open and close poses. `vr_hand_pose` is +L6-only: missing hand pose holds the last command for that side, L6 speed is +set to the maximum, and Teleopit converts Pico hand state to 21 landmarks before +calling somehand 0.2.0 through `somehand.api` only. | Field | Description | Default | |-------|-------------|---------| -| `realtime_catchup_enabled` | Enable catch-up when buffer grows too large | `true` | -| `realtime_catchup_trigger_steps` | Buffer depth that triggers catch-up | `6` | -| `realtime_catchup_release_steps` | Buffer depth to release catch-up | `3` | -| `realtime_catchup_target_delay_s` | Target delay for catch-up | `0.04` | -| `reference_qpos_smoothing_alpha` | Joint position smoothing (1.0 = no smoothing) | `0.4` | +| `hands.enabled` | Enable optional hand worker | `false` | +| `hands.driver` | Hand driver plugin: `linkerhand_l6` or `linkerhand_o6` | `linkerhand_l6` | +| `hands.mode` | `gripper` or `vr_hand_pose` | `gripper` | +| `hands.sides` | Controlled sides | `[left, right]` | +| `hands.rate_hz` | Maximum gripper command rate in Hz | `30.0` | +| `hands.frame_timeout_s` | Controller or hand-pose staleness threshold | `0.3` | +| `hands.linkerhand_l6.left_can` / `right_can` | CAN channels for each hand | `can0` / `can1` | +| `hands.linkerhand_l6.speed` | L6 speed used by `gripper`; `vr_hand_pose` overrides this to maximum speed | see config | +| `hands.linkerhand_l6.open_pose` / `close_pose` | Six-value L6 open/closed poses | see config | +| `hands.linkerhand_o6.left_can` / `right_can` | CAN channels for each O6 hand | `can0` / `can1` | +| `hands.linkerhand_o6.speed` | O6 speed used by `gripper` | see config | +| `hands.linkerhand_o6.open_pose` / `close_pose` | Six-value O6 open/closed poses | see config | +| `hands.somehand.config_path` | Official somehand 0.2.0 bi-hand L6 config used by `vr_hand_pose` | see config | +| `hands.somehand.rate_hz` | Low-latency `vr_hand_pose` command rate in Hz | `60.0` | +| `hands.somehand.max_iterations` | somehand solver iteration cap for `vr_hand_pose` | `12` | +| `hands.somehand.temporal_filter_alpha` | somehand input landmark smoothing alpha; `1.0` disables smoothing delay | `1.0` | +| `hands.somehand.output_alpha` | somehand qpos output smoothing alpha; `1.0` disables smoothing delay | `1.0` | + +### HDF5 Recording (Pico sim2real) + +`recording.enabled=true` is supported only with `input.provider=pico4`, +`input.video.enabled=true`, `input.video.source=realsense`, and an interactive +terminal. The recorder is manual: `R` starts an episode, `S` saves the active +episode, `D` discards the active episode, and `Q` shuts down. `STANDING`, +`MOCAP`, `ARMS`, and paused mocap can be recorded. + +`sim2real_record.yaml` enables both recording and the required RealSense +`input.video` path. Recording does not open a second camera; it consumes the +same frames produced by `pico_input`. + +| Field | Description | Default | +|-------|-------------|---------| +| `recording.enabled` | Enable manual HDF5 recording | `false` | +| `recording.output_dir` | Dataset root directory | `data/recordings/sim2real_hdf5` | +| `recording.task` | Task string stored with frames | `demo` | +| `recording.fps` | Recording/video clock rate | `30` | +| `recording.min_episode_seconds` | Discard saved episodes shorter than this duration | `1.0` | +| `recording.record_modes` | Modes that allow recording start and frame writes | `[standing, mocap, arms, pause]` | +| `recording.camera.key` | RGB image dataset key | `observation.images.d435i_rgb` | +| `recording.camera.width` / `height` / `fps` | RealSense RGB capture settings | `640` / `480` / `30` | +| `recording.camera.device` | Optional RealSense serial | `null` | +| `recording.video.codec` / `quality` / `pixelformat` | MP4 sidecar encoder settings | `libx264` / `8` / `yuv420p` | + +Camera failure behavior is controlled by `input.video.fail_on_error`. + +Each saved episode has one `.h5` file under `recording.output_dir/episodes/` +and one compressed MP4 sidecar under +`recording.output_dir/videos//`. The HDF5 episode stores +`frame_index` and `timestamp` arrays, plus `video_path`, `video_fps`, and +`video_frames` root attributes for synchronization. Raw RGB image datasets are +not written. + +HDF5 datasets: + +```text +frame_index int64[N] +timestamp float64[N] +observation.state float32[68] +observation.mode float32[1] +action float32[36] +action.hand float32[12] +``` + +The root attributes include the Teleopit HDF5 recording format, schema version, +task, fps, frame count, and video sync metadata. + +`observation.state` is ordered as `joint_pos(29)`, `joint_vel(29)`, +`base_quat_wxyz(4)`, `base_ang_vel(3)`, and `projected_gravity(3)`. +`observation.mode` is a numeric categorical: `standing=0`, `mocap=1`, +`arms=2`, and `pause=3`. `action` is the current reference qpos: +`root_pos(3) + root_quat_wxyz(4) + joint_pos(29)`. +`action.hand` is the latest LinkerHand command from the hand worker: +`left_pose(6) + right_pose(6)`, using the SDK's 0-255 pose values. ## Critical: `default_dof_pos` diff --git a/docs/docs/configuration/faq.md b/docs/docs/configuration/faq.md index 2509abcb..ccc28537 100644 --- a/docs/docs/configuration/faq.md +++ b/docs/docs/configuration/faq.md @@ -7,8 +7,7 @@ sidebar_position: 3 ## Why does it fail even though I set `policy_path`? 1. Verify the file exists -2. Confirm it's not an old 1402D / TWIST2 ONNX model -3. Confirm the input dimension is `166` with dual inputs (`obs` + `obs_history`) +2. Confirm the input dimension is `167` with dual inputs (`obs` + `obs_history`) ## Why must I specify `input.bvh_file` explicitly? diff --git a/docs/docs/contributing.md b/docs/docs/contributing.md index b38bda60..465fc06d 100644 --- a/docs/docs/contributing.md +++ b/docs/docs/contributing.md @@ -34,7 +34,7 @@ teleopit/ # Core inference & deployment package ├── retargeting/ # GMR motion retargeting ├── sim/ # SimulationLoop, reference motion utilities ├── sim2real/ # Hardware state machines -├── recording/ # HDF5Recorder +├── recording/ # Pico motion recording helpers ├── runtime/ # Config parsing, factories, external assets ├── bus/ # InProcessBus for inter-component communication └── configs/ # Hydra YAML configurations diff --git a/docs/docs/getting-started/download-assets.md b/docs/docs/getting-started/download-assets.md index fc8e6182..994d9486 100644 --- a/docs/docs/getting-started/download-assets.md +++ b/docs/docs/getting-started/download-assets.md @@ -20,25 +20,28 @@ python scripts/setup/download_assets.py Download only what you need for inference: ```bash -python scripts/setup/download_assets.py --only gmr ckpt bvh +python scripts/setup/download_assets.py --only robots gmr ckpt bvh ``` ## Asset Inventory -| Asset | Size | Purpose | -|-------|------|---------| -| `track.onnx` | 4 MB | ONNX inference model | -| `track.pt` | 27 MB | PyTorch checkpoint (for resume training) | -| `data/datasets/seed/train/shard_*.npz` | ~25 GB | Training dataset | -| `data/datasets/seed/val/shard_*.npz` | ~1.4 GB | Validation dataset | -| `data/sample_bvh/*.bvh` | 5 MB | Sample motion files | -| `teleopit/retargeting/gmr/assets/` | ~1.2 GB | GMR retargeting robot models | +Downloaded file sizes change as checkpoints, datasets, and asset bundles are updated. Use the repository paths below as the stable contract. + +| Local Path | Purpose | +|------------|---------| +| `track.onnx` | ONNX inference model | +| `track.pt` | PyTorch checkpoint for resume training | +| `data/datasets/seed/shard_*.h5` | Minimal motion dataset; run precompute before training | +| `data/sample_bvh/*.bvh` | Sample motion files | +| `assets/robots/unitree_g1/` | Canonical G1 XML and meshes used by training, sim2sim, retargeting, and FK validation | +| `teleopit/retargeting/gmr/assets/` | GMR retargeting assets, IK configs, and non-canonical robot descriptions | ## Asset Groups | Group | ModelScope Repo | Contents | |-------|----------------|----------| | `ckpt` | `BingqianWu/Teleopit-models` | `track.onnx`, `track.pt` | +| `robots` | `BingqianWu/Teleopit-models` | Canonical robot XML/meshes | | `gmr` | `BingqianWu/Teleopit-models` | GMR retargeting assets | | `bvh` | `BingqianWu/Teleopit-models` | Sample BVH motion files | | `data` | `BingqianWu/Teleopit-datasets` | Training/validation shards | diff --git a/docs/docs/getting-started/installation.md b/docs/docs/getting-started/installation.md index 3a0ceff0..ed9daf6c 100644 --- a/docs/docs/getting-started/installation.md +++ b/docs/docs/getting-started/installation.md @@ -32,7 +32,7 @@ This is sufficient for offline BVH playback and MuJoCo simulation. pip install -e '.[train]' ``` -Adds `rsl-rl-lib`, `mjlab`, `wandb`, and training dependencies. +Adds `rsl-rl-lib`, `mjlab`, `wandb`, `swanlab`, and training dependencies. ### Sim2Real (Hardware Deployment) @@ -40,7 +40,7 @@ Adds `rsl-rl-lib`, `mjlab`, `wandb`, and training dependencies. pip install -e '.[sim2real]' ``` -Adds `opencv-python` and `g1_bridge_sdk`. You also need to initialize submodules and build the C++ bridge: +Adds `opencv-python`. You also need to initialize submodules and build/install the C++ `g1_bridge_sdk` bridge: ```bash git submodule update --init --recursive @@ -56,13 +56,39 @@ pip install -e '.[pico4]' ``` Teleopit uses the in-process `pico_bridge.PicoBridge` receiver for Pico tracking. -For Teleopit 0.3.0, use pico-bridge 0.2.0. Do not upgrade the receiver to -pico-bridge 0.2.1, because 0.2.1 changes interface semantics that Teleopit -0.3.0 does not target. +Teleopit targets pico-bridge 0.2.1 and its `pico_native` tracking semantics. The receiver can run on a workstation PC or the robot onboard computer. See [Pico Sim2Sim](../tutorials/pico-sim2sim) and [Pico Sim2Real](../tutorials/pico-sim2real) for the full setup guides. +Optional LinkerHand control for Pico sim2real uses local third-party packages. +Install those packages directly after initializing the submodules: + +```bash +git submodule update --init --recursive +pip install -e third_party/linkerhand-python-sdk +pip install -e third_party/somehand +scripts/setup/download_somehand_l6_assets.sh +``` + +These packages are only required when `hands.enabled=true`. + +### Sim2Real Recording + +```bash +pip install -e '.[recording]' +``` + +Adds the Pico sim2real stack plus the video dependencies used by +`sim2real_record.yaml`. RealSense Python bindings are platform-specific: install +`pyrealsense2` manually in the active environment when using +`input.video.source=realsense`. On Arm machines, use conda-forge rather than the +pip package: + +```bash +conda install -c conda-forge pyrealsense2 +``` + ## Verify Installation ```bash diff --git a/docs/docs/getting-started/quick-start.md b/docs/docs/getting-started/quick-start.md index dc59e17e..bede79a1 100644 --- a/docs/docs/getting-started/quick-start.md +++ b/docs/docs/getting-started/quick-start.md @@ -9,7 +9,7 @@ This guide walks you through running your first sim2sim playback in under 5 minu ## Prerequisites 1. [Install Teleopit](installation) (inference profile) -2. [Download assets](download-assets) (`--only gmr ckpt bvh`) +2. [Download assets](download-assets) (`--only robots gmr ckpt bvh`) ## Run Offline Sim2Sim @@ -55,7 +55,7 @@ python scripts/run/run_sim.py controller.policy_path=track.onnx 'viewers=[retarg ## What's Next -- [Offline Sim2Sim Tutorial](../tutorials/offline-sim2sim) - Full guide with recording and rendering +- [Offline Sim2Sim Tutorial](../tutorials/offline-sim2sim) - Full guide with rendering - [Pico Sim2Sim](../tutorials/pico-sim2sim) - Verify Pico tracking in MuJoCo - [Standalone Standing](../tutorials/standalone-standing) - Check G1 bridge, network, and policy standing - [Pico Sim2Real](../tutorials/pico-sim2real) - Deploy Pico teleoperation to Unitree G1 diff --git a/docs/docs/intro.md b/docs/docs/intro.md index b26720b9..8555ed00 100644 --- a/docs/docs/intro.md +++ b/docs/docs/intro.md @@ -22,7 +22,7 @@ slug: / ```text InputProvider (BVH / Pico4 VR) -> Retargeter (GMR) - -> ObservationBuilder (166D) + -> ObservationBuilder (167D) -> Controller (dual-input TemporalCNN ONNX) -> Robot (MuJoCo sim or Unitree G1) ``` @@ -32,8 +32,8 @@ InputProvider (BVH / Pico4 VR) | Spec | Value | |------|-------| | Policy frequency | 50 Hz | -| PD control frequency | 1000 Hz | -| Observation dimension | 166D | +| PD control frequency | 200 Hz | +| Observation dimension | 167D | | Action dimension | 29D (G1 joints) | | ONNX model | Dual-input TemporalCNN | | Retargeting | GMR (General Motion Retargeting) | diff --git a/docs/docs/reference/architecture.md b/docs/docs/reference/architecture.md index fb0915e0..f979d687 100644 --- a/docs/docs/reference/architecture.md +++ b/docs/docs/reference/architecture.md @@ -11,12 +11,12 @@ System internals and technical constraints for developers. ```text InputProvider (BVH file / Pico4) -> Retargeter (GMR) - -> ObservationBuilder (166D) + -> ObservationBuilder (167D) -> Controller (dual-input TemporalCNN ONNX) -> Robot (MuJoCo sim or Unitree G1) ``` -Offline/online inference is assembled by `teleopit/runtime/` and `teleopit/pipeline.py`. The hardware state machine lives in `teleopit/sim2real/controller.py`. Training is provided by `train_mimic/`. +Offline/online inference is assembled by `teleopit/runtime/` and `teleopit/pipeline.py`. The hardware state machine runs through the process-isolated runtime in `teleopit/sim2real/mp/`. Training is provided by `train_mimic/`. ## Code Structure @@ -40,12 +40,12 @@ train_mimic/scripts/data | Module | Role | |--------|------| -| `teleopit/interfaces.py` | Stable protocols: InputProvider, Retargeter, Controller, Robot, ObservationBuilder, Recorder | +| `teleopit/interfaces.py` | Stable protocols: InputProvider, Retargeter, Controller, Robot, ObservationBuilder | | `teleopit/runtime/` | Config parsing, path normalization, component assembly, CLI validation | | `teleopit/pipeline.py` | Lightweight facade for offline sim | -| `teleopit/sim2real/controller.py` | Hardware state machine and control logic | +| `teleopit/sim2real/mp/` | Process-isolated sim2real state machine, IPC, and robot-control loop | | `teleopit/controllers/observation.py` | ObservationBuilder | -| `teleopit/controllers/rl_policy.py` | Only accepts 166D dual-input ONNX | +| `teleopit/controllers/rl_policy.py` | Accepts dual-input ONNX whose observation dimension matches the runtime builder | | `train_mimic/app.py` | Shared train/play/benchmark assembly | | `train_mimic/tasks/tracking/config/` | Single task registration (`General-Tracking-G1`) | | `train_mimic/data/dataset_builder.py` | Sole official dataset construction entry | @@ -55,12 +55,12 @@ train_mimic/scripts/data | Spec | Value | |------|-------| | Training task | `General-Tracking-G1` | -| Inference observation | `velcmd_history` (166D) | -| ONNX signature | Dual-input `obs` (166D) + `obs_history` | -| Actor/Critic | TemporalCNN (1024, 512, 256, 256, 128) | -| Training sampling | `uniform`; playback/benchmark use `start` | +| Inference observation | `velcmd_history` (167D) | +| ONNX signature | Dual-input `obs` (167D) + `obs_history` | +| Actor/Critic | TemporalCNN (2048, 1024, 512, 256, 128) | +| Training sampling | Default `rewind`; also supports `uniform`; playback/benchmark use `start` | | Training `window_steps` | `[0]` | -| Data format | Shard directories only (`shard_*.npz`) | +| Data format | Minimal recursive HDF5 shards (`shard_*.h5`) | ## Constraints @@ -68,7 +68,7 @@ train_mimic/scripts/data - Offline BVH runs require explicit `input.bvh_file` - `viewers` is the sole viewer configuration entry - Observation/ONNX dimension mismatch causes immediate startup error -- sim2real also only supports 166D dual-input ONNX +- sim2real also requires a dual-input ONNX whose observation dimension matches the runtime builder ## Public Surface @@ -76,4 +76,4 @@ train_mimic/scripts/data **Stable training entry points:** `train.py`, `play.py`, `benchmark.py`, `save_onnx.py` -**Stable data entry points:** `build_dataset.py`, `split_shards.py` +**Stable data entry points:** `build_dataset.py`, `precompute_dataset.py` diff --git a/docs/docs/reference/assets.md b/docs/docs/reference/assets.md index c41f8bb1..92322c10 100644 --- a/docs/docs/reference/assets.md +++ b/docs/docs/reference/assets.md @@ -4,11 +4,12 @@ sidebar_position: 2 # Asset Management -Robot models, datasets, checkpoints, and demo media are not tracked in Git. They are distributed via ModelScope and HuggingFace. +Datasets, checkpoints, robot models, and demo media are not tracked in Git. They are distributed via ModelScope and HuggingFace. The canonical Unitree G1 model is downloaded to `assets/robots/unitree_g1/g1_29dof.xml`. ## What's Not in Git -- `teleopit/retargeting/gmr/assets/` - Robot meshes, URDF/MJCF +- `assets/robots/` - Canonical robot XML/meshes +- `teleopit/retargeting/gmr/assets/` - GMR retargeting assets, IK configs, and non-canonical robot descriptions - `data/`, checkpoints, caches - Demo media (`assets/demo.gif`, `assets/demo.mp4`) @@ -33,9 +34,10 @@ Robot models, datasets, checkpoints, and demo media are not tracked in Git. They | Group | Repository | Remote Path | |-------|-----------|-------------| | `ckpt` | Teleopit-models | `checkpoints/track.onnx`, `checkpoints/track.pt` | +| `robots` | Teleopit-models | `archives/robot_assets.tar.gz` | | `gmr` | Teleopit-models | `archives/gmr_assets.tar.gz` | | `bvh` | Teleopit-models | `archives/sample_bvh.tar.gz` | -| `data` | Teleopit-datasets | `data/train/`, `data/val/` | +| `data` | Teleopit-datasets | `data/` | ## Download @@ -46,7 +48,7 @@ Use the project download script (defaults to ModelScope): python scripts/setup/download_assets.py # Only inference essentials -python scripts/setup/download_assets.py --only gmr ckpt bvh +python scripts/setup/download_assets.py --only robots gmr ckpt bvh # Only training data python scripts/setup/download_assets.py --only data @@ -61,17 +63,17 @@ Local paths after download: |--------|-------| | `checkpoints/track.onnx` | `track.onnx` | | `checkpoints/track.pt` | `track.pt` | +| `archives/robot_assets.tar.gz` | `assets/robots/` (extracted) | | `archives/gmr_assets.tar.gz` | `teleopit/retargeting/gmr/assets/` (extracted) | | `archives/sample_bvh.tar.gz` | `data/sample_bvh/` (extracted) | -| `data/train/` | `data/datasets/seed/train/` | -| `data/val/` | `data/datasets/seed/val/` | +| `data/` | `data/datasets/seed/` | ## Upload to ModelScope ### Step 1: Prepare Upload Directory ```bash -python scripts/setup/prepare_modelscope_assets.py --only ckpt gmr bvh --clean +python scripts/setup/prepare_modelscope_assets.py --only ckpt robots gmr bvh --clean python scripts/setup/prepare_modelscope_assets.py --only data ``` @@ -112,7 +114,7 @@ Tags should match Git tags for traceability. ```bash # Prepare and upload model assets (--clean ensures no leftover files) -python scripts/setup/upload_hf_assets.py --only ckpt gmr bvh --clean +python scripts/setup/upload_hf_assets.py --only ckpt robots gmr bvh --clean # Prepare and upload dataset python scripts/setup/upload_hf_assets.py --only data --clean diff --git a/docs/docs/reference/dataset.md b/docs/docs/reference/dataset.md index 348db65a..b5580847 100644 --- a/docs/docs/reference/dataset.md +++ b/docs/docs/reference/dataset.md @@ -7,62 +7,88 @@ sidebar_position: 3 ## Download Pre-Built Dataset (Recommended) ```bash -python scripts/setup/download_assets.py --only data +python scripts/setup/download_assets.py --only robots data ``` -Then train directly with the shard directory: +Then precompute the training shard and train with the precomputed dataset root: ```bash -python train_mimic/scripts/train.py --motion_file data/datasets/seed/train +python train_mimic/scripts/data/precompute_dataset.py \ + data/datasets/seed --outdir data/datasets/seed_precomputed --jobs 8 +python train_mimic/scripts/train.py --motion_file data/datasets/seed_precomputed ``` For custom dataset construction, read on. --- +## Record Pico Clips + +Use the interactive Pico recorder to create training-ready NPZ clips from live +body tracking: + +```bash +pip install -e '.[pico4]' +python scripts/run/record_pico_motion.py +``` + +The recorder starts the Pico receiver and live `Retarget` viewer before waiting +for clip names, so preview keeps running while the terminal is idle. Enter a +semantic clip name, then use `R` to start, `S` to save, `D` to discard, `N` to +enter a new name, and `Q` to quit. Saved clips go to +`data/pico_motion/clips/` as `_.npz`; no per-clip +JSON is written, so clips can be renamed or deleted manually. + +Build all recorded clips into the standard HDF5 shard dataset: + +```bash +python train_mimic/scripts/data/build_dataset.py \ + --spec data/pico_motion/pico_recorded.yaml --force +``` + +At least one valid clip is required after preprocessing. + ## Custom Dataset Construction -Data pipeline: `typed source YAML -> preprocess/filter -> shard-only training data` +Data pipeline: `typed source YAML -> preprocess/filter -> minimal HDF5 shards -> precomputed training dataset` ```bash python train_mimic/scripts/data/build_dataset.py \ - --spec train_mimic/configs/datasets/twist2_full.yaml + --spec train_mimic/configs/datasets/twist2.yaml ``` ## Output Structure ```text data/datasets// -├── clips/ # Optional; only for per-clip intermediates -│ └── /... -├── train/ -│ └── shard_*.npz -├── val/ -│ └── shard_*.npz -├── manifest_resolved.csv -└── build_info.json +└── shard_*.h5 + +data/datasets/_precomputed/ +└── shard_*.h5 ``` -- If the spec contains `bvh` or `npz` sources, the builder retains/generates `clips/` -- If the spec is all `pkl` or `seed_csv` sources, the builder takes a batch path producing split-level shards directly +- If the spec contains `bvh` or `npz` sources, the full dataset builder uses a temporary `clips/` directory during conversion and deletes it after shards are written. Rebuilds do not reuse converted clips. +- If the spec is all `pkl` or `seed_csv` sources, the builder takes a batch path producing shards directly +- `build_dataset.py` only writes the minimal distributable dataset. It does not run FK precompute. +- `precompute_dataset.py` writes a separate training dataset containing the minimal motion plus precomputed joint velocities and body FK/velocities. +- Training accepts only the precomputed dataset directory. It recursively discovers precomputed `*.h5` shards below the specified root, so precomputed datasets can be merged by placing multiple shard directories under one parent. +- Training loads all discovered precomputed motion windows into memory at startup. Joint velocities and body FK/velocities are not computed during training. ## YAML Spec Format -Example (`train_mimic/configs/datasets/twist2_full.yaml`): +Example (`train_mimic/configs/datasets/twist2.yaml`): ```yaml -name: twist2_full +name: twist2 target_fps: 30 -val_percent: 5 -hash_salt: "" preprocess: normalize_root_xy: true - ground_align: clip_min_foot + ground_align: first_frame_foot sources: - name: OMOMO_g1_GMR type: pkl input: data/twist2_retarget_pkl/OMOMO_g1_GMR - - name: lafan1_v1 + - name: lafan1 type: bvh input: data/lafan1_bvh bvh_format: lafan1 @@ -74,51 +100,55 @@ sources: |-------|-------------| | `name` | Dataset name, maps to output directory | | `target_fps` | Target frame rate for resampling | -| `val_percent` | Validation split percentage (hash-based on clip_id) | -| `hash_salt` | Optional split salt | | `preprocess.normalize_root_xy` | Normalize root body first-frame xy to origin | -| `preprocess.ground_align` | `none` / `clip_min_foot` | +| `preprocess.ground_align` | `none` / `first_frame_foot` | | `preprocess.min_frames` | Minimum clip length | | `preprocess.max_root_lin_vel` | Root linear velocity filter threshold | | `preprocess.min_peak_body_height` | Minimum peak body height | | `preprocess.max_all_off_ground_s` | Max duration all feet off ground | -| `sources[].name` | Source name (used for clips subdirectory) | +| `sources[].name` | Source name | | `sources[].type` | `bvh` / `pkl` / `npz` / `seed_csv` | | `sources[].input` | Input file or directory | -| `sources[].weight` | Optional sampling weight (default `1.0`) | | `sources[].bvh_format` | Required for BVH: `lafan1` / `hc_mocap` / `nokov` | | `sources[].robot_name` | BVH only, default `unitree_g1` | | `sources[].max_frames` | BVH only, `0` = full length | ## Conversion Rules -All sources are converted to standard training shards. Each clip goes through preprocess/filter before writing to shards: +All sources are converted to standard minimal shards. Each clip goes through preprocessing/filtering before writing to shards: - `bvh -> retarget pkl -> npz clip` - `pkl -> npz clip` (or direct batch shard for pkl-only datasets) - `npz -> validate + copy/reuse` -Each shard contains: `clip_starts`, `clip_lengths`, `clip_fps`, `clip_weights`. +Each minimal shard stores `root_pos`, `root_quat_w`, `joint_pos`, `body_names`, `clip_starts`, `clip_lengths`, and `clip_fps`. The precomputed training shards store `joint_pos`, `joint_vel`, `body_pos_w`, `body_quat_w`, `body_lin_vel_w`, `body_ang_vel_w`, and the same metadata. Training fails fast if `--motion_file` points at a minimal dataset instead of a precomputed training dataset. ## Common Commands ```bash # Force rebuild python train_mimic/scripts/data/build_dataset.py \ - --spec train_mimic/configs/datasets/twist2_full.yaml --force + --spec train_mimic/configs/datasets/twist2.yaml --force # Parallel processing python train_mimic/scripts/data/build_dataset.py \ - --spec train_mimic/configs/datasets/twist2_full.yaml --jobs 8 + --spec train_mimic/configs/datasets/twist2.yaml --jobs 8 # Custom output root python train_mimic/scripts/data/build_dataset.py \ - --spec train_mimic/configs/datasets/twist2_full.yaml \ + --spec train_mimic/configs/datasets/twist2.yaml \ --output_root /tmp/my_datasets # Print build report python train_mimic/scripts/data/build_dataset.py \ - --spec train_mimic/configs/datasets/twist2_full.yaml --json + --spec train_mimic/configs/datasets/twist2.yaml --json + +# Generate a precomputed training dataset from an existing minimal dataset +python train_mimic/scripts/data/precompute_dataset.py \ + data/datasets/twist2 --outdir data/datasets/twist2_precomputed --jobs 8 --force + +# Inspect a dataset root +python train_mimic/scripts/data/inspect_dataset.py data/datasets/twist2 ``` ## Batch Ingest to NPZ Clips @@ -128,28 +158,15 @@ Convert raw data to standard NPZ clips without merging: ```bash python train_mimic/scripts/data/ingest_motion.py \ --type bvh --input data/lafan1_bvh \ - --output data/datasets/lafan1_v1/clips/lafan1_v1 \ - --source lafan1_v1 --bvh_format lafan1 --jobs 8 + --output data/lafan1_clips/lafan1 \ + --source lafan1 --bvh_format lafan1 --jobs 8 ``` ## Check Clip FK Consistency ```bash python train_mimic/scripts/data/check_motion_npz_fk.py \ - --npz data/datasets//clips//.npz + --npz data/lafan1_clips/lafan1/.npz ``` Recommended thresholds: `pos_max < 1e-3 m`, `quat_mean < 0.05 rad`, `quat_p95 < 0.10 rad`. - -## Re-shard - -Split large shards for distribution: - -```bash -python train_mimic/scripts/data/split_shards.py \ - --input data/datasets/seed/train \ - --output data/datasets/seed/train_small_shards \ - --max_size_gb 2 -``` - -Each shard is self-contained with full clip metadata. diff --git a/docs/docs/reference/training-troubleshooting.md b/docs/docs/reference/training-troubleshooting.md index f2e56a77..b683c6c4 100644 --- a/docs/docs/reference/training-troubleshooting.md +++ b/docs/docs/reference/training-troubleshooting.md @@ -35,7 +35,7 @@ The current `convert_pkl_to_npz.py` fixes these issues. ```bash python train_mimic/scripts/data/check_motion_npz_fk.py \ - --npz data/datasets//clips//.npz + --npz data/lafan1_clips/lafan1/.npz ``` Expected thresholds: `pos_max < 1e-3 m`, `quat_mean < 0.05 rad`, `quat_p95 < 0.10 rad`. @@ -45,7 +45,7 @@ If check fails, regenerate data and run a smoke test: ```bash python train_mimic/scripts/train.py \ --num_envs 64 --max_iterations 100 \ - --motion_file data/datasets//train + --motion_file data/datasets/_precomputed ``` Expected: `Mean episode length` significantly > 1, `error_anchor_pos` starts decreasing. @@ -126,7 +126,7 @@ Ensure `num_eval_steps >= video_length`: ```bash python train_mimic/scripts/benchmark.py \ --checkpoint logs/rsl_rl/g1_general_tracking//model_30000.pt \ - --motion_file data/datasets//val \ + --motion_file data/datasets/_precomputed \ --num_envs 1 --num_eval_steps 2000 \ --video --video_length 600 ``` @@ -168,6 +168,6 @@ print(cfg.init_state.joint_pos) # Must match g1.yaml default_angles ### Solution -Update `teleopit/configs/robot/g1.yaml` and `g1_mjlab.xml` to match training environment values (default angles, armature, condim). +Update `teleopit/configs/robot/g1.yaml` and `assets/robots/unitree_g1/g1_29dof.xml` to match training environment values (default angles, armature, condim). This fix also affects the sim2real path since `default_angles` is shared by `rl_policy.py` and `observation.py`. diff --git a/docs/docs/tutorials/bvh-sim2real.md b/docs/docs/tutorials/bvh-sim2real.md index 18741b2a..75c9a83b 100644 --- a/docs/docs/tutorials/bvh-sim2real.md +++ b/docs/docs/tutorials/bvh-sim2real.md @@ -76,12 +76,6 @@ real_robot.network_interface=enp130s0 # Pause at the final BVH frame playback.pause_on_end=true -# Smooth transition from standing/current robot state into playback -transition_duration=2.0 - -# Resume blend for offline playback -pause_resume_transition_duration=1.0 - # Control loop rate policy_hz=50 ``` diff --git a/docs/docs/tutorials/offline-sim2sim.md b/docs/docs/tutorials/offline-sim2sim.md index 172e1868..57455374 100644 --- a/docs/docs/tutorials/offline-sim2sim.md +++ b/docs/docs/tutorials/offline-sim2sim.md @@ -68,20 +68,6 @@ viewers=none # Headless When all active viewer windows are closed, the simulation ends automatically. ::: -## Recording - -Record simulation data to HDF5: - -```bash -python scripts/run/run_sim.py \ - controller.policy_path=track.onnx \ - input.bvh_file=data/sample_bvh/aiming1_subject1.bvh \ - +record=true \ - recording.output_path=outputs/session.h5 -``` - -Recorded fields: `joint_pos`, `joint_vel`, `mimic_obs`, `action`, `target_dof_pos`, `torque`, `timestamp`. - ## Offline Rendering Render simulation to video (headless): diff --git a/docs/docs/tutorials/pico-sim2real.md b/docs/docs/tutorials/pico-sim2real.md index dcd3e399..4b89d3f6 100644 --- a/docs/docs/tutorials/pico-sim2real.md +++ b/docs/docs/tutorials/pico-sim2real.md @@ -21,9 +21,7 @@ There are two deployment styles: Both styles use `Pico4InputProvider` and the in-process pico-bridge receiver. There is no separate onboard Pico input mode. -For Teleopit 0.3.0, keep the host receiver on pico-bridge 0.2.0. pico-bridge -0.2.1 changes interface semantics and is not the supported receiver version for -this Teleopit release. +Teleopit targets pico-bridge 0.2.1 and its `pico_native` tracking semantics. ## 1. Install Runtime Dependencies @@ -97,6 +95,32 @@ python scripts/run/run_sim2real.py \ real_robot.network_interface=eth0 ``` +## Optional HDF5 Recording + +Install the recording extra on the machine that owns Pico input and RealSense: + +```bash +pip install -e '.[recording]' +``` + +Run the recording config: + +```bash +python scripts/run/run_sim2real.py \ + --config-name sim2real_record \ + controller.policy_path=track.onnx \ + real_robot.network_interface=enp130s0 \ + recording.task="walk forward" +``` + +Terminal controls are `R` start episode, `S` save, `D` discard, and `Q` +shutdown. `STANDING`, `MOCAP`, `ARMS`, and paused mocap can be recorded; +saved episodes cannot be discarded afterward. Episodes are saved as `.h5` files +under `data/recordings/sim2real_hdf5/episodes/`, with compressed MP4 sidecar +videos under `data/recordings/sim2real_hdf5/videos/`. The HDF5 episode stores +`frame_index` and `timestamp` sync arrays plus `observation.state(68)`, +`observation.mode(1)`, `action(36)`, and `action.hand(12)` at 30 Hz. + ## Operator Flow Keep the Unitree remote in hand. `L1+R1` is the emergency stop path into @@ -107,6 +131,7 @@ Keep the Unitree remote in hand. `L1+R1` is the emergency stop path into | Unitree remote `Start` | Enter `STANDING` | | Unitree remote `Y` | Enter `MOCAP` | | Pico/controller `A` | Pause / resume live mocap | +| Pico/controller `B` | Toggle `MOCAP` / `ARMS` | | Unitree remote `X` | Return to `STANDING` | | Unitree remote `L1+R1` | Emergency stop (`DAMPING`) | @@ -124,10 +149,15 @@ Pico body frames -> retarget -> reference buffer -> observation -> policy -> G1 When entering `STANDING`, Teleopit releases active Unitree modes, enters debug/low-level control, locks the current joints briefly, resets policy state, -and ramps Kp to reduce startup spikes. +and ramps Kp without changing policy targets. -When entering `MOCAP`, Teleopit resets policy/reference state and blends the -reference from the current robot state into the live mocap command. +When entering `MOCAP`, Teleopit resets policy/reference state and starts tracking +the live mocap command through the realtime reference timeline. + +`ARMS` keeps the same live retargeting timeline running, but sends the motion +tracker a composed reference: body, waist, and legs stay at the standing pose +while both arms follow the live retargeted result. Entering or leaving `ARMS` +resets policy/reference alignment and uses the same Kp ramp safety path. ## Pause / Resume @@ -142,6 +172,92 @@ Resume while standing still and close to the paused pose. This reduces sudden reference changes when live tracking resumes. ::: +## Optional LinkerHand Control + +Pico sim2real can drive LinkerHand hands from Pico input: + +- `gripper`: hold the matching side grip as a deadman switch; the matching + trigger closes that hand. This mode supports `hands.driver=linkerhand_l6` and + `hands.driver=linkerhand_o6`; speed and open/close poses come from the matching + driver config. +- `vr_hand_pose`: L6-only mode that retargets Pico hand pose through somehand and + commands the continuous L6 hand target. If a hand pose disappears, that side + keeps its last commanded pose. This mode uses Teleopit's Pico landmark adapter + and the public `somehand.api` from somehand 0.2.0. It always sets L6 speed to + the maximum. + +When `hands.enabled=true`, hand control remains active in all sim2real modes. +Shutdown and hand-runtime failure send the configured open pose. + +Install the local hand-control packages first if they were not installed with +the main Pico profile: + +```bash +git submodule update --init --recursive +pip install -e third_party/linkerhand-python-sdk +pip install -e third_party/somehand +scripts/setup/download_somehand_l6_assets.sh +``` + +Bring up the CAN interfaces before testing or running hand control: + +```bash +sudo /usr/sbin/ip link set can0 up type can bitrate 1000000 +sudo /usr/sbin/ip link set can1 up type can bitrate 1000000 +``` + +Before enabling full sim2real, verify the hand connection with a standalone +open/close test. The test runs until Ctrl-C: + +```bash +python scripts/dev/test_linkerhand_l6.py \ + --hand-type both \ + --left-can can0 \ + --right-can can1 +``` + +For an O6 standalone open/close test, add the O6 driver: + +```bash +python scripts/dev/test_linkerhand_l6.py \ + --driver linkerhand_o6 \ + --hand-type both \ + --left-can can0 \ + --right-can can1 +``` + +To test O6 with live Pico gripper input, add `--mode gripper`. + +Then enable L6 gripper control in Pico sim2real: + +```bash +hands.enabled=true +hands.driver=linkerhand_l6 +hands.mode=gripper +hands.linkerhand_l6.left_can=can0 +hands.linkerhand_l6.right_can=can1 +``` + +For O6 gripper control, use: + +```bash +hands.enabled=true +hands.driver=linkerhand_o6 +hands.mode=gripper +hands.linkerhand_o6.left_can=can0 +hands.linkerhand_o6.right_can=can1 +``` + +For continuous VR hand-pose control, use: + +```bash +hands.enabled=true +hands.driver=linkerhand_l6 +hands.mode=vr_hand_pose +hands.linkerhand_l6.left_can=can0 +hands.linkerhand_l6.right_can=can1 +``` + ## Optional RealSense Preview Stream the G1 RealSense color camera back to the Pico headset: @@ -172,15 +288,14 @@ input.bridge_advertise_ip=192.168.1.20 # Consecutive valid mocap frames required before MOCAP mocap_switch.check_frames=10 -# Smooth transition into mocap reference -transition_duration=2.0 - -# Realtime frames to collect before resume -pause_resume_warmup_steps=2 - # Change Pico pause button input.pause_button=right_axis_click +# Enable LinkerHand gripper control +hands.enabled=true +hands.driver=linkerhand_l6 +hands.mode=gripper + # Enable headset video preview input.video.enabled=true ``` @@ -194,4 +309,5 @@ input.video.enabled=true | Cannot enter debug mode | Unitree mode release failed | Stop other robot modes and press `Start` again | | Robot enters `STANDING` but not `MOCAP` | Mocap validation failed | Keep tracking active and stable; check `mocap_switch.check_frames` logs | | Pico pause does not return to `STANDING` | Expected behavior | Pico pause freezes mocap; press remote `X` for `STANDING` | +| LinkerHand does not move | `hands.enabled=false`, gripper deadman released, SDK/assets not installed, or CAN channel wrong | Enable `hands.enabled`, set `hands.mode`, run `scripts/dev/test_linkerhand_l6.py`, and check the selected driver's `left_can` / `right_can` | | Video preview is unavailable | RealSense or video source failed | Check camera permissions, `input.video.source`, and logs | diff --git a/docs/docs/tutorials/pico-sim2sim.md b/docs/docs/tutorials/pico-sim2sim.md index 06bccf2c..ff354dd9 100644 --- a/docs/docs/tutorials/pico-sim2sim.md +++ b/docs/docs/tutorials/pico-sim2sim.md @@ -47,15 +47,13 @@ Teleopit starts `pico_bridge.PicoBridge` in-process through `Pico4InputProvider`. The same Pico input path is used later for wired and onboard sim2real deployment. -For Teleopit 0.3.0, keep the host receiver on pico-bridge 0.2.0. pico-bridge -0.2.1 changes interface semantics and is not the supported receiver version for -this Teleopit release. +Teleopit targets pico-bridge 0.2.1 and its `pico_native` tracking semantics. ## 3. Download Assets ```bash pip install modelscope -python scripts/setup/download_assets.py --only gmr ckpt bvh +python scripts/setup/download_assets.py --only robots gmr ckpt bvh ``` ## 4. Run Pico Sim2Sim @@ -73,6 +71,7 @@ enter `MOCAP`. |----------|--------| | `Y` | Enter `MOCAP` | | `A` | Pause / resume live mocap | +| `B` | Toggle `MOCAP` / `ARMS` | | `X` | Return to `STANDING` | | `Q` | Quit | @@ -94,9 +93,12 @@ The default Pico pause button is `A`. Supported overrides include `B`, `X`, `Y`, `left_axis_click`, `right_axis_click`, `left_menu_button`, and `right_menu_button`. +The default Pico arms-mode button is `B`. `ARMS` keeps body, waist, and legs at +the standing pose while both arms follow the live retargeted result. + ## Optional Headset Video Preview -pico-bridge 0.2.0 can show a host-side camera stream in the headset. In +pico-bridge 0.2.1 can show a host-side camera stream in the headset. In simulation, Teleopit can stream the MuJoCo `d435i_rgb` camera: ```bash @@ -141,7 +143,7 @@ input.video.enabled=true | Symptom | Likely Cause | Fix | |---------|--------------|-----| | `ImportError: pico_bridge` | Pico extra not installed | Run `pip install -e '.[pico4]'` | -| Startup says pico-bridge is too old | Installed receiver does not support video args | Reinstall the Pico extra so pico-bridge 0.2.0 is used | +| Startup says pico-bridge is too old | Installed receiver does not support the required API or tracking semantics | Reinstall the Pico extra so pico-bridge 0.2.1 is used | | `TimeoutError: No Pico4 body data` | Headset is not connected or body tracking is inactive | Check the headset app, network, and `input.pico4_timeout` | | Discovery cannot find the host | Wrong advertised IP or blocked UDP | Set `input.bridge_advertise_ip=` and confirm UDP port `63901` is reachable | | Sim robot does not follow | Loop is still in `STANDING` | Press `Y` after tracking is ready | diff --git a/docs/docs/tutorials/standalone-standing.md b/docs/docs/tutorials/standalone-standing.md index a40c63bf..772b0e91 100644 --- a/docs/docs/tutorials/standalone-standing.md +++ b/docs/docs/tutorials/standalone-standing.md @@ -5,8 +5,9 @@ sidebar_position: 3 # Standalone Standing Test Run this before full sim2real control when bringing up a new robot, network -setup, or policy. It verifies the G1 bridge and RL standing path without Pico, -BVH playback, retargeting, or the full Teleopit mocap pipeline. +setup, or policy. It verifies the G1 bridge and the same RL standing path used +by sim2real, without Pico, BVH playback, retargeting, or the full mocap +pipeline. ```text G1 LowState -> standing observation -> RL policy -> G1 LowCmd targets @@ -58,12 +59,28 @@ python scripts/run/standalone_standing.py \ --network-interface eth0 ``` +Standalone standing reuses the sim2real standing components: `UnitreeG1Robot`, +`Sim2RealSafetyManager`, `RLPolicyController`, `VelCmdObservationBuilder`, and +`Sim2RealReferenceProcessor`. After locking the current joints, policy targets +are sent while Kp ramps from 10% to the configured gains over 2 seconds. To tune +this startup behavior: + +```bash +python scripts/run/standalone_standing.py \ + --policy track.onnx \ + --network-interface eth0 \ + --kp-ramp-duration 2.0 \ + --kp-ramp-floor-ratio 0.1 +``` + ## What It Checks - `g1_bridge_sdk` imports correctly. - LowState is received from the robot. - The dual-input ONNX policy can run the standing observation path. - Low-level position targets can be published through the C++ bridge. +- Observation construction, action scaling, default standing pose, Kp ramp, and + joint-limit clipping match the sim2real standing runtime. ## Next Steps diff --git a/docs/docs/tutorials/training.md b/docs/docs/tutorials/training.md index ead24ecb..cdd0efb7 100644 --- a/docs/docs/tutorials/training.md +++ b/docs/docs/tutorials/training.md @@ -23,6 +23,14 @@ Verify: python -c "import train_mimic.tasks; print('training OK')" ``` +Download the minimal seed dataset and generate the precomputed training shard: + +```bash +python scripts/setup/download_assets.py --only robots data +python train_mimic/scripts/data/precompute_dataset.py \ + data/datasets/seed --outdir data/datasets/seed_precomputed --jobs 8 +``` + ## Training ### Smoke Test @@ -31,7 +39,7 @@ python -c "import train_mimic.tasks; print('training OK')" python train_mimic/scripts/train.py \ --num_envs 64 \ --max_iterations 100 \ - --motion_file data/datasets/seed/train + --motion_file data/datasets/seed_precomputed ``` ### Full Training @@ -40,7 +48,7 @@ python train_mimic/scripts/train.py \ python train_mimic/scripts/train.py \ --num_envs 4096 \ --max_iterations 30000 \ - --motion_file data/datasets/seed/train + --motion_file data/datasets/seed_precomputed ``` ### Multi-GPU @@ -50,13 +58,33 @@ python train_mimic/scripts/train.py \ --gpu_ids 0 1 2 3 \ --num_envs 1024 \ --max_iterations 30000 \ - --motion_file data/datasets/seed/train + --motion_file data/datasets/seed_precomputed +``` + +### Multi-Node Multi-GPU + +Use `torchrun` directly when training across multiple machines: + +```bash +torchrun \ + --nnodes=$PET_NNODES \ + --nproc_per_node=$PET_NPROC_PER_NODE \ + --node_rank=$PET_NODE_RANK \ + --master_addr=$PET_MASTER_ADDR \ + --master_port=$PET_MASTER_PORT \ + train_mimic/scripts/train.py \ + --num_envs 1024 \ + --max_iterations 1000 \ + --motion_file data/datasets/seed_precomputed ``` **Notes:** - `--num_envs` is per-GPU in multi-GPU mode -- Default logger is TensorBoard; pass `--wandb_project ` to enable W&B -- `--motion_file` accepts only shard directories (containing `shard_*.npz` files) +- `--num_envs` is also per-process in multi-node mode, so total environments scale with `world_size` +- Default logger is TensorBoard. Use `--logger wandb` or `--logger swanlab` to select W&B or SwanLab; the project name defaults to `experiment_name` +- `--motion_file` accepts a precomputed training dataset root directory or a single precomputed `.h5` shard; shard discovery is recursive +- If you only have the minimal distributed shards, first run `python train_mimic/scripts/data/precompute_dataset.py --outdir ` and pass the precomputed output to training. +- Training loads all discovered precomputed motion windows into memory at startup. - `--max_iterations` means additional iterations; resuming from `model_12000.pt` with `--max_iterations 18000` trains to `model_30000.pt` ## Export ONNX @@ -68,7 +96,7 @@ python train_mimic/scripts/save_onnx.py \ --history_length 10 ``` -The exported model is a dual-input ONNX (`obs` + `obs_history`). The inference side only supports 166D dual-input ONNX. +The exported model is a dual-input ONNX (`obs` + `obs_history`). The inference side expects a 167D dual-input ONNX policy matching the current `velcmd_history` observation. ## Evaluation @@ -77,7 +105,7 @@ The exported model is a dual-input ONNX (`obs` + `obs_history`). The inference s ```bash python train_mimic/scripts/play.py \ --checkpoint logs/rsl_rl/g1_general_tracking//model_30000.pt \ - --motion_file data/datasets/seed/val + --motion_file data/datasets/seed_precomputed ``` ### Benchmark @@ -85,7 +113,7 @@ python train_mimic/scripts/play.py \ ```bash python train_mimic/scripts/benchmark.py \ --checkpoint logs/rsl_rl/g1_general_tracking//model_30000.pt \ - --motion_file data/datasets/seed/val \ + --motion_file data/datasets/seed_precomputed \ --num_envs 1 ``` @@ -94,7 +122,7 @@ python train_mimic/scripts/benchmark.py \ ```bash python train_mimic/scripts/benchmark.py \ --checkpoint logs/rsl_rl/g1_general_tracking//model_30000.pt \ - --motion_file data/datasets/seed/val \ + --motion_file data/datasets/seed_precomputed \ --num_envs 1 \ --video \ --video_length 600 @@ -113,4 +141,4 @@ Key files: - `train_mimic/app.py` - Shared entry point for train/play/benchmark - `train_mimic/tasks/tracking/config/env.py` - General-Tracking-G1 env builder - `train_mimic/tasks/tracking/config/rl.py` - TemporalCNN PPO config -- `train_mimic/tasks/tracking/mdp/commands.py` - Supports `adaptive` / `uniform` / `start` sampling modes +- `train_mimic/tasks/tracking/mdp/commands.py` - Supports `uniform`, `start`, and `rewind` sampling modes. Training defaults to `rewind`; playback/benchmark use `start`. diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md index 27641896..0366815c 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/config-reference.md @@ -15,7 +15,6 @@ sidebar_position: 2 | `viewers` | str/list | `sim2sim` | 可视化窗口集合:`mocap`、`retarget`、`sim2sim`、`camera`、`all`、`none`。`all` 打开 `mocap`、`retarget` 和 `sim2sim`;如需相机画面需显式加入 `camera` | | `realtime` | bool | `false` | 是否启用实时模式(实机部署时需开启) | | `num_steps` | int | — | 仿真总步数;设为 `-1` 表示无限运行 | -| `transition_duration` | float | — | 从静止姿态过渡到策略控制的时长(秒) | | `keyboard.enabled` | bool | `false` | 是否启用 sim2sim 实时键盘模式控制 | | `playback.pause_on_end` | bool | `false` | 回放结束后是否暂停(而非退出) | | `playback.keyboard.enabled` | bool | `false` | 是否启用键盘控制回放进度 | @@ -80,13 +79,15 @@ target = clip(action, clip_range) * action_scale + default_dof_pos | `pico4_buffer_size` | int | `60` | 帧缓冲区大小 | | `pause_button` | str | `A` | 用于暂停/恢复的手柄按钮名称 | | `pause_debounce_s` | float | `0.25` | 暂停按钮防抖时间 | +| `arms_button` | str | `B` | Pico 中用于切换 `MOCAP` / `ARMS` 的按钮 | +| `arms_debounce_s` | float | `0.25` | 双臂模式按钮防抖时间 | | `bridge_host` | str | `0.0.0.0` | Teleopit host receiver 绑定地址 | | `bridge_port` | int | `63901` | Teleopit host receiver TCP/UDP 端口 | | `bridge_discovery` | bool | `true` | 是否启用 pico-bridge 发现广播 | | `bridge_advertise_ip` | str/null | `null` | 可选的 host 广播 IP 覆盖 | | `bridge_start_timeout` | float | `10.0` | 启动 bridge 的超时时间 | | `bridge_history_size` | int | `120` | bridge 保留的 Pico 帧历史长度 | -| `video.enabled` | bool | `false` | 通过 pico-bridge 0.2.0 将 host 相机预览发送回 Pico | +| `video.enabled` | bool | `false` | 通过 pico-bridge 0.2.1 将 host 相机预览发送回 Pico | | `video.source` | str/null | `null` | 视频源:`mujoco`、`realsense` 或 `test-pattern` | | `video.width` / `height` / `fps` | int | `1280` / `720` / `30` | 视频采集/渲染设置 | | `video.device` | str/null | `null` | 可选的 RealSense 序列号 | @@ -103,8 +104,6 @@ target = clip(action, clip_range) * action_scale + default_dof_pos | `retarget_buffer_delay_s` | 缓冲延迟 | | `reference_steps` | 参考轨迹窗口步数 | | `realtime_buffer_warmup_steps` | 播放前预热帧数 | -| `realtime_buffer_low_watermark_steps` | 低水位线 | -| `realtime_buffer_high_watermark_steps` | 高水位线 | | `reference_velocity_smoothing_alpha` | 速度平滑系数 | | `reference_anchor_velocity_smoothing_alpha` | 锚点速度平滑系数 | @@ -112,14 +111,18 @@ target = clip(action, clip_range) * action_scale + default_dof_pos 以下字段用于 sim2real 配置(`sim2real.yaml`、`pico4_sim2real.yaml`)。 +sim2real 默认使用 `viewers=none`。设置 `viewers=retarget` 可打开一个可选的 +MuJoCo 窗口显示重定向参考;`sim2sim`、`mocap`、`camera` 和 `all` +仅用于仿真 viewer。 + ### 安全相关 | 字段 | 说明 | 默认值 | |---|---|---| -| `startup_ramp_duration` | 从锁定位置平滑过渡到策略控制的时长(秒) | `2.0` | +| `startup_ramp_duration` | 进入 `STANDING` 后的 Kp ramp 时长;逐步提高 PD 增益,不改变 policy target | `2.0` | | `joint_vel_limit` | 关节速度限制(rad/s),超过时触发急停 | `10.0` | | `mocap_switch.check_frames` | 切换到 MOCAP 前所需的连续有效帧数 | `10` | -| `mocap_switch.max_position_value` | 位置合理性阈值(米) | `5.0` | +| `arm_mocap.controlled_joint_indices` | Pico `ARMS` 模式下由实时 retargeting 驱动的 G1 关节 | `[15..28]` | ### 真机 SDK @@ -135,19 +138,87 @@ target = clip(action, clip_range) * action_scale + default_dof_pos ### 暂停/恢复(Pico sim2real) -| 字段 | 说明 | 默认值 | -|---|---|---| -| `pause_resume_transition_duration` | 离线回放恢复时的混合时长;实时 Pico 恢复使用实时追踪重新居中 | `1.0` | -| `pause_resume_warmup_steps` | 恢复追踪前采集的实时动捕帧数 | `2` | - 实时 Pico 恢复追踪时会先重新居中航向和地面平面位置。操作者应保持静止,并尽量贴近暂停时的姿态,以减少参考突变。 -### 实时追赶(Pico sim2real) +### 灵巧手(Pico sim2real) + +`hands.enabled=true` 要求 `input.provider=pico4`,并以本地 editable 方式安装 +`third_party/linkerhand-python-sdk` 和 `third_party/somehand`。启用后,手控会在所有 sim2real 模式中保持生效。 +`gripper` 支持 `linkerhand_l6` 和 `linkerhand_o6`,会用 Pico trigger 在配置的张开和闭合姿态之间插值。 +`vr_hand_pose` 只支持 L6:手部 pose 消失时,对应侧会保持上一条命令;L6 速度会设为最大值; +Teleopit 会先将 Pico 手部状态转成 21 个 landmarks,再只通过 somehand 0.2.0 公开的 `somehand.api` 调用。 | 字段 | 说明 | 默认值 | |---|---|---| -| `realtime_catchup_enabled` | 缓冲区过大时启用追赶 | `true` | -| `realtime_catchup_trigger_steps` | 触发追赶的缓冲区深度 | `6` | -| `realtime_catchup_release_steps` | 释放追赶的缓冲区深度 | `3` | -| `realtime_catchup_target_delay_s` | 追赶目标延迟 | `0.04` | -| `reference_qpos_smoothing_alpha` | 关节位置平滑系数(1.0 = 无平滑) | `0.4` | +| `hands.enabled` | 启用可选手部运行时 | `false` | +| `hands.mode` | `gripper` 或 `vr_hand_pose` | `gripper` | +| `hands.driver` | 手部设备驱动:`linkerhand_l6` 或 `linkerhand_o6` | `linkerhand_l6` | +| `hands.sides` | 控制侧 | `[left, right]` | +| `hands.rate_hz` | gripper 最大命令频率(Hz) | `30.0` | +| `hands.frame_timeout_s` | 手柄或手部 pose 过期阈值 | `0.3` | +| `hands.linkerhand_l6.left_can` / `right_can` | 左右手 CAN 通道 | `can0` / `can1` | +| `hands.linkerhand_l6.speed` | `gripper` 使用的 L6 速度;`vr_hand_pose` 会覆盖为最大速度 | 见配置 | +| `hands.linkerhand_l6.deadman_threshold` | 启用单侧控制所需的最小 grip 值 | `0.5` | +| `hands.linkerhand_l6.trigger_deadzone` | trigger 两端死区 | `0.05` | +| `hands.linkerhand_l6.open_pose` / `close_pose` | L6 的 6 维张开/闭合姿态 | 见配置 | +| `hands.linkerhand_o6.left_can` / `right_can` | 左右 O6 手 CAN 通道 | `can0` / `can1` | +| `hands.linkerhand_o6.speed` | `gripper` 使用的 O6 速度 | 见配置 | +| `hands.linkerhand_o6.open_pose` / `close_pose` | O6 的 6 维张开/闭合姿态 | 见配置 | +| `hands.somehand.config_path` | `vr_hand_pose` 使用的 somehand 双手 L6 配置 | 见配置 | +| `hands.somehand.rate_hz` | 低延时 `vr_hand_pose` 命令频率(Hz) | `60.0` | +| `hands.somehand.max_iterations` | `vr_hand_pose` 的 somehand solver 迭代上限 | `12` | +| `hands.somehand.temporal_filter_alpha` | somehand 输入 landmarks 平滑 alpha;`1.0` 表示关闭平滑延时 | `1.0` | +| `hands.somehand.output_alpha` | somehand qpos 输出平滑 alpha;`1.0` 表示关闭平滑延时 | `1.0` | + +### HDF5 录制(Pico sim2real) + +`recording.enabled=true` 只支持 `input.provider=pico4`、 +`input.video.enabled=true`、`input.video.source=realsense`,并且需要交互式终端。 +录制是手动控制:`R` 开始 episode,`S` 保存当前 episode,`D` 丢弃当前 episode, +`Q` 关闭。可以录制 `STANDING`、`MOCAP`、`ARMS` 和暂停状态的 mocap。 + +`sim2real_record.yaml` 会同时启用录制和必需的 RealSense `input.video` +路径。录制不会打开第二路相机,而是消费 `pico_input` 已经产生的同一批帧。 + +| 字段 | 说明 | 默认值 | +|---|---|---| +| `recording.enabled` | 启用手动 HDF5 录制 | `false` | +| `recording.output_dir` | 数据集根目录 | `data/recordings/sim2real_hdf5` | +| `recording.task` | 写入 frame 的任务字符串 | `demo` | +| `recording.fps` | 录制/视频主时钟频率 | `30` | +| `recording.min_episode_seconds` | 保存时短于该时长的 episode 会被丢弃 | `1.0` | +| `recording.record_modes` | 允许开始录制和写帧的模式 | `[standing, mocap, arms, pause]` | +| `recording.camera.key` | RGB 图像数据集 key | `observation.images.d435i_rgb` | +| `recording.camera.width` / `height` / `fps` | RealSense RGB 采集设置 | `640` / `480` / `30` | +| `recording.camera.device` | 可选 RealSense 序列号 | `null` | +| `recording.video.codec` / `quality` / `pixelformat` | MP4 sidecar 编码设置 | `libx264` / `8` / `yuv420p` | + +相机失败时的行为由 `input.video.fail_on_error` 控制。 + +每个保存的 episode 会在 `recording.output_dir/episodes/` 下写入一个 `.h5` +文件,并在 `recording.output_dir/videos//` 下写入一个压缩 MP4 +sidecar。HDF5 episode 保存 `frame_index` 和 `timestamp` 数组,并在根属性中 +记录 `video_path`、`video_fps` 和 `video_frames` 用于同步。录制不会写入原始 +RGB 图像 dataset。 + +HDF5 datasets: + +```text +frame_index int64[N] +timestamp float64[N] +observation.state float32[68] +observation.mode float32[1] +action float32[36] +action.hand float32[12] +``` + +根属性包含 Teleopit HDF5 recording format、schema version、task、fps、 +frame count 和视频同步元数据。 + +`observation.state` 的顺序是 `joint_pos(29)`、`joint_vel(29)`、 +`base_quat_wxyz(4)`、`base_ang_vel(3)` 和 `projected_gravity(3)`。 +`observation.mode` 是数值类别:`standing=0`、`mocap=1`、 +`arms=2`、`pause=3`。`action` 是当前 reference qpos: +`root_pos(3) + root_quat_wxyz(4) + joint_pos(29)`。 +`action.hand` 是手部 worker 最新的 LinkerHand 命令: +`left_pose(6) + right_pose(6)`,使用 SDK 的 0-255 pose 数值。 diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/faq.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/faq.md index a3a75e25..65856f4e 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/faq.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/configuration/faq.md @@ -7,8 +7,7 @@ sidebar_position: 3 ## 为什么设置了 `policy_path` 还是启动不了? 1. 确认文件存在 -2. 确认不是旧的 1402D / TWIST2 ONNX 模型 -3. 确认输入维度是 `166`,且为双输入 ONNX(`obs` + `obs_history`) +2. 确认输入维度是 `167`,且为双输入 ONNX(`obs` + `obs_history`) ## 为什么必须显式指定 `input.bvh_file`? diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/contributing.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/contributing.md index ab2e7b2c..2056aa8b 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/contributing.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/contributing.md @@ -34,7 +34,7 @@ teleopit/ # 核心推理与部署包 ├── retargeting/ # GMR 动作重定向 ├── sim/ # SimulationLoop、参考运动工具 ├── sim2real/ # 真机状态机 -├── recording/ # HDF5Recorder +├── recording/ # Pico motion 录制辅助工具 ├── runtime/ # 配置解析、工厂、外部资源管理 ├── bus/ # InProcessBus 进程内通信 └── configs/ # Hydra YAML 配置 diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/download-assets.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/download-assets.md index 419a2222..e8f17538 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/download-assets.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/download-assets.md @@ -20,25 +20,28 @@ python scripts/setup/download_assets.py 只下载推理所需的资源: ```bash -python scripts/setup/download_assets.py --only gmr ckpt bvh +python scripts/setup/download_assets.py --only robots gmr ckpt bvh ``` ## 资源清单 -| 资源 | 大小 | 用途 | -|------|------|------| -| `track.onnx` | 4 MB | ONNX 推理模型 | -| `track.pt` | 27 MB | PyTorch 检查点(用于恢复训练) | -| `data/datasets/seed/train/shard_*.npz` | ~25 GB | 训练数据集 | -| `data/datasets/seed/val/shard_*.npz` | ~1.4 GB | 验证数据集 | -| `data/sample_bvh/*.bvh` | 5 MB | 示例动捕文件 | -| `teleopit/retargeting/gmr/assets/` | ~1.2 GB | GMR 重定向机器人模型 | +checkpoint、数据集和资源包更新后,下载文件大小会变化。下表中的仓库路径才是稳定约定。 + +| 本地路径 | 用途 | +|----------|------| +| `track.onnx` | ONNX 推理模型 | +| `track.pt` | 用于恢复训练的 PyTorch checkpoint | +| `data/datasets/seed/shard_*.h5` | 最小运动数据集;训练前需先预计算 | +| `data/sample_bvh/*.bvh` | 示例动捕文件 | +| `assets/robots/unitree_g1/` | 训练、sim2sim、重定向和 FK 校验共用的 G1 canonical XML 与 mesh | +| `teleopit/retargeting/gmr/assets/` | GMR 重定向资源、IK 配置和非 canonical 机器人描述 | ## 资源分组 | 分组 | ModelScope 仓库 | 包含内容 | |------|-----------------|----------| | `ckpt` | `BingqianWu/Teleopit-models` | `track.onnx`、`track.pt` | +| `robots` | `BingqianWu/Teleopit-models` | Canonical 机器人 XML/mesh | | `gmr` | `BingqianWu/Teleopit-models` | GMR 重定向资源 | | `bvh` | `BingqianWu/Teleopit-models` | 示例 BVH 动捕文件 | | `data` | `BingqianWu/Teleopit-datasets` | 训练 / 验证数据分片 | diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md index 466d84f4..4f3028c0 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/installation.md @@ -32,7 +32,7 @@ pip install -e . pip install -e '.[train]' ``` -额外安装 `rsl-rl-lib`、`mjlab`、`wandb` 等训练相关依赖。 +额外安装 `rsl-rl-lib`、`mjlab`、`wandb`、`swanlab` 等训练相关依赖。 ### Sim2Real(硬件部署) @@ -40,7 +40,7 @@ pip install -e '.[train]' pip install -e '.[sim2real]' ``` -额外安装 `opencv-python` 和 `g1_bridge_sdk`。此外还需要初始化子模块并编译 C++ 桥接库: +额外安装 `opencv-python`。此外还需要初始化子模块并编译/安装 C++ `g1_bridge_sdk` 桥接库: ```bash git submodule update --init --recursive @@ -56,12 +56,38 @@ pip install -e '.[pico4]' ``` Teleopit 使用进程内的 `pico_bridge.PicoBridge` receiver 接收 Pico 追踪数据。 -Teleopit 0.3.0 请使用 pico-bridge 0.2.0,不要升级到 pico-bridge 0.2.1; -0.2.1 修改了接口语义,Teleopit 0.3.0 未按该语义适配。 +Teleopit 面向 pico-bridge 0.2.1 及其 `pico_native` tracking 语义。 receiver 可以运行在工作站 PC,也可以运行在机器人 onboard 计算机。 完整设置流程详见 [Pico Sim2Sim](../tutorials/pico-sim2sim) 和 [Pico Sim2Real](../tutorials/pico-sim2real)。 +Pico sim2real 可选的 LinkerHand 控制使用本地 third-party 包。初始化 +submodule 后,直接安装这些包: + +```bash +git submodule update --init --recursive +pip install -e third_party/linkerhand-python-sdk +pip install -e third_party/somehand +scripts/setup/download_somehand_l6_assets.sh +``` + +只有在 `hands.enabled=true` 时才需要安装这些包。 + +### Sim2Real 录制 + +```bash +pip install -e '.[recording]' +``` + +该配置包含 Pico sim2real 栈,以及 `sim2real_record.yaml` 使用的视频依赖。 +RealSense Python 绑定与平台相关;使用 `input.video.source=realsense` 时, +需要在当前环境中手动安装 `pyrealsense2`。在 Arm 机器上,请使用 +conda-forge,而不是 pip 包: + +```bash +conda install -c conda-forge pyrealsense2 +``` + ## 验证安装 ```bash diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/quick-start.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/quick-start.md index 6c1b6f3a..1fee2d36 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/quick-start.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/getting-started/quick-start.md @@ -9,7 +9,7 @@ sidebar_position: 3 ## 前置条件 1. [安装 Teleopit](installation)(推理配置) -2. [下载资源](download-assets)(`--only gmr ckpt bvh`) +2. [下载资源](download-assets)(`--only robots gmr ckpt bvh`) ## 运行离线 Sim2Sim @@ -55,7 +55,7 @@ python scripts/run/run_sim.py controller.policy_path=track.onnx 'viewers=[retarg ## 下一步 -- [离线 Sim2Sim 教程](../tutorials/offline-sim2sim) - 包含录制和渲染的完整指南 +- [离线 Sim2Sim 教程](../tutorials/offline-sim2sim) - 包含渲染的完整指南 - [Pico Sim2Sim](../tutorials/pico-sim2sim) - 在 MuJoCo 中验证 Pico 追踪 - [独立站立测试](../tutorials/standalone-standing) - 检查 G1 bridge、网络和 policy 站立 - [Pico Sim2Real](../tutorials/pico-sim2real) - 将 Pico 遥操作部署到 Unitree G1 diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/intro.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/intro.md index 8856a069..30fa4e3c 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/intro.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/intro.md @@ -20,7 +20,7 @@ slug: / ```text InputProvider (BVH / Pico4 VR) -> Retargeter (GMR) - -> ObservationBuilder (166D) + -> ObservationBuilder (167D) -> Controller (双输入 TemporalCNN ONNX) -> Robot (MuJoCo 仿真 或 Unitree G1) ``` @@ -30,8 +30,8 @@ InputProvider (BVH / Pico4 VR) | 项目 | 参数 | |------|------| | 策略频率 | 50 Hz | -| PD 控制频率 | 1000 Hz | -| 观测维度 | 166D | +| PD 控制频率 | 200 Hz | +| 观测维度 | 167D | | 动作维度 | 29D(G1 关节) | | ONNX 模型 | 双输入 TemporalCNN | | 运动重定向 | GMR(General Motion Retargeting) | diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/architecture.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/architecture.md index cbeaa013..5baef6ff 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/architecture.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/architecture.md @@ -4,76 +4,76 @@ sidebar_position: 1 # 架构 -本页介绍 Teleopit 的整体架构设计、核心模块边界及技术规格。 +面向开发者的系统内部结构和技术约束。 -## 运行主线 Pipeline +## Pipeline -Teleopit 的运行主线是一条线性数据流管线: - -``` -输入源 (Input) - → 重定向 (Retarget) - → 观测构建 (Observation) - → 策略推理 (Policy) - → PD 控制 (Controller) - → 仿真/实机执行 (Sim / Real) +```text +InputProvider(BVH 文件 / Pico4) + -> Retargeter(GMR) + -> ObservationBuilder(167D) + -> Controller(双输入 TemporalCNN ONNX) + -> Robot(MuJoCo 仿真或 Unitree G1) ``` -每个环节职责单一,通过明确定义的接口相互连接。 +离线/在线推理由 `teleopit/runtime/` 和 `teleopit/pipeline.py` 装配。硬件状态机通过 `teleopit/sim2real/mp/` 中的进程隔离运行时执行。训练由 `train_mimic/` 提供。 ## 代码结构 -``` -teleopit/ -├── app.py # 应用入口 -├── interfaces.py # 核心接口定义 -├── runtime/ # 运行时配置装配与启动逻辑 -├── pipeline/ # 数据流管线 -├── sim2real/ # 实机部署适配层 -├── observation/ # 观测构建 -├── rl_policy/ # 强化学习策略推理 -├── task/ # 任务配置 -└── dataset_builder/ # 数据集构建工具 +```text +configs / scripts + -> runtime + -> interfaces + pipeline state machines + -> adapters(inputs / retargeting / controller / robot / recording) + +train_mimic/scripts + -> train_mimic/app.py + -> single task registry / env builder / runner cfg + -> mjlab / rsl_rl + +train_mimic/scripts/data + -> train_mimic/data/dataset_builder.py + -> dataset_lib / motion_fk / convert_pkl_to_npz ``` ## 核心模块边界 -| 模块 | 文件/目录 | 职责 | -|---|---|---| -| 接口层 | `interfaces.py` | 定义所有核心抽象接口,模块间仅通过接口通信 | -| 运行时 | `runtime/` | Hydra 配置加载、对象组装、依赖注入 | -| Pipeline | `pipeline/` | 数据流编排,驱动每一帧的采样-推理-执行循环 | -| Sim2Real | `sim2real/` | 实机通信适配(DDS 桥接、状态同步) | -| 观测 | `observation/` | 从仿真/实机状态构建策略所需的观测向量 | -| 策略 | `rl_policy/` | ONNX 模型加载与推理,action 后处理 | -| 入口 | `app.py` | 命令行入口,调用 runtime 装配并启动 pipeline | -| 任务配置 | `task/` | Hydra 配置文件(YAML) | -| 数据集 | `dataset_builder/` | 动捕数据转换、NPZ 打包、数据集分片 | +| 模块 | 职责 | +|------|------| +| `teleopit/interfaces.py` | 稳定协议:InputProvider、Retargeter、Controller、Robot、ObservationBuilder | +| `teleopit/runtime/` | 配置解析、路径规范化、组件装配、CLI 校验 | +| `teleopit/pipeline.py` | 离线仿真的轻量 facade | +| `teleopit/sim2real/mp/` | 进程隔离的 sim2real 状态机、IPC 和机器人控制循环 | +| `teleopit/controllers/observation.py` | ObservationBuilder | +| `teleopit/controllers/rl_policy.py` | 接受观测维度与运行时 builder 匹配的双输入 ONNX | +| `train_mimic/app.py` | 共享的训练/播放/benchmark 装配 | +| `train_mimic/tasks/tracking/config/` | 单一任务注册(`General-Tracking-G1`) | +| `train_mimic/data/dataset_builder.py` | 唯一官方数据集构建入口 | ## 技术规格 | 项目 | 规格 | |---|---| -| 仿真引擎 | MuJoCo | -| 策略推理 | ONNX Runtime | -| 配置系统 | Hydra | -| 实机通信 | DDS(Cyclone DDS) | -| 支持的机器人 | 宇树 G1 | -| 输入源 | BVH 文件、Pico 4 VR 头显 | +| 训练任务 | `General-Tracking-G1` | +| 推理观测 | `velcmd_history`(167D) | +| ONNX 签名 | 双输入 `obs`(167D)+ `obs_history` | +| Actor/Critic | TemporalCNN(2048、1024、512、256、128) | +| 训练采样 | 默认 `rewind`;也支持 `uniform`;播放/评估使用 `start` | +| 训练 `window_steps` | `[0]` | +| 数据格式 | 可递归发现的最小 HDF5 shard(`shard_*.h5`) | ## 约束 -- **单机器人**:当前架构假设同一时刻只控制一台机器人。 -- **固定观测格式**:观测构建器的输出维度在初始化时确定,运行时不可变。 -- **同步 pipeline**:pipeline 各阶段串行执行,策略推理频率即为 pipeline 的帧率。 -- **ONNX 模型**:策略必须导出为 ONNX 格式,不直接支持 PyTorch checkpoint。 +- 必须显式提供 `controller.policy_path`,且文件必须存在 +- 离线 BVH 运行必须显式提供 `input.bvh_file` +- `viewers` 是唯一的 viewer 配置入口 +- 观测/ONNX 维度不匹配会在启动时立即报错 +- sim2real 也要求双输入 ONNX,且观测维度必须与运行时 builder 匹配 ## 公共接口 -以下接口构成 Teleopit 的公共 API 面,外部扩展应仅依赖这些接口: +**稳定运行模式:** 离线 sim2sim、离线 sim2real playback、Pico4 sim2sim、G1 sim2real -- `interfaces.py` 中定义的抽象基类(InputProvider、ObsBuilder、Controller 等) -- `runtime/` 提供的工厂注册机制 -- Hydra 配置 schema +**稳定训练入口:** `train.py`、`play.py`、`benchmark.py`、`save_onnx.py` -内部实现细节(如 pipeline 的具体调度逻辑)不属于公共 API,可能在版本迭代中变更。 +**稳定数据入口:** `build_dataset.py`、`precompute_dataset.py` diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/assets.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/assets.md index 733afcf6..2451e7b7 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/assets.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/assets.md @@ -4,11 +4,12 @@ sidebar_position: 2 # 资源管理 -机器人模型、数据集、checkpoint 和演示媒体不进 Git 历史,统一走外部下载。 +数据集、checkpoint、机器人模型和演示媒体不进 Git 历史,统一走外部下载。Unitree G1 的 canonical 模型下载到 `assets/robots/unitree_g1/g1_29dof.xml`。 ## 不入库的内容 -- `teleopit/retargeting/gmr/assets/` — 机器人 mesh、URDF/MJCF 等 +- `assets/robots/` — canonical 机器人 XML/mesh +- `teleopit/retargeting/gmr/assets/` — GMR 重定向资源、IK 配置和非 canonical 机器人描述 - `data/`、checkpoint、缓存等生成产物 - 演示媒体(`assets/demo.gif`、`assets/demo.mp4`) @@ -33,9 +34,10 @@ sidebar_position: 2 | 组 | 仓库 | 远端路径 | |----|------|---------| | `ckpt` | Teleopit-models | `checkpoints/track.onnx`、`checkpoints/track.pt` | +| `robots` | Teleopit-models | `archives/robot_assets.tar.gz` | | `gmr` | Teleopit-models | `archives/gmr_assets.tar.gz` | | `bvh` | Teleopit-models | `archives/sample_bvh.tar.gz` | -| `data` | Teleopit-datasets | `data/train/`、`data/val/` | +| `data` | Teleopit-datasets | `data/` | ## 下载 @@ -46,7 +48,7 @@ sidebar_position: 2 python scripts/setup/download_assets.py # 只下载推理必需的资源 -python scripts/setup/download_assets.py --only gmr ckpt bvh +python scripts/setup/download_assets.py --only robots gmr ckpt bvh # 只下载训练数据 python scripts/setup/download_assets.py --only data @@ -61,17 +63,17 @@ python scripts/setup/download_assets.py --source huggingface |---------|---------| | `checkpoints/track.onnx` | `track.onnx` | | `checkpoints/track.pt` | `track.pt` | +| `archives/robot_assets.tar.gz` | `assets/robots/`(自动解压) | | `archives/gmr_assets.tar.gz` | `teleopit/retargeting/gmr/assets/`(自动解压) | | `archives/sample_bvh.tar.gz` | `data/sample_bvh/`(自动解压) | -| `data/train/` | `data/datasets/seed/train/` | -| `data/val/` | `data/datasets/seed/val/` | +| `data/` | `data/datasets/seed/` | ## 上传到 ModelScope ### 第一步:准备上传目录 ```bash -python scripts/setup/prepare_modelscope_assets.py --only ckpt gmr bvh --clean +python scripts/setup/prepare_modelscope_assets.py --only ckpt robots gmr bvh --clean python scripts/setup/prepare_modelscope_assets.py --only data ``` @@ -115,7 +117,7 @@ tag 与代码仓库的 Git tag 保持一致,方便追溯每个版本对应的 python scripts/setup/upload_hf_assets.py --dry-run --clean # 只准备指定组 -python scripts/setup/upload_hf_assets.py --only ckpt gmr bvh --dry-run +python scripts/setup/upload_hf_assets.py --only ckpt robots gmr bvh --dry-run python scripts/setup/upload_hf_assets.py --only data --dry-run ``` @@ -124,7 +126,7 @@ python scripts/setup/upload_hf_assets.py --only data --dry-run ### 第二步:执行上传 ```bash -python scripts/setup/upload_hf_assets.py --only ckpt gmr bvh --clean +python scripts/setup/upload_hf_assets.py --only ckpt robots gmr bvh --clean python scripts/setup/upload_hf_assets.py --only data --clean ``` diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/dataset.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/dataset.md index b7b17831..eb497301 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/dataset.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/dataset.md @@ -7,62 +7,86 @@ sidebar_position: 3 ## 下载预构建数据集(推荐) ```bash -python scripts/setup/download_assets.py --only data +python scripts/setup/download_assets.py --only robots data ``` -下载后直接传 shard 目录用于训练: +下载后先生成预计算训练 shard,再把预计算数据集根目录用于训练: ```bash -python train_mimic/scripts/train.py --motion_file data/datasets/seed/train +python train_mimic/scripts/data/precompute_dataset.py \ + data/datasets/seed --outdir data/datasets/seed_precomputed --jobs 8 +python train_mimic/scripts/train.py --motion_file data/datasets/seed_precomputed ``` 如需自定义构建,继续阅读下文。 --- +## 录制 Pico clips + +使用交互式 Pico 录制脚本,从实时 body tracking 生成训练可用的 NPZ clips: + +```bash +pip install -e '.[pico4]' +python scripts/run/record_pico_motion.py +``` + +录制器会先启动 Pico receiver 和实时 `Retarget` viewer,再等待输入 clip 名; +因此终端空闲时预览仍会持续运行。输入动作语义名后,用 `R` 开始录制、`S` +保存、`D` 丢弃、`N` 输入新名字、`Q` 退出。保存的 clip 会写入 +`data/pico_motion/clips/`,文件名格式为 `_.npz`;不会写 +每段 clip 的 JSON,因此可以手动改名或删除。 + +将所有已录制 clips 构建为标准 HDF5 shard 数据集: + +```bash +python train_mimic/scripts/data/build_dataset.py \ + --spec data/pico_motion/pico_recorded.yaml --force +``` + +预处理后至少需要保留一段有效 clip。 + ## 自定义构建 -数据主线:`typed source YAML -> preprocess/filter -> shard-only 训练数据` +数据主线:`typed source YAML -> preprocess/filter -> minimal HDF5 shards -> precomputed training dataset` ```bash python train_mimic/scripts/data/build_dataset.py \ - --spec train_mimic/configs/datasets/twist2_full.yaml + --spec train_mimic/configs/datasets/twist2.yaml ``` ## 输出目录结构 ```text data/datasets// -├── clips/ # 可选;仅在需要逐 clip 中间产物时存在 -│ └── /... -├── train/ -│ └── shard_*.npz -├── val/ -│ └── shard_*.npz -├── manifest_resolved.csv -└── build_info.json +└── shard_*.h5 + +data/datasets/_precomputed/ +└── shard_*.h5 ``` -- 若 spec 包含 `bvh` 或 `npz` source,builder 会保留/生成 `clips/` -- 若 spec 全部是 `pkl` 或 `seed_csv` source,直接并行产出 split 级别的 shard,默认不写中间 clip 文件 +- 若 spec 包含 `bvh` 或 `npz` source,完整 dataset builder 会在转换期间使用临时 `clips/` 目录,并在 shard 写入完成后删除。重新 build 不会复用已转换 clips。 +- 若 spec 全部是 `pkl` 或 `seed_csv` source,builder 会直接并行产出 shard,默认不写中间 clip 文件 +- `build_dataset.py` 只写最小分发数据集,不执行 FK 预计算。 +- `precompute_dataset.py` 会写出独立的训练数据集,里面包含最小运动数据以及预计算的 joint velocity 和 body FK/velocity。 +- 训练只接受预计算后的数据集目录。它会递归发现指定根目录下的预计算 `*.h5` shard,因此可以把多个预计算数据集目录放到同一个父目录下完成合并。 +- 训练会在启动时把所有发现的预计算 motion window 全量加载到内存中。joint velocity 和 body FK/velocity 不会在训练时计算。 ## YAML spec -示例(`train_mimic/configs/datasets/twist2_full.yaml`): +示例(`train_mimic/configs/datasets/twist2.yaml`): ```yaml -name: twist2_full +name: twist2 target_fps: 30 -val_percent: 5 -hash_salt: "" preprocess: normalize_root_xy: true - ground_align: clip_min_foot + ground_align: first_frame_foot sources: - name: OMOMO_g1_GMR type: pkl input: data/twist2_retarget_pkl/OMOMO_g1_GMR - - name: lafan1_v1 + - name: lafan1 type: bvh input: data/lafan1_bvh bvh_format: lafan1 @@ -74,68 +98,71 @@ sources: |------|------| | `name` | 数据集名称,对应输出目录 `data/datasets//` | | `target_fps` | 写入 shard 前统一重采样到的目标帧率 | -| `val_percent` | 基于 `clip_id` hash 的验证集比例 | -| `hash_salt` | 可选 split salt | | `preprocess.normalize_root_xy` | 是否把根 body 首帧 xy 平移到原点 | -| `preprocess.ground_align` | `none` / `clip_min_foot` | +| `preprocess.ground_align` | `none` / `first_frame_foot` | | `preprocess.min_frames` | clip 最短长度约束 | | `preprocess.max_root_lin_vel` / `min_peak_body_height` / `max_all_off_ground_s` | 基础过滤阈值 | -| `sources[].name` | source 名称;生成 clip 中间产物时也作为 `clips//` 子目录名 | +| `sources[].name` | source 名称 | | `sources[].type` | `bvh` / `pkl` / `npz` / `seed_csv` | | `sources[].input` | 原始输入文件或目录 | -| `sources[].weight` | 可选源级别采样权重,默认 `1.0` | | `sources[].bvh_format` | 仅 `bvh` source 必填:`lafan1` / `hc_mocap` / `nokov` | | `sources[].robot_name` | 仅 `bvh` source,默认 `unitree_g1` | | `sources[].max_frames` | 仅 `bvh` source,`0` 表示全长 | +## 转换规则 + +所有 source 都会转换成标准最小 shard。每段 clip 会先经过预处理/过滤,再写入 shard: + +- `bvh -> retarget pkl -> npz clip` +- `pkl -> npz clip`(或在 pkl-only 数据集中直接 batch 写 shard) +- `npz -> validate + copy/reuse` + +每个最小 shard 保存 `root_pos`、`root_quat_w`、`joint_pos`、`body_names`、`clip_starts`、`clip_lengths` 和 `clip_fps`。预计算训练 shard 保存 `joint_pos`、`joint_vel`、`body_pos_w`、`body_quat_w`、`body_lin_vel_w`、`body_ang_vel_w` 以及相同的元数据。如果 `--motion_file` 指向最小数据集而不是预计算训练数据集,训练会立即报错。 + ## 常用命令 ```bash # 强制重建 python train_mimic/scripts/data/build_dataset.py \ - --spec train_mimic/configs/datasets/twist2_full.yaml --force + --spec train_mimic/configs/datasets/twist2.yaml --force # 多进程并行 python train_mimic/scripts/data/build_dataset.py \ - --spec train_mimic/configs/datasets/twist2_full.yaml --jobs 8 + --spec train_mimic/configs/datasets/twist2.yaml --jobs 8 # 自定义输出根目录 python train_mimic/scripts/data/build_dataset.py \ - --spec train_mimic/configs/datasets/twist2_full.yaml \ + --spec train_mimic/configs/datasets/twist2.yaml \ --output_root /tmp/my_datasets # 打印 build report python train_mimic/scripts/data/build_dataset.py \ - --spec train_mimic/configs/datasets/twist2_full.yaml --json + --spec train_mimic/configs/datasets/twist2.yaml --json + +# 从已有最小数据集生成预计算训练数据集 +python train_mimic/scripts/data/precompute_dataset.py \ + data/datasets/twist2 --outdir data/datasets/twist2_precomputed --jobs 8 --force + +# 查看数据集统计 +python train_mimic/scripts/data/inspect_dataset.py data/datasets/twist2 ``` ## 批量转换为 NPZ clips -只把某批原始数据转成标准 NPZ clip,不做 train/val merge: +只把某批原始数据转成标准 NPZ clip,不合并为 shard: ```bash python train_mimic/scripts/data/ingest_motion.py \ --type bvh --input data/lafan1_bvh \ - --output data/datasets/lafan1_v1/clips/lafan1_v1 \ - --source lafan1_v1 --bvh_format lafan1 --jobs 8 + --output data/lafan1_clips/lafan1 \ + --source lafan1 --bvh_format lafan1 --jobs 8 ``` ## FK 一致性检查 ```bash python train_mimic/scripts/data/check_motion_npz_fk.py \ - --npz data/datasets//clips//.npz + --npz data/lafan1_clips/lafan1/.npz ``` 推荐判据:`pos_max < 1e-3 m`、`quat_mean < 0.05 rad`、`quat_p95 < 0.10 rad`。 - -## 重新切分 shard - -```bash -python train_mimic/scripts/data/split_shards.py \ - --input data/datasets/seed/train \ - --output data/datasets/seed/train_small_shards \ - --max_size_gb 2 -``` - -每个 shard 是自包含的 merged NPZ(含完整 clip metadata),训练时直接传目录。 diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/training-troubleshooting.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/training-troubleshooting.md index 84d119c3..14a8500a 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/training-troubleshooting.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/reference/training-troubleshooting.md @@ -35,7 +35,7 @@ sidebar_position: 5 ```bash python train_mimic/scripts/data/check_motion_npz_fk.py \ - --npz data/datasets//clips//.npz + --npz data/lafan1_clips/lafan1/.npz ``` 推荐判据:`pos_max < 1e-3 m`、`quat_mean < 0.05 rad`、`quat_p95 < 0.10 rad`。 @@ -45,7 +45,7 @@ python train_mimic/scripts/data/check_motion_npz_fk.py \ ```bash python train_mimic/scripts/train.py \ --num_envs 64 --max_iterations 100 \ - --motion_file data/datasets//train + --motion_file data/datasets/_precomputed ``` 预期:`Mean episode length` 明显大于 1,`error_anchor_pos` 开始下降。 @@ -126,7 +126,7 @@ self.sim.nconmax = 150_000 ```bash python train_mimic/scripts/benchmark.py \ --checkpoint logs/rsl_rl/g1_general_tracking//model_30000.pt \ - --motion_file data/datasets//val \ + --motion_file data/datasets/_precomputed \ --num_envs 1 --num_eval_steps 2000 \ --video --video_length 600 ``` @@ -168,6 +168,6 @@ print(cfg.init_state.joint_pos) # 必须与 g1.yaml default_angles 一致 ### 解决方案 -更新 `teleopit/configs/robot/g1.yaml` 和 `g1_mjlab.xml`,使其与训练环境的值一致(default angles、armature、condim)。 +更新 `teleopit/configs/robot/g1.yaml` 和 `assets/robots/unitree_g1/g1_29dof.xml`,使其与训练环境的值一致(default angles、armature、condim)。 此修复同时影响 sim2real 路径,因为 `default_angles` 被 `rl_policy.py` 和 `observation.py` 共用。 diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/bvh-sim2real.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/bvh-sim2real.md index 80644d14..59acadaa 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/bvh-sim2real.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/bvh-sim2real.md @@ -73,12 +73,6 @@ real_robot.network_interface=enp130s0 # 在 BVH 最后一帧暂停 playback.pause_on_end=true -# 从 standing/当前机器人状态平滑进入回放 -transition_duration=2.0 - -# 离线回放恢复混合时长 -pause_resume_transition_duration=1.0 - # 控制循环频率 policy_hz=50 ``` diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/offline-sim2sim.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/offline-sim2sim.md index 66d317f7..b5731659 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/offline-sim2sim.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/offline-sim2sim.md @@ -68,20 +68,6 @@ viewers=none # 无头模式(不显示窗口) 当所有 Viewer 窗口被关闭后,仿真会自动结束。 ::: -## 录制 - -将仿真数据录制为 HDF5 文件: - -```bash -python scripts/run/run_sim.py \ - controller.policy_path=track.onnx \ - input.bvh_file=data/sample_bvh/aiming1_subject1.bvh \ - +record=true \ - recording.output_path=outputs/session.h5 -``` - -录制包含以下字段:`joint_pos`、`joint_vel`、`mimic_obs`、`action`、`target_dof_pos`、`torque`、`timestamp`。 - ## 离线渲染 在无头模式下将仿真渲染为视频: diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md index ba163c74..23fd2ce4 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2real.md @@ -21,8 +21,7 @@ Pico 头显 -> Teleopit host -> retarget -> RL policy -> g1_bridge_sdk -> G1 两种方式都使用 `Pico4InputProvider` 和进程内 pico-bridge receiver。不存在单独的 onboard Pico 输入模式。 -Teleopit 0.3.0 请保持 host receiver 使用 pico-bridge 0.2.0。pico-bridge -0.2.1 修改了接口语义,不是该 Teleopit 版本支持的 receiver 版本。 +Teleopit 面向 pico-bridge 0.2.1 及其 `pico_native` tracking 语义。 ## 1. 安装运行时依赖 @@ -92,6 +91,32 @@ python scripts/run/run_sim2real.py \ real_robot.network_interface=eth0 ``` +## 可选 HDF5 录制 + +在负责 Pico 输入和 RealSense 的机器上安装 recording extra: + +```bash +pip install -e '.[recording]' +``` + +运行录制配置: + +```bash +python scripts/run/run_sim2real.py \ + --config-name sim2real_record \ + controller.policy_path=track.onnx \ + real_robot.network_interface=enp130s0 \ + recording.task="walk forward" +``` + +终端控制为:`R` 开始 episode,`S` 保存,`D` 丢弃,`Q` 关闭。可以录制 +`STANDING`、`MOCAP`、`ARMS` 和暂停状态的 mocap;已经保存的 episode 不支持再丢弃。 +episode 会保存为 `data/recordings/sim2real_hdf5/episodes/` 下的 `.h5` 文件, +压缩 MP4 sidecar 视频保存在 `data/recordings/sim2real_hdf5/videos/` 下。 +HDF5 episode 以 30 Hz 保存 `frame_index` 和 `timestamp` 同步数组,以及 +`observation.state(68)`、`observation.mode(1)`、`action(36)` 和 +`action.hand(12)`。 + ## 操作流程 始终把 Unitree 遥控器拿在手里。`L1+R1` 是进入 `DAMPING` 的急停路径。 @@ -101,6 +126,7 @@ python scripts/run/run_sim2real.py \ | Unitree remote `Start` | 进入 `STANDING` | | Unitree remote `Y` | 进入 `MOCAP` | | Pico/controller `A` | 暂停 / 恢复实时动捕 | +| Pico/controller `B` | 在 `MOCAP` / `ARMS` 之间切换 | | Unitree remote `X` | 返回 `STANDING` | | Unitree remote `L1+R1` | 急停(`DAMPING`) | @@ -116,11 +142,15 @@ Pico body frames -> retarget -> reference buffer -> observation -> policy -> G1 ``` 进入 `STANDING` 时,Teleopit 会释放当前 Unitree 模式,进入 debug/low-level 控制, -短暂锁住当前关节,重置 policy 状态,并通过 Kp ramp 减少启动冲击。 +短暂锁住当前关节,重置 policy 状态,并在不改变 policy target 的情况下执行 Kp ramp。 -进入 `MOCAP` 时,Teleopit 会重置 policy/reference 状态,并将参考从当前机器人状态平滑过渡到 +进入 `MOCAP` 时,Teleopit 会重置 policy/reference 状态,并通过实时参考时间线开始跟踪 实时 mocap 命令。 +`ARMS` 会保持同一条实时 retargeting 时间线继续运行,但发送给 motion tracker 的参考会被组合: +身体、腰部和腿部保持站立姿态,双臂跟随实时 retarget 结果。进入或离开 `ARMS` 时会重置 +policy/reference 对齐,并使用同一套 Kp ramp 安全路径。 + ## 暂停 / 恢复 Pico 暂停/恢复是 mocap-session control event。 @@ -133,6 +163,88 @@ Pico 暂停/恢复是 mocap-session control event。 恢复时请保持静止,并尽量接近暂停时的姿态。这样可以减少实时追踪恢复时的参考突变。 ::: +## 可选 LinkerHand 控制 + +Pico sim2real 可以用 Pico 输入控制 LinkerHand: + +- `gripper`:按住同侧 grip 作为 deadman,同侧 trigger 控制对应手闭合。 + 该模式支持 `hands.driver=linkerhand_l6` 和 `hands.driver=linkerhand_o6`; + 速度和张开/闭合姿态来自对应 driver 配置。 +- `vr_hand_pose`:只支持 L6,通过 somehand 重定向 Pico 手部 pose,并下发连续 L6 手部目标。 + 如果某侧手部 pose 消失,该侧会保持上一条手势命令。这个模式使用 Teleopit 的 + Pico landmark 适配器和 somehand 0.2.0 公开的 `somehand.api`,并始终将 L6 + 速度设为最大值。默认配置使用 60 Hz 的低延时 somehand 路径并减少平滑,所以响应会更快, + 但可能比标准 somehand 设置更抖。 + +`hands.enabled=true` 时,手控会在所有 sim2real 模式中保持生效。退出和手控运行时失败会发送配置的张开姿态。 + +如果主 Pico profile 没有包含手控支持,先安装本地手控包: + +```bash +git submodule update --init --recursive +pip install -e third_party/linkerhand-python-sdk +pip install -e third_party/somehand +scripts/setup/download_somehand_l6_assets.sh +``` + +测试或运行手控前,先开启 CAN 接口: + +```bash +sudo /usr/sbin/ip link set can0 up type can bitrate 1000000 +sudo /usr/sbin/ip link set can1 up type can bitrate 1000000 +``` + +启用完整 sim2real 前,先用独立开合测试验证灵巧手连接。测试默认一直运行到 Ctrl-C: + +```bash +python scripts/dev/test_linkerhand_l6.py \ + --hand-type both \ + --left-can can0 \ + --right-can can1 +``` + +O6 独立开合测试需要加上 O6 driver: + +```bash +python scripts/dev/test_linkerhand_l6.py \ + --driver linkerhand_o6 \ + --hand-type both \ + --left-can can0 \ + --right-can can1 +``` + +如果要用实时 Pico gripper 输入测试 O6,再加 `--mode gripper`。 + +然后在 Pico sim2real 中启用 L6 gripper 控制: + +```bash +hands.enabled=true +hands.driver=linkerhand_l6 +hands.mode=gripper +hands.linkerhand_l6.left_can=can0 +hands.linkerhand_l6.right_can=can1 +``` + +O6 gripper 控制使用: + +```bash +hands.enabled=true +hands.driver=linkerhand_o6 +hands.mode=gripper +hands.linkerhand_o6.left_can=can0 +hands.linkerhand_o6.right_can=can1 +``` + +连续 VR 手部 pose 控制使用: + +```bash +hands.enabled=true +hands.driver=linkerhand_l6 +hands.mode=vr_hand_pose +hands.linkerhand_l6.left_can=can0 +hands.linkerhand_l6.right_can=can1 +``` + ## 可选 RealSense 预览 将 G1 RealSense 彩色相机推送回 Pico 头显: @@ -163,15 +275,14 @@ input.bridge_advertise_ip=192.168.1.20 # 进入 MOCAP 前要求的连续有效动捕帧数 mocap_switch.check_frames=10 -# 平滑过渡到 mocap 参考 -transition_duration=2.0 - -# 恢复前采集的实时帧数 -pause_resume_warmup_steps=2 - # 更换 Pico 暂停键 input.pause_button=right_axis_click +# 开启 LinkerHand gripper 控制 +hands.enabled=true +hands.driver=linkerhand_l6 +hands.mode=gripper + # 开启头显视频预览 input.video.enabled=true ``` @@ -185,4 +296,5 @@ input.video.enabled=true | 无法进入 debug mode | Unitree mode 释放失败 | 停止其他机器人模式后再次按 `Start` | | 机器人进入 `STANDING` 但不进入 `MOCAP` | 动捕验证失败 | 保持追踪稳定,查看 `mocap_switch.check_frames` 日志 | | Pico 暂停没有返回 `STANDING` | 这是预期行为 | Pico 暂停只冻结 mocap;按遥控器 `X` 返回 `STANDING` | +| LinkerHand 不动 | `hands.enabled=false`、gripper deadman 未按住、SDK/资产未安装,或 CAN 通道错误 | 设置 `hands.enabled=true` 和 `hands.mode`,运行 `scripts/dev/test_linkerhand_l6.py`,并检查所选 driver 的 `left_can` / `right_can` | | 视频预览不可用 | RealSense 或视频源失败 | 检查相机权限、`input.video.source` 和日志 | diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2sim.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2sim.md index 867f78f7..02eb6143 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2sim.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/pico-sim2sim.md @@ -46,14 +46,13 @@ python -c "from pico_bridge import PicoBridge; print('OK')" Teleopit 会通过 `Pico4InputProvider` 在进程内启动 `pico_bridge.PicoBridge`。 后续 wired 和 onboard sim2real 部署也使用同一条 Pico 输入路径。 -Teleopit 0.3.0 请保持 host receiver 使用 pico-bridge 0.2.0。pico-bridge -0.2.1 修改了接口语义,不是该 Teleopit 版本支持的 receiver 版本。 +Teleopit 面向 pico-bridge 0.2.1 及其 `pico_native` tracking 语义。 ## 3. 下载资源 ```bash pip install modelscope -python scripts/setup/download_assets.py --only gmr ckpt bvh +python scripts/setup/download_assets.py --only robots gmr ckpt bvh ``` ## 4. 运行 Pico Sim2Sim @@ -70,6 +69,7 @@ python scripts/run/run_sim.py \ |------|------| | `Y` | 进入 `MOCAP` | | `A` | 暂停 / 恢复实时动捕 | +| `B` | 在 `MOCAP` / `ARMS` 之间切换 | | `X` | 返回 `STANDING` | | `Q` | 退出 | @@ -86,9 +86,12 @@ Pico 暂停/恢复会冻结 mocap session;它不是切回 `STANDING`。 默认 Pico 暂停键是 `A`。支持的覆盖值包括 `B`、`X`、`Y`、`left_axis_click`、 `right_axis_click`、`left_menu_button` 和 `right_menu_button`。 +默认 Pico 双臂模式按钮是 `B`。`ARMS` 会让身体、腰部和腿部保持站立姿态,同时双臂跟随 +实时 retarget 结果。 + ## 可选头显视频预览 -pico-bridge 0.2.0 可以在头显中显示 host 侧视频流。在仿真中,Teleopit 可以推送 +pico-bridge 0.2.1 可以在头显中显示 host 侧视频流。在仿真中,Teleopit 可以推送 MuJoCo `d435i_rgb` 相机: ```bash @@ -132,7 +135,7 @@ input.video.enabled=true | 现象 | 可能原因 | 解决方法 | |------|----------|----------| | `ImportError: pico_bridge` | 未安装 Pico extra | 执行 `pip install -e '.[pico4]'` | -| 启动提示 pico-bridge 太旧 | 已安装 receiver 不支持视频参数 | 重新安装 Pico extra,确保使用 pico-bridge 0.2.0 | +| 启动提示 pico-bridge 太旧 | 已安装 receiver 不支持所需 API 或 tracking 语义 | 重新安装 Pico extra,确保使用 pico-bridge 0.2.1 | | `TimeoutError: No Pico4 body data` | 头显未连接或 body tracking 未激活 | 检查头显 app、网络和 `input.pico4_timeout` | | discovery 找不到 host | 广播 IP 不对或 UDP 被阻断 | 设置 `input.bridge_advertise_ip=`,确认 UDP 端口 `63901` 可达 | | 仿真机器人不跟随 | 循环仍在 `STANDING` | 追踪准备好后按 `Y` | diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/standalone-standing.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/standalone-standing.md index 14802910..54bb06e6 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/standalone-standing.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/standalone-standing.md @@ -5,8 +5,8 @@ sidebar_position: 3 # 独立站立测试 在接入完整 sim2real 控制之前,如果正在调试新机器人、网络配置或 policy,先运行此测试。 -它不使用 Pico、BVH 回放、retargeting,也不走完整 Teleopit mocap pipeline,只验证 G1 bridge -和 RL standing 路径。 +它不使用 Pico、BVH 回放、retargeting,也不走完整 mocap pipeline,只验证 G1 bridge +和 sim2real 使用的同一条 RL standing 路径。 ```text G1 LowState -> standing observation -> RL policy -> G1 LowCmd targets @@ -57,12 +57,27 @@ python scripts/run/standalone_standing.py \ --network-interface eth0 ``` +standalone standing 复用 sim2real standing 组件:`UnitreeG1Robot`、 +`Sim2RealSafetyManager`、`RLPolicyController`、`VelCmdObservationBuilder` 和 +`Sim2RealReferenceProcessor`。锁住当前关节后发送 policy target,同时在 2 秒内把 Kp +从 10% 逐步升到配置的增益。可以这样调整启动行为: + +```bash +python scripts/run/standalone_standing.py \ + --policy track.onnx \ + --network-interface eth0 \ + --kp-ramp-duration 2.0 \ + --kp-ramp-floor-ratio 0.1 +``` + ## 它会检查什么 - `g1_bridge_sdk` 能正确导入。 - 能从机器人收到 LowState。 - dual-input ONNX policy 能运行 standing observation 路径。 - 能通过 C++ bridge 发布 low-level position targets。 +- observation 构建、action scale、默认站姿、Kp ramp 和 joint-limit clipping 与 sim2real + standing runtime 保持一致。 ## 下一步 diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md index e8f23f80..84a9b366 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/tutorials/training.md @@ -23,6 +23,14 @@ pip install -e '.[train]' python -c "import train_mimic.tasks; print('training OK')" ``` +下载最小 seed 数据集,并生成预计算训练 shard: + +```bash +python scripts/setup/download_assets.py --only robots data +python train_mimic/scripts/data/precompute_dataset.py \ + data/datasets/seed --outdir data/datasets/seed_precomputed --jobs 8 +``` + ## 训练 ### 冒烟测试 @@ -31,7 +39,7 @@ python -c "import train_mimic.tasks; print('training OK')" python train_mimic/scripts/train.py \ --num_envs 64 \ --max_iterations 100 \ - --motion_file data/datasets/seed/train + --motion_file data/datasets/seed_precomputed ``` ### 完整训练 @@ -40,7 +48,7 @@ python train_mimic/scripts/train.py \ python train_mimic/scripts/train.py \ --num_envs 4096 \ --max_iterations 30000 \ - --motion_file data/datasets/seed/train + --motion_file data/datasets/seed_precomputed ``` ### 多卡训练 @@ -50,13 +58,33 @@ python train_mimic/scripts/train.py \ --gpu_ids 0 1 2 3 \ --num_envs 1024 \ --max_iterations 30000 \ - --motion_file data/datasets/seed/train + --motion_file data/datasets/seed_precomputed +``` + +### 多机多卡训练 + +跨多台机器训练时,直接使用 `torchrun`: + +```bash +torchrun \ + --nnodes=$PET_NNODES \ + --nproc_per_node=$PET_NPROC_PER_NODE \ + --node_rank=$PET_NODE_RANK \ + --master_addr=$PET_MASTER_ADDR \ + --master_port=$PET_MASTER_PORT \ + train_mimic/scripts/train.py \ + --num_envs 1024 \ + --max_iterations 1000 \ + --motion_file data/datasets/seed_precomputed ``` **注意事项:** - 多卡模式下 `--num_envs` 为每张 GPU 的环境数量 -- 默认日志工具为 TensorBoard;传入 `--wandb_project ` 可启用 W&B -- `--motion_file` 仅接受分片目录(包含 `shard_*.npz` 文件的目录) +- 多机模式下 `--num_envs` 也按每个进程计算,因此总环境数会随 `world_size` 线性增长 +- 默认日志工具为 TensorBoard。使用 `--logger wandb` 或 `--logger swanlab` 可选择 W&B 或 SwanLab;项目名默认使用 `experiment_name` +- `--motion_file` 接受预计算训练数据集根目录或单个预计算 `.h5` shard;shard 会递归发现 +- 如果只有最小分发 shard,先运行 `python train_mimic/scripts/data/precompute_dataset.py --outdir `,再把预计算输出传给训练。 +- 训练会在启动时把所有发现的预计算 motion window 全量加载到内存中。 - `--max_iterations` 表示追加迭代次数;例如从 `model_12000.pt` 恢复训练并设置 `--max_iterations 18000`,最终将训练到 `model_30000.pt` ## 导出 ONNX @@ -68,7 +96,7 @@ python train_mimic/scripts/save_onnx.py \ --history_length 10 ``` -导出的模型为双输入 ONNX(`obs` + `obs_history`)。推理端仅支持 166D 双输入 ONNX 格式。 +导出的模型为双输入 ONNX(`obs` + `obs_history`)。推理端需要与当前 `velcmd_history` 观测匹配的 167D 双输入 ONNX 策略。 ## 评估 @@ -77,7 +105,7 @@ python train_mimic/scripts/save_onnx.py \ ```bash python train_mimic/scripts/play.py \ --checkpoint logs/rsl_rl/g1_general_tracking//model_30000.pt \ - --motion_file data/datasets/seed/val + --motion_file data/datasets/seed_precomputed ``` ### 定量评估 @@ -85,7 +113,7 @@ python train_mimic/scripts/play.py \ ```bash python train_mimic/scripts/benchmark.py \ --checkpoint logs/rsl_rl/g1_general_tracking//model_30000.pt \ - --motion_file data/datasets/seed/val \ + --motion_file data/datasets/seed_precomputed \ --num_envs 1 ``` @@ -94,7 +122,7 @@ python train_mimic/scripts/benchmark.py \ ```bash python train_mimic/scripts/benchmark.py \ --checkpoint logs/rsl_rl/g1_general_tracking//model_30000.pt \ - --motion_file data/datasets/seed/val \ + --motion_file data/datasets/seed_precomputed \ --num_envs 1 \ --video \ --video_length 600 @@ -113,4 +141,4 @@ train_mimic/scripts - `train_mimic/app.py` - 训练/播放/评估的统一入口 - `train_mimic/tasks/tracking/config/env.py` - General-Tracking-G1 环境构建器 - `train_mimic/tasks/tracking/config/rl.py` - TemporalCNN PPO 配置 -- `train_mimic/tasks/tracking/mdp/commands.py` - 支持 `adaptive` / `uniform` / `start` 三种采样模式 +- `train_mimic/tasks/tracking/mdp/commands.py` - 支持 `uniform`、`start` 和 `rewind` 采样模式。训练默认使用 `rewind`;播放/评估使用 `start`。 diff --git a/pyproject.toml b/pyproject.toml index a17b4714..4bb6ce39 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,14 +4,14 @@ build-backend = "setuptools.build_meta" [project] name = "teleopit" -version = "0.3.0" +version = "0.4.0" description = "Teleoperation framework for humanoid robots with motion retargeting" authors = [ {name = "Teleopit Team"} ] readme = "README.md" requires-python = ">=3.10" -license = {text = "MIT"} +license = {text = "Apache-2.0"} dependencies = [ "mujoco", "mink", @@ -23,6 +23,7 @@ dependencies = [ "omegaconf", "h5py", "onnxruntime", + "pyzmq", "rich", "loop-rate-limiters", "imageio", @@ -40,17 +41,24 @@ sim2real = [ # bash scripts/setup/setup_g1_bridge.sh ] train = [ - "torch>=2.0.0", + "torch>=2.7.0", "numpy>=1.20.0", - "rsl-rl-lib", - "mjlab>=1.2.0", - "wandb>=0.15.0", + "rsl-rl-lib==5.2.0", + "mjlab==1.4.0", + "wandb>=0.22.3", + "swanlab", "tqdm>=4.65.0", ] pico4 = [ - "pico-bridge[camera] @ https://github.com/BotRunner64/pico-bridge/releases/download/v0.2.0/pico_bridge-0.2.0-py3-none-any.whl", + "pico-bridge[camera] @ https://github.com/BotRunner64/pico-bridge/releases/download/v0.2.1/pico_bridge-0.2.1-py3-none-any.whl", "teleopit[sim2real]", ] +recording = [ + "teleopit[pico4]", + "opencv-python", + "imageio[ffmpeg]", +] +dexhand = [] [tool.setuptools.packages.find] where = ["."] diff --git a/scripts/dev/bench_policy_onnx.py b/scripts/dev/bench_policy_onnx.py new file mode 100644 index 00000000..b40b7c2a --- /dev/null +++ b/scripts/dev/bench_policy_onnx.py @@ -0,0 +1,495 @@ +"""Benchmark Teleopit ONNX policy inference latency. + +This is a policy-only micro-benchmark intended for onboard model-size checks. +It does not require MuJoCo, robot hardware, GMR assets, or Pico input. + +Examples: + python scripts/dev/bench_policy_onnx.py --policy track.onnx + python scripts/dev/bench_policy_onnx.py --policy track.onnx --runs 20000 --device cpu + python scripts/dev/bench_policy_onnx.py --policy track.onnx --mode direct +""" + +from __future__ import annotations + +import argparse +import json +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Sequence + +import numpy as np + + +@dataclass(frozen=True) +class InputSpec: + name: str + shape: tuple[Any, ...] + dtype: str + + +@dataclass(frozen=True) +class PolicySignature: + obs_name: str + obs_dim: int + history_name: str | None + history_length: int + history_obs_dim: int + output_name: str + + @property + def is_dual_input(self) -> bool: + return self.history_name is not None + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Benchmark ONNX policy latency for Teleopit runtime sizing.", + ) + parser.add_argument("--policy", required=True, help="Path to exported policy.onnx") + parser.add_argument( + "--device", + choices=["cpu", "cuda", "auto"], + default="cpu", + help="ONNX Runtime execution provider preference (default: cpu)", + ) + parser.add_argument( + "--mode", + choices=["controller", "direct"], + default="controller", + help=( + "controller simulates RLPolicyController obs_history stacking; " + "direct reuses prebuilt feed tensors and measures session.run only" + ), + ) + parser.add_argument("--runs", type=int, default=5000, help="Measured iterations") + parser.add_argument("--warmup", type=int, default=200, help="Warmup iterations") + parser.add_argument("--policy-hz", type=float, default=50.0, help="Runtime policy frequency") + parser.add_argument( + "--input-mode", + choices=["random", "zeros"], + default="random", + help="Synthetic observation contents; generated outside the timed region", + ) + parser.add_argument( + "--obs-dim", + type=int, + default=None, + help="Required only when the ONNX obs feature dimension is dynamic", + ) + parser.add_argument( + "--history-length", + type=int, + default=None, + help="Required only when the ONNX obs_history length dimension is dynamic", + ) + parser.add_argument( + "--intra-op-threads", + type=int, + default=0, + help="ONNX Runtime intra-op threads; 0 keeps ORT default", + ) + parser.add_argument( + "--inter-op-threads", + type=int, + default=0, + help="ONNX Runtime inter-op threads; 0 keeps ORT default", + ) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument( + "--max-p99-ms", + type=float, + default=None, + help="Exit non-zero if p99 latency exceeds this value", + ) + parser.add_argument( + "--json", + type=str, + default=None, + help="Optional path to write benchmark summary JSON", + ) + return parser.parse_args() + + +def _dim_to_int(dim: Any) -> int | None: + if isinstance(dim, int): + return dim + if isinstance(dim, np.integer): + return int(dim) + if isinstance(dim, float) and dim.is_integer(): + return int(dim) + if isinstance(dim, str): + try: + return int(dim) + except ValueError: + return None + return None + + +def _feature_dim(shape: Sequence[Any], fallback: int | None, label: str) -> int: + if not shape: + raise ValueError(f"{label} input has empty shape; cannot infer feature dimension") + dim = _dim_to_int(shape[-1]) + if dim is not None: + return dim + if fallback is not None: + return int(fallback) + raise ValueError( + f"{label} feature dimension is dynamic ({shape[-1]!r}). " + f"Pass --obs-dim to make the benchmark input explicit." + ) + + +def _history_len(shape: Sequence[Any], fallback: int | None) -> int: + if len(shape) < 3: + raise ValueError(f"obs_history must be rank 3 [batch, history, obs_dim], got shape={shape}") + dim = _dim_to_int(shape[1]) + if dim is not None: + return dim + if fallback is not None: + return int(fallback) + raise ValueError( + f"obs_history length dimension is dynamic ({shape[1]!r}). " + f"Pass --history-length to make the benchmark input explicit." + ) + + +def _select_providers(ort: Any, device: str) -> list[str]: + available = set(ort.get_available_providers()) + providers: list[str] = [] + if (device == "cuda" or device == "auto") and "CUDAExecutionProvider" in available: + providers.append("CUDAExecutionProvider") + if device == "cuda" and not providers: + raise RuntimeError( + "CUDAExecutionProvider was requested but is not available. " + f"Available providers: {sorted(available)}" + ) + providers.append("CPUExecutionProvider") + return providers + + +def _make_session(policy_path: Path, args: argparse.Namespace) -> tuple[Any, list[str], list[str]]: + try: + import onnxruntime as ort + except ModuleNotFoundError as exc: + raise RuntimeError("onnxruntime is required; install the Teleopit inference dependencies") from exc + + options = ort.SessionOptions() + options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL + if args.intra_op_threads > 0: + options.intra_op_num_threads = int(args.intra_op_threads) + if args.inter_op_threads > 0: + options.inter_op_num_threads = int(args.inter_op_threads) + + providers = _select_providers(ort, str(args.device)) + session = ort.InferenceSession(str(policy_path), sess_options=options, providers=providers) + return session, providers, list(ort.get_available_providers()) + + +def _inspect_signature(session: Any, args: argparse.Namespace) -> tuple[PolicySignature, list[InputSpec]]: + inputs = session.get_inputs() + outputs = session.get_outputs() + if len(outputs) < 1: + raise ValueError("ONNX policy has no outputs") + if len(inputs) not in (1, 2): + names = [inp.name for inp in inputs] + raise ValueError( + "Unsupported ONNX policy input signature. Expected one obs input or " + f"dual inputs ('obs', 'obs_history'), got {names}." + ) + + specs = [ + InputSpec(name=inp.name, shape=tuple(inp.shape), dtype=str(getattr(inp, "type", ""))) + for inp in inputs + ] + obs_dim = _feature_dim(inputs[0].shape, args.obs_dim, inputs[0].name) + history_name: str | None = None + history_length = 0 + history_obs_dim = 0 + + if len(inputs) == 2: + if inputs[1].name != "obs_history": + names = [inp.name for inp in inputs] + raise ValueError( + "Unsupported dual-input policy. Expected second input named " + f"'obs_history', got {names}." + ) + history_name = inputs[1].name + history_length = _history_len(inputs[1].shape, args.history_length) + history_obs_dim = _feature_dim(inputs[1].shape, args.obs_dim, "obs_history") + if history_obs_dim != obs_dim: + raise ValueError( + f"obs_dim mismatch: obs has {obs_dim}, obs_history has {history_obs_dim}" + ) + + signature = PolicySignature( + obs_name=inputs[0].name, + obs_dim=obs_dim, + history_name=history_name, + history_length=history_length, + history_obs_dim=history_obs_dim, + output_name=outputs[0].name, + ) + return signature, specs + + +def _make_obs(total: int, obs_dim: int, input_mode: str, seed: int) -> np.ndarray: + if input_mode == "zeros": + return np.zeros((total, obs_dim), dtype=np.float32) + rng = np.random.default_rng(seed) + return rng.standard_normal((total, obs_dim), dtype=np.float32) + + +def _stats_ms(samples_ms: np.ndarray) -> dict[str, float]: + return { + "mean": float(np.mean(samples_ms)), + "std": float(np.std(samples_ms)), + "min": float(np.min(samples_ms)), + "p50": float(np.percentile(samples_ms, 50)), + "p90": float(np.percentile(samples_ms, 90)), + "p95": float(np.percentile(samples_ms, 95)), + "p99": float(np.percentile(samples_ms, 99)), + "max": float(np.max(samples_ms)), + } + + +def _run_direct( + session: Any, + signature: PolicySignature, + obs_samples: np.ndarray, + runs: int, + warmup: int, +) -> tuple[np.ndarray, tuple[int, ...]]: + obs = obs_samples[0:1] + feed = {signature.obs_name: obs} + if signature.history_name is not None: + feed[signature.history_name] = np.repeat( + obs[:, np.newaxis, :], + signature.history_length, + axis=1, + ).astype(np.float32, copy=False) + + for _ in range(warmup): + session.run([signature.output_name], feed) + + timings = np.empty(runs, dtype=np.float64) + output_shape: tuple[int, ...] = () + for idx in range(runs): + t0 = time.perf_counter_ns() + output = session.run([signature.output_name], feed)[0] + t1 = time.perf_counter_ns() + timings[idx] = (t1 - t0) / 1_000_000.0 + if idx == 0: + output_shape = tuple(np.asarray(output).shape) + return timings, output_shape + + +def _run_controller_like( + session: Any, + signature: PolicySignature, + obs_samples: np.ndarray, + runs: int, + warmup: int, +) -> tuple[np.ndarray, tuple[int, ...]]: + from collections import deque + + history_buf: deque[np.ndarray] = deque(maxlen=max(signature.history_length, 1)) + total = warmup + runs + timings = np.empty(runs, dtype=np.float64) + output_shape: tuple[int, ...] = () + + for idx in range(total): + measured = idx >= warmup + t0 = time.perf_counter_ns() if measured else 0 + obs_flat = obs_samples[idx] + obs = obs_flat[np.newaxis, :] + if signature.history_name is not None: + if len(history_buf) == 0: + for _ in range(signature.history_length): + history_buf.append(obs_flat.copy()) + else: + history_buf.append(obs_flat.copy()) + obs_history = np.stack(list(history_buf), axis=0)[np.newaxis].astype(np.float32) + feed = { + signature.obs_name: obs, + signature.history_name: obs_history, + } + else: + feed = {signature.obs_name: obs} + + output = session.run([signature.output_name], feed)[0] + if measured: + t1 = time.perf_counter_ns() + out_idx = idx - warmup + timings[out_idx] = (t1 - t0) / 1_000_000.0 + if out_idx == 0: + output_shape = tuple(np.asarray(output).shape) + return timings, output_shape + + +def _print_summary( + policy_path: Path, + args: argparse.Namespace, + providers: list[str], + available_providers: list[str], + signature: PolicySignature, + input_specs: list[InputSpec], + stats: dict[str, float], + output_shape: tuple[int, ...], + over_budget: int, +) -> None: + budget_ms = 1000.0 / float(args.policy_hz) + measured_fps_mean = 1000.0 / stats["mean"] if stats["mean"] > 0 else float("inf") + p95_margin_ms = budget_ms - stats["p95"] + p99_margin_ms = budget_ms - stats["p99"] + + print("=" * 72) + print("ONNX Policy Benchmark") + print("=" * 72) + print(f"Policy: {policy_path}") + print(f"Mode: {args.mode}") + print(f"Runs / warmup: {args.runs} / {args.warmup}") + print(f"Policy rate budget: {args.policy_hz:.1f} Hz = {budget_ms:.2f} ms/step") + print(f"Providers selected: {providers}") + print(f"Providers available: {available_providers}") + print() + print("Input signature:") + for spec in input_specs: + print(f" {spec.name}: shape={spec.shape}, type={spec.dtype}") + print(f"Output: {signature.output_name}, measured shape={output_shape}") + print() + print("Latency (ms):") + print(f" mean: {stats['mean']:.4f} std: {stats['std']:.4f}") + print(f" min: {stats['min']:.4f} max: {stats['max']:.4f}") + print(f" p50: {stats['p50']:.4f}") + print(f" p90: {stats['p90']:.4f}") + print(f" p95: {stats['p95']:.4f} margin vs budget: {p95_margin_ms:.4f} ms") + print(f" p99: {stats['p99']:.4f} margin vs budget: {p99_margin_ms:.4f} ms") + print() + print(f"Mean throughput: {measured_fps_mean:.1f} policy calls/s") + print(f"Over budget: {over_budget} / {args.runs}") + print("=" * 72) + + if p99_margin_ms < 0: + print("WARNING: p99 latency exceeds the policy-rate budget.") + elif p99_margin_ms < 0.25 * budget_ms: + print("NOTE: p99 latency fits but leaves less than 25% budget margin.") + else: + print("OK: p99 latency leaves at least 25% policy-rate budget margin.") + + +def _write_json( + path: Path, + policy_path: Path, + args: argparse.Namespace, + providers: list[str], + available_providers: list[str], + signature: PolicySignature, + stats: dict[str, float], + output_shape: tuple[int, ...], + over_budget: int, +) -> None: + budget_ms = 1000.0 / float(args.policy_hz) + payload = { + "policy": str(policy_path), + "mode": args.mode, + "runs": int(args.runs), + "warmup": int(args.warmup), + "policy_hz": float(args.policy_hz), + "budget_ms": budget_ms, + "providers": providers, + "available_providers": available_providers, + "signature": { + "obs_name": signature.obs_name, + "obs_dim": signature.obs_dim, + "history_name": signature.history_name, + "history_length": signature.history_length, + "history_obs_dim": signature.history_obs_dim, + "output_name": signature.output_name, + "output_shape": list(output_shape), + }, + "latency_ms": stats, + "mean_policy_calls_per_s": 1000.0 / stats["mean"] if stats["mean"] > 0 else None, + "over_budget": int(over_budget), + "p95_margin_ms": budget_ms - stats["p95"], + "p99_margin_ms": budget_ms - stats["p99"], + } + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8") + + +def main() -> int: + args = _parse_args() + if args.runs <= 0: + raise ValueError("--runs must be > 0") + if args.warmup < 0: + raise ValueError("--warmup must be >= 0") + if args.policy_hz <= 0: + raise ValueError("--policy-hz must be > 0") + + policy_path = Path(args.policy).expanduser().resolve() + if not policy_path.is_file(): + raise FileNotFoundError(f"ONNX policy not found: {policy_path}") + + session, providers, available_providers = _make_session(policy_path, args) + signature, input_specs = _inspect_signature(session, args) + total_samples = max(1, args.warmup + args.runs) + obs_samples = _make_obs(total_samples, signature.obs_dim, args.input_mode, args.seed) + + if args.mode == "direct": + timings_ms, output_shape = _run_direct( + session=session, + signature=signature, + obs_samples=obs_samples, + runs=args.runs, + warmup=args.warmup, + ) + else: + timings_ms, output_shape = _run_controller_like( + session=session, + signature=signature, + obs_samples=obs_samples, + runs=args.runs, + warmup=args.warmup, + ) + + stats = _stats_ms(timings_ms) + budget_ms = 1000.0 / float(args.policy_hz) + over_budget = int(np.count_nonzero(timings_ms > budget_ms)) + _print_summary( + policy_path=policy_path, + args=args, + providers=providers, + available_providers=available_providers, + signature=signature, + input_specs=input_specs, + stats=stats, + output_shape=output_shape, + over_budget=over_budget, + ) + + if args.json is not None: + json_path = Path(args.json).expanduser() + _write_json( + path=json_path, + policy_path=policy_path, + args=args, + providers=providers, + available_providers=available_providers, + signature=signature, + stats=stats, + output_shape=output_shape, + over_budget=over_budget, + ) + print(f"Wrote JSON summary: {json_path}") + + if args.max_p99_ms is not None and stats["p99"] > args.max_p99_ms: + print( + f"FAIL: p99={stats['p99']:.4f}ms exceeds --max-p99-ms={args.max_p99_ms:.4f}ms" + ) + return 2 + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/dev/compute_ik_offsets.py b/scripts/dev/compute_ik_offsets.py index 283817af..74366c17 100644 --- a/scripts/dev/compute_ik_offsets.py +++ b/scripts/dev/compute_ik_offsets.py @@ -33,6 +33,8 @@ import numpy as np from scipy.spatial.transform import Rotation as R +from teleopit.runtime.assets import UNITREE_G1_XML + PROJECT_ROOT = Path(__file__).resolve().parents[2] @@ -184,15 +186,7 @@ def main(): if not config_path.is_absolute(): config_path = PROJECT_ROOT / config_path - xml_path = str( - PROJECT_ROOT - / "teleopit" - / "retargeting" - / "gmr" - / "assets" - / "unitree_g1" - / "g1_mocap_29dof.xml" - ) + xml_path = str(UNITREE_G1_XML) with open(config_path) as f: config = json.load(f) diff --git a/scripts/dev/test_linkerhand_l6.py b/scripts/dev/test_linkerhand_l6.py new file mode 100644 index 00000000..3b40d58c --- /dev/null +++ b/scripts/dev/test_linkerhand_l6.py @@ -0,0 +1,307 @@ +#!/usr/bin/env python3 +"""Exercise LinkerHand dexterous-hand control modes.""" + +from __future__ import annotations + +import argparse +import logging +from pathlib import Path +import sys +import time +from typing import Sequence + + +REPO_ROOT = Path(__file__).resolve().parents[2] +sys.path.insert(0, str(REPO_ROOT)) +SDK_PATH = REPO_ROOT / "third_party" / "linkerhand-python-sdk" +if SDK_PATH.exists(): + sys.path.insert(0, str(SDK_PATH)) +SOMEHAND_SRC_PATH = REPO_ROOT / "third_party" / "somehand" / "src" +if SOMEHAND_SRC_PATH.exists(): + sys.path.insert(0, str(SOMEHAND_SRC_PATH)) + +from teleopit.inputs.pico4_provider import ( # noqa: E402 + Pico4InputProvider, +) +from teleopit.sim2real.hands.linkerhand_l6 import VR_HAND_POSE_SPEED, build_linkerhand_l6 # noqa: E402 +from teleopit.sim2real.hands.linkerhand_o6 import build_linkerhand_o6 # noqa: E402 + + +THUMB_YAW_DEFAULT = 10 +OPEN_POSE = [250, THUMB_YAW_DEFAULT, 250, 250, 250, 250] +CLOSE_POSE = [79, THUMB_YAW_DEFAULT, 0, 0, 0, 0] +DEFAULT_SPEED = [50, 50, 50, 50, 50, 50] +O6_OPEN_POSE = [250, 250, 250, 250, 250, 250] +O6_CLOSE_POSE = [86, 73, 118, 111, 110, 111] +O6_DEFAULT_SPEED = [255, 255, 255, 255, 255, 255] +DEFAULT_SOMEHAND_CONFIG_PATH = "third_party/somehand/configs/retargeting/bihand/linkerhand_l6_bihand.yaml" +OPEN_CLOSE_HOLD_S = 1.0 +GRIPPER_RATE_HZ = 30.0 +VR_HAND_POSE_RATE_HZ = 60.0 +PICO_START_TIMEOUT_S = 60.0 +FRAME_TIMEOUT_S = 0.3 +TRIGGER_DEADZONE = 0.05 +DEADMAN_THRESHOLD = 0.5 + + +def selected_hand_types(hand_type: str) -> tuple[str, ...]: + if hand_type == "both": + return ("left", "right") + return (hand_type,) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Test LinkerHand dexterous-hand control modes") + parser.add_argument( + "--driver", + choices=["linkerhand_l6", "linkerhand_o6"], + default="linkerhand_l6", + help="Hand driver to test. O6 currently supports open_close and gripper only.", + ) + parser.add_argument( + "--mode", + choices=["open_close", "gripper", "vr_hand_pose"], + default="open_close", + help=( + "open_close sends fixed poses directly; gripper reads real Pico controller " + "grip/trigger input; vr_hand_pose reads real Pico hand pose input through Teleopit " + "and uses somehand only for hand retargeting." + ), + ) + parser.add_argument("--hand-type", choices=["left", "right", "both"], default="both") + parser.add_argument("--left-can", default="can0") + parser.add_argument("--right-can", default="can1") + parser.add_argument( + "--modbus", + default="None", + help='RS485 serial port such as /dev/ttyUSB0; "None" uses CAN', + ) + args = parser.parse_args() + if args.driver == "linkerhand_o6" and args.mode == "vr_hand_pose": + raise SystemExit("hands.driver=linkerhand_o6 supports only --mode open_close or gripper") + args.speed = list(O6_DEFAULT_SPEED if args.driver == "linkerhand_o6" else DEFAULT_SPEED) + args.open_pose = list(O6_OPEN_POSE if args.driver == "linkerhand_o6" else OPEN_POSE) + args.close_pose = list(O6_CLOSE_POSE if args.driver == "linkerhand_o6" else CLOSE_POSE) + return args + + +def make_config(args: argparse.Namespace, *, mode: str) -> dict[str, object]: + speed = VR_HAND_POSE_SPEED if mode == "vr_hand_pose" else args.speed + rate_hz = VR_HAND_POSE_RATE_HZ if mode == "vr_hand_pose" else GRIPPER_RATE_HZ + driver_section = "linkerhand_o6" if args.driver == "linkerhand_o6" else "linkerhand_l6" + driver_cfg = { + "left_can": args.left_can, + "right_can": args.right_can, + "modbus": args.modbus, + "trigger_deadzone": TRIGGER_DEADZONE, + "deadman_threshold": DEADMAN_THRESHOLD, + "speed": list(speed), + "open_pose": list(args.open_pose), + "close_pose": list(args.close_pose), + "print_input": False, + } + if args.driver == "linkerhand_l6": + driver_cfg["thumb_yaw_center"] = THUMB_YAW_DEFAULT + return { + "input": {"provider": "pico4"}, + "hands": { + "enabled": True, + "driver": args.driver, + "mode": mode, + "sides": list(selected_hand_types(args.hand_type)), + "rate_hz": rate_hz, + "frame_timeout_s": FRAME_TIMEOUT_S, + driver_section: driver_cfg, + "somehand": { + "config_path": DEFAULT_SOMEHAND_CONFIG_PATH, + "rate_hz": VR_HAND_POSE_RATE_HZ, + "max_iterations": 12, + "temporal_filter_alpha": 1.0, + "output_alpha": 1.0, + }, + }, + } + + +def build_driver_runtime(config: dict[str, object], *, driver: str): + if driver == "linkerhand_o6": + return build_linkerhand_o6(config) + return build_linkerhand_l6(config) + + +def send_all(hands: dict[str, object], pose: Sequence[int], *, label: str) -> None: + print(f"{label}: {list(pose)}", flush=True) + for hand_type, hand in hands.items(): + print(f" {hand_type}", flush=True) + hand.finger_move(pose=list(pose)) + + +def make_pico_provider() -> Pico4InputProvider: + return Pico4InputProvider( + timeout=PICO_START_TIMEOUT_S, + pause_button=None, + bridge_host="0.0.0.0", + bridge_port=63901, + bridge_discovery=True, + bridge_advertise_ip=None, + bridge_start_timeout=10.0, + bridge_video=None, + bridge_video_enabled=False, + ) + + +def run_live_until_done( + runtime: object, + *, + provider: Pico4InputProvider, + mode_label: str, + rate_hz: float, +) -> None: + last_seq: int | None = None + print(f"Running {mode_label}; press Ctrl-C to stop.", flush=True) + while True: + now_s = time.monotonic() + controller_snapshot = provider.get_controller_snapshot() + hand_snapshot = provider.get_hand_snapshot() + runtime.tick( + controller_snapshot=controller_snapshot, + hand_snapshot=hand_snapshot, + active=True, + now_s=now_s, + ) + snapshot = controller_snapshot if mode_label == "gripper" else hand_snapshot + if snapshot is not None and snapshot.seq != last_seq: + last_seq = snapshot.seq + age_ms = max((now_s - snapshot.timestamp_s) * 1000.0, 0.0) + print(f" pico seq={snapshot.seq} age={age_ms:.1f}ms", flush=True) + time.sleep(max(1.0 / rate_hz, 0.001)) + + +def run_open_close(args: argparse.Namespace) -> None: + try: + from LinkerHand.linker_hand_api import LinkerHandApi + except ImportError as exc: + raise SystemExit( + "LinkerHand SDK import failed. Run: " + "git submodule update --init third_party/linkerhand-python-sdk && " + "pip install -e third_party/linkerhand-python-sdk" + ) from exc + + hand_types = selected_hand_types(args.hand_type) + can_channels = {"left": args.left_can, "right": args.right_can} + hands: dict[str, object] = {} + + print( + f"Testing {args.driver} | " + f"hands={','.join(hand_types)} | " + f"can={','.join(f'{hand}:{can_channels[hand]}' for hand in hand_types)} | " + f"modbus={args.modbus}", + flush=True, + ) + try: + for hand_type in hand_types: + hand = LinkerHandApi( + hand_joint="O6" if args.driver == "linkerhand_o6" else "L6", + hand_type=hand_type, + modbus=args.modbus, + can=can_channels[hand_type], + ) + hand.set_speed(speed=list(args.speed)) + hands[hand_type] = hand + + send_all(hands, args.open_pose, label="startup open") + time.sleep(OPEN_CLOSE_HOLD_S) + cycle = 0 + while True: + cycle += 1 + print(f"cycle {cycle}", flush=True) + send_all(hands, args.close_pose, label="close") + time.sleep(OPEN_CLOSE_HOLD_S) + send_all(hands, args.open_pose, label="open") + time.sleep(OPEN_CLOSE_HOLD_S) + except KeyboardInterrupt: + print("Interrupted; opening hands before exit", flush=True) + finally: + if hands: + send_all(hands, args.open_pose, label="exit open") + time.sleep(0.2) + print( + "Exit cleanup intentionally leaves CAN interfaces up to avoid SDK network-down noise.", + flush=True, + ) + + +def run_gripper(args: argparse.Namespace) -> None: + config = make_config(args, mode="gripper") + provider = make_pico_provider() + device, mapper = build_driver_runtime(config, driver=args.driver) + from teleopit.sim2real.hands.worker import HandRuntime + runtime = HandRuntime(device, mapper) + + print( + "Testing hands.mode=gripper with real Pico controller input. " + "Hold grip above the deadman threshold, then use trigger to close/open.", + flush=True, + ) + try: + runtime.start() + run_live_until_done( + runtime, + provider=provider, + mode_label="gripper", + rate_hz=GRIPPER_RATE_HZ, + ) + except KeyboardInterrupt: + print("Interrupted; opening hands before exit", flush=True) + finally: + runtime.tick(controller_snapshot=None, hand_snapshot=None, active=False) + runtime.close() + provider.close() + + +def run_vr_hand_pose(args: argparse.Namespace) -> None: + if args.hand_type != "both": + raise SystemExit("hands.mode=vr_hand_pose currently requires --hand-type both") + + config = make_config(args, mode="vr_hand_pose") + provider = make_pico_provider() + device, mapper = build_driver_runtime(config, driver=args.driver) + from teleopit.sim2real.hands.worker import HandRuntime + runtime = HandRuntime(device, mapper) + + print( + "Testing hands.mode=vr_hand_pose with real Pico hand-pose input. " + "Enable Pico hand tracking and move both hands; start with the robot clear of contacts.", + flush=True, + ) + try: + runtime.start() + run_live_until_done( + runtime, + provider=provider, + mode_label="vr_hand_pose", + rate_hz=VR_HAND_POSE_RATE_HZ, + ) + except KeyboardInterrupt: + print("Interrupted; opening hands before exit", flush=True) + finally: + runtime.tick(controller_snapshot=None, hand_snapshot=None, active=False) + runtime.close() + provider.close() + + +def main() -> None: + logging.basicConfig(level=logging.INFO, format="%(levelname)s:%(name)s:%(message)s") + args = parse_args() + if args.mode == "open_close": + run_open_close(args) + elif args.mode == "gripper": + run_gripper(args) + elif args.mode == "vr_hand_pose": + run_vr_hand_pose(args) + else: + raise AssertionError(f"Unhandled mode: {args.mode}") + + +if __name__ == "__main__": + main() diff --git a/scripts/render/render_motion_npz.py b/scripts/render/render_motion_npz.py index 40b1c8d2..12cb306c 100644 --- a/scripts/render/render_motion_npz.py +++ b/scripts/render/render_motion_npz.py @@ -27,7 +27,7 @@ def parse_args() -> argparse.Namespace: "--robot", type=str, default="unitree_g1", - help="Robot viewer type. For lafan1_v1 clips this should usually stay as unitree_g1.", + help="Robot viewer type. For lafan1 clips this should usually stay as unitree_g1.", ) parser.add_argument("--start", type=int, default=0, help="Start frame index") parser.add_argument("--end", type=int, default=-1, help="Exclusive end frame index; -1 means full clip") diff --git a/scripts/render/render_sim.py b/scripts/render/render_sim.py index 94d71262..9ec33146 100644 --- a/scripts/render/render_sim.py +++ b/scripts/render/render_sim.py @@ -31,6 +31,7 @@ import mujoco # noqa: E402 from teleopit.debug.rollout_trace import RolloutTraceWriter # noqa: E402 +from teleopit.runtime.assets import UNITREE_G1_XML # noqa: E402 from teleopit.sim.mocap_mujoco import ( # noqa: E402 MocapSkeletonSceneDrawer, fit_mocap_camera, @@ -191,17 +192,7 @@ def render_retarget( project_root = _find_project_root() cfgs = _load_configs(str(bvh_path), project_root, bvh_format, policy_path=None) - # Retarget rendering uses the GMR mocap XML (no actuator limits needed) - mocap_xml = ( - project_root - / "teleopit" - / "retargeting" - / "gmr" - / "assets" - / "unitree_g1" - / "g1_mocap_29dof.xml" - ) - cfgs["robot"].xml_path = str(mocap_xml) + cfgs["robot"].xml_path = str(UNITREE_G1_XML) from teleopit.inputs import BVHInputProvider from teleopit.retargeting.core import RetargetingModule @@ -306,7 +297,6 @@ def render_sim2sim( "policy_hz": POLICY_HZ, "pd_hz": PD_HZ, "debug_trace_path": debug_trace_path, - "recording": {"output_path": "/tmp/teleopit_render.h5"}, } ) pipeline = TeleopPipeline(cfg) diff --git a/scripts/review/__init__.py b/scripts/review/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/scripts/review/build_dataset_from_review.py b/scripts/review/build_dataset_from_review.py deleted file mode 100644 index 9018f313..00000000 --- a/scripts/review/build_dataset_from_review.py +++ /dev/null @@ -1,217 +0,0 @@ -#!/usr/bin/env python3 -"""Rebuild train/val shard directories from a filtered manifest (review results). - -Reads filtered_manifest.csv (output of export_reviewed_manifest.py), -verifies all NPZ files exist, and rebuilds cleaned train/val shard splits. - -Usage: - python scripts/data/build_dataset_from_review.py \ - --filtered_manifest data/datasets/review/twist2_full/filtered_manifest.csv \ - --output_dir data/datasets/builds/twist2_full_cleaned -""" - -from __future__ import annotations - -import argparse -import csv -import sys -from pathlib import Path - -import numpy as np - -PROJECT_ROOT = Path(__file__).resolve().parents[2] -sys.path.insert(0, str(PROJECT_ROOT)) - -from train_mimic.data.dataset_lib import ( - extract_clip_arrays, - merge_clip_dicts, - merge_npz_files, - utc_now_iso, - write_json, -) - - -def main() -> None: - parser = argparse.ArgumentParser( - description="Build cleaned dataset from filtered manifest" - ) - parser.add_argument( - "--filtered_manifest", type=str, required=True, - help="Path to filtered_manifest.csv", - ) - parser.add_argument( - "--output_dir", type=str, required=True, - help="Output directory, e.g. data/datasets/builds/twist2_full_cleaned", - ) - parser.add_argument( - "--target_fps", type=int, default=None, - help="Resample all clips to this FPS. " - "Required when source clips have mixed FPS values.", - ) - args = parser.parse_args() - - manifest_path = Path(args.filtered_manifest) - if not manifest_path.is_absolute(): - manifest_path = (PROJECT_ROOT / manifest_path).resolve() - - output_dir = Path(args.output_dir) - if not output_dir.is_absolute(): - output_dir = (PROJECT_ROOT / output_dir).resolve() - - if not manifest_path.is_file(): - print(f"ERROR: filtered manifest not found: {manifest_path}", file=sys.stderr) - sys.exit(1) - - # Read filtered manifest - # Columns: clip_id, source, file_rel, num_frames, fps, resolved_split, resolved_npz_path, - # weight, clip_index - rows = [] - with manifest_path.open("r", encoding="utf-8", newline="") as f: - reader = csv.DictReader(f) - has_clip_index = "clip_index" in (reader.fieldnames or []) - for idx, raw in enumerate(reader, start=2): - rows.append({ - "clip_id": raw["clip_id"].strip(), - "source": raw["source"].strip(), - "file_rel": raw["file_rel"].strip(), - "num_frames": int(raw["num_frames"]), - "fps": int(raw["fps"]), - "resolved_split": raw["resolved_split"].strip(), - "resolved_npz_path": raw["resolved_npz_path"].strip(), - "weight": float(raw["weight"]), - "clip_index": int(raw["clip_index"]) if has_clip_index else -1, - "line_no": idx, - }) - - if not rows: - print("ERROR: filtered manifest has no data rows", file=sys.stderr) - sys.exit(1) - - # Verify all NPZ files exist - missing = [] - for row in rows: - p = Path(row["resolved_npz_path"]) - if not p.is_file(): - # Try resolving from file_rel - alt = Path(row["file_rel"]) - if not alt.is_absolute(): - alt = PROJECT_ROOT / alt - if alt.is_file(): - row["resolved_npz_path"] = str(alt.resolve()) - else: - missing.append(f" line {row['line_no']}: {row['clip_id']} -> {row['resolved_npz_path']}") - - if missing: - print(f"ERROR: {len(missing)} NPZ files not found:", file=sys.stderr) - for m in missing[:20]: - print(m, file=sys.stderr) - if len(missing) > 20: - print(f" ... and {len(missing) - 20} more", file=sys.stderr) - sys.exit(1) - - # Check for mixed FPS - fps_values = sorted(set(r["fps"] for r in rows)) - if len(fps_values) > 1: - print(f"[INFO] Mixed FPS detected: {fps_values}. Resampling all clips to {args.target_fps} FPS.") - if args.target_fps is None: - print( - "ERROR: clips have mixed FPS values but --target_fps is not set.\n" - "Please specify --target_fps (e.g. --target_fps 30).", - file=sys.stderr, - ) - sys.exit(1) - - # Split into train / val - train_rows = [r for r in rows if r["resolved_split"] == "train"] - val_rows = [r for r in rows if r["resolved_split"] == "val"] - - if not train_rows: - print("ERROR: no train clips in filtered manifest", file=sys.stderr) - sys.exit(1) - if not val_rows: - print("WARNING: no val clips in filtered manifest", file=sys.stderr) - - output_dir.mkdir(parents=True, exist_ok=True) - - # Check if any rows use clip_index (batch-built dataset) - has_indexed_clips = any(r["clip_index"] >= 0 for r in rows) - - def _merge_split(split_rows: list[dict], split_name: str) -> dict | None: - if not split_rows: - return None - print(f"Merging {len(split_rows)} {split_name} clips...") - split_dir = output_dir / split_name - split_dir.mkdir(parents=True, exist_ok=True) - out = split_dir / "shard_000.npz" - - if has_indexed_clips: - # Extract individual clip slices from shard NPZ files - clip_dicts = [] - for r in split_rows: - npz_path = Path(r["resolved_npz_path"]) - ci = r["clip_index"] - if ci >= 0: - clip_dicts.append(extract_clip_arrays(npz_path, ci)) - else: - # Standalone clip — load entire file as one clip dict - d = np.load(npz_path, allow_pickle=True) - clip_dicts.append({ - "fps": int(d["fps"]), - "joint_pos": np.asarray(d["joint_pos"]), - "joint_vel": np.asarray(d["joint_vel"]), - "body_pos_w": np.asarray(d["body_pos_w"]), - "body_quat_w": np.asarray(d["body_quat_w"]), - "body_lin_vel_w": np.asarray(d["body_lin_vel_w"]), - "body_ang_vel_w": np.asarray(d["body_ang_vel_w"]), - "body_names": np.asarray(d["body_names"]), - }) - weights_list = [r["weight"] for r in split_rows] - stats = merge_clip_dicts( - clip_dicts, out, - target_fps=args.target_fps, weights=weights_list, - ) - else: - # Legacy per-clip NPZ files — use file-based merge - files = [Path(r["resolved_npz_path"]) for r in split_rows] - weights_list = [r["weight"] for r in split_rows] - stats = merge_npz_files( - files, out, - target_fps=args.target_fps, weights=weights_list, - ) - - stats["output"] = str(split_dir) - stats["shards"] = 1 - print(f" {split_name}/: {stats['frames']} frames, {stats['duration_s'] / 60:.1f} min") - return stats - - train_stats = _merge_split(train_rows, "train") - val_stats = _merge_split(val_rows, "val") - - # Copy manifest into output dir - import shutil - shutil.copy2(manifest_path, output_dir / "manifest_resolved.csv") - - # Write build info - report = { - "built_at_utc": utc_now_iso(), - "source_manifest": str(manifest_path), - "output_dir": str(output_dir), - "target_fps": args.target_fps, - "clip_counts": { - "total": len(rows), - "train": len(train_rows), - "val": len(val_rows), - }, - "splits": { - "train": train_stats, - "val": val_stats, - }, - } - write_json(output_dir / "build_info.json", report) - - print(f"\nCleaned dataset built at: {output_dir}") - print(f" Total clips: {len(rows)} (train={len(train_rows)}, val={len(val_rows)})") - - -if __name__ == "__main__": - main() diff --git a/scripts/review/export_reviewed_manifest.py b/scripts/review/export_reviewed_manifest.py deleted file mode 100644 index b2a893f5..00000000 --- a/scripts/review/export_reviewed_manifest.py +++ /dev/null @@ -1,130 +0,0 @@ -#!/usr/bin/env python3 -"""Export a filtered manifest from review results. - -Reads review_state.csv, keeps only clips with decision == 'keep', -and outputs filtered_manifest.csv + review_summary.json. - -Usage: - python scripts/data/export_reviewed_manifest.py \ - --review data/datasets/review/twist2_full/review_state.csv \ - --output data/datasets/review/twist2_full/filtered_manifest.csv -""" - -from __future__ import annotations - -import argparse -import csv -import sys -from pathlib import Path - -PROJECT_ROOT = Path(__file__).resolve().parents[2] -sys.path.insert(0, str(PROJECT_ROOT)) - -from train_mimic.data.dataset_lib import write_json -from train_mimic.data.review_lib import ( - compute_review_stats, - load_review_state, - utc_now_iso, -) - - -def main() -> None: - parser = argparse.ArgumentParser(description="Export filtered manifest from review") - parser.add_argument("--review", type=str, required=True, help="Path to review_state.csv") - parser.add_argument( - "--output", type=str, default=None, - help="Path for filtered_manifest.csv (default: same dir as review)", - ) - parser.add_argument( - "--summary", type=str, default=None, - help="Path for review_summary.json (default: same dir as review)", - ) - parser.add_argument( - "--require_complete", action="store_true", - help="Fail if any clips are unreviewed", - ) - args = parser.parse_args() - - review_path = Path(args.review) - if not review_path.is_absolute(): - review_path = (PROJECT_ROOT / review_path).resolve() - - rows = load_review_state(review_path) - stats = compute_review_stats(rows) - - # Check completeness - unreviewed = stats.total - stats.reviewed - if args.require_complete and unreviewed > 0: - print( - f"ERROR: {unreviewed} clips are unreviewed. " - "Complete the review or remove --require_complete.", - file=sys.stderr, - ) - sys.exit(1) - - # Filter to keep only - kept_rows = [r for r in rows if r.decision == "keep"] - - # Output paths - review_dir = review_path.parent - if args.output: - output_path = Path(args.output) - if not output_path.is_absolute(): - output_path = (PROJECT_ROOT / output_path).resolve() - else: - output_path = review_dir / "filtered_manifest.csv" - - if args.summary: - summary_path = Path(args.summary) - if not summary_path.is_absolute(): - summary_path = (PROJECT_ROOT / summary_path).resolve() - else: - summary_path = review_dir / "review_summary.json" - - # Write filtered_manifest.csv in manifest_resolved.csv format - output_path.parent.mkdir(parents=True, exist_ok=True) - with output_path.open("w", encoding="utf-8", newline="") as f: - writer = csv.writer(f) - writer.writerow([ - "clip_id", "source", "file_rel", "num_frames", "fps", - "resolved_split", "resolved_npz_path", "weight", "clip_index", - ]) - for r in sorted(kept_rows, key=lambda x: x.clip_id): - writer.writerow([ - r.clip_id, r.source, r.file_rel, r.num_frames, r.fps, - r.resolved_split, r.resolved_npz_path, r.weight, r.clip_index, - ]) - - # Write summary JSON - summary = { - "exported_at_utc": utc_now_iso(), - "review_file": str(review_path), - "output_file": str(output_path), - "total_clips": stats.total, - "reviewed_clips": stats.reviewed, - "unreviewed_clips": unreviewed, - "keep_count": stats.keep_count, - "drop_count": stats.drop_count, - "skip_count": stats.skip_count, - "progress_pct": stats.progress_pct, - "kept_duration_s": stats.kept_duration_s, - "kept_duration_min": stats.kept_duration_s / 60.0, - "kept_train_duration_s": stats.kept_train_duration_s, - "kept_val_duration_s": stats.kept_val_duration_s, - "kept_duration_by_source": { - src: {"duration_s": dur, "duration_min": dur / 60.0} - for src, dur in sorted(stats.kept_duration_by_source.items()) - }, - } - write_json(summary_path, summary) - - print(f"Exported filtered manifest: {output_path}") - print(f" Kept clips: {len(kept_rows)}") - print(f" Kept duration: {stats.kept_duration_s / 60:.1f} min") - print(f" Train: {stats.kept_train_duration_s / 60:.1f} min") - print(f" Val: {stats.kept_val_duration_s / 60:.1f} min") - print(f"Summary: {summary_path}") - - -if __name__ == "__main__": - main() diff --git a/scripts/review/init_review_manifest.py b/scripts/review/init_review_manifest.py deleted file mode 100644 index afd5c01b..00000000 --- a/scripts/review/init_review_manifest.py +++ /dev/null @@ -1,125 +0,0 @@ -#!/usr/bin/env python3 -"""Initialize a review_state.csv from an existing manifest_resolved.csv. - -Usage: - python scripts/data/init_review_manifest.py \ - --dataset twist2_full \ - --manifest data/datasets/builds/twist2_full/manifest_resolved.csv -""" - -from __future__ import annotations - -import argparse -import csv -import sys -from pathlib import Path - -PROJECT_ROOT = Path(__file__).resolve().parents[2] -sys.path.insert(0, str(PROJECT_ROOT)) - -from train_mimic.data.review_lib import ReviewRow, save_review_state - - -def main() -> None: - parser = argparse.ArgumentParser(description="Initialize review state from manifest") - parser.add_argument("--dataset", type=str, required=True, help="Dataset name, e.g. twist2_full") - parser.add_argument( - "--manifest", type=str, required=True, help="Path to manifest_resolved.csv" - ) - parser.add_argument( - "--output", - type=str, - default=None, - help="Output path (default: data/datasets/review/{dataset}/review_state.csv)", - ) - parser.add_argument("--force", action="store_true", help="Overwrite existing review file") - args = parser.parse_args() - - manifest_path = Path(args.manifest) - if not manifest_path.is_absolute(): - manifest_path = (PROJECT_ROOT / manifest_path).resolve() - if not manifest_path.is_file(): - print(f"ERROR: manifest not found: {manifest_path}", file=sys.stderr) - sys.exit(1) - - if args.output: - output_path = Path(args.output) - if not output_path.is_absolute(): - output_path = (PROJECT_ROOT / output_path).resolve() - else: - output_path = PROJECT_ROOT / "data" / "datasets" / "review" / args.dataset / "review_state.csv" - - if output_path.is_file() and not args.force: - print( - f"ERROR: review file already exists: {output_path}\n" - "Use --force to overwrite.", - file=sys.stderr, - ) - sys.exit(1) - - # Read manifest_resolved.csv - # Columns: clip_id, source, file_rel, num_frames, fps, resolved_split, resolved_npz_path, - # weight, clip_index - rows: list[ReviewRow] = [] - with manifest_path.open("r", encoding="utf-8", newline="") as f: - reader = csv.DictReader(f) - if reader.fieldnames is None: - print(f"ERROR: manifest is empty: {manifest_path}", file=sys.stderr) - sys.exit(1) - - required = ["clip_id", "source", "file_rel", "num_frames", "fps", "resolved_split", "resolved_npz_path", "weight"] - missing = [c for c in required if c not in reader.fieldnames] - if missing: - print(f"ERROR: manifest missing columns: {missing}", file=sys.stderr) - sys.exit(1) - has_clip_index = "clip_index" in reader.fieldnames - - for idx, raw in enumerate(reader, start=2): - clip_id = raw["clip_id"].strip() - source = raw["source"].strip() - file_rel = raw["file_rel"].strip() - resolved_npz_path = raw["resolved_npz_path"].strip() - num_frames = int(raw["num_frames"]) - fps = int(raw["fps"]) - resolved_split = raw["resolved_split"].strip() - weight = float(raw["weight"]) - clip_index = int(raw["clip_index"]) if has_clip_index else -1 - duration_s = num_frames / fps if fps > 0 else 0.0 - - rows.append( - ReviewRow( - clip_id=clip_id, - source=source, - file_rel=file_rel, - resolved_npz_path=resolved_npz_path, - resolved_split=resolved_split, - num_frames=num_frames, - fps=fps, - duration_s=duration_s, - weight=weight, - clip_index=clip_index, - ) - ) - - if not rows: - print("ERROR: manifest has no data rows", file=sys.stderr) - sys.exit(1) - - save_review_state(rows, output_path) - - # Print summary - total_duration_s = sum(r.duration_s for r in rows) - sources = {} - for r in rows: - sources[r.source] = sources.get(r.source, 0) + 1 - - print(f"Initialized review state: {output_path}") - print(f" Total clips: {len(rows)}") - print(f" Total duration: {total_duration_s / 60:.1f} min ({total_duration_s / 3600:.1f} h)") - print(f" Sources:") - for src, count in sorted(sources.items()): - print(f" {src}: {count} clips") - - -if __name__ == "__main__": - main() diff --git a/scripts/review/review_dataset.py b/scripts/review/review_dataset.py deleted file mode 100644 index c0dfa8a2..00000000 --- a/scripts/review/review_dataset.py +++ /dev/null @@ -1,748 +0,0 @@ -#!/usr/bin/env python3 -"""Web-based dataset clip review tool using viser. - -Plays reference motion clips one-by-one in a browser and provides -GUI controls for annotation (keep/drop/skip, difficulty, notes). - -Usage: - python scripts/review/review_dataset.py \ - --dataset lafan1_v1 \ - --review data/datasets/review/lafan1_v1/review_state.csv -""" - -from __future__ import annotations - -import argparse -import sys -import time -from pathlib import Path -from threading import Lock - - -import mujoco -import numpy as np -import viser - -PROJECT_ROOT = Path(__file__).resolve().parents[2] -sys.path.insert(0, str(PROJECT_ROOT)) - -from mjlab.viewer.viser import ViserMujocoScene - -from teleopit.runtime.assets import UNITREE_G1_MJLAB_XML, missing_gmr_assets_message -from train_mimic.data.review_lib import ( - ReviewRow, - ReviewStats, - compute_review_stats, - load_review_state, - save_review_state, - utc_now_iso, -) - -DEFAULT_XML = UNITREE_G1_MJLAB_XML - - -# --------------------------------------------------------------------------- -# ClipPlayer: loads NPZ clip and drives MuJoCo qpos per frame -# --------------------------------------------------------------------------- - -class ClipPlayer: - """Loads a single clip NPZ and sets MuJoCo qpos frame-by-frame.""" - - def __init__(self, mj_model: mujoco.MjModel) -> None: - self.model = mj_model - self.data = mujoco.MjData(mj_model) - self._joint_pos: np.ndarray | None = None # (T, 29) - self._pelvis_pos: np.ndarray | None = None # (T, 3) - self._pelvis_quat: np.ndarray | None = None # (T, 4) wxyz - self._fps: int = 30 - self._num_frames: int = 0 - # Cache for shard NPZ: avoid re-reading large shard files on every clip switch - self._cached_npz_path: str | None = None - self._cached_npz_data: dict[str, np.ndarray] | None = None - - def _get_npz_data(self, npz_path: Path) -> dict[str, np.ndarray]: - """Return NPZ data, using cache for shard files.""" - path_str = str(npz_path) - if self._cached_npz_path == path_str and self._cached_npz_data is not None: - return self._cached_npz_data - d = dict(np.load(path_str, allow_pickle=True)) - # Only cache shard NPZ files (those with clip_starts) - if "clip_starts" in d: - self._cached_npz_path = path_str - self._cached_npz_data = d - else: - self._cached_npz_path = None - self._cached_npz_data = None - return d - - def load_clip(self, npz_path: Path, clip_index: int = -1) -> None: - """Load NPZ clip data. - - Args: - npz_path: Path to NPZ file (standalone clip or shard file). - clip_index: If >= 0, extract this clip from a shard NPZ using - clip_starts/clip_lengths. If -1, load the entire file - as a single clip. - """ - d = self._get_npz_data(npz_path) - - if clip_index >= 0 and "clip_starts" in d and "clip_lengths" in d: - start = int(d["clip_starts"][clip_index]) - length = int(d["clip_lengths"][clip_index]) - s = slice(start, start + length) - self._joint_pos = np.asarray(d["joint_pos"][s]) - body_pos_w = np.asarray(d["body_pos_w"][s]) - body_quat_w = np.asarray(d["body_quat_w"][s]) - else: - self._joint_pos = np.asarray(d["joint_pos"]) - body_pos_w = np.asarray(d["body_pos_w"]) - body_quat_w = np.asarray(d["body_quat_w"]) - - self._pelvis_pos = body_pos_w[:, 0, :] # pelvis = body 0 - self._pelvis_quat = body_quat_w[:, 0, :] - self._fps = int(d["fps"]) - self._num_frames = self._joint_pos.shape[0] - - def set_frame(self, frame_idx: int) -> None: - """Set qpos from frame data and run mj_forward.""" - if self._joint_pos is None: - return - idx = max(0, min(frame_idx, self._num_frames - 1)) - self.data.qpos[:3] = self._pelvis_pos[idx] - self.data.qpos[3:7] = self._pelvis_quat[idx] - self.data.qpos[7:] = self._joint_pos[idx] - mujoco.mj_forward(self.model, self.data) - - @property - def fps(self) -> int: - return self._fps - - @property - def num_frames(self) -> int: - return self._num_frames - - @property - def duration_s(self) -> float: - return self._num_frames / self._fps if self._fps > 0 else 0.0 - - -# --------------------------------------------------------------------------- -# ReviewSession: manages review state, navigation, persistence -# --------------------------------------------------------------------------- - -class ReviewSession: - """Manages review state, navigation order, and persistence.""" - - def __init__( - self, - review_path: Path, - sort_mode: str = "unreviewed_first", - ) -> None: - self._review_path = review_path - self._rows: list[ReviewRow] = load_review_state(review_path) - self._order: list[int] = [] # indices into _rows - self._cursor: int = 0 - self._lock = Lock() - self._reorder(sort_mode) - - @property - def total(self) -> int: - return len(self._rows) - - @property - def cursor_display(self) -> int: - """1-based position in the current ordering.""" - return self._cursor + 1 - - def current_row(self) -> ReviewRow | None: - if not self._order: - return None - return self._rows[self._order[self._cursor]] - - def go_next(self) -> ReviewRow | None: - if self._cursor < len(self._order) - 1: - self._cursor += 1 - return self.current_row() - - def go_prev(self) -> ReviewRow | None: - if self._cursor > 0: - self._cursor -= 1 - return self.current_row() - - def go_next_unreviewed(self) -> ReviewRow | None: - """Jump to the next unreviewed clip after the current cursor.""" - for i in range(self._cursor + 1, len(self._order)): - if self._rows[self._order[i]].decision == "": - self._cursor = i - return self.current_row() - # Wrap around from beginning - for i in range(0, self._cursor): - if self._rows[self._order[i]].decision == "": - self._cursor = i - return self.current_row() - return self.current_row() - - def jump_to(self, position: int) -> ReviewRow | None: - """Jump to a 1-based position in the ordering.""" - idx = max(0, min(position - 1, len(self._order) - 1)) - self._cursor = idx - return self.current_row() - - def annotate( - self, - decision: str, - difficulty: str = "", - issue_tags: str = "", - note: str = "", - ) -> None: - """Set annotation on current row and save to disk.""" - with self._lock: - row = self.current_row() - if row is None: - return - row.decision = decision - row.difficulty = difficulty - row.issue_tags = issue_tags - row.note = note - row.reviewed_at = utc_now_iso() - self.save() - - def save(self) -> None: - save_review_state(self._rows, self._review_path) - - def stats(self) -> ReviewStats: - return compute_review_stats(self._rows) - - def _reorder(self, sort_mode: str) -> None: - indices = list(range(len(self._rows))) - - if sort_mode == "unreviewed_first": - indices.sort(key=lambda i: ( - 0 if self._rows[i].decision == "" else 1, - self._rows[i].source, - self._rows[i].clip_id, - )) - elif sort_mode == "source": - indices.sort(key=lambda i: (self._rows[i].source, self._rows[i].clip_id)) - elif sort_mode == "duration_desc": - indices.sort(key=lambda i: -self._rows[i].duration_s) - else: - indices.sort(key=lambda i: self._rows[i].clip_id) - - self._order = indices - self._cursor = 0 - - -# --------------------------------------------------------------------------- -# ReviewViewerApp: viser server + GUI + main loop -# --------------------------------------------------------------------------- - -class ReviewViewerApp: - """Main application: ties viser, ClipPlayer, and ReviewSession together.""" - - def __init__( - self, - review_path: Path, - xml_path: Path, - project_root: Path, - *, - port: int = 8012, - sort_mode: str = "unreviewed_first", - ) -> None: - self._project_root = project_root - self._model = mujoco.MjModel.from_xml_path(str(xml_path)) - self._player = ClipPlayer(self._model) - self._session = ReviewSession(review_path, sort_mode) - - self._server = viser.ViserServer(port=port, label="Clip Review") - self._scene = ViserMujocoScene.create( - server=self._server, mj_model=self._model, num_envs=1, - ) - - # Pre-warm viser's cached_property type hint resolution at a shallow - # call stack. Without this, the first update_from_mjdata triggered - # from a deep callback chain causes RecursionError in Python 3.10's - # get_type_hints / _eval_type. - mujoco.mj_forward(self._model, self._player.data) - self._scene.update_from_mjdata(self._player.data) - - # Playback state - self._playing: bool = False - self._speed: float = 1.0 - self._current_frame: int = 0 - - # Pending action queue: callbacks only append here, main loop processes. - # This avoids deep recursion inside viser's callback / message queue. - self._pending_actions: list[str] = [] - self._pending_frame_scrub: int | None = None # from slider drag - self._pending_jump: int | None = None # from jump input - self._hotkey_clearing: bool = False # guard against recursive on_update - - def setup_gui(self) -> None: - """Build all viser GUI elements. Callbacks only set flags.""" - gui = self._server.gui - - # --- Clip Info --- - with gui.add_folder("Clip Info", order=0): - self._info_html = gui.add_html("Loading...") - - # --- Playback --- - with gui.add_folder("Playback", order=1): - self._play_btn = gui.add_button("Play", color="green") - self._frame_slider = gui.add_slider( - "Frame", min=0, max=1, step=1, initial_value=0, - ) - self._speed_group = gui.add_button_group( - "Speed", options=["0.25x", "0.5x", "1x", "2x"], - ) - self._restart_btn = gui.add_button("Restart") - - @self._play_btn.on_click - def _(_) -> None: - self._pending_actions.append("toggle_play") - - @self._frame_slider.on_update - def _(_) -> None: - self._pending_frame_scrub = int(self._frame_slider.value) - - @self._speed_group.on_click - def _(event) -> None: - speed_map = {"0.25x": 0.25, "0.5x": 0.5, "1x": 1.0, "2x": 2.0} - self._speed = speed_map.get(event.target.value, 1.0) - - @self._restart_btn.on_click - def _(_) -> None: - self._pending_actions.append("restart") - - # --- Annotation --- - with gui.add_folder("Annotation", order=2): - self._decision_dropdown = gui.add_dropdown( - "Decision", - options=["", "Keep", "Drop", "Skip"], - initial_value="", - hint="Must select Keep/Drop/Skip before saving", - ) - self._difficulty_dropdown = gui.add_dropdown( - "Difficulty", - options=["", "easy", "medium", "hard", "bad_data"], - initial_value="", - ) - self._tags_input = gui.add_text("Issue Tags", initial_value="") - self._note_input = gui.add_text("Note", initial_value="") - self._save_next_btn = gui.add_button("Save & Next", color="blue") - self._save_btn = gui.add_button("Save") - - @self._save_next_btn.on_click - def _(_) -> None: - self._pending_actions.append("save_next") - - @self._save_btn.on_click - def _(_) -> None: - self._pending_actions.append("save") - - # --- Navigation --- - with gui.add_folder("Navigation", order=3): - self._prev_btn = gui.add_button("Prev") - self._next_btn = gui.add_button("Next") - self._next_unreviewed_btn = gui.add_button("Next Unreviewed", color="orange") - self._jump_input = gui.add_number( - "Jump to #", initial_value=1, min=1, max=self._session.total, step=1, - ) - - @self._prev_btn.on_click - def _(_) -> None: - self._pending_actions.append("prev") - - @self._next_btn.on_click - def _(_) -> None: - self._pending_actions.append("next") - - @self._next_unreviewed_btn.on_click - def _(_) -> None: - self._pending_actions.append("next_unreviewed") - - @self._jump_input.on_update - def _(_) -> None: - self._pending_jump = int(self._jump_input.value) - - # --- Stats --- - with gui.add_folder("Stats", order=4, expand_by_default=False): - self._stats_html = gui.add_html("Loading...") - - # --- Keyboard Shortcuts --- - with gui.add_folder("Shortcuts", order=5): - self._hotkey_input = gui.add_text( - "Hotkey (click here, type key)", - initial_value="", - ) - gui.add_markdown( - "**K**=Keep+Next **D**=Drop+Next **S**=Skip+Next\n\n" - "**N**=Next **P**=Prev **U**=Next Unreviewed\n\n" - "**Space**=Play/Pause **R**=Restart **F**=Speed Up\n\n" - "**1**=Easy **2**=Medium **3**=Hard" - ) - - @self._hotkey_input.on_update - def _(_) -> None: - if self._hotkey_clearing: - return - raw = self._hotkey_input.value - self._hotkey_clearing = True - self._hotkey_input.value = "" - self._hotkey_clearing = False - if not raw: - return - ch = raw[-1].lower() - key_map = { - "k": "hotkey_keep", - "d": "hotkey_drop", - "s": "hotkey_skip", - "n": "next", - "p": "prev", - "u": "next_unreviewed", - " ": "toggle_play", - "r": "restart", - "f": "speed_up", - "1": "set_easy", - "2": "set_medium", - "3": "set_hard", - } - action = key_map.get(ch) - if action: - self._pending_actions.append(action) - - # Visualization options - self._scene.create_visualization_gui(show_debug_viz_control=False) - - # ------------------------------------------------------------------ - # Actions executed from the main loop (shallow call stack) - # ------------------------------------------------------------------ - - def _do_save(self) -> bool: - """Save the current annotation. Returns False if no decision selected.""" - decision = self._decision_dropdown.value.lower() if self._decision_dropdown.value else "" - if decision not in ("keep", "drop", "skip"): - print("[REVIEW] WARNING: no decision selected, save skipped") - return False - self._session.annotate( - decision=decision, - difficulty=self._difficulty_dropdown.value, - issue_tags=self._tags_input.value, - note=self._note_input.value, - ) - self._update_stats_display() - - # Terminal summary - s = self._session.stats() - row = self._session.current_row() - clip_id = row.clip_id if row else "?" - print( - f"[REVIEW] {clip_id} -> {decision} | " - f"{s.reviewed}/{s.total} ({s.progress_pct:.1f}%) | " - f"keep={s.keep_count} drop={s.drop_count} skip={s.skip_count} | " - f"kept_dur={s.kept_duration_s / 60:.1f}min" - ) - return True - - def _load_current_clip(self) -> None: - """Load the clip for the current session cursor.""" - row = self._session.current_row() - if row is None: - self._info_html.content = "No clips to review" - return - - # Resolve NPZ path (use resolved_npz_path which always points to .npz) - npz_path = Path(row.resolved_npz_path) - if not npz_path.is_absolute(): - npz_path = self._project_root / npz_path - - try: - self._player.load_clip(npz_path, clip_index=row.clip_index) - except Exception as exc: - self._info_html.content = f"Error loading clip:
{exc}" - return - - self._current_frame = 0 - self._playing = False - - # Update scene first (before touching GUI widgets) - self._player.set_frame(0) - self._scene.update_from_mjdata(self._player.data) - - # Now update GUI widgets — disable slider callback processing - # by setting _pending_frame_scrub to None after we're done. - self._play_btn.label = "Play" - self._play_btn.color = "green" - - max_frame = max(0, self._player.num_frames - 1) - self._frame_slider.max = max_frame - self._frame_slider.value = 0 - - self._decision_dropdown.value = row.decision.capitalize() if row.decision else "" - self._difficulty_dropdown.value = row.difficulty - self._tags_input.value = row.issue_tags - self._note_input.value = row.note - self._jump_input.value = self._session.cursor_display - - # Drain any spurious pending events triggered by the GUI updates above - self._pending_frame_scrub = None - self._pending_jump = None - self._pending_actions.clear() - - self._update_info_display() - self._update_stats_display() - - def _update_frame(self) -> None: - """Set frame on player and update the 3D scene.""" - self._player.set_frame(self._current_frame) - self._scene.update_from_mjdata(self._player.data) - - def _update_info_display(self) -> None: - """Update the Clip Info HTML panel.""" - row = self._session.current_row() - if row is None: - return - - status_color = { - "keep": "green", "drop": "red", "skip": "orange", "": "gray", - } - status_label = row.decision if row.decision else "unreviewed" - color = status_color.get(row.decision, "gray") - - kept_min = self._session.stats().kept_duration_s / 60.0 - - self._info_html.content = f""" -
- #{self._session.cursor_display}/{self._session.total} - [{status_label}] - Kept: {kept_min:.1f} min
- clip_id: {row.clip_id}
- source: {row.source}
- split: {row.resolved_split}
- frames: {row.num_frames} | fps: {row.fps} - | duration: {row.duration_s:.2f}s -
- """ - - def _update_stats_display(self) -> None: - """Update the Stats HTML panel.""" - s = self._session.stats() - kept_min = s.kept_duration_s / 60.0 - kept_h = s.kept_duration_s / 3600.0 - by_source_lines = "".join( - f"{src}: {dur / 60:.1f} min
" - for src, dur in sorted(s.kept_duration_by_source.items()) - ) - self._stats_html.content = f""" -
-

- Kept: {kept_min:.1f} min ({kept_h:.2f} h) -

- Progress: {s.reviewed} / {s.total} ({s.progress_pct:.1f}%)
- Keep: {s.keep_count} | - Drop: {s.drop_count} | - Skip: {s.skip_count}
-

Duration Breakdown

- Train: {s.kept_train_duration_s / 60:.1f} min
- Val: {s.kept_val_duration_s / 60:.1f} min
-

By Source

- {by_source_lines if by_source_lines else "none yet"} -
- """ - - # ------------------------------------------------------------------ - # Main loop — processes pending actions from callbacks - # ------------------------------------------------------------------ - - def _process_pending(self) -> None: - """Process all pending actions from GUI callbacks.""" - # Handle frame scrub from slider (take latest value only) - scrub = self._pending_frame_scrub - if scrub is not None and not self._playing: - self._pending_frame_scrub = None - self._current_frame = scrub - self._update_frame() - - # Handle jump input - jump = self._pending_jump - if jump is not None: - self._pending_jump = None - self._session.jump_to(jump) - self._load_current_clip() - return # load_current_clip clears remaining actions - - # Handle button actions (process one per tick to keep things responsive) - while self._pending_actions: - action = self._pending_actions.pop(0) - - if action == "toggle_play": - self._playing = not self._playing - self._play_btn.label = "Pause" if self._playing else "Play" - self._play_btn.color = "red" if self._playing else "green" - - elif action == "restart": - self._current_frame = 0 - self._update_frame() - self._frame_slider.value = 0 - self._pending_frame_scrub = None # drain spurious slider event - - elif action == "save": - self._do_save() - - elif action == "save_next": - if self._do_save(): - self._session.go_next_unreviewed() - self._load_current_clip() - return # load clears actions - - elif action == "prev": - self._session.go_prev() - self._load_current_clip() - return - - elif action == "next": - self._session.go_next() - self._load_current_clip() - return - - elif action == "next_unreviewed": - self._session.go_next_unreviewed() - self._load_current_clip() - return - - elif action == "hotkey_keep": - self._decision_dropdown.value = "Keep" - self._pending_actions.append("save_next") - - elif action == "hotkey_drop": - self._decision_dropdown.value = "Drop" - self._pending_actions.append("save_next") - - elif action == "hotkey_skip": - self._decision_dropdown.value = "Skip" - self._pending_actions.append("save_next") - - elif action == "speed_up": - _speed_levels = [1.0, 2.0, 4.0, 8.0] - try: - idx = _speed_levels.index(self._speed) - self._speed = _speed_levels[(idx + 1) % len(_speed_levels)] - except ValueError: - self._speed = 1.0 - # Update button group to reflect current speed - label = f"{self._speed:g}x" - if label in ("0.25x", "0.5x", "1x", "2x"): - self._speed_group.value = label - - elif action == "set_easy": - self._difficulty_dropdown.value = "easy" - - elif action == "set_medium": - self._difficulty_dropdown.value = "medium" - - elif action == "set_hard": - self._difficulty_dropdown.value = "hard" - - def run(self) -> None: - """Main loop: process callbacks and handle playback timing.""" - self.setup_gui() - self._load_current_clip() - - print(f"\nReview viewer ready at http://localhost:{self._server.get_port()}") - print("Press Ctrl+C to exit.\n") - - last_frame_time = time.time() - try: - while True: - now = time.time() - - # Check if scene visualization settings changed - if self._scene.needs_update: - self._scene.refresh_visualization() - - # Process GUI actions from callbacks - self._process_pending() - - # Advance playback - if self._playing and self._player.num_frames > 0: - dt = now - last_frame_time - frames_to_advance = dt * self._player.fps * self._speed - if frames_to_advance >= 1.0: - self._current_frame += int(frames_to_advance) - if self._current_frame >= self._player.num_frames: - self._current_frame = 0 - self._playing = False - self._play_btn.label = "Play" - self._play_btn.color = "green" - self._update_frame() - self._frame_slider.value = self._current_frame - self._pending_frame_scrub = None # drain spurious event - last_frame_time = now - else: - last_frame_time = now - - time.sleep(1.0 / 60.0) - - except KeyboardInterrupt: - print("\nShutting down...") - self._server.stop() - - -# --------------------------------------------------------------------------- -# CLI entry point -# --------------------------------------------------------------------------- - -def main() -> None: - parser = argparse.ArgumentParser(description="Web-based dataset clip review tool") - parser.add_argument("--dataset", type=str, required=True, help="Dataset name") - parser.add_argument( - "--review", type=str, default=None, - help="Path to review_state.csv (default: data/datasets/review/{dataset}/review_state.csv)", - ) - parser.add_argument( - "--xml", - type=str, - default=None, - help="Robot XML path (default: teleopit/retargeting/gmr/assets/unitree_g1/g1_mjlab.xml)", - ) - parser.add_argument("--port", type=int, default=8012, help="Viser server port") - parser.add_argument( - "--sort", type=str, default="unreviewed_first", - choices=["unreviewed_first", "source", "duration_desc"], - help="Initial clip sort order", - ) - args = parser.parse_args() - - if args.review: - review_path = Path(args.review) - if not review_path.is_absolute(): - review_path = (PROJECT_ROOT / review_path).resolve() - else: - review_path = PROJECT_ROOT / "data" / "datasets" / "review" / args.dataset / "review_state.csv" - - if not review_path.is_file(): - print(f"ERROR: review state not found: {review_path}", file=sys.stderr) - print("Run init_review_manifest.py first.", file=sys.stderr) - sys.exit(1) - - xml_path = Path(args.xml) if args.xml else DEFAULT_XML - if not xml_path.is_file(): - print( - "ERROR: " - + missing_gmr_assets_message(xml_path, label="Robot XML"), - file=sys.stderr, - ) - sys.exit(1) - - app = ReviewViewerApp( - review_path=review_path, - xml_path=xml_path, - project_root=PROJECT_ROOT, - port=args.port, - sort_mode=args.sort, - ) - app.run() - - -if __name__ == "__main__": - main() diff --git a/scripts/run/check_pico_signal.py b/scripts/run/check_pico_signal.py new file mode 100644 index 00000000..44eac250 --- /dev/null +++ b/scripts/run/check_pico_signal.py @@ -0,0 +1,296 @@ +"""Pico mocap/video signal diagnostic entry point.""" + +from __future__ import annotations + +from collections import Counter +import logging +import os +import signal +import time +import threading +from typing import Any + +import hydra +import numpy as np +from omegaconf import DictConfig + +from teleopit.inputs.human_frame_validation import HumanFrameValidationResult, validate_human_frame +from teleopit.inputs.pico4_provider import Pico4InputProvider +from teleopit.inputs.pico_video import PicoVideoRuntime, bridge_video_source, parse_pico_video_config +from teleopit.runtime.common import cfg_get + + +logger = logging.getLogger("teleopit.tools.check_pico_signal") + + +def _fmt_vec(values: tuple[float, ...] | None) -> str: + if values is None: + return "None" + return "[" + ", ".join(f"{value:.4f}" for value in values) + "]" + + +def _frame_stats(frame: dict[str, Any]) -> dict[str, Any]: + positions = [] + quat_norms = [] + pelvis_pos = None + for name, value in frame.items(): + try: + pos, quat = value + except Exception: + continue + try: + pos_arr = np.asarray(pos, dtype=np.float64).reshape(-1) + quat_arr = np.asarray(quat, dtype=np.float64).reshape(-1) + except Exception: + continue + if pos_arr.shape[0] >= 3 and np.all(np.isfinite(pos_arr[:3])): + positions.append(pos_arr[:3]) + if str(name) == "Pelvis": + pelvis_pos = pos_arr[:3].copy() + if quat_arr.size > 0 and np.all(np.isfinite(quat_arr)): + quat_norms.append(float(np.linalg.norm(quat_arr))) + + if not positions: + return {} + + pos = np.asarray(positions, dtype=np.float64) + return { + "pelvis_pos": pelvis_pos, + "min_pos": np.min(pos, axis=0), + "max_pos": np.max(pos, axis=0), + "extent": np.ptp(pos, axis=0), + "max_abs_pos": float(np.max(np.abs(pos))), + "quat_norm_min": min(quat_norms) if quat_norms else None, + "quat_norm_max": max(quat_norms) if quat_norms else None, + } + + +def _log_invalid(seq: int, age_ms: float, result: HumanFrameValidationResult) -> None: + logger.warning( + "Invalid Pico body frame | seq=%s age_ms=%.1f reason=%s joint=%s " + "max_abs_pos=%s pos=%s quat=%s detail=%s", + seq, + age_ms, + result.reason, + result.joint_name, + f"{result.max_abs_pos:.4f}" if result.max_abs_pos is not None else "None", + _fmt_vec(result.pos), + _fmt_vec(result.quat), + result.detail, + ) + + +def _log_summary( + *, + window_s: float, + total: int, + valid: int, + invalid_reasons: Counter[str], + provider_fps: float, + last_seq: int | None, + last_age_ms: float | None, + last_stats: dict[str, Any], + pushed_video_frames: int, +) -> None: + if total <= 0: + logger.info( + "Pico signal summary | window=%.1fs samples=0 provider_fps=%.1f " + "last_seq=%s video_frames=%d", + window_s, + provider_fps, + last_seq, + pushed_video_frames, + ) + return + + invalid = total - valid + reason_text = ",".join(f"{reason}:{count}" for reason, count in invalid_reasons.most_common()) or "none" + pelvis = last_stats.get("pelvis_pos") + extent = last_stats.get("extent") + min_pos = last_stats.get("min_pos") + max_pos = last_stats.get("max_pos") + logger.info( + "Pico signal summary | window=%.1fs samples=%d valid=%d invalid=%d reasons=%s " + "provider_fps=%.1f last_seq=%s last_age_ms=%s video_frames=%d " + "max_abs_pos=%s pelvis=%s extent=%s min=%s max=%s quat_norm=[%s,%s]", + window_s, + total, + valid, + invalid, + reason_text, + provider_fps, + last_seq, + f"{last_age_ms:.1f}" if last_age_ms is not None else "None", + pushed_video_frames, + f"{last_stats.get('max_abs_pos'):.4f}" if "max_abs_pos" in last_stats else "None", + _fmt_np_vec(pelvis), + _fmt_np_vec(extent), + _fmt_np_vec(min_pos), + _fmt_np_vec(max_pos), + _fmt_float(last_stats.get("quat_norm_min")), + _fmt_float(last_stats.get("quat_norm_max")), + ) + + +def _fmt_np_vec(values: Any) -> str: + if values is None: + return "None" + arr = np.asarray(values, dtype=np.float64).reshape(-1) + return "[" + ", ".join(f"{float(value):.4f}" for value in arr) + "]" + + +def _fmt_float(value: Any) -> str: + if value is None: + return "None" + return f"{float(value):.4f}" + + +def _build_provider(cfg: DictConfig, video_enabled: bool) -> Pico4InputProvider: + input_cfg = cfg_get(cfg, "input", {}) or {} + video_cfg = parse_pico_video_config(input_cfg) + return Pico4InputProvider( + human_format=str(cfg_get(input_cfg, "human_format", "pico_bridge")), + timeout=float(cfg_get(input_cfg, "pico4_timeout", 60.0)), + buffer_size=int(cfg_get(input_cfg, "pico4_buffer_size", 60)), + timestamp_gap_reset_s=float(cfg_get(input_cfg, "pico4_timestamp_gap_reset_s", 0.15)), + pause_button=cfg_get(input_cfg, "pause_button", "A"), + pause_debounce_s=float(cfg_get(input_cfg, "pause_debounce_s", 0.25)), + bridge_host=str(cfg_get(input_cfg, "bridge_host", "0.0.0.0")), + bridge_port=int(cfg_get(input_cfg, "bridge_port", 63901)), + bridge_discovery=bool(cfg_get(input_cfg, "bridge_discovery", True)), + bridge_advertise_ip=cfg_get(input_cfg, "bridge_advertise_ip", None), + bridge_video=bridge_video_source(video_cfg), + bridge_video_enabled=video_enabled, + bridge_start_timeout=float(cfg_get(input_cfg, "bridge_start_timeout", 10.0)), + bridge_history_size=int(cfg_get(input_cfg, "bridge_history_size", 120)), + ) + + +def _install_signal_handlers(stop_event: threading.Event) -> None: + def _handle_signal(signum: int, _frame: Any) -> None: + if stop_event.is_set(): + os._exit(130) + logger.info("Received signal %s -- shutting down", signum) + stop_event.set() + + signal.signal(signal.SIGINT, _handle_signal) + signal.signal(signal.SIGTERM, _handle_signal) + + +def _start_video_runtime_async(video_runtime: PicoVideoRuntime) -> threading.Event: + done = threading.Event() + + def _run() -> None: + try: + video_runtime.start() + finally: + done.set() + + thread = threading.Thread(target=_run, name="pico_video_start", daemon=True) + thread.start() + return done + + +@hydra.main(version_base=None, config_path="../../teleopit/configs", config_name="pico4_sim2real") +def main(cfg: DictConfig) -> None: + logging.basicConfig(level=logging.INFO, format="%(levelname)s:%(name)s:%(message)s") + input_cfg = cfg_get(cfg, "input", {}) or {} + video_cfg = parse_pico_video_config(input_cfg) + diag_cfg = cfg_get(cfg, "diagnostic", {}) or {} + poll_hz = float(cfg_get(diag_cfg, "poll_hz", cfg_get(cfg_get(cfg, "runtime", {}) or {}, "pico_input_hz", 120.0))) + summary_interval_s = float(cfg_get(diag_cfg, "summary_interval_s", 1.0)) + duration_s = float(cfg_get(diag_cfg, "duration_s", 0.0)) + + logger.info("Starting Pico signal diagnostic") + logger.info( + "Pico bridge | host=%s port=%s discovery=%s advertise_ip=%s", + cfg_get(input_cfg, "bridge_host", "0.0.0.0"), + cfg_get(input_cfg, "bridge_port", 63901), + cfg_get(input_cfg, "bridge_discovery", True), + cfg_get(input_cfg, "bridge_advertise_ip", None), + ) + logger.info( + "Signal check | validation=finite_values poll_hz=%.1f summary_interval_s=%.1f " + "duration_s=%s video_enabled=%s video_source=%s", + poll_hz, + summary_interval_s, + f"{duration_s:.1f}" if duration_s > 0.0 else "until Ctrl-C", + video_cfg.enabled, + video_cfg.source, + ) + + stop_event = threading.Event() + _install_signal_handlers(stop_event) + provider = _build_provider(cfg, video_cfg.enabled) + video_runtime = PicoVideoRuntime(provider=provider, config=video_cfg, mode="sim2real") + total = 0 + valid = 0 + invalid_reasons: Counter[str] = Counter() + last_seq: int | None = None + last_age_ms: float | None = None + last_stats: dict[str, Any] = {} + window_start_s = time.monotonic() + start_s = window_start_s + sleep_s = 1.0 / max(poll_hz, 1.0) + video_start_done: threading.Event | None = None + + try: + if video_cfg.enabled: + logger.info("Starting Pico video backend asynchronously") + video_start_done = _start_video_runtime_async(video_runtime) + while not stop_event.is_set(): + now = time.monotonic() + if duration_s > 0.0 and now - start_s >= duration_s: + break + + video_runtime.tick() + if provider.has_frame(): + try: + frame, timestamp_s, seq = provider.get_frame_packet() + except Exception: + logger.exception("Failed to read Pico body frame packet") + else: + seq = int(seq) + if seq != last_seq: + last_seq = seq + last_age_ms = max((time.monotonic() - float(timestamp_s)) * 1000.0, 0.0) + last_stats = _frame_stats(frame) + result = validate_human_frame(frame) + total += 1 + if result.valid: + valid += 1 + else: + invalid_reasons[result.reason] += 1 + _log_invalid(seq, last_age_ms, result) + + now = time.monotonic() + if now - window_start_s >= summary_interval_s: + _log_summary( + window_s=now - window_start_s, + total=total, + valid=valid, + invalid_reasons=invalid_reasons, + provider_fps=float(provider.fps), + last_seq=last_seq, + last_age_ms=last_age_ms, + last_stats=last_stats, + pushed_video_frames=video_runtime.pushed_frames, + ) + total = 0 + valid = 0 + invalid_reasons.clear() + window_start_s = now + + if video_start_done is not None and not video_start_done.is_set() and now - start_s >= 5.0: + logger.info("Waiting for Pico video backend to become ready in the background") + video_start_done = None + stop_event.wait(timeout=sleep_s) + except KeyboardInterrupt: + logger.info("KeyboardInterrupt -- stopping Pico signal diagnostic") + finally: + video_runtime.stop() + provider.close() + + +if __name__ == "__main__": + main() diff --git a/scripts/run/record_pico_motion.py b/scripts/run/record_pico_motion.py new file mode 100644 index 00000000..3fb1bff1 --- /dev/null +++ b/scripts/run/record_pico_motion.py @@ -0,0 +1,373 @@ +"""Interactively record Pico-retargeted G1 motion clips as training NPZ files.""" + +from __future__ import annotations + +import logging +import threading +import time +from pathlib import Path + +import hydra +import numpy as np +from omegaconf import DictConfig + +from teleopit.constants import FULL_QPOS_DIM +from teleopit.inputs.pico4_provider import Pico4InputProvider +from teleopit.recording.pico_motion import ( + PicoDatasetSpec, + RecordingState, + ensure_pico_dataset_spec, + sanitize_clip_name, + unique_clip_path, + write_motion_clip_npz, +) +from teleopit.retargeting.core import RetargetingModule +from teleopit.runtime.assets import PROJECT_ROOT, UNITREE_G1_XML, missing_gmr_assets_message +from teleopit.runtime.common import cfg_get +from teleopit.runtime.terminal_keyboard import TerminalKeyboardReader +from teleopit.sim.viewer_subprocess import start_robot_viewer + + +class RetargetPreview: + """Small wrapper around the existing MuJoCo retarget viewer subprocess.""" + + def __init__(self, xml_path: str | Path = UNITREE_G1_XML, *, enabled: bool = True) -> None: + self.enabled = bool(enabled) + self._proc = None + self._arr = None + self._alive = None + self._shutdown = None + self._qpos_len = FULL_QPOS_DIM + if not self.enabled: + return + + path = Path(xml_path).expanduser().resolve() + if not path.is_file(): + raise FileNotFoundError(missing_gmr_assets_message(path, label="G1 MuJoCo XML for retarget viewer")) + + import mujoco + + model = mujoco.MjModel.from_xml_path(str(path)) + self._qpos_len = int(model.nq) + self._proc, self._arr, self._alive, self._shutdown = start_robot_viewer( + str(path), + self._qpos_len, + True, + "Retarget", + 900, + 50, + ) + initial = np.zeros((self._qpos_len,), dtype=np.float64) + if self._qpos_len > 3: + initial[3] = 1.0 + self.update(initial) + + def update(self, qpos: np.ndarray) -> None: + if not self.enabled or self._arr is None: + return + qpos_arr = np.asarray(qpos, dtype=np.float64).reshape(-1) + if qpos_arr.shape[0] > self._qpos_len: + raise ValueError(f"retarget qpos has {qpos_arr.shape[0]} values, viewer accepts {self._qpos_len}") + out = np.zeros((self._qpos_len,), dtype=np.float64) + out[: qpos_arr.shape[0]] = qpos_arr + with self._arr.get_lock(): + self._arr[: self._qpos_len] = out.tolist() + + def close(self) -> None: + if self._shutdown is not None: + self._shutdown.set() + if self._proc is not None: + self._proc.join(timeout=3.0) + if self._proc.is_alive(): + self._proc.terminate() + + +class RetargetRecordingWorker: + """Continuously retarget Pico frames without blocking terminal input.""" + + def __init__( + self, + *, + provider: Pico4InputProvider, + retargeter: RetargetingModule, + preview: RetargetPreview, + state: RecordingState, + target_fps: int, + ) -> None: + if target_fps <= 0: + raise ValueError(f"target_fps must be > 0, got {target_fps}") + self._provider = provider + self._retargeter = retargeter + self._preview = preview + self._state = state + self._target_fps = int(target_fps) + self._stop_event = threading.Event() + self._thread = threading.Thread(target=self._run, name="pico_retarget_recorder", daemon=True) + self._last_error: BaseException | None = None + + def start(self) -> None: + self._thread.start() + + def stop(self) -> None: + self._stop_event.set() + self._thread.join(timeout=3.0) + + def raise_if_failed(self) -> None: + if self._last_error is not None: + raise RuntimeError(f"retarget worker failed: {self._last_error}") from self._last_error + + def _run(self) -> None: + last_seq: int | None = None + dt = 1.0 / float(self._target_fps) + next_tick = time.monotonic() + try: + while not self._stop_event.is_set(): + now = time.monotonic() + if now < next_tick: + self._stop_event.wait(timeout=min(next_tick - now, 0.01)) + continue + next_tick = now + dt + qpos, last_seq = _retarget_latest_frame(self._provider, self._retargeter, last_seq=last_seq) + if qpos is None: + continue + self._preview.update(qpos) + self._state.append(qpos) + except BaseException as exc: + self._last_error = exc + + +def _project_path(raw: str | Path) -> Path: + path = Path(raw).expanduser() + return path if path.is_absolute() else PROJECT_ROOT / path + + +def _build_provider(cfg: DictConfig) -> Pico4InputProvider: + input_cfg = cfg_get(cfg, "input", {}) or {} + return Pico4InputProvider( + human_format=str(cfg_get(input_cfg, "human_format", "pico_bridge")), + timeout=float(cfg_get(input_cfg, "pico4_timeout", 60.0)), + buffer_size=int(cfg_get(input_cfg, "pico4_buffer_size", 60)), + timestamp_gap_reset_s=float(cfg_get(input_cfg, "pico4_timestamp_gap_reset_s", 0.15)), + pause_button=cfg_get(input_cfg, "pause_button", "A"), + pause_debounce_s=float(cfg_get(input_cfg, "pause_debounce_s", 0.25)), + bridge_host=str(cfg_get(input_cfg, "bridge_host", "0.0.0.0")), + bridge_port=int(cfg_get(input_cfg, "bridge_port", 63901)), + bridge_discovery=bool(cfg_get(input_cfg, "bridge_discovery", True)), + bridge_advertise_ip=cfg_get(input_cfg, "bridge_advertise_ip", None), + bridge_video=None, + bridge_video_enabled=False, + bridge_start_timeout=float(cfg_get(input_cfg, "bridge_start_timeout", 10.0)), + bridge_history_size=int(cfg_get(input_cfg, "bridge_history_size", 120)), + ) + + +def _build_retargeter(cfg: DictConfig, provider: Pico4InputProvider) -> RetargetingModule: + input_cfg = cfg_get(cfg, "input", {}) or {} + human_height = cfg_get(input_cfg, "human_height", 1.75) + return RetargetingModule( + robot_name=str(cfg_get(input_cfg, "robot_name", "unitree_g1")), + human_format=str(provider.human_format), + actual_human_height=float(human_height), + ) + + +def _prompt_clip_name() -> str | None: + while True: + raw = input("\nClip name (semantic label, or q to quit): ").strip() + if raw.lower() in {"q", "quit", "exit"}: + return None + try: + return sanitize_clip_name(raw) + except ValueError as exc: + print(f"Invalid clip name: {exc}") + + +def _retarget_latest_frame( + provider: Pico4InputProvider, + retargeter: RetargetingModule, + *, + last_seq: int | None, +) -> tuple[np.ndarray | None, int | None]: + if not provider.has_frame(): + return None, last_seq + frame, _timestamp_s, seq = provider.get_frame_packet() + seq = int(seq) + if last_seq is not None and seq == last_seq: + return None, last_seq + qpos = np.asarray(retargeter.retarget(frame), dtype=np.float64).reshape(-1) + if qpos.shape[0] != FULL_QPOS_DIM: + raise ValueError(f"retarget qpos must be {FULL_QPOS_DIM}D, got {qpos.shape[0]}") + if not np.isfinite(qpos).all(): + raise ValueError("retarget qpos contains NaN/Inf") + return qpos, seq + + +def _record_one_clip( + *, + cfg: DictConfig, + state: RecordingState, +) -> str: + record_cfg = cfg_get(cfg, "record", {}) or {} + target_fps = int(cfg_get(record_cfg, "target_fps", 30)) + min_frames = int(cfg_get(record_cfg, "min_frames", 30)) + max_duration_s = float(cfg_get(record_cfg, "max_duration_s", 0.0)) + output_dir = _project_path(str(cfg_get(record_cfg, "output_dir", "data/pico_motion/clips"))) + if target_fps <= 0: + raise ValueError(f"record.target_fps must be > 0, got {target_fps}") + if min_frames < 2: + raise ValueError(f"record.min_frames must be >= 2, got {min_frames}") + + keyboard = TerminalKeyboardReader() + if not keyboard.active: + keyboard.close() + raise RuntimeError("record_pico_motion.py requires an interactive TTY for keyboard controls") + + print("Controls: R=start S=save D=discard N=new name Q=quit") + clip_name, recording, frame_count, _elapsed = state.status() + if clip_name is None: + print("Preview is running. Press N to enter a clip name before recording.") + else: + print(f"Ready: {clip_name}") + + try: + while True: + for event in keyboard.poll(): + key = event.key.lower() + if key == "q": + _clip_name, was_recording, _frame_count, _elapsed = state.status() + if was_recording: + discarded_name, _ = state.discard() + print(f"\nDiscarded unsaved clip: {discarded_name}") + return "quit" + if key == "n": + _clip_name, was_recording, _frame_count, _elapsed = state.status() + if was_recording: + print("\nPress S to save or D to discard before changing clip name.") + continue + return "next" + if key == "r": + current_name, _recording, _frame_count, _elapsed = state.status() + if current_name is None: + print("\nPress N and enter a clip name before recording.") + continue + clip_name = state.start() + print(f"\nRecording: {clip_name}") + elif key == "d": + clip_name, frame_count = state.discard() + label = "" if clip_name is None else clip_name + print(f"\nDiscarded: {label} ({frame_count} frames)") + return "next" + elif key == "s": + clip_name, _was_recording, frames = state.snapshot() + if clip_name is None: + print("\nPress N and enter a clip name before saving.") + continue + if len(frames) < min_frames: + print(f"\nNeed at least {min_frames} frames before saving; current={len(frames)}") + continue + output_path = unique_clip_path(output_dir, clip_name) + write_motion_clip_npz(output_path, frames, fps=target_fps) + state.mark_saved() + duration_s = len(frames) / float(target_fps) + print(f"\nSaved: {output_path} ({len(frames)} frames, {duration_s:.2f}s)") + return "next" + + clip_name, recording, frame_count, elapsed = state.status() + if recording and clip_name is not None and max_duration_s > 0.0 and elapsed is not None: + if elapsed >= max_duration_s: + clip_name, _was_recording, frames = state.snapshot() + if clip_name is None: + raise RuntimeError("recording reached max_duration_s without a clip name") + if len(frames) < min_frames: + raise ValueError( + f"max_duration_s reached but only recorded {len(frames)} frames; " + f"min_frames={min_frames}" + ) + output_path = unique_clip_path(output_dir, clip_name) + write_motion_clip_npz(output_path, frames, fps=target_fps) + state.mark_saved() + print(f"\nSaved by max_duration_s: {output_path}") + return "next" + + time.sleep(0.02) + finally: + keyboard.close() + + +def _maybe_write_dataset_spec(cfg: DictConfig) -> None: + record_cfg = cfg_get(cfg, "record", {}) or {} + if not bool(cfg_get(record_cfg, "write_dataset_spec", True)): + return + output_dir = _project_path(str(cfg_get(record_cfg, "output_dir", "data/pico_motion/clips"))) + spec_path = _project_path(str(cfg_get(record_cfg, "dataset_spec_path", "data/pico_motion/pico_recorded.yaml"))) + dataset_name = str(cfg_get(record_cfg, "dataset_name", "pico_recorded")) + target_fps = int(cfg_get(record_cfg, "target_fps", 30)) + overwrite = bool(cfg_get(record_cfg, "overwrite_dataset_spec", False)) + path = ensure_pico_dataset_spec( + spec_path, + output_dir, + spec=PicoDatasetSpec(dataset_name=dataset_name, target_fps=target_fps), + overwrite=overwrite, + ) + print(f"Dataset spec: {path}") + + +def _configure_logging(cfg: DictConfig) -> None: + logging.basicConfig(level=logging.INFO, format="%(levelname)s:%(name)s:%(message)s") + record_cfg = cfg_get(cfg, "record", {}) or {} + if bool(cfg_get(record_cfg, "quiet_logs", True)): + logging.getLogger("pico_bridge").setLevel(logging.WARNING) + logging.getLogger("teleopit.inputs.pico4_provider").setLevel(logging.ERROR) + + +@hydra.main(version_base=None, config_path="../../teleopit/configs", config_name="pico4_record") +def main(cfg: DictConfig) -> None: + _configure_logging(cfg) + _maybe_write_dataset_spec(cfg) + + print("Starting Pico receiver; waiting for body tracking...") + provider = _build_provider(cfg) + preview = RetargetPreview(enabled=False) + try: + preview.close() + preview = RetargetPreview(enabled=bool(cfg_get(cfg_get(cfg, "record", {}) or {}, "viewer_enabled", True))) + retargeter = _build_retargeter(cfg, provider) + print("Pico recorder is ready. Retarget viewer will update after the first frame.") + record_cfg = cfg_get(cfg, "record", {}) or {} + target_fps = int(cfg_get(record_cfg, "target_fps", 30)) + state = RecordingState() + worker = RetargetRecordingWorker( + provider=provider, + retargeter=retargeter, + preview=preview, + state=state, + target_fps=target_fps, + ) + worker.start() + try: + clip_name = _prompt_clip_name() + if clip_name is not None: + state.set_clip_name(clip_name) + while True: + worker.raise_if_failed() + if clip_name is None: + break + command = _record_one_clip(cfg=cfg, state=state) + worker.raise_if_failed() + if command == "quit": + break + clip_name = _prompt_clip_name() + if clip_name is not None: + state.set_clip_name(clip_name) + finally: + worker.stop() + worker.raise_if_failed() + except KeyboardInterrupt: + print("\nInterrupted; unsaved in-progress clip was discarded.") + finally: + preview.close() + provider.close() + + +if __name__ == "__main__": + main() diff --git a/scripts/run/run_sim.py b/scripts/run/run_sim.py index fb729844..216179c3 100644 --- a/scripts/run/run_sim.py +++ b/scripts/run/run_sim.py @@ -4,35 +4,48 @@ from omegaconf import DictConfig from teleopit.pipeline import TeleopPipeline +from teleopit.runtime.common import cfg_get +from teleopit.runtime.console import ( + PlainConsole, + configure_runtime_logging, + sim_keyboard_controls, +) from teleopit.runtime.cli import validate_policy_path -def _print_sim_controls(cfg: DictConfig) -> None: - provider = str(cfg.input.get("provider", "bvh")).lower() +def _sim_status(cfg: DictConfig) -> tuple[tuple[str, str], ...]: + input_cfg = cfg_get(cfg, "input", {}) or {} + provider = str(cfg_get(input_cfg, "provider", "bvh")).lower() + viewers = str(cfg_get(cfg, "viewers", "none")) if provider == "pico4": - print("Pico sim2sim controls:") - if bool(cfg.get("keyboard", {}).get("enabled", False)): - print(" Keyboard: starts in STANDING; Y mocap, A pause/resume, X standing, Q quit.") - else: - print(" Pico controller: A pause/resume.") - print(" State flow: STANDING -> MOCAP -> STANDING.") - return - if bool(cfg.get("playback", {}).get("keyboard", {}).get("enabled", False)): - print("Offline sim2sim controls:") - print(" Keyboard: Space/P pause/resume, R replay, Q stop.") + keyboard_cfg = cfg_get(cfg, "keyboard", {}) or {} + state = "STANDING" if bool(cfg_get(keyboard_cfg, "enabled", False)) else "MOCAP" + return ( + ("State", state), + ("Input", "Pico4 live"), + ("Viewers", viewers), + ) + return ( + ("State", "MOCAP"), + ("Input", "BVH"), + ("Viewers", viewers), + ) @hydra.main(version_base=None, config_path="../../teleopit/configs", config_name="default") def main(cfg: DictConfig) -> None: + configure_runtime_logging(cfg, force=True) validate_policy_path(cfg, "run_sim.py") - pipeline = TeleopPipeline(cfg) + console = PlainConsole(title="Teleopit sim2sim") + pipeline = TeleopPipeline(cfg, console=console) num_steps = int(cfg.get("num_steps", 0)) - record = bool(cfg.get("record", False)) - if cfg.input.get("provider") == "pico4": - print("Waiting for Pico4 body tracking data...") - _print_sim_controls(cfg) - result = pipeline.run(num_steps=num_steps, record=record) - print(result) + events = [] + input_cfg = cfg_get(cfg, "input", {}) or {} + if cfg_get(input_cfg, "provider", None) == "pico4": + events.append("waiting for Pico4 body tracking data") + console.start(status=_sim_status(cfg), controls=sim_keyboard_controls(cfg), events=events) + result = pipeline.run(num_steps=num_steps) + console.event(str(result)) if __name__ == "__main__": diff --git a/scripts/run/run_sim2real.py b/scripts/run/run_sim2real.py index d546d8a3..4e260a35 100644 --- a/scripts/run/run_sim2real.py +++ b/scripts/run/run_sim2real.py @@ -2,34 +2,57 @@ from __future__ import annotations +import inspect + import hydra from omegaconf import DictConfig +from teleopit.runtime.common import cfg_get +from teleopit.runtime.console import ( + PlainConsole, + configure_runtime_logging, + sim2real_operator_controls, +) from teleopit.runtime.cli import validate_policy_path -from teleopit.sim2real.controller import Sim2RealController +from teleopit.sim2real.mp import Sim2RealRuntime -def _print_sim2real_controls(cfg: DictConfig) -> None: - provider = str(cfg.input.get("provider", "bvh")).lower() - print("Sim2real controls:") - print(" Remote Start: enter STANDING.") - print(" Remote Y: enter MOCAP.") - print(" Remote X: return to STANDING.") - print(" Remote L1+R1: DAMPING / estop.") - if provider == "pico4": - print(" Mocap pause/resume: Pico/controller A.") - else: - print(" Offline playback: A pause/resume, B replay from start.") - print(" State flow: IDLE -> STANDING -> MOCAP -> STANDING, Any -> DAMPING.") +def _sim2real_status(cfg: DictConfig) -> tuple[tuple[str, str], ...]: + input_cfg = cfg_get(cfg, "input", {}) or {} + provider = str(cfg_get(input_cfg, "provider", "bvh")).lower() + input_label = "Pico4 live" if provider == "pico4" else "BVH" + recording_cfg = cfg_get(cfg, "recording", {}) or {} + recording = "enabled" if bool(cfg_get(recording_cfg, "enabled", False)) else "off" + return ( + ("State", "IDLE"), + ("Runtime", "multiprocess"), + ("Input", input_label), + ("Recording", recording), + ) @hydra.main(version_base=None, config_path="../../teleopit/configs", config_name="sim2real") def main(cfg: DictConfig) -> None: + _run_sim2real(cfg) + + +def _run_sim2real(cfg: DictConfig) -> None: + configure_runtime_logging(cfg, force=True) validate_policy_path(cfg, "run_sim2real.py") - controller = Sim2RealController(cfg) - if cfg.input.get("provider") == "pico4": - print("Waiting for Pico4 body tracking data...") - _print_sim2real_controls(cfg) + console = PlainConsole(title="Teleopit sim2real") + runtime_params = inspect.signature(Sim2RealRuntime).parameters + controller = Sim2RealRuntime(cfg, console=console) if "console" in runtime_params else Sim2RealRuntime(cfg) + events = [] + input_cfg = cfg_get(cfg, "input", {}) or {} + if cfg_get(input_cfg, "provider", None) == "pico4": + events.append("waiting for Pico4 body tracking data") + console.start( + status=_sim2real_status(cfg), + controls=sim2real_operator_controls(cfg), + events=events, + control_section="Controls", + show_help_key=False, + ) try: controller.run() finally: diff --git a/scripts/run/standalone_standing.py b/scripts/run/standalone_standing.py index 81bb4b60..1553393a 100644 --- a/scripts/run/standalone_standing.py +++ b/scripts/run/standalone_standing.py @@ -1,861 +1,527 @@ #!/usr/bin/env python3 -"""Standalone G1 standing script with RL policy -- no Teleopit/Pico dependency. - -Uses ONNX RL policy inference to maintain balanced standing, matching the -STANDING mode in Sim2RealController. Only depends on: - - g1_bridge_sdk (C++ DDS bridge) - - onnxruntime - - mujoco - - numpy - -Usage: - python scripts/run/standalone_standing.py \ - --policy track.onnx \ - --network-interface enp130s0 - -Flow: - 1. Init DDS, subscribe to rt/lowstate - 2. Load ONNX policy + MuJoCo model for observation building - 3. Enter debug mode (release MotionSwitcher modes) - 4. Lock joints, then run RL policy standing loop with startup ramp - 5. Hold standing until Ctrl-C or L1+R1 - 6. On exit: set damping, restore ai mode -""" +"""Standalone G1 standing controller using the sim2real standing path.""" from __future__ import annotations import argparse +from dataclasses import dataclass import logging -import math import signal -import struct -import sys -import threading import time -from collections import deque from pathlib import Path +from typing import Any -_REPO_ROOT = Path(__file__).resolve().parents[2] - -import mujoco import numpy as np -import onnxruntime as ort +from omegaconf import OmegaConf + +from teleopit.constants import FULL_QPOS_DIM, NUM_JOINTS, ROOT_DIM +from teleopit.controllers.observation import align_motion_qpos_yaw +from teleopit.controllers.rl_policy import RLPolicyController +from teleopit.runtime.common import cfg_get +from teleopit.runtime.factory import _build_policy_components, build_simulation_cfg +from teleopit.sim.reference_timeline import ReferenceWindowBuilder +from teleopit.sim.reference_utils import build_static_reference_window, obs_builder_requires_reference_window +from teleopit.sim2real.reference_processor import Sim2RealReferenceProcessor +from teleopit.sim2real.remote import UnitreeRemote +from teleopit.sim2real.safety import Sim2RealSafetyManager +from teleopit.sim2real.unitree_g1 import UnitreeG1Robot + +PROJECT_ROOT = Path(__file__).resolve().parents[2] +DEFAULT_POLICY_HZ = 50.0 logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") logger = logging.getLogger(__name__) -# ---- G1 29-DOF Constants ---- -NUM_JOINTS = 29 -NUM_MOTORS = 35 -MODE_PR = 0 -MODE_MACHINE = 5 -PUBLISH_HZ = 200 -POLICY_HZ = 50.0 -POS_STOP_F = 2146000000.0 -VEL_STOP_F = 16000.0 -KD_DAMPING = 8.0 -RAMP_DURATION = 2.0 -JOINT_VEL_LIMIT = 10.0 - -# Default standing pose (from g1_constants.py HOME_KEYFRAME) -DEFAULT_ANGLES = np.array([ - -0.312, 0, 0, 0.669, -0.363, 0, # Left leg (0-5) - -0.312, 0, 0, 0.669, -0.363, 0, # Right leg (6-11) - 0, 0, 0, # Waist (12-14) - 0.2, 0.2, 0, 0.6, 0, 0, 0, # Left arm (15-21) - 0.2, -0.2, 0, 0.6, 0, 0, 0, # Right arm (22-28) -], dtype=np.float32) - -# Action scale (from g1.yaml) -ACTION_SCALE = np.array([ - 0.5475, 0.3507, 0.5475, 0.3507, 0.4386, 0.4386, - 0.5475, 0.3507, 0.5475, 0.3507, 0.4386, 0.4386, - 0.5475, 0.4386, 0.4386, - 0.4386, 0.4386, 0.4386, 0.4386, 0.4386, 0.0745, 0.0745, - 0.4386, 0.4386, 0.4386, 0.4386, 0.4386, 0.0745, 0.0745, -], dtype=np.float32) - -# PD gains (from mjlab g1_constants.py) -KP = np.array([ - 40.2, 99.1, 40.2, 99.1, 28.5, 28.5, - 40.2, 99.1, 40.2, 99.1, 28.5, 28.5, - 40.2, 28.5, 28.5, - 14.3, 14.3, 14.3, 14.3, 14.3, 16.8, 16.8, - 14.3, 14.3, 14.3, 14.3, 14.3, 16.8, 16.8, -], dtype=np.float32) - -KD = np.array([ - 2.6, 6.3, 2.6, 6.3, 1.8, 1.8, - 2.6, 6.3, 2.6, 6.3, 1.8, 1.8, - 2.6, 1.8, 1.8, - 0.9, 0.9, 0.9, 0.9, 0.9, 1.1, 1.1, - 0.9, 0.9, 0.9, 0.9, 0.9, 1.1, 1.1, -], dtype=np.float32) - -# Joint position limits (from pico4_sim2real.yaml) -JOINT_POS_LOWER = np.array([ - -2.5307, -0.5236, -2.7576, -0.087267, -0.87267, -0.2618, - -2.5307, -2.9671, -2.7576, -0.087267, -0.87267, -0.2618, - -2.618, -0.52, -0.52, - -3.0892, -1.5882, -2.618, -1.0472, -1.972222054, -1.61443, -1.61443, - -3.0892, -2.2515, -2.618, -1.0472, -1.972222054, -1.61443, -1.61443, -], dtype=np.float32) - -JOINT_POS_UPPER = np.array([ - 2.8798, 2.9671, 2.7576, 2.8798, 0.5236, 0.2618, - 2.8798, 0.5236, 2.7576, 2.8798, 0.5236, 0.2618, - 2.618, 0.52, 0.52, - 2.6704, 2.2515, 2.618, 2.0944, 1.972222054, 1.61443, 1.61443, - 2.6704, 1.5882, 2.618, 2.0944, 1.972222054, 1.61443, 1.61443, -], dtype=np.float32) - -# MuJoCo XML path for FK -MJCF_PATH = _REPO_ROOT / "teleopit" / "retargeting" / "gmr" / "assets" / "unitree_g1" / "g1_mjlab.xml" - -JOINT_MAP = list(range(NUM_JOINTS)) -GRAVITY_UNIT_W = np.array([0.0, 0.0, -1.0], dtype=np.float32) - -# Wireless remote button masks -_KEYS_OFFSET = 2 -_R1 = 0x0001 -_L1 = 0x0002 - - -# ===================================================================== -# Quaternion helpers (from teleopit.controllers.observation) -# ===================================================================== - -def quat_inv(q): - inv = q.copy() - inv[..., 1:] = -inv[..., 1:] - return inv - - -def quat_mul(q1, q2): - w1, x1, y1, z1 = q1[..., 0], q1[..., 1], q1[..., 2], q1[..., 3] - w2, x2, y2, z2 = q2[..., 0], q2[..., 1], q2[..., 2], q2[..., 3] - return np.stack([ - w1*w2 - x1*x2 - y1*y2 - z1*z2, - w1*x2 + x1*w2 + y1*z2 - z1*y2, - w1*y2 - x1*z2 + y1*w2 + z1*x2, - w1*z2 + x1*y2 - y1*x2 + z1*w2, - ], axis=-1).astype(np.float32) - - -def quat_rotate(q, v): - v_quat = np.zeros((*v.shape[:-1], 4), dtype=np.float32) - v_quat[..., 1:4] = v - result = quat_mul(quat_mul(q, v_quat), quat_inv(q)) - return result[..., 1:4] - - -def yaw_quat(q): - w, x, y, z = float(q[0]), float(q[1]), float(q[2]), float(q[3]) - yaw = math.atan2(2.0 * (w * z + x * y), 1.0 - 2.0 * (y * y + z * z)) - out = np.array([math.cos(yaw / 2.0), 0.0, 0.0, math.sin(yaw / 2.0)], dtype=np.float32) - out /= max(float(np.linalg.norm(out)), 1e-8) - return out - - -def align_motion_qpos_yaw(robot_quat_wxyz, motion_qpos): - """Align motion_qpos[3:7] quaternion yaw to match robot yaw.""" - motion_quat = np.asarray(motion_qpos[3:7], dtype=np.float32) - robot_quat = np.asarray(robot_quat_wxyz, dtype=np.float32) - delta = quat_mul(robot_quat, quat_inv(motion_quat)) - delta_yaw = yaw_quat(delta) - motion_qpos[3:7] = quat_mul(delta_yaw, motion_quat).astype(motion_qpos.dtype) - return motion_qpos - - -def quat_to_rot6d(q): - w, x, y, z = q[0], q[1], q[2], q[3] - r00 = 1 - 2 * (y*y + z*z) - r01 = 2 * (x*y - w*z) - r10 = 2 * (x*y + w*z) - r11 = 1 - 2 * (x*x + z*z) - r20 = 2 * (x*z - w*y) - r21 = 2 * (y*z + w*x) - return np.array([r00, r01, r10, r11, r20, r21], dtype=np.float32) - - -# ===================================================================== -# Observation builder (from teleopit.controllers.observation) -# ===================================================================== - -class ObservationBuilder: - """166D VelCmd observation builder using MuJoCo FK.""" - - def __init__(self, xml_path: str): - self._mj_model = mujoco.MjModel.from_xml_path(xml_path) - self._mj_data = mujoco.MjData(self._mj_model) - - # Find anchor body (torso_link) - body_names = {self._mj_model.body(i).name: i for i in range(self._mj_model.nbody)} - self._anchor_body_id = body_names["torso_link"] - - # Base obs: command(29+29) + anchor_ori_b(6) + ang_vel(3) + joint_pos_rel(29) + qvel(29) + last_act(29) - # = 154 - # VelCmd extra: projected_gravity(3) + ref_lin_vel_b(3) + ref_ang_vel_b(3) + ref_proj_gravity(3) - # = 12 - # Total = 166 - self.total_obs_size = NUM_JOINTS * 2 + 6 + 3 + NUM_JOINTS * 3 + 12 - - # Precompute motion torso offset for standing (DEFAULT_ANGLES with identity base) - # torso_quat_world = quat_mul(base_quat, torso_offset) for constant joint angles - self._run_fk(np.zeros(3), np.array([1, 0, 0, 0], dtype=np.float32), DEFAULT_ANGLES) - self._standing_torso_offset = self._get_body_quat(self._anchor_body_id).copy() - - def _run_fk(self, base_pos, base_quat, joint_pos): - self._mj_data.qpos[:] = 0.0 - self._mj_data.qpos[0:3] = np.asarray(base_pos, dtype=np.float64).reshape(3) - quat = np.asarray(base_quat, dtype=np.float64).reshape(4) - quat = quat / max(np.linalg.norm(quat), 1e-8) - self._mj_data.qpos[3:7] = quat - n = min(len(joint_pos), self._mj_model.nq - 7) - self._mj_data.qpos[7:7 + n] = np.asarray(joint_pos, dtype=np.float64)[:n] - mujoco.mj_kinematics(self._mj_model, self._mj_data) - - def _get_body_quat(self, body_id): - return np.asarray(self._mj_data.xquat[body_id], dtype=np.float32).copy() - - def build(self, robot_qpos, robot_qvel, robot_quat, robot_ang_vel, - motion_qpos, motion_joint_vel, last_action): - """Build 166D observation for VelCmd policy. - - Args: - robot_qpos: (29,) current joint positions - robot_qvel: (29,) current joint velocities - robot_quat: (4,) base orientation quaternion (w,x,y,z) - robot_ang_vel: (3,) base angular velocity - motion_qpos: (36,) reference motion [pos(3) + quat(4) + joints(29)] - motion_joint_vel: (29,) reference joint velocity - last_action: (29,) previous policy action - """ - qpos = np.asarray(robot_qpos, dtype=np.float32)[:NUM_JOINTS] - qvel = np.asarray(robot_qvel, dtype=np.float32)[:NUM_JOINTS] - robot_q = np.asarray(robot_quat, dtype=np.float32) - ang_vel = np.asarray(robot_ang_vel, dtype=np.float32) - motion = np.asarray(motion_qpos, dtype=np.float32) - m_joint_vel = np.asarray(motion_joint_vel, dtype=np.float32)[:NUM_JOINTS] - last_act = np.asarray(last_action, dtype=np.float32) - - # Anchor quaternions: skip MuJoCo FK, use base quat * precomputed offset - # For standing, waist joints ≈ 0 so torso ≈ base orientation - robot_anchor_quat = quat_mul(robot_q, self._standing_torso_offset) - - motion_base_quat = motion[3:7] - motion_joint_pos = motion[7:7 + NUM_JOINTS] - motion_anchor_quat = quat_mul(motion_base_quat, self._standing_torso_offset) - - # Base observation (154D) - command = np.concatenate((motion_joint_pos, m_joint_vel), dtype=np.float32) - rel_quat = quat_mul(quat_inv(robot_anchor_quat), motion_anchor_quat) - motion_anchor_ori_b = quat_to_rot6d(rel_quat) - joint_pos_rel = qpos - DEFAULT_ANGLES - - base_obs = np.concatenate([ - command, # 29 + 29 = 58 - motion_anchor_ori_b, # 6 - ang_vel, # 3 - joint_pos_rel, # 29 - qvel, # 29 - last_act, # 29 - ], dtype=np.float32) - - # VelCmd extra (12D) -- standing has zero reference velocities - projected_gravity = quat_rotate(quat_inv(robot_q), GRAVITY_UNIT_W) - robot_inv = quat_inv(robot_anchor_quat) - # Zero reference velocities for standing - ref_lin_vel_b = np.zeros(3, dtype=np.float32) - ref_ang_vel_b = np.zeros(3, dtype=np.float32) - ref_proj_gravity = quat_rotate(quat_inv(motion_anchor_quat), GRAVITY_UNIT_W) - - velcmd_obs = np.concatenate([ - projected_gravity, # 3 - ref_lin_vel_b, # 3 - ref_ang_vel_b, # 3 - ref_proj_gravity, # 3 - ], dtype=np.float32) - - obs = np.concatenate([base_obs, velcmd_obs], dtype=np.float32) - return obs - - -# ===================================================================== -# ONNX Policy wrapper (from teleopit.controllers.rl_policy) -# ===================================================================== - -class PolicyInference: - """Minimal ONNX policy inference wrapper.""" - - def __init__(self, policy_path: str): - providers = ["CPUExecutionProvider"] - available = set(ort.get_available_providers()) - if "CUDAExecutionProvider" in available: - providers.insert(0, "CUDAExecutionProvider") - - self._session = ort.InferenceSession(policy_path, providers=providers) - onnx_inputs = self._session.get_inputs() - self._input_name = onnx_inputs[0].name - self._output_name = self._session.get_outputs()[0].name - - # Check for dual-input (obs + obs_history) model - self._multi_input = False - self._history_length = 0 - self._history_buf: deque[np.ndarray] = deque() - if len(onnx_inputs) == 2 and onnx_inputs[1].name == "obs_history": - self._multi_input = True - self._history_length = int(onnx_inputs[1].shape[1]) - self._history_buf = deque(maxlen=self._history_length) - - # Extract expected obs dim - self._expected_obs_dim = None - if len(onnx_inputs[0].shape) >= 2 and isinstance(onnx_inputs[0].shape[-1], int): - self._expected_obs_dim = int(onnx_inputs[0].shape[-1]) +@dataclass(frozen=True) +class _StepTiming: + state_s: float + obs_s: float + infer_s: float + target_s: float + send_s: float + action_delta: float | None + target_delta: float | None + qvel_norm: float + ang_vel_norm: float + + +class _StandaloneTimingReporter: + def __init__( + self, + *, + target_period_s: float, + log_interval_s: float = 1.0, + deadline_miss_tolerance_s: float = 0.001, + enabled: bool = True, + ) -> None: + self._target_period_s = float(target_period_s) + self._log_interval_s = float(log_interval_s) + self._deadline_miss_tolerance_s = float(deadline_miss_tolerance_s) + self._enabled = bool(enabled) + self._window_start_s: float | None = None + self._loop_ms: list[float] = [] + self._late_ms: list[float] = [] + self._work_ms: list[float] = [] + self._state_ms: list[float] = [] + self._obs_ms: list[float] = [] + self._infer_ms: list[float] = [] + self._target_ms: list[float] = [] + self._send_ms: list[float] = [] + self._action_delta: list[float] = [] + self._target_delta: list[float] = [] + self._qvel_norm: list[float] = [] + self._ang_vel_norm: list[float] = [] + self._deadline_miss_count = 0 + self._work_overrun_count = 0 + + def record( + self, + *, + loop_start_s: float, + work_elapsed_s: float, + cycle_elapsed_s: float, + step: _StepTiming, + ) -> None: + if self._window_start_s is None: + self._window_start_s = float(loop_start_s) + self._loop_ms.append(float(cycle_elapsed_s) * 1000.0) + self._late_ms.append(max(0.0, float(cycle_elapsed_s) - self._target_period_s) * 1000.0) + self._work_ms.append(float(work_elapsed_s) * 1000.0) + self._state_ms.append(float(step.state_s) * 1000.0) + self._obs_ms.append(float(step.obs_s) * 1000.0) + self._infer_ms.append(float(step.infer_s) * 1000.0) + self._target_ms.append(float(step.target_s) * 1000.0) + self._send_ms.append(float(step.send_s) * 1000.0) + if step.action_delta is not None: + self._action_delta.append(float(step.action_delta)) + if step.target_delta is not None: + self._target_delta.append(float(step.target_delta)) + self._qvel_norm.append(float(step.qvel_norm)) + self._ang_vel_norm.append(float(step.ang_vel_norm)) + if cycle_elapsed_s > self._target_period_s + self._deadline_miss_tolerance_s: + self._deadline_miss_count += 1 + if work_elapsed_s > self._target_period_s + 1e-9: + self._work_overrun_count += 1 + if loop_start_s - self._window_start_s >= self._log_interval_s: + self._emit(loop_start_s) + + def _emit(self, end_s: float) -> None: + sample_count = len(self._loop_ms) + if sample_count <= 0: + self._reset(end_s) + return + if not self._enabled: + self._reset(end_s) + return logger.info( - "Policy loaded: input=%s, obs_dim=%s, multi_input=%s, history_len=%d", - self._input_name, self._expected_obs_dim, self._multi_input, self._history_length, + "Standalone timing | samples=%d window=%.1fs | " + "loop_ms p50=%.2f p95=%.2f p99=%.2f max=%.2f | " + "late_ms p50=%.2f p95=%.2f p99=%.2f max=%.2f deadline_miss(>%.2fms)=%d/%d | " + "work_ms p50=%.2f p95=%.2f p99=%.2f max=%.2f work_overrun=%d/%d | " + "state_ms p50=%.2f p95=%.2f p99=%.2f max=%.2f | " + "obs_ms p50=%.2f p95=%.2f p99=%.2f max=%.2f | " + "infer_ms p50=%.2f p95=%.2f p99=%.2f max=%.2f | " + "target_ms p50=%.2f p95=%.2f p99=%.2f max=%.2f | " + "send_ms p50=%.2f p95=%.2f p99=%.2f max=%.2f | " + "action_delta p50=%.4f p95=%.4f p99=%.4f max=%.4f | " + "target_delta p50=%.4f p95=%.4f p99=%.4f max=%.4f | " + "qvel_norm p50=%.4f p95=%.4f p99=%.4f max=%.4f | " + "ang_vel_norm p50=%.4f p95=%.4f p99=%.4f max=%.4f", + sample_count, + end_s - float(self._window_start_s), + *self._summarize(self._loop_ms), + *self._summarize(self._late_ms), + self._deadline_miss_tolerance_s * 1000.0, + self._deadline_miss_count, + sample_count, + *self._summarize(self._work_ms), + self._work_overrun_count, + sample_count, + *self._summarize(self._state_ms), + *self._summarize(self._obs_ms), + *self._summarize(self._infer_ms), + *self._summarize(self._target_ms), + *self._summarize(self._send_ms), + *self._summarize(self._action_delta), + *self._summarize(self._target_delta), + *self._summarize(self._qvel_norm), + *self._summarize(self._ang_vel_norm), ) - - def compute_action(self, observation: np.ndarray) -> np.ndarray: - obs = np.asarray(observation, dtype=np.float32) - if obs.ndim == 1: - obs = obs[np.newaxis, :] - obs_flat = obs.reshape(-1) - - if self._multi_input: - if len(self._history_buf) == 0: - for _ in range(self._history_length): - self._history_buf.append(obs_flat.copy()) - else: - self._history_buf.append(obs_flat.copy()) - obs_history = np.stack(list(self._history_buf), axis=0)[np.newaxis].astype(np.float32) - feed = {self._input_name: obs, "obs_history": obs_history} - else: - feed = {self._input_name: obs} - - raw_action = np.asarray( - self._session.run([self._output_name], feed)[0], dtype=np.float32 + self._reset(end_s) + + def _reset(self, window_start_s: float) -> None: + self._window_start_s = float(window_start_s) + self._loop_ms.clear() + self._late_ms.clear() + self._work_ms.clear() + self._state_ms.clear() + self._obs_ms.clear() + self._infer_ms.clear() + self._target_ms.clear() + self._send_ms.clear() + self._action_delta.clear() + self._target_delta.clear() + self._qvel_norm.clear() + self._ang_vel_norm.clear() + self._deadline_miss_count = 0 + self._work_overrun_count = 0 + + @staticmethod + def _summarize(samples: list[float]) -> tuple[float, float, float, float]: + if not samples: + return 0.0, 0.0, 0.0, 0.0 + values = np.asarray(samples, dtype=np.float64) + p50, p95, p99 = np.percentile(values, [50.0, 95.0, 99.0]) + return float(p50), float(p95), float(p99), float(np.max(values)) + + +class StandaloneStandingController: + """Small wrapper around the production sim2real STANDING implementation.""" + + def __init__( + self, + cfg: Any, + *, + dry_run: bool = False, + no_policy: bool = False, + obs_delay_s: float = 0.0, + command_delay_s: float = 0.0, + ) -> None: + self.cfg = cfg + self.dry_run = dry_run + self.no_policy = no_policy + self.obs_delay_s = self._validate_delay_s(obs_delay_s, name="obs_delay_s") + self.command_delay_s = self._validate_delay_s(command_delay_s, name="command_delay_s") + self.shutdown_requested = False + + self.policy_hz = float(cfg_get(cfg, "policy_hz", DEFAULT_POLICY_HZ)) + self.dt = 1.0 / self.policy_hz + self.robot_cfg = cfg_get(cfg, "robot") + self.real_robot_cfg = cfg_get(cfg, "real_robot") + self.default_angles = np.asarray(cfg_get(self.robot_cfg, "default_angles"), dtype=np.float32) + self.num_actions = int(cfg_get(self.robot_cfg, "num_actions", NUM_JOINTS)) + + default_root_qpos = np.asarray( + cfg_get(self.robot_cfg, "mujoco_default_qpos", [0.0, 0.0, 0.0]), dtype=np.float64 ).reshape(-1) - return raw_action - - def get_target_dof_pos(self, raw_action: np.ndarray) -> np.ndarray: - clipped = np.clip(raw_action, -10.0, 10.0) - scaled = clipped * ACTION_SCALE - return scaled + DEFAULT_ANGLES - - def reset(self): - self._history_buf.clear() - - -# ===================================================================== -# Main controller -# ===================================================================== - -class StandingController: - """RL-policy-based standing controller matching Sim2RealController.STANDING.""" - - def __init__(self, network_interface: str, policy_path: str, - ramp_duration: float = RAMP_DURATION, - no_policy: bool = False, - publish_hz: int = 250) -> None: - self._network_interface = network_interface - self._ramp_duration = ramp_duration - self._shutdown = False - - # ---- Load policy and observation builder ---- - self._policy = PolicyInference(policy_path) - xml_path = str(MJCF_PATH) - if not MJCF_PATH.exists(): - raise FileNotFoundError(f"MuJoCo XML not found: {xml_path}") - self._obs_builder = ObservationBuilder(xml_path) - - # Verify observation dimension matches policy expectation - if (self._policy._expected_obs_dim is not None and - self._policy._expected_obs_dim != self._obs_builder.total_obs_size): - raise ValueError( - f"Obs dimension mismatch: builder={self._obs_builder.total_obs_size}, " - f"policy={self._policy._expected_obs_dim}" - ) - - self._no_policy = no_policy - self._dry_run = False - self._state_delay = 0.0 + self._default_root_pos = np.zeros(3, dtype=np.float64) + if default_root_qpos.shape[0] >= 3: + self._default_root_pos[:] = default_root_qpos[:3] + + self.policy, self.obs_builder = self._build_policy_and_obs() + self.robot = UnitreeG1Robot(self.real_robot_cfg) + self.remote = UnitreeRemote() + self.safety = Sim2RealSafetyManager(cfg, self.robot, self.policy_hz, self.num_actions) + self._entered_debug = False + + sim_cfg = build_simulation_cfg(cfg) + self.ref_proc = Sim2RealReferenceProcessor( + obs_builder=self.obs_builder, + policy=self.policy, + policy_hz=self.policy_hz, + num_actions=self.num_actions, + reference_velocity_smoothing_alpha=float(sim_cfg["reference_velocity_smoothing_alpha"]), + reference_anchor_velocity_smoothing_alpha=float(sim_cfg["reference_anchor_velocity_smoothing_alpha"]), + ) + self.reference_window_builder = ReferenceWindowBuilder( + policy_dt_s=self.dt, + reference_steps=cfg_get(cfg, "reference_steps", [0]), + ) - # ---- Policy state ---- + self._standing_qpos = np.zeros(FULL_QPOS_DIM, dtype=np.float64) + self._standing_qpos[3] = 1.0 + self._standing_qpos[ROOT_DIM:FULL_QPOS_DIM] = self.default_angles.astype(np.float64) + self._last_action = np.zeros(self.num_actions, dtype=np.float32) + self._last_target: np.ndarray | None = None self._step_count = 0 - self._last_action = np.zeros(NUM_JOINTS, dtype=np.float32) - # Standing reference qpos: [pos(3), quat(4), joints(29)] = 36D - self._standing_qpos = np.zeros(36, dtype=np.float64) - self._standing_qpos[3] = 1.0 # identity quaternion w=1 - self._standing_qpos[7:36] = DEFAULT_ANGLES.astype(np.float64) - - # Startup ramp state - self._ramp_duration_steps = max(1, int(ramp_duration * POLICY_HZ)) - self._ramp_step = 0 - self._ramp_start_positions: np.ndarray | None = None - self._ramp_active = False - - # ---- Pipeline state ---- - self._inference_thread: threading.Thread | None = None - self._inference_running = False - - self._publish_hz = publish_hz - - self._init_cpp_backend() + console_cfg = cfg_get(cfg, "console", {}) or {} + self._timing = _StandaloneTimingReporter( + target_period_s=self.dt, + log_interval_s=float(cfg_get(console_cfg, "timing_log_interval_s", 10.0)), + enabled=bool(cfg_get(console_cfg, "show_timing", False)), + ) - # ================================================================== - # Backend init - # ================================================================== + if self.obs_delay_s > 0.0 or self.command_delay_s > 0.0: + logger.info( + "Diagnostic delay injection enabled | obs_delay=%.3fms command_delay=%.3fms", + self.obs_delay_s * 1000.0, + self.command_delay_s * 1000.0, + ) - def _init_cpp_backend(self) -> None: - import g1_bridge_sdk - logger.info("Using C++ bridge backend (%dHz publish)", self._publish_hz) - self._bridge = g1_bridge_sdk.G1Bridge(self._network_interface, self._publish_hz) + @staticmethod + def _validate_delay_s(value: float, *, name: str) -> float: + delay_s = float(value) + if not np.isfinite(delay_s) or delay_s < 0.0: + raise ValueError(f"{name} must be finite and >= 0, got {value!r}") + return delay_s + + def _build_policy_and_obs(self) -> tuple[Any, Any]: + controller_cfg = cfg_get(self.cfg, "controller") + policy, obs_builder = _build_policy_components( + robot_cfg=self.robot_cfg, + controller_cfg=controller_cfg, + sim_cfg=build_simulation_cfg(self.cfg), + project_root=PROJECT_ROOT, + controller_cls=RLPolicyController, + ) + if not bool(getattr(policy, "_multi_input", False)): + raise ValueError("Standalone standing requires a dual-input ONNX policy ('obs', 'obs_history').") + return policy, obs_builder - logger.info("Waiting for LowState on %s ...", self._network_interface) - if not self._bridge.wait_for_state(5.0): - raise RuntimeError("No LowState received within 5s -- check network and robot power") - logger.info("LowState received, robot connected") + def run(self) -> None: + signal.signal(signal.SIGINT, self._signal_handler) + signal.signal(signal.SIGTERM, self._signal_handler) - # ================================================================== - # Robot state reading - # ================================================================== + try: + if self.dry_run: + self._run_dry() + return + self._enter_standing() + self._run_control_loop() + finally: + self._cleanup() - def _get_robot_state(self): - return self._bridge.get_state() + def _enter_standing(self) -> None: + logger.info("Entering debug mode...") + if not self.robot.enter_debug_mode(): + raise RuntimeError("Failed to enter debug mode") + self._entered_debug = True + time.sleep(0.5) - # ================================================================== - # Publish thread - # ================================================================== + logger.info("Locking joints to current position...") + self.robot.lock_all_joints() + time.sleep(0.3) - def _start_publish(self) -> None: - self._bridge.start_publish() + self._warmup_policy() + state = self.robot.get_state() + self._set_default_standing_reference(state) + self._reset_policy_state() + self.safety.start_kp_ramp() + logger.info("Mode -> STANDING (standalone)") - def _stop_publish(self) -> None: - self._bridge.stop_publish() + def _run_control_loop(self) -> None: + while not self.shutdown_requested: + t0 = time.monotonic() + self.remote.update(self.robot.get_wireless_remote()) + if self.remote.LB.pressed and self.remote.RB.pressed: + logger.warning("L1+R1 pressed -- damping") + self.shutdown_requested = True + break + if self.safety.check_joint_velocity_safety(): + self.robot.set_damping() + self.shutdown_requested = True + break - # ================================================================== - # Motion switcher - # ================================================================== + step_timing = self._standing_step() + work_elapsed_s = time.monotonic() - t0 + cycle_elapsed_s = self._sleep_until(t0) + self._timing.record( + loop_start_s=t0, + work_elapsed_s=work_elapsed_s, + cycle_elapsed_s=cycle_elapsed_s, + step=step_timing, + ) - def enter_debug_mode(self) -> bool: - try: - for _ in range(10): - code, name = self._bridge.check_mode() - logger.info("check_mode -> code=%s, name=%s", code, name) - if code != 0: - logger.error("check_mode RPC failed with code=%s", code) - return False - if not name: - logger.info("Debug mode ready (no active mode)") - return True - logger.info("Releasing mode: %s", name) - release_code = self._bridge.release_mode() - logger.info("release_mode('%s') -> code=%s", name, release_code) - if release_code != 0: - logger.error("Failed to release mode '%s' (code=%s)", name, release_code) - return False - time.sleep(1) - logger.warning("Could not release modes after 10 attempts") - return False - except Exception as exc: - logger.error("enter_debug_mode failed: %s", exc) - return False - - def exit_debug_mode(self) -> bool: - self._stop_publish() - try: - code = self._bridge.select_mode("ai") - logger.info("select_mode('ai') -> code=%s", code) - return code == 0 - except Exception as exc: - logger.error("exit_debug_mode failed: %s", exc) - return False - - # ================================================================== - # Motor commands - # ================================================================== - - def _set_damping(self) -> None: - self._bridge.set_damping() - - def _lock_joints(self) -> None: - qpos, _, _, _ = self._get_robot_state() - self._bridge.set_target(qpos, KP, KD) - self._bridge.lock_joints() - - # ================================================================== - # Safety checks - # ================================================================== - - def _check_emergency_stop(self) -> bool: - try: - remote_bytes = self._bridge.get_wireless_remote() - if len(remote_bytes) < 4: - return False - keys = struct.unpack_from(" bool: - max_vel = np.max(np.abs(qvel)) - if max_vel > JOINT_VEL_LIMIT: - logger.error("SAFETY: joint vel %.2f rad/s > limit %.2f -- damping!", max_vel, JOINT_VEL_LIMIT) - return True - return False - - # ---- Startup ramp ---- - - def _apply_startup_ramp(self, target_dof_pos: np.ndarray) -> np.ndarray: - if not self._ramp_active or self._ramp_start_positions is None: - return target_dof_pos - - ramp_factor = min(1.0, self._ramp_step / self._ramp_duration_steps) - ramped = self._ramp_start_positions + ramp_factor * (target_dof_pos - self._ramp_start_positions) - - self._ramp_step += 1 - if self._ramp_step >= self._ramp_duration_steps: - self._ramp_active = False - logger.info("Startup ramp complete (%d steps)", self._ramp_duration_steps) - - return np.asarray(ramped, dtype=np.float32) - - # ---- Standing step (matches Sim2RealController._standing_step) ---- - - def _standing_step(self) -> np.ndarray: - """One step of RL policy standing inference. Returns target joint positions.""" - _t0 = time.monotonic() - qpos, qvel, quat, ang_vel = self._get_robot_state() - - # Build standing reference aligned to robot's current yaw - ref_qpos = self._standing_qpos.copy() - align_motion_qpos_yaw(quat, ref_qpos) - - # Zero joint velocity reference for standing - motion_joint_vel = np.zeros(NUM_JOINTS, dtype=np.float32) - motion_qpos = np.asarray(ref_qpos[:7 + NUM_JOINTS], dtype=np.float32) - - _t1 = time.monotonic() - # Build observation - obs = self._obs_builder.build( - robot_qpos=qpos, - robot_qvel=qvel, - robot_quat=quat, - robot_ang_vel=ang_vel, - motion_qpos=ref_qpos, + def _standing_step(self) -> _StepTiming: + t0 = time.monotonic() + robot_state = self.robot.get_state() + t_state = time.monotonic() + if self.obs_delay_s > 0.0: + time.sleep(self.obs_delay_s) + qpos = self._standing_qpos.copy() + motion_joint_vel = np.zeros(self.num_actions, dtype=np.float32) + motion_qpos = np.asarray(qpos[: ROOT_DIM + self.num_actions], dtype=np.float32) + reference_window = None + if obs_builder_requires_reference_window(self.obs_builder): + reference_window = build_static_reference_window(qpos, self.reference_window_builder, self.policy_hz) + obs = self.ref_proc.build_observation( + robot_state=robot_state, + motion_qpos=motion_qpos, motion_joint_vel=motion_joint_vel, last_action=self._last_action, + anchor_lin_vel_w=np.zeros(3, dtype=np.float32), + anchor_ang_vel_w=np.zeros(3, dtype=np.float32), + reference_window=reference_window, ) - - _t2 = time.monotonic() - # Policy inference - action = self._policy.compute_action(obs) - - target_dof_pos = self._policy.get_target_dof_pos(action) - _t3 = time.monotonic() - - # Diagnostic - self._step_count += 1 - step_ms = (_t3 - _t0) * 1000 - if self._step_count % 25 == 1 or step_ms > (1000.0 / POLICY_HZ): - tag = "OVERRUN" if step_ms > (1000.0 / POLICY_HZ) else "DIAG" - logger.info( - "%s step=%d | state=%.2fms obs=%.2fms infer=%.2fms total=%.1fms | " - "qvel_norm=%.4f | action_norm=%.4f | " - "target[:6]=%s | qpos[:6]=%s", - tag, self._step_count, - (_t1 - _t0) * 1000, (_t2 - _t1) * 1000, (_t3 - _t2) * 1000, - step_ms, - float(np.linalg.norm(qvel)), - float(np.linalg.norm(action)), - np.array2string(target_dof_pos[:6], precision=4, separator=','), - np.array2string(qpos[:6], precision=4, separator=','), - ) - - # Startup ramp - target_dof_pos = self._apply_startup_ramp(target_dof_pos) - - # Joint limits - target_dof_pos = np.clip(target_dof_pos, JOINT_POS_LOWER, JOINT_POS_UPPER) - + obs = self.ref_proc.validate_observation(obs) + t_obs = time.monotonic() + action = np.zeros(self.num_actions, dtype=np.float32) if self.no_policy else self.policy.compute_action(obs) + t_infer = time.monotonic() + target_dof_pos = self.safety.clip_to_joint_limits(self.policy.get_target_dof_pos(action)) + t_target = time.monotonic() + action_delta = float(np.linalg.norm(action - self._last_action)) + target_delta = None if self._last_target is None else float(np.linalg.norm(target_dof_pos - self._last_target)) + if self.dry_run: + self._log_step(robot_state.qvel, action, target_dof_pos, dry=True) + else: + if self.command_delay_s > 0.0: + time.sleep(self.command_delay_s) + self.safety.send_positions(target_dof_pos) + self._log_step(robot_state.qvel, action, target_dof_pos, dry=False) + t_send = time.monotonic() self._last_action = np.asarray(action, dtype=np.float32).reshape(-1) - return target_dof_pos - - # ---- Inference thread ---- - - def _start_inference(self) -> None: - if self._inference_thread is not None: - return - self._inference_running = True - self._inference_thread = threading.Thread(target=self._inference_loop, daemon=True) - self._inference_thread.start() - - def _stop_inference(self) -> None: - self._inference_running = False - if self._inference_thread is not None: - self._inference_thread.join(timeout=2.0) - self._inference_thread = None - - def _inference_loop(self) -> None: - """~50Hz soft-realtime inference loop. - - Runs policy inference and writes target positions to _target_buf. - The 250Hz publish thread reads _target_buf independently — even if - inference takes 30ms, the robot keeps receiving 250Hz commands with - the last good target. - """ - dt = 1.0 / POLICY_HZ - loop_count = 0 - overrun_count = 0 - max_elapsed = 0.0 - elapsed_sum = 0.0 - - while self._inference_running and not self._shutdown: - t0 = time.monotonic() - - # Emergency stop check - if self._check_emergency_stop(): - logger.warning("L1+R1 pressed -- emergency damping!") - self._set_damping() - self._shutdown = True - break - - # Joint velocity safety check (warning only, no shutdown) - _, qvel, _, _ = self._get_robot_state() - self._check_joint_vel_safety(qvel) - - # Artificial state delay (for debugging) - if self._state_delay > 0: - time.sleep(self._state_delay) - - # Policy step - if self._no_policy: - target = self._apply_startup_ramp(DEFAULT_ANGLES.copy()) - target = np.clip(target, JOINT_POS_LOWER, JOINT_POS_UPPER) - else: - target = self._standing_step() - - # Write target to publish thread - self._bridge.set_target(target, KP, KD) - - # Timing diagnostics (informational only — not a control failure) - elapsed = time.monotonic() - t0 - loop_count += 1 - elapsed_sum += elapsed - max_elapsed = max(max_elapsed, elapsed) - if elapsed > dt: - overrun_count += 1 - - if loop_count % 50 == 0: - avg_ms = (elapsed_sum / loop_count) * 1000 - logger.info( - "Inference stats: avg=%.1fms, max=%.1fms, overruns=%d/%d (target=%.1fms)", - avg_ms, max_elapsed * 1000, overrun_count, loop_count, dt * 1000, - ) - - remain = dt - (time.monotonic() - t0) - if remain > 0: - time.sleep(remain) - - # ---- Main loop ---- + self._last_target = np.asarray(target_dof_pos, dtype=np.float32).reshape(-1).copy() + return _StepTiming( + state_s=t_state - t0, + obs_s=t_obs - t_state, + infer_s=t_infer - t_obs, + target_s=t_target - t_infer, + send_s=t_send - t_target, + action_delta=action_delta, + target_delta=target_delta, + qvel_norm=float(np.linalg.norm(robot_state.qvel)), + ang_vel_norm=float(np.linalg.norm(robot_state.ang_vel)), + ) def _run_dry(self) -> None: - """Dry-run: read state + build obs + infer, no motor commands. Safe for timing tests.""" logger.info("=== DRY-RUN MODE: no motor commands will be sent ===") - self._last_action = np.zeros(NUM_JOINTS, dtype=np.float32) - self._policy.reset() - self._ramp_active = False - - dt = 1.0 / POLICY_HZ - loop_count = 0 - overrun_count = 0 - elapsed_sum = 0.0 - max_elapsed = 0.0 - state_sum = 0.0 - obs_sum = 0.0 - infer_sum = 0.0 - - while not self._shutdown: + self._warmup_policy() + state = self.robot.get_state() + self._set_default_standing_reference(state) + self._reset_policy_state() + while not self.shutdown_requested: t0 = time.monotonic() - - # 1. Read state - qpos, qvel, quat, ang_vel = self._get_robot_state() - t1 = time.monotonic() - - # 2. Build reference - ref_qpos = self._standing_qpos.copy() - align_motion_qpos_yaw(quat, ref_qpos) - motion_joint_vel = np.zeros(NUM_JOINTS, dtype=np.float32) - - # 3. Build observation - obs = self._obs_builder.build( - robot_qpos=qpos, robot_qvel=qvel, - robot_quat=quat, robot_ang_vel=ang_vel, - motion_qpos=ref_qpos, motion_joint_vel=motion_joint_vel, - last_action=self._last_action, + step_timing = self._standing_step() + work_elapsed_s = time.monotonic() - t0 + cycle_elapsed_s = self._sleep_until(t0) + self._timing.record( + loop_start_s=t0, + work_elapsed_s=work_elapsed_s, + cycle_elapsed_s=cycle_elapsed_s, + step=step_timing, ) - t2 = time.monotonic() - - # 4. Policy inference - raw_action = self._policy.compute_action(obs) - target = self._policy.get_target_dof_pos(raw_action) - t3 = time.monotonic() - - self._last_action = raw_action.copy() - loop_count += 1 - e = t3 - t0 - elapsed_sum += e - max_elapsed = max(max_elapsed, e) - state_sum += (t1 - t0) - obs_sum += (t2 - t1) - infer_sum += (t3 - t2) - if e > dt: - overrun_count += 1 - - if loop_count % 50 == 0: - n = loop_count - logger.info( - "DRY step=%d | state=%.2fms obs=%.2fms infer=%.2fms total=%.2fms | " - "max=%.2fms overruns=%d/%d | target[:6]=%s", - n, - (state_sum / n) * 1000, (obs_sum / n) * 1000, - (infer_sum / n) * 1000, (elapsed_sum / n) * 1000, - max_elapsed * 1000, overrun_count, n, - np.array2string(target[:6], precision=4, separator=','), - ) - - remain = dt - (time.monotonic() - t0) - if remain > 0: - time.sleep(remain) + def _set_default_standing_reference(self, state: object) -> None: + self._standing_qpos[:] = 0.0 + self._standing_qpos[0:3] = self._default_root_pos + self._standing_qpos[3:7] = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64) + align_motion_qpos_yaw(np.asarray(getattr(state, "quat"), dtype=np.float32), self._standing_qpos) + self._standing_qpos[ROOT_DIM:FULL_QPOS_DIM] = self.default_angles.astype(np.float64) + + def _reset_policy_state(self) -> None: + self._last_action = np.zeros(self.num_actions, dtype=np.float32) + self.ref_proc.reset_smoothers() + self.ref_proc.reset_alignment() + self.policy.reset() + self.obs_builder.reset() + self._last_target = None + + def _warmup_policy(self) -> None: + if self.no_policy: + return + logger.info("Warming up ONNX runtime...") + dummy_obs = np.zeros(int(self.obs_builder.total_obs_size), dtype=np.float32) + for _ in range(3): + self.policy.compute_action(dummy_obs) + self.policy.reset() + logger.info("ONNX warmup complete") + + def _log_step(self, qvel: np.ndarray, action: np.ndarray, target: np.ndarray, *, dry: bool) -> None: + self._step_count += 1 + if self._step_count % 50 != 1: + return + tag = "DRY" if dry else "STANDING" logger.info( - "DRY-RUN finished: %d steps, avg total=%.2fms " - "(state=%.2f obs=%.2f infer=%.2f) max=%.2fms overruns=%d", - loop_count, - (elapsed_sum / max(loop_count, 1)) * 1000, - (state_sum / max(loop_count, 1)) * 1000, - (obs_sum / max(loop_count, 1)) * 1000, - (infer_sum / max(loop_count, 1)) * 1000, - max_elapsed * 1000, overrun_count, + "%s step=%d | qvel_norm=%.4f | action_norm=%.4f | target[:6]=%s", + tag, + self._step_count, + float(np.linalg.norm(qvel)), + float(np.linalg.norm(action)), + np.array2string(target[:6], precision=4, separator=","), ) - def run(self) -> None: - signal.signal(signal.SIGINT, self._signal_handler) - signal.signal(signal.SIGTERM, self._signal_handler) + def _sleep_until(self, t0: float) -> float: + remaining = self.dt - (time.monotonic() - t0) + if remaining > 0.0: + time.sleep(remaining) + return time.monotonic() - t0 - if self._dry_run: - self._run_dry() - return + def _signal_handler(self, _signum: int, _frame: object) -> None: + logger.info("Shutdown signal received") + self.shutdown_requested = True + def _cleanup(self) -> None: + if self.dry_run: + self.robot.close() + return + if not self._entered_debug: + self.robot.close() + return try: - # 1. Enter debug mode - logger.info("Entering debug mode ...") - if not self.enter_debug_mode(): - logger.error("Failed to enter debug mode, aborting") - return + logger.info("Shutting down: setting damping...") + self.robot.set_damping() time.sleep(0.5) - - # 2. Start publish thread (C++: 500Hz, Python: 250Hz) - self._start_publish() - - # 3. Lock joints to current position - logger.info("Locking joints to current position ...") - self._lock_joints() - time.sleep(0.3) - - # 4. ONNX warmup — eliminate first-inference spike - logger.info("Warming up ONNX runtime ...") - dummy_obs = np.zeros(self._obs_builder.total_obs_size, dtype=np.float32) - for _ in range(3): - self._policy.compute_action(dummy_obs) - self._policy.reset() - logger.info("ONNX warmup complete") - - # 5. Initialize policy state and startup ramp - qpos, _, quat, _ = self._get_robot_state() - self._last_action = np.zeros(NUM_JOINTS, dtype=np.float32) - self._ramp_start_positions = qpos.copy() - self._ramp_step = 0 - self._ramp_active = True - - logger.info( - "Starting RL policy standing (pipelined) | ramp=%d steps (%.1fs)", - self._ramp_duration_steps, self._ramp_duration_steps / POLICY_HZ, - ) - - # 6. Start inference thread (~50Hz, soft deadline) - self._start_inference() - - # 7. Main thread waits for shutdown signal - while not self._shutdown: - time.sleep(0.1) - - # 8. Stop inference thread - self._stop_inference() - - except Exception as exc: - logger.error("Error in main loop: %s", exc) finally: - self._cleanup() - - def _signal_handler(self, signum, frame) -> None: - logger.info("Shutdown signal received") - self._shutdown = True - - def _cleanup(self) -> None: - self._inference_running = False - logger.info("Shutting down: setting damping ...") - self._set_damping() - time.sleep(0.5) - logger.info("Stopping publish and restoring ai mode ...") - self.exit_debug_mode() - logger.info("Done.") - - -def main(): - parser = argparse.ArgumentParser(description="G1 standalone standing with RL policy") - parser.add_argument( - "--policy", type=str, required=True, - help="Path to ONNX policy file (e.g. track.onnx)", - ) - parser.add_argument( - "--network-interface", type=str, default="eth0", - help="Network interface for DDS (e.g. eth0, enp130s0)", - ) - parser.add_argument( - "--ramp-duration", type=float, default=RAMP_DURATION, - help="Seconds for startup ramp (default: 2.0)", - ) - parser.add_argument( - "--no-policy", action="store_true", - help="Skip RL policy, just send fixed DEFAULT_ANGLES (diagnostic mode)", + logger.info("Stopping publish and restoring ai mode...") + self.robot.exit_debug_mode() + self.robot.close() + logger.info("Done.") + + +def _build_cfg(args: argparse.Namespace) -> Any: + cfg = OmegaConf.create( + { + "policy_hz": DEFAULT_POLICY_HZ, + "startup_ramp_duration": args.kp_ramp_duration, + "kp_ramp_floor_ratio": args.kp_ramp_floor_ratio, + "joint_vel_limit": args.joint_vel_limit, + "reference_steps": [0], + "reference_velocity_smoothing_alpha": 1.0, + "reference_anchor_velocity_smoothing_alpha": 1.0, + "robot": OmegaConf.load(PROJECT_ROOT / "teleopit" / "configs" / "robot" / "g1.yaml"), + "controller": OmegaConf.load(PROJECT_ROOT / "teleopit" / "configs" / "controller" / "rl_policy.yaml"), + "real_robot": OmegaConf.load(PROJECT_ROOT / "teleopit" / "configs" / "sim2real.yaml").real_robot, + "console": { + "show_timing": bool(args.show_timing), + "timing_log_interval_s": float(args.timing_log_interval_s), + }, + } ) + cfg.controller.policy_path = str(args.policy) + cfg.real_robot.network_interface = str(args.network_interface) + cfg.real_robot.publish_hz = int(args.publish_hz) + return cfg + + +def main() -> None: + parser = argparse.ArgumentParser(description="G1 standalone standing using sim2real standing implementation") + parser.add_argument("--policy", type=str, required=True, help="Path to a dual-input ONNX policy") + parser.add_argument("--network-interface", type=str, default="eth0", help="DDS network interface") + parser.add_argument("--dry-run", action="store_true", help="Read state and run policy without motor commands") + parser.add_argument("--no-policy", action="store_true", help="Send zero-action standing target") + parser.add_argument("--publish-hz", type=int, default=200, help="C++ bridge publish frequency") + parser.add_argument("--kp-ramp-duration", type=float, default=2.0, help="Startup Kp ramp duration in seconds") + parser.add_argument("--kp-ramp-floor-ratio", type=float, default=0.1, help="Initial Kp ratio during startup") + parser.add_argument("--joint-vel-limit", type=float, default=10.0, help="Damp if any joint exceeds this velocity") + parser.add_argument("--show-timing", action="store_true", help="Print periodic timing diagnostics") parser.add_argument( - "--state-delay", type=float, default=0.0, - help="Artificial delay (seconds) before reading state, simulates network latency (e.g. 0.005)", + "--timing-log-interval-s", + type=float, + default=10.0, + help="Timing diagnostic print interval when --show-timing is set", ) parser.add_argument( - "--dry-run", action="store_true", - help="Read state + build obs + infer only, no motor commands (safe timing test)", + "--obs-delay-ms", + type=float, + default=0.0, + help="Diagnostic delay after LowState read, before observation build/inference", ) parser.add_argument( - "--publish-hz", type=int, default=200, - help="C++ publish frequency in Hz (default: 200, matching training pd_hz)", + "--command-delay-ms", + type=float, + default=0.0, + help="Diagnostic delay after target computation, before C++ bridge set_target", ) args = parser.parse_args() - controller = StandingController( - network_interface=args.network_interface, - policy_path=args.policy, - ramp_duration=args.ramp_duration, - no_policy=args.no_policy, - publish_hz=args.publish_hz, + controller = StandaloneStandingController( + _build_cfg(args), + dry_run=bool(args.dry_run), + no_policy=bool(args.no_policy), + obs_delay_s=float(args.obs_delay_ms) / 1000.0, + command_delay_s=float(args.command_delay_ms) / 1000.0, ) - controller._state_delay = args.state_delay - controller._dry_run = args.dry_run controller.run() diff --git a/scripts/setup/download_assets.py b/scripts/setup/download_assets.py index fcaa303f..77d3dae3 100755 --- a/scripts/setup/download_assets.py +++ b/scripts/setup/download_assets.py @@ -4,7 +4,7 @@ Usage: python scripts/setup/download_assets.py # download everything (ModelScope) python scripts/setup/download_assets.py --source huggingface # download from HuggingFace - python scripts/setup/download_assets.py --only gmr # only GMR retargeting assets + python scripts/setup/download_assets.py --only robots gmr # only robot XML/meshes + GMR retargeting assets python scripts/setup/download_assets.py --only ckpt # only checkpoints python scripts/setup/download_assets.py --only data # only training data python scripts/setup/download_assets.py --only bvh # only sample BVH diff --git a/scripts/setup/download_somehand_l6_assets.sh b/scripts/setup/download_somehand_l6_assets.sh new file mode 100755 index 00000000..5729e731 --- /dev/null +++ b/scripts/setup/download_somehand_l6_assets.sh @@ -0,0 +1,14 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)" +SOMEHAND_DIR="${PROJECT_ROOT}/third_party/somehand" + +if [[ ! -f "${SOMEHAND_DIR}/scripts/setup/download_assets.py" ]]; then + echo "somehand submodule is not initialized. Run: git submodule update --init third_party/somehand" >&2 + exit 1 +fi + +cd "${SOMEHAND_DIR}" +python scripts/setup/download_assets.py --only mjcf "$@" diff --git a/scripts/setup/prepare_modelscope_assets.py b/scripts/setup/prepare_modelscope_assets.py index 528fea72..ccb77ae3 100644 --- a/scripts/setup/prepare_modelscope_assets.py +++ b/scripts/setup/prepare_modelscope_assets.py @@ -2,7 +2,7 @@ """Prepare a compact upload directory for the ModelScope asset repos. Asset groups by destination repo: - model (BingqianWu/Teleopit-models): ckpt, gmr, bvh + model (BingqianWu/Teleopit-models): ckpt, robots, gmr, bvh dataset (BingqianWu/Teleopit-datasets): data """ diff --git a/scripts/setup/upload_hf_assets.py b/scripts/setup/upload_hf_assets.py index 1fe4d815..46bfe156 100644 --- a/scripts/setup/upload_hf_assets.py +++ b/scripts/setup/upload_hf_assets.py @@ -2,12 +2,12 @@ """Prepare and upload Teleopit assets to HuggingFace repos. Asset groups by destination repo: - model (12e21/Teleopit-models): ckpt, gmr, bvh + model (12e21/Teleopit-models): ckpt, robots, gmr, bvh dataset (12e21/Teleopit-datasets): data Usage: python scripts/setup/upload_hf_assets.py # upload everything - python scripts/setup/upload_hf_assets.py --only gmr # only GMR assets + python scripts/setup/upload_hf_assets.py --only robots gmr # only robot + GMR assets python scripts/setup/upload_hf_assets.py --dry-run # prepare files without uploading """ diff --git a/scripts/view/view_dataset.py b/scripts/view/view_dataset.py new file mode 100644 index 00000000..582f50e7 --- /dev/null +++ b/scripts/view/view_dataset.py @@ -0,0 +1,451 @@ +#!/usr/bin/env python3 +"""Read-only web viewer for Teleopit HDF5 motion datasets.""" + +from __future__ import annotations + +import argparse +import sys +import time +from dataclasses import dataclass +from pathlib import Path + +import h5py +import mujoco +import numpy as np +import viser + +PROJECT_ROOT = Path(__file__).resolve().parents[2] +sys.path.insert(0, str(PROJECT_ROOT)) + +from mjlab.viewer.viser import ViserMujocoScene + +from teleopit.runtime.assets import UNITREE_G1_XML, missing_gmr_assets_message +from train_mimic.data.dataset_lib import find_motion_shards, read_motion_clip + +DEFAULT_XML = UNITREE_G1_XML + + +@dataclass(frozen=True) +class DatasetClip: + clip_id: str + shard_path: Path + clip_index: int + num_frames: int + fps: int + + @property + def duration_s(self) -> float: + return self.num_frames / self.fps if self.fps > 0 else 0.0 + + +def discover_dataset_clips(dataset_path: Path) -> list[DatasetClip]: + clips: list[DatasetClip] = [] + for shard_path in find_motion_shards(dataset_path): + with h5py.File(shard_path, "r") as h5: + required = ["source_clip_lengths", "source_clip_fps"] + missing = [key for key in required if key not in h5] + if missing: + raise ValueError( + f"HDF5 shard {shard_path} missing source clip metadata {missing}. " + "Rebuild the dataset with the current HDF5 writer." + ) + lengths = np.asarray(h5["source_clip_lengths"], dtype=np.int64) + fps = np.asarray(h5["source_clip_fps"], dtype=np.int64) + + if dataset_path.is_file(): + shard_rel = Path(shard_path.name) + else: + try: + shard_rel = shard_path.relative_to(dataset_path) + except ValueError: + shard_rel = Path(shard_path.name) + for clip_index, (num_frames, clip_fps) in enumerate(zip(lengths, fps)): + clips.append( + DatasetClip( + clip_id=f"{shard_rel.as_posix()}#{clip_index}", + shard_path=shard_path, + clip_index=int(clip_index), + num_frames=int(num_frames), + fps=int(clip_fps), + ) + ) + if not clips: + raise ValueError(f"dataset has no source clips: {dataset_path}") + return clips + + +class ClipPlayer: + """Loads one source clip and drives MuJoCo qpos frame by frame.""" + + def __init__(self, mj_model: mujoco.MjModel) -> None: + self.model = mj_model + self.data = mujoco.MjData(mj_model) + self._joint_pos: np.ndarray | None = None + self._pelvis_pos: np.ndarray | None = None + self._pelvis_quat: np.ndarray | None = None + self._fps: int = 30 + self._num_frames: int = 0 + + def load_clip(self, clip: DatasetClip) -> None: + d = read_motion_clip(clip.shard_path, clip.clip_index) + self._joint_pos = np.asarray(d["joint_pos"]) + body_pos_w = np.asarray(d["body_pos_w"]) + body_quat_w = np.asarray(d["body_quat_w"]) + self._pelvis_pos = body_pos_w[:, 0, :] + self._pelvis_quat = body_quat_w[:, 0, :] + self._fps = int(d["fps"]) + self._num_frames = int(self._joint_pos.shape[0]) + + def set_frame(self, frame_idx: int) -> None: + if self._joint_pos is None or self._pelvis_pos is None or self._pelvis_quat is None: + return + idx = max(0, min(frame_idx, self._num_frames - 1)) + self.data.qpos[:3] = self._pelvis_pos[idx] + self.data.qpos[3:7] = self._pelvis_quat[idx] + self.data.qpos[7:] = self._joint_pos[idx] + mujoco.mj_forward(self.model, self.data) + + @property + def fps(self) -> int: + return self._fps + + @property + def num_frames(self) -> int: + return self._num_frames + + +class DatasetSession: + def __init__(self, clips: list[DatasetClip], sort_mode: str) -> None: + self._clips = clips + self._order = list(range(len(clips))) + if sort_mode == "duration_desc": + self._order.sort(key=lambda i: -clips[i].duration_s) + elif sort_mode == "shard": + self._order.sort(key=lambda i: clips[i].clip_id) + self._cursor = 0 + + @property + def total(self) -> int: + return len(self._clips) + + @property + def cursor_display(self) -> int: + return self._cursor + 1 + + @property + def total_duration_s(self) -> float: + return sum(clip.duration_s for clip in self._clips) + + def current_clip(self) -> DatasetClip: + return self._clips[self._order[self._cursor]] + + def go_next(self) -> None: + if self._cursor < len(self._order) - 1: + self._cursor += 1 + + def go_prev(self) -> None: + if self._cursor > 0: + self._cursor -= 1 + + def jump_to(self, position: int) -> None: + self._cursor = max(0, min(position - 1, len(self._order) - 1)) + + +class DatasetViewerApp: + def __init__( + self, + *, + dataset_path: Path, + xml_path: Path, + port: int, + sort_mode: str, + ) -> None: + self._dataset_path = dataset_path + self._model = mujoco.MjModel.from_xml_path(str(xml_path)) + self._player = ClipPlayer(self._model) + self._session = DatasetSession(discover_dataset_clips(dataset_path), sort_mode) + + self._server = viser.ViserServer(port=port, label="Dataset Viewer") + self._scene = ViserMujocoScene.create( + server=self._server, + mj_model=self._model, + num_envs=1, + ) + mujoco.mj_forward(self._model, self._player.data) + self._scene.update_from_mjdata(self._player.data) + + self._playing = False + self._speed = 1.0 + self._current_frame = 0 + self._pending_actions: list[str] = [] + self._pending_frame_scrub: int | None = None + self._pending_jump: int | None = None + self._hotkey_clearing = False + + def setup_gui(self) -> None: + gui = self._server.gui + + with gui.add_folder("Clip", order=0): + self._info_html = gui.add_html("Loading...") + + with gui.add_folder("Playback", order=1): + self._play_btn = gui.add_button("Play", color="green") + self._frame_slider = gui.add_slider("Frame", min=0, max=1, step=1, initial_value=0) + self._speed_group = gui.add_button_group( + "Speed", + options=["0.25x", "0.5x", "1x", "2x"], + ) + self._restart_btn = gui.add_button("Restart") + + @self._play_btn.on_click + def _(_) -> None: + self._pending_actions.append("toggle_play") + + @self._frame_slider.on_update + def _(_) -> None: + self._pending_frame_scrub = int(self._frame_slider.value) + + @self._speed_group.on_click + def _(event) -> None: + self._speed = {"0.25x": 0.25, "0.5x": 0.5, "1x": 1.0, "2x": 2.0}.get( + event.target.value, + 1.0, + ) + + @self._restart_btn.on_click + def _(_) -> None: + self._pending_actions.append("restart") + + with gui.add_folder("Navigation", order=2): + self._prev_btn = gui.add_button("Prev") + self._next_btn = gui.add_button("Next") + self._jump_input = gui.add_number( + "Jump to #", + initial_value=1, + min=1, + max=self._session.total, + step=1, + ) + + @self._prev_btn.on_click + def _(_) -> None: + self._pending_actions.append("prev") + + @self._next_btn.on_click + def _(_) -> None: + self._pending_actions.append("next") + + @self._jump_input.on_update + def _(_) -> None: + self._pending_jump = int(self._jump_input.value) + + with gui.add_folder("Dataset", order=3, expand_by_default=False): + self._stats_html = gui.add_html("Loading...") + + with gui.add_folder("Shortcuts", order=4): + self._hotkey_input = gui.add_text("Hotkey (click here, type key)", initial_value="") + gui.add_markdown("**N**=Next **P**=Prev **Space**=Play/Pause **R**=Restart **F**=Speed Up") + + @self._hotkey_input.on_update + def _(_) -> None: + if self._hotkey_clearing: + return + raw = self._hotkey_input.value + self._hotkey_clearing = True + self._hotkey_input.value = "" + self._hotkey_clearing = False + if not raw: + return + action = { + "n": "next", + "p": "prev", + " ": "toggle_play", + "r": "restart", + "f": "speed_up", + }.get(raw[-1].lower()) + if action: + self._pending_actions.append(action) + + self._scene.create_visualization_gui(show_debug_viz_control=False) + + def _load_current_clip(self) -> None: + clip = self._session.current_clip() + try: + self._player.load_clip(clip) + except Exception as exc: + self._info_html.content = f"Error loading clip:
{exc}" + return + + self._current_frame = 0 + self._playing = False + self._player.set_frame(0) + self._scene.update_from_mjdata(self._player.data) + + self._play_btn.label = "Play" + self._play_btn.color = "green" + self._frame_slider.max = max(0, self._player.num_frames - 1) + self._frame_slider.value = 0 + self._jump_input.value = self._session.cursor_display + + self._pending_frame_scrub = None + self._pending_jump = None + self._pending_actions.clear() + self._update_info_display() + self._update_stats_display() + + def _update_frame(self) -> None: + self._player.set_frame(self._current_frame) + self._scene.update_from_mjdata(self._player.data) + + def _update_info_display(self) -> None: + clip = self._session.current_clip() + self._info_html.content = f""" +
+ #{self._session.cursor_display}/{self._session.total}
+ clip: {clip.clip_id}
+ shard: {clip.shard_path}
+ source clip index: {clip.clip_index}
+ frames: {clip.num_frames} | fps: {clip.fps} + | duration: {clip.duration_s:.2f}s +
+ """ + + def _update_stats_display(self) -> None: + total_min = self._session.total_duration_s / 60.0 + self._stats_html.content = f""" +
+ root: {self._dataset_path}
+ source clips: {self._session.total}
+ duration: {total_min:.1f} min +
+ """ + + def _process_pending(self) -> None: + scrub = self._pending_frame_scrub + if scrub is not None and not self._playing: + self._pending_frame_scrub = None + self._current_frame = scrub + self._update_frame() + + jump = self._pending_jump + if jump is not None: + self._pending_jump = None + self._session.jump_to(jump) + self._load_current_clip() + return + + while self._pending_actions: + action = self._pending_actions.pop(0) + if action == "toggle_play": + self._playing = not self._playing + self._play_btn.label = "Pause" if self._playing else "Play" + self._play_btn.color = "red" if self._playing else "green" + elif action == "restart": + self._current_frame = 0 + self._update_frame() + self._frame_slider.value = 0 + self._pending_frame_scrub = None + elif action == "prev": + self._session.go_prev() + self._load_current_clip() + return + elif action == "next": + self._session.go_next() + self._load_current_clip() + return + elif action == "speed_up": + levels = [1.0, 2.0, 4.0, 8.0] + self._speed = levels[(levels.index(self._speed) + 1) % len(levels)] if self._speed in levels else 1.0 + label = f"{self._speed:g}x" + if label in ("0.25x", "0.5x", "1x", "2x"): + self._speed_group.value = label + + def run(self) -> None: + self.setup_gui() + self._load_current_clip() + + print(f"\nDataset viewer ready at http://localhost:{self._server.get_port()}") + print("Press Ctrl+C to exit.\n") + + last_frame_time = time.time() + try: + while True: + now = time.time() + if self._scene.needs_update: + self._scene.refresh_visualization() + + self._process_pending() + + if self._playing and self._player.num_frames > 0: + dt = now - last_frame_time + frames_to_advance = dt * self._player.fps * self._speed + if frames_to_advance >= 1.0: + self._current_frame += int(frames_to_advance) + if self._current_frame >= self._player.num_frames: + self._current_frame = 0 + self._playing = False + self._play_btn.label = "Play" + self._play_btn.color = "green" + self._update_frame() + self._frame_slider.value = self._current_frame + self._pending_frame_scrub = None + last_frame_time = now + else: + last_frame_time = now + + time.sleep(1.0 / 60.0) + except KeyboardInterrupt: + print("\nShutting down...") + self._server.stop() + + +def main() -> None: + parser = argparse.ArgumentParser(description="Read-only web viewer for Teleopit HDF5 datasets") + parser.add_argument( + "--dataset", + type=str, + required=True, + help="Dataset root directory or a single Teleopit .h5 shard", + ) + parser.add_argument( + "--xml", + type=str, + default=None, + help="Robot XML path", + ) + parser.add_argument("--port", type=int, default=8012, help="Viser server port") + parser.add_argument( + "--sort", + type=str, + default="shard", + choices=["shard", "duration_desc"], + help="Initial clip sort order", + ) + args = parser.parse_args() + + dataset_path = Path(args.dataset) + if not dataset_path.is_absolute(): + dataset_path = (PROJECT_ROOT / dataset_path).resolve() + if not dataset_path.exists(): + print(f"ERROR: dataset path not found: {dataset_path}", file=sys.stderr) + sys.exit(1) + + xml_path = Path(args.xml) if args.xml else DEFAULT_XML + if not xml_path.is_file(): + print( + "ERROR: " + missing_gmr_assets_message(xml_path, label="Robot XML"), + file=sys.stderr, + ) + sys.exit(1) + + app = DatasetViewerApp( + dataset_path=dataset_path, + xml_path=xml_path, + port=args.port, + sort_mode=args.sort, + ) + app.run() + + +if __name__ == "__main__": + main() diff --git a/teleopit/__init__.py b/teleopit/__init__.py index 3dc1f76b..6a9beea8 100644 --- a/teleopit/__init__.py +++ b/teleopit/__init__.py @@ -1 +1 @@ -__version__ = "0.1.0" +__version__ = "0.4.0" diff --git a/teleopit/configs/default.yaml b/teleopit/configs/default.yaml index 9e45aefb..41eae1e4 100644 --- a/teleopit/configs/default.yaml +++ b/teleopit/configs/default.yaml @@ -6,7 +6,6 @@ defaults: policy_hz: 50.0 pd_hz: 200.0 -transition_duration: 2.0 # seconds to interpolate from default pose to motion command (0.0 to disable) viewers: "sim2sim" realtime: false debug_trace_path: null @@ -15,6 +14,11 @@ playback: keyboard: enabled: false +console: + log_level: warning + show_timing: false + timing_log_interval_s: 10.0 + hydra: run: dir: . diff --git a/teleopit/configs/input/pico4.yaml b/teleopit/configs/input/pico4.yaml index be0ef2cd..8aa4d7dc 100644 --- a/teleopit/configs/input/pico4.yaml +++ b/teleopit/configs/input/pico4.yaml @@ -7,6 +7,8 @@ pico4_buffer_size: 60 pico4_timestamp_gap_reset_s: 0.15 pause_button: A pause_debounce_s: 0.25 +arms_button: B +arms_debounce_s: 0.25 bridge_host: "0.0.0.0" bridge_port: 63901 bridge_discovery: true diff --git a/teleopit/configs/online.yaml b/teleopit/configs/online.yaml index 7fa0a2db..f9a5533e 100644 --- a/teleopit/configs/online.yaml +++ b/teleopit/configs/online.yaml @@ -6,15 +6,10 @@ defaults: policy_hz: 50.0 pd_hz: 200.0 -pause_resume_transition_duration: 1.0 # offline playback resume blend; realtime mocap resumes from re-centered live tracking -pause_resume_warmup_steps: 2 retarget_buffer_enabled: true retarget_buffer_window_s: 0.5 retarget_buffer_delay_s: null -realtime_buffer_low_watermark_steps: 2 -realtime_buffer_high_watermark_steps: 4 realtime_buffer_warmup_steps: 2 -reference_qpos_smoothing_alpha: 0.4 reference_velocity_smoothing_alpha: 0.35 reference_anchor_velocity_smoothing_alpha: 0.25 reference_steps: [0] diff --git a/teleopit/configs/pico4_record.yaml b/teleopit/configs/pico4_record.yaml new file mode 100644 index 00000000..684a380e --- /dev/null +++ b/teleopit/configs/pico4_record.yaml @@ -0,0 +1,24 @@ +defaults: + - input: pico4 + - _self_ + +input: + # Recording does not need aggressive realtime timeline reset warnings. + pico4_timestamp_gap_reset_s: 0.5 + +record: + output_dir: data/pico_motion/clips + target_fps: 30 + min_frames: 30 + dataset_name: pico_recorded + dataset_spec_path: data/pico_motion/pico_recorded.yaml + write_dataset_spec: true + overwrite_dataset_spec: false + viewer_enabled: true + quiet_logs: true + # Optional safety cap in seconds. 0 means record until S/D/Q. + max_duration_s: 0.0 + +hydra: + run: + dir: . diff --git a/teleopit/configs/pico4_sim.yaml b/teleopit/configs/pico4_sim.yaml index a5313ae0..9abf8243 100644 --- a/teleopit/configs/pico4_sim.yaml +++ b/teleopit/configs/pico4_sim.yaml @@ -6,8 +6,6 @@ defaults: policy_hz: 50.0 pd_hz: 200.0 -pause_resume_transition_duration: 1.0 # offline playback resume blend; realtime mocap resumes from re-centered live tracking -pause_resume_warmup_steps: 0 # realtime mocap frames to accumulate before resuming live tracking in sim2sim keyboard: enabled: true # starts in STANDING; Y=enter mocap, A=pause/resume, X=return standing, Q=quit input: @@ -16,13 +14,13 @@ input: retarget_buffer_enabled: true retarget_buffer_window_s: 0.5 retarget_buffer_delay_s: null # null = auto use one input-frame delay for timeline sampling -realtime_buffer_low_watermark_steps: 2 -realtime_buffer_high_watermark_steps: 4 realtime_buffer_warmup_steps: 2 reference_velocity_smoothing_alpha: 0.35 reference_anchor_velocity_smoothing_alpha: 0.25 reference_steps: [0] reference_debug_log: false +arm_mocap: + controlled_joint_indices: [15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28] viewers: "all" realtime: true num_steps: 0 # 0 = infinite loop until Ctrl+C or device disconnect diff --git a/teleopit/configs/pico4_sim2real.yaml b/teleopit/configs/pico4_sim2real.yaml index 3c55d1c4..9b335e58 100644 --- a/teleopit/configs/pico4_sim2real.yaml +++ b/teleopit/configs/pico4_sim2real.yaml @@ -5,33 +5,105 @@ defaults: - _self_ policy_hz: 50.0 -transition_duration: 2.0 # seconds to interpolate from default pose to motion command (0.0 to disable) -pause_resume_transition_duration: 1.0 # offline playback resume blend; realtime mocap resumes from re-centered live tracking -pause_resume_warmup_steps: 2 # realtime mocap frames to accumulate before resuming live tracking +viewers: "none" # Optional: set viewers=retarget to show the retargeted reference input: video: source: realsense retarget_buffer_enabled: true retarget_buffer_window_s: 0.5 -retarget_buffer_delay_s: 0.05 -realtime_buffer_low_watermark_steps: 2 -realtime_buffer_high_watermark_steps: 4 +retarget_buffer_delay_s: null # null = auto use one input-frame delay for timeline sampling realtime_buffer_warmup_steps: 2 -realtime_catchup_enabled: true -realtime_catchup_trigger_steps: 6 -realtime_catchup_release_steps: 3 -realtime_catchup_target_delay_s: 0.04 -reference_qpos_smoothing_alpha: 0.4 reference_velocity_smoothing_alpha: 0.35 reference_anchor_velocity_smoothing_alpha: 0.25 reference_steps: [0] reference_debug_log: false -# Startup ramp duration (seconds) -- smoothly blend from locked to policy positions +recording: + enabled: false + format: hdf5 + output_dir: data/recordings/sim2real_hdf5 + task: demo + fps: 30 + control: terminal + min_episode_seconds: 1.0 + discard_on_shutdown: true + record_modes: [standing, mocap, arms, pause] + camera: + enabled: true + key: observation.images.d435i_rgb + source: realsense + width: 640 + height: 480 + fps: 30 + device: null + video: + codec: libx264 + quality: 8 + pixelformat: yuv420p + +runtime: + host: 127.0.0.1 + base_port: 39700 + start_method: spawn + shutdown_timeout_s: 3.0 + pico_input_hz: 120.0 + hand_worker_hz: 120.0 + retarget_idle_sleep_s: 0.001 + video_slots: 3 + # In active MOCAP, stale/invalid realtime references hold the last command instead of damping. + # max_reference_age_s only rejects entering MOCAP from STANDING on an old reference. + stale_reference_hold_s: 0.08 + max_reference_age_s: 0.25 + +# Kp ramp duration (seconds) -- gradually increases PD gains after entering STANDING startup_ramp_duration: 2.0 +kp_ramp_floor_ratio: 0.1 +# Faster Kp ramp used when returning from MOCAP to default STANDING with X +standing_return_ramp_duration: 0.5 +standing_return_kp_ramp_floor_ratio: 0.5 # Joint velocity safety limit (rad/s) -- trigger emergency damping if exceeded joint_vel_limit: 10.0 +arm_mocap: + controlled_joint_indices: [15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28] + +# Optional LinkerHand control from Pico controller grip/trigger or VR hand pose. +hands: + enabled: false + driver: linkerhand_l6 # linkerhand_l6 | linkerhand_o6 + mode: gripper # gripper | vr_hand_pose + sides: [left, right] + rate_hz: 30.0 + frame_timeout_s: 0.3 + linkerhand_l6: + left_can: can0 + right_can: can1 + modbus: "None" + trigger_deadzone: 0.05 + deadman_threshold: 0.5 + thumb_yaw_center: 10 + speed: [50, 50, 50, 50, 50, 50] + open_pose: [250, 10, 250, 250, 250, 250] + close_pose: [79, 10, 0, 0, 0, 0] + print_input: false + linkerhand_o6: + left_can: can0 + right_can: can1 + modbus: "None" + trigger_deadzone: 0.05 + deadman_threshold: 0.5 + speed: [255, 255, 255, 255, 255, 255] + open_pose: [250, 250, 250, 250, 250, 250] + close_pose: [86, 73, 118, 111, 110, 111] + print_input: false + somehand: + config_path: third_party/somehand/configs/retargeting/bihand/linkerhand_l6_bihand.yaml + # Low-latency vr_hand_pose path. This favors response speed over smoothing. + rate_hz: 60.0 + max_iterations: 12 + temporal_filter_alpha: 1.0 + output_alpha: 1.0 + # Physical robot SDK configuration real_robot: network_interface: "eth0" @@ -68,7 +140,6 @@ real_robot: # Mocap mode switching safety checks mocap_switch: check_frames: 10 # Consecutive valid frames required before switching - max_position_value: 5.0 # Position sanity threshold (meters) hydra: run: diff --git a/teleopit/configs/robot/g1.yaml b/teleopit/configs/robot/g1.yaml index f1cfa837..014e84f1 100644 --- a/teleopit/configs/robot/g1.yaml +++ b/teleopit/configs/robot/g1.yaml @@ -3,7 +3,7 @@ # Do NOT copy from deploy.yaml (contains known waist_yaw scale error). num_actions: 29 -# Inference uses the 166D velcmd_history observation path (General-Tracking-G1). +# Inference uses the 167D velcmd_history observation path (General-Tracking-G1). kps: [40.2, 99.1, 40.2, 99.1, 28.5, 28.5, 40.2, 99.1, 40.2, 99.1, 28.5, 28.5, @@ -46,7 +46,7 @@ ls_iterations: 20 # Use built-in PD actuators matching mjlab training builtin_pd: true -xml_path: "teleopit/retargeting/gmr/assets/unitree_g1/g1_mjlab.xml" +xml_path: "assets/robots/unitree_g1/g1_29dof.xml" base_body: "pelvis" ang_vel_scale: 0.25 diff --git a/teleopit/configs/sim2real.yaml b/teleopit/configs/sim2real.yaml index a677d3a9..f2dbe24f 100644 --- a/teleopit/configs/sim2real.yaml +++ b/teleopit/configs/sim2real.yaml @@ -5,16 +5,11 @@ defaults: - _self_ policy_hz: 50.0 -transition_duration: 2.0 # seconds to interpolate from default pose to motion command (0.0 to disable) -pause_resume_transition_duration: 1.0 # offline playback resume blend; realtime mocap resumes from re-centered live tracking -pause_resume_warmup_steps: 2 # realtime mocap frames to accumulate before resuming live tracking +viewers: "none" # Optional: set viewers=retarget to show the retargeted reference retarget_buffer_enabled: true retarget_buffer_window_s: 0.5 retarget_buffer_delay_s: null # null = auto use one input-frame delay for timeline sampling -realtime_buffer_low_watermark_steps: 2 -realtime_buffer_high_watermark_steps: 4 realtime_buffer_warmup_steps: 2 -reference_qpos_smoothing_alpha: 1.0 reference_velocity_smoothing_alpha: 0.35 reference_anchor_velocity_smoothing_alpha: 0.25 reference_steps: [0] @@ -22,11 +17,94 @@ reference_debug_log: false playback: pause_on_end: true -# Startup ramp duration (seconds) -- smoothly blend from locked to policy positions +console: + log_level: warning + show_timing: false + timing_log_interval_s: 10.0 + +recording: + enabled: false + format: hdf5 + output_dir: data/recordings/sim2real_hdf5 + task: demo + fps: 30 + control: terminal + min_episode_seconds: 1.0 + discard_on_shutdown: true + record_modes: [standing, mocap, arms, pause] + camera: + enabled: true + key: observation.images.d435i_rgb + source: realsense + width: 640 + height: 480 + fps: 30 + device: null + video: + codec: libx264 + quality: 8 + pixelformat: yuv420p + +runtime: + host: 127.0.0.1 + base_port: 39700 + start_method: spawn + shutdown_timeout_s: 3.0 + pico_input_hz: 120.0 + hand_worker_hz: 120.0 + retarget_idle_sleep_s: 0.001 + video_slots: 3 + # In active MOCAP, delayed/invalid realtime references hold the last command instead of damping. + # max_reference_age_s only rejects entering MOCAP from STANDING on an old reference. + stale_reference_hold_s: 0.08 + max_reference_age_s: 0.25 + +# Kp ramp duration (seconds) -- gradually increases PD gains after entering STANDING startup_ramp_duration: 2.0 +kp_ramp_floor_ratio: 0.1 +# Faster Kp ramp used when returning from MOCAP to default STANDING with X +standing_return_ramp_duration: 0.5 +standing_return_kp_ramp_floor_ratio: 0.5 # Joint velocity safety limit (rad/s) -- trigger emergency damping if exceeded joint_vel_limit: 10.0 +# Optional LinkerHand control. Use only with input.provider=pico4. +hands: + enabled: false + driver: linkerhand_l6 # linkerhand_l6 | linkerhand_o6 + mode: gripper # gripper | vr_hand_pose + sides: [left, right] + rate_hz: 30.0 + frame_timeout_s: 0.3 + linkerhand_l6: + left_can: can0 + right_can: can1 + modbus: "None" + trigger_deadzone: 0.05 + deadman_threshold: 0.5 + thumb_yaw_center: 10 + speed: [50, 50, 50, 50, 50, 50] # gripper mode; vr_hand_pose always uses max speed + open_pose: [250, 10, 250, 250, 250, 250] + close_pose: [79, 10, 0, 0, 0, 0] + print_input: false + linkerhand_o6: + left_can: can0 + right_can: can1 + modbus: "None" + trigger_deadzone: 0.05 + deadman_threshold: 0.5 + speed: [255, 255, 255, 255, 255, 255] + open_pose: [250, 250, 250, 250, 250, 250] + close_pose: [86, 73, 118, 111, 110, 111] + print_input: false + somehand: + config_path: third_party/somehand/configs/retargeting/bihand/linkerhand_l6_bihand.yaml + # Low-latency vr_hand_pose path. This favors response speed over smoothing. + rate_hz: 60.0 + max_iterations: 12 + temporal_filter_alpha: 1.0 + output_alpha: 1.0 + # Physical robot SDK configuration real_robot: network_interface: "eth0" @@ -63,7 +141,6 @@ real_robot: # Mocap mode switching safety checks mocap_switch: check_frames: 10 # Consecutive valid frames required before switching - max_position_value: 5.0 # Position sanity threshold (meters) hydra: run: diff --git a/teleopit/configs/sim2real_record.yaml b/teleopit/configs/sim2real_record.yaml new file mode 100644 index 00000000..2fda8bf4 --- /dev/null +++ b/teleopit/configs/sim2real_record.yaml @@ -0,0 +1,16 @@ +defaults: + - pico4_sim2real + - _self_ + +recording: + enabled: true + +input: + video: + enabled: true + source: realsense + width: 640 + height: 480 + fps: 30 + device: null + fail_on_error: true diff --git a/teleopit/controllers/observation.py b/teleopit/controllers/observation.py index b637b5b8..606ceaea 100644 --- a/teleopit/controllers/observation.py +++ b/teleopit/controllers/observation.py @@ -107,7 +107,7 @@ def _quat_to_rot6d_np(q: FloatVec) -> FloatVec: @final class _VelCmdBaseObservationBuilder: - """Internal base block used by the public 166D VelCmd builder.""" + """Internal base block used by the public 167D VelCmd builder.""" def __init__(self, cfg: ConfigType) -> None: self.num_actions: int = _as_int_scalar(cfg_get(cfg, "num_actions"), "num_actions") @@ -164,13 +164,13 @@ def build( last_action: FloatVec, ) -> FloatVec: qpos = np.asarray(robot_state.qpos, dtype=np.float32).reshape(-1)[: self.num_actions] - qvel = np.asarray(robot_state.qvel, dtype=np.float32).reshape(-1)[: self.num_actions] + robot_joint_vel = np.asarray(robot_state.qvel, dtype=np.float32).reshape(-1)[: self.num_actions] robot_quat = np.asarray(robot_state.quat, dtype=np.float32).reshape(-1) - base_ang_vel = np.asarray(robot_state.ang_vel, dtype=np.float32).reshape(-1) + robot_base_ang_vel_b = np.asarray(robot_state.ang_vel, dtype=np.float32).reshape(-1) if robot_quat.shape[0] != 4: raise ValueError(f"robot_state.quat must be 4D (wxyz), got {robot_quat.shape[0]}") - if base_ang_vel.shape[0] != 3: - raise ValueError(f"robot_state.ang_vel must be 3D, got {base_ang_vel.shape[0]}") + if robot_base_ang_vel_b.shape[0] != 3: + raise ValueError(f"robot_state.ang_vel must be 3D, got {robot_base_ang_vel_b.shape[0]}") motion = np.asarray(motion_qpos, dtype=np.float32).reshape(-1) if motion.shape[0] < 7 + self.num_actions: @@ -182,9 +182,9 @@ def build( raise ValueError( f"motion_joint_vel must be {self.num_actions}D, got {motion_joint_vel_vec.shape[0]}" ) - last_act = np.asarray(last_action, dtype=np.float32).reshape(-1) - if last_act.shape[0] != self.num_actions: - raise ValueError(f"last_action length must be {self.num_actions}, got {last_act.shape[0]}") + prev_action = np.asarray(last_action, dtype=np.float32).reshape(-1) + if prev_action.shape[0] != self.num_actions: + raise ValueError(f"last_action length must be {self.num_actions}, got {prev_action.shape[0]}") self._run_fk(np.zeros(3, dtype=np.float32), robot_quat, qpos) robot_anchor_quat = self._get_body_quat(self._anchor_body_id) @@ -193,20 +193,22 @@ def build( motion_base_quat = motion[3:7] motion_joint_pos = motion[7:7 + self.num_actions] self._run_fk(motion_base_pos, motion_base_quat, motion_joint_pos) - motion_anchor_quat = self._get_body_quat(self._anchor_body_id) + ref_anchor_quat = self._get_body_quat(self._anchor_body_id) - command = np.concatenate((motion_joint_pos, motion_joint_vel_vec), dtype=np.float32) - rel_quat = _quat_mul_np(_quat_inv_np(robot_anchor_quat), motion_anchor_quat) - motion_anchor_ori_b = _quat_to_rot6d_np(rel_quat) - joint_pos_rel = qpos - self.default_dof_pos + ref_joint_pos = motion_joint_pos + ref_joint_vel = motion_joint_vel_vec + rel_quat = _quat_mul_np(_quat_inv_np(robot_anchor_quat), ref_anchor_quat) + ref_anchor_ori_b = _quat_to_rot6d_np(rel_quat) + robot_joint_pos_rel = qpos - self.default_dof_pos obs = np.concatenate([ - command, - motion_anchor_ori_b, - base_ang_vel, - joint_pos_rel, - qvel, - last_act, + ref_joint_pos, + ref_joint_vel, + ref_anchor_ori_b, + robot_base_ang_vel_b, + robot_joint_pos_rel, + robot_joint_vel, + prev_action, ], dtype=np.float32) if obs.shape[0] != self.total_obs_size: raise ValueError(f"Expected {self.total_obs_size}D base observation, got {obs.shape[0]}") @@ -215,12 +217,12 @@ def build( @final class VelCmdObservationBuilder: - """166D observation builder for the only supported VelCmdHistory policy.""" + """167D observation builder for the only supported VelCmdHistory policy.""" def __init__(self, cfg: ConfigType) -> None: self._base = _VelCmdBaseObservationBuilder(cfg) self.num_actions = self._base.num_actions - self.total_obs_size = self._base.total_obs_size + 12 + self.total_obs_size = self._base.total_obs_size + 13 def reset(self) -> None: self._base.reset() @@ -239,7 +241,8 @@ def build( robot_quat = np.asarray(robot_state.quat, dtype=np.float32).reshape(-1) motion = np.asarray(motion_qpos, dtype=np.float32).reshape(-1) self._base._run_fk(motion[0:3], motion[3:7], motion[7:7 + self.num_actions]) - motion_anchor_quat = self._base._get_body_quat(self._base._anchor_body_id) + ref_anchor_pos = self._base._get_body_pos(self._base._anchor_body_id) + ref_anchor_quat = self._base._get_body_quat(self._base._anchor_body_id) qpos = np.asarray(robot_state.qpos, dtype=np.float32).reshape(-1)[: self.num_actions] robot_base_pos = np.zeros(3, dtype=np.float32) @@ -250,17 +253,19 @@ def build( ref_lin_vel_w = np.asarray(motion_anchor_lin_vel_w, dtype=np.float32).reshape(3) ref_ang_vel_w = np.asarray(motion_anchor_ang_vel_w, dtype=np.float32).reshape(3) - projected_gravity = _quat_rotate_np(_quat_inv_np(robot_quat), _GRAVITY_UNIT_W) + robot_projected_gravity_b = _quat_rotate_np(_quat_inv_np(robot_quat), _GRAVITY_UNIT_W) robot_inv = _quat_inv_np(robot_anchor_quat) - ref_base_lin_vel_b = _quat_rotate_np(robot_inv, ref_lin_vel_w) - ref_base_ang_vel_b = _quat_rotate_np(robot_inv, ref_ang_vel_w) - ref_projected_gravity_b = _quat_rotate_np(_quat_inv_np(motion_anchor_quat), _GRAVITY_UNIT_W) + ref_anchor_lin_vel_b = _quat_rotate_np(robot_inv, ref_lin_vel_w) + ref_anchor_ang_vel_b = _quat_rotate_np(robot_inv, ref_ang_vel_w) + ref_projected_gravity_b = _quat_rotate_np(_quat_inv_np(ref_anchor_quat), _GRAVITY_UNIT_W) + ref_anchor_height = ref_anchor_pos[2:3] velcmd_obs = np.concatenate([ - projected_gravity, - ref_base_lin_vel_b, - ref_base_ang_vel_b, + robot_projected_gravity_b, + ref_anchor_lin_vel_b, + ref_anchor_ang_vel_b, ref_projected_gravity_b, + ref_anchor_height, ], dtype=np.float32) obs = np.concatenate([base_obs, velcmd_obs], dtype=np.float32) if obs.shape[0] != self.total_obs_size: @@ -282,4 +287,3 @@ def build_observation( "Use build(robot_state, motion_qpos, motion_joint_vel, last_action, " "motion_anchor_lin_vel_w, motion_anchor_ang_vel_w)." ) - diff --git a/teleopit/controllers/qpos_interpolator.py b/teleopit/controllers/qpos_interpolator.py deleted file mode 100644 index da1af66e..00000000 --- a/teleopit/controllers/qpos_interpolator.py +++ /dev/null @@ -1,139 +0,0 @@ -"""Smooth transition interpolator for retargeted 36D qpos. - -Prevents violent robot motion when switching from default/idle pose to a new -motion command by gradually blending between the start pose and the live -retargeted target over a configurable duration. -""" - -from __future__ import annotations - -import numpy as np -from numpy.typing import NDArray - - -def _slerp(q0: NDArray, q1: NDArray, t: float) -> NDArray: - """Spherical linear interpolation between two wxyz quaternions.""" - q0 = q0 / max(np.linalg.norm(q0), 1e-8) - q1 = q1 / max(np.linalg.norm(q1), 1e-8) - dot = float(np.dot(q0, q1)) - # Ensure shortest path - if dot < 0.0: - q1 = -q1 - dot = -dot - # Fall back to lerp for nearly identical quaternions - if dot > 0.9995: - result = q0 + t * (q1 - q0) - return result / max(np.linalg.norm(result), 1e-8) - theta = np.arccos(np.clip(dot, -1.0, 1.0)) - sin_theta = np.sin(theta) - a = np.sin((1.0 - t) * theta) / sin_theta - b = np.sin(t * theta) / sin_theta - return a * q0 + b * q1 - - -class QposInterpolator: - """Smoothly interpolates retargeted qpos from a start pose to the live target. - - Operates on 36D qpos: pos(3) + quat_wxyz(4) + joints(29). - Position and joints use linear interpolation; quaternion uses SLERP. - - Parameters - ---------- - duration : float - Transition duration in seconds. 0.0 disables interpolation. - policy_hz : float - Policy frequency (steps per second) for step-based progress. - """ - - def __init__(self, duration: float, policy_hz: float) -> None: - self._policy_hz = policy_hz - self._duration = 0.0 - self._total_steps = 0 - self._step = 0 - self._start_qpos: NDArray | None = None - self._active = False - self._last_alpha = np.float64(1.0) - self.configure(duration) - - @property - def duration(self) -> float: - return self._duration - - @property - def is_active(self) -> bool: - return self._active - - @property - def last_alpha(self) -> float: - return float(self._last_alpha) - - def reset(self) -> None: - self._step = 0 - self._start_qpos = None - self._active = False - self._last_alpha = np.float64(1.0) - - def configure(self, duration: float) -> None: - self._duration = max(float(duration), 0.0) - self._total_steps = int(self._duration * self._policy_hz) - - def start(self, start_qpos: NDArray) -> None: - """Begin interpolation from *start_qpos* toward future targets.""" - if self._total_steps <= 0: - return - self._start_qpos = np.array(start_qpos, dtype=np.float64).ravel() - self._step = 0 - self._active = True - self._last_alpha = np.float64(0.0) - - def apply(self, target_qpos: NDArray) -> NDArray: - """Return interpolated qpos. Passthrough when inactive or finished.""" - if not self._active or self._start_qpos is None: - self._last_alpha = np.float64(1.0) - return target_qpos - - if self._step >= self._total_steps: - self._active = False - self._last_alpha = np.float64(1.0) - return target_qpos - - alpha = self._step / self._total_steps - self._step += 1 - self._last_alpha = np.float64(alpha) - - result = np.empty_like(target_qpos) - # Position: lerp - result[0:3] = (1.0 - alpha) * self._start_qpos[0:3] + alpha * target_qpos[0:3] - # Quaternion: SLERP - result[3:7] = _slerp(self._start_qpos[3:7], target_qpos[3:7], alpha) - # Joints: lerp - result[7:] = (1.0 - alpha) * self._start_qpos[7:] + alpha * target_qpos[7:] - return result - - -class QposLowPassFilter: - """Low-pass filter for retargeted qpos.""" - - def __init__(self, alpha: float) -> None: - alpha_f = float(alpha) - if not np.isfinite(alpha_f) or alpha_f <= 0.0 or alpha_f > 1.0: - raise ValueError(f"alpha must be finite and in (0, 1], got {alpha}") - self._alpha = alpha_f - self._state: NDArray | None = None - - def reset(self) -> None: - self._state = None - - def apply(self, target_qpos: NDArray) -> NDArray: - target = np.asarray(target_qpos, dtype=np.float64).reshape(-1) - if self._state is None or self._state.shape != target.shape or self._alpha >= 1.0 - 1e-6: - self._state = target.copy() - return self._state.copy() - - alpha = float(self._alpha) - filtered = np.empty_like(target) - filtered[0:3] = (1.0 - alpha) * self._state[0:3] + alpha * target[0:3] - filtered[3:7] = _slerp(self._state[3:7], target[3:7], alpha) - filtered[7:] = (1.0 - alpha) * self._state[7:] + alpha * target[7:] - self._state = filtered - return filtered.copy() diff --git a/teleopit/inputs/bvh_provider.py b/teleopit/inputs/bvh_provider.py index d88c5a43..338ce5a9 100644 --- a/teleopit/inputs/bvh_provider.py +++ b/teleopit/inputs/bvh_provider.py @@ -6,6 +6,7 @@ import numpy as np from teleopit.retargeting.gmr.utils.lafan_vendor.extract import read_bvh from teleopit.retargeting.gmr.utils.lafan_vendor import utils +from teleopit.retargeting.gmr.utils.xsens_vendor.BVHParser import BVHParser, Anim from scipy.spatial.transform import Rotation as R @@ -117,6 +118,9 @@ def process_single_bvh_frame( def _load_bvh_file(bvh_file: str, format: str = "lafan1"): + if format == "xsens": + return _load_xsens_bvh_file(bvh_file) + data = read_bvh(bvh_file) bone_names = list(data.bones) bone_parents = np.array(data.parents, dtype=np.int32) @@ -182,6 +186,44 @@ def _load_bvh_file(bvh_file: str, format: str = "lafan1"): return frames, human_height, fps, bone_names, bone_parents +def _load_xsens_bvh_file(bvh_file: str): + parser = BVHParser(axis_order="zxy", scale=0.01) + with open(bvh_file, "r") as f: + bvh_text = f.read() + + rotations, positions = parser.parse(bvh_text, reset_to_zero=True) + quats, processed_positions, offsets, parents = parser._MOTION_data_post_processing( + rotations, + positions, + reset_to_zero=True, + ) + anim = Anim(quats, processed_positions, offsets, parents, parser.names) + global_quats, global_pos = utils.quat_fk(anim.quats, anim.pos, anim.parents) + + frames = [] + for frame in range(anim.pos.shape[0]): + result = {} + for i, bone in enumerate(anim.bones): + result[bone] = (global_pos[frame, i], global_quats[frame, i]) + + result["LeftFootMod"] = (np.array(result["LeftAnkle"][0], copy=True), result["LeftAnkle"][1]) + result["RightFootMod"] = (np.array(result["RightAnkle"][0], copy=True), result["RightAnkle"][1]) + frames.append(result) + + if not frames: + raise ValueError(f"No frames parsed from xsens BVH input: {bvh_file}") + + frame_time = float(parser.frame_time) + fps = int(round(1.0 / frame_time)) if frame_time > 0.0 else 60 + last_frame = frames[-1] + human_height = float( + last_frame["Head_end_site"][0][2] + - min(last_frame["LeftToe_end_site"][0][2], last_frame["RightToe_end_site"][0][2]) + ) + + return frames, human_height, fps, list(anim.bones), np.array(anim.parents, dtype=np.int32) + + class BVHInputProvider: def __init__(self, bvh_path: str, human_format: str = "lafan1"): diff --git a/teleopit/inputs/human_frame_validation.py b/teleopit/inputs/human_frame_validation.py new file mode 100644 index 00000000..ab854505 --- /dev/null +++ b/teleopit/inputs/human_frame_validation.py @@ -0,0 +1,107 @@ +"""HumanFrame finite-value validation shared by realtime diagnostics and runtimes.""" + +from __future__ import annotations + +from dataclasses import dataclass + +import numpy as np + + +@dataclass(frozen=True) +class HumanFrameValidationResult: + valid: bool + reason: str = "ok" + joint_name: str | None = None + pos: tuple[float, ...] | None = None + quat: tuple[float, ...] | None = None + max_abs_pos: float | None = None + detail: str = "" + + +def validate_human_frame(frame: object) -> HumanFrameValidationResult: + """Validate HumanFrame numeric values before realtime retargeting.""" + if not isinstance(frame, dict): + return HumanFrameValidationResult(False, reason="frame_not_dict", detail=f"type={type(frame)!r}") + + max_frame_abs_pos: float | None = None + for name, value in frame.items(): + joint_name = str(name) + try: + pos, quat = value + except Exception as exc: + return HumanFrameValidationResult( + False, + reason="joint_unpack_failed", + joint_name=joint_name, + detail=str(exc), + ) + + try: + pos_arr = np.asarray(pos, dtype=np.float64).reshape(-1) + except Exception as exc: + return HumanFrameValidationResult( + False, + reason="position_cast_failed", + joint_name=joint_name, + detail=str(exc), + ) + try: + quat_arr = np.asarray(quat, dtype=np.float64).reshape(-1) + except Exception as exc: + return HumanFrameValidationResult( + False, + reason="quaternion_cast_failed", + joint_name=joint_name, + pos=_to_tuple(pos_arr), + detail=str(exc), + ) + + pos_tuple = _to_tuple(pos_arr) + quat_tuple = _to_tuple(quat_arr) + max_abs_pos = float(np.max(np.abs(pos_arr))) if pos_arr.size > 0 else 0.0 + max_frame_abs_pos = ( + max_abs_pos if max_frame_abs_pos is None else max(max_frame_abs_pos, max_abs_pos) + ) + + if np.any(np.isnan(pos_arr)): + return HumanFrameValidationResult( + False, + reason="position_nan", + joint_name=joint_name, + pos=pos_tuple, + quat=quat_tuple, + max_abs_pos=max_abs_pos, + ) + if np.any(np.isinf(pos_arr)): + return HumanFrameValidationResult( + False, + reason="position_inf", + joint_name=joint_name, + pos=pos_tuple, + quat=quat_tuple, + max_abs_pos=max_abs_pos, + ) + if np.any(np.isnan(quat_arr)): + return HumanFrameValidationResult( + False, + reason="quaternion_nan", + joint_name=joint_name, + pos=pos_tuple, + quat=quat_tuple, + max_abs_pos=max_abs_pos, + ) + if np.any(np.isinf(quat_arr)): + return HumanFrameValidationResult( + False, + reason="quaternion_inf", + joint_name=joint_name, + pos=pos_tuple, + quat=quat_tuple, + max_abs_pos=max_abs_pos, + ) + + return HumanFrameValidationResult(True, max_abs_pos=max_frame_abs_pos) + + +def _to_tuple(values: np.ndarray) -> tuple[float, ...]: + return tuple(float(value) for value in np.asarray(values, dtype=np.float64).reshape(-1)) diff --git a/teleopit/inputs/pico4_provider.py b/teleopit/inputs/pico4_provider.py index b6bc021e..1780b03f 100644 --- a/teleopit/inputs/pico4_provider.py +++ b/teleopit/inputs/pico4_provider.py @@ -1,14 +1,16 @@ """Pico4 VR full-body motion capture input provider. Uses the in-process ``pico_bridge`` receiver to collect PICO tracking frames. -The provider converts native PICO/Unity poses (meters, xyzw quaternions) into +The provider converts native PICO poses (meters, xyzw quaternions) into Teleopit's realtime ``HumanFrame`` format. """ from __future__ import annotations from collections import deque +from dataclasses import dataclass import inspect +from importlib.metadata import PackageNotFoundError, version import logging import threading import time @@ -31,7 +33,7 @@ logger = logging.getLogger(__name__) -# PICO/Unity -> Teleopit retarget input space. +# PICO native -> Teleopit retarget input space. _INPUT_TO_TELEOPIT_MATRIX = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]], dtype=np.float64) _INPUT_TO_TELEOPIT_QUAT = R.from_matrix(_INPUT_TO_TELEOPIT_MATRIX).as_quat(scalar_first=True) @@ -52,6 +54,46 @@ dtype=np.int32, ) + +@dataclass(frozen=True) +class PicoControllerState: + """Latest per-controller input state exposed by pico_bridge.""" + + raw: bool + grip: float + trigger: float + present: bool = True + + +@dataclass(frozen=True) +class PicoControllerSnapshot: + """Immutable snapshot of Pico controller inputs for auxiliary runtimes.""" + + left: PicoControllerState + right: PicoControllerState + timestamp_s: float + seq: int + + +@dataclass(frozen=True) +class PicoHandState: + """Latest per-hand pose state exposed by pico_bridge.""" + + active: bool + joints: NDArray[np.float64] + present: bool = True + + +@dataclass(frozen=True) +class PicoHandSnapshot: + """Immutable snapshot of Pico hand poses for auxiliary runtimes.""" + + left: PicoHandState + right: PicoHandState + timestamp_s: float + seq: int + + _PAUSE_BUTTON_MAP: dict[str, tuple[str, str]] = { "A": ("right", "primaryButton"), "B": ("right", "secondaryButton"), @@ -64,6 +106,32 @@ } +def _has_non_degenerate_positions(positions: NDArray[np.float64]) -> bool: + pos = np.asarray(positions, dtype=np.float64).reshape(-1, 3) + if pos.size == 0: + return False + finite_mask = np.all(np.isfinite(pos), axis=1) + valid_pos = pos[finite_mask] + if valid_pos.shape[0] < 2: + return False + nonzero_pos = valid_pos[np.linalg.norm(valid_pos, axis=1) > 1e-9] + if nonzero_pos.shape[0] < 2: + return False + extent = float(np.max(np.ptp(nonzero_pos, axis=0))) + return extent > 1e-6 + + +def _compute_ground_alignment_offset(positions: NDArray[np.float64]) -> float: + pos = np.asarray(positions, dtype=np.float64).reshape(-1, 3) + if pos.size == 0: + return 0.0 + finite_mask = np.all(np.isfinite(pos), axis=1) + if not np.any(finite_mask): + return 0.0 + min_z = float(np.min(pos[finite_mask, 2])) + return -min_z + + def _bridge_accepts_video_enabled(bridge_cls: type[Any]) -> bool: try: signature = inspect.signature(bridge_cls) @@ -75,6 +143,20 @@ def _bridge_accepts_video_enabled(bridge_cls: type[Any]) -> bool: return any(param.kind == inspect.Parameter.VAR_KEYWORD for param in parameters.values()) +def _installed_pico_bridge_version() -> tuple[int, ...] | None: + try: + raw_version = version("pico-bridge") + except PackageNotFoundError: + return None + release = raw_version.split("+", 1)[0].split("-", 1)[0] + parts: list[int] = [] + for part in release.split("."): + if not part.isdigit(): + break + parts.append(int(part)) + return tuple(parts) if parts else None + + def _coordinate_transform_input(body_pose_dict: dict[str, list]) -> dict[str, list]: """Transform provider-space poses into Teleopit's expected coordinates.""" for body_name, value in body_pose_dict.items(): @@ -102,6 +184,8 @@ def __init__( timestamp_gap_reset_s: float = 0.15, pause_button: str | None = "A", pause_debounce_s: float = 0.25, + arms_button: str | None = "B", + arms_debounce_s: float | None = None, bridge_host: str = "0.0.0.0", bridge_port: int = 63901, bridge_discovery: bool = True, @@ -120,10 +204,16 @@ def __init__( "pico_bridge is required for Pico4 input. Install the receiver package, " "for example: pip install -e '.[pico4]'" ) from exc + installed_version = _installed_pico_bridge_version() + if installed_version is None or installed_version < (0, 2, 1): + raise RuntimeError( + "pico_bridge >= 0.2.1 is required for Pico4 input. Reinstall the Pico extra with " + "pip install -e '.[pico4]' so Teleopit receives pico_native tracking semantics." + ) bridge_cls = PicoBridge if not _bridge_accepts_video_enabled(bridge_cls): raise RuntimeError( - "pico_bridge >= 0.2.0 is required for Pico4 input. Reinstall the Pico extra with " + "pico_bridge >= 0.2.1 is required for Pico4 input. Reinstall the Pico extra with " "pip install -e '.[pico4]' so PicoBridge accepts video_enabled and push_video_frame()." ) @@ -136,13 +226,21 @@ def __init__( self._timestamp_gap_reset_s = float(timestamp_gap_reset_s) self._pending_control_events: deque[ControlEvent] = deque() self._pause_button = None if pause_button in (None, "", "null") else str(pause_button) + self._arms_button = None if arms_button in (None, "", "null") else str(arms_button) self._pause_debounce_s = max(float(pause_debounce_s), 0.0) + self._arms_debounce_s = self._pause_debounce_s if arms_debounce_s is None else max(float(arms_debounce_s), 0.0) self._pause_button_path = self._resolve_button_path(self._pause_button) + self._arms_button_path = self._resolve_button_path(self._arms_button) self._last_pause_button_pressed = False + self._last_arms_button_pressed = False self._last_pause_toggle_timestamp: float | None = None + self._last_arms_toggle_timestamp: float | None = None self._last_raw_body_joints: NDArray[np.float64] | None = None self._last_frame_timestamp: float | None = None self._last_source_seq: int | None = None + self._controller_snapshot: PicoControllerSnapshot | None = None + self._hand_snapshot: PicoHandSnapshot | None = None + self._ground_alignment_offset: float | None = None self._bridge = bridge_cls( host=bridge_host, port=int(bridge_port), @@ -161,6 +259,11 @@ def __init__( "Pico4InputProvider pause button '%s' is unsupported by pico_bridge; pause events disabled", self._pause_button, ) + if self._arms_button is not None and self._arms_button_path is None: + logger.warning( + "Pico4InputProvider arms button '%s' is unsupported by pico_bridge; arms events disabled", + self._arms_button, + ) logger.info("Pico4InputProvider initialized (pico_bridge)") @property @@ -223,11 +326,21 @@ def pop_control_events(self) -> tuple[ControlEvent, ...]: self._pending_control_events.clear() return control_events + def get_controller_snapshot(self) -> PicoControllerSnapshot | None: + """Return the latest Pico controller-axis snapshot, if one has arrived.""" + with self._lock: + return self._controller_snapshot + + def get_hand_snapshot(self) -> PicoHandSnapshot | None: + """Return the latest Pico hand-pose snapshot, if one has arrived.""" + with self._lock: + return self._hand_snapshot + def push_video_frame(self, frame: NDArray[np.uint8]) -> int: - """Push one RGB camera frame to pico-bridge 0.2.0 video output.""" + """Push one RGB camera frame to pico-bridge 0.2.1 video output.""" push_video_frame = getattr(self._bridge, "push_video_frame", None) if not callable(push_video_frame): - raise RuntimeError("Installed pico_bridge does not expose push_video_frame(); use pico-bridge 0.2.0") + raise RuntimeError("Installed pico_bridge does not expose push_video_frame(); use pico-bridge 0.2.1") return int(push_video_frame(frame)) def has_frame(self) -> bool: @@ -287,6 +400,8 @@ def _poll_loop(self) -> None: def _accept_pico_frame(self, frame: Any) -> bool: timestamp = float(getattr(frame, "receive_time_s", time.monotonic())) + self._accept_controller_snapshot(frame, timestamp=timestamp) + self._accept_hand_snapshot(frame, timestamp=timestamp) self._poll_control_events(frame, timestamp=timestamp) body = getattr(frame, "body", None) @@ -312,6 +427,7 @@ def _accept_pico_frame(self, frame: Any) -> bool: and timestamp - self._last_frame_timestamp > self._timestamp_gap_reset_s ): self._frame_cache.clear() + self._ground_alignment_offset = None logger.warning( "Pico4InputProvider timestamp-gap reset | gap=%.4fs", timestamp - self._last_frame_timestamp, @@ -319,6 +435,7 @@ def _accept_pico_frame(self, frame: Any) -> bool: if self._last_frame_timestamp is not None and timestamp <= self._last_frame_timestamp + 1e-9: timestamp = self._last_frame_timestamp + 1e-6 + human_frame = self._apply_ground_alignment(human_frame) self._frame_cache.append(human_frame, timestamp, fps_timestamp=timestamp) self._last_raw_body_joints = body_joints.copy() self._last_frame_timestamp = timestamp @@ -327,47 +444,133 @@ def _accept_pico_frame(self, frame: Any) -> bool: self._frame_ready.set() return True + def _accept_controller_snapshot(self, frame: Any, *, timestamp: float) -> None: + seq = int(getattr(frame, "seq", self._last_source_seq or -1)) + controllers = getattr(frame, "controllers", None) + snapshot = PicoControllerSnapshot( + left=self._read_controller_state(None if controllers is None else getattr(controllers, "left", None)), + right=self._read_controller_state(None if controllers is None else getattr(controllers, "right", None)), + timestamp_s=float(timestamp), + seq=seq, + ) + with self._lock: + self._controller_snapshot = snapshot + + def _accept_hand_snapshot(self, frame: Any, *, timestamp: float) -> None: + seq = int(getattr(frame, "seq", self._last_source_seq or -1)) + snapshot = PicoHandSnapshot( + left=self._read_hand_state(getattr(frame, "left_hand", None)), + right=self._read_hand_state(getattr(frame, "right_hand", None)), + timestamp_s=float(timestamp), + seq=seq, + ) + with self._lock: + self._hand_snapshot = snapshot + def _poll_control_events(self, frame: Any, *, timestamp: float) -> bool: - if self._pause_button_path is None: + emitted = False + emitted = self._poll_button_control_event( + frame, + timestamp=timestamp, + button_path=self._pause_button_path, + button_label=self._pause_button, + event_type=ControlEventType.TOGGLE_PAUSE, + last_pressed_attr="_last_pause_button_pressed", + last_toggle_attr="_last_pause_toggle_timestamp", + debounce_s=self._pause_debounce_s, + ) or emitted + emitted = self._poll_button_control_event( + frame, + timestamp=timestamp, + button_path=self._arms_button_path, + button_label=self._arms_button, + event_type=ControlEventType.TOGGLE_ARMS, + last_pressed_attr="_last_arms_button_pressed", + last_toggle_attr="_last_arms_toggle_timestamp", + debounce_s=self._arms_debounce_s, + ) or emitted + return emitted + + def _poll_button_control_event( + self, + frame: Any, + *, + timestamp: float, + button_path: tuple[str, str] | None, + button_label: str | None, + event_type: ControlEventType, + last_pressed_attr: str, + last_toggle_attr: str, + debounce_s: float, + ) -> bool: + if button_path is None: return False - side, button_name = self._pause_button_path + side, button_name = button_path controllers = getattr(frame, "controllers", None) controller = None if controllers is None else getattr(controllers, side, None) buttons = {} if controller is None else getattr(controller, "buttons", {}) or {} pressed = bool(buttons.get(button_name, False)) + last_pressed = bool(getattr(self, last_pressed_attr)) emitted = False - if pressed and not self._last_pause_button_pressed: - if ( - self._last_pause_toggle_timestamp is None - or timestamp - self._last_pause_toggle_timestamp >= self._pause_debounce_s - 1e-9 - ): + if pressed and not last_pressed: + last_toggle = getattr(self, last_toggle_attr) + if last_toggle is None or timestamp - float(last_toggle) >= debounce_s - 1e-9: with self._lock: self._pending_control_events.append( ControlEvent( - event_type=ControlEventType.TOGGLE_PAUSE, - source=f"pico4:{self._pause_button}", + event_type=event_type, + source=f"pico4:{button_label}", timestamp_s=float(timestamp), ) ) - self._last_pause_toggle_timestamp = float(timestamp) + logger.info("Pico control event: %s from %s", event_type.value, button_label) + setattr(self, last_toggle_attr, float(timestamp)) emitted = True - self._last_pause_button_pressed = pressed + setattr(self, last_pressed_attr, pressed) return emitted @staticmethod - def _resolve_button_path(pause_button: str | None) -> tuple[str, str] | None: - if pause_button is None: + def _resolve_button_path(button: str | None) -> tuple[str, str] | None: + if button is None: return None - return _PAUSE_BUTTON_MAP.get(pause_button) + return _PAUSE_BUTTON_MAP.get(button) + + @staticmethod + def _read_controller_state(controller: Any) -> PicoControllerState: + axis = {} if controller is None else getattr(controller, "axis", {}) or {} + return PicoControllerState( + raw=bool(False if controller is None else getattr(controller, "raw", False)), + grip=float(axis.get("grip", 0.0)), + trigger=float(axis.get("trigger", 0.0)), + present=controller is not None, + ) + + @staticmethod + def _read_hand_state(hand: Any) -> PicoHandState: + joints = np.zeros((26, 7), dtype=np.float64) + valid_shape = False + if hand is None: + return PicoHandState(active=False, joints=joints, present=False) + try: + raw_joints = np.asarray(getattr(hand, "joints"), dtype=np.float64) + if raw_joints.shape == (26, 7): + joints = raw_joints.copy() + valid_shape = True + except (TypeError, ValueError): + pass + return PicoHandState( + active=bool(getattr(hand, "active", False)) and valid_shape, + joints=joints, + present=True, + ) @staticmethod def _convert_body_joints_to_frame(body_joints: NDArray[np.float64]) -> HumanFrame: - body_joints = Pico4InputProvider._normalize_pico_bridge_body_joints(body_joints) body_pose_dict: dict[str, list] = {} for i, joint_name in enumerate(BODY_JOINT_NAMES): pos = [body_joints[i][0], body_joints[i][1], body_joints[i][2]] - # pico_bridge returns [x, y, z, qx, qy, qz, qw]. + # pico_bridge 0.2.1 returns pico_native [x, y, z, qx, qy, qz, qw]. rot = [body_joints[i][6], body_joints[i][3], body_joints[i][4], body_joints[i][5]] body_pose_dict[joint_name] = [pos, rot] @@ -378,11 +581,21 @@ def _convert_body_joints_to_frame(body_joints: NDArray[np.float64]) -> HumanFram result[name] = (np.asarray(pos, dtype=np.float64), np.asarray(quat, dtype=np.float64)) return result - @staticmethod - def _normalize_pico_bridge_body_joints(body_joints: NDArray[np.float64]) -> NDArray[np.float64]: - """Match Teleopit's calibrated Pico body-pose convention.""" - converted = np.array(body_joints, dtype=np.float64, copy=True) - converted[:, 2] *= -1.0 - converted[:, 5] *= -1.0 - converted[:, 6] *= -1.0 - return converted + def _apply_ground_alignment(self, human_frame: HumanFrame) -> HumanFrame: + """Apply one fixed Z offset so the initial Pico skeleton sits on the floor.""" + if self._ground_alignment_offset is None: + positions = np.asarray([value[0] for value in human_frame.values()], dtype=np.float64) + if _has_non_degenerate_positions(positions): + self._ground_alignment_offset = _compute_ground_alignment_offset(positions) + else: + return human_frame + + offset = float(self._ground_alignment_offset) + if abs(offset) <= 1e-12: + return human_frame + + z_offset = np.array([0.0, 0.0, offset], dtype=np.float64) + lifted: HumanFrame = {} + for name, (pos, quat) in human_frame.items(): + lifted[name] = (np.asarray(pos, dtype=np.float64) + z_offset, np.asarray(quat, dtype=np.float64)) + return lifted diff --git a/teleopit/inputs/pico_video.py b/teleopit/inputs/pico_video.py index 500b3104..fd1ed25e 100644 --- a/teleopit/inputs/pico_video.py +++ b/teleopit/inputs/pico_video.py @@ -1,4 +1,4 @@ -"""Optional camera-to-Pico video streaming for pico-bridge 0.2.0.""" +"""Optional camera-to-Pico video streaming for pico-bridge 0.2.1.""" from __future__ import annotations @@ -6,7 +6,7 @@ import logging import threading import time -from typing import Any +from typing import Any, Callable import numpy as np @@ -70,11 +70,13 @@ def __init__( config: PicoVideoConfig, mode: str, robot: Any | None = None, + frame_callback: Callable[[np.ndarray, float], None] | None = None, ) -> None: self._provider = provider self._config = config self._mode = mode self._robot = robot + self._frame_callback = frame_callback self._producer: _VideoProducer | None = None self._stopped = False @@ -82,6 +84,13 @@ def __init__( def enabled(self) -> bool: return self._config.enabled + @property + def pushed_frames(self) -> int: + producer = self._producer + if producer is None: + return 0 + return int(getattr(producer, "pushed_frames", 0)) + def start(self) -> None: if not self._config.enabled: return @@ -89,16 +98,16 @@ def start(self) -> None: if self._config.source == "test-pattern": logger.info("Pico video enabled via pico-bridge test-pattern source") return - if not callable(getattr(self._provider, "push_video_frame", None)): + if self._frame_callback is None and not callable(getattr(self._provider, "push_video_frame", None)): self._handle_error(RuntimeError("Pico input provider does not support push_video_frame")) return producer: _VideoProducer | None = None try: if self._config.source == "realsense": - producer = _RealSenseVideoProducer(self._provider, self._config) + producer = _RealSenseVideoProducer(self._provider, self._config, self._frame_callback) elif self._config.source == "mujoco": - producer = _MujocoCameraVideoProducer(self._provider, self._config, self._robot) + producer = _MujocoCameraVideoProducer(self._provider, self._config, self._robot, self._frame_callback) else: raise ValueError(f"Unsupported Pico video source: {self._config.source!r}") self._producer = producer @@ -145,13 +154,24 @@ def stop(self) -> None: ... class _RealSenseVideoProducer(_VideoProducer): - def __init__(self, provider: Any, config: PicoVideoConfig) -> None: + def __init__( + self, + provider: Any, + config: PicoVideoConfig, + frame_callback: Callable[[np.ndarray, float], None] | None = None, + ) -> None: self._provider = provider self._config = config + self._frame_callback = frame_callback self._stop_event = threading.Event() self._ready_event = threading.Event() self._thread = threading.Thread(target=self._run, name="pico_realsense_video", daemon=True) self._error: BaseException | None = None + self._pushed_frames = 0 + + @property + def pushed_frames(self) -> int: + return int(self._pushed_frames) def start(self) -> None: self._thread.start() @@ -194,7 +214,13 @@ def _run(self) -> None: if not color_frame: continue rgb = np.ascontiguousarray(np.asanyarray(color_frame.get_data()), dtype=np.uint8) - self._provider.push_video_frame(rgb) + timestamp_s = time.monotonic() + if self._frame_callback is not None: + self._frame_callback(rgb, timestamp_s) + if callable(getattr(self._provider, "push_video_frame", None)): + self._pushed_frames = int(self._provider.push_video_frame(rgb)) + else: + self._pushed_frames += 1 finally: pipeline.stop() except BaseException as exc: @@ -204,13 +230,25 @@ def _run(self) -> None: class _MujocoCameraVideoProducer(_VideoProducer): - def __init__(self, provider: Any, config: PicoVideoConfig, robot: Any | None) -> None: + def __init__( + self, + provider: Any, + config: PicoVideoConfig, + robot: Any | None, + frame_callback: Callable[[np.ndarray, float], None] | None = None, + ) -> None: self._provider = provider self._config = config self._robot = robot + self._frame_callback = frame_callback self._renderer: Any | None = None self._next_frame_time = 0.0 self._camera_name = "d435i_rgb" + self._pushed_frames = 0 + + @property + def pushed_frames(self) -> int: + return int(self._pushed_frames) def start(self) -> None: if self._robot is None: @@ -236,7 +274,13 @@ def tick(self) -> None: raise RuntimeError("MuJoCo Pico video requires robot.data") self._renderer.update_scene(data, camera=self._camera_name) frame = np.ascontiguousarray(self._renderer.render(), dtype=np.uint8) - self._provider.push_video_frame(frame) + timestamp_s = time.monotonic() + if self._frame_callback is not None: + self._frame_callback(frame, timestamp_s) + if callable(getattr(self._provider, "push_video_frame", None)): + self._pushed_frames = int(self._provider.push_video_frame(frame)) + else: + self._pushed_frames += 1 self._next_frame_time = now + 1.0 / float(self._config.fps) def stop(self) -> None: diff --git a/teleopit/inputs/realtime_packet.py b/teleopit/inputs/realtime_packet.py index e10cbd4b..5299fe03 100644 --- a/teleopit/inputs/realtime_packet.py +++ b/teleopit/inputs/realtime_packet.py @@ -16,6 +16,7 @@ class ControlEventType(str, Enum): TOGGLE_PAUSE = "toggle_pause" + TOGGLE_ARMS = "toggle_arms" @dataclass(frozen=True) diff --git a/teleopit/interfaces.py b/teleopit/interfaces.py index c115906f..68c79ffc 100644 --- a/teleopit/interfaces.py +++ b/teleopit/interfaces.py @@ -122,15 +122,6 @@ def subscribe(self, topic: str) -> Any: ... -@runtime_checkable -class Recorder(Protocol): - """Records teleoperation data.""" - - def add_frame(self, data: Dict[str, Any]) -> None: - """Record a single frame of data.""" - ... - - @runtime_checkable class ObservationBuilder(Protocol): """Builds observations for controller from robot state.""" diff --git a/teleopit/pipeline.py b/teleopit/pipeline.py index 5ff4b5f7..181a1912 100644 --- a/teleopit/pipeline.py +++ b/teleopit/pipeline.py @@ -10,16 +10,16 @@ from teleopit.controllers.rl_policy import RLPolicyController from teleopit.inputs import BVHInputProvider, Pico4InputProvider from teleopit.inputs.pico_video import PicoVideoRuntime, parse_pico_video_config -from teleopit.recording.hdf5_recorder import HDF5Recorder from teleopit.retargeting.core import RetargetingModule from teleopit.robots.mujoco_robot import MuJoCoRobot from teleopit.runtime.common import cfg_get +from teleopit.runtime.console import PlainConsole from teleopit.runtime.factory import build_inference_components from teleopit.sim.loop import SimulationLoop class TeleopPipeline: - def __init__(self, cfg: DictConfig | dict[str, Any]) -> None: + def __init__(self, cfg: DictConfig | dict[str, Any], *, console: PlainConsole | None = None) -> None: self.cfg = cfg self._project_root = Path(__file__).resolve().parent.parent components = build_inference_components( @@ -54,31 +54,12 @@ def __init__(self, cfg: DictConfig | dict[str, Any]) -> None: components.sim_cfg, viewers=components.viewers, video_runtime=self.video_runtime, + console=console, ) - def run(self, num_steps: int, record: bool = False) -> dict[str, float | int | str]: + def run(self, num_steps: int) -> dict[str, float | int | str]: if num_steps < 0: raise ValueError("num_steps must be non-negative (0 = infinite)") self.controller.reset() - - if not record: - return dict(self.loop.run(cast(Any, self.input_provider), cast(Any, self.retargeter), num_steps=num_steps)) - - rec_cfg = cast(Any, cfg_get(self.cfg, "recording", {})) - output_path = Path(str(cfg_get(rec_cfg, "output_path", "teleop_session.h5"))).expanduser() - if not output_path.is_absolute(): - output_path = (Path.cwd() / output_path).resolve() - output_path.parent.mkdir(parents=True, exist_ok=True) - - with HDF5Recorder(output_path) as recorder: - result = self.loop.run( - cast(Any, self.input_provider), - cast(Any, self.retargeter), - num_steps=num_steps, - recorder=cast(Any, recorder), - ) - - result_with_path: dict[str, float | int | str] = dict(result) - result_with_path["record_path"] = str(output_path) - return result_with_path + return dict(self.loop.run(cast(Any, self.input_provider), cast(Any, self.retargeter), num_steps=num_steps)) diff --git a/teleopit/recording/__init__.py b/teleopit/recording/__init__.py index 82c5b0c5..b8b60782 100644 --- a/teleopit/recording/__init__.py +++ b/teleopit/recording/__init__.py @@ -1,3 +1,19 @@ -from teleopit.recording.hdf5_recorder import HDF5Recorder +from teleopit.recording.pico_motion import ( + PicoDatasetSpec, + RecordingState, + ensure_pico_dataset_spec, + qpos_sequence_to_motion_clip, + sanitize_clip_name, + unique_clip_path, + write_motion_clip_npz, +) -__all__ = ["HDF5Recorder"] +__all__ = [ + "PicoDatasetSpec", + "RecordingState", + "ensure_pico_dataset_spec", + "qpos_sequence_to_motion_clip", + "sanitize_clip_name", + "unique_clip_path", + "write_motion_clip_npz", +] diff --git a/teleopit/recording/hdf5.py b/teleopit/recording/hdf5.py new file mode 100644 index 00000000..36849179 --- /dev/null +++ b/teleopit/recording/hdf5.py @@ -0,0 +1,508 @@ +"""HDF5 recorder and schema helpers for Teleopit sim2real recording.""" + +from __future__ import annotations + +from dataclasses import dataclass +import json +import logging +from pathlib import Path +import re +import time +from typing import Any + +import h5py +import numpy as np + +from teleopit.constants import FULL_QPOS_DIM, NUM_JOINTS +from teleopit.controllers.observation import _quat_rotate_np +from teleopit.math_utils import quat_inv_np +from teleopit.runtime.common import cfg_get + + +logger = logging.getLogger(__name__) + +IMAGE_KEY = "observation.images.d435i_rgb" +STATE_KEY = "observation.state" +MODE_KEY = "observation.mode" +ACTION_KEY = "action" +HAND_ACTION_KEY = "action.hand" +FRAME_INDEX_KEY = "frame_index" +TIMESTAMP_KEY = "timestamp" +STATE_DIM = 68 +MODE_DIM = 1 +ACTION_DIM = FULL_QPOS_DIM +HAND_ACTION_DIM = 12 +DEFAULT_IMAGE_SHAPE = (480, 640, 3) +HDF5_RECORDING_FORMAT = "teleopit_sim2real_recording_hdf5" +HDF5_RECORDING_VERSION = 1 +MODE_CODES = { + "standing": 0, + "mocap": 1, + "arms": 2, + "pause": 3, +} + + +@dataclass(frozen=True) +class RecordingSchema: + image_key: str + image_shape: tuple[int, int, int] + state_key: str = STATE_KEY + state_dim: int = STATE_DIM + mode_key: str = MODE_KEY + mode_dim: int = MODE_DIM + action_key: str = ACTION_KEY + action_dim: int = ACTION_DIM + hand_action_key: str = HAND_ACTION_KEY + hand_action_dim: int = HAND_ACTION_DIM + + +@dataclass(frozen=True) +class MP4VideoConfig: + codec: str = "libx264" + quality: int = 8 + pixelformat: str = "yuv420p" + + +def build_recording_schema(camera_cfg: Any) -> RecordingSchema: + key = str(cfg_get(camera_cfg, "key", IMAGE_KEY)) + width = int(cfg_get(camera_cfg, "width", DEFAULT_IMAGE_SHAPE[1])) + height = int(cfg_get(camera_cfg, "height", DEFAULT_IMAGE_SHAPE[0])) + if width <= 0 or height <= 0: + raise ValueError("recording.camera.width and recording.camera.height must be positive") + return RecordingSchema(image_key=key, image_shape=(height, width, 3)) + + +def build_mp4_video_config(video_cfg: Any) -> MP4VideoConfig: + quality = int(cfg_get(video_cfg, "quality", 8)) + if quality < 0 or quality > 10: + raise ValueError("recording.video.quality must be in [0, 10]") + return MP4VideoConfig( + codec=str(cfg_get(video_cfg, "codec", "libx264")), + quality=quality, + pixelformat=str(cfg_get(video_cfg, "pixelformat", "yuv420p")), + ) + + +def hdf5_schema( + schema: RecordingSchema, + *, + video_config: MP4VideoConfig | None = None, +) -> dict[str, object]: + video_cfg = video_config or MP4VideoConfig() + return { + "format": HDF5_RECORDING_FORMAT, + "version": HDF5_RECORDING_VERSION, + "features": { + schema.image_key: { + "type": "video", + "format": "mp4", + "codec": video_cfg.codec, + "shape": list(schema.image_shape), + "dtype": "uint8", + "sync": { + "frame_index": FRAME_INDEX_KEY, + "timestamp": TIMESTAMP_KEY, + }, + }, + FRAME_INDEX_KEY: { + "type": "index", + "shape": [], + "dtype": "int64", + }, + TIMESTAMP_KEY: { + "type": "timestamp", + "shape": [], + "dtype": "float64", + "units": "seconds", + }, + schema.state_key: { + "type": "low_dim", + "shape": [schema.state_dim], + "dtype": "float32", + "slices": { + "joint_pos": [0, 29], + "joint_vel": [29, 58], + "base_quat_wxyz": [58, 62], + "base_ang_vel": [62, 65], + "projected_gravity": [65, 68], + }, + }, + schema.mode_key: { + "type": "categorical", + "shape": [schema.mode_dim], + "dtype": "float32", + "codes": MODE_CODES, + }, + schema.action_key: { + "type": "low_dim", + "shape": [schema.action_dim], + "dtype": "float32", + "slices": { + "root_pos": [0, 3], + "root_quat_wxyz": [3, 7], + "joint_pos": [7, 36], + }, + }, + schema.hand_action_key: { + "type": "low_dim", + "shape": [schema.hand_action_dim], + "dtype": "float32", + "units": "linkerhand_uint8_pose", + "slices": { + "left_pose": [0, 6], + "right_pose": [6, 12], + }, + }, + }, + } + + +def build_observation_state(robot_state: object) -> np.ndarray: + joint_pos = np.asarray(getattr(robot_state, "qpos"), dtype=np.float32).reshape(-1)[:NUM_JOINTS] + joint_vel = np.asarray(getattr(robot_state, "qvel"), dtype=np.float32).reshape(-1)[:NUM_JOINTS] + base_quat = np.asarray(getattr(robot_state, "quat"), dtype=np.float32).reshape(-1)[:4] + base_ang_vel = np.asarray(getattr(robot_state, "ang_vel"), dtype=np.float32).reshape(-1)[:3] + if joint_pos.shape[0] != NUM_JOINTS: + raise ValueError(f"robot_state.qpos must contain {NUM_JOINTS} joints, got {joint_pos.shape[0]}") + if joint_vel.shape[0] != NUM_JOINTS: + raise ValueError(f"robot_state.qvel must contain {NUM_JOINTS} joints, got {joint_vel.shape[0]}") + if base_quat.shape[0] != 4: + raise ValueError(f"robot_state.quat must be 4D (wxyz), got {base_quat.shape[0]}") + if base_ang_vel.shape[0] != 3: + raise ValueError(f"robot_state.ang_vel must be 3D, got {base_ang_vel.shape[0]}") + gravity_w = np.array([0.0, 0.0, -1.0], dtype=np.float32) + projected_gravity = _quat_rotate_np(quat_inv_np(base_quat), gravity_w) + state = np.concatenate( + [joint_pos, joint_vel, base_quat, base_ang_vel, projected_gravity], + dtype=np.float32, + ) + if state.shape[0] != STATE_DIM: + raise ValueError(f"recording observation.state must be {STATE_DIM}D, got {state.shape[0]}") + return state + + +def normalize_action_reference_qpos(reference_qpos: object) -> np.ndarray: + action = np.asarray(reference_qpos, dtype=np.float32).reshape(-1)[:ACTION_DIM] + if action.shape[0] != ACTION_DIM: + raise ValueError(f"recording action reference qpos must be {ACTION_DIM}D, got {action.shape[0]}") + return action + + +def normalize_hand_action(left_pose: object, right_pose: object) -> np.ndarray: + left = np.asarray(left_pose, dtype=np.float32).reshape(-1) + right = np.asarray(right_pose, dtype=np.float32).reshape(-1) + if left.shape[0] != 6: + raise ValueError(f"recording left hand pose must be 6D, got {left.shape[0]}") + if right.shape[0] != 6: + raise ValueError(f"recording right hand pose must be 6D, got {right.shape[0]}") + action = np.concatenate([left, right], dtype=np.float32) + if action.shape[0] != HAND_ACTION_DIM: + raise ValueError(f"recording action.hand must be {HAND_ACTION_DIM}D, got {action.shape[0]}") + return action + + +def build_mode_observation(mode: str) -> np.ndarray: + normalized = str(mode).strip().lower() + if normalized not in MODE_CODES: + raise ValueError(f"Unsupported recording mode {mode!r}; expected one of {sorted(MODE_CODES)}") + return np.array([MODE_CODES[normalized]], dtype=np.float32) + + +class TeleopitHDF5Recorder: + """Writes one HDF5 file per saved sim2real recording episode.""" + + def __init__( + self, + *, + output_dir: Path, + task: str, + fps: int, + schema: RecordingSchema, + video_config: MP4VideoConfig | None = None, + ) -> None: + self._output_dir = output_dir + self._task = str(task) + self._fps = int(fps) + self._schema = schema + self._video_config = video_config or MP4VideoConfig() + self._active = False + self._frames_in_episode = 0 + self._episode_index = 0 + self._h5: h5py.File | None = None + self._tmp_path: Path | None = None + self._episode_path: Path | None = None + self._tmp_video_path: Path | None = None + self._episode_video_path: Path | None = None + self._video_rel_path: str | None = None + self._video_writer: Any | None = None + self._datasets: dict[str, h5py.Dataset] = {} + + @classmethod + def create( + cls, + *, + output_dir: str | Path, + task: str, + fps: int, + schema: RecordingSchema, + video_config: MP4VideoConfig | None = None, + ) -> "TeleopitHDF5Recorder": + root = Path(output_dir) + root.mkdir(parents=True, exist_ok=True) + recorder = cls( + output_dir=root, + task=task, + fps=fps, + schema=schema, + video_config=video_config, + ) + recorder._write_schema_sidecar() + return recorder + + def start_episode(self) -> None: + if self._active: + raise RuntimeError("Cannot start a new recording episode while one is active") + self._episode_index += 1 + timestamp = time.strftime("%Y%m%d_%H%M%S") + stem = f"episode_{timestamp}_{time.time_ns()}_{self._episode_index:06d}" + tmp_dir = self._output_dir / ".tmp" + episodes_dir = self._output_dir / "episodes" + tmp_dir.mkdir(parents=True, exist_ok=True) + episodes_dir.mkdir(parents=True, exist_ok=True) + self._tmp_path = tmp_dir / f"{stem}.h5" + self._episode_path = episodes_dir / f"{stem}.h5" + self._h5 = h5py.File(self._tmp_path, "w") + video_dir = self._output_dir / "videos" / _safe_path_component(self._schema.image_key) + tmp_video_dir = tmp_dir / "videos" / _safe_path_component(self._schema.image_key) + video_dir.mkdir(parents=True, exist_ok=True) + tmp_video_dir.mkdir(parents=True, exist_ok=True) + self._tmp_video_path = tmp_video_dir / f"{stem}.mp4" + self._episode_video_path = video_dir / f"{stem}.mp4" + self._video_rel_path = self._episode_video_path.relative_to(self._output_dir).as_posix() + try: + self._video_writer = self._create_video_writer(self._tmp_video_path) + self._write_episode_header(self._h5) + self._datasets = self._create_datasets(self._h5) + self._active = True + self._frames_in_episode = 0 + except Exception: + self._cleanup_partial_episode() + raise + + def add_frame( + self, + *, + image: np.ndarray, + state: np.ndarray, + mode: np.ndarray, + action: np.ndarray, + hand_action: np.ndarray, + task: str, + ) -> None: + if not self._active or self._h5 is None: + raise RuntimeError("Cannot add a recording frame without an active episode") + image_arr = np.asarray(image, dtype=np.uint8) + if tuple(image_arr.shape) != self._schema.image_shape: + raise ValueError(f"{self._schema.image_key} frame shape {image_arr.shape} != {self._schema.image_shape}") + state_arr = self._validate_vector(state, self._schema.state_key, self._schema.state_dim) + mode_arr = self._validate_vector(mode, self._schema.mode_key, self._schema.mode_dim) + action_arr = self._validate_vector(action, self._schema.action_key, self._schema.action_dim) + hand_action_arr = self._validate_vector(hand_action, self._schema.hand_action_key, self._schema.hand_action_dim) + + row = self._frames_in_episode + for dataset in self._datasets.values(): + dataset.resize((row + 1, *dataset.shape[1:])) + if self._video_writer is None: + raise RuntimeError("MP4 recording writer is not open") + self._video_writer.append_data(image_arr) + self._datasets[FRAME_INDEX_KEY][row] = row + self._datasets[TIMESTAMP_KEY][row] = float(row) / float(self._fps) + self._datasets[self._schema.state_key][row] = state_arr + self._datasets[self._schema.mode_key][row] = mode_arr + self._datasets[self._schema.action_key][row] = action_arr + self._datasets[self._schema.hand_action_key][row] = hand_action_arr + self._frames_in_episode += 1 + self._h5.attrs["frames"] = self._frames_in_episode + self._h5.attrs["task"] = str(task) + + def save_episode(self) -> None: + if not self._active: + return + tmp_path = self._require_tmp_path() + episode_path = self._require_episode_path() + tmp_video_path = self._tmp_video_path + episode_video_path = self._episode_video_path + self._close_active_outputs() + if tmp_video_path is None or episode_video_path is None: + raise RuntimeError("recording episode has no video output path") + tmp_video_path.replace(episode_video_path) + tmp_path.replace(episode_path) + self._reset_episode() + + def discard_episode(self) -> None: + if not self._active: + return + tmp_path = self._tmp_path + tmp_video_path = self._tmp_video_path + self._close_active_outputs() + if tmp_path is not None and tmp_path.exists(): + tmp_path.unlink() + if tmp_video_path is not None and tmp_video_path.exists(): + tmp_video_path.unlink() + self._reset_episode() + + def finalize(self) -> None: + if self._active: + self.discard_episode() + + def _write_schema_sidecar(self) -> None: + path = self._output_dir / "schema.json" + path.write_text(json.dumps(self._schema_dict(), indent=2) + "\n", encoding="utf-8") + + def _write_episode_header(self, h5: h5py.File) -> None: + h5.attrs["format"] = HDF5_RECORDING_FORMAT + h5.attrs["version"] = HDF5_RECORDING_VERSION + h5.attrs["task"] = self._task + h5.attrs["fps"] = self._fps + h5.attrs["frames"] = 0 + h5.attrs["schema_json"] = json.dumps(self._schema_dict(), sort_keys=True) + h5.attrs["video_key"] = self._schema.image_key + h5.attrs["video_path"] = self._video_rel_path or "" + h5.attrs["video_format"] = "mp4" + h5.attrs["video_codec"] = self._video_config.codec + h5.attrs["video_pixelformat"] = self._video_config.pixelformat + h5.attrs["video_fps"] = self._fps + h5.attrs["video_frames"] = 0 + h5.attrs["video_from_timestamp_s"] = 0.0 + h5.attrs["video_to_timestamp_s"] = 0.0 + + def _create_datasets(self, h5: h5py.File) -> dict[str, h5py.Dataset]: + return { + FRAME_INDEX_KEY: h5.create_dataset( + FRAME_INDEX_KEY, + shape=(0,), + maxshape=(None,), + chunks=(1024,), + dtype=np.int64, + ), + TIMESTAMP_KEY: h5.create_dataset( + TIMESTAMP_KEY, + shape=(0,), + maxshape=(None,), + chunks=(1024,), + dtype=np.float64, + ), + self._schema.state_key: self._create_vector_dataset(h5, self._schema.state_key, self._schema.state_dim), + self._schema.mode_key: self._create_vector_dataset(h5, self._schema.mode_key, self._schema.mode_dim), + self._schema.action_key: self._create_vector_dataset(h5, self._schema.action_key, self._schema.action_dim), + self._schema.hand_action_key: self._create_vector_dataset( + h5, self._schema.hand_action_key, self._schema.hand_action_dim + ), + } + + @staticmethod + def _create_vector_dataset(h5: h5py.File, key: str, dim: int) -> h5py.Dataset: + return h5.create_dataset( + key, + shape=(0, dim), + maxshape=(None, dim), + chunks=(1024, dim), + dtype=np.float32, + ) + + @staticmethod + def _validate_vector(value: object, key: str, dim: int) -> np.ndarray: + arr = np.asarray(value, dtype=np.float32).reshape(-1) + if arr.shape[0] != dim: + raise ValueError(f"{key} must be {dim}D") + return arr + + def _schema_dict(self) -> dict[str, object]: + return hdf5_schema( + self._schema, + video_config=self._video_config, + ) + + def _create_video_writer(self, path: Path) -> Any: + try: + import imageio.v2 as imageio + except Exception as exc: + raise RuntimeError("MP4 recording requires imageio[ffmpeg]") from exc + return imageio.get_writer( + str(path), + fps=self._fps, + codec=self._video_config.codec, + quality=self._video_config.quality, + macro_block_size=1, + pixelformat=self._video_config.pixelformat, + ) + + def _close_active_outputs(self) -> None: + if self._video_writer is not None: + self._video_writer.close() + self._video_writer = None + if self._h5 is not None: + self._h5.attrs["video_frames"] = self._frames_in_episode + self._h5.attrs["video_to_timestamp_s"] = ( + float(max(self._frames_in_episode - 1, 0)) / float(self._fps) + if self._frames_in_episode > 0 + else 0.0 + ) + self._h5.attrs["video_path"] = self._video_rel_path or "" + self._h5.close() + self._h5 = None + self._datasets = {} + + def _reset_episode(self) -> None: + self._active = False + self._frames_in_episode = 0 + self._tmp_path = None + self._episode_path = None + self._tmp_video_path = None + self._episode_video_path = None + self._video_rel_path = None + self._video_writer = None + + def _cleanup_partial_episode(self) -> None: + if self._video_writer is not None: + try: + self._video_writer.close() + except Exception: + logger.exception("Failed to close partial MP4 recording writer") + self._video_writer = None + if self._h5 is not None: + try: + self._h5.close() + except Exception: + logger.exception("Failed to close partial HDF5 recording file") + self._h5 = None + tmp_path = self._tmp_path + tmp_video_path = self._tmp_video_path + if tmp_path is not None and tmp_path.exists(): + try: + tmp_path.unlink() + except Exception: + logger.exception("Failed to remove partial HDF5 recording file: %s", tmp_path) + if tmp_video_path is not None and tmp_video_path.exists(): + try: + tmp_video_path.unlink() + except Exception: + logger.exception("Failed to remove partial MP4 recording file: %s", tmp_video_path) + self._datasets = {} + self._reset_episode() + + def _require_tmp_path(self) -> Path: + if self._tmp_path is None: + raise RuntimeError("recording episode has no temporary path") + return self._tmp_path + + def _require_episode_path(self) -> Path: + if self._episode_path is None: + raise RuntimeError("recording episode has no output path") + return self._episode_path + + +def _safe_path_component(value: str) -> str: + safe = re.sub(r"[^A-Za-z0-9._-]+", "_", value).strip("._") + return safe or "camera" diff --git a/teleopit/recording/hdf5_recorder.py b/teleopit/recording/hdf5_recorder.py deleted file mode 100644 index b9138ec1..00000000 --- a/teleopit/recording/hdf5_recorder.py +++ /dev/null @@ -1,114 +0,0 @@ -"""HDF5-based data recorder for teleoperation sessions.""" -from __future__ import annotations - -import time -from pathlib import Path -from typing import Any, Dict - -import h5py -import numpy as np - - -class HDF5Recorder: - """Records teleoperation data to HDF5 format with chunked storage and compression. - - Supports numerical data fields: joint_pos, joint_vel, mimic_obs, action, timestamp. - First call to add_frame creates datasets based on data keys/shapes (resizable). - Subsequent calls resize and append data. - """ - - def __init__(self, path: str | Path, chunk_size: int = 100): - """Initialize HDF5 recorder. - - Args: - path: Path to HDF5 file to create - chunk_size: Chunk size for HDF5 datasets (affects compression/performance) - """ - self.path = Path(path) - self.chunk_size = chunk_size - self.file: h5py.File | None = None - self.datasets: Dict[str, h5py.Dataset] = {} - self.frame_count = 0 - self.start_time = time.time() - self._initialized = False - - def __enter__(self) -> HDF5Recorder: - """Context manager entry.""" - self.file = h5py.File(self.path, 'w') - return self - - def __exit__(self, exc_type, exc_val, exc_tb) -> None: - """Context manager exit.""" - self.close() - - def add_frame(self, data: Dict[str, np.ndarray]) -> None: - """Add a frame of data to the recording. - - First call creates datasets based on data keys/shapes. - Subsequent calls resize and append data. - - Args: - data: Dictionary mapping field names to numpy arrays - Supported fields: joint_pos, joint_vel, mimic_obs, action, timestamp - """ - if self.file is None: - raise RuntimeError("Recorder not opened. Use context manager or call __enter__") - - if not self._initialized: - self._create_datasets(data) - self._initialized = True - - # Resize and write to each dataset - for key, value in data.items(): - if key not in self.datasets: - raise ValueError(f"Unexpected key '{key}' not in initial frame") - - dataset = self.datasets[key] - # Resize to accommodate new frame - dataset.resize(self.frame_count + 1, axis=0) - # Write data - dataset[self.frame_count] = value - - self.frame_count += 1 - - def _create_datasets(self, data: Dict[str, np.ndarray]) -> None: - """Create resizable HDF5 datasets based on first frame data. - - Args: - data: First frame of data defining dataset structure - """ - for key, value in data.items(): - arr = np.asarray(value) - shape = arr.shape - dtype = arr.dtype - - # Create resizable dataset with chunking and compression - chunk_shape = (self.chunk_size,) + shape - max_shape = (None,) + shape # Unlimited first dimension - - self.datasets[key] = self.file.create_dataset( - key, - shape=(0,) + shape, - maxshape=max_shape, - chunks=chunk_shape, - dtype=dtype, - compression='gzip', - compression_opts=4 - ) - - def close(self) -> None: - """Close the recorder and write metadata.""" - if self.file is None: - return - - # Write metadata - end_time = time.time() - recording_duration = end_time - self.start_time - - self.file.attrs['total_frames'] = self.frame_count - self.file.attrs['recording_time'] = recording_duration - self.file.attrs['start_time'] = self.start_time - self.file.attrs['end_time'] = end_time - - self.file.close() - self.file = None diff --git a/teleopit/recording/pico_motion.py b/teleopit/recording/pico_motion.py new file mode 100644 index 00000000..77d9509e --- /dev/null +++ b/teleopit/recording/pico_motion.py @@ -0,0 +1,232 @@ +"""Helpers for recording Pico-retargeted G1 motion clips.""" + +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +import re +import threading +import time +from typing import Any, Iterable + +import numpy as np + +from teleopit.constants import FULL_QPOS_DIM, NUM_JOINTS, ROOT_DIM +from teleopit.runtime.assets import PROJECT_ROOT +from train_mimic.data.dataset_lib import inspect_clip_dict +from train_mimic.data.motion_fk import MotionFkExtractor, compute_body_velocities, finite_diff_velocity +from train_mimic.scripts.convert_pkl_to_npz import _MJLAB_G1_BODY_NAMES + + +@dataclass(frozen=True) +class PicoDatasetSpec: + """Dataset spec defaults for Pico-recorded NPZ clips.""" + + dataset_name: str = "pico_recorded" + target_fps: int = 30 + source_name: str = "pico_clips" + + +class RecordingState: + """Thread-safe state shared by terminal UI and retarget worker.""" + + def __init__(self, clip_name: str | None = None) -> None: + self._lock = threading.Lock() + self.clip_name = clip_name + self.recording = False + self.qpos_buffer: list[np.ndarray] = [] + self.record_start_s: float | None = None + + def set_clip_name(self, clip_name: str) -> None: + with self._lock: + if self.recording: + raise RuntimeError("cannot change clip name while recording") + self.clip_name = clip_name + self.qpos_buffer.clear() + self.record_start_s = None + + def start(self) -> str: + with self._lock: + if self.clip_name is None: + raise RuntimeError("clip name must be set before recording") + if self.recording: + return self.clip_name + self.qpos_buffer.clear() + self.recording = True + self.record_start_s = time.monotonic() + return self.clip_name + + def discard(self) -> tuple[str | None, int]: + with self._lock: + clip_name = self.clip_name + frame_count = len(self.qpos_buffer) + self.recording = False + self.qpos_buffer.clear() + self.record_start_s = None + return clip_name, frame_count + + def snapshot(self) -> tuple[str | None, bool, list[np.ndarray]]: + with self._lock: + clip_name = self.clip_name + recording = self.recording + frames = [frame.copy() for frame in self.qpos_buffer] + return clip_name, recording, frames + + def mark_saved(self) -> None: + with self._lock: + self.recording = False + self.qpos_buffer.clear() + self.record_start_s = None + + def append(self, qpos: np.ndarray) -> None: + with self._lock: + if self.recording: + self.qpos_buffer.append(qpos.copy()) + + def status(self) -> tuple[str | None, bool, int, float | None]: + with self._lock: + elapsed = None if self.record_start_s is None else time.monotonic() - self.record_start_s + return self.clip_name, self.recording, len(self.qpos_buffer), elapsed + + +def sanitize_clip_name(raw_name: str) -> str: + """Return a filesystem-friendly semantic clip name.""" + name = raw_name.strip().lower() + name = re.sub(r"\s+", "_", name) + name = re.sub(r"[^a-z0-9_.-]+", "_", name) + name = re.sub(r"_+", "_", name).strip("._-") + if not name: + raise ValueError("clip name must contain at least one letter or digit") + return name + + +def timestamp_suffix(now: datetime | None = None) -> str: + current = now or datetime.now() + return current.strftime("%Y%m%d_%H%M%S") + + +def unique_clip_path(output_dir: str | Path, clip_name: str, *, now: datetime | None = None) -> Path: + """Build a non-overwriting NPZ path for one semantic clip.""" + out_dir = Path(output_dir).expanduser() + safe_name = sanitize_clip_name(clip_name) + stem = f"{safe_name}_{timestamp_suffix(now)}" + candidate = out_dir / f"{stem}.npz" + if not candidate.exists(): + return candidate + for index in range(1, 1000): + candidate = out_dir / f"{stem}_{index:03d}.npz" + if not candidate.exists(): + return candidate + raise RuntimeError(f"could not allocate a unique clip path in {out_dir}") + + +def _display_path(path: Path, *, project_root: Path = PROJECT_ROOT) -> str: + try: + return path.resolve().relative_to(project_root.resolve()).as_posix() + except ValueError: + return str(path) + + +def ensure_pico_dataset_spec( + spec_path: str | Path, + clips_dir: str | Path, + *, + spec: PicoDatasetSpec = PicoDatasetSpec(), + overwrite: bool = False, +) -> Path: + """Create the dataset YAML spec used to merge recorded Pico clips. + + Existing specs are preserved by default so hand-edited settings are not lost. + """ + path = Path(spec_path).expanduser() + if path.exists() and not overwrite: + return path + + clips_path = Path(clips_dir).expanduser() + path.parent.mkdir(parents=True, exist_ok=True) + clips_display = _display_path(clips_path) + content = ( + f"name: {spec.dataset_name}\n" + f"target_fps: {int(spec.target_fps)}\n" + "preprocess:\n" + " normalize_root_xy: true\n" + " ground_align: first_frame_foot\n" + "sources:\n" + f" - name: {spec.source_name}\n" + " type: npz\n" + f" input: {clips_display}\n" + ) + path.write_text(content, encoding="utf-8") + return path + + +def qpos_sequence_to_motion_clip( + qpos_sequence: Iterable[np.ndarray] | np.ndarray, + *, + fps: int, + extractor: Any | None = None, + body_names: list[str] | None = None, +) -> dict[str, Any]: + """Convert retargeted G1 qpos frames into a standard training motion clip.""" + qpos = np.asarray(list(qpos_sequence) if not isinstance(qpos_sequence, np.ndarray) else qpos_sequence, dtype=np.float32) + if qpos.ndim != 2 or qpos.shape[1] != FULL_QPOS_DIM: + raise ValueError(f"qpos_sequence must have shape (T,{FULL_QPOS_DIM}), got {qpos.shape}") + if qpos.shape[0] < 2: + raise ValueError("qpos_sequence must contain at least 2 frames") + if int(fps) <= 0: + raise ValueError(f"fps must be > 0, got {fps}") + if not np.isfinite(qpos).all(): + raise ValueError("qpos_sequence contains NaN/Inf") + + names = list(_MJLAB_G1_BODY_NAMES if body_names is None else body_names) + root_pos = qpos[:, 0:3].astype(np.float32, copy=False) + root_quat_wxyz = qpos[:, 3:7].astype(np.float32, copy=False) + joint_pos = qpos[:, ROOT_DIM:ROOT_DIM + NUM_JOINTS].astype(np.float32, copy=False) + if joint_pos.shape[1] != NUM_JOINTS: + raise ValueError(f"joint_pos must have {NUM_JOINTS} columns, got {joint_pos.shape}") + + dt = 1.0 / float(fps) + joint_vel = finite_diff_velocity(joint_pos, dt) + + fk_extractor = extractor or MotionFkExtractor() + body_pos_w, body_quat_w = fk_extractor.extract(root_pos, root_quat_wxyz, joint_pos, names) + body_lin_vel_w, body_ang_vel_w = compute_body_velocities(body_pos_w, body_quat_w, dt) + + clip = { + "fps": int(fps), + "root_pos": root_pos.astype(np.float32, copy=False), + "root_quat_w": root_quat_wxyz.astype(np.float32, copy=False), + "joint_pos": joint_pos.astype(np.float32, copy=False), + "joint_vel": joint_vel.astype(np.float32, copy=False), + "body_pos_w": np.asarray(body_pos_w, dtype=np.float32), + "body_quat_w": np.asarray(body_quat_w, dtype=np.float32), + "body_lin_vel_w": np.asarray(body_lin_vel_w, dtype=np.float32), + "body_ang_vel_w": np.asarray(body_ang_vel_w, dtype=np.float32), + "body_names": np.asarray(names, dtype=str), + } + inspect_clip_dict(clip) + return clip + + +def write_motion_clip_npz( + output_path: str | Path, + qpos_sequence: Iterable[np.ndarray] | np.ndarray, + *, + fps: int, + extractor: Any | None = None, +) -> Path: + """Write a retargeted qpos sequence as an atomically replaced motion NPZ.""" + path = Path(output_path).expanduser() + path.parent.mkdir(parents=True, exist_ok=True) + clip = qpos_sequence_to_motion_clip(qpos_sequence, fps=fps, extractor=extractor) + tmp_path = path.with_name(f"{path.name}.tmp") + try: + with tmp_path.open("wb") as handle: + np.savez(handle, **clip) + path_tmp = tmp_path + path_tmp.replace(path) + finally: + if tmp_path.exists(): + tmp_path.unlink() + return path diff --git a/teleopit/retargeting/gmr/ik_configs/pico_bridge_to_g1.json b/teleopit/retargeting/gmr/ik_configs/pico_bridge_to_g1.json index 673bbc1e..75568756 100644 --- a/teleopit/retargeting/gmr/ik_configs/pico_bridge_to_g1.json +++ b/teleopit/retargeting/gmr/ik_configs/pico_bridge_to_g1.json @@ -25,10 +25,10 @@ "pelvis": ["Pelvis", 0, 10, [0.0, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], "left_hip_yaw_link": ["Left_Hip", 0, 10, [0.0, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], "left_knee_link": ["Left_Knee", 0, 10, [0.0, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], - "left_toe_link": ["Left_Foot", 100, 10, [0.05, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], + "left_foot": ["Left_Foot", 100, 10, [-0.01, 0.0, -0.015], [-0.5, 0.5, -0.5, -0.5], "site"], "right_hip_yaw_link": ["Right_Hip", 0, 10, [0.0, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], "right_knee_link": ["Right_Knee", 0, 10, [0.0, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], - "right_toe_link": ["Right_Foot", 100, 10, [0.05, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], + "right_foot": ["Right_Foot", 100, 10, [-0.01, 0.0, -0.015], [-0.5, 0.5, -0.5, -0.5], "site"], "torso_link": ["Spine3", 0, 10, [0.0, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], "left_shoulder_yaw_link": ["Left_Shoulder", 0, 10, [0.0, 0.0, 0.0], [0.7071067811865475, 0.0, 0.7071067811865475, 0.0]], "left_elbow_link": ["Left_Elbow", 0, 10, [0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]], @@ -41,10 +41,10 @@ "pelvis": ["Pelvis", 10, 5, [0.0, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], "left_hip_yaw_link": ["Left_Hip", 10, 5, [0.0, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], "left_knee_link": ["Left_Knee", 10, 5, [0.0, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], - "left_toe_link": ["Left_Foot", 100, 10, [0.05, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], + "left_foot": ["Left_Foot", 100, 10, [-0.01, 0.0, -0.015], [-0.5, 0.5, -0.5, -0.5], "site"], "right_hip_yaw_link": ["Right_Hip", 10, 5, [0.0, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], "right_knee_link": ["Right_Knee", 10, 5, [0.0, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], - "right_toe_link": ["Right_Foot", 100, 10, [0.05, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], + "right_foot": ["Right_Foot", 100, 10, [-0.01, 0.0, -0.015], [-0.5, 0.5, -0.5, -0.5], "site"], "torso_link": ["Spine3", 0, 10, [0.0, 0.0, 0.0], [-0.5, 0.5, -0.5, -0.5]], "left_shoulder_yaw_link": ["Left_Shoulder", 0, 10, [0.0, 0.0, 0.0], [0.7071067811865475, 0.0, 0.7071067811865475, 0.0]], "left_elbow_link": ["Left_Elbow", 0, 10, [0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]], @@ -53,4 +53,4 @@ "right_elbow_link": ["Right_Elbow", 0, 10, [0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]], "right_wrist_yaw_link": ["Right_Wrist", 0, 10, [0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]] } -} \ No newline at end of file +} diff --git a/teleopit/retargeting/gmr/motion_retarget.py b/teleopit/retargeting/gmr/motion_retarget.py index 28508e13..c97400f9 100644 --- a/teleopit/retargeting/gmr/motion_retarget.py +++ b/teleopit/retargeting/gmr/motion_retarget.py @@ -7,6 +7,13 @@ from .params import ROBOT_XML_DICT, IK_CONFIG_DICT from rich import print +_FRAME_TYPE_TO_MJ_OBJ = { + "body": mj.mjtObj.mjOBJ_BODY, + "geom": mj.mjtObj.mjOBJ_GEOM, + "site": mj.mjtObj.mjOBJ_SITE, +} + + class GeneralMotionRetargeting: """General Motion Retargeting (GMR). """ @@ -28,8 +35,9 @@ def __init__( self.model = mj.MjModel.from_xml_path(self.xml_file) # Print DoF names in order - print("[GMR] Robot Degrees of Freedom (DoF) names and their order:") self.robot_dof_names = {} + if verbose: + print("[GMR] Robot Degrees of Freedom (DoF) names and their order:") for i in range(self.model.nv): # 'nv' is the number of DoFs dof_name = mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_JOINT, self.model.dof_jntid[i]) self.robot_dof_names[dof_name] = i @@ -37,16 +45,18 @@ def __init__( print(f"DoF {i}: {dof_name}") - print("[GMR] Robot Body names and their IDs:") self.robot_body_names = {} + if verbose: + print("[GMR] Robot Body names and their IDs:") for i in range(self.model.nbody): # 'nbody' is the number of bodies body_name = mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_BODY, i) self.robot_body_names[body_name] = i if verbose: print(f"Body ID {i}: {body_name}") - print("[GMR] Robot Motor (Actuator) names and their IDs:") self.robot_motor_names = {} + if verbose: + print("[GMR] Robot Motor (Actuator) names and their IDs:") for i in range(self.model.nu): # 'nu' is the number of actuators (motors) motor_name = mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_ACTUATOR, i) self.robot_motor_names[motor_name] = i @@ -107,6 +117,47 @@ def __init__( self._warmup_max_iter = 200 self._warmup_dt = 0.1 # large integration step for fast convergence during warmup + def _parse_ik_entry(self, entry): + if len(entry) == 5: + body_name, pos_weight, rot_weight, pos_offset, rot_offset = entry + frame_type = "body" + elif len(entry) == 6: + body_name, pos_weight, rot_weight, pos_offset, rot_offset, frame_type = entry + else: + raise ValueError( + "IK config entries must be [human_body, pos_weight, rot_weight, " + "pos_offset, rot_offset] or the same list plus frame_type" + ) + frame_type = str(frame_type) + if frame_type not in _FRAME_TYPE_TO_MJ_OBJ: + supported = ", ".join(sorted(_FRAME_TYPE_TO_MJ_OBJ)) + raise ValueError(f"Unsupported IK frame_type '{frame_type}'. Supported values: {supported}") + return body_name, pos_weight, rot_weight, pos_offset, rot_offset, frame_type + + def _available_frame_names(self, frame_type): + obj_type = _FRAME_TYPE_TO_MJ_OBJ[frame_type] + count = { + "body": self.model.nbody, + "geom": self.model.ngeom, + "site": self.model.nsite, + }[frame_type] + names = [] + for idx in range(count): + name = mj.mj_id2name(self.model, obj_type, idx) + if name: + names.append(name) + return names + + def _validate_frame(self, frame_name, frame_type, table_name): + obj_type = _FRAME_TYPE_TO_MJ_OBJ[frame_type] + if mj.mj_name2id(self.model, obj_type, frame_name) >= 0: + return + available = ", ".join(self._available_frame_names(frame_type)) + raise ValueError( + f"IK config {table_name} references {frame_type} '{frame_name}', but it does not exist " + f"in robot model '{self.xml_file}'. Update the IK config to use one of: {available}" + ) + def reset_configuration(self): """Reset the IK configuration to the model's default qpos. @@ -128,11 +179,12 @@ def setup_retarget_configuration(self): self.tasks2 = [] for frame_name, entry in self.ik_match_table1.items(): - body_name, pos_weight, rot_weight, pos_offset, rot_offset = entry + body_name, pos_weight, rot_weight, pos_offset, rot_offset, frame_type = self._parse_ik_entry(entry) if pos_weight != 0 or rot_weight != 0: + self._validate_frame(frame_name, frame_type, "ik_match_table1") task = mink.FrameTask( frame_name=frame_name, - frame_type="body", + frame_type=frame_type, position_cost=pos_weight, orientation_cost=rot_weight, lm_damping=1, @@ -146,11 +198,12 @@ def setup_retarget_configuration(self): self.task_errors1[task] = [] for frame_name, entry in self.ik_match_table2.items(): - body_name, pos_weight, rot_weight, pos_offset, rot_offset = entry + body_name, pos_weight, rot_weight, pos_offset, rot_offset, frame_type = self._parse_ik_entry(entry) if pos_weight != 0 or rot_weight != 0: + self._validate_frame(frame_name, frame_type, "ik_match_table2") task = mink.FrameTask( frame_name=frame_name, - frame_type="body", + frame_type=frame_type, position_cost=pos_weight, orientation_cost=rot_weight, lm_damping=1, diff --git a/teleopit/retargeting/gmr/params.py b/teleopit/retargeting/gmr/params.py index 7c87c30d..1110eef0 100644 --- a/teleopit/retargeting/gmr/params.py +++ b/teleopit/retargeting/gmr/params.py @@ -1,5 +1,11 @@ from pathlib import Path +from teleopit.runtime.assets import ( + UNITREE_G1_AVP_O6_XML, + UNITREE_G1_DEX3_XML, + UNITREE_G1_XML, +) + BASE_DIR = Path(__file__).parent @@ -11,8 +17,9 @@ def _resolve_path(relative_path): ASSET_ROOT = _resolve_path("assets") ROBOT_XML_DICT = { - "unitree_g1": _resolve_path("assets/unitree_g1/g1_mocap_29dof.xml"), - "unitree_g1_with_hands": _resolve_path("assets/unitree_g1/g1_mocap_29dof_with_hands.xml"), + "unitree_g1": UNITREE_G1_XML, + "unitree_g1_with_hands": UNITREE_G1_DEX3_XML, + "unitree_g1_avp_o6": UNITREE_G1_AVP_O6_XML, "unitree_h1": _resolve_path("assets/unitree_h1/h1.xml"), "unitree_h1_2": _resolve_path("assets/unitree_h1_2/h1_2_handless.xml"), "booster_t1": _resolve_path("assets/booster_t1/T1_serial.xml"), @@ -36,6 +43,7 @@ def _resolve_path(relative_path): "smplx": { "unitree_g1": _resolve_path("ik_configs/smplx_to_g1.json"), "unitree_g1_with_hands": _resolve_path("ik_configs/smplx_to_g1.json"), + "unitree_g1_avp_o6": _resolve_path("ik_configs/smplx_to_g1.json"), "unitree_h1": _resolve_path("ik_configs/smplx_to_h1.json"), "unitree_h1_2": _resolve_path("ik_configs/smplx_to_h1_2.json"), "booster_t1": _resolve_path("ik_configs/smplx_to_t1.json"), @@ -55,6 +63,7 @@ def _resolve_path(relative_path): "bvh_lafan1": { "unitree_g1": _resolve_path("ik_configs/bvh_lafan1_to_g1.json"), "unitree_g1_with_hands": _resolve_path("ik_configs/bvh_lafan1_to_g1.json"), + "unitree_g1_avp_o6": _resolve_path("ik_configs/bvh_lafan1_to_g1.json"), "booster_t1_29dof": _resolve_path("ik_configs/bvh_lafan1_to_t1_29dof.json"), "fourier_n1": _resolve_path("ik_configs/bvh_lafan1_to_n1.json"), "stanford_toddy": _resolve_path("ik_configs/bvh_lafan1_to_toddy.json"), @@ -74,6 +83,7 @@ def _resolve_path(relative_path): "fbx": { "unitree_g1": _resolve_path("ik_configs/fbx_to_g1.json"), "unitree_g1_with_hands": _resolve_path("ik_configs/fbx_to_g1.json"), + "unitree_g1_avp_o6": _resolve_path("ik_configs/fbx_to_g1.json"), }, "fbx_offline": { "unitree_g1": _resolve_path("ik_configs/fbx_offline_to_g1.json"), @@ -87,6 +97,7 @@ def _resolve_path(relative_path): ROBOT_BASE_DICT = { "unitree_g1": "pelvis", "unitree_g1_with_hands": "pelvis", + "unitree_g1_avp_o6": "pelvis", "unitree_h1": "pelvis", "unitree_h1_2": "pelvis", "booster_t1": "Waist", @@ -108,6 +119,7 @@ def _resolve_path(relative_path): VIEWER_CAM_DISTANCE_DICT = { "unitree_g1": 2.0, "unitree_g1_with_hands": 2.0, + "unitree_g1_avp_o6": 2.0, "unitree_h1": 3.0, "unitree_h1_2": 3.0, "booster_t1": 2.0, diff --git a/teleopit/robots/mujoco_robot.py b/teleopit/robots/mujoco_robot.py index b3283921..bdaaf5db 100644 --- a/teleopit/robots/mujoco_robot.py +++ b/teleopit/robots/mujoco_robot.py @@ -8,7 +8,11 @@ from omegaconf import DictConfig from teleopit.interfaces import RobotState -from teleopit.runtime.assets import GMR_ASSETS_ROOT, missing_gmr_assets_message +from teleopit.runtime.assets import ( + GMR_ASSETS_ROOT, + ROBOT_ASSETS_ROOT, + missing_gmr_assets_message, +) def _quat_conjugate(quat_wxyz: np.ndarray) -> np.ndarray: @@ -62,9 +66,11 @@ def __init__(self, cfg: DictConfig) -> None: xml_path = Path.cwd() / xml_path xml_path = xml_path.resolve() if not xml_path.exists(): - try: - xml_path.relative_to(GMR_ASSETS_ROOT) - except ValueError: + asset_roots = (ROBOT_ASSETS_ROOT, GMR_ASSETS_ROOT) + is_external_asset = any( + xml_path.is_relative_to(root) for root in asset_roots + ) + if not is_external_asset: raise FileNotFoundError(f"MuJoCo XML not found: {xml_path}") from None raise FileNotFoundError( missing_gmr_assets_message(xml_path, label="MuJoCo XML") diff --git a/teleopit/runtime/arm_mocap.py b/teleopit/runtime/arm_mocap.py new file mode 100644 index 00000000..3144749d --- /dev/null +++ b/teleopit/runtime/arm_mocap.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from typing import Any + +import numpy as np +from numpy.typing import NDArray + +from teleopit.constants import ROOT_DIM +from teleopit.runtime.common import cfg_get +from teleopit.sim.reference_timeline import ReferenceSample, ReferenceWindow + + +Float64Array = NDArray[np.float64] + + +def parse_arm_joint_indices(cfg: Any, *, num_actions: int) -> NDArray[np.int64]: + arm_mocap_cfg = cfg_get(cfg, "arm_mocap", {}) or {} + raw = cfg_get(arm_mocap_cfg, "controlled_joint_indices", None) + default = list(range(15, num_actions)) if num_actions > 15 else [num_actions - 1] + indices = np.asarray(default if raw is None else raw, dtype=np.int64).reshape(-1) + if indices.size == 0 or np.any(indices < 0) or np.any(indices >= num_actions): + raise ValueError( + f"arm_mocap.controlled_joint_indices must contain valid joint indices in [0, {num_actions}), " + f"got {indices.tolist()}" + ) + if np.unique(indices).shape[0] != indices.shape[0]: + raise ValueError("arm_mocap.controlled_joint_indices must not contain duplicates") + return indices + + +def compose_arm_reference( + *, + standing_qpos: Float64Array, + retarget_qpos: Float64Array, + arm_joint_indices: NDArray[np.int64], + num_actions: int, +) -> Float64Array: + retarget = np.asarray(retarget_qpos, dtype=np.float64).reshape(-1) + if retarget.shape[0] < ROOT_DIM + num_actions: + raise ValueError(f"Retargeted qpos too short: {retarget.shape[0]} (need >= {ROOT_DIM + num_actions})") + composed = np.asarray(standing_qpos, dtype=np.float64).reshape(-1).copy() + joint_indices = ROOT_DIM + np.asarray(arm_joint_indices, dtype=np.int64).reshape(-1) + composed[joint_indices] = retarget[joint_indices] + return composed + + +def compose_arm_reference_window( + reference_window: ReferenceWindow | None, + *, + standing_qpos: Float64Array, + arm_joint_indices: NDArray[np.int64], + num_actions: int, +) -> ReferenceWindow | None: + if reference_window is None: + return None + samples = tuple( + ReferenceSample( + qpos=compose_arm_reference( + standing_qpos=standing_qpos, + retarget_qpos=sample.qpos, + arm_joint_indices=arm_joint_indices, + num_actions=num_actions, + ), + timestamp_s=float(sample.timestamp_s), + mode=str(sample.mode), + used_fallback=bool(sample.used_fallback), + older_timestamp_s=sample.older_timestamp_s, + newer_timestamp_s=sample.newer_timestamp_s, + alpha=sample.alpha, + ) + for sample in reference_window.samples + ) + return ReferenceWindow( + base_time_s=float(reference_window.base_time_s), + policy_dt_s=float(reference_window.policy_dt_s), + reference_steps=tuple(reference_window.reference_steps), + samples=samples, + ) diff --git a/teleopit/runtime/assets.py b/teleopit/runtime/assets.py index e9be0638..00d2e8e7 100644 --- a/teleopit/runtime/assets.py +++ b/teleopit/runtime/assets.py @@ -4,8 +4,12 @@ PROJECT_ROOT = Path(__file__).resolve().parents[2] +ROBOT_ASSETS_ROOT = PROJECT_ROOT / "assets" / "robots" GMR_ASSETS_ROOT = PROJECT_ROOT / "teleopit" / "retargeting" / "gmr" / "assets" -UNITREE_G1_MJLAB_XML = GMR_ASSETS_ROOT / "unitree_g1" / "g1_mjlab.xml" +UNITREE_G1_XML = ROBOT_ASSETS_ROOT / "unitree_g1" / "g1_29dof.xml" +UNITREE_G1_DEX3_XML = ROBOT_ASSETS_ROOT / "unitree_g1" / "g1_29dof_dex3.xml" +UNITREE_G1_AVP_O6_XML = ROBOT_ASSETS_ROOT / "unitree_g1" / "g1_29dof_avp_o6.xml" +UNITREE_G1_MJLAB_XML = UNITREE_G1_XML def missing_gmr_assets_message(path: str | Path, *, label: str = "Required asset") -> str: @@ -17,5 +21,5 @@ def missing_gmr_assets_message(path: str | Path, *, label: str = "Required asset return ( f"{label} not found: {resolved}\n" "Download the external robot assets with:\n" - " python scripts/setup/download_assets.py --only gmr" + " python scripts/setup/download_assets.py --only robots gmr" ) diff --git a/teleopit/runtime/console.py b/teleopit/runtime/console.py new file mode 100644 index 00000000..2b6d7b12 --- /dev/null +++ b/teleopit/runtime/console.py @@ -0,0 +1,271 @@ +from __future__ import annotations + +import sys +import logging +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Iterable + +from teleopit.runtime.common import cfg_get + + +RESET = "\033[0m" +BOLD = "\033[1m" +DIM = "\033[2m" +CYAN = "\033[36m" +GREEN = "\033[32m" +YELLOW = "\033[33m" +RED = "\033[31m" +MAGENTA = "\033[35m" +OPERATOR_LOGGER_NAME = "teleopit.operator" + + +@dataclass(frozen=True) +class KeyboardControl: + keys: str + action: str + + +def console_show_timing(cfg: Any) -> bool: + console_cfg = cfg_get(cfg, "console", {}) or {} + return bool(cfg_get(console_cfg, "show_timing", False)) + + +def console_timing_interval_s(cfg: Any, default: float = 10.0) -> float: + console_cfg = cfg_get(cfg, "console", {}) or {} + return float(cfg_get(console_cfg, "timing_log_interval_s", default)) + + +def console_log_level(cfg: Any, default: str = "warning") -> str: + console_cfg = cfg_get(cfg, "console", {}) or {} + return str(cfg_get(console_cfg, "log_level", default)).strip().lower() + + +def configure_runtime_logging(cfg: Any, *, force: bool = False) -> None: + """Keep operator output quiet by default while preserving warnings/errors.""" + + level_name = console_log_level(cfg) + level = getattr(logging, level_name.upper(), logging.WARNING) + logging.basicConfig(level=level, format="%(levelname)s:%(name)s:%(message)s", force=force) + + # Operator events are intentional console feedback; ordinary subsystem INFO stays hidden. + logging.getLogger(OPERATOR_LOGGER_NAME).setLevel(logging.INFO) + + if level > logging.INFO: + noisy_names = ( + "pico_bridge", + "onnxruntime", + "teleopit.inputs.pico4_provider", + "teleopit.sim2real.unitree_g1", + "teleopit.sim2real.safety", + ) + for name in noisy_names: + logging.getLogger(name).setLevel(logging.WARNING) + + +class PlainConsole: + """Small runtime console for keyboard controls and operator feedback.""" + + def __init__(self, *, title: str, enabled: bool = True, color: bool | None = None) -> None: + self._title = str(title) + self._enabled = bool(enabled) + self._color = sys.stdout.isatty() if color is None else bool(color) + self._started = False + + @property + def enabled(self) -> bool: + return self._enabled + + def start( + self, + *, + status: Iterable[tuple[str, str]] = (), + controls: Iterable[KeyboardControl] = (), + events: Iterable[str] = (), + control_section: str = "Keyboard", + show_help_key: bool = True, + ) -> None: + if not self._enabled: + return + self._started = True + print( + self.render( + status=status, + controls=controls, + events=events, + control_section=control_section, + show_help_key=show_help_key, + ), + flush=True, + ) + + def event(self, message: str) -> None: + if not self._enabled: + return + prefix = datetime.now().strftime("%H:%M:%S") + print(f"{self._dim(prefix)} {self._highlight_text(message)}", flush=True) + + def key_feedback(self, key: str, action: str, *, result: str | None = None) -> None: + details = f" -> {result}" if result else "" + self.event(f"{self.format_key(key)} {action}{details}") + + def help(self, controls: Iterable[KeyboardControl]) -> None: + if not self._enabled: + return + control_items = list(controls) + if not control_items: + self.event("no keyboard controls active") + return + lines = [self._section("Keyboard help")] + lines.extend(f" {self.format_key(item.keys)} {item.action}" for item in control_items) + print("\n".join(lines), flush=True) + + def render( + self, + *, + status: Iterable[tuple[str, str]] = (), + controls: Iterable[KeyboardControl] = (), + events: Iterable[str] = (), + control_section: str = "Keyboard", + show_help_key: bool = True, + ) -> str: + lines = [self._title_text(self._title)] + status_items = [(str(key), str(value)) for key, value in status if str(value)] + if status_items: + width = max(len(key) for key, _value in status_items) + lines.append("") + lines.extend( + f"{self._label(key.ljust(width))} {self._status_value(value)}" + for key, value in status_items + ) + + control_items = list(controls) + if control_items: + lines.append("") + lines.append(self._section(control_section)) + lines.append(" ".join(f"{self.format_key(item.keys)} {item.action}" for item in control_items)) + if show_help_key: + lines.append(f"{self.format_key('H')} help") + + event_items = [str(event) for event in events if str(event)] + if event_items: + lines.append("") + lines.append(self._section("Events")) + lines.extend(self._highlight_text(event) for event in event_items) + return "\n".join(lines) + + def _style(self, text: str, code: str) -> str: + if not self._color: + return text + return f"{code}{text}{RESET}" + + def _title_text(self, text: str) -> str: + return self._style(text, BOLD) + + def _section(self, text: str) -> str: + return self._style(text, CYAN + BOLD) + + def _label(self, text: str) -> str: + return self._style(text, DIM) + + def _dim(self, text: str) -> str: + return self._style(text, DIM) + + def format_key(self, text: str) -> str: + return self._style(f"[{text}]", YELLOW + BOLD) + + def _status_value(self, text: str) -> str: + normalized = text.strip().lower() + if normalized in {"ok", "enabled", "running", "mocap", "standing", "active"}: + return self._style(text, GREEN + BOLD) + if normalized in {"idle", "waiting", "paused", "off", "none"} or "waiting" in normalized: + return self._style(text, YELLOW + BOLD) + if normalized in {"damping", "error", "failed"} or "stale" in normalized: + return self._style(text, RED + BOLD) + if normalized in {"arms", "pico4 live"}: + return self._style(text, MAGENTA + BOLD) + return self._style(text, BOLD) + + def _highlight_text(self, text: str) -> str: + if not self._color: + return text + highlighted = text + replacements = { + "waiting": YELLOW + BOLD, + "paused": YELLOW + BOLD, + "resumed": GREEN + BOLD, + "replay": GREEN + BOLD, + "stopping": RED + BOLD, + "shutdown": RED + BOLD, + "MOCAP": GREEN + BOLD, + "STANDING": GREEN + BOLD, + "ARMS": MAGENTA + BOLD, + } + for word, code in replacements.items(): + highlighted = highlighted.replace(word, f"{code}{word}{RESET}") + return highlighted + + +def sim_keyboard_controls(cfg: Any) -> tuple[KeyboardControl, ...]: + input_cfg = cfg_get(cfg, "input", {}) or {} + provider = str(cfg_get(input_cfg, "provider", "bvh")).lower() + if provider == "pico4": + keyboard_cfg = cfg_get(cfg, "keyboard", {}) or {} + if not bool(cfg_get(keyboard_cfg, "enabled", False)): + return () + return ( + KeyboardControl("Y", "mocap"), + KeyboardControl("A", "pause/resume"), + KeyboardControl("B", "arms"), + KeyboardControl("X", "standing"), + KeyboardControl("Q", "quit"), + ) + + playback_cfg = cfg_get(cfg, "playback", {}) or {} + keyboard_cfg = cfg_get(playback_cfg, "keyboard", {}) or {} + if not bool(cfg_get(keyboard_cfg, "enabled", False)): + return () + return ( + KeyboardControl("Space/P", "pause/resume"), + KeyboardControl("R", "replay"), + KeyboardControl("Q", "stop"), + ) + + +def sim2real_keyboard_controls(cfg: Any) -> tuple[KeyboardControl, ...]: + recording_cfg = cfg_get(cfg, "recording", {}) or {} + if not bool(cfg_get(recording_cfg, "enabled", False)): + return () + return ( + KeyboardControl("R", "start"), + KeyboardControl("S", "save"), + KeyboardControl("D", "discard"), + KeyboardControl("Q", "shutdown"), + ) + + +def sim2real_operator_controls(cfg: Any) -> tuple[KeyboardControl, ...]: + input_cfg = cfg_get(cfg, "input", {}) or {} + provider = str(cfg_get(input_cfg, "provider", "bvh")).lower() + controls = [ + KeyboardControl("Remote Start", "standing"), + KeyboardControl("Remote Y", "mocap"), + KeyboardControl("Remote X", "standing"), + KeyboardControl("Remote L1+R1", "damping / estop"), + ] + if provider == "pico4": + controls.extend( + [ + KeyboardControl("Pico/Controller A", "pause/resume"), + KeyboardControl("Pico/Controller B", "arms"), + ] + ) + else: + controls.extend( + [ + KeyboardControl("Remote A", "pause/resume playback"), + KeyboardControl("Remote B", "replay from start"), + ] + ) + controls.extend(sim2real_keyboard_controls(cfg)) + return tuple(controls) diff --git a/teleopit/runtime/external_assets.py b/teleopit/runtime/external_assets.py index 0513f137..6fdfb7c9 100644 --- a/teleopit/runtime/external_assets.py +++ b/teleopit/runtime/external_assets.py @@ -31,6 +31,14 @@ class AssetEntry: mode="extract", ), ], + "robots": [ + AssetEntry( + "archives/robot_assets.tar.gz", + "assets/robots", + repo="model", + mode="extract", + ), + ], "bvh": [ AssetEntry( "archives/sample_bvh.tar.gz", @@ -40,7 +48,6 @@ class AssetEntry: ), ], "data": [ - AssetEntry("data/train", "data/datasets/seed/train", repo="dataset"), - AssetEntry("data/val", "data/datasets/seed/val", repo="dataset"), + AssetEntry("data", "data/datasets/seed", repo="dataset"), ], } diff --git a/teleopit/runtime/factory.py b/teleopit/runtime/factory.py index 970b8a13..ba581c8f 100644 --- a/teleopit/runtime/factory.py +++ b/teleopit/runtime/factory.py @@ -37,25 +37,13 @@ def build_simulation_cfg(cfg: Any) -> dict[str, object]: return { "policy_hz": float(cfg_get(cfg, "policy_hz", 50.0)), "pd_hz": float(cfg_get(cfg, "pd_hz", 1000.0)), - "transition_duration": float(cfg_get(cfg, "transition_duration", 0.0) or 0.0), - "pause_resume_transition_duration": float( - cfg_get(cfg, "pause_resume_transition_duration", cfg_get(cfg, "transition_duration", 0.0)) or 0.0 - ), - "pause_resume_warmup_steps": cfg_get(cfg, "pause_resume_warmup_steps", None), "retarget_buffer_enabled": bool(cfg_get(cfg, "retarget_buffer_enabled", True)), "retarget_buffer_window_s": float(cfg_get(cfg, "retarget_buffer_window_s", 0.5)), "retarget_buffer_delay_s": cfg_get(cfg, "retarget_buffer_delay_s", None), "reference_steps": cfg_get(cfg, "reference_steps", [0]), "reference_debug_log": bool(cfg_get(cfg, "reference_debug_log", False)), "realtime_input_delay_s": cfg_get(cfg, "realtime_input_delay_s", None), - "realtime_buffer_low_watermark_steps": cfg_get(cfg, "realtime_buffer_low_watermark_steps", None), - "realtime_buffer_high_watermark_steps": cfg_get(cfg, "realtime_buffer_high_watermark_steps", None), "realtime_buffer_warmup_steps": cfg_get(cfg, "realtime_buffer_warmup_steps", None), - "realtime_catchup_enabled": bool(cfg_get(cfg, "realtime_catchup_enabled", False)), - "realtime_catchup_trigger_steps": cfg_get(cfg, "realtime_catchup_trigger_steps", None), - "realtime_catchup_release_steps": cfg_get(cfg, "realtime_catchup_release_steps", None), - "realtime_catchup_target_delay_s": cfg_get(cfg, "realtime_catchup_target_delay_s", None), - "reference_qpos_smoothing_alpha": float(cfg_get(cfg, "reference_qpos_smoothing_alpha", 1.0)), "reference_velocity_smoothing_alpha": float(cfg_get(cfg, "reference_velocity_smoothing_alpha", 1.0)), "reference_anchor_velocity_smoothing_alpha": float( cfg_get(cfg, "reference_anchor_velocity_smoothing_alpha", 1.0) @@ -164,11 +152,6 @@ def _build_policy_components( policy_dim = getattr(controller, "_expected_obs_dim", None) builder_dim = getattr(obs_builder, "total_obs_size", None) if policy_dim is not None and builder_dim is not None and policy_dim != builder_dim: - if builder_dim == 166: - raise ValueError( - f"Only 166D velcmd_history ONNX policies are supported here; " - f"obs_builder produces 166D but policy expects {policy_dim}D." - ) raise ValueError( f"Observation dimension mismatch at startup: obs_builder produces {builder_dim}D " f"but policy expects {policy_dim}D. Use a matching ONNX model." diff --git a/teleopit/runtime/reference_config.py b/teleopit/runtime/reference_config.py index 468e29c6..cf0e7ab5 100644 --- a/teleopit/runtime/reference_config.py +++ b/teleopit/runtime/reference_config.py @@ -1,7 +1,7 @@ """Shared reference-window / realtime-buffer configuration. Parsed once from the top-level config and consumed by both -``SimulationLoop`` and ``Sim2RealController``. +``SimulationLoop`` and the process-isolated sim2real runtime. """ from __future__ import annotations @@ -13,7 +13,6 @@ cfg_get, parse_alpha, parse_nonnegative_int, - parse_optional_nonnegative_int, ) @@ -23,17 +22,9 @@ class ReferenceConfig: retarget_buffer_window_s: float reference_delay_s: float | None reference_debug_log: bool - realtime_buffer_low_watermark_steps: int - realtime_buffer_high_watermark_steps: int | None realtime_buffer_warmup_steps: int - pause_resume_warmup_steps: int - realtime_catchup_enabled: bool - realtime_catchup_trigger_steps: int | None - realtime_catchup_release_steps: int | None - realtime_catchup_target_delay_s: float | None reference_velocity_smoothing_alpha: float reference_anchor_velocity_smoothing_alpha: float - reference_qpos_smoothing_alpha: float def _resolve_delay(cfg: Any, *, provider_fps: float | None) -> float | None: @@ -48,13 +39,6 @@ def _resolve_delay(cfg: Any, *, provider_fps: float | None) -> float | None: return None -def _resolve_catchup_target_delay(cfg: Any) -> float | None: - raw = cfg_get(cfg, "realtime_catchup_target_delay_s", None) - if raw in (None, "", "null"): - return None - return float(raw) - - def parse_reference_config( cfg: Any, *, @@ -80,38 +64,11 @@ def parse_reference_config( reference_debug_log = bool(cfg_get(cfg, "reference_debug_log", False)) reference_delay_s = _resolve_delay(cfg, provider_fps=provider_fps) - low = parse_nonnegative_int( - cfg_get(cfg, "realtime_buffer_low_watermark_steps", 0), - field_name="realtime_buffer_low_watermark_steps", - default=0, - ) - high = parse_optional_nonnegative_int( - cfg_get(cfg, "realtime_buffer_high_watermark_steps", None), - field_name="realtime_buffer_high_watermark_steps", - ) - if high is not None and high < low: - raise ValueError("realtime_buffer_high_watermark_steps must be >= realtime_buffer_low_watermark_steps") - warmup = parse_nonnegative_int( cfg_get(cfg, "realtime_buffer_warmup_steps", 0), field_name="realtime_buffer_warmup_steps", default=0, ) - pause_resume_warmup = parse_nonnegative_int( - cfg_get(cfg, "pause_resume_warmup_steps", warmup), - field_name="pause_resume_warmup_steps", - default=warmup, - ) - catchup_enabled = bool(cfg_get(cfg, "realtime_catchup_enabled", False)) - catchup_trigger = parse_optional_nonnegative_int( - cfg_get(cfg, "realtime_catchup_trigger_steps", None), - field_name="realtime_catchup_trigger_steps", - ) - catchup_release = parse_optional_nonnegative_int( - cfg_get(cfg, "realtime_catchup_release_steps", None), - field_name="realtime_catchup_release_steps", - ) - catchup_target_delay = _resolve_catchup_target_delay(cfg) vel_alpha = parse_alpha( cfg_get(cfg, "reference_velocity_smoothing_alpha", 1.0), @@ -123,26 +80,13 @@ def parse_reference_config( field_name="reference_anchor_velocity_smoothing_alpha", default=1.0, ) - qpos_alpha = parse_alpha( - cfg_get(cfg, "reference_qpos_smoothing_alpha", 1.0), - field_name="reference_qpos_smoothing_alpha", - default=1.0, - ) return ReferenceConfig( retarget_buffer_enabled=retarget_buffer_enabled, retarget_buffer_window_s=retarget_buffer_window_s, reference_delay_s=reference_delay_s, reference_debug_log=reference_debug_log, - realtime_buffer_low_watermark_steps=low, - realtime_buffer_high_watermark_steps=high, realtime_buffer_warmup_steps=warmup, - pause_resume_warmup_steps=pause_resume_warmup, - realtime_catchup_enabled=catchup_enabled, - realtime_catchup_trigger_steps=catchup_trigger, - realtime_catchup_release_steps=catchup_release, - realtime_catchup_target_delay_s=catchup_target_delay, reference_velocity_smoothing_alpha=vel_alpha, reference_anchor_velocity_smoothing_alpha=anchor_vel_alpha, - reference_qpos_smoothing_alpha=qpos_alpha, ) diff --git a/teleopit/sim/loop.py b/teleopit/sim/loop.py index a726c281..5f1a211d 100644 --- a/teleopit/sim/loop.py +++ b/teleopit/sim/loop.py @@ -8,16 +8,18 @@ from numpy.typing import NDArray from teleopit.constants import FULL_QPOS_DIM, ROOT_DIM -from teleopit.controllers.qpos_interpolator import QposInterpolator +from teleopit.controllers.observation import align_motion_qpos_yaw from teleopit.runtime.reference_config import parse_reference_config +from teleopit.runtime.arm_mocap import compose_arm_reference, parse_arm_joint_indices +from teleopit.runtime.console import PlainConsole, sim_keyboard_controls from teleopit.inputs.realtime_packet import RealtimeInputPacket -from teleopit.interfaces import Controller, InputProvider, MessageBus, ObservationBuilder, Recorder, Retargeter, Robot, RobotState +from teleopit.interfaces import Controller, InputProvider, MessageBus, ObservationBuilder, Retargeter, Robot, RobotState from teleopit.sim.reference_timeline import ( ReferenceWindow, ReferenceWindowBuilder, ) from teleopit.sim.realtime_utils import RealtimeReferenceDiagnostics -from teleopit.sim.runtime_components import MotionPreparation, PolicyStepRunner, RunRecorder, RuntimePublisher, ViewerManager +from teleopit.sim.runtime_components import MotionPreparation, PolicyStepRunner, RuntimePublisher, ViewerManager from teleopit.sim.viewer_subprocess import mocap_viewer_proc, start_camera_viewer, start_robot_viewer from teleopit.runtime.mocap_session import MocapSessionManager from teleopit.runtime.offline_playback import OfflinePlaybackController @@ -34,6 +36,7 @@ class SimulationMode(Enum): IDLE = "idle" STANDING = "standing" MOCAP = "mocap" + ARMS = "arms" @final @@ -47,6 +50,7 @@ def __init__( cfg: object, viewers: set[str] | None = None, video_runtime: object | None = None, + console: PlainConsole | None = None, ) -> None: self.robot: Robot = robot self.controller: Controller = controller @@ -54,6 +58,7 @@ def __init__( self.bus: MessageBus = bus self.cfg: object = cfg self._video_runtime = video_runtime + self._console = console or PlainConsole(title="Teleopit sim2sim") self.policy_hz: float = self._to_float(self._get_cfg("policy_hz", "sim.policy_hz", "control.policy_hz", "policy_frequency")) self.pd_hz: float = self._to_float(self._get_cfg("pd_hz", "sim.pd_hz", "control.pd_hz", "pd_frequency")) @@ -75,20 +80,14 @@ def __init__( self._last_action: Float32Array = np.zeros((self._num_actions,), dtype=np.float32) self._last_retarget_qpos: Float64Array | None = None + self._standing_qpos: Float64Array | None = None + self._arm_joint_indices = parse_arm_joint_indices(cfg, num_actions=self._num_actions) self._realtime: bool = bool(self._try_get_cfg("realtime") or False) raw_debug_trace_path = self._try_get_cfg("debug_trace_path") self._debug_trace_path: str | None = None if raw_debug_trace_path not in (None, "", "null"): self._debug_trace_path = str(raw_debug_trace_path) - # Motion command transition smoothing - transition_dur = float(self._try_get_cfg("transition_duration") or 0.0) - self._mocap_transition_duration = transition_dur - self._pause_resume_transition_duration = float( - self._try_get_cfg("pause_resume_transition_duration") or transition_dur - ) - self._qpos_interpolator = QposInterpolator(transition_dur, self.policy_hz) - self._init_reference_config() self._init_components(viewers) @@ -98,6 +97,7 @@ def _init_reference_config(self) -> None: self._playback_pause_on_end = bool(self._try_get_cfg("playback.pause_on_end", False)) self._playback_keyboard_enabled = bool(self._try_get_cfg("playback.keyboard.enabled", False)) self._realtime_keyboard_enabled = bool(self._try_get_cfg("keyboard.enabled", False)) + self._console_controls = sim_keyboard_controls(self.cfg) # Shared reference config (parsed once, used by both sim and sim2real) self._ref_cfg = parse_reference_config(self.cfg) @@ -115,7 +115,7 @@ def _init_reference_config(self) -> None: ) def _init_components(self, viewers: set[str] | None) -> None: - """Build PolicyStepRunner, publisher, recorder helper, and viewer manager.""" + """Build PolicyStepRunner, publisher, and viewer manager.""" self._viewers: set[str] = set(viewers or set()) self._step_runner = PolicyStepRunner( robot=self.robot, @@ -128,13 +128,10 @@ def _init_components(self, viewers: set[str] | None) -> None: kds=self._kds, torque_limits=self._torque_limits, default_dof_pos=self._default_dof_pos, - qpos_interpolator=self._qpos_interpolator, reference_velocity_smoothing_alpha=self._ref_cfg.reference_velocity_smoothing_alpha, reference_anchor_velocity_smoothing_alpha=self._ref_cfg.reference_anchor_velocity_smoothing_alpha, - reference_qpos_smoothing_alpha=self._ref_cfg.reference_qpos_smoothing_alpha, ) self._publisher = RuntimePublisher(self.bus) - self._recorder_helper = RunRecorder() self._viewer_manager = ViewerManager( robot=self.robot, viewers=self._viewers, @@ -148,11 +145,10 @@ def run( input_provider: InputProvider, retargeter: Retargeter, num_steps: int, - recorder: Recorder | None = None, ) -> dict[str, float | int]: from teleopit.sim.session import SimLoopSession - session = SimLoopSession(self, input_provider, retargeter, num_steps, recorder) + session = SimLoopSession(self, input_provider, retargeter, num_steps) return session.run() def run_headless( @@ -160,9 +156,8 @@ def run_headless( input_provider: InputProvider, retargeter: Retargeter, num_steps: int, - recorder: Recorder | None = None, ) -> dict[str, float | int]: - return self.run(input_provider=input_provider, retargeter=retargeter, num_steps=num_steps, recorder=recorder) + return self.run(input_provider=input_provider, retargeter=retargeter, num_steps=num_steps) def _compute_target_dof_pos(self, action: Float32Array) -> Float32Array: return self._step_runner.compute_target_dof_pos(action) @@ -171,10 +166,27 @@ def _build_standing_qpos(self, state: RobotState) -> Float64Array: standing_qpos = np.zeros(FULL_QPOS_DIM, dtype=np.float64) if state.base_pos is not None: standing_qpos[0:3] = np.asarray(state.base_pos, dtype=np.float64)[:3] - standing_qpos[3:7] = np.asarray(state.quat, dtype=np.float64)[:4] + standing_qpos[3] = 1.0 + align_motion_qpos_yaw(np.asarray(state.quat, dtype=np.float32), standing_qpos) standing_qpos[7:7 + self._num_actions] = self._default_dof_pos.astype(np.float64)[: self._num_actions] return standing_qpos + def _set_standing_reference(self, state: RobotState) -> Float64Array: + standing_qpos = self._build_standing_qpos(state) + self._standing_qpos = standing_qpos.copy() + return standing_qpos + + def _compose_arm_reference(self, retarget_qpos: Float64Array) -> Float64Array: + if self._standing_qpos is None: + self._set_standing_reference(self.robot.get_state()) + assert self._standing_qpos is not None + return compose_arm_reference( + standing_qpos=self._standing_qpos, + retarget_qpos=retarget_qpos, + arm_joint_indices=self._arm_joint_indices, + num_actions=self._num_actions, + ) + @staticmethod def _drain_realtime_control_events(input_provider: InputProvider) -> tuple[object, ...]: pop_control_events = getattr(input_provider, "pop_control_events", None) @@ -258,7 +270,6 @@ def _restart_offline_playback( *, offline_playback: OfflinePlaybackController, mocap_session: MocapSessionManager, - retargeter: Retargeter, ) -> None: offline_playback.replay() mocap_session.reset() @@ -266,7 +277,6 @@ def _restart_offline_playback( self._last_action = np.zeros((self._num_actions,), dtype=np.float32) self.controller.reset() self.obs_builder.reset() - retargeter.reset() self.robot.reset() def _pause_offline_playback( @@ -275,14 +285,12 @@ def _pause_offline_playback( offline_playback: OfflinePlaybackController, mocap_session: MocapSessionManager, hold_qpos: Float64Array, - retargeter: Retargeter, ) -> None: offline_playback.pause() self._step_runner.reset() self._last_action = np.zeros((self._num_actions,), dtype=np.float32) self.controller.reset() self.obs_builder.reset() - retargeter.reset() mocap_session.pause(hold_qpos) def _resume_offline_playback( @@ -290,24 +298,19 @@ def _resume_offline_playback( *, offline_playback: OfflinePlaybackController, mocap_session: MocapSessionManager, - retargeter: Retargeter, state: RobotState, ) -> None: if mocap_session.hold_qpos is None: raise RuntimeError("Cannot resume offline playback without a paused hold qpos") - resume_qpos = self._build_robot_state_qpos(state) + resume_qpos = self._build_resume_alignment_qpos(mocap_session.hold_qpos, state) offline_playback.resume() mocap_session.reset() self._step_runner.reset() self._last_action = np.zeros((self._num_actions,), dtype=np.float32) self.controller.reset() self.obs_builder.reset() - retargeter.reset() + self._step_runner.reset_reference_alignment(resume_qpos) self._step_runner.last_retarget_qpos = resume_qpos.copy() - self._step_runner.arm_motion_transition( - resume_qpos, - duration_s=self._pause_resume_transition_duration, - ) def _build_observation( self, @@ -326,17 +329,6 @@ def _build_observation( def _publish(self, mimic_obs: Float32Array, action: Float32Array, robot_state: object) -> None: self._publisher.publish(mimic_obs, action, robot_state) - def _record( - self, - recorder: Recorder | None, - state: object, - mimic_obs: Float32Array, - action: Float32Array, - target_dof_pos: Float32Array, - torque: Float32Array, - ) -> None: - self._recorder_helper.record(recorder, state, mimic_obs, action, target_dof_pos, torque) - def _retarget_to_qpos(self, retargeted: object) -> Float64Array: return self._step_runner._retarget_to_qpos(retargeted) @@ -444,8 +436,6 @@ def _write_debug_trace( reference_future_horizon_steps=(None if realtime_reference_diag is None else np.asarray(getattr(realtime_reference_diag, "future_horizon_steps"), dtype=np.int64)), reference_real_frame_count=(None if realtime_reference_diag is None else np.asarray(getattr(realtime_reference_diag, "real_frame_count"), dtype=np.int64)), reference_warmup_done=(None if realtime_reference_diag is None else np.asarray(getattr(realtime_reference_diag, "warmup_done"), dtype=np.bool_)), - reference_used_repeat_padding=(None if realtime_reference_diag is None else np.asarray(getattr(realtime_reference_diag, "used_repeat_padding"), dtype=np.bool_)), - reference_padding_active=(None if realtime_reference_diag is None else np.asarray(getattr(realtime_reference_diag, "padding_active"), dtype=np.bool_)), ) def _log_reference_window(self, reference_window: ReferenceWindow, buffer_len: int) -> None: @@ -458,18 +448,3 @@ def _log_reference_window(self, reference_window: ReferenceWindow, buffer_len: i list(reference_window.reference_steps), list(reference_window.modes()), ) - - def _log_repeat_padding( - self, - reference_window: ReferenceWindow, - diagnostics: RealtimeReferenceDiagnostics, - buffer_len: int, - ) -> None: - import logging - - logging.getLogger(__name__).warning( - "Reference timeline repeat padding | buffer_len=%d | future_horizon_steps=%d | steps=%s", - buffer_len, - diagnostics.future_horizon_steps, - list(reference_window.reference_steps), - ) diff --git a/teleopit/sim/realtime_utils.py b/teleopit/sim/realtime_utils.py index 7a51283a..589e1e2f 100644 --- a/teleopit/sim/realtime_utils.py +++ b/teleopit/sim/realtime_utils.py @@ -5,12 +5,7 @@ import numpy as np from numpy.typing import NDArray -from teleopit.sim.reference_timeline import ( - ReferenceSample, - ReferenceTimeline, - ReferenceWindow, - ReferenceWindowBuilder, -) +from teleopit.sim.reference_timeline import ReferenceTimeline, ReferenceWindow, ReferenceWindowBuilder Float32Array = NDArray[np.float32] @@ -44,13 +39,9 @@ class RealtimeReferenceDiagnostics: future_horizon_steps: int real_frame_count: int warmup_done: bool - used_repeat_padding: bool - padding_active: bool requested_base_time_s: float effective_base_time_s: float latest_timestamp_s: float | None - used_catchup: bool - catchup_active: bool class RealtimeReferenceManager: @@ -58,62 +49,13 @@ def __init__( self, *, reference_window_builder: ReferenceWindowBuilder, - low_watermark_steps: int = 0, - high_watermark_steps: int | None = None, warmup_steps: int = 0, - catchup_enabled: bool = False, - catchup_trigger_steps: int | None = None, - catchup_release_steps: int | None = None, - catchup_target_delay_s: float | None = None, ) -> None: self._builder = reference_window_builder - self._low_watermark_steps = max(int(low_watermark_steps), self._builder.max_future_step) - if high_watermark_steps is None: - self._high_watermark_steps = self._low_watermark_steps - else: - self._high_watermark_steps = int(high_watermark_steps) - if self._low_watermark_steps < 0: - raise ValueError("low_watermark_steps must be >= 0") - if self._high_watermark_steps < self._low_watermark_steps: - raise ValueError( - "high_watermark_steps must be >= low_watermark_steps, " - f"got {self._high_watermark_steps} < {self._low_watermark_steps}" - ) self._warmup_steps = max(int(warmup_steps), 0) - self._catchup_enabled = bool(catchup_enabled) - self._catchup_trigger_steps = ( - max(int(catchup_trigger_steps), 0) - if catchup_trigger_steps is not None - else max(self._high_watermark_steps, self._low_watermark_steps + 2) - ) - self._catchup_release_steps = ( - max(int(catchup_release_steps), 0) - if catchup_release_steps is not None - else max(self._low_watermark_steps, self._builder.max_future_step) - ) - if self._catchup_release_steps > self._catchup_trigger_steps: - raise ValueError( - "catchup_release_steps must be <= catchup_trigger_steps, " - f"got {self._catchup_release_steps} > {self._catchup_trigger_steps}" - ) - self._catchup_target_delay_s = ( - float(catchup_target_delay_s) - if catchup_target_delay_s is not None - else float(self._builder.policy_dt_s * max(self._builder.max_future_step, 1)) - ) - if ( - not np.isfinite(self._catchup_target_delay_s) - or self._catchup_target_delay_s < 0.0 - ): - raise ValueError( - "catchup_target_delay_s must be finite and >= 0, " - f"got {catchup_target_delay_s}" - ) self.reset() def reset(self) -> None: - self._padding_active = False - self._catchup_active = False self._real_frame_count = 0 def set_warmup_steps(self, warmup_steps: int) -> None: @@ -148,115 +90,14 @@ def sample( ) -> tuple[ReferenceWindow, RealtimeReferenceDiagnostics]: requested_base_time_s = float(base_time_s) latest_timestamp = timeline.latest_timestamp() - requested_future_horizon_steps = self.future_horizon_steps(timeline, requested_base_time_s) - effective_base_time_s, used_catchup = self._apply_catchup( - requested_base_time_s, - latest_timestamp, - requested_future_horizon_steps, - ) + effective_base_time_s = requested_base_time_s window = self._builder.sample(timeline, effective_base_time_s) future_horizon_steps = self.future_horizon_steps(timeline, effective_base_time_s) - self._update_padding_state(future_horizon_steps) - target_padding_horizon = ( - self._high_watermark_steps - if self._padding_active - else self._builder.max_future_step - ) - padded_window, used_repeat_padding = self._pad_future_window( - timeline, - window, - effective_base_time_s, - target_padding_horizon, - ) - return padded_window, RealtimeReferenceDiagnostics( + return window, RealtimeReferenceDiagnostics( future_horizon_steps=future_horizon_steps, real_frame_count=self._real_frame_count, warmup_done=self.warmup_done, - used_repeat_padding=used_repeat_padding, - padding_active=self._padding_active, requested_base_time_s=requested_base_time_s, effective_base_time_s=effective_base_time_s, latest_timestamp_s=None if latest_timestamp is None else float(latest_timestamp), - used_catchup=used_catchup, - catchup_active=self._catchup_active, - ) - - def _apply_catchup( - self, - requested_base_time_s: float, - latest_timestamp: float | None, - requested_future_horizon_steps: int, - ) -> tuple[float, bool]: - if not self._catchup_enabled or latest_timestamp is None: - self._catchup_active = False - return requested_base_time_s, False - - if self._catchup_active: - if requested_future_horizon_steps <= self._catchup_release_steps: - self._catchup_active = False - elif requested_future_horizon_steps >= self._catchup_trigger_steps: - self._catchup_active = True - - if not self._catchup_active: - return requested_base_time_s, False - - target_base_time_s = max( - requested_base_time_s, - float(latest_timestamp) - self._catchup_target_delay_s, ) - used_catchup = target_base_time_s > requested_base_time_s + 1e-9 - return target_base_time_s, used_catchup - - def _update_padding_state(self, future_horizon_steps: int) -> None: - if self._high_watermark_steps <= 0: - self._padding_active = False - return - if future_horizon_steps <= self._low_watermark_steps: - self._padding_active = True - elif future_horizon_steps >= self._high_watermark_steps: - self._padding_active = False - - def _pad_future_window( - self, - timeline: ReferenceTimeline, - window: ReferenceWindow, - base_time_s: float, - target_padding_horizon: int, - ) -> tuple[ReferenceWindow, bool]: - latest_timestamp = timeline.latest_timestamp() - if latest_timestamp is None or target_padding_horizon <= 0: - return window, False - - latest_sample = timeline.sample_at(latest_timestamp) - latest_ts = float(latest_timestamp) - policy_dt_s = self._builder.policy_dt_s - samples: list[ReferenceSample] = [] - used_repeat_padding = False - - for step, sample in zip(window.reference_steps, window.samples): - target_time_s = float(base_time_s) + float(step) * policy_dt_s - if step > 0 and step <= target_padding_horizon and target_time_s > latest_ts + 1e-9: - used_repeat_padding = True - samples.append( - ReferenceSample( - qpos=np.asarray(latest_sample.qpos, dtype=np.float64).copy(), - timestamp_s=target_time_s, - mode="repeat_latest", - used_fallback=False, - older_timestamp_s=latest_ts, - newer_timestamp_s=latest_ts, - alpha=None, - ) - ) - else: - samples.append(sample) - - if not used_repeat_padding: - return window, False - - return ReferenceWindow( - base_time_s=float(window.base_time_s), - policy_dt_s=float(window.policy_dt_s), - reference_steps=tuple(window.reference_steps), - samples=tuple(samples), - ), True diff --git a/teleopit/sim/reference_utils.py b/teleopit/sim/reference_utils.py index 8f862eae..9b5807ad 100644 --- a/teleopit/sim/reference_utils.py +++ b/teleopit/sim/reference_utils.py @@ -1,6 +1,6 @@ """Shared reference-window utilities used by both offline simulation and sim2real. -These are pure functions extracted from ``SimLoop`` and ``Sim2RealController`` to avoid +These are pure functions extracted from ``SimLoop`` and sim2real runtime code to avoid code duplication. Each function takes explicit arguments instead of ``self``. """ diff --git a/teleopit/sim/runtime_components.py b/teleopit/sim/runtime_components.py index 91e83ff4..ae450813 100644 --- a/teleopit/sim/runtime_components.py +++ b/teleopit/sim/runtime_components.py @@ -11,12 +11,11 @@ _logger = logging.getLogger(__name__) -from teleopit.constants import FULL_QPOS_DIM, NUM_JOINTS, ROOT_DIM +from teleopit.constants import FULL_QPOS_DIM from teleopit.bus.topics import TOPIC_ACTION, TOPIC_MIMIC_OBS, TOPIC_ROBOT_STATE from teleopit.controllers.observation import VelCmdObservationBuilder -from teleopit.controllers.qpos_interpolator import QposInterpolator, QposLowPassFilter from teleopit.controllers import reference_processing as ref_proc -from teleopit.interfaces import MessageBus, ObservationBuilder, Recorder, Robot, RobotState +from teleopit.interfaces import MessageBus, ObservationBuilder, Robot, RobotState from teleopit.retargeting.core import extract_mimic_obs from teleopit.sim.reference_timeline import ReferenceWindow from teleopit.sim.reference_utils import obs_builder_requires_reference_window @@ -55,31 +54,6 @@ def publish(self, mimic_obs: Float32Array, action: Float32Array, robot_state: ob self._bus.publish(TOPIC_ROBOT_STATE, robot_state) -class RunRecorder: - def record( - self, - recorder: Recorder | None, - state: RobotState, - mimic_obs: Float32Array, - action: Float32Array, - target_dof_pos: Float32Array, - torque: Float32Array, - ) -> None: - if recorder is None: - return - - payload: dict[str, object] = { - "joint_pos": np.asarray(state.qpos, dtype=np.float32), - "joint_vel": np.asarray(state.qvel, dtype=np.float32), - "mimic_obs": mimic_obs.astype(np.float32, copy=False), - "action": action.astype(np.float32, copy=False), - "target_dof_pos": target_dof_pos.astype(np.float32, copy=False), - "torque": torque.astype(np.float32, copy=False), - "timestamp": np.asarray(float(state.timestamp), dtype=np.float64), - } - recorder.add_frame(payload) - - class PolicyStepRunner: def __init__( self, @@ -94,10 +68,8 @@ def __init__( kds: Float32Array, torque_limits: Float32Array, default_dof_pos: Float32Array, - qpos_interpolator: QposInterpolator, reference_velocity_smoothing_alpha: float = 1.0, reference_anchor_velocity_smoothing_alpha: float = 1.0, - reference_qpos_smoothing_alpha: float = 1.0, ) -> None: self.robot = robot self.controller = controller @@ -109,11 +81,9 @@ def __init__( self.kds = kds self.torque_limits = torque_limits self.default_dof_pos = default_dof_pos - self.qpos_interpolator = qpos_interpolator self._motion_joint_vel_smoother = ExponentialVecSmoother(reference_velocity_smoothing_alpha) self._motion_anchor_lin_vel_smoother = ExponentialVecSmoother(reference_anchor_velocity_smoothing_alpha) self._motion_anchor_ang_vel_smoother = ExponentialVecSmoother(reference_anchor_velocity_smoothing_alpha) - self._reference_qpos_smoother = QposLowPassFilter(reference_qpos_smoothing_alpha) self.last_action: Float32Array = np.zeros((self.num_actions,), dtype=np.float32) self.last_retarget_qpos: Float64Array | None = None self.last_reference_qpos: Float64Array | None = None @@ -135,8 +105,6 @@ def reset(self) -> None: self._motion_joint_vel_smoother.reset() self._motion_anchor_lin_vel_smoother.reset() self._motion_anchor_ang_vel_smoother.reset() - self._reference_qpos_smoother.reset() - self.qpos_interpolator.reset() def soft_reset_reference_state(self, *, reset_alignment: bool = True) -> None: self.last_reference_qpos = None @@ -149,7 +117,6 @@ def soft_reset_reference_state(self, *, reset_alignment: bool = True) -> None: self._motion_joint_vel_smoother.reset() self._motion_anchor_lin_vel_smoother.reset() self._motion_anchor_ang_vel_smoother.reset() - self._reference_qpos_smoother.reset() def reset_reference_alignment(self, target_qpos: Float64Array | None = None) -> None: self._fixed_reference_yaw_quat = None @@ -161,11 +128,6 @@ def reset_reference_alignment(self, target_qpos: Float64Array | None = None) -> else np.asarray(target_qpos[0:2], dtype=np.float32).reshape(2).copy() ) - def arm_motion_transition(self, start_qpos: Float64Array, *, duration_s: float) -> None: - self.qpos_interpolator.reset() - self.qpos_interpolator.configure(duration_s) - self.qpos_interpolator.start(start_qpos) - def prepare_static_motion_command(self, qpos: Float64Array) -> MotionPreparation: hold_qpos = self._retarget_to_qpos(qpos) mimic_obs = extract_mimic_obs( @@ -191,16 +153,8 @@ def prepare_static_motion_command(self, qpos: Float64Array) -> MotionPreparation def prepare_motion_command(self, retargeted: object, state: object) -> MotionPreparation: reference_qpos = self._retarget_to_qpos(retargeted) reference_qpos = self._align_velcmd_reference_yaw(reference_qpos, state) - reference_qpos = self._reference_qpos_smoother.apply(reference_qpos) self._pending_reference_qpos = reference_qpos.copy() - - if self.last_retarget_qpos is None and self.qpos_interpolator.duration > 0: - start_qpos = np.zeros(FULL_QPOS_DIM, dtype=np.float64) - start_qpos[0:3] = np.asarray(state.base_pos[:3], dtype=np.float64) - start_qpos[3:7] = np.asarray(state.quat[:4], dtype=np.float64) - start_qpos[ROOT_DIM:FULL_QPOS_DIM] = np.asarray(state.qpos[:NUM_JOINTS], dtype=np.float64) - self.qpos_interpolator.start(start_qpos) - qpos = self.qpos_interpolator.apply(reference_qpos) + qpos = reference_qpos.copy() mimic_obs = extract_mimic_obs(qpos=qpos, last_qpos=self.last_retarget_qpos, dt=1.0 / self.policy_hz) retarget_viewer_qpos = qpos.copy() @@ -214,13 +168,6 @@ def prepare_motion_command(self, retargeted: object, state: object) -> MotionPre if obs_builder_requires_reference_window(self.obs_builder): motion_anchor_lin_vel_w = None motion_anchor_ang_vel_w = None - elif self.qpos_interpolator.is_active: - true_lin_vel_w, true_ang_vel_w = self._compute_anchor_velocities(reference_qpos) - blend = np.float32(self.qpos_interpolator.last_alpha) - raw_motion_anchor_lin_vel_w = np.asarray(true_lin_vel_w * blend, dtype=np.float32) - raw_motion_anchor_ang_vel_w = np.asarray(true_ang_vel_w * blend, dtype=np.float32) - motion_anchor_lin_vel_w = self._motion_anchor_lin_vel_smoother.apply(raw_motion_anchor_lin_vel_w) - motion_anchor_ang_vel_w = self._motion_anchor_ang_vel_smoother.apply(raw_motion_anchor_ang_vel_w) else: raw_motion_anchor_lin_vel_w, raw_motion_anchor_ang_vel_w = self._compute_anchor_velocities(reference_qpos) motion_anchor_lin_vel_w = self._motion_anchor_lin_vel_smoother.apply(raw_motion_anchor_lin_vel_w) diff --git a/teleopit/sim/session.py b/teleopit/sim/session.py index 703430fd..c404200c 100644 --- a/teleopit/sim/session.py +++ b/teleopit/sim/session.py @@ -19,7 +19,7 @@ from numpy.typing import NDArray from teleopit.debug.rollout_trace import RolloutTraceWriter -from teleopit.interfaces import InputProvider, Recorder, Retargeter, RobotState +from teleopit.interfaces import InputProvider, Retargeter, RobotState from teleopit.sim.reference_motion import ( OfflineReferenceMotion, interpolate_human_frames, @@ -33,6 +33,7 @@ ) from teleopit.sim.realtime_utils import RealtimeReferenceDiagnostics, RealtimeReferenceManager from teleopit.sim.runtime_components import MotionPreparation +from teleopit.runtime.arm_mocap import compose_arm_reference_window from teleopit.runtime.mocap_session import MocapSessionManager, MocapSessionState from teleopit.runtime.offline_playback import OfflinePlaybackController from teleopit.runtime.terminal_keyboard import TerminalKeyboardReader @@ -63,12 +64,10 @@ def __init__( input_provider: InputProvider, retargeter: Retargeter, num_steps: int, - recorder: Recorder | None, ) -> None: self._loop = loop self._input_provider = input_provider self._retargeter = retargeter - self._recorder = recorder # Convenience aliases for heavily-used loop attributes self._step_runner = loop._step_runner @@ -148,13 +147,7 @@ def __init__( if self.reference_timeline is not None: self.realtime_reference_manager = RealtimeReferenceManager( reference_window_builder=loop._reference_window_builder, - low_watermark_steps=ref_cfg.realtime_buffer_low_watermark_steps, - high_watermark_steps=ref_cfg.realtime_buffer_high_watermark_steps, warmup_steps=ref_cfg.realtime_buffer_warmup_steps, - catchup_enabled=ref_cfg.realtime_catchup_enabled, - catchup_trigger_steps=ref_cfg.realtime_catchup_trigger_steps, - catchup_release_steps=ref_cfg.realtime_catchup_release_steps, - catchup_target_delay_s=ref_cfg.realtime_catchup_target_delay_s, ) # Realtime live-frame tracking @@ -198,6 +191,8 @@ def __init__( self.simulation_mode: SimulationMode = ( SimulationMode.STANDING if self.realtime_keyboard_mode_enabled else SimulationMode.MOCAP ) + if self.simulation_mode == SimulationMode.STANDING: + loop._set_standing_reference(loop.robot.get_state()) # Debug writer self.debug_writer: RolloutTraceWriter | None = None @@ -217,14 +212,12 @@ def __init__( # State-management methods (formerly closures with nonlocal) # ------------------------------------------------------------------ - def reset_runtime_tracking(self, *, warmup_steps: int | None = None) -> None: + def reset_runtime_tracking(self) -> None: ref_cfg = self._loop._ref_cfg if self.reference_timeline is not None: self.reference_timeline.clear() if self.realtime_reference_manager is not None: - self.realtime_reference_manager.set_warmup_steps( - ref_cfg.realtime_buffer_warmup_steps if warmup_steps is None else warmup_steps - ) + self.realtime_reference_manager.set_warmup_steps(ref_cfg.realtime_buffer_warmup_steps) self.realtime_reference_manager.reset() self.previous_live_human_frame = None self.previous_live_retargeted = None @@ -236,57 +229,81 @@ def reset_runtime_tracking(self, *, warmup_steps: int | None = None) -> None: self.cached_human_frame = None self.cached_retargeted = None - def full_policy_reset(self, *, warmup_steps: int | None = None) -> None: + def reset_policy_reference_state(self, *, reset_mocap_session: bool = True) -> None: self._step_runner.reset() self._loop.controller.reset() self._loop.obs_builder.reset() - self._retargeter.reset() - self.mocap_session.reset() + if reset_mocap_session: + self.mocap_session.reset() self.last_commanded_motion_qpos = None - self.reset_runtime_tracking(warmup_steps=warmup_steps) + self.reset_runtime_tracking() def enter_standing_mode(self) -> None: from teleopit.sim.loop import SimulationMode - self.full_policy_reset() + self.reset_policy_reference_state() + self._loop._set_standing_reference(self._loop.robot.get_state()) self.simulation_mode = SimulationMode.STANDING - def enter_mocap_mode(self) -> None: + def enter_mocap_mode(self) -> bool: from teleopit.sim.loop import SimulationMode loop = self._loop if not loop._realtime_input_has_frame(self._input_provider): _logger.warning("Cannot switch to MOCAP yet: realtime input has no frame available") - return + return False state = loop.robot.get_state() start_qpos = loop._resolve_hold_qpos(None, None, None, state) - self.full_policy_reset() + self.reset_policy_reference_state() self._step_runner.last_retarget_qpos = start_qpos.copy() - self._step_runner.arm_motion_transition( - start_qpos, - duration_s=loop._mocap_transition_duration, - ) self.last_commanded_motion_qpos = start_qpos.copy() self.simulation_mode = SimulationMode.MOCAP + return True + + def toggle_arms_mode(self) -> bool: + from teleopit.sim.loop import SimulationMode + if not self.realtime_interpolated_input or self.simulation_mode not in (SimulationMode.MOCAP, SimulationMode.ARMS): + return False + if self.mocap_session.state == MocapSessionState.PAUSED: + _logger.info("Ignoring arm-only mode toggle while mocap session is paused") + return False + loop = self._loop + state = loop.robot.get_state() + resume_qpos = loop._build_resume_alignment_qpos(self.last_commanded_motion_qpos, state) + if self.simulation_mode == SimulationMode.MOCAP: + loop._set_standing_reference(state) + self.simulation_mode = SimulationMode.ARMS + else: + self.simulation_mode = SimulationMode.MOCAP + self._step_runner.reset() + loop.controller.reset() + loop.obs_builder.reset() + self.mocap_session.reset() + self.last_commanded_motion_qpos = None + self._step_runner.reset_reference_alignment(resume_qpos) + self.last_commanded_motion_qpos = resume_qpos.copy() + _logger.info("Simulation mode -> %s", self.simulation_mode.value.upper()) + return True - def toggle_realtime_mocap_pause(self) -> None: + def toggle_realtime_mocap_pause(self) -> str: loop = self._loop if self.mocap_session.state == MocapSessionState.PAUSED: hold_qpos = self.mocap_session.hold_qpos if hold_qpos is None: raise RuntimeError("Cannot resume mocap without a paused hold qpos") resume_qpos = loop._build_resume_alignment_qpos(hold_qpos, loop.robot.get_state()) - self.full_policy_reset(warmup_steps=loop._ref_cfg.pause_resume_warmup_steps) + self.reset_policy_reference_state() self._step_runner.reset_reference_alignment(resume_qpos) self.last_commanded_motion_qpos = resume_qpos.copy() - return + return "resumed" hold_qpos = loop._resolve_hold_qpos( self.last_commanded_motion_qpos, self._step_runner.last_retarget_qpos, self.latest_live_retargeted, loop.robot.get_state(), ) - self.full_policy_reset() + self.reset_policy_reference_state() self.mocap_session.pause(hold_qpos) self.last_commanded_motion_qpos = hold_qpos.copy() + return "paused" # ------------------------------------------------------------------ # Keyboard handling @@ -298,18 +315,32 @@ def _handle_realtime_keyboard(self) -> bool: assert self.keyboard_reader is not None for key_event in self.keyboard_reader.poll(): key = key_event.key.lower() + if key == "h": + self._loop._console.help(self._loop._console_controls) + continue if key == "q": self.playback_stop_requested = True + self._loop._console.key_feedback("Q", "quit", result="stopping") return True if self.simulation_mode == SimulationMode.STANDING: if key == "y": - self.enter_mocap_mode() + if self.enter_mocap_mode(): + self._loop._console.key_feedback("Y", "mocap", result="MOCAP") + else: + self._loop._console.key_feedback("Y", "mocap", result="waiting for input") continue if key == "x": self.enter_standing_mode() + self._loop._console.key_feedback("X", "standing", result="STANDING") + continue + if key == "b": + if self.toggle_arms_mode(): + self._loop._console.key_feedback("B", "arms", result=self.simulation_mode.value.upper()) + else: + self._loop._console.key_feedback("B", "arms", result="ignored") continue if key == "a": - self.toggle_realtime_mocap_pause() + self._loop._console.key_feedback("A", "pause/resume", result=self.toggle_realtime_mocap_pause()) return False def _handle_offline_keyboard(self) -> bool: @@ -319,18 +350,22 @@ def _handle_offline_keyboard(self) -> bool: loop = self._loop for key_event in self.keyboard_reader.poll(): key = key_event.key.lower() + if key == "h": + self._loop._console.help(self._loop._console_controls) + continue if key == "q": self.playback_stop_requested = True + self._loop._console.key_feedback("Q", "stop", result="stopping") return True if key == "r": loop._restart_offline_playback( offline_playback=self.offline_playback, mocap_session=self.mocap_session, - retargeter=self._retargeter, ) self.cached_human_frame = None self.cached_retargeted = None self.last_commanded_motion_qpos = None + self._loop._console.key_feedback("R", "replay", result="frame 0") continue if key not in (" ", "p"): continue @@ -339,14 +374,15 @@ def _handle_offline_keyboard(self) -> bool: logging.getLogger(__name__).info( "Offline playback already ended; press r to replay from frame 0." ) + self._loop._console.key_feedback("Space/P", "pause/resume", result="ended; press R") else: loop._resume_offline_playback( offline_playback=self.offline_playback, mocap_session=self.mocap_session, - retargeter=self._retargeter, state=loop.robot.get_state(), ) self.last_commanded_motion_qpos = None + self._loop._console.key_feedback("Space/P", "pause/resume", result="resumed") else: hold_qpos = loop._resolve_hold_qpos( self.last_commanded_motion_qpos, @@ -358,8 +394,8 @@ def _handle_offline_keyboard(self) -> bool: offline_playback=self.offline_playback, mocap_session=self.mocap_session, hold_qpos=hold_qpos, - retargeter=self._retargeter, ) + self._loop._console.key_feedback("Space/P", "pause/resume", result="paused") return False # ------------------------------------------------------------------ @@ -368,9 +404,11 @@ def _handle_offline_keyboard(self) -> bool: def _fetch_standing_input(self) -> tuple[bool, ReferenceWindow | None, RealtimeReferenceDiagnostics | None]: """Fetch input when in STANDING mode (keyboard). Returns (new_bvh_frame, ref_window, diag).""" - state = self._loop.robot.get_state() self.cached_human_frame = None - self.cached_retargeted = self._loop._build_standing_qpos(state) + if self._loop._standing_qpos is None: + self.cached_retargeted = self._loop._set_standing_reference(self._loop.robot.get_state()) + else: + self.cached_retargeted = self._loop._standing_qpos.copy() return False, None, None def _fetch_offline_reference_input( @@ -423,9 +461,11 @@ def _fetch_realtime_input(self) -> tuple[bool, ReferenceWindow | None, RealtimeR frame_timestamp = float(packet.timestamp_s) frame_seq = int(packet.seq) for control_event in packet.control_events: - if control_event.event_type != ControlEventType.TOGGLE_PAUSE: + if control_event.event_type == ControlEventType.TOGGLE_ARMS: + self.toggle_arms_mode() continue - self.toggle_realtime_mocap_pause() + if control_event.event_type == ControlEventType.TOGGLE_PAUSE: + self.toggle_realtime_mocap_pause() new_bvh_frame = frame_seq != self.last_live_packet_seq if self.mocap_session.state == MocapSessionState.PAUSED: @@ -485,11 +525,8 @@ def _fetch_realtime_input(self) -> tuple[bool, ReferenceWindow | None, RealtimeR self.reference_timeline, target_base_time, ) - if loop._ref_cfg.reference_debug_log: - if any(reference_window.fallback_mask()): - loop._log_reference_window(reference_window, len(self.reference_timeline)) - if realtime_reference_diag.used_repeat_padding: - loop._log_repeat_padding(reference_window, realtime_reference_diag, len(self.reference_timeline)) + if loop._ref_cfg.reference_debug_log and any(reference_window.fallback_mask()): + loop._log_reference_window(reference_window, len(self.reference_timeline)) self.cached_retargeted = reference_window.current_sample().qpos else: if self.latest_live_retargeted is None: @@ -512,6 +549,18 @@ def _fetch_realtime_input(self) -> tuple[bool, ReferenceWindow | None, RealtimeR else: self.cached_retargeted = self.latest_live_retargeted + from teleopit.sim.loop import SimulationMode + if self.simulation_mode == SimulationMode.ARMS: + self.cached_retargeted = loop._compose_arm_reference(cast(Float64Array, self.cached_retargeted)) + if reference_window is not None: + assert loop._standing_qpos is not None + reference_window = compose_arm_reference_window( + reference_window, + standing_qpos=loop._standing_qpos, + arm_joint_indices=loop._arm_joint_indices, + num_actions=loop._num_actions, + ) + return new_bvh_frame, reference_window, realtime_reference_diag, False def _fetch_simple_bvh_input(self, frame_f: float) -> tuple[bool, bool]: @@ -560,7 +609,7 @@ def run(self) -> dict[str, float | int]: if self.playback_stop_requested: break - if self.realtime_keyboard_mode_enabled and self.simulation_mode != SimulationMode.MOCAP: + if self.realtime_keyboard_mode_enabled and self.simulation_mode == SimulationMode.STANDING: loop._drain_realtime_control_events(self._input_provider) # --- Compute time/frame --- @@ -630,7 +679,6 @@ def run(self) -> dict[str, float | int]: target_dof_pos = self._step_runner.compute_target_dof_pos(action) torque, final_state = self._step_runner.apply_control(target_dof_pos) loop._publisher.publish(preparation.mimic_obs, action, final_state) - loop._recorder_helper.record(self._recorder, final_state, preparation.mimic_obs, action, target_dof_pos, torque) self._viewer_manager.write_sim2sim(loop.robot) self._viewer_manager.write_camera(loop.robot) if loop._video_runtime is not None: diff --git a/teleopit/sim2real/__init__.py b/teleopit/sim2real/__init__.py index 7fe810de..2c22fb9c 100644 --- a/teleopit/sim2real/__init__.py +++ b/teleopit/sim2real/__init__.py @@ -1,10 +1,22 @@ -from teleopit.sim2real.controller import Sim2RealController -from teleopit.sim2real.unitree_g1 import UnitreeG1Robot -from teleopit.sim2real.remote import UnitreeRemote, Button - __all__ = [ - "Sim2RealController", + "Sim2RealRuntime", "UnitreeG1Robot", "UnitreeRemote", "Button", ] + + +def __getattr__(name: str): + if name == "Sim2RealRuntime": + from teleopit.sim2real.mp import Sim2RealRuntime + + return Sim2RealRuntime + if name == "UnitreeG1Robot": + from teleopit.sim2real.unitree_g1 import UnitreeG1Robot + + return UnitreeG1Robot + if name in ("UnitreeRemote", "Button"): + from teleopit.sim2real.remote import Button, UnitreeRemote + + return {"UnitreeRemote": UnitreeRemote, "Button": Button}[name] + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/teleopit/sim2real/controller.py b/teleopit/sim2real/controller.py deleted file mode 100644 index 8324743b..00000000 --- a/teleopit/sim2real/controller.py +++ /dev/null @@ -1,917 +0,0 @@ -"""Sim2Real controller -- state machine + dual-mode control loop. - -Supports two operating modes for a physical Unitree G1: -- **Standing**: RL policy maintains balance with fixed default-pose reference -- **Mocap**: RL policy tracks retargeted motion commands - -State machine: - IDLE ──Start──▶ STANDING ──Y──▶ MOCAP ──X──▶ STANDING - Any ──L1+R1──▶ DAMPING ──Start──▶ STANDING -""" - -from __future__ import annotations - -import logging -import time -from enum import Enum -from pathlib import Path -from typing import Any - -import numpy as np -from numpy.typing import NDArray - -from teleopit.constants import FULL_QPOS_DIM, NUM_JOINTS, ROOT_DIM -from teleopit.controllers.observation import ( - VelCmdObservationBuilder, - align_motion_qpos_yaw, -) -from teleopit.controllers.qpos_interpolator import QposInterpolator -from teleopit.controllers.rl_policy import RLPolicyController -from teleopit.inputs.bvh_provider import BVHInputProvider -from teleopit.inputs.pico4_provider import Pico4InputProvider -from teleopit.inputs.pico_video import PicoVideoRuntime, parse_pico_video_config -from teleopit.inputs.realtime_packet import ControlEvent, ControlEventType, RealtimeInputPacket -from teleopit.retargeting.core import RetargetingModule -from teleopit.runtime.common import cfg_get -from teleopit.runtime.reference_config import parse_reference_config -from teleopit.runtime.factory import build_sim2real_mocap_components -from teleopit.runtime.mocap_session import MocapSessionManager, MocapSessionState -from teleopit.runtime.offline_playback import OfflinePlaybackController -from teleopit.sim.reference_motion import OfflineReferenceMotion -from teleopit.sim.reference_timeline import ReferenceTimeline, ReferenceWindow, ReferenceWindowBuilder -from teleopit.sim.reference_utils import ( - build_offline_reference_window, - build_static_reference_window, - obs_builder_requires_reference_window, -) -from teleopit.sim.realtime_utils import RealtimeReferenceManager -from teleopit.sim2real.reference_processor import Sim2RealReferenceProcessor -from teleopit.sim2real.remote import UnitreeRemote -from teleopit.sim2real.safety import Sim2RealSafetyManager -from teleopit.sim2real.unitree_g1 import UnitreeG1Robot - -logger = logging.getLogger(__name__) - -Float32Array = NDArray[np.float32] -Float64Array = NDArray[np.float64] - - -class RobotMode(Enum): - IDLE = "idle" # Script waiting, robot controlled by remote - STANDING = "standing" # Debug mode, RL policy holds default pose - MOCAP = "mocap" # Debug mode, RL policy tracks motion commands - DAMPING = "damping" # Emergency stop / recovery - - -class Sim2RealController: - """G1 real-robot controller -- standing/mocap dual mode with state machine. - - Standing mode: enter debug mode, RL policy maintains balance at default pose. - Mocap mode: RL policy tracks retargeted motion commands. - Both modes share the same RL policy inference pipeline. - """ - - def __init__(self, cfg: Any) -> None: - self.cfg = cfg - self.mode = RobotMode.IDLE - - self.policy_hz: float = float(cfg_get(cfg, "policy_hz", 50.0)) - self._project_root = Path(__file__).resolve().parent.parent.parent - - # Motion command transition smoothing - transition_dur = float(cfg_get(cfg, "transition_duration", 0.0) or 0.0) - self._mocap_transition_duration = transition_dur - self._pause_resume_transition_duration = float( - cfg_get(cfg, "pause_resume_transition_duration", transition_dur) or 0.0 - ) - self._qpos_interpolator = QposInterpolator(transition_dur, self.policy_hz) - - self._init_components(cfg) - self._init_reference_config(cfg) - self._safety = Sim2RealSafetyManager(cfg, self.robot, self.policy_hz, self.num_actions) - - logger.info( - "Sim2RealController ready | mode=IDLE | policy_hz=%.0f", - self.policy_hz, - ) - - def _init_components(self, cfg: Any) -> None: - """Build robot hardware and mocap pipeline components.""" - real_cfg = cfg_get(cfg, "real_robot") - self.robot = UnitreeG1Robot(real_cfg) - self.remote = UnitreeRemote() - - robot_cfg = cfg_get(cfg, "robot") - mocap_components = build_sim2real_mocap_components( - cfg, - self._project_root, - controller_cls=RLPolicyController, - obs_builder_cls=VelCmdObservationBuilder, - bvh_input_cls=BVHInputProvider, - pico4_input_cls=Pico4InputProvider, - retargeter_cls=RetargetingModule, - ) - self.input_provider = mocap_components.input_provider - self.retargeter = mocap_components.retargeter - self.policy = mocap_components.controller - self.obs_builder = mocap_components.obs_builder - self._video_runtime = PicoVideoRuntime( - provider=self.input_provider, - config=parse_pico_video_config(cfg_get(cfg, "input", {})), - mode="sim2real", - ) - self._offline_reference: OfflineReferenceMotion | None = None - self._offline_playback: OfflinePlaybackController | None = None - if hasattr(self.input_provider, "__len__") and hasattr(self.input_provider, "get_frame_by_index"): - playback_cfg = cfg_get(cfg, "playback", {}) - self._offline_reference = OfflineReferenceMotion(self.input_provider, self.retargeter) - self._offline_playback = OfflinePlaybackController( - duration_s=self._offline_reference.duration_s, - step_dt_s=1.0 / self.policy_hz, - pause_on_end=bool(cfg_get(playback_cfg, "pause_on_end", True)), - ) - if not bool(getattr(self.policy, "_multi_input", False)): - raise ValueError( - "Sim2real requires an ONNX policy with dual inputs ('obs' and 'obs_history')." - ) - - # Default standing pose (29-DOF) - self.default_angles = np.asarray( - cfg_get(robot_cfg, "default_angles"), dtype=np.float32 - ) - self.num_actions: int = int(cfg_get(robot_cfg, "num_actions", NUM_JOINTS)) - - # Standing mode reference qpos - self._standing_qpos = np.zeros(FULL_QPOS_DIM, dtype=np.float64) - self._standing_qpos[3] = 1.0 # identity quaternion w=1 - self._standing_qpos[ROOT_DIM:FULL_QPOS_DIM] = self.default_angles.astype(np.float64) - - # Policy state (shared by STANDING and MOCAP) - self._last_action: Float32Array = np.zeros(self.num_actions, dtype=np.float32) - self._last_retarget_qpos: Float64Array | None = None - self._mocap_reentry_armed: bool = False - self._mocap_session = MocapSessionManager() - self._last_commanded_motion_qpos: Float64Array | None = None - - def _init_reference_config(self, cfg: Any) -> None: - """Parse reference-window / realtime-buffer configuration.""" - provider_fps = float(getattr(self.input_provider, "fps", 30.0)) - self._ref_cfg = parse_reference_config(cfg, provider_fps=provider_fps) - rc = self._ref_cfg - - self._reference_window_builder = ReferenceWindowBuilder( - policy_dt_s=1.0 / self.policy_hz, - reference_steps=cfg_get(cfg, "reference_steps", [0]), - ) - if not rc.retarget_buffer_enabled and self._reference_window_builder.requires_timeline: - raise ValueError( - "Non-zero reference_steps require retarget_buffer_enabled=true in sim2real so " - "the realtime reference timeline can sample future/history horizons." - ) - if rc.retarget_buffer_enabled: - self._reference_window_builder.validate_runtime_support( - delay_s=rc.reference_delay_s, - window_s=rc.retarget_buffer_window_s, - config_label="Sim2Real reference timeline", - ) - self._reference_timeline: ReferenceTimeline | None = ( - ReferenceTimeline(window_s=rc.retarget_buffer_window_s) - if rc.retarget_buffer_enabled - else None - ) - self._reference_manager: RealtimeReferenceManager | None = ( - RealtimeReferenceManager( - reference_window_builder=self._reference_window_builder, - low_watermark_steps=rc.realtime_buffer_low_watermark_steps, - high_watermark_steps=rc.realtime_buffer_high_watermark_steps, - warmup_steps=rc.realtime_buffer_warmup_steps, - catchup_enabled=rc.realtime_catchup_enabled, - catchup_trigger_steps=rc.realtime_catchup_trigger_steps, - catchup_release_steps=rc.realtime_catchup_release_steps, - catchup_target_delay_s=rc.realtime_catchup_target_delay_s, - ) - if self._reference_timeline is not None - else None - ) - mocap_sw = cfg_get(cfg, "mocap_switch", {}) - self._ref_proc = Sim2RealReferenceProcessor( - obs_builder=self.obs_builder, - policy=self.policy, - policy_hz=self.policy_hz, - num_actions=self.num_actions, - reference_velocity_smoothing_alpha=rc.reference_velocity_smoothing_alpha, - reference_anchor_velocity_smoothing_alpha=rc.reference_anchor_velocity_smoothing_alpha, - reference_qpos_smoothing_alpha=rc.reference_qpos_smoothing_alpha, - max_pos_value=float(cfg_get(mocap_sw, "max_position_value", 5.0)), - ) - self._last_live_packet_seq = -1 - - # Mocap switch safety - mocap_sw = cfg_get(cfg, "mocap_switch", {}) - self._check_frames: int = int(cfg_get(mocap_sw, "check_frames", 10)) - - # ------------------------------------------------------------------ - # Main control loop - # ------------------------------------------------------------------ - - def run(self) -> None: - """Main control loop at policy_hz.""" - logger.info( - "Control loop started | mode=IDLE | press Start to enter STANDING" - ) - dt = 1.0 / self.policy_hz - - try: - self._video_runtime.start() - while True: - t0 = time.monotonic() - self._video_runtime.tick() - - # 1. Read remote state - remote_bytes = self.robot.get_wireless_remote() - self.remote.update(remote_bytes) - - # 2. Emergency stop (highest priority) - if self.remote.LB.pressed and self.remote.RB.pressed: - if self.mode != RobotMode.DAMPING: - logger.warning("EMERGENCY STOP (L1+R1)") - self._enter_damping() - self._sleep_until(t0, dt) - continue - - # 3. Mode transitions - self._handle_transitions() - - # 5. Execute current mode - if self.mode == RobotMode.STANDING: - self._standing_step() - elif self.mode == RobotMode.MOCAP: - self._mocap_step() - - # 6. Rate control - self._sleep_until(t0, dt) - - except KeyboardInterrupt: - logger.info("KeyboardInterrupt -- shutting down") - - # ------------------------------------------------------------------ - # Mode execution - # ------------------------------------------------------------------ - - def _standing_step(self) -> None: - """Standing mode: feed fixed default-pose reference to RL policy.""" - robot_state = self.robot.get_state() - - # Build standing reference qpos aligned to robot's current yaw - qpos = self._standing_qpos.copy() - align_motion_qpos_yaw( - np.asarray(robot_state.quat, dtype=np.float32), qpos - ) - - # Standing → zero joint velocity reference - motion_joint_vel = np.zeros(self.num_actions, dtype=np.float32) - motion_qpos = np.asarray(qpos[:7 + self.num_actions], dtype=np.float32) - - reference_window = None - if obs_builder_requires_reference_window(self.obs_builder): - reference_window = build_static_reference_window(qpos, self._reference_window_builder, self.policy_hz) - - obs = self._ref_proc.build_observation( - robot_state=robot_state, - motion_qpos=motion_qpos, - motion_joint_vel=motion_joint_vel, - last_action=self._last_action, - anchor_lin_vel_w=np.zeros(3, dtype=np.float32), - anchor_ang_vel_w=np.zeros(3, dtype=np.float32), - reference_window=reference_window, - ) - obs = self._ref_proc.validate_observation(obs) - - action = self.policy.compute_action(obs) - target_dof_pos = self.policy.get_target_dof_pos(action) - target_dof_pos = self._safety.clip_to_joint_limits(target_dof_pos) - - self._safety.send_positions(target_dof_pos) - - self._last_action = np.asarray(action, dtype=np.float32).reshape(-1) - self._last_retarget_qpos = qpos.copy() - self._last_commanded_motion_qpos = qpos.copy() - - def _mocap_step(self) -> None: - """Mocap mode: input provider -> retarget -> policy -> update LowCmd targets.""" - if self._offline_reference is not None: - self._offline_mocap_step() - return - - if not self.input_provider.is_available(): - logger.warning("Input provider unavailable -- entering damping") - self._enter_damping() - return - - try: - packet = self._fetch_realtime_input_packet() - except (TimeoutError, RuntimeError): - logger.warning("Input provider error -- entering damping") - self._enter_damping() - return - - self._handle_mocap_control_events(packet.control_events) - if self._mocap_session.state == MocapSessionState.PAUSED: - self._paused_mocap_step() - return - - human_frame = packet.frame - frame_timestamp = float(packet.timestamp_s) - frame_seq = int(packet.seq) - - reference_window: ReferenceWindow | None = None - if self._reference_timeline is not None: - if int(frame_seq) != self._last_live_packet_seq: - retargeted = self.retargeter.retarget(human_frame) - self._reference_timeline.append( - self._ref_proc.retarget_to_qpos(retargeted), - float(frame_timestamp), - ) - if self._reference_manager is not None: - self._reference_manager.note_realtime_frame() - self._last_live_packet_seq = int(frame_seq) - if self._reference_manager is None: - raise RuntimeError("Realtime reference manager must be initialized when using reference_timeline") - if not self._reference_manager.warmup_done: - return - reference_window, reference_diag = self._reference_manager.sample( - self._reference_timeline, - time.monotonic() - self._ref_cfg.reference_delay_s, - ) - if self._ref_cfg.reference_debug_log and any(reference_window.fallback_mask()): - logger.warning( - "Reference timeline fallback | buffer_len=%d | steps=%s | modes=%s", - len(self._reference_timeline), - list(reference_window.reference_steps), - list(reference_window.modes()), - ) - if self._ref_cfg.reference_debug_log and reference_diag.used_repeat_padding: - logger.warning( - "Reference timeline repeat padding | buffer_len=%d | future_horizon_steps=%d | steps=%s", - len(self._reference_timeline), - reference_diag.future_horizon_steps, - list(reference_window.reference_steps), - ) - if self._ref_cfg.reference_debug_log and reference_diag.used_catchup: - logger.warning( - "Reference timeline catch-up | requested_base=%.6f | effective_base=%.6f | latest=%.6f | future_horizon_steps=%d", - reference_diag.requested_base_time_s, - reference_diag.effective_base_time_s, - -1.0 if reference_diag.latest_timestamp_s is None else reference_diag.latest_timestamp_s, - reference_diag.future_horizon_steps, - ) - reference_qpos = reference_window.current_sample().qpos - else: - retargeted = self.retargeter.retarget(human_frame) - reference_qpos = self._ref_proc.retarget_to_qpos(retargeted) - - robot_state = self.robot.get_state() - self._execute_mocap_pipeline(reference_qpos, robot_state, reference_window) - - def _offline_mocap_step(self) -> None: - if self._offline_reference is None or self._offline_playback is None: - raise RuntimeError("Offline playback step requires an offline reference motion") - - if self._mocap_session.state == MocapSessionState.PAUSED: - self._paused_mocap_step() - return - - sample_time_s = self._offline_playback.current_time_s - sampled = self._offline_reference.sample(sample_time_s) - if sampled is None: - self._hold_completed_offline_playback(self._resolve_mocap_hold_qpos()) - self._paused_mocap_step() - return - - reference_window: ReferenceWindow | None = None - if obs_builder_requires_reference_window(self.obs_builder): - reference_window = build_offline_reference_window( - self._offline_reference, - sample_time_s, - self._reference_window_builder, - self.policy_hz, - ) - - reference_qpos = np.asarray(sampled.qpos, dtype=np.float64).copy() - robot_state = self.robot.get_state() - self._execute_mocap_pipeline(reference_qpos, robot_state, reference_window) - - if self._offline_playback.advance(): - self._hold_completed_offline_playback(self._last_commanded_motion_qpos) - - def _execute_mocap_pipeline( - self, - reference_qpos: Float64Array, - robot_state: object, - reference_window: ReferenceWindow | None, - ) -> None: - """Shared mocap control pipeline: align → smooth → interpolate → infer → send.""" - reference_qpos = self._ref_proc.align_reference_yaw(reference_qpos, robot_state=robot_state) - reference_qpos = self._ref_proc.apply_qpos_smoothing(reference_qpos) - qpos = self._qpos_interpolator.apply(reference_qpos) - - # Compute joint velocities via finite difference - if qpos.shape[0] < 7 + self.num_actions: - raise ValueError( - f"Retargeted qpos too short: {qpos.shape[0]} (need >= {7 + self.num_actions})" - ) - motion_joint_pos = np.asarray(qpos[7:7 + self.num_actions], dtype=np.float32) - if self._last_retarget_qpos is None: - raw_motion_joint_vel = np.zeros((self.num_actions,), dtype=np.float32) - else: - prev_joint_pos = np.asarray(self._last_retarget_qpos[7:7 + self.num_actions], dtype=np.float32) - raw_motion_joint_vel = (motion_joint_pos - prev_joint_pos) * np.float32(self.policy_hz) - motion_joint_vel = self._ref_proc.apply_joint_vel_smoothing(raw_motion_joint_vel) - - # Compute anchor velocities (with interpolator blending if active) - anchor_lin_vel_w = np.zeros(3, dtype=np.float32) - anchor_ang_vel_w = np.zeros(3, dtype=np.float32) - if not obs_builder_requires_reference_window(self.obs_builder): - if self._qpos_interpolator.is_active: - true_lin_vel_w, true_ang_vel_w = self._ref_proc.compute_anchor_velocities(reference_qpos) - blend = np.float32(self._qpos_interpolator.last_alpha) - raw_anchor_lin_vel_w = np.asarray(true_lin_vel_w * blend, dtype=np.float32) - raw_anchor_ang_vel_w = np.asarray(true_ang_vel_w * blend, dtype=np.float32) - else: - raw_anchor_lin_vel_w, raw_anchor_ang_vel_w = self._ref_proc.compute_anchor_velocities(reference_qpos) - anchor_lin_vel_w, anchor_ang_vel_w = self._ref_proc.apply_anchor_vel_smoothing( - raw_anchor_lin_vel_w, raw_anchor_ang_vel_w, - ) - - # Build observation and run policy - motion_qpos = np.asarray(qpos[:7 + self.num_actions], dtype=np.float32) - obs = self._ref_proc.build_observation( - robot_state=robot_state, - motion_qpos=motion_qpos, - motion_joint_vel=motion_joint_vel, - last_action=self._last_action, - anchor_lin_vel_w=anchor_lin_vel_w, - anchor_ang_vel_w=anchor_ang_vel_w, - reference_window=reference_window, - ) - obs = self._ref_proc.validate_observation(obs) - - action = self.policy.compute_action(obs) - target_dof_pos = self.policy.get_target_dof_pos(action) - target_dof_pos = self._safety.clip_to_joint_limits(target_dof_pos) - self._safety.send_positions(target_dof_pos) - - # Update state - self._last_action = np.asarray(action, dtype=np.float32).reshape(-1) - self._last_retarget_qpos = qpos.copy() - self._ref_proc.last_reference_qpos = reference_qpos.copy() - self._last_commanded_motion_qpos = qpos.copy() - - # ------------------------------------------------------------------ - # State machine transitions - # ------------------------------------------------------------------ - - def _handle_transitions(self) -> None: - """Handle remote-triggered mode transitions.""" - if self.mode == RobotMode.IDLE: - if self.remote.start.on_pressed: - logger.info("Start pressed (from IDLE)") - self._enter_standing() - - elif self.mode == RobotMode.STANDING: - reentry_request = self._mocap_reentry_armed and self.remote.Y.pressed - if self.remote.Y.on_pressed or reentry_request: - if self._can_switch_to_mocap(): - if reentry_request and not self.remote.Y.on_pressed: - logger.info("Y held after STANDING return -> re-entering MOCAP") - else: - logger.info("Y pressed -> entering MOCAP") - self._transition_to_mocap() - else: - logger.warning("Cannot switch to MOCAP -- input check failed") - - elif self.mode == RobotMode.MOCAP: - if self.remote.B.on_pressed and self._offline_playback is not None: - logger.info("B pressed -> replaying offline motion from start") - self._restart_offline_playback() - return - if self.remote.A.on_pressed: - if self._mocap_session.state == MocapSessionState.PAUSED: - if self._offline_playback is not None and self._offline_playback.finished: - logger.info("Playback already ended; press B to replay from the start") - else: - logger.info("A pressed -> resuming playback") - self._resume_paused_mocap() - else: - logger.info("A pressed -> pausing playback") - self._pause_active_mocap() - return - if self.remote.X.on_pressed: - logger.info("X pressed -> returning to STANDING") - self._enter_standing() - - elif self.mode == RobotMode.DAMPING: - if self.remote.start.on_pressed: - logger.info("Start pressed (from DAMPING)") - self._enter_standing() - - # ------------------------------------------------------------------ - # Enter STANDING (from IDLE, MOCAP, or DAMPING) - # ------------------------------------------------------------------ - - def _enter_standing(self) -> None: - """Enter standing mode: debug mode + RL policy holds default pose. - - Works from IDLE, MOCAP (already in debug mode), or DAMPING. - """ - prev_mode = self.mode - already_in_debug = self.mode in (RobotMode.STANDING, RobotMode.MOCAP) - - if not already_in_debug: - logger.info("Entering debug mode...") - ok = self.robot.enter_debug_mode() - if not ok: - logger.error("Failed to enter debug mode -- staying in %s", self.mode.value) - return - time.sleep(0.5) - - # Lock joints to current position (prevent collapse during init) - logger.info("Locking joints to current position...") - self.robot.lock_all_joints() - time.sleep(0.3) - - # Episode-reset semantics: reference = current robot state, full policy reset. - # This matches training where robot is teleported to reference position at - # episode start, so policy sees reference ≈ robot state with clean history. - state = self.robot.get_state() - init_qpos = np.zeros(FULL_QPOS_DIM, dtype=np.float64) - init_qpos[3:7] = state.quat.astype(np.float64) - init_qpos[ROOT_DIM:FULL_QPOS_DIM] = state.qpos.astype(np.float64) - self._last_retarget_qpos = init_qpos - self._ref_proc.last_reference_qpos = None - self._mocap_session.reset() - self._last_commanded_motion_qpos = None - - # Always do a full policy reset (episode-reset semantics) to ensure - # the TemporalCNN history is clean and action-state causality holds. - self._reset_policy_state() - - # Kp ramp: gradually increase PD gains to avoid torque spike. - # Unlike the old position ramp, this does NOT break action-state causality. - self._safety.start_kp_ramp() - - self._mocap_reentry_armed = prev_mode == RobotMode.MOCAP - - self.mode = RobotMode.STANDING - logger.info("Mode -> STANDING (RL policy maintaining balance at default pose)") - - # ------------------------------------------------------------------ - # STANDING -> MOCAP - # ------------------------------------------------------------------ - - def _can_switch_to_mocap(self) -> bool: - """Verify input signal is stable and values are reasonable.""" - if not self.input_provider.is_available(): - logger.warning("Mocap check: input provider not available") - return False - - if self._offline_reference is not None: - frame_count = min(self._check_frames, self._offline_reference.num_frames) - valid_count = 0 - for frame_index in range(frame_count): - try: - frame = self.input_provider.get_frame_by_index(frame_index) - except (IndexError, RuntimeError, ValueError): - return False - if self._ref_proc.frame_is_valid(frame): - valid_count += 1 - else: - break - if valid_count >= frame_count: - return True - logger.warning("Mocap check: only %d/%d valid offline frames", valid_count, frame_count) - return False - - has_frame = getattr(self.input_provider, "has_frame", None) - if callable(has_frame): - try: - if not bool(has_frame()): - logger.warning("Mocap check: realtime input has no frame available yet") - return False - except Exception: - logger.warning("Mocap check: failed to query realtime input availability") - return False - - valid_count = 0 - for _ in range(self._check_frames + 5): - try: - frame = self.input_provider.get_frame() - except (TimeoutError, RuntimeError): - return False - - if self._ref_proc.frame_is_valid(frame): - valid_count += 1 - else: - valid_count = 0 - - if valid_count >= self._check_frames: - return True - - time.sleep(0.02) - - logger.warning("Mocap check: only %d/%d valid frames", valid_count, self._check_frames) - return False - - def _transition_to_mocap(self) -> None: - """Switch from STANDING -> MOCAP. - - Episode-reset + reference-side interpolation. The policy state is - fully reset (clean history, zero last_action) so the TemporalCNN - starts fresh. A QposInterpolator smoothly blends the *reference* - from the current robot state toward incoming live mocap so the policy - never sees a large instantaneous tracking error. - """ - state = self.robot.get_state() - init_qpos = np.zeros(FULL_QPOS_DIM, dtype=np.float64) - init_qpos[3:7] = state.quat.astype(np.float64) - init_qpos[ROOT_DIM:FULL_QPOS_DIM] = state.qpos.astype(np.float64) - self._last_retarget_qpos = init_qpos - self._last_commanded_motion_qpos = init_qpos.copy() - self._mocap_reentry_armed = False - - # Full episode reset: clean policy state, alignment, timeline. - self._reset_policy_state() - - # Reference-side interpolation: smoothly blend reference from current - # robot state toward incoming live mocap. This is done AFTER the - # episode reset so the interpolator starts with a clean slate. - self._arm_qpos_transition(init_qpos, duration_s=self._mocap_transition_duration) - if self._offline_playback is not None: - self._offline_playback.replay() - - self.mode = RobotMode.MOCAP - logger.info("Mode -> MOCAP (tracking motion commands)") - - # ------------------------------------------------------------------ - # Emergency stop / damping - # ------------------------------------------------------------------ - - def _enter_damping(self) -> None: - """Enter damping mode from any state.""" - if self.mode in (RobotMode.STANDING, RobotMode.MOCAP): - logger.info("DAMPING: sending LowCmd damping...") - self.robot.set_damping() - time.sleep(0.5) - logger.info("DAMPING: exiting debug mode...") - self.robot.exit_debug_mode() - - self.mode = RobotMode.DAMPING - self._ref_proc.last_reference_qpos = None - if self._reference_timeline is not None: - self._reference_timeline.clear() - self._last_live_packet_seq = -1 - self._mocap_reentry_armed = False - self._mocap_session.reset() - self._last_commanded_motion_qpos = None - logger.info("Mode -> DAMPING (press Start to re-enter STANDING)") - - # ------------------------------------------------------------------ - # Helpers - # ------------------------------------------------------------------ - - def _reset_policy_state(self) -> None: - """Full episode-reset: clear all policy state so the TemporalCNN sees - a clean start identical to training episode reset.""" - self._last_action = np.zeros(self.num_actions, dtype=np.float32) - self._qpos_interpolator.reset() - self._reset_mocap_reference_state() - self._ref_proc.reset_alignment() - self._mocap_session.reset() - self._last_commanded_motion_qpos = None - self.policy.reset() - self.obs_builder.reset() - self.retargeter.reset() - - def _reset_mocap_reference_state(self, *, warmup_steps: int | None = None) -> None: - """Reset mocap-specific reference state without disrupting policy observation continuity. - - Unlike ``_reset_policy_state``, this preserves ``_last_action``, the - policy history buffer, and the observation builder state so that the - TemporalCNN sees a continuous observation stream across mode switches. - """ - if self._reference_timeline is not None: - self._reference_timeline.clear() - if self._reference_manager is not None: - self._reference_manager.set_warmup_steps( - self._ref_cfg.realtime_buffer_warmup_steps if warmup_steps is None else warmup_steps - ) - self._reference_manager.reset() - self._ref_proc.reset_smoothers() - self._last_live_packet_seq = -1 - - def _build_resume_alignment_qpos(self, hold_qpos: Float64Array | None, state: object) -> Float64Array: - qpos = np.zeros(FULL_QPOS_DIM, dtype=np.float64) - if hold_qpos is not None: - qpos[0:2] = np.asarray(hold_qpos, dtype=np.float64).reshape(-1)[0:2] - base_pos = getattr(state, "base_pos", None) - if base_pos is not None: - qpos[0:2] = np.asarray(base_pos, dtype=np.float64).reshape(-1)[0:2] - qpos[3:7] = np.asarray(getattr(state, "quat"), dtype=np.float64) - qpos[ROOT_DIM:FULL_QPOS_DIM] = np.asarray(getattr(state, "qpos"), dtype=np.float64) - return qpos - - def _restart_offline_playback(self) -> None: - if self._offline_playback is None: - raise RuntimeError("Offline playback replay is only available for indexed BVH input") - - state = self.robot.get_state() - restart_qpos = np.zeros(FULL_QPOS_DIM, dtype=np.float64) - restart_qpos[3:7] = state.quat.astype(np.float64) - restart_qpos[ROOT_DIM:FULL_QPOS_DIM] = state.qpos.astype(np.float64) - - self._last_retarget_qpos = restart_qpos.copy() - self._last_commanded_motion_qpos = restart_qpos.copy() - self._offline_playback.replay() - self._reset_policy_state() - self._arm_qpos_transition(restart_qpos, duration_s=self._mocap_transition_duration) - logger.info("Offline playback restarted from frame 0") - - def _hold_completed_offline_playback(self, hold_qpos: Float64Array) -> None: - if self._offline_playback is None or self._mocap_session.state == MocapSessionState.PAUSED: - return - self._offline_playback.finish() - self._mocap_session.pause(hold_qpos) - logger.info("Offline playback reached the end; press B to replay") - - def _arm_qpos_transition(self, start_qpos: Float64Array, *, duration_s: float) -> None: - self._qpos_interpolator.reset() - self._qpos_interpolator.configure(duration_s) - self._qpos_interpolator.start(start_qpos) - - def _fetch_realtime_input_packet(self) -> RealtimeInputPacket[object]: - get_realtime_input_packet = getattr(self.input_provider, "get_realtime_input_packet", None) - if callable(get_realtime_input_packet): - return get_realtime_input_packet() - - get_packet = getattr(self.input_provider, "get_frame_packet", None) - if callable(get_packet): - frame, frame_timestamp, frame_seq = get_packet() - return RealtimeInputPacket( - frame=frame, - timestamp_s=float(frame_timestamp), - seq=int(frame_seq), - control_events=(), - ) - - frame = self.input_provider.get_frame() - return RealtimeInputPacket( - frame=frame, - timestamp_s=time.monotonic(), - seq=self._last_live_packet_seq + 1, - control_events=(), - ) - - def _handle_mocap_control_events(self, control_events: tuple[ControlEvent, ...]) -> None: - for event in control_events: - if event.event_type != ControlEventType.TOGGLE_PAUSE: - continue - if self._mocap_session.state == MocapSessionState.PAUSED: - self._resume_paused_mocap() - else: - self._pause_active_mocap() - - def _pause_active_mocap(self) -> None: - # Episode-reset semantics: treat pause as a new episode starting at - # the hold pose. Full policy reset ensures TemporalCNN history is - # clean -- no stale frames from the previous motion that would create - # an OOD discontinuity (reference jumps from motion to static). - hold_qpos = self._resolve_mocap_hold_qpos() - self._last_retarget_qpos = hold_qpos.copy() - self._ref_proc.last_reference_qpos = hold_qpos.copy() - self._last_commanded_motion_qpos = hold_qpos.copy() - - # Reset policy state (clears last_action, history, smoothers, etc.) - # Note: _reset_policy_state resets _mocap_session to ACTIVE, so we - # must call pause() *after* it to set the correct PAUSED state. - self._reset_policy_state() - self._mocap_session.pause(hold_qpos) - if self._offline_playback is not None: - self._offline_playback.pause() - logger.info("Mocap session -> PAUSED (episode-reset)") - - def _resume_paused_mocap(self) -> None: - if self._offline_playback is not None and self._offline_playback.finished: - logger.info("Offline playback already ended; press B to replay from the start") - return - - hold_qpos = self._mocap_session.hold_qpos - if hold_qpos is None: - raise RuntimeError("Cannot resume mocap without a paused hold qpos") - state = self.robot.get_state() - resume_qpos = self._build_resume_alignment_qpos(hold_qpos, state) - - self._last_commanded_motion_qpos = resume_qpos.copy() - - # Full policy reset -- clean history, zero last_action, smoothers, - # timeline, alignment. Also resets _mocap_session to ACTIVE. - self._reset_policy_state() - self._last_retarget_qpos = None - self._last_commanded_motion_qpos = resume_qpos.copy() - - # Override warmup steps for the resume-specific buffer warmup. - if self._reference_manager is not None: - self._reference_manager.set_warmup_steps(self._ref_cfg.pause_resume_warmup_steps) - self._reference_manager.reset() - - self._ref_proc.reset_alignment(target_qpos=resume_qpos) - if self._offline_playback is not None: - self._last_retarget_qpos = resume_qpos.copy() - self._arm_qpos_transition(resume_qpos, duration_s=self._pause_resume_transition_duration) - self._offline_playback.resume() - - logger.info("Mocap session -> ACTIVE (episode-reset + reference realignment)") - - def _resolve_mocap_hold_qpos(self) -> Float64Array: - if self._last_commanded_motion_qpos is not None: - return self._last_commanded_motion_qpos.copy() - if self._last_retarget_qpos is not None: - return np.asarray(self._last_retarget_qpos, dtype=np.float64).copy() - state = self.robot.get_state() - hold_qpos = np.zeros(FULL_QPOS_DIM, dtype=np.float64) - hold_qpos[3:7] = np.asarray(state.quat, dtype=np.float64) - hold_qpos[ROOT_DIM:FULL_QPOS_DIM] = np.asarray(state.qpos, dtype=np.float64) - return hold_qpos - - def _paused_mocap_step(self) -> None: - hold_qpos = self._mocap_session.hold_qpos - if hold_qpos is None: - raise RuntimeError("Paused mocap session is missing a hold_qpos") - self._run_static_mocap_step(hold_qpos) - - def _run_static_mocap_step(self, hold_qpos: Float64Array) -> None: - robot_state = self.robot.get_state() - qpos = np.asarray(hold_qpos, dtype=np.float64).copy() - motion_joint_vel = np.zeros(self.num_actions, dtype=np.float32) - motion_qpos = np.asarray(qpos[:7 + self.num_actions], dtype=np.float32) - reference_window = None - if obs_builder_requires_reference_window(self.obs_builder): - reference_window = build_static_reference_window(qpos, self._reference_window_builder, self.policy_hz) - - obs = self._ref_proc.build_observation( - robot_state=robot_state, - motion_qpos=motion_qpos, - motion_joint_vel=motion_joint_vel, - last_action=self._last_action, - anchor_lin_vel_w=np.zeros(3, dtype=np.float32), - anchor_ang_vel_w=np.zeros(3, dtype=np.float32), - reference_window=reference_window, - ) - obs = self._ref_proc.validate_observation(obs) - - action = self.policy.compute_action(obs) - target_dof_pos = self.policy.get_target_dof_pos(action) - target_dof_pos = self._safety.clip_to_joint_limits(target_dof_pos) - - self._safety.send_positions(target_dof_pos) - self._last_action = np.asarray(action, dtype=np.float32).reshape(-1) - self._last_retarget_qpos = qpos.copy() - self._ref_proc.last_reference_qpos = qpos.copy() - self._last_commanded_motion_qpos = qpos.copy() - - @staticmethod - def _sleep_until(t0: float, dt: float) -> None: - """Sleep to maintain control frequency.""" - elapsed = time.monotonic() - t0 - remaining = dt - elapsed - if remaining > 0: - time.sleep(remaining) - - # ------------------------------------------------------------------ - # Lifecycle - # ------------------------------------------------------------------ - - def shutdown(self) -> None: - """Clean shutdown.""" - logger.info("Shutting down Sim2RealController") - if self.mode in (RobotMode.STANDING, RobotMode.MOCAP): - try: - self.robot.set_damping() - time.sleep(0.5) - except Exception: - pass - try: - self.robot.exit_debug_mode() - except Exception: - pass - try: - self._video_runtime.stop() - except Exception: - pass - try: - self.input_provider.close() - except Exception: - pass - try: - self.robot.close() - except Exception: - pass diff --git a/teleopit/sim2real/hands/__init__.py b/teleopit/sim2real/hands/__init__.py new file mode 100644 index 00000000..de2de569 --- /dev/null +++ b/teleopit/sim2real/hands/__init__.py @@ -0,0 +1,3 @@ +from teleopit.sim2real.hands.worker import build_hand_runtime + +__all__ = ["build_hand_runtime"] diff --git a/teleopit/sim2real/hands/base.py b/teleopit/sim2real/hands/base.py new file mode 100644 index 00000000..76257789 --- /dev/null +++ b/teleopit/sim2real/hands/base.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Protocol, Sequence + + +HAND_SIDES = ("left", "right") + + +@dataclass(frozen=True) +class HandPoseCommand: + side: str + pose: tuple[int, ...] + force: bool = False + reason: str = "" + + +class HandDevice(Protocol): + def connect(self) -> None: ... + + def send_pose(self, side: str, pose: Sequence[int], *, force: bool = False, reason: str = "") -> None: ... + + def open_all(self, *, force: bool = False, reason: str = "") -> None: ... + + def close(self) -> None: ... + + +class HandInputMapper(Protocol): + def start(self) -> None: ... + + def map(self, *, controller_snapshot: object | None, hand_snapshot: object | None, active: bool, now_s: float) -> tuple[HandPoseCommand, ...]: ... + + def close(self) -> None: ... diff --git a/teleopit/sim2real/hands/linkerhand_l6.py b/teleopit/sim2real/hands/linkerhand_l6.py new file mode 100644 index 00000000..d7ed00be --- /dev/null +++ b/teleopit/sim2real/hands/linkerhand_l6.py @@ -0,0 +1,453 @@ +from __future__ import annotations + +from dataclasses import dataclass +from importlib.metadata import PackageNotFoundError, version +from pathlib import Path +import importlib.util +import logging +from typing import Any, Sequence + +import numpy as np + +from teleopit.runtime.common import cfg_get +from teleopit.sim2real.hands.base import HAND_SIDES, HandDevice, HandInputMapper, HandPoseCommand +from teleopit.sim2real.hands.pico_landmarks import pico_hand_to_landmarks + +logger = logging.getLogger(__name__) + +PROJECT_ROOT = Path(__file__).resolve().parents[3] +DEFAULT_SOMEHAND_CONFIG = "third_party/somehand/configs/retargeting/bihand/linkerhand_l6_bihand.yaml" +DEFAULT_LINKERHAND_SDK_ROOT = "third_party/linkerhand-python-sdk" +THUMB_YAW_DEFAULT = 10 +OPEN_POSE = (250, THUMB_YAW_DEFAULT, 250, 250, 250, 250) +CLOSE_POSE = (79, THUMB_YAW_DEFAULT, 0, 0, 0, 0) +DEFAULT_SPEED = (50, 50, 50, 50, 50, 50) +VR_HAND_POSE_SPEED = (255, 255, 255, 255, 255, 255) +L6_SDK_JOINT_ORDER = ( + "thumb_cmc_pitch", + "thumb_cmc_roll", + "index_mcp_pitch", + "middle_mcp_pitch", + "ring_mcp_pitch", + "pinky_mcp_pitch", +) + + +@dataclass(frozen=True) +class LinkerHandL6Config: + mode: str + sides: tuple[str, ...] + left_can: str + right_can: str + modbus: str + rate_hz: float + frame_timeout_s: float + trigger_deadzone: float + deadman_threshold: float + thumb_yaw_center: int + speed: tuple[int, ...] + open_pose: tuple[int, ...] + close_pose: tuple[int, ...] + fixed_thumb_yaw: int | None + print_input: bool + somehand_config_path: str + somehand_rate_hz: float + somehand_max_iterations: int | None + somehand_temporal_filter_alpha: float | None + somehand_output_alpha: float | None + + +def parse_linkerhand_l6_config(cfg: Any) -> LinkerHandL6Config: + hands_cfg = cfg_get(cfg, "hands", {}) or {} + l6_cfg = cfg_get(hands_cfg, "linkerhand_l6", {}) or {} + somehand_cfg = cfg_get(hands_cfg, "somehand", {}) or {} + mode = str(cfg_get(hands_cfg, "mode", "gripper")).strip().lower() + if mode not in ("gripper", "vr_hand_pose"): + raise ValueError(f"hands.mode must be gripper or vr_hand_pose, got {mode!r}") + sides = tuple(str(side).strip().lower() for side in cfg_get(hands_cfg, "sides", HAND_SIDES)) + if not sides or any(side not in HAND_SIDES for side in sides): + raise ValueError("hands.sides must contain left, right, or both sides") + thumb_yaw = _uint8(cfg_get(l6_cfg, "thumb_yaw_center", THUMB_YAW_DEFAULT), "thumb_yaw_center") + open_pose = _pose_values(cfg_get(l6_cfg, "open_pose", OPEN_POSE), "open_pose") + close_pose = _pose_values(cfg_get(l6_cfg, "close_pose", CLOSE_POSE), "close_pose") + open_pose[1] = thumb_yaw + close_pose[1] = thumb_yaw + speed = VR_HAND_POSE_SPEED if mode == "vr_hand_pose" else tuple(_pose_values(cfg_get(l6_cfg, "speed", DEFAULT_SPEED), "speed")) + return LinkerHandL6Config( + mode=mode, + sides=sides, + left_can=str(cfg_get(l6_cfg, "left_can", "can0")), + right_can=str(cfg_get(l6_cfg, "right_can", "can1")), + modbus=str(cfg_get(l6_cfg, "modbus", "None")), + rate_hz=_positive_float(cfg_get(hands_cfg, "rate_hz", cfg_get(l6_cfg, "rate_hz", 30.0)), "rate_hz"), + frame_timeout_s=_positive_float(cfg_get(hands_cfg, "frame_timeout_s", 0.3), "frame_timeout_s"), + trigger_deadzone=_deadzone(cfg_get(l6_cfg, "trigger_deadzone", 0.05)), + deadman_threshold=_threshold(cfg_get(l6_cfg, "deadman_threshold", 0.5)), + thumb_yaw_center=thumb_yaw, + speed=tuple(speed), + open_pose=tuple(open_pose), + close_pose=tuple(close_pose), + fixed_thumb_yaw=thumb_yaw, + print_input=bool(cfg_get(l6_cfg, "print_input", False)), + somehand_config_path=str(cfg_get(somehand_cfg, "config_path", DEFAULT_SOMEHAND_CONFIG)), + somehand_rate_hz=_positive_float(cfg_get(somehand_cfg, "rate_hz", cfg_get(somehand_cfg, "rate", 60.0)), "somehand.rate_hz"), + somehand_max_iterations=_optional_positive_int(cfg_get(somehand_cfg, "max_iterations", None), "somehand.max_iterations"), + somehand_temporal_filter_alpha=_optional_alpha(cfg_get(somehand_cfg, "temporal_filter_alpha", None), "somehand.temporal_filter_alpha"), + somehand_output_alpha=_optional_alpha(cfg_get(somehand_cfg, "output_alpha", None), "somehand.output_alpha"), + ) + + +class LinkerHandL6Device(HandDevice): + def __init__(self, config: LinkerHandL6Config): + self.config = config + self._hands: dict[str, Any] = {} + self._last_pose: dict[str, tuple[int, ...] | None] = {side: None for side in config.sides} + + def connect(self) -> None: + try: + from LinkerHand.linker_hand_api import LinkerHandApi + except ImportError as exc: + raise ImportError( + "LinkerHand SDK is required for hands.driver=linkerhand_l6. " + "Install it with: pip install -e third_party/linkerhand-python-sdk" + ) from exc + try: + for side in self.config.sides: + hand = LinkerHandApi( + hand_joint="L6", + hand_type=side, + modbus=self.config.modbus, + can=self.config.left_can if side == "left" else self.config.right_can, + ) + hand.set_speed(speed=list(self.config.speed)) + self._hands[side] = hand + except (Exception, SystemExit) as exc: + self.close() + if isinstance(exc, SystemExit): + raise RuntimeError("LinkerHand SDK exited during startup") from exc + raise + self.open_all(force=True, reason="startup") + + def send_pose(self, side: str, pose: Sequence[int], *, force: bool = False, reason: str = "") -> None: + del reason + next_pose = tuple(_uint8(value, f"{side}.pose") for value in pose) + if not force and self._last_pose.get(side) == next_pose: + return + hand = self._hands.get(side) + if hand is None: + return + hand.finger_move(pose=list(next_pose)) + self._last_pose[side] = next_pose + + def open_all(self, *, force: bool = False, reason: str = "") -> None: + for side in self.config.sides: + self.send_pose(side, self.config.open_pose, force=force, reason=reason) + + def close(self) -> None: + try: + self.open_all(force=True, reason="shutdown") + except Exception: + logger.exception("Failed to open LinkerHand L6 on shutdown") + for hand in self._hands.values(): + inner = getattr(hand, "hand", None) + close = getattr(inner, "close", None) + if callable(close): + close() + self._hands.clear() + + +class GripperMapper(HandInputMapper): + def __init__(self, config: Any): + self.config = config + self._fixed_thumb_yaw = getattr(config, "fixed_thumb_yaw", getattr(config, "thumb_yaw_center", None)) + self._active = False + self._next_tick_s = 0.0 + + def start(self) -> None: + pass + + def map(self, *, controller_snapshot: object | None, hand_snapshot: object | None, active: bool, now_s: float) -> tuple[HandPoseCommand, ...]: + del hand_snapshot + if not active: + if not self._active: + return () + self._active = False + return tuple(HandPoseCommand(side, self.config.open_pose, True, "inactive") for side in self.config.sides) + if now_s < self._next_tick_s: + return () + self._active = True + self._next_tick_s = now_s + 1.0 / self.config.rate_hz + if controller_snapshot is None or now_s - float(getattr(controller_snapshot, "timestamp_s", 0.0)) > self.config.frame_timeout_s: + return tuple(HandPoseCommand(side, self.config.open_pose, False, "timeout") for side in self.config.sides) + commands: list[HandPoseCommand] = [] + for side in self.config.sides: + state = getattr(controller_snapshot, side) + if not bool(getattr(state, "present", True)): + commands.append(HandPoseCommand(side, self.config.open_pose, False, "missing-controller")) + continue + grip = _clamp01(getattr(state, "grip", 0.0)) + trigger = _clamp01(getattr(state, "trigger", 0.0)) + if grip < self.config.deadman_threshold: + commands.append(HandPoseCommand(side, self.config.open_pose, False, "deadman")) + continue + pose = trigger_to_pose( + trigger, + open_pose=self.config.open_pose, + close_pose=self.config.close_pose, + deadzone=self.config.trigger_deadzone, + fixed_thumb_yaw=self._fixed_thumb_yaw, + ) + commands.append(HandPoseCommand(side, tuple(pose), False, "controller")) + return tuple(commands) + + def close(self) -> None: + pass + + +class SomehandL6Mapper(HandInputMapper): + def __init__(self, config: LinkerHandL6Config): + self.config = config + self._engine: Any | None = None + self._hand_frame_cls: Any | None = None + self._bihand_frame_cls: Any | None = None + self._mappers: dict[str, L6RetargetPoseMapper] = {} + self._next_tick_s = 0.0 + self._active = False + + def start(self) -> None: + _require_somehand_020() + from somehand.api import HandFrame, RetargetingEngine, load_bihand_config, load_retargeting_config + + config_path = _resolve_project_path(self.config.somehand_config_path) + if not config_path.exists(): + raise FileNotFoundError(f"somehand L6 config not found: {config_path}") + bihand_config = load_bihand_config(str(config_path)) + self._engine = {} + for side, path in (("left", bihand_config.left_config_path), ("right", bihand_config.right_config_path)): + retarget_cfg = load_retargeting_config(path) + self._apply_low_latency_overrides(retarget_cfg) + self._engine[side] = RetargetingEngine(retarget_cfg) + self._hand_frame_cls = HandFrame + for side, engine in self._engine.items(): + if side in self.config.sides: + self._mappers[side] = L6RetargetPoseMapper(getattr(engine, "hand_model", None), side=side) + + def map(self, *, controller_snapshot: object | None, hand_snapshot: object | None, active: bool, now_s: float) -> tuple[HandPoseCommand, ...]: + del controller_snapshot + if not active: + if not self._active: + return () + self._active = False + return tuple(HandPoseCommand(side, self.config.open_pose, True, "inactive") for side in self.config.sides) + if now_s < self._next_tick_s: + return () + self._active = True + self._next_tick_s = now_s + 1.0 / self.config.somehand_rate_hz + if hand_snapshot is None or now_s - float(getattr(hand_snapshot, "timestamp_s", 0.0)) > self.config.frame_timeout_s: + return () + left_frame = self._make_frame("left", getattr(hand_snapshot, "left", None)) + right_frame = self._make_frame("right", getattr(hand_snapshot, "right", None)) + if left_frame is None and right_frame is None: + return () + commands: list[HandPoseCommand] = [] + for side, frame in (("left", left_frame), ("right", right_frame)): + if side not in self.config.sides or frame is None: + continue + step = self._engine[side].process(frame) + commands.append(HandPoseCommand(side, tuple(self._mappers[side].qpos_to_pose(step.qpos)), False, "vr-hand-pose")) + return tuple(commands) + + def close(self) -> None: + pass + + def _make_frame(self, side: str, state: object | None) -> object | None: + if side not in self.config.sides or state is None: + return None + if not bool(getattr(state, "present", False)) or not bool(getattr(state, "active", False)): + return None + landmarks = pico_hand_to_landmarks(getattr(state, "joints")) + return self._hand_frame_cls(landmarks_3d=landmarks, landmarks_2d=None, hand_side=side) + + def _apply_low_latency_overrides(self, cfg: object) -> None: + if self.config.somehand_max_iterations is not None: + cfg.solver.max_iterations = int(self.config.somehand_max_iterations) + if self.config.somehand_output_alpha is not None: + cfg.solver.output_alpha = float(self.config.somehand_output_alpha) + if self.config.somehand_temporal_filter_alpha is not None: + cfg.preprocess.temporal_filter_alpha = float(self.config.somehand_temporal_filter_alpha) + + +class L6RetargetPoseMapper: + def __init__(self, hand_model: Any | None, *, side: str): + if hand_model is None: + raise ValueError("somehand L6 hand model is missing") + get_index = getattr(hand_model, "get_joint_name_to_qpos_index", None) + if not callable(get_index): + raise ValueError("somehand L6 hand model does not expose get_joint_name_to_qpos_index()") + joint_index = get_index() + self._indices = np.asarray([_resolve_l6_joint_index(joint_index, name, side=side) for name in L6_SDK_JOINT_ORDER], dtype=np.int64) + mapping = _load_linkerhand_mapping_module() + side_key = "l" if side == "left" else "r" + self._mapping = mapping + self._arc_min = np.asarray(getattr(mapping, f"l6_{side_key}_min"), dtype=np.float64) + self._arc_max = np.asarray(getattr(mapping, f"l6_{side_key}_max"), dtype=np.float64) + self._direction = np.asarray(getattr(mapping, f"l6_{side_key}_derict"), dtype=np.int8) + + def qpos_to_pose(self, qpos: object) -> list[int]: + values = np.asarray(qpos, dtype=np.float64).reshape(-1) + selected = values[self._indices] + pose = [] + for index, value in enumerate(selected): + arc = self._mapping.is_within_range(float(value), float(self._arc_min[index]), float(self._arc_max[index])) + if int(self._direction[index]) == -1: + scaled = self._mapping.scale_value(arc, float(self._arc_min[index]), float(self._arc_max[index]), 255.0, 0.0) + else: + scaled = self._mapping.scale_value(arc, float(self._arc_min[index]), float(self._arc_max[index]), 0.0, 255.0) + pose.append(_uint8(round(float(scaled)), "somehand.pose")) + return pose + + +def build_linkerhand_l6(cfg: Any) -> tuple[HandDevice, HandInputMapper]: + config = parse_linkerhand_l6_config(cfg) + device = LinkerHandL6Device(config) + mapper: HandInputMapper = SomehandL6Mapper(config) if config.mode == "vr_hand_pose" else GripperMapper(config) + return device, mapper + + +def trigger_to_pose( + trigger: float, + *, + open_pose: Sequence[int], + close_pose: Sequence[int], + deadzone: float, + fixed_thumb_yaw: int | None = None, + thumb_yaw_default: int | None = None, +) -> list[int]: + if fixed_thumb_yaw is None: + fixed_thumb_yaw = thumb_yaw_default + alpha = _normalize_trigger(trigger, deadzone) + pose = [int(round(float(a) + alpha * (float(b) - float(a)))) for a, b in zip(open_pose, close_pose)] + if fixed_thumb_yaw is not None: + pose[1] = int(fixed_thumb_yaw) + return pose + + +def _require_somehand_020() -> None: + try: + installed = version("somehand") + except PackageNotFoundError as exc: + raise ImportError("somehand==0.2.0 is required for hands.mode=vr_hand_pose") from exc + if installed != "0.2.0": + raise ImportError(f"somehand==0.2.0 is required for hands.mode=vr_hand_pose, found {installed}") + + +def _resolve_l6_joint_index(joint_index: dict[str, int], semantic_name: str, *, side: str) -> int: + for candidate in _l6_joint_candidates(semantic_name, side=side): + if candidate in joint_index: + return int(joint_index[candidate]) + suffixes = tuple(f"_{alias}" for alias in _l6_aliases(semantic_name)) + for name, index in joint_index.items(): + if name in _l6_aliases(semantic_name) or any(name.endswith(suffix) for suffix in suffixes): + return int(index) + raise ValueError(f"Cannot resolve LinkerHand L6 SDK joint {semantic_name!r} in somehand hand model") + + +def _l6_joint_candidates(semantic_name: str, *, side: str) -> tuple[str, ...]: + prefixes = ("", f"{side}_", f"{side[0]}_", f"{side[0].upper()}_", f"{'lh' if side == 'left' else 'rh'}_") + return tuple(f"{prefix}{alias}" for alias in _l6_aliases(semantic_name) for prefix in prefixes) + + +def _l6_aliases(semantic_name: str) -> tuple[str, ...]: + if semantic_name == "thumb_cmc_pitch": + return ("thumb_cmc_pitch", "thumb_pitch") + if semantic_name == "thumb_cmc_roll": + return ("thumb_cmc_roll", "thumb_roll") + aliases = [semantic_name] + if semantic_name.endswith("_mcp_pitch"): + finger = semantic_name[: -len("_mcp_pitch")] + aliases.append(f"{finger}_pitch") + if finger == "pinky": + aliases.extend(("little_mcp_pitch", "little_pitch")) + return tuple(dict.fromkeys(aliases)) + + +def _load_linkerhand_mapping_module() -> Any: + mapping_path = _resolve_project_path(DEFAULT_LINKERHAND_SDK_ROOT) / "LinkerHand" / "utils" / "mapping.py" + spec = importlib.util.spec_from_file_location("teleopit_linkerhand_mapping", mapping_path) + if spec is None or spec.loader is None: + raise RuntimeError(f"Cannot load LinkerHand mapping module from {mapping_path}") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def _resolve_project_path(path_value: str) -> Path: + path = Path(path_value).expanduser() + return path if path.is_absolute() else (PROJECT_ROOT / path).resolve() + + +def _clamp01(value: object) -> float: + return max(0.0, min(1.0, float(value))) + + +def _normalize_trigger(value: float, deadzone: float) -> float: + value = _clamp01(value) + if value <= deadzone: + return 0.0 + upper = 1.0 - deadzone + if value >= upper: + return 1.0 + return (value - deadzone) / (upper - deadzone) + + +def _uint8(value: object, field_name: str) -> int: + parsed = int(value) + if parsed < 0 or parsed > 255: + raise ValueError(f"hands.linkerhand_l6.{field_name} must be in 0-255, got {value!r}") + return parsed + + +def _pose_values(value: object, field_name: str) -> list[int]: + parsed = [_uint8(item, field_name) for item in value] # type: ignore[union-attr] + if len(parsed) != 6: + raise ValueError(f"hands.linkerhand_l6.{field_name} must contain 6 values") + return parsed + + +def _positive_float(value: object, field_name: str) -> float: + parsed = float(value) + if parsed <= 0: + raise ValueError(f"hands.{field_name} must be > 0") + return parsed + + +def _optional_positive_int(value: object, field_name: str) -> int | None: + if value is None: + return None + parsed = int(value) + if parsed <= 0: + raise ValueError(f"hands.{field_name} must be > 0") + return parsed + + +def _optional_alpha(value: object, field_name: str) -> float | None: + if value is None: + return None + parsed = float(value) + if parsed <= 0.0 or parsed > 1.0: + raise ValueError(f"hands.{field_name} must be in (0, 1]") + return parsed + + +def _deadzone(value: object) -> float: + parsed = float(value) + if parsed < 0.0 or parsed >= 0.5: + raise ValueError("hands.linkerhand_l6.trigger_deadzone must be in [0, 0.5)") + return parsed + + +def _threshold(value: object) -> float: + parsed = float(value) + if parsed <= 0.0 or parsed >= 1.0: + raise ValueError("hands.linkerhand_l6.deadman_threshold must be in (0, 1)") + return parsed diff --git a/teleopit/sim2real/hands/linkerhand_o6.py b/teleopit/sim2real/hands/linkerhand_o6.py new file mode 100644 index 00000000..6d4de349 --- /dev/null +++ b/teleopit/sim2real/hands/linkerhand_o6.py @@ -0,0 +1,163 @@ +from __future__ import annotations + +from dataclasses import dataclass +import logging +from typing import Any, Sequence + +from teleopit.runtime.common import cfg_get +from teleopit.sim2real.hands.base import HAND_SIDES, HandDevice, HandInputMapper +from teleopit.sim2real.hands.linkerhand_l6 import GripperMapper + +logger = logging.getLogger(__name__) + +OPEN_POSE = (250, 250, 250, 250, 250, 250) +CLOSE_POSE = (86, 73, 118, 111, 110, 111) +DEFAULT_SPEED = (255, 255, 255, 255, 255, 255) + + +@dataclass(frozen=True) +class LinkerHandO6Config: + mode: str + sides: tuple[str, ...] + left_can: str + right_can: str + modbus: str + rate_hz: float + frame_timeout_s: float + trigger_deadzone: float + deadman_threshold: float + speed: tuple[int, ...] + open_pose: tuple[int, ...] + close_pose: tuple[int, ...] + fixed_thumb_yaw: int | None + print_input: bool + + +def parse_linkerhand_o6_config(cfg: Any) -> LinkerHandO6Config: + hands_cfg = cfg_get(cfg, "hands", {}) or {} + o6_cfg = cfg_get(hands_cfg, "linkerhand_o6", {}) or {} + mode = str(cfg_get(hands_cfg, "mode", "gripper")).strip().lower() + if mode != "gripper": + raise ValueError(f"hands.driver=linkerhand_o6 supports only hands.mode=gripper, got {mode!r}") + sides = tuple(str(side).strip().lower() for side in cfg_get(hands_cfg, "sides", HAND_SIDES)) + if not sides or any(side not in HAND_SIDES for side in sides): + raise ValueError("hands.sides must contain left, right, or both sides") + return LinkerHandO6Config( + mode=mode, + sides=sides, + left_can=str(cfg_get(o6_cfg, "left_can", "can0")), + right_can=str(cfg_get(o6_cfg, "right_can", "can1")), + modbus=str(cfg_get(o6_cfg, "modbus", "None")), + rate_hz=_positive_float(cfg_get(hands_cfg, "rate_hz", cfg_get(o6_cfg, "rate_hz", 30.0)), "rate_hz"), + frame_timeout_s=_positive_float(cfg_get(hands_cfg, "frame_timeout_s", 0.3), "frame_timeout_s"), + trigger_deadzone=_deadzone(cfg_get(o6_cfg, "trigger_deadzone", 0.05)), + deadman_threshold=_threshold(cfg_get(o6_cfg, "deadman_threshold", 0.5)), + speed=tuple(_pose_values(cfg_get(o6_cfg, "speed", DEFAULT_SPEED), "speed")), + open_pose=tuple(_pose_values(cfg_get(o6_cfg, "open_pose", OPEN_POSE), "open_pose")), + close_pose=tuple(_pose_values(cfg_get(o6_cfg, "close_pose", CLOSE_POSE), "close_pose")), + fixed_thumb_yaw=None, + print_input=bool(cfg_get(o6_cfg, "print_input", False)), + ) + + +class LinkerHandO6Device(HandDevice): + def __init__(self, config: LinkerHandO6Config): + self.config = config + self._hands: dict[str, Any] = {} + self._last_pose: dict[str, tuple[int, ...] | None] = {side: None for side in config.sides} + + def connect(self) -> None: + try: + from LinkerHand.linker_hand_api import LinkerHandApi + except ImportError as exc: + raise ImportError( + "LinkerHand SDK is required for hands.driver=linkerhand_o6. " + "Install it with: pip install -e third_party/linkerhand-python-sdk" + ) from exc + try: + for side in self.config.sides: + hand = LinkerHandApi( + hand_joint="O6", + hand_type=side, + modbus=self.config.modbus, + can=self.config.left_can if side == "left" else self.config.right_can, + ) + hand.set_speed(speed=list(self.config.speed)) + self._hands[side] = hand + except (Exception, SystemExit) as exc: + self.close() + if isinstance(exc, SystemExit): + raise RuntimeError("LinkerHand SDK exited during startup") from exc + raise + self.open_all(force=True, reason="startup") + + def send_pose(self, side: str, pose: Sequence[int], *, force: bool = False, reason: str = "") -> None: + del reason + next_pose = tuple(_uint8(value, f"{side}.pose") for value in pose) + if not force and self._last_pose.get(side) == next_pose: + return + hand = self._hands.get(side) + if hand is None: + return + hand.finger_move(pose=list(next_pose)) + self._last_pose[side] = next_pose + + def open_all(self, *, force: bool = False, reason: str = "") -> None: + for side in self.config.sides: + self.send_pose(side, self.config.open_pose, force=force, reason=reason) + + def close(self) -> None: + try: + self.open_all(force=True, reason="shutdown") + except Exception: + logger.exception("Failed to open LinkerHand O6 on shutdown") + for hand in self._hands.values(): + inner = getattr(hand, "hand", None) + close_backend = getattr(inner, "close_can_interface", None) + if callable(close_backend): + close_backend() + continue + close = getattr(inner, "close", None) + if callable(close): + close() + self._hands.clear() + + +def build_linkerhand_o6(cfg: Any) -> tuple[HandDevice, HandInputMapper]: + config = parse_linkerhand_o6_config(cfg) + return LinkerHandO6Device(config), GripperMapper(config) + + +def _uint8(value: object, field_name: str) -> int: + parsed = int(value) + if parsed < 0 or parsed > 255: + raise ValueError(f"hands.linkerhand_o6.{field_name} must be in 0-255, got {value!r}") + return parsed + + +def _pose_values(value: object, field_name: str) -> list[int]: + parsed = [_uint8(item, field_name) for item in value] # type: ignore[union-attr] + if len(parsed) != 6: + raise ValueError(f"hands.linkerhand_o6.{field_name} must contain 6 values") + return parsed + + +def _positive_float(value: object, field_name: str) -> float: + parsed = float(value) + if parsed <= 0: + raise ValueError(f"hands.{field_name} must be > 0") + return parsed + + +def _deadzone(value: object) -> float: + parsed = float(value) + if parsed < 0.0 or parsed >= 0.5: + raise ValueError("hands.linkerhand_o6.trigger_deadzone must be in [0, 0.5)") + return parsed + + +def _threshold(value: object) -> float: + parsed = float(value) + if parsed <= 0.0 or parsed >= 1.0: + raise ValueError("hands.linkerhand_o6.deadman_threshold must be in (0, 1)") + return parsed diff --git a/teleopit/sim2real/hands/pico_landmarks.py b/teleopit/sim2real/hands/pico_landmarks.py new file mode 100644 index 00000000..c747046c --- /dev/null +++ b/teleopit/sim2real/hands/pico_landmarks.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +import numpy as np + +PICO_BRIDGE_TO_MEDIAPIPE = ( + 1, + 2, + 3, + 4, + 5, + 7, + 8, + 9, + 10, + 12, + 13, + 14, + 15, + 17, + 18, + 19, + 20, + 22, + 23, + 24, + 25, +) + +PICO_NATIVE_TO_RH = np.array( + [[1.0, 0.0, 0.0], [0.0, 0.0, -1.0], [0.0, 1.0, 0.0]], + dtype=np.float64, +) + + +def pico_hand_to_landmarks(hand_state: object) -> np.ndarray: + state = np.asarray(hand_state, dtype=np.float64) + if state.shape != (26, 7): + state = state.reshape(26, 7) + positions = state[:, :3] @ PICO_NATIVE_TO_RH.T + landmarks = np.empty((21, 3), dtype=np.float64) + for mp_index, pico_index in enumerate(PICO_BRIDGE_TO_MEDIAPIPE): + landmarks[mp_index] = positions[pico_index] + return landmarks diff --git a/teleopit/sim2real/hands/worker.py b/teleopit/sim2real/hands/worker.py new file mode 100644 index 00000000..e9fe323a --- /dev/null +++ b/teleopit/sim2real/hands/worker.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +import logging +import time +from typing import Any, Sequence + +from teleopit.runtime.common import cfg_get +from teleopit.sim2real.hands.base import HandDevice, HandInputMapper, HandPoseCommand +from teleopit.sim2real.hands.linkerhand_l6 import build_linkerhand_l6 +from teleopit.sim2real.hands.linkerhand_o6 import build_linkerhand_o6 + +logger = logging.getLogger(__name__) + + +class HandRuntime: + def __init__( + self, + device: HandDevice, + mapper: HandInputMapper, + *, + open_commands: Sequence[HandPoseCommand] = (), + ): + self._device = device + self._mapper = mapper + self.enabled = True + self._failed = False + self._open_commands = tuple(open_commands) + + def start(self) -> tuple[HandPoseCommand, ...]: + try: + self._device.connect() + self._mapper.start() + return self._open_pose_commands("startup") + except Exception: + try: + self._device.close() + finally: + raise + + def tick( + self, + *, + controller_snapshot: object | None, + hand_snapshot: object | None, + active: bool, + now_s: float | None = None, + ) -> tuple[HandPoseCommand, ...]: + if self._failed: + return () + now = time.monotonic() if now_s is None else float(now_s) + try: + commands = self._mapper.map( + controller_snapshot=controller_snapshot, + hand_snapshot=hand_snapshot, + active=active, + now_s=now, + ) + sent: list[HandPoseCommand] = [] + for command in commands: + self._device.send_pose(command.side, command.pose, force=command.force, reason=command.reason) + sent.append(command) + return tuple(sent) + except Exception: + self._failed = True + logger.exception("Hand runtime failed; disabling hand control") + try: + self._device.open_all(force=True, reason="failure") + except Exception: + logger.exception("Failed to open hand after hand runtime failure") + return () + return self._open_pose_commands("failure") + + def close(self) -> tuple[HandPoseCommand, ...]: + try: + self._mapper.close() + finally: + self._device.close() + return self._open_pose_commands("shutdown") + + def _open_pose_commands(self, reason: str) -> tuple[HandPoseCommand, ...]: + return tuple( + HandPoseCommand(command.side, command.pose, True, reason) + for command in self._open_commands + ) + + +class DisabledHandRuntime: + enabled = False + + def start(self) -> tuple[HandPoseCommand, ...]: + return () + + def tick( + self, + *, + controller_snapshot: object | None, + hand_snapshot: object | None, + active: bool, + now_s: float | None = None, + ) -> tuple[HandPoseCommand, ...]: + del controller_snapshot, hand_snapshot, active, now_s + return () + + def close(self) -> tuple[HandPoseCommand, ...]: + return () + + +def build_hand_runtime(cfg: Any) -> HandRuntime | DisabledHandRuntime: + hands_cfg = cfg_get(cfg, "hands", {}) or {} + if not bool(cfg_get(hands_cfg, "enabled", False)): + return DisabledHandRuntime() + driver = str(cfg_get(hands_cfg, "driver", "linkerhand_l6")).strip().lower() + if driver == "linkerhand_l6": + device, mapper = build_linkerhand_l6(cfg) + elif driver == "linkerhand_o6": + device, mapper = build_linkerhand_o6(cfg) + else: + raise ValueError(f"Unsupported hands.driver={driver!r}; supported drivers: linkerhand_l6, linkerhand_o6") + return HandRuntime(device, mapper, open_commands=_open_commands_from_device(device)) + + +def _open_commands_from_device(device: HandDevice) -> tuple[HandPoseCommand, ...]: + config = getattr(device, "config", None) + sides = tuple(str(side).strip().lower() for side in getattr(config, "sides", ())) + open_pose = tuple(int(value) for value in getattr(config, "open_pose", ())) + if not sides or len(open_pose) != 6: + return () + return tuple(HandPoseCommand(side, open_pose, True, "open") for side in sides) diff --git a/teleopit/sim2real/mp/__init__.py b/teleopit/sim2real/mp/__init__.py new file mode 100644 index 00000000..f32f55cd --- /dev/null +++ b/teleopit/sim2real/mp/__init__.py @@ -0,0 +1,9 @@ +"""Process-isolated sim2real runtime.""" + +from teleopit.sim2real.mp.runtime import ( + Sim2RealRuntime, +) + +__all__ = [ + "Sim2RealRuntime", +] diff --git a/teleopit/sim2real/mp/ipc.py b/teleopit/sim2real/mp/ipc.py new file mode 100644 index 00000000..e0c3b1c9 --- /dev/null +++ b/teleopit/sim2real/mp/ipc.py @@ -0,0 +1,129 @@ +"""Small-message IPC helpers for multiprocess sim2real.""" + +from __future__ import annotations + +from dataclasses import dataclass +import pickle +import time +from typing import Any + +import zmq + + +BODY_TOPIC = "body" +HAND_TOPIC = "hand" +HAND_COMMAND_TOPIC = "hand_command" +CONTROLLER_TOPIC = "controller" +CONTROL_EVENTS_TOPIC = "control_events" +REFERENCE_TOPIC = "reference" +MODE_TOPIC = "mode" +VIDEO_TOPIC = "video" +RECORD_TOPIC = "record" +HEALTH_TOPIC = "health" +COMMAND_TOPIC = "command" + + +@dataclass(frozen=True) +class Sim2RealIpcEndpoints: + body_pub: str + hand_pub: str + hand_command_pub: str + controller_pub: str + control_events_pub: str + reference_pub: str + mode_pub: str + video_pub: str + record_pub: str + health_pub: str + command_pub: str + reference_command_pub: str + + +def default_endpoints(*, host: str = "127.0.0.1", base_port: int = 39700) -> Sim2RealIpcEndpoints: + """Return deterministic localhost TCP endpoints for one sim2real runtime.""" + prefix = f"tcp://{host}:" + return Sim2RealIpcEndpoints( + body_pub=f"{prefix}{base_port}", + hand_pub=f"{prefix}{base_port + 1}", + hand_command_pub=f"{prefix}{base_port + 2}", + controller_pub=f"{prefix}{base_port + 3}", + control_events_pub=f"{prefix}{base_port + 4}", + reference_pub=f"{prefix}{base_port + 5}", + mode_pub=f"{prefix}{base_port + 6}", + video_pub=f"{prefix}{base_port + 7}", + record_pub=f"{prefix}{base_port + 8}", + health_pub=f"{prefix}{base_port + 9}", + command_pub=f"{prefix}{base_port + 10}", + reference_command_pub=f"{prefix}{base_port + 11}", + ) + + +class ZmqPublisher: + """Topic publisher with low watermarks for realtime latest-only streams.""" + + def __init__(self, endpoint: str, *, context: zmq.Context[Any] | None = None) -> None: + self._own_context = context is None + self._context = zmq.Context() if context is None else context + self._socket = self._context.socket(zmq.PUB) + self._socket.setsockopt(zmq.SNDHWM, 1) + self._socket.setsockopt(zmq.LINGER, 0) + self._socket.bind(endpoint) + self._endpoint = endpoint + # Give subscribers a short chance to connect during process startup. + time.sleep(0.05) + + @property + def endpoint(self) -> str: + return self._endpoint + + def publish(self, topic: str, payload: object) -> None: + try: + self._socket.send_multipart( + [topic.encode("utf-8"), pickle.dumps(payload, protocol=pickle.HIGHEST_PROTOCOL)], + flags=zmq.NOBLOCK, + ) + except zmq.Again: + # Realtime streams are latest-only; dropping beats backpressure. + return + + def close(self) -> None: + self._socket.close(linger=0) + if self._own_context: + self._context.term() + + +class LatestSubscriber: + """Subscriber that drains all pending messages and returns only the latest.""" + + def __init__( + self, + endpoint: str, + topic: str, + *, + context: zmq.Context[Any] | None = None, + ) -> None: + self._own_context = context is None + self._context = zmq.Context() if context is None else context + self._socket = self._context.socket(zmq.SUB) + self._socket.setsockopt(zmq.RCVHWM, 1) + self._socket.setsockopt(zmq.LINGER, 0) + self._socket.setsockopt_string(zmq.SUBSCRIBE, topic) + self._socket.connect(endpoint) + self._topic = topic + + def recv_latest(self) -> object | None: + latest: object | None = None + while True: + try: + topic_raw, payload_raw = self._socket.recv_multipart(flags=zmq.NOBLOCK) + except zmq.Again: + return latest + topic = topic_raw.decode("utf-8") + if topic != self._topic: + continue + latest = pickle.loads(payload_raw) + + def close(self) -> None: + self._socket.close(linger=0) + if self._own_context: + self._context.term() diff --git a/teleopit/sim2real/mp/messages.py b/teleopit/sim2real/mp/messages.py new file mode 100644 index 00000000..a56bb1e7 --- /dev/null +++ b/teleopit/sim2real/mp/messages.py @@ -0,0 +1,108 @@ +"""Pickle-serializable IPC message contracts for multiprocess sim2real.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +import numpy as np +from numpy.typing import NDArray + +from teleopit.inputs.realtime_packet import ControlEvent, HumanFrame +from teleopit.sim.reference_timeline import ReferenceWindow + + +Float64Array = NDArray[np.float64] + + +@dataclass(frozen=True) +class BodyFramePacket: + frame: HumanFrame + timestamp_s: float + seq: int + + +@dataclass(frozen=True) +class ReferencePacket: + qpos: Float64Array + timestamp_s: float + seq: int + source_timestamp_s: float + source_seq: int + frame_valid: bool = True + reference_window: ReferenceWindow | None = None + retarget_elapsed_s: float = 0.0 + playback_paused: bool = False + playback_finished: bool = False + + +@dataclass(frozen=True) +class ControlEventsPacket: + events: tuple[ControlEvent, ...] + timestamp_s: float + seq: int + + +@dataclass(frozen=True) +class SnapshotPacket: + snapshot: Any + timestamp_s: float + seq: int + + +@dataclass(frozen=True) +class ModeStatePacket: + mode: str + mocap_active: bool + mocap_paused: bool + timestamp_s: float + seq: int + + +@dataclass(frozen=True) +class RecordStepPacket: + timestamp_s: float + mode: str + mocap_active: bool + recordable: bool + observation_state: Float64Array + observation_mode: Float64Array + action_reference_qpos: Float64Array + seq: int + + +@dataclass(frozen=True) +class HandCommandPacket: + timestamp_s: float + driver: str + mode: str + active: bool + left_pose: Float64Array + right_pose: Float64Array + seq: int + + +@dataclass(frozen=True) +class HealthPacket: + worker: str + timestamp_s: float + status: str = "ok" + metrics: dict[str, float | int | str] = field(default_factory=dict) + + +@dataclass(frozen=True) +class CommandPacket: + command: str + timestamp_s: float + payload: dict[str, Any] = field(default_factory=dict) + + +@dataclass(frozen=True) +class SharedFrameDescriptor: + shm_name: str + slot: int + seq: int + timestamp_s: float + shape: tuple[int, ...] + dtype: str + slots: int diff --git a/teleopit/sim2real/mp/runtime.py b/teleopit/sim2real/mp/runtime.py new file mode 100644 index 00000000..7cd4e99a --- /dev/null +++ b/teleopit/sim2real/mp/runtime.py @@ -0,0 +1,2018 @@ +"""Multiprocess sim2real runtime using ZMQ and shared memory.""" + +from __future__ import annotations + +import logging +import multiprocessing as mp +from multiprocessing.synchronize import Event as MpEvent +from enum import Enum +from pathlib import Path +import sys +import time +from typing import Any, Callable + +import numpy as np +from numpy.typing import NDArray + +from teleopit.constants import FULL_QPOS_DIM, NUM_JOINTS, ROOT_DIM +from teleopit.controllers.observation import VelCmdObservationBuilder, align_motion_qpos_yaw +from teleopit.controllers.rl_policy import RLPolicyController +from teleopit.inputs.bvh_provider import BVHInputProvider +from teleopit.inputs.human_frame_validation import validate_human_frame +from teleopit.inputs.pico4_provider import Pico4InputProvider +from teleopit.inputs.pico_video import PicoVideoRuntime, bridge_video_source, parse_pico_video_config +from teleopit.inputs.realtime_packet import ControlEvent, ControlEventType +from teleopit.retargeting.core import RetargetingModule +from teleopit.runtime.offline_playback import OfflinePlaybackController +from teleopit.runtime.common import cfg_get, parse_viewers, require_section +from teleopit.runtime.console import ( + OPERATOR_LOGGER_NAME, + PlainConsole, + configure_runtime_logging, + console_show_timing, + console_timing_interval_s, + sim2real_keyboard_controls, +) +from teleopit.runtime.factory import _build_policy_components, build_simulation_cfg +from teleopit.runtime.arm_mocap import ( + compose_arm_reference, + compose_arm_reference_window, + parse_arm_joint_indices, +) +from teleopit.runtime.mocap_session import MocapSessionManager, MocapSessionState +from teleopit.runtime.reference_config import parse_reference_config +from teleopit.runtime.terminal_keyboard import TerminalKeyboardReader +from teleopit.recording.hdf5 import ( + build_mode_observation, + build_mp4_video_config, + build_observation_state, + normalize_hand_action, + normalize_action_reference_qpos, +) +from teleopit.sim.reference_motion import OfflineReferenceMotion +from teleopit.sim.reference_timeline import ReferenceTimeline, ReferenceWindow, ReferenceWindowBuilder +from teleopit.sim.reference_utils import ( + build_offline_reference_window, + build_static_reference_window, + obs_builder_requires_reference_window, +) +from teleopit.sim.realtime_utils import RealtimeReferenceManager +from teleopit.sim.viewer_subprocess import start_robot_viewer +from teleopit.sim2real.hands.worker import build_hand_runtime +from teleopit.sim2real.hands.base import HandPoseCommand +from teleopit.sim2real.hands.linkerhand_l6 import parse_linkerhand_l6_config +from teleopit.sim2real.hands.linkerhand_o6 import parse_linkerhand_o6_config +from teleopit.sim2real.mp.ipc import ( + BODY_TOPIC, + COMMAND_TOPIC, + CONTROL_EVENTS_TOPIC, + CONTROLLER_TOPIC, + HAND_COMMAND_TOPIC, + HAND_TOPIC, + HEALTH_TOPIC, + MODE_TOPIC, + RECORD_TOPIC, + REFERENCE_TOPIC, + VIDEO_TOPIC, + LatestSubscriber, + Sim2RealIpcEndpoints, + ZmqPublisher, + default_endpoints, +) +from teleopit.sim2real.mp.messages import ( + BodyFramePacket, + CommandPacket, + ControlEventsPacket, + HandCommandPacket, + HealthPacket, + ModeStatePacket, + ReferencePacket, + RecordStepPacket, + SnapshotPacket, + SharedFrameDescriptor, +) +from teleopit.sim2real.mp.shm import SharedFrameRingReader, SharedFrameRingWriter +from teleopit.sim2real.reference_processor import Sim2RealReferenceProcessor +from teleopit.sim2real.remote import UnitreeRemote +from teleopit.sim2real.safety import Sim2RealSafetyManager +from teleopit.sim2real.unitree_g1 import UnitreeG1Robot + +try: + from omegaconf import OmegaConf +except ImportError: # pragma: no cover - OmegaConf is a project dependency. + OmegaConf = None # type: ignore[assignment] + + +logger = logging.getLogger(__name__) +operator_logger = logging.getLogger(OPERATOR_LOGGER_NAME) + +Float32Array = NDArray[np.float32] +Float64Array = NDArray[np.float64] +PROJECT_ROOT = Path(__file__).resolve().parents[3] + + +class RobotMode(Enum): + IDLE = "idle" + STANDING = "standing" + MOCAP = "mocap" + ARMS = "arms" + DAMPING = "damping" + + +class _LoopTimingReporter: + def __init__( + self, + *, + target_period_s: float, + log_interval_s: float = 1.0, + deadline_miss_tolerance_s: float = 0.001, + enabled: bool = True, + ) -> None: + self._target_period_s = float(target_period_s) + self._log_interval_s = float(log_interval_s) + self._deadline_miss_tolerance_s = float(deadline_miss_tolerance_s) + self._enabled = bool(enabled) + self._window_start_s: float | None = None + self._loop_ms: list[float] = [] + self._late_ms: list[float] = [] + self._work_ms: list[float] = [] + self._pico_age_ms: list[float] = [] + self._deadline_miss_count = 0 + self._work_overrun_count = 0 + + def record(self, *, loop_start_s: float, work_elapsed_s: float, cycle_elapsed_s: float, pico_age_s: float | None) -> None: + if self._window_start_s is None: + self._window_start_s = float(loop_start_s) + self._loop_ms.append(float(cycle_elapsed_s) * 1000.0) + self._late_ms.append(max(0.0, float(cycle_elapsed_s) - self._target_period_s) * 1000.0) + self._work_ms.append(float(work_elapsed_s) * 1000.0) + if pico_age_s is not None: + self._pico_age_ms.append(float(pico_age_s) * 1000.0) + if cycle_elapsed_s > self._target_period_s + self._deadline_miss_tolerance_s: + self._deadline_miss_count += 1 + if work_elapsed_s > self._target_period_s + 1e-9: + self._work_overrun_count += 1 + if loop_start_s - self._window_start_s >= self._log_interval_s: + self._emit(loop_start_s) + + def _emit(self, end_s: float) -> None: + sample_count = len(self._loop_ms) + if sample_count <= 0: + self._reset(end_s) + return + if not self._enabled: + self._reset(end_s) + return + loop_summary = self._summarize(self._loop_ms) + late_summary = self._summarize(self._late_ms) + work_summary = self._summarize(self._work_ms) + message = ( + "Timing stats | samples=%d window=%.1fs | " + "loop_ms p50=%.2f p95=%.2f p99=%.2f max=%.2f | " + "late_ms p50=%.2f p95=%.2f p99=%.2f max=%.2f deadline_miss(>%.2fms)=%d/%d | " + "work_ms p50=%.2f p95=%.2f p99=%.2f max=%.2f work_overrun=%d/%d" + ) + args: list[object] = [ + sample_count, + end_s - float(self._window_start_s), + *loop_summary, + *late_summary, + self._deadline_miss_tolerance_s * 1000.0, + self._deadline_miss_count, + sample_count, + *work_summary, + self._work_overrun_count, + sample_count, + ] + if self._pico_age_ms: + message += " | reference_age_ms p50=%.2f p95=%.2f p99=%.2f max=%.2f" + args.extend(self._summarize(self._pico_age_ms)) + operator_logger.info(message, *args) + self._reset(end_s) + + def _reset(self, window_start_s: float) -> None: + self._window_start_s = float(window_start_s) + self._loop_ms.clear() + self._late_ms.clear() + self._work_ms.clear() + self._pico_age_ms.clear() + self._deadline_miss_count = 0 + self._work_overrun_count = 0 + + @staticmethod + def _summarize(samples: list[float]) -> tuple[float, float, float, float]: + values = np.asarray(samples, dtype=np.float64) + if values.size <= 0: + return 0.0, 0.0, 0.0, 0.0 + p50, p95, p99 = np.percentile(values, [50.0, 95.0, 99.0]) + return float(p50), float(p95), float(p99), float(np.max(values)) + + +def _parse_sim2real_viewers(cfg: Any) -> set[str]: + viewers = parse_viewers(cfg) + unsupported = viewers.difference({"retarget"}) + if unsupported: + raise ValueError( + f"Sim2real supports only the optional 'retarget' viewer; got unsupported viewers {sorted(unsupported)}. " + "Use viewers=retarget or viewers=none." + ) + return viewers + + +class _Sim2RealRetargetViewer: + def __init__(self, *, xml_path: str | None, enabled: bool) -> None: + self._entry: tuple[Any, Any, Any, Any] | None = None + if not enabled: + return + if not xml_path: + raise ValueError("Sim2real retarget viewer requires robot.xml_path to be set.") + self._entry = start_robot_viewer(xml_path, FULL_QPOS_DIM, True, "Retarget", 900, 50) + + def write(self, qpos: Float64Array) -> None: + if self._entry is None: + return + _, arr, alive, _ = self._entry + if not alive.value: + return + qpos = np.asarray(qpos, dtype=np.float64).reshape(-1) + if qpos.shape[0] < FULL_QPOS_DIM: + return + with arr.get_lock(): + arr[:FULL_QPOS_DIM] = qpos[:FULL_QPOS_DIM].tolist() + + def shutdown(self) -> None: + if self._entry is None: + return + proc, _, _, shutdown = self._entry + shutdown.set() + proc.join(timeout=3) + if proc.is_alive(): + proc.terminate() + self._entry = None + + +def _plain_cfg(cfg: Any) -> dict[str, Any]: + if isinstance(cfg, dict): + return dict(cfg) + if OmegaConf is not None and OmegaConf.is_config(cfg): + return OmegaConf.to_container(cfg, resolve=True) # type: ignore[return-value] + raise TypeError(f"Unsupported sim2real cfg type for multiprocessing: {type(cfg)!r}") + + +def _mp_cfg(cfg: Any) -> Any: + return cfg_get(cfg, "runtime", {}) or {} + + +def _input_provider_kind(cfg: Any) -> str: + return str(cfg_get(cfg_get(cfg, "input", {}) or {}, "provider", "bvh")).strip().lower() + + +def _recording_cfg(cfg: Any) -> Any: + return cfg_get(cfg, "recording", {}) or {} + + +def _recording_enabled(cfg: Any) -> bool: + return bool(cfg_get(_recording_cfg(cfg), "enabled", False)) + + +def _recording_camera_cfg(cfg: Any) -> Any: + return cfg_get(_recording_cfg(cfg), "camera", {}) or {} + + +def _configured_open_hand_pose(cfg: Any) -> tuple[np.ndarray, np.ndarray]: + hands_cfg = cfg_get(cfg, "hands", {}) or {} + driver = str(cfg_get(hands_cfg, "driver", "linkerhand_l6")).strip().lower() + if bool(cfg_get(hands_cfg, "enabled", False)): + if driver == "linkerhand_o6": + hand_cfg = parse_linkerhand_o6_config(cfg) + else: + hand_cfg = parse_linkerhand_l6_config(cfg) + pose = np.asarray(hand_cfg.open_pose, dtype=np.float32).reshape(-1) + elif driver == "linkerhand_o6": + pose = np.array([250, 250, 250, 250, 250, 250], dtype=np.float32) + else: + driver_cfg = cfg_get(hands_cfg, "linkerhand_l6", {}) or {} + thumb_yaw = int(cfg_get(driver_cfg, "thumb_yaw_center", 10)) + pose = np.array([250, thumb_yaw, 250, 250, 250, 250], dtype=np.float32) + if pose.shape[0] != 6: + raise ValueError(f"hands.{driver}.open_pose must contain 6 values") + return pose.copy(), pose.copy() + + +def _validate_new_runtime_config(cfg: Any) -> None: + legacy_keys = [key for key in ("sim2real_runtime", "multiprocess", "dexterous_hand") if cfg_get(cfg, key, None) is not None] + if legacy_keys: + raise ValueError( + "Legacy sim2real config keys are no longer supported: " + f"{', '.join(legacy_keys)}. Use input.provider, runtime, and hands instead." + ) + provider = _input_provider_kind(cfg) + if provider not in ("pico4", "bvh"): + raise ValueError(f"sim2real input.provider must be pico4 or bvh, got {provider!r}") + hands_cfg = cfg_get(cfg, "hands", {}) or {} + if bool(cfg_get(hands_cfg, "enabled", False)) and provider != "pico4": + raise ValueError("hands.enabled=true requires input.provider=pico4") + if _recording_enabled(cfg): + if provider != "pico4": + raise ValueError("recording.enabled=true requires input.provider=pico4") + rec_cfg = _recording_cfg(cfg) + if str(cfg_get(rec_cfg, "format", "hdf5")) != "hdf5": + raise ValueError("Only recording.format=hdf5 is supported") + if str(cfg_get(rec_cfg, "control", "terminal")) != "terminal": + raise ValueError("Only recording.control=terminal is supported") + build_mp4_video_config(cfg_get(rec_cfg, "video", {}) or {}) + camera_cfg = _recording_camera_cfg(cfg) + if not bool(cfg_get(camera_cfg, "enabled", True)): + raise ValueError("recording.camera.enabled=false is not supported for HDF5 recording") + if str(cfg_get(camera_cfg, "source", "realsense")).lower() != "realsense": + raise ValueError("recording.camera.source must be realsense") + if int(cfg_get(rec_cfg, "fps", 30)) != int(cfg_get(camera_cfg, "fps", 30)): + raise ValueError("recording.fps must match recording.camera.fps") + input_video = parse_pico_video_config(cfg_get(cfg, "input", {}) or {}) + if not input_video.enabled: + raise ValueError("recording.enabled=true requires input.video.enabled=true") + if input_video.source != "realsense": + raise ValueError("recording.enabled=true requires input.video.source=realsense") + if int(input_video.width) != int(cfg_get(camera_cfg, "width", 640)): + raise ValueError("recording.camera.width must match input.video.width") + if int(input_video.height) != int(cfg_get(camera_cfg, "height", 480)): + raise ValueError("recording.camera.height must match input.video.height") + if int(input_video.fps) != int(cfg_get(camera_cfg, "fps", 30)): + raise ValueError("recording.camera.fps must match input.video.fps") + input_device = input_video.device + camera_device = cfg_get(camera_cfg, "device", None) + camera_device = None if camera_device in (None, "", "null") else str(camera_device) + if input_device != camera_device: + raise ValueError("recording.camera.device must match input.video.device") + + +def _require_recording_dependencies() -> None: + try: + from teleopit.recording.hdf5 import TeleopitHDF5Recorder + + TeleopitHDF5Recorder.create + except Exception as exc: + raise RuntimeError("HDF5 recording adapter is unavailable") from exc + + +def _worker_loop(name: str, cfg: dict[str, Any], fn: Callable[[], None]) -> None: + configure_runtime_logging(cfg, force=True) + try: + fn() + except KeyboardInterrupt: + pass + except BaseException: + logger.exception("%s worker crashed", name) + raise + + +def _human_frame_is_valid(frame: object) -> bool: + return validate_human_frame(frame).valid + + +class Sim2RealRuntime: + """Supervisor facade for the process-isolated sim2real runtime.""" + + def __init__(self, cfg: Any, *, console: PlainConsole | None = None) -> None: + self.cfg = _plain_cfg(cfg) + _validate_new_runtime_config(self.cfg) + + mp_cfg = _mp_cfg(self.cfg) + video_cfg = parse_pico_video_config(cfg_get(self.cfg, "input", {})) + if video_cfg.enabled and video_cfg.source not in ("realsense", "test-pattern"): + raise ValueError( + "Sim2RealRuntime only supports input.video.source=realsense or test-pattern" + ) + self._ctx = mp.get_context(str(cfg_get(mp_cfg, "start_method", "spawn"))) + self._stop_event = self._ctx.Event() + self._processes: list[mp.Process] = [] + self._shutdown_timeout_s = float(cfg_get(mp_cfg, "shutdown_timeout_s", 3.0)) + self._endpoints = default_endpoints( + host=str(cfg_get(mp_cfg, "host", "127.0.0.1")), + base_port=int(cfg_get(mp_cfg, "base_port", 39700)), + ) + self._command_pub: ZmqPublisher | None = None + self._keyboard: TerminalKeyboardReader | None = None + self._console = console or PlainConsole(title="Teleopit sim2real", enabled=False) + self._console_controls = sim2real_keyboard_controls(self.cfg) + if _recording_enabled(self.cfg): + _require_recording_dependencies() + if not sys.stdin.isatty(): + raise RuntimeError("recording.enabled=true requires an interactive TTY for terminal controls") + if console is None: + self._console = PlainConsole(title="Teleopit sim2real") + + def run(self) -> None: + operator_logger.info("runtime starting") + try: + self._start_processes() + if _recording_enabled(self.cfg): + self._command_pub = ZmqPublisher(self._endpoints.command_pub) + self._keyboard = TerminalKeyboardReader() + operator_logger.info("keyboard recording controls active: R start, S save, D discard, Q shutdown, H help") + while not self._stop_event.is_set(): + self._poll_terminal_recording_controls() + time.sleep(0.2) + critical_names = {"robot_control", "reference"} + if _input_provider_kind(self.cfg) == "pico4": + critical_names.add("pico_input") + critical_dead = [ + process.name + for process in self._processes + if not process.is_alive() + and process.exitcode not in (None, 0) + and process.name in critical_names + ] + if critical_dead: + operator_logger.error("critical worker exited: %s", ", ".join(critical_dead)) + self._stop_event.set() + break + noncritical_dead = [ + process.name + for process in self._processes + if not process.is_alive() + and process.exitcode not in (None, 0) + and process.name not in critical_names + ] + if noncritical_dead: + operator_logger.warning("non-critical worker exited: %s", ", ".join(noncritical_dead)) + except KeyboardInterrupt: + operator_logger.info("keyboard interrupt -> shutting down") + self._stop_event.set() + finally: + self.shutdown() + + def shutdown(self) -> None: + self._stop_event.set() + if self._command_pub is not None: + self._command_pub.publish(COMMAND_TOPIC, CommandPacket(command="shutdown", timestamp_s=time.monotonic())) + for process in self._processes: + process.join(timeout=self._shutdown_timeout_s) + for process in self._processes: + if process.is_alive(): + operator_logger.warning("terminating worker %s", process.name) + process.terminate() + process.join(timeout=1.0) + self._processes.clear() + if self._keyboard is not None: + self._keyboard.close() + self._keyboard = None + if self._command_pub is not None: + self._command_pub.close() + self._command_pub = None + + def _start_processes(self) -> None: + if self._processes: + return + + specs: list[tuple[str, Callable[..., None]]] = [] + if _input_provider_kind(self.cfg) == "pico4": + specs.append(("pico_input", _run_pico_io_worker)) + specs.extend( + [ + ("reference", _run_reference_worker), + ("robot_control", _run_robot_control_worker), + ] + ) + hands_cfg = cfg_get(self.cfg, "hands", {}) or {} + if bool(cfg_get(hands_cfg, "enabled", False)): + specs.append(("hand_worker", _run_hand_worker)) + if _recording_enabled(self.cfg): + specs.append(("recording_worker", _run_recording_worker)) + video_cfg = parse_pico_video_config(cfg_get(self.cfg, "input", {})) + if video_cfg.enabled: + logger.info("Pico video runs inside pico_input so frames are pushed directly to PicoBridge") + + for name, target in specs: + process = self._ctx.Process( + name=name, + target=target, + args=(self.cfg, self._endpoints, self._stop_event), + ) + process.start() + self._processes.append(process) + + def _poll_terminal_recording_controls(self) -> None: + if self._keyboard is None or self._command_pub is None: + return + events = self._keyboard.poll() + if not events: + return + for event in events: + normalized = str(event.key).strip().lower() + if normalized == "h": + self._console.help(self._console_controls) + continue + command = map_recording_key_to_command(event.key) + if command is None: + continue + self._command_pub.publish(COMMAND_TOPIC, CommandPacket(command=command, timestamp_s=time.monotonic())) + self._console.key_feedback(str(event.key).upper(), _recording_command_label(command)) + if command == "shutdown": + self._stop_event.set() + + +def map_recording_key_to_command(key: str) -> str | None: + normalized = str(key).strip().lower() + if normalized == "r": + return "record_start" + if normalized == "s": + return "record_save" + if normalized == "d": + return "record_discard" + if normalized == "q": + return "shutdown" + return None + + +def _recording_command_label(command: str) -> str: + if command == "record_start": + return "start recording" + if command == "record_save": + return "save recording" + if command == "record_discard": + return "discard recording" + if command == "shutdown": + return "shutdown" + return command + + +def _run_pico_io_worker( + cfg: dict[str, Any], + endpoints: Sim2RealIpcEndpoints, + stop_event: MpEvent, +) -> None: + def _main() -> None: + input_cfg = cfg_get(cfg, "input", {}) or {} + video_cfg = parse_pico_video_config(input_cfg) + provider = Pico4InputProvider( + human_format=str(cfg_get(input_cfg, "human_format", "pico_bridge")), + timeout=float(cfg_get(input_cfg, "pico4_timeout", 60.0)), + buffer_size=int(cfg_get(input_cfg, "pico4_buffer_size", 60)), + timestamp_gap_reset_s=float(cfg_get(input_cfg, "pico4_timestamp_gap_reset_s", 0.15)), + pause_button=cfg_get(input_cfg, "pause_button", "A"), + pause_debounce_s=float(cfg_get(input_cfg, "pause_debounce_s", 0.25)), + arms_button=cfg_get(input_cfg, "arms_button", "B"), + arms_debounce_s=float(cfg_get(input_cfg, "arms_debounce_s", cfg_get(input_cfg, "pause_debounce_s", 0.25))), + bridge_host=str(cfg_get(input_cfg, "bridge_host", "0.0.0.0")), + bridge_port=int(cfg_get(input_cfg, "bridge_port", 63901)), + bridge_discovery=bool(cfg_get(input_cfg, "bridge_discovery", True)), + bridge_advertise_ip=cfg_get(input_cfg, "bridge_advertise_ip", None), + bridge_video=bridge_video_source(video_cfg), + bridge_video_enabled=video_cfg.enabled, + bridge_start_timeout=float(cfg_get(input_cfg, "bridge_start_timeout", 10.0)), + bridge_history_size=int(cfg_get(input_cfg, "bridge_history_size", 120)), + ) + + body_pub = ZmqPublisher(endpoints.body_pub) + hand_pub = ZmqPublisher(endpoints.hand_pub) + controller_pub = ZmqPublisher(endpoints.controller_pub) + events_pub = ZmqPublisher(endpoints.control_events_pub) + health_pub = ZmqPublisher(endpoints.health_pub) + video_pub = ZmqPublisher(endpoints.video_pub) if _recording_enabled(cfg) else None + command_sub = LatestSubscriber(endpoints.command_pub, COMMAND_TOPIC) + frame_writer: SharedFrameRingWriter | None = None + + def _publish_recording_frame(frame: NDArray[np.generic], timestamp_s: float) -> None: + nonlocal frame_writer + if video_pub is None: + return + if frame_writer is None: + frame_writer = SharedFrameRingWriter( + shape=tuple(np.asarray(frame).shape), + dtype=np.uint8, + slots=int(cfg_get(_mp_cfg(cfg), "video_slots", 3)), + ) + descriptor = frame_writer.write(np.asarray(frame, dtype=np.uint8), timestamp_s=float(timestamp_s)) + video_pub.publish(VIDEO_TOPIC, descriptor) + + video_runtime = PicoVideoRuntime( + provider=provider, + config=video_cfg, + mode="sim2real", + frame_callback=_publish_recording_frame if _recording_enabled(cfg) else None, + ) + + hz = float(cfg_get(_mp_cfg(cfg), "pico_input_hz", 120.0)) + sleep_s = 1.0 / max(hz, 1.0) + last_body_seq = -1 + last_hand_seq = -1 + last_controller_seq = -1 + last_video_seq = -1 + last_health_s = 0.0 + try: + video_runtime.start() + while not stop_event.is_set(): + video_runtime.tick() + command = command_sub.recv_latest() + if isinstance(command, CommandPacket) and command.command == "shutdown": + stop_event.set() + break + + now = time.monotonic() + if callable(getattr(provider, "has_frame", None)) and provider.has_frame(): + try: + frame, timestamp_s, seq = provider.get_frame_packet() + except Exception: + logger.exception("pico_input failed to read body frame") + else: + if int(seq) != last_body_seq: + body_pub.publish( + BODY_TOPIC, + BodyFramePacket(frame=frame, timestamp_s=float(timestamp_s), seq=int(seq)), + ) + last_body_seq = int(seq) + + events = provider.pop_control_events() + if events: + events_pub.publish( + CONTROL_EVENTS_TOPIC, + ControlEventsPacket(events=tuple(events), timestamp_s=now, seq=last_body_seq), + ) + + controller_snapshot = provider.get_controller_snapshot() + if controller_snapshot is not None and int(controller_snapshot.seq) != last_controller_seq: + controller_pub.publish( + CONTROLLER_TOPIC, + SnapshotPacket( + snapshot=controller_snapshot, + timestamp_s=float(controller_snapshot.timestamp_s), + seq=int(controller_snapshot.seq), + ), + ) + last_controller_seq = int(controller_snapshot.seq) + + hand_snapshot = provider.get_hand_snapshot() + if hand_snapshot is not None and int(hand_snapshot.seq) != last_hand_seq: + hand_pub.publish( + HAND_TOPIC, + SnapshotPacket( + snapshot=hand_snapshot, + timestamp_s=float(hand_snapshot.timestamp_s), + seq=int(hand_snapshot.seq), + ), + ) + last_hand_seq = int(hand_snapshot.seq) + + if video_cfg.enabled: + last_video_seq = int(video_runtime.pushed_frames) + + if now - last_health_s >= 1.0: + health_pub.publish( + HEALTH_TOPIC, + HealthPacket( + worker="pico_input", + timestamp_s=now, + metrics={ + "body_seq": last_body_seq, + "body_fps": float(provider.fps), + "hand_seq": last_hand_seq, + "controller_seq": last_controller_seq, + "video_seq": last_video_seq, + }, + ), + ) + last_health_s = now + time.sleep(sleep_s) + finally: + video_runtime.stop() + if frame_writer is not None: + frame_writer.close(unlink=True) + command_sub.close() + for publisher in (body_pub, hand_pub, controller_pub, events_pub, health_pub, video_pub): + if publisher is not None: + publisher.close() + provider.close() + + _worker_loop("pico_input", cfg, _main) + + +def _run_reference_worker( + cfg: dict[str, Any], + endpoints: Sim2RealIpcEndpoints, + stop_event: MpEvent, +) -> None: + if _input_provider_kind(cfg) == "bvh": + _run_bvh_reference_worker(cfg, endpoints, stop_event) + return + _run_pico_reference_worker(cfg, endpoints, stop_event) + + +def _run_pico_reference_worker( + cfg: dict[str, Any], + endpoints: Sim2RealIpcEndpoints, + stop_event: MpEvent, +) -> None: + def _main() -> None: + input_cfg = cfg_get(cfg, "input", {}) or {} + policy_hz = float(cfg_get(cfg, "policy_hz", 50.0)) + ref_cfg = parse_reference_config(cfg, provider_fps=None) + reference_window_builder = ReferenceWindowBuilder( + policy_dt_s=1.0 / policy_hz, + reference_steps=cfg_get(cfg, "reference_steps", [0]), + ) + if ref_cfg.retarget_buffer_enabled and ref_cfg.reference_delay_s is not None: + reference_window_builder.validate_runtime_support( + delay_s=float(ref_cfg.reference_delay_s or 0.0), + window_s=ref_cfg.retarget_buffer_window_s, + config_label="Multiprocess sim2real reference timeline", + ) + timeline = ReferenceTimeline(window_s=ref_cfg.retarget_buffer_window_s) if ref_cfg.retarget_buffer_enabled else None + reference_manager = ( + RealtimeReferenceManager( + reference_window_builder=reference_window_builder, + warmup_steps=ref_cfg.realtime_buffer_warmup_steps, + ) + if timeline is not None + else None + ) + + retargeter = RetargetingModule( + robot_name=str(cfg_get(input_cfg, "robot_name", "unitree_g1")), + human_format=str(cfg_get(input_cfg, "human_format", "pico_bridge")), + actual_human_height=float(cfg_get(input_cfg, "human_height", 1.75)), + ) + body_sub = LatestSubscriber(endpoints.body_pub, BODY_TOPIC) + health_sub = LatestSubscriber(endpoints.health_pub, HEALTH_TOPIC) + command_sub = LatestSubscriber(endpoints.command_pub, COMMAND_TOPIC) + ref_pub = ZmqPublisher(endpoints.reference_pub) + idle_sleep_s = float(cfg_get(_mp_cfg(cfg), "retarget_idle_sleep_s", 0.001)) + last_body_seq = -1 + last_body_timestamp_s: float | None = None + body_dt_s_ema: float | None = None + latest_body_fps: float | None = None + resolved_reference_delay_s = ( + float(ref_cfg.reference_delay_s) if ref_cfg.reference_delay_s is not None else None + ) + runtime_support_validated = ref_cfg.reference_delay_s is not None or not reference_window_builder.requires_timeline + last_valid_qpos: Float64Array | None = None + + def _publish_invalid_reference(packet: BodyFramePacket, *, elapsed_s: float) -> None: + qpos = np.zeros(FULL_QPOS_DIM, dtype=np.float64) + qpos[3] = 1.0 + if last_valid_qpos is not None: + qpos = np.asarray(last_valid_qpos, dtype=np.float64).copy() + ref_pub.publish( + REFERENCE_TOPIC, + ReferencePacket( + qpos=qpos, + timestamp_s=time.monotonic(), + seq=int(packet.seq), + source_timestamp_s=float(packet.timestamp_s), + source_seq=int(packet.seq), + frame_valid=False, + retarget_elapsed_s=elapsed_s, + ), + ) + + try: + while not stop_event.is_set(): + command = command_sub.recv_latest() + if isinstance(command, CommandPacket) and command.command == "shutdown": + stop_event.set() + break + + health_packet = health_sub.recv_latest() + if isinstance(health_packet, HealthPacket) and health_packet.worker == "pico_input": + metric_fps = health_packet.metrics.get("body_fps") + if isinstance(metric_fps, (int, float)) and float(metric_fps) > 0.0: + latest_body_fps = float(metric_fps) + + packet = body_sub.recv_latest() + if packet is None: + time.sleep(idle_sleep_s) + continue + if not isinstance(packet, BodyFramePacket) or int(packet.seq) == last_body_seq: + continue + start_s = time.monotonic() + frame_valid = _human_frame_is_valid(packet.frame) + if not frame_valid: + last_body_seq = int(packet.seq) + last_body_timestamp_s = None + body_dt_s_ema = None + _publish_invalid_reference(packet, elapsed_s=time.monotonic() - start_s) + logger.warning("reference worker dropped invalid body frame seq=%s", packet.seq) + continue + + try: + retargeted = retargeter.retarget(packet.frame) + qpos = np.asarray(retargeted, dtype=np.float64).reshape(-1) + reference_window: ReferenceWindow | None = None + if timeline is not None: + timeline.append(qpos, float(packet.timestamp_s)) + if reference_manager is not None: + reference_manager.note_realtime_frame() + if reference_manager is None or not reference_manager.warmup_done: + last_body_timestamp_s = float(packet.timestamp_s) + last_body_seq = int(packet.seq) + continue + if last_body_timestamp_s is not None: + dt_s = float(packet.timestamp_s) - float(last_body_timestamp_s) + if dt_s > 1e-6: + body_dt_s_ema = dt_s if body_dt_s_ema is None else 0.9 * body_dt_s_ema + 0.1 * dt_s + last_body_timestamp_s = float(packet.timestamp_s) + if resolved_reference_delay_s is None: + if latest_body_fps is not None and latest_body_fps > 1e-6: + resolved_reference_delay_s = 1.0 / latest_body_fps + elif body_dt_s_ema is not None and body_dt_s_ema > 1e-6: + resolved_reference_delay_s = float(body_dt_s_ema) + elif reference_window_builder.requires_timeline: + last_body_seq = int(packet.seq) + continue + else: + resolved_reference_delay_s = 0.0 + if not runtime_support_validated: + reference_window_builder.validate_runtime_support( + delay_s=float(resolved_reference_delay_s), + window_s=ref_cfg.retarget_buffer_window_s, + config_label="Multiprocess sim2real reference timeline", + ) + runtime_support_validated = True + reference_window, _diag = reference_manager.sample( + timeline, + time.monotonic() - float(resolved_reference_delay_s), + ) + qpos = reference_window.current_sample().qpos + last_valid_qpos = np.asarray(qpos, dtype=np.float64).copy() + ref_pub.publish( + REFERENCE_TOPIC, + ReferencePacket( + qpos=np.asarray(qpos, dtype=np.float64).copy(), + timestamp_s=time.monotonic(), + seq=int(packet.seq), + source_timestamp_s=float(packet.timestamp_s), + source_seq=int(packet.seq), + frame_valid=True, + reference_window=reference_window, + retarget_elapsed_s=time.monotonic() - start_s, + ), + ) + last_body_seq = int(packet.seq) + except Exception: + logger.exception("reference worker failed to retarget body seq=%s", getattr(packet, "seq", None)) + finally: + body_sub.close() + health_sub.close() + command_sub.close() + ref_pub.close() + + _worker_loop("reference", cfg, _main) + + +def _run_bvh_reference_worker( + cfg: dict[str, Any], + endpoints: Sim2RealIpcEndpoints, + stop_event: MpEvent, +) -> None: + def _main() -> None: + input_cfg = cfg_get(cfg, "input", {}) or {} + policy_hz = float(cfg_get(cfg, "policy_hz", 50.0)) + provider = BVHInputProvider( + str(cfg_get(input_cfg, "bvh_file", "")), + human_format=str(cfg_get(input_cfg, "bvh_format", cfg_get(input_cfg, "human_format", "lafan1"))), + ) + retargeter = RetargetingModule( + robot_name=str(cfg_get(input_cfg, "robot_name", "unitree_g1")), + human_format=str(cfg_get(input_cfg, "human_format", cfg_get(input_cfg, "bvh_format", "lafan1"))), + actual_human_height=float(cfg_get(input_cfg, "human_height", provider.human_height)), + ) + offline_reference = OfflineReferenceMotion(provider, retargeter) + playback_cfg = cfg_get(cfg, "playback", {}) or {} + playback = OfflinePlaybackController( + duration_s=offline_reference.duration_s, + step_dt_s=1.0 / policy_hz, + pause_on_end=bool(cfg_get(playback_cfg, "pause_on_end", True)), + ) + reference_window_builder = ReferenceWindowBuilder( + policy_dt_s=1.0 / policy_hz, + reference_steps=cfg_get(cfg, "reference_steps", [0]), + ) + ref_pub = ZmqPublisher(endpoints.reference_pub) + command_sub = LatestSubscriber(endpoints.command_pub, COMMAND_TOPIC) + reference_command_sub = LatestSubscriber(endpoints.reference_command_pub, COMMAND_TOPIC) + mode_sub = LatestSubscriber(endpoints.mode_pub, MODE_TOPIC) + health_pub = ZmqPublisher(endpoints.health_pub) + tick_s = 1.0 / policy_hz + seq = 0 + last_health_s = 0.0 + mocap_active = False + + def _publish(sample_time_s: float, *, frame_valid: bool = True) -> Float64Array | None: + nonlocal seq + start_s = time.monotonic() + sampled = offline_reference.sample(sample_time_s) + if sampled is None: + return None + reference_window = None + if reference_window_builder.requires_timeline: + reference_window = build_offline_reference_window( + offline_reference, + sample_time_s, + reference_window_builder, + policy_hz, + ) + qpos = np.asarray(sampled.qpos, dtype=np.float64).copy() + ref_pub.publish( + REFERENCE_TOPIC, + ReferencePacket( + qpos=qpos, + timestamp_s=time.monotonic(), + seq=seq, + source_timestamp_s=float(sample_time_s), + source_seq=int(sampled.frame_idx0), + frame_valid=frame_valid, + reference_window=reference_window, + retarget_elapsed_s=time.monotonic() - start_s, + playback_paused=playback.paused, + playback_finished=playback.finished, + ), + ) + seq += 1 + return qpos + + try: + while not stop_event.is_set(): + t0 = time.monotonic() + command = command_sub.recv_latest() + if isinstance(command, CommandPacket): + if command.command == "shutdown": + stop_event.set() + break + reference_command = reference_command_sub.recv_latest() + if isinstance(reference_command, CommandPacket): + command = reference_command + if command.command == "pause_mocap": + playback.pause() + elif command.command == "resume_mocap": + if not playback.finished: + playback.resume() + elif command.command == "replay_mocap": + playback.replay() + mode_packet = mode_sub.recv_latest() + if isinstance(mode_packet, ModeStatePacket): + mocap_active = bool(mode_packet.mocap_active) + + qpos = _publish(playback.current_time_s) + if qpos is None: + playback.finish() + _publish(playback.current_time_s) + elif mocap_active: + playback.advance() + + now = time.monotonic() + if now - last_health_s >= 1.0: + health_pub.publish( + HEALTH_TOPIC, + HealthPacket( + worker="reference", + timestamp_s=now, + metrics={ + "source": "bvh", + "seq": seq, + "playback_time_s": float(playback.current_time_s), + "paused": int(playback.paused), + "finished": int(playback.finished), + }, + ), + ) + last_health_s = now + elapsed = time.monotonic() - t0 + if elapsed < tick_s: + time.sleep(tick_s - elapsed) + finally: + command_sub.close() + reference_command_sub.close() + mode_sub.close() + ref_pub.close() + health_pub.close() + + _worker_loop("reference", cfg, _main) + + +class _RobotControlWorker: + def __init__( + self, + cfg: dict[str, Any], + endpoints: Sim2RealIpcEndpoints, + stop_event: MpEvent, + ) -> None: + self.cfg = cfg + self.endpoints = endpoints + self.stop_event = stop_event + self.provider_kind = _input_provider_kind(cfg) + self.mode = RobotMode.IDLE + self.policy_hz = float(cfg_get(cfg, "policy_hz", 50.0)) + self.dt = 1.0 / self.policy_hz + + self.robot = UnitreeG1Robot(cfg_get(cfg, "real_robot")) + self.remote = UnitreeRemote() + self.policy, self.obs_builder = self._build_policy_and_obs() + + robot_cfg = cfg_get(cfg, "robot") + self.default_angles = np.asarray(cfg_get(robot_cfg, "default_angles"), dtype=np.float32) + default_root_qpos = np.asarray( + cfg_get(robot_cfg, "mujoco_default_qpos", [0.0, 0.0, 0.0]), dtype=np.float64 + ).reshape(-1) + self._default_root_pos = np.zeros(3, dtype=np.float64) + if default_root_qpos.shape[0] >= 3: + self._default_root_pos[:] = default_root_qpos[:3] + self.num_actions = int(cfg_get(robot_cfg, "num_actions", NUM_JOINTS)) + self._safety = Sim2RealSafetyManager(cfg, self.robot, self.policy_hz, self.num_actions) + self._standing_return_ramp_duration = float(cfg_get(cfg, "standing_return_ramp_duration", 0.5)) + self._standing_return_kp_ramp_floor_ratio = float( + cfg_get(cfg, "standing_return_kp_ramp_floor_ratio", 0.5) + ) + self._arm_joint_indices = parse_arm_joint_indices(cfg, num_actions=self.num_actions) + + self._ref_cfg = parse_reference_config(cfg, provider_fps=None) + self._reference_window_builder = ReferenceWindowBuilder( + policy_dt_s=self.dt, + reference_steps=cfg_get(cfg, "reference_steps", [0]), + ) + self._ref_proc = Sim2RealReferenceProcessor( + obs_builder=self.obs_builder, + policy=self.policy, + policy_hz=self.policy_hz, + num_actions=self.num_actions, + reference_velocity_smoothing_alpha=self._ref_cfg.reference_velocity_smoothing_alpha, + reference_anchor_velocity_smoothing_alpha=self._ref_cfg.reference_anchor_velocity_smoothing_alpha, + ) + + self._standing_qpos = np.zeros(FULL_QPOS_DIM, dtype=np.float64) + self._standing_qpos[3] = 1.0 + self._standing_qpos[ROOT_DIM:FULL_QPOS_DIM] = self.default_angles.astype(np.float64) + self._last_action = np.zeros(self.num_actions, dtype=np.float32) + self._last_retarget_qpos: Float64Array | None = None + self._last_commanded_motion_qpos: Float64Array | None = None + self._last_mocap_hold_reason: str | None = None + self._mocap_reentry_armed = False + self._mocap_session = MocapSessionManager() + + self._latest_reference: ReferencePacket | None = None + mp_cfg = _mp_cfg(cfg) + self._max_reference_age_s = float(cfg_get(mp_cfg, "max_reference_age_s", 0.25)) + self._stale_reference_hold_s = float(cfg_get(mp_cfg, "stale_reference_hold_s", 0.08)) + mocap_sw = cfg_get(cfg, "mocap_switch", {}) or {} + self._check_frames = int(cfg_get(mocap_sw, "check_frames", 10)) + self._last_reference_seq = -1 + self._consecutive_valid_references = 0 + + self._reference_sub = LatestSubscriber(endpoints.reference_pub, REFERENCE_TOPIC) + self._events_sub = LatestSubscriber(endpoints.control_events_pub, CONTROL_EVENTS_TOPIC) + self._command_sub = LatestSubscriber(endpoints.command_pub, COMMAND_TOPIC) + self._reference_command_pub = ZmqPublisher(endpoints.reference_command_pub) + self._mode_pub = ZmqPublisher(endpoints.mode_pub) + self._record_pub = ZmqPublisher(endpoints.record_pub) if _recording_enabled(cfg) else None + + viewers = _parse_sim2real_viewers(cfg) + self._retarget_viewer = _Sim2RealRetargetViewer( + xml_path=str(cfg_get(robot_cfg, "xml_path", "")) if "retarget" in viewers else None, + enabled="retarget" in viewers, + ) + self._mode_seq = 0 + + def run(self) -> None: + operator_logger.info("robot control ready | mode=IDLE | policy_hz=%.0f", self.policy_hz) + timing = _LoopTimingReporter( + target_period_s=self.dt, + log_interval_s=console_timing_interval_s(self.cfg), + enabled=console_show_timing(self.cfg), + ) + try: + while not self.stop_event.is_set(): + t0 = time.monotonic() + self._drain_ipc() + + remote_bytes = self.robot.get_wireless_remote() + self.remote.update(remote_bytes) + if self.remote.LB.pressed and self.remote.RB.pressed: + if self.mode != RobotMode.DAMPING: + logger.warning("EMERGENCY STOP (L1+R1)") + operator_logger.warning("DAMPING requested by emergency stop") + self._enter_damping() + else: + self._handle_transitions() + if self.mode == RobotMode.STANDING: + self._standing_step() + elif self.mode in (RobotMode.MOCAP, RobotMode.ARMS): + self._mocap_step() + + self._publish_mode_state() + work_elapsed_s = time.monotonic() - t0 + cycle_elapsed_s = self._sleep_until(t0, self.dt) + timing.record( + loop_start_s=t0, + work_elapsed_s=work_elapsed_s, + cycle_elapsed_s=cycle_elapsed_s, + pico_age_s=self._reference_age_s(), + ) + finally: + self.shutdown() + + def shutdown(self) -> None: + if self.mode in (RobotMode.STANDING, RobotMode.MOCAP, RobotMode.ARMS): + try: + self.robot.set_damping() + time.sleep(0.5) + except Exception: + logger.exception("Failed to send damping during robot_control shutdown") + try: + self.robot.exit_debug_mode() + except Exception: + logger.exception("Failed to exit debug mode during robot_control shutdown") + self._retarget_viewer.shutdown() + self._reference_sub.close() + self._events_sub.close() + self._command_sub.close() + self._reference_command_pub.close() + self._mode_pub.close() + if self._record_pub is not None: + self._record_pub.close() + self.robot.close() + + def _build_policy_and_obs(self) -> tuple[Any, Any]: + robot_cfg = require_section(self.cfg, "robot") + controller_cfg = require_section(self.cfg, "controller") + sim_cfg = build_simulation_cfg(self.cfg) + policy, obs_builder = _build_policy_components( + robot_cfg=robot_cfg, + controller_cfg=controller_cfg, + sim_cfg=sim_cfg, + project_root=PROJECT_ROOT, + controller_cls=RLPolicyController, + ) + if not bool(getattr(policy, "_multi_input", False)): + raise ValueError("Sim2real requires an ONNX policy with dual inputs ('obs' and 'obs_history').") + return policy, obs_builder + + def _drain_ipc(self) -> None: + command = self._command_sub.recv_latest() + if isinstance(command, CommandPacket) and command.command == "shutdown": + self.stop_event.set() + return + reference = self._reference_sub.recv_latest() + if isinstance(reference, ReferencePacket): + self._note_reference_packet(reference) + events = self._events_sub.recv_latest() + if isinstance(events, ControlEventsPacket): + self._handle_mocap_control_events(events.events) + + def _handle_transitions(self) -> None: + if self.mode == RobotMode.IDLE: + if self.remote.start.on_pressed: + operator_logger.info("Start -> STANDING") + self._enter_standing() + elif self.mode == RobotMode.STANDING: + reentry_request = self._mocap_reentry_armed and self.remote.Y.pressed + if self.remote.Y.on_pressed or reentry_request: + if self._can_switch_to_mocap(): + operator_logger.info("Y -> MOCAP") + self._transition_to_mocap() + else: + operator_logger.warning("Y -> waiting for fresh retarget reference") + elif self.mode in (RobotMode.MOCAP, RobotMode.ARMS): + if self.provider_kind == "bvh" and self.remote.B.on_pressed: + operator_logger.info("B -> replay BVH from frame 0") + self._send_reference_command("replay_mocap") + self._resume_paused_mocap_if_needed() + return + if self.remote.A.on_pressed: + if self._mocap_session.state == MocapSessionState.PAUSED: + operator_logger.info("A -> resume playback") + self._send_reference_command("resume_mocap") + self._resume_paused_mocap() + else: + operator_logger.info("A -> pause playback") + self._send_reference_command("pause_mocap") + self._pause_active_mocap() + return + if self.remote.X.on_pressed: + operator_logger.info("X -> STANDING") + self._enter_standing() + elif self.mode == RobotMode.DAMPING: + if self.remote.start.on_pressed: + operator_logger.info("Start -> STANDING") + self._enter_standing() + + def _standing_step(self) -> None: + robot_state = self.robot.get_state() + qpos = self._standing_qpos.copy() + motion_joint_vel = np.zeros(self.num_actions, dtype=np.float32) + motion_qpos = np.asarray(qpos[:7 + self.num_actions], dtype=np.float32) + reference_window = None + if obs_builder_requires_reference_window(self.obs_builder): + reference_window = build_static_reference_window(qpos, self._reference_window_builder, self.policy_hz) + obs = self._ref_proc.build_observation( + robot_state=robot_state, + motion_qpos=motion_qpos, + motion_joint_vel=motion_joint_vel, + last_action=self._last_action, + anchor_lin_vel_w=np.zeros(3, dtype=np.float32), + anchor_ang_vel_w=np.zeros(3, dtype=np.float32), + reference_window=reference_window, + ) + obs = self._ref_proc.validate_observation(obs) + action = self.policy.compute_action(obs) + target_dof_pos = self._safety.clip_to_joint_limits(self.policy.get_target_dof_pos(action)) + self._safety.send_positions(target_dof_pos) + self._last_action = np.asarray(action, dtype=np.float32).reshape(-1) + self._last_retarget_qpos = qpos.copy() + self._last_commanded_motion_qpos = qpos.copy() + self._publish_record_step(robot_state=robot_state, reference_qpos=qpos) + self._write_retarget_viewer(qpos) + + def _mocap_step(self) -> None: + if self._mocap_session.state == MocapSessionState.PAUSED: + self._paused_mocap_step() + return + + reference = self._latest_reference + age_s = self._reference_age_s() + if reference is None or age_s is None: + self._hold_mocap_reference("no retarget reference") + return + if not reference.frame_valid: + self._hold_mocap_reference("invalid retarget reference") + return + if age_s > self._stale_reference_hold_s: + self._hold_mocap_reference( + "delayed retarget reference", + detail=f"age={age_s:.3f}s", + ) + return + + robot_state = self.robot.get_state() + self._execute_mocap_pipeline(reference.qpos, robot_state, reference.reference_window) + + def _execute_mocap_pipeline( + self, + reference_qpos: Float64Array, + robot_state: object, + reference_window: ReferenceWindow | None, + ) -> None: + reference_window_aligned = False + reference_qpos = self._ref_proc.align_reference_yaw(reference_qpos, robot_state=robot_state) + if self.mode == RobotMode.ARMS: + reference_qpos = self._compose_arm_reference(reference_qpos) + aligned_window = self._ref_proc.align_reference_window(reference_window, robot_state) + reference_window = self._compose_arm_reference_window(aligned_window) + reference_window_aligned = True + qpos = reference_qpos.copy() + if qpos.shape[0] < 7 + self.num_actions: + raise ValueError(f"Retargeted qpos too short: {qpos.shape[0]} (need >= {7 + self.num_actions})") + motion_joint_pos = np.asarray(qpos[7:7 + self.num_actions], dtype=np.float32) + if self._last_retarget_qpos is None: + raw_motion_joint_vel = np.zeros((self.num_actions,), dtype=np.float32) + else: + prev_joint_pos = np.asarray(self._last_retarget_qpos[7:7 + self.num_actions], dtype=np.float32) + raw_motion_joint_vel = (motion_joint_pos - prev_joint_pos) * np.float32(self.policy_hz) + motion_joint_vel = self._ref_proc.apply_joint_vel_smoothing(raw_motion_joint_vel) + + anchor_lin_vel_w = np.zeros(3, dtype=np.float32) + anchor_ang_vel_w = np.zeros(3, dtype=np.float32) + if not obs_builder_requires_reference_window(self.obs_builder): + raw_lin, raw_ang = self._ref_proc.compute_anchor_velocities(reference_qpos) + anchor_lin_vel_w, anchor_ang_vel_w = self._ref_proc.apply_anchor_vel_smoothing(raw_lin, raw_ang) + + motion_qpos = np.asarray(qpos[:7 + self.num_actions], dtype=np.float32) + obs = self._ref_proc.build_observation( + robot_state=robot_state, + motion_qpos=motion_qpos, + motion_joint_vel=motion_joint_vel, + last_action=self._last_action, + anchor_lin_vel_w=anchor_lin_vel_w, + anchor_ang_vel_w=anchor_ang_vel_w, + reference_window=reference_window, + reference_window_aligned=reference_window_aligned, + ) + obs = self._ref_proc.validate_observation(obs) + action = self.policy.compute_action(obs) + target_dof_pos = self._safety.clip_to_joint_limits(self.policy.get_target_dof_pos(action)) + self._safety.send_positions(target_dof_pos) + self._last_action = np.asarray(action, dtype=np.float32).reshape(-1) + self._last_retarget_qpos = qpos.copy() + self._ref_proc.last_reference_qpos = reference_qpos.copy() + self._last_commanded_motion_qpos = qpos.copy() + self._last_mocap_hold_reason = None + self._publish_record_step(robot_state=robot_state, reference_qpos=qpos) + self._write_retarget_viewer(qpos) + + def _compose_arm_reference(self, retarget_qpos: Float64Array) -> Float64Array: + return compose_arm_reference( + standing_qpos=self._standing_qpos, + retarget_qpos=retarget_qpos, + arm_joint_indices=self._arm_joint_indices, + num_actions=self.num_actions, + ) + + def _compose_arm_reference_window(self, reference_window: ReferenceWindow | None) -> ReferenceWindow | None: + return compose_arm_reference_window( + reference_window, + standing_qpos=self._standing_qpos, + arm_joint_indices=self._arm_joint_indices, + num_actions=self.num_actions, + ) + + def _enter_standing(self) -> None: + prev_mode = self.mode + already_in_debug = self.mode in (RobotMode.STANDING, RobotMode.MOCAP, RobotMode.ARMS) + if not already_in_debug: + logger.info("Entering debug mode...") + ok = self.robot.enter_debug_mode() + if not ok: + logger.error("Failed to enter debug mode -- staying in %s", self.mode.value) + return + time.sleep(0.5) + + state = self.robot.get_state() + if prev_mode not in (RobotMode.MOCAP, RobotMode.ARMS): + logger.info("Locking joints to current position...") + self.robot.lock_all_joints() + time.sleep(0.3) + + init_qpos = self._build_robot_state_qpos(state) + self._last_retarget_qpos = init_qpos + self._ref_proc.last_reference_qpos = None + self._mocap_session.reset() + self._last_commanded_motion_qpos = None + self._set_default_standing_reference(state) + self._reset_policy_state() + if prev_mode in (RobotMode.MOCAP, RobotMode.ARMS): + self._safety.start_kp_ramp( + duration_s=self._standing_return_ramp_duration, + floor_ratio=self._standing_return_kp_ramp_floor_ratio, + ) + else: + self._safety.start_kp_ramp() + self._mocap_reentry_armed = prev_mode in (RobotMode.MOCAP, RobotMode.ARMS) + self.mode = RobotMode.STANDING + operator_logger.info("mode -> STANDING") + + def _can_switch_to_mocap(self) -> bool: + age_s = self._reference_age_s() + if self._latest_reference is None or age_s is None: + return False + if not self._latest_reference.frame_valid: + return False + if self.provider_kind == "bvh": + return True + if age_s > self._max_reference_age_s: + return False + if self._consecutive_valid_references < self._check_frames: + logger.warning( + "Mocap check: only %d/%d valid references", + self._consecutive_valid_references, + self._check_frames, + ) + return False + return True + + def _transition_to_mocap(self) -> None: + state = self.robot.get_state() + last_commanded = getattr(self, "_last_commanded_motion_qpos", None) + hold_qpos = last_commanded if last_commanded is not None else self._standing_qpos + resume_qpos = self._build_resume_alignment_qpos(hold_qpos, state) + self._mocap_reentry_armed = False + self._reset_policy_state() + self._last_retarget_qpos = None + self._last_commanded_motion_qpos = resume_qpos.copy() + self._ref_proc.reset_alignment(target_qpos=resume_qpos) + if self.provider_kind == "bvh": + self._send_reference_command("replay_mocap") + self.mode = RobotMode.MOCAP + operator_logger.info("mode -> MOCAP") + + def _toggle_arms_mode(self) -> None: + if self.provider_kind != "pico4" or self.mode not in (RobotMode.MOCAP, RobotMode.ARMS): + return + if self._mocap_session.state == MocapSessionState.PAUSED: + logger.info("Ignoring Pico B mode toggle while mocap session is paused") + return + + state = self.robot.get_state() + resume_qpos = self._build_resume_alignment_qpos(self._last_commanded_motion_qpos, state) + next_mode = RobotMode.ARMS if self.mode == RobotMode.MOCAP else RobotMode.MOCAP + if next_mode == RobotMode.ARMS: + self._set_default_standing_reference(state) + self._reset_policy_state() + self._last_retarget_qpos = None + self._last_commanded_motion_qpos = resume_qpos.copy() + self._ref_proc.reset_alignment(target_qpos=resume_qpos) + self._safety.start_kp_ramp( + duration_s=self._standing_return_ramp_duration, + floor_ratio=self._standing_return_kp_ramp_floor_ratio, + ) + self.mode = next_mode + operator_logger.info("mode -> %s", next_mode.value.upper()) + + def _resume_paused_mocap_if_needed(self) -> None: + if self._mocap_session.state == MocapSessionState.PAUSED: + self._resume_paused_mocap() + + def _enter_damping(self) -> None: + if self.mode in (RobotMode.STANDING, RobotMode.MOCAP, RobotMode.ARMS): + logger.info("DAMPING: sending LowCmd damping...") + self.robot.set_damping() + time.sleep(0.5) + logger.info("DAMPING: exiting debug mode...") + self.robot.exit_debug_mode() + self.mode = RobotMode.DAMPING + self._publish_damping_record_step() + self._ref_proc.last_reference_qpos = None + self._mocap_reentry_armed = False + self._mocap_session.reset() + self._last_commanded_motion_qpos = None + self._last_mocap_hold_reason = None + operator_logger.warning("mode -> DAMPING") + + def _reset_policy_state(self) -> None: + self._last_action = np.zeros(self.num_actions, dtype=np.float32) + self._ref_proc.reset_smoothers() + self._ref_proc.reset_alignment() + self._mocap_session.reset() + self._last_commanded_motion_qpos = None + self._last_mocap_hold_reason = None + self.policy.reset() + self.obs_builder.reset() + + def _reset_policy_reference_state(self) -> None: + self._last_action = np.zeros(self.num_actions, dtype=np.float32) + self._ref_proc.reset_smoothers() + self._ref_proc.reset_alignment() + self._mocap_session.reset() + self._last_commanded_motion_qpos = None + self._last_mocap_hold_reason = None + self.policy.reset() + self.obs_builder.reset() + + def _build_robot_state_qpos(self, state: object) -> Float64Array: + qpos = np.zeros(FULL_QPOS_DIM, dtype=np.float64) + qpos[0:3] = self._resolve_base_pos(state) + qpos[3:7] = np.asarray(getattr(state, "quat"), dtype=np.float64).reshape(-1)[:4] + qpos[ROOT_DIM:FULL_QPOS_DIM] = np.asarray(getattr(state, "qpos"), dtype=np.float64).reshape(-1)[ + : self.num_actions + ] + return qpos + + def _set_default_standing_reference(self, state: object) -> None: + self._standing_qpos[:] = 0.0 + self._standing_qpos[0:3] = self._resolve_base_pos(state) + self._standing_qpos[3:7] = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64) + align_motion_qpos_yaw(np.asarray(getattr(state, "quat"), dtype=np.float32), self._standing_qpos) + self._standing_qpos[ROOT_DIM:FULL_QPOS_DIM] = self.default_angles.astype(np.float64) + + def _resolve_base_pos(self, state: object) -> Float64Array: + base_pos = getattr(state, "base_pos", None) + if base_pos is None: + return self._default_root_pos.copy() + resolved = self._default_root_pos.copy() + live = np.asarray(base_pos, dtype=np.float64).reshape(-1) + resolved[: min(3, live.shape[0])] = live[:3] + return resolved + + def _build_resume_alignment_qpos(self, hold_qpos: Float64Array | None, state: object) -> Float64Array: + qpos = self._build_robot_state_qpos(state) + if hold_qpos is not None: + qpos[0:2] = np.asarray(hold_qpos, dtype=np.float64).reshape(-1)[0:2] + base_pos = getattr(state, "base_pos", None) + if base_pos is not None: + qpos[0:2] = np.asarray(base_pos, dtype=np.float64).reshape(-1)[0:2] + return qpos + + def _handle_mocap_control_events(self, control_events: tuple[ControlEvent, ...]) -> None: + for event in control_events: + if event.event_type == ControlEventType.TOGGLE_ARMS: + self._toggle_arms_mode() + continue + if event.event_type == ControlEventType.TOGGLE_PAUSE: + if self.mode not in (RobotMode.MOCAP, RobotMode.ARMS): + continue + if self._mocap_session.state == MocapSessionState.PAUSED: + self._resume_paused_mocap() + else: + self._pause_active_mocap() + + def _pause_active_mocap(self) -> None: + hold_qpos = self._resolve_mocap_hold_qpos() + self._last_retarget_qpos = hold_qpos.copy() + self._ref_proc.last_reference_qpos = hold_qpos.copy() + self._last_commanded_motion_qpos = hold_qpos.copy() + self._reset_policy_reference_state() + self._mocap_session.pause(hold_qpos) + logger.info("Mocap session -> PAUSED (multiprocess episode-reset)") + + def _resume_paused_mocap(self) -> None: + hold_qpos = self._mocap_session.hold_qpos + if hold_qpos is None: + raise RuntimeError("Cannot resume mocap without a paused hold qpos") + state = self.robot.get_state() + resume_qpos = self._build_resume_alignment_qpos(hold_qpos, state) + self._last_commanded_motion_qpos = resume_qpos.copy() + self._reset_policy_reference_state() + self._last_retarget_qpos = None + self._last_commanded_motion_qpos = resume_qpos.copy() + self._ref_proc.reset_alignment(target_qpos=resume_qpos) + if self.mode == RobotMode.ARMS: + self._set_default_standing_reference(state) + self._safety.start_kp_ramp( + duration_s=self._standing_return_ramp_duration, + floor_ratio=self._standing_return_kp_ramp_floor_ratio, + ) + logger.info("Mocap session -> ACTIVE (multiprocess episode-reset + reference realignment)") + + def _send_reference_command(self, command: str) -> None: + self._reference_command_pub.publish( + COMMAND_TOPIC, + CommandPacket(command=command, timestamp_s=time.monotonic()), + ) + + def _resolve_mocap_hold_qpos(self) -> Float64Array: + if self._last_commanded_motion_qpos is not None: + return self._last_commanded_motion_qpos.copy() + if self._last_retarget_qpos is not None: + return np.asarray(self._last_retarget_qpos, dtype=np.float64).copy() + state = self.robot.get_state() + hold_qpos = np.zeros(FULL_QPOS_DIM, dtype=np.float64) + hold_qpos[3:7] = np.asarray(state.quat, dtype=np.float64) + hold_qpos[ROOT_DIM:FULL_QPOS_DIM] = np.asarray(state.qpos, dtype=np.float64) + return hold_qpos + + def _paused_mocap_step(self) -> None: + hold_qpos = self._mocap_session.hold_qpos + if hold_qpos is None: + raise RuntimeError("Paused mocap session is missing a hold_qpos") + self._run_static_mocap_step(hold_qpos) + + def _run_static_mocap_step(self, hold_qpos: Float64Array) -> None: + robot_state = self.robot.get_state() + qpos = np.asarray(hold_qpos, dtype=np.float64).copy() + motion_joint_vel = np.zeros(self.num_actions, dtype=np.float32) + motion_qpos = np.asarray(qpos[:7 + self.num_actions], dtype=np.float32) + reference_window = None + if obs_builder_requires_reference_window(self.obs_builder): + reference_window = build_static_reference_window(qpos, self._reference_window_builder, self.policy_hz) + obs = self._ref_proc.build_observation( + robot_state=robot_state, + motion_qpos=motion_qpos, + motion_joint_vel=motion_joint_vel, + last_action=self._last_action, + anchor_lin_vel_w=np.zeros(3, dtype=np.float32), + anchor_ang_vel_w=np.zeros(3, dtype=np.float32), + reference_window=reference_window, + ) + obs = self._ref_proc.validate_observation(obs) + action = self.policy.compute_action(obs) + target_dof_pos = self._safety.clip_to_joint_limits(self.policy.get_target_dof_pos(action)) + self._safety.send_positions(target_dof_pos) + self._last_action = np.asarray(action, dtype=np.float32).reshape(-1) + self._last_retarget_qpos = qpos.copy() + self._ref_proc.last_reference_qpos = qpos.copy() + self._last_commanded_motion_qpos = qpos.copy() + self._publish_record_step(robot_state=robot_state, reference_qpos=qpos) + self._write_retarget_viewer(qpos) + + def _hold_mocap_reference(self, reason: str, *, detail: str | None = None) -> None: + if self._last_mocap_hold_reason != reason: + suffix = f" ({detail})" if detail else "" + logger.warning("Mocap reference not fresh: %s%s -- holding command", reason, suffix) + self._last_mocap_hold_reason = reason + hold_qpos = self._resolve_mocap_hold_qpos() + self._run_static_mocap_step(hold_qpos) + + def _publish_mode_state(self) -> None: + self._mode_seq += 1 + mocap_like = self.mode in (RobotMode.MOCAP, RobotMode.ARMS) + active = mocap_like and self._mocap_session.state == MocapSessionState.ACTIVE + paused = mocap_like and self._mocap_session.state == MocapSessionState.PAUSED + self._mode_pub.publish( + MODE_TOPIC, + ModeStatePacket( + mode=self.mode.value, + mocap_active=active, + mocap_paused=paused, + timestamp_s=time.monotonic(), + seq=self._mode_seq, + ), + ) + + def _publish_record_step(self, *, robot_state: object, reference_qpos: Float64Array) -> None: + if self._record_pub is None: + return + record_mode = self._recording_mode_label() + mocap_like = self.mode in (RobotMode.MOCAP, RobotMode.ARMS) + active = mocap_like and self._mocap_session.state == MocapSessionState.ACTIVE + recordable = self.mode != RobotMode.DAMPING + try: + self._record_pub.publish( + RECORD_TOPIC, + RecordStepPacket( + timestamp_s=time.monotonic(), + mode=record_mode, + mocap_active=active, + recordable=recordable, + observation_state=build_observation_state(robot_state).astype(np.float32, copy=True), + observation_mode=build_mode_observation(record_mode).astype(np.float32, copy=True), + action_reference_qpos=normalize_action_reference_qpos(reference_qpos).astype(np.float32, copy=True), + seq=self._mode_seq, + ), + ) + except Exception: + logger.exception("Failed to publish sim2real recording step") + + def _publish_damping_record_step(self) -> None: + if self._record_pub is None: + return + try: + robot_state = self.robot.get_state() + reference_qpos = self._build_robot_state_qpos(robot_state) + self._record_pub.publish( + RECORD_TOPIC, + RecordStepPacket( + timestamp_s=time.monotonic(), + mode=RobotMode.DAMPING.value, + mocap_active=False, + recordable=False, + observation_state=build_observation_state(robot_state).astype(np.float32, copy=True), + observation_mode=np.array([-1.0], dtype=np.float32), + action_reference_qpos=normalize_action_reference_qpos(reference_qpos).astype(np.float32, copy=True), + seq=self._mode_seq, + ), + ) + except Exception: + logger.exception("Failed to publish sim2real damping recording state") + + def _recording_mode_label(self) -> str: + if ( + self.mode in (RobotMode.MOCAP, RobotMode.ARMS) + and self._mocap_session.state == MocapSessionState.PAUSED + ): + return "pause" + return self.mode.value + + def _write_retarget_viewer(self, qpos: Float64Array) -> None: + try: + self._retarget_viewer.write(qpos) + except Exception: + logger.exception("Sim2real retarget viewer update failed; control continues") + + def _reference_age_s(self) -> float | None: + if self._latest_reference is None: + return None + return max(0.0, time.monotonic() - float(self._latest_reference.timestamp_s)) + + def _note_reference_packet(self, reference: ReferencePacket) -> None: + if int(reference.seq) <= self._last_reference_seq: + return + self._last_reference_seq = int(reference.seq) + self._latest_reference = reference + if ( + self.provider_kind == "bvh" + and bool(getattr(reference, "playback_paused", False)) + and self.mode in (RobotMode.MOCAP, RobotMode.ARMS) + and self._mocap_session.state == MocapSessionState.ACTIVE + ): + self._pause_active_mocap() + if not reference.frame_valid: + self._consecutive_valid_references = 0 + return + self._consecutive_valid_references += 1 + + @staticmethod + def _sleep_until(t0: float, dt: float) -> float: + elapsed = time.monotonic() - t0 + remaining = dt - elapsed + if remaining > 0: + time.sleep(remaining) + return time.monotonic() - t0 + + +def _run_robot_control_worker( + cfg: dict[str, Any], + endpoints: Sim2RealIpcEndpoints, + stop_event: MpEvent, +) -> None: + def _main() -> None: + worker = _RobotControlWorker(cfg, endpoints, stop_event) + worker.run() + + _worker_loop("robot_control", cfg, _main) + + +class _RecordingWorker: + def __init__( + self, + cfg: dict[str, Any], + endpoints: Sim2RealIpcEndpoints, + stop_event: MpEvent, + *, + recorder_factory: Callable[..., Any] | None = None, + frame_reader: SharedFrameRingReader | None = None, + ) -> None: + self.cfg = cfg + self.endpoints = endpoints + self.stop_event = stop_event + self.rec_cfg = _recording_cfg(cfg) + self.camera_cfg = _recording_camera_cfg(cfg) + self.record_modes = { + str(mode).lower() + for mode in cfg_get(self.rec_cfg, "record_modes", ["standing", "mocap", "arms", "pause"]) + } + self.min_episode_seconds = float(cfg_get(self.rec_cfg, "min_episode_seconds", 1.0)) + self.discard_on_shutdown = bool(cfg_get(self.rec_cfg, "discard_on_shutdown", True)) + self.task = str(cfg_get(self.rec_cfg, "task", "demo")) + self.fps = int(cfg_get(self.rec_cfg, "fps", 30)) + self._record_sub = LatestSubscriber(endpoints.record_pub, RECORD_TOPIC) + self._video_sub = LatestSubscriber(endpoints.video_pub, VIDEO_TOPIC) + self._hand_command_sub = LatestSubscriber(endpoints.hand_command_pub, HAND_COMMAND_TOPIC) + self._command_sub = LatestSubscriber(endpoints.command_pub, COMMAND_TOPIC) + self._frame_reader = frame_reader or SharedFrameRingReader() + self._latest_record: RecordStepPacket | None = None + left_open, right_open = _configured_open_hand_pose(cfg) + self._latest_hand_command = HandCommandPacket( + timestamp_s=0.0, + driver=str(cfg_get(cfg_get(cfg, "hands", {}) or {}, "driver", "linkerhand_l6")).strip().lower(), + mode=str(cfg_get(cfg_get(cfg, "hands", {}) or {}, "mode", "gripper")).strip().lower(), + active=False, + left_pose=left_open.astype(np.float32, copy=True), + right_pose=right_open.astype(np.float32, copy=True), + seq=0, + ) + self._latest_video_seq = -1 + self._active = False + self._episode_started_s = 0.0 + self._episode_frames = 0 + + from teleopit.recording.hdf5 import ( + TeleopitHDF5Recorder, + build_recording_schema, + ) + + self._schema = build_recording_schema(self.camera_cfg) + self._video_config = build_mp4_video_config(cfg_get(self.rec_cfg, "video", {}) or {}) + factory = recorder_factory or TeleopitHDF5Recorder.create + self._recorder = factory( + output_dir=cfg_get(self.rec_cfg, "output_dir", "data/recordings/sim2real_hdf5"), + task=self.task, + fps=self.fps, + schema=self._schema, + video_config=self._video_config, + ) + + def run(self) -> None: + operator_logger.info("recording worker ready | fps=%d | modes=%s", self.fps, sorted(self.record_modes)) + idle_sleep_s = 1.0 / max(float(self.fps) * 4.0, 1.0) + try: + while not self.stop_event.is_set(): + command = self._command_sub.recv_latest() + if isinstance(command, CommandPacket): + if self._handle_command(command): + break + + record = self._record_sub.recv_latest() + if isinstance(record, RecordStepPacket): + self._latest_record = record + + hand_command = self._hand_command_sub.recv_latest() + if isinstance(hand_command, HandCommandPacket): + self._latest_hand_command = hand_command + + video = self._video_sub.recv_latest() + if isinstance(video, SharedFrameDescriptor): + self._handle_video(video) + + time.sleep(idle_sleep_s) + finally: + if self._active: + if self.discard_on_shutdown: + self._discard_episode("shutdown") + else: + self._save_episode() + try: + self._recorder.finalize() + finally: + self._record_sub.close() + self._video_sub.close() + self._hand_command_sub.close() + self._command_sub.close() + self._frame_reader.close() + + def _handle_command(self, command: CommandPacket) -> bool: + name = command.command + if name == "shutdown": + self.stop_event.set() + return True + if name == "record_start": + self._start_episode() + elif name == "record_save": + self._save_episode() + elif name == "record_discard": + self._discard_episode("manual discard") + return False + + def _start_episode(self) -> None: + if self._active: + logger.warning("Recording episode already active; ignoring R") + return + record = self._latest_record + if record is None: + logger.warning("Cannot start recording: no robot record packet yet") + return + mode = str(record.mode).lower() + if mode not in self.record_modes or not bool(record.recordable): + logger.warning( + "Cannot start recording: mode=%s recordable=%s", + record.mode, + record.recordable, + ) + return + self._recorder.start_episode() + self._active = True + self._episode_started_s = time.monotonic() + self._episode_frames = 0 + operator_logger.info("recording episode started") + + def _save_episode(self) -> None: + if not self._active: + operator_logger.info("no active recording episode to save") + return + duration_s = time.monotonic() - self._episode_started_s + if self._episode_frames <= 0: + self._discard_episode("empty episode") + return + if duration_s < self.min_episode_seconds: + self._discard_episode(f"short episode ({duration_s:.2f}s < {self.min_episode_seconds:.2f}s)") + return + self._recorder.save_episode() + operator_logger.info("recording episode saved | frames=%d duration=%.2fs", self._episode_frames, duration_s) + self._active = False + self._episode_frames = 0 + + def _discard_episode(self, reason: str) -> None: + if not self._active: + operator_logger.info("no active recording episode to discard") + return + self._recorder.discard_episode() + operator_logger.info("recording episode discarded | reason=%s | frames=%d", reason, self._episode_frames) + self._active = False + self._episode_frames = 0 + + def _handle_video(self, descriptor: SharedFrameDescriptor) -> None: + if int(descriptor.seq) == self._latest_video_seq: + return + self._latest_video_seq = int(descriptor.seq) + if not self._active: + return + record = self._latest_record + if record is None: + return + mode = str(record.mode).lower() + if mode not in self.record_modes or not bool(record.recordable): + logger.warning("Recording stopped because mode is no longer recordable: %s", record.mode) + self._discard_episode("mode not recordable") + return + image = self._frame_reader.read(descriptor, copy=True) + self._recorder.add_frame( + image=np.asarray(image, dtype=np.uint8), + state=np.asarray(record.observation_state, dtype=np.float32), + mode=np.asarray(record.observation_mode, dtype=np.float32), + action=np.asarray(record.action_reference_qpos, dtype=np.float32), + hand_action=normalize_hand_action( + self._latest_hand_command.left_pose, + self._latest_hand_command.right_pose, + ), + task=self.task, + ) + self._episode_frames += 1 + +def _run_recording_worker( + cfg: dict[str, Any], + endpoints: Sim2RealIpcEndpoints, + stop_event: MpEvent, +) -> None: + def _main() -> None: + worker = _RecordingWorker(cfg, endpoints, stop_event) + worker.run() + + _worker_loop("recording_worker", cfg, _main) + + +class _HandSnapshotProxy: + def __init__(self) -> None: + self.hand_snapshot: Any | None = None + self.controller_snapshot: Any | None = None + + def get_hand_snapshot(self) -> Any | None: + return self.hand_snapshot + + def get_controller_snapshot(self) -> Any | None: + return self.controller_snapshot + + +def _hand_worker_active_for_mode(mode_packet: ModeStatePacket) -> bool: + del mode_packet + return True + + +def _run_hand_worker( + cfg: dict[str, Any], + endpoints: Sim2RealIpcEndpoints, + stop_event: MpEvent, +) -> None: + def _main() -> None: + proxy = _HandSnapshotProxy() + runtime = build_hand_runtime(cfg) + hand_sub = LatestSubscriber(endpoints.hand_pub, HAND_TOPIC) + controller_sub = LatestSubscriber(endpoints.controller_pub, CONTROLLER_TOPIC) + mode_sub = LatestSubscriber(endpoints.mode_pub, MODE_TOPIC) + command_sub = LatestSubscriber(endpoints.command_pub, COMMAND_TOPIC) + hand_command_pub = ZmqPublisher(endpoints.hand_command_pub) + active = False + hz = float(cfg_get(_mp_cfg(cfg), "hand_worker_hz", 120.0)) + sleep_s = 1.0 / max(hz, 1.0) + hands_cfg = cfg_get(cfg, "hands", {}) or {} + driver = str(cfg_get(hands_cfg, "driver", "linkerhand_l6")).strip().lower() + hand_mode = str(cfg_get(hands_cfg, "mode", "gripper")).strip().lower() + left_pose, right_pose = _configured_open_hand_pose(cfg) + command_seq = 0 + + def _apply_hand_commands(commands: tuple[HandPoseCommand, ...]) -> bool: + nonlocal left_pose, right_pose + changed = False + for hand_command in commands: + pose = np.asarray(hand_command.pose, dtype=np.float32).reshape(-1) + if pose.shape[0] != 6: + logger.warning("Ignoring %s hand command with invalid pose shape %s", hand_command.side, pose.shape) + continue + if hand_command.side == "left": + left_pose = pose.copy() + changed = True + elif hand_command.side == "right": + right_pose = pose.copy() + changed = True + else: + logger.warning("Ignoring hand command with unsupported side %r", hand_command.side) + return changed + + def _publish_hand_command(*, timestamp_s: float, active_state: bool) -> None: + nonlocal command_seq + command_seq += 1 + hand_command_pub.publish( + HAND_COMMAND_TOPIC, + HandCommandPacket( + timestamp_s=float(timestamp_s), + driver=driver, + mode=hand_mode, + active=bool(active_state), + left_pose=np.asarray(left_pose, dtype=np.float32).copy(), + right_pose=np.asarray(right_pose, dtype=np.float32).copy(), + seq=command_seq, + ), + ) + + try: + startup_commands = runtime.start() + startup_s = time.monotonic() + _apply_hand_commands(startup_commands) + _publish_hand_command(timestamp_s=startup_s, active_state=False) + while not stop_event.is_set(): + command = command_sub.recv_latest() + if isinstance(command, CommandPacket) and command.command == "shutdown": + stop_event.set() + break + hand_packet = hand_sub.recv_latest() + if isinstance(hand_packet, SnapshotPacket): + proxy.hand_snapshot = hand_packet.snapshot + controller_packet = controller_sub.recv_latest() + if isinstance(controller_packet, SnapshotPacket): + proxy.controller_snapshot = controller_packet.snapshot + mode_packet = mode_sub.recv_latest() + if isinstance(mode_packet, ModeStatePacket): + active = _hand_worker_active_for_mode(mode_packet) + try: + now_s = time.monotonic() + commands = runtime.tick( + controller_snapshot=proxy.controller_snapshot, + hand_snapshot=proxy.hand_snapshot, + active=active, + now_s=now_s, + ) + if commands: + if _apply_hand_commands(commands): + _publish_hand_command(timestamp_s=now_s, active_state=active) + except Exception: + logger.exception("Dexterous hand worker tick failed; hand control continues") + time.sleep(sleep_s) + finally: + try: + shutdown_commands = runtime.close() + shutdown_s = time.monotonic() + if _apply_hand_commands(shutdown_commands): + _publish_hand_command(timestamp_s=shutdown_s, active_state=False) + finally: + hand_sub.close() + controller_sub.close() + mode_sub.close() + command_sub.close() + hand_command_pub.close() + + _worker_loop("hand_worker", cfg, _main) diff --git a/teleopit/sim2real/mp/shm.py b/teleopit/sim2real/mp/shm.py new file mode 100644 index 00000000..36feaf9a --- /dev/null +++ b/teleopit/sim2real/mp/shm.py @@ -0,0 +1,106 @@ +"""Shared-memory ring buffer for large sim2real video frames.""" + +from __future__ import annotations + +from multiprocessing import shared_memory +from typing import Any + +import numpy as np +from numpy.typing import NDArray + +from teleopit.sim2real.mp.messages import SharedFrameDescriptor + + +class SharedFrameRingWriter: + """Owns a fixed-size ring of frame slots in shared memory.""" + + def __init__( + self, + *, + shape: tuple[int, ...], + dtype: str | np.dtype[Any] = np.uint8, + slots: int = 3, + name: str | None = None, + ) -> None: + if slots <= 0: + raise ValueError(f"slots must be positive, got {slots}") + dtype_np = np.dtype(dtype) + shape_tuple = tuple(int(dim) for dim in shape) + if not shape_tuple or any(dim <= 0 for dim in shape_tuple): + raise ValueError(f"shape must contain positive dimensions, got {shape}") + + self.shape = shape_tuple + self.dtype = dtype_np + self.slots = int(slots) + self._slot_size = int(np.prod(shape_tuple)) * dtype_np.itemsize + self._shm = shared_memory.SharedMemory( + name=name, + create=True, + size=self._slot_size * self.slots, + ) + self._seq = 0 + + @property + def name(self) -> str: + return self._shm.name + + def write(self, frame: NDArray[np.generic], *, timestamp_s: float) -> SharedFrameDescriptor: + frame_arr = np.asarray(frame, dtype=self.dtype) + if tuple(frame_arr.shape) != self.shape: + raise ValueError(f"frame shape {frame_arr.shape} does not match shared ring shape {self.shape}") + slot = self._seq % self.slots + view = self._slot_view(slot) + np.copyto(view, np.ascontiguousarray(frame_arr)) + descriptor = SharedFrameDescriptor( + shm_name=self._shm.name, + slot=slot, + seq=self._seq, + timestamp_s=float(timestamp_s), + shape=self.shape, + dtype=str(self.dtype), + slots=self.slots, + ) + self._seq += 1 + return descriptor + + def close(self, *, unlink: bool = True) -> None: + self._shm.close() + if unlink: + try: + self._shm.unlink() + except FileNotFoundError: + pass + + def _slot_view(self, slot: int) -> NDArray[np.generic]: + if slot < 0 or slot >= self.slots: + raise ValueError(f"slot must be in [0, {self.slots}), got {slot}") + start = slot * self._slot_size + return np.ndarray(self.shape, dtype=self.dtype, buffer=self._shm.buf, offset=start) + + +class SharedFrameRingReader: + """Attaches to shared frame rings lazily and reads descriptor-selected slots.""" + + def __init__(self) -> None: + self._rings: dict[str, shared_memory.SharedMemory] = {} + + def read(self, descriptor: SharedFrameDescriptor, *, copy: bool = False) -> NDArray[np.generic]: + shm = self._rings.get(descriptor.shm_name) + if shm is None: + shm = shared_memory.SharedMemory(name=descriptor.shm_name) + self._rings[descriptor.shm_name] = shm + + dtype = np.dtype(descriptor.dtype) + shape = tuple(int(dim) for dim in descriptor.shape) + slot_size = int(np.prod(shape)) * dtype.itemsize + if descriptor.slot < 0 or descriptor.slot >= descriptor.slots: + raise ValueError(f"descriptor slot {descriptor.slot} out of range for {descriptor.slots} slots") + view = np.ndarray(shape, dtype=dtype, buffer=shm.buf, offset=descriptor.slot * slot_size) + if copy: + return view.copy() + return view + + def close(self) -> None: + for shm in self._rings.values(): + shm.close() + self._rings.clear() diff --git a/teleopit/sim2real/reference_processor.py b/teleopit/sim2real/reference_processor.py index e0dc4841..315ba795 100644 --- a/teleopit/sim2real/reference_processor.py +++ b/teleopit/sim2real/reference_processor.py @@ -15,7 +15,7 @@ from teleopit.controllers import reference_processing as ref_proc from teleopit.controllers.observation import VelCmdObservationBuilder -from teleopit.controllers.qpos_interpolator import QposLowPassFilter +from teleopit.inputs.human_frame_validation import validate_human_frame from teleopit.sim.realtime_utils import ExponentialVecSmoother from teleopit.sim.reference_timeline import ReferenceWindow @@ -35,14 +35,11 @@ def __init__( num_actions: int, reference_velocity_smoothing_alpha: float, reference_anchor_velocity_smoothing_alpha: float, - reference_qpos_smoothing_alpha: float, - max_pos_value: float, ) -> None: self._obs_builder = obs_builder self._policy = policy self._policy_hz = policy_hz self._num_actions = num_actions - self._max_pos_value = max_pos_value # Yaw alignment state (lazy-init) self._fixed_reference_yaw_quat: Float32Array | None = None @@ -54,7 +51,6 @@ def __init__( self._motion_joint_vel_smoother = ExponentialVecSmoother(reference_velocity_smoothing_alpha) self._motion_anchor_lin_vel_smoother = ExponentialVecSmoother(reference_anchor_velocity_smoothing_alpha) self._motion_anchor_ang_vel_smoother = ExponentialVecSmoother(reference_anchor_velocity_smoothing_alpha) - self._reference_qpos_smoother = QposLowPassFilter(reference_qpos_smoothing_alpha) # Last reference qpos for velocity computation self._last_reference_qpos: Float64Array | None = None @@ -80,14 +76,7 @@ def retarget_to_qpos(self, retargeted: object) -> Float64Array: return ref_proc.retarget_to_qpos(retargeted) def frame_is_valid(self, frame: dict[str, tuple[np.ndarray, np.ndarray]]) -> bool: - for pos, quat in frame.values(): - if np.any(np.isnan(pos)) or np.any(np.isinf(pos)): - return False - if np.any(np.abs(pos) > self._max_pos_value): - return False - if np.any(np.isnan(quat)) or np.any(np.isinf(quat)): - return False - return True + return validate_human_frame(frame).valid # ------------------------------------------------------------------ # Yaw alignment @@ -176,8 +165,11 @@ def build_observation( anchor_lin_vel_w: Float32Array, anchor_ang_vel_w: Float32Array, reference_window: ReferenceWindow | None, + reference_window_aligned: bool = False, ) -> Float32Array: - aligned_reference_window = self.align_reference_window(reference_window, robot_state) + aligned_reference_window = ( + reference_window if reference_window_aligned else self.align_reference_window(reference_window, robot_state) + ) return ref_proc.dispatch_build_observation( self._obs_builder, robot_state, reference_window, aligned_reference_window, motion_qpos, motion_joint_vel, last_action, @@ -201,9 +193,6 @@ def validate_observation(self, obs: Float32Array) -> Float32Array: # Smoothing # ------------------------------------------------------------------ - def apply_qpos_smoothing(self, qpos: Float64Array) -> Float64Array: - return self._reference_qpos_smoother.apply(qpos) - def apply_joint_vel_smoothing(self, vel: Float32Array) -> Float32Array: return self._motion_joint_vel_smoother.apply(vel) @@ -219,5 +208,4 @@ def reset_smoothers(self) -> None: self._motion_joint_vel_smoother.reset() self._motion_anchor_lin_vel_smoother.reset() self._motion_anchor_ang_vel_smoother.reset() - self._reference_qpos_smoother.reset() self._last_reference_qpos = None diff --git a/teleopit/sim2real/safety.py b/teleopit/sim2real/safety.py index 7cd05423..6ce08ebe 100644 --- a/teleopit/sim2real/safety.py +++ b/teleopit/sim2real/safety.py @@ -30,18 +30,18 @@ def __init__( ) -> None: self._robot = robot self._policy_hz = policy_hz - real_cfg = cfg_get(cfg, "real_robot") - # KP ramp (gradually increase PD gains after episode-reset) - _legacy_ramp_dur = cfg_get(cfg, "startup_ramp_duration", cfg_get(real_cfg, "startup_ramp_duration", 2.0)) - kp_ramp_dur = float(cfg_get(cfg, "kp_ramp_duration", _legacy_ramp_dur)) - self._kp_ramp_duration_steps: int = max(1, int(kp_ramp_dur * policy_hz)) + # Kp ramp gradually increases PD gains after episode-reset. + startup_ramp_duration = float(cfg_get(cfg, "startup_ramp_duration", 2.0)) + self._default_kp_ramp_duration_steps: int = max(1, int(startup_ramp_duration * policy_hz)) + self._kp_ramp_duration_steps: int = self._default_kp_ramp_duration_steps self._kp_ramp_step: int = 0 self._kp_ramp_active: bool = False self._kp_nominal = np.asarray(cfg_get(real_cfg, "kp_real", [100] * num_actions), dtype=np.float32) self._kd_nominal = np.asarray(cfg_get(real_cfg, "kd_real", [2] * num_actions), dtype=np.float32) - self._kp_ramp_floor_ratio: float = float(cfg_get(cfg, "kp_ramp_floor_ratio", 0.1)) + self._default_kp_ramp_floor_ratio: float = float(cfg_get(cfg, "kp_ramp_floor_ratio", 0.1)) + self._kp_ramp_floor_ratio: float = self._default_kp_ramp_floor_ratio # Joint safety limits self._joint_vel_limit: float = float( @@ -71,8 +71,21 @@ def compute_kp_ramp_gains(self) -> tuple[Float32Array, Float32Array] | None: return np.asarray(kp, dtype=np.float32), self._kd_nominal.copy() - def start_kp_ramp(self) -> None: + def start_kp_ramp( + self, + *, + duration_s: float | None = None, + floor_ratio: float | None = None, + ) -> None: """Arm the Kp ramp for gradual PD gain increase.""" + if duration_s is None: + self._kp_ramp_duration_steps = self._default_kp_ramp_duration_steps + else: + self._kp_ramp_duration_steps = max(1, int(float(duration_s) * self._policy_hz)) + if floor_ratio is None: + self._kp_ramp_floor_ratio = self._default_kp_ramp_floor_ratio + else: + self._kp_ramp_floor_ratio = float(floor_ratio) self._kp_ramp_step = 0 self._kp_ramp_active = True logger.info( diff --git a/teleopit/sim2real/unitree_g1.py b/teleopit/sim2real/unitree_g1.py index 704b612c..d36c95f6 100644 --- a/teleopit/sim2real/unitree_g1.py +++ b/teleopit/sim2real/unitree_g1.py @@ -30,6 +30,7 @@ class UnitreeG1Robot: def __init__(self, cfg: Any) -> None: self._network_interface: str = str(cfg_get(cfg, "network_interface", "eth0")) + self._publish_hz: int = int(cfg_get(cfg, "publish_hz", 200)) self._kp = np.asarray(cfg_get(cfg, "kp_real", [100] * NUM_JOINTS), dtype=np.float32) self._kd = np.asarray(cfg_get(cfg, "kd_real", [2] * NUM_JOINTS), dtype=np.float32) @@ -40,7 +41,7 @@ def __init__(self, cfg: Any) -> None: import g1_bridge_sdk - self._bridge = g1_bridge_sdk.G1Bridge(self._network_interface) + self._bridge = g1_bridge_sdk.G1Bridge(self._network_interface, self._publish_hz) self._publishing: bool = False if not self._bridge.wait_for_state(3.0): diff --git a/tests/conftest.py b/tests/conftest.py index 0d7ba0c8..6f7b5110 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -51,9 +51,7 @@ def find_g1_xml_path() -> str | None: """Return the preferred test XML path for G1 MuJoCo-based tests.""" root = Path(__file__).parent.parent candidates = [ - root / "teleopit" / "retargeting" / "gmr" / "assets" / "unitree_g1" / "g1_mjlab.xml", - root / "GMR" / "assets" / "unitree_g1" / "g1_sim2sim_29dof.xml", - root / "teleopit" / "retargeting" / "gmr" / "assets" / "unitree_g1" / "g1_sim2sim_29dof.xml", + root / "assets" / "robots" / "unitree_g1" / "g1_29dof.xml", ] for path in candidates: if path.exists(): diff --git a/tests/test_adaptive_sampling.py b/tests/test_adaptive_sampling.py deleted file mode 100644 index 49f932f5..00000000 --- a/tests/test_adaptive_sampling.py +++ /dev/null @@ -1,284 +0,0 @@ -from __future__ import annotations - -import pytest -import torch - -from train_mimic.tasks.tracking.mdp.commands import ( - _cap_failure_rates, - _compute_clip_counts, - _compute_clip_failure_rate, - _normalize_sampling_probabilities, - _validate_legacy_adaptive_config, -) - - -# --------------------------------------------------------------------------- -# Existing tests (refactored helper, same semantics) -# --------------------------------------------------------------------------- - - -def test_compute_clip_failure_rate_is_invariant_to_parallel_env_count() -> None: - motion_ids = torch.tensor([0, 0, 0, 1, 1, 1], dtype=torch.long) - episode_failed = torch.tensor([True, False, True, False, False, True]) - - small_rate = _compute_clip_failure_rate(motion_ids, episode_failed, bin_count=3) - large_rate = _compute_clip_failure_rate( - motion_ids.repeat(16), episode_failed.repeat(16), bin_count=3 - ) - - expected = torch.tensor([2.0 / 3.0, 1.0 / 3.0, 0.0], dtype=torch.float32) - assert torch.allclose(small_rate, expected) - assert torch.allclose(large_rate, expected) - - -def test_compute_clip_failure_rate_ignores_unseen_clips() -> None: - motion_ids = torch.tensor([1, 1, 3, 3], dtype=torch.long) - episode_failed = torch.tensor([True, False, False, False]) - - rate = _compute_clip_failure_rate(motion_ids, episode_failed, bin_count=5) - - expected = torch.tensor([0.0, 0.5, 0.0, 0.0, 0.0], dtype=torch.float32) - assert torch.allclose(rate, expected) - - -# --------------------------------------------------------------------------- -# New tests for _compute_clip_counts -# --------------------------------------------------------------------------- - - -def test_compute_clip_counts_returns_correct_counts() -> None: - motion_ids = torch.tensor([0, 0, 0, 1, 1, 1], dtype=torch.long) - episode_failed = torch.tensor([True, False, True, False, False, True]) - - exposure, failure = _compute_clip_counts(motion_ids, episode_failed, bin_count=3) - - assert torch.allclose(exposure, torch.tensor([3.0, 3.0, 0.0])) - assert torch.allclose(failure, torch.tensor([2.0, 1.0, 0.0])) - - -# --------------------------------------------------------------------------- -# EMA no-decay on empty steps -# --------------------------------------------------------------------------- - - -def test_ema_skips_update_when_no_data() -> None: - """bin_failed_rate must NOT decay when accumulators are empty.""" - bin_count = 4 - alpha = 0.01 - bin_failed_rate = torch.tensor([0.5, 0.3, 0.0, 0.2]) - accum_exposure = torch.zeros(bin_count) - accum_failure = torch.zeros(bin_count) - - # Simulate the guarded EMA logic from _update_command - if accum_exposure.sum() > 0: - valid = accum_exposure > 0 - global_rate = torch.zeros_like(bin_failed_rate) - global_rate[valid] = accum_failure[valid] / accum_exposure[valid] - bin_failed_rate = alpha * global_rate + (1 - alpha) * bin_failed_rate - - expected = torch.tensor([0.5, 0.3, 0.0, 0.2]) - assert torch.allclose(bin_failed_rate, expected) - - -# --------------------------------------------------------------------------- -# Fail-fast validation for invalid adaptive distributions -# --------------------------------------------------------------------------- - - -def test_normalize_sampling_probabilities_raises_on_zero_mass() -> None: - sampling_probabilities = torch.zeros(5) - - with pytest.raises(ValueError, match="invalid probability mass"): - _normalize_sampling_probabilities( - sampling_probabilities, - adaptive_uniform_ratio=0.0, - bin_count=5, - ) - - -def test_validate_legacy_adaptive_config_rejects_nondefault_kernel_size() -> None: - with pytest.raises(ValueError, match="adaptive_kernel_size"): - _validate_legacy_adaptive_config(adaptive_kernel_size=3, adaptive_lambda=0.8) - - -def test_validate_legacy_adaptive_config_rejects_nondefault_lambda() -> None: - with pytest.raises(ValueError, match="adaptive_lambda"): - _validate_legacy_adaptive_config(adaptive_kernel_size=1, adaptive_lambda=0.5) - - -# --------------------------------------------------------------------------- -# Accumulation sums across multiple resamples -# --------------------------------------------------------------------------- - - -def test_accumulation_sums_across_resamples() -> None: - """Two accumulate calls before EMA update should sum counts correctly.""" - bin_count = 3 - accum_exposure = torch.zeros(bin_count) - accum_failure = torch.zeros(bin_count) - - # First resample batch - ids1 = torch.tensor([0, 0, 1], dtype=torch.long) - failed1 = torch.tensor([True, False, True]) - e1, f1 = _compute_clip_counts(ids1, failed1, bin_count) - accum_exposure += e1 - accum_failure += f1 - - # Second resample batch - ids2 = torch.tensor([1, 2, 2], dtype=torch.long) - failed2 = torch.tensor([False, True, False]) - e2, f2 = _compute_clip_counts(ids2, failed2, bin_count) - accum_exposure += e2 - accum_failure += f2 - - assert torch.allclose(accum_exposure, torch.tensor([2.0, 2.0, 2.0])) - assert torch.allclose(accum_failure, torch.tensor([1.0, 1.0, 1.0])) - - -# --------------------------------------------------------------------------- -# adaptive_bin: _cap_failure_rates -# --------------------------------------------------------------------------- - - -def test_cap_failure_rates_clamps_outliers() -> None: - rates = torch.tensor([0.1, 0.9, 0.2]) - # mean = 0.4, beta=2.0 -> cap = 0.8 - capped = _cap_failure_rates(rates, beta=2.0) - expected = torch.tensor([0.1, 0.8, 0.2]) - assert torch.allclose(capped, expected) - - -def test_cap_failure_rates_all_zero() -> None: - rates = torch.zeros(5) - capped = _cap_failure_rates(rates, beta=5.0) - assert torch.allclose(capped, torch.zeros(5)) - - -def test_cap_failure_rates_uniform() -> None: - """When all rates are equal, capping changes nothing.""" - rates = torch.tensor([0.3, 0.3, 0.3]) - capped = _cap_failure_rates(rates, beta=2.0) - assert torch.allclose(capped, rates) - - -# --------------------------------------------------------------------------- -# adaptive_bin: build_time_bins (via MotionLib mock) -# --------------------------------------------------------------------------- - - -def _make_mock_motion_lib( - clip_sample_start_s: list[float], - clip_sample_end_s: list[float], - clip_weights: list[float], -): - """Create a lightweight mock with just the fields build_time_bins needs.""" - import math as _math - from types import SimpleNamespace - - from train_mimic.tasks.tracking.mdp.commands import MotionLib - - num_clips = len(clip_weights) - device = "cpu" - - mock = SimpleNamespace() - mock.num_clips = num_clips - mock._device = device - mock.clip_weights = torch.tensor(clip_weights, dtype=torch.float32, device=device) - mock.clip_sample_start_s = torch.tensor( - clip_sample_start_s, dtype=torch.float32, device=device - ) - mock.clip_sample_end_s = torch.tensor( - clip_sample_end_s, dtype=torch.float32, device=device - ) - # Bind the real method - mock.build_time_bins = MotionLib.build_time_bins.__get__(mock, type(mock)) - return mock - - -def test_build_time_bins_single_clip() -> None: - ml = _make_mock_motion_lib( - clip_sample_start_s=[0.0], - clip_sample_end_s=[12.0], - clip_weights=[1.0], - ) - result = ml.build_time_bins(bin_duration_s=5.0) - - assert result["num_bins"] == 3 # [0,5), [5,10), [10,12) - assert torch.equal(result["bin_clip_id"], torch.tensor([0, 0, 0])) - assert torch.allclose(result["bin_start_s"], torch.tensor([0.0, 5.0, 10.0])) - assert torch.allclose(result["bin_end_s"], torch.tensor([5.0, 10.0, 12.0])) - assert torch.allclose(result["bin_duration"], torch.tensor([5.0, 5.0, 2.0])) - assert result["clip_bin_offset"][0].item() == 0 - - -def test_build_time_bins_multiple_clips() -> None: - ml = _make_mock_motion_lib( - clip_sample_start_s=[0.0, 0.0, 0.0], - clip_sample_end_s=[3.0, 7.0, 10.0], - clip_weights=[1.0, 1.0, 1.0], - ) - result = ml.build_time_bins(bin_duration_s=5.0) - - # clip 0: 3s -> 1 bin; clip 1: 7s -> 2 bins; clip 2: 10s -> 2 bins - assert result["num_bins"] == 5 - expected_clip_ids = torch.tensor([0, 1, 1, 2, 2]) - assert torch.equal(result["bin_clip_id"], expected_clip_ids) - - -def test_build_time_bins_skips_zero_weight() -> None: - ml = _make_mock_motion_lib( - clip_sample_start_s=[0.0, 0.0], - clip_sample_end_s=[10.0, 10.0], - clip_weights=[1.0, 0.0], - ) - result = ml.build_time_bins(bin_duration_s=5.0) - - assert result["num_bins"] == 2 # only clip 0 - assert torch.equal(result["bin_clip_id"], torch.tensor([0, 0])) - assert result["clip_bin_offset"][1].item() == -1 - - -def test_build_time_bins_short_clip() -> None: - """A clip shorter than bin_duration produces exactly 1 bin.""" - ml = _make_mock_motion_lib( - clip_sample_start_s=[0.0], - clip_sample_end_s=[2.0], - clip_weights=[1.0], - ) - result = ml.build_time_bins(bin_duration_s=5.0) - - assert result["num_bins"] == 1 - assert torch.allclose(result["bin_duration"], torch.tensor([2.0])) - - -# --------------------------------------------------------------------------- -# adaptive_bin: probability formula -# --------------------------------------------------------------------------- - - -def test_adaptive_bin_probability_formula() -> None: - """Verify the full capped + blended probability computation.""" - rates = torch.tensor([0.1, 0.9, 0.2, 0.0]) - beta = 2.0 - alpha = 0.8 - N = len(rates) - - # Step 1: cap - capped = _cap_failure_rates(rates, beta) - # mean=0.3, cap=0.6 -> [0.1, 0.6, 0.2, 0.0] - expected_capped = torch.tensor([0.1, 0.6, 0.2, 0.0]) - assert torch.allclose(capped, expected_capped) - - # Step 2: normalize to p_hat - capped_sum = capped.sum() - p_hat = capped / capped_sum # [0.1/0.9, 0.6/0.9, 0.2/0.9, 0] - - # Step 3: blend - p_final = alpha * p_hat + (1.0 - alpha) / N - p_final = p_final / p_final.sum() - - # Verify it's a valid distribution - assert torch.allclose(p_final.sum(), torch.tensor(1.0)) - assert (p_final > 0).all(), "All bins should have positive probability" - # Bin 1 (highest failure) should have highest probability - assert p_final[1] == p_final.max() diff --git a/tests/test_console.py b/tests/test_console.py new file mode 100644 index 00000000..e4e163b8 --- /dev/null +++ b/tests/test_console.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import logging + +from teleopit.runtime.console import ( + OPERATOR_LOGGER_NAME, + PlainConsole, + configure_runtime_logging, + sim2real_operator_controls, + sim_keyboard_controls, +) + + +def test_sim_console_shows_only_enabled_keyboard_controls() -> None: + cfg = { + "input": {"provider": "pico4"}, + "keyboard": {"enabled": True}, + } + + labels = [control.keys for control in sim_keyboard_controls(cfg)] + + assert labels == ["Y", "A", "B", "X", "Q"] + + +def test_sim_console_hides_non_keyboard_controls() -> None: + cfg = { + "input": {"provider": "pico4"}, + "keyboard": {"enabled": False}, + } + + assert sim_keyboard_controls(cfg) == () + + +def test_sim2real_console_shows_remote_and_pico_controls() -> None: + cfg = {"input": {"provider": "pico4"}, "recording": {"enabled": False}} + + rendered = PlainConsole(title="Teleopit sim2real", color=False).render( + controls=sim2real_operator_controls(cfg), + control_section="Controls", + show_help_key=False, + ) + + assert "Controls" in rendered + assert "[Remote Start] standing" in rendered + assert "[Remote L1+R1] damping / estop" in rendered + assert "[Pico/Controller A] pause/resume" in rendered + assert "[Pico/Controller B] arms" in rendered + assert "[H] help" not in rendered + + +def test_sim2real_console_includes_terminal_recording_controls_when_enabled() -> None: + cfg = {"recording": {"enabled": True}} + + rendered = PlainConsole(title="Teleopit sim2real", color=False).render( + controls=sim2real_operator_controls(cfg), + control_section="Controls", + show_help_key=False, + ) + + assert "Controls" in rendered + assert "[Remote Start] standing" in rendered + assert "[R] start" in rendered + assert "[Q] shutdown" in rendered + assert "[H] help" not in rendered + + +def test_console_can_highlight_important_words_with_ansi_color() -> None: + rendered = PlainConsole(title="Teleopit sim2sim", color=True).render( + status=(("State", "MOCAP"),), + controls=sim_keyboard_controls({"input": {"provider": "bvh"}, "playback": {"keyboard": {"enabled": True}}}), + events=("waiting for input",), + ) + + assert "\033[" in rendered + assert "[Space/P]" in rendered + assert "MOCAP" in rendered + + +def test_runtime_logging_keeps_operator_info_but_hides_noisy_info() -> None: + configure_runtime_logging({"console": {"log_level": "warning"}}, force=True) + + assert logging.getLogger().getEffectiveLevel() == logging.WARNING + assert logging.getLogger(OPERATOR_LOGGER_NAME).getEffectiveLevel() == logging.INFO + assert logging.getLogger("pico_bridge").getEffectiveLevel() == logging.WARNING diff --git a/tests/test_controller.py b/tests/test_controller.py index 3a07d103..e86e41ca 100644 --- a/tests/test_controller.py +++ b/tests/test_controller.py @@ -41,11 +41,11 @@ def test_normalize_clip_range_inverted_raises(self): def test_extract_feature_dim_int(self): from teleopit.controllers.rl_policy import RLPolicyController - assert RLPolicyController._extract_feature_dim([1, 166]) == 166 + assert RLPolicyController._extract_feature_dim([1, 167]) == 167 def test_extract_feature_dim_string(self): from teleopit.controllers.rl_policy import RLPolicyController - assert RLPolicyController._extract_feature_dim(["batch", "166"]) == 166 + assert RLPolicyController._extract_feature_dim(["batch", "167"]) == 167 def test_extract_feature_dim_dynamic(self): from teleopit.controllers.rl_policy import RLPolicyController @@ -105,7 +105,7 @@ class TestRLPolicyControllerInference: def test_compute_action_shape(self): from teleopit.controllers.rl_policy import RLPolicyController - obs_dim = 166 + obs_dim = 167 action_dim = 29 # Build a controller with mocked internals @@ -138,7 +138,7 @@ def test_compute_action_wrong_dim_raises(self): from teleopit.controllers.rl_policy import RLPolicyController ctrl = RLPolicyController.__new__(RLPolicyController) - ctrl._expected_obs_dim = 166 + ctrl._expected_obs_dim = 167 ctrl.clip_range = (-10.0, 10.0) ctrl.action_scale = np.ones(29, dtype=np.float32) ctrl._session = MagicMock() @@ -154,8 +154,8 @@ def test_reset_is_noop(self): ctrl = RLPolicyController.__new__(RLPolicyController) from collections import deque ctrl._history_buf = deque(maxlen=3) - ctrl._last_obs_input = np.zeros(166, dtype=np.float32) - ctrl._last_obs_history_input = np.zeros((3, 166), dtype=np.float32) + ctrl._last_obs_input = np.zeros(167, dtype=np.float32) + ctrl._last_obs_history_input = np.zeros((3, 167), dtype=np.float32) assert ctrl.reset() is None assert len(ctrl._history_buf) == 0 assert ctrl._last_obs_input is None @@ -165,7 +165,7 @@ def test_multi_input_debug_inputs_capture_history(self): from teleopit.controllers.rl_policy import RLPolicyController ctrl = RLPolicyController.__new__(RLPolicyController) - ctrl._expected_obs_dim = 166 + ctrl._expected_obs_dim = 167 ctrl.clip_range = (-10.0, 10.0) ctrl.action_scale = np.ones(2, dtype=np.float32) ctrl.default_dof_pos = np.zeros(2, dtype=np.float32) @@ -173,7 +173,7 @@ def test_multi_input_debug_inputs_capture_history(self): ctrl._output_name = "action" ctrl._multi_input = True ctrl._history_length = 3 - ctrl._history_obs_dim = 166 + ctrl._history_obs_dim = 167 from collections import deque ctrl._history_buf = deque(maxlen=3) ctrl._last_obs_input = None @@ -182,9 +182,9 @@ def test_multi_input_debug_inputs_capture_history(self): mock_session.run.return_value = [np.zeros((1, 2), dtype=np.float32)] ctrl._session = mock_session - obs = np.zeros(166, dtype=np.float32) + obs = np.zeros(167, dtype=np.float32) ctrl.compute_action(obs) debug = ctrl.get_debug_inputs() assert debug["obs"] is not None assert debug["obs_history"] is not None - assert debug["obs_history"].shape == (3, 166) + assert debug["obs_history"].shape == (3, 167) diff --git a/tests/test_dataset_v2.py b/tests/test_dataset_v2.py index 07f5d894..3722ff0e 100644 --- a/tests/test_dataset_v2.py +++ b/tests/test_dataset_v2.py @@ -1,18 +1,20 @@ from __future__ import annotations +import csv import pickle from pathlib import Path import numpy as np import pytest +import h5py from train_mimic.data import dataset_builder +from train_mimic.data.dataset_lib import compute_dataset_stats, write_hdf5_motion_shard from train_mimic.data.dataset_builder import ( DatasetClipRow, SourceInputFile, DatasetSourceSpec, DatasetSpec, - assign_splits, build_dataset_from_spec, convert_source_to_npz_clips, load_dataset_spec, @@ -96,13 +98,10 @@ def test_load_dataset_spec_parses_typed_sources(tmp_path: Path) -> None: spec_path.write_text( f"""name: demo target_fps: 30 -val_percent: 5 -hash_salt: "" sources: - name: clips type: npz input: {tmp_path / 'npz_source'} - weight: 2.5 - name: lafan1 type: bvh input: {tmp_path / 'lafan1'} @@ -116,7 +115,6 @@ def test_load_dataset_spec_parses_typed_sources(tmp_path: Path) -> None: assert spec.name == "demo" assert spec.target_fps == 30 assert spec.sources[0].type == "npz" - assert spec.sources[0].weight == 2.5 assert spec.sources[1].type == "bvh" assert spec.sources[1].bvh_format == "lafan1" @@ -126,24 +124,93 @@ def test_load_dataset_spec_parses_preprocess(tmp_path: Path) -> None: spec_path.write_text( f"""name: demo target_fps: 30 -val_percent: 5 -hash_salt: "" preprocess: normalize_root_xy: true - ground_align: clip_min_foot + ground_align: none min_frames: 10 + max_all_off_ground_s: 0.5 + off_ground_height: 0.08 sources: - name: clips type: npz input: {tmp_path / 'npz_source'} + exclude_patterns: ["*obstacle*"] """, encoding="utf-8", ) spec = load_dataset_spec(spec_path) assert spec.preprocess.normalize_root_xy is True - assert spec.preprocess.ground_align == "clip_min_foot" + assert spec.preprocess.ground_align == "none" assert spec.preprocess.min_frames == 10 + assert spec.preprocess.max_all_off_ground_s == 0.5 + assert spec.preprocess.off_ground_height == 0.08 + assert spec.sources[0].exclude_patterns == ("*obstacle*",) + + +def test_load_dataset_spec_parses_seed_filter_preset(tmp_path: Path) -> None: + metadata_csv = tmp_path / "seed_metadata.csv" + metadata_csv.write_text("move_g1_path,is_mirror\n", encoding="utf-8") + spec_path = tmp_path / "seed.yaml" + spec_path.write_text( + f"""name: seed_demo +target_fps: 30 +sources: + - name: seed + type: seed_csv + input: {tmp_path / 'seed_source'} + metadata_csv: {metadata_csv} + seed_filter_preset: groot_strict + filters: + is_mirror: [false] +""", + encoding="utf-8", + ) + + spec = load_dataset_spec(spec_path) + assert spec.sources[0].seed_filter_preset == "groot_strict" + + +def test_load_dataset_spec_rejects_seed_filter_preset_on_non_seed_source(tmp_path: Path) -> None: + metadata_csv = tmp_path / "seed_metadata.csv" + metadata_csv.write_text("move_g1_path,is_mirror\n", encoding="utf-8") + spec_path = tmp_path / "bad_seed_preset.yaml" + spec_path.write_text( + f"""name: demo +target_fps: 30 +sources: + - name: clips + type: npz + input: {tmp_path / 'npz_source'} + metadata_csv: {metadata_csv} + seed_filter_preset: groot_strict +""", + encoding="utf-8", + ) + + with pytest.raises(ValueError, match="supported only for seed_csv sources"): + load_dataset_spec(spec_path) + + +def test_load_dataset_spec_rejects_unknown_seed_filter_preset(tmp_path: Path) -> None: + metadata_csv = tmp_path / "seed_metadata.csv" + metadata_csv.write_text("move_g1_path,is_mirror\n", encoding="utf-8") + spec_path = tmp_path / "bad_seed_preset.yaml" + spec_path.write_text( + f"""name: demo +target_fps: 30 +sources: + - name: seed + type: seed_csv + input: {tmp_path / 'seed_source'} + metadata_csv: {metadata_csv} + seed_filter_preset: unknown_preset +""", + encoding="utf-8", + ) + + with pytest.raises(ValueError, match="unknown seed_filter_preset"): + load_dataset_spec(spec_path) def test_load_dataset_spec_rejects_bvh_without_format(tmp_path: Path) -> None: @@ -151,8 +218,6 @@ def test_load_dataset_spec_rejects_bvh_without_format(tmp_path: Path) -> None: spec_path.write_text( """name: demo target_fps: 30 -val_percent: 5 -hash_salt: "" sources: - name: broken type: bvh @@ -165,22 +230,6 @@ def test_load_dataset_spec_rejects_bvh_without_format(tmp_path: Path) -> None: load_dataset_spec(spec_path) -def test_assign_splits_guarantees_non_empty_train_and_val() -> None: - rows = [ - DatasetClipRow("clip_a", "src", "a.npz", 10, 30, "", "/tmp/a.npz"), - DatasetClipRow("clip_b", "src", "b.npz", 10, 30, "", "/tmp/b.npz"), - ] - resolved = assign_splits(rows, 1, "") - splits = {row.clip_id: row.resolved_split for row in resolved} - assert set(splits.values()) == {"train", "val"} - - -def test_assign_splits_rejects_single_clip_dataset() -> None: - rows = [DatasetClipRow("clip_a", "src", "a.npz", 10, 30, "", "/tmp/a.npz")] - with pytest.raises(ValueError, match="at least 2 clips"): - assign_splits(rows, 5, "") - - def test_convert_source_to_npz_clips_handles_pkl_source(tmp_path: Path) -> None: input_dir = tmp_path / "pkl_source" _write_pkl(input_dir / "clip.pkl") @@ -273,6 +322,154 @@ def test_convert_source_to_npz_clips_rejects_dataset_root_for_npz_source(tmp_pat convert_source_to_npz_clips(source, tmp_path / "dataset" / "clips" / "npz_src", jobs=1) +def test_collect_source_files_with_report_applies_seed_filter_preset(tmp_path: Path) -> None: + metadata_csv = tmp_path / "seed_metadata.csv" + with metadata_csv.open("w", encoding="utf-8", newline="") as handle: + writer = csv.DictWriter( + handle, + fieldnames=[ + "move_g1_path", + "is_mirror", + "content_body_position", + "content_type_of_movement", + "content_props", + "filename", + "move_name", + ], + ) + writer.writeheader() + writer.writerow( + { + "move_g1_path": "g1/csv/240101/walk_forward.csv", + "is_mirror": "False", + "content_body_position": "standing", + "content_type_of_movement": "walking", + "content_props": "0", + "filename": "walk_forward", + "move_name": "walk_forward", + } + ) + writer.writerow( + { + "move_g1_path": "g1/csv/240101/sit_pose.csv", + "is_mirror": "False", + "content_body_position": "sitting", + "content_type_of_movement": "sitting", + "content_props": "chair", + "filename": "sit_pose", + "move_name": "sit_pose", + } + ) + + source = DatasetSourceSpec( + name="seed", + type="seed_csv", + input=str(tmp_path / "seed_source" / "g1" / "csv"), + metadata_csv=str(metadata_csv), + filters={"is_mirror": [False]}, + seed_filter_preset="groot_strict", + ) + + input_root = tmp_path / "seed_source" / "g1" / "csv" + input_root.mkdir(parents=True, exist_ok=True) + (input_root / "240101").mkdir(parents=True, exist_ok=True) + (input_root / "240101" / "walk_forward.csv").write_text("placeholder", encoding="utf-8") + (input_root / "240101" / "sit_pose.csv").write_text("placeholder", encoding="utf-8") + + items, _scan_root, report = dataset_builder._collect_source_files_with_report(source, quiet=True) + + assert [item.rel_no_suffix.as_posix() for item in items] == ["240101/walk_forward"] + assert report["scanned_files"] == 2 + assert report["metadata_rows_matched"] == 2 + assert report["preset_rejected_rows"] == 1 + assert report["kept_files"] == 1 + assert report["filtered_files"] == 1 + assert report["preset_reject_reasons"]["content_body_position:sitting"] == 1 + + +def test_collect_source_files_with_report_handles_single_file_source(tmp_path: Path) -> None: + npz_path = tmp_path / "clip_a.npz" + _write_npz_from_pkl(npz_path) + source = DatasetSourceSpec(name="clip", type="npz", input=str(npz_path)) + + items, scan_root, report = dataset_builder._collect_source_files_with_report( + source, quiet=True + ) + legacy_items, legacy_scan_root = dataset_builder._collect_source_files(source) + + assert scan_root == tmp_path + assert [item.rel_no_suffix.as_posix() for item in items] == ["clip_a"] + assert report["scanned_files"] == 1 + assert report["kept_files"] == 1 + assert legacy_scan_root == scan_root + assert [item.rel_no_suffix.as_posix() for item in legacy_items] == ["clip_a"] + + +def test_collect_source_files_with_report_applies_exclude_patterns(tmp_path: Path) -> None: + input_root = tmp_path / "lafan1" + input_root.mkdir(parents=True, exist_ok=True) + (input_root / "walk1.bvh").write_text("placeholder", encoding="utf-8") + (input_root / "obstacle_run.bvh").write_text("placeholder", encoding="utf-8") + nested = input_root / "subject1" + nested.mkdir() + (nested / "obstacle_jump.bvh").write_text("placeholder", encoding="utf-8") + + source = DatasetSourceSpec( + name="lafan1", + type="bvh", + input=str(input_root), + bvh_format="lafan1", + exclude_patterns=("*obstacle*",), + ) + + items, _scan_root, report = dataset_builder._collect_source_files_with_report( + source, + quiet=True, + ) + + assert [item.rel_no_suffix.as_posix() for item in items] == ["walk1"] + assert report["scanned_files"] == 3 + assert report["path_rejected_files"] == 2 + assert report["kept_files"] == 1 + assert report["filtered_files"] == 2 + assert report["path_reject_reasons"] == {"*obstacle*": 2} + + +def test_collect_source_files_with_report_preserves_path_excludes_with_metadata(tmp_path: Path) -> None: + input_root = tmp_path / "seed_source" / "g1" / "csv" + input_root.mkdir(parents=True, exist_ok=True) + for name in ("walk_a.csv", "walk_b.csv", "obstacle_walk.csv"): + (input_root / name).write_text("placeholder", encoding="utf-8") + + metadata_csv = tmp_path / "metadata.csv" + with metadata_csv.open("w", encoding="utf-8", newline="") as handle: + writer = csv.DictWriter(handle, fieldnames=["move_g1_path", "is_mirror"]) + writer.writeheader() + for name in ("walk_a.csv", "walk_b.csv", "obstacle_walk.csv"): + writer.writerow({"move_g1_path": f"g1/csv/{name}", "is_mirror": "False"}) + + source = DatasetSourceSpec( + name="seed", + type="seed_csv", + input=str(input_root), + metadata_csv=str(metadata_csv), + filters={"is_mirror": [False]}, + exclude_patterns=("*obstacle*",), + ) + + items, _scan_root, report = dataset_builder._collect_source_files_with_report( + source, + quiet=True, + ) + + assert [item.rel_no_suffix.as_posix() for item in items] == ["walk_a", "walk_b"] + assert report["scanned_files"] == 3 + assert report["path_rejected_files"] == 1 + assert report["kept_files"] == 2 + assert report["filtered_files"] == 1 + assert report["path_reject_reasons"] == {"*obstacle*": 1} + + def test_build_dataset_from_spec_writes_shard_directories(tmp_path: Path) -> None: npz_input = tmp_path / "npz_source" _write_npz_from_pkl(npz_input / "clip_a.npz") @@ -281,28 +478,77 @@ def test_build_dataset_from_spec_writes_shard_directories(tmp_path: Path) -> Non spec = DatasetSpec( name="demo_dataset", target_fps=30, - val_percent=5, - hash_salt="", sources=[DatasetSourceSpec(name="npz_src", type="npz", input=str(npz_input))], ) output_root = tmp_path / "datasets" + stale_cache = output_root / "demo_dataset" / "clips" / "npz_src" / "clip_a.npz" + _write_npz_from_pkl(stale_cache) + stale_payload = dict(np.load(stale_cache, allow_pickle=True)) + stale_payload["root_pos"] = np.asarray(stale_payload["root_pos"], dtype=np.float32) + 100.0 + np.savez(stale_cache, **stale_payload) + report = build_dataset_from_spec(spec, jobs=2, skip_fk_check=True, output_root=output_root) dataset_dir = output_root / "demo_dataset" assert report["dataset_dir"] == str(dataset_dir) - assert report["build_dir"] == str(dataset_dir) - assert (dataset_dir / "clips" / "npz_src" / "clip_a.npz").is_file() - assert (dataset_dir / "clips" / "npz_src" / "clip_b.npz").is_file() - assert (dataset_dir / "train" / "shard_000.npz").is_file() - assert (dataset_dir / "val" / "shard_000.npz").is_file() - assert (dataset_dir / "manifest_resolved.csv").is_file() - assert (dataset_dir / "build_info.json").is_file() - assert report["clip_counts"]["total"] == 2 + assert not (dataset_dir / "clips").exists() + assert (dataset_dir / "shard_000.h5").is_file() + assert report["input_clips"] == 2 + + with h5py.File(dataset_dir / "shard_000.h5", "r") as shard: + assert "root_pos" in shard + assert "root_quat_w" in shard + assert "joint_pos" in shard + assert "body_pos_w" not in shard + assert float(shard["root_pos"][0, 2]) < 10.0 - train_data = np.load(dataset_dir / "train" / "shard_000.npz", allow_pickle=True) - assert "clip_starts" in train_data.files - assert "clip_lengths" in train_data.files + +def test_build_dataset_from_spec_rejects_clips_root_source_without_deleting_input( + tmp_path: Path, +) -> None: + output_root = tmp_path / "datasets" + source_dir = output_root / "demo_dataset" / "clips" / "npz_src" + clip_path = source_dir / "clip_a.npz" + _write_npz_from_pkl(clip_path) + + spec = DatasetSpec( + name="demo_dataset", + target_fps=30, + sources=[DatasetSourceSpec(name="npz_src", type="npz", input=str(source_dir))], + ) + + with pytest.raises(ValueError, match="temporary clips directory"): + build_dataset_from_spec(spec, jobs=1, skip_fk_check=True, output_root=output_root) + + assert clip_path.is_file() + + +def test_collect_clip_rows_ignores_stale_excluded_cached_npz(tmp_path: Path) -> None: + npz_input = tmp_path / "npz_source" + for name in ("keep_a.npz", "keep_b.npz", "obstacle_old.npz"): + _write_npz_from_pkl(npz_input / name) + + spec = DatasetSpec( + name="demo_dataset", + target_fps=30, + sources=[ + DatasetSourceSpec( + name="npz_src", + type="npz", + input=str(npz_input), + exclude_patterns=("*obstacle*",), + ) + ], + ) + paths = dataset_builder.resolve_dataset_paths(spec, output_root=tmp_path / "datasets") + source_dir = paths.clips_root / "npz_src" + for name in ("keep_a.npz", "keep_b.npz", "obstacle_old.npz"): + _write_npz_from_pkl(source_dir / name) + + rows = dataset_builder.collect_clip_rows(spec, paths=paths) + + assert sorted(row.clip_id for row in rows) == ["npz_src:keep_a", "npz_src:keep_b"] def test_convert_source_to_npz_clips_applies_preprocess(tmp_path: Path) -> None: @@ -318,8 +564,6 @@ def test_convert_source_to_npz_clips_applies_preprocess(tmp_path: Path) -> None: preprocess=DatasetSpec( name="unused", target_fps=30, - val_percent=5, - hash_salt="", sources=[source], ).preprocess, ) @@ -335,7 +579,7 @@ def test_convert_source_to_npz_clips_applies_preprocess(tmp_path: Path) -> None: jobs=1, preprocess=DatasetPreprocessSpec( normalize_root_xy=True, - ground_align="clip_min_foot", + ground_align="first_frame_foot", ), ) clip = np.load(output_dir_2 / "clip_a.npz", allow_pickle=True) @@ -345,7 +589,67 @@ def test_convert_source_to_npz_clips_applies_preprocess(tmp_path: Path) -> None: right_idx = body_names.index("right_ankle_roll_link") assert np.allclose(clip["body_pos_w"][0, pelvis_idx, :2], 0.0) foot_z = clip["body_pos_w"][:, [left_idx, right_idx], 2] - assert np.isclose(float(np.min(foot_z)), 0.0) + assert np.isclose(float(np.min(foot_z[0])), 0.0) + + +def test_convert_source_to_npz_clips_skips_all_off_ground_clips_before_ground_align(tmp_path: Path) -> None: + npz_input = tmp_path / "npz_source" + _write_npz_from_pkl(npz_input / "keep.npz") + _write_npz_from_pkl(npz_input / "float.npz") + + for name, floating in (("keep.npz", False), ("float.npz", True)): + path = npz_input / name + clip = dict(np.load(path, allow_pickle=True)) + body_names = [str(body_name) for body_name in clip["body_names"].tolist()] + left_idx = body_names.index("left_ankle_roll_link") + right_idx = body_names.index("right_ankle_roll_link") + body_pos_w = np.asarray(clip["body_pos_w"]).copy() + if floating: + body_pos_w[..., 2] = np.maximum(body_pos_w[..., 2], 0.3) + else: + body_pos_w[:, left_idx, 2] = 0.0 + body_pos_w[:, right_idx, 2] = 0.3 + clip["body_pos_w"] = body_pos_w + np.savez(path, **clip) + + source = DatasetSourceSpec(name="npz_src", type="npz", input=str(npz_input)) + output_dir = tmp_path / "dataset" / "clips" / "npz_src" + report = convert_source_to_npz_clips( + source, + output_dir, + jobs=1, + preprocess=dataset_builder.DatasetPreprocessSpec( + ground_align="first_frame_foot", + max_all_off_ground_s=0.05, + off_ground_height=0.08, + ), + ) + + assert report["clips"] == 1 + assert (output_dir / "keep.npz").is_file() + assert not (output_dir / "float.npz").exists() + assert (output_dir / "float.npz.filtered.json").is_file() + + def _unexpected_run_conversion_tasks(*_args, **_kwargs): + raise AssertionError("filtered clip should be skipped by marker on incremental rebuild") + + monkeypatch = pytest.MonkeyPatch() + monkeypatch.setattr(dataset_builder, "run_conversion_tasks", _unexpected_run_conversion_tasks) + try: + second_report = convert_source_to_npz_clips( + source, + output_dir, + jobs=1, + preprocess=dataset_builder.DatasetPreprocessSpec( + ground_align="first_frame_foot", + max_all_off_ground_s=0.05, + off_ground_height=0.08, + ), + ) + finally: + monkeypatch.undo() + + assert second_report["clips"] == 1 def test_merge_clip_dicts_rejects_inconsistent_body_names(tmp_path: Path) -> None: @@ -418,7 +722,6 @@ def _capture(clip_dict, *, preprocess, clip_label): dataset_builder._batch_convert_chunk( [str(pkl_path)], - [1.0], 30, str(tmp_path / "merged.npz"), "train", @@ -444,6 +747,8 @@ def _arrays(num_frames: int) -> dict[str, object]: body_quat_w[..., 0] = 1.0 return { "fps": 30, + "root_pos": np.zeros((num_frames, 3), dtype=np.float32), + "root_quat_w": np.tile(np.asarray([[1.0, 0.0, 0.0, 0.0]], dtype=np.float32), (num_frames, 1)), "joint_pos": np.zeros((num_frames, 29), dtype=np.float32), "joint_vel": np.zeros((num_frames, 29), dtype=np.float32), "body_pos_w": np.zeros((num_frames, num_bodies, 3), dtype=np.float32), @@ -467,21 +772,74 @@ def _convert(path: str, **_kwargs): stats = dataset_builder._batch_convert_chunk( [str(short_path), str(valid_path)], - [1.0, 1.0], 30, - str(tmp_path / "merged.npz"), + str(tmp_path / "merged.h5"), "train", preprocess=dataset_builder.DatasetPreprocessSpec( normalize_root_xy=True, - ground_align="clip_min_foot", + ground_align="first_frame_foot", min_frames=22, ), ) - merged = np.load(tmp_path / "merged.npz", allow_pickle=True) assert stats["clips"] == 1 assert stats["kept_file_paths"] == [str(valid_path)] - assert merged["clip_lengths"].tolist() == [22] + with h5py.File(tmp_path / "merged.h5", "r") as merged: + assert merged["clip_lengths"][()].tolist() == [22] + + +def test_shard_stats_counts_real_frames_not_overlapped_windows(tmp_path: Path) -> None: + stats = dataset_builder._shard_stats( + output_dir=tmp_path, + shard_infos=[{ + "path": tmp_path / "shard_000.h5", + "clips": 3, + "frames": 1000, + "clip_lengths": [512, 512, 512], + "source_clip_lengths": [1000], + }], + fps=30, + ) + + assert stats["frames"] == 1000 + assert stats["duration_s"] == 1000 / 30 + + +def test_compute_dataset_stats_reports_total_duration_from_source_clips(tmp_path: Path) -> None: + num_bodies = len(_MJLAB_G1_BODY_NAMES) + total_frames = 18 + body_quat_w = np.zeros((total_frames, num_bodies, 4), dtype=np.float32) + body_quat_w[..., 0] = 1.0 + write_hdf5_motion_shard( + { + "fps": 30, + "root_pos": np.zeros((total_frames, 3), dtype=np.float32), + "root_quat_w": np.tile( + np.asarray([[1.0, 0.0, 0.0, 0.0]], dtype=np.float32), + (total_frames, 1), + ), + "joint_pos": np.zeros((total_frames, 29), dtype=np.float32), + "joint_vel": np.zeros((total_frames, 29), dtype=np.float32), + "body_pos_w": np.zeros((total_frames, num_bodies, 3), dtype=np.float32), + "body_quat_w": body_quat_w, + "body_lin_vel_w": np.zeros((total_frames, num_bodies, 3), dtype=np.float32), + "body_ang_vel_w": np.zeros((total_frames, num_bodies, 3), dtype=np.float32), + "body_names": np.asarray(_MJLAB_G1_BODY_NAMES, dtype=str), + "clip_starts": np.asarray([0, 12], dtype=np.int64), + "clip_lengths": np.asarray([12, 6], dtype=np.int64), + "clip_fps": np.asarray([30, 60], dtype=np.int64), + }, + tmp_path / "shard_000.h5", + max_window_frames=8, + overlap_frames=2, + ) + + stats = compute_dataset_stats(tmp_path) + + assert stats["frames"] == total_frames + assert stats["duration_s"] == pytest.approx(12 / 30 + 6 / 60) + assert stats["duration_h"] == pytest.approx((12 / 30 + 6 / 60) / 3600) + assert stats["shard_details"][0]["duration_s"] == pytest.approx(12 / 30 + 6 / 60) @@ -500,28 +858,37 @@ def test_build_dataset_batch_manifest_skips_filtered_entries( spec = DatasetSpec( name="seed_demo", target_fps=30, - val_percent=5, - hash_salt="", sources=[DatasetSourceSpec(name="seed", type="seed_csv", input=str(source_dir))], ) dataset_dir = tmp_path / "datasets" / spec.name - def _collect(_source): + def _collect_with_report(_source, *, quiet=False): + _ = quiet return ([ SourceInputFile(path=keep_train, rel_no_suffix=Path("keep_train")), SourceInputFile(path=drop_train, rel_no_suffix=Path("drop_train")), SourceInputFile(path=keep_val, rel_no_suffix=Path("keep_val")), - ], source_dir) - - def _hash_split(clip_id: str, _val_percent: int, _salt: str = "") -> str: - return "val" if clip_id.endswith("keep_val") else "train" + ], source_dir, { + "source": "seed", + "type": "seed_csv", + "metadata_csv": None, + "seed_filter_preset": "groot_strict", + "scanned_files": 3, + "metadata_rows_matched": 3, + "preset_rejected_rows": 1, + "kept_files": 2, + "filtered_files": 1, + "preset_reject_reasons": {"content_body_position:sitting": 1}, + }) num_bodies = len(_MJLAB_G1_BODY_NAMES) - def _write_merged(path: Path, lengths: list[int]) -> None: + def _write_merged(path: Path, lengths: list[int]) -> dict: total = sum(lengths) joint_pos = np.zeros((total, 29), dtype=np.float32) joint_vel = np.zeros_like(joint_pos) + root_pos = np.zeros((total, 3), dtype=np.float32) + root_quat_w = np.tile(np.asarray([[1.0, 0.0, 0.0, 0.0]], dtype=np.float32), (total, 1)) body_pos_w = np.zeros((total, num_bodies, 3), dtype=np.float32) body_quat_w = np.zeros((total, num_bodies, 4), dtype=np.float32) body_quat_w[..., 0] = 1.0 @@ -531,59 +898,45 @@ def _write_merged(path: Path, lengths: list[int]) -> None: clip_starts = np.zeros(len(lengths), dtype=np.int64) if len(lengths) > 1: clip_starts[1:] = np.cumsum(clip_lengths[:-1]) - np.savez( - path, - fps=30, - joint_pos=joint_pos, - joint_vel=joint_vel, - body_pos_w=body_pos_w, - body_quat_w=body_quat_w, - body_lin_vel_w=body_lin_vel_w, - body_ang_vel_w=body_ang_vel_w, - body_names=np.asarray(_MJLAB_G1_BODY_NAMES, dtype=str), - clip_starts=clip_starts, - clip_lengths=clip_lengths, - clip_fps=np.full(len(lengths), 30, dtype=np.int64), - clip_weights=np.ones(len(lengths), dtype=np.float64), - ) + return write_hdf5_motion_shard({ + "fps": 30, + "root_pos": root_pos, + "root_quat_w": root_quat_w, + "joint_pos": joint_pos, + "joint_vel": joint_vel, + "body_pos_w": body_pos_w, + "body_quat_w": body_quat_w, + "body_lin_vel_w": body_lin_vel_w, + "body_ang_vel_w": body_ang_vel_w, + "body_names": np.asarray(_MJLAB_G1_BODY_NAMES, dtype=str), + "clip_starts": clip_starts, + "clip_lengths": clip_lengths, + "clip_fps": np.full(len(lengths), 30, dtype=np.int64), + }, path) - def _batch_convert_split(clips, target_fps, output_dir, jobs, split_name, preprocess): + def _batch_convert_split(clips, target_fps, output_dir, jobs, label, preprocess): _ = clips, target_fps, jobs, preprocess output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) - shard_path = output_dir / "shard_000.npz" - if split_name == "train": - _write_merged(shard_path, [22]) - return ({ - "output": str(output_dir), - "shards": 1, - "clips": 1, - "num_clips": 1, - "frames": 22, - "fps": 30, - "duration_s": 22.0 / 30.0, - }, [{ - "path": shard_path, - "clip_lengths": [22], - "kept_file_paths": [str(keep_train)], - }]) - _write_merged(shard_path, [24]) + shard_path = output_dir / "shard_000.h5" + h5_info = _write_merged(shard_path, [22, 24]) return ({ "output": str(output_dir), "shards": 1, - "clips": 1, - "num_clips": 1, - "frames": 24, + "clips": 2, + "num_clips": 2, + "frames": 46, "fps": 30, - "duration_s": 24.0 / 30.0, + "duration_s": 46.0 / 30.0, }, [{ "path": shard_path, - "clip_lengths": [24], - "kept_file_paths": [str(keep_val)], + "clip_lengths": h5_info["clip_lengths"], + "source_clip_lengths": h5_info["source_clip_lengths"], + "frames": h5_info["frames"], + "kept_file_paths": [str(keep_train), str(keep_val)], }]) - monkeypatch.setattr(dataset_builder, "_collect_source_files", _collect) - monkeypatch.setattr(dataset_builder, "hash_split", _hash_split) + monkeypatch.setattr(dataset_builder, "_collect_source_files_with_report", _collect_with_report) monkeypatch.setattr(dataset_builder, "_batch_convert_split", _batch_convert_split) report = dataset_builder._build_dataset_batch( @@ -595,9 +948,102 @@ def _batch_convert_split(clips, target_fps, output_dir, jobs, split_name, prepro jobs=2, ) - manifest = (dataset_dir / "manifest_resolved.csv").read_text(encoding="utf-8") - assert "seed:keep_train" in manifest - assert "seed:keep_val" in manifest - assert "seed:drop_train" not in manifest - assert report["clip_counts"] == {"total": 2, "train": 1, "val": 1} - assert report["input_clip_counts"] == {"total": 3, "train": 2, "val": 1} + assert (dataset_dir / "shard_000.h5").is_file() + assert not (dataset_dir / "manifest_resolved.csv").exists() + assert report["input_clips"] == 3 + assert report["stats"]["clips"] == 2 + assert report["source_filters"][0]["seed_filter_preset"] == "groot_strict" + assert report["source_filters"][0]["preset_reject_reasons"] == {"content_body_position:sitting": 1} + + +def test_build_dataset_batch_clears_stale_top_level_shards( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + source_dir = tmp_path / "seed_source" + keep_path = source_dir / "keep.csv" + keep_path.parent.mkdir(parents=True, exist_ok=True) + keep_path.write_text("placeholder", encoding="utf-8") + + spec = DatasetSpec( + name="seed_demo", + target_fps=30, + sources=[DatasetSourceSpec(name="seed", type="seed_csv", input=str(source_dir))], + ) + paths = dataset_builder.resolve_dataset_paths(spec, output_root=tmp_path / "datasets") + paths.dataset_dir.mkdir(parents=True) + stale_shard = paths.dataset_dir / "shard_999.h5" + stale_tmp = paths.dataset_dir / ".seed_demo_chunk_7.h5" + stale_shard.write_text("stale", encoding="utf-8") + stale_tmp.write_text("stale", encoding="utf-8") + + def _collect_with_report(_source, *, quiet=False): + _ = quiet + return ([ + SourceInputFile(path=keep_path, rel_no_suffix=Path("keep")), + ], source_dir, { + "source": "seed", + "type": "seed_csv", + "metadata_csv": None, + "seed_filter_preset": None, + "scanned_files": 1, + "metadata_rows_matched": 1, + "kept_files": 1, + "filtered_files": 0, + }) + + def _batch_convert_split(clips, target_fps, output_dir, jobs, label, preprocess): + _ = clips, target_fps, jobs, label, preprocess + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + shard_path = output_dir / "shard_000.h5" + num_bodies = len(_MJLAB_G1_BODY_NAMES) + h5_info = write_hdf5_motion_shard({ + "fps": 30, + "root_pos": np.zeros((22, 3), dtype=np.float32), + "root_quat_w": np.tile(np.asarray([[1.0, 0.0, 0.0, 0.0]], dtype=np.float32), (22, 1)), + "joint_pos": np.zeros((22, 29), dtype=np.float32), + "joint_vel": np.zeros((22, 29), dtype=np.float32), + "body_pos_w": np.zeros((22, num_bodies, 3), dtype=np.float32), + "body_quat_w": np.tile( + np.asarray([[[1.0, 0.0, 0.0, 0.0]]], dtype=np.float32), + (22, num_bodies, 1), + ), + "body_lin_vel_w": np.zeros((22, num_bodies, 3), dtype=np.float32), + "body_ang_vel_w": np.zeros((22, num_bodies, 3), dtype=np.float32), + "body_names": np.asarray(_MJLAB_G1_BODY_NAMES, dtype=str), + "clip_starts": np.asarray([0], dtype=np.int64), + "clip_lengths": np.asarray([22], dtype=np.int64), + "clip_fps": np.asarray([30], dtype=np.int64), + }, shard_path) + return ({ + "output": str(output_dir), + "shards": 1, + "clips": 1, + "num_clips": 1, + "frames": 22, + "fps": 30, + "duration_s": 22.0 / 30.0, + }, [{ + "path": shard_path, + "clip_lengths": h5_info["clip_lengths"], + "source_clip_lengths": h5_info["source_clip_lengths"], + "frames": h5_info["frames"], + "kept_file_paths": [str(keep_path)], + }]) + + monkeypatch.setattr(dataset_builder, "_collect_source_files_with_report", _collect_with_report) + monkeypatch.setattr(dataset_builder, "_batch_convert_split", _batch_convert_split) + + dataset_builder._build_dataset_batch( + spec, + paths=paths, + force=False, + skip_fk_check=True, + skip_validate=False, + jobs=1, + ) + + assert (paths.dataset_dir / "shard_000.h5").is_file() + assert not stale_shard.exists() + assert not stale_tmp.exists() diff --git a/tests/test_dataset_viewer.py b/tests/test_dataset_viewer.py new file mode 100644 index 00000000..cb8790d4 --- /dev/null +++ b/tests/test_dataset_viewer.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +import sys +from pathlib import Path + +import numpy as np + +_PROJECT_ROOT = str(Path(__file__).resolve().parents[1]) +if _PROJECT_ROOT not in sys.path: + sys.path.insert(0, _PROJECT_ROOT) + +from scripts.view.view_dataset import discover_dataset_clips +from train_mimic.data.dataset_lib import merge_clip_dicts_payload, write_hdf5_motion_shard + + +def _clip_dict(num_frames: int, fps: int = 30) -> dict[str, object]: + root_quat_w = np.zeros((num_frames, 4), dtype=np.float32) + root_quat_w[:, 0] = 1.0 + body_quat_w = np.zeros((num_frames, 1, 4), dtype=np.float32) + body_quat_w[..., 0] = 1.0 + return { + "fps": fps, + "root_pos": np.zeros((num_frames, 3), dtype=np.float32), + "root_quat_w": root_quat_w, + "joint_pos": np.zeros((num_frames, 29), dtype=np.float32), + "joint_vel": np.zeros((num_frames, 29), dtype=np.float32), + "body_pos_w": np.zeros((num_frames, 1, 3), dtype=np.float32), + "body_quat_w": body_quat_w, + "body_lin_vel_w": np.zeros((num_frames, 1, 3), dtype=np.float32), + "body_ang_vel_w": np.zeros((num_frames, 1, 3), dtype=np.float32), + "body_names": np.asarray(["pelvis"], dtype=str), + } + + +def test_discover_dataset_clips_reads_source_clip_metadata(tmp_path: Path) -> None: + payload = merge_clip_dicts_payload([ + _clip_dict(4, fps=30), + _clip_dict(6, fps=30), + ]) + write_hdf5_motion_shard(payload, tmp_path / "nested" / "shard_000.h5") + + clips = discover_dataset_clips(tmp_path) + + assert [clip.clip_id for clip in clips] == ["nested/shard_000.h5#0", "nested/shard_000.h5#1"] + assert [clip.num_frames for clip in clips] == [4, 6] + assert [clip.clip_index for clip in clips] == [0, 1] + + +def test_discover_dataset_clips_uses_filename_for_single_h5_input(tmp_path: Path) -> None: + shard_path = tmp_path / "shard_000.h5" + payload = merge_clip_dicts_payload([ + _clip_dict(4, fps=30), + _clip_dict(6, fps=30), + ]) + write_hdf5_motion_shard(payload, shard_path) + + clips = discover_dataset_clips(shard_path) + + assert [clip.clip_id for clip in clips] == ["shard_000.h5#0", "shard_000.h5#1"] diff --git a/tests/test_dexterous_hand.py b/tests/test_dexterous_hand.py new file mode 100644 index 00000000..53aad28a --- /dev/null +++ b/tests/test_dexterous_hand.py @@ -0,0 +1,368 @@ +from __future__ import annotations + +import sys +from types import SimpleNamespace + +import numpy as np +import pytest + +from teleopit.inputs.pico4_provider import PicoControllerSnapshot, PicoControllerState +from teleopit.sim2real.hands.linkerhand_l6 import ( + GripperMapper, + LinkerHandL6Device, + SomehandL6Mapper, + parse_linkerhand_l6_config, + trigger_to_pose, +) +from teleopit.sim2real.hands.base import HandPoseCommand +from teleopit.sim2real.hands.linkerhand_o6 import ( + CLOSE_POSE as O6_CLOSE_POSE, + LinkerHandO6Device, + parse_linkerhand_o6_config, +) +from teleopit.sim2real.hands.pico_landmarks import pico_hand_to_landmarks +from teleopit.sim2real.hands.worker import HandRuntime + + +class FakeInnerHand: + def __init__(self) -> None: + self.close_calls = 0 + self.close_can_interface_calls = 0 + + def close(self) -> None: + self.close_calls += 1 + + def close_can_interface(self) -> None: + self.close_can_interface_calls += 1 + + +class FakeLinkerHandApi: + instances: list["FakeLinkerHandApi"] = [] + + def __init__(self, *, hand_joint: str, hand_type: str, modbus: str, can: str) -> None: + self.hand_joint = hand_joint + self.hand_type = hand_type + self.modbus = modbus + self.can = can + self.hand = FakeInnerHand() + self.speed: list[int] | None = None + self.poses: list[list[int]] = [] + self.close_can_calls = 0 + FakeLinkerHandApi.instances.append(self) + + def set_speed(self, speed: list[int]) -> None: + self.speed = list(speed) + + def finger_move(self, pose: list[int]) -> None: + self.poses.append(list(pose)) + + def close_can(self) -> None: + self.close_can_calls += 1 + + +def _cfg(mode: str = "gripper") -> dict[str, object]: + return { + "input": {"provider": "pico4"}, + "hands": { + "enabled": True, + "driver": "linkerhand_l6", + "mode": mode, + "sides": ["left", "right"], + "rate_hz": 30.0, + "frame_timeout_s": 0.3, + "linkerhand_l6": { + "left_can": "can0", + "right_can": "can1", + "modbus": "None", + "trigger_deadzone": 0.05, + "deadman_threshold": 0.5, + "open_pose": [250, 10, 250, 250, 250, 250], + "close_pose": [79, 10, 0, 0, 0, 0], + }, + "somehand": { + "rate_hz": 60.0, + "max_iterations": 12, + "temporal_filter_alpha": 1.0, + "output_alpha": 1.0, + }, + }, + } + + +def _o6_cfg(mode: str = "gripper") -> dict[str, object]: + return { + "input": {"provider": "pico4"}, + "hands": { + "enabled": True, + "driver": "linkerhand_o6", + "mode": mode, + "sides": ["left", "right"], + "rate_hz": 30.0, + "frame_timeout_s": 0.3, + "linkerhand_o6": { + "left_can": "can0", + "right_can": "can1", + "modbus": "None", + "trigger_deadzone": 0.05, + "deadman_threshold": 0.5, + }, + }, + } + + +def test_pico_hand_to_landmarks_uses_teleopit_adapter() -> None: + joints = np.zeros((26, 7), dtype=np.float64) + joints[:, 0] = np.arange(26) + joints[:, 1] = np.arange(26) + 100 + joints[:, 2] = np.arange(26) + 200 + + landmarks = pico_hand_to_landmarks(joints) + + assert landmarks.shape == (21, 3) + np.testing.assert_allclose(landmarks[0], [1.0, -201.0, 101.0]) + np.testing.assert_allclose(landmarks[-1], [25.0, -225.0, 125.0]) + + +def test_gripper_mapper_maps_trigger_and_deadman() -> None: + cfg = parse_linkerhand_l6_config(_cfg()) + mapper = GripperMapper(cfg) + snapshot = PicoControllerSnapshot( + left=PicoControllerState(raw=True, grip=1.0, trigger=1.0, present=True), + right=PicoControllerState(raw=True, grip=0.1, trigger=1.0, present=True), + timestamp_s=10.0, + seq=1, + ) + + commands = mapper.map(controller_snapshot=snapshot, hand_snapshot=None, active=True, now_s=10.0) + + assert commands[0].side == "left" + assert commands[0].pose == cfg.close_pose + assert commands[1].side == "right" + assert commands[1].pose == cfg.open_pose + + +def test_hand_mappers_force_open_once_when_inactive() -> None: + cfg = parse_linkerhand_l6_config(_cfg()) + snapshot = PicoControllerSnapshot( + left=PicoControllerState(raw=True, grip=1.0, trigger=1.0, present=True), + right=PicoControllerState(raw=True, grip=1.0, trigger=1.0, present=True), + timestamp_s=10.0, + seq=1, + ) + + gripper = GripperMapper(cfg) + assert gripper.map(controller_snapshot=None, hand_snapshot=None, active=False, now_s=9.0) == () + assert gripper.map(controller_snapshot=snapshot, hand_snapshot=None, active=True, now_s=10.0) + first_inactive = gripper.map(controller_snapshot=snapshot, hand_snapshot=None, active=False, now_s=10.1) + assert [command.force for command in first_inactive] == [True, True] + assert gripper.map(controller_snapshot=snapshot, hand_snapshot=None, active=False, now_s=10.2) == () + + somehand = SomehandL6Mapper(cfg) + assert somehand.map(controller_snapshot=None, hand_snapshot=None, active=False, now_s=9.0) == () + somehand._active = True + first_inactive = somehand.map(controller_snapshot=None, hand_snapshot=None, active=False, now_s=10.0) + assert [command.force for command in first_inactive] == [True, True] + assert somehand.map(controller_snapshot=None, hand_snapshot=None, active=False, now_s=10.1) == () + + +def test_trigger_to_pose_applies_deadzone_and_fixed_thumb_yaw() -> None: + assert trigger_to_pose( + 0.5, + open_pose=[250, 10, 250, 250, 250, 250], + close_pose=[79, 10, 0, 0, 0, 0], + deadzone=0.05, + thumb_yaw_default=10, + ) == [164, 10, 125, 125, 125, 125] + + +def test_trigger_to_pose_can_interpolate_thumb_yaw_for_o6() -> None: + assert trigger_to_pose( + 1.0, + open_pose=[250, 250, 250, 250, 250, 250], + close_pose=[86, 73, 118, 111, 110, 111], + deadzone=0.05, + ) == list(O6_CLOSE_POSE) + + +def test_linkerhand_l6_device_starts_sdk(monkeypatch) -> None: + FakeLinkerHandApi.instances = [] + monkeypatch.setitem( + sys.modules, + "LinkerHand.linker_hand_api", + SimpleNamespace(LinkerHandApi=FakeLinkerHandApi), + ) + cfg = parse_linkerhand_l6_config(_cfg()) + device = LinkerHandL6Device(cfg) + + device.connect() + device.send_pose("left", cfg.close_pose) + device.close() + + assert [hand.can for hand in FakeLinkerHandApi.instances] == ["can0", "can1"] + assert FakeLinkerHandApi.instances[0].speed == [50, 50, 50, 50, 50, 50] + assert FakeLinkerHandApi.instances[0].poses[-2] == list(cfg.close_pose) + assert [hand.hand.close_calls for hand in FakeLinkerHandApi.instances] == [1, 1] + + +def test_linkerhand_o6_gripper_defaults_to_reference_grasp_pose() -> None: + cfg = parse_linkerhand_o6_config(_o6_cfg()) + mapper = GripperMapper(cfg) + snapshot = PicoControllerSnapshot( + left=PicoControllerState(raw=True, grip=1.0, trigger=1.0, present=True), + right=PicoControllerState(raw=True, grip=0.1, trigger=1.0, present=True), + timestamp_s=10.0, + seq=1, + ) + + commands = mapper.map(controller_snapshot=snapshot, hand_snapshot=None, active=True, now_s=10.0) + + assert cfg.close_pose == O6_CLOSE_POSE + assert commands[0].pose == O6_CLOSE_POSE + assert commands[1].pose == cfg.open_pose + + +def test_linkerhand_o6_device_starts_sdk(monkeypatch) -> None: + FakeLinkerHandApi.instances = [] + monkeypatch.setitem( + sys.modules, + "LinkerHand.linker_hand_api", + SimpleNamespace(LinkerHandApi=FakeLinkerHandApi), + ) + cfg = parse_linkerhand_o6_config(_o6_cfg()) + device = LinkerHandO6Device(cfg) + + device.connect() + device.send_pose("left", cfg.close_pose) + device.close() + + assert [hand.hand_joint for hand in FakeLinkerHandApi.instances] == ["O6", "O6"] + assert [hand.can for hand in FakeLinkerHandApi.instances] == ["can0", "can1"] + assert FakeLinkerHandApi.instances[0].speed == [255, 255, 255, 255, 255, 255] + assert FakeLinkerHandApi.instances[0].poses[-2] == list(O6_CLOSE_POSE) + assert [hand.close_can_calls for hand in FakeLinkerHandApi.instances] == [0, 0] + assert [hand.hand.close_can_interface_calls for hand in FakeLinkerHandApi.instances] == [1, 1] + assert [hand.hand.close_calls for hand in FakeLinkerHandApi.instances] == [0, 0] + + +def test_linkerhand_o6_rejects_vr_hand_pose() -> None: + with pytest.raises(ValueError, match="supports only hands.mode=gripper"): + parse_linkerhand_o6_config(_o6_cfg(mode="vr_hand_pose")) + + +def test_hand_runtime_closes_device_when_mapper_start_fails() -> None: + calls: list[str] = [] + + class FakeDevice: + def connect(self) -> None: + calls.append("connect") + + def send_pose(self, *args, **kwargs) -> None: + raise AssertionError("send_pose should not be called") + + def open_all(self, *args, **kwargs) -> None: + calls.append("open_all") + + def close(self) -> None: + calls.append("close") + + class FailingMapper: + def start(self) -> None: + calls.append("mapper_start") + raise RuntimeError("mapper failed") + + def map(self, *args, **kwargs): + return () + + def close(self) -> None: + calls.append("mapper_close") + + runtime = HandRuntime(FakeDevice(), FailingMapper()) + + with pytest.raises(RuntimeError, match="mapper failed"): + runtime.start() + + assert calls == ["connect", "mapper_start", "close"] + + +def test_hand_runtime_reports_actual_open_commands() -> None: + calls: list[tuple[str, object, object]] = [] + open_commands = ( + HandPoseCommand("left", (250, 10, 250, 250, 250, 250), True, "open"), + HandPoseCommand("right", (250, 10, 250, 250, 250, 250), True, "open"), + ) + + class FakeDevice: + def connect(self) -> None: + calls.append(("connect", None, None)) + + def send_pose(self, side, pose, *, force=False, reason="") -> None: + calls.append((side, tuple(pose), reason)) + + def open_all(self, *, force=False, reason="") -> None: + calls.append(("open_all", force, reason)) + + def close(self) -> None: + calls.append(("close", None, None)) + + class Mapper: + def __init__(self) -> None: + self.fail = False + + def start(self) -> None: + calls.append(("mapper_start", None, None)) + + def map(self, *args, **kwargs): + if self.fail: + raise RuntimeError("tick failed") + return (HandPoseCommand("left", (1, 2, 3, 4, 5, 6), False, "mapped"),) + + def close(self) -> None: + calls.append(("mapper_close", None, None)) + + mapper = Mapper() + runtime = HandRuntime(FakeDevice(), mapper, open_commands=open_commands) + + startup = runtime.start() + ticked = runtime.tick(controller_snapshot=None, hand_snapshot=None, active=True, now_s=1.0) + mapper.fail = True + failure = runtime.tick(controller_snapshot=None, hand_snapshot=None, active=True, now_s=2.0) + shutdown = runtime.close() + + assert [command.reason for command in startup] == ["startup", "startup"] + assert ticked[0].pose == (1, 2, 3, 4, 5, 6) + assert [command.reason for command in failure] == ["failure", "failure"] + assert [command.reason for command in shutdown] == ["shutdown", "shutdown"] + assert ("open_all", True, "failure") in calls + assert ("close", None, None) in calls + + +def test_linkerhand_l6_device_wraps_sdk_system_exit_and_cleans_up(monkeypatch) -> None: + created_hands = [] + + class ExitingLinkerHandApi: + def __init__(self, *, hand_joint: str, hand_type: str, modbus: str, can: str) -> None: + del hand_joint, modbus, can + if hand_type == "right": + raise SystemExit(1) + self.hand = FakeInnerHand() + created_hands.append(self) + + def set_speed(self, speed: list[int]) -> None: + self.speed = list(speed) + + def finger_move(self, pose: list[int]) -> None: + self.pose = list(pose) + + monkeypatch.setitem( + sys.modules, + "LinkerHand.linker_hand_api", + SimpleNamespace(LinkerHandApi=ExitingLinkerHandApi), + ) + cfg = parse_linkerhand_l6_config(_cfg()) + device = LinkerHandL6Device(cfg) + + with pytest.raises(RuntimeError, match="LinkerHand SDK exited during startup"): + device.connect() + + assert len(created_hands) == 1 + assert created_hands[0].hand.close_calls == 1 diff --git a/tests/test_domain_randomization.py b/tests/test_domain_randomization.py new file mode 100644 index 00000000..b98ce407 --- /dev/null +++ b/tests/test_domain_randomization.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +from train_mimic.app import DEFAULT_TASK +from train_mimic.tasks.tracking import mdp + + +def test_general_tracking_domain_randomization_matches_gr00t_active_set() -> None: + import mjlab.tasks # noqa: F401 + import train_mimic.tasks # noqa: F401 + from mjlab.envs.mdp import dr + from mjlab.tasks.registry import load_env_cfg + + env_cfg = load_env_cfg(DEFAULT_TASK) + events = env_cfg.events + + assert set(events) == { + "push_robot", + "base_com", + "add_joint_default_pos", + "physics_material", + "randomize_rigid_body_mass", + } + + push_robot = events["push_robot"] + assert push_robot.func is mdp.push_by_setting_velocity + assert push_robot.mode == "interval" + assert push_robot.interval_range_s == (4.0, 6.0) + assert push_robot.params["velocity_range"] == { + "x": (-0.5, 0.5), + "y": (-0.5, 0.5), + "z": (-0.2, 0.2), + "roll": (-0.52, 0.52), + "pitch": (-0.52, 0.52), + "yaw": (-0.78, 0.78), + } + + base_com = events["base_com"] + assert base_com.func is dr.body_com_offset + assert base_com.mode == "startup" + assert base_com.params["asset_cfg"].body_names == ("torso_link",) + assert base_com.params["operation"] == "add" + assert base_com.params["ranges"] == { + 0: (-0.025, 0.025), + 1: (-0.05, 0.05), + 2: (-0.05, 0.05), + } + + add_joint_default_pos = events["add_joint_default_pos"] + assert add_joint_default_pos.func is dr.joint_default_pos + assert add_joint_default_pos.mode == "startup" + assert add_joint_default_pos.params["asset_cfg"].joint_names == ".*" + assert add_joint_default_pos.params["operation"] == "add" + assert add_joint_default_pos.params["ranges"] == (-0.01, 0.01) + + physics_material = events["physics_material"] + assert physics_material.func is dr.geom_friction + assert physics_material.mode == "startup" + assert physics_material.params["asset_cfg"].geom_names == r".*_collision$" + assert physics_material.params["operation"] == "abs" + assert physics_material.params["ranges"] == (0.3, 1.6) + + mass = events["randomize_rigid_body_mass"] + assert mass.func is dr.pseudo_inertia + assert mass.mode == "startup" + assert mass.params["asset_cfg"].body_names == ( + "torso_link", + "left_wrist_yaw_link", + "right_wrist_yaw_link", + ) + assert mass.params["alpha_range"] == (-0.1, 0.45) + + +def test_play_env_disables_training_only_domain_randomization() -> None: + import mjlab.tasks # noqa: F401 + import train_mimic.tasks # noqa: F401 + from mjlab.tasks.registry import load_env_cfg + + play_cfg = load_env_cfg(DEFAULT_TASK, play=True) + + assert "push_robot" not in play_cfg.events + assert "base_com" not in play_cfg.events + assert "add_joint_default_pos" not in play_cfg.events + assert "physics_material" not in play_cfg.events + assert "randomize_rigid_body_mass" not in play_cfg.events + assert play_cfg.events == {} diff --git a/tests/test_download_assets.py b/tests/test_download_assets.py index 0e1f14d4..bdd383c1 100644 --- a/tests/test_download_assets.py +++ b/tests/test_download_assets.py @@ -40,3 +40,14 @@ def test_resolve_entry_source_uses_only_current_remote_layout(tmp_path: Path) -> ) assert _resolve_entry_source(tmp_path, entry) == archive + + +def test_robot_asset_group_uses_archive_layout() -> None: + from teleopit.runtime.external_assets import ASSET_GROUPS + + entries = ASSET_GROUPS["robots"] + + assert len(entries) == 1 + assert entries[0].remote_path == "archives/robot_assets.tar.gz" + assert entries[0].local_path == "assets/robots" + assert entries[0].mode == "extract" diff --git a/tests/test_e2e.py b/tests/test_e2e.py index 879210c2..8fbfa306 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -2,7 +2,6 @@ import os from pathlib import Path -from typing import cast import numpy as np import pytest @@ -21,7 +20,6 @@ def _has_module(name: str) -> bool: requires_mujoco = pytest.mark.skipif(not _has_module("mujoco"), reason="mujoco not installed") requires_onnxruntime = pytest.mark.skipif(not _has_module("onnxruntime"), reason="onnxruntime not installed") -requires_h5py = pytest.mark.skipif(not _has_module("h5py"), reason="h5py not installed") requires_mink = pytest.mark.skipif(not _has_module("mink"), reason="mink not installed") @@ -37,16 +35,14 @@ def _asset_paths(project_root: Path) -> tuple[Path, Path, Path]: @requires_mujoco @requires_onnxruntime @requires_mink -@requires_h5py -def test_bvh_to_mujoco_pipeline_stands_and_records(project_root: Path, tmp_dir: Path) -> None: - import h5py +def test_bvh_to_mujoco_pipeline_stands(project_root: Path) -> None: from omegaconf import OmegaConf from teleopit.pipeline import TeleopPipeline policy_path, bvh_path, xml_path = _asset_paths(project_root) if not policy_path.exists() or not bvh_path.exists() or not xml_path.exists(): - pytest.skip("set TELEOPIT_TEST_POLICY_ONNX to a compatible 166D ONNX policy to run this e2e test") + pytest.skip("set TELEOPIT_TEST_POLICY_ONNX to a compatible 167D ONNX policy to run this e2e test") robot_cfg = OmegaConf.load(project_root / "teleopit" / "configs" / "robot" / "g1.yaml") controller_cfg = OmegaConf.load(project_root / "teleopit" / "configs" / "controller" / "rl_policy.yaml") @@ -64,7 +60,6 @@ def test_bvh_to_mujoco_pipeline_stands_and_records(project_root: Path, tmp_dir: input_cfg.human_format = "bvh_xsens" input_cfg.robot_name = "unitree_g1" - recording_path = tmp_dir / "e2e.h5" cfg = OmegaConf.create( { "robot": robot_cfg, @@ -72,7 +67,6 @@ def test_bvh_to_mujoco_pipeline_stands_and_records(project_root: Path, tmp_dir: "input": input_cfg, "policy_hz": 50, "pd_hz": 50, - "recording": {"output_path": str(recording_path)}, } ) @@ -86,7 +80,7 @@ def test_bvh_to_mujoco_pipeline_stands_and_records(project_root: Path, tmp_dir: pipeline.bus.subscribe(TOPIC_MIMIC_OBS, lambda _: mimic_count.append(1)) pipeline.bus.subscribe(TOPIC_ROBOT_STATE, lambda _: state_count.append(1)) - result = pipeline.run(num_steps=100, record=True) + result = pipeline.run(num_steps=100) assert float(result["root_height"]) > 0.3 assert int(result["steps"]) == 100 @@ -104,20 +98,3 @@ def test_bvh_to_mujoco_pipeline_stands_and_records(project_root: Path, tmp_dir: assert isinstance(latest_mimic, np.ndarray) assert latest_mimic.shape == (35,) assert latest_state is not None - - record_path = Path(str(result["record_path"])) - assert record_path.exists() - - with h5py.File(record_path, "r") as f: - total_frames = int(np.asarray(f.attrs["total_frames"]).item()) - assert total_frames == 100 - - joint_pos = cast(h5py.Dataset, f["joint_pos"]) - joint_vel = cast(h5py.Dataset, f["joint_vel"]) - mimic_obs = cast(h5py.Dataset, f["mimic_obs"]) - action = cast(h5py.Dataset, f["action"]) - - assert joint_pos.shape[0] == 100 - assert joint_vel.shape[0] == 100 - assert mimic_obs.shape == (100, 35) - assert action.shape == (100, 29) diff --git a/tests/test_human_frame_validation.py b/tests/test_human_frame_validation.py new file mode 100644 index 00000000..d1c7d358 --- /dev/null +++ b/tests/test_human_frame_validation.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +import numpy as np + +from teleopit.inputs.human_frame_validation import validate_human_frame + + +def test_validate_human_frame_accepts_large_finite_positions() -> None: + frame = { + "Pelvis": (np.array([100.0, -100.0, 1.0]), np.array([1.0, 0.0, 0.0, 0.0])), + "Head": (np.array([100.1, -100.1, 1.7]), np.array([1.0, 0.0, 0.0, 0.0])), + } + + result = validate_human_frame(frame) + + assert result.valid + assert result.reason == "ok" + + +def test_validate_human_frame_reports_position_inf() -> None: + frame = { + "Pelvis": (np.array([np.inf, 0.0, 1.0]), np.array([1.0, 0.0, 0.0, 0.0])), + } + + result = validate_human_frame(frame) + + assert not result.valid + assert result.reason == "position_inf" + assert result.joint_name == "Pelvis" + + +def test_validate_human_frame_reports_position_nan() -> None: + frame = { + "Pelvis": (np.array([np.nan, 0.0, 1.0]), np.array([1.0, 0.0, 0.0, 0.0])), + } + + result = validate_human_frame(frame) + + assert not result.valid + assert result.reason == "position_nan" + assert result.joint_name == "Pelvis" + + +def test_validate_human_frame_reports_quaternion_nan() -> None: + frame = { + "Head": (np.array([0.0, 0.0, 1.5]), np.array([1.0, np.nan, 0.0, 0.0])), + } + + result = validate_human_frame(frame) + + assert not result.valid + assert result.reason == "quaternion_nan" + assert result.joint_name == "Head" + + +def test_validate_human_frame_accepts_finite_frame() -> None: + frame = { + "Pelvis": (np.array([0.0, 0.0, 1.0]), np.array([1.0, 0.0, 0.0, 0.0])), + "Head": (np.array([0.0, 0.0, 1.7]), np.array([1.0, 0.0, 0.0, 0.0])), + } + + result = validate_human_frame(frame) + + assert result.valid + assert result.reason == "ok" diff --git a/tests/test_interfaces.py b/tests/test_interfaces.py index df57aba9..0847800a 100644 --- a/tests/test_interfaces.py +++ b/tests/test_interfaces.py @@ -6,7 +6,6 @@ InputProvider, MessageBus, ObservationBuilder, - Recorder, RealtimeInputProvider, Retargeter, Robot, @@ -127,11 +126,6 @@ def subscribe(self, topic): return None -class _FakeRecorder: - def add_frame(self, data): - pass - - class _FakeObservationBuilder: def build_observation(self, state, history, action_mimic): return np.zeros(10) @@ -169,9 +163,6 @@ def test_robot_isinstance(self): def test_message_bus_isinstance(self): assert isinstance(_FakeMessageBus(), MessageBus) - def test_recorder_isinstance(self): - assert isinstance(_FakeRecorder(), Recorder) - def test_observation_builder_isinstance(self): assert isinstance(_FakeObservationBuilder(), ObservationBuilder) diff --git a/tests/test_mocap_mujoco.py b/tests/test_mocap_mujoco.py index e0a077c7..8d128641 100644 --- a/tests/test_mocap_mujoco.py +++ b/tests/test_mocap_mujoco.py @@ -1,5 +1,8 @@ from __future__ import annotations +import threading +from types import SimpleNamespace + import mujoco import numpy as np import pytest @@ -157,3 +160,28 @@ def test_pico4_provider_exposes_mocap_skeleton_metadata() -> None: assert pico4.bone_names == list(BODY_JOINT_NAMES) np.testing.assert_array_equal(pico4.bone_parents, BODY_JOINT_PARENTS) + + +def test_pico4_provider_exposes_controller_snapshot() -> None: + pico4 = object.__new__(Pico4InputProvider) + pico4._lock = threading.Lock() + pico4._controller_snapshot = None + pico4._last_source_seq = None + + frame = SimpleNamespace( + seq=42, + controllers=SimpleNamespace( + left=SimpleNamespace(raw=True, axis={"grip": 0.75, "trigger": 0.25}), + right=SimpleNamespace(raw=False, axis={}), + ), + ) + pico4._accept_controller_snapshot(frame, timestamp=12.5) + + snapshot = pico4.get_controller_snapshot() + assert snapshot is not None + assert snapshot.seq == 42 + assert snapshot.timestamp_s == pytest.approx(12.5) + assert snapshot.left.raw is True + assert snapshot.left.grip == pytest.approx(0.75) + assert snapshot.left.trigger == pytest.approx(0.25) + assert snapshot.right.raw is False diff --git a/tests/test_motion_sampling.py b/tests/test_motion_sampling.py index d9b6b5de..28d01ea5 100644 --- a/tests/test_motion_sampling.py +++ b/tests/test_motion_sampling.py @@ -1,12 +1,20 @@ from __future__ import annotations from pathlib import Path +from types import SimpleNamespace import numpy as np +import pytest import torch +import h5py -from train_mimic.data.dataset_lib import merge_clip_dicts -from train_mimic.tasks.tracking.mdp.commands import MotionLib +from train_mimic.data.dataset_lib import ( + PRECOMPUTED_MOTION_VERSION, + compute_dataset_stats, + write_precomputed_motion_shard, + write_hdf5_motion_shard, +) +from train_mimic.tasks.tracking.mdp.commands import MotionCommand, MotionLib def _clip_dict(num_frames: int = 6, fps: int = 1) -> dict[str, object]: @@ -18,6 +26,7 @@ def _clip_dict(num_frames: int = 6, fps: int = 1) -> dict[str, object]: body_pos_w = np.zeros((num_frames, 3, 3), dtype=np.float32) body_pos_w[:, :, 0] = time[:, None] + body_pos_w[:, :, 1] = np.arange(3, dtype=np.float32)[None, :] body_quat_w = np.zeros((num_frames, 3, 4), dtype=np.float32) body_quat_w[:, :, 0] = 1.0 body_lin_vel_w = np.zeros((num_frames, 3, 3), dtype=np.float32) @@ -29,6 +38,8 @@ def _clip_dict(num_frames: int = 6, fps: int = 1) -> dict[str, object]: ) return { "fps": fps, + "root_pos": body_pos_w[:, 0], + "root_quat_w": body_quat_w[:, 0], "joint_pos": joint_pos, "joint_vel": joint_vel, "body_pos_w": body_pos_w, @@ -39,12 +50,57 @@ def _clip_dict(num_frames: int = 6, fps: int = 1) -> dict[str, object]: } -def _write_shard_dir(path: Path, clip_dicts: list[dict[str, object]]) -> Path: +def _write_shard_dir( + path: Path, + clip_dicts: list[dict[str, object]], +) -> Path: path.mkdir(parents=True, exist_ok=True) - merge_clip_dicts(clip_dicts, path / "shard_000.npz") + array_keys = [ + "root_pos", "root_quat_w", + "joint_pos", "joint_vel", "body_pos_w", "body_quat_w", + "body_lin_vel_w", "body_ang_vel_w", + ] + clip_lengths = np.asarray([np.asarray(cd["joint_pos"]).shape[0] for cd in clip_dicts], dtype=np.int64) + clip_starts = np.zeros(len(clip_lengths), dtype=np.int64) + if len(clip_lengths) > 1: + clip_starts[1:] = np.cumsum(clip_lengths[:-1]) + merged = {key: np.concatenate([np.asarray(cd[key]) for cd in clip_dicts], axis=0) for key in array_keys} + merged["fps"] = int(clip_dicts[0]["fps"]) + merged["body_names"] = np.asarray(clip_dicts[0]["body_names"]) + merged["clip_starts"] = clip_starts + merged["clip_lengths"] = clip_lengths + merged["clip_fps"] = np.full(len(clip_dicts), int(clip_dicts[0]["fps"]), dtype=np.int64) + shard_path = path / "shard_000.h5" + _write_precomputed_from_merged(shard_path, merged) return path +def _write_precomputed_from_merged(shard_path: Path, merged: dict[str, object]) -> None: + shard_path.parent.mkdir(parents=True, exist_ok=True) + str_dt = h5py.string_dtype(encoding="utf-8") + clip_starts = np.asarray(merged["clip_starts"], dtype=np.int64) + clip_lengths = np.asarray(merged["clip_lengths"], dtype=np.int64) + clip_fps = np.asarray(merged["clip_fps"], dtype=np.int64) + with h5py.File(shard_path, "w") as h5: + h5.attrs["format"] = "teleopit_precomputed_motion_hdf5" + h5.attrs["version"] = PRECOMPUTED_MOTION_VERSION + h5.create_dataset( + "body_names", + data=np.asarray(merged["body_names"]).astype(str).astype(object), + dtype=str_dt, + ) + for key in ("joint_pos", "joint_vel", "body_pos_w", "body_quat_w", "body_lin_vel_w", "body_ang_vel_w"): + h5.create_dataset(key, data=np.asarray(merged[key], dtype=np.float32), chunks=True) + h5.create_dataset("clip_starts", data=clip_starts) + h5.create_dataset("clip_lengths", data=clip_lengths) + h5.create_dataset("clip_fps", data=clip_fps) + h5.create_dataset("source_clip_ids", data=np.arange(len(clip_lengths), dtype=np.int64)) + h5.create_dataset("source_start_frames", data=np.zeros(len(clip_lengths), dtype=np.int64)) + h5.create_dataset("source_clip_starts", data=clip_starts) + h5.create_dataset("source_clip_lengths", data=clip_lengths) + h5.create_dataset("source_clip_fps", data=clip_fps) + + def test_motion_lib_sample_times_respect_window_steps(tmp_path: Path) -> None: motion_path = _write_shard_dir(tmp_path / "motion", [_clip_dict()]) @@ -96,10 +152,6 @@ def test_motion_lib_get_window_frames_returns_requested_offsets(tmp_path: Path) torch.tensor([2.0, 4.0, 1.0], dtype=torch.float32), ) assert frames["body_pos_w"].shape == (1, 3, 2, 3) - assert torch.allclose( - frames["body_pos_w"][0, :, 0, 0], - torch.tensor([2.0, 4.0, 1.0], dtype=torch.float32), - ) current = motion.get_frames( torch.tensor([0], dtype=torch.long), @@ -109,6 +161,197 @@ def test_motion_lib_get_window_frames_returns_requested_offsets(tmp_path: Path) assert torch.allclose(current["joint_pos"][0, :1], torch.tensor([2.0], dtype=torch.float32)) +def test_motion_lib_frame_offsets_select_requested_clip(tmp_path: Path) -> None: + clip0 = _clip_dict(num_frames=3) + clip1 = _clip_dict(num_frames=7) + clip1["root_pos"] = np.asarray(clip1["root_pos"]).copy() + clip1["root_pos"][:, 0] += 100.0 + clip1["joint_pos"] = np.asarray(clip1["joint_pos"]).copy() + clip1["joint_pos"][:, 0] += 100.0 + clip1["body_pos_w"] = np.asarray(clip1["body_pos_w"]).copy() + clip1["body_pos_w"][:, :, 0] += 100.0 + + motion_path = _write_shard_dir(tmp_path / "motion_flat_offsets", [clip0, clip1]) + motion = MotionLib( + str(motion_path), + body_indexes=torch.tensor([0, 1], dtype=torch.long), + window_steps=(0,), + ) + + frames = motion.get_frames( + torch.tensor([1], dtype=torch.long), + torch.tensor([2.0], dtype=torch.float32), + ) + + assert torch.allclose(frames["joint_pos"][0, :1], torch.tensor([102.0])) + assert torch.allclose(frames["body_pos_w"][0, 0], torch.tensor([102.0, 0.0, 0.0])) + + +def test_motion_lib_uses_original_window_starts_for_overlapped_shards(tmp_path: Path) -> None: + motion_path = tmp_path / "motion_overlapped_offsets" + motion_path.mkdir() + clip = _clip_dict(num_frames=12) + shard_path = motion_path / "shard_000.h5" + merged = { + "fps": int(clip["fps"]), + "root_pos": np.asarray(clip["root_pos"]), + "root_quat_w": np.asarray(clip["root_quat_w"]), + "joint_pos": np.asarray(clip["joint_pos"]), + "joint_vel": np.asarray(clip["joint_vel"]), + "body_pos_w": np.asarray(clip["body_pos_w"]), + "body_quat_w": np.asarray(clip["body_quat_w"]), + "body_lin_vel_w": np.asarray(clip["body_lin_vel_w"]), + "body_ang_vel_w": np.asarray(clip["body_ang_vel_w"]), + "body_names": np.asarray(clip["body_names"]), + "clip_starts": np.asarray([0, 4, 8], dtype=np.int64), + "clip_lengths": np.asarray([6, 6, 4], dtype=np.int64), + "clip_fps": np.asarray([1, 1, 1], dtype=np.int64), + } + _write_precomputed_from_merged(shard_path, merged) + + motion = MotionLib( + str(motion_path), + body_indexes=torch.tensor([0, 1], dtype=torch.long), + window_steps=(0,), + ) + + assert motion.clip_frame_offsets.cpu().tolist() == [0, 4, 8] + assert motion._joint_pos_t.shape[0] == 12 + + frames = motion.get_frames( + torch.tensor([0, 1, 2], dtype=torch.long), + torch.tensor([2.0, 2.0, 2.0], dtype=torch.float32), + ) + + assert torch.allclose( + frames["joint_pos"][:, 0], + torch.tensor([2.0, 6.0, 10.0], dtype=torch.float32), + ) + + +def test_motion_lib_selects_bodies_by_dataset_names(tmp_path: Path) -> None: + motion_path = _write_shard_dir(tmp_path / "motion_named_bodies", [_clip_dict()]) + + motion = MotionLib( + str(motion_path), + body_indexes=torch.tensor([99, 0], dtype=torch.long), + body_names=["right_ankle_roll_link", "pelvis"], + window_steps=(0,), + ) + frames = motion.get_frames( + torch.tensor([0], dtype=torch.long), + torch.tensor([2.0], dtype=torch.float32), + ) + + assert frames["body_pos_w"].shape == (1, 2, 3) + assert torch.isfinite(frames["body_pos_w"]).all() + + +def test_motion_lib_rejects_minimal_motion_shards(tmp_path: Path) -> None: + path = tmp_path / "motion_minimal" + path.mkdir() + clip = _clip_dict() + merged = { + "fps": int(clip["fps"]), + "root_pos": np.asarray(clip["root_pos"]), + "root_quat_w": np.asarray(clip["root_quat_w"]), + "joint_pos": np.asarray(clip["joint_pos"]), + "body_names": np.asarray(clip["body_names"]), + "clip_starts": np.asarray([0], dtype=np.int64), + "clip_lengths": np.asarray([np.asarray(clip["joint_pos"]).shape[0]], dtype=np.int64), + "clip_fps": np.asarray([int(clip["fps"])], dtype=np.int64), + } + write_hdf5_motion_shard(merged, path / "shard_000.h5") + + with pytest.raises(FileNotFoundError, match="precomputed Teleopit"): + MotionLib( + str(path), + body_indexes=torch.tensor([0, 1], dtype=torch.long), + window_steps=(0,), + ) + + +def test_precomputed_stats_preserve_source_clip_ids_for_windowed_shards(tmp_path: Path) -> None: + clip = _clip_dict(num_frames=12, fps=30) + shard_path = tmp_path / "motion" / "shard_000.h5" + shard_path.parent.mkdir(parents=True) + str_dt = h5py.string_dtype(encoding="utf-8") + with h5py.File(shard_path, "w") as h5: + h5.attrs["format"] = "teleopit_precomputed_motion_hdf5" + h5.attrs["version"] = PRECOMPUTED_MOTION_VERSION + h5.create_dataset("body_names", data=np.asarray(clip["body_names"]).astype(object), dtype=str_dt) + for key in ("joint_pos", "joint_vel", "body_pos_w", "body_quat_w", "body_lin_vel_w", "body_ang_vel_w"): + h5.create_dataset(key, data=np.asarray(clip[key], dtype=np.float32), chunks=True) + h5.create_dataset("clip_starts", data=np.asarray([0, 4, 8], dtype=np.int64)) + h5.create_dataset("clip_lengths", data=np.asarray([6, 6, 4], dtype=np.int64)) + h5.create_dataset("clip_fps", data=np.asarray([30, 30, 30], dtype=np.int64)) + h5.create_dataset("source_clip_ids", data=np.asarray([0, 0, 0], dtype=np.int64)) + h5.create_dataset("source_start_frames", data=np.asarray([0, 4, 8], dtype=np.int64)) + h5.create_dataset("source_clip_starts", data=np.asarray([0], dtype=np.int64)) + h5.create_dataset("source_clip_lengths", data=np.asarray([12], dtype=np.int64)) + h5.create_dataset("source_clip_fps", data=np.asarray([30], dtype=np.int64)) + + stats = compute_dataset_stats(shard_path.parent, precomputed=True) + + assert stats["windows"] == 3 + assert stats["source_clips"] == 1 + assert stats["duration_s"] == pytest.approx(12 / 30) + + +def test_write_precomputed_motion_shard_copies_window_source_metadata( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + class _FakeMotionFkExtractor: + def __init__(self, model_path: object = None) -> None: + del model_path + + def extract( + self, + root_pos: np.ndarray, + root_quat_w: np.ndarray, + joint_pos: np.ndarray, + body_names: np.ndarray, + ) -> tuple[np.ndarray, np.ndarray]: + del joint_pos + bodies = int(len(body_names)) + body_pos_w = np.repeat(np.asarray(root_pos, dtype=np.float32)[:, None, :], bodies, axis=1) + body_quat_w = np.repeat(np.asarray(root_quat_w, dtype=np.float32)[:, None, :], bodies, axis=1) + return body_pos_w, body_quat_w + + monkeypatch.setattr("train_mimic.data.motion_fk.MotionFkExtractor", _FakeMotionFkExtractor) + clip = _clip_dict(num_frames=12, fps=30) + merged = { + "fps": int(clip["fps"]), + "root_pos": np.asarray(clip["root_pos"], dtype=np.float32), + "root_quat_w": np.asarray(clip["root_quat_w"], dtype=np.float32), + "joint_pos": np.asarray(clip["joint_pos"], dtype=np.float32), + "body_names": np.asarray(clip["body_names"]).astype(str), + "clip_starts": np.asarray([0], dtype=np.int64), + "clip_lengths": np.asarray([12], dtype=np.int64), + "clip_fps": np.asarray([30], dtype=np.int64), + } + minimal_path = tmp_path / "minimal" / "shard_000.h5" + precomputed_path = tmp_path / "precomputed" / "shard_000.h5" + write_hdf5_motion_shard( + merged, + minimal_path, + max_window_frames=6, + overlap_frames=2, + ) + + write_precomputed_motion_shard(minimal_path, precomputed_path) + + with h5py.File(precomputed_path, "r") as h5: + assert h5.attrs["version"] == PRECOMPUTED_MOTION_VERSION + assert h5["clip_starts"][()].tolist() == [0, 4, 6] + assert h5["clip_lengths"][()].tolist() == [6, 6, 6] + assert h5["source_clip_ids"][()].tolist() == [0, 0, 0] + assert h5["source_start_frames"][()].tolist() == [0, 4, 6] + assert h5["source_clip_starts"][()].tolist() == [0] + assert h5["source_clip_lengths"][()].tolist() == [12] + + def test_motion_lib_window_start_and_end_times_follow_valid_center_range(tmp_path: Path) -> None: motion_path = _write_shard_dir(tmp_path / "motion_windowed", [_clip_dict()]) @@ -121,3 +364,153 @@ def test_motion_lib_window_start_and_end_times_follow_valid_center_range(tmp_pat assert torch.allclose(motion.sample_start_times(motion_ids), torch.tensor([1.0])) assert torch.allclose(motion.clip_sample_end_s[motion_ids], torch.tensor([3.0])) + + +def test_motion_lib_sampling_weights_follow_valid_duration(tmp_path: Path) -> None: + motion_path = _write_shard_dir( + tmp_path / "motion_weighted", + [ + _clip_dict(num_frames=3, fps=10), + _clip_dict(num_frames=6, fps=10), + _clip_dict(num_frames=11, fps=10), + ], + ) + + motion = MotionLib( + str(motion_path), + body_indexes=torch.tensor([0, 1], dtype=torch.long), + window_steps=(0,), + ) + + assert torch.allclose( + motion.sample_weights, + torch.tensor([0.2 / 1.7, 0.5 / 1.7, 1.0 / 1.7], dtype=torch.float32), + ) + + +def test_motion_lib_samples_ids_randomly_by_valid_duration(tmp_path: Path) -> None: + motion_path = _write_shard_dir( + tmp_path / "motion_weighted_global", + [ + _clip_dict(num_frames=3, fps=10), + _clip_dict(num_frames=11, fps=10), + ], + ) + + motion = MotionLib( + str(motion_path), + body_indexes=torch.tensor([0, 1], dtype=torch.long), + window_steps=(0,), + ) + + ids = motion.sample_motion_ids(2048) + counts = torch.bincount(ids.cpu(), minlength=2).float() + + assert counts[1] > counts[0] * 3.0 + + +def test_motion_lib_rejects_shard_body_name_mismatch(tmp_path: Path) -> None: + motion_path = tmp_path / "motion_mismatch" + clip = _clip_dict() + shard0 = _write_shard_dir(motion_path, [clip]) + + clip_bad = _clip_dict() + clip_bad["body_names"] = np.asarray( + ["pelvis", "right_ankle_roll_link", "left_ankle_roll_link"], + dtype=str, + ) + array_keys = [ + "root_pos", "root_quat_w", + "joint_pos", "joint_vel", "body_pos_w", "body_quat_w", + "body_lin_vel_w", "body_ang_vel_w", + ] + merged = {key: np.asarray(clip_bad[key]) for key in array_keys} + merged["fps"] = int(clip_bad["fps"]) + merged["body_names"] = np.asarray(clip_bad["body_names"]) + merged["clip_starts"] = np.asarray([0], dtype=np.int64) + merged["clip_lengths"] = np.asarray([np.asarray(clip_bad["joint_pos"]).shape[0]], dtype=np.int64) + merged["clip_fps"] = np.asarray([int(clip_bad["fps"])], dtype=np.int64) + bad_shard = motion_path / "shard_001.h5" + _write_precomputed_from_merged(bad_shard, merged) + + with pytest.raises(ValueError, match="body_names"): + MotionLib( + str(shard0), + body_indexes=torch.tensor([0, 1], dtype=torch.long), + window_steps=(0,), + ) + + +class _FakeMotion: + def __init__(self) -> None: + self.clip_sample_start_s = torch.tensor([0.0, 1.0, 2.0]) + + def sample_motion_ids(self, n: int) -> torch.Tensor: + return torch.full((n,), 2, dtype=torch.long) + + def sample_times(self, motion_ids: torch.Tensor) -> torch.Tensor: + return torch.full_like(motion_ids, 9.0, dtype=torch.float32) + + +def test_motion_command_rewind_sampling_uses_failed_env_previous_time() -> None: + command = MotionCommand.__new__(MotionCommand) + command.cfg = SimpleNamespace( + sampling_mode="rewind", + rewind_prob=1.0, + rewind_min_steps=2, + rewind_max_steps=2, + ) + command._env = SimpleNamespace( + device="cpu", + termination_manager=SimpleNamespace( + terminated=torch.tensor([True, False, True]) + ), + ) + command.motion = _FakeMotion() + command.motion_ids = torch.tensor([0, 1, 1], dtype=torch.long) + command.motion_times = torch.tensor([5.0, 6.0, 1.25], dtype=torch.float32) + command._step_dt = 0.5 + + command._rewind_sampling(torch.tensor([0, 1, 2], dtype=torch.long)) + + assert torch.equal(command.motion_ids, torch.tensor([0, 2, 1])) + assert torch.allclose(command.motion_times, torch.tensor([4.0, 9.0, 1.0])) + + +def test_motion_command_rewind_sampling_falls_back_to_uniform_when_disabled() -> None: + command = MotionCommand.__new__(MotionCommand) + command.cfg = SimpleNamespace( + sampling_mode="rewind", + rewind_prob=0.0, + rewind_min_steps=2, + rewind_max_steps=2, + ) + command._env = SimpleNamespace( + device="cpu", + termination_manager=SimpleNamespace(terminated=torch.tensor([True])), + ) + command.motion = _FakeMotion() + command.motion_ids = torch.tensor([0], dtype=torch.long) + command.motion_times = torch.tensor([5.0], dtype=torch.float32) + command._step_dt = 0.5 + + command._rewind_sampling(torch.tensor([0], dtype=torch.long)) + + assert torch.equal(command.motion_ids, torch.tensor([2])) + assert torch.allclose(command.motion_times, torch.tensor([9.0])) + + +def test_motion_command_rewind_sampling_rejects_invalid_step_range() -> None: + command = MotionCommand.__new__(MotionCommand) + command.cfg = SimpleNamespace( + sampling_mode="rewind", + rewind_prob=1.0, + rewind_min_steps=3, + rewind_max_steps=2, + ) + command.motion_ids = torch.tensor([0], dtype=torch.long) + command.motion_times = torch.tensor([5.0], dtype=torch.float32) + command.motion = _FakeMotion() + + with pytest.raises(ValueError, match="rewind_max_steps"): + command._rewind_sampling(torch.tensor([0], dtype=torch.long)) diff --git a/tests/test_observation.py b/tests/test_observation.py index 0086125b..64dcbddb 100644 --- a/tests/test_observation.py +++ b/tests/test_observation.py @@ -70,7 +70,7 @@ def test_rotate_motion_qpos_by_yaw_keeps_first_frame_position_with_pivot() -> No @requires_mujoco @_skip_no_xml class TestVelCmdObservationBuilder: - def test_output_dimension_is_166(self) -> None: + def test_output_dimension_is_167(self) -> None: builder = VelCmdObservationBuilder(_velcmd_cfg()) obs = builder.build( _make_state(), @@ -80,15 +80,17 @@ def test_output_dimension_is_166(self) -> None: np.zeros(3, dtype=np.float32), np.zeros(3, dtype=np.float32), ) - assert builder.total_obs_size == 166 - assert obs.shape == (166,) + assert builder.total_obs_size == 167 + assert obs.shape == (167,) assert obs.dtype == np.float32 + assert obs[-1] > 0.0 def test_projected_gravity_uses_root_quat_not_anchor_quat(self) -> None: builder = VelCmdObservationBuilder(_velcmd_cfg()) builder._base.build = lambda *args, **kwargs: np.zeros(154, dtype=np.float32) # type: ignore[method-assign] builder._base._run_fk = lambda *args, **kwargs: None # type: ignore[method-assign] builder._base._anchor_body_id = 0 + builder._base._get_body_pos = lambda _body_id: np.array([0.0, 0.0, 1.0], dtype=np.float32) # type: ignore[method-assign] builder._base._get_body_quat = lambda _body_id: np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32) # type: ignore[method-assign] root_quat = np.array( diff --git a/tests/test_pico4_provider.py b/tests/test_pico4_provider.py index 4dcde8e5..e785aaae 100644 --- a/tests/test_pico4_provider.py +++ b/tests/test_pico4_provider.py @@ -27,6 +27,7 @@ def _pico_frame( timestamp: float, body_active: bool = True, right_primary: bool = False, + right_secondary: bool = False, ) -> SimpleNamespace: return SimpleNamespace( seq=seq, @@ -34,11 +35,17 @@ def _pico_frame( body=SimpleNamespace(active=body_active, joints=body_poses), controllers=SimpleNamespace( left=SimpleNamespace(buttons={}), - right=SimpleNamespace(buttons={"primaryButton": right_primary}), + right=SimpleNamespace(buttons={"primaryButton": right_primary, "secondaryButton": right_secondary}), ), ) +def _hand_state(*, active: bool, value: float) -> SimpleNamespace: + joints = np.zeros((26, 7), dtype=np.float64) + joints[:, 0:3] = value + return SimpleNamespace(active=active, joints=joints) + + def _make_provider() -> Pico4InputProvider: provider = object.__new__(Pico4InputProvider) provider._lock = threading.Lock() @@ -48,13 +55,21 @@ def _make_provider() -> Pico4InputProvider: provider._timestamp_gap_reset_s = 0.15 provider._pending_control_events = deque() provider._pause_button = "A" + provider._arms_button = "B" provider._pause_debounce_s = 0.0 + provider._arms_debounce_s = 0.0 provider._pause_button_path = provider._resolve_button_path(provider._pause_button) + provider._arms_button_path = provider._resolve_button_path(provider._arms_button) provider._last_pause_button_pressed = False + provider._last_arms_button_pressed = False provider._last_pause_toggle_timestamp = None + provider._last_arms_toggle_timestamp = None provider._last_raw_body_joints = None provider._last_frame_timestamp = None provider._last_source_seq = None + provider._ground_alignment_offset = None + provider._controller_snapshot = None + provider._hand_snapshot = None provider._closed = False return provider @@ -129,8 +144,8 @@ def test_pico4_provider_starts_pico_bridge_receiver_with_config() -> None: assert bridge.closed is True -def test_pico4_provider_requires_pico_bridge_0_2_signature() -> None: - with pytest.raises(RuntimeError, match=r"pico_bridge >= 0\.2\.0"): +def test_pico4_provider_requires_pico_bridge_0_2_1_signature() -> None: + with pytest.raises(RuntimeError, match=r"pico_bridge >= 0\.2\.1"): Pico4InputProvider(timeout=0.01, bridge_cls=_LegacyBridge) @@ -152,15 +167,86 @@ def push_video_frame(self: _FakeBridge, frame: np.ndarray) -> int: delattr(_FakeBridge, "push_video_frame") -def test_pico4_provider_normalizes_pico_bridge_body_pose_convention() -> None: +def test_pico4_provider_converts_pico_native_body_pose_convention() -> None: body_poses = _body_poses(0.0) body_poses[0] = [1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9] - converted = Pico4InputProvider._normalize_pico_bridge_body_joints(body_poses) - np.testing.assert_allclose(converted[0], [1.0, 2.0, -3.0, 0.1, 0.2, -0.3, -0.9]) - frame = Pico4InputProvider._convert_body_joints_to_frame(body_poses) - np.testing.assert_allclose(frame["Pelvis"][0], [1.0, 3.0, 2.0], atol=1e-6) + np.testing.assert_allclose(frame["Pelvis"][0], [1.0, -3.0, 2.0], atol=1e-6) + + +def test_pico4_provider_applies_fixed_ground_alignment_from_first_real_frame() -> None: + provider = _make_provider() + body_poses = np.zeros((len(BODY_JOINT_NAMES), 7), dtype=np.float64) + pelvis_idx = BODY_JOINT_NAMES.index("Pelvis") + left_ankle_idx = BODY_JOINT_NAMES.index("Left_Ankle") + right_ankle_idx = BODY_JOINT_NAMES.index("Right_Ankle") + body_poses[pelvis_idx, 0:3] = [0.0, 0.8, 0.0] + body_poses[left_ankle_idx, 0:3] = [0.1, -0.2, 0.0] + body_poses[right_ankle_idx, 0:3] = [-0.1, 0.1, 0.0] + body_poses[:, 6] = 1.0 + + assert provider._accept_pico_frame(_pico_frame(body_poses, seq=1, timestamp=1.0)) is True + first_frame, _, _ = provider._frame_cache.latest_packet() + np.testing.assert_allclose(first_frame["Pelvis"][0][2], 0.8 + 0.2, atol=1e-6) + np.testing.assert_allclose(first_frame["Left_Ankle"][0][2], 0.0, atol=1e-6) + assert provider._ground_alignment_offset == pytest.approx(0.2) + + body_poses[:, 1] += 0.3 + assert provider._accept_pico_frame(_pico_frame(body_poses, seq=2, timestamp=1.1)) is True + second_frame, _, _ = provider._frame_cache.latest_packet() + np.testing.assert_allclose(second_frame["Pelvis"][0][2], first_frame["Pelvis"][0][2] + 0.3, atol=1e-6) + np.testing.assert_allclose(second_frame["Left_Ankle"][0][2], 0.3, atol=1e-6) + + +def test_pico4_provider_aligns_floating_first_frame_down_to_ground() -> None: + provider = _make_provider() + body_poses = np.zeros((len(BODY_JOINT_NAMES), 7), dtype=np.float64) + pelvis_idx = BODY_JOINT_NAMES.index("Pelvis") + left_ankle_idx = BODY_JOINT_NAMES.index("Left_Ankle") + right_ankle_idx = BODY_JOINT_NAMES.index("Right_Ankle") + body_poses[:, 1] = 0.2 + body_poses[pelvis_idx, 0:3] = [0.0, 0.9, 0.0] + body_poses[left_ankle_idx, 0:3] = [0.1, 0.2, 0.0] + body_poses[right_ankle_idx, 0:3] = [-0.1, 0.4, 0.0] + body_poses[:, 6] = 1.0 + + assert provider._accept_pico_frame(_pico_frame(body_poses, seq=1, timestamp=1.0)) is True + first_frame, _, _ = provider._frame_cache.latest_packet() + np.testing.assert_allclose(first_frame["Left_Ankle"][0][2], 0.0, atol=1e-6) + np.testing.assert_allclose(first_frame["Pelvis"][0][2], 0.7, atol=1e-6) + assert provider._ground_alignment_offset == pytest.approx(-0.2) + + body_poses[:, 1] += 0.3 + assert provider._accept_pico_frame(_pico_frame(body_poses, seq=2, timestamp=1.1)) is True + second_frame, _, _ = provider._frame_cache.latest_packet() + np.testing.assert_allclose(second_frame["Left_Ankle"][0][2], 0.3, atol=1e-6) + np.testing.assert_allclose(second_frame["Pelvis"][0][2], 1.0, atol=1e-6) + + +def test_pico4_provider_recomputes_ground_alignment_after_timestamp_gap_reset() -> None: + provider = _make_provider() + body_poses = np.zeros((len(BODY_JOINT_NAMES), 7), dtype=np.float64) + pelvis_idx = BODY_JOINT_NAMES.index("Pelvis") + left_ankle_idx = BODY_JOINT_NAMES.index("Left_Ankle") + right_ankle_idx = BODY_JOINT_NAMES.index("Right_Ankle") + body_poses[pelvis_idx, 0:3] = [0.0, 0.8, 0.0] + body_poses[left_ankle_idx, 0:3] = [0.1, -0.2, 0.0] + body_poses[right_ankle_idx, 0:3] = [-0.1, 0.1, 0.0] + body_poses[:, 6] = 1.0 + + assert provider._accept_pico_frame(_pico_frame(body_poses, seq=1, timestamp=1.0)) is True + assert provider._ground_alignment_offset == pytest.approx(0.2) + + body_poses[pelvis_idx, 1] = 0.7 + body_poses[left_ankle_idx, 1] = -0.5 + body_poses[right_ankle_idx, 1] = 0.2 + assert provider._accept_pico_frame(_pico_frame(body_poses, seq=2, timestamp=1.3)) is True + latest_frame, _, _ = provider._frame_cache.latest_packet() + np.testing.assert_allclose(latest_frame["Left_Ankle"][0][2], 0.0, atol=1e-6) + np.testing.assert_allclose(latest_frame["Pelvis"][0][2], 1.2, atol=1e-6) + assert provider._ground_alignment_offset == pytest.approx(0.5) + assert len(provider._frame_cache) == 1 def test_pico4_provider_drops_duplicate_raw_body_pose() -> None: @@ -201,6 +287,35 @@ def test_pico4_provider_exposes_pause_control_events_once() -> None: assert packet.control_events == () +def test_pico4_provider_exposes_arms_control_events_once() -> None: + provider = _make_provider() + + assert provider._accept_pico_frame( + _pico_frame(_body_poses(1.0), seq=1, timestamp=1.0, right_secondary=True) + ) is True + + packet = provider.get_realtime_input_packet() + assert [event.event_type for event in packet.control_events] == [ControlEventType.TOGGLE_ARMS] + + packet = provider.get_realtime_input_packet() + assert packet.control_events == () + + +def test_pico4_provider_marks_controller_present_without_raw_field() -> None: + provider = _make_provider() + frame = _pico_frame(_body_poses(1.0), seq=1, timestamp=1.0) + frame.controllers.left.axis = {"grip": 0.8, "trigger": 0.4} + + assert provider._accept_pico_frame(frame) is True + + snapshot = provider.get_controller_snapshot() + assert snapshot is not None + assert snapshot.left.present is True + assert snapshot.left.raw is False + assert snapshot.left.grip == pytest.approx(0.8) + assert snapshot.left.trigger == pytest.approx(0.4) + + def test_pico4_provider_reads_pause_control_events_when_body_inactive() -> None: provider = _make_provider() @@ -211,3 +326,22 @@ def test_pico4_provider_reads_pause_control_events_when_body_inactive() -> None: events = provider.pop_control_events() assert [event.event_type for event in events] == [ControlEventType.TOGGLE_PAUSE] assert provider._last_source_seq == 1 + + +def test_pico4_provider_exposes_hand_snapshot_when_body_inactive() -> None: + provider = _make_provider() + frame = _pico_frame(_body_poses(1.0), seq=4, timestamp=2.0, body_active=False) + frame.left_hand = _hand_state(active=True, value=1.5) + frame.right_hand = _hand_state(active=False, value=2.5) + + assert provider._accept_pico_frame(frame) is False + + snapshot = provider.get_hand_snapshot() + assert snapshot is not None + assert snapshot.seq == 4 + assert snapshot.timestamp_s == pytest.approx(2.0) + assert snapshot.left.present is True + assert snapshot.left.active is True + assert snapshot.right.present is True + assert snapshot.right.active is False + np.testing.assert_allclose(snapshot.left.joints[:, 0:3], 1.5) diff --git a/tests/test_pico_motion_recording.py b/tests/test_pico_motion_recording.py new file mode 100644 index 00000000..5efca8d7 --- /dev/null +++ b/tests/test_pico_motion_recording.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +from datetime import datetime +from pathlib import Path + +import numpy as np +import pytest + +from teleopit.constants import FULL_QPOS_DIM, NUM_JOINTS +from teleopit.recording.pico_motion import ( + PicoDatasetSpec, + RecordingState, + ensure_pico_dataset_spec, + qpos_sequence_to_motion_clip, + sanitize_clip_name, + unique_clip_path, + write_motion_clip_npz, +) +from train_mimic.data.dataset_builder import load_dataset_spec +from train_mimic.data.dataset_lib import inspect_npz + + +class _FakeFkExtractor: + num_actions = NUM_JOINTS + + def extract( + self, + root_pos: np.ndarray, + root_quat_wxyz: np.ndarray, + joint_pos: np.ndarray, + body_names: list[str], + ) -> tuple[np.ndarray, np.ndarray]: + del root_quat_wxyz, joint_pos + offsets = np.zeros((len(body_names), 3), dtype=np.float32) + offsets[:, 2] = np.linspace(0.0, 0.2, len(body_names), dtype=np.float32) + body_pos_w = root_pos[:, None, :] + offsets[None, :, :] + body_quat_w = np.zeros((root_pos.shape[0], len(body_names), 4), dtype=np.float32) + body_quat_w[..., 0] = 1.0 + return body_pos_w.astype(np.float32), body_quat_w + + +def _qpos_sequence(num_frames: int = 4) -> np.ndarray: + qpos = np.zeros((num_frames, FULL_QPOS_DIM), dtype=np.float32) + qpos[:, 0] = np.linspace(0.0, 0.3, num_frames, dtype=np.float32) + qpos[:, 2] = 0.76 + qpos[:, 3] = 1.0 + qpos[:, 7] = np.linspace(0.0, 0.2, num_frames, dtype=np.float32) + return qpos + + +def test_sanitize_clip_name_keeps_semantic_label_filesystem_safe() -> None: + assert sanitize_clip_name(" Walk Forward Slow ") == "walk_forward_slow" + assert sanitize_clip_name("turn-left/fast") == "turn-left_fast" + with pytest.raises(ValueError, match="clip name"): + sanitize_clip_name("...") + + +def test_unique_clip_path_adds_timestamp_and_avoids_overwrite(tmp_path: Path) -> None: + now = datetime(2026, 6, 10, 14, 22, 33) + first = unique_clip_path(tmp_path, "walk forward", now=now) + assert first.name == "walk_forward_20260610_142233.npz" + first.write_bytes(b"placeholder") + + second = unique_clip_path(tmp_path, "walk forward", now=now) + assert second.name == "walk_forward_20260610_142233_001.npz" + + +def test_qpos_sequence_to_motion_clip_writes_standard_npz_fields() -> None: + clip = qpos_sequence_to_motion_clip(_qpos_sequence(), fps=30, extractor=_FakeFkExtractor()) + assert int(clip["fps"]) == 30 + assert clip["root_pos"].shape == (4, 3) + assert clip["root_quat_w"].shape == (4, 4) + assert clip["joint_pos"].shape == (4, NUM_JOINTS) + assert clip["joint_vel"].shape == (4, NUM_JOINTS) + assert clip["body_pos_w"].shape[0] == 4 + assert clip["body_quat_w"].shape[-1] == 4 + + +def test_qpos_sequence_to_motion_clip_rejects_invalid_input() -> None: + with pytest.raises(ValueError, match=r"shape"): + qpos_sequence_to_motion_clip(np.zeros((3, FULL_QPOS_DIM - 1)), fps=30, extractor=_FakeFkExtractor()) + + bad = _qpos_sequence() + bad[0, 0] = np.nan + with pytest.raises(ValueError, match="NaN/Inf"): + qpos_sequence_to_motion_clip(bad, fps=30, extractor=_FakeFkExtractor()) + + with pytest.raises(ValueError, match="at least 2"): + qpos_sequence_to_motion_clip(_qpos_sequence(1), fps=30, extractor=_FakeFkExtractor()) + + +def test_write_motion_clip_npz_and_inspect(tmp_path: Path) -> None: + out = tmp_path / "clip.npz" + write_motion_clip_npz(out, _qpos_sequence(), fps=30, extractor=_FakeFkExtractor()) + assert out.exists() + meta = inspect_npz(out) + assert meta.fps == 30 + assert meta.num_frames == 4 + + +def test_ensure_pico_dataset_spec_preserves_existing_file(tmp_path: Path) -> None: + clips_dir = tmp_path / "clips" + clips_dir.mkdir() + spec_path = tmp_path / "pico_recorded.yaml" + + ensure_pico_dataset_spec( + spec_path, + clips_dir, + spec=PicoDatasetSpec(dataset_name="pico_recorded", target_fps=30), + ) + spec = load_dataset_spec(spec_path) + assert spec.name == "pico_recorded" + assert spec.target_fps == 30 + assert spec.sources[0].type == "npz" + assert spec.sources[0].input == str(clips_dir) + + spec_path.write_text("name: hand_edited\n", encoding="utf-8") + ensure_pico_dataset_spec(spec_path, clips_dir) + assert spec_path.read_text(encoding="utf-8") == "name: hand_edited\n" + + +def test_recording_state_snapshot_does_not_clear_buffer() -> None: + state = RecordingState("walk") + state.start() + state.append(_qpos_sequence(2)[0]) + state.append(_qpos_sequence(2)[1]) + + clip_name, recording, frames = state.snapshot() + assert clip_name == "walk" + assert recording is True + assert len(frames) == 2 + assert state.status()[2] == 2 + + state.mark_saved() + assert state.status()[1] is False + assert state.status()[2] == 0 + + +def test_recording_state_discard_clears_buffer() -> None: + state = RecordingState("turn") + state.start() + state.append(_qpos_sequence(2)[0]) + + clip_name, frame_count = state.discard() + assert clip_name == "turn" + assert frame_count == 1 + assert state.status()[1] is False + assert state.status()[2] == 0 diff --git a/tests/test_pico_video.py b/tests/test_pico_video.py index 075ec194..6e0d3cb7 100644 --- a/tests/test_pico_video.py +++ b/tests/test_pico_video.py @@ -87,11 +87,62 @@ def stop(self) -> None: assert sink.frames[-1].dtype == np.uint8 +def test_realsense_video_runtime_invokes_frame_callback(monkeypatch: pytest.MonkeyPatch) -> None: + fake_rs = ModuleType("pyrealsense2") + fake_rs.stream = SimpleNamespace(color="color") + fake_rs.format = SimpleNamespace(rgb8="rgb8") + + class FakeConfig: + def enable_stream(self, *_args: object) -> None: + pass + + class FakeColorFrame: + def get_data(self) -> np.ndarray: + return np.full((2, 2, 3), 7, dtype=np.uint8) + + class FakeFrames: + def get_color_frame(self) -> FakeColorFrame: + return FakeColorFrame() + + class FakePipeline: + def start(self, _config: object) -> None: + pass + + def wait_for_frames(self) -> FakeFrames: + time.sleep(0.005) + return FakeFrames() + + def stop(self) -> None: + pass + + fake_rs.config = FakeConfig + fake_rs.pipeline = FakePipeline + monkeypatch.setitem(sys.modules, "pyrealsense2", fake_rs) + + sink = _FrameSink() + callback_frames: list[np.ndarray] = [] + config = parse_pico_video_config({"video": {"enabled": True, "source": "realsense", "width": 2, "height": 2}}) + runtime = PicoVideoRuntime( + provider=sink, + config=config, + mode="sim2real", + frame_callback=lambda frame, _timestamp_s: callback_frames.append(frame.copy()), + ) + + runtime.start() + time.sleep(0.03) + runtime.stop() + + assert sink.frames + assert callback_frames + np.testing.assert_array_equal(callback_frames[-1], sink.frames[-1]) + + def test_video_runtime_stops_producer_after_startup_error(monkeypatch: pytest.MonkeyPatch) -> None: stopped = False class FailingProducer: - def __init__(self, _provider: object, _config: object) -> None: + def __init__(self, _provider: object, _config: object, _frame_callback: object = None) -> None: pass def start(self) -> None: @@ -120,7 +171,7 @@ def test_video_runtime_stops_producer_before_reraising_tick_error(monkeypatch: p stopped = False class FailingProducer: - def __init__(self, _provider: object, _config: object) -> None: + def __init__(self, _provider: object, _config: object, _frame_callback: object = None) -> None: pass def start(self) -> None: diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 31487c52..17333348 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -18,7 +18,7 @@ def __init__(self, cfg: object) -> None: captured["robot_cfg"] = cfg class DummyController: - _expected_obs_dim = 166 + _expected_obs_dim = 167 _multi_input = True def __init__(self, cfg: object) -> None: @@ -28,7 +28,7 @@ def reset(self) -> None: pass class DummyObsBuilder: - total_obs_size = 166 + total_obs_size = 167 def __init__(self, cfg: object) -> None: captured["obs_cfg"] = cfg @@ -81,14 +81,10 @@ def __init__(self, *args: object, **kwargs: object) -> None: }, "policy_hz": 50, "pd_hz": 1000, - "transition_duration": 1.5, - "pause_resume_transition_duration": 0.75, - "pause_resume_warmup_steps": 3, "keyboard": {"enabled": True}, "retarget_buffer_enabled": True, "retarget_buffer_window_s": 0.75, "retarget_buffer_delay_s": 0.02, - "reference_qpos_smoothing_alpha": 0.4, "reference_steps": [0, 1, -1], "realtime": True, "viewers": ["retarget", "sim2sim"], @@ -108,14 +104,10 @@ def __init__(self, *args: object, **kwargs: object) -> None: loop_cfg = captured["loop_args"][4] assert list(controller_cfg.default_dof_pos) == robot_default_angles assert list(controller_cfg.action_scale) == robot_action_scale - assert loop_cfg["transition_duration"] == pytest.approx(1.5) - assert loop_cfg["pause_resume_transition_duration"] == pytest.approx(0.75) - assert loop_cfg["pause_resume_warmup_steps"] == 3 assert loop_cfg["keyboard"]["enabled"] is True assert loop_cfg["retarget_buffer_enabled"] is True assert loop_cfg["retarget_buffer_window_s"] == pytest.approx(0.75) assert loop_cfg["retarget_buffer_delay_s"] == pytest.approx(0.02) - assert loop_cfg["reference_qpos_smoothing_alpha"] == pytest.approx(0.4) assert list(loop_cfg["reference_steps"]) == [0, 1, -1] assert loop_cfg["realtime"] is True assert captured["loop_kwargs"]["viewers"] == {"retarget", "sim2sim"} @@ -127,7 +119,7 @@ def __init__(self, cfg: object) -> None: pass class DummyController: - _expected_obs_dim = 166 + _expected_obs_dim = 167 _multi_input = True def __init__(self, cfg: object) -> None: @@ -137,7 +129,7 @@ def reset(self) -> None: pass class DummyObsBuilder: - total_obs_size = 166 + total_obs_size = 167 def __init__(self, cfg: object) -> None: pass @@ -196,7 +188,7 @@ def __init__(self, cfg: object) -> None: pass class DummyController: - _expected_obs_dim = 166 + _expected_obs_dim = 167 _multi_input = True def __init__(self, cfg: object) -> None: @@ -206,7 +198,7 @@ def reset(self) -> None: pass class DummyObsBuilder: - total_obs_size = 166 + total_obs_size = 167 def __init__(self, cfg: object) -> None: pass @@ -276,7 +268,7 @@ def __init__(self, cfg: object) -> None: pass class DummyController: - _expected_obs_dim = 166 + _expected_obs_dim = 167 _multi_input = True def __init__(self, cfg: object) -> None: @@ -286,7 +278,7 @@ def reset(self) -> None: pass class DummyObsBuilder: - total_obs_size = 166 + total_obs_size = 167 def __init__(self, cfg: object) -> None: pass @@ -392,7 +384,7 @@ def _pipeline_cfg(tmp_path: Path) -> dict: @requires_mujoco @_skip_no_xml -def test_pipeline_166d_policy_required(monkeypatch, tmp_path: Path) -> None: +def test_pipeline_policy_dim_mismatch_required(monkeypatch, tmp_path: Path) -> None: class DummyController160: _expected_obs_dim = 160 _multi_input = True @@ -428,7 +420,7 @@ def __init__(self, *args: object, **kwargs: object) -> None: monkeypatch.setattr("teleopit.pipeline.SimulationLoop", DummyLoop) cfg = OmegaConf.create(_pipeline_cfg(tmp_path)) - with pytest.raises(ValueError, match="Only 166D"): + with pytest.raises(ValueError, match="Observation dimension mismatch"): TeleopPipeline(cfg) @@ -436,7 +428,7 @@ def __init__(self, *args: object, **kwargs: object) -> None: @_skip_no_xml def test_pipeline_requires_dual_input_policy(monkeypatch, tmp_path: Path) -> None: class DummyControllerSingle: - _expected_obs_dim = 166 + _expected_obs_dim = 167 _multi_input = False def __init__(self, cfg: object) -> None: @@ -477,8 +469,8 @@ def __init__(self, *args: object, **kwargs: object) -> None: @requires_mujoco @_skip_no_xml def test_pipeline_dim_match_passes(monkeypatch, tmp_path: Path) -> None: - class DummyController166: - _expected_obs_dim = 166 + class DummyController167: + _expected_obs_dim = 167 _multi_input = True def __init__(self, cfg: object) -> None: @@ -505,7 +497,7 @@ class DummyLoop: def __init__(self, *args: object, **kwargs: object) -> None: pass - monkeypatch.setattr("teleopit.pipeline.RLPolicyController", DummyController166) + monkeypatch.setattr("teleopit.pipeline.RLPolicyController", DummyController167) monkeypatch.setattr("teleopit.pipeline.MuJoCoRobot", DummyRobot) monkeypatch.setattr("teleopit.pipeline.BVHInputProvider", DummyInputProvider) monkeypatch.setattr("teleopit.pipeline.RetargetingModule", DummyRetargeter) @@ -513,4 +505,4 @@ def __init__(self, *args: object, **kwargs: object) -> None: cfg = OmegaConf.create(_pipeline_cfg(tmp_path)) pipeline = TeleopPipeline(cfg) - assert pipeline.obs_builder.total_obs_size == 166 + assert pipeline.obs_builder.total_obs_size == 167 diff --git a/tests/test_realtime_utils.py b/tests/test_realtime_utils.py index 99a10c72..29649436 100644 --- a/tests/test_realtime_utils.py +++ b/tests/test_realtime_utils.py @@ -30,8 +30,6 @@ def test_exponential_vec_smoother_blends_and_resets() -> None: def test_realtime_reference_manager_warmup_counts_real_frames() -> None: manager = RealtimeReferenceManager( reference_window_builder=ReferenceWindowBuilder(policy_dt_s=0.02, reference_steps=[0, 1, 2]), - low_watermark_steps=1, - high_watermark_steps=3, warmup_steps=2, ) @@ -42,15 +40,13 @@ def test_realtime_reference_manager_warmup_counts_real_frames() -> None: assert manager.warmup_done is True -def test_realtime_reference_manager_repeat_pads_future_window() -> None: +def test_realtime_reference_manager_samples_without_padding_or_catchup() -> None: timeline = ReferenceTimeline(window_s=1.0) timeline.append(_qpos(0.0), 0.0) timeline.append(_qpos(1.0), 0.02) manager = RealtimeReferenceManager( reference_window_builder=ReferenceWindowBuilder(policy_dt_s=0.02, reference_steps=[0, 1, 2]), - low_watermark_steps=1, - high_watermark_steps=2, warmup_steps=0, ) manager.note_realtime_frame() @@ -59,71 +55,7 @@ def test_realtime_reference_manager_repeat_pads_future_window() -> None: window, diagnostics = manager.sample(timeline, 0.02) assert diagnostics.future_horizon_steps == 0 - assert diagnostics.used_repeat_padding is True - assert diagnostics.padding_active is True np.testing.assert_allclose(window.current_sample().qpos[0], 1.0, atol=1e-6) - assert window.samples[1].mode == 'repeat_latest' - assert window.samples[2].mode == 'repeat_latest' - np.testing.assert_allclose(window.samples[1].timestamp_s, 0.04, atol=1e-6) - np.testing.assert_allclose(window.samples[2].timestamp_s, 0.06, atol=1e-6) - - -def test_realtime_reference_manager_defaults_high_watermark_to_effective_low() -> None: - manager = RealtimeReferenceManager( - reference_window_builder=ReferenceWindowBuilder(policy_dt_s=0.02, reference_steps=[0, 1, 2, 3, 4]), - low_watermark_steps=0, - high_watermark_steps=None, - warmup_steps=0, - ) - - assert manager._low_watermark_steps == 4 - assert manager._high_watermark_steps == 4 - - -def test_realtime_reference_manager_catchup_advances_base_time() -> None: - timeline = ReferenceTimeline(window_s=1.0) - for idx in range(11): - timeline.append(_qpos(float(idx)), idx * 0.02) - - manager = RealtimeReferenceManager( - reference_window_builder=ReferenceWindowBuilder(policy_dt_s=0.02, reference_steps=[0]), - low_watermark_steps=2, - high_watermark_steps=4, - warmup_steps=0, - catchup_enabled=True, - catchup_trigger_steps=6, - catchup_release_steps=3, - catchup_target_delay_s=0.04, - ) - - window, diagnostics = manager.sample(timeline, 0.00) - - assert diagnostics.used_catchup is True - assert diagnostics.catchup_active is True - np.testing.assert_allclose(diagnostics.effective_base_time_s, 0.16, atol=1e-6) - np.testing.assert_allclose(window.current_sample().qpos[0], 8.0, atol=1e-6) - - -def test_realtime_reference_manager_catchup_releases_with_hysteresis() -> None: - timeline = ReferenceTimeline(window_s=1.0) - for idx in range(11): - timeline.append(_qpos(float(idx)), idx * 0.02) - - manager = RealtimeReferenceManager( - reference_window_builder=ReferenceWindowBuilder(policy_dt_s=0.02, reference_steps=[0]), - low_watermark_steps=2, - high_watermark_steps=4, - warmup_steps=0, - catchup_enabled=True, - catchup_trigger_steps=6, - catchup_release_steps=3, - catchup_target_delay_s=0.04, - ) - - _, first_diag = manager.sample(timeline, 0.00) - _, second_diag = manager.sample(timeline, 0.18) - - assert first_diag.catchup_active is True - assert second_diag.used_catchup is False - assert second_diag.catchup_active is False - np.testing.assert_allclose(second_diag.effective_base_time_s, 0.18, atol=1e-6) + assert window.samples[1].mode == "fallback_latest" + assert window.samples[2].mode == "fallback_latest" + np.testing.assert_allclose(diagnostics.effective_base_time_s, 0.02, atol=1e-6) diff --git a/tests/test_recording.py b/tests/test_recording.py deleted file mode 100644 index 02a65bdb..00000000 --- a/tests/test_recording.py +++ /dev/null @@ -1,84 +0,0 @@ -"""Tests for teleopit.recording.hdf5_recorder — HDF5 write/read, context manager, empty file.""" -import numpy as np -import pytest - -from conftest import requires_h5py - - -@requires_h5py -class TestHDF5RecorderContextManager: - """Context manager and basic write/read.""" - - def test_context_manager_creates_file(self, tmp_dir): - from teleopit.recording.hdf5_recorder import HDF5Recorder - - path = tmp_dir / "test.h5" - with HDF5Recorder(path) as rec: - assert rec.file is not None - # After exit, file should be closed - assert rec.file is None - assert path.exists() - - def test_write_and_read_frames(self, tmp_dir): - import h5py - from teleopit.recording.hdf5_recorder import HDF5Recorder - - path = tmp_dir / "test.h5" - n_frames = 5 - joint_dim = 29 - - with HDF5Recorder(path) as rec: - for i in range(n_frames): - rec.add_frame({ - "joint_pos": np.full(joint_dim, float(i), dtype=np.float32), - "timestamp": np.array(float(i)), - }) - assert rec.frame_count == n_frames - - # Read back - with h5py.File(path, "r") as f: - assert f["joint_pos"].shape == (n_frames, joint_dim) - assert f["timestamp"].shape == (n_frames,) - assert f.attrs["total_frames"] == n_frames - np.testing.assert_allclose(f["joint_pos"][0], 0.0) - np.testing.assert_allclose(f["joint_pos"][4], 4.0) - - def test_empty_file_has_metadata(self, tmp_dir): - import h5py - from teleopit.recording.hdf5_recorder import HDF5Recorder - - path = tmp_dir / "empty.h5" - with HDF5Recorder(path) as rec: - pass # no frames added - - with h5py.File(path, "r") as f: - assert f.attrs["total_frames"] == 0 - assert "recording_time" in f.attrs - - -@requires_h5py -class TestHDF5RecorderErrors: - """Error handling.""" - - def test_add_frame_without_context_raises(self, tmp_dir): - from teleopit.recording.hdf5_recorder import HDF5Recorder - - rec = HDF5Recorder(tmp_dir / "test.h5") - with pytest.raises(RuntimeError, match="not opened"): - rec.add_frame({"x": np.array([1.0])}) - - def test_unexpected_key_raises(self, tmp_dir): - from teleopit.recording.hdf5_recorder import HDF5Recorder - - path = tmp_dir / "test.h5" - with HDF5Recorder(path) as rec: - rec.add_frame({"a": np.array([1.0])}) - with pytest.raises(ValueError, match="Unexpected key"): - rec.add_frame({"b": np.array([2.0])}) - - def test_close_idempotent(self, tmp_dir): - from teleopit.recording.hdf5_recorder import HDF5Recorder - - rec = HDF5Recorder(tmp_dir / "test.h5") - rec.close() # should not raise even if never opened - rec.close() # double close is fine diff --git a/tests/test_retargeting.py b/tests/test_retargeting.py index 0041213f..3d0cd8a1 100644 --- a/tests/test_retargeting.py +++ b/tests/test_retargeting.py @@ -5,12 +5,14 @@ the RetargetingModule.retarget output contract via a mock GMR, and the extract_mimic_obs helper. """ +import json +from pathlib import Path from unittest.mock import MagicMock import numpy as np import pytest -from conftest import requires_mink, requires_mujoco +from conftest import find_g1_xml_path, requires_mink, requires_mujoco class TestRetargetingModuleImport: @@ -123,3 +125,48 @@ def test_init_with_invalid_robot_raises(self): robot_name="nonexistent_robot_xyz", human_format="nonexistent_format", ) + + def test_pico_bridge_g1_ik_frames_exist_in_canonical_xml(self): + import mujoco + + from teleopit.retargeting.gmr.params import IK_CONFIG_DICT, ROBOT_XML_DICT + + robot_name = "unitree_g1" + xml_path = find_g1_xml_path() + if xml_path is None: + pytest.skip("G1 robot XML asset not available; run scripts/setup/download_assets.py --only robots gmr") + + assert Path(xml_path).resolve() == Path(ROBOT_XML_DICT[robot_name]).resolve() + model = mujoco.MjModel.from_xml_path(xml_path) + with open(IK_CONFIG_DICT["pico_bridge"][robot_name], encoding="utf-8") as f: + ik_config = json.load(f) + + object_types = { + "body": mujoco.mjtObj.mjOBJ_BODY, + "geom": mujoco.mjtObj.mjOBJ_GEOM, + "site": mujoco.mjtObj.mjOBJ_SITE, + } + missing = [] + for table_name in ("ik_match_table1", "ik_match_table2"): + for frame_name, entry in ik_config[table_name].items(): + _, pos_weight, rot_weight, _, _, *rest = entry + if pos_weight == 0 and rot_weight == 0: + continue + frame_type = rest[0] if rest else "body" + frame_id = mujoco.mj_name2id(model, object_types[frame_type], frame_name) + if frame_id < 0: + missing.append(f"{table_name}:{frame_type}:{frame_name}") + + assert missing == [] + + def test_pico_bridge_g1_foot_ik_uses_canonical_foot_sites(self): + from teleopit.retargeting.gmr.params import IK_CONFIG_DICT + + with open(IK_CONFIG_DICT["pico_bridge"]["unitree_g1"], encoding="utf-8") as f: + ik_config = json.load(f) + + for table_name in ("ik_match_table1", "ik_match_table2"): + assert ik_config[table_name]["left_foot"][-1] == "site" + assert ik_config[table_name]["left_foot"][0] == "Left_Foot" + assert ik_config[table_name]["right_foot"][-1] == "site" + assert ik_config[table_name]["right_foot"][0] == "Right_Foot" diff --git a/tests/test_review_pipeline.py b/tests/test_review_pipeline.py deleted file mode 100644 index 7c9715cc..00000000 --- a/tests/test_review_pipeline.py +++ /dev/null @@ -1,280 +0,0 @@ -from __future__ import annotations - -import csv -import sys -from pathlib import Path - -# Ensure project root is on sys.path so `scripts.review` is importable -# even when running `pytest` directly (without `python -m pytest`). -_PROJECT_ROOT = str(Path(__file__).resolve().parents[1]) -if _PROJECT_ROOT not in sys.path: - sys.path.insert(0, _PROJECT_ROOT) - -import numpy as np - -from scripts.review import build_dataset_from_review -from scripts.review import export_reviewed_manifest -from scripts.review import init_review_manifest -from train_mimic.data.review_lib import ReviewRow, load_review_state, save_review_state - - -BODY_NAMES = np.array(["pelvis", "torso"], dtype=str) - - -def _write_manifest(path: Path) -> None: - with path.open("w", encoding="utf-8", newline="") as f: - writer = csv.writer(f) - writer.writerow( - [ - "clip_id", - "source", - "file_rel", - "num_frames", - "fps", - "resolved_split", - "resolved_npz_path", - "weight", - "clip_index", - ] - ) - writer.writerow( - [ - "src:clip_train", - "src", - "cache/clip_train.npz", - 4, - 24, - "train", - "/tmp/placeholder_train.npz", - 2.5, - -1, - ] - ) - writer.writerow( - [ - "src:clip_val", - "src", - "cache/clip_val.npz", - 5, - 30, - "val", - "/tmp/placeholder_val.npz", - 0.75, - -1, - ] - ) - - -def _write_npz(path: Path, *, num_frames: int, fps: int) -> None: - joint_pos = np.linspace(0.0, 0.2, num_frames * 29, dtype=np.float32).reshape(num_frames, 29) - joint_vel = np.gradient(joint_pos, axis=0).astype(np.float32) - body_pos_w = np.zeros((num_frames, len(BODY_NAMES), 3), dtype=np.float32) - body_pos_w[:, 0, 2] = np.linspace(0.75, 0.8, num_frames, dtype=np.float32) - body_pos_w[:, 1, 2] = body_pos_w[:, 0, 2] + 0.3 - body_quat_w = np.zeros((num_frames, len(BODY_NAMES), 4), dtype=np.float32) - body_quat_w[..., 0] = 1.0 - body_lin_vel_w = np.zeros((num_frames, len(BODY_NAMES), 3), dtype=np.float32) - body_ang_vel_w = np.zeros((num_frames, len(BODY_NAMES), 3), dtype=np.float32) - np.savez( - path, - fps=fps, - joint_pos=joint_pos, - joint_vel=joint_vel, - body_pos_w=body_pos_w, - body_quat_w=body_quat_w, - body_lin_vel_w=body_lin_vel_w, - body_ang_vel_w=body_ang_vel_w, - body_names=BODY_NAMES, - ) - - -def test_init_review_manifest_preserves_weight(tmp_path: Path, monkeypatch) -> None: - manifest_path = tmp_path / "manifest_resolved.csv" - review_path = tmp_path / "review_state.csv" - _write_manifest(manifest_path) - - monkeypatch.setattr( - sys, - "argv", - [ - "init_review_manifest.py", - "--dataset", - "demo", - "--manifest", - str(manifest_path), - "--output", - str(review_path), - ], - ) - - init_review_manifest.main() - - rows = load_review_state(review_path) - assert [row.weight for row in rows] == [2.5, 0.75] - assert [row.decision for row in rows] == ["", ""] - - -def test_export_reviewed_manifest_preserves_weight_and_filters_keep( - tmp_path: Path, - monkeypatch, -) -> None: - review_path = tmp_path / "review_state.csv" - output_path = tmp_path / "filtered_manifest.csv" - summary_path = tmp_path / "review_summary.json" - save_review_state( - [ - ReviewRow( - clip_id="src:clip_train", - source="src", - file_rel="cache/clip_train.npz", - resolved_npz_path="cache/clip_train.npz", - resolved_split="train", - num_frames=4, - fps=24, - duration_s=4 / 24, - weight=2.5, - decision="keep", - ), - ReviewRow( - clip_id="src:clip_val", - source="src", - file_rel="cache/clip_val.npz", - resolved_npz_path="cache/clip_val.npz", - resolved_split="val", - num_frames=5, - fps=30, - duration_s=5 / 30, - weight=0.75, - decision="drop", - ), - ], - review_path, - ) - - monkeypatch.setattr( - sys, - "argv", - [ - "export_reviewed_manifest.py", - "--review", - str(review_path), - "--output", - str(output_path), - "--summary", - str(summary_path), - ], - ) - - export_reviewed_manifest.main() - - with output_path.open("r", encoding="utf-8", newline="") as f: - rows = list(csv.DictReader(f)) - - assert len(rows) == 1 - assert rows[0]["clip_id"] == "src:clip_train" - assert float(rows[0]["weight"]) == 2.5 - - -def test_build_dataset_from_review_resamples_mixed_fps_and_preserves_weights( - tmp_path: Path, - monkeypatch, -) -> None: - cache_dir = tmp_path / "cache" - cache_dir.mkdir() - train_npz = cache_dir / "clip_train.npz" - val_npz = cache_dir / "clip_val.npz" - _write_npz(train_npz, num_frames=4, fps=24) - _write_npz(val_npz, num_frames=5, fps=30) - - filtered_manifest = tmp_path / "filtered_manifest.csv" - with filtered_manifest.open("w", encoding="utf-8", newline="") as f: - writer = csv.writer(f) - writer.writerow( - [ - "clip_id", - "source", - "file_rel", - "num_frames", - "fps", - "resolved_split", - "resolved_npz_path", - "weight", - "clip_index", - ] - ) - writer.writerow( - [ - "src:clip_train", - "src", - str(train_npz), - 4, - 24, - "train", - str(train_npz), - 2.5, - -1, - ] - ) - writer.writerow( - [ - "src:clip_val", - "src", - str(val_npz), - 5, - 30, - "val", - str(val_npz), - 0.75, - -1, - ] - ) - - output_dir = tmp_path / "twist2_full_cleaned" - monkeypatch.setattr( - sys, - "argv", - [ - "build_dataset_from_review.py", - "--filtered_manifest", - str(filtered_manifest), - "--output_dir", - str(output_dir), - "--target_fps", - "30", - ], - ) - - build_dataset_from_review.main() - - train = np.load(output_dir / "train" / "shard_000.npz", allow_pickle=True) - val = np.load(output_dir / "val" / "shard_000.npz", allow_pickle=True) - assert int(train["fps"]) == 30 - assert int(val["fps"]) == 30 - assert train["clip_weights"].tolist() == [2.5] - assert val["clip_weights"].tolist() == [0.75] - - -def test_init_review_manifest_preserves_weight_metadata(tmp_path: Path, monkeypatch) -> None: - manifest_path = tmp_path / "manifest_resolved.csv" - review_path = tmp_path / "review_state.csv" - _write_manifest(manifest_path) - - monkeypatch.setattr( - sys, - "argv", - [ - "init_review_manifest.py", - "--dataset", - "demo", - "--manifest", - str(manifest_path), - "--output", - str(review_path), - ], - ) - - init_review_manifest.main() - - rows = load_review_state(review_path) - assert rows[0].weight == 2.5 - assert rows[1].weight == 0.75 diff --git a/tests/test_runner_iteration_numbering.py b/tests/test_runner_iteration_numbering.py index adad80e4..9b61ed5f 100644 --- a/tests/test_runner_iteration_numbering.py +++ b/tests/test_runner_iteration_numbering.py @@ -3,6 +3,7 @@ import pytest from train_mimic.tasks.tracking.rl.runner import ( + _format_duration, _one_based_iteration_range, _resolve_total_iterations, ) @@ -36,3 +37,7 @@ def test_resolve_total_iterations_adds_requested_iterations_on_resume() -> None: def test_resolve_total_iterations_rejects_negative_requested_iterations() -> None: with pytest.raises(ValueError, match='non-negative'): _resolve_total_iterations(10, -1) + + +def test_format_duration_keeps_hours_above_one_day() -> None: + assert _format_duration(33 * 3600 + 16 * 60 + 25) == "33:16:25" diff --git a/tests/test_runtime_components.py b/tests/test_runtime_components.py index 7e6e776f..8e10c3ef 100644 --- a/tests/test_runtime_components.py +++ b/tests/test_runtime_components.py @@ -6,7 +6,6 @@ import numpy as np from teleopit.controllers.observation import VelCmdObservationBuilder -from teleopit.controllers.qpos_interpolator import QposInterpolator from teleopit.sim.runtime_components import PolicyStepRunner @@ -39,7 +38,6 @@ def _make_runner( kds=np.ones(29, dtype=np.float32), torque_limits=np.ones(29, dtype=np.float32), default_dof_pos=np.zeros(29, dtype=np.float32), - qpos_interpolator=QposInterpolator(0.0, 50.0), reference_velocity_smoothing_alpha=reference_velocity_smoothing_alpha, ) diff --git a/tests/test_save_onnx.py b/tests/test_save_onnx.py index c351b2ac..d2ff0ad6 100644 --- a/tests/test_save_onnx.py +++ b/tests/test_save_onnx.py @@ -11,7 +11,7 @@ def _build_temporal_actor_model( *, - obs_dim: int = 166, + obs_dim: int = 167, history_length: int = 10, ref_window_dim: int | None = None, ref_window_length: int = 20, @@ -51,7 +51,7 @@ def _build_temporal_actor_model( ) -def _build_temporal_actor_state_dict(obs_dim: int = 166, history_length: int = 10) -> dict[str, torch.Tensor]: +def _build_temporal_actor_state_dict(obs_dim: int = 167, history_length: int = 10) -> dict[str, torch.Tensor]: return _build_temporal_actor_model(obs_dim=obs_dim, history_length=history_length).state_dict() @@ -113,7 +113,7 @@ def fake_export(model, args, output_path, **kwargs): # type: ignore[no-untyped- assert captured["output_path"] == str(tmp_path / "policy.onnx") assert captured["input_names"] == ["obs", "obs_history"] - assert captured["arg_shapes"] == [(1, 166), (1, 10, 166)] + assert captured["arg_shapes"] == [(1, 167), (1, 10, 167)] def test_export_temporal_cnn_multi_group_checkpoint(monkeypatch, tmp_path: Path) -> None: diff --git a/tests/test_sim2real_dim.py b/tests/test_sim2real_dim.py index 51b8fa15..31dc247a 100644 --- a/tests/test_sim2real_dim.py +++ b/tests/test_sim2real_dim.py @@ -1,134 +1,54 @@ -"""Integration tests for Sim2RealController startup validation.""" from __future__ import annotations -from pathlib import Path -from unittest.mock import MagicMock +from types import SimpleNamespace import pytest -from conftest import find_g1_xml_path, requires_mujoco +from teleopit.sim2real.mp.runtime import _RobotControlWorker -_XML_PATH = find_g1_xml_path() -_skip_no_xml = pytest.mark.skipif(_XML_PATH is None, reason="Robot XML not found") - -_DEFAULT_ANGLES_29 = [ - -0.312, 0.0, 0.0, 0.669, -0.363, 0.0, - -0.312, 0.0, 0.0, 0.669, -0.363, 0.0, - 0.0, 0.0, 0.0, - 0.2, 0.2, 0.0, 0.6, 0.0, 0.0, 0.0, - 0.2, -0.2, 0.0, 0.6, 0.0, 0.0, 0.0, -] - - -def _make_dummy_policy(expected_obs_dim: int, *, multi_input: bool = True) -> MagicMock: - policy = MagicMock() - policy._expected_obs_dim = expected_obs_dim - policy._multi_input = multi_input - return policy - - -def _make_sim2real_cfg(tmp_path: Path) -> dict: - policy_path = tmp_path / "policy.onnx" - policy_path.write_bytes(b"dummy") - bvh_path = tmp_path / "clip.bvh" - bvh_path.write_text("HIERARCHY\n", encoding="utf-8") - +def _cfg() -> dict[str, object]: return { "policy_hz": 50.0, - "real_robot": {"network_interface": "lo"}, - "gamepad": {}, - "mocap_switch": {}, - "input": { - "provider": "bvh", - "bvh_file": str(bvh_path), - "bvh_format": "hc_mocap", - "robot_name": "unitree_g1", - }, - "controller": { - "policy_path": str(policy_path), - "action_scale": [0.5] * 29, - "default_dof_pos": list(_DEFAULT_ANGLES_29), - }, + "input": {"provider": "bvh"}, + "runtime": {}, + "real_robot": {}, + "controller": {}, "robot": { "num_actions": 29, - "xml_path": _XML_PATH or "", - "default_angles": _DEFAULT_ANGLES_29, - "action_scale": [0.5] * 29, - "anchor_body_name": "torso_link", + "default_angles": [0.0] * 29, + "xml_path": "robot.xml", }, } -def _apply_sim2real_mocks(monkeypatch, policy_mock: MagicMock) -> None: - dummy_provider = MagicMock() - dummy_provider.human_format = "hc_mocap" - dummy_provider.fps = 30 - dummy_provider.__len__.return_value = 1 - dummy_provider.get_frame_by_index.return_value = {} - obs_builder = MagicMock(total_obs_size=166) - - def _build_components(*args, **kwargs): - if policy_mock._expected_obs_dim != 166: - raise ValueError( - f"Only 166D velcmd_history ONNX policies are supported here; " - f"obs_builder produces 166D but policy expects {policy_mock._expected_obs_dim}D." - ) - if not policy_mock._multi_input: - raise ValueError( - "Sim2real requires an ONNX policy with dual inputs ('obs' and 'obs_history')." - ) - return MagicMock( - input_provider=dummy_provider, - retargeter=MagicMock(), - controller=policy_mock, - obs_builder=obs_builder, - ) - +def test_robot_worker_requires_dual_input_policy(monkeypatch) -> None: + policy = SimpleNamespace(_multi_input=False) + obs_builder = SimpleNamespace(total_obs_size=167) monkeypatch.setattr( - "teleopit.sim2real.controller.UnitreeG1Robot", - lambda cfg: MagicMock(check_mode=MagicMock(return_value={"name": "mock"})), - ) - monkeypatch.setattr( - "teleopit.sim2real.controller.UnitreeRemote", - MagicMock, - ) - - monkeypatch.setattr( - "teleopit.sim2real.controller.build_sim2real_mocap_components", - _build_components, + "teleopit.sim2real.mp.runtime._build_policy_components", + lambda **_kwargs: (policy, obs_builder), ) + worker = object.__new__(_RobotControlWorker) + worker.cfg = _cfg() -@requires_mujoco -@_skip_no_xml -class TestSim2RealStartupDim: - def test_velcmd_history_startup_builds_166d_obs(self, monkeypatch, tmp_path: Path) -> None: - from teleopit.sim2real.controller import Sim2RealController + with pytest.raises(ValueError, match="dual inputs"): + worker._build_policy_and_obs() - policy_mock = _make_dummy_policy(166, multi_input=True) - _apply_sim2real_mocks(monkeypatch, policy_mock) - cfg = _make_sim2real_cfg(tmp_path) - ctrl = Sim2RealController(cfg) - assert ctrl.obs_builder.total_obs_size == 166 - - def test_non_166_policy_is_rejected(self, monkeypatch, tmp_path: Path) -> None: - from teleopit.sim2real.controller import Sim2RealController - - policy_mock = _make_dummy_policy(160, multi_input=True) - _apply_sim2real_mocks(monkeypatch, policy_mock) - cfg = _make_sim2real_cfg(tmp_path) - - with pytest.raises(ValueError, match="Only 166D"): - Sim2RealController(cfg) +def test_robot_worker_accepts_167d_dual_input_policy(monkeypatch) -> None: + policy = SimpleNamespace(_multi_input=True) + obs_builder = SimpleNamespace(total_obs_size=167) + monkeypatch.setattr( + "teleopit.sim2real.mp.runtime._build_policy_components", + lambda **_kwargs: (policy, obs_builder), + ) - def test_single_input_policy_is_rejected(self, monkeypatch, tmp_path: Path) -> None: - from teleopit.sim2real.controller import Sim2RealController + worker = object.__new__(_RobotControlWorker) + worker.cfg = _cfg() - policy_mock = _make_dummy_policy(166, multi_input=False) - _apply_sim2real_mocks(monkeypatch, policy_mock) - cfg = _make_sim2real_cfg(tmp_path) + built_policy, built_obs_builder = worker._build_policy_and_obs() - with pytest.raises(ValueError, match="dual inputs"): - Sim2RealController(cfg) + assert built_policy is policy + assert built_obs_builder is obs_builder diff --git a/tests/test_sim2real_multiprocess.py b/tests/test_sim2real_multiprocess.py new file mode 100644 index 00000000..b59b8e1b --- /dev/null +++ b/tests/test_sim2real_multiprocess.py @@ -0,0 +1,976 @@ +from __future__ import annotations + +import importlib.util +import logging +from pathlib import Path +from types import SimpleNamespace + +import h5py +import numpy as np +import pytest + +from teleopit.runtime.mocap_session import MocapSessionState +from teleopit.inputs.realtime_packet import ControlEvent, ControlEventType +from teleopit.runtime.arm_mocap import compose_arm_reference, compose_arm_reference_window +from teleopit.recording.hdf5 import ( + ACTION_KEY, + FRAME_INDEX_KEY, + HAND_ACTION_KEY, + HDF5_RECORDING_FORMAT, + IMAGE_KEY, + MODE_KEY, + STATE_KEY, + TIMESTAMP_KEY, + build_mode_observation, + build_observation_state, + build_recording_schema, + hdf5_schema, +) +from teleopit.sim2real.mp.ipc import HEALTH_TOPIC, LatestSubscriber, ZmqPublisher +from teleopit.sim2real.mp.messages import HandCommandPacket, ModeStatePacket, RecordStepPacket, ReferencePacket, SharedFrameDescriptor +from teleopit.sim.reference_timeline import ReferenceSample, ReferenceWindow +from teleopit.sim2real.mp.runtime import ( + map_recording_key_to_command, + RobotMode, + Sim2RealRuntime, + _LoopTimingReporter, + _RecordingWorker, + _RobotControlWorker, + _configured_open_hand_pose, + _hand_worker_active_for_mode, + _human_frame_is_valid, +) +from teleopit.sim2real.mp.shm import SharedFrameRingReader, SharedFrameRingWriter + + +def test_loop_timing_reporter_separates_late_sleep_from_work_overrun(caplog) -> None: + reporter = _LoopTimingReporter(target_period_s=0.02, log_interval_s=1.0, deadline_miss_tolerance_s=0.001) + + with caplog.at_level(logging.INFO, logger="teleopit.operator"): + reporter.record(loop_start_s=0.0, work_elapsed_s=0.0004, cycle_elapsed_s=0.02006, pico_age_s=None) + reporter.record(loop_start_s=1.0, work_elapsed_s=0.021, cycle_elapsed_s=0.0212, pico_age_s=None) + + message = caplog.messages[-1] + assert "late_ms" in message + assert "deadline_miss(>1.00ms)=1/2" in message + assert "work_overrun=1/2" in message + assert " overrun=" not in message + + +def test_sim2real_runtime_rejects_legacy_runtime_keys() -> None: + cfg = {"sim2real_runtime": "auto", "input": {"provider": "pico4"}} + with pytest.raises(ValueError, match="Legacy sim2real config keys"): + Sim2RealRuntime(cfg) + + +def test_sim2real_runtime_accepts_bvh_provider() -> None: + cfg = {"input": {"provider": "bvh"}, "runtime": {"shutdown_timeout_s": 0.01}} + runtime = Sim2RealRuntime(cfg) + runtime.shutdown() + + +def test_sim2real_runtime_rejects_hands_without_pico_provider() -> None: + cfg = { + "input": {"provider": "bvh"}, + "runtime": {"shutdown_timeout_s": 0.01}, + "hands": {"enabled": True, "driver": "linkerhand_l6", "mode": "gripper"}, + } + with pytest.raises(ValueError, match="hands.enabled=true requires input.provider=pico4"): + Sim2RealRuntime(cfg) + + +def test_sim2real_runtime_rejects_recording_without_pico_provider() -> None: + cfg = { + "input": {"provider": "bvh"}, + "runtime": {"shutdown_timeout_s": 0.01}, + "recording": {"enabled": True}, + } + with pytest.raises(ValueError, match="recording.enabled=true requires input.provider=pico4"): + Sim2RealRuntime(cfg) + + +def test_sim2real_runtime_rejects_recording_without_input_video() -> None: + cfg = { + "input": {"provider": "pico4", "video": {"enabled": False, "source": "realsense"}}, + "runtime": {"shutdown_timeout_s": 0.01}, + "recording": {"enabled": True}, + } + with pytest.raises(ValueError, match="recording.enabled=true requires input.video.enabled=true"): + Sim2RealRuntime(cfg) + + +def test_shared_frame_ring_roundtrip() -> None: + writer = SharedFrameRingWriter(shape=(2, 3, 1), dtype=np.uint8, slots=2) + reader = SharedFrameRingReader() + try: + frame0 = np.arange(6, dtype=np.uint8).reshape(2, 3, 1) + desc0 = writer.write(frame0, timestamp_s=1.0) + assert isinstance(desc0, SharedFrameDescriptor) + np.testing.assert_array_equal(reader.read(desc0, copy=True), frame0) + + frame1 = np.full((2, 3, 1), 9, dtype=np.uint8) + desc1 = writer.write(frame1, timestamp_s=2.0) + np.testing.assert_array_equal(reader.read(desc1, copy=True), frame1) + assert desc1.slot != desc0.slot + finally: + reader.close() + writer.close(unlink=True) + + +def test_multiprocess_rejects_unsupported_video_source() -> None: + cfg = { + "input": {"provider": "pico4", "video": {"enabled": True, "source": "mujoco"}}, + "runtime": {}, + } + with pytest.raises(ValueError, match="only supports input.video.source=realsense or test-pattern"): + Sim2RealRuntime(cfg) + + +def test_zmq_endpoint_allows_one_publisher_and_subscribers() -> None: + endpoint = "inproc://sim2real-health-test" + import zmq + + context = zmq.Context() + publisher = ZmqPublisher(endpoint, context=context) + subscriber = LatestSubscriber(endpoint, HEALTH_TOPIC, context=context) + try: + with pytest.raises(zmq.ZMQError): + ZmqPublisher(endpoint, context=context) + finally: + subscriber.close() + publisher.close() + context.term() + + +def test_run_sim2real_shutdowns_on_exception(monkeypatch) -> None: + script_path = Path.cwd() / "scripts" / "run" / "run_sim2real.py" + spec = importlib.util.spec_from_file_location("test_run_sim2real", script_path) + assert spec is not None and spec.loader is not None + run_sim2real = importlib.util.module_from_spec(spec) + spec.loader.exec_module(run_sim2real) + + calls: list[str] = [] + + class FailingRuntime: + def __init__(self, _cfg: object) -> None: + calls.append("init") + + def run(self) -> None: + calls.append("run") + raise RuntimeError("boom") + + def shutdown(self) -> None: + calls.append("shutdown") + + cfg = SimpleNamespace( + input={"provider": "bvh"}, + controller=SimpleNamespace(policy_path="policy.onnx"), + ) + monkeypatch.setattr(run_sim2real, "validate_policy_path", lambda *_args, **_kwargs: None) + monkeypatch.setattr(run_sim2real, "Sim2RealRuntime", FailingRuntime) + + with pytest.raises(RuntimeError, match="boom"): + run_sim2real._run_sim2real(cfg) + + assert calls == ["init", "run", "shutdown"] + + +def test_multiprocess_run_cleans_up_after_start_failure(monkeypatch) -> None: + class FakeProcess: + def __init__(self, *, name: str, fail_start: bool = False) -> None: + self.name = name + self.exitcode = None + self.fail_start = fail_start + self.started = False + self.join_calls: list[float | None] = [] + self.terminated = False + + def start(self) -> None: + if self.fail_start: + raise RuntimeError("start failed") + self.started = True + + def is_alive(self) -> bool: + return self.started and not self.terminated + + def join(self, timeout: float | None = None) -> None: + self.join_calls.append(timeout) + + def terminate(self) -> None: + self.terminated = True + + started_process = FakeProcess(name="pico_input") + + def fake_start_processes(self: Sim2RealRuntime) -> None: + started_process.started = True + self._processes.append(started_process) + raise RuntimeError("start failed") + + cfg = { + "input": {"provider": "pico4"}, + "runtime": {"shutdown_timeout_s": 0.01}, + } + controller = Sim2RealRuntime(cfg) + monkeypatch.setattr(controller, "_start_processes", fake_start_processes.__get__(controller)) + + with pytest.raises(RuntimeError, match="start failed"): + controller.run() + + assert started_process.join_calls == [0.01, 1.0] + assert started_process.terminated is True + assert controller._processes == [] + + +def test_pico_video_does_not_spawn_separate_video_worker() -> None: + started_names: list[str] = [] + + class FakeProcess: + def __init__(self, *, name: str, target: object, args: tuple[object, ...]) -> None: + del target, args + self.name = name + self.exitcode = 0 + + def start(self) -> None: + started_names.append(self.name) + + class FakeContext: + def Event(self) -> object: + return SimpleNamespace(set=lambda: None, is_set=lambda: False) + + def Process(self, *, name: str, target: object, args: tuple[object, ...]) -> FakeProcess: + return FakeProcess(name=name, target=target, args=args) + + cfg = { + "input": {"provider": "pico4", "video": {"enabled": True, "source": "realsense"}}, + "runtime": {"shutdown_timeout_s": 0.01}, + } + runtime = Sim2RealRuntime(cfg) + runtime._ctx = FakeContext() # type: ignore[assignment] + + runtime._start_processes() + + assert started_names == ["pico_input", "reference", "robot_control"] + + +def test_recording_enabled_adds_recording_worker(monkeypatch) -> None: + started_names: list[str] = [] + + class FakeProcess: + def __init__(self, *, name: str, target: object, args: tuple[object, ...]) -> None: + del target, args + self.name = name + self.exitcode = 0 + + def start(self) -> None: + started_names.append(self.name) + + class FakeContext: + def Event(self) -> object: + return SimpleNamespace(set=lambda: None, is_set=lambda: False) + + def Process(self, *, name: str, target: object, args: tuple[object, ...]) -> FakeProcess: + return FakeProcess(name=name, target=target, args=args) + + cfg = { + "input": { + "provider": "pico4", + "video": {"enabled": True, "source": "realsense", "width": 640, "height": 480, "fps": 30}, + }, + "runtime": {"shutdown_timeout_s": 0.01}, + "recording": {"enabled": True}, + } + monkeypatch.setattr("teleopit.sim2real.mp.runtime._require_recording_dependencies", lambda: None) + monkeypatch.setattr("sys.stdin.isatty", lambda: True) + runtime = Sim2RealRuntime(cfg) + runtime._ctx = FakeContext() # type: ignore[assignment] + + runtime._start_processes() + + assert started_names == ["pico_input", "reference", "robot_control", "recording_worker"] + + +def test_recording_key_mapping() -> None: + assert map_recording_key_to_command("R") == "record_start" + assert map_recording_key_to_command("s") == "record_save" + assert map_recording_key_to_command("D") == "record_discard" + assert map_recording_key_to_command("q") == "shutdown" + assert map_recording_key_to_command("x") is None + + +def test_hdf5_recording_schema() -> None: + schema = build_recording_schema({"width": 640, "height": 480, "key": IMAGE_KEY}) + sidecar = hdf5_schema(schema) + features = sidecar["features"] + + assert sidecar["format"] == HDF5_RECORDING_FORMAT + assert features[IMAGE_KEY]["type"] == "video" + assert features[IMAGE_KEY]["format"] == "mp4" + assert features[IMAGE_KEY]["shape"] == [480, 640, 3] + assert features[FRAME_INDEX_KEY]["dtype"] == "int64" + assert features[TIMESTAMP_KEY]["dtype"] == "float64" + assert features[STATE_KEY]["shape"] == [68] + assert features[MODE_KEY]["shape"] == [1] + assert features[ACTION_KEY]["shape"] == [36] + assert features[HAND_ACTION_KEY]["shape"] == [12] + assert sidecar["features"][STATE_KEY]["slices"]["joint_pos"] == [0, 29] + assert sidecar["features"][STATE_KEY]["slices"]["projected_gravity"] == [65, 68] + assert sidecar["features"][MODE_KEY]["codes"]["pause"] == 3 + assert sidecar["features"][ACTION_KEY]["slices"]["joint_pos"] == [7, 36] + assert sidecar["features"][HAND_ACTION_KEY]["slices"]["left_pose"] == [0, 6] + assert sidecar["features"][HAND_ACTION_KEY]["slices"]["right_pose"] == [6, 12] + + +def test_hdf5_recorder_mp4_sidecar_writes_sync_metadata(tmp_path: Path) -> None: + from teleopit.recording.hdf5 import MP4VideoConfig, TeleopitHDF5Recorder + + schema = build_recording_schema({"width": 2, "height": 2, "key": IMAGE_KEY}) + recorder = TeleopitHDF5Recorder.create( + output_dir=tmp_path, + task="walk", + fps=30, + schema=schema, + video_config=MP4VideoConfig(quality=5), + ) + + recorder.start_episode() + for idx in range(2): + recorder.add_frame( + image=np.full((2, 2, 3), idx * 64, dtype=np.uint8), + state=np.arange(68, dtype=np.float32), + mode=build_mode_observation("mocap"), + action=np.arange(36, dtype=np.float32), + hand_action=np.arange(12, dtype=np.float32), + task="walk", + ) + recorder.save_episode() + recorder.finalize() + + episodes = sorted((tmp_path / "episodes").glob("*.h5")) + videos = sorted((tmp_path / "videos" / "observation.images.d435i_rgb").glob("*.mp4")) + assert len(episodes) == 1 + assert len(videos) == 1 + assert videos[0].stat().st_size > 0 + assert (tmp_path / "schema.json").exists() + assert not list((tmp_path / ".tmp").glob("*.h5")) + + with h5py.File(episodes[0], "r") as h5: + assert h5.attrs["format"] == HDF5_RECORDING_FORMAT + assert h5.attrs["version"] == 1 + assert h5.attrs["task"] == "walk" + assert h5.attrs["fps"] == 30 + assert h5.attrs["frames"] == 2 + assert h5.attrs["video_path"] == videos[0].relative_to(tmp_path).as_posix() + assert h5.attrs["video_key"] == IMAGE_KEY + assert h5.attrs["video_frames"] == 2 + assert h5.attrs["video_fps"] == 30 + assert IMAGE_KEY not in h5 + assert h5[FRAME_INDEX_KEY].shape == (2,) + assert h5[TIMESTAMP_KEY].shape == (2,) + np.testing.assert_array_equal(h5[FRAME_INDEX_KEY][...], np.array([0, 1], dtype=np.int64)) + np.testing.assert_allclose(h5[TIMESTAMP_KEY][...], np.array([0.0, 1.0 / 30.0], dtype=np.float64)) + assert h5[STATE_KEY].shape == (2, 68) + assert h5[MODE_KEY].shape == (2, 1) + assert h5[ACTION_KEY].shape == (2, 36) + assert h5[HAND_ACTION_KEY].shape == (2, 12) + + +def test_hdf5_recorder_cleans_partial_episode_when_video_writer_fails(tmp_path: Path) -> None: + from teleopit.recording.hdf5 import TeleopitHDF5Recorder + + class FailingVideoRecorder(TeleopitHDF5Recorder): + def _create_video_writer(self, path: Path) -> object: + path.write_bytes(b"partial") + raise RuntimeError("writer failed") + + schema = build_recording_schema({"width": 2, "height": 2, "key": IMAGE_KEY}) + recorder = FailingVideoRecorder.create(output_dir=tmp_path, task="walk", fps=30, schema=schema) + + with pytest.raises(RuntimeError, match="writer failed"): + recorder.start_episode() + + recorder.finalize() + assert not list((tmp_path / ".tmp").glob("*.h5")) + assert not list((tmp_path / ".tmp" / "videos" / "observation.images.d435i_rgb").glob("*.mp4")) + assert not list((tmp_path / "episodes").glob("*.h5")) + assert not list((tmp_path / "videos" / "observation.images.d435i_rgb").glob("*.mp4")) + + +def test_hdf5_recorder_keeps_startup_error_when_partial_cleanup_fails(tmp_path: Path) -> None: + from teleopit.recording.hdf5 import TeleopitHDF5Recorder + + class BrokenWriter: + def close(self) -> None: + raise RuntimeError("cleanup failed") + + class FailingDatasetRecorder(TeleopitHDF5Recorder): + def _create_video_writer(self, path: Path) -> object: + return BrokenWriter() + + def _create_datasets(self, h5: h5py.File) -> dict[str, h5py.Dataset]: + raise RuntimeError("startup failed") + + schema = build_recording_schema({"width": 2, "height": 2, "key": IMAGE_KEY}) + recorder = FailingDatasetRecorder.create(output_dir=tmp_path, task="walk", fps=30, schema=schema) + + with pytest.raises(RuntimeError, match="startup failed"): + recorder.start_episode() + + recorder.finalize() + assert not list((tmp_path / ".tmp").glob("*.h5")) + + +def test_configured_open_hand_pose_matches_linkerhand_l6_parser() -> None: + left, right = _configured_open_hand_pose( + { + "hands": { + "enabled": True, + "driver": "linkerhand_l6", + "mode": "gripper", + "linkerhand_l6": { + "thumb_yaw_center": 42, + "open_pose": [250, 99, 250, 250, 250, 250], + }, + }, + } + ) + + np.testing.assert_allclose(left, np.array([250, 42, 250, 250, 250, 250], dtype=np.float32)) + np.testing.assert_allclose(right, left) + + +def test_configured_open_hand_pose_defaults_without_enabled_hands() -> None: + left, right = _configured_open_hand_pose({}) + + np.testing.assert_allclose(left, np.array([250, 10, 250, 250, 250, 250], dtype=np.float32)) + np.testing.assert_allclose(right, left) + + +def test_record_observation_state_concat_order() -> None: + state = SimpleNamespace( + qpos=np.arange(29, dtype=np.float32), + qvel=np.arange(29, dtype=np.float32) + 100.0, + quat=np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), + ang_vel=np.array([1.0, 2.0, 3.0], dtype=np.float32), + ) + + out = build_observation_state(state) + + assert out.shape == (68,) + np.testing.assert_allclose(out[0:29], state.qpos) + np.testing.assert_allclose(out[29:58], state.qvel) + np.testing.assert_allclose(out[58:62], state.quat) + np.testing.assert_allclose(out[62:65], state.ang_vel) + np.testing.assert_allclose(out[65:68], np.array([0.0, 0.0, -1.0], dtype=np.float32)) + + +def test_human_frame_validation_rejects_bad_inputs() -> None: + valid_frame = { + "Pelvis": (np.zeros(3, dtype=np.float64), np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64)), + } + assert _human_frame_is_valid(valid_frame) + + bad_frame = { + "Pelvis": (np.array([np.nan, 0.0, 0.0], dtype=np.float64), np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64)), + } + assert not _human_frame_is_valid(bad_frame) + + +def test_robot_worker_holds_stale_reference_instead_of_damping() -> None: + worker = object.__new__(_RobotControlWorker) + hold_qpos = np.zeros(36, dtype=np.float64) + hold_qpos[3] = 1.0 + hold_qpos[7] = 0.25 + worker._mocap_session = SimpleNamespace(state=MocapSessionState.ACTIVE) + worker._latest_reference = ReferencePacket( + qpos=np.zeros(36, dtype=np.float64), + timestamp_s=1.0, + seq=1, + source_timestamp_s=1.0, + source_seq=1, + frame_valid=True, + ) + worker._reference_age_s = lambda: 0.30 + worker._stale_reference_hold_s = 0.08 + worker._last_mocap_hold_reason = None + worker._last_commanded_motion_qpos = hold_qpos.copy() + worker._resolve_mocap_hold_qpos = lambda: hold_qpos.copy() + held: list[np.ndarray] = [] + worker._run_static_mocap_step = lambda qpos: held.append(np.asarray(qpos, dtype=np.float64).copy()) + worker._enter_damping = lambda: pytest.fail("stale references must not enter damping") + + worker._mocap_step() + + assert len(held) == 1 + np.testing.assert_allclose(held[0], hold_qpos) + + +def test_robot_worker_requires_consecutive_valid_references(monkeypatch) -> None: + worker = object.__new__(_RobotControlWorker) + worker._latest_reference = None + worker._last_reference_seq = -1 + worker._consecutive_valid_references = 0 + worker._check_frames = 2 + worker._max_reference_age_s = 0.25 + worker.provider_kind = "pico4" + worker._reference_age_s = lambda: 0.0 + worker._mocap_session = SimpleNamespace(state=MocapSessionState.ACTIVE) + worker._last_commanded_motion_qpos = np.zeros(36, dtype=np.float64) + worker._mode_pub_events: list[tuple[str, object]] = [] + worker._mode_pub = SimpleNamespace( + publish=lambda topic, payload: worker._mode_pub_events.append((topic, payload)) + ) + + valid0 = ReferencePacket( + qpos=np.zeros(36, dtype=np.float64), + timestamp_s=1.0, + seq=1, + source_timestamp_s=1.0, + source_seq=1, + frame_valid=True, + ) + worker._note_reference_packet(valid0) + assert worker._can_switch_to_mocap() is False + + valid1 = ReferencePacket( + qpos=np.zeros(36, dtype=np.float64), + timestamp_s=1.1, + seq=2, + source_timestamp_s=1.1, + source_seq=2, + frame_valid=True, + ) + worker._note_reference_packet(valid1) + assert worker._can_switch_to_mocap() is True + + invalid = ReferencePacket( + qpos=np.zeros(36, dtype=np.float64), + timestamp_s=1.2, + seq=3, + source_timestamp_s=1.2, + source_seq=3, + frame_valid=False, + ) + worker._note_reference_packet(invalid) + assert worker._can_switch_to_mocap() is False + + fresh_packet = ReferencePacket( + qpos=np.zeros(36, dtype=np.float64), + timestamp_s=1.4, + seq=4, + source_timestamp_s=1.4, + source_seq=4, + frame_valid=True, + ) + worker._note_reference_packet(fresh_packet) + assert worker._latest_reference == fresh_packet + assert worker._consecutive_valid_references == 1 + + +def test_robot_worker_replays_bvh_on_mocap_entry() -> None: + worker = object.__new__(_RobotControlWorker) + commands: list[str] = [] + worker.provider_kind = "bvh" + worker.robot = SimpleNamespace(get_state=lambda: SimpleNamespace()) + worker._standing_qpos = np.zeros(36, dtype=np.float64) + worker._mocap_reentry_armed = True + worker._reset_policy_state = lambda: None + worker._build_resume_alignment_qpos = lambda _hold, _state: np.ones(36, dtype=np.float64) + worker._ref_proc = SimpleNamespace(reset_alignment=lambda **_kwargs: None) + worker._send_reference_command = commands.append + + worker._transition_to_mocap() + + assert worker.mode == RobotMode.MOCAP + assert commands == ["replay_mocap"] + + +def test_robot_worker_standing_reference_uses_default_root_height_without_base_pos() -> None: + worker = object.__new__(_RobotControlWorker) + worker.default_angles = np.zeros(29, dtype=np.float32) + worker.num_actions = 29 + worker._default_root_pos = np.array([0.0, 0.0, 0.76], dtype=np.float64) + worker._standing_qpos = np.zeros(36, dtype=np.float64) + state = SimpleNamespace( + quat=np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), + qpos=np.zeros(29, dtype=np.float32), + ) + + worker._set_default_standing_reference(state) + + np.testing.assert_allclose(worker._standing_qpos[0:3], np.array([0.0, 0.0, 0.76])) + + +def test_robot_worker_pauses_when_bvh_reference_reports_paused() -> None: + worker = object.__new__(_RobotControlWorker) + worker.provider_kind = "bvh" + worker.mode = RobotMode.MOCAP + worker._mocap_session = SimpleNamespace(state=MocapSessionState.ACTIVE) + worker._last_reference_seq = -1 + worker._consecutive_valid_references = 0 + paused: list[str] = [] + + def pause_active_mocap() -> None: + paused.append("pause") + worker._mocap_session.state = MocapSessionState.PAUSED + + worker._pause_active_mocap = pause_active_mocap + packet = ReferencePacket( + qpos=np.zeros(36, dtype=np.float64), + timestamp_s=1.0, + seq=1, + source_timestamp_s=0.0, + source_seq=0, + playback_paused=True, + playback_finished=True, + ) + + worker._note_reference_packet(packet) + + assert paused == ["pause"] + assert worker._latest_reference is packet + + +def test_robot_worker_composes_arm_reference_from_standing_pose() -> None: + standing_qpos = np.arange(36, dtype=np.float64) + retarget = np.full(36, 100.0, dtype=np.float64) + retarget[7 + 15:7 + 29] = np.arange(14, dtype=np.float64) + 200.0 + + composed = compose_arm_reference( + standing_qpos=standing_qpos, + retarget_qpos=retarget, + arm_joint_indices=np.arange(15, 29, dtype=np.int64), + num_actions=29, + ) + + np.testing.assert_allclose(composed[: 7 + 15], standing_qpos[: 7 + 15]) + np.testing.assert_allclose(composed[7 + 15:7 + 29], retarget[7 + 15:7 + 29]) + + +def test_robot_worker_composes_arm_reference_window_samples() -> None: + standing_qpos = np.zeros(36, dtype=np.float64) + qpos0 = np.ones(36, dtype=np.float64) + qpos1 = np.full(36, 2.0, dtype=np.float64) + window = ReferenceWindow( + base_time_s=1.0, + policy_dt_s=0.02, + reference_steps=(0, 1), + samples=( + ReferenceSample(qpos=qpos0, timestamp_s=1.0, mode="a", used_fallback=False, older_timestamp_s=None, newer_timestamp_s=None, alpha=None), + ReferenceSample(qpos=qpos1, timestamp_s=1.02, mode="b", used_fallback=False, older_timestamp_s=None, newer_timestamp_s=None, alpha=None), + ), + ) + + composed = compose_arm_reference_window( + window, + standing_qpos=standing_qpos, + arm_joint_indices=np.arange(15, 29, dtype=np.int64), + num_actions=29, + ) + + assert composed is not None + np.testing.assert_allclose(composed.samples[0].qpos[7 + 15:7 + 29], 1.0) + np.testing.assert_allclose(composed.samples[1].qpos[7 + 15:7 + 29], 2.0) + np.testing.assert_allclose(composed.samples[0].qpos[:7 + 15], 0.0) + np.testing.assert_allclose(composed.samples[1].qpos[:7 + 15], 0.0) + + +def test_robot_worker_pico_arms_event_toggles_mocap_and_arms() -> None: + worker = object.__new__(_RobotControlWorker) + worker.provider_kind = "pico4" + worker.mode = RobotMode.MOCAP + worker._mocap_session = SimpleNamespace(state=MocapSessionState.ACTIVE) + worker.robot = SimpleNamespace(get_state=lambda: SimpleNamespace(base_pos=np.zeros(3), quat=np.array([1.0, 0.0, 0.0, 0.0]), qpos=np.zeros(29))) + worker._last_commanded_motion_qpos = np.zeros(36, dtype=np.float64) + worker._build_resume_alignment_qpos = lambda _hold, _state: np.ones(36, dtype=np.float64) + worker._set_default_standing_reference = lambda _state: None + worker._reset_policy_state = lambda: None + worker._last_retarget_qpos = np.zeros(36, dtype=np.float64) + resets: list[np.ndarray] = [] + ramps: list[str] = [] + worker._ref_proc = SimpleNamespace(reset_alignment=lambda *, target_qpos=None: resets.append(np.asarray(target_qpos).copy())) + worker._standing_return_ramp_duration = 0.5 + worker._standing_return_kp_ramp_floor_ratio = 0.5 + worker._safety = SimpleNamespace(start_kp_ramp=lambda **_kwargs: ramps.append("ramp")) + + event = ControlEvent(event_type=ControlEventType.TOGGLE_ARMS, source="test") + worker._handle_mocap_control_events((event,)) + assert worker.mode == RobotMode.ARMS + + worker._handle_mocap_control_events((event,)) + assert worker.mode == RobotMode.MOCAP + assert len(resets) == 2 + assert ramps == ["ramp", "ramp"] + + +def test_robot_worker_bvh_ignores_pico_arms_event() -> None: + worker = object.__new__(_RobotControlWorker) + worker.provider_kind = "bvh" + worker.mode = RobotMode.MOCAP + worker._mocap_session = SimpleNamespace(state=MocapSessionState.ACTIVE) + worker._handle_mocap_control_events(( + ControlEvent(event_type=ControlEventType.TOGGLE_ARMS, source="test"), + )) + + assert worker.mode == RobotMode.MOCAP + + +def test_robot_worker_mode_state_marks_arms_as_mocap_active() -> None: + worker = object.__new__(_RobotControlWorker) + worker.mode = RobotMode.ARMS + worker._mocap_session = SimpleNamespace(state=MocapSessionState.ACTIVE) + worker._mode_seq = 0 + published: list[object] = [] + worker._mode_pub = SimpleNamespace(publish=lambda _topic, packet: published.append(packet)) + + worker._publish_mode_state() + + packet = published[-1] + assert packet.mode == "arms" + assert packet.mocap_active is True + assert packet.mocap_paused is False + + +@pytest.mark.parametrize( + ("mode", "mocap_active", "mocap_paused"), + [ + ("standing", False, False), + ("mocap", True, False), + ("arms", True, False), + ("mocap", False, True), + ("arms", False, True), + ("damping", False, False), + ], +) +def test_hand_worker_stays_active_in_all_modes(mode: str, mocap_active: bool, mocap_paused: bool) -> None: + packet = ModeStatePacket( + mode=mode, + mocap_active=mocap_active, + mocap_paused=mocap_paused, + timestamp_s=1.0, + seq=1, + ) + + assert _hand_worker_active_for_mode(packet) is True + + +def test_hand_worker_active_state_only_updates_from_mode_packets() -> None: + active = False + mode_packet = None + if isinstance(mode_packet, ModeStatePacket): + active = _hand_worker_active_for_mode(mode_packet) + assert active is False + + mode_packet = ModeStatePacket( + mode="standing", + mocap_active=False, + mocap_paused=False, + timestamp_s=1.0, + seq=1, + ) + if isinstance(mode_packet, ModeStatePacket): + active = _hand_worker_active_for_mode(mode_packet) + assert active is True + + mode_packet = None + if isinstance(mode_packet, ModeStatePacket): + active = _hand_worker_active_for_mode(mode_packet) + assert active is True + + +def test_robot_worker_publish_record_step() -> None: + worker = object.__new__(_RobotControlWorker) + worker.mode = RobotMode.ARMS + worker._mocap_session = SimpleNamespace(state=MocapSessionState.ACTIVE) + worker._mode_seq = 7 + published: list[tuple[str, object]] = [] + worker._record_pub = SimpleNamespace(publish=lambda topic, packet: published.append((topic, packet))) + robot_state = SimpleNamespace( + qpos=np.arange(29, dtype=np.float32), + qvel=np.arange(29, dtype=np.float32) + 10.0, + quat=np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), + ang_vel=np.array([0.1, 0.2, 0.3], dtype=np.float32), + ) + reference_qpos = np.arange(36, dtype=np.float64) + + worker._publish_record_step(robot_state=robot_state, reference_qpos=reference_qpos) + + assert len(published) == 1 + packet = published[0][1] + assert isinstance(packet, RecordStepPacket) + assert packet.mode == "arms" + assert packet.mocap_active is True + assert packet.recordable is True + assert packet.observation_state.shape == (68,) + assert packet.observation_mode.shape == (1,) + assert packet.action_reference_qpos.shape == (36,) + np.testing.assert_allclose(packet.observation_mode, build_mode_observation("arms")) + np.testing.assert_allclose(packet.action_reference_qpos, reference_qpos.astype(np.float32)) + + +def test_robot_worker_enter_damping_publishes_non_recordable_packet() -> None: + worker = object.__new__(_RobotControlWorker) + worker.mode = RobotMode.MOCAP + worker._mode_seq = 9 + worker._mocap_reentry_armed = True + worker._last_commanded_motion_qpos = np.ones(36, dtype=np.float64) + worker._last_mocap_hold_reason = "stale" + worker._default_root_pos = np.zeros(3, dtype=np.float64) + worker.num_actions = 29 + worker._mocap_session = SimpleNamespace(reset=lambda: None) + worker._ref_proc = SimpleNamespace(last_reference_qpos=np.zeros(36, dtype=np.float64)) + robot_state = SimpleNamespace( + qpos=np.arange(29, dtype=np.float32), + qvel=np.arange(29, dtype=np.float32) + 10.0, + quat=np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), + ang_vel=np.array([0.1, 0.2, 0.3], dtype=np.float32), + base_pos=np.array([0.0, 0.0, 0.8], dtype=np.float32), + ) + worker.robot = SimpleNamespace( + set_damping=lambda: None, + exit_debug_mode=lambda: None, + get_state=lambda: robot_state, + ) + published: list[tuple[str, object]] = [] + worker._record_pub = SimpleNamespace(publish=lambda topic, packet: published.append((topic, packet))) + + worker._enter_damping() + + assert worker.mode == RobotMode.DAMPING + assert published + packet = published[-1][1] + assert isinstance(packet, RecordStepPacket) + assert packet.mode == "damping" + assert packet.recordable is False + assert packet.mocap_active is False + np.testing.assert_allclose(packet.observation_mode, np.array([-1.0], dtype=np.float32)) + + +def test_recording_worker_start_save_discard_with_fake_adapter() -> None: + from teleopit.sim2real.mp.ipc import default_endpoints + + calls: list[str] = [] + frames: list[dict[str, np.ndarray]] = [] + + class FakeRecorder: + def start_episode(self) -> None: + calls.append("start") + + def add_frame( + self, + *, + image: np.ndarray, + state: np.ndarray, + mode: np.ndarray, + action: np.ndarray, + hand_action: np.ndarray, + task: str, + ) -> None: + calls.append(f"frame:{task}") + frames.append( + { + "image": image.copy(), + "state": state.copy(), + "mode": mode.copy(), + "action": action.copy(), + "hand_action": hand_action.copy(), + } + ) + + def save_episode(self) -> None: + calls.append("save") + + def discard_episode(self) -> None: + calls.append("discard") + + def finalize(self) -> None: + calls.append("finalize") + + def fake_factory(**_kwargs: object) -> FakeRecorder: + return FakeRecorder() + + stop_event = SimpleNamespace(is_set=lambda: False, set=lambda: None) + endpoints = default_endpoints(base_port=39850) + worker = _RecordingWorker( + { + "recording": { + "enabled": True, + "task": "walk", + "fps": 30, + "min_episode_seconds": 0.0, + "camera": {"width": 2, "height": 2, "key": IMAGE_KEY}, + } + }, + endpoints, + stop_event, # type: ignore[arg-type] + recorder_factory=fake_factory, + ) + writer = SharedFrameRingWriter(shape=(2, 2, 3), dtype=np.uint8, slots=2) + try: + worker._latest_record = RecordStepPacket( + timestamp_s=1.0, + mode="damping", + mocap_active=False, + recordable=False, + observation_state=np.ones(68, dtype=np.float32), + observation_mode=build_mode_observation("standing"), + action_reference_qpos=np.ones(36, dtype=np.float32), + seq=1, + ) + worker._start_episode() + assert calls == [] + + worker._latest_record = RecordStepPacket( + timestamp_s=2.0, + mode="standing", + mocap_active=False, + recordable=True, + observation_state=np.arange(68, dtype=np.float32), + observation_mode=build_mode_observation("standing"), + action_reference_qpos=np.arange(36, dtype=np.float32), + seq=2, + ) + worker._start_episode() + worker._save_episode() + assert calls == ["start", "discard"] + + worker._start_episode() + worker._latest_hand_command = HandCommandPacket( + timestamp_s=2.05, + driver="linkerhand_l6", + mode="gripper", + active=True, + left_pose=np.arange(6, dtype=np.float32), + right_pose=np.arange(6, 12, dtype=np.float32), + seq=1, + ) + desc = writer.write(np.full((2, 2, 3), 5, dtype=np.uint8), timestamp_s=2.1) + worker._handle_video(desc) + worker._save_episode() + + assert calls == ["start", "discard", "start", "frame:walk", "save"] + assert frames[0]["image"].shape == (2, 2, 3) + np.testing.assert_allclose(frames[0]["state"], np.arange(68, dtype=np.float32)) + np.testing.assert_allclose(frames[0]["mode"], build_mode_observation("standing")) + np.testing.assert_allclose(frames[0]["action"], np.arange(36, dtype=np.float32)) + np.testing.assert_allclose(frames[0]["hand_action"], np.arange(12, dtype=np.float32)) + + worker._latest_record = RecordStepPacket( + timestamp_s=3.0, + mode="pause", + mocap_active=False, + recordable=True, + observation_state=np.zeros(68, dtype=np.float32), + observation_mode=build_mode_observation("pause"), + action_reference_qpos=np.zeros(36, dtype=np.float32), + seq=3, + ) + worker._start_episode() + worker._discard_episode("test") + assert calls[-2:] == ["start", "discard"] + finally: + writer.close(unlink=True) + worker._record_sub.close() + worker._video_sub.close() + worker._hand_command_sub.close() + worker._command_sub.close() + worker._frame_reader.close() diff --git a/tests/test_sim2real_runtime.py b/tests/test_sim2real_runtime.py deleted file mode 100644 index 80963eb1..00000000 --- a/tests/test_sim2real_runtime.py +++ /dev/null @@ -1,589 +0,0 @@ -from __future__ import annotations - -from types import SimpleNamespace - -import numpy as np -import pytest - -from teleopit.inputs.realtime_packet import ControlEvent, ControlEventType, RealtimeInputPacket - - -class DummyRobot: - def __init__(self, _cfg: object) -> None: - self._state = SimpleNamespace( - quat=np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), - qpos=np.zeros(29, dtype=np.float32), - qvel=np.zeros(29, dtype=np.float32), - ang_vel=np.zeros(3, dtype=np.float32), - ) - self.sent_positions: list[np.ndarray] = [] - - def enter_debug_mode(self) -> bool: - return True - - def lock_all_joints(self) -> None: - pass - - def get_state(self) -> SimpleNamespace: - return self._state - - def send_positions(self, target_dof_pos: np.ndarray, kp: np.ndarray | None = None, kd: np.ndarray | None = None) -> None: - self.sent_positions.append(np.asarray(target_dof_pos, dtype=np.float32)) - - def set_damping(self) -> None: - pass - - def exit_debug_mode(self) -> None: - pass - - -class DummyRemote: - def __init__(self) -> None: - self.LB = SimpleNamespace(pressed=False, on_pressed=False) - self.RB = SimpleNamespace(pressed=False, on_pressed=False) - self.start = SimpleNamespace(pressed=False, on_pressed=False) - self.A = SimpleNamespace(pressed=False, on_pressed=False) - self.B = SimpleNamespace(pressed=False, on_pressed=False) - self.Y = SimpleNamespace(pressed=False, on_pressed=False) - self.X = SimpleNamespace(pressed=False, on_pressed=False) - - -def _set_button(button: SimpleNamespace, *, pressed: bool, on_pressed: bool) -> None: - button.pressed = pressed - button.on_pressed = on_pressed - - -class DummyProvider: - def __init__(self) -> None: - self._xrt = SimpleNamespace(is_body_data_available=lambda: True) - self._frame = {"Pelvis": (np.zeros(3, dtype=np.float64), np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64))} - self.fps = 30 - self._frame_seq = 0 - self._frame_timestamp = 1.0 - self._control_events: tuple[ControlEvent, ...] = () - - def is_available(self) -> bool: - return True - - def get_frame(self) -> dict[str, tuple[np.ndarray, np.ndarray]]: - return self._frame - - def get_frame_packet(self) -> tuple[dict[str, tuple[np.ndarray, np.ndarray]], float, int]: - return self._frame, self._frame_timestamp, self._frame_seq - - def get_realtime_input_packet(self) -> RealtimeInputPacket[dict[str, tuple[np.ndarray, np.ndarray]]]: - control_events = tuple(self._control_events) - self._control_events = () - return RealtimeInputPacket( - frame=self._frame, - timestamp_s=self._frame_timestamp, - seq=self._frame_seq, - control_events=control_events, - ) - - -class DummyRetargeter: - def __init__(self, qpos: np.ndarray) -> None: - self._qpos = np.asarray(qpos, dtype=np.float64) - - def retarget(self, _frame: object) -> np.ndarray: - return self._qpos.copy() - - def reset(self) -> None: - pass - - -class DummyPolicy: - def __init__(self, expected_obs_dim: int = 166) -> None: - self._expected_obs_dim = expected_obs_dim - self._multi_input = True - self.reset_calls = 0 - - def reset(self) -> None: - self.reset_calls += 1 - - def compute_action(self, _obs: np.ndarray) -> np.ndarray: - return np.zeros(29, dtype=np.float32) - - def get_target_dof_pos(self, _action: np.ndarray) -> np.ndarray: - return np.zeros(29, dtype=np.float32) - - -class DummyVelCmdObservationBuilder: - def __init__(self) -> None: - self.total_obs_size = 166 - self.reset_calls = 0 - self.build_calls: list[dict[str, np.ndarray]] = [] - - def reset(self) -> None: - self.reset_calls += 1 - - def build( - self, - _robot_state: object, - motion_qpos: np.ndarray, - motion_joint_vel: np.ndarray, - _last_action: np.ndarray, - motion_anchor_lin_vel_w: np.ndarray, - motion_anchor_ang_vel_w: np.ndarray, - ) -> np.ndarray: - self.build_calls.append( - { - "motion_qpos": np.asarray(motion_qpos, dtype=np.float32).copy(), - "motion_joint_vel": np.asarray(motion_joint_vel, dtype=np.float32).copy(), - "motion_anchor_lin_vel_w": np.asarray(motion_anchor_lin_vel_w, dtype=np.float32).copy(), - "motion_anchor_ang_vel_w": np.asarray(motion_anchor_ang_vel_w, dtype=np.float32).copy(), - } - ) - return np.zeros(self.total_obs_size, dtype=np.float32) - - -def _make_cfg(transition_duration: float = 1.0) -> dict[str, object]: - return { - "policy_hz": 50.0, - "transition_duration": transition_duration, - "real_robot": {}, - "mocap_switch": {"check_frames": 1}, - "robot": { - "default_angles": [0.0] * 29, - "num_actions": 29, - }, - "controller": {}, - "input": {"provider": "pico4"}, - } - - -def _install_controller_mocks(monkeypatch, *, policy: DummyPolicy, obs_builder: DummyVelCmdObservationBuilder, qpos: np.ndarray) -> None: - import teleopit.sim2real.controller as controller_mod - - monkeypatch.setattr(controller_mod, "UnitreeG1Robot", DummyRobot) - monkeypatch.setattr(controller_mod, "UnitreeRemote", DummyRemote) - monkeypatch.setattr(controller_mod, "VelCmdObservationBuilder", DummyVelCmdObservationBuilder) - monkeypatch.setattr(controller_mod.time, "sleep", lambda _seconds: None) - monkeypatch.setattr( - controller_mod, - "build_sim2real_mocap_components", - lambda *args, **kwargs: SimpleNamespace( - input_provider=DummyProvider(), - retargeter=DummyRetargeter(qpos), - controller=policy, - obs_builder=obs_builder, - ), - ) - - -def test_mode_transitions_reset_stateful_policy(monkeypatch) -> None: - from teleopit.sim2real.controller import Sim2RealController - - policy = DummyPolicy() - obs_builder = DummyVelCmdObservationBuilder() - _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=np.zeros(36, dtype=np.float64)) - - ctrl = Sim2RealController(_make_cfg()) - ctrl._enter_standing() - ctrl._transition_to_mocap() - - # Both _enter_standing and _transition_to_mocap now do full episode-reset - assert policy.reset_calls == 2 - assert obs_builder.reset_calls == 2 - - -def test_reset_policy_state_clears_reference_timeline(monkeypatch) -> None: - from teleopit.sim2real.controller import Sim2RealController - - policy = DummyPolicy() - obs_builder = DummyVelCmdObservationBuilder() - _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=np.zeros(36, dtype=np.float64)) - - ctrl = Sim2RealController(_make_cfg()) - assert ctrl._reference_timeline is not None - ctrl._reference_timeline.append(np.zeros(36, dtype=np.float64), 1.0) - ctrl._last_live_packet_seq = 7 - - ctrl._reset_policy_state() - - assert len(ctrl._reference_timeline) == 0 - assert ctrl._last_live_packet_seq == -1 - - -def test_sim2real_rejects_nonzero_reference_steps_without_buffer(monkeypatch) -> None: - from teleopit.sim2real.controller import Sim2RealController - - policy = DummyPolicy() - obs_builder = DummyVelCmdObservationBuilder() - _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=np.zeros(36, dtype=np.float64)) - - cfg = _make_cfg() - cfg["retarget_buffer_enabled"] = False - cfg["reference_steps"] = [0, 1] - - with pytest.raises(ValueError, match="retarget_buffer_enabled=true"): - Sim2RealController(cfg) - - -def test_sim2real_rejects_reference_horizon_with_insufficient_delay(monkeypatch) -> None: - from teleopit.sim2real.controller import Sim2RealController - - policy = DummyPolicy() - obs_builder = DummyVelCmdObservationBuilder() - _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=np.zeros(36, dtype=np.float64)) - - cfg = _make_cfg() - cfg["reference_steps"] = [0, 2] - cfg["retarget_buffer_delay_s"] = 0.01 - - with pytest.raises(ValueError, match="retarget_buffer_delay_s"): - Sim2RealController(cfg) - - -def test_state_machine_allows_mocap_reentry_after_returning_to_standing(monkeypatch) -> None: - from teleopit.sim2real.controller import RobotMode, Sim2RealController - - policy = DummyPolicy(expected_obs_dim=154) - obs_builder = DummyVelCmdObservationBuilder() - _install_controller_mocks( - monkeypatch, - policy=policy, - obs_builder=obs_builder, - qpos=np.zeros(36, dtype=np.float64), - ) - - ctrl = Sim2RealController(_make_cfg()) - - _set_button(ctrl.remote.start, pressed=True, on_pressed=True) - ctrl._handle_transitions() - assert ctrl.mode == RobotMode.STANDING - - _set_button(ctrl.remote.start, pressed=False, on_pressed=False) - _set_button(ctrl.remote.Y, pressed=True, on_pressed=True) - ctrl._handle_transitions() - assert ctrl.mode == RobotMode.MOCAP - - _set_button(ctrl.remote.Y, pressed=False, on_pressed=False) - _set_button(ctrl.remote.X, pressed=True, on_pressed=True) - ctrl._handle_transitions() - assert ctrl.mode == RobotMode.STANDING - - _set_button(ctrl.remote.X, pressed=False, on_pressed=False) - _set_button(ctrl.remote.Y, pressed=True, on_pressed=False) - ctrl._handle_transitions() - assert ctrl.mode == RobotMode.MOCAP - - -def test_can_switch_to_mocap_returns_false_without_blocking_when_realtime_has_no_frame(monkeypatch) -> None: - from teleopit.sim2real.controller import Sim2RealController - - policy = DummyPolicy() - obs_builder = DummyVelCmdObservationBuilder() - _install_controller_mocks( - monkeypatch, - policy=policy, - obs_builder=obs_builder, - qpos=np.zeros(36, dtype=np.float64), - ) - - ctrl = Sim2RealController(_make_cfg()) - get_frame_calls = 0 - - def blocking_get_frame() -> dict[str, tuple[np.ndarray, np.ndarray]]: - nonlocal get_frame_calls - get_frame_calls += 1 - raise AssertionError("get_frame should not be called before a realtime frame is available") - - ctrl.input_provider.has_frame = lambda: False - ctrl.input_provider.get_frame = blocking_get_frame - - assert ctrl._can_switch_to_mocap() is False - assert get_frame_calls == 0 - - -def test_mocap_step_episode_reset_on_transition(monkeypatch) -> None: - """After _transition_to_mocap (episode-reset), the first mocap step should - produce zero anchor velocities because _last_reference_qpos is None.""" - from teleopit.sim2real.controller import Sim2RealController - - policy = DummyPolicy() - obs_builder = DummyVelCmdObservationBuilder() - target_qpos = np.zeros(36, dtype=np.float64) - target_qpos[0] = 0.3 - _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=target_qpos) - - ctrl = Sim2RealController(_make_cfg(transition_duration=2.0)) - ctrl._transition_to_mocap() - monkeypatch.setattr( - ctrl._ref_proc, - "compute_anchor_velocities", - lambda _qpos: ( - np.array([1.0, 2.0, 3.0], dtype=np.float32), - np.array([4.0, 5.0, 6.0], dtype=np.float32), - ), - ) - - ctrl._mocap_step() - - assert len(obs_builder.build_calls) == 1 - # First step after episode-reset: _last_reference_qpos was None on entry, - # so _compute_anchor_velocities returns zeros for the initial call. - # The mock overrides this, but the first call computes joint vel from - # finite diff with _last_retarget_qpos which IS set, so vel ≠ 0. - # Just verify the observation was built. - assert obs_builder.build_calls[0]["motion_qpos"] is not None - - -def test_mocap_step_velcmd_applies_fixed_initial_yaw_alignment(monkeypatch) -> None: - from teleopit.sim2real.controller import Sim2RealController - - policy = DummyPolicy() - obs_builder = DummyVelCmdObservationBuilder() - target_qpos = np.zeros(36, dtype=np.float64) - target_qpos[3:7] = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64) - _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=target_qpos) - - ctrl = Sim2RealController(_make_cfg(transition_duration=0.0)) - ctrl.robot._state.quat = np.array([0.70710677, 0.0, 0.0, 0.70710677], dtype=np.float32) - monkeypatch.setattr( - ctrl._ref_proc, - "compute_anchor_velocities", - lambda _qpos: ( - np.zeros(3, dtype=np.float32), - np.zeros(3, dtype=np.float32), - ), - ) - - ctrl._mocap_step() - - np.testing.assert_allclose( - obs_builder.build_calls[0]["motion_qpos"][3:7], - np.array([0.70710677, 0.0, 0.0, 0.70710677], dtype=np.float32), - atol=1e-6, - ) - - -def test_mocap_step_velcmd_keeps_fixed_yaw_after_start(monkeypatch) -> None: - from teleopit.sim2real.controller import Sim2RealController - - policy = DummyPolicy() - obs_builder = DummyVelCmdObservationBuilder() - target_qpos = np.zeros(36, dtype=np.float64) - target_qpos[3:7] = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64) - _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=target_qpos) - - ctrl = Sim2RealController(_make_cfg(transition_duration=0.0)) - monkeypatch.setattr( - ctrl._ref_proc, - "compute_anchor_velocities", - lambda _qpos: ( - np.zeros(3, dtype=np.float32), - np.zeros(3, dtype=np.float32), - ), - ) - - ctrl.robot._state.quat = np.array([0.70710677, 0.0, 0.0, 0.70710677], dtype=np.float32) - ctrl._mocap_step() - ctrl.robot._state.quat = np.array([0.0, 0.0, 0.0, 1.0], dtype=np.float32) - ctrl._mocap_step() - - np.testing.assert_allclose( - obs_builder.build_calls[1]["motion_qpos"][3:7], - np.array([0.70710677, 0.0, 0.0, 0.70710677], dtype=np.float32), - atol=1e-6, - ) - - -def test_mocap_step_waits_for_realtime_warmup_before_running_policy(monkeypatch) -> None: - from teleopit.sim2real.controller import Sim2RealController - - policy = DummyPolicy() - obs_builder = DummyVelCmdObservationBuilder() - target_qpos = np.zeros(36, dtype=np.float64) - target_qpos[3] = 1.0 - _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=target_qpos) - - cfg = _make_cfg(transition_duration=0.0) - cfg['realtime_buffer_warmup_steps'] = 2 - ctrl = Sim2RealController(cfg) - monkeypatch.setattr( - ctrl._ref_proc, - 'compute_anchor_velocities', - lambda _qpos: ( - np.zeros(3, dtype=np.float32), - np.zeros(3, dtype=np.float32), - ), - ) - monkeypatch.setattr('teleopit.sim2real.controller.time.monotonic', lambda: 1.1) - - ctrl._mocap_step() - - assert len(obs_builder.build_calls) == 0 - assert ctrl.robot.sent_positions == [] - - ctrl.input_provider._frame_seq = 1 - ctrl.input_provider._frame_timestamp = 1.03 - ctrl._mocap_step() - - assert len(obs_builder.build_calls) == 1 - assert len(ctrl.robot.sent_positions) == 1 - - -def test_sim2real_allows_future_reference_steps_without_explicit_high_watermark(monkeypatch) -> None: - from teleopit.sim2real.controller import Sim2RealController - - policy = DummyPolicy() - obs_builder = DummyVelCmdObservationBuilder() - _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=np.zeros(36, dtype=np.float64)) - - cfg = _make_cfg() - cfg["reference_steps"] = [0, 1, 2, 3, 4] - cfg["retarget_buffer_delay_s"] = 0.08 - cfg["retarget_buffer_window_s"] = 0.5 - cfg["realtime_buffer_low_watermark_steps"] = 0 - - ctrl = Sim2RealController(cfg) - - assert ctrl._ref_cfg.realtime_buffer_low_watermark_steps == 0 - assert ctrl._ref_cfg.realtime_buffer_high_watermark_steps is None - - -def test_mocap_step_reference_qpos_smoothing_filters_motion_change(monkeypatch) -> None: - from teleopit.sim2real.controller import Sim2RealController - - policy = DummyPolicy() - obs_builder = DummyVelCmdObservationBuilder() - target_qpos = np.zeros(36, dtype=np.float64) - target_qpos[3] = 1.0 - _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=target_qpos) - - cfg = _make_cfg(transition_duration=0.0) - cfg["retarget_buffer_enabled"] = False - cfg["reference_qpos_smoothing_alpha"] = 0.5 - ctrl = Sim2RealController(cfg) - monkeypatch.setattr( - ctrl._ref_proc, - "compute_anchor_velocities", - lambda _qpos: ( - np.zeros(3, dtype=np.float32), - np.zeros(3, dtype=np.float32), - ), - ) - - ctrl._mocap_step() - ctrl.retargeter._qpos[0] = 1.0 - ctrl._mocap_step() - - assert len(obs_builder.build_calls) == 2 - np.testing.assert_allclose(obs_builder.build_calls[0]["motion_qpos"][0], 0.0, atol=1e-6) - np.testing.assert_allclose(obs_builder.build_calls[1]["motion_qpos"][0], 0.5, atol=1e-6) - - -def test_mocap_pause_freezes_reference_and_zeroes_velocities(monkeypatch) -> None: - from teleopit.sim2real.controller import Sim2RealController - from teleopit.runtime.mocap_session import MocapSessionState - - policy = DummyPolicy() - obs_builder = DummyVelCmdObservationBuilder() - target_qpos = np.zeros(36, dtype=np.float64) - target_qpos[0] = 0.2 - target_qpos[3] = 1.0 - _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=target_qpos) - - cfg = _make_cfg(transition_duration=0.0) - cfg["retarget_buffer_enabled"] = False - ctrl = Sim2RealController(cfg) - monkeypatch.setattr( - ctrl._ref_proc, - "compute_anchor_velocities", - lambda _qpos: ( - np.zeros(3, dtype=np.float32), - np.zeros(3, dtype=np.float32), - ), - ) - - ctrl._mocap_step() - ctrl.input_provider._control_events = ( - ControlEvent( - event_type=ControlEventType.TOGGLE_PAUSE, - source="pico4:test", - timestamp_s=1.1, - ), - ) - ctrl.retargeter._qpos[0] = 1.0 - ctrl.input_provider._frame_seq = 1 - ctrl.input_provider._frame_timestamp = 1.1 - ctrl._mocap_step() - - assert ctrl._mocap_session.state == MocapSessionState.PAUSED - np.testing.assert_allclose(obs_builder.build_calls[-1]["motion_qpos"][0], 0.2, atol=1e-6) - np.testing.assert_allclose(obs_builder.build_calls[-1]["motion_joint_vel"], np.zeros(29, dtype=np.float32)) - np.testing.assert_allclose( - obs_builder.build_calls[-1]["motion_anchor_lin_vel_w"], - np.zeros(3, dtype=np.float32), - ) - np.testing.assert_allclose( - obs_builder.build_calls[-1]["motion_anchor_ang_vel_w"], - np.zeros(3, dtype=np.float32), - ) - - -def test_mocap_resume_uses_episode_reset_semantics(monkeypatch) -> None: - """Resume does an episode reset and reanchors live mocap root XY.""" - from teleopit.sim2real.controller import Sim2RealController - from teleopit.runtime.mocap_session import MocapSessionState - - policy = DummyPolicy() - obs_builder = DummyVelCmdObservationBuilder() - target_qpos = np.zeros(36, dtype=np.float64) - target_qpos[0] = 0.2 - target_qpos[3] = 1.0 - _install_controller_mocks(monkeypatch, policy=policy, obs_builder=obs_builder, qpos=target_qpos) - - cfg = _make_cfg(transition_duration=0.0) - cfg["retarget_buffer_enabled"] = False - cfg["pause_resume_transition_duration"] = 1.0 - cfg["pause_resume_warmup_steps"] = 0 - ctrl = Sim2RealController(cfg) - monkeypatch.setattr( - ctrl._ref_proc, - "compute_anchor_velocities", - lambda _qpos: ( - np.zeros(3, dtype=np.float32), - np.zeros(3, dtype=np.float32), - ), - ) - - # Run one step, then pause - ctrl._mocap_step() - ctrl.input_provider._control_events = ( - ControlEvent( - event_type=ControlEventType.TOGGLE_PAUSE, - source="pico4:test", - timestamp_s=1.1, - ), - ) - ctrl.input_provider._frame_seq = 1 - ctrl.input_provider._frame_timestamp = 1.1 - ctrl._mocap_step() - assert ctrl._mocap_session.state == MocapSessionState.PAUSED - - # Resume: should stay on the ACTIVE/PAUSED state model. - ctrl.retargeter._qpos[0] = 1.0 - ctrl.retargeter._qpos[7] = 1.0 - ctrl.input_provider._control_events = ( - ControlEvent( - event_type=ControlEventType.TOGGLE_PAUSE, - source="pico4:test", - timestamp_s=1.2, - ), - ) - ctrl.input_provider._frame_seq = 2 - ctrl.input_provider._frame_timestamp = 1.2 - ctrl._mocap_step() - - # Episode-reset resume goes straight to ACTIVE - assert ctrl._mocap_session.state == MocapSessionState.ACTIVE - # Policy was reset (last_action zeroed, history cleared) - assert np.allclose(ctrl._last_action, 0.0) - # Retarget reference jumps to the live mocap pose (joint 0), while root XY - # is reanchored to the paused reference because real-robot XY is unobserved. - np.testing.assert_allclose(obs_builder.build_calls[-1]["motion_qpos"][0], 0.2, atol=1e-6) - np.testing.assert_allclose(obs_builder.build_calls[-1]["motion_qpos"][7], 1.0, atol=1e-6) - np.testing.assert_allclose(obs_builder.build_calls[-1]["motion_joint_vel"], np.zeros(29, dtype=np.float32)) diff --git a/tests/test_sim_loop.py b/tests/test_sim_loop.py index e3ca6d31..11cdebd8 100644 --- a/tests/test_sim_loop.py +++ b/tests/test_sim_loop.py @@ -56,6 +56,7 @@ def reset(self) -> None: class _DummyObsBuilder: def __init__(self) -> None: self.mimic_obs_calls: list[np.ndarray] = [] + self.motion_qpos_calls: list[np.ndarray] = [] self._base = _DummyObsBuilderBase() def reset(self) -> None: @@ -72,6 +73,7 @@ def build( ) -> np.ndarray: del state, motion_joint_vel, last_action, motion_anchor_lin_vel_w, motion_anchor_ang_vel_w self.mimic_obs_calls.append(np.asarray(motion_qpos[:1], dtype=np.float32).copy()) + self.motion_qpos_calls.append(np.asarray(motion_qpos, dtype=np.float32).copy()) return np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32) @@ -110,6 +112,9 @@ def is_available(self) -> bool: class _DummyRetargeter: + def __init__(self) -> None: + self.reset_calls = 0 + def retarget(self, human_data: dict[str, tuple[np.ndarray, np.ndarray]]) -> tuple[np.ndarray, np.ndarray, np.ndarray]: pelvis_x = float(human_data["Pelvis"][0][0]) return ( @@ -119,19 +124,101 @@ def retarget(self, human_data: dict[str, tuple[np.ndarray, np.ndarray]]) -> tupl ) def reset(self) -> None: - pass + self.reset_calls += 1 + + +def _quat_mul(a: np.ndarray, b: np.ndarray) -> np.ndarray: + aw, ax, ay, az = a + bw, bx, by, bz = b + return np.array( + [ + aw * bw - ax * bx - ay * by - az * bz, + aw * bx + ax * bw + ay * bz - az * by, + aw * by - ax * bz + ay * bw + az * bx, + aw * bz + ax * by - ay * bx + az * bw, + ], + dtype=np.float32, + ) -class _DummyRecorder: - def __init__(self) -> None: - self.frames: list[dict[str, object]] = [] +def test_standing_qpos_keeps_yaw_but_drops_roll() -> None: + from teleopit.sim.loop import SimulationLoop + + bus = InProcessBus() + robot = _DummyRobot() + loop = SimulationLoop( + robot=robot, + controller=_DummyController(), + obs_builder=_DummyObsBuilder(), + bus=bus, + cfg={"policy_hz": 50.0, "pd_hz": 50.0, "realtime": False}, + viewers=set(), + ) + + yaw = np.float32(np.pi / 2.0) + roll = np.float32(np.pi / 6.0) + yaw_quat = np.array([np.cos(yaw / 2.0), 0.0, 0.0, np.sin(yaw / 2.0)], dtype=np.float32) + roll_quat = np.array([np.cos(roll / 2.0), np.sin(roll / 2.0), 0.0, 0.0], dtype=np.float32) + tilted_quat = _quat_mul(yaw_quat, roll_quat) + state = RobotState( + qpos=np.array([0.2, -0.1], dtype=np.float32), + qvel=np.zeros(2, dtype=np.float32), + quat=tilted_quat, + ang_vel=np.zeros(3, dtype=np.float32), + timestamp=0.0, + base_pos=np.array([1.0, 2.0, 0.9], dtype=np.float32), + base_lin_vel=np.zeros(3, dtype=np.float32), + ) + + standing_qpos = loop._build_standing_qpos(state) - def add_frame(self, data: dict[str, object]) -> None: - self.frames.append(data) + np.testing.assert_allclose(standing_qpos[0:3], state.base_pos, atol=1e-6) + np.testing.assert_allclose(standing_qpos[3:7], yaw_quat, atol=1e-6) + np.testing.assert_allclose(standing_qpos[7:9], robot.default_dof_pos, atol=1e-6) + + +def test_standing_reference_is_fixed_after_initialization() -> None: + from teleopit.sim.loop import SimulationLoop + + bus = InProcessBus() + robot = _DummyRobot() + loop = SimulationLoop( + robot=robot, + controller=_DummyController(), + obs_builder=_DummyObsBuilder(), + bus=bus, + cfg={"policy_hz": 50.0, "pd_hz": 50.0, "realtime": False}, + viewers=set(), + ) + + first_state = RobotState( + qpos=np.zeros(2, dtype=np.float32), + qvel=np.zeros(2, dtype=np.float32), + quat=np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), + ang_vel=np.zeros(3, dtype=np.float32), + timestamp=0.0, + base_pos=np.array([1.0, 2.0, 0.9], dtype=np.float32), + base_lin_vel=np.zeros(3, dtype=np.float32), + ) + first = loop._set_standing_reference(first_state) + + drifted_state = RobotState( + qpos=np.zeros(2, dtype=np.float32), + qvel=np.zeros(2, dtype=np.float32), + quat=np.array([0.70710677, 0.0, 0.0, 0.70710677], dtype=np.float32), + ang_vel=np.zeros(3, dtype=np.float32), + timestamp=1.0, + base_pos=np.array([5.0, 6.0, 0.9], dtype=np.float32), + base_lin_vel=np.zeros(3, dtype=np.float32), + ) + live = loop._build_standing_qpos(drifted_state) + + np.testing.assert_allclose(loop._standing_qpos, first, atol=1e-6) + assert not np.allclose(live[0:7], first[0:7]) @requires_mujoco -def test_simulation_loop_runs_and_records_without_viewers() -> None: +def test_simulation_loop_runs_without_viewers() -> None: from teleopit.sim.loop import SimulationLoop bus = InProcessBus() @@ -141,20 +228,17 @@ def test_simulation_loop_runs_and_records_without_viewers() -> None: controller=_DummyController(), obs_builder=_DummyObsBuilder(), bus=bus, - cfg={"policy_hz": 50.0, "pd_hz": 50.0, "realtime": False, "transition_duration": 0.0}, + cfg={"policy_hz": 50.0, "pd_hz": 50.0, "realtime": False}, viewers=set(), ) - recorder = _DummyRecorder() result = loop.run( input_provider=_DummyInputProvider(), retargeter=_DummyRetargeter(), num_steps=2, - recorder=recorder, ) assert result["steps"] == 2 - assert len(recorder.frames) == 2 latest_action = bus.get_latest(TOPIC_ACTION) latest_mimic = bus.get_latest(TOPIC_MIMIC_OBS) @@ -166,9 +250,6 @@ def test_simulation_loop_runs_and_records_without_viewers() -> None: assert latest_mimic.shape == (35,) assert latest_state is not None - target = np.asarray(recorder.frames[-1]["target_dof_pos"], dtype=np.float32) - np.testing.assert_allclose(target, np.array([0.6, -0.6], dtype=np.float32)) - @requires_mujoco def test_simulation_loop_interpolates_realtime_input_with_one_frame_delay(monkeypatch) -> None: @@ -213,7 +294,7 @@ def get_frame_packet(self): controller=_DummyController(), obs_builder=obs_builder, bus=bus, - cfg={"policy_hz": 50.0, "pd_hz": 50.0, "realtime": False, "transition_duration": 0.0}, + cfg={"policy_hz": 50.0, "pd_hz": 50.0, "realtime": False}, viewers=set(), ) @@ -243,7 +324,6 @@ def test_simulation_loop_rejects_nonzero_reference_steps_without_realtime_timeli "policy_hz": 50.0, "pd_hz": 50.0, "realtime": False, - "transition_duration": 0.0, "retarget_buffer_enabled": True, "reference_steps": [0, 1], }, @@ -307,7 +387,6 @@ def get_frame_packet(self): 'policy_hz': 50.0, 'pd_hz': 50.0, 'realtime': False, - 'transition_duration': 0.0, 'realtime_buffer_warmup_steps': 2, }, viewers=set(), @@ -326,7 +405,7 @@ def get_frame_packet(self): @requires_mujoco -def test_simulation_loop_allows_future_reference_steps_without_explicit_high_watermark() -> None: +def test_simulation_loop_allows_future_reference_steps() -> None: from teleopit.sim.loop import SimulationLoop bus = InProcessBus() @@ -340,18 +419,13 @@ def test_simulation_loop_allows_future_reference_steps_without_explicit_high_wat "policy_hz": 50.0, "pd_hz": 50.0, "realtime": False, - "transition_duration": 0.0, "reference_steps": [0, 1, 2, 3, 4], "retarget_buffer_delay_s": 0.08, "retarget_buffer_window_s": 0.5, - "realtime_buffer_low_watermark_steps": 0, }, viewers=set(), ) - assert loop._ref_cfg.realtime_buffer_low_watermark_steps == 0 - assert loop._ref_cfg.realtime_buffer_high_watermark_steps is None - class _RealtimeInputProvider: fps = 50 @@ -464,11 +538,8 @@ def get_realtime_input_packet(self): "policy_hz": 50.0, "pd_hz": 50.0, "realtime": False, - "transition_duration": 0.0, "retarget_buffer_enabled": False, "realtime_input_delay_s": 0.0, - "pause_resume_transition_duration": 1.0, - "pause_resume_warmup_steps": 0, }, viewers=set(), ) @@ -484,8 +555,10 @@ def get_realtime_input_packet(self): np.testing.assert_allclose(obs_builder.mimic_obs_calls[1], np.array([0.2], dtype=np.float32), atol=1e-6) np.testing.assert_allclose(obs_builder.mimic_obs_calls[2], np.array([0.0], dtype=np.float32), atol=1e-6) np.testing.assert_allclose(obs_builder.mimic_obs_calls[3], np.array([0.0], dtype=np.float32), atol=1e-6) + assert loop._step_runner.last_retarget_qpos is not None +@requires_mujoco @requires_mujoco def test_simulation_loop_realtime_keyboard_mode_transitions(monkeypatch) -> None: from teleopit.sim.loop import SimulationLoop @@ -562,7 +635,6 @@ def close(self) -> None: "policy_hz": 50.0, "pd_hz": 50.0, "realtime": False, - "transition_duration": 0.0, "retarget_buffer_enabled": False, "realtime_input_delay_s": 0.0, "keyboard": {"enabled": True}, @@ -582,6 +654,138 @@ def close(self) -> None: np.testing.assert_allclose(obs_builder.mimic_obs_calls[2], np.array([0.0], dtype=np.float32), atol=1e-6) +@requires_mujoco +def test_simulation_loop_pico_arms_mode_composes_standing_body_with_live_arm(monkeypatch) -> None: + from teleopit.sim.loop import SimulationLoop + + class _RealtimeInputProvider: + fps = 1 + + def __init__(self) -> None: + self._packets = [ + RealtimeInputPacket( + frame={ + "Pelvis": ( + np.array([0.3, 0.0, 0.0], dtype=np.float32), + np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), + ) + }, + timestamp_s=0.0, + seq=0, + control_events=(), + ), + RealtimeInputPacket( + frame={ + "Pelvis": ( + np.array([0.9, 0.0, 0.0], dtype=np.float32), + np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), + ) + }, + timestamp_s=1.0, + seq=1, + control_events=(), + ), + RealtimeInputPacket( + frame={ + "Pelvis": ( + np.array([0.9, 0.0, 0.0], dtype=np.float32), + np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), + ) + }, + timestamp_s=2.0, + seq=2, + control_events=(ControlEvent(event_type=ControlEventType.TOGGLE_ARMS, source="pico4:test"),), + ), + RealtimeInputPacket( + frame={ + "Pelvis": ( + np.array([1.2, 0.0, 0.0], dtype=np.float32), + np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), + ) + }, + timestamp_s=3.0, + seq=3, + control_events=(ControlEvent(event_type=ControlEventType.TOGGLE_ARMS, source="pico4:test"),), + ), + ] + self._idx = 0 + + def get_realtime_input_packet(self): + packet = self._packets[min(self._idx, len(self._packets) - 1)] + self._idx += 1 + return packet + + class _Retargeter: + def retarget(self, frame: object) -> np.ndarray: + pelvis = np.asarray(frame["Pelvis"][0], dtype=np.float64) + qpos = np.zeros(36, dtype=np.float64) + qpos[0] = pelvis[0] + qpos[3] = 1.0 + qpos[7] = pelvis[0] + qpos[8] = pelvis[0] + 10.0 + return qpos + + class _KeyboardReader: + def __init__(self) -> None: + self._polls = [ + (TerminalKeyEvent("y"),), + (), + (), + (), + ] + self._idx = 0 + + @property + def active(self) -> bool: + return True + + def poll(self) -> tuple[TerminalKeyEvent, ...]: + if self._idx >= len(self._polls): + return () + events = self._polls[self._idx] + self._idx += 1 + return events + + def close(self) -> None: + pass + + monkeypatch.setattr("teleopit.sim.session.TerminalKeyboardReader", _KeyboardReader) + + robot = _DummyRobot() + obs_builder = _DummyObsBuilder() + loop = SimulationLoop( + robot=robot, + controller=_DummyController(), + obs_builder=obs_builder, + bus=InProcessBus(), + cfg={ + "policy_hz": 50.0, + "pd_hz": 50.0, + "realtime": False, + "retarget_buffer_enabled": False, + "realtime_input_delay_s": 0.0, + "keyboard": {"enabled": True}, + "arm_mocap": {"controlled_joint_indices": [1]}, + }, + viewers=set(), + ) + + result = loop.run( + input_provider=_RealtimeInputProvider(), + retargeter=_Retargeter(), + num_steps=4, + ) + + assert result["steps"] == 4 + # Step 2 is ARMS: root/non-arm stays at standing while arm index 1 follows retarget. + np.testing.assert_allclose(obs_builder.motion_qpos_calls[2][0], 0.0, atol=1e-6) + np.testing.assert_allclose(obs_builder.motion_qpos_calls[2][7], 0.5, atol=1e-6) + np.testing.assert_allclose(obs_builder.motion_qpos_calls[2][8], 10.9, atol=1e-6) + # Step 3 toggles back to full-body MOCAP; root XY is reanchored, while non-arm joints follow retarget again. + np.testing.assert_allclose(obs_builder.motion_qpos_calls[3][0], 0.0, atol=1e-6) + np.testing.assert_allclose(obs_builder.motion_qpos_calls[3][7], 1.2, atol=1e-6) + + @requires_mujoco def test_simulation_loop_realtime_keyboard_mode_drains_stale_pause_events(monkeypatch) -> None: from teleopit.sim.loop import SimulationLoop @@ -663,7 +867,6 @@ def close(self) -> None: "policy_hz": 50.0, "pd_hz": 50.0, "realtime": False, - "transition_duration": 0.0, "retarget_buffer_enabled": False, "realtime_input_delay_s": 0.0, "keyboard": {"enabled": True}, @@ -731,7 +934,6 @@ def close(self) -> None: "policy_hz": 50.0, "pd_hz": 50.0, "realtime": False, - "transition_duration": 0.0, "retarget_buffer_enabled": False, "realtime_input_delay_s": 0.0, "keyboard": {"enabled": True}, diff --git a/tests/test_task_registry.py b/tests/test_task_registry.py index 43c1cac1..fdae6cc6 100644 --- a/tests/test_task_registry.py +++ b/tests/test_task_registry.py @@ -22,20 +22,91 @@ def test_general_tracking_task_is_registered() -> None: critic_terms = env_cfg.observations["critic"].terms assert DEFAULT_TASK == GENERAL_TRACKING_TASK - for terms in (actor_terms, critic_terms): - assert "projected_gravity" in terms - assert "ref_base_lin_vel_b" in terms - assert "ref_base_ang_vel_b" in terms - assert "ref_projected_gravity_b" in terms - assert "motion_anchor_pos_b" not in actor_terms - assert "base_lin_vel" not in actor_terms + assert list(actor_terms) == [ + "ref_joint_pos", + "ref_joint_vel", + "ref_anchor_ori_b", + "robot_base_ang_vel_b", + "robot_joint_pos_rel", + "robot_joint_vel", + "prev_action", + "robot_projected_gravity_b", + "ref_anchor_lin_vel_b", + "ref_anchor_ang_vel_b", + "ref_projected_gravity_b", + "ref_anchor_height", + ] + assert list(critic_terms) == [ + "ref_joint_pos", + "ref_joint_vel", + "ref_anchor_pos_b", + "ref_anchor_ori_b", + "robot_tracking_body_pos_b", + "robot_tracking_body_ori_b", + "robot_base_lin_vel_b", + "robot_base_ang_vel_b", + "robot_joint_pos_rel", + "robot_joint_vel", + "prev_action", + "robot_projected_gravity_b", + "ref_anchor_lin_vel_b", + "ref_anchor_ang_vel_b", + "ref_projected_gravity_b", + "ref_anchor_height", + ] assert "actor_history" in env_cfg.observations assert "critic_history" in env_cfg.observations - assert env_cfg.commands["motion"].sampling_mode == "uniform" + assert env_cfg.commands["motion"].sampling_mode == "rewind" assert env_cfg.commands["motion"].window_steps == (0,) + assert env_cfg.rewards["motion_global_root_lin_vel"].weight == 1.0 + assert env_cfg.rewards["motion_global_root_lin_vel"].params == { + "command_name": "motion", + "std": 1.0, + } + assert env_cfg.rewards["motion_global_root_ang_vel"].weight == 1.0 + assert env_cfg.rewards["motion_global_root_ang_vel"].params == { + "command_name": "motion", + "std": 3.0, + } + assert env_cfg.rewards["motion_joint_pos"].weight == 1.0 + assert env_cfg.rewards["motion_joint_pos"].params == { + "command_name": "motion", + "std": 0.5, + } + assert env_cfg.rewards["motion_joint_vel"].weight == 0.5 + assert env_cfg.rewards["motion_joint_vel"].params == { + "command_name": "motion", + "std": 3.0, + } + assert env_cfg.rewards["survival"].weight == 3.0 + assert env_cfg.rewards["survival"].params == {} + reward = env_cfg.rewards["self_collisions"] + assert reward.weight == -0.1 + assert reward.params == { + "sensor_name": "self_collision", + "force_threshold": 1.0, + } + assert "undesired_contacts" not in env_cfg.rewards + sensors = {sensor.name: sensor for sensor in env_cfg.scene.sensors} + assert set(sensors) == {"self_collision"} + assert sensors["self_collision"].primary.mode == "body" + assert sensors["self_collision"].primary.pattern == r".*" + assert sensors["self_collision"].primary.exclude == ( + "left_wrist_yaw_link", + "right_wrist_yaw_link", + ) + assert sensors["self_collision"].secondary.mode == "subtree" + assert sensors["self_collision"].secondary.pattern == "pelvis" + assert sensors["self_collision"].reduce == "maxforce" + assert sensors["self_collision"].history_length == 4 + feet_acc = env_cfg.rewards["feet_acc"] + assert feet_acc.weight == -2.5e-6 + assert feet_acc.params["asset_cfg"].name == "robot" + assert feet_acc.params["asset_cfg"].joint_names == r".*ankle.*" + assert "anti_shake_ang_vel" not in env_cfg.rewards rl_cfg = load_rl_cfg(DEFAULT_TASK) assert rl_cfg.experiment_name == GENERAL_TRACKING_EXPERIMENT_NAME - assert rl_cfg.actor.hidden_dims == (1024, 512, 256, 256, 128) + assert rl_cfg.actor.hidden_dims == (2048, 1024, 512, 256, 128) assert load_runner_cls(DEFAULT_TASK) is MotionTrackingOnPolicyRunner diff --git a/tests/test_termination_config.py b/tests/test_termination_config.py new file mode 100644 index 00000000..bf51159d --- /dev/null +++ b/tests/test_termination_config.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from train_mimic.app import DEFAULT_TASK +from train_mimic.tasks.tracking import mdp + + +def test_general_tracking_termination_config_matches_baseline_policy() -> None: + import mjlab.tasks # noqa: F401 + import train_mimic.tasks # noqa: F401 + from mjlab.tasks.registry import load_env_cfg + + env_cfg = load_env_cfg(DEFAULT_TASK) + terminations = env_cfg.terminations + + assert set(terminations) == { + "time_out", + "anchor_pos", + "anchor_ori", + "ee_body_pos", + } + + anchor_pos = terminations["anchor_pos"] + assert anchor_pos.func is mdp.bad_anchor_pos_z_only + assert anchor_pos.params == { + "command_name": "motion", + "threshold": 0.25, + } + + anchor_ori = terminations["anchor_ori"] + assert anchor_ori.func is mdp.bad_anchor_ori + assert anchor_ori.params["threshold"] == 1.0 + + ee_body_pos = terminations["ee_body_pos"] + assert ee_body_pos.func is mdp.bad_motion_body_pos_z_only + assert ee_body_pos.params == { + "command_name": "motion", + "threshold": 0.25, + "body_names": ( + "left_ankle_roll_link", + "right_ankle_roll_link", + "left_wrist_yaw_link", + "right_wrist_yaw_link", + ), + } diff --git a/tests/test_tracking_rewards.py b/tests/test_tracking_rewards.py new file mode 100644 index 00000000..377551ad --- /dev/null +++ b/tests/test_tracking_rewards.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import torch + +from train_mimic.tasks.tracking.mdp.rewards import ( + motion_joint_position_error_exp, + self_collision_cost, + survival, +) + + +def _env_with_force_history(force_history: torch.Tensor) -> SimpleNamespace: + sensor = SimpleNamespace( + data=SimpleNamespace(force_history=force_history, found=None) + ) + return SimpleNamespace(scene={"self_collision": sensor}) + + +def test_self_collision_cost_counts_contact_slots_not_history_frames() -> None: + force_history = torch.zeros((2, 3, 4, 3), dtype=torch.float32) + force_history[0, 0, 0, 0] = 2.0 + force_history[0, 1, 0, 0] = 3.0 + force_history[0, 2, 2, 2] = 2.0 + force_history[1, 0, 1, 0] = 0.5 + force_history[1, 0, 3, 0] = 1.0 + + penalty = self_collision_cost( + _env_with_force_history(force_history), + sensor_name="self_collision", + force_threshold=1.0, + ) + + torch.testing.assert_close(penalty, torch.tensor([3.0, 0.0])) + + +def test_motion_joint_position_error_exp_uses_mean_squared_error() -> None: + command = SimpleNamespace( + joint_pos=torch.tensor([[0.0, 1.0], [0.5, -0.5]], dtype=torch.float32), + robot_joint_pos=torch.tensor([[0.0, 3.0], [0.5, -0.5]], dtype=torch.float32), + ) + env = SimpleNamespace( + command_manager=SimpleNamespace(get_term=lambda _name: command), + ) + + reward = motion_joint_position_error_exp(env, command_name="motion", std=2.0) + + expected = torch.exp(torch.tensor([-0.5, 0.0])) + torch.testing.assert_close(reward, expected) + + +def test_survival_reward_returns_one_per_env() -> None: + reward = survival(SimpleNamespace(num_envs=3, device=torch.device("cpu"))) + + torch.testing.assert_close(reward, torch.ones(3)) diff --git a/tests/test_train_script.py b/tests/test_train_script.py index 1fdfe9cb..53681c6f 100644 --- a/tests/test_train_script.py +++ b/tests/test_train_script.py @@ -3,11 +3,17 @@ from __future__ import annotations import argparse +from functools import partial +import sys +import types from pathlib import Path +import h5py +import numpy as np import pytest from train_mimic.app import DEFAULT_TASK, validate_checkpoint_path, validate_motion_file +from train_mimic.data.dataset_lib import PRECOMPUTED_MOTION_VERSION from train_mimic.scripts import train from train_mimic.tasks.tracking.config.rl import make_general_tracking_ppo_runner_cfg @@ -27,10 +33,15 @@ def _args(**overrides: object) -> argparse.Namespace: "num_envs": 1024, "max_iterations": 10, "seed": 42, - "wandb_project": None, + "logger": "tensorboard", "experiment_name": None, - "motion_file": "data/datasets/twist2_full/train", + "motion_file": "data/datasets/seed_precomputed", + "robot_xml": None, "resume": None, + "sampling_mode": None, + "rewind_prob": None, + "rewind_min_steps": None, + "rewind_max_steps": None, "device": None, "gpu_ids": None, "master_port": 29500, @@ -43,11 +54,27 @@ def _args(**overrides: object) -> argparse.Namespace: class TestTrainLauncherHelpers: + def test_parse_args_defaults_to_tensorboard_logger(self) -> None: + args = train.parse_args([]) + assert args.logger == "tensorboard" + + def test_parse_args_accepts_logger_choice(self) -> None: + args = train.parse_args(["--logger", "swanlab"]) + assert args.logger == "swanlab" + + def test_parse_args_rejects_removed_wandb_project(self) -> None: + with pytest.raises(SystemExit): + train.parse_args(["--wandb_project", "teleopit"]) + def test_parse_args_with_gpu_ids(self) -> None: args = train.parse_args(["--gpu_ids", "0", "2", "3", "--master_port", "29600"]) assert args.gpu_ids == [0, 2, 3] assert args.master_port == 29600 + def test_parse_args_accepts_robot_xml(self) -> None: + args = train.parse_args(["--robot_xml", "assets/robots/unitree_g1/g1_29dof_dex3.xml"]) + assert args.robot_xml == "assets/robots/unitree_g1/g1_29dof_dex3.xml" + def test_should_launch_multi_gpu(self) -> None: args = _args(gpu_ids=[0, 1, 2, 3]) assert train._should_launch_multi_gpu(args, env={"WORLD_SIZE": "1"}) is True @@ -104,6 +131,100 @@ def test_resolve_device_rejects_wrong_distributed_device(self, monkeypatch: pyte with pytest.raises(ValueError, match="LOCAL_RANK=1"): train._resolve_device(_args(device="cuda:0"), _TorchStub()) + def test_resolve_worker_seed_offsets_by_global_rank(self) -> None: + assert train._resolve_worker_seed(42, env={"WORLD_SIZE": "4", "RANK": "0"}) == 42 + assert train._resolve_worker_seed(42, env={"WORLD_SIZE": "4", "RANK": "3"}) == 300051 + + def test_resolve_worker_seed_defaults_to_base_seed_without_rank(self) -> None: + assert train._resolve_worker_seed(123, env={}) == 123 + + def test_resolve_worker_seed_ignores_rank_outside_distributed_mode(self) -> None: + assert train._resolve_worker_seed(42, env={"WORLD_SIZE": "1", "RANK": "3"}) == 42 + + def test_configure_tensorboard_logger(self) -> None: + agent_cfg = types.SimpleNamespace(logger="wandb", experiment_name="exp") + env_cfg = types.SimpleNamespace() + + active = train._configure_experiment_logger( + logger_name="tensorboard", + agent_cfg=agent_cfg, + env_cfg=env_cfg, + log_dir="/tmp/run", + ) + + assert active is False + assert agent_cfg.logger == "tensorboard" + + def test_configure_wandb_logger_uses_experiment_name_as_project(self) -> None: + agent_cfg = types.SimpleNamespace(logger="tensorboard", experiment_name="exp") + env_cfg = types.SimpleNamespace() + + active = train._configure_experiment_logger( + logger_name="wandb", + agent_cfg=agent_cfg, + env_cfg=env_cfg, + log_dir="/tmp/run", + ) + + assert active is False + assert agent_cfg.logger == "wandb" + assert agent_cfg.wandb_project == "exp" + + def test_configure_swanlab_logger_syncs_tensorboard(self, monkeypatch: pytest.MonkeyPatch) -> None: + calls: list[tuple[str, object]] = [] + fake_swanlab = types.SimpleNamespace( + init=lambda **kwargs: calls.append(("init", kwargs)), + sync_tensorboard_torch=lambda **kwargs: calls.append(("sync", kwargs)), + ) + monkeypatch.setitem(sys.modules, "swanlab", fake_swanlab) + monkeypatch.setenv("RANK", "0") + agent_cfg = types.SimpleNamespace(logger="wandb", experiment_name="exp", max_iterations=10) + env_cfg = types.SimpleNamespace( + commands={ + "motion": types.SimpleNamespace( + motion_file="data/train", + sampling_mode="uniform", + rewind_prob=0.8, + rewind_min_steps=25, + rewind_max_steps=75, + ) + }, + scene=types.SimpleNamespace(num_envs=64), + robot_xml="/tmp/g1.xml", + ) + + active = train._configure_experiment_logger( + logger_name="swanlab", + agent_cfg=agent_cfg, + env_cfg=env_cfg, + log_dir="/tmp/2026-01-01_00-00-00", + ) + + assert active is True + assert agent_cfg.logger == "tensorboard" + assert calls == [ + ( + "init", + { + "project": "exp", + "name": "2026-01-01_00-00-00", + "log_dir": "/tmp/2026-01-01_00-00-00", + "config": { + "experiment_name": "exp", + "motion_file": "data/train", + "robot_xml": "/tmp/g1.xml", + "num_envs": 64, + "max_iterations": 10, + "sampling_mode": "uniform", + "rewind_prob": 0.8, + "rewind_min_steps": 25, + "rewind_max_steps": 75, + }, + }, + ), + ("sync", {"types": ["scalar", "scalars", "image", "text"]}), + ] + def test_main_uses_launcher_branch(self, monkeypatch: pytest.MonkeyPatch) -> None: called: dict[str, object] = {} @@ -129,13 +250,53 @@ def test_tracking_runner_configs_disable_model_upload() -> None: assert make_general_tracking_ppo_runner_cfg().upload_model is False +def test_make_g1_training_robot_cfg_uses_requested_xml() -> None: + from train_mimic.tasks.tracking.config.env import make_g1_training_robot_cfg + + xml_path = Path("assets/robots/unitree_g1/g1_29dof_dex3.xml").resolve() + robot_cfg = make_g1_training_robot_cfg(xml_path) + + assert isinstance(robot_cfg.spec_fn, partial) + assert robot_cfg.spec_fn.args == (xml_path,) + spec = robot_cfg.spec_fn() + assert len(spec.actuators) == 0 + assert len(spec.keys) == 0 + + def test_validate_motion_file_accepts_shard_directories(tmp_path: Path) -> None: - (tmp_path / "shard_000.npz").write_bytes(b"placeholder") + num_frames = 3 + shard_path = tmp_path / "shard_000.h5" + str_dt = h5py.string_dtype(encoding="utf-8") + with h5py.File(shard_path, "w") as h5: + h5.attrs["format"] = "teleopit_precomputed_motion_hdf5" + h5.attrs["version"] = PRECOMPUTED_MOTION_VERSION + h5.create_dataset("body_names", data=np.asarray(["pelvis"], dtype=object), dtype=str_dt) + h5.create_dataset("joint_pos", data=np.zeros((num_frames, 29), dtype=np.float32), chunks=True) + h5.create_dataset("joint_vel", data=np.zeros((num_frames, 29), dtype=np.float32), chunks=True) + h5.create_dataset("body_pos_w", data=np.zeros((num_frames, 1, 3), dtype=np.float32), chunks=True) + h5.create_dataset( + "body_quat_w", + data=np.tile( + np.asarray([[[1.0, 0.0, 0.0, 0.0]]], dtype=np.float32), + (num_frames, 1, 1), + ), + chunks=True, + ) + h5.create_dataset("body_lin_vel_w", data=np.zeros((num_frames, 1, 3), dtype=np.float32), chunks=True) + h5.create_dataset("body_ang_vel_w", data=np.zeros((num_frames, 1, 3), dtype=np.float32), chunks=True) + h5.create_dataset("clip_starts", data=np.asarray([0], dtype=np.int64)) + h5.create_dataset("clip_lengths", data=np.asarray([num_frames], dtype=np.int64)) + h5.create_dataset("clip_fps", data=np.asarray([30], dtype=np.int64)) + h5.create_dataset("source_clip_ids", data=np.asarray([0], dtype=np.int64)) + h5.create_dataset("source_start_frames", data=np.asarray([0], dtype=np.int64)) + h5.create_dataset("source_clip_starts", data=np.asarray([0], dtype=np.int64)) + h5.create_dataset("source_clip_lengths", data=np.asarray([num_frames], dtype=np.int64)) + h5.create_dataset("source_clip_fps", data=np.asarray([30], dtype=np.int64)) validate_motion_file(str(tmp_path)) def test_validate_motion_file_rejects_non_shard_paths(tmp_path: Path) -> None: - with pytest.raises(FileNotFoundError, match="Motion shard directory not found"): + with pytest.raises(FileNotFoundError, match="Motion dataset not found"): validate_motion_file(str(tmp_path)) diff --git a/third_party/linkerhand-python-sdk b/third_party/linkerhand-python-sdk new file mode 160000 index 00000000..40dbb8f8 --- /dev/null +++ b/third_party/linkerhand-python-sdk @@ -0,0 +1 @@ +Subproject commit 40dbb8f85a98d636285fc23f391fa083d0a30724 diff --git a/third_party/somehand b/third_party/somehand new file mode 160000 index 00000000..0e9adba4 --- /dev/null +++ b/third_party/somehand @@ -0,0 +1 @@ +Subproject commit 0e9adba4e193540279f8e5803a9339a49666499a diff --git a/train_mimic/__init__.py b/train_mimic/__init__.py index fc8aa849..4796fe89 100644 --- a/train_mimic/__init__.py +++ b/train_mimic/__init__.py @@ -1,6 +1,6 @@ -__version__ = "0.1.0" +__version__ = "0.4.0" import os # TRAIN_MIMIC_ROOT_DIR points to train_mimic/ directory -TRAIN_MIMIC_ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) \ No newline at end of file +TRAIN_MIMIC_ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) diff --git a/train_mimic/app.py b/train_mimic/app.py index e261086e..9d08250f 100644 --- a/train_mimic/app.py +++ b/train_mimic/app.py @@ -11,19 +11,22 @@ GENERAL_TRACKING_TASK, SUPPORTED_TASKS, ) +from train_mimic.data.dataset_lib import find_precomputed_motion_shards, validate_precomputed_motion_dataset DEFAULT_TASK = GENERAL_TRACKING_TASK def validate_motion_file(motion_file: str) -> None: - p = Path(motion_file) - if p.is_dir() and any(p.glob("*.npz")): - return - raise FileNotFoundError( - f"Motion shard directory not found: {motion_file}. Provide --motion_file " - f"pointing to a directory of shard NPZ files. " - f"Example: {DEFAULT_TRAIN_MOTION_FILE}" - ) + try: + find_precomputed_motion_shards(Path(motion_file)) + except FileNotFoundError as exc: + raise FileNotFoundError( + f"Motion dataset not found: {motion_file}. Provide --motion_file pointing " + "to a precomputed training dataset root produced by " + "train_mimic/scripts/data/precompute_dataset.py. Example: " + f"{DEFAULT_TRAIN_MOTION_FILE}" + ) from exc + validate_precomputed_motion_dataset(Path(motion_file)) def validate_checkpoint_path(checkpoint_path: str) -> None: diff --git a/train_mimic/configs/datasets/lafan1.yaml b/train_mimic/configs/datasets/lafan1.yaml new file mode 100644 index 00000000..19df2d0c --- /dev/null +++ b/train_mimic/configs/datasets/lafan1.yaml @@ -0,0 +1,15 @@ +name: lafan1 +target_fps: 30 +preprocess: + min_frames: 22 + normalize_root_xy: true + ground_align: none + max_all_off_ground_s: 0.8 + off_ground_height: 0.12 +sources: + - name: lafan1 + type: bvh + input: data/lafan1_bvh + bvh_format: lafan1 + exclude_patterns: + - "*obstacle*" diff --git a/train_mimic/configs/datasets/lafan1_v1.yaml b/train_mimic/configs/datasets/lafan1_v1.yaml deleted file mode 100644 index 2c4ffac3..00000000 --- a/train_mimic/configs/datasets/lafan1_v1.yaml +++ /dev/null @@ -1,12 +0,0 @@ -name: lafan1_v1 -target_fps: 30 -val_percent: 5 -hash_salt: "" -preprocess: - normalize_root_xy: true - ground_align: clip_min_foot -sources: - - name: lafan1_v1 - type: bvh - input: data/lafan1_bvh - bvh_format: lafan1 diff --git a/train_mimic/configs/datasets/pico_motion.yaml b/train_mimic/configs/datasets/pico_motion.yaml new file mode 100644 index 00000000..11004041 --- /dev/null +++ b/train_mimic/configs/datasets/pico_motion.yaml @@ -0,0 +1,12 @@ +name: pico_record +target_fps: 30 +preprocess: + normalize_root_xy: true + ground_align: first_frame_foot + min_frames: 22 + max_all_off_ground_s: 0.8 + off_ground_height: 0.12 +sources: + - name: pico_clips + type: npz + input: data/pico_motion/clips diff --git a/train_mimic/configs/datasets/seed_v1.yaml b/train_mimic/configs/datasets/seed.yaml similarity index 65% rename from train_mimic/configs/datasets/seed_v1.yaml rename to train_mimic/configs/datasets/seed.yaml index 882a39d8..e5b5a146 100644 --- a/train_mimic/configs/datasets/seed_v1.yaml +++ b/train_mimic/configs/datasets/seed.yaml @@ -1,16 +1,16 @@ -name: seed_v1 +name: seed target_fps: 30 -val_percent: 5 -hash_salt: "" preprocess: - normalize_root_xy: true - ground_align: clip_min_foot min_frames: 22 + normalize_root_xy: true + ground_align: none + max_all_off_ground_s: 0.8 + off_ground_height: 0.12 sources: - name: seed_full type: seed_csv input: data/SEED/g1/csv metadata_csv: data/SEED/seed_metadata_v003.csv - weight: 1.0 + seed_filter_preset: groot_strict filters: is_mirror: [false] diff --git a/train_mimic/configs/datasets/seed_v1_smoke.yaml b/train_mimic/configs/datasets/seed_v1_smoke.yaml deleted file mode 100644 index 7b9bbd39..00000000 --- a/train_mimic/configs/datasets/seed_v1_smoke.yaml +++ /dev/null @@ -1,13 +0,0 @@ -name: seed_v1_smoke -target_fps: 30 -val_percent: 5 -hash_salt: "" -preprocess: - normalize_root_xy: true - ground_align: clip_min_foot - min_frames: 22 -sources: - - name: seed_smoke - type: seed_csv - input: data/SEED/g1/csv/221118 - weight: 1.0 diff --git a/train_mimic/configs/datasets/seed_v2_3h.yaml b/train_mimic/configs/datasets/seed_v2_3h.yaml deleted file mode 100644 index 2f275413..00000000 --- a/train_mimic/configs/datasets/seed_v2_3h.yaml +++ /dev/null @@ -1,16 +0,0 @@ -name: seed_v2_3h -target_fps: 30 -val_percent: 5 -hash_salt: "" -preprocess: - normalize_root_xy: true - ground_align: clip_min_foot - min_frames: 22 -sources: - - name: seed_3h_sampled - type: seed_csv - input: data/SEED/g1/csv - metadata_csv: train_mimic/data/seed/seed_metadata_v003_3h.csv - weight: 1.0 - filters: - is_mirror: [false] diff --git a/train_mimic/configs/datasets/twist2_full.yaml b/train_mimic/configs/datasets/twist2.yaml similarity index 79% rename from train_mimic/configs/datasets/twist2_full.yaml rename to train_mimic/configs/datasets/twist2.yaml index 48c14d78..19043b8d 100644 --- a/train_mimic/configs/datasets/twist2_full.yaml +++ b/train_mimic/configs/datasets/twist2.yaml @@ -1,10 +1,11 @@ -name: twist2_full +name: twist2 target_fps: 30 -val_percent: 5 -hash_salt: "" preprocess: + min_frames: 22 normalize_root_xy: true - ground_align: clip_min_foot + ground_align: none + max_all_off_ground_s: 0.8 + off_ground_height: 0.12 sources: - name: OMOMO_g1_GMR type: pkl diff --git a/train_mimic/data/dataset_builder.py b/train_mimic/data/dataset_builder.py index 7ef9a52f..20369814 100644 --- a/train_mimic/data/dataset_builder.py +++ b/train_mimic/data/dataset_builder.py @@ -1,10 +1,13 @@ from __future__ import annotations import csv +import fnmatch +import json import os import shutil import multiprocessing from concurrent.futures import ProcessPoolExecutor, as_completed +from collections import Counter from dataclasses import asdict, dataclass, field from pathlib import Path from tempfile import TemporaryDirectory @@ -16,12 +19,14 @@ import numpy as np from train_mimic.data.dataset_lib import ( - hash_split, + DEFAULT_HDF5_MAX_WINDOW_FRAMES, + DEFAULT_HDF5_WINDOW_OVERLAP_FRAMES, + FULL_CLIP_ARRAY_KEYS, inspect_clip_dict, inspect_npz, merge_npz_files, resample_along_time, - utc_now_iso, + write_hdf5_motion_shard, write_json, ) from train_mimic.data.preprocess import ( @@ -38,7 +43,7 @@ from teleopit.retargeting.export_pkl import convert_bvh_to_retarget_pkl, mocap_xml_path PROJECT_ROOT = Path(__file__).resolve().parents[2] -DEFAULT_SPEC_PATH = PROJECT_ROOT / "train_mimic" / "configs" / "datasets" / "twist2_full.yaml" +DEFAULT_SPEC_PATH = PROJECT_ROOT / "train_mimic" / "configs" / "datasets" / "twist2.yaml" DEFAULT_DATASETS_ROOT = PROJECT_ROOT / "data" / "datasets" DEFAULT_FK_SAMPLE_CLIPS = 2 DEFAULT_FK_SAMPLE_FRAMES = 16 @@ -53,16 +58,9 @@ "seed_csv": ".csv", } _DATASET_ROOT_MARKERS = { - "train", - "val", - "train.npz", - "val.npz", - "manifest_resolved.csv", - "build_info.json", + "clips", } -_SPLIT_NAMES = ("train", "val") - _PROCESS_FK_EXTRACTOR: MotionFkExtractor | None = None _PROCESS_BVH_MODEL_CACHE: dict[str, mujoco.MjModel] = {} @@ -80,20 +78,19 @@ class DatasetSourceSpec: name: str type: str input: str - weight: float = 1.0 bvh_format: str | None = None robot_name: str = "unitree_g1" max_frames: int = 0 metadata_csv: str | None = None filters: dict[str, list] | None = None + seed_filter_preset: str | None = None + exclude_patterns: tuple[str, ...] = () @dataclass(frozen=True) class DatasetSpec: name: str target_fps: int - val_percent: int - hash_salt: str sources: list[DatasetSourceSpec] preprocess: DatasetPreprocessSpec = field(default_factory=DatasetPreprocessSpec) @@ -105,10 +102,8 @@ class DatasetClipRow: file_rel: str num_frames: int fps: int - resolved_split: str resolved_npz_path: str - weight: float = 1.0 - clip_index: int = -1 # index into shard NPZ clip_starts/clip_lengths; -1 = standalone clip + clip_index: int = -1 # index into source clip metadata; -1 = standalone clip @dataclass(frozen=True) @@ -148,6 +143,48 @@ class ConversionTask: preprocess: DatasetPreprocessSpec = field(default_factory=DatasetPreprocessSpec) +@dataclass(frozen=True) +class FilteredClipResult: + input_path: str + reason: str + + +_FILTERED_MARKER_SUFFIX = ".filtered.json" + + +@dataclass(frozen=True) +class SeedFilterRule: + columns: tuple[str, ...] + patterns: tuple[str, ...] + label: str + + +_SEED_FILTER_PRESETS: dict[str, tuple[SeedFilterRule, ...]] = { + "groot_strict": ( + SeedFilterRule( + columns=("content_body_position",), + patterns=("sitting", "on all fours", "handstand"), + label="content_body_position", + ), + SeedFilterRule( + columns=("content_type_of_movement",), + patterns=("crawling", "on hands and knees", "rolling", "flipping", "climbing"), + label="content_type_of_movement", + ), + SeedFilterRule( + columns=("content_props",), + patterns=("chair", "crutch", "crutches", "ladder", "box", "table", "bike", "scooter", "bed"), + label="content_props", + ), + SeedFilterRule( + columns=("filename", "move_name"), + patterns=("safety_roll", "cartwheel", "box_jump", "monkey_jump", "walking_on_edge"), + label="filename_or_move_name", + ), + ) +} + + def _display_path(path: Path) -> str: try: return path.relative_to(PROJECT_ROOT).as_posix() @@ -165,6 +202,34 @@ def _validate_source_type(raw_type: object, spec_path: Path, source_name: str) - return source_type +def _validate_seed_filter_preset( + source_type: str, + seed_filter_preset: str | None, + metadata_csv: str | None, + *, + spec_path: Path, + source_name: str, +) -> str | None: + if seed_filter_preset is None: + return None + if source_type != "seed_csv": + raise ValueError( + f"source {source_name!r} uses seed_filter_preset={seed_filter_preset!r}, " + "but seed_filter_preset is supported only for seed_csv sources" + ) + if metadata_csv is None: + raise ValueError( + f"source {source_name!r} uses seed_filter_preset={seed_filter_preset!r} " + f"without metadata_csv: {spec_path}" + ) + if seed_filter_preset not in _SEED_FILTER_PRESETS: + raise ValueError( + f"source {source_name!r} has unknown seed_filter_preset {seed_filter_preset!r} " + f"in {spec_path}. Expected one of {sorted(_SEED_FILTER_PRESETS)}." + ) + return seed_filter_preset + + def _load_preprocess_spec(raw: object, spec_path: Path) -> DatasetPreprocessSpec: if raw is None: return DatasetPreprocessSpec() @@ -193,10 +258,31 @@ def _load_preprocess_spec(raw: object, spec_path: Path) -> DatasetPreprocessSpec else float(raw["max_all_off_ground_s"]) ), off_ground_height=float(raw.get("off_ground_height", 0.2)), + max_feet_off_ground_s=( + None + if raw.get("max_feet_off_ground_s") in (None, "", "null") + else float(raw["max_feet_off_ground_s"]) + ), + foot_off_ground_height=float(raw.get("foot_off_ground_height", 0.08)), ) return validate_preprocess_spec(spec) +def _load_exclude_patterns(raw: object, spec_path: Path, source_name: str) -> tuple[str, ...]: + if raw is None: + return () + if not isinstance(raw, list): + raise ValueError( + f"source {source_name!r} exclude_patterns must be a list in {spec_path}" + ) + patterns = tuple(str(item).strip() for item in raw if str(item).strip()) + if not patterns: + raise ValueError( + f"source {source_name!r} exclude_patterns must contain at least one non-empty pattern" + ) + return patterns + + def load_dataset_spec(path: str | Path) -> DatasetSpec: spec_path = Path(path).expanduser().resolve() if not spec_path.is_file(): @@ -213,11 +299,6 @@ def load_dataset_spec(path: str | Path) -> DatasetSpec: if target_fps <= 0: raise ValueError(f"dataset spec target_fps must be > 0: {spec_path}") - val_percent = int(payload.get("val_percent", 0)) - if val_percent <= 0 or val_percent >= 100: - raise ValueError(f"dataset spec val_percent must be in [1, 99]: {spec_path}") - - hash_salt = str(payload.get("hash_salt", "")) preprocess = _load_preprocess_spec(payload.get("preprocess"), spec_path) raw_sources = payload.get("sources") if not isinstance(raw_sources, list) or not raw_sources: @@ -241,10 +322,6 @@ def load_dataset_spec(path: str | Path) -> DatasetSpec: if not source_input: raise ValueError(f"source {source_name!r} missing non-empty input: {spec_path}") - source_weight = float(raw.get("weight", 1.0)) - if source_weight <= 0: - raise ValueError(f"source {source_name!r} has non-positive weight: {source_weight}") - bvh_format = raw.get("bvh_format") if source_type == "bvh": bvh_format = str(bvh_format or "").strip() @@ -266,26 +343,40 @@ def load_dataset_spec(path: str | Path) -> DatasetSpec: filters = raw.get("filters") if filters is not None and not isinstance(filters, dict): raise ValueError(f"source {source_name!r} filters must be a mapping: {spec_path}") + seed_filter_preset = raw.get("seed_filter_preset") + if seed_filter_preset is not None: + seed_filter_preset = str(seed_filter_preset).strip() or None + seed_filter_preset = _validate_seed_filter_preset( + source_type, + seed_filter_preset, + metadata_csv, + spec_path=spec_path, + source_name=source_name, + ) + exclude_patterns = _load_exclude_patterns( + raw.get("exclude_patterns"), + spec_path, + source_name, + ) sources.append( DatasetSourceSpec( name=source_name, type=source_type, input=source_input, - weight=source_weight, bvh_format=bvh_format, robot_name=robot_name, max_frames=max_frames, metadata_csv=metadata_csv, filters=filters, + seed_filter_preset=seed_filter_preset, + exclude_patterns=exclude_patterns, ) ) return DatasetSpec( name=name, target_fps=target_fps, - val_percent=val_percent, - hash_salt=hash_salt, sources=sources, preprocess=preprocess, ) @@ -298,14 +389,64 @@ def resolve_dataset_paths(spec: DatasetSpec, *, output_root: str | Path | None = return DatasetPaths(dataset_dir=dataset_dir, clips_root=clips_root) -def split_output_dir(dataset_dir: Path, split: str) -> Path: - if split not in _SPLIT_NAMES: - raise ValueError(f"invalid split {split!r}, expected one of {_SPLIT_NAMES}") - return dataset_dir / split +def shard_output_path(dataset_dir: Path, shard_index: int) -> Path: + return dataset_dir / f"shard_{shard_index:03d}.h5" -def shard_output_path(split_dir: Path, shard_index: int) -> Path: - return split_dir / f"shard_{shard_index:03d}.npz" +def _clear_existing_motion_shards(dataset_dir: Path) -> None: + """Remove stale top-level HDF5 outputs before writing a fresh dataset build.""" + if not dataset_dir.exists(): + return + for pattern in ("shard_*.h5", ".*_chunk_*.h5"): + for path in dataset_dir.glob(pattern): + if path.is_file(): + path.unlink() + + +def _clear_intermediate_clips(clips_root: Path) -> None: + if not clips_root.exists() and not clips_root.is_symlink(): + return + if clips_root.is_dir() and not clips_root.is_symlink(): + shutil.rmtree(clips_root) + else: + clips_root.unlink() + + +def _path_contains(path: Path, candidate: Path) -> bool: + try: + candidate.relative_to(path) + except ValueError: + return False + return True + + +def _source_input_candidate_path(source: DatasetSourceSpec) -> Path: + candidate = Path(source.input).expanduser() + if not candidate.is_absolute(): + candidate = PROJECT_ROOT / candidate + return candidate.resolve(strict=False) + + +def _ensure_source_inputs_do_not_overlap_intermediate_clips( + spec: DatasetSpec, + clips_root: Path, +) -> None: + clips_path = clips_root.resolve(strict=False) + conflicts: list[tuple[str, Path]] = [] + for source in spec.sources: + input_path = _source_input_candidate_path(source) + if _path_contains(clips_path, input_path) or _path_contains(input_path, clips_path): + conflicts.append((source.name, input_path)) + + if not conflicts: + return + + details = ", ".join(f"{name}={path}" for name, path in conflicts) + raise ValueError( + f"source input overlaps the temporary clips directory {clips_path}: {details}. " + "The dataset builder deletes that directory during full builds, so move source clips " + "outside data/datasets//clips or choose a different output root." + ) def resolve_source_input_path(source: DatasetSourceSpec) -> Path: @@ -346,10 +487,29 @@ def _filter_seed_csv_by_metadata( source: DatasetSourceSpec, all_files: list[SourceInputFile], input_dir: Path, -) -> list[SourceInputFile]: + *, + quiet: bool = False, + report: dict[str, Any] | None = None, +) -> tuple[list[SourceInputFile], dict[str, Any]]: """Filter seed_csv files using metadata_csv + filters from the source spec.""" - if source.metadata_csv is None or source.filters is None: - return all_files + if report is None: + report = { + "source": source.name, + "type": source.type, + "metadata_csv": source.metadata_csv, + "seed_filter_preset": source.seed_filter_preset, + "exclude_patterns": list(source.exclude_patterns), + "scanned_files": len(all_files), + "metadata_rows_matched": len(all_files), + "preset_rejected_rows": 0, + "path_rejected_files": 0, + "kept_files": len(all_files), + "filtered_files": 0, + "preset_reject_reasons": {}, + "path_reject_reasons": {}, + } + if source.metadata_csv is None or (source.filters is None and source.seed_filter_preset is None): + return all_files, report meta_path = Path(source.metadata_csv).expanduser() if not meta_path.is_absolute(): @@ -360,24 +520,53 @@ def _filter_seed_csv_by_metadata( with meta_path.open("r", encoding="utf-8") as f: reader = csv.DictReader(f) fieldnames = reader.fieldnames or [] - for col in source.filters: - if col not in fieldnames: - raise ValueError( - f"filter column {col!r} not found in metadata CSV for {source.name}. " - f"Available: {sorted(fieldnames)}" - ) + if source.filters is not None: + for col in source.filters: + if col not in fieldnames: + raise ValueError( + f"filter column {col!r} not found in metadata CSV for {source.name}. " + f"Available: {sorted(fieldnames)}" + ) if "move_g1_path" not in fieldnames: raise ValueError(f"metadata CSV missing move_g1_path column for {source.name}") + if source.seed_filter_preset is not None: + for rule in _SEED_FILTER_PRESETS[source.seed_filter_preset]: + for col in rule.columns: + if col not in fieldnames: + raise ValueError( + f"seed_filter_preset column {col!r} not found in metadata CSV for " + f"{source.name}. Available: {sorted(fieldnames)}" + ) # Normalize filter values to strings for comparison str_filters: dict[str, set[str]] = {} - for col, allowed_values in source.filters.items(): + for col, allowed_values in (source.filters or {}).items(): str_filters[col] = {str(v) for v in allowed_values} rows = [] for row in reader: if all(row.get(col, "") in vals for col, vals in str_filters.items()): rows.append(row) + report["metadata_csv"] = str(meta_path) + report["metadata_rows_matched"] = len(rows) + + if source.seed_filter_preset is not None: + reject_counts: Counter[str] = Counter() + kept_rows = [] + for row in rows: + reasons: list[str] = [] + for rule in _SEED_FILTER_PRESETS[source.seed_filter_preset]: + for pattern in rule.patterns: + if any(pattern in row.get(col, "").lower() for col in rule.columns): + reasons.append(f"{rule.label}:{pattern}") + break + if reasons: + reject_counts.update(reasons) + continue + kept_rows.append(row) + report["preset_rejected_rows"] = len(rows) - len(kept_rows) + report["preset_reject_reasons"] = dict(sorted(reject_counts.items())) + rows = kept_rows # Build set of allowed relative paths (without .csv suffix) allowed_rels: set[str] = set() @@ -395,24 +584,101 @@ def _filter_seed_csv_by_metadata( allowed_rels.add(Path(g1_path).stem) filtered = [f for f in all_files if f.rel_no_suffix.as_posix() in allowed_rels] - print( - f"[FILTER] source={source.name}: {len(filtered)}/{len(all_files)} files " - f"after metadata filtering" - ) - return filtered + report["kept_files"] = len(filtered) + report["filtered_files"] = int(report.get("scanned_files", len(all_files))) - len(filtered) + if not quiet: + print( + f"[FILTER] source={source.name}: {len(filtered)}/{len(all_files)} files " + f"after metadata filtering" + ) + if source.seed_filter_preset is not None and report["preset_rejected_rows"] > 0: + print( + f"[FILTER] source={source.name}: preset={source.seed_filter_preset} " + f"rejected={report['preset_rejected_rows']} " + f"reasons={report['preset_reject_reasons']}" + ) + return filtered, report -def _collect_source_files(source: DatasetSourceSpec) -> tuple[list[SourceInputFile], Path]: +def _collect_source_files_with_report( + source: DatasetSourceSpec, + *, + quiet: bool = False, +) -> tuple[list[SourceInputFile], Path, dict[str, Any]]: input_path = resolve_source_input_path(source) _ensure_not_dataset_root_npz_input(source, input_path) suffix = _SOURCE_SUFFIXES[source.type] + def _base_report(items_count: int) -> dict[str, Any]: + return { + "source": source.name, + "type": source.type, + "metadata_csv": source.metadata_csv, + "seed_filter_preset": source.seed_filter_preset, + "exclude_patterns": list(source.exclude_patterns), + "scanned_files": items_count, + "metadata_rows_matched": items_count, + "preset_rejected_rows": 0, + "path_rejected_files": 0, + "kept_files": items_count, + "filtered_files": 0, + "preset_reject_reasons": {}, + "path_reject_reasons": {}, + } + + def _matches_exclude(item: SourceInputFile) -> str | None: + rel_no_suffix = item.rel_no_suffix.as_posix() + candidates = ( + rel_no_suffix, + f"{rel_no_suffix}{suffix}", + item.path.name, + item.path.stem, + ) + for pattern in source.exclude_patterns: + pat = pattern.lower() + for candidate in candidates: + if fnmatch.fnmatchcase(candidate.lower(), pat): + return pattern + return None + + def _apply_path_excludes( + items: list[SourceInputFile], + report: dict[str, Any], + ) -> list[SourceInputFile]: + if not source.exclude_patterns: + return items + reject_counts: Counter[str] = Counter() + kept: list[SourceInputFile] = [] + for item in items: + reason = _matches_exclude(item) + if reason is None: + kept.append(item) + else: + reject_counts[reason] += 1 + report["path_rejected_files"] = len(items) - len(kept) + report["path_reject_reasons"] = dict(sorted(reject_counts.items())) + report["kept_files"] = len(kept) + report["filtered_files"] = len(items) - len(kept) + if not quiet and report["path_rejected_files"] > 0: + print( + f"[FILTER] source={source.name}: path_excludes rejected=" + f"{report['path_rejected_files']} reasons={report['path_reject_reasons']}" + ) + if not kept: + raise ValueError( + f"no files remain after path exclude filtering for source {source.name}: {input_path}" + ) + return kept + if input_path.is_file(): if input_path.suffix.lower() != suffix: raise ValueError( f"source {source.name} expected {suffix} input, got file {input_path.name}" ) - return [SourceInputFile(path=input_path, rel_no_suffix=Path(input_path.stem))], input_path.parent + items = [SourceInputFile(path=input_path, rel_no_suffix=Path(input_path.stem))] + report = _base_report(len(items)) + items = _apply_path_excludes(items, report) + return items, input_path.parent, report if not input_path.is_dir(): raise FileNotFoundError(f"source input is neither file nor directory: {input_path}") @@ -430,14 +696,28 @@ def _collect_source_files(source: DatasetSourceSpec) -> tuple[list[SourceInputFi for path in files ] + report = _base_report(len(items)) + items = _apply_path_excludes(items, report) + # Apply metadata filtering for seed_csv sources if source.type == "seed_csv" and source.metadata_csv is not None: - items = _filter_seed_csv_by_metadata(source, items, input_path) + items, report = _filter_seed_csv_by_metadata( + source, + items, + input_path, + quiet=quiet, + report=report, + ) if not items: raise ValueError( f"no files remain after metadata filtering for source {source.name}: {input_path}" ) + return items, input_path, report + + +def _collect_source_files(source: DatasetSourceSpec) -> tuple[list[SourceInputFile], Path]: + items, input_path, _report = _collect_source_files_with_report(source, quiet=False) return items, input_path @@ -480,12 +760,102 @@ def build_source_conversion_tasks( return tasks -def _source_has_cached_npz(out_dir: Path) -> bool: - return out_dir.is_dir() and any(out_dir.rglob("*.npz")) +def _filtered_marker_path(output_path: Path) -> Path: + return output_path.with_name(f"{output_path.name}{_FILTERED_MARKER_SUFFIX}") + + +def _conversion_task_signature(task: ConversionTask) -> dict[str, Any]: + input_path = Path(task.input_path) + stat = input_path.stat() + signature = { + "source_name": task.source_name, + "source_type": task.source_type, + "input_path": str(input_path), + "input_size": int(stat.st_size), + "input_mtime_ns": int(stat.st_mtime_ns), + "bvh_format": task.bvh_format, + "robot_name": task.robot_name, + "max_frames": int(task.max_frames), + "mocap_xml": task.mocap_xml, + "preprocess": task.preprocess.to_dict(), + } + return json.loads(json.dumps(signature, sort_keys=True)) + + +def _filtered_marker_matches(task: ConversionTask) -> bool: + marker_path = _filtered_marker_path(Path(task.output_path)) + if not marker_path.is_file(): + return False + try: + payload = json.loads(marker_path.read_text(encoding="utf-8")) + except (OSError, json.JSONDecodeError): + marker_path.unlink(missing_ok=True) + return False + if payload.get("signature") == _conversion_task_signature(task): + return True + marker_path.unlink(missing_ok=True) + return False + + +def _write_filtered_marker(task: ConversionTask, reason: str) -> None: + marker_path = _filtered_marker_path(Path(task.output_path)) + write_json( + marker_path, + { + "filtered": True, + "reason": reason, + "signature": _conversion_task_signature(task), + }, + ) + + +def _clear_filtered_marker(output_path: Path) -> None: + _filtered_marker_path(output_path).unlink(missing_ok=True) + + +def _prune_unexpected_source_outputs(out_dir: Path, tasks: list[ConversionTask]) -> None: + if not out_dir.is_dir(): + return + expected = {Path(task.output_path).resolve() for task in tasks} + for npz_path in sorted(out_dir.rglob("*.npz")): + if npz_path.resolve() not in expected: + npz_path.unlink() + _clear_filtered_marker(npz_path) + for marker_path in sorted(out_dir.rglob(f"*.npz{_FILTERED_MARKER_SUFFIX}")): + output_path = Path(str(marker_path)[: -len(_FILTERED_MARKER_SUFFIX)]) + if output_path.resolve() not in expected: + marker_path.unlink(missing_ok=True) + + +def _current_source_npz_files(source: DatasetSourceSpec, source_dir: Path) -> list[Path]: + items, _scan_root, _report = _collect_source_files_with_report(source, quiet=True) + allowed_rels = {item.rel_no_suffix.as_posix() for item in items} + return [ + npz_path + for npz_path in sorted(source_dir.rglob("*.npz")) + if npz_path.relative_to(source_dir).with_suffix("").as_posix() in allowed_rels + ] def _pending_tasks(tasks: list[ConversionTask]) -> list[ConversionTask]: - return [task for task in tasks if not Path(task.output_path).is_file()] + pending: list[ConversionTask] = [] + for task in tasks: + output_path = Path(task.output_path) + if output_path.is_file(): + _clear_filtered_marker(output_path) + continue + if _filtered_marker_matches(task): + continue + pending.append(task) + return pending + + +def _build_source_filter_reports(spec: DatasetSpec) -> list[dict[str, Any]]: + reports: list[dict[str, Any]] = [] + for source in spec.sources: + _, _, report = _collect_source_files_with_report(source, quiet=True) + reports.append(report) + return reports def _get_fk_extractor() -> MotionFkExtractor: @@ -530,7 +900,7 @@ def _maybe_preprocess_npz_file( np.savez(npz_path, **processed) -def _convert_task(task: ConversionTask) -> str: +def _convert_task(task: ConversionTask) -> str | FilteredClipResult: input_path = Path(task.input_path) output_path = Path(task.output_path) output_path.parent.mkdir(parents=True, exist_ok=True) @@ -539,36 +909,42 @@ def _convert_task(task: ConversionTask) -> str: if task.source_type == "npz": payload = np.load(input_path, allow_pickle=True) clip_dict = {key: payload[key] for key in payload.files} + clip_label = f"{task.source_name}:{input_path.name}" processed = _maybe_preprocess_clip_dict( clip_dict, preprocess=task.preprocess, - clip_label=f"{task.source_name}:{input_path.name}", + clip_label=clip_label, ) np.savez(output_path, **processed) inspect_npz(output_path) + _clear_filtered_marker(output_path) return str(output_path) extractor = _get_fk_extractor() if task.source_type == "pkl": convert_pkl_to_npz(str(input_path), str(output_path), extractor=extractor) + clip_label = f"{task.source_name}:{input_path.name}" _maybe_preprocess_npz_file( output_path, preprocess=task.preprocess, - clip_label=f"{task.source_name}:{input_path.name}", + clip_label=clip_label, ) inspect_npz(output_path) + _clear_filtered_marker(output_path) return str(output_path) if task.source_type == "seed_csv": arrays = convert_seed_csv_to_arrays(str(input_path), extractor=extractor) + clip_label = f"{task.source_name}:{input_path.name}" arrays = _maybe_preprocess_clip_dict( arrays, preprocess=task.preprocess, - clip_label=f"{task.source_name}:{input_path.name}", + clip_label=clip_label, ) output_path.parent.mkdir(parents=True, exist_ok=True) np.savez(str(output_path), **arrays) inspect_npz(output_path) + _clear_filtered_marker(output_path) return str(output_path) if task.source_type == "bvh": @@ -586,15 +962,26 @@ def _convert_task(task: ConversionTask) -> str: model, ) convert_pkl_to_npz(str(tmp_pkl), str(output_path), extractor=extractor) + clip_label = f"{task.source_name}:{input_path.name}" _maybe_preprocess_npz_file( output_path, preprocess=task.preprocess, - clip_label=f"{task.source_name}:{input_path.name}", + clip_label=clip_label, ) inspect_npz(output_path) + _clear_filtered_marker(output_path) return str(output_path) raise ValueError(f"unsupported source type: {task.source_type}") + except ValueError as exc: + clip_label = f"{task.source_name}:{input_path.name}" + if str(exc).startswith(f"{clip_label}:"): + output_path.unlink(missing_ok=True) + _write_filtered_marker(task, str(exc)) + return FilteredClipResult(input_path=str(input_path), reason=str(exc)) + raise RuntimeError( + f"failed converting source={task.source_name} input={input_path}: {exc}" + ) from exc except Exception as exc: raise RuntimeError( f"failed converting source={task.source_name} input={input_path}: {exc}" @@ -604,7 +991,10 @@ def _convert_task(task: ConversionTask) -> str: def _run_conversion_tasks_serial(tasks: list[ConversionTask]) -> None: total = len(tasks) for idx, task in enumerate(tasks, start=1): - _convert_task(task) + result = _convert_task(task) + if isinstance(result, FilteredClipResult): + print(f"[FILTER] {result.reason}") + continue print(f"[CONVERT] {idx}/{total} source={task.source_name} -> {_display_path(Path(task.output_path))}") @@ -628,12 +1018,15 @@ def run_conversion_tasks(tasks: list[ConversionTask], *, jobs: int = DEFAULT_JOB try: for future in as_completed(future_map): task = future_map[future] - future.result() + result = future.result() completed += 1 - print( - f"[CONVERT] {completed}/{total} " - f"source={task.source_name} -> {_display_path(Path(task.output_path))}" - ) + if isinstance(result, FilteredClipResult): + print(f"[FILTER] {result.reason}") + else: + print( + f"[CONVERT] {completed}/{total} " + f"source={task.source_name} -> {_display_path(Path(task.output_path))}" + ) except Exception: for future in future_map: future.cancel() @@ -654,14 +1047,15 @@ def convert_source_to_npz_clips( if force and output_dir.exists(): shutil.rmtree(output_dir) tasks = build_source_conversion_tasks(source, output_dir, preprocess=preprocess) + _prune_unexpected_source_outputs(output_dir, tasks) pending = _pending_tasks(tasks) - if tasks and not pending and _source_has_cached_npz(output_dir): + if tasks and not pending: print(f"[CACHE] reusing source={source.name} clips: {_display_path(output_dir)}") else: output_dir.mkdir(parents=True, exist_ok=True) - run_conversion_tasks(pending or tasks, jobs=jobs) + run_conversion_tasks(pending, jobs=jobs) - npz_files = sorted(output_dir.rglob("*.npz")) + npz_files = _current_source_npz_files(source, output_dir) if not npz_files: raise ValueError(f"no converted npz clips found for source {source.name}: {output_dir}") return { @@ -681,6 +1075,7 @@ def convert_sources_to_npz( ) -> dict[str, Path]: if jobs <= 0: raise ValueError(f"jobs must be > 0, got {jobs}") + _ensure_source_inputs_do_not_overlap_intermediate_clips(spec, paths.clips_root) if force and paths.dataset_dir.exists(): shutil.rmtree(paths.dataset_dir) paths.clips_root.mkdir(parents=True, exist_ok=True) @@ -690,12 +1085,13 @@ def convert_sources_to_npz( for source in spec.sources: out_dir = source_out_dirs[source.name] tasks = build_source_conversion_tasks(source, out_dir, preprocess=spec.preprocess) + _prune_unexpected_source_outputs(out_dir, tasks) pending = _pending_tasks(tasks) - if tasks and not pending and _source_has_cached_npz(out_dir): + if tasks and not pending: print(f"[CACHE] reusing source={source.name} clips: {_display_path(out_dir)}") continue out_dir.mkdir(parents=True, exist_ok=True) - pending_tasks.extend(pending or tasks) + pending_tasks.extend(pending) run_conversion_tasks(pending_tasks, jobs=jobs) return source_out_dirs @@ -707,7 +1103,7 @@ def collect_clip_rows(spec: DatasetSpec, *, paths: DatasetPaths) -> list[Dataset source_dir = paths.clips_root / source.name if not source_dir.is_dir(): raise FileNotFoundError(f"expected converted npz dir for {source.name}: {source_dir}") - npz_files = sorted(source_dir.rglob("*.npz")) + npz_files = _current_source_npz_files(source, source_dir) if not npz_files: raise ValueError(f"no converted npz clips found for source {source.name}: {source_dir}") for npz_path in npz_files: @@ -720,59 +1116,12 @@ def collect_clip_rows(spec: DatasetSpec, *, paths: DatasetPaths) -> list[Dataset file_rel=_display_path(npz_path), num_frames=meta.num_frames, fps=meta.fps, - resolved_split="", resolved_npz_path=str(npz_path), - weight=source.weight, ) ) - return assign_splits(rows, spec.val_percent, spec.hash_salt) - - -def assign_splits(rows: list[DatasetClipRow], val_percent: int, hash_salt: str) -> list[DatasetClipRow]: if not rows: - raise ValueError("no clip rows to split") - - resolved = [ - DatasetClipRow( - clip_id=row.clip_id, - source=row.source, - file_rel=row.file_rel, - num_frames=row.num_frames, - fps=row.fps, - resolved_split=hash_split(row.clip_id, val_percent, hash_salt), - resolved_npz_path=row.resolved_npz_path, - weight=row.weight, - ) - for row in rows - ] - train_count = sum(1 for row in resolved if row.resolved_split == "train") - val_count = sum(1 for row in resolved if row.resolved_split == "val") - if train_count > 0 and val_count > 0: - return resolved - if len(resolved) < 2: - raise ValueError("dataset must contain at least 2 clips to create both train and val splits") - - ordered = sorted(resolved, key=lambda row: row.clip_id) - adjusted: list[DatasetClipRow] = [] - for idx, row in enumerate(ordered): - split = row.resolved_split - if val_count == 0 and idx == 0: - split = "val" - elif train_count == 0 and idx == 0: - split = "train" - adjusted.append( - DatasetClipRow( - clip_id=row.clip_id, - source=row.source, - file_rel=row.file_rel, - num_frames=row.num_frames, - fps=row.fps, - resolved_split=split, - resolved_npz_path=row.resolved_npz_path, - weight=row.weight, - ) - ) - return adjusted + raise ValueError("no clip rows collected") + return rows def run_sample_fk_checks( @@ -807,64 +1156,17 @@ def run_sample_fk_checks( return summaries -def write_manifest_resolved(rows: list[DatasetClipRow], dataset_dir: Path) -> Path: - out_path = dataset_dir / "manifest_resolved.csv" - dataset_dir.mkdir(parents=True, exist_ok=True) - with out_path.open("w", encoding="utf-8", newline="") as handle: - writer = csv.writer(handle) - writer.writerow( - [ - "clip_id", - "source", - "file_rel", - "num_frames", - "fps", - "resolved_split", - "resolved_npz_path", - "weight", - "clip_index", - ] - ) - for row in sorted(rows, key=lambda item: item.clip_id): - writer.writerow( - [ - row.clip_id, - row.source, - row.file_rel, - row.num_frames, - row.fps, - row.resolved_split, - row.resolved_npz_path, - row.weight, - row.clip_index, - ] - ) - return out_path - - -def _rows_for_split(rows: list[DatasetClipRow], split: str) -> tuple[list[Path], list[float]]: - selected = [(Path(row.resolved_npz_path), row.weight) for row in rows if row.resolved_split == split] - if not selected: - raise ValueError(f"no clips for split={split}") - files, weights = zip(*selected) - return list(files), list(weights) - - -_ARRAY_KEYS = [ - "joint_pos", "joint_vel", "body_pos_w", "body_quat_w", - "body_lin_vel_w", "body_ang_vel_w", -] +_ARRAY_KEYS = FULL_CLIP_ARRAY_KEYS def _batch_convert_chunk( file_paths: list[str], - weights: list[float], target_fps: int, output_path: str, label: str, preprocess: DatasetPreprocessSpec, ) -> dict[str, Any]: - """Worker: convert a batch of PKL/seed_csv files and write one merged chunk NPZ. + """Worker: convert a batch of PKL/seed_csv files and write one HDF5 shard. Designed to run in a spawned subprocess via ProcessPoolExecutor. """ @@ -874,13 +1176,12 @@ def _batch_convert_chunk( extractor = MotionFkExtractor() acc: dict[str, list[np.ndarray]] = {k: [] for k in _ARRAY_KEYS} clip_lengths: list[int] = [] - clip_weights: list[float] = [] body_names: np.ndarray | None = None total = len(file_paths) filtered = 0 kept_file_paths: list[str] = [] - for i, (file_path, weight) in enumerate(zip(file_paths, weights)): + for i, file_path in enumerate(file_paths): try: if file_path.endswith(".csv"): arrays = convert_seed_csv_to_arrays(file_path, extractor=extractor) @@ -902,6 +1203,8 @@ def _batch_convert_chunk( new_t = max(1, round(old_t * target_fps / fps)) for key in _ARRAY_KEYS: arrays[key] = resample_along_time(arrays[key], new_t) + qn_root = np.linalg.norm(arrays["root_quat_w"], axis=-1, keepdims=True) + arrays["root_quat_w"] = arrays["root_quat_w"] / np.where(qn_root < 1e-8, 1.0, qn_root) qn = np.linalg.norm(arrays["body_quat_w"], axis=-1, keepdims=True) arrays["body_quat_w"] = arrays["body_quat_w"] / np.where(qn < 1e-8, 1.0, qn) @@ -925,7 +1228,6 @@ def _batch_convert_chunk( for key in _ARRAY_KEYS: acc[key].append(np.asarray(clip_dict[key])) clip_lengths.append(int(np.asarray(clip_dict["joint_pos"]).shape[0])) - clip_weights.append(weight) kept_file_paths.append(file_path) if (i + 1) % 500 == 0 or (i + 1) == total: @@ -942,6 +1244,7 @@ def _batch_convert_chunk( "fps": target_fps, "duration_s": 0.0, "clip_lengths": [], + "source_clip_lengths": [], "kept_file_paths": [], } @@ -957,10 +1260,13 @@ def _batch_convert_chunk( merged["clip_starts"] = clip_starts merged["clip_lengths"] = clip_lengths_arr merged["clip_fps"] = np.full(kept, target_fps, dtype=np.int64) - merged["clip_weights"] = np.array(clip_weights, dtype=np.float64) - Path(output_path).parent.mkdir(parents=True, exist_ok=True) - np.savez(output_path, **merged) + shard_info = write_hdf5_motion_shard( + merged, + Path(output_path), + max_window_frames=DEFAULT_HDF5_MAX_WINDOW_FRAMES, + overlap_frames=DEFAULT_HDF5_WINDOW_OVERLAP_FRAMES, + ) total_frames = int(merged["joint_pos"].shape[0]) print( @@ -968,14 +1274,17 @@ def _batch_convert_chunk( f"{filtered} filtered, {total_frames} frames -> {Path(output_path).name}", flush=True, ) + hdf5_windows = int(shard_info["clips"]) return { "output": output_path, - "clips": kept, - "num_clips": kept, + "clips": hdf5_windows, + "num_clips": hdf5_windows, + "source_clips": kept, "frames": total_frames, "fps": target_fps, "duration_s": total_frames / max(target_fps, 1), - "clip_lengths": clip_lengths, + "clip_lengths": list(shard_info["clip_lengths"]), + "source_clip_lengths": clip_lengths, "kept_file_paths": kept_file_paths, } @@ -987,7 +1296,12 @@ def _shard_stats( fps: int, ) -> dict[str, Any]: total_clips = sum(len(info["clip_lengths"]) for info in shard_infos) - total_frames = sum(sum(int(length) for length in info["clip_lengths"]) for info in shard_infos) + total_frames = sum(int(info.get("frames", 0)) for info in shard_infos) + if total_frames <= 0: + total_frames = sum( + sum(int(length) for length in info.get("source_clip_lengths", info["clip_lengths"])) + for info in shard_infos + ) return { "output": str(output_dir), "shards": len(shard_infos), @@ -1000,51 +1314,51 @@ def _shard_stats( def _batch_convert_split( - clips: list[tuple[str, float]], + file_paths: list[str], target_fps: int, output_dir: Path, jobs: int, - split_name: str, + label: str, preprocess: DatasetPreprocessSpec, ) -> tuple[dict[str, Any], list[dict[str, Any]]]: - """Convert clips for one split using parallel chunk workers.""" - if not clips: - raise ValueError(f"no clips for split {split_name}") + """Convert clips for one dataset using parallel chunk workers.""" + if not file_paths: + raise ValueError(f"no clips for dataset {label}") - file_paths = [c[0] for c in clips] - weights = [c[1] for c in clips] - num_workers = min(jobs, len(clips)) + num_workers = min(jobs, len(file_paths)) output_dir.mkdir(parents=True, exist_ok=True) if num_workers <= 1: shard_path = shard_output_path(output_dir, 0) stats = _batch_convert_chunk( - file_paths, weights, target_fps, str(shard_path), split_name, preprocess, + file_paths, target_fps, str(shard_path), label, preprocess, ) if int(stats["clips"]) <= 0: - raise ValueError(f"no valid clips remain for split {split_name} after preprocessing") + raise ValueError(f"no valid clips remain for dataset {label} after preprocessing") shard_infos = [{ "path": shard_path, + "clips": int(stats.get("clips", 0)), + "frames": int(stats.get("frames", 0)), "clip_lengths": list(stats.pop("clip_lengths", [])), + "source_clip_lengths": list(stats.pop("source_clip_lengths", [])), "kept_file_paths": list(stats.pop("kept_file_paths", [])), }] return _shard_stats(output_dir=output_dir, shard_infos=shard_infos, fps=target_fps), shard_infos # Split into chunks, one per worker - chunk_size = (len(clips) + num_workers - 1) // num_workers - chunk_args: list[tuple[list[str], list[float], int, str, str, DatasetPreprocessSpec]] = [] + chunk_size = (len(file_paths) + num_workers - 1) // num_workers + chunk_args: list[tuple[list[str], int, str, str, DatasetPreprocessSpec]] = [] for i in range(num_workers): start = i * chunk_size - end = min(start + chunk_size, len(clips)) - if start >= len(clips): + end = min(start + chunk_size, len(file_paths)) + if start >= len(file_paths): break - chunk_out = str(output_dir / f".{split_name}_chunk_{i}.npz") + chunk_out = str(output_dir / f".{label}_chunk_{i}.h5") chunk_args.append(( file_paths[start:end], - weights[start:end], target_fps, chunk_out, - f"{split_name}[{i}]", + f"{label}[{i}]", preprocess, )) @@ -1053,7 +1367,7 @@ def _batch_convert_split( try: with ProcessPoolExecutor(max_workers=len(chunk_args), mp_context=ctx) as executor: futures = { - executor.submit(_batch_convert_chunk, *args): args[3] + executor.submit(_batch_convert_chunk, *args): args[2] for args in chunk_args } try: @@ -1065,21 +1379,23 @@ def _batch_convert_split( future.cancel() raise except (PermissionError, OSError): - print(f"[WARN] process pool unavailable; falling back to serial for {split_name}") + print(f"[WARN] process pool unavailable; falling back to serial for {label}") shard_path = shard_output_path(output_dir, 0) stats = _batch_convert_chunk( file_paths, - weights, target_fps, str(shard_path), - split_name, + label, preprocess, ) if int(stats["clips"]) <= 0: - raise ValueError(f"no valid clips remain for split {split_name} after preprocessing") + raise ValueError(f"no valid clips remain for dataset {label} after preprocessing") shard_infos = [{ "path": shard_path, + "clips": int(stats.get("clips", 0)), + "frames": int(stats.get("frames", 0)), "clip_lengths": list(stats.pop("clip_lengths", [])), + "source_clip_lengths": list(stats.pop("source_clip_lengths", [])), "kept_file_paths": list(stats.pop("kept_file_paths", [])), }] return _shard_stats(output_dir=output_dir, shard_infos=shard_infos, fps=target_fps), shard_infos @@ -1087,8 +1403,8 @@ def _batch_convert_split( shard_infos: list[dict[str, Any]] = [] shard_index = 0 for args in chunk_args: - tmp_path = Path(args[3]) - chunk_stat = chunk_results.get(args[3], {}) + tmp_path = Path(args[2]) + chunk_stat = chunk_results.get(args[2], {}) if int(chunk_stat.get("clips", 0)) <= 0: tmp_path.unlink(missing_ok=True) continue @@ -1097,16 +1413,19 @@ def _batch_convert_split( tmp_path.replace(final_path) shard_infos.append({ "path": final_path, + "clips": int(chunk_stat.get("clips", 0)), + "frames": int(chunk_stat.get("frames", 0)), "clip_lengths": list(chunk_stat.get("clip_lengths", [])), + "source_clip_lengths": list(chunk_stat.get("source_clip_lengths", [])), "kept_file_paths": list(chunk_stat.get("kept_file_paths", [])), }) if not shard_infos: - raise ValueError(f"no valid clips remain for split {split_name} after preprocessing") + raise ValueError(f"no valid clips remain for dataset {label} after preprocessing") stats = _shard_stats(output_dir=output_dir, shard_infos=shard_infos, fps=target_fps) print( - f"[SHARDS] {split_name}: {stats['shards']} shards, " + f"[SHARDS] {label}: {stats['shards']} shards, " f"{stats['clips']} clips, {stats['frames']} frames ({stats['duration_s']:.1f}s)", flush=True, ) @@ -1122,151 +1441,56 @@ def _build_dataset_batch( skip_validate: bool = False, jobs: int = DEFAULT_JOBS, ) -> dict[str, Any]: - """Batch build: enumerate -> split -> parallel chunk convert -> merge. + """Batch build: enumerate -> parallel chunk convert -> minimal HDF5 shards. - Skips writing individual clip NPZ files. Each worker converts a batch of - PKL files in-memory and writes one merged chunk NPZ. + Skips writing individual clip files. Each worker converts a batch of PKL or + seed CSV files in-memory and writes one minimal HDF5 shard. """ if force and paths.dataset_dir.exists(): shutil.rmtree(paths.dataset_dir) - # 1. Enumerate all source files and pre-compute splits - clip_entries: list[tuple[str, str, str, float, str]] = [] + clip_entries: list[tuple[str, str, str]] = [] + source_filter_reports: list[dict[str, Any]] = [] for source in spec.sources: - items, _ = _collect_source_files(source) + items, _, filter_report = _collect_source_files_with_report(source, quiet=False) + source_filter_reports.append(filter_report) for item in items: clip_id = f"{source.name}:{item.rel_no_suffix.as_posix()}" - split = hash_split(clip_id, spec.val_percent, spec.hash_salt) - clip_entries.append((str(item.path), clip_id, source.name, source.weight, split)) - - if len(clip_entries) < 2: - raise ValueError("dataset must contain at least 2 clips") - - # Ensure both splits are non-empty - train_entries = [e for e in clip_entries if e[4] == "train"] - val_entries = [e for e in clip_entries if e[4] == "val"] - if not train_entries or not val_entries: - ordered = sorted(clip_entries, key=lambda e: e[1]) - target_split = "val" if not val_entries else "train" - first = ordered[0] - clip_entries = [ - (p, cid, src, w, target_split if cid == first[1] else sp) - for p, cid, src, w, sp in clip_entries - ] - train_entries = [e for e in clip_entries if e[4] == "train"] - val_entries = [e for e in clip_entries if e[4] == "val"] + clip_entries.append((str(item.path), clip_id, source.name)) + + if not clip_entries: + raise ValueError("dataset must contain at least 1 clip") print( - f"[DATASET] {spec.name}: {len(clip_entries)} clips " - f"({len(train_entries)} train, {len(val_entries)} val), " - f"jobs={jobs}", + f"[DATASET] {spec.name}: {len(clip_entries)} clips, jobs={jobs}", flush=True, ) - # 2. Process each split with parallel chunk workers + _clear_existing_motion_shards(paths.dataset_dir) paths.dataset_dir.mkdir(parents=True, exist_ok=True) - train_dir = split_output_dir(paths.dataset_dir, "train") - val_dir = split_output_dir(paths.dataset_dir, "val") - - train_stats, train_shards = _batch_convert_split( - [(e[0], e[3]) for e in train_entries], - spec.target_fps, - train_dir, - jobs, - "train", - spec.preprocess, - ) - val_stats, val_shards = _batch_convert_split( - [(e[0], e[3]) for e in val_entries], + stats, shard_infos = _batch_convert_split( + [e[0] for e in clip_entries], spec.target_fps, - val_dir, + paths.dataset_dir, jobs, - "val", + spec.name, spec.preprocess, ) - def _consume_entries_by_path( - entries: list[tuple[str, str, str, float, str]], - ) -> dict[str, list[tuple[str, str, str, float, str]]]: - buckets: dict[str, list[tuple[str, str, str, float, str]]] = {} - for entry in entries: - buckets.setdefault(entry[0], []).append(entry) - return buckets - - def _build_rows_for_shards( - entries: list[tuple[str, str, str, float, str]], - shard_infos: list[dict[str, Any]], - ) -> list[DatasetClipRow]: - entry_buckets = _consume_entries_by_path(entries) - built_rows: list[DatasetClipRow] = [] - for shard in shard_infos: - shard_path = Path(shard["path"]) - kept_paths = list(shard["kept_file_paths"]) - clip_lengths = list(shard["clip_lengths"]) - if len(kept_paths) != len(clip_lengths): - raise ValueError( - f"kept path count mismatch for {shard_path}: {len(kept_paths)} vs {len(clip_lengths)}" - ) - for clip_index, (kept_path, num_frames) in enumerate(zip(kept_paths, clip_lengths)): - candidates = entry_buckets.get(kept_path) - if not candidates: - raise ValueError(f"kept path not found in split entries: {kept_path}") - path, clip_id, source, weight, split = candidates.pop(0) - built_rows.append(DatasetClipRow( - clip_id=clip_id, - source=source, - file_rel=_display_path(Path(path)), - num_frames=int(num_frames), - fps=spec.target_fps, - resolved_split=split, - resolved_npz_path=str(shard_path), - weight=weight, - clip_index=clip_index, - )) - return built_rows - - # 3. Write manifest with correct num_frames and clip_index for kept clips only - rows: list[DatasetClipRow] = [] - train_rows = _build_rows_for_shards(train_entries, train_shards) - val_rows = _build_rows_for_shards(val_entries, val_shards) - rows.extend(train_rows) - rows.extend(val_rows) - manifest_path = write_manifest_resolved(rows, paths.dataset_dir) - - # 5. Build report - report: dict[str, Any] = { + return { "dataset": spec.name, - "built_at_utc": utc_now_iso(), "target_fps": spec.target_fps, - "val_percent": spec.val_percent, - "hash_salt": spec.hash_salt, "dataset_dir": str(paths.dataset_dir), - "build_dir": str(paths.dataset_dir), - "clips_dir": "", - "manifest_resolved": str(manifest_path), "skip_validate": bool(skip_validate), "skip_fk_check": bool(skip_fk_check), "jobs": int(jobs), "preprocess": spec.preprocess.to_dict(), - "sources": [asdict(source) for source in spec.sources], - "splits": { - "train": train_stats, - "val": val_stats, - }, - "clip_counts": { - "total": len(rows), - "train": len(train_rows), - "val": len(val_rows), - }, - "input_clip_counts": { - "total": len(clip_entries), - "train": len(train_entries), - "val": len(val_entries), - }, + "source_filters": source_filter_reports, + "stats": stats, + "shards": [str(info["path"]) for info in shard_infos], + "input_clips": len(clip_entries), "fk_checks": [], } - write_json(paths.dataset_dir / "build_info.json", report) - return report def build_dataset_from_spec( @@ -1292,7 +1516,11 @@ def build_dataset_from_spec( jobs=jobs, ) - # Legacy per-file mode for bvh/npz sources + # Per-file mode for BVH/NPZ sources. Converted clips are temporary build + # inputs and are rebuilt every time; final training data is the minimal + # shard(s) in dataset_dir. + _ensure_source_inputs_do_not_overlap_intermediate_clips(spec, paths.clips_root) + _clear_intermediate_clips(paths.clips_root) convert_sources_to_npz(spec, paths=paths, force=force, jobs=jobs) rows = collect_clip_rows(spec, paths=paths) @@ -1300,101 +1528,37 @@ def build_dataset_from_spec( if not skip_fk_check: fk_checks = run_sample_fk_checks(rows) - train_rows = [row for row in rows if row.resolved_split == "train"] - val_rows = [row for row in rows if row.resolved_split == "val"] - train_files, train_weights = _rows_for_split(rows, "train") - val_files, val_weights = _rows_for_split(rows, "val") + _clear_existing_motion_shards(paths.dataset_dir) paths.dataset_dir.mkdir(parents=True, exist_ok=True) - train_dir = split_output_dir(paths.dataset_dir, "train") - val_dir = split_output_dir(paths.dataset_dir, "val") - train_out = shard_output_path(train_dir, 0) - val_out = shard_output_path(val_dir, 0) - - train_stats = merge_npz_files( - train_files, - train_out, + shard_path = shard_output_path(paths.dataset_dir, 0) + tmp_npz = paths.dataset_dir / ".merged_tmp.npz" + stats = merge_npz_files( + [Path(row.resolved_npz_path) for row in rows], + tmp_npz, target_fps=spec.target_fps, - weights=train_weights, ) - val_stats = merge_npz_files( - val_files, - val_out, - target_fps=spec.target_fps, - weights=val_weights, - ) - - train_stats["output"] = str(train_dir) - train_stats["shards"] = 1 - val_stats["output"] = str(val_dir) - val_stats["shards"] = 1 - - train_clip_lengths = np.load(train_out, allow_pickle=True)["clip_lengths"] - val_clip_lengths = np.load(val_out, allow_pickle=True)["clip_lengths"] - if len(train_rows) != len(train_clip_lengths): - raise ValueError( - f"train row count mismatch: {len(train_rows)} vs {len(train_clip_lengths)}" - ) - if len(val_rows) != len(val_clip_lengths): - raise ValueError( - f"val row count mismatch: {len(val_rows)} vs {len(val_clip_lengths)}" - ) + payload_npz = np.load(tmp_npz, allow_pickle=True) + payload = {key: payload_npz[key] for key in payload_npz.files} + shard_info = write_hdf5_motion_shard(payload, shard_path) + tmp_npz.unlink(missing_ok=True) + _clear_intermediate_clips(paths.clips_root) - updated_rows: list[DatasetClipRow] = [] - for clip_index, (row, num_frames) in enumerate(zip(train_rows, train_clip_lengths)): - updated_rows.append( - DatasetClipRow( - clip_id=row.clip_id, - source=row.source, - file_rel=row.file_rel, - num_frames=int(num_frames), - fps=spec.target_fps, - resolved_split=row.resolved_split, - resolved_npz_path=str(train_out), - weight=row.weight, - clip_index=clip_index, - ) - ) - for clip_index, (row, num_frames) in enumerate(zip(val_rows, val_clip_lengths)): - updated_rows.append( - DatasetClipRow( - clip_id=row.clip_id, - source=row.source, - file_rel=row.file_rel, - num_frames=int(num_frames), - fps=spec.target_fps, - resolved_split=row.resolved_split, - resolved_npz_path=str(val_out), - weight=row.weight, - clip_index=clip_index, - ) - ) + stats["output"] = str(paths.dataset_dir) + stats["shards"] = 1 + stats["clips"] = int(shard_info["clips"]) + stats["num_clips"] = int(shard_info["clips"]) - manifest_path = write_manifest_resolved(updated_rows, paths.dataset_dir) - report: dict[str, Any] = { + return { "dataset": spec.name, - "built_at_utc": utc_now_iso(), "target_fps": spec.target_fps, - "val_percent": spec.val_percent, - "hash_salt": spec.hash_salt, "dataset_dir": str(paths.dataset_dir), - "build_dir": str(paths.dataset_dir), "clips_dir": str(paths.clips_root), - "manifest_resolved": str(manifest_path), "skip_validate": bool(skip_validate), "skip_fk_check": bool(skip_fk_check), "jobs": int(jobs), "preprocess": spec.preprocess.to_dict(), - "sources": [asdict(source) for source in spec.sources], - "splits": { - "train": train_stats, - "val": val_stats, - }, - "clip_counts": { - "total": len(updated_rows), - "train": len(train_rows), - "val": len(val_rows), - }, + "stats": stats, + "shards": [str(shard_path)], + "input_clips": len(rows), "fk_checks": [item.to_dict() for item in fk_checks], } - write_json(paths.dataset_dir / "build_info.json", report) - return report diff --git a/train_mimic/data/dataset_lib.py b/train_mimic/data/dataset_lib.py index 71bfaa12..76c9411f 100644 --- a/train_mimic/data/dataset_lib.py +++ b/train_mimic/data/dataset_lib.py @@ -1,19 +1,21 @@ #!/usr/bin/env python3 -"""Shared utilities for the active NPZ-based dataset pipeline.""" +"""Shared utilities for motion dataset build and runtime loading.""" from __future__ import annotations -import hashlib import json from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path from typing import Any, Mapping, Sequence +import h5py import numpy as np REQUIRED_NPZ_KEYS = [ "fps", + "root_pos", + "root_quat_w", "joint_pos", "joint_vel", "body_pos_w", @@ -24,6 +26,34 @@ ] NUM_ACTIONS = 29 +MOTION_ARRAY_KEYS = [ + "joint_pos", "joint_vel", "body_pos_w", "body_quat_w", + "body_lin_vel_w", "body_ang_vel_w", +] +PRECOMPUTED_MOTION_ARRAY_KEYS = [ + "joint_vel", + "body_pos_w", + "body_quat_w", + "body_lin_vel_w", + "body_ang_vel_w", +] +MINIMAL_MOTION_ARRAY_KEYS = ["root_pos", "root_quat_w", "joint_pos"] +FULL_CLIP_ARRAY_KEYS = [ + "root_pos", + "root_quat_w", + "joint_pos", + "joint_vel", + "body_pos_w", + "body_quat_w", + "body_lin_vel_w", + "body_ang_vel_w", +] +HDF5_DATASET_VERSION = 1 +PRECOMPUTED_MOTION_VERSION = 2 +MINIMAL_HDF5_FORMAT = "teleopit_motion_hdf5" +PRECOMPUTED_HDF5_FORMAT = "teleopit_precomputed_motion_hdf5" +DEFAULT_HDF5_MAX_WINDOW_FRAMES = 512 +DEFAULT_HDF5_WINDOW_OVERLAP_FRAMES = 64 @dataclass(frozen=True) @@ -118,6 +148,8 @@ def inspect_clip_dict(payload: Mapping[str, Any]) -> NpzMeta: if fps <= 0: raise ValueError(f"invalid fps={fps}") + root_pos = np.asarray(payload["root_pos"]) + root_quat_w = np.asarray(payload["root_quat_w"]) joint_pos = np.asarray(payload["joint_pos"]) joint_vel = np.asarray(payload["joint_vel"]) body_pos_w = np.asarray(payload["body_pos_w"]) @@ -126,6 +158,10 @@ def inspect_clip_dict(payload: Mapping[str, Any]) -> NpzMeta: body_ang_vel_w = np.asarray(payload["body_ang_vel_w"]) body_names = np.asarray(payload["body_names"]) + if root_pos.ndim != 2 or root_pos.shape[1] != 3: + raise ValueError(f"root_pos must be (T,3), got {root_pos.shape}") + if root_quat_w.ndim != 2 or root_quat_w.shape[1] != 4: + raise ValueError(f"root_quat_w must be (T,4), got {root_quat_w.shape}") if joint_pos.ndim != 2 or joint_pos.shape[1] != NUM_ACTIONS: raise ValueError(f"joint_pos must be (T,{NUM_ACTIONS}), got {joint_pos.shape}") if joint_vel.ndim != 2 or joint_vel.shape != joint_pos.shape: @@ -145,6 +181,8 @@ def inspect_clip_dict(payload: Mapping[str, Any]) -> NpzMeta: raise ValueError("empty time/body dimension") for name, arr in [ + ("root_pos", root_pos), + ("root_quat_w", root_quat_w), ("joint_pos", joint_pos), ("joint_vel", joint_vel), ("body_pos_w", body_pos_w), @@ -162,6 +200,9 @@ def inspect_clip_dict(payload: Mapping[str, Any]) -> NpzMeta: if body_names.ndim != 1 or body_names.shape[0] != nb: raise ValueError(f"body_names must be (nb,), got {body_names.shape}") + root_quat_norm = np.linalg.norm(root_quat_w, axis=-1) + if np.min(root_quat_norm) < 1e-6: + raise ValueError("root_quat_w contains near-zero norms") quat_norm = np.linalg.norm(body_quat_w, axis=-1) if np.min(quat_norm) < 1e-6: raise ValueError("body_quat_w contains near-zero norms") @@ -180,12 +221,6 @@ def inspect_npz(path: Path) -> NpzMeta: return inspect_clip_dict({key: data[key] for key in data.files}) -def hash_split(clip_id: str, val_percent: int, salt: str = "") -> str: - payload = f"{salt}:{clip_id}".encode("utf-8") - bucket = int(hashlib.md5(payload).hexdigest(), 16) % 100 - return "val" if bucket < val_percent else "train" - - def resample_along_time(arr: np.ndarray, new_t: int) -> np.ndarray: """Resample an array along time axis 0 using linear interpolation.""" old_t = int(arr.shape[0]) @@ -211,29 +246,16 @@ def merge_npz_files( output_path: Path, *, target_fps: int | None = None, - weights: list[float] | None = None, ) -> dict[str, Any]: if not npz_files: raise ValueError("no npz files to merge") - if weights is not None and len(weights) != len(npz_files): - raise ValueError( - f"weights length ({len(weights)}) != npz_files length ({len(npz_files)})" - ) - arrays: dict[str, list[np.ndarray]] = { - "joint_pos": [], - "joint_vel": [], - "body_pos_w": [], - "body_quat_w": [], - "body_lin_vel_w": [], - "body_ang_vel_w": [], - } + arrays: dict[str, list[np.ndarray]] = {key: [] for key in FULL_CLIP_ARRAY_KEYS} fps: int | None = None body_names: np.ndarray[Any, Any] | None = None per_clip_fps: list[int] = [] - per_clip_weights: list[float] = [] - for i, p in enumerate(npz_files): + for p in npz_files: d = np.load(p, allow_pickle=True) cur_fps = int(d["fps"]) if cur_fps <= 0: @@ -242,6 +264,8 @@ def merge_npz_files( if target_fps is not None and target_fps <= 0: raise ValueError(f"target_fps must be > 0, got {target_fps}") + root_pos = np.asarray(d["root_pos"]) + root_quat_w = np.asarray(d["root_quat_w"]) joint_pos = np.asarray(d["joint_pos"]) joint_vel = np.asarray(d["joint_vel"]) body_pos_w = np.asarray(d["body_pos_w"]) @@ -254,6 +278,8 @@ def merge_npz_files( new_t = int(round(old_t * float(target_fps) / float(cur_fps))) new_t = max(new_t, 1) + root_pos = resample_along_time(root_pos, new_t) + root_quat_w = resample_along_time(root_quat_w, new_t) joint_pos = resample_along_time(joint_pos, new_t) joint_vel = resample_along_time(joint_vel, new_t) body_pos_w = resample_along_time(body_pos_w, new_t) @@ -261,6 +287,9 @@ def merge_npz_files( body_lin_vel_w = resample_along_time(body_lin_vel_w, new_t) body_ang_vel_w = resample_along_time(body_ang_vel_w, new_t) + root_quat_norm = np.linalg.norm(root_quat_w, axis=-1, keepdims=True) + root_quat_norm = np.where(root_quat_norm < 1e-8, 1.0, root_quat_norm) + root_quat_w = root_quat_w / root_quat_norm quat_norm = np.linalg.norm(body_quat_w, axis=-1, keepdims=True) quat_norm = np.where(quat_norm < 1e-8, 1.0, quat_norm) body_quat_w = body_quat_w / quat_norm @@ -276,8 +305,9 @@ def merge_npz_files( raise ValueError(f"inconsistent body_names in merge: {p}") per_clip_fps.append(cur_fps) - per_clip_weights.append(weights[i] if weights is not None else 1.0) + arrays["root_pos"].append(root_pos) + arrays["root_quat_w"].append(root_quat_w) arrays["joint_pos"].append(joint_pos) arrays["joint_vel"].append(joint_vel) arrays["body_pos_w"].append(body_pos_w) @@ -299,7 +329,6 @@ def merge_npz_files( merged["clip_starts"] = clip_starts merged["clip_lengths"] = clip_lengths merged["clip_fps"] = np.array(per_clip_fps, dtype=np.int64) - merged["clip_weights"] = np.array(per_clip_weights, dtype=np.float64) output_path.parent.mkdir(parents=True, exist_ok=True) np.savez(output_path, **merged) @@ -315,65 +344,51 @@ def merge_npz_files( } -def extract_clip_arrays(npz_path: Path, clip_index: int) -> dict[str, Any]: - """Extract a single clip's arrays from a merged NPZ by clip index. +def merge_clip_dicts( + clip_dicts: list[dict[str, Any]], + output_path: Path, + *, + target_fps: int | None = None, +) -> dict[str, Any]: + """Merge a list of in-memory clip array dicts into a single NPZ.""" + merged = merge_clip_dicts_payload( + clip_dicts, + target_fps=target_fps, + ) + output_path.parent.mkdir(parents=True, exist_ok=True) + np.savez(output_path, **merged) - Returns a dict with the same keys as a standalone clip NPZ: - fps, joint_pos, joint_vel, body_pos_w, body_quat_w, - body_lin_vel_w, body_ang_vel_w, body_names. - """ - d = np.load(npz_path, allow_pickle=True) - clip_starts = d["clip_starts"] - clip_lengths = d["clip_lengths"] - if clip_index < 0 or clip_index >= len(clip_starts): - raise IndexError( - f"clip_index {clip_index} out of range [0, {len(clip_starts)}) in {npz_path}" - ) - start = int(clip_starts[clip_index]) - length = int(clip_lengths[clip_index]) - s = slice(start, start + length) + total_frames = int(merged["joint_pos"].shape[0]) return { - "fps": int(d["fps"]), - "joint_pos": np.asarray(d["joint_pos"][s]), - "joint_vel": np.asarray(d["joint_vel"][s]), - "body_pos_w": np.asarray(d["body_pos_w"][s]), - "body_quat_w": np.asarray(d["body_quat_w"][s]), - "body_lin_vel_w": np.asarray(d["body_lin_vel_w"][s]), - "body_ang_vel_w": np.asarray(d["body_ang_vel_w"][s]), - "body_names": np.asarray(d["body_names"]), + "output": str(output_path), + "clips": len(clip_dicts), + "num_clips": len(clip_dicts), + "frames": total_frames, + "fps": int(merged["fps"]), + "duration_s": float(total_frames / max(int(merged["fps"]), 1)), } -def merge_clip_dicts( +def merge_clip_dicts_payload( clip_dicts: list[dict[str, Any]], - output_path: Path, *, target_fps: int | None = None, - weights: list[float] | None = None, ) -> dict[str, Any]: - """Merge a list of in-memory clip array dicts into a single NPZ. + """Merge in-memory clip array dicts and return a flat motion payload. Each dict must have keys: fps, joint_pos, joint_vel, body_pos_w, body_quat_w, body_lin_vel_w, body_ang_vel_w, body_names. """ if not clip_dicts: raise ValueError("no clip dicts to merge") - if weights is not None and len(weights) != len(clip_dicts): - raise ValueError( - f"weights length ({len(weights)}) != clip_dicts length ({len(clip_dicts)})" - ) if target_fps is not None and target_fps <= 0: raise ValueError(f"target_fps must be > 0, got {target_fps}") - array_keys = [ - "joint_pos", "joint_vel", "body_pos_w", "body_quat_w", - "body_lin_vel_w", "body_ang_vel_w", - ] + array_keys = FULL_CLIP_ARRAY_KEYS arrays: dict[str, list[np.ndarray]] = {k: [] for k in array_keys} fps: int | None = None body_names: np.ndarray | None = None per_clip_fps: list[int] = [] - per_clip_weights: list[float] = [] for i, cd in enumerate(clip_dicts): cur_fps = int(cd["fps"]) @@ -386,6 +401,8 @@ def merge_clip_dicts( new_t = max(1, round(old_t * target_fps / cur_fps)) for k in array_keys: clip_arrays[k] = resample_along_time(clip_arrays[k], new_t) + qn_root = np.linalg.norm(clip_arrays["root_quat_w"], axis=-1, keepdims=True) + clip_arrays["root_quat_w"] = clip_arrays["root_quat_w"] / np.where(qn_root < 1e-8, 1.0, qn_root) qn = np.linalg.norm(clip_arrays["body_quat_w"], axis=-1, keepdims=True) clip_arrays["body_quat_w"] = clip_arrays["body_quat_w"] / np.where(qn < 1e-8, 1.0, qn) cur_fps = target_fps @@ -399,7 +416,6 @@ def merge_clip_dicts( raise ValueError(f"inconsistent body_names: clip {i}") per_clip_fps.append(cur_fps) - per_clip_weights.append(weights[i] if weights is not None else 1.0) for k in array_keys: arrays[k].append(clip_arrays[k]) @@ -414,19 +430,539 @@ def merge_clip_dicts( merged["clip_starts"] = clip_starts merged["clip_lengths"] = clip_lengths merged["clip_fps"] = np.array(per_clip_fps, dtype=np.int64) - merged["clip_weights"] = np.array(per_clip_weights, dtype=np.float64) + return merged + + +def _window_clip_ranges( + *, + clip_start: int, + clip_length: int, + max_window_frames: int, + overlap_frames: int, +) -> list[tuple[int, int]]: + if clip_length <= 0: + raise ValueError(f"clip_length must be > 0, got {clip_length}") + if max_window_frames <= 1: + raise ValueError(f"max_window_frames must be > 1, got {max_window_frames}") + if overlap_frames < 0 or overlap_frames >= max_window_frames: + raise ValueError( + "overlap_frames must be in [0, max_window_frames), got " + f"{overlap_frames} for max_window_frames={max_window_frames}" + ) + if clip_length <= max_window_frames: + return [(clip_start, clip_length)] + + stride = max_window_frames - overlap_frames + starts = list(range(0, max(clip_length - max_window_frames + 1, 1), stride)) + tail_start = clip_length - max_window_frames + if starts[-1] != tail_start: + starts.append(tail_start) + return [(clip_start + int(start), max_window_frames) for start in starts] + + +def write_hdf5_motion_shard( + merged: Mapping[str, Any], + output_path: Path, + *, + max_window_frames: int = DEFAULT_HDF5_MAX_WINDOW_FRAMES, + overlap_frames: int = DEFAULT_HDF5_WINDOW_OVERLAP_FRAMES, +) -> dict[str, Any]: + """Write a merged motion payload as one minimal HDF5 shard. + + The frame arrays remain flat in the HDF5 file. ``clip_starts`` and + ``clip_lengths`` describe training windows, not necessarily original clips. + Long clips are split into overlapping windows to bound runtime cache size. + Runtime loading derives joint velocities and body FK/velocities from the + stored root pose and joint positions. + """ + missing = [key for key in [*MINIMAL_MOTION_ARRAY_KEYS, "fps", "body_names", "clip_starts", "clip_lengths", "clip_fps"] if key not in merged] + if missing: + raise ValueError(f"merged payload missing required keys: {missing}") + + fps = int(merged["fps"]) + if fps <= 0: + raise ValueError(f"fps must be > 0, got {fps}") + body_names = np.asarray(merged["body_names"]).astype(str) + original_starts = np.asarray(merged["clip_starts"], dtype=np.int64) + original_lengths = np.asarray(merged["clip_lengths"], dtype=np.int64) + original_fps = np.asarray(merged["clip_fps"], dtype=np.int64) + + window_starts: list[int] = [] + window_lengths: list[int] = [] + window_fps: list[int] = [] + source_clip_ids: list[int] = [] + source_start_frames: list[int] = [] + for source_idx, (clip_start, clip_length) in enumerate(zip(original_starts, original_lengths)): + ranges = _window_clip_ranges( + clip_start=int(clip_start), + clip_length=int(clip_length), + max_window_frames=max_window_frames, + overlap_frames=overlap_frames, + ) + for start, length in ranges: + window_starts.append(int(start)) + window_lengths.append(int(length)) + window_fps.append(int(original_fps[source_idx])) + source_clip_ids.append(int(source_idx)) + source_start_frames.append(int(start - int(clip_start))) output_path.parent.mkdir(parents=True, exist_ok=True) - np.savez(output_path, **merged) + str_dt = h5py.string_dtype(encoding="utf-8") + with h5py.File(output_path, "w") as h5: + h5.attrs["format"] = MINIMAL_HDF5_FORMAT + h5.attrs["version"] = HDF5_DATASET_VERSION + h5.attrs["fps"] = fps + h5.attrs["max_window_frames"] = int(max_window_frames) + h5.attrs["overlap_frames"] = int(overlap_frames) + h5.create_dataset("body_names", data=body_names.astype(object), dtype=str_dt) + for key in MINIMAL_MOTION_ARRAY_KEYS: + arr = np.asarray(merged[key], dtype=np.float32) + h5.create_dataset(key, data=arr, chunks=True) + h5.create_dataset("clip_starts", data=np.asarray(window_starts, dtype=np.int64)) + h5.create_dataset("clip_lengths", data=np.asarray(window_lengths, dtype=np.int64)) + h5.create_dataset("clip_fps", data=np.asarray(window_fps, dtype=np.int64)) + h5.create_dataset("source_clip_ids", data=np.asarray(source_clip_ids, dtype=np.int64)) + h5.create_dataset("source_start_frames", data=np.asarray(source_start_frames, dtype=np.int64)) + h5.create_dataset("source_clip_starts", data=original_starts.astype(np.int64)) + h5.create_dataset("source_clip_lengths", data=original_lengths.astype(np.int64)) + h5.create_dataset("source_clip_fps", data=original_fps.astype(np.int64)) + + total_frames = int(np.asarray(merged["joint_pos"]).shape[0]) + return { + "path": str(output_path), + "clips": len(window_lengths), + "num_clips": len(window_lengths), + "source_clips": len(original_lengths), + "frames": total_frames, + "fps": fps, + "duration_s": float(total_frames / max(fps, 1)), + "clip_lengths": [int(v) for v in window_lengths], + "source_clip_lengths": [int(v) for v in original_lengths], + } + + +def read_hdf5_body_names(path: Path) -> list[str]: + with h5py.File(path, "r") as h5: + return [ + str(name.decode("utf-8") if isinstance(name, bytes) else name) + for name in h5["body_names"][()] + ] + + +def _read_hdf5_string_list(h5: h5py.File, key: str) -> list[str]: + return [ + str(name.decode("utf-8") if isinstance(name, bytes) else name) + for name in h5[key][()] + ] + + +def validate_precomputed_motion_shard( + shard_path: str | Path, +) -> None: + """Validate that a shard is a standalone precomputed training shard.""" + shard = Path(shard_path).expanduser().resolve() + if not shard.is_file(): + raise FileNotFoundError(f"precomputed motion shard not found: {shard}") + + with h5py.File(shard, "r") as h5: + if h5.attrs.get("format") != PRECOMPUTED_HDF5_FORMAT: + raise ValueError( + f"motion dataset must be precomputed before training: {shard}. " + "Run: python train_mimic/scripts/data/precompute_dataset.py " + " --outdir , then pass " + "--motion_file ." + ) + if int(h5.attrs.get("version", 0)) != PRECOMPUTED_MOTION_VERSION: + raise ValueError( + f"unsupported precomputed motion version in {shard}: " + f"{h5.attrs.get('version')}. Regenerate with precompute_dataset.py." + ) + missing = [ + key for key in [ + "joint_pos", + *PRECOMPUTED_MOTION_ARRAY_KEYS, + "body_names", + "clip_starts", + "clip_lengths", + "clip_fps", + "source_clip_ids", + "source_start_frames", + "source_clip_starts", + "source_clip_lengths", + "source_clip_fps", + ] + if key not in h5 + ] + if missing: + raise ValueError( + f"precomputed motion shard {shard} missing required datasets: {missing}. " + "Regenerate with precompute_dataset.py." + ) + + body_names = _read_hdf5_string_list(h5, "body_names") + frames = int(h5["joint_pos"].shape[0]) + num_bodies = len(body_names) + expected_shapes = { + "joint_pos": (frames, NUM_ACTIONS), + "joint_vel": (frames, NUM_ACTIONS), + "body_pos_w": (frames, num_bodies, 3), + "body_quat_w": (frames, num_bodies, 4), + "body_lin_vel_w": (frames, num_bodies, 3), + "body_ang_vel_w": (frames, num_bodies, 3), + } + for key, shape in expected_shapes.items(): + if h5[key].shape != shape: + raise ValueError( + f"precomputed motion shard {shard} dataset {key} shape mismatch: " + f"expected {shape}, got {h5[key].shape}. Regenerate with precompute_dataset.py." + ) + windows = int(h5["clip_starts"].shape[0]) + for key in ["clip_lengths", "clip_fps", "source_clip_ids", "source_start_frames"]: + if h5[key].shape != (windows,): + raise ValueError( + f"precomputed motion shard {shard} dataset {key} shape mismatch: " + f"expected {(windows,)}, got {h5[key].shape}. Regenerate with precompute_dataset.py." + ) + source_clips = int(h5["source_clip_starts"].shape[0]) + for key in ["source_clip_lengths", "source_clip_fps"]: + if h5[key].shape != (source_clips,): + raise ValueError( + f"precomputed motion shard {shard} dataset {key} shape mismatch: " + f"expected {(source_clips,)}, got {h5[key].shape}. Regenerate with precompute_dataset.py." + ) + + +def validate_precomputed_motion_dataset(dataset_dir_or_shard: str | Path) -> None: + for shard_path in find_precomputed_motion_shards(dataset_dir_or_shard): + validate_precomputed_motion_shard(shard_path) + + +def write_precomputed_motion_shard( + shard_path: str | Path, + output_path: str | Path, + *, + force: bool = False, + model_path: str | Path | None = None, +) -> dict[str, Any]: + """Write one standalone precomputed training shard from a minimal source shard.""" + from train_mimic.data.motion_fk import ( + MotionFkExtractor, + compute_body_velocities, + finite_diff_velocity, + ) + + source_path = Path(shard_path).expanduser().resolve() + output = Path(output_path).expanduser().resolve() + if output.is_file() and not force: + validate_precomputed_motion_shard(output) + return { + "shard": str(source_path), + "output": str(output), + "status": "existing", + } + + extractor = MotionFkExtractor(model_path) + with h5py.File(source_path, "r") as source: + if source.attrs.get("format") != MINIMAL_HDF5_FORMAT: + raise ValueError( + f"precompute input must be a minimal Teleopit motion shard, got {source_path}" + ) + required = [ + "root_pos", + "root_quat_w", + "joint_pos", + "body_names", + "clip_starts", + "clip_lengths", + "clip_fps", + "source_clip_ids", + "source_start_frames", + "source_clip_starts", + "source_clip_lengths", + "source_clip_fps", + ] + missing = [key for key in required if key not in source] + if missing: + raise ValueError( + f"HDF5 shard {source_path} missing required datasets for precompute: {missing}. " + "Rebuild the dataset with the current HDF5 writer." + ) + + root_pos = np.asarray(source["root_pos"], dtype=np.float32) + root_quat_w = np.asarray(source["root_quat_w"], dtype=np.float32) + joint_pos = np.asarray(source["joint_pos"], dtype=np.float32) + body_names = np.asarray(_read_hdf5_string_list(source, "body_names"), dtype=str) + source_starts = np.asarray(source["source_clip_starts"], dtype=np.int64) + source_lengths = np.asarray(source["source_clip_lengths"], dtype=np.int64) + source_fps = np.asarray(source["source_clip_fps"], dtype=np.int64) + + frames = int(joint_pos.shape[0]) + num_bodies = int(body_names.shape[0]) + joint_vel = np.empty_like(joint_pos, dtype=np.float32) + body_pos_w = np.empty((frames, num_bodies, 3), dtype=np.float32) + body_quat_w = np.empty((frames, num_bodies, 4), dtype=np.float32) + body_lin_vel_w = np.empty((frames, num_bodies, 3), dtype=np.float32) + body_ang_vel_w = np.empty((frames, num_bodies, 3), dtype=np.float32) + + for start, length, fps in zip(source_starts, source_lengths, source_fps): + start_i = int(start) + length_i = int(length) + fps_i = int(fps) + if length_i <= 0: + raise ValueError(f"invalid source clip length {length_i} in {source_path}") + if fps_i <= 0: + raise ValueError(f"invalid source clip fps {fps_i} in {source_path}") + sl = slice(start_i, start_i + length_i) + dt = 1.0 / float(fps_i) + cur_body_pos_w, cur_body_quat_w = extractor.extract( + root_pos[sl], + root_quat_w[sl], + joint_pos[sl], + body_names, + ) + cur_body_lin_vel_w, cur_body_ang_vel_w = compute_body_velocities( + cur_body_pos_w, + cur_body_quat_w, + dt, + ) + joint_vel[sl] = finite_diff_velocity(joint_pos[sl], dt) + body_pos_w[sl] = cur_body_pos_w + body_quat_w[sl] = cur_body_quat_w + body_lin_vel_w[sl] = cur_body_lin_vel_w + body_ang_vel_w[sl] = cur_body_ang_vel_w + + clip_starts = np.asarray(source["clip_starts"], dtype=np.int64) + clip_lengths = np.asarray(source["clip_lengths"], dtype=np.int64) + clip_fps = np.asarray(source["clip_fps"], dtype=np.int64) + source_clip_ids = np.asarray(source["source_clip_ids"], dtype=np.int64) + source_start_frames = np.asarray(source["source_start_frames"], dtype=np.int64) + + output.parent.mkdir(parents=True, exist_ok=True) + tmp_path = output.with_suffix(f"{output.suffix}.tmp") + if tmp_path.exists(): + tmp_path.unlink() + str_dt = h5py.string_dtype(encoding="utf-8") + with h5py.File(tmp_path, "w") as h5: + h5.attrs["format"] = PRECOMPUTED_HDF5_FORMAT + h5.attrs["version"] = PRECOMPUTED_MOTION_VERSION + h5.attrs["source_shard"] = str(source_path) + h5.attrs["created_at"] = utc_now_iso() + h5.create_dataset("body_names", data=body_names.astype(object), dtype=str_dt) + h5.create_dataset("joint_pos", data=joint_pos, chunks=True) + h5.create_dataset("joint_vel", data=joint_vel, chunks=True) + h5.create_dataset("body_pos_w", data=body_pos_w, chunks=True) + h5.create_dataset("body_quat_w", data=body_quat_w, chunks=True) + h5.create_dataset("body_lin_vel_w", data=body_lin_vel_w, chunks=True) + h5.create_dataset("body_ang_vel_w", data=body_ang_vel_w, chunks=True) + h5.create_dataset("clip_starts", data=clip_starts) + h5.create_dataset("clip_lengths", data=clip_lengths) + h5.create_dataset("clip_fps", data=clip_fps) + h5.create_dataset("source_clip_ids", data=source_clip_ids) + h5.create_dataset("source_start_frames", data=source_start_frames) + h5.create_dataset("source_clip_starts", data=source_starts) + h5.create_dataset("source_clip_lengths", data=source_lengths) + h5.create_dataset("source_clip_fps", data=source_fps) + tmp_path.replace(output) + validate_precomputed_motion_shard(output) - total_frames = int(merged["joint_pos"].shape[0]) return { - "output": str(output_path), - "clips": len(clip_dicts), - "num_clips": len(clip_dicts), + "shard": str(source_path), + "output": str(output), + "status": "written", + "frames": frames, + "bodies": num_bodies, + "arrays": ["joint_pos", *PRECOMPUTED_MOTION_ARRAY_KEYS], + } + + +def read_motion_clip(path: Path, clip_index: int) -> dict[str, Any]: + """Read one source clip from a current HDF5 motion shard path. + + HDF5 shards use source-clip metadata, so ``clip_index`` indexes original + clips, not bounded training windows. + """ + if path.suffix == ".h5": + return read_hdf5_source_clip(path, clip_index) + raise ValueError( + f"source clip input must be a current HDF5 shard (.h5), got: {path}" + ) + + +def read_hdf5_source_clip(path: Path, clip_index: int) -> dict[str, Any]: + if clip_index < 0: + raise ValueError(f"HDF5 shard rows require clip_index >= 0: {path}") + with h5py.File(path, "r") as h5: + required = [ + "source_clip_starts", + "source_clip_lengths", + "source_clip_fps", + "body_names", + *MINIMAL_MOTION_ARRAY_KEYS, + ] + missing = [key for key in required if key not in h5] + if missing: + raise ValueError( + f"HDF5 shard {path} is missing source clip metadata {missing}. " + "Rebuild the dataset with the current HDF5 writer." + ) + starts = np.asarray(h5["source_clip_starts"], dtype=np.int64) + lengths = np.asarray(h5["source_clip_lengths"], dtype=np.int64) + fps = np.asarray(h5["source_clip_fps"], dtype=np.int64) + if clip_index >= len(starts): + raise IndexError( + f"clip_index {clip_index} out of range [0, {len(starts)}) in {path}" + ) + start = int(starts[clip_index]) + length = int(lengths[clip_index]) + sl = slice(start, start + length) + root_pos = np.asarray(h5["root_pos"][sl], dtype=np.float32) + root_quat_w = np.asarray(h5["root_quat_w"][sl], dtype=np.float32) + joint_pos = np.asarray(h5["joint_pos"][sl], dtype=np.float32) + body_names = np.asarray(read_hdf5_body_names(path), dtype=str) + + from train_mimic.data.motion_fk import MotionFkExtractor, compute_body_velocities, finite_diff_velocity + + dt = 1.0 / float(fps[clip_index]) + joint_vel = finite_diff_velocity(joint_pos, dt) + fk = MotionFkExtractor() + body_pos_w, body_quat_w = fk.extract(root_pos, root_quat_w, joint_pos, body_names) + body_lin_vel_w, body_ang_vel_w = compute_body_velocities(body_pos_w, body_quat_w, dt) + return { + "fps": int(fps[clip_index]), + "root_pos": root_pos, + "root_quat_w": root_quat_w, + "joint_pos": joint_pos, + "joint_vel": joint_vel, + "body_pos_w": body_pos_w, + "body_quat_w": body_quat_w, + "body_lin_vel_w": body_lin_vel_w, + "body_ang_vel_w": body_ang_vel_w, + "body_names": body_names, + } + + +def _find_hdf5_shards_by_format( + dataset_dir: str | Path, + *, + expected_format: str, + label: str, +) -> list[Path]: + root = Path(dataset_dir).expanduser().resolve() + if root.is_file(): + candidates = [root] + elif root.is_dir(): + candidates = sorted(root.rglob("*.h5")) + else: + raise FileNotFoundError(f"motion dataset path not found: {dataset_dir}") + + shards: list[Path] = [] + for path in candidates: + try: + with h5py.File(path, "r") as h5: + if h5.attrs.get("format") == expected_format: + shards.append(path) + except OSError: + continue + if not shards: + raise FileNotFoundError(f"no {label} HDF5 motion shards found under {dataset_dir}") + return shards + + +def find_motion_shards(dataset_dir: str | Path) -> list[Path]: + """Recursively find minimal Teleopit HDF5 motion shards under a root directory.""" + return _find_hdf5_shards_by_format( + dataset_dir, + expected_format=MINIMAL_HDF5_FORMAT, + label="minimal Teleopit", + ) + + +def find_precomputed_motion_shards(dataset_dir: str | Path) -> list[Path]: + """Recursively find precomputed Teleopit HDF5 training shards.""" + return _find_hdf5_shards_by_format( + dataset_dir, + expected_format=PRECOMPUTED_HDF5_FORMAT, + label="precomputed Teleopit", + ) + + +def compute_dataset_stats( + dataset_dir: str | Path, + *, + precomputed: bool = False, +) -> dict[str, Any]: + shards = find_precomputed_motion_shards(dataset_dir) if precomputed else find_motion_shards(dataset_dir) + array_keys = MOTION_ARRAY_KEYS if precomputed else MINIMAL_MOTION_ARRAY_KEYS + total_windows = 0 + total_frames = 0 + total_duration_s = 0.0 + total_source_clips = 0 + fps_values: set[int] = set() + body_names_ref: list[str] | None = None + shard_rows: list[dict[str, Any]] = [] + + for shard_path in shards: + with h5py.File(shard_path, "r") as h5: + missing = [ + key for key in [ + *array_keys, + "body_names", + "clip_starts", + "clip_lengths", + "clip_fps", + "source_clip_ids", + "source_clip_starts", + "source_clip_lengths", + "source_clip_fps", + ] + if key not in h5 + ] + if missing: + raise ValueError(f"HDF5 shard {shard_path} missing required datasets: {missing}") + body_names = read_hdf5_body_names(shard_path) + if body_names_ref is None: + body_names_ref = body_names + elif body_names != body_names_ref: + raise ValueError( + f"inconsistent body_names in {shard_path}; all shards under one training root must match" + ) + lengths = np.asarray(h5["clip_lengths"], dtype=np.int64) + fps_arr = np.asarray(h5["clip_fps"], dtype=np.int64) + source_ids = np.asarray(h5["source_clip_ids"], dtype=np.int64) + fps_values.update(int(v) for v in np.unique(fps_arr)) + windows = int(lengths.shape[0]) + frames = int(h5["joint_pos"].shape[0]) + source_clips = int(len(np.unique(source_ids))) + source_lengths = np.asarray(h5["source_clip_lengths"], dtype=np.float64) + source_fps = np.asarray(h5["source_clip_fps"], dtype=np.float64) + shard_duration_s = float(np.sum(source_lengths / np.maximum(source_fps, 1.0))) + total_windows += windows + total_frames += frames + total_duration_s += shard_duration_s + total_source_clips += source_clips + shard_rows.append({ + "path": str(shard_path), + "windows": windows, + "source_clips": source_clips, + "frames": frames, + "duration_s": shard_duration_s, + "min_window_frames": int(lengths.min()) if windows else 0, + "max_window_frames": int(lengths.max()) if windows else 0, + }) + + return { + "format": PRECOMPUTED_HDF5_FORMAT if precomputed else MINIMAL_HDF5_FORMAT, + "version": PRECOMPUTED_MOTION_VERSION if precomputed else HDF5_DATASET_VERSION, + "root": str(Path(dataset_dir).expanduser().resolve()), + "shards": len(shards), + "windows": total_windows, + "source_clips": total_source_clips, "frames": total_frames, - "fps": int(merged["fps"]), - "duration_s": float(total_frames / max(int(merged["fps"]), 1)), + "duration_s": total_duration_s, + "duration_h": total_duration_s / 3600.0, + "fps": sorted(fps_values), + "body_names": body_names_ref or [], + "shard_details": shard_rows, } diff --git a/train_mimic/data/motion_fk.py b/train_mimic/data/motion_fk.py index 3b03c671..730089c4 100644 --- a/train_mimic/data/motion_fk.py +++ b/train_mimic/data/motion_fk.py @@ -9,10 +9,10 @@ import mujoco import numpy as np -from teleopit.runtime.assets import UNITREE_G1_MJLAB_XML, missing_gmr_assets_message +from teleopit.runtime.assets import UNITREE_G1_XML, missing_gmr_assets_message -DEFAULT_G1_XML_PATH = UNITREE_G1_MJLAB_XML +DEFAULT_G1_XML_PATH = UNITREE_G1_XML def quat_xyzw_to_wxyz(q: np.ndarray) -> np.ndarray: @@ -33,6 +33,21 @@ def normalize_quaternion(q: np.ndarray) -> np.ndarray: return arr / norms +def finite_diff_velocity(x: np.ndarray, dt: float) -> np.ndarray: + """Finite-difference velocity with central interior and one-sided edges.""" + arr = np.asarray(x, dtype=np.float32) + vel = np.zeros_like(arr, dtype=np.float32) + if dt <= 0.0: + raise ValueError(f"dt must be > 0, got {dt}") + if arr.shape[0] < 2: + return vel + inv_dt = 1.0 / float(dt) + vel[1:-1] = (arr[2:] - arr[:-2]) * (0.5 * inv_dt) + vel[0] = (arr[1] - arr[0]) * inv_dt + vel[-1] = (arr[-1] - arr[-2]) * inv_dt + return vel + + def quat_multiply(q1: np.ndarray, q2: np.ndarray) -> np.ndarray: """Multiply two wxyz quaternions: q1 * q2.""" q1 = np.asarray(q1, dtype=np.float32) @@ -75,9 +90,19 @@ def quat_rotate_inverse(q: np.ndarray, v: np.ndarray) -> np.ndarray: def quat_to_angular_velocity(q: np.ndarray, dt: float) -> np.ndarray: """Compute angular velocity from quaternion sequence in wxyz convention.""" - q = normalize_quaternion(q) - q_dot = np.gradient(q, dt, axis=0) - product = quat_multiply(q_dot, quat_conjugate(q)) + quat = normalize_quaternion(q) + if quat.shape[0] < 2: + return np.zeros(quat.shape[:-1] + (3,), dtype=np.float32) + + flat = quat.reshape(quat.shape[0], -1, 4) + dots = np.sum(flat[1:] * flat[:-1], axis=-1) + signs = np.where(dots < 0.0, -1.0, 1.0).astype(np.float32) + signs = np.concatenate([np.ones_like(signs[:1]), signs], axis=0) + flat = flat * np.cumprod(signs, axis=0)[..., None] + quat = flat.reshape(quat.shape) + + q_dot = finite_diff_velocity(quat, dt) + product = quat_multiply(q_dot, quat_conjugate(quat)) return (2.0 * product[..., 1:4]).astype(np.float32) @@ -192,9 +217,7 @@ def compute_body_velocities( """Compute linear and angular velocity sequences from FK outputs.""" body_pos_w = np.asarray(body_pos_w, dtype=np.float32) body_quat_w = normalize_quaternion(body_quat_w) - if dt <= 0.0: - raise ValueError(f"dt must be > 0, got {dt}") - body_lin_vel_w = np.gradient(body_pos_w, dt, axis=0).astype(np.float32) + body_lin_vel_w = finite_diff_velocity(body_pos_w, dt) body_ang_vel_w = quat_to_angular_velocity(body_quat_w, dt).astype(np.float32) return body_lin_vel_w, body_ang_vel_w diff --git a/train_mimic/data/preprocess.py b/train_mimic/data/preprocess.py index 4e565024..89cad31f 100644 --- a/train_mimic/data/preprocess.py +++ b/train_mimic/data/preprocess.py @@ -9,7 +9,7 @@ from train_mimic.data.dataset_lib import inspect_clip_dict -GROUND_ALIGN_MODES = {"none", "clip_min_foot"} +GROUND_ALIGN_MODES = {"none", "first_frame_foot"} @dataclass(frozen=True) @@ -23,6 +23,8 @@ class DatasetPreprocessSpec: min_peak_body_height: float | None = None max_all_off_ground_s: float | None = None off_ground_height: float = 0.2 + max_feet_off_ground_s: float | None = None + foot_off_ground_height: float = 0.08 def to_dict(self) -> dict[str, Any]: return asdict(self) @@ -55,6 +57,16 @@ def validate_preprocess_spec(spec: DatasetPreprocessSpec) -> DatasetPreprocessSp raise ValueError( f"preprocess.off_ground_height must be >= 0, got {spec.off_ground_height}" ) + if spec.max_feet_off_ground_s is not None and spec.max_feet_off_ground_s <= 0.0: + raise ValueError( + "preprocess.max_feet_off_ground_s must be > 0, " + f"got {spec.max_feet_off_ground_s}" + ) + if spec.foot_off_ground_height < 0.0: + raise ValueError( + "preprocess.foot_off_ground_height must be >= 0, " + f"got {spec.foot_off_ground_height}" + ) return spec @@ -89,6 +101,8 @@ def preprocess_clip_dict( result = { "fps": int(payload["fps"]), + "root_pos": np.asarray(payload["root_pos"]).copy(), + "root_quat_w": np.asarray(payload["root_quat_w"]).copy(), "joint_pos": np.asarray(payload["joint_pos"]).copy(), "joint_vel": np.asarray(payload["joint_vel"]).copy(), "body_pos_w": np.asarray(payload["body_pos_w"]).copy(), @@ -100,6 +114,7 @@ def preprocess_clip_dict( inspect_clip_dict(result) fps = int(result["fps"]) + root_pos = result["root_pos"] body_pos_w = result["body_pos_w"] body_lin_vel_w = result["body_lin_vel_w"] body_names = np.asarray(result["body_names"]) @@ -117,19 +132,11 @@ def preprocess_clip_dict( if spec.normalize_root_xy: assert root_index is not None offset_xy = body_pos_w[0, root_index, :2].copy() + root_pos[:, 0] -= offset_xy[0] + root_pos[:, 1] -= offset_xy[1] body_pos_w[..., 0] -= offset_xy[0] body_pos_w[..., 1] -= offset_xy[1] - foot_indices: tuple[int, int] | None = None - if spec.ground_align != "none": - foot_indices = ( - _body_index(body_names, spec.foot_body_names[0], label="foot"), - _body_index(body_names, spec.foot_body_names[1], label="foot"), - ) - foot_z = body_pos_w[:, foot_indices, 2] - if spec.ground_align == "clip_min_foot": - body_pos_w[..., 2] -= float(np.min(foot_z)) - if spec.max_root_lin_vel is not None: assert root_index is not None peak_root_lin_vel = float(np.max(np.abs(body_lin_vel_w[:, root_index, :]))) @@ -149,6 +156,31 @@ def preprocess_clip_dict( f"{spec.max_all_off_ground_s:.3f}s" ) + foot_indices: tuple[int, int] | None = None + if spec.ground_align != "none" or spec.max_feet_off_ground_s is not None: + foot_indices = ( + _body_index(body_names, spec.foot_body_names[0], label="foot"), + _body_index(body_names, spec.foot_body_names[1], label="foot"), + ) + + if spec.max_feet_off_ground_s is not None: + assert foot_indices is not None + foot_z = body_pos_w[:, foot_indices, 2] + both_feet_off = np.all(foot_z > spec.foot_off_ground_height, axis=1) + longest_run = _longest_true_run(both_feet_off) + if longest_run > int(round(spec.max_feet_off_ground_s * fps)): + raise ValueError( + f"{clip_label}: both feet off ground for {longest_run / fps:.3f}s, exceeds " + f"{spec.max_feet_off_ground_s:.3f}s" + ) + + if spec.ground_align == "first_frame_foot": + assert foot_indices is not None + foot_z = body_pos_w[:, foot_indices, 2] + dz = float(np.min(foot_z[0])) + root_pos[:, 2] -= dz + body_pos_w[..., 2] -= dz + if spec.min_peak_body_height is not None: peak_height = float(np.max(body_pos_w[:, :, 2])) if peak_height < spec.min_peak_body_height: diff --git a/train_mimic/data/review_lib.py b/train_mimic/data/review_lib.py deleted file mode 100644 index 72880737..00000000 --- a/train_mimic/data/review_lib.py +++ /dev/null @@ -1,185 +0,0 @@ -"""Shared utilities for dataset review: data model, I/O, and statistics.""" - -from __future__ import annotations - -import csv -import os -from dataclasses import dataclass, fields -from datetime import datetime, timezone -from pathlib import Path -from typing import Any - - -REVIEW_COLUMNS = [ - "clip_id", - "source", - "file_rel", - "resolved_npz_path", - "resolved_split", - "num_frames", - "fps", - "duration_s", - "weight", - "clip_index", - "decision", - "difficulty", - "issue_tags", - "note", - "reviewed_at", -] - -VALID_DECISIONS = {"keep", "drop", "skip", ""} -VALID_DIFFICULTIES = {"easy", "medium", "hard", "bad_data", ""} - - -@dataclass -class ReviewRow: - clip_id: str - source: str - file_rel: str - resolved_npz_path: str - resolved_split: str - num_frames: int - fps: int - duration_s: float - weight: float = 1.0 - clip_index: int = -1 # index into merged NPZ clip_starts/clip_lengths; -1 = standalone clip - decision: str = "" - difficulty: str = "" - issue_tags: str = "" - note: str = "" - reviewed_at: str = "" - - -@dataclass(frozen=True) -class ReviewStats: - total: int - reviewed: int - keep_count: int - drop_count: int - skip_count: int - progress_pct: float - kept_duration_s: float - kept_train_duration_s: float - kept_val_duration_s: float - kept_duration_by_source: dict[str, float] - - -def load_review_state(path: Path) -> list[ReviewRow]: - """Load review_state.csv and return a list of ReviewRow.""" - if not path.is_file(): - raise FileNotFoundError(f"review state not found: {path}") - - with path.open("r", encoding="utf-8", newline="") as f: - reader = csv.DictReader(f) - if reader.fieldnames is None: - raise ValueError(f"review state is empty: {path}") - optional = {"clip_index"} - required = [c for c in REVIEW_COLUMNS if c not in optional] - missing = [c for c in required if c not in reader.fieldnames] - if missing: - raise ValueError(f"review state missing columns: {missing}") - has_clip_index = "clip_index" in reader.fieldnames - - rows: list[ReviewRow] = [] - for idx, raw in enumerate(reader, start=2): - try: - row = ReviewRow( - clip_id=raw["clip_id"].strip(), - source=raw["source"].strip(), - file_rel=raw["file_rel"].strip(), - resolved_npz_path=raw["resolved_npz_path"].strip(), - resolved_split=raw["resolved_split"].strip(), - num_frames=int(raw["num_frames"]), - fps=int(raw["fps"]), - duration_s=float(raw["duration_s"]), - weight=float(raw["weight"]), - clip_index=int(raw["clip_index"]) if has_clip_index else -1, - decision=raw["decision"].strip(), - difficulty=raw["difficulty"].strip(), - issue_tags=raw["issue_tags"].strip(), - note=raw["note"].strip(), - reviewed_at=raw["reviewed_at"].strip(), - ) - except Exception as exc: - raise ValueError(f"line {idx}: {exc}") from exc - - if row.decision not in VALID_DECISIONS: - raise ValueError(f"line {idx}: invalid decision '{row.decision}'") - if row.difficulty not in VALID_DIFFICULTIES: - raise ValueError(f"line {idx}: invalid difficulty '{row.difficulty}'") - rows.append(row) - return rows - - -def save_review_state(rows: list[ReviewRow], path: Path) -> None: - """Atomically write review_state.csv (write to .tmp, then os.replace).""" - path.parent.mkdir(parents=True, exist_ok=True) - tmp_path = path.with_suffix(".csv.tmp") - with tmp_path.open("w", encoding="utf-8", newline="") as f: - writer = csv.writer(f) - writer.writerow(REVIEW_COLUMNS) - for row in rows: - writer.writerow([ - row.clip_id, - row.source, - row.file_rel, - row.resolved_npz_path, - row.resolved_split, - row.num_frames, - row.fps, - f"{row.duration_s:.4f}", - row.weight, - row.clip_index, - row.decision, - row.difficulty, - row.issue_tags, - row.note, - row.reviewed_at, - ]) - os.replace(tmp_path, path) - - -def compute_review_stats(rows: list[ReviewRow]) -> ReviewStats: - """Compute aggregate statistics from review rows.""" - total = len(rows) - reviewed = sum(1 for r in rows if r.decision != "") - keep_count = sum(1 for r in rows if r.decision == "keep") - drop_count = sum(1 for r in rows if r.decision == "drop") - skip_count = sum(1 for r in rows if r.decision == "skip") - - kept_duration_s = 0.0 - kept_train_duration_s = 0.0 - kept_val_duration_s = 0.0 - kept_duration_by_source: dict[str, float] = {} - - for r in rows: - if r.decision != "keep": - continue - kept_duration_s += r.duration_s - if r.resolved_split == "train": - kept_train_duration_s += r.duration_s - elif r.resolved_split == "val": - kept_val_duration_s += r.duration_s - kept_duration_by_source[r.source] = ( - kept_duration_by_source.get(r.source, 0.0) + r.duration_s - ) - - progress_pct = (reviewed / total * 100.0) if total > 0 else 0.0 - - return ReviewStats( - total=total, - reviewed=reviewed, - keep_count=keep_count, - drop_count=drop_count, - skip_count=skip_count, - progress_pct=progress_pct, - kept_duration_s=kept_duration_s, - kept_train_duration_s=kept_train_duration_s, - kept_val_duration_s=kept_val_duration_s, - kept_duration_by_source=kept_duration_by_source, - ) - - -def utc_now_iso() -> str: - return datetime.now(timezone.utc).isoformat() diff --git a/train_mimic/scripts/benchmark.py b/train_mimic/scripts/benchmark.py index f9748e84..50c59f1a 100644 --- a/train_mimic/scripts/benchmark.py +++ b/train_mimic/scripts/benchmark.py @@ -10,7 +10,7 @@ # Benchmark only (no video) python train_mimic/scripts/benchmark.py \ --checkpoint logs/rsl_rl/g1_tracking/.../model_30000.pt \ - --motion_file data/datasets/twist2_full/val \ + --motion_file data/datasets/seed_precomputed \ --num_envs 1 # Single video (one continuous clip) @@ -27,6 +27,7 @@ import os from pathlib import Path +import h5py import numpy as np from tensordict import TensorDictBase @@ -39,6 +40,7 @@ validate_checkpoint_path, validate_motion_file, ) +from train_mimic.data.dataset_lib import find_precomputed_motion_shards from teleopit.debug.rollout_trace import RolloutTraceWriter @@ -142,7 +144,7 @@ def _stats(values: list[float]) -> dict[str, float]: def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Benchmark G1 tracking policy.") parser.add_argument("--checkpoint", type=str, required=True) - parser.add_argument("--motion_file", type=str, required=True, help="Path to motion shard directory") + parser.add_argument("--motion_file", type=str, required=True, help="Path to precomputed training dataset root containing Teleopit shard_*.h5 files") parser.add_argument("--num_envs", type=int, default=1) parser.add_argument("--num_eval_steps", type=int, default=2000, help="Number of rollout steps for evaluation (default: 2000)") @@ -172,25 +174,25 @@ def parse_args() -> argparse.Namespace: def _load_motion_dir_video_metadata(motion_dir: str) -> tuple[float, int]: - shard_dir = Path(motion_dir) - shard_files = sorted(shard_dir.glob("*.npz")) - if not shard_files: - raise FileNotFoundError(f"no shard NPZ files found in {motion_dir}") - clip_fps: float | None = None max_clip_frames = 0 - for shard_path in shard_files: - motion_data = np.load(shard_path, allow_pickle=True) - cur_fps = float(motion_data["fps"]) - if clip_fps is None: - clip_fps = cur_fps - elif clip_fps != cur_fps: - raise ValueError( - f"inconsistent fps across shards: {shard_path} has {cur_fps}, expected {clip_fps}" - ) - max_clip_frames = max(max_clip_frames, int(np.asarray(motion_data["clip_lengths"]).max())) + for shard_path in find_precomputed_motion_shards(motion_dir): + with h5py.File(shard_path, "r") as h5: + fps_arr = np.asarray(h5["clip_fps"], dtype=np.float32) + if fps_arr.size == 0: + continue + cur_fps = float(fps_arr[0]) + if np.any(fps_arr != cur_fps): + raise ValueError(f"inconsistent fps within HDF5 shard: {shard_path}") + if clip_fps is None: + clip_fps = cur_fps + elif clip_fps != cur_fps: + raise ValueError( + f"inconsistent fps across shards: {shard_path} has {cur_fps}, expected {clip_fps}" + ) + max_clip_frames = max(max_clip_frames, int(np.asarray(h5["clip_lengths"]).max())) if clip_fps is None: - raise ValueError(f"failed reading shard metadata from {motion_dir}") + raise ValueError(f"failed reading HDF5 shard metadata from {motion_dir}") return clip_fps, max_clip_frames @@ -505,9 +507,6 @@ def _open_clip_writer(idx: int): "error_body_ang_vel", "error_joint_pos", "error_joint_vel", - "sampling_entropy", - "sampling_top1_prob", - "sampling_top1_bin", ): if key not in metric_stats: continue diff --git a/train_mimic/scripts/convert_pkl_to_npz.py b/train_mimic/scripts/convert_pkl_to_npz.py index a0afb1d7..382d3744 100644 --- a/train_mimic/scripts/convert_pkl_to_npz.py +++ b/train_mimic/scripts/convert_pkl_to_npz.py @@ -1,8 +1,8 @@ #!/usr/bin/env python3 -"""Convert PKL motion files to NPZ format for mjlab MotionCommand. +"""Convert PKL motion files to per-clip NPZ format for dataset building. Reads retargeted PKL files (IsaacGym convention) and converts them to the NPZ -format expected by mjlab's MotionLoader. +clip format consumed by the HDF5 dataset builder. PKL fields: fps : int scalar @@ -23,7 +23,7 @@ body_names : list[str] 30 body names in mjlab G1 robot order IMPORTANT: NPZ body ordering must match mjlab G1 robot body ordering because -mjlab's MotionLoader uses robot body indices to index into body_pos_w. +the dataset builder preserves this order for HDF5 training shards. Usage: # Convert a single file @@ -47,6 +47,7 @@ from train_mimic.data.motion_fk import ( MotionFkExtractor, compute_body_velocities, + finite_diff_velocity, normalize_quaternion, quat_xyzw_to_wxyz, ) @@ -60,8 +61,8 @@ # mjlab G1 robot body ordering (matches robot.body_names from G1_ROBOT_CFG). -# mjlab's MotionLoader uses robot body indices to index into NPZ body_pos_w, -# so NPZ body ordering MUST match this list exactly. +# The dataset builder preserves this order in HDF5 shards, so per-clip NPZ +# body ordering MUST match this list exactly. _MJLAB_G1_BODY_NAMES = [ "pelvis", "left_hip_pitch_link", "left_hip_roll_link", "left_hip_yaw_link", @@ -131,7 +132,7 @@ def convert_pkl_to_arrays( ) # Joint velocity via finite difference - joint_vel = np.gradient(dof_pos, dt, axis=0).astype(np.float32) + joint_vel = finite_diff_velocity(dof_pos, dt) # Convert root quaternion: xyzw -> wxyz root_rot_wxyz = normalize_quaternion(quat_xyzw_to_wxyz(root_rot_xyzw)) @@ -148,6 +149,8 @@ def convert_pkl_to_arrays( return { "fps": fps, + "root_pos": root_pos, + "root_quat_w": root_rot_wxyz, "joint_pos": dof_pos, "joint_vel": joint_vel, "body_pos_w": body_pos_w, @@ -217,7 +220,7 @@ def convert_seed_csv_to_arrays( root_rot_xyzw = np.asarray(pkl_dict["root_rot"], dtype=np.float32) dof_pos = np.asarray(pkl_dict["dof_pos"], dtype=np.float32) - joint_vel = np.gradient(dof_pos, dt, axis=0).astype(np.float32) + joint_vel = finite_diff_velocity(dof_pos, dt) root_rot_wxyz = normalize_quaternion(quat_xyzw_to_wxyz(root_rot_xyzw)) fk_extractor = extractor or MotionFkExtractor() @@ -229,6 +232,8 @@ def convert_seed_csv_to_arrays( return { "fps": fps, + "root_pos": root_pos, + "root_quat_w": root_rot_wxyz, "joint_pos": dof_pos, "joint_vel": joint_vel, "body_pos_w": body_pos_w, diff --git a/train_mimic/scripts/data/build_dataset.py b/train_mimic/scripts/data/build_dataset.py index 74d1739a..0f9c4b2f 100644 --- a/train_mimic/scripts/data/build_dataset.py +++ b/train_mimic/scripts/data/build_dataset.py @@ -41,9 +41,8 @@ def main() -> int: output_root=args.output_root, ) print(f"[DONE] dataset={report['dataset']}") - print(f"[DONE] train={report['splits']['train']['output']}") - print(f"[DONE] val={report['splits']['val']['output']}") - print(f"[DONE] build_info={report['dataset_dir']}/build_info.json") + print(f"[DONE] output={report['dataset_dir']}") + print(f"[DONE] shards={len(report.get('shards', []))}") if args.json: print(json.dumps(report, ensure_ascii=False, indent=2)) return 0 diff --git a/train_mimic/scripts/data/check_motion_npz_fk.py b/train_mimic/scripts/data/check_motion_npz_fk.py index 6e31bc0a..621ac865 100644 --- a/train_mimic/scripts/data/check_motion_npz_fk.py +++ b/train_mimic/scripts/data/check_motion_npz_fk.py @@ -17,7 +17,7 @@ def parse_args() -> argparse.Namespace: "--xml", type=str, default=None, - help="Optional MuJoCo XML path (default: teleopit/retargeting/gmr/assets/unitree_g1/g1_mjlab.xml)", + help="Optional MuJoCo XML path (default: assets/robots/unitree_g1/g1_29dof.xml)", ) parser.add_argument( "--sample_count", diff --git a/train_mimic/scripts/data/ingest_motion.py b/train_mimic/scripts/data/ingest_motion.py index b8953d52..59c6c534 100644 --- a/train_mimic/scripts/data/ingest_motion.py +++ b/train_mimic/scripts/data/ingest_motion.py @@ -26,7 +26,6 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--input", required=True, help="Input file or directory") parser.add_argument("--output", required=True, help="Output clips directory") parser.add_argument("--source", default="source", help="Logical source name used in logs") - parser.add_argument("--weight", type=float, default=1.0, help="Optional source weight metadata") parser.add_argument("--bvh_format", choices=["lafan1", "hc_mocap", "nokov"], default=None) parser.add_argument("--robot_name", default="unitree_g1", help="Robot name for BVH retargeting") parser.add_argument("--max_frames", type=int, default=0, help="Max frames per BVH clip (0 = all)") @@ -42,7 +41,6 @@ def main() -> int: name=args.source, type=args.type, input=args.input, - weight=float(args.weight), bvh_format=args.bvh_format, robot_name=args.robot_name, max_frames=int(args.max_frames), diff --git a/train_mimic/scripts/data/inspect_dataset.py b/train_mimic/scripts/data/inspect_dataset.py new file mode 100644 index 00000000..84caf97c --- /dev/null +++ b/train_mimic/scripts/data/inspect_dataset.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import json + +from train_mimic.data.dataset_lib import compute_dataset_stats + + +def _format_duration(seconds: float) -> str: + total_seconds = int(round(seconds)) + hours, rem = divmod(total_seconds, 3600) + minutes, secs = divmod(rem, 60) + if hours: + return f"{hours:d}:{minutes:02d}:{secs:02d} ({seconds / 3600.0:.2f} h)" + return f"{minutes:d}:{secs:02d} ({seconds:.1f} s)" + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Inspect a Teleopit motion dataset root.") + parser.add_argument("dataset", type=str, help="Dataset root directory or a single .h5 shard") + parser.add_argument("--json", action="store_true", help="Print full JSON statistics") + return parser.parse_args() + + +def main() -> int: + args = parse_args() + stats = compute_dataset_stats(args.dataset) + if args.json: + print(json.dumps(stats, ensure_ascii=False, indent=2)) + return 0 + + print(f"root: {stats['root']}") + print(f"shards: {stats['shards']}") + print(f"windows: {stats['windows']}") + print(f"source_clips: {stats['source_clips']}") + print(f"frames: {stats['frames']}") + print(f"duration: {_format_duration(float(stats['duration_s']))}") + print(f"fps: {stats['fps']}") + print(f"bodies: {len(stats['body_names'])}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/train_mimic/scripts/data/precompute_dataset.py b/train_mimic/scripts/data/precompute_dataset.py new file mode 100644 index 00000000..323cf009 --- /dev/null +++ b/train_mimic/scripts/data/precompute_dataset.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python3 +"""Precompute derived FK and velocity arrays for minimal Teleopit HDF5 shards.""" + +from __future__ import annotations + +import argparse +import json +from concurrent.futures import ProcessPoolExecutor, as_completed +from pathlib import Path +from typing import Any + +from train_mimic.data.dataset_lib import ( + find_motion_shards, + write_precomputed_motion_shard, +) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Precompute FK and velocity arrays for a minimal Teleopit HDF5 dataset." + ) + parser.add_argument("dataset", type=str, help="Dataset root directory or a single .h5 shard") + parser.add_argument( + "--outdir", + type=str, + default=None, + help="Output training dataset directory. Defaults to _precomputed.", + ) + parser.add_argument("--force", action="store_true", help="Overwrite existing precomputed shards") + parser.add_argument("--jobs", type=int, default=1, help="Number of shard-level worker processes") + parser.add_argument( + "--model_path", + type=str, + default=None, + help="Optional MuJoCo XML override for FK precompute", + ) + parser.add_argument("--json", action="store_true", help="Print final report JSON") + return parser.parse_args() + + +def main() -> int: + args = parse_args() + if args.jobs <= 0: + raise ValueError(f"--jobs must be positive, got {args.jobs}") + + shard_paths = find_motion_shards(args.dataset) + source_root = Path(args.dataset).expanduser().resolve() + if source_root.is_file(): + default_outdir = source_root.parent.with_name(f"{source_root.parent.name}_precomputed") + else: + default_outdir = source_root.with_name(f"{source_root.name}_precomputed") + outdir = Path(args.outdir).expanduser().resolve() if args.outdir is not None else default_outdir + outdir.mkdir(parents=True, exist_ok=True) + + def output_path_for(source_path: Path) -> Path: + rel = Path(source_path.name) if source_root.is_file() else source_path.relative_to(source_root) + return outdir / rel + + results: list[dict[str, Any]] = [] + + if args.jobs == 1: + for shard_path in shard_paths: + result = write_precomputed_motion_shard( + shard_path, + output_path_for(shard_path), + force=args.force, + model_path=args.model_path, + ) + results.append(result) + print(f"[{result['status'].upper()}] {result['output']}") + else: + with ProcessPoolExecutor(max_workers=args.jobs) as executor: + futures = { + executor.submit( + write_precomputed_motion_shard, + shard_path, + output_path_for(shard_path), + force=args.force, + model_path=args.model_path, + ): shard_path + for shard_path in shard_paths + } + for future in as_completed(futures): + result = future.result() + results.append(result) + print(f"[{result['status'].upper()}] {result['output']}") + + report = { + "dataset": str(Path(args.dataset).expanduser().resolve()), + "outdir": str(outdir), + "shards": len(shard_paths), + "written": sum(1 for item in results if item["status"] == "written"), + "existing": sum(1 for item in results if item["status"] == "existing"), + "results": sorted(results, key=lambda item: item["shard"]), + } + print( + f"[DONE] shards={report['shards']} written={report['written']} " + f"existing={report['existing']}" + ) + if args.json: + print(json.dumps(report, ensure_ascii=True, indent=2)) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/train_mimic/scripts/data/split_shards.py b/train_mimic/scripts/data/split_shards.py deleted file mode 100644 index 69072263..00000000 --- a/train_mimic/scripts/data/split_shards.py +++ /dev/null @@ -1,217 +0,0 @@ -#!/usr/bin/env python3 -"""Repartition a shard directory into smaller shard NPZ files. - -Each shard is a self-contained merged NPZ (with clip_starts, clip_lengths, -clip_fps, clip_weights, etc.) that can be independently loaded. Splits are -made at clip boundaries — no clip is ever truncated across shards. - -Usage: - python train_mimic/scripts/data/split_shards.py \ - --input data/datasets/seed_v1/train \ - --output data/datasets/seed_v1/train_small_shards \ - --max_size_gb 2 -""" - -from __future__ import annotations - -import argparse -import sys -from pathlib import Path - -import numpy as np - -MOTION_ARRAY_KEYS = [ - "joint_pos", - "joint_vel", - "body_pos_w", - "body_quat_w", - "body_lin_vel_w", - "body_ang_vel_w", -] - - -def _estimate_frames_per_gb(data: dict) -> float: - """Estimate how many frames fit in 1 GB based on array dtypes/shapes.""" - one_frame_bytes = 0 - for k in MOTION_ARRAY_KEYS: - arr = data[k] - one_frame_bytes += int(np.prod(arr.shape[1:])) * arr.dtype.itemsize - return 1e9 / max(one_frame_bytes, 1) - - -def _load_shard_dir(input_dir: Path) -> dict: - shard_files = sorted(input_dir.glob("*.npz")) - if not shard_files: - raise FileNotFoundError(f"no shard NPZ files found in {input_dir}") - - arrays: dict[str, list[np.ndarray]] = {k: [] for k in MOTION_ARRAY_KEYS} - clip_lengths: list[np.ndarray] = [] - clip_fps: list[np.ndarray] = [] - clip_weights: list[np.ndarray] = [] - clip_sample_starts: list[np.ndarray] = [] - clip_sample_ends: list[np.ndarray] = [] - window_steps: np.ndarray | None = None - fps: int | None = None - body_names: np.ndarray | None = None - - for shard_path in shard_files: - data = np.load(shard_path, allow_pickle=True) - for key in MOTION_ARRAY_KEYS: - arrays[key].append(np.asarray(data[key])) - if fps is None: - fps = int(data["fps"]) - body_names = np.asarray(data["body_names"]) - else: - if int(data["fps"]) != fps: - raise ValueError( - f"inconsistent fps across shards: {shard_path} has {int(data['fps'])}, expected {fps}" - ) - if not np.array_equal(np.asarray(data["body_names"]), body_names): - raise ValueError(f"inconsistent body_names across shards: {shard_path}") - clip_lengths.append(np.asarray(data["clip_lengths"])) - clip_fps.append(np.asarray(data["clip_fps"])) - clip_weights.append(np.asarray(data["clip_weights"])) - if "clip_sample_starts" in data: - clip_sample_starts.append(np.asarray(data["clip_sample_starts"])) - if "clip_sample_ends" in data: - clip_sample_ends.append(np.asarray(data["clip_sample_ends"])) - if "window_steps" in data and window_steps is None: - window_steps = np.asarray(data["window_steps"]) - - merged = {key: np.concatenate(values, axis=0) for key, values in arrays.items()} - merged["fps"] = fps - merged["body_names"] = body_names - merged["clip_lengths"] = np.concatenate(clip_lengths) - merged["clip_fps"] = np.concatenate(clip_fps) - merged["clip_weights"] = np.concatenate(clip_weights) - merged["clip_starts"] = np.zeros(len(merged["clip_lengths"]), dtype=np.int64) - if len(merged["clip_lengths"]) > 1: - merged["clip_starts"][1:] = np.cumsum(merged["clip_lengths"][:-1]) - if clip_sample_starts and sum(len(v) for v in clip_sample_starts) == len(merged["clip_lengths"]): - merged["clip_sample_starts"] = np.concatenate(clip_sample_starts) - if clip_sample_ends and sum(len(v) for v in clip_sample_ends) == len(merged["clip_lengths"]): - merged["clip_sample_ends"] = np.concatenate(clip_sample_ends) - if window_steps is not None: - merged["window_steps"] = window_steps - return merged - - -def split_shards( - input_path: Path, - output_dir: Path, - max_size_gb: float = 2.0, -) -> list[Path]: - """Split a shard directory into smaller shards of approximately *max_size_gb* each.""" - print(f"Loading {input_path} ...") - data = _load_shard_dir(input_path) - - clip_starts = np.asarray(data["clip_starts"]) - clip_lengths = np.asarray(data["clip_lengths"]) - num_clips = len(clip_starts) - total_frames = int(data["joint_pos"].shape[0]) - - has_fps_array = "clip_fps" in data - has_weights = "clip_weights" in data - has_window_steps = "window_steps" in data - has_sample_starts = "clip_sample_starts" in data - has_sample_ends = "clip_sample_ends" in data - fps = data["fps"] - body_names = data["body_names"] - - frames_per_gb = _estimate_frames_per_gb(data) - max_frames_per_shard = int(frames_per_gb * max_size_gb) - - # --- plan shard boundaries (by clip index) --- - shard_ranges: list[tuple[int, int]] = [] # (clip_start_idx, clip_end_idx) - cur_start = 0 - cur_frames = 0 - for i in range(num_clips): - cl = int(clip_lengths[i]) - if cur_frames + cl > max_frames_per_shard and cur_frames > 0: - shard_ranges.append((cur_start, i)) - cur_start = i - cur_frames = 0 - cur_frames += cl - if cur_start < num_clips: - shard_ranges.append((cur_start, num_clips)) - - print( - f" {num_clips} clips, {total_frames} frames -> " - f"{len(shard_ranges)} shards (target ~{max_size_gb} GB each)" - ) - - # --- write shards --- - output_dir.mkdir(parents=True, exist_ok=True) - shard_paths: list[Path] = [] - n_digits = max(3, len(str(len(shard_ranges) - 1))) - - for shard_idx, (c_start, c_end) in enumerate(shard_ranges): - frame_start = int(clip_starts[c_start]) - if c_end < num_clips: - frame_end = int(clip_starts[c_end]) - else: - frame_end = total_frames - - s = slice(frame_start, frame_end) - shard: dict = {} - for k in MOTION_ARRAY_KEYS: - shard[k] = np.asarray(data[k][s]) - - # Rebuild clip metadata relative to this shard - shard_clip_lengths = clip_lengths[c_start:c_end].copy() - shard_clip_starts = np.zeros(c_end - c_start, dtype=np.int64) - if len(shard_clip_lengths) > 1: - shard_clip_starts[1:] = np.cumsum(shard_clip_lengths[:-1]) - - shard["fps"] = fps - shard["body_names"] = body_names - shard["clip_starts"] = shard_clip_starts - shard["clip_lengths"] = shard_clip_lengths - if has_fps_array: - shard["clip_fps"] = np.asarray(data["clip_fps"][c_start:c_end]) - if has_weights: - shard["clip_weights"] = np.asarray(data["clip_weights"][c_start:c_end]) - if has_window_steps: - shard["window_steps"] = np.asarray(data["window_steps"]) - if has_sample_starts: - shard["clip_sample_starts"] = np.asarray(data["clip_sample_starts"][c_start:c_end]) - if has_sample_ends: - shard["clip_sample_ends"] = np.asarray(data["clip_sample_ends"][c_start:c_end]) - - shard_name = f"shard_{shard_idx:0{n_digits}d}.npz" - shard_path = output_dir / shard_name - np.savez(shard_path, **shard) - - shard_frames = frame_end - frame_start - shard_size_mb = shard_path.stat().st_size / 1e6 - print( - f" {shard_name}: {c_end - c_start} clips, " - f"{shard_frames} frames, {shard_size_mb:.0f} MB" - ) - shard_paths.append(shard_path) - - print(f"Done. {len(shard_paths)} shards written to {output_dir}") - return shard_paths - - -def main() -> None: - parser = argparse.ArgumentParser(description="Repartition a shard directory into smaller shards") - parser.add_argument("--input", required=True, type=Path, help="Input shard directory") - parser.add_argument("--output", required=True, type=Path, help="Output shard directory") - parser.add_argument( - "--max_size_gb", - type=float, - default=2.0, - help="Target max size per shard in GB (default: 2.0)", - ) - args = parser.parse_args() - - if not args.input.is_dir(): - print(f"Error: shard directory not found: {args.input}", file=sys.stderr) - sys.exit(1) - - split_shards(args.input, args.output, args.max_size_gb) - - -if __name__ == "__main__": - main() diff --git a/train_mimic/scripts/play.py b/train_mimic/scripts/play.py index 6d2a4b5a..8ad5a9d7 100644 --- a/train_mimic/scripts/play.py +++ b/train_mimic/scripts/play.py @@ -9,18 +9,18 @@ # Native window python train_mimic/scripts/play.py \ --checkpoint logs/rsl_rl/g1_tracking/2026-.../model_30000.pt \ - --motion_file data/datasets/twist2_full/val + --motion_file data/datasets/seed_precomputed # Browser viewer (no display required) python train_mimic/scripts/play.py \ --checkpoint logs/rsl_rl/g1_tracking/2026-.../model_30000.pt \ - --motion_file data/datasets/twist2_full/val \ + --motion_file data/datasets/seed_precomputed \ --viewer viser # Record video instead of interactive viewer python train_mimic/scripts/play.py \ --checkpoint logs/rsl_rl/g1_tracking/2026-.../model_30000.pt \ - --motion_file data/datasets/twist2_full/val \ + --motion_file data/datasets/seed_precomputed \ --video """ @@ -44,7 +44,7 @@ def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Play trained G1 tracking policy.") parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint") - parser.add_argument("--motion_file", type=str, required=True, help="Path to motion shard directory") + parser.add_argument("--motion_file", type=str, required=True, help="Path to precomputed training dataset root containing Teleopit shard_*.h5 files") parser.add_argument("--num_envs", type=int, default=1) parser.add_argument( "--viewer", type=str, default="native", choices=["native", "viser"], diff --git a/train_mimic/scripts/train.py b/train_mimic/scripts/train.py index f45d5afc..4f8bf7b3 100644 --- a/train_mimic/scripts/train.py +++ b/train_mimic/scripts/train.py @@ -4,24 +4,30 @@ Usage: python train_mimic/scripts/train.py \ --num_envs 4096 --max_iterations 18000 \ - --motion_file data/datasets/twist2_full/train + --motion_file data/datasets/seed_precomputed # Quick verification python train_mimic/scripts/train.py \ --num_envs 64 --max_iterations 100 \ - --motion_file data/datasets/twist2_full/train + --motion_file data/datasets/seed_precomputed - # With wandb logging + # With W&B logging python train_mimic/scripts/train.py \ --num_envs 4096 --max_iterations 30000 \ - --motion_file data/datasets/twist2_full/train \ - --wandb_project teleopit + --motion_file data/datasets/seed_precomputed \ + --logger wandb + + # With SwanLab logging + python train_mimic/scripts/train.py \ + --num_envs 4096 --max_iterations 30000 \ + --motion_file data/datasets/seed_precomputed \ + --logger swanlab # Resume for additional iterations python train_mimic/scripts/train.py \ --resume logs/rsl_rl/g1_general_tracking//model_12000.pt \ --max_iterations 18000 \ - --motion_file data/datasets/twist2_full/train + --motion_file data/datasets/seed_precomputed """ from __future__ import annotations @@ -44,6 +50,10 @@ validate_motion_file, ) from train_mimic.tasks.tracking.config.constants import DEFAULT_TRAIN_MOTION_FILE +from train_mimic.tasks.tracking.config.env import ( + make_g1_training_robot_cfg, + resolve_g1_training_xml, +) def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace: @@ -59,11 +69,25 @@ def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace: ), ) parser.add_argument("--seed", type=int, default=42) - parser.add_argument("--wandb_project", type=str, default=None, - help="Enable wandb and set project name (default: tensorboard)") + parser.add_argument( + "--logger", + type=str, + default="tensorboard", + choices=["tensorboard", "wandb", "swanlab"], + help="Experiment logger backend (default: tensorboard)", + ) parser.add_argument("--experiment_name", type=str, default=None) parser.add_argument("--motion_file", type=str, default=None, - help="Shard directory path containing shard_*.npz files") + help="Precomputed training dataset root containing Teleopit shard_*.h5 files, searched recursively") + parser.add_argument( + "--robot_xml", + type=str, + default=None, + help=( + "MuJoCo XML used for the G1 training robot " + "(default: assets/robots/unitree_g1/g1_29dof.xml)" + ), + ) parser.add_argument( "--resume", type=str, @@ -74,8 +98,14 @@ def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace: ), ) parser.add_argument("--sampling_mode", type=str, default=None, - choices=["uniform", "adaptive", "adaptive_bin", "start"], + choices=["uniform", "start", "rewind"], help="Motion sampling mode (default: from task config)") + parser.add_argument("--rewind_prob", type=float, default=None, + help="Rewind sampling probability for failed episodes") + parser.add_argument("--rewind_min_steps", type=int, default=None, + help="Minimum policy steps to rewind for rewind sampling") + parser.add_argument("--rewind_max_steps", type=int, default=None, + help="Maximum policy steps to rewind for rewind sampling") parser.add_argument("--device", type=str, default=None) parser.add_argument( "--gpu_ids", @@ -219,6 +249,70 @@ def _resolve_device(args: argparse.Namespace, torch_module: object) -> str: return "cuda:0" if torch_module.cuda.is_available() else "cpu" +def _resolve_worker_seed(base_seed: int, env: dict[str, str] | None = None) -> int: + runtime_env = os.environ if env is None else env + if not _is_distributed_env(runtime_env): + return base_seed + global_rank = int(runtime_env.get("RANK", "0")) + return base_seed + global_rank * 100003 + + +def _is_main_process(env: dict[str, str] | None = None) -> bool: + runtime_env = os.environ if env is None else env + return int(runtime_env.get("RANK", "0")) == 0 + + +def _configure_experiment_logger( + *, + logger_name: str, + agent_cfg: Any, + env_cfg: Any, + log_dir: str, +) -> bool: + """Configure the training logger and return whether SwanLab was started.""" + if logger_name == "tensorboard": + agent_cfg.logger = "tensorboard" + return False + + if logger_name == "wandb": + agent_cfg.logger = "wandb" + agent_cfg.wandb_project = agent_cfg.experiment_name + return False + + if logger_name != "swanlab": + raise ValueError(f"Unsupported logger '{logger_name}'") + + agent_cfg.logger = "tensorboard" + if not _is_main_process(): + return False + + try: + import swanlab + except ModuleNotFoundError: + raise ModuleNotFoundError( + "swanlab package is required for --logger swanlab. Install it with `pip install swanlab`." + ) from None + + swanlab.init( + project=agent_cfg.experiment_name, + name=os.path.basename(log_dir), + log_dir=log_dir, + config={ + "experiment_name": agent_cfg.experiment_name, + "motion_file": env_cfg.commands["motion"].motion_file, + "robot_xml": getattr(env_cfg, "robot_xml", None), + "num_envs": env_cfg.scene.num_envs, + "max_iterations": agent_cfg.max_iterations, + "sampling_mode": env_cfg.commands["motion"].sampling_mode, + "rewind_prob": env_cfg.commands["motion"].rewind_prob, + "rewind_min_steps": env_cfg.commands["motion"].rewind_min_steps, + "rewind_max_steps": env_cfg.commands["motion"].rewind_max_steps, + }, + ) + swanlab.sync_tensorboard_torch(types=["scalar", "scalars", "image", "text"]) + return True + + def _launch_multi_gpu(args: argparse.Namespace, argv: Sequence[str]) -> None: _validate_multi_gpu_args(args) command = _build_torchrun_command(args, argv) @@ -286,12 +380,13 @@ def _handle_shutdown(signum: int, _frame: Any) -> None: load_runner_cls=load_runner_cls, ) - # Default to tensorboard (mjlab defaults to wandb) - if args.wandb_project is None: - agent_cfg.logger = "tensorboard" - # CLI overrides - env_cfg.seed = args.seed + env_cfg.seed = _resolve_worker_seed(args.seed) + robot_xml = resolve_g1_training_xml(args.robot_xml) + if not robot_xml.is_file(): + raise FileNotFoundError(f"G1 training MuJoCo XML not found: {robot_xml}") + env_cfg.scene.entities["robot"] = make_g1_training_robot_cfg(robot_xml) + env_cfg.robot_xml = str(robot_xml) if args.num_envs is not None: env_cfg.scene.num_envs = args.num_envs if args.motion_file is not None: @@ -299,13 +394,16 @@ def _handle_shutdown(signum: int, _frame: Any) -> None: validate_motion_file(env_cfg.commands["motion"].motion_file) if args.sampling_mode is not None: env_cfg.commands["motion"].sampling_mode = args.sampling_mode + if args.rewind_prob is not None: + env_cfg.commands["motion"].rewind_prob = args.rewind_prob + if args.rewind_min_steps is not None: + env_cfg.commands["motion"].rewind_min_steps = args.rewind_min_steps + if args.rewind_max_steps is not None: + env_cfg.commands["motion"].rewind_max_steps = args.rewind_max_steps if args.max_iterations is not None: agent_cfg.max_iterations = args.max_iterations if args.experiment_name is not None: agent_cfg.experiment_name = args.experiment_name - if args.wandb_project is not None: - agent_cfg.logger = "wandb" - agent_cfg.wandb_project = args.wandb_project device = _resolve_device(args, torch) @@ -314,6 +412,12 @@ def _handle_shutdown(signum: int, _frame: Any) -> None: os.makedirs(log_root, exist_ok=True) log_dir = os.path.join(log_root, datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) os.makedirs(log_dir, exist_ok=True) + swanlab_active = _configure_experiment_logger( + logger_name=args.logger, + agent_cfg=agent_cfg, + env_cfg=env_cfg, + log_dir=log_dir, + ) # render_mode only needed for video recording render_mode = "rgb_array" if args.video else None @@ -355,6 +459,11 @@ def _handle_shutdown(signum: int, _frame: Any) -> None: if env is not None: with contextlib.suppress(Exception): env.close() + if swanlab_active: + with contextlib.suppress(Exception): + import swanlab + + swanlab.finish() _destroy_process_group(torch) signal.signal(signal.SIGINT, old_sigint) signal.signal(signal.SIGTERM, old_sigterm) diff --git a/train_mimic/tasks/tracking/config/constants.py b/train_mimic/tasks/tracking/config/constants.py index f534929e..17560fa5 100644 --- a/train_mimic/tasks/tracking/config/constants.py +++ b/train_mimic/tasks/tracking/config/constants.py @@ -1,6 +1,6 @@ """Public constants for supported tracking tasks.""" -DEFAULT_TRAIN_MOTION_FILE = "data/datasets/twist2_full/train" +DEFAULT_TRAIN_MOTION_FILE = "data/datasets/seed_precomputed" GENERAL_TRACKING_TASK = "General-Tracking-G1" GENERAL_TRACKING_EXPERIMENT_NAME = "g1_general_tracking" diff --git a/train_mimic/tasks/tracking/config/env.py b/train_mimic/tasks/tracking/config/env.py index d8aa5268..48bc3ecb 100644 --- a/train_mimic/tasks/tracking/config/env.py +++ b/train_mimic/tasks/tracking/config/env.py @@ -3,11 +3,16 @@ from __future__ import annotations from copy import deepcopy +from functools import partial +from pathlib import Path + +import mujoco from mjlab.asset_zoo.robots import G1_ACTION_SCALE, get_g1_robot_cfg from mjlab.envs import ManagerBasedRlEnvCfg from mjlab.envs.mdp.actions import JointPositionActionCfg from mjlab.managers.observation_manager import ObservationGroupCfg, ObservationTermCfg +from mjlab.managers.reward_manager import RewardTermCfg from mjlab.managers.scene_entity_config import SceneEntityCfg from mjlab.sensor import ContactMatch, ContactSensorCfg from mjlab.utils.noise import UniformNoiseCfg as Unoise @@ -16,6 +21,7 @@ from train_mimic.tasks.tracking.config.constants import DEFAULT_TRAIN_MOTION_FILE from train_mimic.tasks.tracking.mdp import MotionCommandCfg from train_mimic.tasks.tracking.tracking_env_cfg import make_tracking_env_cfg +from teleopit.runtime.assets import UNITREE_G1_XML, missing_gmr_assets_message _TRACKING_BODY_NAMES = ( "pelvis", @@ -34,6 +40,46 @@ "right_wrist_yaw_link", ) +_TRAIN_ONLY_EVENTS = ( + "push_robot", + "base_com", + "add_joint_default_pos", + "physics_material", + "randomize_rigid_body_mass", +) + +def resolve_g1_training_xml(robot_xml: str | Path | None = None) -> Path: + """Resolve the MuJoCo XML used for G1 policy training.""" + if robot_xml is None or str(robot_xml).strip() == "": + return UNITREE_G1_XML.resolve() + + path = Path(robot_xml).expanduser() + if not path.is_absolute(): + path = (Path.cwd() / path).resolve() + return path + + +def _get_g1_training_spec(robot_xml: str | Path | None = None) -> mujoco.MjSpec: + xml_path = resolve_g1_training_xml(robot_xml) + if not xml_path.is_file(): + raise FileNotFoundError( + missing_gmr_assets_message(xml_path, label="G1 training MuJoCo XML") + ) + spec = mujoco.MjSpec.from_file(str(xml_path)) + for actuator in list(spec.actuators): + spec.delete(actuator) + for key in list(spec.keys): + spec.delete(key) + return spec + + +def make_g1_training_robot_cfg(robot_xml: str | Path | None = None): + robot_cfg = get_g1_robot_cfg() + robot_cfg.articulation = deepcopy(robot_cfg.articulation) + xml_path = resolve_g1_training_xml(robot_xml) + robot_cfg.spec_fn = partial(_get_g1_training_spec, xml_path) + return robot_cfg + def _apply_play_mode_overrides(cfg: ManagerBasedRlEnvCfg) -> None: motion_cmd = cfg.commands["motion"] @@ -41,7 +87,8 @@ def _apply_play_mode_overrides(cfg: ManagerBasedRlEnvCfg) -> None: cfg.episode_length_s = int(1e9) cfg.observations["actor"].enable_corruption = False - cfg.events.pop("push_robot", None) + for event_name in _TRAIN_ONLY_EVENTS: + cfg.events.pop(event_name, None) motion_cmd.pose_range = {} motion_cmd.velocity_range = {} motion_cmd.sampling_mode = "start" @@ -67,59 +114,112 @@ def _add_history_obs_groups( _VELCMD_ACTOR_TERMS: dict[str, ObservationTermCfg] = { - "projected_gravity": ObservationTermCfg( + "robot_projected_gravity_b": ObservationTermCfg( func=mdp.projected_gravity, noise=Unoise(n_min=-0.05, n_max=0.05), ), - "ref_base_lin_vel_b": ObservationTermCfg( - func=mdp.ref_base_lin_vel_b, + "ref_anchor_lin_vel_b": ObservationTermCfg( + func=mdp.ref_anchor_lin_vel_b, params={"command_name": "motion"}, ), - "ref_base_ang_vel_b": ObservationTermCfg( - func=mdp.ref_base_ang_vel_b, + "ref_anchor_ang_vel_b": ObservationTermCfg( + func=mdp.ref_anchor_ang_vel_b, params={"command_name": "motion"}, ), "ref_projected_gravity_b": ObservationTermCfg( func=mdp.ref_projected_gravity_b, params={"command_name": "motion"}, ), + "ref_anchor_height": ObservationTermCfg( + func=mdp.ref_anchor_height, + params={"command_name": "motion"}, + ), } _VELCMD_CRITIC_TERMS: dict[str, ObservationTermCfg] = { - "projected_gravity": ObservationTermCfg(func=mdp.projected_gravity), - "ref_base_lin_vel_b": ObservationTermCfg( - func=mdp.ref_base_lin_vel_b, + "robot_projected_gravity_b": ObservationTermCfg(func=mdp.projected_gravity), + "ref_anchor_lin_vel_b": ObservationTermCfg( + func=mdp.ref_anchor_lin_vel_b, params={"command_name": "motion"}, ), - "ref_base_ang_vel_b": ObservationTermCfg( - func=mdp.ref_base_ang_vel_b, + "ref_anchor_ang_vel_b": ObservationTermCfg( + func=mdp.ref_anchor_ang_vel_b, params={"command_name": "motion"}, ), "ref_projected_gravity_b": ObservationTermCfg( func=mdp.ref_projected_gravity_b, params={"command_name": "motion"}, ), + "ref_anchor_height": ObservationTermCfg( + func=mdp.ref_anchor_height, + params={"command_name": "motion"}, + ), } -def make_general_tracking_env_cfg( - *, play: bool = False, -) -> ManagerBasedRlEnvCfg: - """Create the General-Tracking-G1 training env.""" - cfg = make_tracking_env_cfg() - - cfg.scene.entities = {"robot": get_g1_robot_cfg()} +def _configure_self_collision_reward(cfg: ManagerBasedRlEnvCfg) -> None: + excluded_body_names = ( + "left_wrist_yaw_link", + "right_wrist_yaw_link", + ) cfg.scene.sensors = ( + *tuple(getattr(cfg.scene, "sensors", ()) or ()), ContactSensorCfg( name="self_collision", - primary=ContactMatch(mode="subtree", pattern="pelvis", entity="robot"), + # Exclude only primary wrist bodies; wrist vs torso is still caught by torso. + primary=ContactMatch( + mode="body", + pattern=r".*", + entity="robot", + exclude=excluded_body_names, + ), secondary=ContactMatch(mode="subtree", pattern="pelvis", entity="robot"), fields=("found", "force"), - reduce="none", + reduce="maxforce", num_slots=1, history_length=4, ), ) + cfg.rewards["self_collisions"] = RewardTermCfg( + func=mdp.self_collision_cost, + weight=-0.1, + params={ + "sensor_name": "self_collision", + "force_threshold": 1.0, + }, + ) + + +def _configure_feet_acc_reward(cfg: ManagerBasedRlEnvCfg) -> None: + cfg.rewards["feet_acc"] = RewardTermCfg( + func=mdp.joint_acc_l2, + weight=-2.5e-6, + params={ + "asset_cfg": SceneEntityCfg("robot", joint_names=r".*ankle.*"), + }, + ) + + +def _configure_additional_wrist_pos_reward(cfg: ManagerBasedRlEnvCfg) -> None: + cfg.rewards["additional_wrist_pos"] = RewardTermCfg( + func=mdp.motion_relative_body_point_position_error_exp, + weight=1.0, + params={ + "command_name": "motion", + "std": 0.12, + "body_names": ("left_wrist_yaw_link", "right_wrist_yaw_link"), + "body_offsets": ((0.18, -0.025, 0.0), (0.18, 0.025, 0.0)), + }, + ) + + +def make_general_tracking_env_cfg( + *, play: bool = False, +) -> ManagerBasedRlEnvCfg: + """Create the General-Tracking-G1 training env.""" + cfg = make_tracking_env_cfg() + + cfg.scene.entities = {"robot": make_g1_training_robot_cfg()} joint_pos_action = cfg.actions["joint_pos"] assert isinstance(joint_pos_action, JointPositionActionCfg) @@ -130,22 +230,32 @@ def make_general_tracking_env_cfg( motion_cmd.anchor_body_name = "torso_link" motion_cmd.body_names = _TRACKING_BODY_NAMES motion_cmd.motion_file = DEFAULT_TRAIN_MOTION_FILE - motion_cmd.sampling_mode = "uniform" + motion_cmd.sampling_mode = "rewind" motion_cmd.window_steps = (0,) - cfg.events["foot_friction"].params[ + cfg.events["physics_material"].params[ "asset_cfg" - ].geom_names = r"^(left|right)_foot[1-7]_collision$" + ].geom_names = r".*_collision$" cfg.events["base_com"].params["asset_cfg"].body_names = ("torso_link",) + cfg.events["randomize_rigid_body_mass"].params[ + "asset_cfg" + ].body_names = ( + "torso_link", + "left_wrist_yaw_link", + "right_wrist_yaw_link", + ) + _configure_self_collision_reward(cfg) + _configure_feet_acc_reward(cfg) + _configure_additional_wrist_pos_reward(cfg) cfg.terminations["ee_body_pos"].params["body_names"] = ( "left_ankle_roll_link", "right_ankle_roll_link", "left_wrist_yaw_link", "right_wrist_yaw_link", ) - cfg.terminations["anchor_pos"].params["threshold"] = 0.4 + cfg.terminations["anchor_pos"].params["threshold"] = 0.25 cfg.terminations["anchor_ori"].params["threshold"] = 1.0 - cfg.terminations["ee_body_pos"].params["threshold"] = 0.4 + cfg.terminations["ee_body_pos"].params["threshold"] = 0.25 cfg.viewer.body_name = "torso_link" cfg.episode_length_s = 10.0 if cfg.sim.njmax < 500: @@ -154,7 +264,7 @@ def make_general_tracking_env_cfg( actor_terms = { key: value for key, value in cfg.observations["actor"].terms.items() - if key not in {"motion_anchor_pos_b", "base_lin_vel"} + if key not in {"ref_anchor_pos_b", "robot_base_lin_vel_b"} } cfg.observations["actor"] = ObservationGroupCfg( terms=actor_terms, diff --git a/train_mimic/tasks/tracking/config/rl.py b/train_mimic/tasks/tracking/config/rl.py index 2c674a36..c35715cc 100644 --- a/train_mimic/tasks/tracking/config/rl.py +++ b/train_mimic/tasks/tracking/config/rl.py @@ -10,7 +10,7 @@ "train_mimic.tasks.tracking.rl.temporal_cnn_model:TemporalCNNModel" ) _CNN_CFG: dict = { - "output_channels": (128, 64, 32), + "output_channels": (256, 128, 64), "kernel_size": 3, "activation": "elu", "global_pool": "avg", @@ -24,7 +24,7 @@ def make_general_tracking_ppo_runner_cfg( return RslRlOnPolicyRunnerCfg( actor=RslRlModelCfg( class_name=_TEMPORAL_CNN_MODEL_CLASS, - hidden_dims=(1024, 512, 256, 256, 128), + hidden_dims=(2048, 1024, 512, 256, 128), activation="elu", obs_normalization=True, cnn_cfg=_CNN_CFG, @@ -36,7 +36,7 @@ def make_general_tracking_ppo_runner_cfg( ), critic=RslRlModelCfg( class_name=_TEMPORAL_CNN_MODEL_CLASS, - hidden_dims=(1024, 512, 256, 256, 128), + hidden_dims=(2048, 1024, 512, 256, 128), activation="elu", obs_normalization=True, cnn_cfg=_CNN_CFG, @@ -48,7 +48,7 @@ def make_general_tracking_ppo_runner_cfg( entropy_coef=0.005, num_learning_epochs=5, num_mini_batches=4, - learning_rate=1.0e-3, + learning_rate=5.0e-4, schedule="adaptive", gamma=0.99, lam=0.95, diff --git a/train_mimic/tasks/tracking/mdp/commands.py b/train_mimic/tasks/tracking/mdp/commands.py index f24e964d..add39de0 100644 --- a/train_mimic/tasks/tracking/mdp/commands.py +++ b/train_mimic/tasks/tracking/mdp/commands.py @@ -2,16 +2,23 @@ import copy import logging -import math from dataclasses import dataclass, field from pathlib import Path from typing import TYPE_CHECKING, Any, Literal +import h5py import mujoco import numpy as np import torch -from train_mimic.data.dataset_lib import compute_clip_sample_ranges, parse_window_steps +from train_mimic.data.dataset_lib import ( + MOTION_ARRAY_KEYS, + compute_clip_sample_ranges, + compute_dataset_stats, + find_precomputed_motion_shards, + parse_window_steps, + validate_precomputed_motion_shard, +) from mjlab.managers import CommandTerm, CommandTermCfg from mjlab.utils.lab_api.math import ( @@ -58,193 +65,166 @@ def _batched_quat_slerp( return result / result.norm(dim=-1, keepdim=True) -def _compute_clip_counts( - motion_ids: torch.Tensor, episode_failed: torch.Tensor, bin_count: int -) -> tuple[torch.Tensor, torch.Tensor]: - """Return (exposure_count, failure_count) per clip bin.""" - exposure = torch.bincount(motion_ids, minlength=bin_count).float() - failure = torch.bincount( - motion_ids[episode_failed], minlength=bin_count - ).float() - return exposure, failure - - -def _compute_clip_failure_rate( - motion_ids: torch.Tensor, episode_failed: torch.Tensor, bin_count: int -) -> torch.Tensor: - """Compute per-clip failure rate for the current adaptive-sampling window.""" - exposure, failure = _compute_clip_counts(motion_ids, episode_failed, bin_count) - rate = torch.zeros(bin_count, dtype=torch.float32, device=motion_ids.device) - valid = exposure > 0 - rate[valid] = failure[valid] / exposure[valid] - return rate - - -def _is_distributed() -> bool: - """Return True when running inside an initialized torch.distributed group.""" - return torch.distributed.is_available() and torch.distributed.is_initialized() +@dataclass +class _MotionBatch: + tensors: dict[str, torch.Tensor] + frame_offsets: torch.Tensor + lengths: torch.Tensor + fps: torch.Tensor + sample_starts: torch.Tensor + sample_ends: torch.Tensor -def _normalize_sampling_probabilities( - sampling_probabilities: torch.Tensor, - *, - adaptive_uniform_ratio: float, - bin_count: int, +def _read_selected_body_array( + h5: h5py.File, + key: str, + body_idx_np: np.ndarray, ) -> torch.Tensor: - """Normalize adaptive sampling weights and fail fast on invalid mass.""" - prob_sum = sampling_probabilities.sum() - if not torch.isfinite(prob_sum) or prob_sum <= 0: - raise ValueError( - "Adaptive sampling produced an invalid probability mass. " - f"sum={prob_sum.item() if torch.isfinite(prob_sum) else prob_sum}, " - f"adaptive_uniform_ratio={adaptive_uniform_ratio}, bin_count={bin_count}. " - "Increase adaptive_uniform_ratio or accumulate failure statistics before " - "using pure adaptive sampling." - ) - - sampling_probabilities = sampling_probabilities / prob_sum - if not torch.isfinite(sampling_probabilities).all(): - raise ValueError("Adaptive sampling probabilities contain NaN or Inf values.") - return sampling_probabilities - - -def _cap_failure_rates(rates: torch.Tensor, beta: float) -> torch.Tensor: - """Clamp per-bin failure rates to ``beta * mean(rates)``. - - When all rates are zero the cap is zero and the returned tensor is all-zero, - which is safe — callers fall back to uniform sampling via the blend term. - """ - f_mean = rates.mean() - return torch.clamp(rates, max=beta * f_mean) - - -def _validate_legacy_adaptive_config(*, adaptive_kernel_size: int, adaptive_lambda: float) -> None: - """Reject legacy adaptive knobs that are no longer implemented.""" - if adaptive_kernel_size != 1: - raise ValueError( - "adaptive_kernel_size is not implemented in the restored adaptive sampler. " - f"Expected 1, got {adaptive_kernel_size}." - ) - if adaptive_lambda != 0.8: - raise ValueError( - "adaptive_lambda is not implemented in the restored adaptive sampler. " - f"Expected 0.8, got {adaptive_lambda}." + """Read only the requested body axis while preserving caller order.""" + dataset = h5[key] + idx = np.asarray(body_idx_np, dtype=np.int64) + if idx.ndim != 1: + raise ValueError(f"body indexes must be 1-D, got {idx.shape}") + if idx.size == 0: + frames = int(dataset.shape[0]) + width = int(dataset.shape[2]) + return torch.empty((frames, 0, width), dtype=torch.float32) + + if np.any(idx < 0) or np.any(idx >= dataset.shape[1]): + raise IndexError( + f"body indexes out of range for {key}: indexes={idx.tolist()}, " + f"num_bodies={dataset.shape[1]}" ) + parts = [ + np.asarray(dataset[:, int(body_idx) : int(body_idx) + 1, :], dtype=np.float32) + for body_idx in idx + ] + arr = np.concatenate(parts, axis=1) + return torch.from_numpy(arr) -def _load_shard_dir(shard_dir: Path) -> dict[str, Any]: - """Load and merge all shard NPZ files from a directory. - - Each shard must be a self-contained merged NPZ with clip metadata. - This function concatenates motion arrays across shards and rebuilds the - clip-level metadata (``clip_starts`` offsets are adjusted). - """ - shard_files = sorted(shard_dir.glob("*.npz")) - if not shard_files: - raise FileNotFoundError(f"No .npz files found in {shard_dir}") - _LOG.info("Loading %d shards from %s ...", len(shard_files), shard_dir) +def _load_all_precomputed_motion( + motion_dir: Path, + *, + body_idx_np: np.ndarray, + device: str, + window_steps: tuple[int, ...], +) -> _MotionBatch: + shard_paths = find_precomputed_motion_shards(motion_dir) + stats = compute_dataset_stats(motion_dir, precomputed=True) + for shard_path in shard_paths: + validate_precomputed_motion_shard(shard_path) + _LOG.info( + "Loading full motion dataset: root=%s shards=%d windows=%d source_clips=%d frames=%d fps=%s", + motion_dir, + stats["shards"], + stats["windows"], + stats["source_clips"], + stats["frames"], + stats["fps"], + ) - motion_keys = [ - "joint_pos", "joint_vel", "body_pos_w", "body_quat_w", - "body_lin_vel_w", "body_ang_vel_w", - ] - arrays: dict[str, list[np.ndarray]] = {k: [] for k in motion_keys} - all_clip_lengths: list[np.ndarray] = [] - all_clip_fps: list[np.ndarray] = [] - all_clip_weights: list[np.ndarray] = [] - all_clip_sample_starts: list[np.ndarray] = [] - all_clip_sample_ends: list[np.ndarray] = [] - window_steps: np.ndarray | None = None - fps = None - body_names: np.ndarray | None = None - - for sf in shard_files: - d = np.load(sf, allow_pickle=True) - for k in motion_keys: - arrays[k].append(np.asarray(d[k])) - - # --- validate metadata consistency across shards --- - cur_fps = d["fps"] - cur_body_names = np.asarray(d["body_names"]) - if fps is None: - fps = cur_fps - body_names = cur_body_names - else: - if int(cur_fps) != int(fps): - raise ValueError( - f"Inconsistent fps across shards: {sf} has {int(cur_fps)}, " - f"expected {int(fps)}" - ) - if not np.array_equal(cur_body_names, body_names): + max_future = max((step for step in window_steps if step > 0), default=0) + max_history = -min((step for step in window_steps if step < 0), default=0) + min_clip_length = max_history + 1 + max_future + 1 # +1 for interpolation + + arrays: dict[str, list[torch.Tensor]] = {key: [] for key in MOTION_ARRAY_KEYS} + lengths_out: list[int] = [] + fps_out: list[int] = [] + sample_starts_out: list[int] = [] + sample_ends_out: list[int] = [] + frame_offsets_out: list[int] = [] + skipped_short = 0 + loaded_frames = 0 + for shard_path in shard_paths: + with h5py.File(shard_path, "r") as h5: + starts = np.asarray(h5["clip_starts"], dtype=np.int64) + lengths = np.asarray(h5["clip_lengths"], dtype=np.int64) + fps = np.asarray(h5["clip_fps"], dtype=np.int64) + frames = int(h5["joint_pos"].shape[0]) + if np.any(starts < 0) or np.any(starts + lengths > frames): raise ValueError( - f"Inconsistent body_names across shards: {sf} differs from first shard" + f"HDF5 shard {shard_path} has clip windows outside joint_pos " + f"frame range: frames={frames}" ) - if "clip_lengths" not in d or "clip_fps" not in d or "clip_weights" not in d: - raise ValueError( - f"Shard {sf} is missing required clip metadata. " - "Expected clip_lengths, clip_fps, and clip_weights." + valid_mask = lengths >= min_clip_length + skipped_short += int(np.count_nonzero(~valid_mask)) + if not np.any(valid_mask): + continue + + shard_frame_base = loaded_frames + arrays["joint_pos"].append( + torch.from_numpy(np.asarray(h5["joint_pos"], dtype=np.float32)) ) - shard_clip_lengths = np.asarray(d["clip_lengths"]) - n_clips = len(shard_clip_lengths) - all_clip_lengths.append(shard_clip_lengths) - all_clip_fps.append(np.asarray(d["clip_fps"])) - all_clip_weights.append(np.asarray(d["clip_weights"])) - - if "clip_sample_starts" in d: - all_clip_sample_starts.append(np.asarray(d["clip_sample_starts"])) - if "clip_sample_ends" in d: - all_clip_sample_ends.append(np.asarray(d["clip_sample_ends"])) - if "window_steps" in d and window_steps is None: - window_steps = np.asarray(d["window_steps"]) - - merged: dict[str, Any] = {k: np.concatenate(v, axis=0) for k, v in arrays.items()} - merged["fps"] = fps - merged["body_names"] = body_names - - clip_lengths = np.concatenate(all_clip_lengths) - clip_starts = np.zeros(len(clip_lengths), dtype=np.int64) - if len(clip_lengths) > 1: - clip_starts[1:] = np.cumsum(clip_lengths[:-1]) - merged["clip_starts"] = clip_starts - merged["clip_lengths"] = clip_lengths - merged["clip_fps"] = np.concatenate(all_clip_fps) - merged["clip_weights"] = np.concatenate(all_clip_weights) - - # Propagate precomputed sample ranges if all shards had them - n_total_clips = len(clip_lengths) - if all_clip_sample_starts and sum(len(a) for a in all_clip_sample_starts) == n_total_clips: - merged["clip_sample_starts"] = np.concatenate(all_clip_sample_starts) - if all_clip_sample_ends and sum(len(a) for a in all_clip_sample_ends) == n_total_clips: - merged["clip_sample_ends"] = np.concatenate(all_clip_sample_ends) - if window_steps is not None: - merged["window_steps"] = window_steps + arrays["joint_vel"].append( + torch.from_numpy(np.asarray(h5["joint_vel"], dtype=np.float32)) + ) + arrays["body_pos_w"].append( + _read_selected_body_array(h5, "body_pos_w", body_idx_np) + ) + arrays["body_quat_w"].append( + _read_selected_body_array(h5, "body_quat_w", body_idx_np) + ) + arrays["body_lin_vel_w"].append( + _read_selected_body_array(h5, "body_lin_vel_w", body_idx_np) + ) + arrays["body_ang_vel_w"].append( + _read_selected_body_array(h5, "body_ang_vel_w", body_idx_np) + ) + loaded_frames += frames + + valid_starts = starts[valid_mask] + valid_lengths = lengths[valid_mask] + valid_fps = fps[valid_mask] + sample_starts, sample_ends = compute_clip_sample_ranges( + valid_lengths, + window_steps=window_steps, + ) + valid_frame_offsets = (shard_frame_base + valid_starts).astype(np.int64) + frame_offsets_out.extend(int(offset) for offset in valid_frame_offsets) + lengths_out.extend(int(length) for length in valid_lengths) + fps_out.extend(int(cur_fps) for cur_fps in valid_fps) + sample_starts_out.extend(int(start) for start in sample_starts) + sample_ends_out.extend(int(end) for end in sample_ends) + if not lengths_out: + raise ValueError(f"HDF5 motion dataset is empty: {motion_dir}") + if skipped_short > 0: + _LOG.warning( + "Ignoring %d HDF5 motion windows shorter than %d frames (window_steps=%s)", + skipped_short, + min_clip_length, + list(window_steps), + ) - _LOG.info( - "Loaded %d shards: %d clips, %d total frames", - len(shard_files), len(clip_lengths), merged["joint_pos"].shape[0], + device_obj = torch.device(device) + lengths = torch.tensor(lengths_out, dtype=torch.long) + frame_offsets = torch.tensor(frame_offsets_out, dtype=torch.long) + return _MotionBatch( + tensors={ + key: torch.cat(values, dim=0).to(device_obj) + for key, values in arrays.items() + }, + frame_offsets=frame_offsets.to(device_obj), + lengths=lengths.to(device_obj), + fps=torch.tensor(fps_out, dtype=torch.float32, device=device_obj), + sample_starts=torch.tensor(sample_starts_out, dtype=torch.long, device=device_obj), + sample_ends=torch.tensor(sample_ends_out, dtype=torch.long, device=device_obj), ) - return merged class MotionLib: """Clip-aware motion library. - Loads a directory of shard NPZ files. Each shard contains flat motion arrays - plus per-clip metadata (``clip_starts``, ``clip_lengths``, ``clip_fps``, - ``clip_weights``). - - Motion data is stored as GPU tensors for fast gather+lerp interpolation. - All indexing, lerp, and slerp run entirely on device with zero CPU - round-trips. Numpy arrays are kept alongside for external consumers - (e.g. runner checkpoint buffers). + Loads all precomputed HDF5 motion windows into memory at startup. """ def __init__( self, motion_file: str, body_indexes: torch.Tensor, + body_names: tuple[str, ...] | list[str] | None = None, device: str = "cpu", window_steps: tuple[int, ...] | list[int] | None = None, ) -> None: @@ -252,122 +232,83 @@ def __init__( self.window_steps = parse_window_steps(window_steps) motion_path = Path(motion_file) - if not motion_path.is_dir(): + if not motion_path.exists(): raise FileNotFoundError( - f"motion_file must be a shard directory, got: {motion_file}" - ) - data = _load_shard_dir(motion_path) - body_idx_np = body_indexes.cpu().numpy() - - self._joint_pos = np.asarray(data["joint_pos"], dtype=np.float32) # (T, 29) - self._joint_vel = np.asarray(data["joint_vel"], dtype=np.float32) # (T, 29) - - # Body arrays: index by selected bodies immediately. Accessing an - # NpzFile key inflates that array from the zip; the intermediate full - # array is released once we slice and discard the reference. - self._body_pos_w = np.asarray( - data["body_pos_w"], dtype=np.float32 - )[:, body_idx_np] - self._body_quat_w = np.asarray( - data["body_quat_w"], dtype=np.float32 - )[:, body_idx_np] - self._body_lin_vel_w = np.asarray( - data["body_lin_vel_w"], dtype=np.float32 - )[:, body_idx_np] - self._body_ang_vel_w = np.asarray( - data["body_ang_vel_w"], dtype=np.float32 - )[:, body_idx_np] - - self.time_step_total = self._joint_pos.shape[0] - - # GPU tensors for fast gather+lerp interpolation (zero CPU round-trips). - self._joint_pos_t = torch.from_numpy(self._joint_pos).to(device) - self._joint_vel_t = torch.from_numpy(self._joint_vel).to(device) - self._body_pos_w_t = torch.from_numpy(self._body_pos_w).to(device) - self._body_quat_w_t = torch.from_numpy(self._body_quat_w).to(device) - self._body_lin_vel_w_t = torch.from_numpy(self._body_lin_vel_w).to(device) - self._body_ang_vel_w_t = torch.from_numpy(self._body_ang_vel_w).to(device) - - # --- clip-aware metadata (small — lives on GPU for sampling) --- - self.clip_starts = torch.tensor(data["clip_starts"], dtype=torch.long, device=device) - self.clip_lengths = torch.tensor(data["clip_lengths"], dtype=torch.long, device=device) - self.clip_weights = torch.tensor( - data["clip_weights"], dtype=torch.float32, device=device - ) - fps_arr = np.asarray(data["clip_fps"]) - if fps_arr.ndim == 0: - self.clip_fps = torch.full( - (len(self.clip_starts),), float(fps_arr), dtype=torch.float32, device=device + f"motion_file must be a dataset root directory or .h5 shard, got: {motion_file}" ) - else: - self.clip_fps = torch.tensor(fps_arr, dtype=torch.float32, device=device) + stats = compute_dataset_stats(motion_path, precomputed=True) - self.num_clips = len(self.clip_starts) - self.clip_dt = 1.0 / self.clip_fps - self.clip_duration_s = (self.clip_lengths.float() - 1.0) * self.clip_dt - - # Compute minimum clip length required by the window - max_future = max((s for s in self.window_steps if s > 0), default=0) - max_history = -min((s for s in self.window_steps if s < 0), default=0) - min_clip_length = max_history + 1 + max_future + 1 # +1 for interpolation - short_mask = self.clip_lengths < min_clip_length - n_short = int(short_mask.sum().item()) - if n_short > 0: - _LOG.warning( - "Disabling %d/%d clips shorter than %d frames (window_steps=%s)", - n_short, self.num_clips, min_clip_length, list(self.window_steps), - ) - self.clip_weights = self.clip_weights.clone() - self.clip_weights[short_mask] = 0.0 - - file_window_steps = parse_window_steps(data["window_steps"]) if "window_steps" in data else (0,) - if ( - "clip_sample_starts" in data - and "clip_sample_ends" in data - and file_window_steps == self.window_steps - ): - self.clip_sample_starts = torch.tensor( - data["clip_sample_starts"], dtype=torch.long, device=device - ) - self.clip_sample_ends = torch.tensor( - data["clip_sample_ends"], dtype=torch.long, device=device - ) + if body_names is None: + body_idx_np = body_indexes.cpu().numpy() else: - # Only compute ranges for clips long enough; short ones get dummy [0,1) - lengths_np = self.clip_lengths.cpu().numpy() - sample_starts = np.zeros(self.num_clips, dtype=np.int64) - sample_ends = np.ones(self.num_clips, dtype=np.int64) - long_mask = ~short_mask.cpu().numpy() - if long_mask.any(): - s, e = compute_clip_sample_ranges( - lengths_np[long_mask], - window_steps=self.window_steps, + dataset_body_names = [str(name) for name in stats["body_names"]] + dataset_body_index_by_name = { + name: index for index, name in enumerate(dataset_body_names) + } + missing_body_names = [ + name for name in body_names if name not in dataset_body_index_by_name + ] + if missing_body_names: + raise ValueError( + "Motion dataset body_names do not contain all requested tracking " + f"bodies. Missing: {missing_body_names}. " + "Rebuild the dataset with the current G1 body metadata or update " + "motion command body_names." ) - sample_starts[long_mask] = s - sample_ends[long_mask] = e - self.clip_sample_starts = torch.tensor( - sample_starts, dtype=torch.long, device=device - ) - self.clip_sample_ends = torch.tensor( - sample_ends, dtype=torch.long, device=device + body_idx_np = np.asarray( + [dataset_body_index_by_name[name] for name in body_names], + dtype=np.int64, ) + + batch = _load_all_precomputed_motion( + motion_path, + body_idx_np=body_idx_np, + device=device, + window_steps=self.window_steps, + ) + self._set_batch(batch) + + def _set_batch(self, batch: _MotionBatch) -> None: + self._batch = batch + self._joint_pos_t = batch.tensors["joint_pos"] + self._joint_vel_t = batch.tensors["joint_vel"] + self._body_pos_w_t = batch.tensors["body_pos_w"] + self._body_quat_w_t = batch.tensors["body_quat_w"] + self._body_lin_vel_w_t = batch.tensors["body_lin_vel_w"] + self._body_ang_vel_w_t = batch.tensors["body_ang_vel_w"] + self.clip_frame_offsets = batch.frame_offsets + + self.clip_lengths = batch.lengths + self.clip_fps = batch.fps + self.num_clips = int(batch.lengths.shape[0]) + self.time_step_total = int(batch.lengths.max().item()) + self.clip_dt = 1.0 / self.clip_fps + self.clip_duration_s = (self.clip_lengths.float() - 1.0) * self.clip_dt + self.clip_sample_starts = batch.sample_starts + self.clip_sample_ends = batch.sample_ends self.clip_sample_start_s = self.clip_sample_starts.float() * self.clip_dt self.clip_sample_end_s = self.clip_sample_ends.float() * self.clip_dt + self.clip_starts = self.clip_frame_offsets + ref_valid_seconds = ( + (self.clip_sample_ends - self.clip_sample_starts).float() / self.clip_fps + ) + if torch.any(ref_valid_seconds <= 0.0): + raise ValueError( + "HDF5 motion dataset contains windows with no valid sample duration " + f"after applying window_steps={list(self.window_steps)}" + ) + self.sample_weights = ref_valid_seconds / torch.sum(ref_valid_seconds) + + def close(self) -> None: + return None # ------------------------------------------------------------------ # Sampling helpers # ------------------------------------------------------------------ def sample_motion_ids(self, n: int) -> torch.Tensor: - """Sample *n* clip indices weighted by ``clip_weights``.""" - total = self.clip_weights.sum() - if total <= 0: - raise ValueError( - "All clip weights are zero — cannot sample. " - "Check that the shard dataset was built with positive weights." - ) - probs = self.clip_weights / total - return torch.multinomial(probs, n, replacement=True) + """Sample *n* clip indices by valid-duration weights.""" + return torch.multinomial(self.sample_weights, int(n), replacement=True) def sample_times(self, motion_ids: torch.Tensor) -> torch.Tensor: """Uniform random time over valid center frames for each motion id.""" @@ -386,69 +327,6 @@ def sample_start_times(self, motion_ids: torch.Tensor) -> torch.Tensor: """Return the earliest valid center time for each motion id.""" return self.clip_sample_start_s[motion_ids] - def build_time_bins( - self, bin_duration_s: float - ) -> dict[str, torch.Tensor | int]: - """Partition clips into fixed-duration time bins. - - Each clip is independently split into bins of ``bin_duration_s`` seconds. - The last bin in a clip may be shorter than ``bin_duration_s``. - - Returns a dict with: - - ``num_bins``: total number of bins across all clips. - - ``bin_clip_id``: (num_bins,) which clip each bin belongs to. - - ``bin_start_s``: (num_bins,) start time (seconds) within the clip. - - ``bin_end_s``: (num_bins,) end time (seconds) within the clip. - - ``bin_duration``: (num_bins,) actual duration of each bin. - - ``clip_bin_offset``: (num_clips,) index of the first bin for each clip - (-1 for zero-weight clips). - """ - clip_ids: list[int] = [] - starts: list[float] = [] - ends: list[float] = [] - clip_bin_offsets = torch.full( - (self.num_clips,), -1, dtype=torch.long, device=self._device - ) - - for i in range(self.num_clips): - w = self.clip_weights[i].item() - if w <= 0: - continue - t0 = self.clip_sample_start_s[i].item() - t1 = self.clip_sample_end_s[i].item() - sampleable = t1 - t0 - if sampleable <= 0: - continue - - clip_bin_offsets[i] = len(clip_ids) - n_bins = math.ceil(sampleable / bin_duration_s) - for j in range(n_bins): - b_start = t0 + j * bin_duration_s - b_end = min(t0 + (j + 1) * bin_duration_s, t1) - clip_ids.append(i) - starts.append(b_start) - ends.append(b_end) - - if not clip_ids: - raise ValueError( - "No valid time bins could be created — all clips have zero " - "weight or zero sampleable range." - ) - - bin_clip_id = torch.tensor(clip_ids, dtype=torch.long, device=self._device) - bin_start_s = torch.tensor(starts, dtype=torch.float32, device=self._device) - bin_end_s = torch.tensor(ends, dtype=torch.float32, device=self._device) - bin_duration = bin_end_s - bin_start_s - - return { - "num_bins": len(clip_ids), - "bin_clip_id": bin_clip_id, - "bin_start_s": bin_start_s, - "bin_end_s": bin_end_s, - "bin_duration": bin_duration, - "clip_bin_offset": clip_bin_offsets, - } - def _compute_interpolation_state( self, motion_ids: torch.Tensor, @@ -456,7 +334,6 @@ def _compute_interpolation_state( steps: tuple[int, ...], ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: fps = self.clip_fps[motion_ids] - starts = self.clip_starts[motion_ids] lengths = self.clip_lengths[motion_ids] durations = self.clip_duration_s[motion_ids] @@ -475,10 +352,8 @@ def _compute_interpolation_state( frame_i0 = torch.clamp(frame_i0, min=zero, max=max_frame) frame_i1 = torch.clamp(frame_i1, min=zero, max=max_frame) - idx0 = (starts[:, None] + frame_i0).reshape(-1) - idx1 = (starts[:, None] + frame_i1).reshape(-1) window = len(steps) - return idx0, idx1, alpha.reshape(-1), window + return frame_i0.reshape(-1), frame_i1.reshape(-1), alpha.reshape(-1), window _ALL_KEYS = frozenset(( "joint_pos", "joint_vel", @@ -511,6 +386,9 @@ def get_window_frames( steps, ) batch = motion_ids.shape[0] + frame_offsets = self.clip_frame_offsets[motion_ids] + flat_idx0 = (frame_offsets[:, None] + idx0.reshape(batch, window)).reshape(-1) + flat_idx1 = (frame_offsets[:, None] + idx1.reshape(batch, window)).reshape(-1) want = self._ALL_KEYS if keys is None else keys result: dict[str, torch.Tensor] = {} @@ -520,7 +398,7 @@ def get_window_frames( for key, arr_t in (("joint_pos", self._joint_pos_t), ("joint_vel", self._joint_vel_t)): if key not in want: continue - v0, v1 = arr_t[idx0], arr_t[idx1] + v0, v1 = arr_t[flat_idx0], arr_t[flat_idx1] result[key] = (v0 + a1 * (v1 - v0)).reshape(batch, window, -1) # body arrays: (T, B, D) — GPU gather + lerp, optionally pre-slice bodies @@ -533,20 +411,21 @@ def get_window_frames( if key not in want: continue if body_indices is not None: - v0, v1 = arr_t[idx0][:, body_indices], arr_t[idx1][:, body_indices] + v0 = arr_t[flat_idx0][:, body_indices] + v1 = arr_t[flat_idx1][:, body_indices] else: - v0, v1 = arr_t[idx0], arr_t[idx1] + v0, v1 = arr_t[flat_idx0], arr_t[flat_idx1] interp = v0 + a2 * (v1 - v0) result[key] = interp.reshape(batch, window, *interp.shape[1:]) # body_quat_w: GPU slerp, optionally pre-slice bodies if "body_quat_w" in want: if body_indices is not None: - q0 = self._body_quat_w_t[idx0][:, body_indices] - q1 = self._body_quat_w_t[idx1][:, body_indices] + q0 = self._body_quat_w_t[flat_idx0][:, body_indices] + q1 = self._body_quat_w_t[flat_idx1][:, body_indices] else: - q0 = self._body_quat_w_t[idx0] - q1 = self._body_quat_w_t[idx1] + q0 = self._body_quat_w_t[flat_idx0] + q1 = self._body_quat_w_t[flat_idx1] nb = q0.shape[1] q0_flat = q0.reshape(-1, 4) q1_flat = q1.reshape(-1, 4) @@ -574,9 +453,18 @@ def get_frames( key: value[:, 0] for key, value in windowed.items() } - -# Backward compatibility alias -MotionLoader = MotionLib +def _validate_rewind_sampling_cfg(cfg: Any) -> None: + if cfg.rewind_min_steps < 0: + raise ValueError( + f"rewind_min_steps must be non-negative, got {cfg.rewind_min_steps}" + ) + if cfg.rewind_max_steps < cfg.rewind_min_steps: + raise ValueError( + "rewind_max_steps must be >= rewind_min_steps, got " + f"{cfg.rewind_max_steps} < {cfg.rewind_min_steps}" + ) + if not 0.0 <= cfg.rewind_prob <= 1.0: + raise ValueError(f"rewind_prob must be in [0, 1], got {cfg.rewind_prob}") class MotionCommand(CommandTerm): @@ -585,10 +473,8 @@ class MotionCommand(CommandTerm): def __init__(self, cfg: MotionCommandCfg, env: ManagerBasedRlEnv): super().__init__(cfg, env) - _validate_legacy_adaptive_config( - adaptive_kernel_size=cfg.adaptive_kernel_size, - adaptive_lambda=cfg.adaptive_lambda, - ) + if self.cfg.sampling_mode == "rewind": + _validate_rewind_sampling_cfg(self.cfg) self.robot: Entity = env.scene[cfg.entity_name] self.robot_anchor_body_index = self.robot.body_names.index( @@ -604,6 +490,7 @@ def __init__(self, cfg: MotionCommandCfg, env: ManagerBasedRlEnv): self.motion = MotionLib( self.cfg.motion_file, self.body_indexes, + body_names=self.cfg.body_names, device=self.device, window_steps=self.cfg.window_steps, ) @@ -632,50 +519,6 @@ def __init__(self, cfg: MotionCommandCfg, env: ManagerBasedRlEnv): self.body_lin_vel_b = torch.zeros(self.num_envs, nb, 3, device=self.device) self.body_ang_vel_b = torch.zeros(self.num_envs, nb, 3, device=self.device) - self.bin_count = max(self.motion.num_clips, 1) - self.bin_failed_rate = torch.zeros( - self.bin_count, dtype=torch.float, device=self.device - ) - self._accum_exposure_count = torch.zeros( - self.bin_count, dtype=torch.float, device=self.device - ) - self._accum_failure_count = torch.zeros( - self.bin_count, dtype=torch.float, device=self.device - ) - self.kernel = torch.ones(1, device=self.device) - - # --- adaptive_bin mode: time-bin data structures --- - if self.cfg.sampling_mode == "adaptive_bin": - bin_data = self.motion.build_time_bins(self.cfg.adaptive_bin_duration_s) - self.num_time_bins: int = bin_data["num_bins"] - self.tb_clip_id: torch.Tensor = bin_data["bin_clip_id"] - self.tb_start_s: torch.Tensor = bin_data["bin_start_s"] - self.tb_end_s: torch.Tensor = bin_data["bin_end_s"] - self.tb_duration: torch.Tensor = bin_data["bin_duration"] - self.tb_clip_bin_offset: torch.Tensor = bin_data["clip_bin_offset"] - - self.tb_failed_rate = torch.zeros( - self.num_time_bins, dtype=torch.float, device=self.device - ) - self._tb_accum_exposure = torch.zeros( - self.num_time_bins, dtype=torch.float, device=self.device - ) - self._tb_accum_failure = torch.zeros( - self.num_time_bins, dtype=torch.float, device=self.device - ) - self._env_bin_ids = torch.zeros( - self.num_envs, dtype=torch.long, device=self.device - ) - # Override bin_count so metrics (entropy, top1) use time bins - self.bin_count = self.num_time_bins - _LOG.info( - "adaptive_bin: %d time bins (duration=%.1fs, beta=%.1f, alpha=%.2f)", - self.num_time_bins, - self.cfg.adaptive_bin_duration_s, - self.cfg.adaptive_bin_beta, - self.cfg.adaptive_bin_alpha, - ) - self.metrics["error_anchor_pos"] = torch.zeros(self.num_envs, device=self.device) self.metrics["error_anchor_rot"] = torch.zeros(self.num_envs, device=self.device) self.metrics["error_anchor_lin_vel"] = torch.zeros( @@ -688,11 +531,7 @@ def __init__(self, cfg: MotionCommandCfg, env: ManagerBasedRlEnv): self.metrics["error_body_rot"] = torch.zeros(self.num_envs, device=self.device) self.metrics["error_joint_pos"] = torch.zeros(self.num_envs, device=self.device) self.metrics["error_joint_vel"] = torch.zeros(self.num_envs, device=self.device) - self.metrics["sampling_entropy"] = torch.zeros(self.num_envs, device=self.device) - self.metrics["sampling_top1_prob"] = torch.zeros(self.num_envs, device=self.device) - self.metrics["sampling_top1_bin"] = torch.zeros(self.num_envs, device=self.device) - # Feet standing state (for feet_air_time_ref rewards) if self.cfg.feet_body_names: self._feet_body_indexes = [ self.cfg.body_names.index(n) for n in self.cfg.feet_body_names @@ -932,110 +771,44 @@ def _update_metrics(self): # Sampling # ------------------------------------------------------------------ - def _adaptive_sampling(self, env_ids: torch.Tensor): - current_motion_ids = self.motion_ids[env_ids] - episode_failed = self._env.termination_manager.terminated[env_ids] - exposure, failure = _compute_clip_counts( - current_motion_ids, episode_failed, self.bin_count - ) - self._accum_exposure_count += exposure - self._accum_failure_count += failure - - sampling_probabilities = ( - self.bin_failed_rate - + self.cfg.adaptive_uniform_ratio / float(self.bin_count) - ) - sampling_probabilities = _normalize_sampling_probabilities( - sampling_probabilities, - adaptive_uniform_ratio=self.cfg.adaptive_uniform_ratio, - bin_count=self.bin_count, - ) - - sampled_clips = torch.multinomial( - sampling_probabilities, len(env_ids), replacement=True - ) - self.motion_ids[env_ids] = sampled_clips - self.motion_times[env_ids] = self.motion.sample_times(sampled_clips) - - entropy = -(sampling_probabilities * (sampling_probabilities + 1e-12).log()).sum() - entropy_norm = entropy / math.log(self.bin_count) if self.bin_count > 1 else 1.0 - top1_prob, top1_bin = sampling_probabilities.max(dim=0) - self.metrics["sampling_entropy"][:] = entropy_norm - self.metrics["sampling_top1_prob"][:] = top1_prob - self.metrics["sampling_top1_bin"][:] = top1_bin.float() / self.bin_count - def _uniform_sampling(self, env_ids: torch.Tensor): self.motion_ids[env_ids] = self.motion.sample_motion_ids(len(env_ids)) self.motion_times[env_ids] = self.motion.sample_times(self.motion_ids[env_ids]) - self.metrics["sampling_entropy"][:] = 1.0 - self.metrics["sampling_top1_prob"][:] = 1.0 / max(self.bin_count, 1) - self.metrics["sampling_top1_bin"][:] = 0.5 - def _time_to_bin_index( - self, motion_ids: torch.Tensor, motion_times: torch.Tensor - ) -> torch.Tensor: - """Map (clip_id, time_in_clip) pairs to time-bin indices.""" - offsets = self.tb_clip_bin_offset[motion_ids] # first bin of each clip - local_time = motion_times - self.tb_start_s[offsets] - bin_within_clip = (local_time / self.cfg.adaptive_bin_duration_s).long() - # Clamp to valid range: last bin of each clip - # Total bins per clip = number of consecutive bins with same clip_id - # Simple approach: clamp so offset + bin_within_clip stays < num_time_bins - # and the clip_id matches - bin_within_clip = torch.clamp(bin_within_clip, min=0) - result = offsets + bin_within_clip - result = torch.clamp(result, max=self.num_time_bins - 1) - return result + def _rewind_sampling(self, env_ids: torch.Tensor) -> None: + _validate_rewind_sampling_cfg(self.cfg) + + previous_motion_ids = self.motion_ids[env_ids].clone() + previous_motion_times = self.motion_times[env_ids].clone() + self._uniform_sampling(env_ids) + if env_ids.numel() == 0 or self.cfg.rewind_prob <= 0.0: + return - def _adaptive_bin_sampling(self, env_ids: torch.Tensor): - """SONIC-inspired bin-based adaptive sampling.""" - # 1. Accumulate failure statistics per time bin - current_bin_ids = self._env_bin_ids[env_ids] episode_failed = self._env.termination_manager.terminated[env_ids] - exposure = torch.bincount( - current_bin_ids, minlength=self.num_time_bins - ).float() - failure = torch.bincount( - current_bin_ids[episode_failed], minlength=self.num_time_bins - ).float() - self._tb_accum_exposure += exposure - self._tb_accum_failure += failure - - # 2. Compute capped, blended probabilities - capped = _cap_failure_rates(self.tb_failed_rate, self.cfg.adaptive_bin_beta) - capped_sum = capped.sum() - if capped_sum > 0: - p_hat = capped / capped_sum - else: - p_hat = torch.ones_like(capped) / self.num_time_bins - - alpha = self.cfg.adaptive_bin_alpha - p_final = alpha * p_hat + (1.0 - alpha) / self.num_time_bins - # Normalize for numerical safety - p_final = p_final / p_final.sum() - - # 3. Sample bins, then sample frames within bins - sampled_bins = torch.multinomial(p_final, len(env_ids), replacement=True) - sampled_clip_ids = self.tb_clip_id[sampled_bins] - bin_starts = self.tb_start_s[sampled_bins] - bin_durs = self.tb_duration[sampled_bins] - sampled_times = bin_starts + torch.rand( - len(env_ids), device=self.device - ) * bin_durs - - self.motion_ids[env_ids] = sampled_clip_ids - self.motion_times[env_ids] = sampled_times - self._env_bin_ids[env_ids] = sampled_bins - - # 4. Update metrics - entropy = -(p_final * (p_final + 1e-12).log()).sum() - entropy_norm = ( - entropy / math.log(self.num_time_bins) if self.num_time_bins > 1 else 1.0 + use_rewind = episode_failed & ( + torch.rand(env_ids.numel(), device=self.device) < self.cfg.rewind_prob + ) + if not torch.any(use_rewind): + return + + rewind_env_ids = env_ids[use_rewind] + rewind_steps = torch.randint( + self.cfg.rewind_min_steps, + self.cfg.rewind_max_steps + 1, + (rewind_env_ids.numel(),), + device=self.device, + dtype=torch.long, ) - top1_prob, top1_bin = p_final.max(dim=0) - self.metrics["sampling_entropy"][:] = entropy_norm - self.metrics["sampling_top1_prob"][:] = top1_prob - self.metrics["sampling_top1_bin"][:] = top1_bin.float() / self.num_time_bins + rewind_s = rewind_steps.to(dtype=self.motion_times.dtype) * float(self._step_dt) + motion_ids = previous_motion_ids[use_rewind] + rewind_times = previous_motion_times[use_rewind] - rewind_s + rewind_times = torch.maximum( + rewind_times, + self.motion.clip_sample_start_s[motion_ids], + ) + + self.motion_ids[rewind_env_ids] = motion_ids + self.motion_times[rewind_env_ids] = rewind_times def _resample_command(self, env_ids: torch.Tensor): if self.cfg.sampling_mode == "start": @@ -1043,12 +816,17 @@ def _resample_command(self, env_ids: torch.Tensor): self.motion_times[env_ids] = self.motion.sample_start_times(self.motion_ids[env_ids]) elif self.cfg.sampling_mode == "uniform": self._uniform_sampling(env_ids) - elif self.cfg.sampling_mode == "adaptive_bin": - self._adaptive_bin_sampling(env_ids) + elif self.cfg.sampling_mode == "rewind": + self._rewind_sampling(env_ids) else: - assert self.cfg.sampling_mode == "adaptive" - self._adaptive_sampling(env_ids) + raise ValueError( + f"Unsupported motion sampling_mode={self.cfg.sampling_mode!r}. " + "Supported modes are 'uniform', 'start', and 'rewind'." + ) + self._reset_envs_to_current_reference(env_ids) + + def _reset_envs_to_current_reference(self, env_ids: torch.Tensor) -> None: if env_ids.numel() == 0: return @@ -1174,54 +952,6 @@ def _update_command(self): self._refresh_body_local_cache() self._update_feet_standing() - if self.cfg.sampling_mode == "adaptive": - # Sync raw counts across ranks for unified statistics - if _is_distributed(): - torch.distributed.all_reduce(self._accum_exposure_count) - torch.distributed.all_reduce(self._accum_failure_count) - - # Only update EMA when new data exists (fixes decay-on-empty-step) - if self._accum_exposure_count.sum() > 0: - valid = self._accum_exposure_count > 0 - global_rate = torch.zeros_like(self.bin_failed_rate) - global_rate[valid] = ( - self._accum_failure_count[valid] - / self._accum_exposure_count[valid] - ) - self.bin_failed_rate = ( - self.cfg.adaptive_alpha * global_rate - + (1 - self.cfg.adaptive_alpha) * self.bin_failed_rate - ) - - self._accum_exposure_count.zero_() - self._accum_failure_count.zero_() - - elif self.cfg.sampling_mode == "adaptive_bin": - # Keep bin ids in sync with current motion_times so that - # failures are attributed to the bin the agent is *in*, not the - # bin it was initially sampled from. - self._env_bin_ids = self._time_to_bin_index( - self.motion_ids, self.motion_times - ) - - if _is_distributed(): - torch.distributed.all_reduce(self._tb_accum_exposure) - torch.distributed.all_reduce(self._tb_accum_failure) - - if self._tb_accum_exposure.sum() > 0: - valid = self._tb_accum_exposure > 0 - global_rate = torch.zeros_like(self.tb_failed_rate) - global_rate[valid] = ( - self._tb_accum_failure[valid] / self._tb_accum_exposure[valid] - ) - ema = self.cfg.adaptive_bin_alpha_ema - self.tb_failed_rate = ( - ema * global_rate + (1 - ema) * self.tb_failed_rate - ) - - self._tb_accum_exposure.zero_() - self._tb_accum_failure.zero_() - # ------------------------------------------------------------------ # Visualization # ------------------------------------------------------------------ @@ -1306,17 +1036,11 @@ class MotionCommandCfg(CommandTermCfg): pose_range: dict[str, tuple[float, float]] = field(default_factory=dict) velocity_range: dict[str, tuple[float, float]] = field(default_factory=dict) joint_position_range: tuple[float, float] = (-0.52, 0.52) - adaptive_kernel_size: int = 1 - adaptive_lambda: float = 0.8 - adaptive_uniform_ratio: float = 0.1 - adaptive_alpha: float = 0.001 - sampling_mode: Literal["adaptive", "adaptive_bin", "uniform", "start"] = "uniform" - # --- adaptive_bin mode params --- - adaptive_bin_duration_s: float = 5.0 - adaptive_bin_beta: float = 5.0 - adaptive_bin_alpha: float = 0.8 - adaptive_bin_alpha_ema: float = 0.001 + sampling_mode: Literal["uniform", "start", "rewind"] = "rewind" window_steps: tuple[int, ...] = (0,) + rewind_prob: float = 0.8 + rewind_min_steps: int = 25 + rewind_max_steps: int = 75 feet_body_names: tuple[str, ...] = () feet_standing_z_threshold: float = 0.18 feet_standing_vxy_threshold: float = 0.2 diff --git a/train_mimic/tasks/tracking/mdp/observations.py b/train_mimic/tasks/tracking/mdp/observations.py index b3180c7f..a7484efa 100644 --- a/train_mimic/tasks/tracking/mdp/observations.py +++ b/train_mimic/tasks/tracking/mdp/observations.py @@ -18,7 +18,17 @@ from mjlab.envs import ManagerBasedRlEnv -def motion_anchor_pos_b(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: +def ref_joint_pos(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: + command = cast(MotionCommand, env.command_manager.get_term(command_name)) + return command.joint_pos + + +def ref_joint_vel(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: + command = cast(MotionCommand, env.command_manager.get_term(command_name)) + return command.joint_vel + + +def ref_anchor_pos_b(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: command = cast(MotionCommand, env.command_manager.get_term(command_name)) pos, _ = subtract_frame_transforms( @@ -31,7 +41,7 @@ def motion_anchor_pos_b(env: ManagerBasedRlEnv, command_name: str) -> torch.Tens return pos.view(env.num_envs, -1) -def motion_anchor_ori_b(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: +def ref_anchor_ori_b(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: command = cast(MotionCommand, env.command_manager.get_term(command_name)) _, ori = subtract_frame_transforms( @@ -44,7 +54,7 @@ def motion_anchor_ori_b(env: ManagerBasedRlEnv, command_name: str) -> torch.Tens return mat[..., :2].reshape(mat.shape[0], -1) -def robot_body_pos_b(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: +def robot_tracking_body_pos_b(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: command = cast(MotionCommand, env.command_manager.get_term(command_name)) num_bodies = len(command.cfg.body_names) @@ -58,7 +68,7 @@ def robot_body_pos_b(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: return pos_b.view(env.num_envs, -1) -def robot_body_ori_b(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: +def robot_tracking_body_ori_b(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: command = cast(MotionCommand, env.command_manager.get_term(command_name)) num_bodies = len(command.cfg.body_names) @@ -73,18 +83,17 @@ def robot_body_ori_b(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: # --------------------------------------------------------------------------- -# Velocity-command observation terms: reference velocities and projected -# gravity for the VelCmd task variant. +# Velocity-command observation terms: reference velocities and projected gravity. # --------------------------------------------------------------------------- -def ref_base_lin_vel_b(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: +def ref_anchor_lin_vel_b(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: """Reference anchor linear velocity in the robot's body frame. (N, 3)""" command = cast(MotionCommand, env.command_manager.get_term(command_name)) return quat_apply(quat_inv(command.robot_anchor_quat_w), command.anchor_lin_vel_w) -def ref_base_ang_vel_b(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: +def ref_anchor_ang_vel_b(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: """Reference anchor angular velocity in the robot's body frame. (N, 3)""" command = cast(MotionCommand, env.command_manager.get_term(command_name)) return quat_apply(quat_inv(command.robot_anchor_quat_w), command.anchor_ang_vel_w) @@ -101,19 +110,19 @@ def ref_projected_gravity_b(env: ManagerBasedRlEnv, command_name: str) -> torch. return quat_apply(quat_inv(command.anchor_quat_w), asset.data.gravity_vec_w) -def ref_base_height(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: - """Reference anchor height (z-coordinate). (N, 1) — critic privileged.""" +def ref_anchor_height(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: + """Reference anchor height (z-coordinate). (N, 1)""" command = cast(MotionCommand, env.command_manager.get_term(command_name)) return command.anchor_pos_w[:, 2:3] # --------------------------------------------------------------------------- # Yaw-only variants: use yaw_quat(robot_anchor_quat_w) to decouple -# roll/pitch from the coordinate transform, matching the TWIST2 approach. +# roll/pitch from the coordinate transform. # --------------------------------------------------------------------------- -def motion_anchor_pos_b_yaw(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: +def ref_anchor_pos_b_yaw(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: command = cast(MotionCommand, env.command_manager.get_term(command_name)) pos, _ = subtract_frame_transforms( @@ -126,7 +135,7 @@ def motion_anchor_pos_b_yaw(env: ManagerBasedRlEnv, command_name: str) -> torch. return pos.view(env.num_envs, -1) -def motion_anchor_ori_b_yaw(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: +def ref_anchor_ori_b_yaw(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: command = cast(MotionCommand, env.command_manager.get_term(command_name)) _, ori = subtract_frame_transforms( @@ -139,7 +148,7 @@ def motion_anchor_ori_b_yaw(env: ManagerBasedRlEnv, command_name: str) -> torch. return mat[..., :2].reshape(mat.shape[0], -1) -def robot_body_pos_b_yaw(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: +def robot_tracking_body_pos_b_yaw(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: command = cast(MotionCommand, env.command_manager.get_term(command_name)) num_bodies = len(command.cfg.body_names) @@ -154,7 +163,7 @@ def robot_body_pos_b_yaw(env: ManagerBasedRlEnv, command_name: str) -> torch.Ten return pos_b.view(env.num_envs, -1) -def robot_body_ori_b_yaw(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: +def robot_tracking_body_ori_b_yaw(env: ManagerBasedRlEnv, command_name: str) -> torch.Tensor: command = cast(MotionCommand, env.command_manager.get_term(command_name)) num_bodies = len(command.cfg.body_names) diff --git a/train_mimic/tasks/tracking/mdp/rewards.py b/train_mimic/tasks/tracking/mdp/rewards.py index 108bf46b..ff1bfe07 100644 --- a/train_mimic/tasks/tracking/mdp/rewards.py +++ b/train_mimic/tasks/tracking/mdp/rewards.py @@ -4,11 +4,8 @@ import torch -from mjlab.entity import Entity -from mjlab.managers.reward_manager import RewardTermCfg -from mjlab.managers.scene_entity_config import SceneEntityCfg -from mjlab.sensor import ContactSensor from mjlab.utils.lab_api.math import ( + quat_apply, quat_error_magnitude, ) @@ -18,9 +15,6 @@ from mjlab.envs import ManagerBasedRlEnv -_DEFAULT_ASSET_CFG = SceneEntityCfg("robot") - - def _get_body_indexes( command: MotionCommand, body_names: tuple[str, ...] | None ) -> list[int]: @@ -31,44 +25,41 @@ def _get_body_indexes( ] -def survival(env: ManagerBasedRlEnv) -> torch.Tensor: - """Match the sibling motion_tracking task's constant alive reward.""" - return torch.ones(env.num_envs, dtype=torch.float32, device=env.device) - - -def joint_pos_tracking_exp( +def motion_global_anchor_position_error_exp( env: ManagerBasedRlEnv, command_name: str, std: float ) -> torch.Tensor: - """Joint position tracking reward using MotionCommand interface.""" command = cast(MotionCommand, env.command_manager.get_term(command_name)) - error = torch.mean(torch.abs(command.joint_pos - command.robot_joint_pos), dim=-1) + error = torch.sum( + torch.square(command.anchor_pos_w - command.robot_anchor_pos_w), dim=-1 + ) return torch.exp(-error / std**2) -def joint_vel_tracking_exp( +def motion_global_anchor_orientation_error_exp( env: ManagerBasedRlEnv, command_name: str, std: float ) -> torch.Tensor: - """Joint velocity tracking reward using MotionCommand interface.""" command = cast(MotionCommand, env.command_manager.get_term(command_name)) - error = torch.mean(torch.abs(command.joint_vel - command.robot_joint_vel), dim=-1) + error = quat_error_magnitude(command.anchor_quat_w, command.robot_anchor_quat_w) ** 2 return torch.exp(-error / std**2) -def motion_global_anchor_position_error_exp( +def motion_global_anchor_linear_velocity_error_exp( env: ManagerBasedRlEnv, command_name: str, std: float ) -> torch.Tensor: command = cast(MotionCommand, env.command_manager.get_term(command_name)) error = torch.sum( - torch.square(command.anchor_pos_w - command.robot_anchor_pos_w), dim=-1 + torch.square(command.anchor_lin_vel_w - command.robot_anchor_lin_vel_w), dim=-1 ) return torch.exp(-error / std**2) -def motion_global_anchor_orientation_error_exp( +def motion_global_anchor_angular_velocity_error_exp( env: ManagerBasedRlEnv, command_name: str, std: float ) -> torch.Tensor: command = cast(MotionCommand, env.command_manager.get_term(command_name)) - error = quat_error_magnitude(command.anchor_quat_w, command.robot_anchor_quat_w) ** 2 + error = torch.sum( + torch.square(command.anchor_ang_vel_w - command.robot_anchor_ang_vel_w), dim=-1 + ) return torch.exp(-error / std**2) @@ -90,6 +81,41 @@ def motion_relative_body_position_error_exp( return torch.exp(-error.mean(-1) / std**2) +def motion_relative_body_point_position_error_exp( + env: ManagerBasedRlEnv, + command_name: str, + std: float, + body_names: tuple[str, ...], + body_offsets: tuple[tuple[float, float, float], ...], +) -> torch.Tensor: + command = cast(MotionCommand, env.command_manager.get_term(command_name)) + body_index_by_name = {name: i for i, name in enumerate(command.cfg.body_names)} + missing_body_names = [name for name in body_names if name not in body_index_by_name] + if missing_body_names: + raise ValueError( + "body_names must exist in the motion command tracking body list: " + f"missing {missing_body_names}" + ) + if len(body_names) != len(body_offsets): + raise ValueError( + "body_offsets must contain one offset for each selected body: " + f"got {len(body_offsets)} offsets for {len(body_names)} bodies" + ) + + body_indexes = [body_index_by_name[name] for name in body_names] + offsets = torch.tensor( + body_offsets, dtype=command.body_pos_relative_w.dtype, device=env.device + ).expand(env.num_envs, -1, -1) + ref_points = command.body_pos_relative_w[:, body_indexes] + quat_apply( + command.body_quat_relative_w[:, body_indexes], offsets + ) + robot_points = command.robot_body_pos_w[:, body_indexes] + quat_apply( + command.robot_body_quat_w[:, body_indexes], offsets + ) + error = torch.sum(torch.square(ref_points - robot_points), dim=-1) + return torch.exp(-error.mean(-1) / std**2) + + def motion_relative_body_orientation_error_exp( env: ManagerBasedRlEnv, command_name: str, @@ -144,364 +170,59 @@ def motion_global_body_angular_velocity_error_exp( return torch.exp(-error.mean(-1) / std**2) -def self_collision_cost( - env: ManagerBasedRlEnv, - sensor_name: str, - force_threshold: float = 10.0, +def motion_joint_position_error_exp( + env: ManagerBasedRlEnv, command_name: str, std: float ) -> torch.Tensor: - """Penalize self-collisions.""" - sensor: ContactSensor = env.scene[sensor_name] - data = sensor.data - if data.force_history is not None: - force_mag = torch.norm(data.force_history, dim=-1) - hit = (force_mag > force_threshold).any(dim=1) - return hit.sum(dim=-1).float() - assert data.found is not None - return data.found.squeeze(-1) - - -class joint_torque_limits: - """Penalize actuator-force limit violations with a configurable soft margin.""" - - def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRlEnv): - asset_cfg = cast(SceneEntityCfg, cfg.params.get("asset_cfg", _DEFAULT_ASSET_CFG)) - self.asset: Entity = env.scene[asset_cfg.name] - self.soft_factor = float(cfg.params.get("soft_factor", 0.9)) - if not (0.0 < self.soft_factor <= 1.0): - raise ValueError(f"soft_factor must be in (0, 1], got {self.soft_factor}") - - joint_ids = asset_cfg.joint_ids - if isinstance(joint_ids, slice): - joint_names = list(self.asset.joint_names) - else: - joint_names = [self.asset.joint_names[idx] for idx in joint_ids] - actuator_names = list(self.asset.actuator_names) - name_to_actuator = {name: idx for idx, name in enumerate(actuator_names)} - actuator_ids: list[int] = [] - for joint_name in joint_names: - if joint_name not in name_to_actuator: - raise RuntimeError(f"Actuator for joint '{joint_name}' not found.") - actuator_ids.append(name_to_actuator[joint_name]) - self._actuator_ids = torch.tensor(actuator_ids, device=env.device, dtype=torch.long) - - force_range = torch.as_tensor( - env.sim.model.actuator_forcerange, device=env.device, dtype=torch.float32 - ) - if force_range.ndim == 2: - force_range = force_range.unsqueeze(0).expand(env.num_envs, -1, -1) - elif force_range.ndim != 3: - raise RuntimeError( - f"Unexpected actuator_forcerange shape: {tuple(force_range.shape)}" - ) - ctrl_ids = torch.as_tensor( - self.asset.indexing.ctrl_ids, device=env.device, dtype=torch.long - ) - force_range = force_range.index_select(1, ctrl_ids) - torque_limit = torch.maximum(force_range[..., 0].abs(), force_range[..., 1].abs()) - self._soft_limits = torque_limit.index_select(1, self._actuator_ids) * self.soft_factor - - def __call__( - self, - env: ManagerBasedRlEnv, - asset_cfg: SceneEntityCfg = _DEFAULT_ASSET_CFG, - soft_factor: float = 0.9, - ) -> torch.Tensor: - del env, asset_cfg, soft_factor - applied_torque = self.asset.data.actuator_force.index_select(1, self._actuator_ids) - violation = (applied_torque.abs() / self._soft_limits - 1.0).clamp_min(0.0) - return -violation.sum(dim=1) - - -class feet_air_time_ref: - """Reward landing timing that matches the reference contact schedule.""" - - def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRlEnv): - self.sensor_name = str(cfg.params["sensor_name"]) - self.command_name = str(cfg.params["command_name"]) - self.threshold = float(cfg.params.get("thres", 0.5)) - self._reward_time = torch.zeros((env.num_envs, 0), dtype=torch.float32, device=env.device) - - def reset(self, env_ids: torch.Tensor | slice | None) -> None: - if env_ids is None: - env_ids = slice(None) - self._reward_time[env_ids] = 0.0 - - def __call__( - self, - env: ManagerBasedRlEnv, - sensor_name: str, - command_name: str, - thres: float = 0.5, - ) -> torch.Tensor: - del sensor_name, command_name, thres - sensor: ContactSensor = env.scene[self.sensor_name] - data = sensor.data - if data.found is None: - raise RuntimeError(f"Contact sensor '{self.sensor_name}' must expose 'found'") - current_contact = data.found > 0 - first_contact = sensor.compute_first_contact(dt=env.step_dt) - command = cast(object, env.command_manager.get_term(self.command_name)) - target_contact = torch.as_tensor( - cast(object, command).feet_standing, device=env.device, dtype=torch.bool - ) - if self._reward_time.shape != current_contact.shape: - self._reward_time = torch.zeros_like(current_contact, dtype=torch.float32) - - contact_match = current_contact == target_contact - self._reward_time = self._reward_time + torch.where( - contact_match, - torch.full_like(self._reward_time, env.step_dt), - torch.full_like(self._reward_time, -env.step_dt), - ) - reward = torch.sum( - (self._reward_time - self.threshold).clamp_max(0.0) * first_contact.float(), - dim=1, - ) - self._reward_time = self._reward_time * (~current_contact).float() - return reward - - -class feet_air_time_ref_dense: - """Dense contact/height shaping against the reference foot contact state.""" - - def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRlEnv): - asset_cfg = cast(SceneEntityCfg, cfg.params.get("asset_cfg", _DEFAULT_ASSET_CFG)) - self.asset: Entity = env.scene[asset_cfg.name] - body_names = cast(list[str], cfg.params["body_names"]) - body_ids, _ = self.asset.find_bodies(body_names, preserve_order=True) - self._body_ids = torch.tensor(body_ids, device=env.device, dtype=torch.long) - - site_names = cast(list[str] | None, cfg.params.get("site_names")) - self._site_ids: torch.Tensor | None = None - if site_names is not None: - site_ids, _ = self.asset.find_sites(site_names, preserve_order=True) - if len(site_ids) != len(body_ids): - raise ValueError( - "site_names must match body_names length for feet_air_time_ref_dense" - ) - self._site_ids = torch.tensor(site_ids, device=env.device, dtype=torch.long) - else: - body2_names = cast(list[str] | None, cfg.params.get("body2_names")) - if body2_names is None: - self._body2_ids = self._body_ids - else: - body2_ids, _ = self.asset.find_bodies(body2_names, preserve_order=True) - if len(body2_ids) != len(body_ids): - raise ValueError( - "body2_names must match body_names length for feet_air_time_ref_dense" - ) - self._body2_ids = torch.tensor(body2_ids, device=env.device, dtype=torch.long) - - self.sensor_name = str(cfg.params["sensor_name"]) - self.command_name = str(cfg.params["command_name"]) - self.air_h_low = float(cfg.params.get("air_h_low", 0.035)) - self.air_h_high = float(cfg.params.get("air_h_high", 0.155)) - self.contact_h_low = float(cfg.params.get("contact_h_low", 0.035)) - self.contact_h_high = float(cfg.params.get("contact_h_high", 0.125)) - self.air_h_span = max(self.air_h_high - self.air_h_low, 1.0e-6) - self.contact_h_span = max(self.contact_h_high - self.contact_h_low, 1.0e-6) - - def __call__( - self, - env: ManagerBasedRlEnv, - sensor_name: str, - command_name: str, - body_names: tuple[str, ...], - body2_names: tuple[str, ...] | None = None, - site_names: tuple[str, ...] | None = None, - air_h_low: float = 0.035, - air_h_high: float = 0.155, - contact_h_low: float = 0.035, - contact_h_high: float = 0.125, - asset_cfg: SceneEntityCfg = _DEFAULT_ASSET_CFG, - ) -> torch.Tensor: - del ( - sensor_name, - command_name, - body_names, - body2_names, - site_names, - air_h_low, - air_h_high, - contact_h_low, - contact_h_high, - asset_cfg, - ) - sensor: ContactSensor = env.scene[self.sensor_name] - data = sensor.data - if data.found is None: - raise RuntimeError(f"Contact sensor '{self.sensor_name}' must expose 'found'") - current_contact = data.found > 0 - command = cast(object, env.command_manager.get_term(self.command_name)) - target_contact = torch.as_tensor( - cast(object, command).feet_standing, device=env.device, dtype=torch.bool - ) + command = cast(MotionCommand, env.command_manager.get_term(command_name)) + error = torch.square(command.joint_pos - command.robot_joint_pos) + return torch.exp(-error.mean(-1) / std**2) - mismatch = current_contact ^ target_contact - both_air = (~current_contact) & (~target_contact) - both_contact = current_contact & target_contact - penalty = torch.zeros_like(target_contact, dtype=torch.float32) - penalty[mismatch] = -1.0 +def motion_joint_velocity_error_exp( + env: ManagerBasedRlEnv, command_name: str, std: float +) -> torch.Tensor: + command = cast(MotionCommand, env.command_manager.get_term(command_name)) + error = torch.square(command.joint_vel - command.robot_joint_vel) + return torch.exp(-error.mean(-1) / std**2) - if self._site_ids is not None: - foot_probe_height = self.asset.data.site_pos_w.index_select(1, self._site_ids)[..., 2] - feet_height_air = foot_probe_height - feet_height_contact = foot_probe_height - else: - feet_height_air = torch.minimum( - self.asset.data.body_link_pos_w.index_select(1, self._body_ids)[..., 2], - self.asset.data.body_link_pos_w.index_select(1, self._body2_ids)[..., 2], - ) - feet_height_contact = torch.maximum( - self.asset.data.body_link_pos_w.index_select(1, self._body_ids)[..., 2], - self.asset.data.body_link_pos_w.index_select(1, self._body2_ids)[..., 2], - ) - air_ratio = ((feet_height_air - self.air_h_low) / self.air_h_span).clamp(0.0, 1.0) - penalty = torch.where(both_air, -(1.0 - air_ratio), penalty) - - contact_ratio = ( - (feet_height_contact - self.contact_h_low) / self.contact_h_span - ).clamp(0.0, 1.0) - penalty = torch.where(both_contact, -contact_ratio, penalty) - return penalty.mean(dim=1) - - -# --------------------------------------------------------------------------- -# TWIST2-style feet rewards (adapted to mjlab API) -# --------------------------------------------------------------------------- - - -class feet_air_time: - """Reward feet air time matching a target duration (TWIST2-style). - - Accumulates air time per foot; on first contact, rewards - ``(air_time - target).clamp(max=0)`` so that too-short air phases are - penalised. Only active when the reference root XY speed > 0.05 m/s. - """ - - def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRlEnv): - self.sensor_name = str(cfg.params["sensor_name"]) - self.command_name = str(cfg.params["command_name"]) - self.air_time_target = float(cfg.params.get("air_time_target", 0.5)) - self.feet_air_time = torch.zeros( - (env.num_envs, 0), dtype=torch.float32, device=env.device - ) - self._last_contacts = torch.zeros( - (env.num_envs, 0), dtype=torch.bool, device=env.device - ) - def reset(self, env_ids: torch.Tensor | slice | None) -> None: - if env_ids is None: - env_ids = slice(None) - self.feet_air_time[env_ids] = 0.0 - self._last_contacts[env_ids] = False - - def __call__( - self, - env: ManagerBasedRlEnv, - sensor_name: str, - command_name: str, - air_time_target: float = 0.5, - ) -> torch.Tensor: - del sensor_name, command_name, air_time_target - sensor: ContactSensor = env.scene[self.sensor_name] - data = sensor.data - if data.force is None: - raise RuntimeError(f"Contact sensor '{self.sensor_name}' must expose 'force'") - contact = data.force[..., 2].abs() > 5.0 # [B, N] — TWIST2: fz > 5N - - # Lazy init on first call (shape unknown at __init__) - if self.feet_air_time.shape != contact.shape: - self.feet_air_time = torch.zeros_like(contact, dtype=torch.float32) - self._last_contacts = torch.zeros_like(contact, dtype=torch.bool) - - # contact_filt = contact OR last_contacts (same as TWIST2) - contact_filt = contact | self._last_contacts - self._last_contacts = contact - - first_contact = (self.feet_air_time > 0.0) & contact_filt - - self.feet_air_time += env.step_dt - air_time = (self.feet_air_time - self.air_time_target) * first_contact.float() - air_time = air_time.clamp(max=0.0) - self.feet_air_time *= ~contact_filt - - reward = air_time.sum(dim=1) - - # Gate by reference root XY speed > 0.05 - command = cast(MotionCommand, env.command_manager.get_term(self.command_name)) - ref_root_vxy = torch.norm(command.body_lin_vel_w[:, 0, :2], dim=-1) - reward *= (ref_root_vxy > 0.05).float() - - return reward - - -def feet_stumble( +def survival(env: ManagerBasedRlEnv) -> torch.Tensor: + return torch.ones(env.num_envs, device=env.device) + + +def self_collision_cost( env: ManagerBasedRlEnv, - sensor_name: str, + sensor_name: str | tuple[str, ...], + force_threshold: float = 10.0, ) -> torch.Tensor: - """Penalise stumbling: lateral contact force > 4x vertical (TWIST2-style).""" - sensor: ContactSensor = env.scene[sensor_name] - data = sensor.data - if data.force is None: - raise RuntimeError(f"Contact sensor '{sensor_name}' must expose 'force'") - force = data.force # [B, N, 3] - lateral = torch.norm(force[..., :2], dim=-1) # [B, N] - vertical = torch.abs(force[..., 2]) # [B, N] - stumble = torch.any(lateral > 4.0 * vertical, dim=1) - return stumble.float() - - -def feet_contact_forces( + """Penalize self-collision slots above the configured force threshold.""" + hit = _self_collision_hits(env, sensor_name, force_threshold) + return hit.sum(dim=-1).float() + + +def _self_collision_hits( env: ManagerBasedRlEnv, - sensor_name: str, - max_contact_force: float = 350.0, + sensor_name: str | tuple[str, ...], + force_threshold: float, ) -> torch.Tensor: - """Penalise excessive vertical contact forces (TWIST2-style). - - Computes L2 norm of vertical forces across all feet, then penalises the - excess above *max_contact_force*. This matches TWIST2's implementation - where two feet at 300 N each (norm ≈ 424 N) would trigger a penalty. - """ - sensor: ContactSensor = env.scene[sensor_name] - data = sensor.data - if data.force is None: - raise RuntimeError(f"Contact sensor '{sensor_name}' must expose 'force'") - fz = data.force[..., 2] # [B, N] - fz_norm = torch.norm(fz, dim=-1) # [B] - excess = (fz_norm - max_contact_force).clamp(min=0.0) - return excess - - -class feet_slip: - """Penalise horizontal foot velocity while in contact (TWIST2-style).""" - - def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRlEnv): - asset_cfg = cast(SceneEntityCfg, cfg.params.get("asset_cfg", _DEFAULT_ASSET_CFG)) - self.asset: Entity = env.scene[asset_cfg.name] - body_names = cast(list[str], cfg.params["body_names"]) - body_ids, _ = self.asset.find_bodies(body_names, preserve_order=True) - self._body_ids = torch.tensor(body_ids, device=env.device, dtype=torch.long) - self.sensor_name = str(cfg.params["sensor_name"]) - - def __call__( - self, - env: ManagerBasedRlEnv, - sensor_name: str, - body_names: tuple[str, ...], - asset_cfg: SceneEntityCfg = _DEFAULT_ASSET_CFG, - ) -> torch.Tensor: - del sensor_name, body_names, asset_cfg - sensor: ContactSensor = env.scene[self.sensor_name] - data = sensor.data - if data.force is None: - raise RuntimeError(f"Contact sensor '{self.sensor_name}' must expose 'force'") - contact = (data.force[..., 2].abs() > 5.0).float() # [B, N] — TWIST2: fz > 5N - - feet_vel_xy = self.asset.data.body_link_lin_vel_w.index_select( - 1, self._body_ids - )[..., :2] # [B, N, 2] - speed_xy = torch.norm(feet_vel_xy, dim=-1) # [B, N] - slip = torch.sqrt(speed_xy) * contact - return slip.sum(dim=1) + sensor_names = (sensor_name,) if isinstance(sensor_name, str) else sensor_name + force_histories = [] + found_values = [] + for name in sensor_names: + data = env.scene[name].data + if data.force_history is not None: + force_histories.append(data.force_history) + else: + assert data.found is not None + found = data.found + if found.ndim == 1: + found = found.unsqueeze(-1) + found_values.append(found) + + if force_histories: + force_history = torch.cat(force_histories, dim=1) + force_mag = torch.norm(force_history, dim=-1) + return (force_mag > force_threshold).any(dim=2) + + found = torch.cat(found_values, dim=1) + return (found > 0).any(dim=1, keepdim=True) diff --git a/train_mimic/tasks/tracking/rl/runner.py b/train_mimic/tasks/tracking/rl/runner.py index 9bd143b8..4614ce50 100644 --- a/train_mimic/tasks/tracking/rl/runner.py +++ b/train_mimic/tasks/tracking/rl/runner.py @@ -2,16 +2,13 @@ import pathlib import statistics import time -from typing import cast import torch from rsl_rl.env.vec_env import VecEnv -from torch import nn from mjlab.rl import RslRlVecEnvWrapper from mjlab.rl.runner import MjlabOnPolicyRunner from rsl_rl.utils import check_nan -from train_mimic.tasks.tracking.mdp import MotionCommand def _one_based_iteration_range(start_iteration: int, total_iterations: int) -> range: @@ -34,41 +31,12 @@ def _resolve_total_iterations(start_iteration: int, num_learning_iterations: int return start_iteration + num_learning_iterations -class _OnnxMotionModel(nn.Module): - """ONNX-exportable model that wraps the policy and bundles motion reference data.""" - - def __init__(self, actor, motion): - super().__init__() - self.policy = actor.as_onnx(verbose=False) - # torch.from_numpy shares memory with the underlying numpy array - # (zero-copy). The arrays are already writable (loaded from .npz). - self.register_buffer("joint_pos", torch.from_numpy(motion._joint_pos)) - self.register_buffer("joint_vel", torch.from_numpy(motion._joint_vel)) - self.register_buffer("body_pos_w", torch.from_numpy(motion._body_pos_w)) - self.register_buffer("body_quat_w", torch.from_numpy(motion._body_quat_w)) - self.register_buffer("body_lin_vel_w", torch.from_numpy(motion._body_lin_vel_w)) - self.register_buffer("body_ang_vel_w", torch.from_numpy(motion._body_ang_vel_w)) - self.time_step_total: int = self.joint_pos.shape[0] # type: ignore[index] - - def forward(self, *args): - # Last arg is always time_step; preceding args are policy inputs. - *policy_args, time_step = args - time_step_clamped = torch.clamp( - time_step.long().squeeze(-1), max=self.time_step_total - 1 - ) - if len(policy_args) == 1: - policy_out = self.policy(policy_args[0]) - else: - policy_out = self.policy(*policy_args) - return ( - policy_out, - self.joint_pos[time_step_clamped], # type: ignore[index] - self.joint_vel[time_step_clamped], # type: ignore[index] - self.body_pos_w[time_step_clamped], # type: ignore[index] - self.body_quat_w[time_step_clamped], # type: ignore[index] - self.body_lin_vel_w[time_step_clamped], # type: ignore[index] - self.body_ang_vel_w[time_step_clamped], # type: ignore[index] - ) +def _format_duration(seconds: float) -> str: + """Format elapsed/remaining seconds without wrapping after 24 hours.""" + total_seconds = max(0, int(seconds)) + hours, remainder = divmod(total_seconds, 3600) + minutes, secs = divmod(remainder, 60) + return f"{hours:02d}:{minutes:02d}:{secs:02d}" class MotionTrackingOnPolicyRunner(MjlabOnPolicyRunner): @@ -264,9 +232,9 @@ def _log_one_based_iteration( """ f"""{'Iteration time:':>{pad}} {iteration_time:.2f}s """ - f"""{'Time elapsed:':>{pad}} {time.strftime('%H:%M:%S', time.gmtime(logger.tot_time))} + f"""{'Time elapsed:':>{pad}} {_format_duration(logger.tot_time)} """ - f"""{'ETA:':>{pad}} {time.strftime('%H:%M:%S', time.gmtime(eta))} + f"""{'ETA:':>{pad}} {_format_duration(eta)} """ ) print(log_string) @@ -282,41 +250,9 @@ def export_policy_to_onnx( path: str, filename: str = "policy.onnx", verbose: bool = False, - *, - include_motion_labels: bool = False, ) -> None: os.makedirs(path, exist_ok=True) output_path = os.path.join(path, filename) - if include_motion_labels: - cmd = cast(MotionCommand, self.env.unwrapped.command_manager.get_term("motion")) - model = _OnnxMotionModel(self.alg.get_policy(), cmd.motion) - model.to("cpu") - model.eval() - dummy_inputs = model.policy.get_dummy_inputs() - time_step = torch.zeros(1, 1) - input_names = model.policy.input_names + ["time_step"] - torch.onnx.export( - model, - (*dummy_inputs, time_step), - output_path, - export_params=True, - opset_version=18, - verbose=verbose, - input_names=input_names, - output_names=[ - "actions", - "joint_pos", - "joint_vel", - "body_pos_w", - "body_quat_w", - "body_lin_vel_w", - "body_ang_vel_w", - ], - dynamic_axes={}, - dynamo=False, - ) - return - model = self.alg.get_policy().as_onnx(verbose=False) model.to("cpu") model.eval() diff --git a/train_mimic/tasks/tracking/tracking_env_cfg.py b/train_mimic/tasks/tracking/tracking_env_cfg.py index ce2e252b..eb1faaae 100644 --- a/train_mimic/tasks/tracking/tracking_env_cfg.py +++ b/train_mimic/tasks/tracking/tracking_env_cfg.py @@ -1,6 +1,6 @@ """Base motion tracking task configuration. -Copied from mjlab 1.2.0 ``mjlab.tasks.tracking.tracking_env_cfg`` for local +Copied from mjlab 1.4.0 ``mjlab.tasks.tracking.tracking_env_cfg`` for local customisation. All observation / reward / termination / event terms still reference ``mjlab.tasks.tracking.mdp`` — only the *wiring* lives here. """ @@ -42,65 +42,71 @@ def make_tracking_env_cfg() -> ManagerBasedRlEnvCfg: ## actor_terms = { - "command": ObservationTermCfg( - func=mdp.generated_commands, params={"command_name": "motion"} + "ref_joint_pos": ObservationTermCfg( + func=mdp.ref_joint_pos, params={"command_name": "motion"} ), - "motion_anchor_pos_b": ObservationTermCfg( - func=mdp.motion_anchor_pos_b, + "ref_joint_vel": ObservationTermCfg( + func=mdp.ref_joint_vel, params={"command_name": "motion"} + ), + "ref_anchor_pos_b": ObservationTermCfg( + func=mdp.ref_anchor_pos_b, params={"command_name": "motion"}, noise=Unoise(n_min=-0.25, n_max=0.25), ), - "motion_anchor_ori_b": ObservationTermCfg( - func=mdp.motion_anchor_ori_b, + "ref_anchor_ori_b": ObservationTermCfg( + func=mdp.ref_anchor_ori_b, params={"command_name": "motion"}, noise=Unoise(n_min=-0.05, n_max=0.05), ), - "base_lin_vel": ObservationTermCfg( + "robot_base_lin_vel_b": ObservationTermCfg( func=mdp.builtin_sensor, params={"sensor_name": "robot/imu_lin_vel"}, noise=Unoise(n_min=-0.5, n_max=0.5), ), - "base_ang_vel": ObservationTermCfg( + "robot_base_ang_vel_b": ObservationTermCfg( func=mdp.builtin_sensor, params={"sensor_name": "robot/imu_ang_vel"}, noise=Unoise(n_min=-0.2, n_max=0.2), ), - "joint_pos": ObservationTermCfg( + "robot_joint_pos_rel": ObservationTermCfg( func=mdp.joint_pos_rel, noise=Unoise(n_min=-0.01, n_max=0.01), params={"biased": True}, ), - "joint_vel": ObservationTermCfg( + "robot_joint_vel": ObservationTermCfg( func=mdp.joint_vel_rel, noise=Unoise(n_min=-0.5, n_max=0.5) ), - "actions": ObservationTermCfg(func=mdp.last_action), + "prev_action": ObservationTermCfg(func=mdp.last_action), } critic_terms = { - "command": ObservationTermCfg( - func=mdp.generated_commands, params={"command_name": "motion"} + "ref_joint_pos": ObservationTermCfg( + func=mdp.ref_joint_pos, params={"command_name": "motion"} + ), + "ref_joint_vel": ObservationTermCfg( + func=mdp.ref_joint_vel, params={"command_name": "motion"} ), - "motion_anchor_pos_b": ObservationTermCfg( - func=mdp.motion_anchor_pos_b, params={"command_name": "motion"} + "ref_anchor_pos_b": ObservationTermCfg( + func=mdp.ref_anchor_pos_b, params={"command_name": "motion"} ), - "motion_anchor_ori_b": ObservationTermCfg( - func=mdp.motion_anchor_ori_b, params={"command_name": "motion"} + "ref_anchor_ori_b": ObservationTermCfg( + func=mdp.ref_anchor_ori_b, params={"command_name": "motion"} ), - "body_pos": ObservationTermCfg( - func=mdp.robot_body_pos_b, params={"command_name": "motion"} + "robot_tracking_body_pos_b": ObservationTermCfg( + func=mdp.robot_tracking_body_pos_b, params={"command_name": "motion"} ), - "body_ori": ObservationTermCfg( - func=mdp.robot_body_ori_b, params={"command_name": "motion"} + "robot_tracking_body_ori_b": ObservationTermCfg( + func=mdp.robot_tracking_body_ori_b, params={"command_name": "motion"} ), - "base_lin_vel": ObservationTermCfg( + "robot_base_lin_vel_b": ObservationTermCfg( func=mdp.builtin_sensor, params={"sensor_name": "robot/imu_lin_vel"} ), - "base_ang_vel": ObservationTermCfg( + "robot_base_ang_vel_b": ObservationTermCfg( func=mdp.builtin_sensor, params={"sensor_name": "robot/imu_ang_vel"} ), - "joint_pos": ObservationTermCfg(func=mdp.joint_pos_rel), - "joint_vel": ObservationTermCfg(func=mdp.joint_vel_rel), - "actions": ObservationTermCfg(func=mdp.last_action), + "robot_joint_pos_rel": ObservationTermCfg(func=mdp.joint_pos_rel), + "robot_joint_vel": ObservationTermCfg(func=mdp.joint_vel_rel), + "prev_action": ObservationTermCfg(func=mdp.last_action), } observations = { @@ -163,7 +169,7 @@ def make_tracking_env_cfg() -> ManagerBasedRlEnvCfg: "push_robot": EventTermCfg( func=mdp.push_by_setting_velocity, mode="interval", - interval_range_s=(1.0, 3.0), + interval_range_s=(4.0, 6.0), params={"velocity_range": VELOCITY_RANGE}, ), "base_com": EventTermCfg( @@ -179,22 +185,30 @@ def make_tracking_env_cfg() -> ManagerBasedRlEnvCfg: }, }, ), - "encoder_bias": EventTermCfg( + "add_joint_default_pos": EventTermCfg( mode="startup", - func=dr.encoder_bias, + func=dr.joint_default_pos, params={ - "asset_cfg": SceneEntityCfg("robot"), - "bias_range": (-0.01, 0.01), + "asset_cfg": SceneEntityCfg("robot", joint_names=".*"), + "operation": "add", + "ranges": (-0.01, 0.01), }, ), - "foot_friction": EventTermCfg( + "physics_material": EventTermCfg( mode="startup", func=dr.geom_friction, params={ "asset_cfg": SceneEntityCfg("robot", geom_names=()), # Set per-robot. "operation": "abs", - "ranges": (0.3, 1.2), - "shared_random": True, # All foot geoms share the same friction. + "ranges": (0.3, 1.6), + }, + ), + "randomize_rigid_body_mass": EventTermCfg( + mode="startup", + func=dr.pseudo_inertia, + params={ + "asset_cfg": SceneEntityCfg("robot", body_names=()), # Set per-robot. + "alpha_range": (-0.1, 0.45), }, ), } @@ -214,6 +228,16 @@ def make_tracking_env_cfg() -> ManagerBasedRlEnvCfg: weight=0.5, params={"command_name": "motion", "std": 0.4}, ), + "motion_global_root_lin_vel": RewardTermCfg( + func=mdp.motion_global_anchor_linear_velocity_error_exp, + weight=1.0, + params={"command_name": "motion", "std": 1.0}, + ), + "motion_global_root_ang_vel": RewardTermCfg( + func=mdp.motion_global_anchor_angular_velocity_error_exp, + weight=1.0, + params={"command_name": "motion", "std": 3.0}, + ), "motion_body_pos": RewardTermCfg( func=mdp.motion_relative_body_position_error_exp, weight=1.0, @@ -234,17 +258,23 @@ def make_tracking_env_cfg() -> ManagerBasedRlEnvCfg: weight=1.0, params={"command_name": "motion", "std": 3.14}, ), - "action_rate_l2": RewardTermCfg(func=mdp.action_rate_l2, weight=-1e-1), + "motion_joint_pos": RewardTermCfg( + func=mdp.motion_joint_position_error_exp, + weight=1.0, + params={"command_name": "motion", "std": 0.5}, + ), + "motion_joint_vel": RewardTermCfg( + func=mdp.motion_joint_velocity_error_exp, + weight=0.5, + params={"command_name": "motion", "std": 3.0}, + ), + "survival": RewardTermCfg(func=mdp.survival, weight=3.0), + "action_rate_l2": RewardTermCfg(func=mdp.action_rate_l2, weight=-0.5), "joint_limit": RewardTermCfg( func=mdp.joint_pos_limits, weight=-10.0, params={"asset_cfg": SceneEntityCfg("robot", joint_names=(".*",))}, ), - "self_collisions": RewardTermCfg( - func=mdp.self_collision_cost, - weight=-10.0, - params={"sensor_name": "self_collision", "force_threshold": 10.0}, - ), } ##