Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ This file is used to list changes made in each version of the aws-parallelcluste

**BUG FIXES**
- Fix clustermgtd failing to detect compute node bootstrap timeouts, which prevented the cluster from entering protected mode.
- Fix an issue where compute nodes are incorrectly replaced when launching a large number of nodes due to eventual consistency.

3.15.0
------
Expand Down
53 changes: 33 additions & 20 deletions src/common/schedulers/slurm_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@
SCONTROL_OUTPUT_AWK_PARSER = (
'awk \'BEGIN{{RS="\\n\\n" ; ORS="######\\n";}} {{print}}\' | '
+ "grep -oP '^(NodeName=\\S+)|(NodeAddr=\\S+)|(NodeHostName=\\S+)|(?<!Next)(State=\\S+)|"
+ "(Partitions=\\S+)|(SlurmdStartTime=\\S+)|(LastBusyTime=\\S+)|(ReservationName=\\S+)|(Reason=.*)|(######)'"
+ "(Partitions=\\S+)|(SlurmdStartTime=\\S+)|(LastBusyTime=\\S+)|(ReservationName=\\S+)"
+ "|(InstanceId=\\S+)|(Reason=.*)|(######)'"
)

# Set default timeouts for running different slurm commands.
Expand Down Expand Up @@ -129,6 +130,7 @@ def update_nodes(
nodes,
nodeaddrs=None,
nodehostnames=None,
instance_ids=None,
state=None,
reason=None,
raise_on_error=True,
Expand All @@ -150,8 +152,11 @@ def update_nodes(
For example, if updating a state cause failure, but updating nodeaddr cause no failure.
if we run scontrol update state=fail_state nodeaddr=good_addr nodename=name,
the scontrol command will fail but nodeaddr will be updated to good_addr.

InstanceId is set in the same batched command as NodeAddr so that the node and its backing
instance are associated atomically.
"""
batched_node_info = _batch_node_info(nodes, nodeaddrs, nodehostnames, batch_size=100)
batched_node_info = _batch_node_info(nodes, nodeaddrs, nodehostnames, instance_ids, batch_size=100)

update_cmd = f"{SCONTROL} update"
if state:
Expand All @@ -160,7 +165,7 @@ def update_nodes(
if reason:
validate_subprocess_argument(reason)
update_cmd += f' reason="{reason}"'
for nodenames, addrs, hostnames in batched_node_info:
for nodenames, addrs, hostnames, ids in batched_node_info:
validate_subprocess_argument(nodenames)
node_info = f"nodename={nodenames}"
if addrs:
Expand All @@ -169,6 +174,9 @@ def update_nodes(
if hostnames:
validate_subprocess_argument(hostnames)
node_info += f" nodehostname={hostnames}"
if ids:
validate_subprocess_argument(ids)
node_info += f" instanceid={ids}"
# It's safe to use the function affected by B604 since the command is fully built in this code
run_command( # nosec B604
f"{update_cmd} {node_info}", raise_on_error=raise_on_error, timeout=command_timeout, shell=True
Expand Down Expand Up @@ -223,29 +231,30 @@ def _batch_attribute(attribute, batch_size, expected_length=None):
return [",".join(batch) for batch in grouper(attribute, batch_size)]


def _batch_node_info(nodenames, nodeaddrs, nodehostnames, batch_size):
"""Group nodename, nodeaddrs, nodehostnames into batches."""
def _batch_optional_attribute(attribute, attribute_label, nodenames, default_batch, batch_size):
"""Batch an optional per-node attribute, raising if its entry count does not match the nodes."""
if not attribute:
return default_batch
try:
return _batch_attribute(attribute, batch_size, expected_length=len(nodenames))
except ValueError:
log.error("Nodename %s and %s %s contain different number of entries", nodenames, attribute_label, attribute)
raise


def _batch_node_info(nodenames, nodeaddrs, nodehostnames, instance_ids, batch_size):
"""Group nodename, nodeaddrs, nodehostnames, instance_ids into batches."""
if type(nodenames) is str:
# Only split on , if there is ] before
# For ex. "node-[1,3,4-5],node-[20,30]" should split into ["node-[1,3,4-5]","node-[20,30]"]
nodenames = re.split("(?<=]),", nodenames)
nodename_batch = _batch_attribute(nodenames, batch_size)
nodeaddrs_batch = [None] * len(nodename_batch)
nodehostnames_batch = [None] * len(nodename_batch)
if nodeaddrs:
try:
nodeaddrs_batch = _batch_attribute(nodeaddrs, batch_size, expected_length=len(nodenames))
except ValueError:
log.error("Nodename %s and NodeAddr %s contain different number of entries", nodenames, nodeaddrs)
raise
if nodehostnames:
try:
nodehostnames_batch = _batch_attribute(nodehostnames, batch_size, expected_length=len(nodenames))
except ValueError:
log.error("Nodename %s and NodeHostname %s contain different number of entries", nodenames, nodehostnames)
raise
default_batch = [None] * len(nodename_batch)
nodeaddrs_batch = _batch_optional_attribute(nodeaddrs, "NodeAddr", nodenames, default_batch, batch_size)
nodehostnames_batch = _batch_optional_attribute(nodehostnames, "NodeHostName", nodenames, default_batch, batch_size)
instance_ids_batch = _batch_optional_attribute(instance_ids, "InstanceId", nodenames, default_batch, batch_size)

return zip(nodename_batch, nodeaddrs_batch, nodehostnames_batch)
return zip(nodename_batch, nodeaddrs_batch, nodehostnames_batch, instance_ids_batch)


def set_nodes_down(nodes, reason):
Expand Down Expand Up @@ -433,6 +442,7 @@ def _parse_nodes_info(slurm_node_info: str) -> List[SlurmNode]:
"SlurmdStartTime": "slurmdstarttime",
"LastBusyTime": "lastbusytime",
"ReservationName": "reservation_name",
"InstanceId": "instance_id",
}

date_fields = ["SlurmdStartTime", "LastBusyTime"]
Expand All @@ -449,6 +459,9 @@ def _parse_nodes_info(slurm_node_info: str) -> List[SlurmNode]:
value = datetime.strptime(value, "%Y-%m-%dT%H:%M:%S").astimezone(tz=timezone.utc)
else:
value = None
elif key == "InstanceId" and value == "(null)":
# Slurm reports an unset InstanceId as "(null)"
value = None
kwargs[map_slurm_key_to_arg[key]] = value
if lines:
try:
Expand Down
21 changes: 14 additions & 7 deletions src/slurm_plugin/clustermgtd.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from slurm_plugin.cluster_event_publisher import ClusterEventPublisher
from slurm_plugin.common import TIMESTAMP_FORMAT, ScalingStrategy, log_exception, print_with_count
from slurm_plugin.console_logger import ConsoleLogger
from slurm_plugin.fleet_manager import INSTANCE_INFO_RETRIEVAL_TIMEOUT_DEFAULT
from slurm_plugin.instance_manager import InstanceManager
from slurm_plugin.slurm_resources import (
CONFIG_FILE_DIR,
Expand Down Expand Up @@ -147,6 +148,7 @@ class ClustermgtdConfig:
"run_instances_overrides": "/opt/slurm/etc/pcluster/run_instances_overrides.json",
"create_fleet_overrides": "/opt/slurm/etc/pcluster/create_fleet_overrides.json",
"fleet_config_file": "/etc/parallelcluster/slurm_plugin/fleet-config.json",
"instance_info_retrieval_timeout": INSTANCE_INFO_RETRIEVAL_TIMEOUT_DEFAULT,
# Terminate configs
"terminate_max_batch_size": 1000,
# Timeout to wait for node initialization, should be the same as ResumeTimeout
Expand Down Expand Up @@ -257,6 +259,11 @@ def _get_launch_config(self, config):
"clustermgtd", "create_fleet_overrides", fallback=self.DEFAULTS.get("create_fleet_overrides")
)
self.create_fleet_overrides = read_json(create_fleet_overrides_file, default={})
self.instance_info_retrieval_timeout = config.getint(
"clustermgtd",
"instance_info_retrieval_timeout",
fallback=self.DEFAULTS.get("instance_info_retrieval_timeout"),
)

def _get_health_check_config(self, config):
self.disable_ec2_health_check = config.getboolean(
Expand Down Expand Up @@ -452,6 +459,7 @@ def _initialize_instance_manager(config):
run_instances_overrides=config.run_instances_overrides,
create_fleet_overrides=config.create_fleet_overrides,
fleet_config=config.fleet_config,
instance_info_retrieval_timeout=config.instance_info_retrieval_timeout,
)

def _initialize_executor(self, config):
Expand Down Expand Up @@ -1149,15 +1157,14 @@ def _parse_scheduler_nodes_data(nodes):

@staticmethod
def _update_slurm_nodes_with_ec2_info(nodes, cluster_instances):
"""Associate EC2 instances with Slurm nodes by matching on instance ID."""
if cluster_instances:
ip_to_slurm_node_map = {node.nodeaddr: node for node in nodes}
instance_id_to_slurm_node_map = {node.instance_id: node for node in nodes if node.instance_id}
for instance in cluster_instances:
for private_ip in instance.all_private_ips:
if private_ip in ip_to_slurm_node_map:
slurm_node = ip_to_slurm_node_map.get(private_ip)
slurm_node.instance = instance
instance.slurm_node = slurm_node
break
if instance.id in instance_id_to_slurm_node_map:
slurm_node = instance_id_to_slurm_node_map[instance.id]
slurm_node.instance = instance
instance.slurm_node = slurm_node

@staticmethod
def get_instance_id_to_active_node_map(partitions: List[SlurmPartition]) -> Dict:
Expand Down
35 changes: 31 additions & 4 deletions src/slurm_plugin/fleet_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@

logger = logging.getLogger(__name__)

# Total time budget (seconds) and per-attempt backoff cap for retrying DescribeInstances after a CreateFleet
# launch, to tolerate EC2 API eventual consistency.
# See https://docs.aws.amazon.com/ec2/latest/devguide/eventual-consistency.html
INSTANCE_INFO_RETRIEVAL_TIMEOUT_DEFAULT = 90
INSTANCE_INFO_RETRIEVAL_MAX_BACKOFF = 30


class EC2Instance:
def __init__(self, id, private_ip, hostname, all_private_ips, launch_time):
Expand Down Expand Up @@ -94,6 +100,7 @@ def get_manager(
all_or_nothing,
run_instances_overrides,
create_fleet_overrides,
instance_info_retrieval_timeout=INSTANCE_INFO_RETRIEVAL_TIMEOUT_DEFAULT,
):
try:
queue_config = fleet_config[queue]
Expand All @@ -120,6 +127,7 @@ def get_manager(
compute_resource_config,
all_or_nothing,
create_fleet_overrides.get(queue, {}).get(compute_resource, {}),
instance_info_retrieval_timeout=instance_info_retrieval_timeout,
)
elif api == "run-instances":
return Ec2RunInstancesManager(
Expand Down Expand Up @@ -272,6 +280,7 @@ def __init__(
compute_resource_config,
all_or_nothing,
launch_overrides,
instance_info_retrieval_timeout=INSTANCE_INFO_RETRIEVAL_TIMEOUT_DEFAULT,
):
super().__init__(
cluster_name,
Expand All @@ -283,6 +292,7 @@ def __init__(
all_or_nothing,
launch_overrides,
)
self._instance_info_retrieval_timeout = instance_info_retrieval_timeout

def _evaluate_template_overrides(self) -> list:
"""Build and return the list of Launch Template Overrides to be applied in the CreateFleet request.
Expand Down Expand Up @@ -436,22 +446,39 @@ def _get_instances_info(self, instance_ids: list):
"""
Describe instances to retrieve info not available from create-fleet response.

Right after a CreateFleet launch, DescribeInstances may return instances with missing info
(e.g. PrivateIpAddress) or even InvalidInstanceID.NotFound due to EC2 API eventual consistency.
Retry with exponential backoff (capped per attempt) until the configured total timeout is reached,
as recommended at https://docs.aws.amazon.com/ec2/latest/devguide/eventual-consistency.html

:raises ClientError in case of boto3 failure
:return list of instances with complete information and list of IDs for instances with incomplete information
"""
instances = []
partial_instance_ids = instance_ids

retries = 5
attempt_count = 0
# Budget is tracked against the un-jittered backoff; jitter is added only to the actual sleep.
elapsed_backoff = 0
# Wait for instances to be available in EC2
time.sleep(0.1)
while attempt_count < retries and partial_instance_ids:
while partial_instance_ids:
complete_instances, partial_instance_ids = self._retrieve_instances_info_from_ec2(partial_instance_ids)
instances.extend(complete_instances)
if not partial_instance_ids:
break
base_backoff = min(0.3 * 2 ** (attempt_count + 1), INSTANCE_INFO_RETRIEVAL_MAX_BACKOFF)
if elapsed_backoff + base_backoff > self._instance_info_retrieval_timeout:
logger.warning(
"Unable to retrieve complete info for instances %s within %s seconds, giving up after %s attempts.",
print_with_count(partial_instance_ids),
self._instance_info_retrieval_timeout,
attempt_count + 1,
)
break
elapsed_backoff += base_backoff
attempt_count += 1
if attempt_count < retries:
time.sleep(0.3 * 2**attempt_count + (secrets.randbelow(500) / 1000))
time.sleep(base_backoff + (secrets.randbelow(500) / 1000))

return instances, partial_instance_ids

Expand Down
20 changes: 17 additions & 3 deletions src/slurm_plugin/instance_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
from common.schedulers.slurm_commands import get_nodes_info, update_nodes
from common.utils import grouper, setup_logging_filter
from slurm_plugin.common import ComputeInstanceDescriptor, ScalingStrategy, log_exception, print_with_count
from slurm_plugin.fleet_manager import EC2Instance, FleetManagerFactory
from slurm_plugin.fleet_manager import (
INSTANCE_INFO_RETRIEVAL_TIMEOUT_DEFAULT,
EC2Instance,
FleetManagerFactory,
)
from slurm_plugin.slurm_resources import (
EC2_HEALTH_STATUS_UNHEALTHY_STATES,
EC2_INSTANCE_ALIVE_STATES,
Expand Down Expand Up @@ -85,6 +89,7 @@ def __init__(
run_instances_overrides: dict = None,
create_fleet_overrides: dict = None,
job_level_scaling: bool = False,
instance_info_retrieval_timeout: int = INSTANCE_INFO_RETRIEVAL_TIMEOUT_DEFAULT,
):
"""Initialize InstanceLauncher with required attributes."""
self._region = region
Expand All @@ -101,6 +106,7 @@ def __init__(
self._fleet_config = fleet_config
self._run_instances_overrides = run_instances_overrides or {}
self._create_fleet_overrides = create_fleet_overrides or {}
self._instance_info_retrieval_timeout = instance_info_retrieval_timeout
self._boto3_resource_factory = lambda resource_name: boto3.session.Session().resource(
resource_name, region_name=region, config=boto3_config
)
Expand Down Expand Up @@ -262,7 +268,8 @@ def get_cluster_instances(self, include_head_node=False, alive_states_only=True)
"""
Get instances that are associated with the cluster.

Instances without all the info set are ignored and not returned
Instances with missing info (e.g. PrivateIpAddress due to EC2 eventual consistency) are kept with
empty IP fields so that clustermgtd can still match them to Slurm nodes by instance ID.
"""
ec2_client = boto3.client("ec2", region_name=self._region, config=self._boto3_config)
paginator = ec2_client.get_paginator("describe_instances")
Expand Down Expand Up @@ -290,12 +297,17 @@ def get_cluster_instances(self, include_head_node=False, alive_states_only=True)
)
)
except Exception as e:
# Keep the instance with empty IP info so it can still be matched by instance ID in clustermgtd.
logger.warning(
"Ignoring instance %s because not all EC2 info are available, exception: %s, message: %s",
"Incomplete EC2 info for instance %s, keeping it for instance-ID matching, "
"exception: %s, message: %s",
instance_info["InstanceId"],
type(e).__name__,
e,
)
instances.append(
EC2Instance(instance_info["InstanceId"], "", "", set(), instance_info.get("LaunchTime"))
)

return instances

Expand Down Expand Up @@ -1008,6 +1020,7 @@ def _get_fleet_manager(self, all_or_nothing_batch, compute_resource, queue):
all_or_nothing=all_or_nothing_batch,
run_instances_overrides=self._run_instances_overrides,
create_fleet_overrides=self._create_fleet_overrides,
instance_info_retrieval_timeout=self._instance_info_retrieval_timeout,
)
return fleet_manager

Expand Down Expand Up @@ -1077,6 +1090,7 @@ def _update_slurm_node_addrs(self, slurm_nodes: List[str], launched_instances: L
slurm_nodes,
nodeaddrs=[instance.private_ip for instance in launched_instances],
nodehostnames=node_hostnames,
instance_ids=[instance.id for instance in launched_instances],
)
logger.info(
"Nodes are now configured with instances %s",
Expand Down
8 changes: 8 additions & 0 deletions src/slurm_plugin/resume.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from common.utils import read_json
from slurm_plugin.cluster_event_publisher import ClusterEventPublisher
from slurm_plugin.common import ScalingStrategy, is_clustermgtd_heartbeat_valid, print_with_count
from slurm_plugin.fleet_manager import INSTANCE_INFO_RETRIEVAL_TIMEOUT_DEFAULT
from slurm_plugin.instance_manager import InstanceManager
from slurm_plugin.slurm_resources import CONFIG_FILE_DIR

Expand All @@ -47,6 +48,7 @@ class SlurmResumeConfig:
"fleet_config_file": "/etc/parallelcluster/slurm_plugin/fleet-config.json",
"job_level_scaling": True,
"scaling_strategy": "all-or-nothing",
"instance_info_retrieval_timeout": INSTANCE_INFO_RETRIEVAL_TIMEOUT_DEFAULT,
}

def __init__(self, config_file_path):
Expand Down Expand Up @@ -96,6 +98,11 @@ def _get_config(self, config_file_path):
self.job_level_scaling = config.getboolean(
"slurm_resume", "job_level_scaling", fallback=self.DEFAULTS.get("job_level_scaling")
)
self.instance_info_retrieval_timeout = config.getint(
"slurm_resume",
"instance_info_retrieval_timeout",
fallback=self.DEFAULTS.get("instance_info_retrieval_timeout"),
)
fleet_config_file = config.get(
"slurm_resume", "fleet_config_file", fallback=self.DEFAULTS.get("fleet_config_file")
)
Expand Down Expand Up @@ -206,6 +213,7 @@ def _resume(arg_nodes, resume_config, slurm_resume):
run_instances_overrides=resume_config.run_instances_overrides,
create_fleet_overrides=resume_config.create_fleet_overrides,
job_level_scaling=resume_config.job_level_scaling,
instance_info_retrieval_timeout=resume_config.instance_info_retrieval_timeout,
)
instance_manager.add_instances(
slurm_resume=slurm_resume,
Expand Down
Loading
Loading