Skip to content

Commit 049f17a

Browse files
committed
compiler: convert printer to f-string
1 parent 91f2018 commit 049f17a

6 files changed

Lines changed: 74 additions & 41 deletions

File tree

devito/ir/cgen/printer.py

Lines changed: 48 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,16 @@ def func_prefix(self, expr, abs=False):
9898

9999
def parenthesize(self, item, level, strict=False):
100100
if isinstance(item, BooleanFunction):
101-
return "(%s)" % self._print(item)
101+
return f"({self._print(item)})"
102102
return super().parenthesize(item, level, strict=strict)
103103

104+
def _print_PyCPointerType(self, expr):
105+
ctype = f'{self._print_type(expr._type_)}'
106+
if ctype.endswith('*'):
107+
return f'{ctype}*'
108+
else:
109+
return f'{ctype} *'
110+
104111
def _print_type(self, expr):
105112
try:
106113
expr = dtype_to_ctype(expr)
@@ -120,7 +127,7 @@ def _print_Function(self, expr):
120127
return super()._print_Function(expr)
121128

122129
def _print_CondEq(self, expr):
123-
return "%s == %s" % (self._print(expr.lhs), self._print(expr.rhs))
130+
return f"{self._print(expr.lhs)} == {self._print(expr.rhs)}"
124131

125132
def _print_Indexed(self, expr):
126133
"""
@@ -131,7 +138,7 @@ def _print_Indexed(self, expr):
131138
U[t,x,y,z] -> U[t][x][y][z]
132139
"""
133140
inds = ''.join(['[' + self._print(x) + ']' for x in expr.indices])
134-
return '%s%s' % (self._print(expr.base.label), inds)
141+
return f'{self._print(expr.base.label)}{inds}'
135142

