Repro: #5890
$ mpirun -np 1 -x NVFUSER_DUMP=pre_segmenter_logging pytest tests/python/multidevice/test_alphafold3.py -k outgoing --only-mpi -vs
The code of interest:
|
match direction: |
|
case Direction.OUTGOING: |
|
# z_out = einsum("bikc,bjkc->bijc", a, b) |
|
a = fd.ops.permute(a, [0, 3, 1, 2]) # [b, c, i, k] |
|
b = fd.ops.permute(b, [0, 3, 2, 1]) # [b, c, k, j] |
|
case Direction.INCOMING: |
|
# z_out = einsum("bkic,bkjc->bijc", a, b) |
|
a = fd.ops.permute(a, [0, 3, 2, 1]) # [b, c, i, k] |
|
b = fd.ops.permute(b, [0, 3, 1, 2]) # [b, c, k, j] |
|
z = fd.ops.matmul(a, b) # [b, c, i, j] |
|
z = fd.ops.permute(z, [0, 2, 3, 1]) # [b, i, j, c] |
The current heuristic for the forward propagation is to prefer the second input (usually the weight). Therefore, in the einsum output, j is sharded by DIDy not DIDx. This breaks the backprop from z_in (i by DIDy and j by DIDx) to the einsum output, because z_in wants j to be sharded by DIDx instead.
By the way, this is not a problem for "incoming" mode. Following the current heuristic, the einsum output does have j sharded on DIDx.
cc @DejunL
Repro: #5890
The code of interest:
Fuser/tests/python/direct/test_alphafold3.py
Lines 123 to 133 in 77abd29
The current heuristic for the forward propagation is to prefer the second input (usually the weight). Therefore, in the einsum output,
jis sharded byDIDynotDIDx. This breaks the backprop fromz_in(ibyDIDyandjbyDIDx) to the einsum output, becausez_inwantsjto be sharded byDIDxinstead.By the way, this is not a problem for "incoming" mode. Following the current heuristic, the einsum output does have
jsharded onDIDx.cc @DejunL