Skip to content

Commit 82cdb29

Browse files
Merge pull request #2516 from devitocodes/async-loads-final-2
compiler: Misc improvements to code generation
2 parents f71764a + b8de9ec commit 82cdb29

14 files changed

Lines changed: 333 additions & 131 deletions

File tree

devito/arch/archinfo.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
'POWER8', 'POWER9',
3232
# Generic GPUs
3333
'AMDGPUX', 'NVIDIAX', 'INTELGPUX',
34+
# Nvidia GPUs
35+
'VOLTA', 'AMPERE', 'HOPPER', 'BLACKWELL',
3436
# Intel GPUs
3537
'PVC', 'INTELGPUMAX', 'MAX1100', 'MAX1550']
3638

@@ -867,6 +869,12 @@ def limits(self, compiler=None, language=None):
867869
'max-block-dims': 3,
868870
}
869871

872+
def supports(self, query, language=None):
873+
"""
874+
Check if the device supports a given feature.
875+
"""
876+
return False
877+
870878

871879
class IntelDevice(Device):
872880

@@ -895,6 +903,52 @@ def march(self):
895903
return 'tesla'
896904
return None
897905

906+
def supports(self, query, language=None):
907+
if language != 'cuda':
908+
return False
909+
910+
cc = get_nvidia_cc()
911+
if query == 'async-loads' and cc >= 80:
912+
# Asynchronous pipeline loads -- introduced in Ampere
913+
return True
914+
elif query == 'tma' and cc >= 90:
915+
# Tensor Memory Accelerator -- introduced in Hopper
916+
return True
917+
else:
918+
return False
919+
920+
921+
class Volta(NvidiaDevice):
922+
pass
923+
924+
925+
class Ampere(Volta):
926+
927+
def supports(self, query, language=None):
928+
if language != 'cuda':
929+
return False
930+
931+
if query == 'async-loads':
932+
return True
933+
934+
return super().supports(query, language)
935+
936+
937+
class Hopper(Ampere):
938+
939+
def supports(self, query, language=None):
940+
if language != 'cuda':
941+
return False
942+
943+
if query == 'tma':
944+
return True
945+
946+
return super().supports(query, language)
947+
948+
949+
class Blackwell(Hopper):
950+
pass
951+
898952

899953
class AmdDevice(Device):
900954

@@ -963,6 +1017,10 @@ def march(cls):
9631017
ANYGPU = Cpu64('gpu')
9641018

9651019
NVIDIAX = NvidiaDevice('nvidiaX')
1020+
VOLTA = Volta('volta')
1021+
AMPERE = Ampere('ampere')
1022+
HOPPER = Hopper('hopper')
1023+
BLACKWELL = Blackwell('blackwell')
9661024

9671025
AMDGPUX = AmdDevice('amdgpuX')
9681026

devito/arch/compiler.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
from codepy.toolchain import (GCCToolchain,
1414
call_capture_output as _call_capture_output)
1515

