@@ -98,9 +98,16 @@ def func_prefix(self, expr, abs=False):
9898
9999 def parenthesize (self , item , level , strict = False ):
100100 if isinstance (item , BooleanFunction ):
101- return "(%s)" % self ._print (item )
101+ return f"( { self ._print (item )} )"
102102 return super ().parenthesize (item , level , strict = strict )
103103
104+ def _print_PyCPointerType (self , expr ):
105+ ctype = f'{ self ._print_type (expr ._type_ )} '
106+ if ctype .endswith ('*' ):
107+ return f'{ ctype } *'
108+ else :
109+ return f'{ ctype } *'
110+
104111 def _print_type (self , expr ):
105112 try :
106113 expr = dtype_to_ctype (expr )
@@ -120,7 +127,7 @@ def _print_Function(self, expr):
120127 return super ()._print_Function (expr )
121128
122129 def _print_CondEq (self , expr ):
123- return "%s == %s" % ( self ._print (expr .lhs ), self ._print (expr .rhs ))
130+ return f" { self ._print (expr .lhs )} == { self ._print (expr .rhs )} "
124131
125132 def _print_Indexed (self , expr ):
126133 """
@@ -131,7 +138,7 @@ def _print_Indexed(self, expr):
131138 U[t,x,y,z] -> U[t][x][y][z]
132139 """
133140 inds = '' .join (['[' + self ._print (x ) + ']' for x in expr .indices ])
134- return '%s%s' % ( self ._print (expr .base .label ), inds )
141+ return f' { self ._print (expr .base .label )} { inds } '
135142
136143 def _print_FIndexed (self , expr ):
137144 """
@@ -146,7 +153,7 @@ def _print_FIndexed(self, expr):
146153 label = expr .accessor .label
147154 except AttributeError :
148155 label = expr .base .label
149- return '%s(%s)' % ( self ._print (label ), inds )
156+ return f' { self ._print (label )} ( { inds } )'
150157
151158 def _print_Rational (self , expr ):
152159 """Print a Rational as a C-like float/float division."""
@@ -155,10 +162,8 @@ def _print_Rational(self, expr):
155162 # to be 32-bit floats.
156163 # http://en.cppreference.com/w/cpp/language/floating_literal
157164 p , q = int (expr .p ), int (expr .q )
158- if self .dtype == np .float64 :
159- return '%d.0/%d.0' % (p , q )
160- else :
161- return '%d.0F/%d.0F' % (p , q )
165+ prec = self .prec_literal (expr )
166+ return f'{ p } .0{ prec } /{ q } .0{ prec } '
162167
163168 def _print_math_func (self , expr , nest = False , known = None ):
164169 cls = type (expr )
@@ -208,16 +213,22 @@ def _print_SafeInv(self, expr):
208213
209214 def _print_Mod (self , expr ):
210215 """Print a Mod as a C-like %-based operation."""
211- args = ['(%s)' % self ._print (a ) for a in expr .args ]
216+ args = [f'( { self ._print (a )} )' for a in expr .args ]
212217 return '%' .join (args )
213218
214219 def _print_Mul (self , expr ):
215- term = super ()._print_Mul (expr )
216- # avoid (-1)*...
217- term = term .replace ("(-1)*" , "-" )
218- # Avoid (-1) / ...
219- term = term .replace ("(-1)/" , f"-{ self ._prec (expr )(1 )} /" )
220- return term
220+ args = [a for a in expr .args if a != - 1 ]
221+ neg = (len (expr .args ) - len (args )) % 2
222+
223+ if len (args ) > 1 :
224+ term = super ()._print_Mul (expr .func (* args , evaluate = False ))
225+ else :
226+ term = self .parenthesize (args [0 ], precedence (expr ))
227+
228+ if neg :
229+ return f'-{ term } '
230+ else :
231+ return term
221232
222233 def _print_fmath_func (self , name , expr ):
223234 args = "," .join ([self ._print (i ) for i in expr .args ])
@@ -230,7 +241,7 @@ def _print_Min(self, expr):
230241 expr .func (* expr .args [1 :]),
231242 evaluate = False ))
232243 elif has_integer_args (* expr .args ) and len (expr .args ) == 2 :
233- return "MIN(%s)" % self ._print (expr .args )[1 :- 1 ]
244+ return f "MIN({ self ._print (expr .args )[1 :- 1 ]} )"
234245 else :
235246 return self ._print_fmath_func ('min' , expr )
236247
@@ -240,7 +251,7 @@ def _print_Max(self, expr):
240251 expr .func (* expr .args [1 :]),
241252 evaluate = False ))
242253 elif has_integer_args (* expr .args ) and len (expr .args ) == 2 :
243- return "MAX(%s)" % self ._print (expr .args )[1 :- 1 ]
254+ return f "MAX({ self ._print (expr .args )[1 :- 1 ]} )"
244255 else :
245256 return self ._print_fmath_func ('max' , expr )
246257
@@ -251,7 +262,7 @@ def _print_Abs(self, expr):
251262 # AOMPCC errors with abs, always use fabs
252263 if isinstance (self .compiler , AOMPCompiler ) and \
253264 not np .issubdtype (self ._prec (expr ), np .integer ):
254- return "fabs(%s)" % self ._print (arg )
265+ return f "fabs({ self ._print (arg )} )"
255266 return self ._print_fmath_func ('abs' , expr )
256267
257268 def _print_Add (self , expr , order = None ):
@@ -265,7 +276,7 @@ def _print_Add(self, expr, order=None):
265276 for term in terms :
266277 t = self ._print (term )
267278 if precedence (term ) < PREC :
268- l .extend (["+" , "(%s)" % t ])
279+ l .extend (["+" , f"( { t } )" ])
269280 elif t .startswith ('-' ):
270281 l .extend (["-" , t [1 :]])
271282 else :
@@ -305,44 +316,44 @@ def _print_Float(self, expr):
305316 return f'{ rv } { self .prec_literal (expr )} '
306317
307318 def _print_Differentiable (self , expr ):
308- return "(%s)" % self ._print (expr ._expr )
319+ return f"( { self ._print (expr ._expr )} )"
309320
310321 _print_EvalDerivative = _print_Add
311322
312323 def _print_CallFromPointer (self , expr ):
313324 indices = [self ._print (i ) for i in expr .params ]
314- return "%s->%s(%s)" % ( expr .pointer , expr .call , ', ' .join (indices ))
325+ return f" { expr .pointer } -> { expr .call } ( { ', ' .join (indices )} )"
315326
316327 def _print_CallFromComposite (self , expr ):
317328 indices = [self ._print (i ) for i in expr .params ]
318- return "%s.%s(%s)" % ( expr .pointer , expr .call , ', ' .join (indices ))
329+ return f" { expr .pointer } . { expr .call } ( { ', ' .join (indices )} )"
319330
320331 def _print_FieldFromPointer (self , expr ):
321- return "%s->%s" % ( expr .pointer , expr .field )
332+ return f" { expr .pointer } -> { expr .field } "
322333
323334 def _print_FieldFromComposite (self , expr ):
324- return "%s.%s" % ( expr .pointer , expr .field )
335+ return f" { expr .pointer } . { expr .field } "
325336
326337 def _print_ListInitializer (self , expr ):
327- return "{%s}" % ', ' .join ([ self ._print (i ) for i in expr .params ])
338+ return f"{{ { ', ' .join (self ._print (i ) for i in expr .params ) } }}"
328339
329340 def _print_IndexedPointer (self , expr ):
330- return "%s%s" % ( expr .base , '' .join ('[%s]' % self ._print (i ) for i in expr .index ))
341+ return f" { expr .base } { '' .join (f'[ { self ._print (i )} ]' for i in expr .index )} "
331342
332343 def _print_IntDiv (self , expr ):
333344 lhs = self ._print (expr .lhs )
334345 if not expr .lhs .is_Atom :
335- lhs = '(%s)' % ( lhs )
346+ lhs = f"( { lhs } )"
336347 rhs = self ._print (expr .rhs )
337348 PREC = precedence (expr )
338- return self .parenthesize ("%s / %s" % ( lhs , rhs ) , PREC )
349+ return self .parenthesize (f" { lhs } / { rhs } " , PREC )
339350
340351 def _print_InlineIf (self , expr ):
341352 cond = self ._print (expr .cond )
342353 true_expr = self ._print (expr .true_expr )
343354 false_expr = self ._print (expr .false_expr )
344355 PREC = precedence (expr )
345- return self .parenthesize ("(%s ) ? %s : %s" % ( cond , true_expr , false_expr ) , PREC )
356+ return self .parenthesize (f"( { cond } ) ? { true_expr } : { false_expr } " , PREC )
346357
347358 def _print_UnaryOp (self , expr , op = None , parenthesize = False ):
348359 op = op or expr ._op
@@ -356,20 +367,23 @@ def _print_Cast(self, expr):
356367 return self ._print_UnaryOp (expr , op = cast )
357368
358369 def _print_ComponentAccess (self , expr ):
359- return "%s.%s" % ( self ._print (expr .base ), expr .sindex )
370+ return f" { self ._print (expr .base )} . { expr .sindex } "
360371
361372 def _print_DefFunction (self , expr ):
362373 arguments = [self ._print (i ) for i in expr .arguments ]
363374 if expr .template :
364- template = '<%s>' % ',' .join ([str (i ) for i in expr .template ])
375+ ctemplate = ',' .join ([str (i ) for i in expr .template ])
376+ template = f'<{ ctemplate } >'
365377 else :
366378 template = ''
367- return "%s%s(%s)" % (expr .name , template , ',' .join (arguments ))
379+ args = ',' .join (arguments )
380+ return f"{ expr .name } { template } ({ args } )"
368381
369382 def _print_SizeOf (self , expr ):
370383 return f'sizeof({ self ._print (expr .intype )} { self ._print (expr .stars )} )'
371384
372- _print_MathFunction = _print_DefFunction
385+ def _print_MathFunction (self , expr ):
386+ return f"{ self ._ns } { self ._print_DefFunction (expr )} "
373387
374388 def _print_Fallback (self , expr ):
375389 return expr .__str__ ()
@@ -385,7 +399,7 @@ def _print_Fallback(self, expr):
385399
386400# Lifted from SymPy so that we go through our own `_print_math_func`
387401for k in ('exp log sin cos tan ceiling floor' ).split ():
388- setattr (BasePrinter , '_print_%s' % k , BasePrinter ._print_math_func )
402+ setattr (BasePrinter , f '_print_{ k } ' , BasePrinter ._print_math_func )
389403
390404
391405# Always parenthesize IntDiv and InlineIf within expressions
0 commit comments