Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions pyisolate/_internal/model_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,11 @@

import logging
import os
import sys
from typing import TYPE_CHECKING, Any

from .serialization_registry import SerializerRegistry
from .torch_gate import get_torch_optional

_cuda_ipc_enabled = sys.platform == "linux" and os.environ.get("PYISOLATE_ENABLE_CUDA_IPC") == "1"

if TYPE_CHECKING: # pragma: no cover - typing aids
pass # type: ignore[import-not-found]

Expand All @@ -39,17 +36,28 @@ def _serialize_for_isolation_impl(
if isinstance(handle, remote_handle_type):
return handle

serializer = registry.get_serializer(type_name)
if serializer is not None:
return serializer(data)

# Handle torch tensors BEFORE the registry's mode-bound "Tensor" serializer so the
# per-channel transport mode (JSONSocketTransport._tensor_transport) decides the wire
# format -- not the process-global registry mode. Otherwise a host running a
# shared_memory (share_torch) extension alongside a json (sealed/conda) extension emits
# a shared-memory TensorRef onto the json channel, which a torch-free sealed worker
# cannot decode (KeyError 'data'). Returning the tensor here defers encoding to the
# transport, which already serializes per channel via serialize_tensor(mode=...).
if torch_module is not None and isinstance(data, torch_module.Tensor):
if data.is_cuda:
if _cuda_ipc_enabled:
# Read the CUDA IPC env at call time, not import time: the host sets
# PYISOLATE_ENABLE_CUDA_IPC during _initialize_process, after this module is
# imported, so an import-time snapshot would be stale and downgrade configured
# CUDA tensors to CPU. Matches rpc_serialization._prepare_for_rpc_impl.
if os.environ.get("PYISOLATE_ENABLE_CUDA_IPC") == "1":
return data
return data.cpu()
Comment on lines +52 to 54
return data

serializer = registry.get_serializer(type_name)
if serializer is not None:
return serializer(data)

if isinstance(data, dict):
return {
k: _serialize_for_isolation_impl(
Expand Down
62 changes: 61 additions & 1 deletion pyisolate/_internal/rpc_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import queue
import threading
import uuid
import warnings
from collections.abc import Callable
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -162,7 +163,33 @@ def __init__(

self.lock = threading.Lock()
self.pending: dict[int, RPCPendingRequest] = {}
self.default_loop = asyncio.get_event_loop()
# Set only when the last-resort branch below creates a loop we own, so run()/
# update_event_loop() can close it if a real loop later supersedes it (avoids a
# leaked, never-run event loop -> ResourceWarning: unclosed event loop).
self._created_default_loop: asyncio.AbstractEventLoop | None = None
# Acquire the loop without raising when constructed outside a running loop.
# Python >=3.10 deprecated and >=3.12 removed implicit main-thread event loop
# creation, so an eager asyncio.get_event_loop() raised here in sync construction
# paths. Preserve the historical get_event_loop() semantics: prefer the running
# loop, then the thread's installed loop (set via asyncio.set_event_loop, e.g. a
# synchronous caller that constructs the RPC before running its own loop), and only
# create+install a new loop as a last resort. update_event_loop() may replace it,
# and run() rebinds to the running loop when started inside one.
try:
self.default_loop = asyncio.get_running_loop()
except RuntimeError:
# No running loop. Reuse the thread's installed loop if present, else install a
# fresh one. asyncio.get_event_loop() returns an installed loop, and only emits
# the "no current event loop" DeprecationWarning (Python >=3.12) when none is
# installed -- we handle creation explicitly, so that one warning is silenced.
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
try:
self.default_loop = asyncio.get_event_loop()
except RuntimeError:
self.default_loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.default_loop)
self._created_default_loop = self.default_loop
self._loop_lock = threading.Lock() # Protects default_loop updates
Comment on lines +170 to 193
self.callees: dict[str, object] = {}
self.callbacks: dict[str, Any] = {}
Expand Down Expand Up @@ -209,9 +236,25 @@ def update_event_loop(self, loop: asyncio.AbstractEventLoop | None = None) -> No
with self._loop_lock:
if loop is None:
loop = asyncio.get_event_loop()
self._release_created_loop_if_superseded(loop)
self.default_loop = loop
logger.debug(f"RPC {self.id}: Updated default_loop to {loop}")

def _release_created_loop_if_superseded(self, new_loop: asyncio.AbstractEventLoop) -> None:
"""Close the fallback loop __init__ created if a different loop supersedes it.

Caller must hold self._loop_lock. __init__ creates and installs a new event loop
only as a last resort (no running or installed loop). When run() or
update_event_loop() later binds dispatch to a different, real loop, that created
loop would otherwise be left open and unused -- a leaked event loop that emits
ResourceWarning. We own it, so we close it on supersession.
"""
created = self._created_default_loop
if created is not None and created is not new_loop:
if not created.is_closed():
created.close()
self._created_default_loop = None

def register_callback(self, func: Any) -> str:
callback_id = str(uuid.uuid4())
with self.lock:
Expand Down Expand Up @@ -349,6 +392,23 @@ def _fail_pending_requests(self, error_msg: str) -> None:
)

def run(self) -> None:
# Bind dispatch to the loop that actually services it. _recv_thread
# dispatches inbound calls via run_coroutine_threadsafe(default_loop),
# which executes only on a *running* loop. run() starts those threads, so
# when invoked from within a running loop -- the supported pattern
# (ensure_process_started()/run_until_stopped() under asyncio.run) -- that
# running loop is the authoritative dispatch target and supersedes any loop
# __init__ acquired before the loop existed. Mirrors _get_valid_loop()'s
# running-loop capture; a fully synchronous caller with no running loop must
# call update_event_loop() once its loop starts.
try:
running_loop = asyncio.get_running_loop()
except RuntimeError:
running_loop = None
if running_loop is not None:
with self._loop_lock:
self._release_created_loop_if_superseded(running_loop)
self.default_loop = running_loop
Comment on lines +404 to +411
self.blocking_future = self.default_loop.create_future()
self._threads = [
threading.Thread(target=self._recv_thread, daemon=True),
Expand Down
14 changes: 11 additions & 3 deletions pyisolate/_internal/rpc_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,14 @@ def _prepare_for_rpc_impl(
torch_module: Any,
) -> Any:
obj_type = type(obj)
serializer = _resolve_serializer_for_type(registry, obj_type)
if serializer is not None:
return serializer(obj)

# Handle torch tensors BEFORE the registry's mode-bound "Tensor" serializer so the
# per-channel transport mode (JSONSocketTransport._tensor_transport) decides the wire
# format -- not the process-global registry mode. Otherwise a host running a
# shared_memory (share_torch) extension alongside a json (sealed/conda) extension emits
# a shared-memory TensorRef onto the json channel, which a torch-free sealed worker
# cannot decode (KeyError 'data'). Returning the tensor here defers encoding to the
# transport, whose _json_default serializes per channel via serialize_tensor(mode=...).
if torch_module is not None and isinstance(obj, torch_module.Tensor):
if obj.is_cuda:
if os.environ.get("PYISOLATE_ENABLE_CUDA_IPC") == "1":
Expand All @@ -277,6 +281,10 @@ def _prepare_for_rpc_impl(
return obj.cpu()
Comment on lines 277 to 281
return obj

serializer = _resolve_serializer_for_type(registry, obj_type)
if serializer is not None:
return serializer(obj)

if isinstance(obj, dict):
return {
k: _prepare_for_rpc_impl(v, registry=registry, torch_module=torch_module) for k, v in obj.items()
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "pyisolate"
version = "0.10.2"
version = "0.10.3rc1"
description = "A Python library for dividing execution across multiple virtual environments"
readme = "README.md"
requires-python = ">=3.10"
Expand Down
10 changes: 5 additions & 5 deletions tests/test_event_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def handler(payload: Any) -> None:
received.append(payload)

bridge.register_handler("progress", handler)
asyncio.get_event_loop().run_until_complete(bridge.dispatch("progress", {"value": 5, "total": 10}))
asyncio.run(bridge.dispatch("progress", {"value": 5, "total": 10}))

assert len(received) == 1
assert received[0] == {"value": 5, "total": 10}
Expand All @@ -37,7 +37,7 @@ def test_emit_unregistered_event_raises(self) -> None:
bridge = _EventBridge()

with pytest.raises(ValueError, match="No handler registered for event 'unknown_event'"):
asyncio.get_event_loop().run_until_complete(bridge.dispatch("unknown_event", {}))
asyncio.run(bridge.dispatch("unknown_event", {}))

def test_emit_event_rejects_non_json_payload(self) -> None:
"""emit_event with non-JSON-serializable payload raises immediately."""
Expand All @@ -62,7 +62,7 @@ async def async_handler(payload: Any) -> None:
received.append(payload)

bridge.register_handler("test", async_handler)
asyncio.get_event_loop().run_until_complete(bridge.dispatch("test", {"key": "value"}))
asyncio.run(bridge.dispatch("test", {"key": "value"}))

assert received == [{"key": "value"}]

Expand All @@ -75,8 +75,8 @@ def test_multiple_events_independent(self) -> None:
bridge.register_handler("progress", lambda p: progress_calls.append(p))
bridge.register_handler("preview", lambda p: preview_calls.append(p))

asyncio.get_event_loop().run_until_complete(bridge.dispatch("progress", {"value": 1}))
asyncio.get_event_loop().run_until_complete(bridge.dispatch("preview", {"image": "data"}))
asyncio.run(bridge.dispatch("progress", {"value": 1}))
asyncio.run(bridge.dispatch("preview", {"image": "data"}))

assert progress_calls == [{"value": 1}]
assert preview_calls == [{"image": "data"}]
Expand Down
130 changes: 130 additions & 0 deletions tests/test_rpc_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,136 @@ def test_singleton_survives_loop_recreation(self) -> None:
elif not loop1.is_closed():
loop1.close()

def test_asyncrpc_constructs_without_current_event_loop(self) -> None:
"""AsyncRPC must construct when no current event loop exists.

The host launches extensions from a synchronous path (host._launch_with_uds),
constructing AsyncRPC outside any running loop. Python >=3.12 removed implicit
main-thread loop creation, so an eager asyncio.get_event_loop() in __init__
raised "There is no current event loop". This guards that regression.
"""
import queue

from pyisolate._internal.rpc_protocol import AsyncRPC

try:
previous_loop = asyncio.get_event_loop_policy().get_event_loop()
except RuntimeError:
previous_loop = None

asyncio.set_event_loop(None)
rpc = None
try:
rpc = AsyncRPC(recv_queue=cast(Any, queue.Queue()), send_queue=cast(Any, queue.Queue()))
assert isinstance(rpc.default_loop, asyncio.AbstractEventLoop)
assert not rpc.default_loop.is_closed()
finally:
created = rpc.default_loop if rpc is not None else None
asyncio.set_event_loop(previous_loop)
if created is not None and created is not previous_loop:
created.close()

def test_asyncrpc_reuses_preset_thread_loop(self) -> None:
"""AsyncRPC must adopt the thread's installed (set-but-not-running) loop.

A synchronous caller may create a loop, install it via asyncio.set_event_loop(),
construct AsyncRPC, then drive that loop. __init__ must adopt the installed loop
(matching historical asyncio.get_event_loop() behavior) instead of creating a
separate loop that rpc.run()/dispatch would schedule on but nobody runs.
"""
import queue

from pyisolate._internal.rpc_protocol import AsyncRPC

try:
previous_loop = asyncio.get_event_loop_policy().get_event_loop()
except RuntimeError:
previous_loop = None

installed = asyncio.new_event_loop()
asyncio.set_event_loop(installed)
try:
rpc = AsyncRPC(recv_queue=cast(Any, queue.Queue()), send_queue=cast(Any, queue.Queue()))
assert rpc.default_loop is installed
finally:
asyncio.set_event_loop(previous_loop)
installed.close()

def test_asyncrpc_construction_emits_no_deprecation_warning(self) -> None:
"""AsyncRPC construction must not leak a 'no current event loop' DeprecationWarning.

The fix for the >=3.12 get_event_loop() crash must not itself emit the very
deprecation it works around. Treats DeprecationWarning as an error while
constructing with no installed loop.
"""
import queue
import warnings

from pyisolate._internal.rpc_protocol import AsyncRPC

try:
previous_loop = asyncio.get_event_loop_policy().get_event_loop()
except RuntimeError:
previous_loop = None

asyncio.set_event_loop(None)
rpc = None
try:
with warnings.catch_warnings():
warnings.simplefilter("error", DeprecationWarning)
rpc = AsyncRPC(recv_queue=cast(Any, queue.Queue()), send_queue=cast(Any, queue.Queue()))
assert isinstance(rpc.default_loop, asyncio.AbstractEventLoop)
finally:
created = rpc.default_loop if rpc is not None else None
asyncio.set_event_loop(previous_loop)
if created is not None and created is not previous_loop:
created.close()

def test_run_rebinds_dispatch_to_running_loop(self) -> None:
"""run() binds default_loop to the running loop that services dispatch.

AsyncRPC may be constructed before the loop that will run it exists (the
synchronous host launch path), so __init__ can only install a placeholder
loop. _recv_thread dispatches inbound calls via
run_coroutine_threadsafe(default_loop), which executes only on a *running*
loop; a placeholder nobody runs would hang every inbound child->host call.
run() therefore adopts the running loop before starting the dispatch
threads. Guards the regression where a never-run fallback loop is used as
the dispatch target on the Python >=3.12 sync-host path.
"""
import queue

from pyisolate._internal.rpc_protocol import AsyncRPC

try:
previous_loop = asyncio.get_event_loop_policy().get_event_loop()
except RuntimeError:
previous_loop = None

# Construct with no running/installed loop so default_loop is a placeholder
# distinct from the loop run() will later execute under.
asyncio.set_event_loop(None)
recv_q: queue.Queue[Any] = queue.Queue()
recv_q.put(None) # makes _recv_thread exit cleanly right after run()
rpc = AsyncRPC(recv_queue=cast(Any, recv_q), send_queue=cast(Any, queue.Queue()))
placeholder = rpc.default_loop

async def _run_inside_loop() -> asyncio.AbstractEventLoop:
rpc.run()
return asyncio.get_running_loop()

try:
running = asyncio.run(_run_inside_loop())
assert rpc.default_loop is running
assert rpc.default_loop is not placeholder
# The created fallback loop is closed when superseded -- no leaked loop.
assert placeholder.is_closed()
finally:
rpc.shutdown()
asyncio.set_event_loop(previous_loop)
if not placeholder.is_closed():
placeholder.close()

def test_singleton_data_persists_across_loops(self) -> None:
"""Data stored in singleton persists across event loops."""
try:
Expand Down
36 changes: 36 additions & 0 deletions tests/test_tensor_shared_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from __future__ import annotations

import pytest

torch = pytest.importorskip("torch")

from pyisolate._internal.tensor_serializer import ( # noqa: E402
_reset_shm_check,
deserialize_tensor,
serialize_tensor,
)


def test_cpu_torch_share_roundtrip_is_zero_copy() -> None:
"""torch_share CPU transport must stay shared-memory, never a value copy.

Runs single-process so it executes on Windows, where the multi-process
torch_share RPC tests are skipped (extension loading needs Unix sockets).
Guards the regression where a /dev/shm gate degrades CPU sharing to a
file-based value copy: a copy still passes ``torch.equal`` but breaks the
shared-storage contract that callers depend on for zero-copy transfer.
"""
_reset_shm_check()
original = torch.arange(25, dtype=torch.float32).reshape(5, 5)

payload = serialize_tensor(original, mode="shared_memory")

assert payload["__type__"] == "TensorRef", "CPU tensor degraded to a value copy"
assert payload["device"] == "cpu"
assert payload["strategy"] in ("file_system", "file_system_borrowed")

rebuilt = deserialize_tensor(payload, mode="shared_memory")
assert torch.equal(rebuilt, original)

original[0, 0] = 999.0
assert float(rebuilt[0, 0]) == 999.0, "receiver did not observe sender mutation"
Loading
Loading