22from packaging .version import Version
33
44import cgen as c
5+ import numpy as np
56from sympy import And , Ne , Not
67
78from devito .arch import AMDGPUX , NVIDIAX , INTELGPUX , PVC
89from devito .arch .compiler import GNUCompiler , NvidiaCompiler
10+ from devito .exceptions import InvalidOperator
911from devito .ir import (Call , Conditional , DeviceCall , List , Pragma , Prodder ,
1012 ParallelBlock , PointerCast , While , FindSymbols )
1113from devito .passes .iet .definitions import DataManager , DeviceAwareDataManager
1820from devito .passes .iet .languages .C import CBB
1921from devito .passes .iet .languages .CXX import CXXBB
2022from devito .symbolics import CondEq , DefFunction
23+ from devito .symbolics .extended_sympy import UnaryOp
2124from devito .tools import filter_ordered
2225
2326__all__ = ['SimdOmpizer' , 'Ompizer' , 'OmpIteration' , 'OmpRegion' ,
@@ -113,6 +116,44 @@ def _generate(self):
113116 return self .pragma % (joins (* items ), n )
114117
115118
119+ class RealExt (UnaryOp ):
120+
121+ _op = '__real__ '
122+
123+
124+ class ImagExt (UnaryOp ):
125+
126+ _op = '__imag__ '
127+
128+
129+ def atomic_add (i , pragmas ):
130+ lhs , rhs = i .expr .lhs , i .expr .rhs
131+ if (np .issubdtype (lhs .dtype , np .complexfloating )
132+ and np .issubdtype (rhs .dtype , np .complexfloating )):
133+ # Complex i, complex j
134+ # Atomic add real and imaginary parts separately
135+ lhsr , rhsr = RealExt (lhs ), RealExt (rhs )
136+ lhsi , rhsi = ImagExt (lhs ), ImagExt (rhs )
137+ real = i ._rebuild (expr = i .expr ._rebuild (lhs = lhsr , rhs = rhsr ),
138+ pragmas = pragmas )
139+ imag = i ._rebuild (expr = i .expr ._rebuild (lhs = lhsi , rhs = rhsi ),
140+ pragmas = pragmas )
141+ return List (body = [real , imag ])
142+
143+ elif (np .issubdtype (lhs .dtype , np .complexfloating )
144+ and not np .issubdtype (rhs .dtype , np .complexfloating )):
145+ # Complex i, real j
146+ # Atomic add j to real part of i
147+ lhsr , rhsr = RealExt (lhs ), rhs
148+ real = i ._rebuild (expr = i .expr ._rebuild (lhs = lhsr , rhs = rhsr ),
149+ pragmas = pragmas )
150+ return real
151+ else :
152+ # Real i, complex j
153+ raise InvalidOperator ("Atomic add not implemented for real "
154+ "Functions with complex increments" )
155+
156+
116157class AbstractOmpBB (LangBB ):
117158
118159 mapper = {
@@ -134,7 +175,8 @@ class AbstractOmpBB(LangBB):
134175 'simd-for-aligned' : lambda n , * a :
135176 SimdForAligned ('omp simd aligned(%s:%d)' , arguments = (n , * a )),
136177 'atomic' :
137- Pragma ('omp atomic update' )
178+ Pragma ('omp atomic update' ),
179+ 'split-atomic' : lambda i : atomic_add (i , Pragma ('omp atomic update' ))
138180 }
139181
140182 Region = OmpRegion
@@ -241,6 +283,13 @@ def _support_array_reduction(cls, compiler):
241283 else :
242284 return True
243285
286+ @classmethod
287+ def _support_complex_reduction (cls , compiler ):
288+ if isinstance (compiler , GNUCompiler ):
289+ # Gcc doesn't supports complex reduction
290+ return False
291+ return True
292+
244293
245294class Ompizer (AbstractOmpizer ):
246295 langbb = OmpBB
0 commit comments