16-
from devito.arch import (AMDGPUX, Cpu64, AppleArm, NVIDIAX, POWER8, POWER9, Graviton,
17-
IntelDevice, get_nvidia_cc, check_cuda_runtime,
16+
from devito.arch import (AMDGPUX, Cpu64, AppleArm, NvidiaDevice, POWER8, POWER9,
17+
Graviton, IntelDevice, get_nvidia_cc, check_cuda_runtime,
1818
get_m1_llvm_path)
1919
from devito.exceptions import CompilationError
2020
from devito.logger import debug, warning
@@ -487,7 +487,7 @@ def __init_finalize__(self, **kwargs):
487487
language = kwargs.pop('language', configuration['language'])
488488
platform = kwargs.pop('platform', configuration['platform'])
489489

490-
if platform is NVIDIAX:
490+
if isinstance(platform, NvidiaDevice):
491491
self.cflags.remove('-std=c99')
492492
# Add flags for OpenMP offloading
493493
if language in ['C', 'openmp']:
@@ -555,7 +555,7 @@ def __init_finalize__(self, **kwargs):
555555
if not configuration['safe-math']:
556556
self.cflags.append('-ffast-math')
557557

558-
if platform is NVIDIAX:
558+
if isinstance(platform, NvidiaDevice):
559559
self.cflags.remove('-std=c99')
560560
elif platform is AMDGPUX:
561561
self.cflags.remove('-std=c99')
@@ -607,7 +607,7 @@ def __init_finalize__(self, **kwargs):
607607
language = kwargs.pop('language', configuration['language'])
608608
platform = kwargs.pop('platform', configuration['platform'])
609609

610-
if platform is NVIDIAX:
610+
if isinstance(platform, NvidiaDevice):
611611
if self.version >= Version("24.9"):
612612
self.cflags.append('-gpu=mem:separate:pinnedalloc')
613613
else:
@@ -843,7 +843,7 @@ def __init_finalize__(self, **kwargs):
843843
self.ldflags.remove('-qopenmp')
844844
self.ldflags.append('-fopenmp')
845845

846-
if platform is NVIDIAX:
846+
if isinstance(platform, NvidiaDevice):
847847
self.cflags.append('-fopenmp-targets=nvptx64-cuda')
848848
elif isinstance(platform, IntelDevice):
849849
self.cflags.append('-fiopenmp')
@@ -900,7 +900,7 @@ def __init_finalize__(self, **kwargs):
900900

901901
if isinstance(platform, Cpu64):
902902
pass
903-
elif platform is NVIDIAX:
903+
elif isinstance(platform, NvidiaDevice):
904904
self.cflags.append('-fsycl-targets=nvptx64-cuda')
905905
elif isinstance(platform, IntelDevice):
906906
self.cflags.append('-fsycl-targets=spir64')
@@ -931,7 +931,7 @@ def __new__(cls, *args, **kwargs):
931931
_base = ClangCompiler
932932
elif isinstance(platform, IntelDevice):
933933
_base = OneapiCompiler
934-
elif platform is NVIDIAX:
934+
elif isinstance(platform, NvidiaDevice):
935935
if language == 'cuda':
936936
_base = CudaCompiler
937937
else:

devito/finite_differences/differentiable.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,12 @@ def __init_finalize__(self, *args, **kwargs):
749749

750750
super().__init_finalize__(*args, **kwargs)
751751

752+
@classmethod
753+
def class_key(cls):
754+
# Ensure Weights appear before any other AbstractFunction
755+
p, v, _ = Array.class_key()
756+
return p, v - 1, cls.__name__
757+
752758
def __eq__(self, other):
753759
return (isinstance(other, Weights) and
754760
self.name == other.name and
@@ -838,7 +844,8 @@ def compare(self, other):
838844
n1 = self.__class__
839845
n2 = other.__class__
840846
if n1.__name__ == n2.__name__:
841-
return self.base.compare(other.base)
847+
return (self.weights.compare(other.weights) or
848+
self.base.compare(other.base))
842849
else:
843850
return super().compare(other)
844851

devito/ir/clusters/cluster.py

Lines changed: 15 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,17 @@
44
import numpy as np
55

66
from devito.ir.equations import ClusterizedEq
7-
from devito.ir.support import (PARALLEL, PARALLEL_IF_PVT, BaseGuardBoundNext,
8-
Forward, Interval, IntervalGroup, IterationSpace,
9-
DataSpace, Guards, Properties, Scope, WaitLock,
10-
WithLock, PrefetchUpdate, detect_accesses, detect_io,
11-
normalize_properties, normalize_syncs, minimum,
12-
maximum, null_ispace)
7+
from devito.ir.support import (
8+
PARALLEL, PARALLEL_IF_PVT, BaseGuardBoundNext, Forward, Interval, IntervalGroup,
9+
IterationSpace, DataSpace, Guards, Properties, Scope, WaitLock, WithLock,
10+
PrefetchUpdate, detect_accesses, detect_io, normalize_properties,
11+
tailor_properties, update_properties, normalize_syncs, minimum, maximum,
12+
null_ispace
13+
)
1314
from devito.mpi.halo_scheme import HaloScheme, HaloTouch
1415
from devito.mpi.reduction_scheme import DistReduce
1516
from devito.symbolics import estimate_cost
16-
from devito.tools import as_tuple, flatten, infer_dtype
17+
from devito.tools import as_tuple, filter_ordered, flatten, infer_dtype
1718
from devito.types import Fence, WeakFence, CriticalRegion
1819

1920
__all__ = ["Cluster", "ClusterGroup"]
@@ -52,7 +53,8 @@ def __init__(self, exprs, ispace=null_ispace, guards=None, properties=None,
5253
self._syncs = normalize_syncs(syncs or {})
5354

5455
properties = Properties(properties or {})
55-
self._properties = tailor_properties(properties, ispace)
56+
properties = tailor_properties(properties, ispace)
57+
self._properties = update_properties(properties, self.exprs)
5658

5759
self._halo_scheme = halo_scheme
5860

@@ -482,15 +484,17 @@ def properties(self):
482484

483485
@cached_property
484486
def guards(self):
485-
"""The guards of each Cluster in self."""
486-
return tuple(i.guards for i in self)
487+
"""
488+
A view of the ClusterGroup's guards.
489+
"""
490+
return tuple(filter_ordered(i.guards for i in self))
487491

488492
@cached_property
489493
def syncs(self):
490494
"""
491495
A view of the ClusterGroup's synchronization operations.
492496
"""
493-
return normalize_syncs(*[c.syncs for c in self])
497+
return normalize_syncs(*[c.syncs for c in self], strict=False)
494498

495499
@cached_property
496500
def dspace(self):
@@ -540,19 +544,3 @@ def reduce_properties(clusters):
540544
properties[d] = normalize_properties(properties.get(d, v), v)
541545

542546
return Properties(properties)
543-
544-
545-
def tailor_properties(properties, ispace):
546-
"""
547-
Create a new Properties object off `properties` that retains all and only
548-
the iteration dimensions in `ispace`.
549-
"""
550-
for i in properties:
551-
for d in as_tuple(i):
552-
if d not in ispace.itdims:
553-
properties = properties.drop(d)
554-
555-
for d in ispace.itdims:
556-
properties = properties.add(d)
557-
558-
return properties

devito/ir/support/properties.py

Lines changed: 76 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,14 @@ def __init__(self, name, val=None):
8686
"""
8787

8888
PREFETCHABLE = Property('prefetchable')
89+
"""
90+
A Dimension along which prefetching is feasible and beneficial.
91+
"""
92+
93+
PREFETCHABLE_SHM = Property('prefetchable-shm')
94+
"""
95+
A Dimension along which shared-memory prefetching is feasible and beneficial.
96+
"""
8997

9098

9199
# Bundles
@@ -129,6 +137,62 @@ def relax_properties(properties):
129137
return frozenset(properties - {PARALLEL_INDEP})
130138

131139

140+
def tailor_properties(properties, ispace):
141+
"""
142+
Create a new Properties object off `properties` that retains all and only
143+
the iteration dimensions in `ispace`.
144+
"""
145+
for i in properties:
146+
for d in as_tuple(i):
147+
if d not in ispace.itdims:
148+
properties = properties.drop(d)
149+
150+
for d in ispace.itdims:
151+
properties = properties.add(d)
152+
153+
return properties
154+
155+
156+
def update_properties(properties, exprs):
157+
"""
158+
Create a new Properties object off `properties` augmented with properties
159+
discovered from `exprs` or with properties removed if they are incompatible
160+
with `exprs`.
161+
"""
162+
exprs = as_tuple(exprs)
163+
164+
if not exprs:
165+
return properties
166+
167+
# Auto-detect prefetchable Dimensions
168+
dims = set()
169+
flag = False
170+
for e in as_tuple(exprs):
171+
w, r = e.args
172+
173+
# Ensure it's in the form `Indexed = Indexed`
174+
try:
175+
wf, rf = w.function, r.function
176+
except AttributeError:
177+
break
178+
179+
if not wf._mem_shared:
180+
break
181+
dims.update({d.parent for d in wf.dimensions if d.parent in properties})
182+
183+
if not rf._mem_heap:
184+
break
185+
else:
186+
flag = True
187+
188+
if flag:
189+
properties = properties.prefetchable_shm(dims)
190+
else:
191+
properties = properties.drop(properties=PREFETCHABLE_SHM)
192+
193+
return properties
194+
195+
132196
class Properties(frozendict):
133197

134198
"""
@@ -183,12 +247,15 @@ def sequentialize(self, dims=None):
183247
m[d] = normalize_properties(set(self.get(d, [])), {SEQUENTIAL})
184248
return Properties(m)
185249

186-
def prefetchable(self, dims):
250+
def prefetchable(self, dims, v=PREFETCHABLE):
187251
m = dict(self)
188252
for d in as_tuple(dims):
189-
m[d] = self.get(d, set()) | {PREFETCHABLE}
253+
m[d] = self.get(d, set()) | {v}
190254
return Properties(m)
191255

256+
def prefetchable_shm(self, dims):
257+
return self.prefetchable(dims, PREFETCHABLE_SHM)
258+
192259
def block(self, dims, kind='default'):
193260
if kind == 'default':
194261
p = TILABLE
@@ -232,8 +299,13 @@ def is_blockable(self, d):
232299
def is_blockable_small(self, d):
233300
return TILABLE_SMALL in self.get(d, set())
234301

235-
def is_prefetchable(self, dims):
236-
return any(PREFETCHABLE in self.get(d, set()) for d in as_tuple(dims))
302+
def is_prefetchable(self, dims=None, v=PREFETCHABLE):
303+
if dims is None:
304+
dims = list(self)
305+
return any(v in self.get(d, set()) for d in as_tuple(dims))
306+
307+
def is_prefetchable_shm(self, dims=None):
308+
return self.is_prefetchable(dims, PREFETCHABLE_SHM)
237309

238310
@property
239311
def nblockable(self):

devito/ir/support/syncs.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def update(self, ops):
164164
return Ops(m)
165165

166166

167-
def normalize_syncs(*args):
167+
def normalize_syncs(*args, strict=True):
168168
if not args:
169169
return {}
170170

@@ -175,12 +175,13 @@ def normalize_syncs(*args):
175175

176176
syncs = {k: tuple(filter_ordered(v)) for k, v in syncs.items()}
177177

178-
for v in syncs.values():
179-
waitlocks = [s for s in v if isinstance(s, WaitLock)]
180-
withlocks = [s for s in v if isinstance(s, WithLock)]
178+
if strict:
179+
for v in syncs.values():
180+
waitlocks = [s for s in v if isinstance(s, WaitLock)]
181+
withlocks = [s for s in v if isinstance(s, WithLock)]
181182

182-
if waitlocks and withlocks:
183-
# We do not allow mixing up WaitLock and WithLock ops
184-
raise ValueError("Incompatible SyncOps")
183+
if waitlocks and withlocks:
184+
# We do not allow mixing up WaitLock and WithLock ops
185+
raise ValueError("Incompatible SyncOps")
185186

186187
return Ops(syncs)

0 commit comments

Comments
 (0)