@@ -684,6 +684,12 @@ def limits(self, compiler=None, language=None):
684684 'max-block-dims' : sys .maxsize ,
685685 }
686686
687+ def supports (self , query , language = None ):
688+ """
689+ Return True if the platform supports a given feature, False otherwise.
690+ """
691+ return False
692+
687693
688694class Cpu64 (Platform ):
689695
@@ -897,12 +903,6 @@ def limits(self, compiler=None, language=None):
897903 'max-block-dims' : 3 ,
898904 }
899905
900- def supports (self , query , language = None ):
901- """
902- Check if the device supports a given feature.
903- """
904- return False
905-
906906
907907class IntelDevice (Device ):
908908
@@ -939,7 +939,7 @@ def supports(self, query, language=None):
939939 if query == 'async-loads' and cc >= 80 :
940940 # Asynchronous pipeline loads -- introduced in Ampere
941941 return True
942- elif query == 'tma' and cc >= 90 :
942+ elif query in ( 'tma' , 'thread-block-cluster' ) and cc >= 90 :
943943 # Tensor Memory Accelerator -- introduced in Hopper
944944 return True
945945 else :
@@ -953,25 +953,19 @@ class Volta(NvidiaDevice):
953953class Ampere (Volta ):
954954
955955 def supports (self , query , language = None ):
956- if language != 'cuda' :
957- return False
958-
959956 if query == 'async-loads' :
960957 return True
961-
962- return super ().supports (query , language )
958+ else :
959+ return super ().supports (query , language )
963960
964961
965962class Hopper (Ampere ):
966963
967964 def supports (self , query , language = None ):
968- if language != 'cuda' :
969- return False
970-
971- if query == 'tma' :
965+ if query in ('tma' , 'thread-block-cluster' ):
972966 return True
973-
974- return super ().supports (query , language )
967+ else :
968+ return super ().supports (query , language )
975969
976970
977971class Blackwell (Hopper ):
0 commit comments