Skip to content

Use symmetric memory in matmul+allreduce reference implementation#5837

Draft
wujingyue wants to merge 4 commits intomainfrom
wjy/symm
Draft

Use symmetric memory in matmul+allreduce reference implementation#5837
wujingyue wants to merge 4 commits intomainfrom
wjy/symm

Conversation

@wujingyue
Copy link
Copy Markdown
Collaborator

@wujingyue wujingyue commented Jan 17, 2026

The biggest performance win has come from limiting the number of streams (cb4d494). The wall time went down from 3.31ms to 3.09ms.

On top of that, using symmetric memory gives a slight speedup -- from 3.09ms (screenshot 1) to 3.07ms (screenshot 2). As a sanity check, I do see ncclSymk in the allreduce kernel name (screenshot 3), and it uses only 5 thread blocks instead of 24.

Repro:

me @ viking-prod-232 : dev | /opt/pytorch/nvfuser (wjy/symm)
$ mpirun -np 2 pytest tests/python/multidevice/test_overlap.py::'test_row_parallel_linear_forward_reference_benchmark' --only-mpi -vs

Screenshot 1:
image

Screenshot 2:
image

Screenshot 3:
image

@wujingyue wujingyue changed the title Use symmetric memory Use symmetric memory in reference model Jan 17, 2026
@github-actions
Copy link
Copy Markdown

github-actions Bot commented Jan 17, 2026

Review updated until commit 3a6025c

Description

  • Enable symmetric memory backend in distributed test setup

  • Refactor row-parallel linear reference to use symmetric memory output

  • Update benchmark parameters from 2 to 4 chunks for improved performance

  • Add new test file with symmetric memory allgather and reduce_scatter tests

Changes walkthrough

Relevant files
Enhancement
conftest.py
Enable symmetric memory backend in test setup                       

