diff --git a/docs/source/guides/batching.rst b/docs/source/guides/batching.rst new file mode 100644 index 00000000..93c10304 --- /dev/null +++ b/docs/source/guides/batching.rst @@ -0,0 +1,152 @@ +.. _batching_guide: + +Batching Strategy and Segmentation +================================= + +The backward induction in ``dcegm`` is solved in batches. This is a computational detail to make array shapes compatible with fast JAX scans while preserving the model logic. + +Why batching exists +------------------- + +The number of feasible state-choice combinations usually changes over the life cycle. However, vectorized scan steps work best with equal leading dimensions. Batching groups state-choice rows into equal-sized chunks so each scan step can run with fixed shapes. + +Two batching modes +------------------ + +``dcegm`` supports two batching modes: + +- ``largest_block``: + + - Finds large dependency-safe batches. + - Typically yields fewer and larger batches. + - Good default for smooth state-choice profiles. + - **Default configuration**: ``dcegm`` uses this batching mode with no segmentation if batching is not configured otherwise. + +- ``period_max``: + + - Uses one batch per period within a segment. + - Pads smaller period batches to the segment-specific maximum number of state choices per period. + - Useful when state-choice counts vary strongly by period. + - **Padding rule**: If a period has fewer state choices than the segment maximum, the batch is padded with a valid dummy state-choice index from the same batch (deterministically the first one). This keeps shapes aligned and does not change the solution logic. + +Segmenting the horizon +---------------------- + +Use ``min_period_batch_segments`` to split the pre-terminal part of the horizon into segments. + +- Without segmentation: + + - ``batch_mode`` must be a single string. + +- With segmentation: + + - ``batch_mode`` can be a string (reused for all segments), or + - ``batch_mode`` can be a list with one entry per segment. + +The number of segments is ``len(min_period_batch_segments) + 1``. + +Valid strings are ``"largest_block"`` and ``"period_max"``. + +Examples +~~~~~~~~ + +No segmentation: + +.. code-block:: python + + model_config = { + "n_periods": 20, + "choices": np.arange(3, dtype=int), + "continuous_states": {"assets_end_of_period": np.linspace(0, 100, 200)}, + "n_quad_points": 5, + "batch_mode": "period_max", + } + +With segmentation: + +.. code-block:: python + + model_config = { + "n_periods": 20, + "choices": np.arange(3, dtype=int), + "continuous_states": {"assets_end_of_period": np.linspace(0, 100, 200)}, + "n_quad_points": 5, + "min_period_batch_segments": [8, 14], + "batch_mode": ["period_max", "largest_block", "period_max"], + } + +Tipp: Use ``get_n_state_choices_per_period`` to choose segments +---------------------------------------------------------------- + +To determine sensible segments for batching, inspect the number of state-choice combinations per period. + +.. code-block:: python + + model = dcegm.setup_model( + model_config=model_config, + model_specs=model_specs, + utility_functions=utility_functions, + utility_functions_final_period=utility_functions_final_period, + budget_constraint=budget_constraint, + state_space_functions=state_space_functions, + stochastic_states_transitions=stochastic_states_transitions, + ) + + n_state_choices = model.get_n_state_choices_per_period() + print(n_state_choices) + +This series can be used to detect structural breaks in complexity. Typical heuristics are: + +- Keep periods with similar counts in one segment. +- Split where there are abrupt jumps/drops. +- Use ``period_max`` in highly uneven segments. +- Keep ``largest_block`` in smoother segments. + +Example: experience growth and retirement regimes +------------------------------------------------- + +Consider a model with a discrete experience state where: + +- choice 0: no work, experience unchanged, +- choice 1: regular work, experience increases by 1, +- choice 2: intensive work, experience increases by 2, +- choice 3: retirement. + +Suppose retirement becomes available from period 8, and is mandatory from period 14. + +.. code-block:: python + + def choice_set(period, lagged_choice): + if period >= 14: + return np.array([3], dtype=int) # mandatory retirement + if period >= 8: + return np.array([0, 1, 2, 3], dtype=int) # retirement becomes available + return np.array([0, 1, 2], dtype=int) + + def next_period_deterministic_state(period, choice, experience): + if choice == 1: + experience_next = experience + 1 + elif choice == 2: + experience_next = experience + 2 + else: + experience_next = experience + return { + "period": period + 1, + "lagged_choice": choice, + "experience": experience_next, + } + +In this setup you often see: + +- gradual growth in state-choice counts early on, +- a jump when retirement becomes optional, +- a drop when retirement becomes mandatory. + +This pattern is a good reason to separate segments around the two regime changes: + +.. code-block:: python + + model_config["min_period_batch_segments"] = [8, 14] + model_config["batch_mode"] = ["period_max", "largest_block", "period_max"] + +We suggest testing different segmentation choices to determine the fastest solution for your model. diff --git a/docs/source/index.rst b/docs/source/index.rst index 1ea4ebb0..6b7dceb8 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -30,6 +30,7 @@ Check out our :ref:`guides` to find information on getting sta :hidden: guides/practitioner_guide + guides/batching guides/templates guides/minimal_example.ipynb diff --git a/src/dcegm/pre_processing/batches/batch_creation.py b/src/dcegm/pre_processing/batches/batch_creation.py index 9c9d620a..e9b2e8ab 100644 --- a/src/dcegm/pre_processing/batches/batch_creation.py +++ b/src/dcegm/pre_processing/batches/batch_creation.py @@ -10,16 +10,16 @@ def create_batches_and_information( model_structure, n_periods, min_period_batch_segments=None, + batch_mode="largest_block", ): """Batches are used instead of periods to have chunks of equal sized state choices. - The batch inparams=paramsformation dictionary contains the following arrays - reflecting the. + The returned batch information dictionary contains the following arrays + reflecting steps in the backward induction: - steps in the backward induction: - batches_state_choice_idx: The state choice indexes in each batch to be solved. - To solve the state choices in the egm step, we have to look at the child states - and the corresponding state choice indexes in the child states. For that we save - the following: + To solve the state choices in the egm step, we have to look at the child states + and the corresponding state choice indexes in the child states. For that we save + the following: - child_state_choice_idxs_to_interp: The state choice indexes in we need to interpolate the wealth on. - child_states_idxs: The parent state indexes of the child states, i.e. the @@ -64,10 +64,22 @@ def create_batches_and_information( state_choice_space = model_structure["state_choice_space"] bool_state_choices_to_batch = state_choice_space[:, 0] < n_periods - 2 + valid_batch_modes = {"largest_block", "period_max"} + if min_period_batch_segments is None: + if isinstance(batch_mode, list): + raise ValueError( + "If min_period_batch_segments is not supplied, batch_mode must be a string." + ) + if batch_mode not in valid_batch_modes: + raise ValueError( + f"batch_mode must be one of {valid_batch_modes}. Got {batch_mode}." + ) single_batch_segment_info = create_single_segment_of_batches( - bool_state_choices_to_batch, model_structure + bool_state_choices_to_batch, + model_structure, + batch_mode=batch_mode, ) segment_infos = { "n_segments": 1, @@ -97,6 +109,24 @@ def create_batches_and_information( "The periods to split the batches have to be increasing and at least two periods apart." ) + if isinstance(batch_mode, str): + if batch_mode not in valid_batch_modes: + raise ValueError( + f"batch_mode must be one of {valid_batch_modes}. Got {batch_mode}." + ) + batch_mode = [batch_mode] * n_segments + elif isinstance(batch_mode, list): + if len(batch_mode) != n_segments: + raise ValueError( + "If min_period_batch_segments is supplied, batch_mode must be a list with one entry per segment." + ) + if not all(mode in valid_batch_modes for mode in batch_mode): + raise ValueError( + f"All entries in batch_mode must be one of {valid_batch_modes}." + ) + else: + raise ValueError("batch_mode must be a string or a list of strings.") + segment_infos = { "n_segments": n_segments, } @@ -111,7 +141,9 @@ def create_batches_and_information( bool_state_choices_segment = bool_state_choices_to_batch & (~split_cond) segment_batch_info = create_single_segment_of_batches( - bool_state_choices_segment, model_structure + bool_state_choices_segment, + model_structure, + batch_mode=batch_mode[id_segment], ) segment_infos[f"batches_info_segment_{id_segment}"] = segment_batch_info @@ -119,7 +151,9 @@ def create_batches_and_information( bool_state_choices_to_batch = bool_state_choices_to_batch & split_cond last_segment_batch_info = create_single_segment_of_batches( - bool_state_choices_to_batch, model_structure + bool_state_choices_to_batch, + model_structure, + batch_mode=batch_mode[n_segments - 1], ) # We loop until n_segments - 2 and then add the last segment diff --git a/src/dcegm/pre_processing/batches/single_segment.py b/src/dcegm/pre_processing/batches/single_segment.py index de7db506..de0e0bb7 100644 --- a/src/dcegm/pre_processing/batches/single_segment.py +++ b/src/dcegm/pre_processing/batches/single_segment.py @@ -3,10 +3,14 @@ from dcegm.pre_processing.batches.algo_batch_size import determine_optimal_batch_size -def create_single_segment_of_batches(bool_state_choices_to_batch, model_structure): +def create_single_segment_of_batches( + bool_state_choices_to_batch, + model_structure, + batch_mode="largest_block", +): """Create a single segment of evenly sized batches. - If the last batch is not evenly we correct it. + If the last batch is not evenly sized we correct it. """ @@ -24,34 +28,54 @@ def create_single_segment_of_batches(bool_state_choices_to_batch, model_structur ] map_state_choice_to_index = model_structure["map_state_choice_to_index_with_proxy"] - ( - batches_list, - child_state_choice_idxs_to_interp_list, - child_state_choices_to_aggr_choice_list, - child_states_to_integrate_stochastic_list, - ) = determine_optimal_batch_size( - bool_state_choices_to_batch=bool_state_choices_to_batch, - state_choice_space=state_choice_space, - map_state_choice_to_child_states=map_state_choice_to_child_states, - map_state_choice_to_index=map_state_choice_to_index, - state_space=state_space, - ) + if batch_mode == "largest_block": + ( + batches_list, + child_state_choice_idxs_to_interp_list, + child_state_choices_to_aggr_choice_list, + child_states_to_integrate_stochastic_list, + ) = determine_optimal_batch_size( + bool_state_choices_to_batch=bool_state_choices_to_batch, + state_choice_space=state_choice_space, + map_state_choice_to_child_states=map_state_choice_to_child_states, + map_state_choice_to_index=map_state_choice_to_index, + state_space=state_space, + ) - ( - batches_list, - child_states_to_integrate_stochastic_list, - child_state_choices_to_aggr_choice_list, - child_state_choice_idxs_to_interp_list, - batches_cover_all, - last_batch_info, - ) = correct_for_uneven_last_batch( - batches_list, - child_states_to_integrate_stochastic_list, - child_state_choices_to_aggr_choice_list, - child_state_choice_idxs_to_interp_list, - state_choice_space_dict, - map_state_choice_to_parent_state, - ) + ( + batches_list, + child_states_to_integrate_stochastic_list, + child_state_choices_to_aggr_choice_list, + child_state_choice_idxs_to_interp_list, + batches_cover_all, + last_batch_info, + ) = correct_for_uneven_last_batch( + batches_list, + child_states_to_integrate_stochastic_list, + child_state_choices_to_aggr_choice_list, + child_state_choice_idxs_to_interp_list, + state_choice_space_dict, + map_state_choice_to_parent_state, + ) + elif batch_mode == "period_max": + ( + batches_list, + child_state_choice_idxs_to_interp_list, + child_state_choices_to_aggr_choice_list, + child_states_to_integrate_stochastic_list, + ) = determine_period_max_batch_size( + bool_state_choices_to_batch=bool_state_choices_to_batch, + state_choice_space=state_choice_space, + map_state_choice_to_child_states=map_state_choice_to_child_states, + map_state_choice_to_index=map_state_choice_to_index, + state_space=state_space, + ) + batches_cover_all = True + last_batch_info = None + else: + raise ValueError( + f"Unknown batch_mode {batch_mode}. Use 'largest_block' or 'period_max'." + ) single_batch_segment_info = prepare_and_align_batch_arrays( batches_list, @@ -69,6 +93,94 @@ def create_single_segment_of_batches(bool_state_choices_to_batch, model_structur return single_batch_segment_info +def determine_period_max_batch_size( + bool_state_choices_to_batch, + state_choice_space, + map_state_choice_to_child_states, + map_state_choice_to_index, + state_space, +): + invalid_state_idx = np.iinfo(map_state_choice_to_index.dtype).max + out_of_bounds_state_choice_idx = state_choice_space.shape[0] + 1 + + idx_state_choice_raw = np.where(bool_state_choices_to_batch)[0] + if idx_state_choice_raw.size == 0: + raise ValueError("No state choices to batch in segment.") + + periods_to_batch = state_choice_space[idx_state_choice_raw, 0] + periods_unique_desc = np.sort(np.unique(periods_to_batch))[::-1] + + n_state_vars = state_space.shape[1] + + batches_to_check = [] + child_states_to_integrate_exog = [] + child_state_choices_to_aggr_choice = [] + child_state_choice_idxs_to_interpolate = [] + + for period in periods_unique_desc: + batch = idx_state_choice_raw[periods_to_batch == period] + batches_to_check += [batch] + + child_states_idxs = map_state_choice_to_child_states[batch] + unique_child_states, inverse_ids = np.unique( + child_states_idxs, return_index=False, return_inverse=True + ) + child_states_to_integrate_exog += [inverse_ids.reshape(child_states_idxs.shape)] + + child_states_batch = np.take(state_space, unique_child_states, axis=0) + child_states_tuple = tuple( + child_states_batch[:, i] for i in range(n_state_vars) + ) + unique_state_choice_idxs_childs = map_state_choice_to_index[child_states_tuple] + + ( + unique_child_state_choice_idxs, + inverse_child_state_choice_ids, + ) = np.unique( + unique_state_choice_idxs_childs, return_index=False, return_inverse=True + ) + + if ( + len(unique_child_state_choice_idxs) > 0 + and unique_child_state_choice_idxs[-1] == invalid_state_idx + ): + unique_child_state_choice_idxs = unique_child_state_choice_idxs[:-1] + inverse_child_state_choice_ids[ + inverse_child_state_choice_ids >= np.max(inverse_child_state_choice_ids) + ] = out_of_bounds_state_choice_idx + + child_state_choices_to_aggr_choice += [ + inverse_child_state_choice_ids.reshape( + unique_state_choice_idxs_childs.shape + ) + ] + child_state_choice_idxs_to_interpolate += [unique_child_state_choice_idxs] + + max_batch_size = max(len(batch) for batch in batches_to_check) + + for id_batch, batch in enumerate(batches_to_check): + n_to_add = max_batch_size - len(batch) + if n_to_add > 0: + pad_state_choice_idx = np.full(n_to_add, batch[0], dtype=batch.dtype) + batches_to_check[id_batch] = np.concatenate([batch, pad_state_choice_idx]) + + first_row = child_states_to_integrate_exog[id_batch][0:1, :] + child_states_to_integrate_exog[id_batch] = np.concatenate( + [ + child_states_to_integrate_exog[id_batch], + np.repeat(first_row, repeats=n_to_add, axis=0), + ], + axis=0, + ) + + return ( + batches_to_check, + child_state_choice_idxs_to_interpolate, + child_state_choices_to_aggr_choice, + child_states_to_integrate_exog, + ) + + def correct_for_uneven_last_batch( batches_list, child_states_to_integrate_stochastic_list, diff --git a/src/dcegm/pre_processing/check_model_config.py b/src/dcegm/pre_processing/check_model_config.py index 57003cb1..f134d546 100644 --- a/src/dcegm/pre_processing/check_model_config.py +++ b/src/dcegm/pre_processing/check_model_config.py @@ -205,6 +205,51 @@ def check_model_config_and_process(model_config): else: processed_model_config["min_period_batch_segments"] = None + if "batch_mode" in model_config.keys(): + batch_mode = model_config["batch_mode"] + valid_batch_modes = {"largest_block", "period_max"} + if not isinstance(batch_mode, (str, list)): + raise ValueError("batch_mode must be a string or a list of strings.") + + if isinstance(batch_mode, str): + if batch_mode not in valid_batch_modes: + raise ValueError( + f"batch_mode must be one of {valid_batch_modes}. Got {batch_mode}." + ) + else: + if not all(isinstance(mode, str) for mode in batch_mode): + raise ValueError( + "If batch_mode is a list, all entries must be strings." + ) + if not all(mode in valid_batch_modes for mode in batch_mode): + raise ValueError( + f"All entries in batch_mode must be one of {valid_batch_modes}." + ) + + min_period_batch_segments = processed_model_config[ + "min_period_batch_segments" + ] + if min_period_batch_segments is None: + expected_n_segments = 1 + elif isinstance(min_period_batch_segments, int): + expected_n_segments = 2 + elif isinstance(min_period_batch_segments, list): + expected_n_segments = len(min_period_batch_segments) + 1 + else: + raise ValueError( + "min_period_batch_segments must be None, int, or list." + ) + + if len(batch_mode) != expected_n_segments: + raise ValueError( + "If batch_mode is a list, it must have one entry per segment. " + f"Expected {expected_n_segments}, got {len(batch_mode)}." + ) + + processed_model_config["batch_mode"] = batch_mode + else: + processed_model_config["batch_mode"] = "largest_block" + if "stochastic_states" in model_config.keys(): processed_model_config["stochastic_states"] = model_config["stochastic_states"] diff --git a/src/dcegm/pre_processing/setup_model.py b/src/dcegm/pre_processing/setup_model.py index 0ab2fe4c..8fe37a6c 100644 --- a/src/dcegm/pre_processing/setup_model.py +++ b/src/dcegm/pre_processing/setup_model.py @@ -132,6 +132,7 @@ def create_model_dict( model_structure=model_structure, n_periods=model_config_processed["n_periods"], min_period_batch_segments=model_config_processed["min_period_batch_segments"], + batch_mode=model_config_processed["batch_mode"], ) if not debug_info == "all": # Delete large arrays which is not needed. Not if all is requested diff --git a/tests/test_batch_mode_period_max.py b/tests/test_batch_mode_period_max.py new file mode 100644 index 00000000..f4493556 --- /dev/null +++ b/tests/test_batch_mode_period_max.py @@ -0,0 +1,217 @@ +import numpy as np +from numpy.testing import assert_allclose + +import dcegm +from dcegm.toy_models.cons_ret_model_dcegm_paper import ( + inverse_marginal_utility_crra, + marginal_utility_crra, + marginal_utility_final_consume_all, +) +from tests.test_changing_choice_set import ( + budget, + choice_set, + flow_utility, + next_period_state, + prob_health, + prob_partner, + sparsity_condition, + utility_final, +) + + +def _get_model_objects(n_periods): + params = { + "rho": 0.5, + "delta": 1, + "phi": 0.5, + "constant": 1, + "exp": 0.1, + "exp_squared": -0.01, + "pension_per_experience": 0.3, + "unemployment_benefits": 0.4, + "health_costs": 0.5, + "consumption_floor": 0, + "p_bad_health_given_good_health": 0.2, + "p_bad_health_given_bad_health": 1, + "p_partner_given_single": 0.5, + "p_partner_given_partner": 0.9, + } + + model_specs = { + "min_age": 0, + "n_periods": n_periods, + "n_choices": 3, + "n_health_states": 2, + "n_partner_states": 2, + "max_experience": n_periods - 1, + "interest_rate": 0.05, + "discount_factor": 0.95, + "taste_shock_scale": 1, + "income_shock_std": 1, + "income_shock_mean": 0.0, + } + + model_config = { + "n_periods": n_periods, + "choices": np.arange(3), + "deterministic_states": { + "experience": np.arange(n_periods), + }, + "continuous_states": { + "assets_end_of_period": np.linspace(0, 500, 100), + }, + "stochastic_states": { + "health": [0, 1], + "partner": [0, 1], + }, + "n_quad_points": 5, + } + + state_space_functions = { + "state_specific_choice_set": choice_set, + "next_period_deterministic_state": next_period_state, + "sparsity_condition": sparsity_condition, + } + + utility_functions = { + "utility": flow_utility, + "marginal_utility": marginal_utility_crra, + "inverse_marginal_utility": inverse_marginal_utility_crra, + } + + utility_functions_final_period = { + "utility": utility_final, + "marginal_utility": marginal_utility_final_consume_all, + } + + exogenous_states_transition = { + "health": prob_health, + "partner": prob_partner, + } + + return ( + params, + model_specs, + model_config, + state_space_functions, + utility_functions, + utility_functions_final_period, + exogenous_states_transition, + ) + + +def _solve_with_config( + model_config, + model_specs, + params, + state_space_functions, + utility_functions, + utility_functions_final_period, + exogenous_states_transition, +): + model = dcegm.setup_model( + model_config=model_config, + model_specs=model_specs, + state_space_functions=state_space_functions, + utility_functions=utility_functions, + utility_functions_final_period=utility_functions_final_period, + budget_constraint=budget, + stochastic_states_transitions=exogenous_states_transition, + ) + return model.solve(params) + + +def test_period_max_equals_largest_block_without_segments(): + ( + params, + model_specs, + model_config, + state_space_functions, + utility_functions, + utility_functions_final_period, + exogenous_states_transition, + ) = _get_model_objects(n_periods=8) + + model_config_baseline = dict(model_config) + model_config_baseline["batch_mode"] = "largest_block" + + model_config_period_max = dict(model_config) + model_config_period_max["batch_mode"] = "period_max" + + solved_baseline = _solve_with_config( + model_config=model_config_baseline, + model_specs=model_specs, + params=params, + state_space_functions=state_space_functions, + utility_functions=utility_functions, + utility_functions_final_period=utility_functions_final_period, + exogenous_states_transition=exogenous_states_transition, + ) + solved_period_max = _solve_with_config( + model_config=model_config_period_max, + model_specs=model_specs, + params=params, + state_space_functions=state_space_functions, + utility_functions=utility_functions, + utility_functions_final_period=utility_functions_final_period, + exogenous_states_transition=exogenous_states_transition, + ) + + assert_allclose(solved_baseline.value, solved_period_max.value, equal_nan=True) + assert_allclose(solved_baseline.policy, solved_period_max.policy, equal_nan=True) + assert_allclose( + solved_baseline.endog_grid, solved_period_max.endog_grid, equal_nan=True + ) + + +def test_period_max_equals_largest_block_with_segments(): + ( + params, + model_specs, + model_config, + state_space_functions, + utility_functions, + utility_functions_final_period, + exogenous_states_transition, + ) = _get_model_objects(n_periods=8) + + model_config_baseline = dict(model_config) + model_config_baseline["min_period_batch_segments"] = [2, 3] + model_config_baseline["batch_mode"] = [ + "largest_block", + "largest_block", + "largest_block", + ] + + model_config_period_max = dict(model_config) + model_config_period_max["min_period_batch_segments"] = [2, 3] + model_config_period_max["batch_mode"] = [ + "period_max", + "largest_block", + "period_max", + ] + + solved_baseline = _solve_with_config( + model_config=model_config_baseline, + model_specs=model_specs, + params=params, + state_space_functions=state_space_functions, + utility_functions=utility_functions, + utility_functions_final_period=utility_functions_final_period, + exogenous_states_transition=exogenous_states_transition, + ) + solved_period_max = _solve_with_config( + model_config=model_config_period_max, + model_specs=model_specs, + params=params, + state_space_functions=state_space_functions, + utility_functions=utility_functions, + utility_functions_final_period=utility_functions_final_period, + exogenous_states_transition=exogenous_states_transition, + ) + + assert_allclose(solved_baseline.value, solved_period_max.value, equal_nan=True) + assert_allclose(solved_baseline.policy, solved_period_max.policy, equal_nan=True) + assert_allclose( + solved_baseline.endog_grid, solved_period_max.endog_grid, equal_nan=True + )