diff --git a/CHANGELOG.md b/CHANGELOG.md index 10fc97e57..4734402a8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ This file is used to list changes made in each version of the aws-parallelcluste **BUG FIXES** - Fix clustermgtd failing to detect compute node bootstrap timeouts, which prevented the cluster from entering protected mode. +- Fix an issue where compute nodes are incorrectly replaced when launching a large number of nodes due to eventual consistency. 3.15.0 ------ diff --git a/src/common/schedulers/slurm_commands.py b/src/common/schedulers/slurm_commands.py index 38a0fc513..3404e1aca 100644 --- a/src/common/schedulers/slurm_commands.py +++ b/src/common/schedulers/slurm_commands.py @@ -63,7 +63,8 @@ SCONTROL_OUTPUT_AWK_PARSER = ( 'awk \'BEGIN{{RS="\\n\\n" ; ORS="######\\n";}} {{print}}\' | ' + "grep -oP '^(NodeName=\\S+)|(NodeAddr=\\S+)|(NodeHostName=\\S+)|(? List[SlurmNode]: "SlurmdStartTime": "slurmdstarttime", "LastBusyTime": "lastbusytime", "ReservationName": "reservation_name", + "InstanceId": "instance_id", } date_fields = ["SlurmdStartTime", "LastBusyTime"] @@ -449,6 +459,9 @@ def _parse_nodes_info(slurm_node_info: str) -> List[SlurmNode]: value = datetime.strptime(value, "%Y-%m-%dT%H:%M:%S").astimezone(tz=timezone.utc) else: value = None + elif key == "InstanceId" and value == "(null)": + # Slurm reports an unset InstanceId as "(null)" + value = None kwargs[map_slurm_key_to_arg[key]] = value if lines: try: diff --git a/src/slurm_plugin/clustermgtd.py b/src/slurm_plugin/clustermgtd.py index 7fcd4ace9..a09924afb 100644 --- a/src/slurm_plugin/clustermgtd.py +++ b/src/slurm_plugin/clustermgtd.py @@ -44,6 +44,7 @@ from slurm_plugin.cluster_event_publisher import ClusterEventPublisher from slurm_plugin.common import TIMESTAMP_FORMAT, ScalingStrategy, log_exception, print_with_count from slurm_plugin.console_logger import ConsoleLogger +from slurm_plugin.fleet_manager import INSTANCE_INFO_RETRIEVAL_TIMEOUT_DEFAULT from slurm_plugin.instance_manager import InstanceManager from slurm_plugin.slurm_resources import ( CONFIG_FILE_DIR, @@ -147,6 +148,7 @@ class ClustermgtdConfig: "run_instances_overrides": "/opt/slurm/etc/pcluster/run_instances_overrides.json", "create_fleet_overrides": "/opt/slurm/etc/pcluster/create_fleet_overrides.json", "fleet_config_file": "/etc/parallelcluster/slurm_plugin/fleet-config.json", + "instance_info_retrieval_timeout": INSTANCE_INFO_RETRIEVAL_TIMEOUT_DEFAULT, # Terminate configs "terminate_max_batch_size": 1000, # Timeout to wait for node initialization, should be the same as ResumeTimeout @@ -257,6 +259,11 @@ def _get_launch_config(self, config): "clustermgtd", "create_fleet_overrides", fallback=self.DEFAULTS.get("create_fleet_overrides") ) self.create_fleet_overrides = read_json(create_fleet_overrides_file, default={}) + self.instance_info_retrieval_timeout = config.getint( + "clustermgtd", + "instance_info_retrieval_timeout", + fallback=self.DEFAULTS.get("instance_info_retrieval_timeout"), + ) def _get_health_check_config(self, config): self.disable_ec2_health_check = config.getboolean( @@ -452,6 +459,7 @@ def _initialize_instance_manager(config): run_instances_overrides=config.run_instances_overrides, create_fleet_overrides=config.create_fleet_overrides, fleet_config=config.fleet_config, + instance_info_retrieval_timeout=config.instance_info_retrieval_timeout, ) def _initialize_executor(self, config): @@ -1149,15 +1157,14 @@ def _parse_scheduler_nodes_data(nodes): @staticmethod def _update_slurm_nodes_with_ec2_info(nodes, cluster_instances): + """Associate EC2 instances with Slurm nodes by matching on instance ID.""" if cluster_instances: - ip_to_slurm_node_map = {node.nodeaddr: node for node in nodes} + instance_id_to_slurm_node_map = {node.instance_id: node for node in nodes if node.instance_id} for instance in cluster_instances: - for private_ip in instance.all_private_ips: - if private_ip in ip_to_slurm_node_map: - slurm_node = ip_to_slurm_node_map.get(private_ip) - slurm_node.instance = instance - instance.slurm_node = slurm_node - break + if instance.id in instance_id_to_slurm_node_map: + slurm_node = instance_id_to_slurm_node_map[instance.id] + slurm_node.instance = instance + instance.slurm_node = slurm_node @staticmethod def get_instance_id_to_active_node_map(partitions: List[SlurmPartition]) -> Dict: diff --git a/src/slurm_plugin/fleet_manager.py b/src/slurm_plugin/fleet_manager.py index 8fc932192..37fb37abb 100644 --- a/src/slurm_plugin/fleet_manager.py +++ b/src/slurm_plugin/fleet_manager.py @@ -24,6 +24,12 @@ logger = logging.getLogger(__name__) +# Total time budget (seconds) and per-attempt backoff cap for retrying DescribeInstances after a CreateFleet +# launch, to tolerate EC2 API eventual consistency. +# See https://docs.aws.amazon.com/ec2/latest/devguide/eventual-consistency.html +INSTANCE_INFO_RETRIEVAL_TIMEOUT_DEFAULT = 90 +INSTANCE_INFO_RETRIEVAL_MAX_BACKOFF = 30 + class EC2Instance: def __init__(self, id, private_ip, hostname, all_private_ips, launch_time): @@ -94,6 +100,7 @@ def get_manager( all_or_nothing, run_instances_overrides, create_fleet_overrides, + instance_info_retrieval_timeout=INSTANCE_INFO_RETRIEVAL_TIMEOUT_DEFAULT, ): try: queue_config = fleet_config[queue] @@ -120,6 +127,7 @@ def get_manager( compute_resource_config, all_or_nothing, create_fleet_overrides.get(queue, {}).get(compute_resource, {}), + instance_info_retrieval_timeout=instance_info_retrieval_timeout, ) elif api == "run-instances": return Ec2RunInstancesManager( @@ -272,6 +280,7 @@ def __init__( compute_resource_config, all_or_nothing, launch_overrides, + instance_info_retrieval_timeout=INSTANCE_INFO_RETRIEVAL_TIMEOUT_DEFAULT, ): super().__init__( cluster_name, @@ -283,6 +292,7 @@ def __init__( all_or_nothing, launch_overrides, ) + self._instance_info_retrieval_timeout = instance_info_retrieval_timeout def _evaluate_template_overrides(self) -> list: """Build and return the list of Launch Template Overrides to be applied in the CreateFleet request. @@ -436,22 +446,39 @@ def _get_instances_info(self, instance_ids: list): """ Describe instances to retrieve info not available from create-fleet response. + Right after a CreateFleet launch, DescribeInstances may return instances with missing info + (e.g. PrivateIpAddress) or even InvalidInstanceID.NotFound due to EC2 API eventual consistency. + Retry with exponential backoff (capped per attempt) until the configured total timeout is reached, + as recommended at https://docs.aws.amazon.com/ec2/latest/devguide/eventual-consistency.html + :raises ClientError in case of boto3 failure :return list of instances with complete information and list of IDs for instances with incomplete information """ instances = [] partial_instance_ids = instance_ids - retries = 5 attempt_count = 0 + # Budget is tracked against the un-jittered backoff; jitter is added only to the actual sleep. + elapsed_backoff = 0 # Wait for instances to be available in EC2 time.sleep(0.1) - while attempt_count < retries and partial_instance_ids: + while partial_instance_ids: complete_instances, partial_instance_ids = self._retrieve_instances_info_from_ec2(partial_instance_ids) instances.extend(complete_instances) + if not partial_instance_ids: + break + base_backoff = min(0.3 * 2 ** (attempt_count + 1), INSTANCE_INFO_RETRIEVAL_MAX_BACKOFF) + if elapsed_backoff + base_backoff > self._instance_info_retrieval_timeout: + logger.warning( + "Unable to retrieve complete info for instances %s within %s seconds, giving up after %s attempts.", + print_with_count(partial_instance_ids), + self._instance_info_retrieval_timeout, + attempt_count + 1, + ) + break + elapsed_backoff += base_backoff attempt_count += 1 - if attempt_count < retries: - time.sleep(0.3 * 2**attempt_count + (secrets.randbelow(500) / 1000)) + time.sleep(base_backoff + (secrets.randbelow(500) / 1000)) return instances, partial_instance_ids diff --git a/src/slurm_plugin/instance_manager.py b/src/slurm_plugin/instance_manager.py index bd60ec579..6d9c32541 100644 --- a/src/slurm_plugin/instance_manager.py +++ b/src/slurm_plugin/instance_manager.py @@ -25,7 +25,11 @@ from common.schedulers.slurm_commands import get_nodes_info, update_nodes from common.utils import grouper, setup_logging_filter from slurm_plugin.common import ComputeInstanceDescriptor, ScalingStrategy, log_exception, print_with_count -from slurm_plugin.fleet_manager import EC2Instance, FleetManagerFactory +from slurm_plugin.fleet_manager import ( + INSTANCE_INFO_RETRIEVAL_TIMEOUT_DEFAULT, + EC2Instance, + FleetManagerFactory, +) from slurm_plugin.slurm_resources import ( EC2_HEALTH_STATUS_UNHEALTHY_STATES, EC2_INSTANCE_ALIVE_STATES, @@ -85,6 +89,7 @@ def __init__( run_instances_overrides: dict = None, create_fleet_overrides: dict = None, job_level_scaling: bool = False, + instance_info_retrieval_timeout: int = INSTANCE_INFO_RETRIEVAL_TIMEOUT_DEFAULT, ): """Initialize InstanceLauncher with required attributes.""" self._region = region @@ -101,6 +106,7 @@ def __init__( self._fleet_config = fleet_config self._run_instances_overrides = run_instances_overrides or {} self._create_fleet_overrides = create_fleet_overrides or {} + self._instance_info_retrieval_timeout = instance_info_retrieval_timeout self._boto3_resource_factory = lambda resource_name: boto3.session.Session().resource( resource_name, region_name=region, config=boto3_config ) @@ -262,7 +268,8 @@ def get_cluster_instances(self, include_head_node=False, alive_states_only=True) """ Get instances that are associated with the cluster. - Instances without all the info set are ignored and not returned + Instances with missing info (e.g. PrivateIpAddress due to EC2 eventual consistency) are kept with + empty IP fields so that clustermgtd can still match them to Slurm nodes by instance ID. """ ec2_client = boto3.client("ec2", region_name=self._region, config=self._boto3_config) paginator = ec2_client.get_paginator("describe_instances") @@ -290,12 +297,17 @@ def get_cluster_instances(self, include_head_node=False, alive_states_only=True) ) ) except Exception as e: + # Keep the instance with empty IP info so it can still be matched by instance ID in clustermgtd. logger.warning( - "Ignoring instance %s because not all EC2 info are available, exception: %s, message: %s", + "Incomplete EC2 info for instance %s, keeping it for instance-ID matching, " + "exception: %s, message: %s", instance_info["InstanceId"], type(e).__name__, e, ) + instances.append( + EC2Instance(instance_info["InstanceId"], "", "", set(), instance_info.get("LaunchTime")) + ) return instances @@ -1008,6 +1020,7 @@ def _get_fleet_manager(self, all_or_nothing_batch, compute_resource, queue): all_or_nothing=all_or_nothing_batch, run_instances_overrides=self._run_instances_overrides, create_fleet_overrides=self._create_fleet_overrides, + instance_info_retrieval_timeout=self._instance_info_retrieval_timeout, ) return fleet_manager @@ -1077,6 +1090,7 @@ def _update_slurm_node_addrs(self, slurm_nodes: List[str], launched_instances: L slurm_nodes, nodeaddrs=[instance.private_ip for instance in launched_instances], nodehostnames=node_hostnames, + instance_ids=[instance.id for instance in launched_instances], ) logger.info( "Nodes are now configured with instances %s", diff --git a/src/slurm_plugin/resume.py b/src/slurm_plugin/resume.py index cb9b22e7c..98e1ad1a9 100644 --- a/src/slurm_plugin/resume.py +++ b/src/slurm_plugin/resume.py @@ -22,6 +22,7 @@ from common.utils import read_json from slurm_plugin.cluster_event_publisher import ClusterEventPublisher from slurm_plugin.common import ScalingStrategy, is_clustermgtd_heartbeat_valid, print_with_count +from slurm_plugin.fleet_manager import INSTANCE_INFO_RETRIEVAL_TIMEOUT_DEFAULT from slurm_plugin.instance_manager import InstanceManager from slurm_plugin.slurm_resources import CONFIG_FILE_DIR @@ -47,6 +48,7 @@ class SlurmResumeConfig: "fleet_config_file": "/etc/parallelcluster/slurm_plugin/fleet-config.json", "job_level_scaling": True, "scaling_strategy": "all-or-nothing", + "instance_info_retrieval_timeout": INSTANCE_INFO_RETRIEVAL_TIMEOUT_DEFAULT, } def __init__(self, config_file_path): @@ -96,6 +98,11 @@ def _get_config(self, config_file_path): self.job_level_scaling = config.getboolean( "slurm_resume", "job_level_scaling", fallback=self.DEFAULTS.get("job_level_scaling") ) + self.instance_info_retrieval_timeout = config.getint( + "slurm_resume", + "instance_info_retrieval_timeout", + fallback=self.DEFAULTS.get("instance_info_retrieval_timeout"), + ) fleet_config_file = config.get( "slurm_resume", "fleet_config_file", fallback=self.DEFAULTS.get("fleet_config_file") ) @@ -206,6 +213,7 @@ def _resume(arg_nodes, resume_config, slurm_resume): run_instances_overrides=resume_config.run_instances_overrides, create_fleet_overrides=resume_config.create_fleet_overrides, job_level_scaling=resume_config.job_level_scaling, + instance_info_retrieval_timeout=resume_config.instance_info_retrieval_timeout, ) instance_manager.add_instances( slurm_resume=slurm_resume, diff --git a/src/slurm_plugin/slurm_resources.py b/src/slurm_plugin/slurm_resources.py index 896b8e8f5..842e6a53a 100644 --- a/src/slurm_plugin/slurm_resources.py +++ b/src/slurm_plugin/slurm_resources.py @@ -243,6 +243,7 @@ def __init__( slurmdstarttime: datetime = None, lastbusytime: datetime = None, reservation_name: str = None, + instance_id: str = None, ): """Initialize slurm node with attributes.""" self.name = name @@ -253,6 +254,7 @@ def __init__( self.partitions = partitions.strip().split(",") if partitions else None self.reason = reason self.instance = instance + self.instance_id = instance_id self.slurmdstarttime = slurmdstarttime self.lastbusytime = lastbusytime self.reservation_name = reservation_name @@ -533,6 +535,7 @@ def __init__( slurmdstarttime=None, lastbusytime=None, reservation_name=None, + instance_id=None, ): """Initialize slurm node with attributes.""" super().__init__( @@ -546,6 +549,7 @@ def __init__( slurmdstarttime, lastbusytime=lastbusytime, reservation_name=reservation_name, + instance_id=instance_id, ) def is_healthy( @@ -667,6 +671,7 @@ def __init__( slurmdstarttime=None, lastbusytime=None, reservation_name=None, + instance_id=None, ): """Initialize slurm node with attributes.""" super().__init__( @@ -680,6 +685,7 @@ def __init__( slurmdstarttime, lastbusytime=lastbusytime, reservation_name=reservation_name, + instance_id=instance_id, ) def is_state_healthy(self, consider_drain_as_unhealthy, consider_down_as_unhealthy, log_warn_if_unhealthy=True): diff --git a/tests/common/schedulers/test_slurm_commands.py b/tests/common/schedulers/test_slurm_commands.py index 6b6fa0d80..083026f01 100644 --- a/tests/common/schedulers/test_slurm_commands.py +++ b/tests/common/schedulers/test_slurm_commands.py @@ -245,6 +245,46 @@ def test_is_static_node(nodename, expected_is_static): ], True, ), + # Test case: InstanceId is parsed from scontrol show nodes output; "(null)" is normalized to None + ( + "NodeName=queue1-st-c5xlarge-1\n" + "NodeAddr=10.0.1.1\n" + "NodeHostName=queue1-st-c5xlarge-1\n" + "State=IDLE+CLOUD\n" + "Partitions=queue1\n" + "SlurmdStartTime=2023-01-23T17:57:07\n" + "InstanceId=i-0abc123def456\n" + "######\n" + "NodeName=queue1-dy-c5xlarge-2\n" + "NodeAddr=queue1-dy-c5xlarge-2\n" + "NodeHostName=queue1-dy-c5xlarge-2\n" + "State=IDLE+CLOUD+POWER\n" + "Partitions=queue1\n" + "SlurmdStartTime=None\n" + "InstanceId=(null)\n" + "######\n", + [ + StaticNode( + "queue1-st-c5xlarge-1", + "10.0.1.1", + "queue1-st-c5xlarge-1", + "IDLE+CLOUD", + "queue1", + slurmdstarttime=datetime(2023, 1, 23, 17, 57, 7).astimezone(tz=timezone.utc), + instance_id="i-0abc123def456", + ), + DynamicNode( + "queue1-dy-c5xlarge-2", + "queue1-dy-c5xlarge-2", + "queue1-dy-c5xlarge-2", + "IDLE+CLOUD+POWER", + "queue1", + slurmdstarttime=None, + instance_id=None, + ), + ], + False, + ), ], ) def test_parse_nodes_info(node_info, expected_parsed_nodes_output, invalid_name, caplog): @@ -255,14 +295,15 @@ def test_parse_nodes_info(node_info, expected_parsed_nodes_output, invalid_name, @pytest.mark.parametrize( - "nodenames, nodeaddrs, hostnames, batch_size, expected_result", + "nodenames, nodeaddrs, hostnames, instance_ids, batch_size, expected_result", [ ( "queue1-st-c5xlarge-1,queue1-st-c5xlarge-2,queue1-st-c5xlarge-3", None, None, + None, 2, - [("queue1-st-c5xlarge-1,queue1-st-c5xlarge-2,queue1-st-c5xlarge-3", None, None)], + [("queue1-st-c5xlarge-1,queue1-st-c5xlarge-2,queue1-st-c5xlarge-3", None, None, None)], ), ( # Only split on commas after bucket @@ -270,12 +311,14 @@ def test_parse_nodes_info(node_info, expected_parsed_nodes_output, invalid_name, "queue1-st-c5xlarge-[1-2],queue1-st-c5xlarge-2,queue1-st-c5xlarge-3,queue1-st-c5xlarge-[4,6]", "nodeaddr-[1-2],nodeaddr-2,nodeaddr-3,nodeaddr-[4,6]", None, + None, 2, [ ( "queue1-st-c5xlarge-[1-2],queue1-st-c5xlarge-2,queue1-st-c5xlarge-3,queue1-st-c5xlarge-[4,6]", "nodeaddr-[1-2],nodeaddr-2,nodeaddr-3,nodeaddr-[4,6]", None, + None, ) ], ), @@ -283,21 +326,45 @@ def test_parse_nodes_info(node_info, expected_parsed_nodes_output, invalid_name, "queue1-st-c5xlarge-[1-2],queue1-st-c5xlarge-2,queue1-st-c5xlarge-[3],queue1-st-c5xlarge-[4,6]", "nodeaddr-[1-2],nodeaddr-2,nodeaddr-[3],nodeaddr-[4,6]", "nodehostname-[1-2],nodehostname-2,nodehostname-[3],nodehostname-[4,6]", + None, 2, [ ( "queue1-st-c5xlarge-[1-2],queue1-st-c5xlarge-2,queue1-st-c5xlarge-[3]", "nodeaddr-[1-2],nodeaddr-2,nodeaddr-[3]", "nodehostname-[1-2],nodehostname-2,nodehostname-[3]", + None, ), - ("queue1-st-c5xlarge-[4,6]", "nodeaddr-[4,6]", "nodehostname-[4,6]"), + ("queue1-st-c5xlarge-[4,6]", "nodeaddr-[4,6]", "nodehostname-[4,6]", None), ], ), - ("queue1-st-c5xlarge-1,queue1-st-c5xlarge-[2],queue1-st-c5xlarge-3", ["nodeaddr-1"], None, 2, ValueError), + ( + # nodeaddr and instanceid are batched together, distributed across the nodes in each batch + ["queue1-st-c5xlarge-1", "queue1-st-c5xlarge-2", "queue1-st-c5xlarge-3"], + ["nodeaddr-1", "nodeaddr-2", "nodeaddr-3"], + None, + ["i-1", "i-2", "i-3"], + 2, + [ + ("queue1-st-c5xlarge-1,queue1-st-c5xlarge-2", "nodeaddr-1,nodeaddr-2", None, "i-1,i-2"), + ("queue1-st-c5xlarge-3", "nodeaddr-3", None, "i-3"), + ], + ), + ("queue1-st-c5xlarge-1,queue1-st-c5xlarge-[2],queue1-st-c5xlarge-3", ["nodeaddr-1"], None, None, 2, ValueError), ( "queue1-st-c5xlarge-1,queue1-st-c5xlarge-[2],queue1-st-c5xlarge-3", None, ["nodehostname-1"], + None, + 2, + ValueError, + ), + ( + # instance_ids count does not match nodenames count + "queue1-st-c5xlarge-1,queue1-st-c5xlarge-[2],queue1-st-c5xlarge-3", + None, + None, + ["i-1"], 2, ValueError, ), @@ -305,6 +372,7 @@ def test_parse_nodes_info(node_info, expected_parsed_nodes_output, invalid_name, "queue1-st-c5xlarge-1,queue1-st-c5xlarge-2,queue1-st-c5xlarge-3", ["nodeaddr-1", "nodeaddr-2"], "nodehostname-1,nodehostname-2,nodehostname-3", + None, 2, ValueError, ), @@ -312,14 +380,16 @@ def test_parse_nodes_info(node_info, expected_parsed_nodes_output, invalid_name, ["queue1-st-c5xlarge-1", "queue1-st-c5xlarge-2", "queue1-st-c5xlarge-3"], "nodeaddr-[1],nodeaddr-[2],nodeaddr-3", ["nodehostname-1", "nodehostname-2", "nodehostname-3"], + None, 2, [ ( "queue1-st-c5xlarge-1,queue1-st-c5xlarge-2", "nodeaddr-[1],nodeaddr-[2]", "nodehostname-1,nodehostname-2", + None, ), - ("queue1-st-c5xlarge-3", "nodeaddr-3", "nodehostname-3"), + ("queue1-st-c5xlarge-3", "nodeaddr-3", "nodehostname-3", None), ], ), ( @@ -327,6 +397,7 @@ def test_parse_nodes_info(node_info, expected_parsed_nodes_output, invalid_name, "queue1-st-c5xlarge-[1-fillerr],queue1-st-c5xlarge-[2-fillerr],queue1-st-c5xlarge-[3-filler]", "nodeaddr-1,nodeaddr-2,nodeaddr-3", ["nodehostname-1", "nodehostname-2", "nodehostname-3"], + None, 2, ValueError, ), @@ -335,19 +406,23 @@ def test_parse_nodes_info(node_info, expected_parsed_nodes_output, invalid_name, "nodename_only", "name+addr", "name+addr+hostname", + "name+addr+instanceid", "incorrect_addr1", "incorrect_hostname1", + "incorrect_instanceid", "incorrect_addr2", "mixed_format", "same_length_string", ], ) -def test_batch_node_info(nodenames, nodeaddrs, hostnames, batch_size, expected_result): +def test_batch_node_info(nodenames, nodeaddrs, hostnames, instance_ids, batch_size, expected_result): if expected_result is not ValueError: - assert_that(list(_batch_node_info(nodenames, nodeaddrs, hostnames, batch_size))).is_equal_to(expected_result) + assert_that(list(_batch_node_info(nodenames, nodeaddrs, hostnames, instance_ids, batch_size))).is_equal_to( + expected_result + ) else: try: - _batch_node_info(nodenames, nodeaddrs, hostnames, batch_size) + _batch_node_info(nodenames, nodeaddrs, hostnames, instance_ids, batch_size) except Exception as e: assert_that(e).is_instance_of(ValueError) else: @@ -481,7 +556,10 @@ def test_set_nodes_drain(nodes, reason, reset_addrs, update_call_kwargs, mocker) "batch_node_info, state, reason, raise_on_error, run_command_calls, expected_exception", [ ( - [("queue1-st-c5xlarge-1", None, None), ("queue1-st-c5xlarge-2,queue1-st-c5xlarge-3", None, None)], + [ + ("queue1-st-c5xlarge-1", None, None, None), + ("queue1-st-c5xlarge-2,queue1-st-c5xlarge-3", None, None, None), + ], None, None, False, @@ -503,8 +581,8 @@ def test_set_nodes_drain(nodes, reason, reset_addrs, update_call_kwargs, mocker) ), ( [ - ("queue1-st-c5xlarge-1", None, "hostname-1"), - ("queue1-st-c5xlarge-2,queue1-st-c5xlarge-3", "addr-2,addr-3", None), + ("queue1-st-c5xlarge-1", None, "hostname-1", None), + ("queue1-st-c5xlarge-2,queue1-st-c5xlarge-3", "addr-2,addr-3", None, None), ], "power_down", None, @@ -529,8 +607,8 @@ def test_set_nodes_drain(nodes, reason, reset_addrs, update_call_kwargs, mocker) ), ( [ - ("queue1-st-c5xlarge-1", None, "hostname-1"), - ("queue1-st-c5xlarge-[3-6]", "addr-[3-6]", "hostname-[3-6]"), + ("queue1-st-c5xlarge-1", None, "hostname-1", None), + ("queue1-st-c5xlarge-[3-6]", "addr-[3-6]", "hostname-[3-6]", None), ], "down", "debugging", @@ -557,9 +635,28 @@ def test_set_nodes_drain(nodes, reason, reset_addrs, update_call_kwargs, mocker) ], None, ), + ( + # InstanceId is set in the same batched command as NodeAddr (Slurm >= 25.11.6) + [ + ("queue1-st-c5xlarge-[1-2]", "addr-1,addr-2", None, "i-111,i-222"), + ], + None, + None, + True, + [ + call( + "sudo /opt/slurm/bin/scontrol update " + "nodename=queue1-st-c5xlarge-[1-2] nodeaddr=addr-1,addr-2 instanceid=i-111,i-222", + raise_on_error=True, + timeout=60, + shell=True, + ), + ], + None, + ), ( [ - ("queue1-st-c5xlarge-1 & rm -rf /", None, "hostname-1"), + ("queue1-st-c5xlarge-1 & rm -rf /", None, "hostname-1", None), ], "down", "debugging", @@ -569,7 +666,7 @@ def test_set_nodes_drain(nodes, reason, reset_addrs, update_call_kwargs, mocker) ), ( [ - ("queue1-st-c5xlarge-1", " & rm -rf /", "hostname-1"), + ("queue1-st-c5xlarge-1", " & rm -rf /", "hostname-1", None), ], "down", "debugging", @@ -579,7 +676,7 @@ def test_set_nodes_drain(nodes, reason, reset_addrs, update_call_kwargs, mocker) ), ( [ - ("queue1-st-c5xlarge-1", None, " & rm -rf /"), + ("queue1-st-c5xlarge-1", None, " & rm -rf /", None), ], "down", "debugging", @@ -589,7 +686,17 @@ def test_set_nodes_drain(nodes, reason, reset_addrs, update_call_kwargs, mocker) ), ( [ - ("queue1-st-c5xlarge-1", None, "hostname-1"), + ("queue1-st-c5xlarge-1", None, None, " & rm -rf /"), + ], + None, + None, + None, + None, + ValueError, + ), + ( + [ + ("queue1-st-c5xlarge-1", None, "hostname-1", None), ], " & rm -rf /", "debugging", @@ -599,7 +706,7 @@ def test_set_nodes_drain(nodes, reason, reset_addrs, update_call_kwargs, mocker) ), ( [ - ("queue1-st-c5xlarge-1", None, "hostname-1"), + ("queue1-st-c5xlarge-1", None, "hostname-1", None), ], "down", " & rm -rf /", @@ -613,13 +720,81 @@ def test_update_nodes(batch_node_info, state, reason, raise_on_error, run_comman mocker.patch("common.schedulers.slurm_commands._batch_node_info", return_value=batch_node_info, autospec=True) if expected_exception is ValueError: with pytest.raises(ValueError): - update_nodes(batch_node_info, "some_nodeaddrs", "some_hostnames", state, reason, raise_on_error) + update_nodes( + batch_node_info, + "some_nodeaddrs", + "some_hostnames", + state=state, + reason=reason, + raise_on_error=raise_on_error, + ) else: cmd_mock = mocker.patch("common.schedulers.slurm_commands.run_command", autospec=True) - update_nodes(batch_node_info, "some_nodeaddrs", "some_hostnames", state, reason, raise_on_error) + update_nodes( + batch_node_info, + "some_nodeaddrs", + "some_hostnames", + state=state, + reason=reason, + raise_on_error=raise_on_error, + ) cmd_mock.assert_has_calls(run_command_calls) +@pytest.mark.parametrize( + "nodes, nodeaddrs, instance_ids, expected_run_command_calls", + [ + ( + # InstanceId and NodeAddr are set together in a single batched scontrol update command, + # distributed across the nodes in the range (requires Slurm >= 25.11.6). + ["queue1-st-c5xlarge-1", "queue1-st-c5xlarge-2"], + ["ip-1", "ip-2"], + ["i-111", "i-222"], + [ + call( + "sudo /opt/slurm/bin/scontrol update " + "nodename=queue1-st-c5xlarge-1,queue1-st-c5xlarge-2 nodeaddr=ip-1,ip-2 instanceid=i-111,i-222", + raise_on_error=True, + timeout=60, + shell=True, + ), + ], + ), + ( + # Batches larger than 100 nodes are split; each batch keeps its own nodeaddr/instanceid slice. + [f"queue1-st-c5xlarge-{i}" for i in range(1, 102)], + [f"ip-{i}" for i in range(1, 102)], + [f"i-{i}" for i in range(1, 102)], + [ + call( + "sudo /opt/slurm/bin/scontrol update " + f"nodename={','.join(f'queue1-st-c5xlarge-{i}' for i in range(1, 101))} " + f"nodeaddr={','.join(f'ip-{i}' for i in range(1, 101))} " + f"instanceid={','.join(f'i-{i}' for i in range(1, 101))}", + raise_on_error=True, + timeout=60, + shell=True, + ), + call( + "sudo /opt/slurm/bin/scontrol update " + "nodename=queue1-st-c5xlarge-101 nodeaddr=ip-101 instanceid=i-101", + raise_on_error=True, + timeout=60, + shell=True, + ), + ], + ), + ], + ids=["single_batch", "split_batches"], +) +def test_update_nodes_with_instance_ids(nodes, nodeaddrs, instance_ids, expected_run_command_calls, mocker): + """Verify InstanceId is set in the same batched scontrol update command as NodeAddr.""" + cmd_mock = mocker.patch("common.schedulers.slurm_commands.run_command", autospec=True) + update_nodes(nodes, nodeaddrs=nodeaddrs, instance_ids=instance_ids) + cmd_mock.assert_has_calls(expected_run_command_calls) + assert_that(cmd_mock.call_count).is_equal_to(len(expected_run_command_calls)) + + @pytest.mark.parametrize( "partitions, state, run_command_calls, run_command_side_effects, expected_succeeded_partitions", [ diff --git a/tests/slurm_plugin/test_clustermgtd.py b/tests/slurm_plugin/test_clustermgtd.py index 95ffb0a54..3a7b7d2bd 100644 --- a/tests/slurm_plugin/test_clustermgtd.py +++ b/tests/slurm_plugin/test_clustermgtd.py @@ -23,7 +23,7 @@ from slurm_plugin.clustermgtd import ClusterManager, ClustermgtdConfig, ComputeFleetStatus, ComputeFleetStatusManager from slurm_plugin.common import ScalingStrategy from slurm_plugin.console_logger import ConsoleLogger -from slurm_plugin.fleet_manager import EC2Instance +from slurm_plugin.fleet_manager import INSTANCE_INFO_RETRIEVAL_TIMEOUT_DEFAULT, EC2Instance from slurm_plugin.slurm_resources import ( EC2_HEALTH_STATUS_UNHEALTHY_STATES, EC2_INSTANCE_ALIVE_STATES, @@ -67,6 +67,7 @@ class TestClustermgtdConfig: # launch configs "update_node_address": True, "launch_max_batch_size": 500, + "instance_info_retrieval_timeout": INSTANCE_INFO_RETRIEVAL_TIMEOUT_DEFAULT, # terminate configs "terminate_max_batch_size": 1000, "node_replacement_timeout": 1800, @@ -107,6 +108,7 @@ class TestClustermgtdConfig: # launch configs "update_node_address": False, "launch_max_batch_size": 1, + "instance_info_retrieval_timeout": 200, # terminate configs "terminate_max_batch_size": 500, "node_replacement_timeout": 10, @@ -412,6 +414,7 @@ def test_get_ec2_instances(mocker): use_private_hostname=False, run_instances_overrides={}, create_fleet_overrides={}, + instance_info_retrieval_timeout=INSTANCE_INFO_RETRIEVAL_TIMEOUT_DEFAULT, insufficient_capacity_timeout=600, fleet_config=FLEET_CONFIG, head_node_instance_id="i-instance-id", @@ -636,6 +639,7 @@ def test_perform_health_check_actions( use_private_hostname=False, run_instances_overrides={}, create_fleet_overrides={}, + instance_info_retrieval_timeout=INSTANCE_INFO_RETRIEVAL_TIMEOUT_DEFAULT, fleet_config=FLEET_CONFIG, insufficient_capacity_timeout=600, head_node_instance_id="i-instance-id", @@ -1169,6 +1173,7 @@ def test_handle_unhealthy_static_nodes( insufficient_capacity_timeout=600, run_instances_overrides={}, create_fleet_overrides={}, + instance_info_retrieval_timeout=INSTANCE_INFO_RETRIEVAL_TIMEOUT_DEFAULT, compute_console_logging_enabled=output_enabled, compute_console_logging_max_sample_size=sample_size, compute_console_wait_time=1, @@ -1492,6 +1497,7 @@ def test_terminate_orphaned_instances( node_replacement_timeout=1800, run_instances_overrides={}, create_fleet_overrides={}, + instance_info_retrieval_timeout=INSTANCE_INFO_RETRIEVAL_TIMEOUT_DEFAULT, fleet_config=FLEET_CONFIG, head_node_instance_id="i-instance-id", ) @@ -1509,6 +1515,38 @@ def test_terminate_orphaned_instances( ) +def test_update_slurm_nodes_with_ec2_info_instance_id_matching(): + """Test that _update_slurm_nodes_with_ec2_info matches by instance ID instead of IP.""" + # Nodes with instance_id set (as would be after our change) + node1 = StaticNode( + "queue1-st-c5xlarge-1", "10.0.1.1", "queue1-st-c5xlarge-1", "IDLE+CLOUD", "queue1", instance_id="i-aaa111" + ) + node2 = DynamicNode( + "queue1-dy-c5xlarge-2", "10.0.1.2", "queue1-dy-c5xlarge-2", "IDLE+CLOUD", "queue1", instance_id="i-bbb222" + ) + # Node without instance_id (powered down, not yet assigned) + node3 = DynamicNode( + "queue1-dy-c5xlarge-3", "queue1-dy-c5xlarge-3", "queue1-dy-c5xlarge-3", "IDLE+CLOUD+POWER", "queue1" + ) + + # EC2 instances - one with full IP, one with missing IP (eventual consistency) + instance1 = EC2Instance("i-aaa111", "10.0.1.1", "hostname-1", {"10.0.1.1"}, "launch_time_1") + instance2 = EC2Instance("i-bbb222", "", "", set(), "launch_time_2") # missing IP + + nodes = [node1, node2, node3] + cluster_instances = [instance1, instance2] + + ClusterManager._update_slurm_nodes_with_ec2_info(nodes, cluster_instances) + + # Both instances should be matched by instance ID + assert_that(node1.instance).is_equal_to(instance1) + assert_that(instance1.slurm_node).is_equal_to(node1) + assert_that(node2.instance).is_equal_to(instance2) + assert_that(instance2.slurm_node).is_equal_to(node2) + # Node3 has no instance_id, should not be matched + assert_that(node3.instance).is_none() + + @pytest.mark.parametrize( "disable_cluster_management, disable_health_check, mock_cluster_instances, nodes, partitions, status, " "queue_compute_resource_nodes_map", @@ -1732,18 +1770,18 @@ def test_manage_cluster( "default.conf", [ # This node fail scheduler state check and corresponding instance will be terminated and replaced - StaticNode("queue-st-c5xlarge-1", "ip-1", "hostname", "IDLE+CLOUD+DRAIN", "queue1"), + StaticNode("queue-st-c5xlarge-1", "ip-1", "hostname", "IDLE+CLOUD+DRAIN", "queue1", instance_id="i-1"), # This node fail scheduler state check and node will be power_down - DynamicNode("queue-dy-c5xlarge-2", "ip-2", "hostname", "DOWN+CLOUD", "queue1"), + DynamicNode("queue-dy-c5xlarge-2", "ip-2", "hostname", "DOWN+CLOUD", "queue1", instance_id="i-2"), # This node is good and should not be touched by clustermgtd - DynamicNode("queue-dy-c5xlarge-3", "ip-3", "hostname", "IDLE+CLOUD", "queue1"), + DynamicNode("queue-dy-c5xlarge-3", "ip-3", "hostname", "IDLE+CLOUD", "queue1", instance_id="i-3"), # This node is in power_saving state but still has running backing instance, it should be terminated DynamicNode("queue-dy-c5xlarge-6", "ip-6", "hostname", "IDLE+CLOUD+POWER", "queue1"), # This node is in powering_down but still has no valid backing instance, no boto3 call DynamicNode("queue-dy-c5xlarge-8", "ip-8", "hostname", "IDLE+CLOUD+POWERING_DOWN", "queue1"), ], [ - StaticNode("queue-st-c5xlarge-4", "ip-4", "hostname", "IDLE+CLOUD", "queue2"), + StaticNode("queue-st-c5xlarge-4", "ip-4", "hostname", "IDLE+CLOUD", "queue2", instance_id="i-4"), DynamicNode("queue-dy-c5xlarge-5", "ip-5", "hostname", "DOWN+CLOUD", "queue2"), ], [ @@ -1944,6 +1982,7 @@ def test_manage_cluster( "DOWN+CLOUD", "queue1", slurmdstarttime=datetime(2020, 1, 1, tzinfo=timezone.utc), + instance_id="i-1", ), DynamicNode( "queue-dy-c5xlarge-2", @@ -1952,6 +1991,7 @@ def test_manage_cluster( "DOWN+CLOUD", "queue1", slurmdstarttime=datetime(2020, 1, 1, tzinfo=timezone.utc), + instance_id="i-2", ), DynamicNode( "queue-dy-c5xlarge-3", @@ -1960,6 +2000,7 @@ def test_manage_cluster( "IDLE+CLOUD", "queue1", slurmdstarttime=datetime(2020, 1, 1, tzinfo=timezone.utc), + instance_id="i-3", ), ], [ @@ -1970,6 +2011,7 @@ def test_manage_cluster( "IDLE+CLOUD", "queue2", slurmdstarttime=datetime(2020, 1, 1, tzinfo=timezone.utc), + instance_id="i-4", ), DynamicNode( "queue-dy-c5xlarge-5", @@ -2389,6 +2431,7 @@ def test_handle_successfully_launched_nodes( terminate_down_nodes=True, run_instances_overrides={}, create_fleet_overrides={}, + instance_info_retrieval_timeout=INSTANCE_INFO_RETRIEVAL_TIMEOUT_DEFAULT, fleet_config=FLEET_CONFIG, head_node_instance_id="i-instance-id", ec2_instance_missing_max_count=0, diff --git a/tests/slurm_plugin/test_clustermgtd/TestClustermgtdConfig/test_config_parsing/all_options.conf b/tests/slurm_plugin/test_clustermgtd/TestClustermgtdConfig/test_config_parsing/all_options.conf index 49c468071..4cea0d70c 100644 --- a/tests/slurm_plugin/test_clustermgtd/TestClustermgtdConfig/test_config_parsing/all_options.conf +++ b/tests/slurm_plugin/test_clustermgtd/TestClustermgtdConfig/test_config_parsing/all_options.conf @@ -9,6 +9,7 @@ proxy = https://fake.proxy logging_config = /my/logging/config update_node_address = false launch_max_batch_size = 1 +instance_info_retrieval_timeout = 200 terminate_max_batch_size = 500 node_replacement_timeout = 10 terminate_drain_nodes = false diff --git a/tests/slurm_plugin/test_fleet_manager.py b/tests/slurm_plugin/test_fleet_manager.py index 9db222bd3..988f3f21e 100644 --- a/tests/slurm_plugin/test_fleet_manager.py +++ b/tests/slurm_plugin/test_fleet_manager.py @@ -17,6 +17,8 @@ from assertpy import assert_that from botocore.exceptions import ClientError from slurm_plugin.fleet_manager import ( + INSTANCE_INFO_RETRIEVAL_MAX_BACKOFF, + INSTANCE_INFO_RETRIEVAL_TIMEOUT_DEFAULT, Ec2CreateFleetManager, EC2Instance, Ec2RunInstancesManager, @@ -27,6 +29,19 @@ from tests.common import FLEET_CONFIG, MockedBoto3Request +def _expected_describe_attempts(timeout): + """Compute DescribeInstances attempts for a never-converging instance, mirroring _get_instances_info.""" + attempts = 0 + elapsed_backoff = 0 + while True: + attempts += 1 + base_backoff = min(0.3 * 2**attempts, INSTANCE_INFO_RETRIEVAL_MAX_BACKOFF) + if elapsed_backoff + base_backoff > timeout: + break + elapsed_backoff += base_backoff + return attempts + + @pytest.fixture() def boto3_stubber_path(): # we need to set the region in the environment because the Boto3ClientFactory requires it. @@ -1129,13 +1144,88 @@ def test_get_instances_info( # Note: some tests cases are covered by test_launc mocker.patch("time.sleep") boto3_stubber("ec2", mocked_boto3_request) # run test + # A 10s retrieval timeout bounds the never-converging cases to exactly 5 DescribeInstances attempts, + # matching the number of mocked responses, while leaving room for the converging cases to succeed. fleet_manager = FleetManagerFactory.get_manager( - "hit", "region", "boto3_config", FLEET_CONFIG, "queue2", "fleet-ondemand", True, {}, {} + "hit", + "region", + "boto3_config", + FLEET_CONFIG, + "queue2", + "fleet-ondemand", + True, + {}, + {}, + instance_info_retrieval_timeout=10, ) complete_instances, partial_instance_ids = fleet_manager._get_instances_info(instance_ids) assert_that(expected_result).is_equal_to((complete_instances, partial_instance_ids)) + def test_instance_info_retrieval_timeout_default(self): + # Default timeout is wired through the factory into the CreateFleet manager + fleet_manager = FleetManagerFactory.get_manager( + "hit", "region", "boto3_config", FLEET_CONFIG, "queue2", "fleet-ondemand", True, {}, {} + ) + assert_that(fleet_manager._instance_info_retrieval_timeout).is_equal_to(INSTANCE_INFO_RETRIEVAL_TIMEOUT_DEFAULT) + + def test_instance_info_retrieval_timeout_override(self): + # A custom timeout is propagated through the factory into the CreateFleet manager + fleet_manager = FleetManagerFactory.get_manager( + "hit", + "region", + "boto3_config", + FLEET_CONFIG, + "queue2", + "fleet-ondemand", + True, + {}, + {}, + instance_info_retrieval_timeout=240, + ) + assert_that(fleet_manager._instance_info_retrieval_timeout).is_equal_to(240) + + @pytest.mark.parametrize( + ("instance_info_retrieval_timeout", "expected_describe_calls"), + [ + # never-converging instance -> attempts bounded by the timeout budget (capped per-attempt backoff) + (10, _expected_describe_attempts(10)), + (1, _expected_describe_attempts(1)), + ( + INSTANCE_INFO_RETRIEVAL_TIMEOUT_DEFAULT, + _expected_describe_attempts(INSTANCE_INFO_RETRIEVAL_TIMEOUT_DEFAULT), + ), + ], + ids=["timeout_10s", "timeout_1s", "timeout_default"], + ) + def test_get_instances_info_retry_count_scales_with_timeout( + self, mocker, instance_info_retrieval_timeout, expected_describe_calls + ): + # Patch sleep so the test runs instantly and stub the EC2 describe to always return incomplete info. + mocker.patch("time.sleep") + fleet_manager = FleetManagerFactory.get_manager( + "hit", + "region", + "boto3_config", + FLEET_CONFIG, + "queue2", + "fleet-ondemand", + True, + {}, + {}, + instance_info_retrieval_timeout=instance_info_retrieval_timeout, + ) + # Always-incomplete response keeps the instance in partial_instance_ids, forcing retries until timeout. + retrieve_mock = mocker.patch.object( + fleet_manager, "_retrieve_instances_info_from_ec2", return_value=([], ["i-12345"]) + ) + + instances, partial_instance_ids = fleet_manager._get_instances_info(["i-12345"]) + + assert_that(instances).is_empty() + assert_that(partial_instance_ids).is_equal_to(["i-12345"]) + assert_that(retrieve_mock.call_count).is_equal_to(expected_describe_calls) + @pytest.mark.parametrize( ("instance_ids", "mocked_boto3_request", "expected_result"), [ diff --git a/tests/slurm_plugin/test_instance_manager.py b/tests/slurm_plugin/test_instance_manager.py index 28f809b0e..d7ae88348 100644 --- a/tests/slurm_plugin/test_instance_manager.py +++ b/tests/slurm_plugin/test_instance_manager.py @@ -907,6 +907,7 @@ def get_unhealthy_cluster_instance_status( generate_error=False, ), [ + EC2Instance("i-1", "", "", set(), datetime(2020, 1, 1, tzinfo=timezone.utc)), EC2Instance("i-2", "ip-2", "hostname", {"ip-2"}, datetime(2020, 1, 1, tzinfo=timezone.utc)), ], False, @@ -3106,7 +3107,7 @@ def test_assign_instances_to_nodes( [], False, None, - call(["queue1-st-c5xlarge-1"], nodeaddrs=[], nodehostnames=None), + call(["queue1-st-c5xlarge-1"], nodeaddrs=[], nodehostnames=None, instance_ids=[]), None, ), ( @@ -3114,7 +3115,7 @@ def test_assign_instances_to_nodes( [EC2Instance("id-1", "ip-1", "hostname-1", {"ip-1"}, "some_launch_time")], False, None, - call(["queue1-st-c5xlarge-1"], nodeaddrs=["ip-1"], nodehostnames=None), + call(["queue1-st-c5xlarge-1"], nodeaddrs=["ip-1"], nodehostnames=None, instance_ids=["id-1"]), None, ), ( @@ -3122,7 +3123,7 @@ def test_assign_instances_to_nodes( [EC2Instance("id-1", "ip-1", "hostname-1", {"ip-1"}, "some_launch_time")], True, None, - call(["queue1-st-c5xlarge-1"], nodeaddrs=["ip-1"], nodehostnames=["hostname-1"]), + call(["queue1-st-c5xlarge-1"], nodeaddrs=["ip-1"], nodehostnames=["hostname-1"], instance_ids=["id-1"]), None, ), ( @@ -3130,7 +3131,7 @@ def test_assign_instances_to_nodes( [EC2Instance("id-1", "ip-1", "hostname-1", {"ip-1"}, "some_launch_time")], True, subprocess.CalledProcessError(1, "command"), - call(["queue1-st-c5xlarge-1"], nodeaddrs=["ip-1"], nodehostnames=["hostname-1"]), + call(["queue1-st-c5xlarge-1"], nodeaddrs=["ip-1"], nodehostnames=["hostname-1"], instance_ids=["id-1"]), NodeAddrUpdateError(), ), ( @@ -3141,7 +3142,12 @@ def test_assign_instances_to_nodes( ], False, None, - call(["queue1-st-c5xlarge-1", "queue1-st-c5xlarge-2"], nodeaddrs=["ip-1", "ip-2"], nodehostnames=None), + call( + ["queue1-st-c5xlarge-1", "queue1-st-c5xlarge-2"], + nodeaddrs=["ip-1", "ip-2"], + nodehostnames=None, + instance_ids=["id-1", "id-2"], + ), None, ), ], diff --git a/tests/slurm_plugin/test_resume.py b/tests/slurm_plugin/test_resume.py index 5601e864d..9370ea270 100644 --- a/tests/slurm_plugin/test_resume.py +++ b/tests/slurm_plugin/test_resume.py @@ -20,7 +20,7 @@ import pytest import slurm_plugin from assertpy import assert_that -from slurm_plugin.fleet_manager import EC2Instance +from slurm_plugin.fleet_manager import INSTANCE_INFO_RETRIEVAL_TIMEOUT_DEFAULT, EC2Instance from slurm_plugin.resume import SlurmResumeConfig, _get_slurm_resume, _handle_failed_nodes, _resume from src.slurm_plugin.common import ScalingStrategy @@ -57,6 +57,7 @@ def boto3_stubber_path(): "job_level_scaling": True, "assign_node_max_batch_size": 500, "terminate_max_batch_size": 1000, + "instance_info_retrieval_timeout": INSTANCE_INFO_RETRIEVAL_TIMEOUT_DEFAULT, }, ), ( @@ -77,6 +78,7 @@ def boto3_stubber_path(): "job_level_scaling": False, "assign_node_max_batch_size": 400, "terminate_max_batch_size": 600, + "instance_info_retrieval_timeout": 200, }, ), ], @@ -280,7 +282,7 @@ def test_resume_config(config_file, expected_attributes, test_datadir, mocker): "ServiceUnavailable": {"queue1-st-c5xlarge-2"}, "LimitedInstanceCapacity": {"queue1-dy-c5xlarge-2", "queue1-st-c5xlarge-1"}, }, - [call(["queue1-dy-c5xlarge-1"], nodeaddrs=["ip.1.0.0.1"], nodehostnames=None)], + [call(["queue1-dy-c5xlarge-1"], nodeaddrs=["ip.1.0.0.1"], nodehostnames=None, instance_ids=["i-11111"])], dict( zip( ["queue1-dy-c5xlarge-1"], @@ -332,7 +334,7 @@ def test_resume_config(config_file, expected_attributes, test_datadir, mocker): client_error("InsufficientReservedInstanceCapacity"), ], {"InsufficientReservedInstanceCapacity": {"queue1-st-c5xlarge-2"}}, - [call(["queue1-dy-c5xlarge-1"], nodeaddrs=["ip.1.0.0.1"], nodehostnames=None)], + [call(["queue1-dy-c5xlarge-1"], nodeaddrs=["ip.1.0.0.1"], nodehostnames=None, instance_ids=["i-11111"])], dict( zip( ["queue1-dy-c5xlarge-1"], @@ -406,6 +408,7 @@ def test_resume_launch( job_level_scaling=job_level_scaling, assign_node_max_batch_size=500, terminate_max_batch_size=1000, + instance_info_retrieval_timeout=INSTANCE_INFO_RETRIEVAL_TIMEOUT_DEFAULT, ) mocker.patch("slurm_plugin.resume.is_clustermgtd_heartbeat_valid", autospec=True, return_value=is_heartbeat_valid) mock_handle_failed_nodes = mocker.patch("slurm_plugin.resume._handle_failed_nodes", autospec=True) diff --git a/tests/slurm_plugin/test_resume/test_resume_config/all_options.conf b/tests/slurm_plugin/test_resume/test_resume_config/all_options.conf index e11fde544..d6da4b972 100644 --- a/tests/slurm_plugin/test_resume/test_resume_config/all_options.conf +++ b/tests/slurm_plugin/test_resume/test_resume_config/all_options.conf @@ -19,3 +19,4 @@ clustermgtd_timeout = 5 job_level_scaling = False assign_node_max_batch_size = 400 terminate_max_batch_size = 600 +instance_info_retrieval_timeout = 200