136143
def _print_FIndexed(self, expr):
137144
"""
@@ -146,7 +153,7 @@ def _print_FIndexed(self, expr):
146153
label = expr.accessor.label
147154
except AttributeError:
148155
label = expr.base.label
149-
return '%s(%s)' % (self._print(label), inds)
156+
return f'{self._print(label)}({inds})'
150157

151158
def _print_Rational(self, expr):
152159
"""Print a Rational as a C-like float/float division."""
@@ -155,10 +162,8 @@ def _print_Rational(self, expr):
155162
# to be 32-bit floats.
156163
# http://en.cppreference.com/w/cpp/language/floating_literal
157164
p, q = int(expr.p), int(expr.q)
158-
if self.dtype == np.float64:
159-
return '%d.0/%d.0' % (p, q)
160-
else:
161-
return '%d.0F/%d.0F' % (p, q)
165+
prec = self.prec_literal(expr)
166+
return f'{p}.0{prec}/{q}.0{prec}'
162167

163168
def _print_math_func(self, expr, nest=False, known=None):
164169
cls = type(expr)
@@ -208,16 +213,22 @@ def _print_SafeInv(self, expr):
208213

209214
def _print_Mod(self, expr):
210215
"""Print a Mod as a C-like %-based operation."""
211-
args = ['(%s)' % self._print(a) for a in expr.args]
216+
args = [f'({self._print(a)})' for a in expr.args]
212217
return '%'.join(args)
213218

214219
def _print_Mul(self, expr):
215-
term = super()._print_Mul(expr)
216-
# avoid (-1)*...
217-
term = term.replace("(-1)*", "-")
218-
# Avoid (-1) / ...
219-
term = term.replace("(-1)/", f"-{self._prec(expr)(1)}/")
220-
return term
220+
args = [a for a in expr.args if a != -1]
221+
neg = (len(expr.args) - len(args)) % 2
222+
223+
if len(args) > 1:
224+
term = super()._print_Mul(expr.func(*args, evaluate=False))
225+
else:
226+
term = self.parenthesize(args[0], precedence(expr))
227+
228+
if neg:
229+
return f'-{term}'
230+
else:
231+
return term
221232

222233
def _print_fmath_func(self, name, expr):
223234
args = ",".join([self._print(i) for i in expr.args])
@@ -230,7 +241,7 @@ def _print_Min(self, expr):
230241
expr.func(*expr.args[1:]),
231242
evaluate=False))
232243
elif has_integer_args(*expr.args) and len(expr.args) == 2:
233-
return "MIN(%s)" % self._print(expr.args)[1:-1]
244+
return f"MIN({self._print(expr.args)[1:-1]})"
234245
else:
235246
return self._print_fmath_func('min', expr)
236247

@@ -240,7 +251,7 @@ def _print_Max(self, expr):
240251
expr.func(*expr.args[1:]),
241252
evaluate=False))
242253
elif has_integer_args(*expr.args) and len(expr.args) == 2:
243-
return "MAX(%s)" % self._print(expr.args)[1:-1]
254+
return f"MAX({self._print(expr.args)[1:-1]})"
244255
else:
245256
return self._print_fmath_func('max', expr)
246257

@@ -251,7 +262,7 @@ def _print_Abs(self, expr):
251262
# AOMPCC errors with abs, always use fabs
252263
if isinstance(self.compiler, AOMPCompiler) and \
253264
not np.issubdtype(self._prec(expr), np.integer):
254-
return "fabs(%s)" % self._print(arg)
265+
return f"fabs({self._print(arg)})"
255266
return self._print_fmath_func('abs', expr)
256267

257268
def _print_Add(self, expr, order=None):
@@ -265,7 +276,7 @@ def _print_Add(self, expr, order=None):
265276
for term in terms:
266277
t = self._print(term)
267278
if precedence(term) < PREC:
268-
l.extend(["+", "(%s)" % t])
279+
l.extend(["+", f"({t})"])
269280
elif t.startswith('-'):
270281
l.extend(["-", t[1:]])
271282
else:
@@ -305,44 +316,44 @@ def _print_Float(self, expr):
305316
return f'{rv}{self.prec_literal(expr)}'
306317

307318
def _print_Differentiable(self, expr):
308-
return "(%s)" % self._print(expr._expr)
319+
return f"({self._print(expr._expr)})"
309320

310321
_print_EvalDerivative = _print_Add
311322

312323
def _print_CallFromPointer(self, expr):
313324
indices = [self._print(i) for i in expr.params]
314-
return "%s->%s(%s)" % (expr.pointer, expr.call, ', '.join(indices))
325+
return f"{expr.pointer}->{expr.call}({', '.join(indices)})"
315326

316327
def _print_CallFromComposite(self, expr):
317328
indices = [self._print(i) for i in expr.params]
318-
return "%s.%s(%s)" % (expr.pointer, expr.call, ', '.join(indices))
329+
return f"{expr.pointer}.{expr.call}({', '.join(indices)})"
319330

320331
def _print_FieldFromPointer(self, expr):
321-
return "%s->%s" % (expr.pointer, expr.field)
332+
return f"{expr.pointer}->{expr.field}"
322333

323334
def _print_FieldFromComposite(self, expr):
324-
return "%s.%s" % (expr.pointer, expr.field)
335+
return f"{expr.pointer}.{expr.field}"
325336

326337
def _print_ListInitializer(self, expr):
327-
return "{%s}" % ', '.join([self._print(i) for i in expr.params])
338+
return f"{{{', '.join(self._print(i) for i in expr.params)}}}"
328339

329340
def _print_IndexedPointer(self, expr):
330-
return "%s%s" % (expr.base, ''.join('[%s]' % self._print(i) for i in expr.index))
341+
return f"{expr.base}{''.join(f'[{self._print(i)}]' for i in expr.index)}"
331342

332343
def _print_IntDiv(self, expr):
333344
lhs = self._print(expr.lhs)
334345
if not expr.lhs.is_Atom:
335-
lhs = '(%s)' % (lhs)
346+
lhs = f"({lhs})"
336347
rhs = self._print(expr.rhs)
337348
PREC = precedence(expr)
338-
return self.parenthesize("%s / %s" % (lhs, rhs), PREC)
349+
return self.parenthesize(f"{lhs} / {rhs}", PREC)
339350

340351
def _print_InlineIf(self, expr):
341352
cond = self._print(expr.cond)
342353
true_expr = self._print(expr.true_expr)
343354
false_expr = self._print(expr.false_expr)
344355
PREC = precedence(expr)
345-
return self.parenthesize("(%s) ? %s : %s" % (cond, true_expr, false_expr), PREC)
356+
return self.parenthesize(f"({cond}) ? {true_expr} : {false_expr}", PREC)
346357

347358
def _print_UnaryOp(self, expr, op=None, parenthesize=False):
348359
op = op or expr._op
@@ -356,20 +367,23 @@ def _print_Cast(self, expr):
356367
return self._print_UnaryOp(expr, op=cast)
357368

358369
def _print_ComponentAccess(self, expr):
359-
return "%s.%s" % (self._print(expr.base), expr.sindex)
370+
return f"{self._print(expr.base)}.{expr.sindex}"
360371

361372
def _print_DefFunction(self, expr):
362373
arguments = [self._print(i) for i in expr.arguments]
363374
if expr.template:
364-
template = '<%s>' % ','.join([str(i) for i in expr.template])
375+
ctemplate = ','.join([str(i) for i in expr.template])
376+
template = f'<{ctemplate}>'
365377
else:
366378
template = ''
367-
return "%s%s(%s)" % (expr.name, template, ','.join(arguments))
379+
args = ','.join(arguments)
380+
return f"{expr.name}{template}({args})"
368381

369382
def _print_SizeOf(self, expr):
370383
return f'sizeof({self._print(expr.intype)}{self._print(expr.stars)})'
371384

372-
_print_MathFunction = _print_DefFunction
385+
def _print_MathFunction(self, expr):
386+
return f"{self._ns}{self._print_DefFunction(expr)}"
373387

374388
def _print_Fallback(self, expr):
375389
return expr.__str__()
@@ -385,7 +399,7 @@ def _print_Fallback(self, expr):
385399

386400
# Lifted from SymPy so that we go through our own `_print_math_func`
387401
for k in ('exp log sin cos tan ceiling floor').split():
388-
setattr(BasePrinter, '_print_%s' % k, BasePrinter._print_math_func)
402+
setattr(BasePrinter, f'_print_{k}', BasePrinter._print_math_func)
389403

390404

391405
# Always parenthesize IntDiv and InlineIf within expressions

devito/ir/iet/nodes.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
'Increment', 'Return', 'While', 'ListMajor', 'ParallelIteration',
3131
'ParallelBlock', 'Dereference', 'Lambda', 'SyncSpot', 'Pragma',
3232
'DummyExpr', 'BlankLine', 'ParallelTree', 'BusyWait', 'UsingNamespace',
33-
'CallableBody', 'Transfer']
33+
'Using', 'CallableBody', 'Transfer']
3434

3535
# First-class IET nodes
3636

@@ -1217,6 +1217,19 @@ def periodic(self):
12171217
return self._periodic
12181218

12191219

1220+
class Using(Node):
1221+
1222+
"""
1223+
A C++ using directive.
1224+
"""
1225+
1226+
def __init__(self, name):
1227+
self.name = name
1228+
1229+
def __repr__(self):
1230+
return "<Using(%s)>" % self.name
1231+
1232+
12201233
class UsingNamespace(Node):
12211234

12221235
"""

