Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24,034 changes: 10,409 additions & 13,625 deletions .basedpyright/baseline.json

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@
"LoopyOptional": "obj:loopy.tools.Optional",
"ShapeType": "obj:loopy.typing.ShapeType",
"ToLoopyTypeConvertible": "obj:loopy.types.ToLoopyTypeConvertible",
"NoSyncScope": "obj:loopy.kernel.instruction.NoSyncScope",
"BarrierKind": "obj:loopy.kernel.instruction.BarrierKind",
"InsnId": "obj:loopy.typing.InsnId",
}


Expand Down
7 changes: 7 additions & 0 deletions doc/ref_kernel.rst
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,9 @@ Instructions
.. {{{

.. autoclass:: HappensAfter
.. currentmodule:: loopy.kernel.instruction
.. autoclass:: NoSyncScope
.. currentmodule:: loopy
Comment thread
inducer marked this conversation as resolved.
.. autoclass:: InstructionBase

.. _assignments:
Expand Down Expand Up @@ -500,6 +503,10 @@ No-Op Instruction
Barrier Instructions
^^^^^^^^^^^^^^^^^^^^

.. currentmodule:: loopy.kernel.instruction
.. autoclass:: BarrierKind
.. currentmodule:: loopy

Comment thread
inducer marked this conversation as resolved.
.. autoclass:: BarrierInstruction

Instruction Tags
Expand Down
6 changes: 3 additions & 3 deletions loopy/kernel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
from pymbolic import ArithmeticExpression, Expression

from loopy.kernel.function_interface import InKernelCallable
from loopy.kernel.instruction import InstructionBase
from loopy.kernel.instruction import BarrierKind, InstructionBase
from loopy.kernel.tools import SetOperationCacheManager
from loopy.options import Options
from loopy.schedule import ScheduleItem
Expand Down Expand Up @@ -747,7 +747,7 @@ def get_written_variables(self) -> AbstractSet[str]:

@memoize_method
def get_temporary_to_base_storage_map(self):
result = {}
result: dict[str, str] = {}
for tv in self.temporary_variables.values():
if tv.base_storage:
result[tv.name] = tv.base_storage
Expand Down Expand Up @@ -1172,7 +1172,7 @@ def local_mem_use(self):
# {{{ nosync sets

@memoize_method
def get_nosync_set(self, insn_id, scope):
def get_nosync_set(self, insn_id: InsnId, scope: BarrierKind):
assert scope in ("local", "global")

return frozenset(
Expand Down
7 changes: 5 additions & 2 deletions loopy/kernel/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from typing_extensions import Self

from pymbolic import Expression
from pymbolic.primitives import is_arithmetic_expression
from pymbolic.primitives import Variable, is_arithmetic_expression
from pytools import ImmutableRecord
from pytools.tag import Tag, Taggable

Expand Down Expand Up @@ -715,7 +715,7 @@ class ArrayBase(ImmutableRecord, Taggable):
"""See :ref:`data-dim-tags`.
"""

offset: Expression | str | None
offset: Expression | None
"""Offset from the beginning of the buffer to the point from
which the strides are counted, in units of the :attr:`dtype`.
May be one of
Expand Down Expand Up @@ -951,6 +951,9 @@ def __init__(self, name, dtype=None, shape=None, dim_tags=None, offset=0,
if tags is None:
tags = frozenset()

if isinstance(offset, str):
offset = Variable(offset)

ImmutableRecord.__init__(self,
name=name,
dtype=dtype,
Expand Down
5 changes: 3 additions & 2 deletions loopy/kernel/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
CallInstruction,
InstructionBase,
MultiAssignmentBase,
to_barrier_kind,
)
from loopy.symbolic import (
IdentityMapper,
Expand Down Expand Up @@ -651,10 +652,10 @@ def parse_special_insn(

if special_insn_kind == "gbarrier":
cls: type[InstructionBase] = BarrierInstruction
kwargs["synchronization_kind"] = "global"
kwargs["synchronization_kind"] = to_barrier_kind("global")
elif special_insn_kind == "lbarrier":
cls = BarrierInstruction
kwargs["synchronization_kind"] = "local"
kwargs["synchronization_kind"] = to_barrier_kind("local")
elif special_insn_kind == "nop":
cls = NoOpInstruction
else:
Expand Down
30 changes: 24 additions & 6 deletions loopy/kernel/instruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
TYPE_CHECKING,
Any,
ClassVar,
Literal,
TypeAlias,
cast,
)
Expand All @@ -60,7 +61,7 @@
from pymbolic import Expression

from loopy.kernel import LoopKernel
from loopy.typing import InameStr
from loopy.typing import InameStr, InsnId


Assignable: TypeAlias = (
Expand Down Expand Up @@ -153,6 +154,9 @@ class HappensAfter:

# {{{ instructions: base class

NoSyncScope: TypeAlias = Literal["global", "local", "any"]


class InstructionBase(ImmutableRecord, Taggable):
"""A base class for all types of instruction that can occur in
a kernel.
Expand Down Expand Up @@ -281,7 +285,7 @@ class InstructionBase(ImmutableRecord, Taggable):
depends_on_is_final: bool
groups: frozenset[str]
conflicts_with_groups: frozenset[str]
no_sync_with: frozenset[tuple[str, str]]
no_sync_with: frozenset[tuple[InsnId, NoSyncScope]]
predicates: frozenset[Expression]
within_inames: frozenset[InameStr]
within_inames_is_final: bool
Expand All @@ -300,7 +304,7 @@ def __init__(self,
depends_on_is_final: bool | None,
groups: frozenset[str] | None,
conflicts_with_groups: frozenset[str] | None,
no_sync_with: frozenset[tuple[str, str]] | None,
no_sync_with: frozenset[tuple[InsnId, NoSyncScope]] | None,
within_inames_is_final: bool | None,
within_inames: frozenset[str] | None,
priority: int | None,
Expand Down Expand Up @@ -942,7 +946,7 @@ def __init__(self,
depends_on_is_final: bool | None = None,
groups: frozenset[str] | None = None,
conflicts_with_groups: frozenset[str] | None = None,
no_sync_with: frozenset[tuple[str, str]] | None = None,
no_sync_with: frozenset[tuple[InsnId, NoSyncScope]] | None = None,
within_inames_is_final: bool | None = None,
within_inames: frozenset[str] | None = None,
priority: int | None = None,
Expand Down Expand Up @@ -1651,6 +1655,16 @@ def __str__(self):

# {{{ barrier instruction

BarrierKind: TypeAlias = Literal["local", "global"]


def to_barrier_kind(s: str) -> BarrierKind:
if s == "local" or s == "global":
return s
else:
raise ValueError(f"Invalid barrier kind: {s!r}")


class BarrierInstruction(_DataObliviousInstruction):
"""An instruction that requires synchronization with all
concurrent work items of :attr:`synchronization_kind`.
Expand All @@ -1674,6 +1688,9 @@ class BarrierInstruction(_DataObliviousInstruction):
... lbarrier {mem_kind=global}
"""

synchronization_kind: BarrierKind
mem_kind: BarrierKind

fields = _DataObliviousInstruction.fields | {"synchronization_kind",
"mem_kind"}

Expand All @@ -1682,8 +1699,9 @@ def __init__(self, id, happens_after=None, depends_on_is_final=None,
no_sync_with=None,
within_inames_is_final=None, within_inames=None,
priority=None,
predicates=None, tags=None, synchronization_kind="global",
mem_kind="local",
predicates=None, tags=None,
synchronization_kind: BarrierKind = "global",
mem_kind: BarrierKind = "local",
depends_on=None):

if predicates:
Expand Down
Loading
Loading