@@ -108,7 +108,13 @@ def test_visible_devices(self, env_variables):
108108 assert argmap2 ._physical_deviceid == 0
109109
110110 @pytest .mark .parallel (mode = 2 )
111- @pytest .mark .parametrize ('visible_devices' , ["1,2" , "1,0" , "0,2,3" ])
111+ @pytest .mark .parametrize ('visible_devices' , [
112+ "1,2" , "1,0" , "0,2,3" ,
113+ # Per rank VISIBLE_DEVICE
114+ ("2" , "1" ),
115+ # Oversubscribed
116+ "1" ,
117+ ])
112118 def test_visible_devices_mpi (self , visible_devices , mode ):
113119 """
114120 Test that physical device IDs used for querying memory on a device via
@@ -122,11 +128,17 @@ def test_visible_devices_mpi(self, visible_devices, mode):
122128
123129 eq = Eq (u , u + 1 )
124130
125- with switchenv ({'CUDA_VISIBLE_DEVICES' : visible_devices }):
131+ if isinstance (visible_devices , tuple ):
132+ cu_device = visible_devices [rank ]
133+ expected = cu_device
134+ else :
135+ cu_device = visible_devices
136+ expected = int (visible_devices .split (',' )[rank ])
137+
138+ with switchenv ({'CUDA_VISIBLE_DEVICES' : cu_device }):
126139 op1 = Operator (eq )
127140 argmap1 = op1 .arguments ()
128- devices = [int (i ) for i in visible_devices .split (',' )]
129- assert argmap1 ._physical_deviceid == devices [rank ]
141+ assert argmap1 ._physical_deviceid == expected
130142
131143 # In default case, physical deviceid will equal rank
132144 op2 = Operator (eq )
0 commit comments