From 159c367fde3f2d787692d0a8e2f83c4fec903570 Mon Sep 17 00:00:00 2001 From: atharrva01 Date: Wed, 25 Mar 2026 14:47:07 +0530 Subject: [PATCH] Fix temporal supervision shift in NDP-HNN training loop Use birth_times[c] <= t instead of <= (t+1) so the XYZ loss supervises predictions against cells alive at the current snapshot, not one step ahead. Signed-off-by: atharrva01 --- NDP-HNN/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/NDP-HNN/train.py b/NDP-HNN/train.py index d8f12db..6bae995 100644 --- a/NDP-HNN/train.py +++ b/NDP-HNN/train.py @@ -31,10 +31,10 @@ def train_model(model, #--- forward one snapshot state, pred_xyz, inc_logits = model(data, state) - #--- mask nodes that are alive at next time step (t+1) + #--- mask nodes that are alive at current time step (t) t = int(data.t[0].item()) mask_next = torch.tensor( - [birth_times[c] <= (t + 1) for c in cells], + [birth_times[c] <= t for c in cells], dtype=torch.bool, device=device ) target_xyz = torch.tensor(birth_feat[:, :3], device=device)[mask_next]