Skip to content

Commit 913b3b4

Browse files
committed
compiler: Support increment indexing
1 parent 91564c0 commit 913b3b4

7 files changed

Lines changed: 54 additions & 35 deletions

File tree

devito/ir/cgen/printer.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from devito.arch.compiler import AOMPCompiler
1818
from devito.symbolics.inspection import has_integer_args, sympy_dtype
1919
from devito.symbolics.queries import q_leaf
20-
from devito.types.basic import AbstractFunction
20+
from devito.types.basic import AbstractFunction, PostIncrementIndex
2121
from devito.tools import ctypes_to_cstr, dtype_to_ctype, ctypes_vector_mapper
2222

2323
__all__ = ['BasePrinter', 'ccode']
@@ -148,8 +148,13 @@ def _print_Indexed(self, expr):
148148
--------
149149
U[t,x,y,z] -> U[t][x][y][z]
150150
"""
151-
inds = ''.join(['[' + self._print(x) + ']' for x in expr.indices])
152-
return f'{self._print(expr.base.label)}{inds}'
151+
inds = []
152+
for i in expr.indices:
153+
if isinstance(i, PostIncrementIndex):
154+
inds.append(f"[{self._print(i)}++]")
155+
else:
156+
inds.append(f"[{self._print(i)}]")
157+
return f"{self._print(expr.base.label)}{''.join(inds)}"
153158

154159
def _print_FIndexed(self, expr):
155160
"""
@@ -165,7 +170,19 @@ def _print_FIndexed(self, expr):
165170
except AttributeError:
166171
label = expr.base.label
167172
return f'{self._print(label)}({inds})'
168-
173+
174+
# def _print_PostIncrementIndexed(self, expr):
175+
# """
176+
# Print an Indexed as a ...
177+
178+
# Examples
179+
# --------
180+
# U[k] -> U[k++]
181+
# """
182+
# # from IPython import embed; embed()
183+
# inds = ''.join(['[' + self._print(x) + '++' + ']' for x in expr.indices])
184+
# return f'{self._print(expr.base.label)}{inds}'
185+
169186
def _print_Rational(self, expr):
170187
"""Print a Rational as a C-like float/float division."""
171188
# This method and _print_Float below forcefully add a F to any

devito/petsc/iet/callbacks.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
BlankLine, Callable, Iteration, PointerCast, Definition
66
)
77
from devito.symbolics import (
8-
Byref, FieldFromPointer, IntDiv, Deref, Mod, String, Null, VOID
8+
Byref, FieldFromPointer, IntDiv, Deref, Mod, String, Null, VOID, Cast
99
)
1010
from devito.symbolics.unevaluation import Mul
1111
from devito.types.basic import AbstractFunction
@@ -17,7 +17,7 @@
1717
from devito.petsc.iet.type_builder import objs
1818
from devito.petsc.types.macros import petsc_func_begin_user
1919
from devito.petsc.types.modes import InsertMode
20-
from devito.petsc.types.object import TempSymb
20+
from devito.petsc.types.object import Counter
2121

2222

2323
class BaseCallbackBuilder:
@@ -693,10 +693,6 @@ def _create_count_bc_body(self, body):
693693
ctx = objs['dummyctx']
694694

695695

696-
dm_get_local_info = petsc_call(
697-
'DMDAGetLocalInfo', [dmda, Byref(linsolve_expr.localinfo)]
698-
)
699-
700696
body = self.time_dependence.uxreplace_time(body)
701697

702698
fields = get_user_struct_fields(body)
@@ -710,21 +706,24 @@ def _create_count_bc_body(self, body):
710706

711707
# body = body._rebuild(body=body.body)
712708

713-
body = body._rebuild(body.body)
714-
715-
stacks = (
716-
dm_get_local_info,
717-
)
718-
719709
# Dereference function data in struct
720-
derefs = dereference_funcs(ctx, fields)
710+
# derefs = dereference_funcs(ctx, fields)
711+
712+
# OBVS change names
713+
deref_ptr = DummyExpr(Counter, Deref(sobjs['numBCPtr']))
714+
move_ptr = DummyExpr(Deref(sobjs['numBCPtr']), Counter)
715+
716+
# from IPython import embed; embed()
721717

722718
# Force the struct definition to appear at the very start, since
723719
# stacks, allocs etc may rely on its information
724720
struct_definition = [Definition(ctx), dm_get_app_context]
725721

722+
723+
body = body._rebuild(body.body + (move_ptr,))
724+
726725
body = self._make_callable_body(
727-
body, standalones=struct_definition, stacks=stacks+derefs
726+
body, standalones=struct_definition, stacks=(deref_ptr,)
728727
)
729728

730729
# Replace non-function data with pointer to data in struct
@@ -734,7 +733,7 @@ def _create_count_bc_body(self, body):
734733
# subs[]
735734
# subs[self.target] = sobjs['numBC']
736735

737-
subs[TempSymb._C_symbol] = sobjs['numBCPtr']._C_symbol
736+
# subs[Counter._C_symbol] = Cast(Deref(sobjs['numBCPtr']._C_symbol))
738737