devito/ir/iet/visitors.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,9 @@ def visit_MultiTraversable(self, o):
614614
body.extend(as_tuple(v))
615615
return c.Collection(body)
616616

617+
def visit_Using(self, o):
618+
return c.Statement(f'using {str(o.name)}')
619+
617620
def visit_UsingNamespace(self, o):
618621
return c.Statement(f'using namespace {str(o.namespace)}')
619622

devito/operator/operator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1139,7 +1139,7 @@ def __setstate__(self, state):
11391139
self._lib.name = soname
11401140

11411141
self._allocator = default_allocator(
1142-
'%s.%s.%s' % (self._compiler.__class__.name, self._language, self._platform)
1142+
'%s.%s.%s' % (type(self._compiler).__name__, self._language, self._platform)
11431143
)
11441144

11451145

devito/passes/iet/errors.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
List, Break, Return, FindNodes, FindSymbols, Transformer,
77
make_callable)
88
from devito.passes.iet.engine import iet_pass
9-
from devito.symbolics import CondEq, DefFunction
9+
from devito.symbolics import CondEq, MathFunction
1010
from devito.tools import dtype_to_ctype
1111
from devito.types import Eq, Inc, LocalObject, Symbol
1212

@@ -58,7 +58,7 @@ def _check_stability(iet, wmovs=(), rcompile=None, sregistry=None):
5858
irs, byproduct = rcompile(eqns)
5959

6060
name = sregistry.make_name(prefix='is_finite')
61-
retval = Return(DefFunction('isfinite', accumulator))
61+
retval = Return(MathFunction('isfinite', accumulator))
6262
body = irs.iet.body.body + (retval,)
6363
efunc = make_callable(name, body, retval='int')
6464

devito/symbolics/extended_sympy.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -776,9 +776,12 @@ def __new__(cls, intype, stars=None, **kwargs):
776776
stars = stars or ''
777777
if not isinstance(intype, (str, ReservedWord)):
778778
ctype = dtype_to_ctype(intype)
779-
if ctype in ctypes_vector_mapper.values():
780-
idx = list(ctypes_vector_mapper.values()).index(ctype)
781-
intype = list(ctypes_vector_mapper.keys())[idx]
779+
for k, v in ctypes_vector_mapper.items():
780+
if ctype is v:
781+
intype = k
782+
break
783+
else:
784+
intype = ctypes_to_cstr(ctype)
782785

783786
newobj = super().__new__(cls, 'sizeof', arguments=f'{intype}{stars}', **kwargs)
784787
newobj.stars = stars

0 commit comments

Comments
 (0)