Skip to content

Commit 8b157fe

Browse files
committed
arch: Add Device.max_thread_block_cluster_size
1 parent 15e8d35 commit 8b157fe

1 file changed

Lines changed: 7 additions & 1 deletion

File tree

devito/arch/archinfo.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -842,7 +842,8 @@ class Device(Platform):
842842

843843
def __init__(self, name, cores_logical=None, cores_physical=None, isa='cpp',
844844
max_threads_per_block=1024, max_threads_dimx=1024,
845-
max_threads_dimy=1024, max_threads_dimz=64):
845+
max_threads_dimy=1024, max_threads_dimz=64,
846+
max_thread_block_cluster_size=8):
846847
super().__init__(name)
847848

848849
cpu_info = get_cpu_info()
@@ -855,6 +856,7 @@ def __init__(self, name, cores_logical=None, cores_physical=None, isa='cpp',
855856
self.max_threads_dimx = max_threads_dimx
856857
self.max_threads_dimy = max_threads_dimy
857858
self.max_threads_dimz = max_threads_dimz
859+
self.max_thread_block_cluster_size = max_thread_block_cluster_size
858860

859861
@classmethod
860862
def _mro(cls):
@@ -961,6 +963,10 @@ def supports(self, query, language=None):
961963

962964
class Hopper(Ampere):
963965

966+
def __init__(self, *args, **kwargs):
967+
kwargs.setdefault('max_thread_block_cluster_size', 16)
968+
super().__init__(*args, **kwargs)
969+
964970
def supports(self, query, language=None):
965971
if query in ('tma', 'thread-block-cluster'):
966972
return True

0 commit comments

Comments
 (0)