Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions accelforge/frontend/mapping/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -1535,6 +1535,10 @@ class Reservation(MappingNode):
""" Tensors for which this reservation is reserving the tensor's backing storage.
"""

_component_object: "arch.Component | None" = None
""" The arch component backing the reserved resource, taken from the Storage node
this Reservation was created from. Used internally by the Mapper; do not set. """

persistent: bool = False
"""
Whether this reservation is persistent. Persistent reservations can't be tiled and
Expand Down
4 changes: 4 additions & 0 deletions accelforge/frontend/renames.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
Rank: TypeAlias = str
EinsumName: TypeAlias = str

# The "don't care" rank is useful when specifying in a binding that a loop can
# be over any rank.
RANK_DONT_CARE: Rank = "DONT_CARE"


class Rename(EvalableModel):
"""
Expand Down
74 changes: 70 additions & 4 deletions accelforge/mapper/FFM/_join_pmappings/compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
TilePattern,
Loop as MappingLoop,
)
from accelforge.frontend.renames import Rank, RankVariable, TensorName
from accelforge.frontend.renames import Rank, RankVariable, TensorName, RANK_DONT_CARE
from accelforge.frontend.workload import Einsum
from accelforge.mapper.FFM._pareto_df.df_convention import (
is_fused_loop_col,
Expand All @@ -31,6 +31,8 @@
# 1. Each tensor is stored above some loop index. 0 is the outermost loop, 1 the
# next-innermost...
# 2. All loops above any shared tensor are co-tiled and must match between PmappingGroups.
# 3. Spatial loops *below* a physically-distributed storage (i.e., the data binding)
# must match. These are in TensorReservations.physical_spatial_loops.

T = TypeVar("T", bound="Updatable")

Expand Down Expand Up @@ -58,11 +60,14 @@ class Loop(Updatable):
rank_name: Rank
tile_pattern: TilePattern | None
is_spatial: bool
# The architecture spatial dimension (e.g. "X", "Y") this loop fans out over.
spatial_dim: str | None = None

def __post_init__(self):
assert isinstance(self.rank_name, Rank)
assert isinstance(self.tile_pattern, Number | TilePattern | str | None)
assert isinstance(self.is_spatial, bool)
assert isinstance(self.spatial_dim, str | None)
assert isinstance(
self.tile_pattern.initial_tile_shape,
Number | str | None,
Expand Down Expand Up @@ -162,10 +167,16 @@ class TensorReservation(Updatable):
name: TensorName
resource_name: str
persistent: bool = False
# Spatial loops *below* this storage that distribute the tensor across physical
# instances
physical_spatial_loops: tuple[Loop] = ()

def __post_init__(self):
if self.persistent:
assert len(self.loops) == 0, "Persistent tensors be above all loops"
assert all(
isinstance(l, Loop) and l.is_spatial for l in self.physical_spatial_loops
), "physical_spatial_loops must all be spatial Loops"

@property
def above_loop_index(self) -> int:
Expand All @@ -175,7 +186,12 @@ def __str__(self):
return f"[{self.resource_name}] {self.name} below {self.loops}"

def __repr__(self):
return f"Reservation({repr(self.name)}, {repr(self.loops)}, {repr(self.resource_name)})"
phys = (
f", physical_spatial_loops={repr(self.physical_spatial_loops)}"
if self.physical_spatial_loops
else ""
)
return f"Reservation({repr(self.name)}, {repr(self.loops)}, {repr(self.resource_name)}{phys})"

def pydot_str(self):
return f"[{self.resource_name}] {self.name}"
Expand Down Expand Up @@ -216,6 +232,9 @@ def _prepend_symbols(self, prepend: str) -> "TensorReservation":
def clear_symbolic_tile_patterns(self) -> "TensorReservation":
return self.update(
loops=tuple(l.clear_symbolic_tile_patterns() for l in self.loops),
physical_spatial_loops=tuple(
l.clear_symbolic_tile_patterns() for l in self.physical_spatial_loops
),
)

def make_fused_loop_symbols(
Expand Down Expand Up @@ -243,7 +262,20 @@ def _rename_to_match(
l_mine, new_renames = l_mine._rename_to_match(l_other)
_update_rename_dict(renames, new_renames)
new_loops.append(l_mine)
return self.update(loops=tuple(new_loops)), renames
new_physical = []
for l_mine, l_other in zip(
self.physical_spatial_loops, other.physical_spatial_loops
):
l_mine, new_renames = l_mine._rename_to_match(l_other)
_update_rename_dict(renames, new_renames)
new_physical.append(l_mine)
return (
self.update(
loops=tuple(new_loops),
physical_spatial_loops=tuple(new_physical),
),
renames,
)


class SplitKind(Enum):
Expand Down Expand Up @@ -544,7 +576,6 @@ def from_mapping(
t.name: t.rank_variable2ranks for t in einsum.tensor_accesses
}

# TODO: update compatibility to handle spatial-for loop per-tensor update
tensor_indices = []
split_above_loop_indices = []
reservation_indices = []
Expand Down Expand Up @@ -577,6 +608,11 @@ def from_mapping(
), f"Tensors {backing_remaining} not found in mapping"

def get_rank(rank_variable, tensor):
"""
Return rank in tensor indexed by rank_variable or
Rank("NO RANK.RECOMPUTED") if rank not in tensor.
"""
# TODO: shouldn't this whole logic use relevancy from workload?
rv = rank_variable_to_ranks[tensor].get(rank_variable, oset())
assert (
len(rv) <= 1
Expand All @@ -597,13 +633,43 @@ def make_loops(above_index: int, tensor_name: TensorName) -> list[MappingLoop]:
]
return tuple(loops)

def make_physical_spatial_loops(
above_index: int, tensor_name: TensorName, storage
) -> tuple[Loop]:
"""Make data binding of physically distributed storages."""
if storage is None or not storage._is_distributed():
return ()
out = []
for n in mapping.nodes[above_index + 1 :]:
# Stop at the next storage level: loops below it belong to that storage.
if isinstance(n, (MappingReservation, TensorHolder)):
break
if not isinstance(n, Spatial):
continue
rank = get_rank(n.rank_variable, tensor_name)
# If the rank is irrelevant, the binding could be any rank
if rank == Rank("NO RANK. RECOMPUTED."):
rank = RANK_DONT_CARE
out.append(
Loop(
rank_name=rank,
tile_pattern=n.tile_pattern._symbol2str(),
is_spatial=True,
spatial_dim=n.name,
)
)
return tuple(out)

return cls(
tensors=fzs(
TensorReservation(
name=mapping.nodes[i].purpose,
loops=make_loops(i, mapping.nodes[i].purpose),
resource_name=mapping.nodes[i].resource,
persistent=mapping.nodes[i].persistent,
physical_spatial_loops=make_physical_spatial_loops(
i, mapping.nodes[i].purpose, mapping.nodes[i]._component_object
),
)
for i in tensor_indices
),
Expand Down
1 change: 1 addition & 0 deletions accelforge/model/_looptree/reuse/symbolic/_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,7 @@ def insert_reservation_nodes(
node = Reservation(purposes=[buffet.tensor], resource=buffet.level)
node.persistent = tracker.node.persistent
node._backing = tracker.node._backing
node._component_object = tracker.node.component_object

if (
buffet.tensor not in info.tensor_to_reservation_backer_id
Expand Down
Loading