Skip to content

Commit 4b97ddd

Browse files
FabioLuporinimloubout
authored andcommitted
compiler: Tweak Dereference
1 parent daecec4 commit 4b97ddd

3 files changed

Lines changed: 29 additions & 21 deletions

File tree

devito/ir/iet/nodes.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1047,9 +1047,8 @@ class Dereference(ExprStmt, Node):
10471047
A node encapsulating a dereference from a `pointer` to a `pointee`.
10481048
The following cases are supported:
10491049
1050-
* `pointer` is a PointerArray or TempFunction, and `pointee` is an Array.
1051-
* `pointer` is an ArrayObject representing a pointer to a C struct, and
1052-
`pointee` is a field in `pointer`.
1050+
* `pointer` is an AbstractFunction, and `pointee` is an Array.
1051+
* `pointer` is an AbstractObject, and `pointee` is an Array.
10531052
* `pointer` is a Symbol with its _C_ctype deriving from ct._Pointer, and
10541053
`pointee` is a Symbol representing the dereferenced value.
10551054
"""
@@ -1075,20 +1074,23 @@ def expr_symbols(self):
10751074
assert issubclass(self.pointer._C_ctype, ctypes._Pointer), \
10761075
"Scalar dereference must have a pointer ctype"
10771076
ret.extend([self.pointer._C_symbol, self.pointee._C_symbol])
1078-
elif self.pointer.is_PointerArray or self.pointer.is_TempFunction:
1077+
elif self.pointer.is_AbstractFunction:
10791078
ret.extend([self.pointer.indexed, self.pointee.indexed])
10801079
ret.extend(flatten(i.free_symbols
10811080
for i in self.pointee.symbolic_shape[1:]))
10821081
ret.extend(self.pointer.free_symbols)
1082+
elif self.pointer.is_AbstractObject:
1083+
ret.extend([self.pointer, self.pointee.indexed])
1084+
ret.extend(flatten(i.free_symbols
1085+
for i in self.pointee.symbolic_shape[1:]))
10831086
else:
1084-
ret.extend([self.pointer.indexed, self.pointee._C_symbol])
1087+
assert False, f"Unexpected pointer type {type(self.pointer)}"
1088+
10851089
return tuple(filter_ordered(ret))
10861090

10871091
@property
10881092
def defines(self):
1089-
if self.pointer.is_PointerArray or \
1090-
self.pointer.is_TempFunction or \
1091-
self.pointee._mem_stack:
1093+
if self.pointer.is_AbstractFunction or self.pointee._mem_stack:
10921094
return (self.pointee.indexed, self.pointee)
10931095
else:
10941096
return (self.pointee,)

devito/ir/iet/visitors.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -506,24 +506,30 @@ def visit_PointerCast(self, o):
506506

507507
def visit_Dereference(self, o):
508508
a0, a1 = o.functions
509-
if a1.is_PointerArray or a1.is_TempFunction:
510-
i = a1.indexed
511-
cstr = self.ccode(i._C_typedata)
509+
if a0.is_AbstractFunction:
510+
cstr = self.ccode(a0.indexed._C_typedata)
511+
512+
try:
513+
# Special AbstractFunctions such as PointerArray or TempFunction
514+
cdim = f'[{a1.dim.name}]'
515+
except AttributeError:
516+
cdim = ''
517+
512518
if o.flat is None:
513519
shape = ''.join(f"[{self.ccode(i)}]" for i in a0.symbolic_shape[1:])
514-
rvalue = f'({cstr} (*){shape}) {a1.name}[{a1.dim.name}]'
520+
rvalue = f'({cstr} (*){shape}) {a1.name}{cdim}'
515521
lvalue = c.Value(cstr, f'(*{self._restrict_keyword} {a0.name}){shape}')
516522
else:
517-
rvalue = f'({cstr} *) {a1.name}[{a1.dim.name}]'
523+
rvalue = f'({cstr} *) {a1.name}{cdim}'
518524
lvalue = c.Value(cstr, f'*{self._restrict_keyword} {a0.name}')
519-
if a0._data_alignment:
520-
lvalue = c.AlignedAttribute(a0._data_alignment, lvalue)
525+
521526
else:
522527
if a1.is_Symbol:
523528
rvalue = f'*{a1.name}'
524529
else:
525530
rvalue = f'{a1.name}->{a0._C_name}'
526531
lvalue = self._gen_value(a0, 0)
532+
527533
return c.Initializer(lvalue, rvalue)
528534

529535
def visit_Block(self, o):

examples/performance/00_overview.ipynb

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -701,7 +701,7 @@
701701
" #pragma omp parallel num_threads(nthreads)\n",
702702
" {\n",
703703
" const int tid = omp_get_thread_num();\n",
704-
" float (*restrict r0)[z_size] __attribute__ ((aligned (64))) = (float (*)[z_size]) pr0[tid];\n",
704+
" float (*restrict r0)[z_size] = (float (*)[z_size]) pr0[tid];\n",
705705
"\n",
706706
" #pragma omp for schedule(dynamic,1)\n",
707707
" for (int x = x_m; x <= x_M; x += 1)\n",
@@ -836,7 +836,7 @@
836836
" #pragma omp parallel num_threads(nthreads)\n",
837837
" {\n",
838838
" const int tid = omp_get_thread_num();\n",
839-
" float (*restrict r1)[z_size] __attribute__ ((aligned (64))) = (float (*)[z_size]) pr1[tid];\n",
839+
" float (*restrict r1)[z_size] = (float (*)[z_size]) pr1[tid];\n",
840840
"\n",
841841
" #pragma omp for schedule(dynamic,1)\n",
842842
" for (int x = x_m; x <= x_M; x += 1)\n",
@@ -968,7 +968,7 @@
968968
" #pragma omp parallel num_threads(nthreads)\n",
969969
" {\n",
970970
" const int tid = omp_get_thread_num();\n",
971-
" float (*restrict r0)[z_size] __attribute__ ((aligned (64))) = (float (*)[z_size]) pr0[tid];\n",
971+
" float (*restrict r0)[z_size] = (float (*)[z_size]) pr0[tid];\n",
972972
"\n",
973973
" #pragma omp for schedule(dynamic,1)\n",
974974
" for (int x = x_m; x <= x_M; x += 1)\n",
@@ -1236,7 +1236,7 @@
12361236
" #pragma omp parallel num_threads(nthreads)\n",
12371237
" {\n",
12381238
" const int tid = omp_get_thread_num();\n",
1239-
" float (*restrict r2)[z_size] __attribute__ ((aligned (64))) = (float (*)[z_size]) pr2[tid];\n",
1239+
" float (*restrict r2)[z_size] = (float (*)[z_size]) pr2[tid];\n",
12401240
"\n",
12411241
" #pragma omp for collapse(2) schedule(dynamic,1)\n",
12421242
" for (int x0_blk0 = x_m; x0_blk0 <= x_M; x0_blk0 += x0_blk0_size)\n",
@@ -1348,7 +1348,7 @@
13481348
" #pragma omp parallel num_threads(nthreads)\n",
13491349
" {\n",
13501350
" const int tid = omp_get_thread_num();\n",
1351-
" float (*restrict r2)[z_size] __attribute__ ((aligned (64))) = (float (*)[z_size]) pr2[tid];\n",
1351+
" float (*restrict r2)[z_size] = (float (*)[z_size]) pr2[tid];\n",
13521352
"\n",
13531353
" #pragma omp for collapse(2) schedule(dynamic,1)\n",
13541354
" for (int x0_blk0 = x_m; x0_blk0 <= x_M; x0_blk0 += x0_blk0_size)\n",
@@ -1527,7 +1527,7 @@
15271527
" #pragma omp parallel num_threads(nthreads)\n",
15281528
" {\n",
15291529
" const int tid = omp_get_thread_num();\n",
1530-
" float (*restrict r2)[z_size] __attribute__ ((aligned (64))) = (float (*)[z_size]) pr2[tid];\n",
1530+
" float (*restrict r2)[z_size] = (float (*)[z_size]) pr2[tid];\n",
15311531
"\n",
15321532
" #pragma omp for schedule(dynamic,1)\n",
15331533
" for (int x = x_m; x <= x_M; x += 1)\n",

0 commit comments

Comments
 (0)