Skip to content

Commit 8209107

Browse files
authored
Merge pull request #2906 from devitocodes/tweak-deviceid
arch: make deviceid moe felxible with cuda/rocm env vars
2 parents 6c41641 + bfeebc4 commit 8209107

2 files changed

Lines changed: 31 additions & 12 deletions

File tree

devito/operator/operator.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1396,15 +1396,21 @@ def _physical_deviceid(self):
13961396
visible_device_var, visible_devices = get_visible_devices()
13971397
if visible_devices is None:
13981398
return logical_deviceid
1399+
elif len(visible_devices) == 1:
1400+
# Only one visible device, it's clearly the one we want
1401+
return visible_devices[0]
1402+
elif logical_deviceid <= len(visible_devices):
1403+
# Map the logical device ID to the physical one
1404+
return visible_devices[logical_deviceid]
13991405
else:
1400-
try:
1401-
return visible_devices[logical_deviceid]
1402-
except IndexError as e:
1403-
errmsg = (f"A deviceid value of {logical_deviceid} is not valid "
1404-
f"with {visible_device_var}={visible_devices}. Note that "
1405-
"deviceid corresponds to the logical index within the "
1406-
"visible devices, not the physical device index.")
1407-
raise ValueError(errmsg) from e
1406+
# Logical device ID is out of bounds, likely from oversubscription
1407+
# Print a warning and map modulo the number of visible devices
1408+
deviceid = visible_devices[logical_deviceid % len(visible_devices)]
1409+
warning(f"Logical device ID {logical_deviceid} is out of bounds "
1410+
f"for {len(visible_devices)} visible devices"
1411+
f" in {visible_device_var}."
1412+
f"Mapping to device ID {deviceid} instead.")
1413+
return visible_devices[logical_deviceid % len(visible_devices)]
14081414
else:
14091415
return None
14101416

tests/test_gpu_common.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,13 @@ def test_visible_devices(self, env_variables):
108108
assert argmap2._physical_deviceid == 0
109109

110110
@pytest.mark.parallel(mode=2)
111-
@pytest.mark.parametrize('visible_devices', ["1,2", "1,0", "0,2,3"])
111+
@pytest.mark.parametrize('visible_devices', [
112+
"1,2", "1,0", "0,2,3",
113+
# Per rank VISIBLE_DEVICE
114+
("1", "0"),
115+
# Oversubscribed
116+
"1",
117+
])
112118
def test_visible_devices_mpi(self, visible_devices, mode):
113119
"""
114120
Test that physical device IDs used for querying memory on a device via
@@ -122,11 +128,18 @@ def test_visible_devices_mpi(self, visible_devices, mode):
122128

123129
eq = Eq(u, u+1)
124130

125-
with switchenv({'CUDA_VISIBLE_DEVICES': visible_devices}):
131+
if isinstance(visible_devices, tuple):
132+
cu_device = visible_devices[rank]
133+
expected = int(cu_device)
134+
else:
135+
cu_device = visible_devices
136+
devices = visible_devices.split(',')
137+
expected = int(devices[rank % len(devices)])
138+
139+
with switchenv({'CUDA_VISIBLE_DEVICES': cu_device}):
126140
op1 = Operator(eq)
127141
argmap1 = op1.arguments()
128-
devices = [int(i) for i in visible_devices.split(',')]
129-
assert argmap1._physical_deviceid == devices[rank]
142+
assert argmap1._physical_deviceid == expected
130143

131144
# In default case, physical deviceid will equal rank
132145
op2 = Operator(eq)

0 commit comments

Comments
 (0)