1616from devito .petsc .types import PETScArray , PetscBundle
1717from devito .petsc .iet .nodes import (PETScCallable , FormFunctionCallback ,
1818 MatShellSetOp , PetscMetaData )
19- from devito .petsc .iet .utils import petsc_call , petsc_struct
19+ from devito .petsc .iet .utils import petsc_call , petsc_struct , zero_vector
2020from devito .petsc .utils import solver_mapper
2121from devito .petsc .types import (DM , Mat , CallbackVec , Vec , KSP , PC , SNES ,
2222 PetscInt , StartPtr , PointerIS , PointerDM , VecScatter ,
@@ -99,13 +99,6 @@ def initialguesses(self):
9999 def user_struct_callback (self ):
100100 return self ._user_struct_callback
101101
102- @property
103- def zero_memory (self ):
104- """Indicates whether the memory of the output
105- vector should be set to zero before the computation
106- in the callback."""
107- return True
108-
109102 @property
110103 def fielddata (self ):
111104 return self .injectsolve .expr .rhs .fielddata
@@ -169,7 +162,7 @@ def _create_matvec_body(self, body, jacobian):
169162 'DMGetApplicationContext' , [dmda , Byref (ctx ._C_symbol )]
170163 )
171164
172- zero_y_memory = self . zero_vector (objs ['Y' ])
165+ zero_y_memory = zero_vector (objs ['Y' ]) if jacobian . zero_memory else None
173166
174167 dm_get_local_xvec = petsc_call (
175168 'DMGetLocalVector' , [dmda , Byref (xlocal )]
@@ -188,9 +181,7 @@ def _create_matvec_body(self, body, jacobian):
188181 'DMGetLocalVector' , [dmda , Byref (ylocal )]
189182 )
190183
191- zero_ylocal_memory = petsc_call (
192- 'VecSet' , [ylocal , 0.0 ]
193- )
184+ zero_ylocal_memory = zero_vector (ylocal )
194185
195186 vec_get_array_y = petsc_call (
196187 'VecGetArray' , [ylocal , Byref (y_matvec ._C_symbol )]
@@ -320,7 +311,7 @@ def _create_formfunc_body(self, body):
320311 'DMGetApplicationContext' , [dmda , Byref (ctx ._C_symbol )]
321312 )
322313
323- zero_f_memory = self . zero_vector (objs ['F' ])
314+ zero_f_memory = zero_vector (objs ['F' ])
324315
325316 dm_get_local_xvec = petsc_call (
326317 'DMGetLocalVector' , [dmda , Byref (objs ['xloc' ])]
@@ -652,12 +643,6 @@ def _uxreplace_efuncs(self):
652643 mapper .update ({k : visitor .visit (v )})
653644 return mapper
654645
655- def zero_vector (self , vec ):
656- """
657- Zeros the memory of the output vector before computation
658- """
659- return petsc_call ('VecSet' , [vec , 0.0 ]) if self .zero_memory else None
660-
661646 def dereference_funcs (self , struct , fields ):
662647 return tuple (
663648 [Dereference (i , struct ) for i in
@@ -691,13 +676,6 @@ def main_matvec_callback(self):
691676 def main_formfunc_callback (self ):
692677 return self ._main_formfunc_callback
693678
694- @property
695- def zero_memory (self ):
696- """Indicates whether the memory of the output
697- vector should be set to zero before the computation
698- in the callback."""
699- return False
700-
701679 def _make_core (self ):
702680 for sm in self .fielddata .jacobian .nonzero_submatrices :
703681 self ._make_matvec (sm , prefix = f'{ sm .name } _MatMult' )
@@ -730,9 +708,7 @@ def _whole_matvec_body(self):
730708
731709 nonzero_submats = self .jacobian .nonzero_submatrices
732710
733- zero_y_memory = petsc_call (
734- 'VecSet' , [objs ['Y' ], 0.0 ]
735- )
711+ zero_y_memory = zero_vector (objs ['Y' ])
736712
737713 calls = ()
738714 for sm in nonzero_submats :
@@ -815,9 +791,7 @@ def _whole_formfunc_body(self, body):
815791 'DMGetApplicationContext' , [dmda , Byref (ctx ._C_symbol )]
816792 )
817793
818- zero_f_memory = petsc_call (
819- 'VecSet' , [objs ['F' ], 0.0 ]
820- )
794+ zero_f_memory = zero_vector (objs ['F' ])
821795
822796 dm_get_local_xvec = petsc_call (
823797 'DMGetLocalVector' , [dmda , Byref (objs ['xloc' ])]
0 commit comments