diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 30156f63..118e2fca 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -12,6 +12,8 @@ Changes is available (#427) * FIX: Bytecodes of profiled functions now always labeled to prevent confusion with non-profiled "twins" (#425) +* FEAT: Experimental support for profiling child processes with + ``kernprof --prof-child-procs`` (#431) 5.0.2 diff --git a/_line_profiler_hooks.py b/_line_profiler_hooks.py new file mode 100644 index 00000000..5b5c888a --- /dev/null +++ b/_line_profiler_hooks.py @@ -0,0 +1,72 @@ +""" +Additional hooks installed by :py:mod:`line_profiler`. + +Notes: + - This file and its content should be considered an implmentation + detail of :py:mod:`line_profiler`; currently we just use this to + set up shop in a child Python process, and extend profiling to + therein. + + - This current implementation writes temporary .pth files to the + site-packages directory, which are executed for all Python + processes referring to the same :path:`lib/`. However, only + processes originating from a parent which set the requisite + environment variables will execute to the profiling code. + + - Said .pth file always import this module; hence, this file is kept + intentionally lean and separate from the main + :py:mod:`line_profiler` package to reduce overhead; e.g. + imports in this file are deferred to being as late as possible. + + - Inspired by similar code in :py:mod:`coverage.control` and + :py:mod:`pytest_autoprofile.startup_hook`. +""" +import os + + +__all__ = ('load_pth_hook',) + +INHERITED_PID_ENV_VARNAME = ( + 'LINE_PROFILER_PROFILE_CHILD_PROCESSES_CACHE_PID' +) + + +def load_pth_hook(ppid: int) -> None: + """ + Function imported and called by the written .pth file; to reduce + overhead, we immediately return if ``ppid`` doesn't match + :env:`LINE_PROFILER_PROFILE_CHILD_PROCESSES_CACHE_PID`. + """ + try: + env_ppid = int(os.environ[INHERITED_PID_ENV_VARNAME]) + except (KeyError, ValueError): + return + if env_ppid != ppid: + return + + # If we're here, we're most probably in a descendent process of a + # profiled Python process, so we can be more liberal with the + # imports without worrying about overhead + import warnings + from line_profiler._diagnostics import DEBUG, log + from line_profiler._child_process_profiling.cache import LineProfilingCache + + # Note: .pth files may be double-loaded in a virtual environment + # (see https://stackoverflow.com/questions/58807569), so work around + # that; + # also see similar check in `coverage.control.process_startup()` + if getattr(load_pth_hook, 'called', False): + return + try: + cache = LineProfilingCache.load() + cache._setup_in_child_process(True, 'pth') + except Exception as e: # nocover + if DEBUG: + msg = f'{type(e)}: {e}' + # Write log befor issuing warning, in case the warning is + # promoted to an exception + log.warning(msg) + warnings.warn(msg) + load_pth_hook.called = True # type: ignore + else: + cache.patch(load_pth_hook, 'called', True) diff --git a/kernprof.py b/kernprof.py index 8a7c4d6a..e9612fad 100755 --- a/kernprof.py +++ b/kernprof.py @@ -79,6 +79,7 @@ def main(): [-s SETUP] [-p {path/to/script | object.dotted.path}[,...]] [--preimports [Y[es] | N[o] | T[rue] | F[alse] | on | off | 1 | 0]] [--prof-imports [Y[es] | N[o] | T[rue] | F[alse] | on | off | 1 | 0]] + [--prof-child-procs [Y[es] | N[o] | T[rue] | F[alse] | on | off | 1 | 0]] [-o OUTFILE] [-v] [-q] [--rich [Y[es] | N[o] | T[rue] | F[alse] | on | off | 1 | 0]] [-u UNIT] @@ -137,6 +138,10 @@ def main(): If the script/module profiled is in `--prof-mod`, autoprofile all its imports. Only works with line profiling (`-l`/`--line- by-line`). (Default: False) + --prof-child-procs [Y[es] | N[o] | T[rue] | F[alse] | on | off | 1 | 0] + Extend profiling into child Python processes. Only works with + line profiling (`-l`/`--line-by-line`). (EXPERIMENTAL; + default: False) output options: -o, --outfile OUTFILE @@ -187,7 +192,6 @@ def main(): """ # noqa: E501 import atexit -import builtins import functools import os import sys @@ -198,9 +202,7 @@ def main(): import shutil import tempfile import time -import warnings -from argparse import ArgumentParser -from io import StringIO +from argparse import ArgumentParser, SUPPRESS from operator import methodcaller from runpy import run_module from pathlib import Path @@ -228,12 +230,20 @@ def main(): positive_float, short_string_path, ) +from line_profiler.line_profiler_utils import ( + make_tempfile as _touch_tempfile, # Compatibility +) from line_profiler.profiler_mixin import ByCountProfilerMixin +from line_profiler._child_process_profiling.cache import LineProfilingCache from line_profiler._logger import Logger from line_profiler import _diagnostics as diagnostics DIAGNOSITICS_VERBOSITY = 2 +CLEANUP_PRIORITIES = { # More negative number -> more delayed + 'rm_cache_dir': -1024, + 'gather_logs': -1, +} def execfile(filename, globals=None, locals=None): @@ -330,6 +340,7 @@ def resolve_module_path(mod_name): # type: (str) -> str | None fname = mod_spec.origin # type: str | None if fname and os.path.exists(fname): return fname + return None get_module_path = modname_to_modpath if static else resolve_module_path @@ -681,6 +692,14 @@ def _add_core_parser_arguments(parser): 'Only works with line profiling (`-l`/`--line-by-line`). ' f'(Default: {default.conf_dict["prof_imports"]})', ) + add_argument( + prof_opts, + '--prof-child-procs', + action='store_true', + help='Extend profiling into child Python processes. ' + 'Only works with line profiling (`-l`/`--line-by-line`). ' + f'(EXPERIMENTAL; default: {default.conf_dict["prof_child_procs"]})', + ) out_opts = parser.add_argument_group('output options') if default.conf_dict['outfile']: def_outfile = repr(default.conf_dict['outfile']) @@ -771,6 +790,8 @@ def _add_core_parser_arguments(parser): 'Minimum value (and the value implied if the bare option ' f'is given) is 1 s. (Default: {def_out_int})', ) + # Hidden option for dumping the debug logs to a desinated location + add_argument(out_opts, '--debug-log', help=SUPPRESS) def _build_parsers(args=None): @@ -803,8 +824,8 @@ def _build_parsers(args=None): # We've already consumed the `-m `, so we need a dummy # parser for generating the help text; # but the real parser should not consume the `options.script` - # positional arg, and it it got the `--help` option, it should - # hand off the the dummy parser + # positional arg, and if it got the `--help` option, it should + # hand off to the dummy parser real_parser = ArgumentParser(add_help=False, **parser_kwargs) real_parser.add_argument('-h', '--help', action='store_true') help_parser = ArgumentParser(**parser_kwargs) @@ -994,21 +1015,6 @@ def main(args=None, *, exit_on_error=True): cleanup() -def _touch_tempfile(*args, **kwargs): - """ - Wrapper around :py:func:`tempfile.mkstemp()` which drops and closes - the integer handle (which we don't need and may cause issues on some - platforms). - """ - handle, path = tempfile.mkstemp(*args, **kwargs) - try: - os.close(handle) - except Exception: - os.remove(path) - raise - return path - - def _write_tempfile(source, content, options): """ Called by :py:func:`main()` to handle :command:`kernprof -c` and @@ -1043,104 +1049,33 @@ def _write_tempfile(source, content, options): ) -def _gather_preimport_targets(options, exclude): - """ - Used in _write_preimports - """ - from line_profiler.autoprofile.util_static import modpath_to_modname - from line_profiler.autoprofile.eager_preimports import is_dotted_path - - filtered_targets = [] - recurse_targets = [] - invalid_targets = [] - for target in options.prof_mod: - if is_dotted_path(target): - modname = target - else: - # Paths already normalized by - # `_normalize_profiling_targets()` - if not os.path.exists(target): - invalid_targets.append(target) - continue - if any(os.path.samefile(target, excluded) for excluded in exclude): - # Ignore the script to be run in eager importing - # (`line_profiler.autoprofile.autoprofile.run()` will - # handle it) - continue - modname = modpath_to_modname(target, hide_init=False) - if modname is None: # Not import-able - invalid_targets.append(target) - continue - if modname.endswith('.__init__'): - modname = modname.rpartition('.')[0] - filtered_targets.append(modname) - else: - recurse_targets.append(modname) - if invalid_targets: - invalid_targets = sorted(set(invalid_targets)) - msg = ( - '{} profile-on-import target{} cannot be converted to ' - 'dotted-path form: {!r}'.format( - len(invalid_targets), - '' if len(invalid_targets) == 1 else 's', - invalid_targets, - ) - ) - warnings.warn(msg) - diagnostics.log.warning(msg) - - return filtered_targets, recurse_targets - - -def _write_preimports(prof, options, exclude): +def _write_preimports(prof, options, exclude, keep=False): """ Called by :py:func:`main()` to handle eager pre-imports; not to be invoked on its own. """ - from line_profiler.autoprofile.eager_preimports import ( - write_eager_import_module, - ) - from line_profiler.autoprofile.autoprofile import ( - _extend_line_profiler_for_profiling_imports as upgrade_profiler, - ) + from line_profiler.curated_profiling import ClassifiedPreimportTargets - filtered_targets, recurse_targets = _gather_preimport_targets( - options, exclude - ) - if not (filtered_targets or recurse_targets): - return # We could've done everything in-memory with `io.StringIO` and `exec()`, # but that results in indecipherable tracebacks should anything goes wrong; # so we write to a tempfile and `execfile()` it - upgrade_profiler(prof) temp_mod_path = _touch_tempfile( dir=options.tmpdir, prefix='kernprof-eager-preimports-', suffix='.py' ) - write_module_kwargs = { - 'dotted_paths': filtered_targets, - 'recurse': recurse_targets, - 'static': options.static, - } - temp_file = open(temp_mod_path, mode='w') - if options.debug: - with StringIO() as sio: - write_eager_import_module(stream=sio, **write_module_kwargs) - code = sio.getvalue() - with temp_file as fobj: - print(code, file=fobj) - diagnostics.log.debug( - 'Wrote temporary module for pre-imports to ' - f'{short_string_path(temp_mod_path)!r}' + with open(temp_mod_path, mode='w') as fobj: + preimports = ClassifiedPreimportTargets.from_targets( + options.prof_mod, exclude, ) - else: - with temp_file as fobj: - write_eager_import_module(stream=fobj, **write_module_kwargs) - if not options.dryrun: + preimports.write_preimport_module( + fobj, debug=options.debug, static=options.static, + ) + if preimports and not options.dryrun: ns = {} # Use a fresh namespace execfile(temp_mod_path, ns, ns) # Delete the tempfile ASAP if its execution succeeded - if not diagnostics.KEEP_TEMPDIRS: + if not (keep or diagnostics.KEEP_TEMPDIRS): _remove(temp_mod_path) + return temp_mod_path def _remove(path, *, recursive=False, missing_ok=False): @@ -1154,9 +1089,20 @@ def _remove(path, *, recursive=False, missing_ok=False): path.unlink(missing_ok=missing_ok) -def _dump_filtered_stats(tmpdir, prof, filename): +def _dump_filtered_stats(tmpdir, prof, filename, extra_line_stats=None): import os - import pickle + + if isinstance(prof, ContextualProfile): + # - Not using `line_profiler` + # -> doesn't matter if the source lines can't be retrieved + # -> no need to filter anything + prof.dump_stats(filename) + return + + # Remember to incorporate extra stats where available + line_stats = prof.get_stats() + if extra_line_stats is not None: + line_stats += extra_line_stats # Build list of known temp file paths tempfile_paths = [ @@ -1164,31 +1110,28 @@ def _dump_filtered_stats(tmpdir, prof, filename): for dirpath, _, fnames in os.walk(tmpdir) for fname in fnames ] - - if not tempfile_paths or isinstance(prof, ContextualProfile): + if not tempfile_paths: # - No tempfiles written -> no function lives in tempfiles # -> no need to filter anything - # - Not using `line_profiler` - # -> doesn't matter if the source lines can't be retrieved - # -> no need to filter anything - prof.dump_stats(filename) + line_stats.to_file(filename) return + _dump_filtered_line_stats(line_stats, tempfile_paths, filename) + + +def _dump_filtered_line_stats(stats, exclude, filename): # Filter the filenames to remove data from tempfiles, which will # have been deleted by the time the results are viewed in a # separate process - stats = prof.get_stats() timings = stats.timings for key in set(timings): fname = key[0] try: - if any(os.path.samefile(fname, tmp) for tmp in tempfile_paths): + if any(os.path.samefile(fname, tmp) for tmp in exclude): del timings[key] except OSError: del timings[key] - - with open(filename, 'wb') as f: - pickle.dump(stats, f, protocol=pickle.HIGHEST_PROTOCOL) + stats.to_file(filename) def _format_call_message(func, *args, **kwargs): @@ -1231,13 +1174,87 @@ def _call_with_diagnostics(options, func, *args, **kwargs): return func(*args, **kwargs) -def _pre_profile(options, module, exit_on_error): +class _manage_profiler: """ Prepare the environment to execute profiling with requested options. Note: modifies ``options`` with extra attributes. """ + cache: LineProfilingCache + + def __init__(self, options, module, exit_on_error): + self.options = options + self.module = module + self.exit_on_error = exit_on_error + + def __enter__(self): + from line_profiler.curated_profiling import CuratedProfilerContext + + self.prof = _prepare_profiler( + self.options, self.module, self.exit_on_error, + ) + self._ctx = CuratedProfilerContext( + self.prof, insert_builtin=self.options.builtin, + ) + self._ctx.install() + # Keep the generated pre-imports file to be reused in child + # processes + script_file, preimports_file = _prepare_exec_script( + self.options, self.module, self.prof, + exit_on_error=self.exit_on_error, + keep_preimports_file=self.set_up_child_profiling, + ) + if self.set_up_child_profiling: + self.cache = cache = _prepare_child_profiling_cache( + self.options, self._ctx, self.prof, + preimports_file, script_file, + ) + # Add deferred callbacks for gathering debug logfiles + # (should run right before `.cache.cache_dir` is wiped): + # - Write the debug logs to the `._diagnostics` logger + if cache.debug: + self._ctx.add_cleanup_with_priority( + cache._dump_debug_logs, CLEANUP_PRIORITIES['gather_logs'], + ) + # - Write the debug logs to a specific file + if self.options.debug_log: + self._ctx.add_cleanup_with_priority( + self._gather_debug_log, + CLEANUP_PRIORITIES['gather_logs'], + self.options.debug_log, + ) + return self.prof, script_file + + def __exit__(self, *_, **__): + try: + extra_stats = None + if self.set_up_child_profiling: + # Cleaning up here ensures the `multiprocessing` + # fork-server process is rebooted, thus any profiling + # data on it will be properly collected + self.cache.cleanup() + extra_stats = self.cache.gather_stats() + _post_profile(self.options, self.prof, extra_stats) + finally: + self._ctx.uninstall() + + def _gather_debug_log(self, logfile): + with open(logfile, mode='w') as fobj: + for entry in self.cache._gather_debug_log_entries(): + print(entry.to_text(), file=fobj) + + @property + def set_up_child_profiling(self): + return bool( + self.options.line_by_line and self.options.prof_child_procs + ) + + +def _prepare_profiler(options, module, exit_on_error): + """ + Set up the appropriate profiler instance. + """ if not options.outfile: extension = 'lprof' if options.line_by_line else 'prof' options.outfile = f'{os.path.basename(options.script)}.{extension}' @@ -1267,24 +1284,26 @@ def _pre_profile(options, module, exit_on_error): execfile(setup_file, ns, ns) if options.line_by_line: - prof = line_profiler.LineProfiler() options.builtin = True + return line_profiler.LineProfiler() elif Profile.__module__ == 'profile': raise RuntimeError( 'non-line-by-line profiling depends on cProfile, ' 'which is not available on this platform' ) else: - prof = ContextualProfile() - - # Overwrite the explicit decorator - global_profiler = line_profiler.profile - install_profiler = global_profiler._kernprof_overwrite - install_profiler(prof) + return ContextualProfile() - if options.builtin: - builtins.__dict__['profile'] = prof +def _prepare_exec_script( + options, module, prof, + *, + exit_on_error=False, + keep_preimports_file=False, +): + """ + Set up the script to be executed among other things. + """ if module: script_file = find_module_script( options.script, static=options.static, exit_on_error=exit_on_error @@ -1304,6 +1323,8 @@ def _pre_profile(options, module, exit_on_error): options.prof_mod = _normalize_profiling_targets(options.prof_mod) if not options.prof_mod: options.preimports = False + + preimports_file = None if options.line_by_line and options.preimports: # We assume most items in `.prof_mod` to be import-able without # significant side effects, but the same cannot be said if it @@ -1311,10 +1332,10 @@ def _pre_profile(options, module, exit_on_error): # even have a `if __name__ == '__main__': ...` guard. So don't # eager-import it. exclude = set() if module else {script_file} - _write_preimports(prof, options, exclude) + preimports_file = _write_preimports( + prof, options, exclude, keep=keep_preimports_file, + ) - options.global_profiler = global_profiler - options.install_profiler = install_profiler if options.output_interval and not options.dryrun: options.rt = RepeatedTimer( max(options.output_interval, 1), prof.dump_stats, options.outfile @@ -1322,7 +1343,56 @@ def _pre_profile(options, module, exit_on_error): else: options.rt = None options.original_stdout = sys.stdout - return script_file, prof + return script_file, preimports_file + + +def _prepare_child_profiling_cache( + options, ctx, prof, preimports_file, script_file +): + """ + Handle the (line-)profiling of spawned/forked child Python + processes. + """ + # Create the cache dir and cache file here; the cache instance will + # be responsible for managing their lifetimes, while derivative + # instances in child processes will merely inherit and use them + cache = LineProfilingCache( + cache_dir=tempfile.mkdtemp(), + config=options.config, + profiling_targets=options.prof_mod, + rewrite_module=script_file, + profile_imports=options.prof_imports, + preimports_module=preimports_file, + insert_builtin=options.builtin, + debug=bool(options.debug or options.debug_log), + ) + clean_up = functools.partial(cache.add_cleanup, _remove, missing_ok=True) + if not diagnostics.KEEP_TEMPDIRS: + # Defer the scrubbing of the cache dir and let the context + # handle it, ideally speaking the cache dir should survive the + # cache object + ctx.add_cleanup_with_priority( + _remove, CLEANUP_PRIORITIES['rm_cache_dir'], cache.cache_dir, + recursive=True, + ) + clean_up(cache.filename) + + # This file is handed to us at the end of + # `_manage_profiler.__enter__()`; + # normally it is deleted before `.__enter__()` returns, but when + # child-process profiling is used, it is to persist for the lifetime + # of the cache (so that child processes can do the same preimports) + if not (preimports_file is None or diagnostics.KEEP_TEMPDIRS): + clean_up(preimports_file) + + # Handle various setup tasks (see docs thereof) + cache._setup_in_main_process() + cache.profiler = prof + + # Have the context clean up the cache as a failsafe + ctx.add_cleanup(cache.cleanup) + + return cache def _main_profile(options, module=False, exit_on_error=True): @@ -1330,9 +1400,10 @@ def _main_profile(options, module=False, exit_on_error=True): Called by :py:func:`main()` for the actual execution and profiling of code after initial parsing of options; not to be invoked on its own. """ - script_file, prof = _pre_profile(options, module, exit_on_error) call = functools.partial(_call_with_diagnostics, options) - try: + with _manage_profiler( + options, module, exit_on_error, + ) as (prof, script_file): rmod = functools.partial( run_module, run_name='__main__', alter_sys=True ) @@ -1383,18 +1454,18 @@ def _main_profile(options, module=False, exit_on_error=True): module_ns, module_ns, ) - finally: - _post_profile(options, prof) -def _post_profile(options, prof): +def _post_profile(options, prof, extra_line_stats=None): """ Cleanup setup after executing a main profile """ if options.rt is not None: options.rt.stop() if not options.dryrun: - _dump_filtered_stats(options.tmpdir, prof, options.outfile) + _dump_filtered_stats( + options.tmpdir, prof, options.outfile, extra_line_stats, + ) short_outfile = short_string_path(options.outfile) diagnostics.log.info( ( @@ -1405,9 +1476,15 @@ def _post_profile(options, prof): + f'to {short_outfile!r}' ) if options.verbose > 0 and not options.dryrun: - kwargs = {} - if not isinstance(prof, ContextualProfile): - kwargs.update( + if isinstance(prof, ContextualProfile): + _call_with_diagnostics(options, prof.print_stats) + else: + stats = prof.get_stats() + if extra_line_stats is not None: + stats += extra_line_stats + _call_with_diagnostics( + options, + stats.print, output_unit=options.unit, stripzeros=options.skip_zero, summarize=options.summarize, @@ -1415,7 +1492,6 @@ def _post_profile(options, prof): stream=options.original_stdout, config=options.config, ) - _call_with_diagnostics(options, prof.print_stats, **kwargs) else: py_exe = _python_command() if isinstance(prof, ContextualProfile): @@ -1427,12 +1503,6 @@ def _post_profile(options, prof): f'{quote(py_exe)} -m {show_mod} ' f'{quote(short_outfile)}' ) - # Fully disable the profiler - for _ in range(prof.enable_count): - prof.disable_by_count() - # Restore the state of the global `@line_profiler.profile` - if options.global_profiler: - options.install_profiler(None) if __name__ == '__main__': diff --git a/line_profiler/_child_process_profiling/__init__.py b/line_profiler/_child_process_profiling/__init__.py new file mode 100644 index 00000000..58e0806f --- /dev/null +++ b/line_profiler/_child_process_profiling/__init__.py @@ -0,0 +1,10 @@ +""" +Tooling for profiling child Python processes and gathering their +profiling results. + +Notes: + - THIS IS AN EXPERIMENTAL FEATURE. + + - All contents of this subpackage is to be considered implementation + details. +""" diff --git a/line_profiler/_child_process_profiling/_cache_logging.py b/line_profiler/_child_process_profiling/_cache_logging.py new file mode 100644 index 00000000..992f68d0 --- /dev/null +++ b/line_profiler/_child_process_profiling/_cache_logging.py @@ -0,0 +1,338 @@ +""" +Logging utilities. +""" +from __future__ import annotations + +import os +import re +from collections.abc import Generator +from datetime import datetime +from itertools import pairwise +from pathlib import Path +from string import Formatter as StringParser +from textwrap import dedent +from typing import TYPE_CHECKING, NamedTuple, TextIO, overload +from typing_extensions import Self + +from .. import _diagnostics as diagnostics +from ..line_profiler_utils import block_indent + + +__all__ = ('CacheLoggingEntry',) + + +FILENAME_PATTERN = 'debug_log_{main_pid}_{current_pid}.log' +TIMESTAMP_PATTERN = '[cache-debug-log {timestamp} DEBUG]' +HEADER_PATTERN = 'PID {current_pid} ({main_pid}): Cache {obj_id:#x}' + +TIMESTAMP_FORMAT = '%Y-%m-%d %H:%M:%S' +TIMESTAMP_MICROSECOND_SEP = ',' +TIMESTAMP_MICROSECOND_PLACES = 3 +TIMESTAMP_SPACING = ' ' + +HEADER_SEP = ': ' +HEADER_MAIN_INDICATOR = 'main process' + + +def get_logger_header(current_pid: int, main_pid: int, obj_id: int) -> str: + """ + Returns: + msg_header (str): + Message header, to be prefixed to messages sent to + :py:data:`line_profiler._diagnostics.log`. + """ + return HEADER_PATTERN.format( + current_pid=current_pid, + main_pid=( + HEADER_MAIN_INDICATOR if main_pid == current_pid else main_pid + ), + obj_id=obj_id, + ) + + +def format_timestamp(ts: datetime) -> str: + """ + Replicate the :py:mod:`logging`'s default formatting for timestamps. + + Example: + >>> ts = datetime(2000, 1, 23, 4, 5, 6, 789000) + >>> as_str = format_timestamp(ts) + >>> print(as_str) + 2000-01-23 04:05:06,789 + >>> assert parse_timestamp(as_str) == ts + """ + return '{}{}{:0{}d}'.format( + ts.strftime(TIMESTAMP_FORMAT), + TIMESTAMP_MICROSECOND_SEP, + int(ts.microsecond / 1000), + TIMESTAMP_MICROSECOND_PLACES, + ) + + +def parse_timestamp(ts: str) -> datetime: + """ + Turn a formatted string timestamp back to a + :py:class:`datetime.datetime` object. + """ + assert TIMESTAMP_MICROSECOND_SEP in ts + base, _, fractional = ts.rpartition(TIMESTAMP_MICROSECOND_SEP) + # The microsecond field %f must be 6 digits long + if len(fractional) < 6: + fractional = f'{fractional:<06}' + else: + fractional = fractional[:6] + parse_format = f'{TIMESTAMP_FORMAT}{TIMESTAMP_MICROSECOND_SEP}%f' + ts = f'{base}{TIMESTAMP_MICROSECOND_SEP}{fractional}' + return datetime.strptime(ts, parse_format) + + +def add_timestamp(msg: str, timestamp: datetime | None = None) -> str: + """ + Returns: + msg_with_timestamp (str): + (Block-indented) message with timestamp, to be written to + the :py:attr:`LineProfilingCache._debug_log`. + """ + if timestamp is None: + timestamp = datetime.now() + ts_formatted = TIMESTAMP_PATTERN.format( + timestamp=format_timestamp(timestamp), + ) + return block_indent(msg, ts_formatted + TIMESTAMP_SPACING) + + +def parse_id(uint: str) -> int: + """ + Example: + >>> n = 123456 + >>> for formatter in str, bin, oct, hex: + ... assert parse_id(formatter(n)) == n + """ + for prefix, base in ('0b', 2), ('0o', 8), ('0x', 16): + if uint.startswith(prefix): + return int(uint[len(prefix):], base=base) + return int(uint) + + +@overload +def fmt_to_regex(fmt: str, /, *auto_numbered_fields: str) -> str: + ... + + +@overload +def fmt_to_regex(fmt: str, /, **named_fields: str) -> str: + ... + + +def fmt_to_regex( + fmt: str, /, *auto_numbered_fields: str, **named_fields: str +) -> str: + """ + Example: + >>> import re + + Simple case: + + >>> pattern = fmt_to_regex( + ... '{func}({args})', func=r'[_\\w][_\\w\\d]+', args='.*', + ... ) + >>> print(pattern) + (?P[_\\w][_\\w\\d]+)\\((?P.*)\\) + >>> regex = re.compile('^' + pattern, re.MULTILINE) + >>> assert not regex.search('0(1)') + >>> match = regex.search(' \\nint(-1.5)') + >>> assert match.group('func', 'args') == ('int', '-1.5') + + Repeated fields: + + >>> palindrome_5l = re.compile(fmt_to_regex( + ... '{first}{second}{third}{second}{first}', + ... first='.', second='.', third='.', + ... )) + >>> print(palindrome_5l.pattern) + (?P.)(?P.)(?P.)(?P=second)(?P=first) + >>> assert not palindrome_5l.match('abbbe') + >>> match = palindrome_5l.match('aBcBa') + >>> assert match.group('first', 'second', 'third') == ( + ... 'a', 'B', 'c', + ... ) + + Auto-numbered fields: + + >>> print(fmt_to_regex( + ... '[{} {}-{}-{} {}:{}:{},{} {}]', + ... # Logger name + ... '.+', + ... # Date + ... r'\\d\\d', r'\\d\\d', r'\\d\\d', + ... # Time + milliseconds + ... r'\\d\\d', r'\\d\\d', r'\\d\\d', r'\\d\\d\\d', + ... # Category + ... 'DEBUG|INFO|WARNING|ERROR|CRITICAL', + ... )) + \\[(.+)\\ (\\d\\d)\\-(\\d\\d)\\-(\\d\\d)\\ \ +(\\d\\d):(\\d\\d):(\\d\\d),(\\d\\d\\d)\\ \ +(DEBUG|INFO|WARNING|ERROR|CRITICAL)\\] + """ + chunks: list[str] = [] + seen_fields: set[str] = set() + for i, (prefix, field, *_) in enumerate(StringParser().parse(fmt)): + chunks.append(re.escape(prefix)) + if field is None: + break # Suffix -> we're done + if field: # Named fields + assert field.isidentifier() + if field in seen_fields: + chunks.append(f'(?P={field})') + else: + chunks.append(f'(?P<{field}>{named_fields[field]})') + seen_fields.add(field) + else: # Auto-numbered fields + chunks.append(f'({auto_numbered_fields[i]})') + return ''.join(chunks) + + +class CacheLoggingEntry(NamedTuple): + """ + Logging entry written to a log file by + :py:meth:`LineProfilingCache._debug_output`. + + Example: + >>> from datetime import datetime + >>> + >>> + >>> entry = CacheLoggingEntry( + ... datetime(1900, 1, 1, 0, 0, 0, 0), + ... 12345, + ... 12345, + ... 12345678, + ... 'This is a log message;\\nit has multiple lines', + ... ) + >>> print(entry.to_text()) + [cache-debug-log 1900-01-01 00:00:00,000 DEBUG] PID 12345 \ +(main process): Cache 0xbc614e: This is a log message; + it has \ +multiple lines + >>> another_entry = CacheLoggingEntry( + ... datetime(2000, 12, 31, 12, 34, 56, 789000), + ... 12345, + ... 54321, + ... 87654321, + ... 'FOO BAR BAZ', + ... ) + >>> print(another_entry.to_text()) + [cache-debug-log 2000-12-31 12:34:56,789 DEBUG] PID 54321 \ +(12345): Cache 0x5397fb1: FOO BAR BAZ + >>> log_text = '\\n'.join([ + ... e.to_text() for e in [entry, another_entry] + ... ]) + >>> assert CacheLoggingEntry.from_text(log_text) == [ + ... entry, another_entry, + ... ] + """ + timestamp: datetime + main_pid: int + current_pid: int + cache_id: int + msg: str + + def to_text(self) -> str: + return add_timestamp(self._get_header() + self.msg, self.timestamp) + + def _get_header(self) -> str: + return get_logger_header( + self.current_pid, self.main_pid, self.cache_id, + ) + HEADER_SEP + + def write(self, tee: os.PathLike[str] | str | None = None) -> None: + log_msg = self._get_header() + self.msg + diagnostics.log.debug(log_msg) + if tee is None: + return + with Path(tee).open(mode='a') as fobj: + print(add_timestamp(log_msg, self.timestamp), file=fobj) + + @classmethod + def new(cls, main_pid: int, cache_id: int, msg: str) -> Self: + return cls(datetime.now(), main_pid, os.getpid(), cache_id, msg) + + @classmethod + def from_file(cls, file: os.PathLike[str] | str | TextIO) -> list[Self]: + try: + path = Path(file) # type: ignore + except TypeError: # File object + # If we're here, `file` is a file object + if TYPE_CHECKING: + assert isinstance(file, TextIO) + content = file.read() + else: + content = path.read_text() + return cls.from_text(content) + + @classmethod + def from_text(cls, text: str) -> list[Self]: + def gen_timestamps(text: str) -> Generator[re.Match, None, None]: + last_ts_match: re.Match | None = None + while True: + ts_match = timestamp_regex.search( + text, last_ts_match.end() if last_ts_match else 0, + ) + if ts_match: + yield ts_match + last_ts_match = ts_match + else: + return + + def gen_message_blocks(text: str) -> Generator[ + tuple[datetime, re.Match, str], None, None + ]: + timestamps = list(gen_timestamps(text)) + if not timestamps: + return + + # Handle all the entries up till the 2nd-to-last one + for this_match, next_match in pairwise(timestamps): + ts = parse_timestamp(this_match.group('timestamp')) + # The -1 is for stripping the trailing newline + text_block = text[this_match.start():next_match.start() - 1] + yield (ts, this_match, text_block) + # Handle the last entry + last_match = timestamps[-1] + yield ( + parse_timestamp(last_match.group('timestamp')), + last_match, + text[last_match.start():], + ) + + def get_entries(text: str) -> Generator[Self, None, None]: + for timestamp, ts_match, text_block in gen_message_blocks(text): + # Strip the block indent + ts_text = ts_match.group(0) + assert text_block.startswith(ts_text) + ts_width = len(ts_text) + text_block = dedent(' ' * ts_width + text_block[ts_width:]) + # Strip the header and parse the relevant info from it + header_match = header_regex.match(text_block) + assert header_match + current_pid = int(header_match.group('current_pid')) + main_pid_ = header_match.group('main_pid') + if main_pid_ == HEADER_MAIN_INDICATOR: + main_pid = current_pid + else: + main_pid = int(main_pid_) + cache_id = parse_id(header_match.group('obj_id')) + # The rest of the block is the message proper + msg = text_block[header_match.end():] + yield cls(timestamp, main_pid, current_pid, cache_id, msg) + + timestamp_pattern = fmt_to_regex( + f'{TIMESTAMP_PATTERN}{TIMESTAMP_SPACING}', timestamp='.+?', + ) + timestamp_regex = re.compile('^' + timestamp_pattern, re.MULTILINE) + header_regex = re.compile(fmt_to_regex( + HEADER_PATTERN + HEADER_SEP, + current_pid=r'\d+', + main_pid=r'\d+|' + re.escape(HEADER_MAIN_INDICATOR), + obj_id='.+?', + )) + return list(get_entries(text)) diff --git a/line_profiler/_child_process_profiling/cache.py b/line_profiler/_child_process_profiling/cache.py new file mode 100644 index 00000000..b658e8e0 --- /dev/null +++ b/line_profiler/_child_process_profiling/cache.py @@ -0,0 +1,892 @@ +""" +A cache object to be used by for propagating profiling down to child +processes. +""" +from __future__ import annotations + +import atexit +import dataclasses +import os +import signal +import sys +import sysconfig +try: + import _pickle as pickle +except ImportError: + import pickle # type: ignore[assignment,no-redef] +from collections.abc import ( + Collection, Callable, Iterable, Mapping, MutableMapping, +) +from functools import partial, cached_property, wraps +from importlib import import_module +from pathlib import Path +from pickle import HIGHEST_PROTOCOL +from textwrap import indent +from threading import current_thread, main_thread +from types import FrameType, ModuleType +from typing import Any, ClassVar, Literal, TypeVar, cast, final, overload +from typing_extensions import Concatenate, ParamSpec, Self + +from _line_profiler_hooks import INHERITED_PID_ENV_VARNAME, load_pth_hook +from .. import _diagnostics as diagnostics +from ..cleanup import Cleanup, _CALLBACK_REPR_HELPER +from ..curated_profiling import CuratedProfilerContext +from ..line_profiler import LineProfiler, LineStats +from ..toml_config import ConfigSource +from ._cache_logging import CacheLoggingEntry + + +__all__ = ('LineProfilingCache',) + + +T = TypeVar('T') +PS = ParamSpec('PS') +# Note: `typing.AnyStr` deprecated since 3.13 +AnyStr = TypeVar('AnyStr', str, bytes) +_SignalHandler = Callable[[int, FrameType | None], Any] + +_THIS_SUBPACKAGE, *_ = (lambda: None).__module__.rpartition('.') +INHERITED_CACHE_ENV_VARNAME_PREFIX = ( + 'LINE_PROFILER_PROFILE_CHILD_PROCESSES_CACHE_DIR' +) +CACHE_FILENAME = 'line_profiler_cache.pkl' +_DEBUG_LOG_FILENAME_PATTERN = 'debug_log_{main_pid}_{current_pid}.log' +_PROFILING_OUTPUT_PREFIX_PATTERN = ( + 'child-prof-output-{main_pid}-{current_pid}-{prof}-' +) +_POSSIBLE_EMPTY_STATS_PREFIX_PATTERN = ( + 'ignore-empty-stats-file-{main_pid}-{current_pid}-' +) + + +def _import_sibling(submodule: str) -> ModuleType: + return import_module(f'{_THIS_SUBPACKAGE}.{submodule}') + + +_private_field = partial(dataclasses.field, init=False, repr=False) + + +class _DumpStatsHelper(Cleanup): + def __init__( + self, prof: LineProfiler, outfile: os.PathLike[str] | str, + ) -> None: + super().__init__() + callback = self._callback = partial(prof.dump_stats, outfile) + self.add_cleanup(callback) + + def __repr__(self) -> str: + name = type(self).__name__ + get_repr = _CALLBACK_REPR_HELPER.repr + return f'<{name} @ {hex(id(self))}: {get_repr(self._callback)}>' + + def __call__(self) -> None: + self._callback() + + def cleanup(self, *args, force: bool = False, **kwargs) -> None: + if force and not any(self._current_context.values()): + self.add_cleanup(self._callback) + super().cleanup(*args, **kwargs) + + +@final +@dataclasses.dataclass +class LineProfilingCache(Cleanup): + """ + Helper object for coordinating a line-profiling session, caching the + info required to make profiling persist into child processes. + """ + cache_dir: os.PathLike[str] | str + config: os.PathLike[str] | str | None = None + profiling_targets: Collection[str] = dataclasses.field( + default_factory=list, + ) + rewrite_module: os.PathLike[str] | str | None = None + profile_imports: bool = False + preimports_module: os.PathLike[str] | str | None = None + main_pid: int = dataclasses.field(default_factory=os.getpid) + # Note: if we're using the line profiler, `kernprof` always sets + # `builtin` to true + insert_builtin: bool = True + debug: bool = diagnostics.DEBUG + + profiler: LineProfiler | None = _private_field(default=None) + _sighandlers: dict[int, _SignalHandler | int | None] = ( + _private_field(default_factory=dict) + ) + _stats_dumper: _DumpStatsHelper | None = _private_field(default=None) + # These are unstructured fields; other components can decide on what + # to put in them. They are also pickled by `.dump()`, and are thus + # retrievable in `.load()`-ed instances. + _additional_data: dict[str, Any] = _private_field(default_factory=dict) + + _loaded_instance: ClassVar[LineProfilingCache | None] = None + + def __post_init__(self) -> None: + super().__init__() + + def copy(self, /, **replacements) -> Self: + """ + Make a copy with optionally replaced fields. + + Args: + **replacements (Any): + Optional fields to replace + + Return: + inst (LineProfilingCache): + New instance + """ + init_args: dict[str, Any] = {} + for field, value in self._get_init_args().items(): + init_args[field] = replacements.get(field, value) + return type(self)(**init_args) + + @classmethod + def load(cls) -> Self: + """ + Reconstruct the instance from the environment variables + :env:`LINE_PROFILER_PROFILE_CHILD_PROCESSES_CACHE_PID` and + :env:`LINE_PROFILER_PROFILE_CHILD_PROCESSES_CACHE_DIR_`. + These should have been set from an ancestral Python process. + + Note: + If a previously :py:meth:`.~.load`-ed instance exists, it is + returned instead of a new instance. + """ + # `ty` needs some help here, evenif we've marked the class to be + # `@final` + instance = cast(Self | None, cls._loaded_instance) + if instance is None: + pid = os.environ[INHERITED_PID_ENV_VARNAME] + cache_varname = f'{INHERITED_CACHE_ENV_VARNAME_PREFIX}_{pid}' + cache_dir = os.environ[cache_varname] + msg = ( + f'PID {os.getpid()} (from {pid}): ' + f'Loading instance from ${{{cache_varname}}} = {cache_dir}' + ) + diagnostics.log.debug(msg) + instance = cls._from_path(cls._get_filename(cache_dir)) + instance._replace_loaded_instance(force=True) + return instance + + def dump(self) -> None: + """ + Serialize the cache instance and dump into the default location + as indicated by :py:attr:`~.cache_dir`, so that they can be + :py:meth:`~.load`-ed by child processes. + + Note: + Cleanup callbacks are not serialized. + """ + content = { + 'init_args': self._get_init_args(), + 'additional_data': self._additional_data, + } + msg = f'Dumping instance data to {self.filename}: {content!r}' + self._debug_output(msg) + with open(self.filename, mode='wb') as fobj: + pickle.dump(content, fobj, protocol=HIGHEST_PROTOCOL) + + def gather_stats( + self, + exclude_pids: Collection[int] | None = None, + *, + on_empty: Literal['error', 'warn', 'ignore'] = 'warn', + on_defective: Literal['error', 'warn', 'ignore'] = 'warn', + ) -> LineStats: + """ + Gather the profiling output files matching ``glob_pattern`` from + :py:attr:`~.cache_dir`, consolidating them into a single + :py:class:`LineStats` object. + + Args: + exclude_pids (Collection[int] | None): + Exclude output from child processes with these PIDs; + the default value :py:const:`None` fetches relevant + PIDs dynamically. + on_empty, on_defective (Literal['error', 'warn', 'ignore']): + Passed to :py:meth:`LineStats.from_files`. + + Returns: + :py:class:`LineStats` instance + """ + def is_empty(path: Path) -> bool: + return not path.stat().st_size + + filter_excludes: Callable[[Iterable[Path]], Iterable[Path]] + if exclude_pids is None: + # NOTE: there is no guarantee that the PID hasn't previously + # been used for another child process that we DID properly + # profile and SHOULD include, so we only filter out empty + # files + exclude_pids = self._get_pids_possibly_lacking_stats() + filter_excludes = partial(filter, is_empty) + else: # User-provided values, who are we to object? + filter_excludes = iter + + fnames_ = set(self._get_profiling_outfiles()) + for pid in exclude_pids: + excludes = filter_excludes(self._get_profiling_outfiles(pid)) + fnames_.difference_update(excludes) + fnames = sorted(fnames_) + self._debug_output( + 'Loading results from {} child profiling file(s): {!r}' + .format(len(fnames), fnames) + ) + if not fnames: + return LineStats.get_empty_instance() + return LineStats.from_files( + *fnames, on_empty=on_empty, on_defective=on_defective, + ) + + def _dump_debug_logs(self) -> None: + """ + Gather the debug logfiles in child processes and write their + contents to the logger + (:py:data:`line_profiler._diagnostics.log`). + + Notes: + - The content of each child-process log file is not + re-parsed and is written to the logger as a single + multi-line message. + + - To be called in the main process. + """ + for log in sorted(self._get_debug_logfiles()): + if log == self._debug_log: # Don't double dip + continue + *_, child_pid = log.stem.rpartition('_') + msg = 'Cache log messages from child process {}:\n{}'.format( + child_pid, indent(log.read_text(), ' '), + ) + diagnostics.log.debug(msg) + + def _gather_debug_log_entries( + self, chronological: bool = False, + ) -> list[CacheLoggingEntry]: + """ + Gather and return all entries from debug logfiles sorted by + timestamps. + """ + log_files: Iterable[Path] = self._get_debug_logfiles() + if chronological: # Sorting on the entries -> chronological + to_list: Callable[ + [Iterable[CacheLoggingEntry]], list[CacheLoggingEntry] + ] = sorted + else: + # Otherwise, just sort by filename (entries in each file are + # still chronological) + log_files = sorted(log_files) + to_list = list + return to_list( + entry for log in log_files + for entry in CacheLoggingEntry.from_file(log) + ) + + def _glob(self, *args, **kwargs) -> Iterable[Path]: + return Path(self.cache_dir).glob(*args, **kwargs) + + def _get_debug_logfiles(self) -> Iterable[Path]: + return self._glob(_DEBUG_LOG_FILENAME_PATTERN.format( + main_pid=self.main_pid, current_pid='?*', + )) + + def _get_profiling_outfiles(self, pid: Any = '?*') -> Iterable[Path]: + prefix = _PROFILING_OUTPUT_PREFIX_PATTERN.format( + main_pid=self.main_pid, + current_pid=pid, + # We always format the profiler ID with `hex()`, see + # `._setup_in_child_process()` + prof='0x?*', + ) + return self._glob(prefix + '?*.lprof') + + def inject_env_vars( + self, env: MutableMapping[str, str] | None = None, + ) -> None: + """ + Inject the :py:attr:`~.environ` variables into ``env`` and add + cleanup callbacks to reverse them. + + Args: + env (MutableMapping[str, str] | None): + Dictionary in the format of :py:data:`os.environ`; + default is to use that + """ + self.update_mapping( + os.environ if env is None else env, + self.environ, + _format_debug_msg='Injecting env var ${{{1}}}: {2}'.format, + ) + + def write_pth_hook( + self, *, + prefix: str | None = None, + suffix: str | None = None, + dir: os.PathLike[str] | str | None = None, + # Get rid of the .pth file ASAP so as to be the least disruptive + priority: float = 1, + **kwargs + ) -> Path: + """ + Write a .pth file which allows for setting up profiling in child + Python processes. + + Args: + prefix, suffix (str | None): + Optional filename-stem affixes of the .pth file; default + is to use default values loaded from :py:attr:`.config` + dir (os.PathLike[str] | str | None): + Optional directory to create the .pth file in; default + is to use ``sysconfig.get_path('purelib')`` + priority, **kwargs: + Passed to :py:meth:`.make_tempfile`. + + Returns: + fpath (Path): + Path to the written .pth file + """ + def get_pth_config() -> Mapping[str, Any]: + # Note: the only keys in it should be `prefix` and `suffix` + return ( + self._config_source # Cached + .get_subconfig('child_processes', 'pth_files') + .conf_dict + ) + + if not os.path.exists(self.filename): + self.dump() + assert os.path.exists(self.filename) + + # The string casts are failsafes in case inappropriate values + # (e.g. numbers and booleans) are supplied + if prefix is None: + prefix = str(get_pth_config()['prefix']) + if suffix is None: + suffix = str(get_pth_config()['suffix']) + if dir is None: + dir = sysconfig.get_path('purelib') + + template = 'import {0.__module__}; {0.__module__}.{0.__name__}({1})' + fpath = self.make_tempfile( + prefix=prefix, suffix=suffix + '.pth', dir=dir, priority=priority, + **kwargs, + ) + try: + fpath.write_text(template.format(load_pth_hook, self.main_pid)) + except Exception: + fpath.unlink(missing_ok=True) + raise + + return fpath + + def _debug_output(self, msg: str) -> None: + """ + Beside writing to the logger, also write to the + :py:attr:`~._debug_log`. + """ + try: + self._make_debug_entry(msg).write(self._debug_log) + except OSError: # Cache dir may have been rm-ed during cleanup + pass + + def _setup_in_main_process(self, wrap_os_fork: bool = True) -> None: + """ + Set up shop in the main process so that (line-)profiling can + extend into child processes. + + Args: + wrap_os_fork (bool): + Whether to wrap :py:func:`os.fork` which handles + profiling + + Side effects: + + - Instance data written to :py:attr:`~.cache_dir` + + - Environment variables injected + (see :py:meth:`~.inject_env_vars()`) + + - A ``.pth`` file written so that child processes + automaticaly runs setup code (see + :py:meth:`.write_pth_hook`) + + - :py:func:`os.fork` wrapped so that profiling set up in + forked processes is properly handled (if + ``wrap_os_fork=True``) + + - :py:mod:`multiprocessing` patched so that child processes + managed thereby are properly handled + + - Instance to be returned if :py:func:`~.load()` is called + from now on + """ + self.dump() + self.inject_env_vars() + self.write_pth_hook() + self._setup_common(wrap_os_fork, {'reboot_forkserver': True}) + self._replace_loaded_instance() + + def _setup_in_child_process( + self, + wrap_os_fork: bool = False, + context: str = '', + prof: LineProfiler | None = None, + ) -> bool: + """ + Set up shop in a forked/spawned child process so that + (line-)profiling can extend therein. + + Args: + wrap_os_fork (bool): + Whether to wrap :py:func:`os.fork` which handles + profiling; already-forked child processes should set + this to false + context (str): + Optional context from which the function is called, to + be used in log messages + prof (LineProfiler | None): + Optional profiler instance to associate with the cache; + if not provided, an instance is created + + Returns: + has_set_up (bool): + False the instance has already been set up prior to + calling this function, true otherwise + """ + def wrap_ctx_debug( + ctx: CuratedProfilerContext, msg: str, + ) -> None: + self._debug_output(f' Context {id(ctx):#x}: {msg}') + + if not context: + context = '...' + self._debug_output(f'Setting up ({context})...') + if self.profiler is not None: # Already set up + self._debug_output(f'Setup aborted ({context})') + return False + + # Create a profiler instance and manage it with + # `CuratedProfilerContext` + if prof is None: + prof = LineProfiler() + self.profiler = prof + ctx = CuratedProfilerContext(prof, insert_builtin=self.insert_builtin) + if self.debug: + self.patch(ctx, '_debug_output', wrap_ctx_debug.__get__(ctx)) + ctx.install() + self.add_cleanup(ctx.uninstall) + self._debug_output(f'Set up `.profiler` at {id(prof):#x}') + + # Do the preimports at `cache.preimports_module` where + # appropriate + if self.preimports_module: + self._debug_output('Loading preimports...') + with open(self.preimports_module, mode='rb') as fobj: + code = compile(fobj.read(), self.preimports_module, 'exec') + exec(code, {}) # Use a fresh, empty namespace + + # - Occupy a tempfile slot in `.cache_dir` + # - Set the profiler up to write thereto when the process + # terminates (with high priority) + # (Also keep a separate reference to the callback for e.g. + # dumping stats ASAP when a signal is caught) + prof_outfile = self.make_tempfile( + prefix=_PROFILING_OUTPUT_PREFIX_PATTERN.format( + main_pid=self.main_pid, + current_pid=os.getpid(), + prof=hex(id(prof)), + ), + suffix='.lprof', + delete=False, + ) + self._stats_dumper = dumper = _DumpStatsHelper(prof, prof_outfile) + self.patch( + # If we call `dumper.cleanup()` instead of `dumper` (e.g. + # in `._handle_signal()`), the subsequent debug-log messages + # are attributed to and handled by this cache instance + dumper, '_debug_output', self._debug_output, + cleanup=False, name='._stats_dumper', + ) + self.add_cleanup_with_priority(self._stats_dumper, 1) + + # Various setups + self._setup_common(wrap_os_fork, {'reboot_forkserver': False}) + + # Set `.cleanup()` as an atexit hook to handle everything when + # the child process is about to terminate + atexit.register(self._atexit_hook) + + self._debug_output(f'Setup successful ({context})') + return True + + def _setup_common( + self, + wrap_os_fork: bool, + mp_apply_kwargs: dict[str, Any] | None = None, + ) -> None: + if wrap_os_fork: + self._wrap_os_fork() + _import_sibling('multiprocessing_patches').apply( + self, **(mp_apply_kwargs or {}), + ) + + def _handle_signal(self, signum: int, *_) -> None: # nocover + """ + See also: + :py:meth:`coverage.control.Converage._on_sigterm` + """ + name = self._get_signal_name(signum) + # Shouldn't happen, but all kinds of weird things happen at the + # interpreter's EoL... + state = 'not initiated?!' + try: + # Just use the `._stats_dumper` to dump the stats ASAP + # without running this cache's `.cleanup()` to avoid + # deadlocks + if self._stats_dumper is None: + state = 'unavailable' + else: + reason = f'caught `{name}` ({signum})' + self._stats_dumper.cleanup(force=True, reason=reason) + except BaseException as e: + state = f'failed ({self._format_exception(e)})' + raise e + else: + state = 'succeeded' + finally: + handler = self._sighandlers.pop(signum, None) + msg = f'Stat-dumping {state}, passing `{name}` onto {handler!r}...' + self._debug_output(msg) + if handler is None: + msg = 'original handler set from outside of Python' + raise RuntimeError(msg) + else: + signal.signal(signum, handler) + signal.raise_signal(signum) + + def _add_signal_handler( + self, signum: int = signal.SIGTERM, + ) -> None: # nocover + """ + Side effects: + If on the main thread and not on Windows: + + - :py:func:`signal.signal` called to set + :py:meth:`~._handle_signal` as the ``SIGTERM`` handler + + - :py:meth:`~.cleanup` callback registered undoing that + + Note: + ``SIGTERM`` handling is known to be faulty on Windows; see + previous discussions at (examples `1`_, `2`_). + + .. _1: https://github.com/coveragepy/coveragepy/blob/main/\ +coverage/control.py + .. _2: https://stackoverflow.com/questions/35772001/ + """ + if current_thread() != main_thread() or sys.platform == 'win32': + return + name = self._get_signal_name(signum) + self._debug_output(f'Adding `{name}` handler...') + self._sighandlers[signum] = signal.signal(signum, self._handle_signal) + + @staticmethod + def _get_signal_name(signum: int) -> str: + return signal.Signals(signum).name + + def _wrap_os_fork(self) -> None: + """ + Create a wrapper around :py:func:`os.fork` which handles + profiling. + + Side effects: + + - :py:func:`os.fork` (if available) replaced with the + wrapper + + - :py:meth:`~.cleanup` callback registered undoing that + """ + try: + fork = os.fork + except AttributeError: # Can't fork on this platform + return + + @wraps(fork) + def wrapper() -> int: + ppid = os.getpid() + result = fork() + if result: + return result + # If we're here, we are in the fork + pid = os.getpid() + forked = self.copy() # Ditch inherited cleanups + forked._debug_output(f'Forked: {ppid} -> {pid}') + if forked._replace_loaded_instance(): + forked._debug_output( + 'Superseded cached `.load()`-ed instance in forked process' + ) + # Note: we can reuse the profiler instance in the fork, but + # it needs to go through setup so that the separate + # profiling results are dumped into another output file + forked._setup_in_child_process(False, 'fork', self.profiler) + return result + + self.patch(os, 'fork', wrapper, name='os') + + def _warn_possible_lack_of_stats( + self, pids: int | Collection[int], + ) -> None: + """ + Register PID(s) which may have created a profiling stats file + without writing to it; when calling :py:meth:`.gather_stats`, + empty stats files associated with those PIDs are ignored instead + of warned against or treated as an error. + """ + if not isinstance(pids, Collection): + pids = pids, + with self._empty_stats_pid_registry.open(mode='a') as fobj: + print(*pids, sep='\n', file=fobj) + + def _get_pids_possibly_lacking_stats(self) -> set[int]: + """ + See also + :py:meth:`._warn_possible_lack_of_stats` + """ + prefix = _POSSIBLE_EMPTY_STATS_PREFIX_PATTERN.format( + main_pid=self.main_pid, + current_pid='?*', # Gather from all child processes + ) + result: set[int] = set() + for registry in self._glob(prefix + '?*.dat'): + from_reg: set[int] = set() + with registry.open() as fobj: + for line in fobj: + try: + from_reg.add(int(line)) + except ValueError: + pass + if from_reg: + self._debug_output( + f'Loaded {len(from_reg)} PID(s) possibly lacking ' + f'profiling output from {registry.name!r}: {from_reg!r}' + ) + result.update(from_reg) + return result + + def make_tempfile(self, **kwargs) -> Path: + """ + Create a fresh tempfile under :py:attr:`~.cache_dir`. The other + arguments are passed as-is to :py:func:`tempfile.mkstemp`. + + Returns: + path (Path): + Path to the created file. + """ + kwargs.setdefault('dir', self.cache_dir) + kwargs.setdefault( + '_format_debug_msg', 'Created tempfile: {0.name!r}'.format, + ) + return super().make_tempfile(**kwargs) + + def _replace_loaded_instance(self, force: bool = False) -> bool: + cls = type(self) + if force or self._consistent_with_loaded_instance: + cls._loaded_instance = self + return True + return False + + @classmethod + def _from_path(cls, fname: os.PathLike[str] | str) -> Self: + with open(fname, mode='rb') as fobj: + content = pickle.load(fobj) + instance = cls(**content['init_args']) + instance._additional_data.update(content.get('additional_data', {})) + return instance + + def _get_init_args(self) -> dict[str, Any]: + init_fields = [ + field_obj.name for field_obj in dataclasses.fields(self) + if field_obj.init + ] + return {name: getattr(self, name) for name in init_fields} + + @staticmethod + def _get_filename(cache_dir: os.PathLike[str] | str) -> str: + return os.path.join(cache_dir, CACHE_FILENAME) + + @overload + @classmethod + def _method_wrapper( + cls, + wrapper: Callable[Concatenate[Self, Callable[PS, T], PS], T], + debug: bool | None = None, + ) -> Callable[[Callable[PS, T]], Callable[PS, T]]: + ... + + @overload + @classmethod + def _method_wrapper( + cls, wrapper: None = None, debug: bool | None = None, + ) -> Callable[ + [Callable[Concatenate[Self, Callable[PS, T], PS], T]], + Callable[[Callable[PS, T]], Callable[PS, T]] + ]: + ... + + @classmethod + def _method_wrapper( + cls, + wrapper: ( + Callable[Concatenate[Self, Callable[PS, T], PS], T] | None + ) = None, + debug: bool | None = None, + ) -> ( + Callable[ + [Callable[Concatenate[Self, Callable[PS, T], PS], T]], + Callable[[Callable[PS, T]], Callable[PS, T]] + ] + | Callable[[Callable[PS, T]], Callable[PS, T]] + ): + """ + Convenience wrapper decorator for functions which use the + :py:meth:`load`-ed session instance and wrap another callable. + + Args: + wrapper (Callable[..., T]) + Callable with the call signature + ``(cache, vanilla_impl, *args, **kwargs) -> retval``; + ``*args``, ``**kwargs``, and ``retval`` should be + consistent with that of ``vanilla_impl()``'s. + debug (bool | None) + Whether to format and write debug messages before and + after the call to the ``wrapper`` callable; + if ``debug`` is not set, it will be taken from the + session instance. + + Returns: + inner_wrapper (Callable[[Callable[PS, T]], Callable[PS, T]]) + Wrapper(-maker) which takes the ``vanilla_impl`` and + return a wrapper around it. + """ + if wrapper is None: + # `ty` doesn't quite support `partial` yet, see issue #1536 + return cast( + Callable[[Callable[PS, T]], Callable[PS, T]], + partial(cls._method_wrapper, debug=debug), + ) + + def inner_wrapper(vanilla_impl: Callable[PS, T]) -> Callable[PS, T]: + @wraps(vanilla_impl) + def wrapped_impl(*args: PS.args, **kwargs: PS.kwargs) -> T: + cache = cls.load() + write = cache._debug_output + debug_: bool | None = debug + call = partial(wrapper, cache, vanilla_impl, *args, **kwargs) + + if debug_ is None: + debug_ = cache.debug + if debug_: + call_fmt = cache._format_call(name, *args, **kwargs) + write(f'Wrapped call made: {call_fmt}...') + try: + result = call() + except BaseException as e: + # Note: be more defensive than normal and + # prepared to deal with `BaseException`; this + # decorator is often used for functions invoked + # in child processes which don't cleanly + # terminate + state, outcome = 'failed', cache._format_exception(e) + raise e + else: + state = 'succeeded' + outcome = _CALLBACK_REPR_HELPER.repr(result) + return result + finally: + write(f'Wrapped call {state}: {call_fmt} -> {outcome}') + else: + return call() + + name = cls._get_name(vanilla_impl) + return wrapped_impl + + for field in 'name', 'qualname', 'doc': + dunder = f'__{field}__' + value = getattr(wrapper, dunder, None) + if value is not None: + setattr(inner_wrapper, dunder, value) + return inner_wrapper + + @classmethod + def _format_call( + cls, func: Callable[..., Any] | str, /, *args, **kwargs, + ) -> str: + if isinstance(func, partial): + return cls._format_call( + func.func, [*func.args, *args], {**func.keywords, **kwargs}, + ) + call = _CALLBACK_REPR_HELPER.format_call(*args, **kwargs) + if not isinstance(func, str): + func = cls._get_name(func) + return func + call + + @staticmethod + def _format_exception(xc: BaseException) -> str: + formatted = type(xc).__name__ + if str(xc): + formatted = f'{formatted}: {xc}' + return formatted + + @property + def environ(self) -> dict[str, str]: + """ + Environment variables to be injected into and inherited by child + processes. + """ + cache_varname = f'{INHERITED_CACHE_ENV_VARNAME_PREFIX}_{self.main_pid}' + return { + INHERITED_PID_ENV_VARNAME: str(self.main_pid), + cache_varname: str(self.cache_dir), + } + + @property + def filename(self) -> str: + return self._get_filename(self.cache_dir) + + @property + def _debug_log(self) -> Path | None: + if not self.debug: + return None + fname = _DEBUG_LOG_FILENAME_PATTERN.format( + main_pid=self.main_pid, current_pid=os.getpid(), + ) + return Path(self.cache_dir) / fname + + @cached_property + def _make_debug_entry(self) -> Callable[[str], CacheLoggingEntry]: + return partial(CacheLoggingEntry.new, self.main_pid, id(self)) + + @cached_property + def _consistent_with_loaded_instance(self) -> bool: + return type(self).load()._get_init_args() == self._get_init_args() + + @cached_property + def _config_source(self) -> ConfigSource: + if self.config is None: + config: str | None = None + else: + config = str(self.config) + return ConfigSource.from_config(config) + + @cached_property + def _empty_stats_pid_registry(self) -> Path: + prefix = _POSSIBLE_EMPTY_STATS_PREFIX_PATTERN.format( + main_pid=self.main_pid, + current_pid=os.getpid(), + ) + return self.make_tempfile(prefix=prefix, suffix='.dat', delete=False) + + @cached_property + def _atexit_hook(self) -> Callable[[], None]: + return partial(self.cleanup, reason='`atexit` callback') diff --git a/line_profiler/_child_process_profiling/multiprocessing_patches.py b/line_profiler/_child_process_profiling/multiprocessing_patches.py new file mode 100644 index 00000000..b30b1bca --- /dev/null +++ b/line_profiler/_child_process_profiling/multiprocessing_patches.py @@ -0,0 +1,1476 @@ +""" +Patch :py:mod:`multiprocessing` so that profiling extends into processes +it creates. + +Notes: + - Based on the implementations in :py:mod:`coverage.multiproc` and + :py:mod:`pytest_autoprofile._multiprocessing`. + + - Results may vary if the process pool is not properly + :py:meth:`multiprocessing.pool.Pool.close`-d and + :py:meth:`multiprocessing.pool.Pool.join`-ed; + see `this caveat `__. +""" +from __future__ import annotations + +import atexit +import dataclasses +import multiprocessing +import os +import sys +import warnings +from collections.abc import Callable, Collection, Mapping, Sequence, Set +from functools import partial, wraps +from importlib import import_module +from inspect import getattr_static, signature +from multiprocessing.pool import Pool +from multiprocessing.process import BaseProcess +from operator import attrgetter +from queue import SimpleQueue +from time import sleep, monotonic +from types import MappingProxyType as mappingproxy, MethodType, ModuleType +from typing import ( + TYPE_CHECKING, + Any, ClassVar, Generic, Literal, NamedTuple, Protocol, TypeVar, NoReturn, + cast, final, overload, +) +from typing_extensions import Concatenate, ParamSpec, Self + +try: + from multiprocessing import spawn +except ImportError: + _CAN_USE_SPAWN = False +else: + _CAN_USE_SPAWN = True +try: + from multiprocessing import forkserver +except ImportError: + _CAN_USE_FORKSERVER = False +else: + _CAN_USE_FORKSERVER = ( + 'forkserver' in multiprocessing.get_all_start_methods() + ) +try: + from multiprocessing import resource_tracker +except ImportError: + _CAN_USE_RESOURCE_TRACKER = False +else: + _CAN_USE_RESOURCE_TRACKER = True + +from .. import _diagnostics as diagnostics +from ..cleanup import Cleanup +from ..toml_config import ConfigSource +from .cache import LineProfilingCache +from .runpy_patches import create_runpy_wrapper + + +__all__ = ('apply',) + + +T = TypeVar('T') +T1 = TypeVar('T1') +T2 = TypeVar('T2') +P = TypeVar('P', bound=BaseProcess) +Pt = TypeVar('Pt', bound='_Patch') +PS = ParamSpec('PS') +PS1 = ParamSpec('PS1') +PS2 = ParamSpec('PS2') +_OnTimeout = Literal['ignore', 'warn', 'error'] +PublicPatch = Literal['pool', 'process', 'logging', 'child_pids'] + +_CAN_CATCH_SIGTERM = sys.platform != 'win32' +_PATCHED_MARKER = '__line_profiler_patched_multiprocessing__' +_LOGGERS = ['sub_debug', 'debug', 'info', 'sub_warning', 'warn'] +_PATCHES: dict[str, '_Patch'] = {} + + +# ------------------------------ Helpers ------------------------------- + + +class _Wrapper(Protocol, Generic[PS, T]): + def __call__(self, func: Callable[PS, T], /) -> Callable[PS, T]: + ... + + +class _Queue(Protocol): + """ + Protocol for methods common to e.g. :py:class:`queue.SimpleQueue` + and :py:class:`multiprocessing.queues.SimpleQueue`. + """ + def put(self, obj: Any) -> None: + ... + + def get(self) -> Any: + ... + + +class _Poller: + """ + Poll a callable until it returns true-y. + + Example: + >>> import warnings + >>> from contextlib import ExitStack + >>> from functools import partial + >>> from itertools import count + >>> from typing import Iterator, Literal + + >>> def count_until( + ... limit: int, mode: Literal['until', 'while'] = 'until', + ... ) -> bool: + ... def counter_is_big_enough( + ... counter: Iterator[int], limit: int, + ... ) -> bool: + ... return next(counter) >= limit + ... + ... def counter_is_small_enough( + ... counter: Iterator[int], limit: int, + ... ) -> bool: + ... return next(counter) < limit + ... + ... # The branches are ultimately equal in results, but we + ... # want to explicitly test both `.poll_until()` and + ... # `.poll_while()` + ... if mode == 'until': + ... get_poller = partial( + ... _Poller.poll_until, counter_is_big_enough, + ... ) + ... else: + ... get_poller = partial( + ... _Poller.poll_while, counter_is_small_enough, + ... ) + ... return get_poller(count(), limit) + + >>> with count_until(10).with_cooldown(.01).with_timeout(1): + ... # Note: we shouldn't really need that much time, but + ... # something in CI seems to be slowing down the polling + ... # loop... + ... print('We counted up to 10') + We counted up to 10 + + >>> with ( + ... count_until(100) + ... .with_cooldown(.01) + ... .with_timeout(.5) # `[on_]timeout` separately supplied + ... .with_timeout(on_timeout='ignore') + ... ): + ... print("We probably didn't count up to 100 but whatever") + We probably didn't count up to 100 but whatever + + >>> with ( # doctest: +NORMALIZE_WHITESPACE + ... count_until(30).with_cooldown(.01).with_timeout(.25) + ... ): + ... print('We counted up to 30') + Traceback (most recent call last): + ... + line_profiler..._Poller.Timeout: ... + timed out (... s >= 0.25 s) waiting for + callback ...counter_is_big_enough... to return true + + >>> with ExitStack() as stack: # doctest: +NORMALIZE_WHITESPACE + ... enter = stack.enter_context + ... enter(warnings.catch_warnings()) + ... warnings.simplefilter('error', _Poller.TimeoutWarning) + ... enter( + ... count_until(30, 'while') + ... .with_cooldown(.01) + ... .with_timeout(.25, 'warn') + ... ) + ... print('We counted up to 30 again') + Traceback (most recent call last): + ... + line_profiler..._Poller.TimeoutWarning: ... + timed out (... s >= 0.25 s) waiting for + callback ...counter_is_small_enough... to return true + """ + def __init__( + self, + func: Callable[[], Any], + cooldown: float = 0, + timeout: float = 0, + on_timeout: _OnTimeout = 'error', + ) -> None: + self._func: Callable[[], Any] = func + self._cooldown = max(0, cooldown) + self._timeout = max(0, timeout) + self._on_timeout = on_timeout + + def sleep(self): + cd = self._cooldown + if cd > 0: + sleep(cd) + + def with_cooldown(self, cooldown: float) -> Self: + return type(self)( + self._func, cooldown, self._timeout, self._on_timeout, + ) + + def with_timeout( + self, + timeout: float | None = None, + on_timeout: _OnTimeout | None = None, + ) -> Self: + if timeout is None: + timeout = self._timeout + if on_timeout is None: + on_timeout = self._on_timeout + return type(self)(self._func, self._cooldown, timeout, on_timeout) + + @classmethod + def poll_until( + cls, func: Callable[PS, Any], /, *args: PS.args, **kwargs: PS.kwargs + ) -> Self: + if args or kwargs: + func = partial(func, *args, **kwargs) + return cls(func) + + @classmethod + def poll_while( + cls, func: Callable[PS, Any], /, *args: PS.args, **kwargs: PS.kwargs + ) -> Self: + def negated( + func: Callable[PS, Any], *a: PS.args, **k: PS.kwargs + ) -> bool: + return not func(*a, **k) + + return cls(partial(negated, func, *args, **kwargs)) + + def __enter__(self) -> Self: + def error(msg: str) -> NoReturn: + raise type(self).Timeout(msg) + + def warn(msg: str) -> None: + # Write log before issuing the warning because that may be + # promoted to an exception + diagnostics.log.warning(msg) + warnings.warn(msg, type(self).TimeoutWarning, stacklevel=3) + + timeout = self._timeout + callback = self._func + + handle_timeout: Callable[[str], Any] = { + 'error': error, 'warn': warn, 'ignore': _no_op, + }[self._on_timeout] + fmt = '.3g' + timeout_msg_header = f'{type(self).__name__} at {id(self):#x}' + + start = monotonic() + while not callback(): + elapsed = monotonic() - start + if timeout and elapsed >= timeout: + handle_timeout( + f'{timeout_msg_header}: ' + f'timed out ({elapsed:{fmt}} s >= {timeout:{fmt}} s) ' + f'waiting for callback {callback!r} to return true' + ) + break + self.sleep() + return self + + def __exit__(self, *_, **__) -> None: + pass + + class Timeout(RuntimeError): + """ + Raised when a :py:class:`_Poller` is timed out when polling. + """ + pass + + class TimeoutWarning(Timeout, UserWarning): + """ + Issued when a :py:class:`_Poller` is timed out when polling. + """ + pass + + +@final +class _PollerArgs(NamedTuple): + cooldown: float + timeout: float + on_timeout: str | None + + @classmethod + def new(cls, cooldown: Any, timeout: Any, on_timeout: Any) -> Self: + try: + cd = max(float(cooldown), 0) + except (TypeError, ValueError): + cd = 0 + try: + to = max(float(timeout), 0) + except (TypeError, ValueError): + to = 0 + try: + ot: str | None = on_timeout.lower() + except Exception: # Fallback (use `_Poller`'s default) + ot = None + return cls(cd, to, ot) + + +@final +@dataclasses.dataclass +class MPConfig: + """ + Consolidate the config options into a structured object. + """ + catch_sigterm: bool + patches: dict[PublicPatch, bool] + polling: _PollerArgs + + def _get_terminate_poller( + self, cache: LineProfilingCache, process: BaseProcess, + ) -> _Poller: + cd, timeout, on_timeout = self.polling + if on_timeout not in ('ignore', 'warn', 'error'): + on_timeout = self.get_defaults().polling.on_timeout + # `_process_has_returned()` takes a `timeout` which it passes to + # `popen.wait()`; said timeout is essentially a limit as to how + # often the function is called, hence our cooldown + poller = _Poller.poll_until( + self._process_has_returned, process, cache, cd, + ) + return poller.with_timeout(timeout, cast(_OnTimeout, on_timeout)) + + @classmethod + def from_config(cls, config: ConfigSource) -> Self: + loaded = ( + config + .get_subconfig('child_processes', 'multiprocessing') + .conf_dict + ) + polling = _PollerArgs.new(**loaded['polling']) + return cls( + catch_sigterm=loaded['catch_sigterm'], + patches=dict(loaded['patches']), + polling=polling, + ) + + @classmethod + def from_cache(cls, cache: LineProfilingCache) -> Self: + key = 'mp_config' + try: + return cache._additional_data[key] + except KeyError: + config = cls.from_config(cache._config_source) + return cache._additional_data.setdefault(key, config) + + @classmethod + def get_defaults(cls) -> Self: + namespace = globals() + name = '_DEFAULT_CONFIG' + try: + return namespace[name] + except KeyError: + defaults = cls.from_config(ConfigSource.from_default(copy=False)) + return namespace.setdefault(name, defaults) + + @staticmethod + def _process_has_returned( + proc: BaseProcess, cache: LineProfilingCache, timeout: float, + ) -> bool: + popen = getattr(proc, '_popen', None) + if popen is None: + msg, result = 'No associated process', True + else: + result = popen.wait(timeout) is not None + if result: + msg = f'Process {popen.pid} has returned' + else: + msg = f'Waiting for process {popen.pid} to return...' + cache._debug_output(f' {type(proc).__name__} @ {id(proc):#x}: {msg}') + return result + + +class TaskWrapper(Generic[PS, T]): + """ + Pickle-able wrapper around the supplied task callable, which writes + to the session's profiling-stats file on exit. + + Note: + Since this produces extra overhead for each invocation of the + callable, it is only used when we can't reliably do + end-of-process cleanup. This mainly happens on Windows, where + we can't catch and handle ``SIGTERM``. + """ + def __init__(self, func: Callable[PS, T]) -> None: + self.func = func + try: + self.__signature__ = signature(func) + except Exception: # nocover + # Can happen with e.g. certain builin/c-based callables + pass + + def __call__(self, *args, **kwargs) -> T: + callback: Callable[[], None] = _no_op + try: + cache = LineProfilingCache.load() + except Exception: + pass + else: + # Note: this doesn't write debugging output... doing so for + # every task would be excessive + callback = partial(_dump_stats_quick, cache) + + try: + return self.func(*args, **kwargs) + finally: + callback() + + +class _QueuePIDWrapper: + """ + Wrap around a :py:class:`queue.SimpleQueue` (used by + :py:mod:`multiprocessing.dummy`) so that the PID info is attached to + the :py:meth:`.SimpleQueue.put` tuple. + + Notes: + - Used by the ``child_pids`` patch. + + - While the PID info is useless when using + :py:mod:`multiprocessing.dummy`, it is nonetheless necessary + because :py:meth:`multiprocessing.pool.Pool._handle_results` + has also been patched to expect the queue-item-getter function + to return the tuple ``(pid, original_put_value)``. + """ + def __init__(self, queue: _Queue) -> None: + self._queue = queue + + def __getattr__(self, attr: str) -> Any: + return getattr(self._queue, attr) + + def put(self, obj: Any) -> None: + self._queue.put((os.getpid(), obj)) + + def get(self) -> Any: + return self._queue.get() + + +def _no_op(*_, **__) -> None: + pass + + +def _setup_in_mp_child(cache: LineProfilingCache) -> None: + """ + Perform :py:mod:`multiprocessing`-specific setup in a child process + curated by the module. Currently it does the following: + + - Set up ``cache`` to handle ``SIGTERM`` on POSIX if not already + set. + + - Unregister the :py:mod:`atexit` hook associated with ``cache`` to + avoid possible clashes with the profiling-file writing managed by + this module. + """ + xc: Exception | None = None + for setup in [_add_sigterm_handler_in_child, _unregister_atexit_hook]: + try: + setup(cache) + except Exception as e: + xc = e + if xc is not None: + xc_str = type(xc).__name__ + if str(xc): + xc_str = f'{xc_str}: {xc}' + cache._debug_output(f'Setup failed in process {os.getpid()}: {xc_str}') + raise xc + + +def _add_sigterm_handler_in_child(cache: LineProfilingCache) -> None: + key = 'mp_added_sigterm_handler' + if not MPConfig.from_cache(cache).catch_sigterm: + return + if cache.main_pid == os.getpid(): # Not in a child process + return + if cache._additional_data.get(key, False): + # Already added (e.g. by another plugin) + return + cache._add_signal_handler() + cache._additional_data[key] = True + + +def _unregister_atexit_hook(cache: LineProfilingCache) -> None: + atexit.unregister(cache._atexit_hook) + + +def _dump_stats_quick( + cache: LineProfilingCache, + *, + reason: str | None = None, + debug: bool = False, +) -> None: + """ + Note: + We don't really care about cleanup in the child process, so just + dump the stats and bail to reduce the chance of end-of-process + shenanigans causing a deadlock... + but do use ``._stats_dumper.cleanup()`` instead of + ``.__call__()`` so that we get debugging output (if ``debug`` is + true) + """ + stats_dumper = cache._stats_dumper + if stats_dumper is None: + return + if debug: + stats_dumper.cleanup(force=True, reason=reason) + else: + stats_dumper() + + +# ---------------------- Patching infrastructure ----------------------- + + +class _Patch(Protocol): + """ + Interface for patches. + """ + def apply( + self, + cache: LineProfilingCache, + *, + cleanup: bool = True, + **kwargs + ) -> Any: + """ + Apply the patch. + + Args: + cache (LineProfilingCache): + Session cache + cleanup (bool): + Whether ``cache.cleanup()`` should reverse the patch + **kwargs + Individual implementations should pick the ones they + need and ignore the rest. + """ + ... + + @property + def summary(self) -> Mapping[str, Set[str]]: + """ + A mapping from dotted-path names of objects to the set of + attributes patched thereon. + """ + ... + + +@dataclasses.dataclass +class Patch: + """ + Patch to apply to a component in :py:mod:`multiprocessing`. + + Attributes: + submodule (str): + Name of the :py:mod:`multiprocessing` submodule. + targets (dict[str,\ +dict[str, Callable[[Any], Any] | Sequence[Callable[[Any], Any]]]]): + Dictionary mapping (dot-chained) names in said submodule to + a dictionary of patches; said patches dictionary should have + the format of + ``dict[simple_attribute, wrapper | [wrapper1, ...]]``. See + Example for details. + + Example: + Consider + ``Patch('foo', {'bar.baz': {'foobar': foofoo},\ +'': {'spam': [ham, eggs]}})``. + This instance would perform the following patches on the module + ``multiprocessing.foo``: + + - Replace ``multiprocessing.foo.bar.baz.foobar`` with + ``foofoo(multiprocessing.foo.bar.baz.foobar)`` + + - Replace ``multiprocessing.foo.spam`` with + ``eggs(ham(multiprocessing.foo.spam))``; + note that the two wrappers are applied in order to the + original attribute. + """ + submodule: str + targets: dict[ + str, dict[str, Callable[[Any], Any] | Sequence[Callable[[Any], Any]]] + ] = dataclasses.field(default_factory=dict) + package: ClassVar[str] = 'multiprocessing' + + def add_target( + self, + target: str, + patches: Mapping[ + str, Callable[[Any], Any] | Sequence[Callable[[Any], Any]] + ], + ) -> Self: + """ + Convenience method for gradually constructing the patch with a + fluent interface. + + Args: + target (str): + Dotted path to the object in :py:attr:`.submodule` + patches (Mapping[str, Callable[[Any], Any] \ +| Sequence[Callable[[Any], Any]]]): + Mapping from patched attrbute names to the wrappers to + apply thereto; sequences of wrappers are applied in + order + + Returns: + This instance + """ + self.targets.setdefault(target, {}).update(patches) + return self + + def add_method( + self, + target: str, + method: str, + wrapper: Callable[[Any], Any], + methodtype: ( + type[classmethod] | type[staticmethod] + | Literal['class', 'static'] | None + ) = None, + ) -> Self: + """ + Convenience method for gradually constructing the patch with a + fluent interface. + + Args: + target (str): + Dotted path to the object in :py:attr:`.submodule` + method (str): + Name of the (class, static, or instance) method to patch + wrapper (Callable[[Any], Any]): + Wrapping callable which takes the method-implementaion + callable and returns a wrapper thereof + methodtype (type[classmethod] | type[staticmethod] | \ +Literal['class', 'static'] | None): + Optional type of the method if not an instance method; + the strings ``'class'`` and ``'static'`` are respective + shorthands for :py:class:`classmethod` and + :py:class:`staticmethod` + + Returns: + This instance + """ + wrappers: Callable[[Any], Any] | list[Callable[[Any], Any]] + if methodtype is None: + wrappers = wrapper + else: + if methodtype == 'class': + methodtype = classmethod + elif methodtype == 'static': + methodtype = staticmethod + wrappers = [attrgetter('__func__'), wrapper, methodtype] + return self.add_target(target, {method: wrappers}) + + def apply( + self, + cache: LineProfilingCache, + *, + cleanup: bool = True, + static: bool = True, + **_ + ) -> list[str]: + """ + Apply the patch. + + Args: + cache (LineProfilingCache): + Session cache + cleanup (bool): + Whether ``cache.cleanup()`` should reverse the patch + static (bool): + Whether to use :py:func:`inspect.getattr_static` to + retrieve to the attributes to be patched on the patch + targets + + Returns: + replacements (list[str]): + Names of entities replaced + """ + submod_name = f'{self.package}.{self.submodule}' + get_attribute = getattr_static if static else getattr + result: list[str] = [] + try: + mod = self.load_module() + except ImportError: # nocover + return [] + + for target in sorted(self.targets, key=len, reverse=True): + if TYPE_CHECKING: + # See `ty` issue #2572 + assert isinstance(target, str) + if target: + try: + obj: Any = attrgetter(target)(mod) + except AttributeError: # nocover + continue + name = f'{submod_name}.{target}' + else: + obj, name = mod, submod_name + replace = partial(cache.patch, obj, cleanup=cleanup, name=name) + for method, method_wrappers in self.targets[target].items(): + if callable(method_wrappers): + method_wrappers = cast( + Sequence[Callable[[Any], Any]], (method_wrappers,), + ) + try: + impl = get_attribute(obj, method) + except AttributeError: + continue + for wrapper in method_wrappers: + impl = wrapper(impl) + replace(method, impl) + result.append(f'{name}.{method}') + return result + + def load_module(self) -> ModuleType: + """ + Returns: + Module object :py:attr:`.module` points to + """ + return import_module(self.module) + + @staticmethod + def _join(s: str, *strs: str, sep: str = '.') -> str: + return sep.join(string for string in (s, *strs) if string) + + @property + def module(self) -> str: + """ + Module where the patches are applied + """ + return self._join(self.package, self.submodule) + + @property + def summary(self) -> mappingproxy[str, frozenset[str]]: + """ + Summary of the dotted paths to the patched objects and their + patched attributes + """ + add_prefix = partial(self._join, self.module) + return mappingproxy({ + add_prefix(target): frozenset(patches) + for target, patches in self.targets.items() + }) + + +@overload +def _register_patch(name: str, patch: Pt) -> Pt: + ... + + +@overload +def _register_patch(name: str, patch: None = None) -> _Patch: + ... + + +def _register_patch(name: str, patch: _Patch | None = None) -> _Patch: + """ + Register the ``patch`` under ``name`` and return it as-is. If + ``patch`` isn't provided, look for the existing patch registered + under the name. + + Note: + Patches named with leading double underscores are applied no + matter the user input (e.g. via ``apply(..., patches=...)`` or + the config file). + """ + if patch is not None: + if _PATCHES.setdefault(name, patch) is not patch: + raise ValueError( + f'name = {name!r}, patch = {patch!r}: ' + 'name already in use by {_PATCHES[name]}' + ) + return _PATCHES[name] + + +# ---------------- `multiprocessing.pool.Pool` patches ----------------- + + +@LineProfilingCache._method_wrapper +def wrap_get_tasks( + _, # No need to use the cache, but `_method_wrapper` expects it + vanilla_impl: Callable[Concatenate[Callable[PS1, T1], PS2], T2], + func: Callable[PS1, T1], + *args: PS2.args, + **kwargs: PS2.kwargs +) -> T2: + """ + Wrap around :py:meth:`.Pool._get_tasks` so that the writing of + profiling stats is handled within the callables sent to the child + processes before the parent process assumes control. + + Note: + :py:meth:`.Pool._get_tasks` is a static method. + """ + return vanilla_impl(TaskWrapper(func), *args, **kwargs) + + +@LineProfilingCache._method_wrapper +def wrap_guarded_task_generation( + _, # No need to use the cache, but `_method_wrapper` expects it + vanilla_impl: Callable[Concatenate[Pool, int, Callable[PS1, T1], PS2], T2], + self: Pool, + result_job: int, + func: Callable[PS1, T1], + *args: PS2.args, + **kwargs: PS2.kwargs +) -> T2: + """ + Wrap around :py:meth:`.Pool._guarded_task_generation` so that the + writing of profiling stats is handled within the callables sent to + the child processes before the parent process assumes control. + """ + return vanilla_impl(self, result_job, TaskWrapper(func), *args, **kwargs) + + +@LineProfilingCache._method_wrapper # nocover +def wrap_worker_pool( + cache: LineProfilingCache, + vanilla_impl: Callable[Concatenate[_Queue, PS], None], + inqueue: _Queue, + *args: PS.args, + **kwargs: PS.kwargs +) -> None: + """ + Wrap around :py:func:`multiprocessing.pool.worker` so that child + processes can write profiling output as soon as the pool runs out of + tasks. + + Notes: + - This is only called in child processes and thus we can't + reliably measure coverage thereon; see also + :py:func:`wrap_bootstrap`. + + - This only works reliably for POSIX because we can handle + ``SIGTERM`` on child processes and ensure that they aren't + prematurely terminated. + """ + # Set a signal handler for SIGTERM to help child processes with + # consistently cleaning up + _setup_in_mp_child(cache) + # Note: using the `cache` itself as the context manager is prone to + # deadlock + with Cleanup() as cleanup: + if isinstance(inqueue.get, MethodType): + get = partial(_wrap_inqueue_get, cache, inqueue.get.__func__) + cleanup.patch( + inqueue, 'get', MethodType(get, inqueue), + name='._inqueue', + ) + return vanilla_impl(inqueue, *args, **kwargs) + + +def _wrap_inqueue_get( # nocover + cache: LineProfilingCache, + vanilla_impl: Callable[Concatenate[_Queue, PS], T], + self: _Queue, + /, + *args: PS.args, + **kwargs: PS.kwargs +) -> T: + """ + Intercept the sentinel value (:py:const:`None`) signifiying the end + of the queue and perform cleanup. + """ + result = vanilla_impl(self, *args, **kwargs) + ntasks: dict[int, int] + ntasks = cache._additional_data.setdefault('mp_queue_ntasks', {}) + queue_id = id(self) + if result is None: + n = ntasks.pop(queue_id, 0) + msg = f'`multiprocessing.pool.worker`: recieved {n} task(s) in total' + cache._debug_output(msg) + # Got sentinel value, process is about to exit + reason = 'ran out of tasks in `multiprocessing.process.worker()`' + if cache.main_pid != os.getpid(): + _dump_stats_quick(cache, debug=True, reason=reason) + else: + ntasks[queue_id] = ntasks.get(queue_id, 0) + 1 + return result + + +_patch_pool = _register_patch('pool', Patch('pool')).add_method +if _CAN_CATCH_SIGTERM: + # Only write profiling output once per process if it can be helped + _patch_pool('', 'worker', wrap_worker_pool) +else: + # Don't have a choice on platform like Windows, the only reliable + # way to ensure that the child survives until profiling output is + # written is to write it for every task before control is returned + # to the parent + _patch_pool('Pool', '_get_tasks', wrap_get_tasks, 'static') + _patch_pool( + 'Pool', '_guarded_task_generation', wrap_guarded_task_generation, + ) + +# ----------- `multiprocessing.process.BaseProcess` patches ------------ + + +@LineProfilingCache._method_wrapper +def wrap_terminate( + cache: LineProfilingCache, + vanilla_impl: Callable[[BaseProcess], None], + self: BaseProcess, +) -> None: + """ + Wrap around :py:meth:`.BaseProcess.terminate` to make sure that we + don't actually kill the child (OS-level) process before it has the + chance to properly clean up. + + Note: + We're technically polling in a loop, but it isn't actually + *that* bad: typically ``.terminate()`` is only called when we're + on the bad path (e.g. the parallel workload errored out), and + after the performance-critical part of the code (said workload). + """ + try: + config = MPConfig.from_cache(cache) + with config._get_terminate_poller(cache, self): + pass + except _Poller.Timeout as e: # Also handles `~.TimeoutWarning` + cache._debug_output(f'{type(e).__qualname__}: {e}') + raise + finally: # Always call `Process.terminate()` to avoid orphans + vanilla_impl(self) + + +@LineProfilingCache._method_wrapper # nocover +def wrap_bootstrap( + cache: LineProfilingCache, + vanilla_impl: Callable[Concatenate[BaseProcess, PS], T], + self: BaseProcess, + /, + *args: PS.args, **kwargs: PS.kwargs +) -> T: + """ + Wrap around :py:meth:`.BaseProcess._bootstrap` so that profiling + stats are written at the end. + + Notes: + + - This is only invoked in child processes, and + :py:mod:`coverage` seems to be having trouble with them in the + current setup, probably due to issues with .pth file + precendence causing :py:mod:`line_profiler` to be loaded + before it. Hence the ``# nocover``. + + - ``SIGTERM`` handling is not consistent on Windows, so we made + :py:meth:`.LineProfilingCache._add_signal_handler` a no-op + there. Hence :py:func:`wrap_terminate` remains necessary for + mitigating unclean exits. + """ + # Set a signal handler for SIGTERM to help child processes with + # consistently cleaning up + _setup_in_mp_child(cache) + try: + return vanilla_impl(self, *args, **kwargs) + finally: + reason = 'exiting `multiprocessing.process.BaseProcess._bootstrap`' + _dump_stats_quick(cache, debug=True, reason=reason) + + +_patch_process = partial( + _register_patch('process', Patch('process')).add_method, 'BaseProcess', +) +_patch_process('_bootstrap', wrap_bootstrap) +# We only need to patch `Process.terminate()` if we can't do SIGTERM +# handling, i.e. on Windows +if not _CAN_CATCH_SIGTERM: + _patch_process('terminate', wrap_terminate) + +# ---------------------- PID bookkeeping patches ----------------------- + + +@LineProfilingCache._method_wrapper +def wrap_handle_results( + cache: LineProfilingCache, + vanilla_impl: Callable[ + Concatenate[_Queue, Callable[[], tuple[Any, ...] | None], PS], + None + ], + outqueue: _Queue, + # Since we patched `outqueue.put()` in the child process, the result + # tuple pushed to the parent has an extra item (the child PID) + get: Callable[[], tuple[int, tuple[Any, ...]] | None], + *args: PS.args, + **kwargs: PS.kwargs +) -> None: + """ + Wrap around :py:meth:`multiprocessing.pool.Pool._handle_results` so + that it handles the extra info (PID of child process handling the + task) included by :py:func:`.wrap_worker_pid`. + + Note: + :py:meth:`.Pool._handle_results` is a static method. + """ + # Somehow this doesn't type-check with either `mypy` or `ty` when + # we use a `TypeVar` instead of `Any` with the tuple items... + # (see `ty` issue #3467) + wrapped_get = partial(_wrap_outqueue_quick_get, cache, get) + vanilla_impl(outqueue, wrapped_get, *args, **kwargs) + + +@LineProfilingCache._method_wrapper # nocover +def wrap_worker_pid( + _, # We don't need the cache instance, but `@_method_wrapper` does + vanilla_impl: Callable[Concatenate[_Queue, _Queue, PS], None], + inqueue: _Queue, + outqueue: _Queue, + *args: PS.args, + **kwargs: PS.kwargs +) -> None: + """ + Wrap around :py:func:`multiprocessing.pool.worker` so that child + processes report their PIDs as they pass the task results back to + the parent. + + Note: + This is only called in child processes and thus we can't + reliably measure coverage thereon; see also + :py:func:`wrap_bootstrap`. + """ + # Note: using the `cache` itself as the context manager is prone to + # deadlock + with Cleanup() as cleanup: + if isinstance(outqueue, SimpleQueue): + # `multiprocessing.dummy` instantiates C-based queue + # objects, which doesn't permit assigning to the instance + # (because it doesn't have an instance dict)... + # so just wrap the queue with a helper class + outqueue = _QueuePIDWrapper(outqueue) + elif isinstance(outqueue.put, MethodType): + put = partial(_wrap_outqueue_put, outqueue.put.__func__) + cleanup.patch( + outqueue, 'put', MethodType(put, outqueue), + name='._outqueue', + ) + return vanilla_impl(inqueue, outqueue, *args, **kwargs) + + +@LineProfilingCache._method_wrapper +def wrap_process( + cache: LineProfilingCache, + vanilla_impl: Callable[PS, P], + *args: PS.args, + **kwargs: PS.kwargs +) -> P: + """ + Wrap around :py:func:`multiprocessing.pool.Pool.Process` so that the + processes created can report on usage when + :py:meth:`.BaseProcess.join`-ed or + :py:meth:`.BaseProcess.terminate`-ed. + + Note: + :py:meth:`.Pool.Process` is a static method. + """ + proc = vanilla_impl(*args, **kwargs) + # Note: since we don't clean up here, there's no need to instantiate + # another `Cleanup` helper + name = f'<{type(proc).__name__} @ {hex(id(proc))}>' + patch = partial(cache.patch, cleanup=False, name=name) + for method, action in ('join', 'joining'), ('terminate', 'terminating'): + bound = getattr(proc, method) + if isinstance(bound, MethodType): + finalize = _wrap_process_finalize(cache, bound.__func__, action) + patch(proc, method, MethodType(finalize, proc)) + return proc + + +def _wrap_process_finalize( + cache: LineProfilingCache, + vanilla_impl: Callable[Concatenate[P, PS], None], + action: str, +) -> Callable[Concatenate[P, PS], None]: + """ + Check if the process has run any tasks; + if not, report to the cache. + + Note: + Since the process object is pickled, this method has to directly + return a function object instead of merely being + :py:func:`partial`-ed and wrapped in a + :py:class:`types.MethodType`. + """ + @wraps(vanilla_impl) + def finalize(self: P, *args: PS.args, **kwargs: PS.kwargs) -> None: + log = cache._debug_output + call = cache._format_call(vanilla_impl, self, *args, **kwargs) + try: + log(f'Wrapped call made: {call}') + pid: int | None = getattr(self, 'pid', None) + checked_procs = _get_checked_processes(cache) + identifier = id(self), pid + if not (pid is None or identifier in checked_procs): + ntasks = _get_ntasks(cache).pop(pid, 0) + if not ntasks: + cache._warn_possible_lack_of_stats(pid) + log(f'{action} process {pid} which ran {ntasks} task(s)...') + checked_procs.add(cast(tuple[int, int], identifier)) + except BaseException as e: + log( + f'Error in bookkeeping ({cache._format_exception(e)}), ' + 'invoking base implementation nonetheless...' + ) + raise e + finally: + try: + vanilla_impl(self, *args, **kwargs) + except BaseException as e: + state = f'failed ({cache._format_exception(e)})' + raise e + else: + state = 'succeeded' + finally: + log(f'Wrapped call {call} {state}') + + action = action.capitalize() + return finalize + + +def _wrap_outqueue_put( + vanilla_impl: Callable[ + Concatenate[_Queue, tuple[Any, ...], PS], None + ], + self: _Queue, + obj: tuple[Any, ...], + *args: PS.args, + **kwargs: PS.kwargs +) -> None: + """ + Smuggle in the PID of the child process so that the parent can keep + track of which child completed what task. + """ + vanilla_impl(self, (os.getpid(), obj), *args, **kwargs) + + +def _wrap_outqueue_quick_get( + cache: LineProfilingCache, + vanilla_impl: Callable[PS, tuple[int, tuple[Any, ...]] | None], + *args: PS.args, + **kwargs: PS.kwargs +) -> tuple[Any, ...] | None: + """ + Take and process the PID of the child process completing the task. + """ + result = vanilla_impl(*args, **kwargs) + if result is None: + return None + pid, orig_result = result + ntasks = _get_ntasks(cache) + ntasks[pid] = ntasks.get(pid, 0) + 1 + return orig_result + + +def _get_ntasks(cache: LineProfilingCache) -> dict[int, int]: + key = 'mp_proc_ntasks' + return cache._additional_data.setdefault(key, cast(dict[int, int], {})) + + +def _get_checked_processes( + cache: LineProfilingCache, +) -> set[tuple[int, int]]: + key = 'mp_proc_checked_workload' + return cache._additional_data.setdefault( + key, cast(set[tuple[int, int]], set()), + ) + + +_patch_pid = _register_patch('child_pids', Patch('pool')).add_method +_patch_pid('', 'worker', wrap_worker_pid) +_patch_pid('Pool', '_handle_results', wrap_handle_results, 'static') +_patch_pid('Pool', 'Process', wrap_process, 'static') + +# --------------- `multiprocessing.util` logging patches --------------- + + +def _cache_hook( + vanilla_impl: Callable[PS, T], + get_logging_message: Callable[PS, str], + /, + *args: PS.args, + **kwargs: PS.kwargs +) -> T: + msg = get_logging_message(*args, **kwargs) + LineProfilingCache.load()._debug_output(msg) + return vanilla_impl(*args, **kwargs) + + +def tee_log( + marker: str, + vanilla_impl: Callable[Concatenate[str, PS], None], + /, + msg: str, + *args: PS.args, + **kwargs: PS.kwargs +) -> None: + """ +Wrap around logging functions like + :py:func:`multiprocessing.util.debug` so that we can tee log + messages from the package to our own logs. + """ + def get_msg(msg: str, *_, **__) -> str: + return f'`multiprocessing` logging ({marker}): {msg}' + + _cache_hook( + vanilla_impl, get_msg, # type: ignore[arg-type] + msg, *args, **kwargs, + ) + + +_register_patch('logging', Patch('util')).add_target( + # The logging functions exists directly in the module namespace so + # no further attribute access is needed + '', {func: partial(partial, tee_log, func) for func in _LOGGERS}, +) + +# --------------------------- Misc. patches ---------------------------- + + +class RebootForkserverPatch: + """ + Reboot the process backing the global + :py:class:`multiprocessing.forkserver.ForkServer` instance: + + - When the patch is applied, so as to ensure that child processes + forked therefrom actually receives the active patches; and + + - When the session cache is cleaned up, so that child processes + forked therefrom is no longer polluted by the patches. + + Note: + This uses + :py:method:`multiprocessing.forkserver.ForkServer._stop()` which + is private API, but it's the same hack used in Python's own test + suite -- see the comment to said method. + """ + summary: ClassVar[mappingproxy[str, frozenset[str]]] = mappingproxy({}) + + @classmethod + def apply(cls, cache: LineProfilingCache, **_) -> None: + if not _CAN_USE_FORKSERVER: + return + cls.reboot() + cache.add_cleanup(cls.reboot) + + @staticmethod + def reboot() -> None: + # Appease the type-checker since `._stop()` is not public API + stop = getattr(forkserver._forkserver, '_stop', None) + assert callable(stop) + stop() + + +class ResourceTrackerPatch: + """ + Patch :py:mod:`multiprocessing.resource_tracker` so that + :py:func:`multiprocessing.resource_tracker.ensure_running` and the + eponymous method of + :py:class:`multiprocessing.resource_tracker.ResourceTracker` report + the resource-tracker server PIDs to the session cache. + + Note: + The ``ResourceTracker`` server process is spawned when the first + :py:mod:`multiprocessing` child process is created via the + ``spawn`` or ``forkserver`` start methods. While this server + process does not meaningfully contribute to the profiling result + either way, since it can be created with profiling set up, its + longevity means that :py:meth:`.LineProfilingCache.gather_stats` + often catches empty .lprof files which it has occupied but not + written to. + + To reduce noise while keeping the empty-file warning for other + output files, we report the PIDs used by the server to the + session cache so that they can be ignored if necessary. + """ + if _CAN_USE_RESOURCE_TRACKER: + summary: ClassVar[mappingproxy[str, frozenset[str]]] = mappingproxy({ + 'multiprocessing.resource_tracker': + frozenset({'ensure_running'}), + 'multiprocessing.resource_tracker.ResourceTracker': + frozenset({'ensure_running'}), + }) + else: + summary = mappingproxy({}) + + @staticmethod + @LineProfilingCache._method_wrapper + def wrap_ensure_running( + cache: LineProfilingCache, + vanilla_impl: Callable[['resource_tracker.ResourceTracker'], None], + self: 'resource_tracker.ResourceTracker', + ) -> None: + """ + Wrap around :py:meth:`multiprocessing.resource_tracker\ +.ResourceTracker.ensure_running` + so that the session cache can keep track of the PIDs used by the + resource-tracer server. + """ + maybe_pids: set[int | None] = {getattr(self, '_pid', None)} + try: + vanilla_impl(self) + finally: + maybe_pids.add(getattr(self, '_pid', None)) + pids = cast(set[int], maybe_pids - {None}) + if pids: + cache._warn_possible_lack_of_stats(pids) + + @classmethod + def apply( + cls, cache: LineProfilingCache, *, cleanup: bool = True, **_, + ) -> list[str]: + if _CAN_USE_RESOURCE_TRACKER: + patch = partial(cache.patch, cleanup=cleanup) + # Patch the method on the class + method = resource_tracker.ResourceTracker.ensure_running + method = cls.wrap_ensure_running(method) + patch(resource_tracker.ResourceTracker, 'ensure_running', method) + # Patch the preexisting bound method on the module + instance = resource_tracker._resource_tracker + bound_method = MethodType(method, instance) + patch(resource_tracker, 'ensure_running', bound_method) + return list(cls.summary) + + +class RunpyPatch: + """ + Patch the copy of :py:mod:`runpy` in the + :py:mod:`multiprocessing.spawn` namespace so that subprocesses can + perform rewrite-based profiling as with + :py:func:`line_profiler.autoprofile.autoprofile.run`. + + See also: + :py:mod:`line_profiler._child_process_profiling.runpy_patches` + """ + summary: ClassVar[mappingproxy[str, frozenset[str]]] + if _CAN_USE_SPAWN and hasattr(spawn, 'runpy'): + summary = mappingproxy({'multiprocessing.spawn': frozenset({'runpy'})}) + else: + summary = mappingproxy({}) + + @classmethod + def apply( + cls, cache: LineProfilingCache, *, cleanup: bool = True, **_, + ) -> list[str]: + if cls.summary: + patch = partial(cache.patch, cleanup=cleanup) + patch(spawn, 'runpy', create_runpy_wrapper(cache)) + return list(cls.summary) + + +# See `ty` issue #3429 for why we need the casts +_register_patch('__reboot_forkserver', cast(_Patch, RebootForkserverPatch)) +_register_patch('__resource_tracker', cast(_Patch, ResourceTrackerPatch)) +_register_patch('__spawn_runpy', cast(_Patch, RunpyPatch)) + +# -------------------------- Applying patches -------------------------- + + +def apply( + cache: LineProfilingCache, + reboot_forkserver: bool = True, + patches: Collection[PublicPatch] | None = None, +) -> None: + """ + Set up profiling in :py:mod:`multiprocessing` child processes by + applying patches to the module. + + Args: + cache (LineProfilingCache): + Cache instance governing the profiling run. + reboot_forkserver (bool): + Whether to reboot the global + :py:class`multiprocessing.forkserver.ForkServer` instance + so as to ensure that profiling happens on processes forked + therefrom (see Note). + patches \ +(Collection[Literal['pool', 'process', 'logging', 'child_pids'] \ +| None]): + Patches to apply to :py:mod:`multiprocessing`; see the + following section for a description of each; + the default is taken from the TOML config file. + + Patches: + ``'pool'``: + On Windows + Patch :py:class:`multiprocessing.pool.Pool`'s + ``._get_tasks()`` and ``._guarded_task_generation()`` + methods so that parallel tasks write profiling output. + Else + Patch :py:func:`multiprocessing.pool.worker` so that + profiling output is written as each child process runs + out of task. + ``'process'``: + Patch :py:class:`multiprocessing.process.BaseProcess`'s + ``._bootstrap()`` method (and ``.terminate()`` on Windows) + so that child processes write profiling output on exit and + are given enough time for that. + ``'logging'``: + Patch :py:mod:`multiprocessing.util`'s logging methods (e.g. + ``debug()`` and ``info()``) so that their messages are teed + to the cache's debug log. + ``'child_pids'``: + Patch the following components of + :py:mod:`multiprocess.pool` so that the parent process keeps + track of the workload executed by each child process, + reducing stray warnings about the lack of profiling stats + reported thereby: + + - :py:func:`multiprocessing.pool.worker` + + - :py:meth:`multiprocessing.pool.Pool._handle_results` + + - :py:meth:`multiprocessing.pool.Pool.Process` + + Side effects: + - The aforementioned patches applied + + - If ``reboot_forkserver=True``, fork-server process rebooted: + + - Immediately + + - When ``cache.cleanup()`` is run + + - Cleanup callbacks registered via ``cache.add_cleanup()`` + + Note: + Rebooting the fork server is necessary because its process + staticly inherits the environment when it is first spun up + (see :py:func:`multiprocessing.forkserver.ensure_running`). + Thus, without the reboots: + + - If in the same Python process we ever start up two separate + profliing sessions managed by different caches, the child + processes forked from the server will fail to inherit the + updated environment variables injected by the newer cache + instance, leading to the setup code in this subpackage not + being loaded. + + - Since 3.13.8 and 3.14.1, the bug where the ``main_path`` + argument to :py:func:`multiprocessing.forkserver.main` is + unused has been fixed (see ``cpython`` issue `GH-126631`_). + This causes ``sys.modules['__main__']`` to be set up in the + fork-server process, meaning that children forked therefrom + will NOT redo the setup. Thus, the fork-server process itself + will also need to be properly set up for profiling. + + .. _GH-126631: https://github.com/python/cpython/issues/126631 + """ + if getattr(multiprocessing, _PATCHED_MARKER, False): + return + if patches is None: + patches_dict = MPConfig.from_cache(cache).patches + patches_: set[str] = {p for p, use in patches_dict.items() if use} + else: + patches_ = {p.lower() for p in patches} + for name, patch in _PATCHES.items(): + if name in patches_: + should_apply = True + elif name.startswith('__'): + should_apply = (name != '__reboot_forkserver' or reboot_forkserver) + else: + should_apply = False + if should_apply: + msg = f'applying `multiprocessing` patch {name!r}' + cache._debug_output(msg.capitalize() + '...') + patch.apply(cache) + cache._debug_output('Done with ' + msg) + # Mark `multiprocessing` as having been patched + cache.patch(multiprocessing, _PATCHED_MARKER, True) diff --git a/line_profiler/_child_process_profiling/runpy_patches.py b/line_profiler/_child_process_profiling/runpy_patches.py new file mode 100644 index 00000000..22a600ea --- /dev/null +++ b/line_profiler/_child_process_profiling/runpy_patches.py @@ -0,0 +1,133 @@ +""" +Patches for :py:mod:`runpy` to be patched into the namespace of +:py:mod:`multiprocessing.spawn`, so that the rewriting of ``__main__`` +can be continued into child processes. +""" +from __future__ import annotations + +import os +from collections.abc import Callable +from functools import partial +from importlib.util import find_spec +from types import ModuleType +from typing import cast, TypeVar +from typing_extensions import Concatenate, ParamSpec + +from ..autoprofile.ast_tree_profiler import AstTreeProfiler +from ..autoprofile.run_module import AstTreeModuleProfiler +from ..autoprofile.util_static import modname_to_modpath +from ..cleanup import Cleanup +from .cache import LineProfilingCache + + +__all__ = ('create_runpy_wrapper',) + + +PS = ParamSpec('PS') +T = TypeVar('T') + + +THIS_MODULE = (lambda: None).__module__ + + +def _copy_module(name: str) -> ModuleType: + """ + Returns: + module (ModuleType): + Module object, which is a fresh copy of the module named + ``name`` + """ + spec = find_spec(name) + if spec is None: + raise ModuleNotFoundError(name) + assert spec.loader + assert callable(getattr(spec.loader, 'exec_module', None)) + module = ModuleType(spec.name) + for attr, value in { + '__spec__': spec, + '__name__': spec.name, + '__file__': spec.origin, + '__path__': spec.submodule_search_locations, + }.items(): + if value is not None: + setattr(module, attr, value) + spec.loader.exec_module(module) + return module + + +def _exec( + cache: LineProfilingCache, + CodeWriter: type[AstTreeProfiler], + _code, # This represents the first pos arg to `exec()` (ignored) + /, + *args, **kwargs, +) -> None: + """ + To be monkey-patched into :py:mod:`runpy`'s namespace as `exec()` + so that rewritten and autoprofiled code at ``cache.rewrite_module`` + is always executed. + """ + assert cache.rewrite_module + call = cache._format_call('exec', _code, *args, **kwargs) + cache._debug_output(f'Calling via {THIS_MODULE}: `{call}`') + fname = str(cache.rewrite_module) + code_writer = CodeWriter( + fname, + list(cache.profiling_targets), + cache.profile_imports, + ) + code = compile(code_writer.profile(), fname, 'exec') + exec(code, *args, **kwargs) + + +def _run( + cache: LineProfilingCache, + runpy: ModuleType, + func: Callable[Concatenate[str, PS], T], + name: str, + resolve_target_to_path: Callable[[str], str], + CodeWriter: type[AstTreeProfiler], + target: str, + /, + *args: PS.args, **kwargs: PS.kwargs +) -> T: + call = cache._format_call('runpy.' + name, target, *args, **kwargs) + cache._debug_output(f'Calling via {THIS_MODULE}: `{call}`') + if cache.rewrite_module: + try: + filename = resolve_target_to_path(target) + profile = os.path.samefile(filename, cache.rewrite_module) + except Exception as e: + cache._debug_output( + f'{THIS_MODULE}: Failed to check whether code loaded by ' + f'`runpy.{name}(...)` is to be rewritten ' + f'({type(e).__name__}: {e})' + ) + profile = False + else: + profile = False + # If we are about to run the code to be autoprofiled, monkey-patch + # `exec()` into the `runpy` namespace which just rewrites + # `cache.rewrite_module` and executes it + with Cleanup() as cleanup: + if profile: + cleanup.patch(runpy, 'exec', partial(_exec, cache, CodeWriter)) + return func(target, *args, **kwargs) + + +def create_runpy_wrapper(cache: LineProfilingCache) -> ModuleType: + """ + Create a copy of :py:mod:`runpy` which does code rewriting similar + to :py:func:`line_profiler.autoprofile.autoprofile.run` for the + appropriate file as indicated by ``cache``. + """ + runpy = _copy_module('runpy') + for func, resolver, CodeWriter in [ + ('run_path', str, AstTreeProfiler), + ('run_module', modname_to_modpath, AstTreeModuleProfiler), + ]: + impl = getattr(runpy, func) + res = cast(Callable[[str], str], resolver) # Help `mypy` out + wrapper = partial(_run, cache, runpy, impl, func, res, CodeWriter) + setattr(runpy, func, wrapper) + return runpy diff --git a/line_profiler/_threading_patches.py b/line_profiler/_threading_patches.py new file mode 100644 index 00000000..32f13471 --- /dev/null +++ b/line_profiler/_threading_patches.py @@ -0,0 +1,124 @@ +""" +Patch :py:mod:`threading` so that profiling extends consistenly into +processes it creates. +""" +from __future__ import annotations + +import threading +from collections.abc import Callable +from functools import wraps +from typing import TYPE_CHECKING, Any, TypeVar +from typing_extensions import ParamSpec, Concatenate + +from ._line_profiler import ( # type: ignore + USE_LEGACY_TRACE as SHOULD_PATCH_THREADING, +) +from .line_profiler import LineProfiler +from .cleanup import Cleanup + + +__all__ = ('apply', 'SHOULD_PATCH_THREADING') + + +T = TypeVar('T') +PS = ParamSpec('PS') + +_PATCHED_MARKER = '__line_profiler_patched_threading__' + + +def make_syncing_wrapper( + func: Callable[PS, T], prof: LineProfiler, enable_count: int, +) -> Callable[PS, T]: + """ + Wrap the callable ``func`` so that when we spin up a new thread, we + sync the + :py:attr:`line_profiler.line_profiler.LineProfiler.enable_count` of + the active profiler (stored at the cache instance loaded from + :py:meth:`LineProfilingCache.load`) with ``enable_count``. + + Note: + This only seems to work as intended when using the legacy trace + system... + """ + @wraps(func) + def wrapper(*args: PS.args, **kwargs: PS.kwargs) -> T: + if TYPE_CHECKING: + assert hasattr(prof, 'enable_count') + assert isinstance(prof.enable_count, int) + # Note: `prof.enable_count` is most likely to be zero on the new + # thread + thread_enable_count: int = prof.enable_count + for _ in range(enable_count - thread_enable_count): + prof.enable_by_count() + try: + return func(*args, **kwargs) + finally: + # Reset enable counts to avoid problems if the thread id is + # ever reused + for _ in range(prof.enable_count - thread_enable_count): + prof.disable_by_count() + + return wrapper + + +def make_thread_init_wrapper( + prof: LineProfiler, + vanilla_impl: Callable[ + Concatenate[threading.Thread, None, Callable[..., Any] | None, PS], + None + ], +) -> Callable[ + Concatenate[threading.Thread, None, Callable[..., Any] | None, PS], None +]: + """ + Wrap the initializer of :py:class:`threading.Thread` so that the + profiler's :py:attr:`LineProfiler.enable_count` is synced up on + newly spun-up threads. + """ + @wraps(vanilla_impl) + def wrapper( + self: threading.Thread, + group: None = None, + target: Callable[..., Any] | None = None, + *args: PS.args, + **kwargs: PS.kwargs + ) -> None: + enable_count: int | None = getattr(prof, 'enable_count', None) + if target is not None and enable_count: + if TYPE_CHECKING: + assert prof is not None + target = make_syncing_wrapper(target, prof, enable_count) + vanilla_impl(self, group, target, *args, **kwargs) + + return wrapper + + +def apply(cleanup: Cleanup, prof: LineProfiler) -> None: + """ + Set up profiling in threads started by :py:mod:`threading` by + applying patches to the module. + + Args: + cleanup (Cleanup) + Cleanup instance managing the profiling session + + Side effects: + - :py:mod:`threading` marked as having been set up + + - The following methods and functions patched: + + - :py:meth:`threading.Thread.__init__` + + - Cleanup callbacks registered via ``cleanup.add_cleanup()`` + + Note: + This is a no-op when using :py:mod:`sys.monitoring`-based + profiling. + """ + if not SHOULD_PATCH_THREADING: + return + if getattr(threading, _PATCHED_MARKER, False): + return + init_wrapper = make_thread_init_wrapper(prof, threading.Thread.__init__) + cleanup.patch(threading.Thread, '__init__', init_wrapper) + cleanup.patch(threading, _PATCHED_MARKER, True) diff --git a/line_profiler/cleanup.py b/line_profiler/cleanup.py new file mode 100644 index 00000000..11d21890 --- /dev/null +++ b/line_profiler/cleanup.py @@ -0,0 +1,418 @@ +""" +Utilities for cleaning up after ourselves. +""" +from __future__ import annotations + +from collections.abc import ( + Callable, Generator, Iterable, Mapping, MutableMapping, +) +from functools import partial +from inspect import getattr_static +from operator import setitem +from pathlib import Path +from typing import Any, TypeVar, cast +from typing_extensions import Concatenate, ParamSpec, Self + +from .line_profiler_utils import CallbackRepr, make_tempfile +from . import _diagnostics as diagnostics + + +__all__ = ('Cleanup',) + +PS = ParamSpec('PS') +K = TypeVar('K') +V = TypeVar('V') +_Stacks = dict[float, list[Callable[[], Any]]] +_StackContexts = list[_Stacks] + + +_CALLBACK_REPR_HELPER = CallbackRepr(maxother=cast(int, float('inf'))) +_CALLBACK_REPR = _CALLBACK_REPR_HELPER.repr + + +class Cleanup: + """ + Object which holds cleanup callbacks. Also provides convenience + methods for creating tempfiles, updating mappings, and setting + attributes on objects. + """ + def __init__(self, *_, **__) -> None: + self._contexts: _StackContexts = [] + + def __enter__(self) -> Self: + """ + Returns: + The instance + + Note: + This context manager is reentrant; entering the context + create a new set of cleanup stacks, which is then cleaned up + on :py:meth:`~.__exit__`. + + Example: + >>> strings = [] + >>> add = strings.append + >>> with Cleanup() as cleanup: + ... cleanup.add_cleanup(add, 'one') + ... # Increased priority + ... cleanup.add_cleanup_with_priority(add, 1, 'two') + ... add('three') + ... with cleanup: + ... # Decreased priority + ... cleanup.add_cleanup_with_priority( + ... add, -1, 'four', + ... ) + ... cleanup.add_cleanup(add, 'five') + ... add('six') + ... add('seven') + ... # Increased priority + ... cleanup.add_cleanup_with_priority(add, 1, 'eight') + ... + >>> strings # doctest: +NORMALIZE_WHITESPACE + ['three', 'six', 'five', 'four', 'seven', 'eight', 'two', + 'one'] + """ + self._contexts.append({}) + return self + + def __exit__(self, *_, **__) -> Any: + """ + Call ``~.cleanup(1)``, clearing the level of cleanup stacks we + previously :py:meth:`~.__enter__`-ed into. + """ + self.cleanup(1, reason='context exit') + + # Cleanup methods + + def cleanup( + self, levels: int | None = None, *, reason: str | None = None, + ) -> None: + """ + Pop cleanup callbacks from the internal stacks added via + :py:meth:`~.add_cleanup` etc. and call them in order. + + Args: + levels (int | None): + Number of stack levels to clear; passing :py:const`None` + clears the entire stack of callback stacks + reason (str | None): + Optional description of the reason for cleaning up + """ + def pop_all_contexts( + contexts: _StackContexts, + ) -> Generator[_Stacks, None, None]: + while contexts: + yield contexts.pop() + + def pop_n_levels_of_contexts( + contexts: _StackContexts, n: int, + ) -> Generator[_Stacks, None, None]: + for _ in range(n): + try: + yield contexts.pop() + except IndexError: # Ran out of levels + return + + pop_contexts: Iterable[_Stacks] + if levels is None: + pop_contexts = pop_all_contexts(self._contexts) + else: + pop_contexts = pop_n_levels_of_contexts(self._contexts, levels) + cleanup = partial(self._cleanup, self._debug_output, reason=reason) + for stacks in pop_contexts: + cleanup(stacks) + + @staticmethod + def _cleanup( + log: Callable[[str], Any], stacks: _Stacks, reason: str | None, + ) -> None: + ncallbacks_total = sum(len(stack) for stack in stacks.values()) + note = f'{ncallbacks_total} callback(s)' + if reason: + note = f'{reason}; {note}' + if not ncallbacks_total: + log(f'Cleanup aborted ({note})') + return + # Bookend the cleanup loop with log messages to help detect if + # child processes are prematurely terminated + log(f'Starting cleanup ({note})...') + ncallbacks_run = 0 + for priority in sorted(stacks, reverse=True): + callbacks = stacks.pop(priority) + while callbacks: + callback = callbacks.pop() + callback_repr = _CALLBACK_REPR(callback) + ncallbacks_run += 1 + try: + callback() + except Exception as e: + state = 'failed' + msg = f'{callback_repr}: {type(e).__name__}: {e}' + else: + state, msg = 'succeeded', f'{callback_repr}' + log( + f'- Cleanup {state} ' + f'({ncallbacks_run}/{ncallbacks_total}): {msg}', + ) + log(f'... cleanup completed ({note})') + + def add_cleanup( + self, callback: Callable[PS, Any], *args: PS.args, **kwargs: PS.kwargs, + ) -> None: + """ + Shorthand for calling :py:meth:`~.add_cleanup_with_priority` + with ``priority=0``, which should be considered the default. + """ + self.add_cleanup_with_priority(callback, 0, *args, **kwargs) + + def add_cleanup_with_priority( + self, callback: Callable[PS, Any], priority: float, /, + *args: PS.args, **kwargs: PS.kwargs, + ) -> None: + """ + Add a cleanup callback to the internal stacks. + + Args: + callback (Callable[..., Any]): + Callback to be called at cleanup + priority (float): + Numeric priority value; callbacks with a HIGHER value + are invoked BEFORE those with bigger values + *args, **kwargs: + Arguments ``callback`` should be called with + + Example: + >>> strings = [] + >>> cleanup = Cleanup() + >>> # Default priority + >>> cleanup.add_cleanup(strings.append, 'first') + >>> # Decreased priority + >>> cleanup.add_cleanup_with_priority( + ... strings.append, -1, 'second', + ... ) + >>> # Increased priority + >>> cleanup.add_cleanup_with_priority( + ... strings.append, 1, 'third', + ... ) + >>> cleanup.add_cleanup(strings.append, 'fourth') + >>> assert not strings + >>> cleanup.cleanup() + >>> strings + ['third', 'fourth', 'first', 'second'] + """ + if args or kwargs: + callback = partial(callback, *args, **kwargs) + self._current_context.setdefault(priority, []).append(callback) + header = 'Cleanup callback added' + if priority: + header = f'{header} (priority: {priority})' + self._debug_output(f'{header}: {_CALLBACK_REPR(callback)}') + + # Convenience methods + + def update_mapping( + self, + mapping: MutableMapping[K, V], + updates: Mapping[K, V], + *, + _format_debug_msg: Callable[[Mapping[K, V], K, str], str] = ( + lambda mapping, key, change: 'Update {}[{!r}]: {}'.format( + object.__repr__(mapping), key, change, + ) + ), + ) -> None: + """ + Update a mapping with another and add cleanup callbacks to + reverse them. + + Args: + mapping (MutableMapping[K, V]): + Mapping to be updated + updates (Mapping[K, V]): + Mapping containing the updates + + Example: + >>> d1 = {1: 2, 3: 4} + >>> d2 = d1.copy() + >>> updates = {0: -1, 3: 5} + >>> with Cleanup() as cleanup: + ... cleanup.update_mapping(d1, updates) + ... for key, value in updates.items(): + ... assert d1[key] == value + ... + >>> assert d1 == d2 + """ + for key, value in updates.items(): + try: + old = mapping[key] + except KeyError: + self.add_cleanup(mapping.pop, key, None) + change = f'{value!r} (new)' + else: + self.add_cleanup(setitem, mapping, key, old) + change = f'{old!r} -> {value!r}' + self._debug_output(_format_debug_msg(mapping, key, change)) + mapping[key] = value + + def make_tempfile( + self, *, + delete: bool = True, + priority: float = 0, + _format_debug_msg: Callable[[Path], str] = ( + 'Created tempfile: {}'.format + ), + **kwargs + ) -> Path: + """ + Create a fresh tempfile with :py:func:`tempfile.mkstemp`. + + Args: + delete (bool): + Whether to remove the file on cleanup + priority (float): + Cleanup priority (see + :py:meth:`~.add_cleanup_with_priority`) + **kwargs: + Passed to :py:func:`tempfile.mkstemp` + + Returns: + path (Path): + Path to the created file. + + Example: + >>> prefix, suffix = 'my_file_', '.txt' + >>> with Cleanup() as cleanup: + ... path = cleanup.make_tempfile( + ... prefix=prefix, suffix=suffix, + ... ) + ... assert path.exists() + ... assert path.name.startswith(prefix) + ... assert path.name.endswith(suffix) + ... + >>> assert not path.exists() + """ + path = make_tempfile(**kwargs) + self._debug_output(_format_debug_msg(path)) + if delete: + self.add_cleanup_with_priority( + path.unlink, priority, missing_ok=True, + ) + return path + + def patch( + self, obj: Any, attr: str, value: Any, *, + name: str | None = None, + static: bool = True, + cleanup: bool = True, + priority: float = 0, + ) -> None: + """ + Patch an attribute on an object. + + Args: + obj (Any): + Object to be patched + attr (str): + Name of the attribute + value (Any): + Value to be assigned to said attribute of ``obj`` + name (str | None): + Optional name for ``obj`` to be used in debug messages + static (bool): + Whether to use :py:func:`inspect.getattr_static` to + get the current value of the attribute + cleanup (bool): + Whether to reverse the patch (by resetting or deleting + the attribute) on cleanup + priority (float): + Cleanup priority (see + :py:meth:`~.add_cleanup_with_priority`) + + Example: + >>> class Object(object): + ... pass # Allow setting arbitrary attributes + ... + >>> + >>> obj = Object() + >>> obj.foo = 1 + >>> with Cleanup() as cleanup: + ... cleanup.patch(obj, 'foo', 2) + ... cleanup.patch(obj, 'bar', 3) + ... assert obj.foo == 2 + ... assert obj.bar == 3 + ... + >>> assert obj.foo == 1 + >>> assert not hasattr(obj, 'bar') + """ + if cleanup: + add_cleanup: Callable[ + Concatenate[Callable[..., Any], float, ...], Any + ] = self.add_cleanup_with_priority + else: + # ... yeah gotta disagree with flake8, a lambda makes + # perfect sense here + add_cleanup = lambda *_, **__: None # noqa: E731 + get_attribute = getattr_static if static else getattr + + try: + old = get_attribute(obj, attr) + except AttributeError: + add_cleanup(delattr, priority, obj, attr) + else: + add_cleanup(setattr, priority, obj, attr, old) + setattr(obj, attr, value) + if name is None: + name = self._get_name(obj) + msg = 'Patched `{}.{}` -> `{}`'.format(name, attr, value) + self._debug_output(msg) + + # Helper methods + + @staticmethod + def _get_name(obj: Any, /) -> str: + """ + Get an appropriate name for an arbitrary object. + + Example: + >>> import textwrap + >>> + >>> + >>> Cleanup._get_name(textwrap) + 'textwrap' + >>> Cleanup._get_name(textwrap.dedent) + 'textwrap.dedent' + >>> Cleanup._get_name(str) + 'str' + >>> Cleanup._get_name(print) + 'print' + >>> Cleanup._get_name(object()) # doctest: +ELLIPSIS + '' + """ + if hasattr(obj, '__qualname__'): + name = obj.__qualname__ + elif hasattr(obj, '__name__'): + name = obj.__name__ + else: + return repr(obj) + if hasattr(obj, '__module__'): + if obj.__module__ not in ('builtins', '__builtins__'): + name = f'{obj.__module__}.{name}' + return str(name) + + def _debug_output(self, msg: str, /) -> None: + """ + Write debugging output. + + Note: + This default implementation just writes to the logger. + """ + diagnostics.log.debug(msg) + + @property + def _current_context(self) -> _Stacks: + try: + return self._contexts[-1] + except IndexError: + ctx: _Stacks = {} + self._contexts.append(ctx) + return ctx diff --git a/line_profiler/curated_profiling.py b/line_profiler/curated_profiling.py new file mode 100644 index 00000000..e87dec60 --- /dev/null +++ b/line_profiler/curated_profiling.py @@ -0,0 +1,230 @@ +""" +Tools for setting up profiling in a curated environment (e.g. with +the use of :py:mod:`kernprof`). +""" +from __future__ import annotations + +import builtins +import dataclasses +import os +import warnings +from collections.abc import Collection +from io import StringIO +from textwrap import indent +from typing import Any, TextIO, cast +from typing_extensions import Self + +from . import _diagnostics as diagnostics, profile as _global_profiler +from ._threading_patches import apply as apply_threading_patches +from .autoprofile.autoprofile import ( + _extend_line_profiler_for_profiling_imports as upgrade_profiler, +) +from .autoprofile.util_static import modpath_to_modname +from .autoprofile.eager_preimports import ( + is_dotted_path, write_eager_import_module, +) +from .cleanup import Cleanup +from .cli_utils import short_string_path +from .line_profiler import LineProfiler +from .profiler_mixin import ByCountProfilerMixin + + +__all__ = ('ClassifiedPreimportTargets', 'CuratedProfilerContext') + + +@dataclasses.dataclass +class ClassifiedPreimportTargets: + """ + Pre-import targets classified into three bins: ``regular`` targets, + targets to ``recurse`` into, and ``invalid`` targets + """ + regular: list[str] = dataclasses.field(default_factory=list) + recurse: list[str] = dataclasses.field(default_factory=list) + invalid: list[str] = dataclasses.field(default_factory=list) + + def __bool__(self) -> bool: + return bool(self.regular or self.recurse) + + def write_preimport_module( + self, fobj: TextIO, *, debug: bool | None = None, **kwargs + ) -> None: + """ + Convenience interface with + :py:func:`~.write_eager_import_module`, writing a module which + when imported sets up profiling of the targets. + + Args: + fobj (TextIO): + File object to write said module to. + debug (Optional[bool]): + Whether to generate debugging outputs. + kwargs: + Passed to :py:func:`~.write_eager_import_module`. + """ + if self.invalid: + invalid_targets = sorted(set(self.invalid)) + msg = ( + '{} profile-on-import target{} cannot be converted to ' + 'dotted-path form: {!r}'.format( + len(invalid_targets), + '' if len(invalid_targets) == 1 else 's', + invalid_targets, + ) + ) + # Log before warn in case the warning is raised + diagnostics.log.warning(msg) + warnings.warn(msg, stacklevel=2) + + if not self: + return None + # Note: `ty` (but not `mypy`) keeps complaining about the our + # splatting this dict; explicitly use `Any` to tell it to shut + # up. + write_module_kwargs: dict[str, Any] = { + 'dotted_paths': self.regular, + 'recurse': self.recurse, + **kwargs, + } + if diagnostics.DEBUG if debug is None else debug: + with StringIO() as sio: + write_eager_import_module(stream=sio, **write_module_kwargs) + code = sio.getvalue() + print(code, file=fobj) + if hasattr(fobj, 'name'): + fobj_repr = repr(short_string_path(str(fobj.name))) + else: + fobj_repr = repr(fobj) # Fall back + diagnostics.log.debug( + f'Wrote temporary module for pre-imports to {fobj_repr}:\n' + + indent(code, ' ') + ) + else: + write_eager_import_module(stream=fobj, **write_module_kwargs) + + @classmethod + def from_targets( + cls, + targets: Collection[str], + exclude: Collection[os.PathLike[str] | str] = (), + ) -> Self: + """ + Create an instance based on a collection of targets + (like what is supplied to :cmd:`kernprof --prof-mod=...`). + + Args: + targets (Collection[str]) + Collection of dotted paths and filenames to profile. + exclude (Collection[str]) + Collections of filenames which are explicitly excluded + from being profiled. + + Return: + New instance. + """ + filtered_targets = [] + recurse_targets = [] + invalid_targets = [] + for target in targets: + if is_dotted_path(target): + modname = target + else: + # Paths already normalized by + # `_normalize_profiling_targets()` + if not os.path.exists(target): + invalid_targets.append(target) + continue + if any( + os.path.samefile(target, excluded) for excluded in exclude + ): + # Ignore the script to be run in eager importing + # (`line_profiler.autoprofile.autoprofile.run()` + # will handle it) + continue + modname = modpath_to_modname(target, hide_init=False) + if modname is None: # Not import-able + invalid_targets.append(target) + continue + if modname.endswith('.__init__'): + modname = modname.rpartition('.')[0] + filtered_targets.append(modname) + else: + recurse_targets.append(modname) + return cls(filtered_targets, recurse_targets, invalid_targets) + + +class CuratedProfilerContext(Cleanup): + """ + Context manager for handling various bookkeeping tasks when setting + up and tearing down profiling: + + - Slipping ``prof`` into the builtin namespace (if + ``insert_builtin`` is true) and :py::deco:`~.profile` + - Patch :py:class:`threading.Thread` so that line-profiling is + enabled on new threads if it is on the spawning threads + - At exit, clearing the ``enable_count`` of ``prof``, properly + disabling it + + Notes: + + - The attributes on this object are to be considered + implementation details, but not its methods and their + signatures. + + - In contrast to the base class (:py:class:`Cleanup`), while + this context manager is still reentrant, reentering in nested + `with: ...` statements is a no-op. + """ + def __init__( + self, + prof: ByCountProfilerMixin, + insert_builtin: bool = False, + builtin_loc: str = 'profile', + ) -> None: + super().__init__() + self.prof = prof + self.insert_builtin = insert_builtin + self.builtin_loc = builtin_loc + self._installed = False + self._kpo = _global_profiler._kernprof_overwrite + + def _global_install(self, prof: ByCountProfilerMixin | None) -> None: + # Wrapper to convince type-checkers it is okay to pass these + # stuff to `._kernprof_overwrite()`. We don't want to patch + # that method's signature because passing non `LineProfiler` + # objects to it should be the exception, not the norm. + self._kpo(cast(LineProfiler, prof)) + + @staticmethod + def _disable_profiler(prof: ByCountProfilerMixin) -> None: + for _ in range(getattr(prof, 'enable_count', 0)): + prof.disable_by_count() + + def install(self) -> None: + if self._installed: + return + # Equip the profiler instance with the + # `.add_imported_function_or_module()` pseudo-method + upgrade_profiler(self.prof) + # Overwrite the explicit profiler (`@line_profiler.profile`) + self._global_install(self.prof) + self.add_cleanup(self._global_install, None) + # Patch `threading` + if isinstance(self.prof, LineProfiler): + apply_threading_patches(self, self.prof) + # Set up hooks to deal with inserting `.prof` as a builtin name + if self.insert_builtin: + self.patch(builtins, self.builtin_loc, self.prof) + # Disable the profiler + self.add_cleanup(self._disable_profiler, self.prof) + + self.patch(self, '_installed', True) + + def uninstall(self) -> None: + self.cleanup(reason='uninstalling profiling context') + + def __enter__(self) -> Self: + self.install() + return self + + def __exit__(self, *_, **__) -> None: + self.uninstall() diff --git a/line_profiler/line_profiler.py b/line_profiler/line_profiler.py index bede6f6c..539f8703 100755 --- a/line_profiler/line_profiler.py +++ b/line_profiler/line_profiler.py @@ -18,21 +18,12 @@ import tempfile import types import tokenize +import warnings from argparse import ArgumentParser +from collections.abc import Callable, Collection, Mapping, Sequence from datetime import datetime from os import PathLike -from typing import ( - TYPE_CHECKING, - IO, - Callable, - Literal, - Mapping, - Protocol, - Sequence, - TypeVar, - cast, - Tuple, -) +from typing import TYPE_CHECKING, IO, Any, Literal, Protocol, TypeVar, cast try: from ._line_profiler import ( @@ -62,7 +53,7 @@ class _IPythonLike(Protocol): def register_magics(self, magics: type) -> None: ... PS = ParamSpec('PS') - _TimingsMap = Mapping[Tuple[str, int, str], list[Tuple[int, int, int]]] + _TimingsMap = Mapping[tuple[str, int, str], list[tuple[int, int, int]]] T = TypeVar('T') T_co = TypeVar('T_co', covariant=True) @@ -226,6 +217,15 @@ def tokeneater( return super().tokeneater(type, token, srowcol, erowcol, line) +class _EmptyFileError(OSError): + """ + Error raised when trying to read profiling data from an empty file. + """ + def __init__(self, file: PathLike[str] | str) -> None: + super().__init__(str(file)) + self.file = file + + class _WrapperInfo: """ Helper object for holding the state of a wrapper function. @@ -264,8 +264,8 @@ def __eq__(self, other: object) -> bool: Example: >>> from copy import deepcopy >>> stats1 = LineStats( - ... {('foo', 1, 'spam.py'): [(2, 10, 300)], - ... ('bar', 10, 'spam.py'): + ... {('spam.py', 1, 'foo'): [(2, 10, 300)], + ... ('spam.py', 10, 'bar'): ... [(11, 2, 1000), (12, 1, 500)]}, ... 1E-6) >>> stats2 = deepcopy(stats1) @@ -274,7 +274,7 @@ def __eq__(self, other: object) -> bool: >>> assert stats2 != stats1 >>> stats3 = deepcopy(stats1) >>> assert stats1 == stats3 is not stats1 - >>> stats3.timings['foo', 1, 'spam.py'][:] = [(2, 11, 330)] + >>> stats3.timings['spam.py', 1, 'foo'][:] = [(2, 11, 330)] >>> assert stats3 != stats1 """ for attr in 'timings', 'unit': @@ -290,20 +290,20 @@ def __add__(self, other: _StatsLike) -> Self: """ Example: >>> stats1 = LineStats( - ... {('foo', 1, 'spam.py'): [(2, 10, 300)], - ... ('bar', 10, 'spam.py'): + ... {('spam.py', 1, 'foo'): [(2, 10, 300)], + ... ('spam.py', 10, 'bar'): ... [(11, 2, 1000), (12, 1, 500)]}, ... 1E-6) >>> stats2 = LineStats( - ... {('bar', 10, 'spam.py'): + ... {('spam.py', 10, 'bar'): ... [(11, 10, 20000), (12, 5, 1000)], - ... ('baz', 5, 'eggs.py'): [(5, 2, 5000)]}, + ... ('eggs.py', 5, 'baz'): [(5, 2, 5000)]}, ... 1E-7) >>> stats_sum = LineStats( - ... {('foo', 1, 'spam.py'): [(2, 10, 300)], - ... ('bar', 10, 'spam.py'): + ... {('spam.py', 1, 'foo'): [(2, 10, 300)], + ... ('spam.py', 10, 'bar'): ... [(11, 12, 3000), (12, 6, 600)], - ... ('baz', 5, 'eggs.py'): [(5, 2, 500)]}, + ... ('eggs.py', 5, 'baz'): [(5, 2, 500)]}, ... 1E-6) >>> assert stats1 + stats2 == stats2 + stats1 == stats_sum """ @@ -314,20 +314,20 @@ def __iadd__(self, other: _StatsLike) -> Self: """ Example: >>> stats1 = LineStats( - ... {('foo', 1, 'spam.py'): [(2, 10, 300)], - ... ('bar', 10, 'spam.py'): + ... {('spam.py', 1, 'foo'): [(2, 10, 300)], + ... ('spam.py', 10, 'bar'): ... [(11, 2, 1000), (12, 1, 500)]}, ... 1E-6) >>> stats2 = LineStats( - ... {('bar', 10, 'spam.py'): + ... {('spam.py', 10, 'bar'): ... [(11, 10, 20000), (12, 5, 1000)], - ... ('baz', 5, 'eggs.py'): [(5, 2, 5000)]}, + ... ('eggs.py', 5, 'baz'): [(5, 2, 5000)]}, ... 1E-7) >>> stats_sum = LineStats( - ... {('foo', 1, 'spam.py'): [(2, 10, 300)], - ... ('bar', 10, 'spam.py'): + ... {('spam.py', 1, 'foo'): [(2, 10, 300)], + ... ('spam.py', 10, 'bar'): ... [(11, 12, 3000), (12, 6, 600)], - ... ('baz', 5, 'eggs.py'): [(5, 2, 500)]}, + ... ('eggs.py', 5, 'baz'): [(5, 2, 500)]}, ... 1E-6) >>> address = id(stats2) >>> stats2 += stats1 @@ -367,17 +367,85 @@ def to_file(self, filename: PathLike[str] | str) -> None: with open(filename, 'wb') as f: pickle.dump(self, f, pickle.HIGHEST_PROTOCOL) + @classmethod + def get_empty_instance(cls) -> Self: + """ + Returns: + instance (LineStats): + New instance without any profiling data. + """ + prof = LineProfiler() + if TYPE_CHECKING: + assert hasattr(prof, 'timer_unit') + return cls({}, cast(float, prof.timer_unit)) + @classmethod def from_files( - cls, file: PathLike[str] | str, /, *files: PathLike[str] | str + cls, + file: PathLike[str] | str, + /, + *files: PathLike[str] | str, + on_empty: Literal['ignore', 'warn', 'error'] = 'warn', + on_defective: Literal['ignore', 'warn', 'error'] = 'error', ) -> Self: """ Utility function to load an instance from the given filenames. + + Args: + file (PathLike[str] | str): + File to load profiling data from + *files (PathLike[str] | str): + Ditto above + on_empty, on_defective (Literal['ignore', 'warn', 'error']): + What to do if some files are empty (resp. otherwise fail + to load): ``'ignore'`` those files, skip them but with a + ``'warn'``-ing, or raise the ``'error'`` as soon as one + is encountered + + Returns: + instance (LineStats): + New instance """ stats_objs = [] - for file in [file, *files]: + failures: dict[str, str] = {} + empty_files: set[str] = set() + all_files = [file, *files] + + for file in all_files: with open(file, 'rb') as f: - stats_objs.append(pickle.load(f)) + try: + if not os.stat(file).st_size: + raise _EmptyFileError(file) + stats_objs.append(pickle.load(f)) + except _EmptyFileError as e: + if on_empty == 'error': + raise + empty_files.add(str(e.file)) + except Exception as e: + if on_defective == 'error': + raise + failures[str(file)] = f'{type(e).__name__}: {e}' + + problems: Collection[Any] + for problems, description, behavior in [ + (list(empty_files), 'is/are empty and thus skipped:', on_empty), + (failures, 'failed to load and is/are skipped', on_defective), + ]: + if not problems: + continue + msg = '{} file(s) out of {} {}: {!r}'.format( + len(problems), len(all_files), description, problems, + ) + if behavior == 'warn': + # Log before warning because warnings may be promoted to + # errors + diagnostics.log.warning(msg) + warnings.warn(msg, stacklevel=2) + else: # 'ignore' + diagnostics.log.debug(msg) + + if not stats_objs: + return cls.get_empty_instance() return cls.from_stats_objects(*stats_objs) @classmethod @@ -387,23 +455,23 @@ def from_stats_objects( """ Example: >>> stats1 = LineStats( - ... {('foo', 1, 'spam.py'): [(2, 10, 300)], - ... ('bar', 10, 'spam.py'): + ... {('spam.py', 1, 'foo'): [(2, 10, 300)], + ... ('spam.py', 10, 'bar'): ... [(11, 2, 1000), (12, 1, 500)]}, ... 1E-6) >>> stats2 = LineStats( - ... {('bar', 10, 'spam.py'): + ... {('spam.py', 10, 'bar'): ... [(11, 10, 20000), (12, 5, 1000)], - ... ('baz', 5, 'eggs.py'): [(5, 2, 5000)]}, + ... ('eggs.py', 5, 'baz'): [(5, 2, 5000)]}, ... 1E-7) >>> stats_combined = LineStats.from_stats_objects( ... stats1, stats2) >>> assert stats_combined.unit == 1E-6 >>> assert stats_combined.timings == { - ... ('foo', 1, 'spam.py'): [(2, 10, 300)], - ... ('bar', 10, 'spam.py'): + ... ('spam.py', 1, 'foo'): [(2, 10, 300)], + ... ('spam.py', 10, 'bar'): ... [(11, 12, 3000), (12, 6, 600)], - ... ('baz', 5, 'eggs.py'): [(5, 2, 500)]} + ... ('eggs.py', 5, 'baz'): [(5, 2, 500)]} """ timings, unit = cls._get_aggregated_timings([stats, *more_stats]) return cls(timings, unit) @@ -840,7 +908,7 @@ def show_func( func_name (str): name of profiled function - timings (List[Tuple[int, int, float]]): + timings (list[tuple[int, int, float]]): Measurements for each line (lineno, nhits, time). unit (float): diff --git a/line_profiler/line_profiler_utils.py b/line_profiler/line_profiler_utils.py index 887cdd55..eae5d27f 100644 --- a/line_profiler/line_profiler_utils.py +++ b/line_profiler/line_profiler_utils.py @@ -5,10 +5,23 @@ from __future__ import annotations import enum -import typing +import os +import sys +from collections.abc import Callable, Collection, Mapping, Sequence +from functools import partial +from pathlib import Path +from reprlib import Repr +from tempfile import mkstemp +from textwrap import indent +from types import MethodType +from typing import TYPE_CHECKING, Any, TypedDict, TypeVar +from typing_extensions import Self, Unpack -if typing.TYPE_CHECKING: - from typing_extensions import Self + +__all__ = ('StringEnum', 'CallbackRepr', 'block_indent', 'make_tempfile') + +# Note: `typing.AnyStr` deprecated since 3.13 +AnyStr = TypeVar('AnyStr', str, bytes) class _StrEnumBase(str, enum.Enum): @@ -49,7 +62,7 @@ def __str__(self) -> str: try: from enum import StrEnum as _StrEnum except ImportError: - if not typing.TYPE_CHECKING: # Don't confuse the typechecker + if not TYPE_CHECKING: # Don't confuse the typechecker _StrEnum = _StrEnumBase @@ -89,3 +102,251 @@ def _missing_(cls, value: object) -> Self | None: for name, instance in cls.__members__.items() } return members.get(value.casefold()) + + +class _ReprAttributes(TypedDict, total=False): + """ + Note: + We use this typed dict instead of directly supplying them in the + :py:meth:`CallbackRepr.__init__()` signature, because we don't + want to bother with the default values there. + """ + maxlevel: int + maxtuple: int + maxlist: int + maxarray: int + maxdict: int + maxset: int + maxfrozenset: int + maxdeque: int + maxstring: int + maxlog: int + maxother: int + fillvalue: str + indent: str | int | None + + +class CallbackRepr(Repr): + """ + :py:class:`reprlib.Repr` subclass to help with representing cleanup + callbacks, special-casing certain relevant object types (see + examples below). + + Example: + >>> from functools import partial + >>> from sys import version_info + + >>> class MyEnviron(dict): + ... def some_method(self) -> None: + ... ... + ... + >>> + >>> class MyRepr(CallbackRepr): + ... # Since we can't instantiate a new `os._Environ`, test + ... # the relevant method with a mock + ... repr_MyEnviron = CallbackRepr.repr__Environ + ... + >>> + >>> r = MyRepr(maxenv=3, maxargs=4, maxstring=15) + + Environ-dict formatting: + + >>> my_env = MyEnviron( + ... foo='1', + ... bar='2', + ... this_varname_is_long_but_isnt_truncated=( + ... "THIS VALUE IS TRUNCATED BECAUSE IT'S TOO LONG" + ... ), + ... baz='4', + ... ) + >>> print(r.repr(my_env)) + environ({'foo': '1', 'bar': '2', \ +'this_varname_is_long_but_isnt_truncated': 'THIS ... LONG', ...}) + + Partial-object formatting: + + >>> r.maxenv = 0 + >>> print(r.repr(my_env.some_method)) + + + Bound-method formatting: + + >>> r.maxargs = 0 + >>> callback_1 = partial(int, base=8) + >>> print(r.repr(callback_1)) + functools.partial(, ...) + + Indentation (Python 3.12+): + + >>> if version_info < (3, 12): + ... from pytest import skip + ... + ... skip( + ... '`Repr.indent` not available on {}.{},{}' + ... .format(*sys.version_info) + ... ) + + >>> r = MyRepr(maxenv=2, maxargs=4) + >>> r.indent = 2 + >>> callback_1 = partial(int, base=8) + >>> print(r.repr(callback_1)) + functools.partial( + , + base=8, + ) + + >>> callback_2 = partial(min, 5, 4, 3, 2, 1) + >>> r.indent = '----' + >>> print(r.repr(callback_2)) + functools.partial( + ----, + ----5, + ----4, + ----3, + ----2, + ----..., + ) + + >>> r.indent = ' ' + >>> r.maxenv = 2 + >>> print(r.repr(my_env.some_method)) + + """ + def __init__( + self, + *, + maxargs: int = 5, + maxenv: int = 3, + **kwargs: Unpack[_ReprAttributes] + ) -> None: + super().__init__() # kwargs are 3.12+ + valid_kwargs = ( + _ReprAttributes.__optional_keys__ + | _ReprAttributes.__required_keys__ + ) + for k, v in kwargs.items(): + if k in valid_kwargs: + setattr(self, k, v) + self.maxargs = maxargs + self.maxenv = maxenv + + def repr__Environ(self, env: os._Environ[AnyStr], level: int) -> str: + """ + Format :py:data:`os.environ` or :py:data:`os.environb`. + """ + get: Callable[[AnyStr], str] = partial(self.repr1, level=level-1) + # Truncate envvar values, but not their names + envvars = ['{!r}: {}'.format(k, get(v)) for k, v in env.items()] + return self._format_items(envvars, ('environ({', '})'), self.maxenv) + + def repr_method(self, method: MethodType, level: int) -> str: + """ + Format a :py:class:`types.MethodType`. + """ + instance = self.repr1(method.__self__, level-1) + func = getattr(method.__func__, '__qualname__', '?') + prefix, suffix = f'' + # Take care of possible multi-line reprs + return block_indent(instance, prefix) + suffix + + def repr_partial(self, ptl: partial, level: int) -> str: + """ + Format a :py:func:`functools.partial`. + """ + name = '{0.__module__}.{0.__qualname__}'.format(type(ptl)) + # The +1 is to account for `ptl.func` + return self._format_call( + level, (name + '(', ')'), self.maxargs + 1, + [ptl.func, *ptl.args], ptl.keywords, + ) + + def format_call(self, /, *args, **kwargs) -> str: + """ + Convenience method for Formating a call a la + :py:meth:`inspect.BoundArguments.__str__`. + + Example: + >>> r = CallbackRepr(maxargs=3, maxlist=3) + >>> print(r.format_call( + ... [1, 2, 3, 4, 5], 'foo', spam=1, ham=2, + ... )) + ([1, 2, 3, ...], 'foo', spam=1, ...) + """ + return self._format_call( + self.maxlevel, ('(', ')'), self.maxargs, args, kwargs, + ) + + def _format_call( + self, + level: int, + delims: tuple[str, str], + maxargs: int, + args: Sequence[Any], + kwargs: Mapping[str, Any], + ) -> str: + get: Callable[[Any], str] = partial(self.repr1, level=level-1) + args = [get(arg) for arg in args] + args.extend('{}={}'.format(k, get(v)) for k, v in kwargs.items()) + return self._format_items(args, delims, maxargs) + + def _format_items( + self, + items: Collection[str], + delims: tuple[str, str], + maxlen: int | None = None, + ) -> str: + start, end = delims + if maxlen is not None and len(items) > maxlen: + items = list(items)[:maxlen] + ['...'] + indent_prefix: str | None = self._get_indent() + if indent_prefix is None or not items: + return '{}{}{}'.format(start, ', '.join(items), end) + return '\n'.join([ + start, *(indent(item + ',', indent_prefix) for item in items), end, + ]) + + if sys.version_info >= (3, 12): + # Note: `.indent` only available since 3.12 + def _get_indent(self) -> str | None: + indent = self.indent + if indent is None or isinstance(indent, str): + return indent + return ' ' * indent + else: + @staticmethod + def _get_indent() -> None: + return None + + +def block_indent(string: str, prefix: str, fill_char: str = ' ') -> str: + r""" + Example: + >>> string = 'foo\nbar\nbaz' + >>> print(string) + foo + bar + baz + >>> print(block_indent(string, '++++', '-')) + ++++foo + ----bar + ----baz + """ + width = len(prefix) + return prefix + indent(string, fill_char * width)[width:] + + +def make_tempfile(**kwargs) -> Path: + """ + Convenience wrapper around :py:func:`tempfile.mkstemp`, discarding + and closing the integer handle (which if left unattended causes + problems on some platforms). + """ + handle, fname = mkstemp(**kwargs) + try: + return Path(fname) + finally: + os.close(handle) diff --git a/line_profiler/rc/line_profiler.toml b/line_profiler/rc/line_profiler.toml index 6680c06a..bdd8549b 100644 --- a/line_profiler/rc/line_profiler.toml +++ b/line_profiler/rc/line_profiler.toml @@ -94,6 +94,9 @@ preimports = true # - `prof-imports` (bool): # `--prof-imports` (true) or `--no-prof-imports` (false) prof-imports = false +# - `prof-child-procs` (bool): +# `--prof-child-procs` (true) or `--no-prof-child-procs` (false) +prof-child-procs = false # - Misc flags # - `verbose` (count): @@ -206,3 +209,78 @@ hits = 9 time = 12 perhit = 8 percent = 8 + +# XXX: --- Start of implementation details --- +# `line_profiler._child_process_profiling` settings + +[tool.line_profiler.child_processes.pth_files] + +# - `pth_files.prefix`, `.suffix` (str): +# Affixes to use for the stem of the name of the .pth file created +# Note: may be useful to tweak this in case of issues with .pth file +# precedence. +prefix = '_line_profiler-profiling-hook-' +suffix = '' + +[tool.line_profiler.child_processes.multiprocessing] + +# - `multiprocessing.catch_sigterm` (bool): +# Whether to set a hook for the child-process session caches to +# `.cleanup()` on SIGTERM +catch_sigterm = true + +# List of individual patches to apply to `multiprocessing` + +[tool.line_profiler.child_processes.multiprocessing.patches] + +# - `multiprocessing.patches.pool` (bool): +# Whether to patch `multiprocessing.pool`: +# - POSIX: +# Patch `multiprocessing.pool.worker()` so that each child process +# write profiling output as it runs out of tasks +# - Windows: +# Patch `multiprocessing.pool.Pool` so that each task writes +# profiling output before pushing the result back to the parent +# process +pool = true +# - `multiprocessing.patches.process` (bool): +# Whether to patch `multiprocessing.process.BaseProcess`, so that each +# child process write profiling output before exiting +process = true +# NOTE: for the best result, stick to the default and have both applied + +# - `multiprocessing.patches.logging` (bool): +# Whether to patch logging functions in `multiprocessing.util`, so +# that the internal logs of `multiprocessing` are teed to the session +# cache's debug logs +logging = false +# - `multiprocessing.patches.child_pids` (bool): +# Whether to patch `multiprocessing.pool.worker()` and +# `multiprocessing.pool.Pool` so that the the parent process keeps +# track of the workload sent to each child, suppressing stray warnings +# about empty output files where appropriate +child_pids = true + +# Polling controls, used for the `pool` patch + +[tool.line_profiler.child_processes.multiprocessing.polling] + +# - `multiprocessing.polling.cooldown` (float): +# Cooldown time (seconds) before successive polls on child processes +# (set to <= 0 to disable cooldowns) +cooldown = 0.03125 # 1/32-nd of a second +# - `multiprocessing.polling.timeout` (float): +# Time (seconds) before the main process disregards the alive-ness of +# child processes, and unblocks calls to `.terminate()` a (most +# probably errored-out) child anyway (set to <= 0 to disable timeouts) +timeout = 0.25 +# - `multiprocessing.polling.on_timeout` +# (Literal['error', 'warn', 'ignore']): +# What to do when the above timeout is exhausted, before actually +# `.terminate()`-ing the child process: +# - 'error': raise an error +# - `warn`: issue a warning +# - `ignore`: nothing +on_timeout = 'warn' + +# XXX: --- End of implementation details --- diff --git a/line_profiler/toml_config.py b/line_profiler/toml_config.py index 781dc60d..236e86da 100644 --- a/line_profiler/toml_config.py +++ b/line_profiler/toml_config.py @@ -106,7 +106,10 @@ def get_subconfig( get_subtable(self.conf_dict, headers, allow_absence=allow_absence), ) new_subtable = [*self.subtable, *headers] - return type(self)(new_dict, self.path, new_subtable) + new_instance = type(self)(new_dict, self.path, new_subtable) + if copy: + new_instance = new_instance.copy() + return new_instance @classmethod def from_default(cls, *, copy: bool = True) -> ConfigSource: @@ -355,7 +358,8 @@ def iter_configs(dir_path): def get_subtable( - table: Mapping[K, Mapping], keys: Sequence[K], *, allow_absence: bool = True + table: Mapping[K, Mapping], keys: Sequence[K], *, + allow_absence: bool = True, ) -> Mapping: """ Arguments: diff --git a/setup.py b/setup.py index 739d5d96..a0810cc6 100755 --- a/setup.py +++ b/setup.py @@ -314,7 +314,9 @@ def run_cythonize(force=False): setupkw['long_description_content_type'] = 'text/x-rst' setupkw['license'] = 'BSD' setupkw['packages'] = list(setuptools.find_packages()) - setupkw['py_modules'] = ['kernprof', 'line_profiler'] + setupkw['py_modules'] = [ + 'kernprof', 'line_profiler', '_line_profiler_hooks', + ] setupkw['python_requires'] = '>=3.10' setupkw['license_files'] = ['LICENSE.txt', 'LICENSE_Python.txt'] setupkw['package_data'] = {'line_profiler': ['py.typed', '*.pyi', '*.toml']} diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..205ecf8d --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,790 @@ +""" +A simple :py:deco:`pytest.mark.retry` decorator. +Function-scoped fixtures are re-fetched between retries. + +Note: + - This file is designed to also function as a standalone + ``conftest.py`` file. + + - Adapted from `pytest-mark-retry`_. + +.. _pytest-mark-retry: https://gitlab.com/TTsangSC/pytest-mark-retry +""" +from __future__ import annotations + +import ast +import dataclasses +import os +import sys +import warnings +from collections.abc import ( + Callable, Collection, Generator, Hashable, Iterable, Mapping, +) +from functools import cached_property, lru_cache, partial +from importlib.util import find_spec +from inspect import Signature, signature +from operator import contains +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Any, ClassVar, Literal, Protocol, TypedDict, TypeVar, cast, overload, +) +from typing_extensions import Self + +import pytest +from _pytest.compat import NOTSET +from _pytest.fixtures import SubRequest +from _pytest.nodes import Node +from _pytest.scope import Scope +from _pytest.unittest import TestCaseFunction +try: + from pytest import TerminalReporter # type: ignore +except ImportError: # pytest < ~8.4 + from _pytest.terminal import TerminalReporter # type: ignore + + +__all__ = ( + 'RetryMarker', + 'RetryMarkerWarning', + 'RetryConditionFailure', + 'pytest_addhooks', + 'pytest_configure', + 'pytest_terminal_summary', +) + +_Status = Literal['passed', 'failed', 'skipped'] +_Require = Literal['any', 'all'] +F = TypeVar('F', bound=TestCaseFunction) +FCls = TypeVar('FCls', bound='type[TestCaseFunction]') +T = TypeVar('T') + +_FUNCTION_SCOPE = Scope.Function + + +class _PyfuncCallImpl(Protocol): + def __call__(self, *, pyfuncitem: pytest.Function) -> Any: + ... + + +class _RetryError(RuntimeError): + """ + Base class for errors associated with retrying tests. + """ + + +class RetryMarkerWarning(_RetryError, UserWarning): + """ + Warning issued when the :deco:`pytest.mark.retry` markers on a test + fail to resolve to a valid :py:class:`RetryMarker` instance. + """ + @classmethod + def warn_from_error( + cls, xc: Exception, *args, **kwargs + ) -> None: + msg = 'disregarding invalid `@pytest.mark.retry` marker' + msg = f'{msg}: ({_format_exception(xc)})' + if 'PYTEST_CURRENT_TEST' in os.environ: + msg = f'{os.environ["PYTEST_CURRENT_TEST"]}: {msg}' + if sys.version_info < (3, 12): # Compatibility + kwargs.pop('skip_file_prefixes', None) + return warnings.warn(msg, cls, *args, **kwargs) + + +class RetryConditionFailure(_RetryError): + """ + Error raised when an attempt to retry a test failed because we + can't :py:func:`eval` the condition. + """ + def __init__( + self, + previous_error: Exception, + condition_error: Exception, + # Should't be here if the condition isn't a string, but whatever + condition: str | bool | None = None + ) -> None: + self.previous_error = previous_error + self.condition_error = condition_error + self.condition = condition + super().__init__(self._format_message()) + + condition_error.__cause__ = previous_error + self.__cause__ = condition_error + + def _format_message(self) -> str: + prev = _format_exception(self.previous_error) + if len(prev.split()) > 1: + prev = f'({prev})' + condition = _format_exception(self.condition_error) + if self.condition: + condition = f'(condition: {self.condition!r} -> {condition})' + return '{} -> {}'.format(prev, condition) + + +@dataclasses.dataclass +class _RetryEntry: + func: pytest.Function + retries: int + status: _Status + + def _get_name(self, with_params: bool) -> str: + path, prefix = self._name_prefixes + if with_params: + name = self.func.name + else: + name = self.func.originalname + if prefix: + name = f'{prefix}.{name}' + if path: + name = f'{path}::{name}' + return name + + @classmethod + def add_entry(cls, func: pytest.Function, *args, **kwargs) -> Self: + assert func.config is not None + entry = cls(func, *args, **kwargs) + entry.get_entries(func.config).append(entry) + return entry + + @staticmethod + def get_entries(config: pytest.Config) -> list[_RetryEntry]: + return config.stash.setdefault(_RETRY_ENTRIES_KEY, []) + + @property + def full_name(self) -> str: + return self._get_name(True) + + @property + def full_original_name(self) -> str: + return self._get_name(False) + + @cached_property + def _name_prefixes(self) -> tuple[str, str]: + chunks: list[str] = [] + node: Node | None = self.func.parent + path = '' + seen: set[int] = set() + while True: + if id(node) in seen: + break + else: + seen.add(id(node)) + if isinstance(node, (pytest.Module, pytest.Package)): + if node.path: + npath = node.path + try: + npath = npath.relative_to(Path.cwd()) + except ValueError: # Not a subpath + pass + path = str(npath) + else: + path = repr(node) + break + name: str | None = getattr(node, 'name', None) + if not name: + break + chunks.append(name) + return path, '.'.join(reversed(chunks)) + + +_RETRY_ENTRIES_KEY = pytest.StashKey[list[_RetryEntry]]() + + +def _retry_marker_sig_helper( + retries=1, *, + exceptions=None, reset_fixtures=None, condition=None, require=None, +): + """ + Dummy callable helping with :py:meth:`RetryMarker.from_test_func`, + so that we can handle marker stacking in a sane way withou having to + special-case the defaults of the class constructor. + """ + pass + + +if sys.version_info >= (3, 10): + _keyword = partial(dataclasses.field, kw_only=True) +else: + _keyword = dataclasses.field + + +class _RetryMarkerArgs(TypedDict, total=False): + retries: int + exceptions: type[Exception] | tuple[type[Exception], ...] + reset_fixtures: bool | Collection[str] + condition: str | bool | None + require: _Require + + +@dataclasses.dataclass +class RetryMarker: + """ + Object representing the :deco:`pytest.mark.retry` marks on a test + function, managing test retries. + + Attributes: + retries (int): + Number of retries to attempt; should be positive. + exceptions (type[Exception] | tuple[type[Exception], ...]): + "Allowed" exception type(s) which result in retries; + mismatching exceptions are propagated normally, resuling in + a failure + reset_fixtures (bool | Collection[str]): + Whether to reset the function-scoped fixtures when retrying; + if a collection of names, only reset matching fixtures + condition (str | bool | None): + Only attempt retries if this is true (or :py:const:`None`); + if a string, it is :py:func:`eval`-ed to the condition + before each retry using the globals of the test function, + and the fixtures and parametrizations as the locals + require (Literal['any', 'all']): + Whether 'any' or 'all' attempts to run a test function + should pass for the test to pass + + See also: + :py:meth:`.from_arguments` for examples + """ + retries: int = 1 + exceptions: type[Exception] | tuple[type[Exception], ...] = ( + _keyword(default=Exception) + ) + reset_fixtures: bool | Collection[str] = _keyword(default=True) + condition: str | bool | None = _keyword(default=None) + require: _Require = _keyword(default='any') + + name: ClassVar[str] = 'retry' + _sig: ClassVar[Signature] = signature(_retry_marker_sig_helper) + + def __post_init__(self) -> None: + # Normalize `.retries` + try: + self.retries = max(0, int(self.retries)) + except Exception as e: + msg = f'.retries = {self.retries!r}' + msg = f'{msg}: not a valid number {_format_exception(e)}' + raise TypeError(msg).with_traceback(e.__traceback__) + # Check `.exceptions` + if isinstance(self.exceptions, tuple): + xc: tuple[type[Exception], ...] = self.exceptions + else: + xc = self.exceptions, + if not all( + isinstance(X, type) and issubclass(X, Exception) for X in xc + ): + raise TypeError( + f'.exceptions = {self.exceptions!r}: ' + 'expected an exception type or a tuple thereof' + ) + # Check `.condition` + if isinstance(self.condition, str): + try: + ast.parse(self.condition, mode='eval') + except Exception as e: + msg = f'.condition = {self.condition!r}' + msg = f'{msg}: not a valid expression ({_format_exception(e)})' + raise ValueError(msg).with_traceback(e.__traceback__) + # Check `.require` + if self.require not in ('all', 'any'): + msg = f'.require = {self.require!r}: expected \'any\' or \'all\'' + raise TypeError(msg) + + def manage_call(self, impl: _PyfuncCallImpl, func: pytest.Function) -> Any: + """ + Manage the call(s) to a function. + + Args: + impl (Callable): + Implementation of + ``pytest_pyfunc_call(pyfuncitem: pytest.Function) \ +-> Any`` + func (pytest.Function): + Test function item + + Returns: + Value of the first successful call to + ``pytest_pyfunc_call(pyfuncitem=func)`` + """ + check_fixture_name: Callable[[str], bool] + reset_fixtures: Callable[[pytest.Function], None] + if self.reset_fixtures: + if isinstance(self.reset_fixtures, Collection): + check_fixture_name = partial(contains, self.reset_fixtures) + else: + check_fixture_name = lambda _: True # noqa: E731 + reset_fixtures = partial( + self._reset_between_retries, + reset_fixtures=True, + should_reset=check_fixture_name, + ) + else: + reset_fixtures = self._reset_between_retries + + result: Any = None + xc: Exception | None = None + for i in range(1 + self.retries): + if i: + reset_fixtures(func) + cond, error = self._check_condition(self.condition, func) + if error: # Bail + # XXX: would be nice if we can directly force an + # internal error, but that doesn't seem to be + # possible from within `pytest_pyfunc_call()`; + # directly calling `pytest_internalerror()` + # results in botched teardown and weird + # tracebacks, and leaves the test session in a + # bad state... + assert xc is not None + raise RetryConditionFailure(xc, cond, self.condition) + if not cond: + i -= 1 + break + try: + result = impl(pyfuncitem=func) + except self.exceptions as e: + # `ty` doesn't agree that `e` is an exception (#3432)... + xc = cast(Exception, e) + if self.require == 'all': + break + except Exception as e: # Uncaught exc. -> break to raise + xc = e + break + else: # Correct execution -> break to return + xc = None + if self.require == 'any': + break + if i: + if xc is None: + status: _Status = 'passed' + elif isinstance(xc, pytest.skip.Exception): + status = 'skipped' + else: + status = 'failed' + _RetryEntry.add_entry(func, i, status) + if xc is None: + return result + else: + raise xc + + @staticmethod + def _check_condition( + condition: str | bool | None, func: pytest.Function, + ) -> tuple[Any, Literal[False]] | tuple[Exception, Literal[True]]: + if condition in (True, None): # Always retry + return (True, False) + if condition in (False,): # Never retry + return (False, False) + + if TYPE_CHECKING: # Help narrowing + assert isinstance(condition, str) + global_ns: dict[str, Any] | None = None + try: + global_ns = func.obj.__globals__ + except AttributeError: # Not a `types.FunctionType` + pass + local_ns = func.funcargs + try: + return (eval(condition, global_ns, local_ns), False) + except Exception as e: + return (e, True) + + @staticmethod + def _reset_between_retries( + func: pytest.Function, + reset_fixtures: bool = False, + should_reset: Callable[[str], bool] = lambda _: False, + ) -> None: + """ + Note: + This makes HEAVY use of :py:mod`_pytest` internals. + """ + def cleanup_fixture(fdef: pytest.FixtureDef[Any]) -> None: + if not ( + fdef.scope == 'function' + and should_reset(fdef.argname) + and getattr(fdef, 'cached_result', None) is not None + ): + return + fdef.cached_result = None + finalize(fdef) + + def finalize(fdef: pytest.FixtureDef[Any]) -> None: + assert fdef.scope == 'function' + + # Plagiarized code from + # `FixtureRequest._get_active_fixture_def()` + try: + callspec = func.callspec + except AttributeError: + callspec = None + if callspec is not None and fdef.argname in callspec.params: + value = callspec.params[fdef.argname] + index = callspec.indices[fdef.argname] + else: + value, index = NOTSET, 0 + + with warnings.catch_warnings(): + warnings.simplefilter( + 'ignore', pytest.PytestDeprecationWarning, + ) + fdef.finish(SubRequest( + request=func._request, + scope=_FUNCTION_SCOPE, + param=value, + param_index=index, + fixturedef=fdef, + )) + + def unique( + items: Iterable[T], key: Callable[[T], Hashable] = id, + ) -> Generator[T, None, None]: + seen: set[Hashable] = set() + for item in items: + hashed = key(item) + if hashed in seen: + continue + seen.add(hashed) + yield item + + def iter_all_fixture_defs( + func: pytest.Function, + ) -> Generator[pytest.FixtureDef[Any], None, None]: + fdef_mapping: Mapping[Any, Iterable[pytest.FixtureDef[Any]]] + # Somehow `mypy` doesn't trust the below but `ty` does... + for fdef_mapping in [ # type:ignore[assignment] + func._fixtureinfo.name2fixturedefs, + func._request._arg2fixturedefs, + func.session._fixturemanager._arg2fixturedefs, + ]: + for fixture_defs in fdef_mapping.values(): + yield from fixture_defs + + if reset_fixtures: + # Beside clearing `.funcargs`, `._initrequest()` also resets + # the `TopRequest` instance that `func` has (`._request`) + func._initrequest() + for fixture_def in unique(iter_all_fixture_defs(func)): + cleanup_fixture(fixture_def) + else: + # Fixture values will naturally refill, possibly from caches + func.funcargs.clear() + func.setup() + + @classmethod + def pytest_pyfunc_call(cls, pyfuncitem: pytest.Function) -> Any: + """ + Run the :py:class:`pytest.Function` object with the requisite + number of retries if necessary. + """ + pm = pyfuncitem.config.pluginmanager + pytest_pyfunc_call: _PyfuncCallImpl = pm.subset_hook_caller( + 'pytest_pyfunc_call', [cls], + ) + try: + helper = cls.from_test_func(pyfuncitem) + except Exception as e: + # Level 1 is the `.warn_from_error()` frame, 2 is here, 3 is + # where the error actually happened + warn = RetryMarkerWarning.warn_from_error + skip_ = {_find_module_path('_pytest'), _find_module_path('pluggy')} + skip = cast(tuple[str, ...], tuple(skip_ - {None})) + warn(e, stacklevel=3, skip_file_prefixes=skip) + else: + if helper.is_active: + return helper.manage_call(pytest_pyfunc_call, pyfuncitem) + return pytest_pyfunc_call(pyfuncitem=pyfuncitem) + + @classmethod + def from_test_func(cls, func: pytest.Function, /) -> Self: + """ + Returns: + Instance combining the stack of :deco:`pytest.mark.retry` + decorators on the :py:class:`pytest.Function`. + """ + marks = (m for m in func.iter_markers() if m.name == cls.name) + return cls.from_arguments(cls._get_marker_args(mark) for mark in marks) + + @classmethod + @overload + def from_arguments(cls, args: Iterable[_RetryMarkerArgs] = (), /) -> Self: + ... + + @classmethod + @overload + def from_arguments(cls, *args: _RetryMarkerArgs) -> Self: + ... + + @classmethod + def from_arguments(cls, *args) -> Self: + """ + Invocations: + (, ...) -> RetryMarker + ([, ...]) -> RetryMarker + + Examples: + >>> empty = RetryMarker.from_arguments() + >>> assert not empty.retries + >>> assert empty == RetryMarker.from_arguments([]) + + >>> default = RetryMarker.from_arguments({}) + >>> assert default.retries == 1 + >>> assert default.exceptions in ((Exception,), Exception) + >>> assert default.reset_fixtures == True + >>> assert default.condition is None + >>> assert default.require == 'any' + >>> assert default.is_active + >>> assert default == RetryMarker.from_arguments([{}]) + + Some arguments result in inactive instances (i.e. no + retries): + + >>> bad_xc = RetryMarker.from_arguments({'exceptions': ()}) + >>> assert not bad_xc.exceptions + >>> assert not bad_xc.is_active + + >>> bad_retries = RetryMarker.from_arguments( + ... {}, {}, {'retries': -5}, + ... ) + >>> assert not bad_retries.retries + >>> assert not bad_retries.is_active + + >>> bad_cond = RetryMarker.from_arguments( + ... {'condition': False}, + ... ) + >>> assert bad_cond.condition == False + >>> assert not bad_cond.is_active + + Congruent values are unioned: + + >>> stacked_xcs = RetryMarker.from_arguments([ + ... {'exceptions': ()}, + ... {'exceptions': ValueError}, + ... {'retries': 3, 'exceptions': (TypeError, OSError)}, + ... ]) + >>> assert stacked_xcs.retries == 5 + >>> assert set(stacked_xcs.exceptions) == { + ... ValueError, TypeError, OSError, + ... } + + >>> stacked_resets_1 = RetryMarker.from_arguments( + ... {'reset_fixtures': ['foo', 'bar']}, + ... {'reset_fixtures': ['baz']}, + ... ) + >>> sorted(stacked_resets_1.reset_fixtures) + ['bar', 'baz', 'foo'] + + Incongruent values override one another: + + >>> stacked_resets_2 = RetryMarker.from_arguments( + ... {'reset_fixtures': ['foo', 'bar']}, + ... {'reset_fixtures': False}, + ... ) + >>> assert stacked_resets_2.reset_fixtures == False + + >>> stacked_conditions = RetryMarker.from_arguments( + ... {'condition': 'foo==bar'}, + ... {'condition': False}, + ... ) + >>> assert stacked_conditions.condition == False + >>> assert not stacked_conditions.is_active + + >>> stacked_requires = RetryMarker.from_arguments( + ... {'require': 'any'}, + ... {'require': 'all'}, + ... ) + >>> assert stacked_requires.require == 'all' + """ + retries: int = 0 + xc: set[type[Exception]] | None = None + reset_fixtures: bool | set[str] = True + condition: bool | str | None = None + require: _Require | None = None + + if args: + if isinstance(args[0], Mapping): + iter_args = args + else: + assert len(args) == 1 + iter_args = args[0] + else: + iter_args = args + + # `ty` needs some help here... hence the `cast()` + for bound_args in cast(Iterable[_RetryMarkerArgs], iter_args): + retries += bound_args.get('retries', 1) + if 'exceptions' in bound_args: + xc_new = bound_args['exceptions'] + if xc is None: + xc = set() + if isinstance(xc_new, type): + xc.add(xc_new) + else: + xc.update(xc_new) + if 'reset_fixtures' in bound_args: + rf_new = bound_args['reset_fixtures'] + if isinstance(rf_new, Collection): + if isinstance(reset_fixtures, Collection): + reset_fixtures.update(rf_new) + else: + reset_fixtures = set(rf_new) + else: + reset_fixtures = bool(rf_new) + if 'condition' in bound_args: + condition = bound_args['condition'] + if 'require' in bound_args: + require = bound_args['require'] + + kwargs: _RetryMarkerArgs = { + 'retries': retries, 'reset_fixtures': reset_fixtures, + } + if xc is not None: + kwargs['exceptions'] = tuple(xc) + if condition is not None: + kwargs['condition'] = condition + if require is not None: + kwargs['require'] = require + return cls(**kwargs) + + @classmethod + def _get_marker_args(cls, mark: pytest.Mark) -> _RetryMarkerArgs: + args = cls._sig.bind(*mark.args, **mark.kwargs).arguments + return cast(_RetryMarkerArgs, args) + + @property + def is_active(self) -> bool: + """ + Whether the instance should possibly attempt retries in any + condition + """ + if not self.retries: + return False + if self.exceptions == (): + return False + if self.condition in (False,): + return False + return True + + +def _pluralize(noun: str, count: int, plural: str | None = None) -> str: + if plural is None: + plural = noun + 's' + return f'{count} {noun if count == 1 else plural}' + + +def _format_exception(xc: Exception) -> str: + msg = type(xc).__name__ + if str(xc): + return f'{msg}: {xc}' + return msg + + +@lru_cache() +def _find_module_path(module: str) -> str | None: + spec = find_spec(module) + if spec is None or spec.origin is None: + return None + file = Path(spec.origin) + if not file.exists(): + return None + if file.name == '__init__.py': # Package + file = file.parent + return str(file) + + +def pytest_addhooks(pluginmanager: pytest.PytestPluginManager) -> None: + """ + Register :py:class:`RetryMarker` as a plugin so that its + :py:meth:`pytest_pyfunc_call` method can safely call other + implementations without recursing to itself. + """ + pluginmanager.register(RetryMarker) + + +def pytest_configure(config: pytest.Config) -> None: + """ + Register the :py:deco:`pytest.mark.retry` marker. + """ + help_text = ' '.join(""" + retry(retries=1, *, exceptions=Exception, \ +reset_fixtures=True, condition=None, require='any'): + + mark the test for retrying upon failure. + + Args (all optional): + retries (int): + Max number of retries for the (sub-)test; + exceptions (type[Exception] | tuple[type[Exception], ...]): + Error types which trigger a retry when caught; + reset_fixtures (bool | Collection[str]): + Names of function-scoped fixtures to reset between retries, + `True` reset all such fixtures, + `False` none thereof; + condition (bool | str | None): + Optional condition for retry: + if a boolean, only retry if true; + if a string, only retry if it `eval()`s true + (w/globals of the test module and locals from the fixtures + and parametrizations); + require (Literal['any', 'all']): + If 'any', stop retrying and record a test pass if ANY + attempt passes; + if 'all', only record a test pass if ALL attempts pass + """.split()) + config.addinivalue_line('markers', help_text) + + +def pytest_terminal_summary( + terminalreporter: TerminalReporter, config: pytest.Config, +) -> None: + """ + Write a summary section about rerun tests. + """ + def get_summary(status: _Status, entries: list[_RetryEntry]) -> str: + return f'{_pluralize("test", len(entries))} {status} with retries' + + def group_subtests( + entries: list[_RetryEntry] + ) -> dict[str, list[_RetryEntry]]: + result: dict[str, list[_RetryEntry]] = {} + for entry in entries: + result.setdefault(entry.full_original_name, []).append(entry) + return result + + formatting = {'yellow': True} + write_line: Callable[[str], None] = partial( + terminalreporter.write_line, **formatting # type: ignore + ) + write_header = partial( + terminalreporter.write_sep, '=', 'retries summary', **formatting + ) + write_newline = partial(write_line, '') + try: + verbosity: int = config.get_verbosity() # type: ignore + except AttributeError: # pytest < 8.0 + verbosity = int(config.option.verbose) + retry_entries: dict[_Status, list[_RetryEntry]] = {} + for entry in _RetryEntry.get_entries(config): + retry_entries.setdefault(entry.status, []).append(entry) + if not retry_entries: + return + if verbosity > 0: + write_newline() + write_header() + write_newline() + for status, entries in retry_entries.items(): + write_line(get_summary(status, entries) + ':') + for entry in entries: + write_line( + f' {entry.full_name}: ' + f'retried {_pluralize("time", entry.retries)}' + ) + write_newline() + else: + write_header() + for status, entries in retry_entries.items(): + tests: list[str] = [] + for name, children in group_subtests(entries).items(): + if len(children) == 1: + tests.append(children[0].full_name) + else: + msg = f'{name} ({_pluralize("subtest", len(children))})' + tests.append(msg) + write_line(f'{get_summary(status, entries)}: {", ".join(tests)}') diff --git a/tests/test_child_procs.py b/tests/test_child_procs.py index 10be3f78..86033b3d 100644 --- a/tests/test_child_procs.py +++ b/tests/test_child_procs.py @@ -1,33 +1,127 @@ from __future__ import annotations +import dataclasses +import enum +import inspect +import itertools +import multiprocessing.pool +import operator import os +import pickle +import re +import shlex import subprocess import sys -from collections.abc import Callable, Generator, Mapping +import sysconfig +import threading +import warnings +from abc import ABC, abstractmethod +from collections.abc import ( + Callable, + Collection, Generator, Iterable, Iterator, Mapping, Sequence, Set, +) +from contextlib import ExitStack +from functools import lru_cache, partial, wraps +from io import BytesIO, StringIO +from importlib import import_module +from multiprocessing.pool import ( # type: ignore + ExceptionWithTraceback as ExceptionHelper, +) +from numbers import Real from pathlib import Path +from runpy import run_path from tempfile import TemporaryDirectory from textwrap import dedent, indent +from time import monotonic +from types import MappingProxyType, ModuleType, TracebackType +from typing import ( + TYPE_CHECKING, Any, Generic, IO, Literal, Protocol, TypeVar, + cast, final, overload, +) +from typing_extensions import Self, ParamSpec +from uuid import uuid4 import pytest import ubelt as ub +from _line_profiler_hooks import load_pth_hook +from line_profiler._child_process_profiling.cache import LineProfilingCache +from line_profiler._child_process_profiling.runpy_patches import ( + create_runpy_wrapper, +) +from line_profiler._child_process_profiling.multiprocessing_patches import ( + _Poller, MPConfig, _PATCHED_MARKER, _PATCHES as MP_PATCHES, +) +from line_profiler.autoprofile.util_static import modpath_to_modname +from line_profiler.curated_profiling import ( + CuratedProfilerContext, ClassifiedPreimportTargets, +) +from line_profiler.line_profiler import LineProfiler, LineStats +from line_profiler.toml_config import ConfigSource + + +T = TypeVar('T') +T1 = TypeVar('T1') +T2 = TypeVar('T2') +TCtx_ = TypeVar('TCtx_') +PS = ParamSpec('PS') +C = TypeVar('C', bound=Callable[..., Any]) NUM_NUMBERS = 100 NUM_PROCS = 4 -TEST_MODULE_BODY = dedent(f""" +START_METHODS = set(multiprocessing.get_all_start_methods()) + +_TEST_TIMEOUT = 5 # Seconds +_DEBUG = True +_WINDOWS = sys.platform == 'win32' + + +def strip(s: str) -> str: + return dedent(s).strip('\n') + + +EXTERNAL_MODULE_BODY = strip(""" from __future__ import annotations -from argparse import ArgumentParser -from multiprocessing import Pool -def my_sum(x: list[int]) -> int: - result: int = 0 +def my_external_sum(x: list[int], fail: bool = False) -> int: + result: int = 0 # GREP_MARKER[EXT-INVOCATION] for item in x: - result += item + result += item # GREP_MARKER[EXT-LOOP] + if fail: + raise RuntimeError('forced failure') + return result +""") + +TEST_MODULE_TEMPLATE = strip(""" +from __future__ import annotations + +from argparse import ArgumentParser +from collections.abc import Callable +from multiprocessing import dummy, get_context, Pool +from typing import Literal + +from {EXT_MODULE} import my_external_sum + + +def my_local_sum(x: list[int], fail: bool = False) -> int: + result: int = 0 # GREP_MARKER[LOCAL-INVOCATION] + # The reversing is to prevent bytecode aliasing with + # `my_external_sum()` (see issue #424, PR #425) + for item in reversed(x): + result += item # GREP_MARKER[LOCAL-LOOP] + if fail: + raise RuntimeError('forced failure') return result -def sum_in_child_procs(length: int, n: int) -> int: +def sum_in_child_procs( + length: int, n: int, my_sum: Callable[[list[int]], int], + start_method: Literal[ + 'fork', 'forkserver', 'spawn', 'dummy' + ] | None = None, + fail: bool = False, +) -> int: my_list: list[int] = list(range(1, length + 1)) sublists: list[list[int]] = [] subsums: list[int] @@ -37,197 +131,3080 @@ def sum_in_child_procs(length: int, n: int) -> int: while my_list: sublist, my_list = my_list[:sublength], my_list[sublength:] sublists.append(sublist) - with Pool(n) as pool: - subsums = pool.map(my_sum, sublists) + if start_method == 'dummy': + pool = dummy.Pool(n) + elif start_method: + pool = get_context(start_method).Pool(n) + else: + pool = Pool(n) + with pool: + subsums = pool.starmap(my_sum, [(sl, fail) for sl in sublists]) pool.close() pool.join() - return my_sum(subsums) + return my_sum(subsums, fail) def main(args: list[str] | None = None) -> None: parser = ArgumentParser() parser.add_argument('-l', '--length', type=int, default={NUM_NUMBERS}) parser.add_argument('-n', type=int, default={NUM_PROCS}) + parser.add_argument( + '-s', '--start-method', + choices=['fork', 'forkserver', 'spawn'], default=None, + ) + parser.add_argument('-f', '--force-failure', action='store_true') + parser.add_argument( + '--local', + action='store_const', + dest='my_sum', + default=my_external_sum, + const=my_local_sum, + ) options = parser.parse_args(args) - print(sum_in_child_procs(options.length, options.n)) + print(sum_in_child_procs( + options.length, options.n, options.my_sum, + start_method=options.start_method, + fail=options.force_failure, + )) if __name__ == '__main__': main() -""").strip('\n') +""") + + +# ============================== Fixtures ============================== + + +@dataclasses.dataclass +class _ModuleFixture: + """ + Convenience wrapper around a Python source file which represents an + importable module. + """ + path: Path + monkeypatch: pytest.MonkeyPatch + dependencies: Collection[_ModuleFixture] = () + + def install( + self, *, + local: bool = False, children: bool = False, deps_only: bool = False, + ) -> None: + """ + Set the module at :py:attr:`~.path` up to be importable. + + Args: + local (bool): + Make it importable for the CURRENT process (via + :py:data:`sys.path`). + children (bool): + Make it importable for CHILD processes (via + ``os.environ['PYTHONPATH']``). + deps_only (bool): + If true, only does the equivalent setup for + dependencies. + """ + for dep in self.dependencies: + dep.install(local=local, children=children) + if deps_only: + return + path = str(self.path.parent) + if local: + self.monkeypatch.syspath_prepend(path) + if children: + self.monkeypatch.setenv('PYTHONPATH', path, prepend=os.pathsep) + + def _import_module_helper(self) -> Generator[ModuleType, None, None]: + def iter_module_names( + module: _ModuleFixture, + ) -> Generator[str, None, None]: + yield module.name + for dep in module.dependencies: + yield from iter_module_names(dep) + + self.install(local=True, children=True) + try: + yield import_module(self.name) + finally: + for name in set(iter_module_names(self)): + sys.modules.pop(name, None) + + @staticmethod + def propose_name(prefix: str) -> Generator[str, None, None]: + """ + Propose a valid module name that isn't already occupied. + """ + while True: + name = '_'.join([prefix] + str(uuid4()).split('-')) + if name not in sys.modules: + assert name.isidentifier() + yield name + + @property + def name(self) -> str: + return self.path.stem + + +# Only write the files once per test session @pytest.fixture(scope='session') -def test_module() -> Generator[Path, None, None]: +def _ext_module() -> Generator[Path, None, None]: + name = next(_ModuleFixture.propose_name('my_ext_module')) with TemporaryDirectory() as mydir_str: my_dir = Path(mydir_str) my_dir.mkdir(exist_ok=True) - my_module = my_dir / 'my_test_module.py' - with my_module.open('w') as fobj: - fobj.write(TEST_MODULE_BODY + '\n') + my_module = my_dir / f'{name}.py' + my_module.write_text(EXTERNAL_MODULE_BODY) yield my_module -@pytest.mark.parametrize('as_module', [True, False]) -@pytest.mark.parametrize( - ('nnums', 'nprocs'), [(None, None), (None, 3), (200, None)], -) -def test_multiproc_script_sanity_check( - test_module: Path, +@pytest.fixture(scope='session') +def _test_module(_ext_module: Path) -> Generator[Path, None, None]: + name = next(_ModuleFixture.propose_name('my_test_module')) + body = TEST_MODULE_TEMPLATE.format( + EXT_MODULE=_ext_module.stem, + NUM_NUMBERS=NUM_NUMBERS, + NUM_PROCS=NUM_PROCS, + ) + with TemporaryDirectory() as mydir_str: + my_dir = Path(mydir_str) + my_dir.mkdir(exist_ok=True) + my_module = my_dir / f'{name}.py' + my_module.write_text(body) + yield my_module + + +@pytest.fixture +def ext_module( + _ext_module: Path, monkeypatch: pytest.MonkeyPatch, +) -> Generator[_ModuleFixture, None, None]: + """ + Yields: + :py:class:`_ModuleFixture` helper object containing the code at + :py:data:`EXTERNAL_MODULE_BODY` + """ + yield _ModuleFixture(_ext_module, monkeypatch) + + +@pytest.fixture +def test_module( + _test_module: Path, + ext_module: _ModuleFixture, + monkeypatch: pytest.MonkeyPatch, +) -> Generator[_ModuleFixture, None, None]: + """ + Yields: + :py:class:`_ModuleFixture` helper object containing the code at + :py:data:`TEST_MODULE_TEMPLATE` + """ + yield _ModuleFixture(_test_module, monkeypatch, [ext_module]) + + +@pytest.fixture +def test_module_clone( tmp_path_factory: pytest.TempPathFactory, - nnums: int, - nprocs: int, - as_module: bool, -) -> None: + monkeypatch: pytest.MonkeyPatch, + _test_module: Path, + ext_module: _ModuleFixture, +) -> Generator[_ModuleFixture, None, None]: """ - Sanity check that the test module functions as expected when run - with vanilla Python. + Yields: + :py:class:`_ModuleFixture` helper object containing the same + code as :py:data:`test_module` """ - _run_test_module( - _run_as_module if as_module else _run_as_script, - test_module, tmp_path_factory, [sys.executable], None, False, - nnums=nnums, nprocs=nprocs, - ) + tmpdir = tmp_path_factory.mktemp('my_path') + name = next(_ModuleFixture.propose_name('my_cloned_module')) + path = tmpdir / f'{name}.py' + path.write_text(_test_module.read_text()) + yield _ModuleFixture(path, monkeypatch, [ext_module]) -# Note: -# Currently code execution in child processes is not properly profiled; -# these tests are just for checking that `kernprof` doesn't impair the -# proper execution of `multiprocessing` code +@pytest.fixture +def ext_module_object( + ext_module: _ModuleFixture, +) -> Generator[ModuleType, None, None]: + """ + Yields: + :py:class:`ModuleType` object containing the code at + :py:data:`EXTERNAL_MODULE_BODY`, and is torn down at the end of + the test + """ + yield from ext_module._import_module_helper() -fuzz_invocations = pytest.mark.parametrize( - ('runner', 'outfile', 'profile', - 'label'), # Dummy argument to make `pytest` output more legible - [ - (['kernprof', '-q'], 'out.prof', False, 'cProfile'), - # Run with `line_profiler` with and w/o profiling targets - (['kernprof', '-q', '-l'], 'out.lprof', False, - 'line_profiler-inactive'), - (['kernprof', '-q', '-l'], 'out.lprof', True, - 'line_profiler-active'), - ], -) +@pytest.fixture +def test_module_object( + test_module: _ModuleFixture, ext_module_object: ModuleType, +) -> Generator[ModuleType, None, None]: + """ + Yields: + :py:class:`ModuleType` object containing the code at + :py:data:`TEST_MODULE_TEMPLATE`, and is torn down at the end of + the test + """ + yield from test_module._import_module_helper() -@fuzz_invocations -def test_running_multiproc_script( - test_module: Path, +@pytest.fixture +def create_cache( tmp_path_factory: pytest.TempPathFactory, - runner: str | list[str], - outfile: str | None, - profile: bool, - label: str, -) -> None: + request: pytest.FixtureRequest, +) -> Generator[Callable[..., LineProfilingCache], None, None]: """ - Check that `kernprof` can run the test module as a script - (`kernprof [...] `). + Wrapper around the :py:class:`LineProfilingCache` instantiator + which: + + - Automatically creates a tempdir and provides it as the + :py:attr:`LineProfilingCache.cache_dir`, + + - Extends the argument ``preimports_module`` to allow for taking + boolean values: + + - ``True``: a temporary preimports module is automatically written + based on ``profiling_targets`` and supplied to the base + constructor. + + - ``False``: equivalent to ``None``. + + - Unless the argument ``_use_curated_profiler: bool = True`` is set + to :py:const:`False`, automatically creates an instance of + :py:class:`LineProfiler` that is curated by a + :py:class:`CuratedProfilerContext` and provides it as the + :py:attr:`LineProfilingCache.profiler`, and + + - At teardown: + + - Removes tempdirs and tempfiles generated. + + - Restores the value of the class' internal reference to the + :py:meth:`LineProfilingCache.load`-ed instance. + + - Calls the `.cleanup()` method of each instance created. + + - Prints these diagnostics for each instance: + + - The stats on the ``.profiler`` associated with each instance + (if any) + + - The stats gathered by + :py:meth:`LineProfilingCache.gather_stats()` + + - The debug logs (if ``.debug`` is true) """ - _run_test_module( - _run_as_script, - test_module, tmp_path_factory, runner, outfile, profile, - ) + def instantiate( + *, + profiling_targets: Collection[str] = (), + preimports_module: os.PathLike[str] | str | bool | None = None, + _use_curated_profiler: bool = True, + **kwargs + ) -> LineProfilingCache: + tmpdir = tmp_path_factory.mktemp('my_cache_dir') + pim: os.PathLike[str] | str | None + if preimports_module in (True, False): + if preimports_module: + targets = ( + ClassifiedPreimportTargets.from_targets(profiling_targets) + ) + if targets: + pim = tmpdir / 'preimports.py' + with pim.open(mode='w') as fobj: + targets.write_preimport_module(fobj) + else: + pim = None + else: + pim = None + else: + # The type checker needs some convincing... + assert not isinstance(preimports_module, bool) + pim = preimports_module + cache = LineProfilingCache( + tmpdir, + profiling_targets=profiling_targets, + preimports_module=pim, + **kwargs, + ) + if _use_curated_profiler: + cache.profiler = request.getfixturevalue('curated_profiler') + instances.append(cache) + return cache + def print_result( + cache: LineProfilingCache, topic: str, result: str, *notes: str, + ) -> None: + header = '{} ({}):'.format( + topic, '; '.join([f'cache instance {id(cache):#x}', *notes]), + ) + print(header, indent(result, ' '), sep='\n') -@fuzz_invocations -def test_running_multiproc_module( - test_module: Path, - tmp_path_factory: pytest.TempPathFactory, - runner: str | list[str], - outfile: str | None, - profile: bool, - label: str, -) -> None: + def print_profiler_stats(cache: LineProfilingCache) -> None: + if cache.profiler is None: + result = '' + notes = [] + else: + with StringIO() as sio: + cache.profiler.print_stats(sio) + result = sio.getvalue() + notes = [f'profiler instance {id(cache.profiler):#x}'] + print_result(cache, 'Native profiler stats', result, *notes) + + def print_gathered_stats(cache: LineProfilingCache) -> None: + with StringIO() as sio: + cache.gather_stats().print(sio) + result = sio.getvalue() + print_result(cache, 'Gathered profiler stats', result) + + def print_debug_logs(cache: LineProfilingCache) -> None: + if cache.debug: + result = '\n'.join( + entry.to_text() for entry in cache._gather_debug_log_entries() + ) + else: + result = '' + print_result(cache, 'Gathered debug logs', result) + + instances: list[LineProfilingCache] = [] + handlers: list[Callable[[LineProfilingCache], None]] + handlers = [print_profiler_stats, print_gathered_stats, print_debug_logs] + try: + with _preserve_obj_attributes( + LineProfilingCache, ['_loaded_instance'], + ): + yield instantiate + finally: + for cache in instances: + callbacks: list[Callable[[], Any]] = [cache.cleanup] + callbacks.extend(partial(func, cache) for func in handlers) + for callback in callbacks: + try: + callback() + except Exception: + pass + + +@pytest.fixture +def curated_profiler() -> Generator[LineProfiler, None, None]: """ - Check that `kernprof` can run the test module as a module - (`kernprof [...] -m `). + Yields: + Fresh instance of :py:class:`LineProfiler` that is managed by a + :py:class:`CuratedProfilerContext` """ - _run_test_module( - _run_as_module, - test_module, tmp_path_factory, runner, outfile, profile, - ) + prof = LineProfiler() + with CuratedProfilerContext(prof, insert_builtin=True): + yield prof -def _run_as_script( - runner_args: list[str], test_args: list[str], test_module: Path, **kwargs -) -> subprocess.CompletedProcess: - cmd = runner_args + [str(test_module)] + test_args - return subprocess.run(cmd, **kwargs) +@pytest.fixture +def another_pid() -> int: + """ + Get a PID which is distinct from the current one. + """ + curr_pid = os.getpid() + pid = (curr_pid - 42) % (2 * 16) + assert pid != curr_pid + return pid -def _run_as_module( - runner_args: list[str], - test_args: list[str], - test_module: Path, - *, - env: Mapping[str, str] | None = None, - **kwargs -) -> subprocess.CompletedProcess: - cmd = runner_args + ['-m', test_module.stem] + test_args - env_dict = {**os.environ, **(env or {})} - python_path = env_dict.pop('PYTHONPATH', '') - if python_path: - env_dict['PYTHONPATH'] = '{}:{}'.format( - test_module.parent, python_path, +@pytest.fixture(autouse=True) +def _trim_mismatch_traceback(pytestconfig: pytest.Config) -> None: + """ + Truncate the traceback of raised :py:class`ResultMismatch` for more + useful error attribution. + """ + try: + pytestconfig.pluginmanager.register(ResultMismatch) + except ValueError: # Already registered + pass + + +# ========================== Helper functions ========================== + + +class _NotSupplied(enum.Enum): + NOT_SUPPLIED = enum.auto() + + +class _GetAttr(Protocol): + """ + Function signature for functions that behave like + :py:func:`getattr``. + """ + @overload + def __call__(self, obj: Any, attr: str, /) -> Any: + ... + + @overload + def __call__(self, obj: Any, attr: str, default: Any, /) -> Any: + ... + + def __call__(self, *args): + ... + + +@final +class ResultMismatch(ValueError): + def __init__( + self, + expected: Any, + actual: Any | _NotSupplied = _NotSupplied.NOT_SUPPLIED, + _trunc_tb: int = 0, + ) -> None: + if actual == _NotSupplied.NOT_SUPPLIED: + msg = f'expected: {expected}' + else: + msg = f'expected {expected}, got {actual}' + super().__init__(msg) + self.expected = expected + self.actual = actual + self._trunc_tb = max(0, _trunc_tb) + + @classmethod + def compare( + cls, expected_: T1, actual_: T2, /, *, + comparator: Callable[[T1, T2], bool] = operator.eq, + expected: str | None = None, + actual: str | None = None, + ) -> None: + if comparator(expected_, actual_): + return + raise cls( + expected_ if expected is None else expected, + actual_ if actual is None else actual, + _trunc_tb=1, ) - else: - env_dict['PYTHONPATH'] = str(test_module.parent) - return subprocess.run(cmd, env=env_dict, **kwargs) + @classmethod + def pytest_runtest_makereport( + cls, item: pytest.Item, call: pytest.CallInfo, + ) -> Any: + """ + Truncate the tracebacks of instances so that pytest outputs are + more useful and actually stops at the frame where the comparends + are shown. + """ + impl: Callable[..., Any] + impl = item.config.pluginmanager.subset_hook_caller( + 'pytest_runtest_makereport', [cls], + ) + make_report = partial(impl, item=item, call=call) -def _run_test_module( - run_helper: Callable[..., subprocess.CompletedProcess], - test_module: Path, - tmp_path_factory: pytest.TempPathFactory, - runner: str | list[str], - outfile: str | None, - profile: bool, - *, - nnums: int | None = None, - nprocs: int | None = None, - check: bool = True, -) -> tuple[subprocess.CompletedProcess, Path | None]: + xc = call.excinfo + if xc is None: + return make_report() + if not (isinstance(xc.value, cls) and xc.value._trunc_tb): + return make_report() + + tb_stack: list[TracebackType] = [xc.tb] + while tb_stack[-1].tb_next: + tb_stack.append(tb_stack[-1].tb_next) + if len(tb_stack) <= xc.value._trunc_tb: + return make_report() + tb_stack[-(xc.value._trunc_tb + 1)].tb_next = None + + del tb_stack # Help the GC + call.excinfo = xc.from_exception(xc.value.with_traceback(xc.tb)) + return make_report(call=call) + + @property + def rich_message(self) -> str: + msg = '{}: {}'.format(type(self).__name__, self.args[0]) + if self.__traceback__ is not None: + tb = self.__traceback__ + msg = '{}:{}: {}'.format( + tb.tb_frame.f_code.co_filename, tb.tb_lineno, msg, + ) + return msg + + +class _TestTimeout(RuntimeError): """ - Return - ------ - `(process_running_the_test_module, path_to_profiling_output | None)` + Error raised by the :py:func:`_timeout` decorator. """ - if isinstance(runner, str): - runner_args: list[str] = [runner] - else: - runner_args = list(runner) - if profile: - runner_args.extend(['--prof-mod', str(test_module)]) + pass - test_args: list[str] = [] - if nnums is None: - nnums = NUM_NUMBERS - else: - test_args.extend(['-l', str(nnums)]) - if nprocs is not None: - test_args.extend(['-n', str(nprocs)]) - with ub.ChDir(tmp_path_factory.mktemp('mytemp')): - if outfile is not None: - runner_args.extend(['--outfile', outfile]) - proc = run_helper( - runner_args, test_args, test_module, - text=True, capture_output=True, +@final +@dataclasses.dataclass +class _Params: + """ + Convenience wrapper around :py:func:`pytest.mark.parametrize`. + """ + params: tuple[str, ...] + values: list[tuple[Any, ...]] + defaults: tuple[Any, ...] + + def __post_init__(self) -> None: + n = len(self.params) + assert all(p.isidentifier() for p in self.params) # Validity + assert len(set(self.params)) == n # Uniqueness + assert len(self.defaults) == n # Consistency + self.values = list(self._unique(self.values)) + assert all(len(v) == n for v in self.values) + + def __mul__(self, other: Self) -> Self: + """ + Form a Cartesian product between the two instances with disjoint + :py:attr:`~.params`, like stacking the + :py:func:`pytest.mark.parametrize `decorators. + + Example: + >>> p1 = _Params.new(('a', 'b'), [(0, 0), (1, 2), (3, 4)], + ... defaults=(1, 2)) + >>> p2 = _Params.new('c', [0, 5, 6]) + >>> p1 * p2 # doctest: +NORMALIZE_WHITESPACE + _Params(params=('a', 'b', 'c'), + values=[(0, 0, 0), (0, 0, 5), (0, 0, 6), + (1, 2, 0), (1, 2, 5), (1, 2, 6), + (3, 4, 0), (3, 4, 5), (3, 4, 6)], + defaults=(1, 2, 0)) + """ + assert not set(self.params) & set(other.params) + return type(self)( + self.params + other.params, + [sv + ov for sv in self.values for ov in other.values], + self.defaults + other.defaults, ) - try: - if check: - proc.check_returncode() - finally: - print(f'stdout:\n{indent(proc.stdout, " ")}') - print(f'stderr:\n{indent(proc.stderr, " ")}', file=sys.stderr) - assert proc.stdout == f'{nnums * (nnums + 1) // 2}\n' + def __add__(self, other: Self) -> Self: + """ + Concatenate two instances: + + - For parameters appearing in both, their lists of values are + concatenated. + + - For parameters appearing in either instance, the missing + values are taken from the other instance's + :py:attr:`~.defaults`. + + Note: + In the case of clashes, the :py:attr:`~.defaults` and the + order of the :py:attr:`~.params` of ``self`` (the left + operand) take precedence. + + Example: + >>> p1 = _Params.new(('a', 'b', 'c'), + ... [(0, 0, 0), # defaults + ... (1, 2, 3), (4, 5, 6)]) + >>> p2 = _Params.new(('c', 'd'), [(7, 8), (9, 10)], + ... defaults=(-1, -1)) + >>> p1 + p2 # doctest: +NORMALIZE_WHITESPACE + _Params(params=('a', 'b', 'c', 'd'), + values=[(0, 0, 0, -1), + (1, 2, 3, -1), + (4, 5, 6, -1), + (0, 0, 7, 8), + (0, 0, 9, 10)], + defaults=(0, 0, 0, -1)) + """ + self_defaults = dict(zip(self.params, self.defaults)) + other_defaults = dict(zip(other.params, other.defaults)) + new_params = tuple(self._unique(self.params + other.params)) + + defaults = {**other_defaults, **self_defaults} + new_defaults_tuple = tuple(defaults[p] for p in new_params) + + new_values: list[tuple[Any, ...]] = [] + for old_values, old_params in [ + (self.values, self.params), (other.values, other.params), + ]: + indices: list[ + tuple[Literal[True], int] | tuple[Literal[False], str] + ] = [ + (True, old_params.index(p)) if p in old_params else (False, p) + for p in new_params + ] + new_values.extend( + tuple( + ( + value[cast(int, index)] + if available else + defaults[cast(str, index)] + ) for available, index in indices + ) + for value in old_values + ) + return type(self)(new_params, new_values, new_defaults_tuple) + + def sorted( + self, + *, + sort_by: Sequence[str] | None = None, + sortable_types: type[Any] | tuple[type[Any], ...] = (Real, str, bytes), + ) -> Self: + """ + Sort by parametrization values. + + Args: + sort_by (Sequence[str] | None): + Column names to sort by; default is to sort by all + sortable params. + sortable_types (type[Any] | tuple[type[Any], ...]): + Type(s) where if a param has all its values being + instances thereof (excl. :py:const:`None`s), said param + is considered sortable. + + Returns: + New instance + """ + def sort_key(obj: Any) -> tuple[bool, str, Any]: + type_name = '{0.__module__}.{0.__qualname__}'.format(type(obj)) + return (obj is None), type_name, obj + + if sort_by is None: + sort_by = self.params + sortable_columns: set[str] = { + param for param, *values in zip(self.params, *self.values) + if all(isinstance(v, sortable_types) or v is None for v in values) + } + sorted_column_indices: tuple[int, ...] = tuple( + i for i, param in enumerate(sort_by) if param in sortable_columns + ) + + if sorted_column_indices: + new_values = sorted( + self.values, key=lambda vtuple: tuple( + sort_key(vtuple[i]) for i in sorted_column_indices + ), + ) + else: # Fallback + new_values = self.values.copy() + return type(self)(self.params, new_values, self.defaults) + + def drop_params(self, params: Collection[str] | str) -> Self: + """ + Return a new instance with the named ``params`` dropped; params + that don't match :py:attr:`.params` are ignored. + + Example: + >>> p = _Params.new(('a', 'b'), [(1, 2), (3, 4)]) + >>> p.drop_params('a') + _Params(params=('b',), values=[(2,), (4,)], defaults=(2,)) + >>> assert p.drop_params(['c', 'd']) == p + """ + def drop(t: tuple[T, ...]) -> tuple[T, ...]: + return tuple(item for i, item in enumerate(t) if i not in dropped) + + if isinstance(params, str): + params = params, + dropped = {i for i, p in enumerate(self.params) if p in params} + return type(self)( + drop(self.params), + [drop(pvalues) for pvalues in self.values], + drop(self.defaults), + ) + + @overload + def split_on_params( + self, params: tuple[str, ...], *, drop_split_params: bool = True, + ) -> dict[tuple[Any, ...], Self]: + ... + + @overload + def split_on_params( + self, params: str, *, drop_split_params: bool = True, + ) -> dict[Any, Self]: + ... - prof_result: Path | None = None - if outfile is None: - assert not list(Path.cwd().iterdir()) + def split_on_params( + self, params: tuple[str, ...] | str, *, drop_split_params: bool = True, + ) -> dict[tuple[Any, ...], Self] | dict[Any, Self]: + """ + Return new instances splitting on the values of the named + ``params``; params that don't match :py:attr:`.params` results + in an error. + + Example: + >>> p = _Params.new(('a', 'b', 'c'), + ... [(1, 2, True), + ... (1, 2, False), + ... (3, 4, True)]) + + >>> p.split_on_params('a') # doctest: +NORMALIZE_WHITESPACE + {1: _Params(params=('b', 'c'), + values=[(2, True), (2, False)], + defaults=(2, True)), + 3: _Params(params=('b', 'c'), + values=[(4, True)], + defaults=(2, True))} + + >>> p.split_on_params( # doctest: +NORMALIZE_WHITESPACE + ... ('a', 'b'), + ... ) + {(1, 2): _Params(params=('c',), + values=[(True,), (False,)], + defaults=(True,)), + (3, 4): _Params(params=('c',), + values=[(True,)], + defaults=(True,))} + + >>> p.split_on_params( # doctest: +NORMALIZE_WHITESPACE + ... 'a', drop_split_params=False, + ... ) + {1: _Params(params=('a', 'b', 'c'), + values=[(1, 2, True), (1, 2, False)], + defaults=(1, 2, True)), + 3: _Params(params=('a', 'b', 'c'), + values=[(3, 4, True)], + defaults=(1, 2, True))} + + >>> p.split_on_params( # doctest: +NORMALIZE_WHITESPACE + ... ('c', 'd'), + ... ) + Traceback (most recent call last): + ... + ValueError: params = ('c', 'd'): + these params not found: ['d'] + """ + if isinstance(params, str): + params = params, + unpack = True else: - prof_result = Path(outfile).resolve() - assert prof_result.exists() - assert prof_result.stat().st_size - return proc, prof_result + unpack = False + nonexistent = sorted(set(params) - set(self.params)) + if nonexistent: + raise ValueError( + f'params = {params!r}: these params not found: {nonexistent!r}' + ) + split_params: dict[tuple[Any, ...], list[tuple[Any, ...]]] = {} + indices = tuple(self.params.index(p) for i, p in enumerate(params)) + for pvalues in self.values: + key = tuple(pvalues[i] for i in indices) + split_params.setdefault(key, []).append(pvalues) + new = partial(type(self), params=self.params, defaults=self.defaults) + instances: dict[tuple[Any, ...], Self] = { + key: new(values=values) for key, values in split_params.items() + } + if drop_split_params: + instances = { + key: instance.drop_params(params) + for key, instance in instances.items() + } + if not unpack: + return instances + return {key[0]: instance for key, instance in instances.items()} + + def __call__(self, func: C) -> C: + """ + Mark a callable as with :py:func:`pytest.mark.parametrize`. + """ + # Note: `pytest` automatically assumes single-param values to + # be unpacked, so comply here + if len(self.params) == 1: + marker = pytest.mark.parametrize( + self.params[0], [v[0] for v in self.values], + ) + else: + marker = pytest.mark.parametrize(self.params, self.values) + return marker(func) + + @staticmethod + def _unique(items: Iterable[T]) -> Generator[T, None, None]: + seen: set[T] = set() + for item in items: + if item in seen: + continue + seen.add(item) + yield item + + @overload + @classmethod + def new( + cls, + params: Sequence[str] | str, + values: Sequence[Sequence[Any]], + defaults: Sequence[Any] | _NotSupplied = _NotSupplied.NOT_SUPPLIED, + ) -> Self: + ... + + @overload + @classmethod + def new( + cls, + params: str, + values: Sequence[Any], + defaults: Any | _NotSupplied = _NotSupplied.NOT_SUPPLIED, + ) -> Self: + ... + + @classmethod + def new( + cls, + params: Sequence[str] | str, + values: Sequence[Sequence[Any]] | Sequence[Any], + defaults: ( + Sequence[Any] | Any | _NotSupplied + ) = _NotSupplied.NOT_SUPPLIED, + ) -> Self: + """ + Instantiator more akin to :py:func:`pytest.mark.parametrize`: + + - ``params`` can be provided as a comma-separated string + + - Single parameters can be unpacked (singular param-name string + and param-value sequences) + + - If ``defaults`` are not given, it is implicitly set to the + FIRST item in ``values``. + """ + if isinstance(params, str): + param_list: tuple[str, ...] = tuple( + p.strip() for p in params.split(',') + ) + unpacked = len(param_list) == 1 + else: + param_list = tuple(params) + unpacked = False + if defaults == _NotSupplied.NOT_SUPPLIED: + defaults, *_ = values + if unpacked: + default_values: tuple[Any, ...] = defaults, + value_tuple_list: list[tuple[Any, ...]] = [(v,) for v in values] + else: + default_values = tuple(defaults) # type: ignore[arg-type] + value_tuple_list = [tuple(v) for v in values] + return cls(param_list, value_tuple_list, default_values) + + +class _CallableContextManager(ABC, Generic[TCtx_]): + debug: bool + + @abstractmethod + def __enter__(self) -> TCtx_: + ... + + @abstractmethod + def __exit__(self, *a, **k) -> Any: + ... + + def __call__(self, func: Callable[PS, T]) -> Callable[PS, T]: + """ + Wrap ``func()`` so that its calls always happen in the context + of the instance. + """ + @wraps(func) + def wrapper(*args: PS.args, **kwargs: PS.kwargs) -> T: + with self: + return func(*args, **kwargs) + + return wrapper + + def _debug(self, msg: str, **kwargs) -> None: + if not self.debug: + return + header = f'{os.environ["PYTEST_CURRENT_TEST"]}: {type(self).__name__}' + print(f'{header}: {msg}', **kwargs) + + +class _preserve_obj_attributes(_CallableContextManager[dict[str, Any]]): + def __init__( + self, obj: Any, attrs: Collection[str], *, + static: bool = True, debug: bool = _DEBUG, + ) -> None: + self.obj = obj + self.attrs = set(attrs) + self._callbacks: list[Callable[[], None]] = [] + self.debug = debug + self.static = static + + def __enter__(self) -> dict[str, Any]: + def get_repr(attr: str) -> str: + try: + value = get_attribute(self.obj, attr) + except ValueError: + return '' + else: + return repr(value) + + def delete(attr: str) -> None: + try: + self._debug('Deleted attr `.{} = {}` on `{!r}`'.format( + attr, get_repr(attr), self.obj, + )) + delattr(self.obj, attr) + except AttributeError: + pass + + def reset(attr: str, value: Any) -> None: + self._debug('Reset attr `.{} = {} -> {!r}` on `{!r}`'.format( + attr, get_repr(attr), value, self.obj, + )) + setattr(self.obj, attr, value) + + if self.static: + get_attribute: _GetAttr = inspect.getattr_static + else: + get_attribute = getattr + + result: dict[str, Any] = {} + for attr in self.attrs: + old = get_attribute(self.obj, attr, _NotSupplied.NOT_SUPPLIED) + if old is _NotSupplied.NOT_SUPPLIED: + callback = partial(delete, attr) + else: + callback = partial(reset, attr, old) + result[attr] = old + self._callbacks.append(callback) + return result + + def __exit__(self, *_, **__) -> None: + for callback in self._callbacks[::-1]: + try: + callback() + except Exception: + pass + + +class _preserve_attributes(_CallableContextManager[dict[str, dict[str, Any]]]): + """ + Example: + >>> from functools import wraps + >>> from line_profiler.curated_profiling import ( + ... CuratedProfilerContext, + ... ) + >>> from line_profiler import line_profiler + + >>> assert not hasattr(CuratedProfilerContext, 'foo') + >>> old_main = line_profiler.main + >>> + >>> + >>> def foo(_) -> None: + ... pass + ... + >>> + >>> @wraps(old_main) + ... def main(*a, **k): + ... return old_main(*a, **k) + ... + >>> + >>> preserved = { + ... 'line_profiler.curated_profiling' + ... '.CuratedProfilerContext': {'foo'}, + ... 'line_profiler.line_profiler': {'main'}, + ... } + >>> with _preserve_attributes(preserved, debug=False) as old: + ... assert old == { + ... 'line_profiler.curated_profiling' + ... '.CuratedProfilerContext': { + ... 'foo': _NotSupplied.NOT_SUPPLIED, + ... }, + ... 'line_profiler.line_profiler': {'main': old_main}, + ... } + ... CuratedProfilerContext.foo = foo + ... line_profiler.main = main + ... print('ok') + ... + ok + >>> assert not hasattr(CuratedProfilerContext, 'foo') + >>> assert old_main is \ +old['line_profiler.line_profiler']['main'] + >>> assert old_main is line_profiler.main + >>> assert main is not line_profiler.main + """ + def __init__( + self, targets: Mapping[str, Collection[str]], *, + static: bool = True, debug: bool = _DEBUG, + ) -> None: + self.targets = { + target: set(attrs) for target, attrs in targets.items() + } + self._stacks: list[ExitStack] = [] + self.static = static + self.debug = debug + + def __enter__(self) -> dict[str, dict[str, Any]]: + stack = ExitStack() + self._stacks.append(stack) + result: dict[str, Any] = {} + for target, attrs in self.targets.items(): + result[target] = stack.enter_context(_preserve_obj_attributes( + _import_target(target), attrs, + debug=self.debug, static=self.static, + )) + return result + + def __exit__(self, *_, **__) -> None: + self._stacks.pop().close() + + @staticmethod + def fetch_current_values( + targets: Mapping[str, Collection[str]], static: bool = True, + ) -> dict[str, dict[str, Any]]: + result: dict[str, dict[str, Any]] = {} + na = _NotSupplied.NOT_SUPPLIED + if static: + get: _GetAttr = inspect.getattr_static + else: + get = getattr + for target, attrs in targets.items(): + obj = _import_target(target) + result[target] = {attr: get(obj, attr, na) for attr in attrs} + return result + + @classmethod + def compare_with_current_values( + cls, + old: Mapping[str, Mapping[str, Any]], + comparator: Callable[[Any, Any], bool] = operator.is_, + assert_true: bool | Mapping[str, Mapping[str, bool]] = True, + static: bool = True, + ) -> dict[str, dict[str, bool]]: + def get_from_mapping(target: str, attr: str) -> bool: + if TYPE_CHECKING: + assert isinstance(assert_true, Mapping) + return assert_true[target][attr] + + def get_from_boolean(*_, **__) -> bool: + return True + + if isinstance(assert_true, Mapping): + get_expected: Callable[[str, str], bool] = get_from_mapping + else: + get_expected = get_from_boolean + + result: dict[str, dict[str, bool]] = {} + new = cls.fetch_current_values(old, static) + failures: list[str] = [] + for target, old_values in old.items(): + new_values = new[target] + cmp_results = result[target] = {} + for attr, old_value in old_values.items(): + print(f'Checking: {target}.{attr}') + new_value = new_values[attr] + cmp_results[attr] = cmp_result = comparator( + new_value, old_value, + ) + format_msg = partial( + '{}: {}'.format, + f'Compared `{target}.{attr}` ' + f'(old: {old_value!r} @ {id(old_value):#x}; ' + f'new: {new_value!r} @ {id(new_value):#x})', + ) + expected_result = get_expected(target, attr) + if assert_true: + if cmp_result == expected_result: + message = format_msg( + f'comparison result with {comparator!r} is ' + f'{cmp_result} (as expected)' + ) + else: + message = format_msg( + f'expected comparison with {comparator!r} to ' + f'return {expected_result}, got {cmp_result}' + ) + failures.append(message) + else: + message = format_msg( + f'comparison result with {comparator!r}: {cmp_result}' + ) + print(message) + assert (not failures), '\n'.join(failures) + return result + + +class _preserve_pth_files(_CallableContextManager[frozenset[str]]): + def __init__(self, debug: bool = _DEBUG) -> None: + self.debug = debug + + def __enter__(self) -> frozenset[str]: + self.old = self.get_pth_files() + return self.old + + def __exit__(self, *_, **__) -> None: + for new_pth_file in self.get_pth_files() - self.old: + self._debug(f'Deleting stray .pth file: {new_pth_file!r}') + (self._get_path() / new_pth_file).unlink(missing_ok=True) + del self.old + + @classmethod + def get_pth_files(cls, name_only: bool = True) -> frozenset[str]: + return frozenset( + pth.name if name_only else str(pth) + for pth in cls._get_path().glob('*.pth') + ) + + @staticmethod + def _get_path() -> Path: + return Path(sysconfig.get_path('purelib')) + + +class _WarningInfo(Protocol): + @property + def message(self) -> str | Warning: + ... + + @property + def category(self) -> type[Warning]: + ... + + @property + def filename(self) -> str: + ... + + @property + def lineno(self) -> int: + ... + + @property + def line(self) -> str | None: + ... + + +@dataclasses.dataclass +class _WarningMatcher: + message: str | None = None + category: type[Warning] | None = None + module: str | None = None + lineno: int | None = None + _filters: dict[str, Callable[[Any], Any]] = dataclasses.field( + repr=False, init=False, default_factory=dict, + ) + + def __post_init__(self) -> None: + if self.message is not None: + self._filters['message'] = partial( + self._check_message, re.compile(self.message), + ) + if self.category is not None: + self._filters['category'] = partial( + self._check_category, self.category, + ) + if self.module is not None: + self._filters['filename'] = partial( + self._check_module, re.compile(self.module), + ) + if self.lineno is not None: + self._filters['lineno'] = partial(operator.eq, self.lineno) + + def __repr__(self) -> str: + fields: dict[str, Any] = { + field.name: getattr(self, field.name, None) + for field in dataclasses.fields(self) + if field.repr + } + return '{}({})'.format( + type(self).__name__, + ', '.join( + f'{k}={v!r}' for k, v in fields.items() if v is not None + ), + ) + + def match(self, info: _WarningInfo) -> bool: + for field, check in self._filters.items(): + if not check(getattr(info, field)): + return False + return True + + @staticmethod + def _check_message( + msg_regex: re.Pattern, msg: str | Warning, + ) -> re.Match | None: + if not isinstance(msg, str): + msg = str(msg) + return msg_regex.match(msg) + + @staticmethod + def _check_category(parent: type[Any], maybe_child: type[Any]) -> bool: + try: + return issubclass(maybe_child, parent) + except Exception: + return False + + @staticmethod + def _check_module( + module_regex: re.Pattern, filename: str, + ) -> re.Match | None: + module = modpath_to_modname(filename, hide_main=False, hide_init=False) + return module_regex.match(module) + + +@dataclasses.dataclass +class _WarningContext: + catch_warnings: warnings.catch_warnings = dataclasses.field( + default_factory=partial(warnings.catch_warnings, record=True) + ) + checks: list[tuple[_WarningMatcher, bool]] = dataclasses.field( + default_factory=list, + ) + + def forbid_warnings( + self, + message: str | None = None, + category: type[Warning] | None = Warning, + module: str | None = None, + lineno: int | None = None, + ) -> None: + matcher = _WarningMatcher( + message=message, category=category, module=module, lineno=lineno, + ) + self.checks.append((matcher, False)) + + def expect_warnings( + self, + message: str | None = None, + category: type[Warning] | None = Warning, + module: str | None = None, + lineno: int | None = None, + ) -> None: + matcher = _WarningMatcher( + message=message, category=category, module=module, lineno=lineno, + ) + self.checks.append((matcher, True)) + + def check(self, warnings: Sequence[_WarningInfo]) -> None: + for matcher, allowed_or_required in self.checks: + matches = [info for info in warnings if matcher.match(info)] + if matches and not allowed_or_required: + raise ResultMismatch( + expected=f'no warnings matching {matcher!r}', + actual=f'{len(matches)} ({matches!r})', + ) + if not matches and allowed_or_required: + raise ResultMismatch( + expected=f'warnings matching {matcher!r}', + actual=f'none out of {len(warnings)} ({warnings!r})', + ) + + @classmethod + def new(cls, **kwargs) -> Self: + kwargs['record'] = True + return cls(warnings.catch_warnings(**kwargs)) + + +class _check_warnings(Sequence[_WarningInfo]): + """ + Helper context for deferring the checking of warnings to until + context exit. + + Example: + >>> import warnings + + >>> cw = _check_warnings() + + >>> with cw: # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE + ... cw.forbid_warnings('foo', UserWarning) + ... warnings.warn('foobar') + ... print('This is printed before the error') + ... + This is printed before the error + Traceback (most recent call last): + ... + test_child_procs.ResultMismatch: expected no warnings matching + _WarningMatcher(message='foo', + category=), + got 1 ([...]) + + >>> with cw: # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE + ... cw.expect_warnings(category=UserWarning) + ... warnings.warn('foobar', Warning) + ... print('This is printed before the error') + ... + This is printed before the error + Traceback (most recent call last): + ... + test_child_procs.ResultMismatch: expected warnings matching + _WarningMatcher(category=), + got none out of 1 ([...]) + >>> assert len(cw) == 1 + >>> assert str(cw[0].message) == 'foobar' + """ + def __init__(self, **kwargs) -> None: + self._new_context: Callable[[], _WarningContext] = partial( + _WarningContext.new, **kwargs, + ) + self._contexts: list[ + tuple[_WarningContext, Sequence[_WarningInfo]] + ] = [] + self._last_captured: Sequence[_WarningInfo] = [] + + def forbid_warnings(self, *args, **kwargs) -> None: + """ + Equivalent to calling + ``filterwarnings('error', *args, **kwargs)``; + at context exit, if ANY matching warning has been issued, an + error will be raised. + """ + ctx, _ = self._current_context + ctx.forbid_warnings(*args, **kwargs) + + def expect_warnings(self, *args, **kwargs) -> None: + """ + Equivalent to calling + ``filterwarnings('always', *args, **kwargs)``; + at context exit, if NO matching warnings have been issued, an + error will be raised. + """ + ctx, _ = self._current_context + ctx.expect_warnings(*args, **kwargs) + + def __enter__(self) -> Self: + ctx = self._new_context() + infos: Sequence[_WarningInfo] | None + infos = ctx.catch_warnings.__enter__() + assert infos is not None + self._contexts.append((ctx, infos)) + return self + + def __exit__(self, *args, **kwargs) -> None: + ctx, infos = self._contexts.pop() + try: + ctx.check(infos) + finally: + self._last_captured = infos + ctx.catch_warnings.__exit__(*args, **kwargs) + + @overload + def __getitem__(self, i: int, /) -> _WarningInfo: + ... + + @overload + def __getitem__(self, i: slice, /) -> list[_WarningInfo]: + ... + + def __getitem__( + self, i: int | slice, /, + ) -> _WarningInfo | Sequence[_WarningInfo]: + return self._current_warnings[i] + + def __len__(self) -> int: + return len(self._current_warnings) + + def __iter__(self) -> Iterator[_WarningInfo]: + return iter(self._current_warnings) + + def __reversed__(self) -> Iterator[_WarningInfo]: + return iter(reversed(self._current_warnings)) + + def __contains__(self, item: Any, /) -> bool: + return item in self._current_warnings + + def index(self, *args, **kwargs) -> int: + return self._current_warnings.index(*args, **kwargs) + + def count(self, *args, **kwargs) -> int: + return self._current_warnings.count(*args, **kwargs) + + @property + def _current_context( + self, + ) -> tuple[_WarningContext, Sequence[_WarningInfo]]: + return self._contexts[-1] + + @property + def _current_warnings(self) -> Sequence[_WarningInfo]: + try: + return self._current_context[1] + except IndexError: + # Outside of contexts, just provide the last captured values + # for convenience + return self._last_captured + + +def _import_target(target: str) -> Any: + try: + return import_module(target) + except ImportError: # Not a module + assert '.' in target + module, _, attr = target.rpartition('.') + return getattr(import_module(module), attr) + + +def _search_cache_logs( + cache: LineProfilingCache, + expecting_logs: bool, + patterns: Mapping[str, bool] | Collection[str], + match_individual_messages: bool = False, + flags: int = 0, +) -> None: + entries = cache._gather_debug_log_entries() + ResultMismatch.compare( + expecting_logs, bool(entries), + expected='logs' if expecting_logs else 'no logs', + actual=repr(entries) if entries else 'nothing', + ) + if not expecting_logs: + return + text_chunks: list[str] = [entry.to_text() for entry in entries] + if not match_individual_messages: + text_chunks = ['\n'.join(text_chunks)] + if isinstance(patterns, Mapping): + to_match: dict[str, bool] = { + str(pat): bool(should_match) + for pat, should_match in patterns.items() + } + else: + to_match = dict.fromkeys(patterns, True) + for pat, should_match in to_match.items(): + pattern = re.compile(pat, flags) + if any(pattern.search(chunk) for chunk in text_chunks) == should_match: + continue + raise ResultMismatch( + f'pattern {pattern!r} to {"" if should_match else "not "}match ' + f'{cache!r}\'s logs: {text_chunks!r}' + ) + + +# `shlex.join()` doesn't work properly on Windows, so use +# `subprocess.list2cmdline()` instead; +# though an "intentionally" undocumented API (cpython issue #10308), +# it's been around since 2.4, seems stable enough, and does exactly what +# is needed +if _WINDOWS: + concat_command_line: Callable[ + [Sequence[str]], str + ] = subprocess.list2cmdline +else: + concat_command_line = shlex.join + + +def _run_as_script( + runner_args: list[str], test_args: list[str], test_module: _ModuleFixture, + **kwargs +) -> subprocess.CompletedProcess: + cmd = runner_args + [str(test_module.path)] + test_args + test_module.install(children=True, deps_only=True) + return _run_subproc(cmd, **kwargs) + + +def _run_as_module( + runner_args: list[str], test_args: list[str], test_module: _ModuleFixture, + **kwargs +) -> subprocess.CompletedProcess: + cmd = runner_args + ['-m', test_module.name] + test_args + test_module.install(children=True) + return _run_subproc(cmd, **kwargs) + + +def _run_as_literal_code( + runner_args: list[str], test_args: list[str], test_module: _ModuleFixture, + **kwargs +) -> subprocess.CompletedProcess: + cmd = runner_args + ['-c', test_module.path.read_text()] + test_args + test_module.install(children=True, deps_only=True) + return _run_subproc(cmd, **kwargs) + + +def _run_subproc( + cmd: Sequence[str] | str, + /, + *args, + check: bool = False, + env: Mapping[str, str] | None = None, + **kwargs +) -> subprocess.CompletedProcess: + """ + Wrapper around :py:func:`subprocess.run` which writes debugging + output. + """ + if isinstance(cmd, str): + cmd_str = cmd + else: + cmd_str = concat_command_line(cmd) + + # If we're capturing outputs, it may be for the best to wait until + # we've processed the output streams to check the return code... + check_rc_in_run = check + for arg in 'stdout', 'stdin': + if kwargs.get(arg) not in {None, subprocess.DEVNULL}: + check_rc_in_run = False + if kwargs.get('capture_output'): + check_rc_in_run = False + + print('Command:', cmd_str) + if env is not None: + diff: list[str] = [] + for key in set(os.environ).union(env): + old = os.environ.get(key) + new = env.get(key) + if old is not None is new: + item = f'{old!r} -> (deleted)' + elif old is None is not new: + item = f'{new!r} (added)' + else: + if old == new: + continue + item = f'{old!r} -> {new!r}' + diff.append(f'${{{key}}}: {item}') + if diff: + print('Env:', indent('\n'.join(diff), ' '), sep='\n') + print('-- Process start --') + # Note: somehow `mypy` doesn't agree with simply unpacking the + # `*args` into `subprocess.run()`... + status: int | str = '???' + proc: subprocess.CompletedProcess | None = None + time = monotonic() + try: + proc = subprocess.run( # type: ignore[call-overload] + cmd, *args, env=env, check=check_rc_in_run, **kwargs, + ) + except Exception: + status = 'error' + raise + else: + assert proc is not None + if check and not check_rc_in_run: # Perform missing check + proc.check_returncode() + status = proc.returncode + return proc + finally: + time = monotonic() - time + if proc is not None: + captured: str | bytes | None + for name, captured, stream in [ + ('stdout', proc.stdout, sys.stdout), + ('stderr', proc.stderr, sys.stderr), + ]: + if captured is None: + continue + if isinstance(captured, bytes): # `text=False` + captured = captured.decode() + print(f'{name}:\n{indent(captured, " ")}', file=stream) + print( + f'-- Process end (time elapsed: {time:.2f} s / ' + f'return status: {status})--' + ) + + +@_preserve_pth_files() +def _run_test_module( + run_helper: Callable[..., subprocess.CompletedProcess], + test_module: _ModuleFixture, + tmp_path_factory: pytest.TempPathFactory, + runner: str | list[str] = 'kernprof', + outfile: str | None = None, + profile: bool = True, + *, + profiled_code_is_tempfile: bool = False, + use_local_func: bool = False, + fail: bool = False, + start_method: Literal['fork', 'forkserver', 'spawn'] | None = None, + nnums: int | None = None, + nprocs: int | None = None, + check: bool = True, + debug_log: str | None = None, + nhits: Mapping[str, int] | None = None, + **kwargs +) -> tuple[subprocess.CompletedProcess, LineStats | None]: + """ + Returns: + process_running_the_test_module (subprocess.CompletedProcess): + Process object + profliing_stats (LineStats | None): + Line-profiling stats (where available) + """ + if isinstance(runner, str): + runner_args: list[str] = [runner] + else: + runner_args = list(runner) + + if not profile: + nhits = None + + if profile and not profiled_code_is_tempfile: + runner_args.extend(['--prof-mod', str(test_module.path)]) + if nhits is not None: + # We need `kernprof` to write the profliing results immediately + # to preserve data from tempfiles (see note below) + runner_args.append('--view') + + test_args: list[str] = [] + if use_local_func: + test_args.append('--local') + if fail: + test_args.append('--force-failure') + if start_method: + if start_method in START_METHODS: + test_args.extend(['-s', start_method]) + else: + pytest.skip( + f'`multiprocessing` start method {start_method!r} ' + 'not available on the platform' + ) + if nnums is None: + nnums = NUM_NUMBERS + else: + test_args.extend(['-l', str(nnums)]) + if nprocs is not None: + test_args.extend(['-n', str(nprocs)]) + + with ub.ChDir(tmp_path_factory.mktemp('mytemp')): + if outfile is not None: + runner_args.extend(['--outfile', outfile]) + if debug_log: + runner_args.extend(['--debug-log', debug_log]) + old_pth_files = _preserve_pth_files.get_pth_files() + try: + proc = run_helper( + runner_args, test_args, test_module, + text=True, capture_output=True, check=(check and not fail), + **kwargs + ) + # Checks: + if fail: + # - The process has failed as expected + if check: + assert proc.returncode + else: + # - The result is correctly calculated + expected = nnums * (nnums + 1) // 2 + output_lines = proc.stdout.splitlines() + ResultMismatch.compare(str(expected), output_lines[0]) + # - Temporary `.pth` file(s) created by + # `LineProfilingCache.write_pth_hook()` has been cleaned up + assert _preserve_pth_files.get_pth_files() == old_pth_files + # - Profiling results are written to the specified file + prof_result: LineStats | None = None + if outfile is None: + assert not list(Path.cwd().iterdir()) + else: + assert os.path.exists(outfile) + assert os.stat(outfile).st_size + if profile: + prof_result = LineStats.from_files(outfile) + # - If we're keeping track, the function is called the + # expected number of times and has run the expected # of + # loops (Note: we do it by parsing the output of + # `kernprof -v` instead of reading the `--outfile`, + # because if the profiled code is in a tempfile the + # profiling data will be dropped in the written outfile) + for tag, num in (nhits or {}).items(): + _check_output(proc.stdout, tag, num) + finally: + if debug_log is not None and os.path.exists(debug_log): + with open(debug_log) as fobj: + print('-- Combined debug logs --', file=sys.stderr) + print(indent(fobj.read(), ' '), end='', file=sys.stderr) + print('-- End of debug logs --', file=sys.stderr) + return proc, prof_result + + +def _check_output(output: str, tag: str, nhits: int) -> None: + # The line should be preixed with 5 numbers: + # lineno, nhits, time, time-per-hit, % time + actual_nhits = 0 + for line in output.splitlines(): + if line.endswith(f'# GREP_MARKER[{tag}]'): + try: + _, n, _, _, _, *_ = line.split() + actual_nhits += int(n) + except Exception: + pass + ResultMismatch.compare( + nhits, actual_nhits, + expected=f'{nhits} hit(s) on line(s) tagged with {tag!r}', + ) + + +run_module = partial(_run_test_module, _run_as_module) +run_script = partial(_run_test_module, _run_as_script) +run_literal_code = partial( + _run_test_module, _run_as_literal_code, profiled_code_is_tempfile=True, +) + + +@overload +def _timeout( + func: Callable[PS, T], *, timeout: float = _TEST_TIMEOUT, +) -> Callable[PS, T]: + ... + + +@overload +def _timeout( + func: None = None, *, timeout: float = _TEST_TIMEOUT, +) -> Callable[[Callable[PS, T]], Callable[PS, T]]: + ... + + +def _timeout( + func: Callable[PS, T] | None = None, *, + timeout: float = _TEST_TIMEOUT, +) -> Callable[PS, T] | Callable[[Callable[PS, T]], Callable[PS, T]]: + """ + Decorate the test function so that it is run in another thread and + can be timed out. + + Example: + >>> from time import sleep + + >>> @_timeout(timeout=.5) + ... def my_func( + ... n: int, delay: float = 1, error: bool = False, + ... ) -> list[int]: + ... sleep(delay) + ... if error: + ... raise RuntimeError('my error message') + ... return list(range(n)) + + Normal execution: + + >>> my_func(3, 0) + [3] + [0, 1, 2, 3] + + Erroring out: + + >>> my_func(3, 0, error=True) + Traceback (most recent call last): + ... + RuntimeError: my error message + + Timing out: + + >>> my_func(4, delay=5) # doctest: +NORMALIZE_WHITESPACE + Traceback (most recent call last): + ... + test_child_procs._TestTimeout: + my_func(4, delay=5): timed out after 0.5 s + """ + if func is None: + return cast( + Callable[[Callable[PS, T]], Callable[PS, T]], + partial(_timeout, timeout=timeout), + ) + + @wraps(func) + def worker( + fobj: IO[bytes], /, *args: PS.args, **kwargs: PS.kwargs + ) -> None: + try: + result = True, func(*args, **kwargs) + except Exception as e: + result = False, ExceptionHelper(e, e.__traceback__) + # Do this instead of directly using `pickle.dump(..., fobj)` so + # that pickling errors and file-handle-related errors are + # handled separately + serialized = pickle.dumps(result, protocol=pickle.HIGHEST_PROTOCOL) + try: + fobj.write(serialized) + fobj.flush() + except Exception: + # Since this is run in a daemon thread, by the time this + # write happens the main thread could've already timed out + # and destroyed `fobj`... in that case just gracefully exit + pass + + @wraps(func) + def wrapper(*args: PS.args, **kwargs: PS.kwargs) -> T: + with BytesIO() as bio: + thread = new_thread(args=(bio, *args), kwargs=kwargs) + thread.start() + thread.join(timeout) + if not thread.is_alive(): + successful, result = pickle.loads(bio.getvalue()) + if successful: + return result + assert isinstance(result, Exception) + raise result + args_repr = [repr(a) for a in args] + args_repr.extend(f'{k}={v!r}' for k, v in kwargs.items()) + name = getattr(func, '__name__', repr(func)) + call_repr = f'{name}({", ".join(args_repr)})' + msg = f'{call_repr}: timed out after {timeout:.2g} s' + raise _TestTimeout(msg) + + new_thread = partial(threading.Thread, target=worker, daemon=True) + + return wrapper + + +# ============================= Unit tests ============================= + +# XXX: Tests in this section concerns implementation details, and the +# tested APIs and behaviors MUST NOT be relied upon by end-users. + +_PatchSummary = Mapping[str, Set[str]] + +_mp_patch_is_internal: Callable[[str], bool] +_mp_patch_is_internal = operator.methodcaller('startswith', '__') + + +def get_patched_attributes( + applied_mp_patches: Collection[str] | None = None, +) -> MappingProxyType[str, frozenset[str]]: + if applied_mp_patches is None: + applied_mp_patches = { + patch for patch, applied in ( + ConfigSource.from_config() + .get_subconfig('child_processes', 'multiprocessing', 'patches') + .conf_dict.items() + ) if applied + } + return _get_patched_attributes(frozenset(applied_mp_patches)) + + +@lru_cache() +def _get_patched_attributes( + applied_mp_patches: frozenset[str], +) -> MappingProxyType[str, frozenset[str]]: + # Get the contents of the individual patches + patches = _GLOBAL_MINIMAL_PATCHES.copy() + iter_summaries = ( + MP_PATCHES[patch].summary + for patch in applied_mp_patches if patch in MP_PATCHES + ) + patches = _get_patch_summary_union(patches, *iter_summaries) + return MappingProxyType({ + target: frozenset(attrs) + for target, attrs in _filter_patches(patches).items() + }) + + +def _get_toml_patches_section(mp_patches: Collection[str]) -> str: + mp_patches_as_dict = { + name: name in mp_patches for name in MP_PATCHES + if not _mp_patch_is_internal(name) + } + return ( + '[tool.line_profiler.child_processes.multiprocessing.patches]\n' + + '\n'.join( + f'{patch} = {str(applied).lower()}' + for patch, applied in mp_patches_as_dict.items() + ) + ) + + +def _get_patch_summary_union( + *summaries: _PatchSummary, +) -> dict[str, frozenset[str]]: + result: dict[str, frozenset[str]] = {} + for summary in summaries: + for target, attrs in summary.items(): + result[target] = result.get(target, frozenset()) | frozenset(attrs) + return result + + +def _summarize_patches( + summaries: Collection[tuple[bool, _PatchSummary]] +) -> dict[str, dict[str, bool]]: + """ + Example: + >>> _summarize_patches([(False, {'foo': {'bar'}})]) + {'foo': {'bar': False}} + >>> _summarize_patches([ # doctest: +NORMALIZE_WHITESPACE + ... (False, {'foo': {'bar', 'baz'}}), + ... (True, {'foo': {'baz', 'foobar'}, 'spam': {'ham'}}), + ... (False, {'foo': {'baz'}, 'spam': {'eggs'}}) + ... ]) + {'foo': {'bar': False, 'baz': True, 'foobar': True}, + 'spam': {'eggs': False, 'ham': True}} + """ + def get_all_mentioned(s: Iterable[_PatchSummary]) -> dict[str, set[str]]: + all_items: dict[str, set[str]] = {} + for summary in s: + for target, attrs in summary.items(): + all_items.setdefault(target, set()).update(attrs) + return all_items + + all_items = get_all_mentioned(s for _, s in summaries) + all_patched = get_all_mentioned(s for applied, s in summaries if applied) + result: dict[str, dict[str, bool]] = { + target: + {attr: attr in all_patched.get(target, set()) for attr in attrs} + for target, attrs in all_items.items() + } + # Normalize the order for convenience + return { + target: dict(sorted(attrs.items())) + for target, attrs in sorted(result.items()) + } + + +def _filter_patches(summary: _PatchSummary) -> dict[str, set[str]]: + result: dict[str, set[str]] = {} + for target, attrs in summary.items(): + try: + obj = _import_target(target) + except ImportError: + continue + present_attrs = {a for a in attrs if hasattr(obj, a)} + # Drop if none of the attributes is present + if present_attrs: + result[target] = present_attrs + return result + + +_GLOBAL_MINIMAL_PATCHES = { + 'multiprocessing': frozenset({_PATCHED_MARKER}), +} +# Get patches that are dynamically resolved: while these patches are +# always applied, some of the patch targets are +# platform-/Pyhon-version-specific and may not always exist +_dynamically_resolved_patch_summaries: Iterable[_PatchSummary] = ( + patch.summary for name, patch in MP_PATCHES.items() + # Basic `multiprocessing` patches are always applied + if _mp_patch_is_internal(name) +) +_dynamically_resolved_patch_summaries = itertools.chain( + _dynamically_resolved_patch_summaries, + # some platforms e.g. Windows don't have `fork()` + [{'os': frozenset({'fork'})}], +) +_dynamically_resolved_patch_summaries = cast( # See `ty` issue #3428 + Iterable[_PatchSummary], + map(_filter_patches, _dynamically_resolved_patch_summaries), +) +_GLOBAL_MINIMAL_PATCHES = _get_patch_summary_union( + _GLOBAL_MINIMAL_PATCHES, *_dynamically_resolved_patch_summaries, +) + +# This is only patched if we called +# `_line_profiler_hooks.load_pth_hook()` +_HOOK_PATCHES = { + f'{load_pth_hook.__module__}.{load_pth_hook.__qualname__}': + frozenset({'called'}), +} +# Upper limit of what we could've patched +_GLOBAL_PATCHES = _get_patch_summary_union( + _GLOBAL_MINIMAL_PATCHES, + _HOOK_PATCHES, + get_patched_attributes([ + name for name in MP_PATCHES if not _mp_patch_is_internal(name) + ]), +) +# Actual patches using the default config +DEFAULT_GLOBAL_PATCHES = _get_patch_summary_union( + _GLOBAL_MINIMAL_PATCHES, get_patched_attributes(), +) + +_DEFAULT_MP_CONFIG = MPConfig.from_config(ConfigSource.from_default()) + + +@pytest.mark.parametrize(('run_profiled_code', 'label1'), + [(True, 'run-profiled'), (False, 'run-unrelated')]) +@pytest.mark.parametrize(('as_module', 'label2'), + [(True, 'run_module'), (False, 'run_path')]) +@pytest.mark.parametrize(('debug', 'label3'), + [(True, 'with-debug'), (False, 'no-debug')]) +def test_runpy_patches( + capsys: pytest.CaptureFixture[str], + ext_module: _ModuleFixture, + test_module: _ModuleFixture, + test_module_clone: _ModuleFixture, + create_cache: Callable[..., LineProfilingCache], + run_profiled_code: bool, + as_module: bool, + debug: bool, + label1: str, label2: str, label3: str, +) -> None: + """ + Test that the :py:mod:`runpy` clone created by + :py:func:`line_profiler._child_process_profiling\ +.create_runpy_wrapper` + correctly sets up profiling when its ``run_*()`` functions are + called. + """ + class restore_argv: + def __enter__(self) -> None: + self.argv = list(sys.argv) + + def __exit__(self, *_, **__) -> None: + sys.argv[:] = self.argv + + cache = create_cache( + rewrite_module=test_module.path, + profiling_targets=[str(ext_module.path)], + profile_imports=True, + debug=debug, + ) + assert cache.profiler is not None + runpy = create_runpy_wrapper(cache) + + nnums = 42 + nprocs = 2 + # If we're running some unrelated code, the profiler should not be + # involved + if run_profiled_code: + module = test_module + num_invocations, num_loops = 1, nprocs + expected_funcs: list[str] = ['my_external_sum'] + else: + module = test_module_clone + num_invocations, num_loops = 0, 0 + expected_funcs = [] + if as_module: + first_arg = module.name + runner = partial(runpy.run_module, alter_sys=True) + called_func = 'run_module' + else: + first_arg = str(module.path) + runner = runpy.run_path + called_func = 'run_path' + + # Check that the code is run + module.install(local=True, deps_only=not as_module) + with restore_argv(): + sys.argv[:] = [first_arg, f'--length={nnums}', '-n', str(nprocs)] + runner(first_arg, run_name='__main__') + stdout = capsys.readouterr().out + assert stdout.rstrip('\n') == str(nnums * (nnums + 1) // 2) + + # Check that profiler has received the appropriate targets + funcs = [func.__name__ for func in getattr(cache.profiler, 'functions')] + assert funcs == expected_funcs + + # Check that calls in the current process are profiled iif the + # correct file is executed + with StringIO() as sio: + cache.profiler.print_stats(sio) + stats = sio.getvalue() + _check_output(stats, 'EXT-INVOCATION', num_invocations) + _check_output(stats, 'EXT-LOOP', num_loops) + + # Check the debug-log entries are correctly gathered + _search_cache_logs( + cache, + debug, + { + rf'calling .*{called_func}\(': True, + r'calling .*exec\(': run_profiled_code, + }, + match_individual_messages=True, + flags=re.IGNORECASE, + ) + + +def test_cache_dump_load( + create_cache: Callable[..., LineProfilingCache], +) -> None: + """ + Test that: + + - We can round-trip the cache via :py:meth:`LineProfilingCache.dump` + and :py:meth:`LineProfilingCache.load` + + - The same instance is :py:meth:`LineProfilingCache.load`-ed in + subsequent calls + """ + original = create_cache( + profiling_targets=['foo', 'bar', 'baz'], main_pid=123456, + ) + cache_instances: list[LineProfilingCache] = [original] + envvars: set[str] = set(os.environ) + try: + original.inject_env_vars() # Needed for `.load()` + # Also test slipping stuff into the `._additional_data` + original._additional_data['foo'] = [1, 'string', None] + try: + # Env vars should be inserted + assert set(os.environ) == envvars.union(original.environ) > envvars + original.dump() + loaded = original.load() + cache_instances.append(loaded) + reloaded = original.load() + cache_instances.append(reloaded) + assert original is not loaded is reloaded + # Compare init fields + for field in dataclasses.fields(LineProfilingCache): + if not field.init: + continue + assert ( + getattr(original, field.name) + == getattr(loaded, field.name) + ) + # Compare `._additional_data` + assert original._additional_data == loaded._additional_data + finally: # Explicitly cleanup + for cache in cache_instances: + cache.cleanup() + finally: # Env vars restored after cleanup + assert set(os.environ) == envvars + + +@(_Params.new(('wrap_os_fork', 'label1'), + [(True, 'with-wrap-fork'), (False, 'no-wrap-fork')]) + + _Params.new(('debug', 'label2'), + [(True, 'with-debug'), (False, 'no-debug')]) + + _Params.new(('patch_pool', 'patch_process', 'intercept_logs', 'label3'), + [(True, True, True, 'all-patches'), + (True, True, False, 'pool-and-process'), + (True, False, True, 'pool-and-logging'), + (True, False, False, 'pool-only'), + (False, True, True, 'process-and-logging'), + (False, True, False, 'process-only'), + (False, False, True, 'logging-only'), + (False, False, False, 'no-patches')])).sorted() +def test_cache_setup_main_process( + tmp_path_factory: pytest.TempPathFactory, + create_cache: Callable[..., LineProfilingCache], + wrap_os_fork: bool, + debug: bool, + patch_pool: bool, + patch_process: bool, + intercept_logs: bool, + label1: str, label2: str, label3: str, +) -> None: + """ + Test that :py:meth:`LineProfilingCache._setup_in_main_process` works + as expected. + """ + mp_patches: set[str] = set() + if patch_pool: + mp_patches.add('pool') + if patch_process: + mp_patches.add('process') + if intercept_logs: + mp_patches.add('logging') + + config = tmp_path_factory.mktemp('myconfig') / 'mytoml.toml' + config.write_text(_get_toml_patches_section(mp_patches)) + cache = create_cache(debug=debug, config=config) + + # Check that only the requested patches are applied + patches = _summarize_patches([ + (True, _GLOBAL_MINIMAL_PATCHES), + *( + (name in mp_patches, _filter_patches(patch.summary)) + for name, patch in MP_PATCHES.items() + if not _mp_patch_is_internal(name) + ), + ]) + try: + patches['os']['fork'] = wrap_os_fork + except KeyError: + # `os.fork()` pruned because it doesn't exist on e.g. Windows + assert not hasattr(os, 'fork') + + with ExitStack() as stack: + patched = stack.enter_context(_preserve_attributes(patches)) + compare_patched = partial( + _preserve_attributes.compare_with_current_values, patched, + ) + original_pths = stack.enter_context(_preserve_pth_files()) + cache._setup_in_main_process(wrap_os_fork=wrap_os_fork) + # There should be exactly one extra `.pth` file + new_pth_hook, = _preserve_pth_files.get_pth_files() - original_pths + # Check whether the patches are applied + compare_patched(operator.is_not, assert_true=patches) + # Check whether the patches are reversed + cache.cleanup() + compare_patched() + # Check that the instance is set as the `.load()`-ed one + assert cache is cache.load() + + # Check the debug-log output + patterns: dict[str, bool] = dict.fromkeys( + [ + r'\(main process\)', + r'Injecting env var.*\$\{LINE_PROFILER_\w+\}', + re.escape(new_pth_hook), + ], + True, + ) + for target, maybe_patches in patches.items(): + patterns.update( + ('Patched.*' + re.escape(f'{target}.{attr}'), is_patched) + for attr, is_patched in maybe_patches.items() + ) + _search_cache_logs(cache, debug, patterns) + + +@pytest.mark.parametrize(('wrap_os_fork', 'label1'), + [(True, 'with-wrap-fork'), (False, 'no-wrap-fork')]) +@pytest.mark.parametrize(('preimports', 'label2'), + [(True, 'with-preimports'), (False, 'no-preimports')]) +@pytest.mark.parametrize(('new_profiler', 'label3'), + [(True, 'no-profiler'), (False, 'with-profiler')]) +@pytest.mark.parametrize(('debug', 'label4'), + [(True, 'with-debug'), (False, 'no-debug')]) +@pytest.mark.parametrize('n', [100]) +@_preserve_attributes(_GLOBAL_PATCHES) +def test_cache_setup_child( + create_cache: Callable[..., LineProfilingCache], + ext_module_object: ModuleType, + another_pid: int, + wrap_os_fork: bool, + preimports: bool, + new_profiler: bool, + debug: bool, + n: int, + label1: str, label2: str, label3: str, label4: str, +) -> None: + """ + Test that :py:meth:`LineProfilingCache._setup_in_child_process` + works as expected. + """ + def list_profiled_funcs() -> list[str]: + return [ + f'{func.__module__}.{func.__qualname__}' + for func in getattr(cache.profiler, 'functions', []) + ] + + func = ext_module_object.my_external_sum + cache = create_cache( + profiling_targets=[f'{func.__module__}.{func.__qualname__}'], + preimports_module=preimports, + _use_curated_profiler=not new_profiler, + main_pid=another_pid, + debug=debug, + ) + assert (cache.profiler is None) == new_profiler + + seen_funcs = list_profiled_funcs() + if preimports: + preimport_targets = list(cache.profiling_targets) + else: + preimport_targets = [] + + with _preserve_obj_attributes(os, ['fork']) as preserved: + old_fork = preserved['fork'] + # Check that we're only setting up if there isn't already a + # profiler + assert cache._setup_in_child_process( + wrap_os_fork=wrap_os_fork, context='test_cache_setup_child', + ) == new_profiler + assert cache.profiler + if not new_profiler: + return + + # Check that the profiler has been presented with the profiling + # target + assert list_profiled_funcs() == (seen_funcs + preimport_targets) + + # Check that on cache cleanup: + # - Profiling data is collected + # - `os.fork()` is restored + # - The warning for empty profiling files is only issued when + # expected + assert func(range(1, n + 1)) == n * (n + 1) // 2 + stats = cache.profiler.get_stats() + for callback, has_nonempty_file, has_stats, fork_patched in [ + (lambda: None, False, False, wrap_os_fork), + (cache.cleanup, True, preimports, False), + ]: + callback() + with _check_warnings() as cw: + if has_nonempty_file: + check_warning = cw.forbid_warnings + else: + check_warning = cw.expect_warnings + check_warning(r'.* file\(s\) .* empty', module='line_profiler') + gathered = cache.gather_stats() + assert any(gathered.timings.values()) == has_stats, gathered + if hasattr(os, 'fork'): + assert (os.fork is not old_fork) == fork_patched + else: # E.g. Windows + assert old_fork == _NotSupplied.NOT_SUPPLIED + # Check that after cleaning up the profiler has been disabled + assert not getattr(cache.profiler, 'enable_count', 0) + + # Check that profiling results have been written to the cache + # directory + stats_file, = Path(cache.cache_dir).glob('*.lprof') + assert LineStats.from_files(stats_file) == stats == gathered + + # Check the debug-log output + patterns = { + f'Set up .*profiler.* {id(cache.profiler):#x}': True, + 'Loading preimports': preimports, + 'Created .*' + re.escape(stats_file.name): True, + 'Cleanup succeeded.*: .*dump_stats': True, + 'Loading results .*' + re.escape(stats_file.name): True, + } + _search_cache_logs(cache, debug, patterns) + + +@pytest.mark.parametrize('ppid_should_match', [True, False, None]) +@_preserve_attributes(_GLOBAL_PATCHES) +def test_load_pth_hook( + create_cache: Callable[..., LineProfilingCache], + another_pid: int, + ppid_should_match: bool | None, +) -> None: + """ + Simulate calling :py:func:`_line_profiler_hooks.load_pth_hook()` in + a child process. + + Notes: + + - The function is CALLED in the .pth file, but we don't actually + NEED a .pth file to call and test it. + + - The counterpart :py:meth:`line_profiler\ +._child_process_profiling.cache.LineProfilingCache.write_pth_hook()` + is implicitly tested in + :py:func:`test_cache_setup_main_process()`. + """ + # This test is mostly here to hack coverage; since the function is + # only to be called in child processes, `coverage` seems to have + # trouble getting data on it... + + # We basically only need this cache instance to set up the + # environment variables and the requisite files... + cache = create_cache(main_pid=another_pid) + if ppid_should_match is not None: + cache.inject_env_vars() + if ppid_should_match: + call_ppid = another_pid + else: # On a PPID mismatch, the function bails after checking + call_ppid = another_pid + 10 + else: + # Without the requisite envvars, the hook should bail very + # quickly (due to the `environ` lookup erroring out), regardless + # of the provided PPID + call_ppid = 0 + cache.dump() + + compare = _preserve_attributes.compare_with_current_values + patches = {**DEFAULT_GLOBAL_PATCHES, **_HOOK_PATCHES} + with _preserve_attributes(patches) as patched: + try: + # NOTE: this creates a cache instance that isn't + # automatically cleaned up by the `create_cache()` + # fixture!!! Hence the try-finally + load_pth_hook(call_ppid) + # Check that the patches are applied where appropriate + assert ( + getattr(load_pth_hook, 'called', False) + == bool(ppid_should_match) + ) + if ppid_should_match: + compare(patched, operator.is_not) + else: # no-op + compare(patched) + return + # Check that calling `load_pth_hook()` again is a no-op + with _preserve_attributes(patches) as re_patched: + load_pth_hook(call_ppid) + compare(re_patched) + finally: + try: + current_cache = LineProfilingCache.load() + except Exception: + pass + else: + current_cache.cleanup() + # Check that the patches are reversed + compare(patched) + + +@_preserve_pth_files() +@_preserve_attributes(_GLOBAL_PATCHES) +def _test_apply_mp_patches_inner( + tmp_path_factory: pytest.TempPathFactory, + create_cache: Callable[..., LineProfilingCache], + ext_module_object: ModuleType, + test_module_object: ModuleType, + start_method: Literal['fork', 'forkserver', 'spawn', 'dummy'], + mp_patches: Collection[str], + fail: bool, + n: int, + nprocs: int, +) -> None: + def is_valid_stats_file(path: os.PathLike[str] | str) -> bool: + try: + LineStats.from_files(path, on_empty='error', on_defective='error') + except Exception: + return False + return True + + def get_lineno(path: os.PathLike[str] | str, query: str) -> int: + with Path(path).open() as fobj: + for i, line in enumerate(fobj): + if query in line: + return 1 + i + raise RuntimeError( + f'Did not find line containing {query!r} in {path!r}', + ) + + config = tmp_path_factory.mktemp('myconfig') / 'mytoml.toml' + intercept_logs = 'logging' in mp_patches + patch_process = 'process' in mp_patches + cfg_chunks: list[str] = [ + _get_toml_patches_section(mp_patches), + # This is easier to debug than `ResultMismatch` + '[tool.line_profiler.child_processes.multiprocessing.polling]\n' + 'on_timeout = "error"', + ] + config.write_text('\n\n'.join(cfg_chunks)) + + # Note: no need to test the case for `my_local_sum()` separately, + # with `preimports_module=True`, both are just imported and added + # to the profiler, so the code paths are the same + profiled_func = ext_module_object.my_external_sum + # Note: it would have been more intuitive to just apply `@_timeout` + # to this whole function and run it all in a new thread, but that + # seem to interact adversely with patch application and prof-data + # collection... so just timeout the function invoking + # `multiprocessing` + sum_with_timeout = _timeout(test_module_object.sum_in_child_procs) + called_func = partial( + sum_with_timeout, + n=nprocs, + my_sum=profiled_func, + start_method=start_method, + fail=fail, + ) + + func_name = f'{profiled_func.__module__}.{profiled_func.__qualname__}' + cache = create_cache( + profiling_targets=[func_name], + preimports_module=True, + config=config, + debug=True, + ) + # Note: + # - The reversibility of the patches have already been tested in + # `test_cache_setup_main_process()`, so we just actually test the + # patched-in components themselves here. + # - `._setup_in_main_process()` doesn't include actually doing the + # preimports. To may the results more consistent between + # `start_method='dummy'` and the others, manually do them below. + cache._setup_in_main_process() # This calls `apply()` + assert cache.profiler is not None + assert cache.preimports_module is not None + run_path(str(cache.preimports_module), {'profile': cache.profiler}) + + timing_key = ( + inspect.getfile(profiled_func), + inspect.getsourcelines(profiled_func)[1], + profiled_func.__qualname__, + ) + assert ext_module_object.__file__ + loop_line = get_lineno(ext_module_object.__file__, 'EXT-LOOP') + + nloops_expected = n + if not fail: + # Counts from the one final sum over the parallel results + nloops_expected += nprocs + + if start_method not in ('dummy', *START_METHODS): + pytest.skip( + f'`multiprocessing` start method {start_method!r} ' + 'not available on the platform' + ) + + # Note: manually handle the error here instead of using + # `pytest.raises()` since we want certain `RuntimeError`s to be + # propagated and handled by `@pytest.mark.retry` + fail_msg = 'forced failure' + try: + result = called_func(n) + except RuntimeError as e: + if not (fail and str(e) == fail_msg): + raise + else: + if fail: + msg = f"expected `RuntimeError({fail_msg!r})`, no error raised" + raise ValueError(msg) + else: # Check correctness of the results + assert result == n * (n + 1) // 2 + + # Check that calls in children are traced + cache.cleanup() + stats = cache.profiler.get_stats() + stats += cache.gather_stats() + entries = stats.timings[timing_key] + nloops = sum(nhits for ln, nhits, _ in entries if ln == loop_line) + ResultMismatch.compare(nloops_expected, nloops) + + # Check the debug logs to see if we have done everything right, esp. + # the logging interception part not covered by other tests + patterns: dict[str, bool] = {} + if patch_process: + # Note: if we're not using `Process`-based patch, there is no + # guaratee that the profiling result is written via cleanup + iter_stats: Iterable[Path] = Path(cache.cache_dir).glob('*.lprof') + iter_stats = cast( # See `ty` issue #3428 + Iterable[Path], filter(is_valid_stats_file, iter_stats), + ) + pat = 'Cleanup succeeded.*: .*dump_stats.*{}' + patterns.update({ + pat.format(re.escape(path.name)): True for path in iter_stats + }) + patterns[re.escape('`multiprocessing` logging (debug)')] = intercept_logs + _search_cache_logs(cache, True, patterns) + + +def _test_apply_mp_patches( + patch_pool: bool | None = None, + patch_process: bool | None = None, + intercept_logs: bool | None = None, + trace_pids: bool | None = None, + **kwargs +) -> None: + patches = cast(dict[str, bool], _DEFAULT_MP_CONFIG.patches.copy()) + for name, applied in { + 'pool': patch_pool, 'process': patch_process, + 'logging': intercept_logs, 'child_pids': trace_pids, + }.items(): + if applied is not None: + patches[name] = applied + mp_patches = [name for name, applied in patches.items() if applied] + with _check_warnings() as cw: + if 'child_pids' in mp_patches: + # With PID bookkeeping we should be able to weed out all + # the child processes which didn't perform any work + cw.forbid_warnings(category=UserWarning, module='line_profiler') + cw.forbid_warnings(module='multiprocessing') + _test_apply_mp_patches_inner(mp_patches=mp_patches, **kwargs) + + +@(_Params.new('start_method', + ['fork', 'forkserver', 'spawn', 'dummy'], + defaults='dummy') + # We only need to check if `intercept_logs = logging` work, the other + # parametrizations don't matter + + _Params.new(('intercept_logs', 'label1'), + [(True, 'with-logging'), (False, 'no-logging')], + defaults=(None, 'default-logging')) + # Same deal with `trace_pids = child_pids` + + _Params.new(('trace_pids', 'label2'), + [(True, 'with-child_pids'), (False, 'no-child-pids')], + defaults=(None, 'default-child-pids'))).sorted() +@pytest.mark.parametrize(('patch_pool', 'patch_process', 'label3'), + [(True, True, 'pool-and-process'), + (True, False, 'pool-only'), + (False, True, 'process-only')]) +@pytest.mark.parametrize(('n', 'nprocs'), [(100, 2)]) +def test_apply_mp_patches_success( + tmp_path_factory: pytest.TempPathFactory, + create_cache: Callable[..., LineProfilingCache], + ext_module_object: ModuleType, + test_module_object: ModuleType, + start_method: Literal['fork', 'forkserver', 'spawn', 'dummy'], + patch_pool: bool, + patch_process: bool, + intercept_logs: bool | None, + trace_pids: bool | None, + n: int, + nprocs: int, + label1: str, + label2: str, + label3: str, +) -> None: + """ + Test that :py:func:`line_profiler._child_process_profiling\ +.multiprocessing_patches.apply` + works as expected when the parallel workload does not error out. + + See also: + :py:func:`test_apply_mp_patches_failure` + """ + _test_apply_mp_patches( + patch_pool, + patch_process, + intercept_logs, + trace_pids, + tmp_path_factory=tmp_path_factory, + create_cache=create_cache, + ext_module_object=ext_module_object, + test_module_object=test_module_object, + start_method=start_method, + fail=False, + n=n, + nprocs=nprocs, + ) + + +# XXX: on POSIX child processes can hang around for long enough for +# profiling-stats collection to occur somewhat robustly, thanks to +# signal handling. But unfortunately on Windows: +# - When `patch_pool` is true, we wrap the task callables so that they +# always write profiling stats before returning/erroring out. This +# incurs extra overhead, but effectively prevents the reliquishing of +# control back to the parent process before the stats are ready. +# - However, when `patch_pool` is false, we can only try to block/delay +# child-process termination. A timeout is used to prevent indefinite +# waits for them to finish, and there's always the off chance that the +# end-of-process cleanup still haven't finished at the end. +# Hence the conditional need for retries... +@pytest.mark.retry( + retries=2, + condition='_WINDOWS and not patch_pool', + exceptions=(ResultMismatch, _Poller.Timeout), +) +@pytest.mark.parametrize('start_method', + ['fork', 'forkserver', 'spawn', 'dummy']) +@pytest.mark.parametrize(('patch_pool', 'patch_process', 'label'), + [(True, True, 'pool-and-process'), + (True, False, 'pool-only'), + (False, True, 'process-only')]) +@pytest.mark.parametrize(('n', 'nprocs'), [(100, 2)]) +def test_apply_mp_patches_failure( + tmp_path_factory: pytest.TempPathFactory, + create_cache: Callable[..., LineProfilingCache], + ext_module_object: ModuleType, + test_module_object: ModuleType, + start_method: Literal['fork', 'forkserver', 'spawn', 'dummy'], + patch_pool: bool, + patch_process: bool, + n: int, + nprocs: int, + label: str, +) -> None: + """ + Test that :py:func:`line_profiler._child_process_profiling\ +.multiprocessing_patches.apply` + works as expected when the parallel workload errors out. + + See also: + :py:func:`test_apply_mp_patches_success` + """ + _test_apply_mp_patches( + patch_pool, + patch_process, + tmp_path_factory=tmp_path_factory, + create_cache=create_cache, + ext_module_object=ext_module_object, + test_module_object=test_module_object, + start_method=start_method, + fail=True, + n=n, + nprocs=nprocs, + ) + + +# XXX: End of tests for implementation details + +# ========================= Integration tests ========================== + + +def _get_mp_start_method_fuzzer(label_name: str | None) -> _Params: + """ + Returns: + :py:class:`_Params` object which does a full Cartesian-product + fuzz between ``fail`` (true or false) and ``start_method`` + ('fork', 'forkserver', and 'spawn'; default :py:const:`None`) + """ + if label_name is None: + label_name, drop_label = '_', True + else: + drop_label = False + fuzz_fail = _Params.new(('fail', label_name), + [(True, 'failure'), (False, 'success')], + defaults=(False, 'success')) + if drop_label: + fuzz_fail = fuzz_fail.drop_params(label_name) + fuzz_start = _Params.new('start_method', ['fork', 'forkserver', 'spawn'], + defaults=None) + return fuzz_fail * fuzz_start + + +@(_Params.new(('run_func', 'label1'), + [(run_module, 'module'), (run_script, 'script')]) + * _Params.new(('use_local_func', 'label2'), + [(True, 'local'), (False, 'ext')]) + # Python can't pickle things unless they resided in a retrievable + # location (so not the script supplied by `python -c`) + + _Params.new(('run_func', 'label1', 'use_local_func', 'label2'), + [(run_literal_code, 'literal-code', False, 'ext')]) + # Also fuzz the parallelization-related stuff, esp. check what + # happens if an exception is raised inside the parallelly-run func + + _get_mp_start_method_fuzzer('label3') + + _Params.new(('nnums', 'nprocs'), [(200, None), (None, 3)], + defaults=(None, None))).sorted() +def test_multiproc_script_sanity_check( + run_func: Callable[..., subprocess.CompletedProcess], + test_module: _ModuleFixture, + tmp_path_factory: pytest.TempPathFactory, + use_local_func: bool, + fail: bool, + start_method: Literal['fork', 'forkserver', 'spawn'] | None, + nnums: int | None, + nprocs: int | None, + # Dummy arguments to make `pytest` output more legible + label1: str, label2: str, label3: str, +) -> None: + """ + Sanity check that the test module functions as expected when run + with vanilla Python. + """ + run_func( + test_module, tmp_path_factory, + runner=sys.executable, profile=False, + fail=fail, + use_local_func=use_local_func, + start_method=start_method, + nnums=nnums, nprocs=nprocs, + ) + + +@pytest.mark.parametrize( + ('run_func', 'label1'), + [(run_module, 'module'), + (run_script, 'script'), + (run_literal_code, 'literal-code')] +) +@pytest.mark.parametrize( + ('runner', 'outfile', 'profile', + 'label2'), # Dummy argument to make `pytest` output more legible + # This is essentially a no-op since it doesn't actually do + # line-profiling, but we check that code path for completeness + [(['kernprof', '-q', '--no-line'], 'out.prof', False, 'cProfile')] + # Run line profiling with and w/o profiling targets + + [(['kernprof', '-q', '-l'], 'out.lprof', False, + 'line_profiler-inactive'), + (['kernprof', '-q', '-l'], 'out.lprof', True, + 'line_profiler-active')], +) +def test_running_multiproc_script( + run_func: Callable[..., subprocess.CompletedProcess], + test_module: _ModuleFixture, + tmp_path_factory: pytest.TempPathFactory, + runner: str | list[str], + outfile: str | None, + profile: bool, + # Dummy arguments to make `pytest` output more legible + label1: str, label2: str, +) -> None: + """ + Check that `kernprof` can RUN the test module in various contexts + (`kernprof [...] `, `kernprof [...] -m `, and + `kernprof [...] -c "code"`). + + Notes: + - See issue #422 for the original motivation. + + - This test does not test the actual profiling, just the + execution of the code and presence of profiling data + thereafter. + """ + run_func(test_module, tmp_path_factory, runner, outfile, profile) + + +_fuzz_prof_mp_run_func = _Params.new(('run_func', 'label1'), + [(run_module, 'module'), + (run_script, 'script'), + (run_literal_code, 'literal-code')], + defaults=(run_script, 'script')) +_fuzz_prof_mp_markers = ( + (_fuzz_prof_mp_run_func + + _Params.new(('prof_child_procs', 'label2'), + [(True, 'with-child-prof'), (False, 'no-child-prof')]) + + _get_mp_start_method_fuzzer(None)) + # Test all `multiproc` start methods with both locally- and + # externally-defined profiling targets + * (_Params.new(('preimports', 'label3'), [(False, 'no-preimports')]) + + _Params.new(('use_local_func', 'label4'), + [(True, 'local'), (False, 'external')], + defaults=(False, 'external'))) + # The 'with-preimports' case is already tested rather thoroughly in + # `test_apply_mp_patches()`, so exclude these from the above "main" + # param matrix and just test the different `kernprof` modes via the + # `run_func()`s + + (_fuzz_prof_mp_run_func + + _Params.new(('preimports', 'label3'), [(True, 'with-preimports')])) +).sorted().split_on_params('fail') + + +def _test_profiling_multiproc_script( + run_func: Callable[..., subprocess.CompletedProcess], + test_module: _ModuleFixture, + ext_module: _ModuleFixture, + tmp_path_factory: pytest.TempPathFactory, + prof_child_procs: bool, + preimports: bool, + use_local_func: bool, + fail: bool, + start_method: Literal['fork', 'forkserver', 'spawn'] | None, + nnums: int, + nprocs: int, +) -> None: + # How many calls do we expect? + nhits = dict.fromkeys( + ['EXT-INVOCATION', 'EXT-LOOP', 'LOCAL-INVOCATION', 'LOCAL-LOOP'], 0, + ) + # Make sure we're profiling the right function + tag = 'LOCAL' if use_local_func else 'EXT' + tag_call = tag + '-INVOCATION' + tag_loop = tag + '-LOOP' + if not fail: + # The final sum in the parent process should always be profiled + # unless the child processes failed and we never returned from + # `Pool.starmap()` + nhits[tag_call] += 1 + nhits[tag_loop] += nprocs + if prof_child_procs: + # When profiling extends into child processes, each of them + # invokes the sum function once and when combined they loop thru + # all the items + nhits[tag_call] += nprocs + nhits[tag_loop] += nnums + + runner = ['kernprof', '-l'] + runner.extend([ + '--{}prof-child-procs'.format('' if prof_child_procs else 'no-'), + '--{}preimports'.format('' if preimports else 'no-'), + ]) + if not use_local_func: + # Also make sure to include the external module in `--prof-mod` + runner.append(f'--prof-mod={ext_module.name}') + run_func( + test_module, tmp_path_factory, + runner=runner, + outfile='out.lprof', + profile=True, + use_local_func=use_local_func, + fail=fail, + start_method=start_method, + nhits=nhits, + nnums=nnums, + nprocs=nprocs, + timeout=_TEST_TIMEOUT, + debug_log=( + 'debug.log' if prof_child_procs and _DEBUG else None + ), + ) + + +@(_fuzz_prof_mp_markers[False]) +@pytest.mark.parametrize( + # XXX: should we explicitly test the single-proc case? We already + # have quite a lot of subtests tho... + ('nnums', 'nprocs'), [(2000, 3)], +) +def test_profiling_multiproc_script_success( + run_func: Callable[..., subprocess.CompletedProcess], + test_module: _ModuleFixture, + ext_module: _ModuleFixture, + tmp_path_factory: pytest.TempPathFactory, + prof_child_procs: bool, + preimports: bool, + use_local_func: bool, + start_method: Literal['fork', 'forkserver', 'spawn'] | None, + nnums: int, + nprocs: int, + # Dummy arguments to make `pytest` output more legible + label1: str, label2: str, label3: str, label4: str, +) -> None: + """ + Check that `kernprof` can PROFILE the test module in various + contexts when the parallel workload runs without errors, optionally + extending profiling into child processes. + + Note: + This test function is heavily parametrized. Here is why that is + necessary: + + - ``run_func`` tests the different :cmd:`kernprof` modes (see + :py:func:`~.test_running_multiproc_script`). + + - ``preimports`` tests that both mechanisms for setting up + profiling targets work: + + - :py:const:`True`: child processes import the module + generated by + :py:mod:`line_profiler.autoprofile.eager_preimports`, like + the main :py:mod:`kernprof` process does. + + - :py:const:`False`: child processes rewrite the executed code + before passing it to :py:mod:`runpy`, similar to what + :py:mod:`line_profiler.autoprofile.autoprofile` does. + + These code paths go through different + :py:mod:`multiprocessing` components that we have patched and + thus needs separate testing. + + - ``use_local_func`` tests that we can consistently set up + profiling in both functions locally-defined in the profiled + code and imported by it. + + - ``fail`` tests that our patches and hook doesn't choke when + exceptions occur in child processes, and profiling data can + still be collected. + + - ``start_method`` tests whether all available + :py:mod:`multiprocessing` start methods are covered. + + - ``prof_child_procs`` of course toggles whether to do the + patches to set up profiling in child processes. + + See also: + :py:func:`test_profiling_multiproc_script_failure` + """ + _test_profiling_multiproc_script( + run_func=run_func, + test_module=test_module, + ext_module=ext_module, + tmp_path_factory=tmp_path_factory, + prof_child_procs=prof_child_procs, + preimports=preimports, + use_local_func=use_local_func, + fail=False, + start_method=start_method, + nnums=nnums, + nprocs=nprocs, + ) + + +@(_fuzz_prof_mp_markers[True]) +@pytest.mark.parametrize(('nnums', 'nprocs'), [(2000, 3)]) +def test_profiling_multiproc_script_failure( + run_func: Callable[..., subprocess.CompletedProcess], + test_module: _ModuleFixture, + ext_module: _ModuleFixture, + tmp_path_factory: pytest.TempPathFactory, + prof_child_procs: bool, + preimports: bool, + use_local_func: bool, + start_method: Literal['fork', 'forkserver', 'spawn'] | None, + nnums: int, + nprocs: int, + # Dummy arguments to make `pytest` output more legible + label1: str, label2: str, label3: str, label4: str, +) -> None: + """ + Check that `kernprof` can PROFILE the test module in various + contexts when the parallel workload errors out, optionally + extending profiling into child processes. + + See also: + :py:func:`test_profiling_multiproc_script_success` + """ + _test_profiling_multiproc_script( + run_func=run_func, + test_module=test_module, + ext_module=ext_module, + tmp_path_factory=tmp_path_factory, + prof_child_procs=prof_child_procs, + preimports=preimports, + use_local_func=use_local_func, + fail=True, + start_method=start_method, + nnums=nnums, + nprocs=nprocs, + ) + + +_fuzz_bare = ( + _Params.new(('use_subprocess', 'label1'), + [(True, 'subprocess.run'), (False, 'os.system')]) + * _Params.new(('prof_child_procs', 'label2'), + [(True, 'with-child-prof'), (False, 'no-child-prof')]) + * _Params.new('n', [200]) +) + + +def _test_profiling_bare_python( + tmp_path_factory: pytest.TempPathFactory, + ext_module: _ModuleFixture, + use_subprocess: bool, + prof_child_procs: bool, + fail: bool, + n: int, +) -> None: + ext_module.install(children=True) + temp_dir = tmp_path_factory.mktemp('mytemp') + + script_path = temp_dir / 'my-script.py' + script_content = strip(""" + from {EXT_MODULE} import my_external_sum + + + if __name__ == '__main__': + numbers = list(range(1, 1 + {N})) + result = my_external_sum(numbers, {FAIL}) + """.format( + EXT_MODULE=ext_module.name, + N=n, + FAIL=fail, + )) + script_path.write_text(script_content) + + out_file = temp_dir / 'out.lprof' + debug_log_file = temp_dir / 'debug.log' + write_debug = _DEBUG and prof_child_procs + cmd = [ + 'kernprof', '-lv', '--preimports', + f'--prof-mod={ext_module.name}', + f'--outfile={out_file}', + '--{}prof-child-procs'.format('' if prof_child_procs else 'no-'), + ] + if write_debug: + cmd.append(f'--debug-log={debug_log_file}') + sub_cmd = [sys.executable, str(script_path)] + if use_subprocess: + code = strip(f""" + import subprocess + + + subprocess.run({sub_cmd!r}, check=True) + """) + else: + code = strip(""" + import os + + + if os.system({!r}): + raise RuntimeError('called process failed') + """.format(concat_command_line(sub_cmd))) + cmd.extend(['-c', code]) + proc = _run_subproc( + cmd, text=True, capture_output=True, timeout=_TEST_TIMEOUT, + ) + + nhits = {'EXT-INVOCATION': 1, 'EXT-LOOP': n} + if not prof_child_procs: + for k in nhits: + nhits[k] = 0 + + try: + # Check that the code errors out when expected + assert bool(fail) == bool(proc.returncode) + # Check that the profiling output is as expected + for tag, num in nhits.items(): + _check_output(proc.stdout, tag, num) + finally: + if write_debug: + print('-- Combined debug logs --', file=sys.stderr) + print( + indent(debug_log_file.read_text(), ' '), + end='', file=sys.stderr, + ) + print('-- End of debug logs --', file=sys.stderr) + + +@_fuzz_bare +def test_profiling_bare_python_success( + tmp_path_factory: pytest.TempPathFactory, + ext_module: _ModuleFixture, + use_subprocess: bool, + prof_child_procs: bool, + n: int, + # Dummy arguments to make `pytest` output more legible + label1: str, label2: str, +) -> None: + """ + Check that `kernprof` can profile the target functions if the code + invokes another bare Python process (via either :py:func:`os.system` + or :py:func:`subprocess.run`) that calls them and exits without + errors. + + See also: + :py:func:`test_profiling_bare_python_failure` + """ + _test_profiling_bare_python( + tmp_path_factory=tmp_path_factory, + ext_module=ext_module, + use_subprocess=use_subprocess, + prof_child_procs=prof_child_procs, + fail=False, + n=n, + ) + + +@_fuzz_bare +def test_profiling_bare_python_failure( + tmp_path_factory: pytest.TempPathFactory, + ext_module: _ModuleFixture, + use_subprocess: bool, + prof_child_procs: bool, + n: int, + label1: str, + label2: str, +) -> None: + """ + Check that `kernprof` can profile the target functions if the code + invokes another bare Python process (via either :py:func:`os.system` + or :py:func:`subprocess.run`) that calls them and exits with errors. + + See also: + :py:func:`test_profiling_bare_python_success` + """ + _test_profiling_bare_python( + tmp_path_factory=tmp_path_factory, + ext_module=ext_module, + use_subprocess=use_subprocess, + prof_child_procs=prof_child_procs, + fail=True, + n=n, + ) diff --git a/tests/test_retry_tests.py b/tests/test_retry_tests.py new file mode 100644 index 00000000..a3251671 --- /dev/null +++ b/tests/test_retry_tests.py @@ -0,0 +1,1039 @@ +""" +Tests to make sure that our :py:deco:`pytest.mark.retry` decorator +works. + +Notes: + This test module is written to work both: + + - When :py:mod:`pytest_mark_retry` (`link`_) is installed from + source along with this file and the rest of the test suite, or + + - In a test directory containing (among other things): + + - This file as a standalone test module, and + + - A ``conftest.py`` containing the content of single-file module + ``pytest_mark_retry.py``. + +.. _link: https://gitlab.com/TTsangSC/pytest-mark-retry +""" +from __future__ import annotations + +import re +import pprint +import textwrap +from collections.abc import Collection, Iterable, Sequence +from dataclasses import dataclass +from functools import cached_property, partial +from importlib.util import find_spec +from operator import attrgetter +from pathlib import Path +from shutil import rmtree +from typing import Any, Literal, cast +from typing_extensions import Self + +import pytest + + +pytest_plugins = ('pytester',) + +_Status = Literal['passed', 'failed', 'skipped'] +_RunPytestMethod = Literal[ + 'runpytest', 'runpytest_inprocess', 'runpytest_subprocess', +] + +PROJECT_MODULE = 'pytest_mark_retry' + +TEST_COUNTERS = """ +from __future__ import annotations +from itertools import count +from typing import Literal + +import pytest + + +@pytest.fixture +def func_scoped_counter() -> count: + return count() + + +@pytest.fixture(scope='module') +def module_scoped_counter() -> count: + return count() + + +@pytest.mark.parametrize( + ('scope', 'n'), + [('func', 0), # This passes + ('func', 2), # This passes with 2 retries + ('func', 6), # This fails with 3 retries + ('module', 4), # This fails with 3 retries (counter now at 3) + ('module', 5)] # This passes with 1 retry (counter now at 5) +) +@pytest.mark.retry(3, reset_fixtures=False) +def test_dynamic_fixtures_persisted( + request: pytest.FixtureRequest, scope: Literal['func', 'module'], n: int, +) -> None: + ''' + Test counter fixtures that are requested dynamically via the + ``request`` fixture; function-scoped fixtures persist between + test retries. + ''' + counter = request.getfixturevalue(scope + '_scoped_counter') + assert next(counter) >= n + + +@pytest.mark.parametrize( + ('scope', 'n'), + [('func', 3), # This passes with 3 retries + ('func', 4), # This fails with 3 retries + ('module', 4), # This passes (counter now at 6) + ('module', 9)] # This passes with 2 retries (counter now at 9) +) +@pytest.mark.retry(3, reset_fixtures=False) +def test_static_fixtures_persisted( + func_scoped_counter: Iterable[int], + module_scoped_counter: Iterable[int], + scope: Literal['func', 'module'], + n: int, +) -> None: + ''' + Test counter fixtures that are requested by name; function-scoped + fixtures persist between test retries. + ''' + if scope == 'func': + counter = func_scoped_counter + else: + counter = module_scoped_counter + assert next(counter) >= n + + +@pytest.mark.parametrize( + ('scope', 'n'), + [('func', 0), # This passes + ('func', 1), # This fails with 1 retry + ('module', 11)] # This passes with 1 retry (counter now at 11) +) +@pytest.mark.retry # Counters reset between retries +def test_dynamic_fixtures_reset( + request: pytest.FixtureRequest, scope: Literal['func', 'module'], n: int, +) -> None: + ''' + Test counter fixtures that are requested dynamically via the + ``request`` fixture; function-scoped fixtures are reset between + test retries. + ''' + counter = request.getfixturevalue(scope + '_scoped_counter') + assert next(counter) >= n + + +@pytest.mark.parametrize( + ('scope', 'n'), + [('func', 0), # This passes + ('func', 1), # This fails with 2 retries + ('module', 14)] # This passes with 2 retries (counter now at 14) +) +@pytest.mark.retry(2) # Ditto above +def test_static_fixtures_reset( + func_scoped_counter: Iterable[int], + module_scoped_counter: Iterable[int], + scope: Literal['func', 'module'], + n: int, +) -> None: + ''' + Test counter fixtures that are requested by name; function-scoped + fixtures are reset between test retries. + ''' + if scope == 'func': + counter = func_scoped_counter + else: + counter = module_scoped_counter + assert next(counter) >= n +""" +TEST_TEARDOWN = """ +from __future__ import annotations + +import os +import tempfile +from collections.abc import Callable, Generator +from functools import partial +from pathlib import Path + +import pytest + + +@pytest.fixture(scope='module') +def my_temp_dir(pytestconfig: pytest.Config) -> Generator[Path, None, None]: + path: Path | None = getattr(pytestconfig.option, 'my_temp_dir', None) + if path is None: + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + else: + yield path + +@pytest.fixture(scope='module') +def my_log(pytestconfig: pytest.Config) -> Path | None: + path: Path | None = getattr(pytestconfig.option, 'my_log', None) + return path + + +def _tempfile(*args, **kwargs) -> Path: + handle, path = tempfile.mkstemp(*args, **kwargs) + try: + return Path(path) + finally: + os.close(handle) + + +@pytest.fixture +def maketemp( + my_temp_dir: Path, my_log: Path | None, +) -> Generator[Callable[..., Path], None, None]: + paths: list[Path] = [] + + def _maketemp(*args, **kwargs) -> Path: + path = _tempfile(*args, **kwargs) + paths.append(path) + log(f'created tempfile {path}') + return path + + log = partial(_log, _maketemp, my_log) + try: + yield _maketemp + finally: + for path in paths: + path.unlink(missing_ok=True) + log(f'removed tempfile {path}') + + +def _log(maketemp: Any, my_log: Path | None, msg: str) -> None: + chunks: list[str] = [ + os.environ['PYTEST_CURRENT_TEST'], + f'maketemp() @ {id(maketemp):#x}', + msg, + ] + msg = ': '.join(chunks) + print(msg) + if my_log is None: + return + with my_log.open(mode='a') as fobj: + print(msg, file=fobj) + + +@pytest.mark.retry(reset_fixtures=True) +def test_with_fixture_reset( + my_temp_dir: Path, maketemp: Callable[..., Path], +) -> None: + path = maketemp(dir=my_temp_dir) + assert False + + +@pytest.mark.retry(2, reset_fixtures=False) +def test_no_fixture_reset( + my_temp_dir: Path, maketemp: Callable[..., Path], +) -> None: + path = maketemp(dir=my_temp_dir) + assert False +""" +TEST_EXCEPTIONS = """ +from __future__ import annotations + +from collections.abc import Iterable +from typing import Any + +import pytest + + +@pytest.fixture +def items() -> Iterable[Any]: + return iter(['1', None, '', '-1']) + + +@pytest.mark.retry(3, reset_fixtures=('foo',)) # Not resetting `items` +def test_all_xc_types(items: Iterable[Any]) -> None: + ''' + This should pass after 3 retries because the last item fulfills the + criterion. + ''' + assert int(next(items)) < 0 + + +@pytest.mark.retry(3, exceptions=AssertionError, reset_fixtures=()) +def test_one_xc_type(items: Iterable[Any]) -> None: + ''' + This should fail after 1 retry because the second item triggers a + :py:class:`TypeError`. + ''' + assert int(next(items)) < 0 + + +@pytest.mark.retry(reset_fixtures=False) +@pytest.mark.retry(exceptions=TypeError) +@pytest.mark.retry(exceptions=AssertionError) +def test_two_xc_types(items: Iterable[Any]) -> None: + ''' + This should fail after 2 retries because the third item triggers a + :py:class:`ValueError`. + + Note: + The three decorators stack to give 3 retries and to accept both + :py:class:`AssertionError` and :py:class:`TypeError`. + ''' + assert int(next(items)) < 0 + + +@pytest.mark.retry( + 3, + exceptions=(AssertionError, TypeError, ValueError), + reset_fixtures=False, +) +def test_three_xc_types(items: Iterable[Any]) -> None: + ''' + This should pass after 3 retries because the last item fulfills the + criterion, and the preceding errors are all included in the + ``exceptions`` argument to the wrapper. + ''' + assert int(next(items)) < 0 +""" +TEST_CONDITIONS = """ +from __future__ import annotations + +from sys import version_info + +import pytest + + +@pytest.mark.retry(2, condition=(11 % 2)) +def test_concrete_positive_condition() -> None: + ''' + This should fail after 2 retries because its condition is true. + ''' + raise RuntimeError + + +@pytest.mark.retry(condition=('a' in 'foo')) +def test_concrete_negative_condition() -> None: + ''' + This should fail without retries because its condition is false. + ''' + raise RuntimeError + + +@pytest.mark.retry(condition='version_info.major >= 3') +def test_dynamic_positive_condition_test_module_globals() -> None: + ''' + This should fail after 1 retry because the condition evaluates to + true on the test module's ``globals()``. + ''' + raise RuntimeError + + +@pytest.mark.retry(condition='version_info.major < 3') +def test_dynamic_negative_condition_test_module_globals() -> None: + ''' + This should fail without retries because the condition evaluates to + false on the test module's ``globals()``. + ''' + raise RuntimeError + + +@pytest.mark.retry(condition='foo == 1') +def test_bad_dynamic_condition() -> None: + ''' + This should fail without retries because the condition cannot be + evaluated (``NameError: name 'foo' is not defined``). + ''' + raise RuntimeError('bar') + + +@pytest.mark.retry(condition='n % 2') +@pytest.mark.parametrize('n', [0, 1, 2]) +def test_dynamic_condition_test_params(n: int) -> None: + ''' + Subtests ``[0]`` and ``[2]`` (resp. subtest ``[1]``) should fail + without retries (resp. with 1 retry) because the condition evaluates + to false (resp. true) on the test's parametrization. + ''' + raise RuntimeError +""" +TEST_BAD_MARKERS = """ +from __future__ import annotations + +import pytest + + +@pytest.mark.retry(1, 2) # `exceptions` cannot be 2 +def test_passing_bad_exceptions() -> None: + ''' + This test passes with a warning because its retry marker has an + invalid :py:attr:`RetryMarker.exceptions`. + ''' + pass + + +@pytest.mark.retry(foo=1) # No argument named `foo` +def test_passing_stray_arg() -> None: + ''' + This test also passes with a warning because its retry marker has am + stray argument ``foo`` + ''' + pass + + +@pytest.mark.retry(condition='') # Syntax error +def test_failing_bad_condition() -> None: + ''' + This test fails with a warning and without retries, because its + retry marker got a bad :py:attr:`RetryMarker.condition`. + ''' + assert False +""" +TEST_REQUIRE = """ +from __future__ import annotations + +import itertools + +import pytest + + +@pytest.fixture(scope='module') +def counter() -> itertools.count: + return itertools.count() + + +@pytest.fixture +def index(counter: itertools.count) -> int: + return next(counter) + + +@pytest.mark.retry(3) +def test_passing_retry_require_any(index: int) -> None: + ''' + This passes with two retries and leave ``index`` at 2. + ''' + assert index >= 2 + + +@pytest.mark.retry(3, require='any') +def test_failing_retry_require_any(index: int) -> None: + ''' + This fails with three retries and leave ``index`` at 6. + ''' + assert index < 3 + + +@pytest.mark.retry(3, require='all') +def test_failing_retry_require_all(index: int) -> None: + ''' + This fails with zero retries and leave ``index`` at 7. + ''' + # Fails right out the gate, no need to continue retrying + assert index > 7 + + +@pytest.mark.retry(3, require='all') +def test_passing_retry_require_all(index: int) -> None: + ''' + This passes with three retries and leave ``index`` at 11. + ''' + # All attempts pass, but we are instructed to exhaust the retries + assert index > 0 +""" + + +@dataclass +class _TestOutcome: + name: str = '' + status: _Status = 'passed' + retries: int = 0 + + def subtest( + self, + *params: str, + status: _Status | None = None, + retries: int | None = None, + ) -> Self: + if status is None: + status = self.status + if retries is None: + retries = self.retries + name = f'{self.name}[{"-".join(params)}]' + return type(self)(name, status, retries) + + +@dataclass +class _TestModule: + """ + Helper object for running a test module. + """ + name: str + content: str + expected_outcomes: dict[str, list[_TestOutcome]] + pytester: pytest.Pytester + conftest: str | None = None + + def __post_init__(self) -> None: + self.content = self._strip(self.content) + if self.conftest: + self.conftest = self._strip(self.conftest) + + def run( + self, + *args: str, + check_results: bool = False, + check_summary: Literal['verbose', 'concise'] | None = None, + check_warnings: int | None = None, + runner: _RunPytestMethod = 'runpytest', + additional_stdout_lines: Collection[str] = (), + additional_stderr_lines: Collection[str] = (), + ) -> pytest.RunResult: + """ + Args: + *args (str): + Passed to :py:meth:`pytester.Pytester.runpytest` + check_results (bool): + If true, check that the test outcomes are as expected + using :py:meth:`pytester.Pytester.assert_outcomes` + check_summary (bool): + If true, check that the 'retries summary' report section + is written with the expected content indicating test + results and number of retries + check_warnings (int | None): + If an integer and if ``check_results`` is true, also + check that the number of captured warnings match + runner (Literal['runpytest', 'runpytest_inprocess', \ +'runpytest_subprocess']): + The :py:class:`pytest.Pytester` method used to run the + test module + additional_stdout_lines, additional_stderr_lines \ +(Collection[str]): + Additional regex patterns (other than the + automatically-generated ones) to match against the + output streams + + Returns: + :py:class:`pytest.RunResult` object returned by the + :py:class:`pytest.Pytester` method + """ + tempfiles: list[Path] = [] + tempdirs: list[Path] = [] + try: + conftests: list[str] = [] + if not self.marker_plugin_globally_installed: + # If we don't do this the project will be loaded twice + # as a plugin, leading to a clash + conftests.append(self.marker_plugin_path.read_text()) + if self.conftest: + conftests.append(self.conftest) + # Create separate conftest.py in nested subdirs to avoid + # hook-func implementations stepping oer one another + path = self.pytester.path + for i, conftest in enumerate(conftests): + if i: + path /= 'nested' + path.mkdir() + tempdirs.append(path) + conftest_file = path / 'conftest.py' + conftest_file.write_text(conftest) + tempfiles.append(conftest_file) + module = path / f'{self.name}.py' + module.write_text(self.content) + tempfiles.append(module) + result = getattr(self.pytester, runner)(*args, str(module)) + if check_results: + self.check_results(result, check_warnings) + if check_summary is not None: + if check_summary == 'verbose': + checker = self.check_verbose_summary + else: + checker = self.check_concise_summary + checker( + result, + stdout=additional_stdout_lines, + stderr=additional_stderr_lines, + ) + return result + finally: + for path in tempfiles: + try: + path.unlink(missing_ok=True) + except OSError: + pass + else: + print('Removed temppath', path) + for path in reversed(tempdirs): + try: + rmtree(path) + except OSError: + pass + else: + print('Removed tempdir', path) + + def check_results( + self, result: pytest.RunResult, warnings: int | None = None, + ) -> None: + counts: dict[_Status, int] = {} + for outcomes in self.expected_outcomes.values(): + for outcome in outcomes: + counts[outcome.status] = counts.get(outcome.status, 0) + 1 + result.assert_outcomes( + warnings=warnings, **cast(dict[str, int], counts), + ) + + def check_verbose_summary( + self, + result: pytest.RunResult, + stdout: Collection[str] = (), + stderr: Collection[str] = (), + ) -> None: + lines: list[str] = [] + counts: dict[_Status, int] = {} + for outcomes in self.expected_outcomes.values(): + for outcome in outcomes: + lines.append( + f'.*::{re.escape(outcome.name)} +{outcome.status.upper()}', + ) + if not outcome.retries: + continue + counts[outcome.status] = counts.get(outcome.status, 0) + 1 + lines.append(r'.*{}.*retried {} time{}'.format( + re.escape(outcome.name), + outcome.retries, + '' if outcome.retries == 1 else 's', + )) + lines.extend( + self._format_header(status, n) for status, n in counts.items() + ) + self._check_lines(result, [*lines, *stdout], stderr) + + def check_concise_summary( + self, + result: pytest.RunResult, + stdout: Collection[str] = (), + stderr: Collection[str] = (), + ) -> None: + lines: list[str] = [] + counts: dict[_Status, int] = {} + test_names: dict[_Status, dict[str, set[str]]] = {} + consolidated_names: dict[_Status, set[str]] = {} + for parent_test, outcomes in self.expected_outcomes.items(): + for outcome in outcomes: + if outcome.status == 'failed': + lines.append( + f'{outcome.status.upper()} +' + f'.*::{re.escape(outcome.name)}', + ) + if not outcome.retries: + continue + counts[outcome.status] = counts.get(outcome.status, 0) + 1 + ( + test_names + .setdefault(outcome.status, {}) + .setdefault(parent_test, set()) + .add(outcome.name) + ) + + for status, tests in test_names.items(): + for parent_test, subtests in tests.items(): + names = consolidated_names.setdefault(status, set()) + n = len(subtests) + if n == 1: + names.add(*subtests) + else: + names.add('{} ({} subtest{})'.format( + parent_test, n, '' if n == 1 else 's', + )) + + self._check_lines(result, [*lines, *stdout], stderr) + + for status, n in counts.items(): + header = self._format_header(status, n) + names = consolidated_names[status] + print(f'Expecting line in the output: "{header}: <...>"...') + print(f'Expecting these names in said line: {names!r}...') + line = self._find_line(header + ':', str(result.stdout)) + for test_name in names: + assert test_name in line + + @property + def marker_plugin_path(self) -> Path: + return self._source[0] + + @property + def marker_plugin_globally_installed(self) -> bool: + return self._source[1] + + @cached_property + def _source(self) -> tuple[Path, bool]: + sources = { + f'module `{PROJECT_MODULE}`': (self._get_proj_module_path, True), + repr('conftest.py'): (self._get_proj_conftest, False), + } + for src, (get_path, retry_globally_installed) in sources.items(): + try: + path = get_path() + assert 'class RetryMarker' in path.read_text() + print(f'Loaded project source from {src}: {str(path)!r}') + return path, retry_globally_installed + except Exception: + pass + raise RuntimeError( + f'Failed to load the project source from any of: {sources!r}', + ) + + @staticmethod + def _check_lines( + result: pytest.RunResult, + stdout: Collection[str], + stderr: Collection[str], + ) -> None: + for stream, lines in { + 'stdout': list(stdout), 'stderr': list(stderr), + }.items(): + if not lines: + continue + print(f'Expecting these lines in the {stream}: {lines!r}...') + getattr(result, stream).re_match_lines_random(lines) + + @staticmethod + def _find_line(pattern: str, text: str) -> str: + pattern = f'^.*{pattern}.*' + maybe_match = re.search(pattern, text, re.MULTILINE) + if not maybe_match: + raise ValueError(f'Cannot find {pattern!r} in {text!r}') + return maybe_match.group() + + @staticmethod + def _format_header(status: _Status, n: int) -> str: + return '{} test{} {} with retries'.format( + n, '' if n == 1 else 's', status, + ) + + @staticmethod + def _get_proj_conftest() -> Path: # If installed as the conftest + return Path(__file__).parent / 'conftest.py' + + @staticmethod + def _get_proj_module_path() -> Path: # If installed as a module + spec = find_spec(PROJECT_MODULE) + assert spec and spec.origin + return Path(spec.origin) + + @staticmethod + def _strip(text: str) -> str: + return textwrap.dedent(text).strip('\n') + + +def _identical_items_are_adjacent(items: Iterable[Any]) -> bool: + """ + Example: + >>> _identical_items_are_adjacent([]) + True + >>> _identical_items_are_adjacent([1]) + True + >>> _identical_items_are_adjacent([1, 10]) + True + >>> _identical_items_are_adjacent([1, 10, 1]) + False + >>> _identical_items_are_adjacent('AAcCb') + True + >>> _identical_items_are_adjacent('AcCAb') + False + """ + past: set[Any] = set() + sentinel = object() + last: Any = sentinel + for item in items: + if last is not sentinel and last != item: + past.add(last) + if item in past: + return False + last = item + return True + + +def _outcomes_to_outcome_dict( + outcomes: Iterable[_TestOutcome], +) -> dict[str, list[_TestOutcome]]: + """ + Example: + >>> o0 = _TestOutcome('foo', 'passed', 0) + >>> o1 = _TestOutcome('bar[1-2-3]', 'failed', 1) + >>> o2 = _TestOutcome('bar[4-5-6]', 'passed', 2) + >>> outcomes = {'foo': [o0], 'bar': [o1, o2]} + >>> assert _outcomes_to_outcome_dict([o1, o0, o2]) == outcomes + """ + result: dict[str, list[_TestOutcome]] = {} + for outcome in outcomes: + name = outcome.name + if name.endswith(']') and '[' in name: # Subtest + base_name, *_ = name.partition('[') + else: + base_name = name + result.setdefault(base_name, []).append(outcome) + return result + + +@pytest.fixture +def counters_module(pytester: pytest.Pytester) -> _TestModule: + dynamic_p = _TestOutcome('test_dynamic_fixtures_persisted').subtest + static_p = _TestOutcome('test_static_fixtures_persisted').subtest + dynamic_r = _TestOutcome('test_dynamic_fixtures_reset').subtest + static_r = _TestOutcome('test_static_fixtures_reset').subtest + outcomes = _outcomes_to_outcome_dict([ + dynamic_p('func-0'), + dynamic_p('func-2', retries=2), + dynamic_p('func-6', status='failed', retries=3), + dynamic_p('module-4', status='failed', retries=3), + dynamic_p('module-5', retries=1), + static_p('func-3', retries=3), + static_p('func-4', status='failed', retries=3), + static_p('module-4'), + static_p('module-9', retries=2), + dynamic_r('func-0'), + dynamic_r('func-1', status='failed', retries=1), + dynamic_r('module-11', retries=1), + static_r('func-0'), + static_r('func-1', status='failed', retries=2), + static_r('module-14', retries=2), + ]) + return _TestModule('test_counters', TEST_COUNTERS, outcomes, pytester) + + +@pytest.fixture +def teardown_module(pytester: pytest.Pytester) -> _TestModule: + outcomes = _outcomes_to_outcome_dict([ + _TestOutcome('test_no_fixture_reset', 'failed', 2), + _TestOutcome('test_with_fixture_reset', 'failed', 1), + ]) + cf = """ + from __future__ import annotations + + from pathlib import Path + + import pytest + + + def pytest_addoption(parser: pytest.Parser) -> None: + parser.addoption( + '--my-temp-dir', + type=Path, + help=f'persisted tempdir location for {__file__!r}', + ) + parser.addoption( + '--my-log', + type=Path, + help=f'log file location for tempfile creation/deletion', + ) + """ + return _TestModule('test_teardown', TEST_TEARDOWN, outcomes, pytester, cf) + + +@pytest.fixture +def exceptions_module(pytester: pytest.Pytester) -> _TestModule: + outcomes = _outcomes_to_outcome_dict([ + _TestOutcome('test_all_xc_types', retries=3), + _TestOutcome('test_one_xc_type', 'failed', 1), + _TestOutcome('test_two_xc_types', 'failed', 2), + _TestOutcome('test_three_xc_types', retries=3), + ]) + return _TestModule('test_exceptions', TEST_EXCEPTIONS, outcomes, pytester) + + +@pytest.fixture +def conditions_module(pytester: pytest.Pytester) -> _TestModule: + test = partial(_TestOutcome, status='failed') + param_test_name = 'test_dynamic_condition_test_params' + param_test = partial(test(param_test_name).subtest, status='failed') + outcomes = _outcomes_to_outcome_dict([ + test('test_concrete_positive_condition', retries=2), + test('test_concrete_negative_condition'), + test('test_dynamic_positive_condition_test_module_globals', retries=1), + test('test_dynamic_negative_condition_test_module_globals'), + test('test_bad_dynamic_condition'), + param_test('0'), + param_test('1', retries=1), + param_test('2'), + ]) + return _TestModule('test_conditions', TEST_CONDITIONS, outcomes, pytester) + + +@pytest.fixture +def bad_markers_module(pytester: pytest.Pytester) -> _TestModule: + outcomes = _outcomes_to_outcome_dict([ + _TestOutcome('test_passing_bad_exceptions'), + _TestOutcome('test_passing_stray_arg'), + _TestOutcome('test_failing_bad_condition', 'failed'), + ]) + return _TestModule('test_bad', TEST_BAD_MARKERS, outcomes, pytester) + + +@pytest.fixture +def require_module(pytester: pytest.Pytester) -> _TestModule: + outcomes = _outcomes_to_outcome_dict([ + _TestOutcome('test_passing_retry_require_any', retries=2), + _TestOutcome('test_failing_retry_require_any', 'failed', 3), + _TestOutcome('test_failing_retry_require_all', 'failed'), + _TestOutcome('test_passing_retry_require_all', retries=3), + ]) + return _TestModule('test_require', TEST_REQUIRE, outcomes, pytester) + + +@pytest.mark.parametrize('verbose', [True, False]) +def test_fixture_scoping(counters_module: _TestModule, verbose: bool) -> None: + """ + Test that the decorator correctly handles scoped fixtures. + """ + run = partial(counters_module.run, check_results=True, check_warnings=0) + if verbose: + run('--verbose', check_summary='verbose') + else: + run(check_summary='concise') + + +def test_fixture_teardown( + tmp_path_factory: pytest.TempPathFactory, teardown_module: _TestModule, +) -> None: + """ + Test that the decorator correctly handles teardown for additional + fixture copies incurred by retries; in particular, superseded + function-scoped fixtures should be torn down before their + replacements are set up. + """ + Stage = Literal['setup', 'call', 'teardown'] + + @dataclass + class LogEntry: + test: str + stage: Stage + fixture_id: int + msg: str + + @classmethod + def parse_line(cls, line: str) -> Self: + test, ident, *remainder = line.split(': ') + msg = ': '.join(remainder) + test_match = re.fullmatch( + r'(.+) +\((setup|call|teardown)\)', test, + ) + assert test_match + test, stage = test_match.group(1, 2) + assert stage in ('setup', 'call', 'teardown') + ident_match = re.fullmatch( + r'maketemp\(\) @ 0x([0-9a-f]+)', ident, + ) + assert ident_match + fixture_id = int(ident_match.group(1), base=16) + return cls(test, cast(Stage, stage), fixture_id, msg) + + tempdir = tmp_path_factory.mktemp('my_temp') + log = tempdir / 'tempfiles.log' + teardown_module.run( + '--verbose', f'--my-temp-dir={tempdir}', f'--my-log={log}', + check_results=True, check_summary='verbose', check_warnings=0, + ) + + # Check that all the tempfiles ahve been wiped + files = {path.name for path in tempdir.iterdir()} + assert not (files - {log.name}) + + # Check that tempfiles are deleted as soon as the fixture value + # that created them went obsolete, before the next rerun; + # we can verify that by checking that the ids of the `makefile()` + # fixtures appear in contiguous blocks + + # Note: there seems to be a weird corner case where neighboring + # tests may reuse the same fixture id (see `line_profiler` failing + # job 73520441960 in pipeline 25091142386); probably has to do with + # object lifetime. + # So instead of just checking the `fixture_id`, also consult + # `test`; it suffices to see that WITHIN THE SAME TEST we don't have + # fixture values stepping over one another + with log.open() as fobj: + entries = [LogEntry.parse_line(line.rstrip('\n')) for line in fobj] + pprint.pprint(entries) + for fields in ('test', 'stage'), ('test', 'fixture_id'): + getter = attrgetter(*fields) + values = [getter(entry) for entry in entries] + assert _identical_items_are_adjacent(values), ( + f'Inconsistency in {fields} order: {values!r}' + ) + + +def test_exception_restrictions(exceptions_module: _TestModule) -> None: + """ + Test that the decorator correctly handles failures owing to + different exception classes. + """ + exceptions_module.run( + '--verbose', + check_results=True, check_summary='verbose', check_warnings=0, + ) + + +def test_retry_conditions(conditions_module: _TestModule) -> None: + """ + Test that the decorator correctly handles retry conditions. + """ + # `test_bad_dynamic_condition()` should have failed with a + # `RetryConditionFailure`, listing the error encountered in the last + # trial and the error encountered when `eval()`-ing the condition + # (Note: grepping for the error details in the short test summary is + # fragile since it may be elided or even entirely omitted; so we + # just use a separate pattern to grep it from the tracebacks) + lines = [ + r'.*RetryConditionFailure: \(RuntimeError: bar\) ' + r"-> \(condition: 'foo == 1' -> NameError: .*'foo'.*\)", + ] + conditions_module.run( + '--verbose', + check_results=True, check_summary='verbose', check_warnings=0, + additional_stdout_lines=lines, + ) + + +def test_bad_markers(bad_markers_module: _TestModule) -> None: + """ + Test that the decorator gracefully handles incorrect constructions. + """ + stdout = bad_markers_module.run( + '--verbose', + check_results=True, check_summary='verbose', check_warnings=3, + ).stdout + # Check the warnings emitted + # (Since we want to match across multiple lines we can't use + # `additional_stdout_lines`) + errors: str | Sequence[str] + pattern = ( + '{0}\n' + r'.*RetryMarkerWarning: .*{0}.*: disregarding .* marker: \(.*{1}.*\)' + ) + for test, errors in { + 'test_passing_bad_exceptions': [ + r'TypeError: \.exceptions = .*2.*: expected .*exception', + r'TypeError: .*not iterable', + r'TypeError: too many positional arguments', + r'TypeError: .*takes 1 positional argument but' + ], + 'test_passing_stray_arg': + r'TypeError: .*unexpected keyword argument \'foo\'', + 'test_failing_bad_condition': + r'ValueError: \.condition = \'\': not a valid expression ' + r'\(SyntaxError.*\)', + }.items(): + if isinstance(errors, str): + errors = [errors] + messages = [pattern.format(re.escape(test), error) for error in errors] + if not any(re.search(msg, str(stdout)) for msg in messages): + msg = f'none of the patterns {messages!r} matched {stdout!r}' + raise AssertionError(msg) + + +def test_requirement(require_module: _TestModule) -> None: + """ + Test that the decorator correctly handles requirements that all + trials should pass (via ``require='all'``). + """ + require_module.run( + '--verbose', check_results=True, check_summary='verbose', + )