@@ -453,7 +453,7 @@ def callback(self, clusters, prefix, seen=None):
453453 not d ._defines & hs .distributed_aindices :
454454 continue
455455
456- if any (halo_write (ci , hs ) for ci in clusters [:i ]):
456+ if any (halo_write (ci , hs , prefix ) for ci in clusters [:i ]):
457457 # If there's a halo write before `c`, then we cannot inject the HaloTouch
458458 continue
459459
@@ -472,14 +472,17 @@ def callback(self, clusters, prefix, seen=None):
472472 # the args is important because that's what search functions honor!
473473 points = sorted (points , key = str )
474474
475- # Construct the HaloTouch Cluster
476- expr = Eq (self .B , HaloTouch (* points , halo_scheme = hs ))
477-
478475 key0 = lambda i : i in prefix [:- 1 ] or i in hs .loc_indices # noqa: B023
479476 key1 = lambda i : i not in hs .distributed_defined # noqa: B023
480477 key = lambda i : key0 (i ) and key1 (i ) # noqa: B023
481478 ispace = c .ispace .project (key )
482479
480+ # Reconstruct the HaloScheme with the new IterationSpace
481+ hs = HaloScheme (c .exprs , ispace )
482+
483+ # Construct the HaloTouch Cluster
484+ expr = Eq (self .B , HaloTouch (* points , halo_scheme = hs ))
485+
483486 properties = c .properties .sequentialize ()
484487
485488 halo_touch = c .rebuild (exprs = expr , ispace = ispace , properties = properties )
@@ -787,12 +790,15 @@ def normalize_reductions_sparse(cluster, sregistry):
787790 return cluster .rebuild (processed )
788791
789792
790- def halo_write (c , hs ):
793+ def halo_write (c , hs , prefix ):
791794 loc_vals = hs .loc_values
795+ hsdims = hs .dimensions & set (prefix .itdims )
792796
793797 for f in hs .fmapper :
794798 for a in c .scope .getwrites (f ):
795- if set (a .access .indices ) & loc_vals :
799+ is_write = set (a .access .indices ) & loc_vals
800+ is_dist = any (c .grid .is_distributed (d ) for d in hsdims )
801+ if is_write and is_dist :
796802 return True
797803
798804 return False
0 commit comments