@@ -176,8 +176,9 @@ class CGen(Visitor):
176176 Return a representation of the Iteration/Expression tree as a :module:`cgen` tree.
177177 """
178178
179- def __init__ (self , * args , ** kwargs ):
179+ def __init__ (self , * args , language = None , ** kwargs ):
180180 super ().__init__ (* args , ** kwargs )
181+ self .language = language
181182
182183 # The following mappers may be customized by subclasses (that is,
183184 # backend-specific CGen-erators)
@@ -189,6 +190,9 @@ def __init__(self, *args, **kwargs):
189190 }
190191 _restrict_keyword = 'restrict'
191192
193+ def ccode (self , expr , ** kwargs ):
194+ return ccode (expr , language = self .language , ** kwargs )
195+
192196 def _gen_struct_decl (self , obj , masked = ()):
193197 """
194198 Convert ctypes.Struct -> cgen.Structure.
@@ -222,7 +226,7 @@ def _gen_struct_decl(self, obj, masked=()):
222226 try :
223227 entries .append (self ._gen_value (i , 0 , masked = ('const' ,)))
224228 except AttributeError :
225- cstr = ccode (ct )
229+ cstr = self . ccode (ct )
226230 if ct is c_restrict_void_p :
227231 cstr = '%srestrict' % cstr
228232 entries .append (c .Value (cstr , n ))
@@ -244,10 +248,10 @@ def _gen_value(self, obj, mode=1, masked=()):
244248 if getattr (obj .function , k , False ) and v not in masked ]
245249
246250 if (obj ._mem_stack or obj ._mem_constant ) and mode == 1 :
247- strtype = ccode (obj ._C_typedata )
248- strshape = '' .join ('[%s]' % ccode (i ) for i in obj .symbolic_shape )
251+ strtype = self . ccode (obj ._C_typedata )
252+ strshape = '' .join ('[%s]' % self . ccode (i ) for i in obj .symbolic_shape )
249253 else :
250- strtype = ccode (obj ._C_ctype )
254+ strtype = self . ccode (obj ._C_ctype )
251255 strshape = ''
252256 if isinstance (obj , (AbstractFunction , IndexedData )) and mode >= 1 :
253257 if not obj ._mem_stack :
@@ -261,7 +265,7 @@ def _gen_value(self, obj, mode=1, masked=()):
261265 strobj = '%s%s' % (strname , strshape )
262266
263267 if obj .is_LocalObject and obj .cargs and mode == 1 :
264- arguments = [ccode (i ) for i in obj .cargs ]
268+ arguments = [self . ccode (i ) for i in obj .cargs ]
265269 strobj = MultilineCall (strobj , arguments , True )
266270
267271 value = c .Value (strtype , strobj )
@@ -275,9 +279,9 @@ def _gen_value(self, obj, mode=1, masked=()):
275279 if obj .is_Array and obj .initvalue is not None and mode == 1 :
276280 init = ListInitializer (obj .initvalue )
277281 if not obj ._mem_constant or init .is_numeric :
278- value = c .Initializer (value , ccode (init ))
282+ value = c .Initializer (value , self . ccode (init ))
279283 elif obj .is_LocalObject and obj .initvalue is not None and mode == 1 :
280- value = c .Initializer (value , ccode (obj .initvalue ))
284+ value = c .Initializer (value , self . ccode (obj .initvalue ))
281285
282286 return value
283287
@@ -311,7 +315,7 @@ def _args_call(self, args):
311315 else :
312316 ret .append (i ._C_name )
313317 except AttributeError :
314- ret .append (ccode (i ))
318+ ret .append (self . ccode (i ))
315319 return ret
316320
317321 def _gen_signature (self , o , is_declaration = False ):
@@ -388,7 +392,7 @@ def visit_tuple(self, o):
388392 def visit_PointerCast (self , o ):
389393 f = o .function
390394 i = f .indexed
391- cstr = ccode (i ._C_typedata )
395+ cstr = self . ccode (i ._C_typedata )
392396
393397 if f .is_PointerArray :
394398 # lvalue
@@ -410,7 +414,7 @@ def visit_PointerCast(self, o):
410414 else :
411415 v = f .name
412416 if o .flat is None :
413- shape = '' .join ("[%s]" % ccode (i ) for i in o .castshape )
417+ shape = '' .join ("[%s]" % self . ccode (i ) for i in o .castshape )
414418 rshape = '(*)%s' % shape
415419 lvalue = c .Value (cstr , '(*restrict %s)%s' % (v , shape ))
416420 else :
@@ -443,9 +447,9 @@ def visit_Dereference(self, o):
443447 a0 , a1 = o .functions
444448 if a1 .is_PointerArray or a1 .is_TempFunction :
445449 i = a1 .indexed
446- cstr = ccode (i ._C_typedata )
450+ cstr = self . ccode (i ._C_typedata )
447451 if o .flat is None :
448- shape = '' .join ("[%s]" % ccode (i ) for i in a0 .symbolic_shape [1 :])
452+ shape = '' .join ("[%s]" % self . ccode (i ) for i in a0 .symbolic_shape [1 :])
449453 rvalue = '(%s (*)%s) %s[%s]' % (cstr , shape , a1 .name ,
450454 a1 .dim .name )
451455 lvalue = c .Value (cstr , '(*restrict %s)%s' % (a0 .name , shape ))
@@ -484,8 +488,8 @@ def visit_Definition(self, o):
484488 return self ._gen_value (o .function )
485489
486490 def visit_Expression (self , o ):
487- lhs = ccode (o .expr .lhs , dtype = o .dtype )
488- rhs = ccode (o .expr .rhs , dtype = o .dtype )
491+ lhs = self . ccode (o .expr .lhs , dtype = o .dtype )
492+ rhs = self . ccode (o .expr .rhs , dtype = o .dtype )
489493
490494 if o .init :
491495 code = c .Initializer (self ._gen_value (o .expr .lhs , 0 ), rhs )
@@ -498,8 +502,8 @@ def visit_Expression(self, o):
498502 return code
499503
500504 def visit_AugmentedExpression (self , o ):
501- c_lhs = ccode (o .expr .lhs , dtype = o .dtype )
502- c_rhs = ccode (o .expr .rhs , dtype = o .dtype )
505+ c_lhs = self . ccode (o .expr .lhs , dtype = o .dtype )
506+ c_rhs = self . ccode (o .expr .rhs , dtype = o .dtype )
503507 code = c .Statement ("%s %s= %s" % (c_lhs , o .op , c_rhs ))
504508 if o .pragmas :
505509 code = c .Module (self ._visit (o .pragmas ) + (code ,))
@@ -518,7 +522,7 @@ def visit_Call(self, o, nested_call=False):
518522 o .templates )
519523 if retobj .is_Indexed or \
520524 isinstance (retobj , (FieldFromComposite , FieldFromPointer )):
521- return c .Assign (ccode (retobj ), call )
525+ return c .Assign (self . ccode (retobj ), call )
522526 else :
523527 return c .Initializer (c .Value (rettype , retobj ._C_name ), call )
524528
@@ -532,9 +536,9 @@ def visit_Conditional(self, o):
532536 then_body = c .Block (self ._visit (then_body ))
533537 if else_body :
534538 else_body = c .Block (self ._visit (else_body ))
535- return c .If (ccode (o .condition ), then_body , else_body )
539+ return c .If (self . ccode (o .condition ), then_body , else_body )
536540 else :
537- return c .If (ccode (o .condition ), then_body )
541+ return c .If (self . ccode (o .condition ), then_body )
538542
539543 def visit_Iteration (self , o ):
540544 body = flatten (self ._visit (i ) for i in self ._blankline_logic (o .children ))
@@ -544,23 +548,23 @@ def visit_Iteration(self, o):
544548
545549 # For backward direction flip loop bounds
546550 if o .direction == Backward :
547- loop_init = 'int %s = %s' % (o .index , ccode (_max ))
548- loop_cond = '%s >= %s' % (o .index , ccode (_min ))
551+ loop_init = 'int %s = %s' % (o .index , self . ccode (_max ))
552+ loop_cond = '%s >= %s' % (o .index , self . ccode (_min ))
549553 loop_inc = '%s -= %s' % (o .index , o .limits [2 ])
550554 else :
551- loop_init = 'int %s = %s' % (o .index , ccode (_min ))
552- loop_cond = '%s <= %s' % (o .index , ccode (_max ))
555+ loop_init = 'int %s = %s' % (o .index , self . ccode (_min ))
556+ loop_cond = '%s <= %s' % (o .index , self . ccode (_max ))
553557 loop_inc = '%s += %s' % (o .index , o .limits [2 ])
554558
555559 # Append unbounded indices, if any
556560 if o .uindices :
557- uinit = ['%s = %s' % (i .name , ccode (i .symbolic_min )) for i in o .uindices ]
561+ uinit = ['%s = %s' % (i .name , self . ccode (i .symbolic_min )) for i in o .uindices ]
558562 loop_init = c .Line (', ' .join ([loop_init ] + uinit ))
559563
560564 ustep = []
561565 for i in o .uindices :
562566 op = '=' if i .is_Modulo else '+='
563- ustep .append ('%s %s %s' % (i .name , op , ccode (i .symbolic_incr )))
567+ ustep .append ('%s %s %s' % (i .name , op , self . ccode (i .symbolic_incr )))
564568 loop_inc = c .Line (', ' .join ([loop_inc ] + ustep ))
565569
566570 # Create For header+body
@@ -577,7 +581,7 @@ def visit_Pragma(self, o):
577581 return c .Pragma (o ._generate )
578582
579583 def visit_While (self , o ):
580- condition = ccode (o .condition )
584+ condition = self . ccode (o .condition )
581585 if o .body :
582586 body = flatten (self ._visit (i ) for i in o .children )
583587 return c .While (condition , c .Block (body ))
0 commit comments