diff --git a/devito/operator/operator.py b/devito/operator/operator.py index 58fab20495..c0bb6145a6 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -1396,15 +1396,21 @@ def _physical_deviceid(self): visible_device_var, visible_devices = get_visible_devices() if visible_devices is None: return logical_deviceid + elif len(visible_devices) == 1: + # Only one visible device, it's clearly the one we want + return visible_devices[0] + elif logical_deviceid <= len(visible_devices): + # Map the logical device ID to the physical one + return visible_devices[logical_deviceid] else: - try: - return visible_devices[logical_deviceid] - except IndexError as e: - errmsg = (f"A deviceid value of {logical_deviceid} is not valid " - f"with {visible_device_var}={visible_devices}. Note that " - "deviceid corresponds to the logical index within the " - "visible devices, not the physical device index.") - raise ValueError(errmsg) from e + # Logical device ID is out of bounds, likely from oversubscription + # Print a warning and map modulo the number of visible devices + deviceid = visible_devices[logical_deviceid % len(visible_devices)] + warning(f"Logical device ID {logical_deviceid} is out of bounds " + f"for {len(visible_devices)} visible devices" + f" in {visible_device_var}." + f"Mapping to device ID {deviceid} instead.") + return visible_devices[logical_deviceid % len(visible_devices)] else: return None diff --git a/tests/test_gpu_common.py b/tests/test_gpu_common.py index a639c5bbe6..e84d0df5d8 100644 --- a/tests/test_gpu_common.py +++ b/tests/test_gpu_common.py @@ -108,7 +108,13 @@ def test_visible_devices(self, env_variables): assert argmap2._physical_deviceid == 0 @pytest.mark.parallel(mode=2) - @pytest.mark.parametrize('visible_devices', ["1,2", "1,0", "0,2,3"]) + @pytest.mark.parametrize('visible_devices', [ + "1,2", "1,0", "0,2,3", + # Per rank VISIBLE_DEVICE + ("1", "0"), + # Oversubscribed + "1", + ]) def test_visible_devices_mpi(self, visible_devices, mode): """ Test that physical device IDs used for querying memory on a device via @@ -122,11 +128,18 @@ def test_visible_devices_mpi(self, visible_devices, mode): eq = Eq(u, u+1) - with switchenv({'CUDA_VISIBLE_DEVICES': visible_devices}): + if isinstance(visible_devices, tuple): + cu_device = visible_devices[rank] + expected = int(cu_device) + else: + cu_device = visible_devices + devices = visible_devices.split(',') + expected = int(devices[rank % len(devices)]) + + with switchenv({'CUDA_VISIBLE_DEVICES': cu_device}): op1 = Operator(eq) argmap1 = op1.arguments() - devices = [int(i) for i in visible_devices.split(',')] - assert argmap1._physical_deviceid == devices[rank] + assert argmap1._physical_deviceid == expected # In default case, physical deviceid will equal rank op2 = Operator(eq)