Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 96 additions & 5 deletions libensemble/ensemble.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import logging
import warnings

import numpy.typing as npt

from libensemble._deprecation import LibEnsembleDeprecationWarning
from libensemble.executors import Executor
from libensemble.libE import libE
from libensemble.specs import AllocSpecs, ExitCriteria, GenSpecs, LibeSpecs, SimSpecs
Expand All @@ -19,6 +21,13 @@
OVERWRITE_COMMS_WARN = "Cannot reset 'comms' if 'ensemble.libE_specs.comms' is already set."
CHANGED_COMMS_WARN = "New 'comms' method detected following initialization of Ensemble. Exiting."

EXIT_CRITERIA_DEPRECATION = (
"ExitCriteria as a standalone parameter is deprecated as of libEnsemble 2.0 "
"and will be removed in 2.1. Pass exit criteria directly to run() instead: "
"ensemble.run(sim_max=100) or ensemble.run(sim_max=100, wallclock_max=3600). "
"See https://libensemble.readthedocs.io/... for migration guidance."
)

CORRESPONDING_CLASSES = {
"sim_specs": SimSpecs,
"gen_specs": GenSpecs,
Expand Down Expand Up @@ -159,7 +168,7 @@ def __init__(
self,
sim_specs: SimSpecs = SimSpecs(),
gen_specs: GenSpecs = GenSpecs(),
exit_criteria: ExitCriteria = ExitCriteria(),
exit_criteria: ExitCriteria | None = None,
libE_specs: LibeSpecs = LibeSpecs(),
alloc_specs: AllocSpecs = AllocSpecs(),
persis_info: dict = {},
Expand All @@ -169,7 +178,11 @@ def __init__(
):
self.sim_specs = sim_specs
self.gen_specs = gen_specs
self.exit_criteria = exit_criteria
self._exit_criteria = ExitCriteria()
if exit_criteria is not None:
if isinstance(exit_criteria, ExitCriteria):
warnings.warn(EXIT_CRITERIA_DEPRECATION, LibEnsembleDeprecationWarning, stacklevel=2)
self._exit_criteria = exit_criteria
self._libE_specs: LibeSpecs = libE_specs
self.alloc_specs = alloc_specs
self.persis_info = persis_info
Expand All @@ -180,6 +193,7 @@ def __init__(
self.is_manager = False
self.parsed = False
self._known_comms: str = ""
self._has_run_n_evals = False

if parse_args:
self._parse_args()
Expand Down Expand Up @@ -254,7 +268,9 @@ def ready(self) -> tuple[bool, list[str]]:
):
issues.append(
"exit_criteria has no stop condition: set at least one of "
"'sim_max', 'gen_max', 'wallclock_max', or 'stop_val'."
"'sim_max', 'gen_max', 'wallclock_max', or 'stop_val' "
"either on an ExitCriteria object or directly via "
"ensemble.run(sim_max=..., gen_max=..., ...)."
)

# --- workers: must be determinable ---
Expand Down Expand Up @@ -308,13 +324,44 @@ def libE_specs(self, new_specs):

self._libE_specs.__dict__.update(**new_specs)

@property
def exit_criteria(self) -> ExitCriteria:
return self._exit_criteria

@exit_criteria.setter
def exit_criteria(self, value: ExitCriteria | None):
if isinstance(value, ExitCriteria):
warnings.warn(EXIT_CRITERIA_DEPRECATION, LibEnsembleDeprecationWarning, stacklevel=2)
self._exit_criteria = value or ExitCriteria()

def _refresh_executor(self):
Executor.executor = self.executor or Executor.executor

def run(self) -> tuple[npt.NDArray, dict, int]:
def run(
self,
sim_max: int | None = None,
gen_max: int | None = None,
wallclock_max: float | None = None,
stop_val: tuple[str, float] | None = None,
) -> tuple[npt.NDArray, dict, int]:
"""
Initializes libEnsemble.

Parameters
----------
sim_max: int, Optional
Maximum number of new simulation evaluations for this run.
Overrides ``exit_criteria.sim_max`` for this call only.
gen_max: int, Optional
Maximum number of new generator calls for this run.
Overrides ``exit_criteria.gen_max`` for this call only.
wallclock_max: float, Optional
Wallclock timeout in seconds for this run.
Overrides ``exit_criteria.wallclock_max`` for this call only.
stop_val: tuple[str, float], Optional
Stop criterion ``(field, value)`` for this run.
Overrides ``exit_criteria.stop_val`` for this call only.

.. dropdown:: MPI/comms Notes

Manager--worker intercommunications are parsed from the ``comms`` key of
Expand All @@ -325,6 +372,25 @@ def run(self) -> tuple[npt.NDArray, dict, int]:
will initiate on a **duplicate** of that communicator.
Otherwise, a duplicate of ``COMM_WORLD`` will be used.

.. dropdown:: Substeps / multi-step usage

Pass exit-criteria kwargs to run a subset of an ensemble at a time.
The ensemble history (``H0``) is automatically chained across calls::

sampling = Ensemble(...)
sampling.sim_specs = SimSpecs(...)
sampling.gen_specs = GenSpecs(...)

# Run in three substeps
sampling.run(sim_max=30)
# ... adjust generator hyperparameters ...
sampling.run(sim_max=30)
sampling.run(sim_max=40)

When ``sim_max`` is used (from kwargs or ``exit_criteria``),
``libE_specs.final_gen_send`` and ``libE_specs.reuse_output_dir`` are
automatically set to ``True`` to support persistent generators across runs.

Returns
-------

Expand Down Expand Up @@ -355,16 +421,41 @@ def run(self) -> tuple[npt.NDArray, dict, int]:
raise ValueError(CHANGED_COMMS_WARN)

assert self._libE_specs is not None

# Merge kwargs into effective exit criteria for this run
run_kwargs = {
k: v
for k, v in {
"sim_max": sim_max,
"gen_max": gen_max,
"wallclock_max": wallclock_max,
"stop_val": stop_val,
}.items()
if v is not None
}
if run_kwargs:
effective_exit = self._exit_criteria.model_copy(update=run_kwargs)
self._has_run_n_evals = True
else:
effective_exit = self._exit_criteria

if sim_max is not None or getattr(self._exit_criteria, "sim_max", None) is not None:
self._libE_specs.final_gen_send = True
self._libE_specs.reuse_output_dir = True

self.H, self.persis_info, self.flag = libE(
self.sim_specs,
self.gen_specs,
self.exit_criteria,
effective_exit,
persis_info=self.persis_info,
alloc_specs=self.alloc_specs,
libE_specs=self._libE_specs,
H0=self.H0,
)

# Chain history for next call
self.H0 = self.H

return self.H, self.persis_info, self.flag

@property
Expand Down
179 changes: 179 additions & 0 deletions libensemble/tests/unit_tests/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,178 @@ def test_ready_happy_path():
assert issues == [], f"Issues should be empty but got: {issues}"


# --- run() kwargs / substep tests ---


# --- run() kwargs / substep tests ---


def test_run_sim_max_kwarg():
"""run(sim_max=10) should evaluate exactly 10 simulations."""
from libensemble.alloc_funcs.give_sim_work_first import give_sim_work_first
from libensemble.ensemble import Ensemble
from libensemble.gen_funcs.sampling import latin_hypercube_sample
from libensemble.sim_funcs.simple_sim import norm_eval
from libensemble.specs import AllocSpecs, GenSpecs, LibeSpecs, SimSpecs

ens = Ensemble(
libE_specs=LibeSpecs(comms="local", nworkers=4),
sim_specs=SimSpecs(sim_f=norm_eval, inputs=["x"], outputs=[("f", float)]),
gen_specs=GenSpecs(
gen_f=latin_hypercube_sample,
outputs=[("x", float, (1,))],
persis_in=["f"],
batch_size=5,
user={"lb": np.array([-3]), "ub": np.array([3])},
),
alloc_specs=AllocSpecs(alloc_f=give_sim_work_first),
)
ens.run(sim_max=10)
if ens.is_manager:
sim_count = int(np.sum(ens.H["sim_ended"]))
assert sim_count == 10, f"Expected 10 sims but got {sim_count}"


def test_run_chaining():
"""Two run(sim_max=N) calls should chain H0, doubling total."""
from libensemble.alloc_funcs.give_sim_work_first import give_sim_work_first
from libensemble.ensemble import Ensemble
from libensemble.gen_funcs.sampling import latin_hypercube_sample
from libensemble.sim_funcs.simple_sim import norm_eval
from libensemble.specs import AllocSpecs, GenSpecs, LibeSpecs, SimSpecs

ens = Ensemble(
libE_specs=LibeSpecs(comms="local", nworkers=4),
sim_specs=SimSpecs(sim_f=norm_eval, inputs=["x"], outputs=[("f", float)]),
gen_specs=GenSpecs(
gen_f=latin_hypercube_sample,
outputs=[("x", float, (1,))],
persis_in=["f"],
batch_size=5,
user={"lb": np.array([-3]), "ub": np.array([3])},
),
alloc_specs=AllocSpecs(alloc_f=give_sim_work_first),
)
ens.run(sim_max=10)
h1_ended = int(np.sum(ens.H["sim_ended"])) if ens.is_manager else 0
ens.run(sim_max=10)
if ens.is_manager:
total_ended = int(np.sum(ens.H["sim_ended"]))
assert total_ended == h1_ended + 10, f"Expected {h1_ended + 10} sims ended but got {total_ended}"
assert ens.H0 is ens.H, "H0 should reference the latest H"


def test_run_sim_max_merge():
"""run() kwargs should merge with existing exit_criteria, not replace."""
from libensemble.alloc_funcs.give_sim_work_first import give_sim_work_first
from libensemble.ensemble import Ensemble
from libensemble.gen_funcs.sampling import latin_hypercube_sample
from libensemble.sim_funcs.simple_sim import norm_eval
from libensemble.specs import AllocSpecs, ExitCriteria, GenSpecs, LibeSpecs, SimSpecs

# Must have full sim/gen specs so run() actually works
ens = Ensemble(
libE_specs=LibeSpecs(comms="local", nworkers=4),
sim_specs=SimSpecs(sim_f=norm_eval, inputs=["x"], outputs=[("f", float)]),
gen_specs=GenSpecs(
gen_f=latin_hypercube_sample,
outputs=[("x", float, (1,))],
persis_in=["f"],
batch_size=5,
user={"lb": np.array([-3]), "ub": np.array([3])},
),
exit_criteria=ExitCriteria(sim_max=100),
alloc_specs=AllocSpecs(alloc_f=give_sim_work_first),
)
ens.run(sim_max=10)
# stored exit_criteria should still have sim_max=100
assert ens.exit_criteria.sim_max == 100, f"Expected sim_max=100 but got {ens.exit_criteria.sim_max}"


def test_exit_criteria_deprecation_init():
"""Passing ExitCriteria to Ensemble() should emit a deprecation warning."""
import warnings

from libensemble._deprecation import LibEnsembleDeprecationWarning
from libensemble.ensemble import Ensemble
from libensemble.specs import ExitCriteria

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
Ensemble(exit_criteria=ExitCriteria(sim_max=10))
deprecations = [x for x in w if issubclass(x.category, LibEnsembleDeprecationWarning)]
assert len(deprecations) >= 1, "Expected at least one LibEnsembleDeprecationWarning"


def test_exit_criteria_deprecation_setter():
"""Setting ensemble.exit_criteria = ExitCriteria(...) should emit a deprecation warning."""
import warnings

from libensemble._deprecation import LibEnsembleDeprecationWarning
from libensemble.ensemble import Ensemble
from libensemble.specs import ExitCriteria, LibeSpecs

ens = Ensemble(libE_specs=LibeSpecs(comms="local", nworkers=4))
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
ens.exit_criteria = ExitCriteria(sim_max=10)
deprecations = [x for x in w if issubclass(x.category, LibEnsembleDeprecationWarning)]
assert len(deprecations) >= 1, "Expected at least one LibEnsembleDeprecationWarning"


def test_run_auto_settings():
"""run(sim_max=...) should auto-set final_gen_send and reuse_output_dir."""
from libensemble.alloc_funcs.give_sim_work_first import give_sim_work_first
from libensemble.ensemble import Ensemble
from libensemble.gen_funcs.sampling import latin_hypercube_sample
from libensemble.sim_funcs.simple_sim import norm_eval
from libensemble.specs import AllocSpecs, GenSpecs, LibeSpecs, SimSpecs

ens = Ensemble(
libE_specs=LibeSpecs(comms="local", nworkers=4),
sim_specs=SimSpecs(sim_f=norm_eval, inputs=["x"], outputs=[("f", float)]),
gen_specs=GenSpecs(
gen_f=latin_hypercube_sample,
outputs=[("x", float, (1,))],
persis_in=["f"],
batch_size=5,
user={"lb": np.array([-3]), "ub": np.array([3])},
),
alloc_specs=AllocSpecs(alloc_f=give_sim_work_first),
)
ens.run(sim_max=10)
assert ens.libE_specs.final_gen_send is True
assert ens.libE_specs.reuse_output_dir is True


def test_h0_chaining_plain_run():
"""H0 should be updated to H after a plain run() call."""
from libensemble.alloc_funcs.give_sim_work_first import give_sim_work_first
from libensemble.ensemble import Ensemble
from libensemble.gen_funcs.sampling import latin_hypercube_sample
from libensemble.sim_funcs.simple_sim import norm_eval
from libensemble.specs import AllocSpecs, GenSpecs, LibeSpecs, SimSpecs

ens = Ensemble(
libE_specs=LibeSpecs(comms="local", nworkers=4),
sim_specs=SimSpecs(sim_f=norm_eval, inputs=["x"], outputs=[("f", float)]),
gen_specs=GenSpecs(
gen_f=latin_hypercube_sample,
outputs=[("x", float, (1,))],
persis_in=["f"],
batch_size=5,
user={"lb": np.array([-3]), "ub": np.array([3])},
),
alloc_specs=AllocSpecs(alloc_f=give_sim_work_first),
)
assert ens.H0 is None, "H0 should be None before first run"
ens.run(sim_max=5)
if ens.is_manager:
assert ens.H0 is not None, "H0 should be set after run"
sim_count = int(np.sum(ens.H0["sim_ended"]))
assert sim_count == 5, f"Expected H0 sim_ended count 5 but got {sim_count}"


if __name__ == "__main__":
test_ensemble_init()
test_ensemble_parse_args_false()
Expand All @@ -283,3 +455,10 @@ def test_ready_happy_path():
test_ready_missing_nworkers_local()
test_ready_field_mismatch()
test_ready_happy_path()
test_run_sim_max_kwarg()
test_run_chaining()
test_run_sim_max_merge()
test_exit_criteria_deprecation_init()
test_exit_criteria_deprecation_setter()
test_run_auto_settings()
test_h0_chaining_plain_run()