739738
# from IPython import embed; embed()
740739

@@ -785,7 +784,7 @@ def _create_set_point_bc_body(self, body):
785784
i in fields if not isinstance(i.function, AbstractFunction)}
786785

787786

788-
subs[TempSymb._C_symbol] = sobjs['bcPointsArr'].indexed[sobjs['k_iter']]
787+
subs[Counter._C_symbol] = sobjs['bcPointsArr'].indexed[sobjs['k_iter']]
789788

790789
return Uxreplace(subs).visit(body)
791790

devito/petsc/iet/type_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
PointerIS, PointerDM, VecScatter, JacobianStruct, SubMatrixStruct, CallbackDM,
1010
PetscMPIInt, PetscErrorCode, PointerMat, MatReuse, CallbackPointerDM,
1111
CallbackPointerIS, CallbackMat, DummyArg, NofSubMats, PetscSectionGlobal, PetscSectionLocal, PetscSF,
12-
PetscIntPtr, CallbackPetscInt, CallbackPointerPetscInt
12+
PetscIntPtr, CallbackPetscInt, CallbackPointerPetscInt, PostIncrementIndex
1313
)
1414

1515

@@ -222,7 +222,7 @@ def _extend_build(self, base_dict):
222222
base_dict['bcPointsArr'] = CallbackPointerPetscInt(
223223
name=sreg.make_name(prefix='bcPointsArr')
224224
)
225-
base_dict['k_iter'] = PetscInt(
225+
base_dict['k_iter'] = PostIncrementIndex(
226226
name='k_iter', initvalue=0
227227
)
228228
return base_dict

devito/petsc/types/metadata.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from devito.operations.solve import eval_time_derivatives
1111

1212
from devito.petsc.config import petsc_variables
13-
from devito.petsc.types.object import PetscInt, TempSymb
13+
from devito.petsc.types.object import PetscInt, Counter
1414
from devito.petsc.types.equation import (
1515
EssentialBC, ZeroRow, ZeroColumn, NoOfEssentialBC, PointEssentialBC
1616
)
@@ -762,12 +762,8 @@ def _make_increment_expr(self, expr):
762762
"""
763763
if isinstance(expr, EssentialBC):
764764
assert expr.lhs == self.target
765-
# return NoOfEssentialBC(
766-
# TempSymb, expr.rhs,
767-
# subdomain=expr.subdomain,
768-
# )
769765
return NoOfEssentialBC(
770-
TempSymb, 1,
766+
Counter, 1,
771767
subdomain=expr.subdomain,
772768
implicit_dims=expr.subdomain.dimensions
773769
)
@@ -790,11 +786,11 @@ def _make_point_bc_expr(self, expr):
790786
Make the Eq that is used to increment the number of essential
791787
boundary nodes in the generated ccode.
792788
"""
793-
numBC = PetscInt(name='numBC2')
789+
# numBC = PetscInt(name='numBC2')
794790
if isinstance(expr, EssentialBC):
795791
assert expr.lhs == self.target
796792
return PointEssentialBC(
797-
TempSymb, expr.rhs,
793+
Counter, expr.rhs,
798794
subdomain=expr.subdomain
799795
)
800796
else:

devito/petsc/types/object.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from ctypes import POINTER, c_char
2+
from functools import cached_property
23

34
from devito.tools import CustomDtype, dtype_to_ctype, as_tuple, CustomIntType
45
from devito.types import (
56
LocalObject, LocalCompositeObject, ModuloDimension, TimeDimension, ArrayObject,
67
CustomDimension, Scalar
78
)
89
from devito.symbolics import Byref, cast
9-
from devito.types.basic import DataSymbol, LocalType
10+
from devito.types.basic import DataSymbol, LocalType, PostIncrementIndex
1011

1112
from devito.petsc.iet.nodes import petsc_call
1213

@@ -319,6 +320,10 @@ def dtype(self):
319320
return CustomDtype('IS', modifier=' *')
320321

321322

323+
class PetscPostIncrementIndex(PostIncrementIndex):
324+
pass
325+
326+
322327
class CallbackPointerPetscInt(PETScArrayObject):
323328
"""
324329
"""
@@ -373,7 +378,8 @@ class NofSubMats(Scalar, LocalType):
373378
pass
374379

375380

376-
TempSymb = PetscInt(name='numBC2')
381+
# Can this be attached to the consrain bc object in metadata maybe? probs shoulnd't be here
382+
Counter = PetscInt(name='count')
377383

378384

379385
FREE_PRIORITY = {

devito/symbolics/extended_sympy.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -947,6 +947,3 @@ def rfunc(func, item, *args):
947947
min: Min,
948948
max: Max,
949949
}
950-
951-
952-
Null = Macro('NULL')

devito/types/basic.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1927,3 +1927,7 @@ def _mem_internal_lazy(self):
19271927
to impose pass-by-reference semantics.
19281928
"""
19291929
_C_modifier = None
1930+
1931+
1932+
class PostIncrementIndex(DataSymbol):
1933+
pass

0 commit comments

Comments
 (0)