tests/python/multidevice/conftest.py

  • Import torch.distributed._symmetric_memory module
  • Configure NCCL process group with zero CTA policy
  • Set symmetric memory backend to "NCCL" and enable for world group
  • Simplify device_id parameter from torch.device to integer
  • +11/-2   
    test_overlap.py
    Refactor reference implementation to use symmetric memory

    tests/python/multidevice/test_overlap.py

  • Import torch.distributed._symmetric_memory module
  • Modify row_parallel_linear_forward_reference to accept output tensor
    parameter
  • Change stream pool indexing to use modulo 2 for stream reuse
  • Replace torch.empty with symm_mem.empty and add symm_mem.rendezvous
    call
  • Update benchmark num_chunks parameter from 2 to 4
  • +21/-10 
    Tests
    test_symmetric_memory.py
    Add symmetric memory integration tests                                     

    tests/python/multidevice/test_symmetric_memory.py

  • Add test_allgather_linear_symmetric_memory: Tests allgather into
    symmetric buffer with linear operation
  • Add test_linear_reducescatter_symmetric_memory: Tests partial linear
    with reduce_scatter into symmetric buffer
  • Include multicast support checks and appropriate test skipping
  • Validate correctness against reference implementations
  • +95/-0   

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    🔒 No security concerns identified
    ⚡ Recommended focus areas for review
    Stream Pool Logic

    The stream allocation logic was changed from stream_pool.get(i) to stream_pool.get(i % 2) to limit to 2 streams. This appears intentional for performance, but the change should be validated to ensure it doesn't introduce race conditions or incorrect synchronization behavior in the matmul operations across different chunks.

    worker_stream = stream_pool.get(i % 2)
    Process Group Configuration

    The NCCL process group is now initialized with cta_policy = NCCL_CTA_POLICY_ZERO and additional parameters. This configuration change should be validated to ensure compatibility across different NCCL versions and hardware configurations, as it may affect collective operations behavior.

    opts = dist.ProcessGroupNCCL.Options()
    opts.config.cta_policy = dist.ProcessGroupNCCL.NCCL_CTA_POLICY_ZERO
    dist.init_process_group(
        backend="nccl",
        pg_options=opts,
        # The default port as used by https://github.com/pytorch/pytorch/blob/45a8b5682eb69d865cbf68c7f2f689b56b4efd53/torch/csrc/distributed/c10d/TCPStore.hpp#L51.
        init_method="tcp://localhost:29500",
        world_size=world_size,
        rank=rank,
        device_id=local_rank,
    )
    Symmetric Memory Tests

    New comprehensive tests for symmetric memory functionality were added. These tests include multicast support checks and proper tensor distribution validation. The tests should be verified to ensure they cover edge cases and provide adequate coverage for the new symmetric memory features.

    def test_allgather_linear_symmetric_memory(setup_default_process_group):
        """Allgather input into symmetric buffer, then linear. Same sizes as row_parallel_linear_forward_reference."""
        if not _SymmetricMemory.has_multicast_support(
            DeviceType.CUDA, torch.cuda.current_device()
        ):
            pytest.skip("multicast not supported on this GPU")
    
        h, s, t = 2, 3, 6
        d = dist.get_world_size()
        if (h * 4) % d != 0:
            pytest.skip(
                f"Row-parallel linear requires {h * 4} to be divisible by world size {d}."
            )
        assert t % s == 0
    
        torch.manual_seed(0)
        inp_ref = torch.testing.make_tensor(t, h * 4, dtype=torch.int32, device="cpu").to(
            torch.bfloat16
        )
        weight_ref = torch.testing.make_tensor(
            h, h * 4, dtype=torch.int32, device="cpu"
        ).to(torch.bfloat16)
        out_ref = torch.nn.functional.linear(inp_ref.cuda(), weight_ref.cuda()).cpu()
    
        mesh = dist.device_mesh.init_device_mesh("cuda", [d])
        inp = distribute_tensor(inp_ref, mesh, placements=[Shard(-1)]).to_local()
        weight = weight_ref.cuda()
    
        # Symmetric buffer for allgathered input: (t, h*4). multimem_all_gather_out requires
        # multicast support (e.g. NVLink SHARP) on the GPU; skip the test otherwise.
        allgathered_symm = symm_mem.empty(t, h * 4, device="cuda", dtype=inp.dtype)
        symm_mem.rendezvous(allgathered_symm, group=dist.group.WORLD)
    
        group_name = dist.group.WORLD.group_name
        torch.ops.symm_mem.multimem_all_gather_out(inp, group_name, allgathered_symm)
        out = torch.nn.functional.linear(allgathered_symm, weight)
    
        torch.testing.assert_close(out.cpu(), out_ref)
    
    
    @pytest.mark.mpi
    def test_linear_reducescatter_symmetric_memory(setup_default_process_group):
        """Partial linear per rank, then reduce_scatter into symmetric buffer. Same sizes as row_parallel_linear_forward_reference."""
        if not _SymmetricMemory.has_multicast_support(
            DeviceType.CUDA, torch.cuda.current_device()
        ):
            pytest.skip("multicast not supported on this GPU")
    
        h, s, t = 2, 3, 6
        d = dist.get_world_size()
        if (h * 4) % d != 0:
            pytest.skip(
                f"Row-parallel linear requires {h * 4} to be divisible by world size {d}."
            )
        if h % d != 0:
            pytest.skip(
                f"Linear+reducescatter requires h={h} to be divisible by world size {d}."
            )
        assert t % s == 0
    
        torch.manual_seed(0)
        inp_ref = torch.testing.make_tensor(t, h * 4, dtype=torch.int32, device="cpu").to(
            torch.bfloat16
        )
        weight_ref = torch.testing.make_tensor(
            h, h * 4, dtype=torch.int32, device="cpu"
        ).to(torch.bfloat16)
    
        mesh = dist.device_mesh.init_device_mesh("cuda", [d])
        inp = distribute_tensor(inp_ref, mesh, placements=[Shard(-1)]).to_local()
        weight = distribute_tensor(weight_ref, mesh, placements=[Shard(-1)]).to_local()
    
        group_name = dist.group.WORLD.group_name
        out = torch.ops.symm_mem.fused_matmul_reduce_scatter(
            inp, weight.T, "sum", 1, group_name
        )
    
        out_ref = torch.nn.functional.linear(inp_ref.cuda(), weight_ref.cuda()).cpu()
        torch.testing.assert_close(
            out, distribute_tensor(out_ref, mesh, placements=[Shard(-1)]).to_local()
        )

    @wujingyue wujingyue changed the title Use symmetric memory in reference model Use symmetric memory in matmul+allreduce reference implementation Jan 17, 2026
    @wujingyue
    Copy link
    Copy Markdown
    Collaborator Author

    cc @kwen2501 let me know if you have any suggestions! In case you'd like to run the microbenchmark by yourself, the reference implementation above doesn't depend on nvFuser at all -- you should be able to simply git clone and mpirun.

    @kwen2501
    Copy link
    Copy Markdown

    kwen2501 commented Jan 17, 2026

    Hi @wujingyue I was thinking of all-gather bc it can use CE while all-reduce still needs SMs (reduction still requires computation).

    Also, matmul + all-gather overlap could be more common these days due to FSDP. In Tensor Parallel, all-reduce is more likely sequential to matmul due to data dependency.

    @wujingyue
    Copy link
    Copy Markdown
    Collaborator Author

    Hi @wujingyue I was thinking of all-gather bc it can use CE while all-reduce still needs SMs (reduction still requires computation).

    Also, matmul + all-gather overlap could be more common these days due to FSDP. In Tensor Parallel, all-reduce is more likely sequential to matmul due to data dependency.

    Sure -- I'll update you when I have an allgather example.

    @kwen2501
    Copy link
    Copy Markdown

    Oops, sorry, I didn't see your comment.

    I pushed a benchmark here:
    https://github.com/pytorch/pytorch/pull/172714
    (I don’t have access to Fuser, so I pushed it to a PyTorch branch)

    I ran it in three modes, on 8 x H100s:

    • Sequential: 2.96 ms
    • Overlap, w/o CE: 2.02 ms
    • Overlap, with CE: 1.77 ms

    I used the default command generated by Claude in the test file:

    torchrun --nproc_per_node=8 benchmarks/distributed/bench_overlapped_matmul_allgather.py \
    --m 8192 --n 8192 --k 8192 --ag-mb 64 --dtype fp16 --iters 200 --warmup 50
    

    (i.e. the all-gather is 64 MiB)

    To enable CE, we can add this option:
    --nccl-cta-policy-zero

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants