@@ -842,7 +842,8 @@ class Device(Platform):
842842
843843 def __init__ (self , name , cores_logical = None , cores_physical = None , isa = 'cpp' ,
844844 max_threads_per_block = 1024 , max_threads_dimx = 1024 ,
845- max_threads_dimy = 1024 , max_threads_dimz = 64 ):
845+ max_threads_dimy = 1024 , max_threads_dimz = 64 ,
846+ max_thread_block_cluster_size = 8 ):
846847 super ().__init__ (name )
847848
848849 cpu_info = get_cpu_info ()
@@ -855,6 +856,7 @@ def __init__(self, name, cores_logical=None, cores_physical=None, isa='cpp',
855856 self .max_threads_dimx = max_threads_dimx
856857 self .max_threads_dimy = max_threads_dimy
857858 self .max_threads_dimz = max_threads_dimz
859+ self .max_thread_block_cluster_size = max_thread_block_cluster_size
858860
859861 @classmethod
860862 def _mro (cls ):
@@ -961,6 +963,10 @@ def supports(self, query, language=None):
961963
962964class Hopper (Ampere ):
963965
966+ def __init__ (self , * args , ** kwargs ):
967+ kwargs .setdefault ('max_thread_block_cluster_size' , 16 )
968+ super ().__init__ (* args , ** kwargs )
969+
964970 def supports (self , query , language = None ):
965971 if query in ('tma' , 'thread-block-cluster' ):
966972 return True
0 commit comments