Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
ac14b29
WIP sharding modeling
gilbertmike Feb 26, 2026
42a6cbe
Merge branch 'main' into hops
gilbertmike Feb 28, 2026
91c30a8
Merge branch 'main' into hops
gilbertmike Apr 9, 2026
fcc90d4
Merge branch 'main' into hops
gilbertmike Apr 10, 2026
8b778c5
Merge branch 'main' into hops
gilbertmike Apr 15, 2026
7122835
Merge branch 'main' into hops
gilbertmike Apr 28, 2026
d2885d4
Merge branch 'main' into hops
gilbertmike May 5, 2026
1d45da5
Merge branch 'main' into hops
gilbertmike May 12, 2026
c1ba862
Merge branch 'main' into hops
gilbertmike May 21, 2026
64470bb
[model] Use flattened arch
gilbertmike May 21, 2026
ebcec40
Merge branch 'main' into hops
gilbertmike May 22, 2026
d34e44b
[frontend] Add comment to explain physical_fanout
gilbertmike May 22, 2026
82bc1b9
Merge branch 'main' into hops
gilbertmike May 22, 2026
4b3fcab
[network] Add *untested* distributed model
gilbertmike May 22, 2026
307196d
[network] Tested distributed buffers
gilbertmike May 22, 2026
0453ddc
[model] Read/writes of distributed buffers
gilbertmike May 22, 2026
bcff540
Merge branch 'main' into hops
gilbertmike May 26, 2026
8313202
Model latency and distributed occupancy
gilbertmike May 26, 2026
847fbd0
Merge branch 'main' into hops
gilbertmike May 27, 2026
db5eae4
Implement (almost) proper latency model; waiting for hwcomponents lat…
gilbertmike May 27, 2026
a7c9c72
Merge branch 'main' into hops
gilbertmike May 28, 2026
fc4fced
Merge branch 'main' into hops
gilbertmike Jun 4, 2026
b9a5434
Merge branch 'main' into hops
gilbertmike Jun 4, 2026
7ad792a
[network] Update to latest spec
gilbertmike Jun 5, 2026
488f4b1
[network] Refactor network cost to handle different topologies
gilbertmike Jun 5, 2026
b9c7d4a
[network] WIP review
gilbertmike Jun 5, 2026
dd47699
[network] Clean up Claude output
gilbertmike Jun 5, 2026
3193735
Merge branch 'main' into hops
gilbertmike Jun 8, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions accelforge/frontend/arch/_flattened_arch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
from typing import TypeVar, Callable


_FIND_SENTINEL = object()

D = TypeVar("D")
T = TypeVar("T")


class FlattenedArch:
"""
A flattened arch is an architecture spec that has been
Expand Down Expand Up @@ -52,3 +61,124 @@ def is_above(self, name_a: str, name_b: str):
idx_a = self.index(name_a)
idx_b = self.index(name_b)
return idx_a < idx_b

def find_first_of_type_between(
self,
node_type: T,
name_lower: str,
name_upper: str,
default: D = _FIND_SENTINEL,
top_bottom: bool = True,
) -> T | D:
"""
Returns the first node with type `node_type` above `name_lower` and under `name_upper`.

If `name` does not exist, raises an error.

If no node of `node_type` is found, either `default` is
returned (if provided) or raises an error.
"""
upper_idx = self.index(name_upper)
lower_idx = self.index(name_lower)

iterator = self.nodes
if not top_bottom:
iterator = reversed(top_bottom)
for i, node in enumerate(iterator):
if not isinstance(node, node_type) or i <= upper_idx or i >= lower_idx:
continue
else:
return node
if default is not _FIND_SENTINEL:
return default
else:
raise ValueError(f"node with type {node_type} between {name_upper} and {name_lower} not found")

def find_first_of_type_above(
self,
node_type: T,
name_lower: str,
default: D = _FIND_SENTINEL,
top_bottom: bool = True,
) -> T | D:
"""
Returns the first node with type `node_type` above `name_lower` and under `name_upper`.

If `name` does not exist, raises an error.

If no node of `node_type` is found, either `default` is
returned (if provided) or raises an error.
"""
lower_idx = self.index(name_lower)

iterator = self.nodes
if not top_bottom:
iterator = reversed(top_bottom)
for i, node in enumerate(iterator):
if not isinstance(node, node_type) or i >= lower_idx:
continue
else:
return node
if default is not _FIND_SENTINEL:
return default
else:
raise ValueError(f"node with type {node_type} above {name_lower} not found")

def find_first_of_type_below(
self,
node_type: T,
name_upper: str,
default: D = _FIND_SENTINEL,
top_bottom: bool = True,
) -> T | D:
"""
Returns the first node with type `node_type` above `name_lower` and under `name_upper`.

If `name` does not exist, raises an error.

If no node of `node_type` is found, either `default` is
returned (if provided) or raises an error.
"""
upper_idx = self.index(name_upper)

iterator = self.nodes
if not top_bottom:
iterator = reversed(top_bottom)
for i, node in enumerate(iterator):
if not isinstance(node, node_type) or i <= upper_idx:
continue
else:
return node
if default is not _FIND_SENTINEL:
return default
else:
raise ValueError(f"node with type {node_type} below {name_upper} not found")

def first_below(
self,
name: str,
filter: Callable = None,
default: D = _FIND_SENTINEL,
) -> T | D:
"""
Returns the first node with type `node_type` above `name_lower` and under `name_upper`.

If `name` does not exist, raises an error.

If no node of `node_type` is found, either `default` is
returned (if provided) or raises an error.
"""
idx = self.index(name)

if filter is None:
filter = lambda x: True

for i, node in enumerate(self.nodes):
if not filter(node) or i <= idx:
continue
else:
return node
if default is not _FIND_SENTINEL:
return default
else:
raise ValueError(f"node below {name} not found")
31 changes: 18 additions & 13 deletions accelforge/frontend/arch/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def _set_n_calls(self, value: int | float) -> None:
@classmethod
def _deprecate_latency_fields(cls, data):
if isinstance(data, dict):
if "latency" in data:
if "latency" in data and not "throughput" in data:
l = data.pop("latency")
warnings.warn(
f"Setting `latency` on `{cls.__name__}` is deprecated; use "
Expand All @@ -155,16 +155,11 @@ def _deprecate_latency_fields(cls, data):
DeprecationWarning,
stacklevel=2,
)
if "throughput" in data:
raise ValueError(
f"Cannot specify both `latency` and `throughput` on "
f"`{cls.__name__}`. Drop the deprecated `latency` field."
)
l = str(l).strip()
data["throughput"] = (
f"1 / ({l}) if ({l}) != 0 else float('inf')"
)
if "latency_scale" in data:
if "latency_scale" in data and not "throughput_scale" in data:
ls = data.pop("latency_scale")
warnings.warn(
f"Setting `latency_scale` on `{cls.__name__}` is deprecated; use "
Expand All @@ -174,11 +169,6 @@ def _deprecate_latency_fields(cls, data):
DeprecationWarning,
stacklevel=2,
)
if "throughput_scale" in data:
raise ValueError(
f"Cannot specify both `latency_scale` and `throughput_scale` "
f"on `{cls.__name__}`. Drop the deprecated `latency_scale`."
)
ls = str(ls).strip()
data["throughput_scale"] = (
f"1 / ({ls}) if ({ls}) != 0 else float('inf')"
Expand Down Expand Up @@ -1304,8 +1294,9 @@ def _render_node_color(self) -> str:
return "#E0EEFF"


class TopologySpec(str, enum.Enum):
class TopologySpec(enum.StrEnum):
MESH = "mesh"
ALL_TO_ALL = "all_to_all"


class Network(Component, Leaf):
Expand All @@ -1316,6 +1307,20 @@ class Network(Component, Leaf):
of the spatial nodes from top to bottom.
"""

total_latency: str | int | float = "max(max_hops*actions['hops'].latency, max_link_traffic/actions['hops'].throughput)"
"""
Models latency as either:
- *Latency-bound*, which means that the latency of the route with the most number of
hops dominate the overall communication latency.
- *Bandwidth-bound*, which means that the traffic over the most congested link
dominates the overall communication latency.

Keywords:
- `max_hops` returns the number of hops in the longest route.
- `max_link_traffic` returns the amount of traffic (in bits) over the most congested
link.
"""

bits_per_value: EvalsTo[dict] = {}
"""
Sets the bits per value for tensors in this `TensorHolder`. Keys are evaluated as
Expand Down
30 changes: 29 additions & 1 deletion accelforge/frontend/arch/spatialable.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,19 @@ def _eval_expressions(self, *args, **kwargs):
return super(self.__class__, self)._eval_expressions(*args, **kwargs)


class PhysicalSpatial(EvalableModel):
name: str
"""
The name of the dimension over which this spatial fanout is occurring (e.g., X or Y).
"""

fanout: EvalsTo[int]
""" The size of this fanout. """

stride: EvalsTo[int]
""" The number of array coordinates between each spatial fanout coordinate."""


class Spatialable(EvalableModel):
"""Something that can be duplicated to create an array of."""

Expand All @@ -107,7 +120,7 @@ class Spatialable(EvalableModel):
specified at this level also apply to lower-level `Leaf` nodes in the architecture.
"""

_physical_spatial: NoParse[Spatial] = EvalableList()
_physical_spatial: NoParse[PhysicalSpatial] = EvalableList()
"""
The physical spatial fanout of this node. Should only have a value for a
flattened arch. Otherwise, the `spatial` attribute is authoritative.
Expand All @@ -123,14 +136,29 @@ def get_fanout_along(self, dim_name: str, default: int = 1) -> int:
return s.fanout
return default

def _has_physical_dim(self, dim_name: str) -> bool:
for s in self._physical_spatial:
if s.name == dim_name:
return True
return False

def _get_physical_fanout_along(self, dim_name: str, default: int = 1) -> int:
for s in self._physical_spatial:
if s.name == dim_name:
return s.fanout
return default

def _get_physical_stride_along(self, dim_name: str) -> int:
for s in self._physical_spatial:
if s.name == dim_name:
return s.stride
raise ValueError(f"dimension {dim_name} not found")

def _spatial_str(self, include_newline=True) -> str:
if not self.spatial:
return ""
result = ", ".join(f"{s.fanout}× {s.name}" for s in self.spatial)
return f"\n[{result}]" if include_newline else result

def _is_distributed(self):
return any(s.fanout > 1 for s in self._physical_spatial)
15 changes: 13 additions & 2 deletions accelforge/frontend/arch/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from accelforge.util.exceptions import EvaluationError

from accelforge.frontend.arch.spatialable import Spatialable
from accelforge.frontend.arch.spatialable import Spatialable, PhysicalSpatial
from accelforge.frontend.arch._flattened_arch import FlattenedArch

from pydantic import Discriminator
Expand Down Expand Up @@ -334,6 +334,10 @@ def _flatten(

nodes = []

# Nodes inside an array are flattened to fit into a hierarchical
# model in order to map.
# However, we will keep information about how these nodes are
# arranged for modeling.
for node in self.nodes:
try:
if isinstance(node, Branch):
Expand All @@ -342,7 +346,14 @@ def _flatten(
if isinstance(node, Spatialable):
fanout *= node.get_fanout()
node = deepcopy(node)
node._physical_spatial = node.spatial
node._physical_spatial = [
PhysicalSpatial(
name=s.name,
fanout=s.fanout,
stride=self.get_fanout_along(s.name)/s.fanout
)
for s in node.spatial
]
node.spatial = EvalableList()
nodes.append(node)
else:
Expand Down
39 changes: 36 additions & 3 deletions accelforge/model/_looptree/latency/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from accelforge.model._looptree.reuse.symbolic import BuffetStats
from accelforge.util._eval_expressions import MATH_FUNCS, eval_expression
from accelforge.util._sympy.broadcast_max import Max, Min
from accelforge.util._sympy.broadcast_max import Max, Min, MaxGeqZero
from accelforge.util._basetypes import EvalableList
import symengine as se

Expand Down Expand Up @@ -71,6 +71,10 @@ def component_latency(
component_to_actions: dict[str, dict[str, float]] = defaultdict(
lambda: defaultdict(lambda: 0)
)
# Holds ``keywords" that do not map neatly to actions, e.g., max_hops for network
component_to_keywords: dict[str, dict[str, float]] = defaultdict(
lambda: defaultdict(lambda: 0)
)
name2component: dict[str, Component] = {node.name: node for node in flattened_arch}

compute_obj = flattened_arch[-1]
Expand Down Expand Up @@ -103,6 +107,30 @@ def component_latency(
f"Component {component} is not a TensorHolder or Compute"
)

network_to_max_link_traffic = defaultdict(lambda: defaultdict(lambda: 0))
network_to_max_hops = defaultdict(lambda: [])
# Aggregates across tensors
for network, network_stats in looptree_results.network_stats.items():
component = network.component
if component not in name2component:
raise ValueError(f"Component {component} found in mapping but not arch")

dim_traffic = network_to_max_link_traffic[component]
for dim, max_traffic_in_dim in network_stats.max_traffic.items():
dim_traffic[dim] += max_traffic_in_dim

network_to_max_hops[component].append(network_stats.max_hops)

for network, network_stats in looptree_results.network_stats.items():
component = network.component
keywords = component_to_keywords[component]
keywords["max_link_traffic"] = MaxGeqZero(
*network_to_max_link_traffic[component].values()
)
keywords["max_hops"] = MaxGeqZero(
*network_to_max_hops[component]
)

longest_compute_latency = Max(
0, *[s.max_latency for s in looptree_results.compute_stats.values()]
)
Expand Down Expand Up @@ -138,13 +166,18 @@ def component_latency(
"sum": _sum,
}

for component in component_to_actions:
for component in name2component:
if component not in component_to_actions and component not in component_to_keywords:
continue
component_obj = name2component[component]
dump = component_obj.shallow_model_dump(include_None=True)
# Replace serialized `actions` dump with local Action copies that carry
# the correct n_calls for this job, so formulas can access `a.n_calls`,
# `a.throughput`, etc. without mutating the shared spec state.
dump["actions"] = component_to_actions[component]
if component in component_to_actions:
dump["actions"] = component_to_actions[component]
if component in component_to_keywords:
dump |= component_to_keywords[component]
symbol_table = {**symbol_table_base, **dump}
if component_obj.total_latency is not None:
component_latency[component] = eval_expression(
Expand Down
Loading
Loading