From 6c77ce089d3408f3b4835c5d4b82bac303255884 Mon Sep 17 00:00:00 2001 From: MaxBlesch Date: Thu, 23 Apr 2026 23:20:09 +0200 Subject: [PATCH 1/7] ALmost --- .../pre_processing/batches/batch_creation.py | 41 ++++- .../pre_processing/batches/single_segment.py | 168 +++++++++++++++--- .../pre_processing/check_model_config.py | 12 ++ src/dcegm/pre_processing/setup_model.py | 1 + 4 files changed, 191 insertions(+), 31 deletions(-) diff --git a/src/dcegm/pre_processing/batches/batch_creation.py b/src/dcegm/pre_processing/batches/batch_creation.py index 9c9d620a..4d6af9f1 100644 --- a/src/dcegm/pre_processing/batches/batch_creation.py +++ b/src/dcegm/pre_processing/batches/batch_creation.py @@ -10,6 +10,7 @@ def create_batches_and_information( model_structure, n_periods, min_period_batch_segments=None, + batch_mode="largest_block", ): """Batches are used instead of periods to have chunks of equal sized state choices. The batch inparams=paramsformation dictionary contains the following arrays @@ -64,10 +65,22 @@ def create_batches_and_information( state_choice_space = model_structure["state_choice_space"] bool_state_choices_to_batch = state_choice_space[:, 0] < n_periods - 2 + valid_batch_modes = {"largest_block", "period_max"} + if min_period_batch_segments is None: + if isinstance(batch_mode, list): + raise ValueError( + "If min_period_batch_segments is not supplied, batch_mode must be a string." + ) + if batch_mode not in valid_batch_modes: + raise ValueError( + f"batch_mode must be one of {valid_batch_modes}. Got {batch_mode}." + ) single_batch_segment_info = create_single_segment_of_batches( - bool_state_choices_to_batch, model_structure + bool_state_choices_to_batch, + model_structure, + batch_mode=batch_mode, ) segment_infos = { "n_segments": 1, @@ -97,6 +110,24 @@ def create_batches_and_information( "The periods to split the batches have to be increasing and at least two periods apart." ) + if isinstance(batch_mode, str): + if batch_mode not in valid_batch_modes: + raise ValueError( + f"batch_mode must be one of {valid_batch_modes}. Got {batch_mode}." + ) + batch_mode = [batch_mode] * n_segments + elif isinstance(batch_mode, list): + if len(batch_mode) != n_segments: + raise ValueError( + "If min_period_batch_segments is supplied, batch_mode must be a list with one entry per segment." + ) + if not all(mode in valid_batch_modes for mode in batch_mode): + raise ValueError( + f"All entries in batch_mode must be one of {valid_batch_modes}." + ) + else: + raise ValueError("batch_mode must be a string or a list of strings.") + segment_infos = { "n_segments": n_segments, } @@ -111,7 +142,9 @@ def create_batches_and_information( bool_state_choices_segment = bool_state_choices_to_batch & (~split_cond) segment_batch_info = create_single_segment_of_batches( - bool_state_choices_segment, model_structure + bool_state_choices_segment, + model_structure, + batch_mode=batch_mode[id_segment], ) segment_infos[f"batches_info_segment_{id_segment}"] = segment_batch_info @@ -119,7 +152,9 @@ def create_batches_and_information( bool_state_choices_to_batch = bool_state_choices_to_batch & split_cond last_segment_batch_info = create_single_segment_of_batches( - bool_state_choices_to_batch, model_structure + bool_state_choices_to_batch, + model_structure, + batch_mode=batch_mode[n_segments - 1], ) # We loop until n_segments - 2 and then add the last segment diff --git a/src/dcegm/pre_processing/batches/single_segment.py b/src/dcegm/pre_processing/batches/single_segment.py index de7db506..ff519c22 100644 --- a/src/dcegm/pre_processing/batches/single_segment.py +++ b/src/dcegm/pre_processing/batches/single_segment.py @@ -3,7 +3,11 @@ from dcegm.pre_processing.batches.algo_batch_size import determine_optimal_batch_size -def create_single_segment_of_batches(bool_state_choices_to_batch, model_structure): +def create_single_segment_of_batches( + bool_state_choices_to_batch, + model_structure, + batch_mode="largest_block", +): """Create a single segment of evenly sized batches. If the last batch is not evenly we correct it. @@ -24,34 +28,54 @@ def create_single_segment_of_batches(bool_state_choices_to_batch, model_structur ] map_state_choice_to_index = model_structure["map_state_choice_to_index_with_proxy"] - ( - batches_list, - child_state_choice_idxs_to_interp_list, - child_state_choices_to_aggr_choice_list, - child_states_to_integrate_stochastic_list, - ) = determine_optimal_batch_size( - bool_state_choices_to_batch=bool_state_choices_to_batch, - state_choice_space=state_choice_space, - map_state_choice_to_child_states=map_state_choice_to_child_states, - map_state_choice_to_index=map_state_choice_to_index, - state_space=state_space, - ) + if batch_mode == "largest_block": + ( + batches_list, + child_state_choice_idxs_to_interp_list, + child_state_choices_to_aggr_choice_list, + child_states_to_integrate_stochastic_list, + ) = determine_optimal_batch_size( + bool_state_choices_to_batch=bool_state_choices_to_batch, + state_choice_space=state_choice_space, + map_state_choice_to_child_states=map_state_choice_to_child_states, + map_state_choice_to_index=map_state_choice_to_index, + state_space=state_space, + ) - ( - batches_list, - child_states_to_integrate_stochastic_list, - child_state_choices_to_aggr_choice_list, - child_state_choice_idxs_to_interp_list, - batches_cover_all, - last_batch_info, - ) = correct_for_uneven_last_batch( - batches_list, - child_states_to_integrate_stochastic_list, - child_state_choices_to_aggr_choice_list, - child_state_choice_idxs_to_interp_list, - state_choice_space_dict, - map_state_choice_to_parent_state, - ) + ( + batches_list, + child_states_to_integrate_stochastic_list, + child_state_choices_to_aggr_choice_list, + child_state_choice_idxs_to_interp_list, + batches_cover_all, + last_batch_info, + ) = correct_for_uneven_last_batch( + batches_list, + child_states_to_integrate_stochastic_list, + child_state_choices_to_aggr_choice_list, + child_state_choice_idxs_to_interp_list, + state_choice_space_dict, + map_state_choice_to_parent_state, + ) + elif batch_mode == "period_max": + ( + batches_list, + child_state_choice_idxs_to_interp_list, + child_state_choices_to_aggr_choice_list, + child_states_to_integrate_stochastic_list, + ) = determine_period_max_batch_size( + bool_state_choices_to_batch=bool_state_choices_to_batch, + state_choice_space=state_choice_space, + map_state_choice_to_child_states=map_state_choice_to_child_states, + map_state_choice_to_index=map_state_choice_to_index, + state_space=state_space, + ) + batches_cover_all = True + last_batch_info = None + else: + raise ValueError( + f"Unknown batch_mode {batch_mode}. Use 'largest_block' or 'period_max'." + ) single_batch_segment_info = prepare_and_align_batch_arrays( batches_list, @@ -69,6 +93,94 @@ def create_single_segment_of_batches(bool_state_choices_to_batch, model_structur return single_batch_segment_info +def determine_period_max_batch_size( + bool_state_choices_to_batch, + state_choice_space, + map_state_choice_to_child_states, + map_state_choice_to_index, + state_space, +): + invalid_state_idx = np.iinfo(map_state_choice_to_index.dtype).max + out_of_bounds_state_choice_idx = state_choice_space.shape[0] + 1 + + idx_state_choice_raw = np.where(bool_state_choices_to_batch)[0] + if idx_state_choice_raw.size == 0: + raise ValueError("No state choices to batch in segment.") + + periods_to_batch = state_choice_space[idx_state_choice_raw, 0] + periods_unique_desc = np.sort(np.unique(periods_to_batch))[::-1] + + n_state_vars = state_space.shape[1] + + batches_to_check = [] + child_states_to_integrate_exog = [] + child_state_choices_to_aggr_choice = [] + child_state_choice_idxs_to_interpolate = [] + + for period in periods_unique_desc: + batch = idx_state_choice_raw[periods_to_batch == period] + batches_to_check += [batch] + + child_states_idxs = map_state_choice_to_child_states[batch] + unique_child_states, inverse_ids = np.unique( + child_states_idxs, return_index=False, return_inverse=True + ) + child_states_to_integrate_exog += [inverse_ids.reshape(child_states_idxs.shape)] + + child_states_batch = np.take(state_space, unique_child_states, axis=0) + child_states_tuple = tuple( + child_states_batch[:, i] for i in range(n_state_vars) + ) + unique_state_choice_idxs_childs = map_state_choice_to_index[child_states_tuple] + + ( + unique_child_state_choice_idxs, + inverse_child_state_choice_ids, + ) = np.unique( + unique_state_choice_idxs_childs, return_index=False, return_inverse=True + ) + + if ( + len(unique_child_state_choice_idxs) > 0 + and unique_child_state_choice_idxs[-1] == invalid_state_idx + ): + unique_child_state_choice_idxs = unique_child_state_choice_idxs[:-1] + inverse_child_state_choice_ids[ + inverse_child_state_choice_ids >= np.max(inverse_child_state_choice_ids) + ] = out_of_bounds_state_choice_idx + + child_state_choices_to_aggr_choice += [ + inverse_child_state_choice_ids.reshape( + unique_state_choice_idxs_childs.shape + ) + ] + child_state_choice_idxs_to_interpolate += [unique_child_state_choice_idxs] + + max_batch_size = max(len(batch) for batch in batches_to_check) + + for id_batch, batch in enumerate(batches_to_check): + n_to_add = max_batch_size - len(batch) + if n_to_add > 0: + pad_state_choice_idx = np.full(n_to_add, batch[0], dtype=batch.dtype) + batches_to_check[id_batch] = np.concatenate([batch, pad_state_choice_idx]) + + first_row = child_states_to_integrate_exog[id_batch][0:1, :] + child_states_to_integrate_exog[id_batch] = np.concatenate( + [ + child_states_to_integrate_exog[id_batch], + np.repeat(first_row, repeats=n_to_add, axis=0), + ], + axis=0, + ) + + return ( + batches_to_check, + child_state_choice_idxs_to_interpolate, + child_state_choices_to_aggr_choice, + child_states_to_integrate_exog, + ) + + def correct_for_uneven_last_batch( batches_list, child_states_to_integrate_stochastic_list, diff --git a/src/dcegm/pre_processing/check_model_config.py b/src/dcegm/pre_processing/check_model_config.py index 57003cb1..1b6c0a79 100644 --- a/src/dcegm/pre_processing/check_model_config.py +++ b/src/dcegm/pre_processing/check_model_config.py @@ -205,6 +205,18 @@ def check_model_config_and_process(model_config): else: processed_model_config["min_period_batch_segments"] = None + if "batch_mode" in model_config.keys(): + batch_mode = model_config["batch_mode"] + if not isinstance(batch_mode, (str, list)): + raise ValueError("batch_mode must be a string or a list of strings.") + if isinstance(batch_mode, list) and not all( + isinstance(mode, str) for mode in batch_mode + ): + raise ValueError("If batch_mode is a list, all entries must be strings.") + processed_model_config["batch_mode"] = batch_mode + else: + processed_model_config["batch_mode"] = "largest_block" + if "stochastic_states" in model_config.keys(): processed_model_config["stochastic_states"] = model_config["stochastic_states"] diff --git a/src/dcegm/pre_processing/setup_model.py b/src/dcegm/pre_processing/setup_model.py index 0ab2fe4c..8fe37a6c 100644 --- a/src/dcegm/pre_processing/setup_model.py +++ b/src/dcegm/pre_processing/setup_model.py @@ -132,6 +132,7 @@ def create_model_dict( model_structure=model_structure, n_periods=model_config_processed["n_periods"], min_period_batch_segments=model_config_processed["min_period_batch_segments"], + batch_mode=model_config_processed["batch_mode"], ) if not debug_info == "all": # Delete large arrays which is not needed. Not if all is requested From 7aaf23d6b05b39408d12f4adc9f6ce0ed3e45fdb Mon Sep 17 00:00:00 2001 From: MaxBlesch Date: Fri, 24 Apr 2026 10:29:58 +0200 Subject: [PATCH 2/7] Guide --- docs/source/guides/batching.rst | 163 +++++++++++++ docs/source/index.rst | 1 + .../pre_processing/check_model_config.py | 41 +++- tests/test_batch_mode_period_max.py | 217 ++++++++++++++++++ 4 files changed, 418 insertions(+), 4 deletions(-) create mode 100644 docs/source/guides/batching.rst create mode 100644 tests/test_batch_mode_period_max.py diff --git a/docs/source/guides/batching.rst b/docs/source/guides/batching.rst new file mode 100644 index 00000000..c5bf0232 --- /dev/null +++ b/docs/source/guides/batching.rst @@ -0,0 +1,163 @@ +.. _batching_guide: + +Batching Strategy and Segmentation +================================= + +The backward induction in ``dcegm`` is solved in batches. This is a computational detail to make array shapes compatible with fast JAX scans while preserving the model logic. + +Why batching exists +------------------- + +The number of feasible state-choice combinations usually changes over the life cycle. However, vectorized scan steps work best with equal leading dimensions. Batching groups state-choice rows into equal-sized chunks so each scan step can run with fixed shapes. + +Two batching modes +------------------ + +``dcegm`` supports two batching modes: + +- ``largest_block``: + + - Finds large dependency-safe batches. + - Typically yields fewer and larger batches. + - Good default for smooth state-choice profiles. + +- ``period_max``: + + - Uses one batch per period within a segment. + - Pads smaller period batches to the segment-specific maximum number of state choices per period. + - Useful when state-choice counts vary strongly by period. + +Padding rule in ``period_max`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If a period has fewer state choices than the segment maximum, the batch is padded with a valid state-choice index from the same batch (deterministically the first one). This keeps shapes aligned and does not change the solution logic. + +Segmenting the horizon +---------------------- + +Use ``min_period_batch_segments`` to split the pre-terminal part of the horizon into segments. + +- Without segmentation: + + - ``batch_mode`` must be a single string. + +- With segmentation: + + - ``batch_mode`` can be a string (reused for all segments), or + - ``batch_mode`` can be a list with one entry per segment. + +The number of segments is ``len(min_period_batch_segments) + 1``. + +Valid strings are ``"largest_block"`` and ``"period_max"``. + +Examples +~~~~~~~~ + +No segmentation: + +.. code-block:: python + + model_config = { + "n_periods": 20, + "choices": np.arange(3, dtype=int), + "continuous_states": {"assets_end_of_period": np.linspace(0, 100, 200)}, + "n_quad_points": 5, + "batch_mode": "period_max", + } + +With segmentation: + +.. code-block:: python + + model_config = { + "n_periods": 20, + "choices": np.arange(3, dtype=int), + "continuous_states": {"assets_end_of_period": np.linspace(0, 100, 200)}, + "n_quad_points": 5, + "min_period_batch_segments": [8, 14], + "batch_mode": ["period_max", "largest_block", "period_max"], + } + +Using ``get_n_state_choices_per_period`` to choose segments +------------------------------------------------------------ + +After model setup, inspect how many state-choice combinations exist in each period: + +.. code-block:: python + + model = dcegm.setup_model( + model_config=model_config, + model_specs=model_specs, + utility_functions=utility_functions, + utility_functions_final_period=utility_functions_final_period, + budget_constraint=budget_constraint, + state_space_functions=state_space_functions, + stochastic_states_transitions=stochastic_states_transitions, + ) + + n_state_choices = model.get_n_state_choices_per_period() + print(n_state_choices) + +Use this series to detect structural breaks in complexity. Typical heuristics: + +- Keep periods with similar counts in one segment. +- Split where there are abrupt jumps/drops. +- Use ``period_max`` in highly uneven segments. +- Keep ``largest_block`` in smoother segments. + +Worked example: experience growth and retirement regimes +-------------------------------------------------------- + +Consider a model with a discrete experience state where: + +- choice 0: no work, experience unchanged, +- choice 1: regular work, experience increases by 1, +- choice 2: intensive work, experience increases by 2, +- choice 3: retirement. + +Suppose retirement becomes available from period 8, and is mandatory from period 14. + +.. code-block:: python + + def choice_set(period, lagged_choice): + if period >= 14: + return np.array([3], dtype=int) # mandatory retirement + if period >= 8: + return np.array([0, 1, 2, 3], dtype=int) # retirement becomes available + return np.array([0, 1, 2], dtype=int) + + def next_period_deterministic_state(period, choice, experience): + if choice == 1: + experience_next = experience + 1 + elif choice == 2: + experience_next = experience + 2 + else: + experience_next = experience + return { + "period": period + 1, + "lagged_choice": choice, + "experience": experience_next, + } + +In this setup you often see: + +- gradual growth in state-choice counts early on, +- a jump when retirement becomes optional, +- a drop when retirement becomes mandatory. + +This pattern is a good reason to separate segments around the two regime changes: + +.. code-block:: python + + model_config["min_period_batch_segments"] = [8, 14] + model_config["batch_mode"] = ["period_max", "largest_block", "period_max"] + +Validation recommendation +------------------------- + +When changing batching setup, compare solutions across configurations (for the same model and parameters): + +- baseline ``largest_block`` everywhere, +- your segmented/mixed ``batch_mode`` setup. + +Then compare ``value``, ``policy``, and ``endog_grid`` arrays with ``assert_allclose(..., equal_nan=True)``. diff --git a/docs/source/index.rst b/docs/source/index.rst index 1ea4ebb0..6b7dceb8 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -30,6 +30,7 @@ Check out our :ref:`guides` to find information on getting sta :hidden: guides/practitioner_guide + guides/batching guides/templates guides/minimal_example.ipynb diff --git a/src/dcegm/pre_processing/check_model_config.py b/src/dcegm/pre_processing/check_model_config.py index 1b6c0a79..f134d546 100644 --- a/src/dcegm/pre_processing/check_model_config.py +++ b/src/dcegm/pre_processing/check_model_config.py @@ -207,12 +207,45 @@ def check_model_config_and_process(model_config): if "batch_mode" in model_config.keys(): batch_mode = model_config["batch_mode"] + valid_batch_modes = {"largest_block", "period_max"} if not isinstance(batch_mode, (str, list)): raise ValueError("batch_mode must be a string or a list of strings.") - if isinstance(batch_mode, list) and not all( - isinstance(mode, str) for mode in batch_mode - ): - raise ValueError("If batch_mode is a list, all entries must be strings.") + + if isinstance(batch_mode, str): + if batch_mode not in valid_batch_modes: + raise ValueError( + f"batch_mode must be one of {valid_batch_modes}. Got {batch_mode}." + ) + else: + if not all(isinstance(mode, str) for mode in batch_mode): + raise ValueError( + "If batch_mode is a list, all entries must be strings." + ) + if not all(mode in valid_batch_modes for mode in batch_mode): + raise ValueError( + f"All entries in batch_mode must be one of {valid_batch_modes}." + ) + + min_period_batch_segments = processed_model_config[ + "min_period_batch_segments" + ] + if min_period_batch_segments is None: + expected_n_segments = 1 + elif isinstance(min_period_batch_segments, int): + expected_n_segments = 2 + elif isinstance(min_period_batch_segments, list): + expected_n_segments = len(min_period_batch_segments) + 1 + else: + raise ValueError( + "min_period_batch_segments must be None, int, or list." + ) + + if len(batch_mode) != expected_n_segments: + raise ValueError( + "If batch_mode is a list, it must have one entry per segment. " + f"Expected {expected_n_segments}, got {len(batch_mode)}." + ) + processed_model_config["batch_mode"] = batch_mode else: processed_model_config["batch_mode"] = "largest_block" diff --git a/tests/test_batch_mode_period_max.py b/tests/test_batch_mode_period_max.py new file mode 100644 index 00000000..f4493556 --- /dev/null +++ b/tests/test_batch_mode_period_max.py @@ -0,0 +1,217 @@ +import numpy as np +from numpy.testing import assert_allclose + +import dcegm +from dcegm.toy_models.cons_ret_model_dcegm_paper import ( + inverse_marginal_utility_crra, + marginal_utility_crra, + marginal_utility_final_consume_all, +) +from tests.test_changing_choice_set import ( + budget, + choice_set, + flow_utility, + next_period_state, + prob_health, + prob_partner, + sparsity_condition, + utility_final, +) + + +def _get_model_objects(n_periods): + params = { + "rho": 0.5, + "delta": 1, + "phi": 0.5, + "constant": 1, + "exp": 0.1, + "exp_squared": -0.01, + "pension_per_experience": 0.3, + "unemployment_benefits": 0.4, + "health_costs": 0.5, + "consumption_floor": 0, + "p_bad_health_given_good_health": 0.2, + "p_bad_health_given_bad_health": 1, + "p_partner_given_single": 0.5, + "p_partner_given_partner": 0.9, + } + + model_specs = { + "min_age": 0, + "n_periods": n_periods, + "n_choices": 3, + "n_health_states": 2, + "n_partner_states": 2, + "max_experience": n_periods - 1, + "interest_rate": 0.05, + "discount_factor": 0.95, + "taste_shock_scale": 1, + "income_shock_std": 1, + "income_shock_mean": 0.0, + } + + model_config = { + "n_periods": n_periods, + "choices": np.arange(3), + "deterministic_states": { + "experience": np.arange(n_periods), + }, + "continuous_states": { + "assets_end_of_period": np.linspace(0, 500, 100), + }, + "stochastic_states": { + "health": [0, 1], + "partner": [0, 1], + }, + "n_quad_points": 5, + } + + state_space_functions = { + "state_specific_choice_set": choice_set, + "next_period_deterministic_state": next_period_state, + "sparsity_condition": sparsity_condition, + } + + utility_functions = { + "utility": flow_utility, + "marginal_utility": marginal_utility_crra, + "inverse_marginal_utility": inverse_marginal_utility_crra, + } + + utility_functions_final_period = { + "utility": utility_final, + "marginal_utility": marginal_utility_final_consume_all, + } + + exogenous_states_transition = { + "health": prob_health, + "partner": prob_partner, + } + + return ( + params, + model_specs, + model_config, + state_space_functions, + utility_functions, + utility_functions_final_period, + exogenous_states_transition, + ) + + +def _solve_with_config( + model_config, + model_specs, + params, + state_space_functions, + utility_functions, + utility_functions_final_period, + exogenous_states_transition, +): + model = dcegm.setup_model( + model_config=model_config, + model_specs=model_specs, + state_space_functions=state_space_functions, + utility_functions=utility_functions, + utility_functions_final_period=utility_functions_final_period, + budget_constraint=budget, + stochastic_states_transitions=exogenous_states_transition, + ) + return model.solve(params) + + +def test_period_max_equals_largest_block_without_segments(): + ( + params, + model_specs, + model_config, + state_space_functions, + utility_functions, + utility_functions_final_period, + exogenous_states_transition, + ) = _get_model_objects(n_periods=8) + + model_config_baseline = dict(model_config) + model_config_baseline["batch_mode"] = "largest_block" + + model_config_period_max = dict(model_config) + model_config_period_max["batch_mode"] = "period_max" + + solved_baseline = _solve_with_config( + model_config=model_config_baseline, + model_specs=model_specs, + params=params, + state_space_functions=state_space_functions, + utility_functions=utility_functions, + utility_functions_final_period=utility_functions_final_period, + exogenous_states_transition=exogenous_states_transition, + ) + solved_period_max = _solve_with_config( + model_config=model_config_period_max, + model_specs=model_specs, + params=params, + state_space_functions=state_space_functions, + utility_functions=utility_functions, + utility_functions_final_period=utility_functions_final_period, + exogenous_states_transition=exogenous_states_transition, + ) + + assert_allclose(solved_baseline.value, solved_period_max.value, equal_nan=True) + assert_allclose(solved_baseline.policy, solved_period_max.policy, equal_nan=True) + assert_allclose( + solved_baseline.endog_grid, solved_period_max.endog_grid, equal_nan=True + ) + + +def test_period_max_equals_largest_block_with_segments(): + ( + params, + model_specs, + model_config, + state_space_functions, + utility_functions, + utility_functions_final_period, + exogenous_states_transition, + ) = _get_model_objects(n_periods=8) + + model_config_baseline = dict(model_config) + model_config_baseline["min_period_batch_segments"] = [2, 3] + model_config_baseline["batch_mode"] = [ + "largest_block", + "largest_block", + "largest_block", + ] + + model_config_period_max = dict(model_config) + model_config_period_max["min_period_batch_segments"] = [2, 3] + model_config_period_max["batch_mode"] = [ + "period_max", + "largest_block", + "period_max", + ] + + solved_baseline = _solve_with_config( + model_config=model_config_baseline, + model_specs=model_specs, + params=params, + state_space_functions=state_space_functions, + utility_functions=utility_functions, + utility_functions_final_period=utility_functions_final_period, + exogenous_states_transition=exogenous_states_transition, + ) + solved_period_max = _solve_with_config( + model_config=model_config_period_max, + model_specs=model_specs, + params=params, + state_space_functions=state_space_functions, + utility_functions=utility_functions, + utility_functions_final_period=utility_functions_final_period, + exogenous_states_transition=exogenous_states_transition, + ) + + assert_allclose(solved_baseline.value, solved_period_max.value, equal_nan=True) + assert_allclose(solved_baseline.policy, solved_period_max.policy, equal_nan=True) + assert_allclose( + solved_baseline.endog_grid, solved_period_max.endog_grid, equal_nan=True + ) From 2374c0d11677d35d7ca17f8b6c19583bb5ad5d33 Mon Sep 17 00:00:00 2001 From: Annica Gehlen Date: Sat, 25 Apr 2026 21:06:53 +0200 Subject: [PATCH 3/7] Minor edits to guide.@ --- docs/source/guides/batching.rst | 28 ++++++++-------------------- 1 file changed, 8 insertions(+), 20 deletions(-) diff --git a/docs/source/guides/batching.rst b/docs/source/guides/batching.rst index c5bf0232..c92d3f50 100644 --- a/docs/source/guides/batching.rst +++ b/docs/source/guides/batching.rst @@ -26,11 +26,7 @@ Two batching modes - Uses one batch per period within a segment. - Pads smaller period batches to the segment-specific maximum number of state choices per period. - Useful when state-choice counts vary strongly by period. - -Padding rule in ``period_max`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -If a period has fewer state choices than the segment maximum, the batch is padded with a valid state-choice index from the same batch (deterministically the first one). This keeps shapes aligned and does not change the solution logic. + - **Padding rule**: If a period has fewer state choices than the segment maximum, the batch is padded with a valid dummy state-choice index from the same batch (deterministically the first one). This keeps shapes aligned and does not change the solution logic. Segmenting the horizon ---------------------- @@ -78,10 +74,10 @@ With segmentation: "batch_mode": ["period_max", "largest_block", "period_max"], } -Using ``get_n_state_choices_per_period`` to choose segments ------------------------------------------------------------- +Tipp: Use ``get_n_state_choices_per_period`` to choose segments +---------------------------------------------------------------- -After model setup, inspect how many state-choice combinations exist in each period: +To determine sensible segments for batching, inspect the number of state-choice combinations per period. .. code-block:: python @@ -98,15 +94,15 @@ After model setup, inspect how many state-choice combinations exist in each peri n_state_choices = model.get_n_state_choices_per_period() print(n_state_choices) -Use this series to detect structural breaks in complexity. Typical heuristics: +This series can be used to detect structural breaks in complexity. Typical heuristics are: - Keep periods with similar counts in one segment. - Split where there are abrupt jumps/drops. - Use ``period_max`` in highly uneven segments. - Keep ``largest_block`` in smoother segments. -Worked example: experience growth and retirement regimes --------------------------------------------------------- +Example: experience growth and retirement regimes +------------------------------------------------- Consider a model with a discrete experience state where: @@ -152,12 +148,4 @@ This pattern is a good reason to separate segments around the two regime changes model_config["min_period_batch_segments"] = [8, 14] model_config["batch_mode"] = ["period_max", "largest_block", "period_max"] -Validation recommendation -------------------------- - -When changing batching setup, compare solutions across configurations (for the same model and parameters): - -- baseline ``largest_block`` everywhere, -- your segmented/mixed ``batch_mode`` setup. - -Then compare ``value``, ``policy``, and ``endog_grid`` arrays with ``assert_allclose(..., equal_nan=True)``. +We suggest testing different segmentation choices to determine the fastest solution for your model. From f9baa9c049710c7922e2e6053cda508b706d3034 Mon Sep 17 00:00:00 2001 From: Annica Gehlen Date: Sat, 25 Apr 2026 21:22:38 +0200 Subject: [PATCH 4/7] Add default to notebook. --- docs/source/guides/batching.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/guides/batching.rst b/docs/source/guides/batching.rst index c92d3f50..93c10304 100644 --- a/docs/source/guides/batching.rst +++ b/docs/source/guides/batching.rst @@ -20,6 +20,7 @@ Two batching modes - Finds large dependency-safe batches. - Typically yields fewer and larger batches. - Good default for smooth state-choice profiles. + - **Default configuration**: ``dcegm`` uses this batching mode with no segmentation if batching is not configured otherwise. - ``period_max``: From 27e18d93e7310c644ba901102e0e9da512da808e Mon Sep 17 00:00:00 2001 From: MaxBlesch Date: Thu, 7 May 2026 13:03:06 +0200 Subject: [PATCH 5/7] Formatting docstrings --- src/dcegm/pre_processing/batches/batch_creation.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/dcegm/pre_processing/batches/batch_creation.py b/src/dcegm/pre_processing/batches/batch_creation.py index 4d6af9f1..f8e2772f 100644 --- a/src/dcegm/pre_processing/batches/batch_creation.py +++ b/src/dcegm/pre_processing/batches/batch_creation.py @@ -13,14 +13,12 @@ def create_batches_and_information( batch_mode="largest_block", ): """Batches are used instead of periods to have chunks of equal sized state choices. - The batch inparams=paramsformation dictionary contains the following arrays - reflecting the. - - steps in the backward induction: - - batches_state_choice_idx: The state choice indexes in each batch to be solved. - To solve the state choices in the egm step, we have to look at the child states - and the corresponding state choice indexes in the child states. For that we save - the following: + The returned batch information dictionary contains the following arrays + reflecting steps in the backward induction: + + - batches_state_choice_idx: The state choice indexes in each batch to be solved. To solve the state choices in the egm step, we have to look at the child states + and the corresponding state choice indexes in the child states. For that we save + the following: - child_state_choice_idxs_to_interp: The state choice indexes in we need to interpolate the wealth on. - child_states_idxs: The parent state indexes of the child states, i.e. the From 45fa4e59c4e43df5b6a715e73e5fb1a59efb8fec Mon Sep 17 00:00:00 2001 From: MaxBlesch Date: Thu, 7 May 2026 13:03:40 +0200 Subject: [PATCH 6/7] Formatting docstrings --- src/dcegm/pre_processing/batches/batch_creation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/dcegm/pre_processing/batches/batch_creation.py b/src/dcegm/pre_processing/batches/batch_creation.py index f8e2772f..e9b2e8ab 100644 --- a/src/dcegm/pre_processing/batches/batch_creation.py +++ b/src/dcegm/pre_processing/batches/batch_creation.py @@ -16,7 +16,8 @@ def create_batches_and_information( The returned batch information dictionary contains the following arrays reflecting steps in the backward induction: - - batches_state_choice_idx: The state choice indexes in each batch to be solved. To solve the state choices in the egm step, we have to look at the child states + - batches_state_choice_idx: The state choice indexes in each batch to be solved. + To solve the state choices in the egm step, we have to look at the child states and the corresponding state choice indexes in the child states. For that we save the following: - child_state_choice_idxs_to_interp: The state choice indexes in we need to From 00aec016da28c32b1432e71fcc3d9da7e1976898 Mon Sep 17 00:00:00 2001 From: MaxBlesch Date: Thu, 7 May 2026 13:05:25 +0200 Subject: [PATCH 7/7] Fix typo --- src/dcegm/pre_processing/batches/single_segment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dcegm/pre_processing/batches/single_segment.py b/src/dcegm/pre_processing/batches/single_segment.py index ff519c22..de0e0bb7 100644 --- a/src/dcegm/pre_processing/batches/single_segment.py +++ b/src/dcegm/pre_processing/batches/single_segment.py @@ -10,7 +10,7 @@ def create_single_segment_of_batches( ): """Create a single segment of evenly sized batches. - If the last batch is not evenly we correct it. + If the last batch is not evenly sized we correct it. """