Skip to content

Outdoor segmentation question #5

Description

@Cureser

这是我加载了预训练参数concerto_large_outdoor.pth的训练结果,似乎还不错,但是我尝试仿照demo/2_sem_seg.py的方法处理sample2_outdoor.npz无法得到正确的结果,请问可以给我点意见建议吗?
[2025-11-28 00:33:59,486 INFO test.py line 295 17919] Test: 08_004070 [4071/4071]-122346 Batch 9.018 (11.609) Accuracy 0.7338 (0.7363) mIoU 0.4692 (0.6516) [2025-11-28 00:33:59,539 INFO test.py line 312 17919] Syncing ... [2025-11-28 00:33:59,544 INFO test.py line 340 17919] Val result: mIoU/mAcc/allAcc 0.6516/0.7363/0.9043 [2025-11-28 00:33:59,544 INFO test.py line 346 17919] Class_0 - car Result: iou/accuracy 0.9610/0.9833 [2025-11-28 00:33:59,544 INFO test.py line 346 17919] Class_1 - bicycle Result: iou/accuracy 0.4675/0.6021 [2025-11-28 00:33:59,545 INFO test.py line 346 17919] Class_2 - motorcycle Result: iou/accuracy 0.6744/0.7549 [2025-11-28 00:33:59,545 INFO test.py line 346 17919] Class_3 - truck Result: iou/accuracy 0.8655/0.9456 [2025-11-28 00:33:59,545 INFO test.py line 346 17919] Class_4 - other-vehicle Result: iou/accuracy 0.6478/0.7219 [2025-11-28 00:33:59,545 INFO test.py line 346 17919] Class_5 - person Result: iou/accuracy 0.7463/0.8718 [2025-11-28 00:33:59,545 INFO test.py line 346 17919] Class_6 - bicyclist Result: iou/accuracy 0.8761/0.9440 [2025-11-28 00:33:59,546 INFO test.py line 346 17919] Class_7 - motorcyclist Result: iou/accuracy 0.0000/0.0000 [2025-11-28 00:33:59,546 INFO test.py line 346 17919] Class_8 - road Result: iou/accuracy 0.9134/0.9559 [2025-11-28 00:33:59,546 INFO test.py line 346 17919] Class_9 - parking Result: iou/accuracy 0.4679/0.5319 [2025-11-28 00:33:59,546 INFO test.py line 346 17919] Class_10 - sidewalk Result: iou/accuracy 0.7478/0.9012 [2025-11-28 00:33:59,546 INFO test.py line 346 17919] Class_11 - other-ground Result: iou/accuracy 0.1269/0.1930 [2025-11-28 00:33:59,546 INFO test.py line 346 17919] Class_12 - building Result: iou/accuracy 0.8861/0.9633 [2025-11-28 00:33:59,546 INFO test.py line 346 17919] Class_13 - fence Result: iou/accuracy 0.5888/0.7681 [2025-11-28 00:33:59,547 INFO test.py line 346 17919] Class_14 - vegetation Result: iou/accuracy 0.8682/0.9265 [2025-11-28 00:33:59,547 INFO test.py line 346 17919] Class_15 - trunk Result: iou/accuracy 0.7115/0.7977 [2025-11-28 00:33:59,547 INFO test.py line 346 17919] Class_16 - terrain Result: iou/accuracy 0.7053/0.7847 [2025-11-28 00:33:59,548 INFO test.py line 346 17919] Class_17 - pole Result: iou/accuracy 0.6363/0.7702 [2025-11-28 00:33:59,548 INFO test.py line 346 17919] Class_18 - traffic-sign Result: iou/accuracy 0.4894/0.5739 [2025-11-28 00:33:59,548 INFO test.py line 354 17919] <<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<
这是测试代码
`
import numpy as np
import concerto
import torch
import torch.nn as nn
import open3d as o3d
import argparse
import os

try:
import flash_attn
except ImportError:
flash_attn = None
device = "cuda" if torch.cuda.is_available() else "cpu"

KITTI Meta data - 19 classes

KITTI_VALID_CLASS_IDS = (
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
)

KITTI_CLASS_LABELS = (
"car",
"bicycle",
"motorcycle",
"truck",
"other-vehicle",
"person",
"bicyclist",
"motorcyclist",
"road",
"parking",
"sidewalk",
"other-ground",
"building",
"fence",
"vegetation",
"trunk",
"terrain",
"pole",
"traffic-sign",
)

KITTI color map - 为每个类别分配颜色

KITTI_COLOR_MAP = {
0: (255.0, 0.0, 0.0), # car - 红色
1: (0.0, 255.0, 0.0), # bicycle - 绿色
2: (0.0, 0.0, 255.0), # motorcycle - 蓝色
3: (255.0, 255.0, 0.0), # truck - 黄色
4: (255.0, 0.0, 255.0), # other-vehicle - 品红
5: (0.0, 255.0, 255.0), # person - 青色
6: (255.0, 128.0, 0.0), # bicyclist - 橙色
7: (128.0, 0.0, 255.0), # motorcyclist - 紫色
8: (128.0, 128.0, 128.0), # road - 灰色
9: (255.0, 192.0, 203.0), # parking - 粉色
10: (0.0, 128.0, 128.0), # sidewalk - 深青
11: (255.0, 215.0, 0.0), # other-ground - 金色
12: (70.0, 130.0, 180.0), # building - 钢蓝
13: (165.0, 42.0, 42.0), # fence - 棕色
14: (50.0, 205.0, 50.0), # vegetation - 亮绿
15: (255.0, 99.0, 71.0), # trunk - 番茄红
16: (0.0, 100.0, 0.0), # terrain - 深绿
17: (211.0, 211.0, 211.0), # pole - 浅灰
18: (255.0, 255.0, 255.0) # traffic-sign - 白色
}

Get colors for valid classes (1-19)

CLASS_COLOR = [KITTI_COLOR_MAP[id] for id in KITTI_VALID_CLASS_IDS]

class SegHead(nn.Module):
def init(self, backbone_out_channels, num_classes):
super(SegHead, self).init()
self.seg_head = nn.Linear(backbone_out_channels, num_classes)

def forward(self, x):
    return self.seg_head(x)

def visualize_results(coord, pred, class_colors):
"""可视化分割结果"""
# 创建点云对象
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(coord)

# 为每个点分配颜色
colors = np.array([class_colors[p] if p < len(class_colors) else (0, 0, 0) for p in pred])
pcd.colors = o3d.utility.Vector3dVector(colors)

# 可视化
o3d.visualization.draw_geometries([pcd])

if name == "main":
parser = argparse.ArgumentParser()
parser.add_argument(
'--wo_color',
dest='wo_color',
action='store_true',
help="disable the color."
)
parser.add_argument(
'--wo_normal',
dest='wo_normal',
action='store_true',
help="disable the normal."
)
parser.add_argument(
'--model_path',
type=str,
default=None,
help="Path to the model checkpoint file"
)
parser.add_argument(
'--seg_head_path',
type=str,
default=None,
help="Path to the segmentation head checkpoint file"
)
parser.add_argument(
'--data_path',
type=str,
default=None,
help="Path to the input point cloud data file"
)
args = parser.parse_args()

# set random seed
concerto.utils.set_seed(46647087)

# 脚本目录
script_dir = os.path.dirname(os.path.abspath(__file__))

# 加载模型 - 使用本地模型而不是从仓库加载
model_path = args.model_path or os.path.join(script_dir, "../model/concerto_large_outdoor.pth")
print(f"Loading model from: {model_path}")

if flash_attn is not None:
    model = concerto.load(model_path).to(device)
else:
    custom_config = dict(
        enc_patch_size=[1024 for _ in range(5)],  # reduce patch size if necessary
        enable_flash=False,
    )
    model = concerto.load(
        model_path, custom_config=custom_config
    ).to(device)

# 加载分割头 - KITTI配置:1728输入通道,19个类别
seg_head_path = args.seg_head_path or os.path.join(script_dir, "../model/seg_head_kitti.pth")
print(f"Loading segmentation head from: {seg_head_path}")

try:
    # 尝试直接加载checkpoint
    ckpt = concerto.load(seg_head_path, ckpt_only=True,)
    
    # 检查是否包含config和state_dict
    if "config" not in ckpt:
        # 如果没有config,使用KITTI的固定配置
        ckpt["config"] = {
            'backbone_out_channels': 1728,
            'num_classes': 19
        }
        print("Using default KITTI configuration: backbone_out_channels=1728, num_classes=19")
except Exception as e:
    print(f"Error loading checkpoint: {e}")
    # 创建默认配置
    ckpt = {
        "config": {
            'backbone_out_channels': 1728,
            'num_classes': 19
        },
        "state_dict": {}
    }

# 创建并加载分割头
seg_head = SegHead(**ckpt["config"]).to(device)
if "state_dict" in ckpt and ckpt["state_dict"]:
    seg_head.load_state_dict(ckpt["state_dict"])
else:
    print("Warning: No state_dict found in checkpoint, using randomly initialized weights")

# 加载默认数据转换管道
transform = concerto.transform.default()
data_path = os.path.join(script_dir, "../data/sample2_outdoor.npz")
# 加载数据
if args.data_path:
    # 从指定路径加载数据
    print(f"Loading data from: {args.data_path}")
    if args.data_path.endswith('.npz'):
        data = np.load(args.data_path)
        point = {k: data[k] for k in data.files}
    else:
        # 尝试使用open3d加载点云文件
        pcd = o3d.io.read_point_cloud(args.data_path)
        point = {
            "coord": np.asarray(pcd.points),
            "color": np.asarray(pcd.colors) if pcd.has_colors() else np.zeros_like(np.asarray(pcd.points)),
            "normal": np.asarray(pcd.normals) if pcd.has_normals() else np.zeros_like(np.asarray(pcd.points))
        }
else:
    # 使用示例数据
    point = concerto.data.load("sample2_outdoor")

# 处理颜色和法线选项
if args.wo_color:
    point["color"] = np.zeros_like(point["coord"])
if args.wo_normal:
    point["normal"] = np.zeros_like(point["coord"])

# 保存原始坐标用于可视化
original_coord = point["coord"].copy()

# 应用数据转换
point = transform(point)

# 推理
model.eval()
seg_head.eval()
with torch.inference_mode():
    # 将数据移至GPU
    for key in point.keys():
        if isinstance(point[key], torch.Tensor) and device == "cuda":
            point[key] = point[key].cuda(non_blocking=True)
    
    # 模型前向传播
    point = model(point)
    
    # 处理池化父节点(如果存在)
    while "pooling_parent" in point.keys():
        assert "pooling_inverse" in point.keys()
        parent = point.pop("pooling_parent")
        inverse = point.pop("pooling_inverse")
        parent.feat = torch.cat([parent.feat, point.feat[inverse]], dim=-1)
        point = parent
    
    # 获取特征并进行分割预测
    feat = point.feat
    seg_logits = seg_head(feat)
    pred = seg_logits.argmax(dim=-1).data.cpu().numpy()
    color = np.array(CLASS_COLOR)[pred]

    print(f"Segmentation completed. Number of points: {len(pred)}")
    print(f"Predicted classes: {np.unique(pred)}")

# 可视化结果
print("Visualizing results...")
# visualize_results(original_coord, pred, CLASS_COLOR)
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(point.coord.cpu().detach().numpy())
pcd.colors = o3d.utility.Vector3dVector(color / 255.0)    
o3d.visualization.draw_geometries([pcd])`

