diff --git a/Pipfile b/Pipfile index 4207275c..dd60baf0 100644 --- a/Pipfile +++ b/Pipfile @@ -14,6 +14,9 @@ pytest-postgresql = {path = ".", editable = true} [dev-packages] towncrier = "==25.8.0" psycopg-binary = {version = "==3.3.4", markers="implementation_name == 'cpython'"} +pytest-asyncio = ">=1.4" +aiofiles = ">=23.0" +types-aiofiles = ">=23.0" coverage = ">=7.14.1" pytest-xdist = "==3.8.0" mock = "==5.2.0" diff --git a/README.rst b/README.rst index 02c57c56..8b56f0f8 100644 --- a/README.rst +++ b/README.rst @@ -38,6 +38,24 @@ Quick Start You will also need to install ``psycopg`` (version 3). See `its installation instructions `_. + For async tests with ``psycopg.AsyncConnection``, install the optional async extra: + + .. code-block:: sh + + pip install pytest-postgresql[async] + + This installs: + + * ``pytest-asyncio`` (>= 1.4) — required for ``@pytest.mark.asyncio`` and + ``postgresql_async`` fixtures. + * ``aiofiles`` (>= 23.0) — required only when loading SQL files via the + async loader (``sql_async``). + + On Windows, the plugin configures a ``SelectorEventLoop`` automatically (required + by ``psycopg`` async). With ``pytest-asyncio`` >= 1.4, this is done via the + loop-factory hook on all supported Python versions. On Python 3.14+, the legacy + ``asyncio`` policy fallback is not used because that API is deprecated. + .. note:: While this plugin requires ``psycopg`` 3 to manage the database, your application code can still use ``psycopg`` 2. @@ -54,6 +72,21 @@ Quick Start cur.execute("CREATE TABLE test (id serial PRIMARY KEY, num integer, data varchar);") postgresql.commit() + For async code, use ``postgresql_async`` with ``pytest.mark.asyncio``: + + .. code-block:: python + + import pytest + + @pytest.mark.asyncio + async def test_example_async(postgresql_async): + """Check main async postgresql fixture.""" + async with postgresql_async.cursor() as cur: + await cur.execute( + "CREATE TABLE test (id serial PRIMARY KEY, num integer, data varchar);" + ) + await postgresql_async.commit() + How to use ========== @@ -75,6 +108,22 @@ The plugin provides two main types of fixtures: * **postgresql** - A function-scoped fixture. It returns a connected ``psycopg.Connection``. After each test, it terminates leftover connections and drops the test database to ensure isolation. + * **postgresql_async** - The async counterpart. It returns a connected ``psycopg.AsyncConnection``. + Requires ``pytest-postgresql[async]`` (``pytest-asyncio`` >= 1.4), and each test must be + marked with ``@pytest.mark.asyncio``. + +**Async fixtures** + ``postgresql_async`` and custom factories created with ``factories.postgresql_async`` are + async generator fixtures using ``pytest_asyncio.fixture``. + + Minimum versions when installing manually instead of via ``[async]``: + + .. code-block:: text + + pytest-asyncio >= 1.4 + aiofiles >= 23.0 # only for async SQL file loading + + If ``pytest-asyncio`` is missing, fixture setup raises ``ImportError``. **2. Process Fixtures** These manage the PostgreSQL server lifecycle. @@ -98,6 +147,9 @@ You can create additional fixtures using factories: # Create a client fixture that uses the custom process postgresql_my = factories.postgresql('postgresql_my_proc') + # Async client fixture (requires pytest-postgresql[async], pytest-asyncio >= 1.4) + postgresql_my_async = factories.postgresql_async('postgresql_my_proc') + .. note:: Each process fixture can be configured independently through factory arguments. diff --git a/newsfragments/1235.feature.rst b/newsfragments/1235.feature.rst new file mode 100644 index 00000000..1a355d0a --- /dev/null +++ b/newsfragments/1235.feature.rst @@ -0,0 +1,2 @@ +Added async PostgreSQL fixture support via ``postgresql_async`` factory and ``AsyncDatabaseJanitor``. +Added optional ``async`` extra (``pip install pytest-postgresql[async]``) providing ``pytest-asyncio`` (>= 1.4) and ``aiofiles`` dependencies. diff --git a/newsfragments/1295.bugfix.rst b/newsfragments/1295.bugfix.rst new file mode 100644 index 00000000..962f8919 --- /dev/null +++ b/newsfragments/1295.bugfix.rst @@ -0,0 +1 @@ +Fixed ``DeprecationWarning`` on Python 3.14 from deprecated asyncio event-loop policy usage on Windows; use pytest-asyncio loop factories instead. diff --git a/pyproject.toml b/pyproject.toml index 60f6a0bf..0bcddb93 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,12 @@ dependencies = [ ] requires-python = ">= 3.10" +[project.optional-dependencies] +async = [ + "pytest-asyncio >= 1.4", + "aiofiles >= 23.0" +] + [project.urls] "Source" = "https://github.com/dbfixtures/pytest-postgresql" "Bug Tracker" = "https://github.com/dbfixtures/pytest-postgresql/issues" diff --git a/pytest_postgresql/factories/__init__.py b/pytest_postgresql/factories/__init__.py index d6bd2f64..002304cb 100644 --- a/pytest_postgresql/factories/__init__.py +++ b/pytest_postgresql/factories/__init__.py @@ -17,8 +17,8 @@ # along with pytest-postgresql. If not, see . """Fixture factories for postgresql fixtures.""" -from pytest_postgresql.factories.client import postgresql +from pytest_postgresql.factories.client import postgresql, postgresql_async from pytest_postgresql.factories.noprocess import postgresql_noproc from pytest_postgresql.factories.process import PortType, postgresql_proc -__all__ = ("postgresql_proc", "postgresql_noproc", "postgresql", "PortType") +__all__ = ("postgresql_proc", "postgresql_noproc", "postgresql", "postgresql_async", "PortType") diff --git a/pytest_postgresql/factories/client.py b/pytest_postgresql/factories/client.py index 5fb5a5be..b49a31ea 100644 --- a/pytest_postgresql/factories/client.py +++ b/pytest_postgresql/factories/client.py @@ -17,17 +17,25 @@ # along with pytest-postgresql. If not, see . """Fixture factory for postgresql client.""" -from typing import Callable, Iterator +from typing import Any, AsyncIterator, Callable, Iterator, cast import psycopg import pytest -from psycopg import Connection +from psycopg import AsyncConnection, Connection from pytest import FixtureRequest from pytest_postgresql.config import get_config from pytest_postgresql.executor import PostgreSQLExecutor from pytest_postgresql.executor_noop import NoopExecutor -from pytest_postgresql.janitor import DatabaseJanitor +from pytest_postgresql.janitor import AsyncDatabaseJanitor, DatabaseJanitor + +pytest_asyncio: Any = None +try: + import pytest_asyncio as _pytest_asyncio_module + + pytest_asyncio = _pytest_asyncio_module +except ImportError: + pass def postgresql( @@ -46,7 +54,7 @@ def postgresql( @pytest.fixture def postgresql_factory(request: FixtureRequest) -> Iterator[Connection]: - """Fixture factory for PostgreSQL. + """Fixture connection factory for PostgreSQL. :param request: fixture request object :returns: postgresql client @@ -85,3 +93,81 @@ def postgresql_factory(request: FixtureRequest) -> Iterator[Connection]: db_connection.close() return postgresql_factory + + +def postgresql_async( + process_fixture_name: str, + dbname: str | None = None, + isolation_level: "psycopg.IsolationLevel | None" = None, +) -> Callable[[FixtureRequest], AsyncIterator[AsyncConnection]]: + """Return async connection fixture factory for PostgreSQL. + + Requires ``pytest-asyncio`` >= 1.4 (install via ``pip install pytest-postgresql[async]``). + + :param process_fixture_name: name of the process fixture + :param dbname: database name + :param isolation_level: optional postgresql isolation level + defaults to server's default + :returns: function which makes an async connection to postgresql + """ + if pytest_asyncio is None: + + @pytest.fixture + def postgresql_async_stub(request: FixtureRequest) -> None: + """Sync stub that raises ImportError when pytest-asyncio is absent.""" + raise ImportError( + "pytest-asyncio >= 1.4 is required for async fixtures. " + "Install it with: pip install pytest-postgresql[async]" + ) + + return cast( + Callable[[FixtureRequest], AsyncIterator[AsyncConnection]], + postgresql_async_stub, + ) + + assert pytest_asyncio is not None + + @pytest_asyncio.fixture # type: ignore[untyped-decorator] + async def postgresql_async_factory(request: FixtureRequest) -> AsyncIterator[AsyncConnection]: + """Async connection fixture factory for PostgreSQL. + + :param request: fixture request object + :returns: postgresql async client + """ + proc_fixture: PostgreSQLExecutor | NoopExecutor = request.getfixturevalue(process_fixture_name) + config = get_config(request) + + pg_host = proc_fixture.host + pg_port = proc_fixture.port + pg_user = proc_fixture.user + pg_password = proc_fixture.password + pg_options = proc_fixture.options + pg_db = dbname or proc_fixture.dbname + janitor = AsyncDatabaseJanitor( + user=pg_user, + host=pg_host, + port=pg_port, + dbname=pg_db, + template_dbname=proc_fixture.template_dbname, + version=proc_fixture.version, + password=pg_password, + isolation_level=isolation_level, + ) + if config.drop_test_database: + await janitor.drop() + async with janitor: + db_connection: AsyncConnection = await AsyncConnection.connect( + dbname=pg_db, + user=pg_user, + password=pg_password, + host=pg_host, + port=pg_port, + options=pg_options, + ) + yield db_connection + await db_connection.close() + + return cast( + Callable[[FixtureRequest], AsyncIterator[AsyncConnection]], + postgresql_async_factory, + ) diff --git a/pytest_postgresql/factories/process.py b/pytest_postgresql/factories/process.py index 27fab57f..31c86148 100644 --- a/pytest_postgresql/factories/process.py +++ b/pytest_postgresql/factories/process.py @@ -146,40 +146,43 @@ def postgresql_proc_fixture( ) n += 1 - tmpdir = tmp_path_factory.mktemp(f"pytest-postgresql-{request.fixturename}") - datadir, logfile_path = _prepare_dir(tmpdir, str(pg_port)) - - postgresql_executor = PostgreSQLExecutor( - executable=postgresql_ctl, - host=host or config.host, - port=pg_port, - user=user or config.user, - password=password or config.password, - dbname=pg_dbname, - options=options or config.options, - datadir=str(datadir), - unixsocketdir=unixsocketdir or config.unixsocketdir, - logfile=str(logfile_path), - startparams=startparams or config.startparams, - postgres_options=postgres_options or config.postgres_options, - ) - # start server - with postgresql_executor: - postgresql_executor.wait_for_postgres() - janitor = DatabaseJanitor( - user=postgresql_executor.user, - host=postgresql_executor.host, - port=postgresql_executor.port, - dbname=postgresql_executor.template_dbname, - as_template=True, - version=postgresql_executor.version, - password=postgresql_executor.password, + try: + tmpdir = tmp_path_factory.mktemp(f"pytest-postgresql-{request.fixturename}") + datadir, logfile_path = _prepare_dir(tmpdir, str(pg_port)) + + postgresql_executor = PostgreSQLExecutor( + executable=postgresql_ctl, + host=host or config.host, + port=pg_port, + user=user or config.user, + password=password or config.password, + dbname=pg_dbname, + options=options or config.options, + datadir=str(datadir), + unixsocketdir=unixsocketdir or config.unixsocketdir, + logfile=str(logfile_path), + startparams=startparams or config.startparams, + postgres_options=postgres_options or config.postgres_options, ) - if config.drop_test_database: - janitor.drop() - with janitor: - for load_element in pg_load: - janitor.load(load_element) - yield postgresql_executor + # start server + with postgresql_executor: + postgresql_executor.wait_for_postgres() + janitor = DatabaseJanitor( + user=postgresql_executor.user, + host=postgresql_executor.host, + port=postgresql_executor.port, + dbname=postgresql_executor.template_dbname, + as_template=True, + version=postgresql_executor.version, + password=postgresql_executor.password, + ) + if config.drop_test_database: + janitor.drop() + with janitor: + for load_element in pg_load: + janitor.load(load_element) + yield postgresql_executor + finally: + port_filename_path.unlink(missing_ok=True) return postgresql_proc_fixture diff --git a/pytest_postgresql/janitor.py b/pytest_postgresql/janitor.py index f602372e..be3d631c 100644 --- a/pytest_postgresql/janitor.py +++ b/pytest_postgresql/janitor.py @@ -1,21 +1,24 @@ """Database Janitor.""" -from contextlib import contextmanager +import inspect +from contextlib import asynccontextmanager, contextmanager from pathlib import Path from types import TracebackType -from typing import Callable, Iterator, Type, TypeVar +from typing import AsyncIterator, Callable, Iterator, Type, TypeVar import psycopg +import psycopg.sql as sql from packaging.version import parse -from psycopg import Connection, Cursor +from psycopg import AsyncCursor, Connection, Cursor -from pytest_postgresql.loader import build_loader -from pytest_postgresql.retry import retry +from pytest_postgresql.loader import build_loader, build_loader_async +from pytest_postgresql.retry import retry, retry_async Version = type(parse("1")) DatabaseJanitorType = TypeVar("DatabaseJanitorType", bound="DatabaseJanitor") +AsyncDatabaseJanitorType = TypeVar("AsyncDatabaseJanitorType", bound="AsyncDatabaseJanitor") class DatabaseJanitor: @@ -67,18 +70,17 @@ def __init__( def init(self) -> None: """Create database in postgresql.""" with self.cursor() as cur: + query = sql.SQL("CREATE DATABASE {}").format(sql.Identifier(self.dbname)) if self.template_dbname: # And make sure no-one is left connected to the template database. # Otherwise, Creating database from template will fail self._terminate_connection(cur, self.template_dbname) - query = f'CREATE DATABASE "{self.dbname}" TEMPLATE "{self.template_dbname}"' - else: - query = f'CREATE DATABASE "{self.dbname}"' + query = query + sql.SQL(" TEMPLATE {}").format(sql.Identifier(self.template_dbname)) if self.as_template: - query += " IS_TEMPLATE = true" + query = query + sql.SQL(" IS_TEMPLATE = true") - cur.execute(f"{query};") + cur.execute(query) def is_template(self) -> bool: """Determine whether the DatabaseJanitor maintains template or database.""" @@ -92,17 +94,17 @@ def drop(self) -> None: self._dont_datallowconn(cur, self.dbname) self._terminate_connection(cur, self.dbname) if self.as_template: - cur.execute(f'ALTER DATABASE "{self.dbname}" with is_template false;') - cur.execute(f'DROP DATABASE IF EXISTS "{self.dbname}";') + cur.execute(sql.SQL("ALTER DATABASE {} WITH is_template false").format(sql.Identifier(self.dbname))) + cur.execute(sql.SQL("DROP DATABASE IF EXISTS {}").format(sql.Identifier(self.dbname))) @staticmethod def _dont_datallowconn(cur: Cursor, dbname: str) -> None: - cur.execute(f'ALTER DATABASE "{dbname}" with allow_connections false;') + cur.execute(sql.SQL("ALTER DATABASE {} WITH allow_connections false").format(sql.Identifier(dbname))) @staticmethod def _terminate_connection(cur: Cursor, dbname: str) -> None: cur.execute( - "SELECT pg_terminate_backend(pg_stat_activity.pid)" + "SELECT pg_terminate_backend(pg_stat_activity.pid) " "FROM pg_stat_activity " "WHERE pg_stat_activity.datname = %s;", (dbname,), @@ -164,3 +166,153 @@ def __exit__( ) -> None: """Exit from Database janitor context cleaning after itself.""" self.drop() + + +class AsyncDatabaseJanitor: + """Manage database state asynchronously for specific tasks.""" + + def __init__( + self, + *, + user: str, + host: str, + port: str | int, + version: str | float | Version, # type: ignore[valid-type] + dbname: str, + template_dbname: str | None = None, + as_template: bool = False, + password: str | None = None, + isolation_level: "psycopg.IsolationLevel | None" = None, + connection_timeout: int = 60, + ) -> None: + """Initialize async janitor. + + :param user: postgresql username + :param host: postgresql host + :param port: postgresql port + :param dbname: database name + :param template_dbname: template database name to clone from + :param as_template: whether to mark the database as a template + :param version: postgresql version number + :param password: optional postgresql password + :param isolation_level: optional postgresql isolation level + defaults to server's default + :param connection_timeout: how long to retry connection before + raising a TimeoutError + """ + self.user = user + self.password = password + self.host = host + self.port = port + self.dbname = dbname + self.template_dbname = template_dbname + self.as_template = as_template + self._connection_timeout = connection_timeout + self.isolation_level = isolation_level + if not isinstance(version, Version): + self.version = parse(str(version)) + else: + self.version = version + + async def init(self) -> None: + """Create database in postgresql.""" + async with self.cursor() as cur: + query = sql.SQL("CREATE DATABASE {}").format(sql.Identifier(self.dbname)) + if self.template_dbname: + # And make sure no-one is left connected to the template database. + # Otherwise, Creating database from template will fail + await self._terminate_connection(cur, self.template_dbname) + query = query + sql.SQL(" TEMPLATE {}").format(sql.Identifier(self.template_dbname)) + + if self.as_template: + query = query + sql.SQL(" IS_TEMPLATE = true") + + await cur.execute(query) + + def is_template(self) -> bool: + """Determine whether the AsyncDatabaseJanitor maintains template or database.""" + return self.as_template + + async def drop(self) -> None: + """Drop database in postgresql.""" + # We cannot drop the database while there are connections to it, so we + # terminate all connections first while not allowing new connections. + async with self.cursor() as cur: + await self._dont_datallowconn(cur, self.dbname) + await self._terminate_connection(cur, self.dbname) + if self.as_template: + await cur.execute( + sql.SQL("ALTER DATABASE {} WITH is_template false").format(sql.Identifier(self.dbname)) + ) + await cur.execute(sql.SQL("DROP DATABASE IF EXISTS {}").format(sql.Identifier(self.dbname))) + + @staticmethod + async def _dont_datallowconn(cur: AsyncCursor, dbname: str) -> None: + await cur.execute(sql.SQL("ALTER DATABASE {} WITH allow_connections false").format(sql.Identifier(dbname))) + + @staticmethod + async def _terminate_connection(cur: AsyncCursor, dbname: str) -> None: + await cur.execute( + "SELECT pg_terminate_backend(pg_stat_activity.pid) " + "FROM pg_stat_activity " + "WHERE pg_stat_activity.datname = %s;", + (dbname,), + ) + + async def load(self, load: Callable | str | Path) -> None: + """Load data into a database. + + Expects: + + * a Path to sql file, that'll be loaded + * an import path to import callable + * a callable that expects: host, port, user, dbname and password arguments. + + """ + _loader = build_loader_async(load) + result = _loader( + host=self.host, + port=self.port, + user=self.user, + dbname=self.dbname, + password=self.password, + ) + if inspect.isawaitable(result): + await result + + @asynccontextmanager + async def cursor(self, dbname: str = "postgres") -> AsyncIterator[AsyncCursor]: + """Return postgresql async cursor.""" + + async def connect() -> psycopg.AsyncConnection: + return await psycopg.AsyncConnection.connect( + dbname=dbname, + user=self.user, + password=self.password, + host=self.host, + port=self.port, + ) + + conn = await retry_async(connect, timeout=self._connection_timeout, possible_exception=psycopg.OperationalError) + try: + await conn.set_isolation_level(self.isolation_level) + await conn.set_autocommit(True) + # We must not run a transaction since we create a database. + async with conn.cursor() as cur: + yield cur + finally: + await conn.close() + + async def __aenter__(self: AsyncDatabaseJanitorType) -> AsyncDatabaseJanitorType: + """Initialize Async Database Janitor.""" + await self.init() + return self + + async def __aexit__( + self: AsyncDatabaseJanitorType, + exc_type: Type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Exit from Async Database Janitor context cleaning after itself.""" + await self.drop() diff --git a/pytest_postgresql/loader.py b/pytest_postgresql/loader.py index c9b28cbd..e73e2852 100644 --- a/pytest_postgresql/loader.py +++ b/pytest_postgresql/loader.py @@ -1,5 +1,6 @@ """Loader helper functions.""" +import importlib import re from functools import partial from pathlib import Path @@ -7,16 +8,21 @@ import psycopg +try: + import aiofiles +except ImportError: + aiofiles = None # type: ignore[assignment] + def build_loader(load: Callable | str | Path) -> Callable: """Build a loader callable.""" if isinstance(load, Path): return partial(sql, load) elif isinstance(load, str): - loader_parts = re.split("[.:]", load, maxsplit=2) + loader_parts = re.split("[.:]", load) import_path = ".".join(loader_parts[:-1]) loader_name = loader_parts[-1] - _temp_import = __import__(import_path, globals(), locals(), fromlist=[loader_name]) + _temp_import = importlib.import_module(import_path) _loader: Callable = getattr(_temp_import, loader_name) return _loader else: @@ -30,3 +36,35 @@ def sql(sql_filename: Path, **kwargs: Any) -> None: with db_connection.cursor() as cur: cur.execute(_fd.read()) db_connection.commit() + + +def build_loader_async(load: Callable | str | Path) -> Callable: + """Build an async loader callable.""" + if isinstance(load, Path): + return partial(sql_async, load) + elif isinstance(load, str): + loader_parts = re.split("[.:]", load) + import_path = ".".join(loader_parts[:-1]) + loader_name = loader_parts[-1] + _temp_import = importlib.import_module(import_path) + _loader: Callable = getattr(_temp_import, loader_name) + return _loader + else: + return load + + +async def sql_async(sql_filename: Path, **kwargs: Any) -> None: + """Async database loader for sql files. + + Requires the optional ``async`` extra: ``pip install pytest-postgresql[async]``. + """ + if aiofiles is None: + raise ImportError( + "aiofiles is required for async SQL loading. Install it with: pip install pytest-postgresql[async]" + ) + + async with await psycopg.AsyncConnection.connect(**kwargs) as db_connection: + async with db_connection.cursor() as cur: + async with aiofiles.open(sql_filename, "r") as _fd: + await cur.execute(await _fd.read()) + await db_connection.commit() diff --git a/pytest_postgresql/plugin.py b/pytest_postgresql/plugin.py index 612e408a..56b27065 100644 --- a/pytest_postgresql/plugin.py +++ b/pytest_postgresql/plugin.py @@ -17,12 +17,27 @@ # along with pytest-postgresql. If not, see . """Plugin module of pytest-postgresql.""" +import asyncio +import platform +import selectors +from collections.abc import Callable from tempfile import gettempdir +from typing import Any +import pytest from _pytest.config.argparsing import Parser +from packaging.version import Version, parse from pytest_postgresql import factories +pytest_asyncio: Any = None +try: + import pytest_asyncio as _pytest_asyncio_module + + pytest_asyncio = _pytest_asyncio_module +except ImportError: + pass + _help_executable = "Path to PostgreSQL executable" _help_host = "Host at which PostgreSQL will accept connections" _help_port = "Port at which PostgreSQL will accept connections" @@ -42,6 +57,44 @@ ) +def _windows_selector_event_loop() -> asyncio.AbstractEventLoop: + return asyncio.SelectorEventLoop(selectors.SelectSelector()) + + +def _pytest_asyncio_supports_loop_factories() -> bool: + if pytest_asyncio is None: + return False + return parse(pytest_asyncio.__version__) >= parse("1.4.0") + + +def _is_windows() -> bool: + return platform.system() == "Windows" + + +def _uses_deprecated_asyncio_policy_on_windows() -> bool: + return Version(platform.python_version()) < Version("3.14") and not _pytest_asyncio_supports_loop_factories() + + +def pytest_configure(config: pytest.Config) -> None: + """Configure pytest-postgresql plugin.""" + if not _is_windows() or not config.pluginmanager.has_plugin("asyncio"): + return + if not _uses_deprecated_asyncio_policy_on_windows(): + return + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) # type: ignore[attr-defined] + + +if _is_windows(): + + @pytest.hookimpl(optionalhook=True) + def pytest_asyncio_loop_factories( + config: pytest.Config, + item: pytest.Item, + ) -> dict[str, Callable[[], asyncio.AbstractEventLoop]]: + """Use SelectorEventLoop on Windows for psycopg async compatibility.""" + return {"selector": _windows_selector_event_loop} + + def pytest_addoption(parser: Parser) -> None: """Configure options for pytest-postgresql.""" parser.addini(name="postgresql_exec", help=_help_executable, default="/usr/lib/postgresql/13/bin/pg_ctl") @@ -135,3 +188,4 @@ def pytest_addoption(parser: Parser) -> None: postgresql_proc = factories.postgresql_proc() postgresql_noproc = factories.postgresql_noproc() postgresql = factories.postgresql("postgresql_proc") +postgresql_async = factories.postgresql_async("postgresql_proc") diff --git a/pytest_postgresql/retry.py b/pytest_postgresql/retry.py index ea25fa2e..078db5bc 100644 --- a/pytest_postgresql/retry.py +++ b/pytest_postgresql/retry.py @@ -1,9 +1,10 @@ """Small retry callable in case of specific error occurred.""" +import asyncio import datetime import sys from time import sleep -from typing import Callable, Type, TypeVar +from typing import Awaitable, Callable, Type, TypeVar T = TypeVar("T") @@ -29,11 +30,41 @@ def retry( i += 1 try: res = func() - return res except possible_exception as e: if time + timeout_diff < get_current_datetime(): raise TimeoutError(f"Failed after {i} attempts") from e sleep(1) + else: + return res + + +async def retry_async( + func: Callable[[], Awaitable[T]], + timeout: int = 60, + possible_exception: Type[Exception] = Exception, +) -> T: + """Attempt to retry the async function for timeout time. + + Most often used for connecting to postgresql database as, + especially on macos on github-actions, first few tries fails + with this message: + + ... :: + FATAL: the database system is starting up + """ + time: datetime.datetime = get_current_datetime() + timeout_diff: datetime.timedelta = datetime.timedelta(seconds=timeout) + i = 0 + while True: + i += 1 + try: + res = await func() + except possible_exception as e: + if time + timeout_diff < get_current_datetime(): + raise TimeoutError(f"Failed after {i} attempts") from e + await asyncio.sleep(1) + else: + return res def get_current_datetime() -> datetime.datetime: diff --git a/tests/conftest.py b/tests/conftest.py index 10b5f39d..e83c485f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,3 +16,5 @@ postgresql_proc2 = factories.postgresql_proc(port=None, load=[TEST_SQL_FILE, TEST_SQL_FILE2]) postgresql2 = factories.postgresql("postgresql_proc2", dbname="test-db") postgresql_load_1 = factories.postgresql("postgresql_proc2") +postgresql2_async = factories.postgresql_async("postgresql_proc2", dbname="test-db") +postgresql_load_1_async = factories.postgresql_async("postgresql_proc2") diff --git a/tests/docker/test_noproc_docker.py b/tests/docker/test_noproc_docker.py index ae25307a..1d0fbb73 100644 --- a/tests/docker/test_noproc_docker.py +++ b/tests/docker/test_noproc_docker.py @@ -3,7 +3,7 @@ import pathlib import pytest -from psycopg import Connection +from psycopg import AsyncConnection, Connection import pytest_postgresql.factories.client import pytest_postgresql.factories.noprocess @@ -14,12 +14,17 @@ ) postgres_with_schema = pytest_postgresql.factories.client.postgresql("postgresql_my_proc") +async_postgres_with_schema = pytest_postgresql.factories.client.postgresql_async("postgresql_my_proc") + postgresql_my_proc_template = pytest_postgresql.factories.noprocess.postgresql_noproc( dbname="stories_templated", load=[load_database] ) postgres_with_template = pytest_postgresql.factories.client.postgresql( "postgresql_my_proc_template", dbname="stories_templated" ) +async_postgres_with_template = pytest_postgresql.factories.client.postgresql_async( + "postgresql_my_proc_template", dbname="stories_templated" +) def test_postgres_docker_load(postgres_with_schema: Connection) -> None: @@ -32,6 +37,14 @@ def test_postgres_docker_load(postgres_with_schema: Connection) -> None: print(cur.fetchall()) +@pytest.mark.asyncio +async def test_postgres_docker_load_async(async_postgres_with_schema: AsyncConnection) -> None: + """Async check main postgres fixture.""" + async with async_postgres_with_schema.cursor() as cur: + await cur.execute("select * from public.tokens") + print(await cur.fetchall()) + + @pytest.mark.parametrize("_", range(5)) def test_template_database(postgres_with_template: Connection, _: int) -> None: """Check that the database structure gets recreated out of a template.""" @@ -43,3 +56,17 @@ def test_template_database(postgres_with_template: Connection, _: int) -> None: cur.execute("SELECT * FROM stories") res = cur.fetchall() assert len(res) == 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("_", range(5)) +async def test_template_database_async(async_postgres_with_template: AsyncConnection, _: int) -> None: + """Async check that the database structure gets recreated out of a template.""" + async with async_postgres_with_template.cursor() as cur: + await cur.execute("SELECT * FROM stories") + rows = await cur.fetchall() + assert len(rows) == 4 + await cur.execute("TRUNCATE stories") + await cur.execute("SELECT * FROM stories") + rows = await cur.fetchall() + assert len(rows) == 0 diff --git a/tests/examples/test_drop_test_database_async.py b/tests/examples/test_drop_test_database_async.py new file mode 100644 index 00000000..88210ae0 --- /dev/null +++ b/tests/examples/test_drop_test_database_async.py @@ -0,0 +1,17 @@ +"""Async tests for pytest-postgresql drop-test-database behaviour.""" + +import pytest +from psycopg import AsyncConnection + +from pytest_postgresql import factories + +postgresql_async = factories.postgresql_async("postgresql_noproc") + + +@pytest.mark.asyncio +async def test_postgres_load_override_async(postgresql_async: AsyncConnection) -> None: + """Check postgresql_async can load one file and override a pre-existing database.""" + async with postgresql_async.cursor() as cur: + await cur.execute("SELECT * FROM test;") + results = await cur.fetchall() + assert len(results) == 1 diff --git a/tests/test_executor.py b/tests/test_executor.py index c6944204..c9b5ce83 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -2,7 +2,7 @@ import platform from typing import Any -from unittest.mock import patch +from unittest.mock import MagicMock, patch import psycopg import pytest @@ -90,7 +90,7 @@ def test_executor_init_with_password( config = get_config(request) monkeypatch.setenv("LC_ALL", locale) pg_exe = process._pg_exe(None, config) - port = process._pg_port(-1, config, []) + port = process._pg_port(None, config, []) tmpdir = tmp_path_factory.mktemp(f"pytest-postgresql-{request.node.name}") datadir, logfile_path = process._prepare_dir(tmpdir, port) executor = PostgreSQLExecutor( @@ -114,7 +114,7 @@ def test_executor_init_bad_tmp_path( r"""Test init with \ and space chars in the path.""" config = get_config(request) pg_exe = process._pg_exe(None, config) - port = process._pg_port(-1, config, []) + port = process._pg_port(None, config, []) tmpdir = tmp_path_factory.mktemp(f"pytest-postgresql-{request.node.name}") / r"a bad\path/" tmpdir.mkdir(parents=True, exist_ok=True) datadir, logfile_path = process._prepare_dir(tmpdir, port) @@ -199,7 +199,7 @@ def test_executor_with_special_chars_in_all_paths( """ config = get_config(request) pg_exe = process._pg_exe(None, config) - port = process._pg_port(-1, config, []) + port = process._pg_port(None, config, []) # Create a tmpdir with spaces in the name tmpdir = tmp_path_factory.mktemp(f"pytest-postgresql-{request.node.name}") / "my test dir" tmpdir.mkdir(exist_ok=True) @@ -293,6 +293,92 @@ def test_custom_isolation_level(postgres_isolation_level: Connection) -> None: assert cur.fetchone() == (1,) +def test_postgresql_proc_removes_port_lock_on_teardown( + request: FixtureRequest, + tmp_path_factory: pytest.TempPathFactory, +) -> None: + """Port sentinel file is removed when the process fixture tears down.""" + fixture_func = postgresql_proc(port=None) + raw_func = getattr(fixture_func, "__wrapped__", fixture_func) + + port_path = tmp_path_factory.getbasetemp() + if hasattr(request.config, "workerinput"): + port_path = tmp_path_factory.getbasetemp().parent + pg_port = 54321 + + executor_mock = MagicMock() + executor_mock.__enter__ = MagicMock(return_value=executor_mock) + executor_mock.__exit__ = MagicMock(return_value=False) + executor_mock.user = "postgres" + executor_mock.host = "127.0.0.1" + executor_mock.port = pg_port + executor_mock.template_dbname = "template_tests" + executor_mock.version = 14 + executor_mock.password = None + executor_mock.wait_for_postgres = MagicMock() + + janitor_mock = MagicMock() + janitor_mock.__enter__ = MagicMock(return_value=janitor_mock) + janitor_mock.__exit__ = MagicMock(return_value=False) + + with ( + patch("pytest_postgresql.factories.process._pg_exe", return_value="/usr/bin/pg_ctl"), + patch("pytest_postgresql.factories.process._pg_port", return_value=pg_port), + patch("pytest_postgresql.factories.process.PostgreSQLExecutor", return_value=executor_mock), + patch("pytest_postgresql.factories.process.DatabaseJanitor", return_value=janitor_mock), + patch("pytest_postgresql.factories.process.get_config") as get_config_mock, + ): + config_mock = MagicMock() + config_mock.dbname = "tests" + config_mock.load = [] + config_mock.drop_test_database = False + config_mock.port_search_count = 5 + get_config_mock.return_value = config_mock + + gen = raw_func(request, tmp_path_factory) + next(gen) + port_file = port_path / f"postgresql-{pg_port}.port" + assert port_file.exists() + with pytest.raises(StopIteration): + next(gen) + + assert not port_file.exists() + + +def test_postgresql_proc_removes_port_lock_on_setup_failure( + request: FixtureRequest, + tmp_path_factory: pytest.TempPathFactory, +) -> None: + """Port sentinel file is removed when fixture setup fails after claiming a port.""" + fixture_func = postgresql_proc(port=None) + raw_func = getattr(fixture_func, "__wrapped__", fixture_func) + + port_path = tmp_path_factory.getbasetemp() + if hasattr(request.config, "workerinput"): + port_path = tmp_path_factory.getbasetemp().parent + pg_port = 54322 + + with ( + patch("pytest_postgresql.factories.process._pg_exe", return_value="/usr/bin/pg_ctl"), + patch("pytest_postgresql.factories.process._pg_port", return_value=pg_port), + patch("pytest_postgresql.factories.process.get_config") as get_config_mock, + patch.object(tmp_path_factory, "mktemp", side_effect=OSError("setup failed")), + ): + config_mock = MagicMock() + config_mock.dbname = "tests" + config_mock.load = [] + config_mock.drop_test_database = False + config_mock.port_search_count = 5 + get_config_mock.return_value = config_mock + + gen = raw_func(request, tmp_path_factory) + with pytest.raises(OSError, match="setup failed"): + next(gen) + + port_file = port_path / f"postgresql-{pg_port}.port" + assert not port_file.exists() + + @pytest.mark.skipif(platform.system() != "Windows", reason="Windows-specific test") def test_actual_postgresql_start_windows( request: FixtureRequest, @@ -305,7 +391,7 @@ def test_actual_postgresql_start_windows( """ config = get_config(request) pg_exe = process._pg_exe(None, config) - port = process._pg_port(-1, config, []) + port = process._pg_port(None, config, []) tmpdir = tmp_path_factory.mktemp(f"pytest-postgresql-{request.node.name}") datadir, logfile_path = process._prepare_dir(tmpdir, port) @@ -344,7 +430,7 @@ def test_actual_postgresql_start_unix( """ config = get_config(request) pg_exe = process._pg_exe(None, config) - port = process._pg_port(-1, config, []) + port = process._pg_port(None, config, []) tmpdir = tmp_path_factory.mktemp(f"pytest-postgresql-{request.node.name}") datadir, logfile_path = process._prepare_dir(tmpdir, port) @@ -380,7 +466,7 @@ def test_actual_postgresql_start_darwin( """ config = get_config(request) pg_exe = process._pg_exe(None, config) - port = process._pg_port(-1, config, []) + port = process._pg_port(None, config, []) tmpdir = tmp_path_factory.mktemp(f"pytest-postgresql-{request.node.name}") datadir, logfile_path = process._prepare_dir(tmpdir, port) diff --git a/tests/test_factory_errors.py b/tests/test_factory_errors.py new file mode 100644 index 00000000..da12ae8a --- /dev/null +++ b/tests/test_factory_errors.py @@ -0,0 +1,33 @@ +"""Tests for factory error paths (missing optional dependencies).""" + +from unittest.mock import patch + +import pytest + +from pytest_postgresql.factories.client import postgresql_async + + +def test_postgresql_async_factory_creation_succeeds_without_pytest_asyncio() -> None: + """postgresql_async() must not raise at factory-creation time when pytest-asyncio is absent. + + The plugin registers ``postgresql_async`` at load time (plugin.py), so raising here + would break all users — including those who only use synchronous fixtures. + """ + with patch("pytest_postgresql.factories.client.pytest_asyncio", None): + fixture_func = postgresql_async("some_proc_fixture") + assert callable(fixture_func) + + +def test_postgresql_async_raises_on_use_without_pytest_asyncio() -> None: + """When pytest-asyncio is absent, the registered stub is synchronous and raises ImportError. + + A synchronous stub avoids the "coroutine was never awaited" warning that would + result from registering an async def with plain pytest.fixture. + """ + with patch("pytest_postgresql.factories.client.pytest_asyncio", None): + fixture_func = postgresql_async("some_proc_fixture") + # pytest 8+ wraps fixtures to prevent direct calls; unwrap first. + raw_func = getattr(fixture_func, "__wrapped__", fixture_func) + assert not hasattr(raw_func, "__await__"), "stub must be a sync function, not a coroutine" + with pytest.raises(ImportError, match=r"pytest-asyncio >= 1\.4"): + raw_func(None) # type: ignore[arg-type] diff --git a/tests/test_janitor.py b/tests/test_janitor.py index fd1fca2a..98217c2a 100644 --- a/tests/test_janitor.py +++ b/tests/test_janitor.py @@ -1,15 +1,23 @@ """Database Janitor tests.""" import sys +from pathlib import Path from typing import Any -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch +import psycopg import pytest from packaging.version import parse +from psycopg import AsyncCursor -from pytest_postgresql.janitor import DatabaseJanitor +from pytest_postgresql.executor import PostgreSQLExecutor +from pytest_postgresql.factories.noprocess import xdistify_dbname +from pytest_postgresql.janitor import AsyncDatabaseJanitor, DatabaseJanitor + +TEST_SQL_FILE = Path(__file__).resolve().parent / "test_sql" / "test.sql" VERSION = parse("10") +TEST_PASSWORD = "some_password" # noqa: S105 @pytest.mark.parametrize("version", (VERSION, 10, "10")) @@ -19,6 +27,14 @@ def test_version_cast(version: Any) -> None: assert janitor.version == VERSION +@pytest.mark.parametrize("version", (VERSION, 10, "10")) +@pytest.mark.asyncio +async def test_version_cast_async(version: Any) -> None: + """Async test that version is cast to Version object.""" + janitor = AsyncDatabaseJanitor(user="user", host="host", port="1234", dbname="database_name", version=version) + assert janitor.version == VERSION + + @patch("pytest_postgresql.janitor.psycopg.connect") def test_cursor_selects_postgres_database(connect_mock: MagicMock) -> None: """Test that the cursor requests the postgres database.""" @@ -27,6 +43,19 @@ def test_cursor_selects_postgres_database(connect_mock: MagicMock) -> None: connect_mock.assert_called_once_with(dbname="postgres", user="user", password=None, host="host", port="1234") +@pytest.mark.asyncio +async def test_cursor_selects_postgres_database_async() -> None: + """Async test that the cursor requests the postgres database.""" + conn_mock = _make_async_conn_mock() + connect_mock = AsyncMock(return_value=conn_mock) + with patch("pytest_postgresql.janitor.psycopg.AsyncConnection.connect", connect_mock): + janitor = AsyncDatabaseJanitor(user="user", host="host", port="1234", dbname="database_name", version=10) + async with janitor.cursor(): + connect_mock.assert_called_once_with( + dbname="postgres", user="user", password=None, host="host", port="1234" + ) + + @patch("pytest_postgresql.janitor.psycopg.connect") def test_cursor_connects_with_password(connect_mock: MagicMock) -> None: """Test that the cursor requests the postgres database.""" @@ -36,12 +65,45 @@ def test_cursor_connects_with_password(connect_mock: MagicMock) -> None: port="1234", dbname="database_name", version=10, - password="some_password", + password=TEST_PASSWORD, ) with janitor.cursor(): connect_mock.assert_called_once_with( - dbname="postgres", user="user", password="some_password", host="host", port="1234" + dbname="postgres", user="user", password=TEST_PASSWORD, host="host", port="1234" + ) + + +@pytest.mark.asyncio +async def test_cursor_connects_with_password_async() -> None: + """Async test that the cursor requests the postgres database with password.""" + conn_mock = _make_async_conn_mock() + connect_mock = AsyncMock(return_value=conn_mock) + with patch("pytest_postgresql.janitor.psycopg.AsyncConnection.connect", connect_mock): + janitor = AsyncDatabaseJanitor( + user="user", + host="host", + port="1234", + dbname="database_name", + version=10, + password=TEST_PASSWORD, ) + async with janitor.cursor(): + connect_mock.assert_called_once_with( + dbname="postgres", user="user", password=TEST_PASSWORD, host="host", port="1234" + ) + + +@pytest.mark.asyncio +async def test_cursor_custom_dbname_async() -> None: + """Test that a custom dbname is forwarded to the connection in AsyncDatabaseJanitor.cursor.""" + conn_mock = _make_async_conn_mock() + connect_mock = AsyncMock(return_value=conn_mock) + with patch("pytest_postgresql.janitor.psycopg.AsyncConnection.connect", connect_mock): + janitor = AsyncDatabaseJanitor(user="user", host="host", port="1234", dbname="database_name", version=10) + async with janitor.cursor(dbname="custom_db"): + connect_mock.assert_called_once_with( + dbname="custom_db", user="user", password=None, host="host", port="1234" + ) @pytest.mark.skipif(sys.version_info < (3, 8), reason="Unittest call_args.kwargs was introduced since python 3.8") @@ -57,9 +119,252 @@ def test_janitor_populate(connect_mock: MagicMock, load_database: str) -> None: "port": "1234", "user": "user", "dbname": "database_name", - "password": "some_password", + "password": TEST_PASSWORD, } janitor = DatabaseJanitor(version=10, **call_kwargs) # type: ignore[arg-type] janitor.load(load_database) assert connect_mock.called assert connect_mock.call_args.kwargs == call_kwargs + + +@pytest.mark.skipif(sys.version_info < (3, 8), reason="Unittest call_args.kwargs was introduced since python 3.8") +@pytest.mark.parametrize("load_database", ("tests.loader.load_database", "tests.loader:load_database")) +@patch("tests.loader.psycopg.connect") +@pytest.mark.asyncio +async def test_janitor_populate_async(connect_mock: MagicMock, load_database: str) -> None: + """Async test that the cursor requests the postgres database and populates. + + load_database (synchronous) uses psycopg.connect, so we mock that. + """ + call_kwargs = { + "host": "host", + "port": "1234", + "user": "user", + "dbname": "database_name", + "password": TEST_PASSWORD, + } + janitor = AsyncDatabaseJanitor(version=10, **call_kwargs) # type: ignore[arg-type] + await janitor.load(load_database) + assert connect_mock.called + assert connect_mock.call_args.kwargs == call_kwargs + + +@pytest.mark.asyncio +async def test_janitor_populate_async_awaitable_loader() -> None: + """AsyncDatabaseJanitor.load awaits async loader callables.""" + call_kwargs = { + "host": "host", + "port": "1234", + "user": "user", + "dbname": "database_name", + "password": TEST_PASSWORD, + } + loader_mock = AsyncMock() + + async def async_loader(**kwargs: object) -> None: + await loader_mock(**kwargs) + + janitor = AsyncDatabaseJanitor(version=10, **call_kwargs) # type: ignore[arg-type] + await janitor.load(async_loader) + loader_mock.assert_awaited_once_with(**call_kwargs) + + +@pytest.mark.asyncio +async def test_janitor_populate_async_sql_path(postgresql_proc: PostgreSQLExecutor) -> None: + """AsyncDatabaseJanitor.load executes SQL from a Path via sql_async against live PostgreSQL.""" + dbname = xdistify_dbname("sql_async_load") + janitor = AsyncDatabaseJanitor( + user=postgresql_proc.user, + host=postgresql_proc.host, + port=postgresql_proc.port, + dbname=dbname, + version=postgresql_proc.version, + password=postgresql_proc.password, + connection_timeout=5, + ) + async with janitor: + await janitor.load(TEST_SQL_FILE) + async with await psycopg.AsyncConnection.connect( + dbname=dbname, + user=postgresql_proc.user, + password=postgresql_proc.password, + host=postgresql_proc.host, + port=postgresql_proc.port, + ) as conn: + async with conn.cursor() as cur: + await cur.execute("SELECT * FROM test_load") + rows = await cur.fetchall() + assert len(rows) == 1 + + +# --------------------------------------------------------------------------- +# AsyncDatabaseJanitor -- init() / drop() integration tests +# --------------------------------------------------------------------------- + + +async def _database_exists(proc: PostgreSQLExecutor, dbname: str) -> bool: + async with await psycopg.AsyncConnection.connect( + dbname="postgres", + user=proc.user, + password=proc.password, + host=proc.host, + port=proc.port, + ) as conn: + async with conn.cursor() as cur: + await cur.execute("SELECT 1 FROM pg_database WHERE datname = %s", (dbname,)) + return await cur.fetchone() is not None + + +async def _database_is_template(proc: PostgreSQLExecutor, dbname: str) -> bool: + async with await psycopg.AsyncConnection.connect( + dbname="postgres", + user=proc.user, + password=proc.password, + host=proc.host, + port=proc.port, + ) as conn: + async with conn.cursor() as cur: + await cur.execute("SELECT datistemplate FROM pg_database WHERE datname = %s", (dbname,)) + row = await cur.fetchone() + return bool(row and row[0]) + + +@pytest.mark.asyncio +async def test_async_janitor_init_and_drop(postgresql_proc: PostgreSQLExecutor) -> None: + """init() creates a database and drop() removes it against live PostgreSQL.""" + dbname = xdistify_dbname("async_janitor_lifecycle") + janitor = AsyncDatabaseJanitor( + user=postgresql_proc.user, + host=postgresql_proc.host, + port=postgresql_proc.port, + dbname=dbname, + version=postgresql_proc.version, + password=postgresql_proc.password, + connection_timeout=5, + ) + await janitor.init() + assert await _database_exists(postgresql_proc, dbname) + await janitor.drop() + assert not await _database_exists(postgresql_proc, dbname) + + +@pytest.mark.asyncio +async def test_async_janitor_template_flag_and_context_manager(postgresql_proc: PostgreSQLExecutor) -> None: + """as_template marks the database as a template and async with drops it cleanly.""" + dbname = xdistify_dbname("async_janitor_tmpl") + janitor = AsyncDatabaseJanitor( + user=postgresql_proc.user, + host=postgresql_proc.host, + port=postgresql_proc.port, + dbname=dbname, + version=postgresql_proc.version, + password=postgresql_proc.password, + as_template=True, + connection_timeout=5, + ) + async with janitor: + assert await _database_is_template(postgresql_proc, dbname) + assert not await _database_exists(postgresql_proc, dbname) + + +@pytest.mark.asyncio +async def test_async_janitor_creates_database_from_template(postgresql_proc: PostgreSQLExecutor) -> None: + """init() clones schema and data from a template database.""" + base_dbname = xdistify_dbname("async_janitor_tmpl_base") + clone_dbname = xdistify_dbname("async_janitor_tmpl_clone") + base_janitor = AsyncDatabaseJanitor( + user=postgresql_proc.user, + host=postgresql_proc.host, + port=postgresql_proc.port, + dbname=base_dbname, + version=postgresql_proc.version, + password=postgresql_proc.password, + as_template=True, + connection_timeout=5, + ) + clone_janitor = AsyncDatabaseJanitor( + user=postgresql_proc.user, + host=postgresql_proc.host, + port=postgresql_proc.port, + dbname=clone_dbname, + template_dbname=base_dbname, + version=postgresql_proc.version, + password=postgresql_proc.password, + connection_timeout=5, + ) + try: + await base_janitor.init() + await base_janitor.load(TEST_SQL_FILE) + await clone_janitor.init() + async with await psycopg.AsyncConnection.connect( + dbname=clone_dbname, + user=postgresql_proc.user, + password=postgresql_proc.password, + host=postgresql_proc.host, + port=postgresql_proc.port, + ) as conn: + async with conn.cursor() as cur: + await cur.execute("SELECT * FROM test_load") + rows = await cur.fetchall() + assert len(rows) == 1 + finally: + await clone_janitor.drop() + await base_janitor.drop() + + assert not await _database_exists(postgresql_proc, clone_dbname) + assert not await _database_exists(postgresql_proc, base_dbname) + + +# --------------------------------------------------------------------------- +# AsyncDatabaseJanitor -- lightweight unit tests +# --------------------------------------------------------------------------- + + +def test_async_janitor_is_template_false() -> None: + """is_template() returns False when as_template is not set.""" + janitor = AsyncDatabaseJanitor(user="user", host="host", port="1234", dbname="mydb", version=10) + assert janitor.is_template() is False + + +def test_async_janitor_is_template_true() -> None: + """is_template() returns True when as_template=True.""" + janitor = AsyncDatabaseJanitor(user="user", host="host", port="1234", dbname="mydb", as_template=True, version=10) + assert janitor.is_template() is True + + +@pytest.mark.asyncio +async def test_async_janitor_context_manager_calls_init_and_drop() -> None: + """__aenter__ calls init() and __aexit__ calls drop().""" + janitor = AsyncDatabaseJanitor(user="user", host="host", port="1234", dbname="mydb", version=10) + init_mock = AsyncMock() + drop_mock = AsyncMock() + with patch.object(AsyncDatabaseJanitor, "init", init_mock), patch.object(AsyncDatabaseJanitor, "drop", drop_mock): + async with janitor: + init_mock.assert_called_once() + drop_mock.assert_not_called() + drop_mock.assert_called_once() + + +@pytest.mark.asyncio +async def test_async_janitor_terminate_connection_sql() -> None: + """_terminate_connection() executes pg_terminate_backend query with correct dbname.""" + cur = AsyncMock(spec=AsyncCursor) + await AsyncDatabaseJanitor._terminate_connection(cur, "target_db") + + cur.execute.assert_called_once() + sql_str, params = cur.execute.call_args.args + assert "pg_terminate_backend" in sql_str + assert params == ("target_db",) + + +def _make_async_conn_mock() -> MagicMock: + """Create a MagicMock that behaves like a psycopg3 AsyncConnection.""" + conn = MagicMock() + conn.close = AsyncMock() + conn.set_isolation_level = AsyncMock() + conn.set_autocommit = AsyncMock() + cursor_mock = MagicMock() + cursor_mock.__aenter__ = AsyncMock(return_value=MagicMock()) + cursor_mock.__aexit__ = AsyncMock(return_value=False) + conn.cursor = MagicMock(return_value=cursor_mock) + return conn diff --git a/tests/test_loader.py b/tests/test_loader.py index c03f8a55..ce931b8e 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -1,8 +1,12 @@ """Tests for the `build_loader` function.""" from pathlib import Path +from types import ModuleType +from unittest.mock import patch -from pytest_postgresql.loader import build_loader, sql +import pytest + +from pytest_postgresql.loader import build_loader, build_loader_async, sql, sql_async from tests.loader import load_database @@ -12,9 +16,71 @@ def test_loader_callables() -> None: assert load_database == build_loader("tests.loader:load_database") +def test_loader_callables_dot_separator() -> None: + """Test dot-separated import path resolves the same callable as colon-separated.""" + assert build_loader("tests.loader.load_database") == load_database + + +def test_loader_deeply_nested_import_path() -> None: + """All path segments before the final delimiter are joined as the import path.""" + sentinel = object() + fake_module = ModuleType("fake_module") + fake_module.my_loader = sentinel # type: ignore[attr-defined] + with patch("pytest_postgresql.loader.importlib.import_module", return_value=fake_module) as import_mock: + result = build_loader("a.b.c.d:my_loader") + import_mock.assert_called_once_with("a.b.c.d") + assert result is sentinel + + +@pytest.mark.asyncio +async def test_loader_callables_async() -> None: + """Async test handling callables in build_loader_async.""" + assert load_database == build_loader_async(load_database) + assert load_database == build_loader_async("tests.loader:load_database") + + async def afun(*_args: object, **_kwargs: object) -> int: + return 0 + + assert afun == build_loader_async(afun) + + +@pytest.mark.asyncio +async def test_loader_callables_async_dot_separator() -> None: + """Dot-separated import path is resolved identically by build_loader_async.""" + assert build_loader_async("tests.loader.load_database") == load_database + + +def test_loader_async_deeply_nested_import_path() -> None: + """build_loader_async splits all path segments before the final loader name.""" + sentinel = object() + fake_module = ModuleType("fake_module") + fake_module.my_loader = sentinel # type: ignore[attr-defined] + with patch("pytest_postgresql.loader.importlib.import_module", return_value=fake_module) as import_mock: + result = build_loader_async("a.b.c.d:my_loader") + import_mock.assert_called_once_with("a.b.c.d") + assert result is sentinel + + def test_loader_sql() -> None: """Test returning partial running sql for the sql file path.""" sql_path = Path("test_sql/eidastats.sql") loader_func = build_loader(sql_path) assert loader_func.args == (sql_path,) # type: ignore assert loader_func.func == sql # type: ignore + + +@pytest.mark.asyncio +async def test_loader_sql_async() -> None: + """Async test returning partial running sql_async for the sql file path.""" + sql_path = Path("test_sql/eidastats.sql") + loader_func = build_loader_async(sql_path) + assert loader_func.args == (sql_path,) # type: ignore + assert loader_func.func == sql_async # type: ignore + + +@pytest.mark.asyncio +async def test_sql_async_raises_without_aiofiles() -> None: + """sql_async raises ImportError with a helpful message when aiofiles is not installed.""" + with patch("pytest_postgresql.loader.aiofiles", None): + with pytest.raises(ImportError, match="aiofiles"): + await sql_async(Path("dummy.sql"), host="h", port=5432, user="u", dbname="d") diff --git a/tests/test_plugin_asyncio.py b/tests/test_plugin_asyncio.py new file mode 100644 index 00000000..290a499f --- /dev/null +++ b/tests/test_plugin_asyncio.py @@ -0,0 +1,57 @@ +"""Tests for Windows asyncio loop configuration in the plugin.""" + +import asyncio +import sys +from unittest.mock import MagicMock, patch + +import pytest + +import pytest_postgresql.plugin as plugin_module +from pytest_postgresql.plugin import _windows_selector_event_loop, pytest_configure + + +@pytest.mark.skipif(sys.version_info < (3, 14), reason="Deprecation applies from Python 3.14") +def test_pytest_configure_skips_deprecated_policy_on_python_314() -> None: + """pytest_configure must not call deprecated asyncio policy APIs on Python 3.14+.""" + config = MagicMock() + config.pluginmanager.has_plugin.return_value = True + + with ( + patch("pytest_postgresql.plugin.platform.system", return_value="Windows"), + patch.object(asyncio, "set_event_loop_policy") as set_policy, + ): + pytest_configure(config) + + set_policy.assert_not_called() + + +@pytest.mark.skipif(sys.platform != "win32", reason="Windows-specific loop factory") +def test_windows_selector_loop_factory() -> None: + """Windows selector loop factory returns a SelectorEventLoop instance.""" + loop = _windows_selector_event_loop() + try: + assert isinstance(loop, asyncio.SelectorEventLoop) + finally: + loop.close() + + +@pytest.mark.skipif(sys.platform == "win32", reason="Windows registers loop factory hook at import") +def test_loop_factory_hook_not_registered_on_non_windows() -> None: + """Non-Windows platforms must not register pytest_asyncio_loop_factories.""" + assert not hasattr(plugin_module, "pytest_asyncio_loop_factories") + + +@pytest.mark.skipif(sys.platform != "win32", reason="Windows-specific loop factory hook") +def test_pytest_asyncio_loop_factories_on_windows() -> None: + """Windows configures a single selector loop factory for pytest-asyncio.""" + config = MagicMock() + item = MagicMock() + factories = plugin_module.pytest_asyncio_loop_factories(config, item) + + assert factories is not None + assert set(factories) == {"selector"} + loop = factories["selector"]() + try: + assert isinstance(loop, asyncio.SelectorEventLoop) + finally: + loop.close() diff --git a/tests/test_postgres_options_plugin.py b/tests/test_postgres_options_plugin.py index 754dc3b6..297eb64f 100644 --- a/tests/test_postgres_options_plugin.py +++ b/tests/test_postgres_options_plugin.py @@ -144,3 +144,67 @@ def test_postgres_drop_test_database( pass assert hasattr(excinfo.value, "__cause__") assert f'FATAL: database "{template_janitor.dbname}" does not exist' in str(excinfo.value.__cause__) + + +def test_postgres_drop_test_database_async( + postgresql_proc_to_override: PostgreSQLExecutor, + pointed_pytester: Pytester, +) -> None: + """Check that async client fixture drops the database when --postgresql-drop-test-database is set. + + Mirrors ``test_postgres_drop_test_database`` but runs an async subprocess test that uses + ``postgresql_async`` against the same live PostgreSQL process. + """ + dbname = xdistify_dbname("override") + template_dbname = dbname + "_tmpl" + template_janitor = DatabaseJanitor( + user=postgresql_proc_to_override.user, + host=postgresql_proc_to_override.host, + port=postgresql_proc_to_override.port, + dbname=template_dbname, + as_template=True, + version=postgresql_proc_to_override.version, + password=postgresql_proc_to_override.password, + connection_timeout=5, + ) + template_janitor.init() + template_janitor.load(load_database) + assert template_janitor.dbname + janitor = DatabaseJanitor( + user=postgresql_proc_to_override.user, + host=postgresql_proc_to_override.host, + port=postgresql_proc_to_override.port, + dbname=dbname, + template_dbname=template_janitor.dbname, + version=postgresql_proc_to_override.version, + password=postgresql_proc_to_override.password, + connection_timeout=5, + ) + janitor.init() + assert janitor.dbname + with janitor.cursor(janitor.dbname) as cur: + cur.execute("SELECT * FROM stories") + res = cur.fetchall() + assert len(res) == 4 + + pointed_pytester.copy_example("test_drop_test_database_async.py") + test_sql_path = pointed_pytester.copy_example("test.sql") + ret = pointed_pytester.runpytest( + f"--postgresql-load={test_sql_path}", + f"--postgresql-port={postgresql_proc_to_override.port}", + "--postgresql-dbname=override", + "--postgresql-drop-test-database", + "test_drop_test_database_async.py", + ) + ret.assert_outcomes(passed=1) + + with pytest.raises(TimeoutError) as excinfo: + with janitor.cursor(janitor.dbname): + pass + assert hasattr(excinfo.value, "__cause__") + assert f'FATAL: database "{janitor.dbname}" does not exist' in str(excinfo.value.__cause__) + with pytest.raises(TimeoutError) as excinfo: + with template_janitor.cursor(template_janitor.dbname): + pass + assert hasattr(excinfo.value, "__cause__") + assert f'FATAL: database "{template_janitor.dbname}" does not exist' in str(excinfo.value.__cause__) diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index 1b86beaf..1461694d 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -1,11 +1,14 @@ """All tests for pytest-postgresql.""" +import decimal + import pytest -from psycopg import Connection +from psycopg import AsyncConnection, Connection from psycopg.pq import ConnStatus from pytest_postgresql.executor import PostgreSQLExecutor -from pytest_postgresql.retry import retry +from pytest_postgresql.retry import retry, retry_async +from tests.conftest import POSTGRESQL_VERSION MAKE_Q = "CREATE TABLE test (id serial PRIMARY KEY, num integer, data varchar);" SELECT_Q = "SELECT * FROM test_load;" @@ -66,3 +69,60 @@ def check_if_one_connection() -> None: assert len(existing_connections) == 1, f"there is always only one connection, {existing_connections}" retry(check_if_one_connection, timeout=120, possible_exception=AssertionError) + + +@pytest.mark.asyncio +async def test_main_postgres_async(postgresql_async: AsyncConnection) -> None: + """Async check main postgresql fixture.""" + async with postgresql_async.cursor() as cur: + await cur.execute(MAKE_Q) + await postgresql_async.commit() + + +@pytest.mark.asyncio +async def test_two_postgreses_async(postgresql_async: AsyncConnection, postgresql2_async: AsyncConnection) -> None: + """Async check two postgresql fixtures on one test.""" + async with postgresql_async.cursor() as cur: + await cur.execute(MAKE_Q) + await postgresql_async.commit() + + async with postgresql2_async.cursor() as cur: + await cur.execute(MAKE_Q) + await postgresql2_async.commit() + + +@pytest.mark.asyncio +async def test_postgres_load_two_files_async(postgresql_load_1_async: AsyncConnection) -> None: + """Async check postgresql fixture can load two files.""" + async with postgresql_load_1_async.cursor() as cur: + await cur.execute(SELECT_Q) + results = await cur.fetchall() + assert len(results) == 2 + + +@pytest.mark.asyncio +async def test_rand_postgres_port_async(postgresql2_async: AsyncConnection) -> None: + """Async check if postgres fixture can be started on random port.""" + assert postgresql2_async.info.status == ConnStatus.OK + + +@pytest.mark.skipif( + decimal.Decimal(POSTGRESQL_VERSION) < 10, + reason="Test query not supported in those postgresql versions, and soon will not be supported.", +) +@pytest.mark.xdist_group(name="terminate_connection") +@pytest.mark.asyncio +@pytest.mark.parametrize("_", range(2)) +async def test_postgres_terminate_connection_async(postgresql2_async: AsyncConnection, _: int) -> None: + """Async test that connections are terminated between tests. + + And check that only one exists at a time. + """ + async with postgresql2_async.cursor() as cur: + + async def check_if_one_connection() -> None: + await cur.execute("SELECT * FROM pg_stat_activity WHERE backend_type = 'client backend';") + existing_connections = await cur.fetchall() + assert len(existing_connections) == 1, f"there is always only one connection, {existing_connections}" + + await retry_async(check_if_one_connection, timeout=120, possible_exception=AssertionError) diff --git a/tests/test_retry.py b/tests/test_retry.py new file mode 100644 index 00000000..613f4eba --- /dev/null +++ b/tests/test_retry.py @@ -0,0 +1,76 @@ +"""Unit tests for retry and retry_async.""" + +import datetime +from unittest.mock import AsyncMock, patch + +import pytest + +from pytest_postgresql.retry import retry_async + + +@pytest.mark.asyncio +async def test_retry_async_immediate_success() -> None: + """Test that retry_async returns immediately when function succeeds on first call.""" + + async def ok() -> int: + return 42 + + assert await retry_async(ok, timeout=5) == 42 + + +@pytest.mark.asyncio +async def test_retry_async_succeeds_after_failures() -> None: + """Test that retry_async retries on the expected exception and returns on success.""" + attempts = 0 + + async def flaky() -> str: + nonlocal attempts + attempts += 1 + if attempts < 3: + raise ConnectionError("transient") + return "ok" + + sleep_mock = AsyncMock() + with patch("pytest_postgresql.retry.asyncio.sleep", sleep_mock): + result = await retry_async(flaky, timeout=10, possible_exception=ConnectionError) + + assert result == "ok" + assert attempts == 3 + assert sleep_mock.call_count == 2 + + +@pytest.mark.asyncio +async def test_retry_async_timeout() -> None: + """Test that retry_async raises TimeoutError after the timeout elapses.""" + always_fail_mock = AsyncMock(side_effect=ValueError("boom")) + sleep_mock = AsyncMock() + base = datetime.datetime(2026, 1, 1, tzinfo=datetime.timezone.utc) + call_count = 0 + + def advancing_clock() -> datetime.datetime: + nonlocal call_count + call_count += 1 + # First call captures starting time; all subsequent calls report past the timeout. + return base if call_count == 1 else base + datetime.timedelta(seconds=10) + + with ( + patch("pytest_postgresql.retry.asyncio.sleep", sleep_mock), + patch("pytest_postgresql.retry.get_current_datetime", advancing_clock), + ): + with pytest.raises(TimeoutError, match="Failed after"): + await retry_async(always_fail_mock, timeout=1, possible_exception=ValueError) + + sleep_mock.assert_not_awaited() + assert always_fail_mock.await_count == 1 + assert call_count == 2 + + +@pytest.mark.asyncio +async def test_retry_async_unmatched_exception_propagates() -> None: + """Test that an exception not matching possible_exception propagates immediately.""" + + async def wrong_exc() -> None: + raise TypeError("unexpected") + + with pytest.raises(TypeError, match="unexpected"): + await retry_async(wrong_exc, timeout=5, possible_exception=ValueError) diff --git a/tests/test_template_database.py b/tests/test_template_database.py index 64631779..7c4fa283 100644 --- a/tests/test_template_database.py +++ b/tests/test_template_database.py @@ -1,9 +1,9 @@ """Template database tests.""" import pytest -from psycopg import Connection +from psycopg import AsyncConnection, Connection -from pytest_postgresql.factories import postgresql, postgresql_proc +from pytest_postgresql.factories import postgresql, postgresql_async, postgresql_proc from tests.loader import load_database postgresql_proc_with_template = postgresql_proc( @@ -17,6 +17,11 @@ dbname="stories_templated", ) +async_postgresql_template = postgresql_async( + "postgresql_proc_with_template", + dbname="stories_templated", +) + @pytest.mark.xdist_group(name="template_database") @pytest.mark.parametrize("_", range(5)) @@ -30,3 +35,18 @@ def test_template_database(postgresql_template: Connection, _: int) -> None: cur.execute("SELECT * FROM stories") res = cur.fetchall() assert len(res) == 0 + + +@pytest.mark.xdist_group(name="template_database") +@pytest.mark.asyncio +@pytest.mark.parametrize("_", range(5)) +async def test_template_database_async(async_postgresql_template: AsyncConnection, _: int) -> None: + """Async check that the database structure gets recreated out of a template.""" + async with async_postgresql_template.cursor() as cur: + await cur.execute("SELECT * FROM stories") + res = await cur.fetchall() + assert len(res) == 4 + await cur.execute("TRUNCATE stories") + await cur.execute("SELECT * FROM stories") + res = await cur.fetchall() + assert len(res) == 0