Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,11 @@ async def _get_cluster_placement_context(
session=session,
fleet_id=instance_model.fleet_id,
)
for placement_group_model in placement_group_models:
_populate_placement_group_relations(
placement_group_model=placement_group_model,
instance_model=instance_model,
)
placement_group_model = None
if not cluster_context.is_current_instance_master:
# Non-master instances only reuse the placement group chosen by the
Expand All @@ -307,11 +312,6 @@ async def _get_cluster_placement_context(
placement_group_models=placement_group_models,
fleet_id=instance_model.fleet_id,
)
if placement_group_model is not None:
_populate_current_master_placement_group_relations(
placement_group_model=placement_group_model,
instance_model=instance_model,
)
return placement_group_models, placement_group_model


Expand Down Expand Up @@ -358,13 +358,13 @@ def _get_current_master_placement_group_model(
return placement_group_models[0]


def _populate_current_master_placement_group_relations(
def _populate_placement_group_relations(
placement_group_model: PlacementGroupModel,
instance_model: InstanceModel,
) -> None:
# Placement groups are loaded in a separate session from the instance worker.
# Reattach the already-known project/fleet objects so later detached access
# can still build a PlacementGroup value object without lazy loading.
# can build a PlacementGroup value object without lazy loading.
set_committed_value(placement_group_model, "project", instance_model.project)
if instance_model.fleet is not None:
set_committed_value(placement_group_model, "fleet", instance_model.fleet)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from dstack._internal.core.models.placement import PlacementGroup, PlacementGroupProvisioningData
from dstack._internal.core.models.runs import JobProvisioningData
from dstack._internal.server.background.pipeline_tasks.instances import InstanceWorker
from dstack._internal.server.models import PlacementGroupModel
from dstack._internal.server.models import FleetModel, InstanceModel, PlacementGroupModel
from dstack._internal.server.testing.common import (
ComputeMockSpec,
create_fleet,
Expand All @@ -42,7 +42,9 @@
)


async def _set_current_master_instance(session: AsyncSession, fleet, instance) -> None:
async def _set_current_master_instance(
session: AsyncSession, fleet: FleetModel, instance: InstanceModel
) -> None:
fleet.current_master_instance_id = None if instance is None else instance.id
await session.commit()

Expand Down Expand Up @@ -812,6 +814,63 @@ def create_instance_method(
to_be_deleted_count = sum(pg.fleet_deleted for pg in placement_groups)
assert to_be_deleted_count == 2

async def test_master_reuses_existing_placement_group(
self,
test_db,
session: AsyncSession,
worker: InstanceWorker,
) -> None:
# Regression test for https://github.com/dstackai/dstack/issues/3904
project = await create_project(session=session)
fleet = await create_fleet(
session,
project,
spec=get_fleet_spec(
conf=get_fleet_configuration(
placement=InstanceGroupPlacement.CLUSTER,
nodes=FleetNodesSpec(min=1, target=1, max=1),
)
),
)
master_instance = await create_instance(
session=session,
project=project,
fleet=fleet,
status=InstanceStatus.PENDING,
offer=None,
job_provisioning_data=None,
)
await _set_current_master_instance(session, fleet, master_instance)
preexisting_pg = await create_placement_group(
session=session,
project=project,
fleet=fleet,
)

backend_mock = Mock()
backend_mock.TYPE = BackendType.AWS
backend_mock.compute.return_value = Mock(spec=ComputeMockSpec)
backend_mock.compute.return_value.get_offers.return_value = [
get_instance_offer_with_availability()
]
backend_mock.compute.return_value.is_suitable_placement_group.return_value = True
backend_mock.compute.return_value.create_instance.return_value = (
get_job_provisioning_data()
)
with patch("dstack._internal.server.services.backends.get_project_backends") as m:
m.return_value = [backend_mock]
await process_instance(session, worker, master_instance)

await session.refresh(master_instance)
assert master_instance.status == InstanceStatus.PROVISIONING
assert backend_mock.compute.return_value.create_placement_group.call_count == 0
create_call = backend_mock.compute.return_value.create_instance.call_args
assert create_call is not None
assert create_call.args[2] is not None
assert create_call.args[2].name == preexisting_pg.name
placement_groups = (await session.execute(select(PlacementGroupModel))).scalars().all()
assert len(placement_groups) == 1

@pytest.mark.parametrize("err", [NoCapacityError(), RuntimeError()])
async def test_handles_create_placement_group_errors(
self,
Expand Down
Loading