这是我提取分割头的方式
`# /home/simon/PTv3_ws/extract_seg_head.py
import torch
import os

设置路径

MODEL_PATH = "/root/workspace/workspace/Pointcept/exp/concerto/semseg-ptv3-large-v1m1-test-kitti-lin/model/model_best.pth"
OUTPUT_PATH = "/root/workspace/workspace/Concerto/model/seg_head_kitti.pth"

确保输出目录存在

os.makedirs(os.path.dirname(OUTPUT_PATH), exist_ok=True)

print(f"Loading model from: {MODEL_PATH}")
try:
# 加载完整模型
checkpoint = torch.load(MODEL_PATH, map_location='cpu')

# 提取模型状态字典
if 'state_dict' in checkpoint:
    state_dict = checkpoint['state_dict']
else:
    state_dict = checkpoint  # 如果没有state_dict键,直接使用checkpoint

# 提取分割头参数
seg_head_state_dict = {}
for key, value in state_dict.items():
    if key.startswith('seg_head.'):
        # 保持原始键名,因为SegHead类的state_dict使用相同的命名
        seg_head_state_dict[key] = value
        print(f"Found seg_head parameter: {key}, shape: {value.shape}")
    elif key.startswith('cls.'):
        # 检查是否有cls开头的键(有些模型可能使用cls作为分割头名称)
        # 将cls.转换为seg_head.以匹配SegHead类的期望
        new_key = key.replace('cls.', 'seg_head.')
        seg_head_state_dict[new_key] = value
        print(f"Found cls parameter, converted to: {new_key}, shape: {value.shape}")

# 验证是否找到分割头参数
if not seg_head_state_dict:
    print("Warning: No seg_head or cls parameters found in the model!")
    print("Available keys:")
    for key in list(state_dict.keys())[:20]:  # 只显示前20个键作为示例
        print(f"  {key}")
    if len(state_dict) > 20:
        print(f"  ... and {len(state_dict) - 20} more keys")
else:
    print(f"Successfully extracted {len(seg_head_state_dict)} seg_head parameters")

# 创建符合要求格式的checkpoint文件,包含config和state_dict
output_checkpoint = {
    'config': {
        'backbone_out_channels': 1728,  # KITTI配置
        'num_classes': 19
    },
    'state_dict': seg_head_state_dict
}

# 保存分割头
torch.save(output_checkpoint, OUTPUT_PATH)
print(f"Segmentation head saved to: {OUTPUT_PATH}")
print(f"Checkpoint structure: {list(output_checkpoint.keys())}")
print(f"State dict keys: {list(seg_head_state_dict.keys())}")

except Exception as e:
print(f"Error processing model: {e}")
import traceback
traceback.print_exc()`

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions