Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions devito/ir/iet/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,16 @@ def ccode(self):

@property
def view(self):
"""A representation of the IET rooted in ``self``."""
"""A high-level representation of the IET rooted in `self`."""
from devito.ir.iet.visitors import printAST
return printAST(self)

@property
def view_cir(self):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is like a prettier print I guess

from devito.ir.iet.visitors import CGen
from devito.passes.iet.languages.CIR import CIRPrinter
return str(CGen(printer=CIRPrinter).visit(self))

@property
def children(self):
"""Return the traversable children."""
Expand Down Expand Up @@ -148,7 +154,7 @@ def writes(self):
return ()

def _signature_items(self):
return (str(self),)
return (self.view_cir,)


class ExprStmt:
Expand Down
16 changes: 12 additions & 4 deletions devito/ir/support/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from devito.ir.support.utils import AccessMode, extrema
from devito.ir.support.vector import LabeledVector, Vector
from devito.symbolics import (compare_ops, retrieve_indexed, retrieve_terminals,
q_constant, q_affine, q_routine, search, uxreplace)
q_constant, q_comp_acc, q_affine, q_routine, search,
uxreplace)
from devito.tools import (Tag, as_mapper, as_tuple, is_integer, filter_sorted,
flatten, memoized_meth, memoized_generator)
from devito.types import (ComponentAccess, Dimension, DimensionTuple, Fence,
Expand Down Expand Up @@ -529,9 +530,16 @@ def __hash__(self):
(self.source, self.sink, self.source.timestamp == self.sink.timestamp)
)

@property
@cached_property
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the motivation for this?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It went from simply being an alias for an attribute of an attribute, to having internal logic etc

def function(self):
return self.source.function
if q_comp_acc(self.source.access) and not q_comp_acc(self.sink.access):
# E.g., `source=ab[x].x` and `sink=ab[x]` -> `a(x)`
return self.source.access.function_access
elif q_comp_acc(self.sink.access) and not q_comp_acc(self.source.access):
# E.g., `source=ab[x]` and `sink=ab[x].y` -> `b(x)`
return self.sink.access.function_access
else:
return self.source.function

@property
def findices(self):
Expand Down Expand Up @@ -955,7 +963,7 @@ def reads_gen(self):
@memoized_generator
def reads_smart_gen(self, f):
"""
Generate all read access to a given function.
Generate all read accesses to a given function.

StencilDimensions, if any, are replaced with their extrema.

Expand Down
31 changes: 21 additions & 10 deletions devito/mpi/halo_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,14 @@ class HaloLabel(Tag):
class HaloSchemeEntry(EnrichedTuple):

__rargs__ = ('loc_indices', 'loc_dirs', 'halos', 'dims')
__rkwargs__ = ('bundle',)

def __new__(cls, loc_indices, loc_dirs, halos, dims, getters=None):
def __new__(cls, loc_indices, loc_dirs, halos, dims, bundle=None, getters=None):
getters = cls.__rargs__ + cls.__rkwargs__
items = [frozendict(loc_indices), frozendict(loc_dirs),
frozenset(halos), frozenset(dims)]
kwargs = dict(zip(cls.__rargs__, items))
return super().__new__(cls, *items, getters=cls.__rargs__, **kwargs)
frozenset(halos), frozenset(dims), bundle]
kwargs = dict(zip(getters, items))
return super().__new__(cls, *items, getters=getters, **kwargs)

def __hash__(self):
return hash((self.loc_indices, self.loc_dirs, self.halos, self.dims))
Expand All @@ -47,7 +49,8 @@ def union(self, other):
exception is raised.
"""
if self.loc_indices != other.loc_indices or \
self.loc_dirs != other.loc_dirs:
self.loc_dirs != other.loc_dirs or \
self.bundle is not other.bundle:
raise HaloSchemeException(
"Inconsistency found while building a HaloScheme"
)
Expand All @@ -56,7 +59,7 @@ def union(self, other):
dims = self.dims | other.dims

return HaloSchemeEntry(self.loc_indices, self.loc_dirs, halos, dims,
getters=self.getters)
bundle=self.bundle, getters=self.getters)


Halo = namedtuple('Halo', 'dim side')
Expand Down Expand Up @@ -168,7 +171,7 @@ def union(self, halo_schemes):
elif not v.loc_indices or hse.loc_indices == v.loc_indices:
loc_indices, loc_dirs = hse.loc_indices, hse.loc_dirs
else:
# The `loc_dirs` must match otherwise it'd be a symptom there's
# These must match otherwise it'd be a symptom there's
# something horribly broken elsewhere!
assert hse.loc_dirs == v.loc_dirs
assert list(hse.loc_indices) == list(v.loc_indices)
Expand All @@ -185,7 +188,11 @@ def union(self, halo_schemes):
halos = hse.halos | v.halos
dims = hse.dims | v.dims

fmapper[k] = HaloSchemeEntry(loc_indices, loc_dirs, halos, dims)
assert hse.bundle is v.bundle

fmapper[k] = HaloSchemeEntry(
loc_indices, loc_dirs, halos, dims, bundle=hse.bundle
)

# Compute the `honored` union
for d, v in i.honored.items():
Expand Down Expand Up @@ -641,8 +648,12 @@ def _uxreplace_dispatch_haloscheme(hs0, rule):
for i, v in rule.items():
if i is f:
# Yes!
g = v
hse = hse0
if v.is_Bundle:
g = f
hse = hse0._rebuild(bundle=v)
else:
g = v
hse = hse0

elif i.is_Indexed and i.function is f and v.is_Indexed:
# Yes, but through an Indexed, hence the `loc_indices` may now
Expand Down
38 changes: 27 additions & 11 deletions devito/mpi/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
from devito.symbolics import (Byref, CondNe, FieldFromPointer, FieldFromComposite,
IndexedPointer, Macro, cast, subs_op_args)
from devito.tools import (as_mapper, dtype_to_mpitype, dtype_len, infer_datasize,
flatten, generator, is_integer, split)
from devito.types import (Array, Bag, Dimension, Eq, Symbol, LocalObject,
CompositeObject, CustomDimension)
flatten, generator, is_integer)
from devito.types import (Array, Bag, BundleView, Dimension, Eq, Symbol,
LocalObject, CompositeObject, CustomDimension)

__all__ = ['HaloExchangeBuilder', 'ReductionBuilder', 'mpi_registry']

Expand Down Expand Up @@ -292,19 +292,28 @@ def _make_bundles(self, hs):

mapper = as_mapper(halo_scheme.fmapper, lambda i: halo_scheme.fmapper[i])
for hse, components in mapper.items():
# We recast everything as Bags for simplicity -- worst case scenario
# all Bags only have one component. Existing Bundles are preserved
halo_scheme = halo_scheme.drop(components)
bundles, candidates = split(tuple(components), lambda i: i.is_Bundle)
for b in bundles:
halo_scheme = halo_scheme.add(b, hse)

# Existing Bundles are preserved
if hse.bundle:
if set(components) == set(hse.bundle.components):
halo_scheme = halo_scheme.add(hse.bundle, hse)
else:
name = f'bundleview_{hse.bundle.name}'
bundle_view = BundleView(
name=name, components=components, parent=hse.bundle
)
halo_scheme = halo_scheme.add(bundle_view, hse)
continue

# We recast everything else as Bags for simplicity -- worst case
# scenario all Bags only have one component.
try:
name = "bag_%s" % "".join(f.name for f in candidates)
bag = Bag(name=name, components=candidates)
name = "bag_%s" % "".join(f.name for f in components)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fstring? Might be a rare occasion where it's less neat tbh

bag = Bag(name=name, components=components)
halo_scheme = halo_scheme.add(bag, hse)
except ValueError:
for i in candidates:
for i in components:
name = "bag_%s" % i.name
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fstring

bag = Bag(name=name, components=i)
halo_scheme = halo_scheme.add(bag, hse)
Expand Down Expand Up @@ -363,10 +372,17 @@ def _make_copy(self, f, hse, key, swap=False):
else:
swap = lambda i, j: (j, i)
name = 'scatter%s' % key
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto


if isinstance(f, Bag):
for i, c in enumerate(f.components):
eqns.append(Eq(*swap(buf[[i] + bdims], c[findices])))
elif isinstance(f, BundleView):
assert f.parent is hse.bundle
for i, c in enumerate(f.components):
indices = [f.parent.components.index(c), *findices]
eqns.append(Eq(*swap(buf[[i] + bdims], f.parent[indices])))
else:
assert f.is_Bundle
for i in range(f.ncomp):
eqns.append(Eq(*swap(buf[[i] + bdims], f[[i] + findices])))

Expand Down
12 changes: 6 additions & 6 deletions devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
from devito.data import default_allocator
from devito.exceptions import (CompilationError, ExecutionError, InvalidArgument,
InvalidOperator)
from devito.logger import debug, info, perf, warning, is_log_enabled_for, switch_log_level
from devito.logger import (debug, info, perf, warning, is_log_enabled_for,
switch_log_level)
from devito.ir.equations import LoweredEq, lower_exprs, concretize_subdims
from devito.ir.clusters import ClusterGroup, clusterize
from devito.ir.iet import (Callable, CInterface, EntryFunction, FindSymbols, MetaCall,
derive_parameters, iet_build)
from devito.ir.iet import (Callable, CInterface, EntryFunction, FindSymbols,
MetaCall, derive_parameters, iet_build)
from devito.ir.support import AccessMode, SymbolRegistry
from devito.ir.stree import stree_build
from devito.operator.profiling import create_profile
Expand All @@ -26,8 +27,7 @@
from devito.parameters import configuration
from devito.passes import (Graph, lower_index_derivatives, generate_implicit,
generate_macros, minimize_symbols, unevaluate,
error_mapper, is_on_device)
from devito.passes.iet.dtypes import lower_dtypes
error_mapper, is_on_device, lower_dtypes)
from devito.symbolics import estimate_cost, subs_op_args
from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_mapper, as_tuple,
flatten, filter_sorted, frozendict, is_integer,
Expand Down Expand Up @@ -488,7 +488,7 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs):
# Extract the necessary macros from the symbolic objects
generate_macros(graph, **kwargs)

# Add type specific metadata
# Target-specific lowering
lower_dtypes(graph, **kwargs)

# Target-independent optimizations
Expand Down
9 changes: 5 additions & 4 deletions devito/passes/clusters/buffering.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,7 @@ class InjectBuffers(Queue):
def __init__(self, mapper, sregistry, options):
super().__init__()

# Sort the mapper so that we always process the same Function in the
# same order, hence we get deterministic code generation
self.mapper = {i: mapper[i] for i in sorted(mapper, key=lambda i: i.name)}
self.mapper = mapper

self.sregistry = sregistry
self.options = options
Expand Down Expand Up @@ -302,6 +300,9 @@ def generate_buffers(clusters, key, sregistry, options, **kwargs):
# {candidate buffered Function -> [Clusters that access it]}
bfmap = map_buffered_functions(clusters, key)

# Sort for deterministic code generation
bfmap = {i: bfmap[i] for i in sorted(bfmap, key=lambda i: i.name)}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this sorting go into map_buffered_functions instead?


# {buffered Function -> Buffer}
xds = {}
mapper = {}
Expand Down Expand Up @@ -718,7 +719,7 @@ def offset_from_centre(d, indices):
# `time/factor` -- the starting pointing at time_m or time_M
v = indices[0]
try:
p = sum(v.args[1:])
p = v.func(*[i for i in v.args if not is_integer(i)])
if not ((p - v).is_Integer or (p - v).is_Symbol):
raise ValueError
except (IndexError, ValueError):
Expand Down
Loading
Loading