@@ -80,26 +80,15 @@ def __init_finalize__(self, *args, **kwargs):
8080 self ._space_dimensions = inds
8181
8282 @classmethod
83- def _component_kwargs (cls , * inds , ** kwargs ):
83+ def _component_kwargs (cls , inds , ** kwargs ):
8484 """
8585 Get the kwargs for a single component
8686 from the kwargs of the TensorFunction.
8787 """
8888 kw = {}
8989 for k , v in kwargs .items ():
90- if k in ('staggered' , 'name' , 'dimensions' , 'shape' ):
91- # Standard Function kwargs
92- kw [k ] = v
93- elif isinstance (v , MatrixBase ):
94- if len (inds ) > 1 :
95- kw [k ] = v [inds [0 ], inds [1 ]]
96- else :
97- kw [k ] = v [inds [0 ]]
98- elif isinstance (v , (list , tuple )):
99- if len (inds ) > 1 :
100- kw [k ] = v [inds [0 ]][inds [1 ]]
101- else :
102- kw [k ] = v [inds [0 ]]
90+ if isinstance (v , MatrixBase ):
91+ kw [k ] = v [inds ]
10392 else :
10493 kw [k ] = v
10594 return kw
@@ -135,7 +124,7 @@ def __subfunc_setup__(cls, *args, **kwargs):
135124 for j in range (start , stop ):
136125 staggj = (stagg [i ][j ] if stagg is not None
137126 else (NODE if i == j else (d , dims [j ])))
138- sub_kwargs = cls ._component_kwargs (i , j , ** kwargs )
127+ sub_kwargs = cls ._component_kwargs (( i , j ) , ** kwargs )
139128 sub_kwargs .update ({'name' : f"{ name } _{ d .name } { dims [j ].name } " ,
140129 'staggered' : staggj })
141130 funcs2 [j ] = cls ._sub_type (** sub_kwargs )
0 commit comments