@@ -79,6 +79,31 @@ def __init_finalize__(self, *args, **kwargs):
7979 inds , _ = Function .__indices_setup__ (grid = grid , dimensions = dimensions )
8080 self ._space_dimensions = inds
8181
82+ @classmethod
83+ def _component_kwargs (cls , * inds , ** kwargs ):
84+ """
85+ Get the kwargs for a single component
86+ from the kwargs of the TensorFunction.
87+ """
88+ kw = {}
89+ for k , v in kwargs .items ():
90+ if k in ('staggered' , 'name' , 'dimensions' , 'shape' ):
91+ # Standard Function kwargs
92+ kw [k ] = v
93+ elif isinstance (v , MatrixBase ):
94+ if len (inds ) > 1 :
95+ kw [k ] = v [inds [0 ], inds [1 ]]
96+ else :
97+ kw [k ] = v [inds [0 ]]
98+ elif isinstance (v , (list , tuple )):
99+ if len (inds ) > 1 :
100+ kw [k ] = v [inds [0 ]][inds [1 ]]
101+ else :
102+ kw [k ] = v [inds [0 ]]
103+ else :
104+ kw [k ] = v
105+ return kw
106+
82107 @classmethod
83108 def __subfunc_setup__ (cls , * args , ** kwargs ):
84109 """
@@ -110,18 +135,9 @@ def __subfunc_setup__(cls, *args, **kwargs):
110135 for j in range (start , stop ):
111136 staggj = (stagg [i ][j ] if stagg is not None
112137 else (NODE if i == j else (d , dims [j ])))
113- # Setup kwargs for subfunction
114- # Through rebuilding or user input, the kwargs could be
115- # Tensors as well from a per-component property
116- sub_kwargs = {'name' : f"{ name } _{ d .name } { dims [j ].name } " ,
117- 'staggered' : staggj }
118- for k , v in kwargs .items ():
119- if isinstance (v , MatrixBase ):
120- sub_kwargs [k ] = v [i , j ]
121- elif isinstance (v , (list , tuple )):
122- sub_kwargs [k ] = v [i ][j ]
123- else :
124- sub_kwargs [k ] = v
138+ sub_kwargs = cls ._component_kwargs (i , j , ** kwargs )
139+ sub_kwargs .update ({'name' : f"{ name } _{ d .name } { dims [j ].name } " ,
140+ 'staggered' : staggj })
125141 funcs2 [j ] = cls ._sub_type (** sub_kwargs )
126142 funcs .append (funcs2 )
127143
@@ -335,17 +351,9 @@ def __subfunc_setup__(cls, *args, **kwargs):
335351 stagg = kwargs .get ("staggered" , None )
336352 name = kwargs .get ("name" )
337353 for i , d in enumerate (dims ):
338- kwargs ["name" ] = "%s_%s" % (name , d .name )
339- kwargs ["staggered" ] = stagg [i ] if stagg is not None else d
340- # Setup kwargs for subfunction
341- # Through rebuilding or user input, the kwargs could be
342- # Tensors as well from a per-component property
343- sub_kwargs = {}
344- for k , v in kwargs .items ():
345- if isinstance (v , (list , tuple , MatrixBase )):
346- sub_kwargs [k ] = v [i ]
347- else :
348- sub_kwargs [k ] = v
354+ sub_kwargs = cls ._component_kwargs (i , ** kwargs )
355+ sub_kwargs .update ({'name' : f"{ name } _{ d .name } " ,
356+ 'staggered' : stagg [i ] if stagg is not None else d })
349357 funcs .append (cls ._sub_type (** sub_kwargs ))
350358
351359 return funcs
0 commit comments