Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion sagemaker-serve/src/sagemaker/serve/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2774,7 +2774,7 @@ def _deploy_core_endpoint(self, **kwargs):
self._deserializer = deserializer

data_capture_config = kwargs.get("data_capture_config", None)
volume_size = kwargs.get("volume_size", None)
volume_size = kwargs.get("volume_size", getattr(self, "volume_size", None))
inference_recommendation_id = kwargs.get("inference_recommendation_id", None)
explainer_config = kwargs.get("explainer_config", None)
endpoint_logging = kwargs.get("endpoint_logging", False)
Expand Down Expand Up @@ -3620,6 +3620,7 @@ def from_jumpstart_config(
"container_startup_health_check_timeout"
)
mb_instance.inference_ami_version = deploy_kwargs.get("inference_ami_version")
mb_instance.volume_size = deploy_kwargs.get("volume_size")

# Apply network isolation from JumpStart model spec if not set by user via network param
if not mb_instance._enable_network_isolation and deploy_kwargs.get(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,53 @@ def test_jumpstart_build_enables_network_isolation():
finally:
core_model.delete()
logger.info("Model deleted.")


VOLUME_SIZE_MODEL_ID = "meta-textgenerationneuron-llama-2-7b"
VOLUME_SIZE_INSTANCE_TYPE = "ml.inf2.xlarge"


@pytest.mark.slow_test
def test_jumpstart_build_sets_volume_size():
"""Integration test verifying volume_size from model specs is propagated.

JumpStart model specs define inference_volume_size for models that need
large EBS volumes for model weights. This test validates that ModelBuilder
propagates volume_size through both from_jumpstart_config() and build() paths,
matching v2 behavior where VolumeSizeInGB appears in CreateEndpointConfig.
"""
logger.info("Starting JumpStart volume_size integration test...")

# Test from_jumpstart_config path
compute = Compute(instance_type=VOLUME_SIZE_INSTANCE_TYPE)
jumpstart_config = JumpStartConfig(model_id=VOLUME_SIZE_MODEL_ID, accept_eula=True)
model_builder = ModelBuilder.from_jumpstart_config(
jumpstart_config=jumpstart_config, compute=compute
)

assert getattr(model_builder, "volume_size", None) is not None, (
f"ModelBuilder.volume_size should be set after from_jumpstart_config() "
f"for model {VOLUME_SIZE_MODEL_ID} on {VOLUME_SIZE_INSTANCE_TYPE}, got None"
)
logger.info(f"from_jumpstart_config set volume_size={model_builder.volume_size}")

# Test build path (also sets volume_size via _build_for_jumpstart)
unique_id = str(uuid.uuid4())[:8]
core_model = model_builder.build(model_name=f"js-volsize-test-{unique_id}")
logger.info(f"Model created: {core_model.model_name}")

try:
assert getattr(model_builder, "volume_size", None) is not None, (
f"ModelBuilder.volume_size should persist after build() "
f"for model {VOLUME_SIZE_MODEL_ID}, got None"
)
assert model_builder.volume_size >= 256, (
f"volume_size should be >= 256, "
f"got {model_builder.volume_size}"
)
logger.info(
f"✅ volume_size={model_builder.volume_size} correctly set"
)
finally:
core_model.delete()
logger.info("Model deleted.")
134 changes: 129 additions & 5 deletions sagemaker-serve/tests/unit/test_model_builder_coverage_boost.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from sagemaker.serve.mode.function_pointers import Mode
from sagemaker.serve.utils.types import ModelServer
from sagemaker.core.training.configs import Compute, Networking
from sagemaker.core.jumpstart.configs import JumpStartConfig
from sagemaker.core.inference_config import AsyncInferenceConfig
from botocore.exceptions import ClientError


class TestModelBuilderInit(unittest.TestCase):
Expand Down Expand Up @@ -189,7 +192,6 @@ class TestBuildDefaultAsyncInferenceConfig(unittest.TestCase):

def test_build_default_async_config(self):
"""Test building default async inference config."""
from sagemaker.core.inference_config import AsyncInferenceConfig

mb = ModelBuilder(model=Mock())
mb.model_name = "test-model"
Expand Down Expand Up @@ -256,7 +258,6 @@ def test_does_ic_exist_true(self):

def test_does_ic_exist_false(self):
"""Test IC doesn't exist."""
from botocore.exceptions import ClientError

mb = ModelBuilder(model=Mock())
mb.sagemaker_session = Mock()
Expand Down Expand Up @@ -366,7 +367,6 @@ class TestFromJumpStartConfig(unittest.TestCase):

def test_from_jumpstart_config_basic(self):
"""Test creating ModelBuilder from JumpStart config."""
from sagemaker.core.jumpstart.configs import JumpStartConfig

js_config = JumpStartConfig(
model_id="test-model",
Expand All @@ -384,8 +384,6 @@ def test_from_jumpstart_config_basic(self):
@patch("sagemaker.serve.model_builder._retrieve_model_deploy_kwargs")
def test_from_jumpstart_config_applies_network_isolation(self, mock_deploy_kwargs):
"""Test that enable_network_isolation from deploy kwargs is applied."""
from sagemaker.core.jumpstart.configs import JumpStartConfig
from sagemaker.core.training.configs import Compute

mock_deploy_kwargs.return_value = {
"model_data_download_timeout": 600,
Expand All @@ -409,6 +407,132 @@ def test_from_jumpstart_config_applies_network_isolation(self, mock_deploy_kwarg

self.assertTrue(mb._enable_network_isolation)

@patch("sagemaker.serve.model_builder._retrieve_model_deploy_kwargs")
def test_from_jumpstart_config_applies_volume_size(self, mock_deploy_kwargs):
"""Test that volume_size from deploy kwargs is applied."""

mock_deploy_kwargs.return_value = {
"model_data_download_timeout": 600,
"volume_size": 256,
}

js_config = JumpStartConfig(
model_id="meta-textgenerationneuron-llama-2-7b",
model_version="1.0.0"
)

mock_session = Mock()
mock_session.boto_region_name = "us-west-2"

mb = ModelBuilder.from_jumpstart_config(
jumpstart_config=js_config,
role_arn="arn:aws:iam::123456789012:role/SageMakerRole",
compute=Compute(instance_type="ml.inf2.xlarge"),
sagemaker_session=mock_session,
)

self.assertEqual(mb.volume_size, 256)

@patch("sagemaker.serve.model_builder.Endpoint.get")
@patch("sagemaker.serve.model_builder.session_helper.production_variant")
@patch("sagemaker.serve.model_builder._retrieve_model_deploy_kwargs")
def test_deploy_passes_volume_size_to_production_variant(
self, mock_deploy_kwargs, mock_prod_variant, mock_endpoint_get
):
"""Test that volume_size kwarg passed to deploy() reaches production_variant."""

mock_deploy_kwargs.return_value = {"volume_size": 256}
mock_prod_variant.return_value = {"VariantName": "AllTraffic"}
mock_endpoint_get.return_value = Mock()

js_config = JumpStartConfig(
model_id="meta-textgenerationneuron-llama-2-7b",
model_version="1.0.0",
)

mock_session = Mock()
mock_session.boto_region_name = "us-west-2"
mock_session.endpoint_in_service_or_not = Mock(return_value=False)
mock_session.endpoint_from_production_variants = Mock()
mock_session.sagemaker_config = {}
mock_session.settings = Mock()
mock_session.settings.include_jumpstart_tags = False
mock_session._append_sagemaker_config_tags = Mock(return_value=[])

mb = ModelBuilder.from_jumpstart_config(
jumpstart_config=js_config,
role_arn="arn:aws:iam::123456789012:role/SageMakerRole",
compute=Compute(instance_type="ml.inf2.xlarge"),
sagemaker_session=mock_session,
)
mb.built_model = Mock()
mb.built_model.model_name = "test-model"
mb.model_server = None
mb.mode = Mode.SAGEMAKER_ENDPOINT

# Deploy with explicit volume_size=512 overriding spec's 256
mb.deploy(
endpoint_name="test-ep",
instance_type="ml.inf2.xlarge",
initial_instance_count=1,
volume_size=512,
wait=False,
)

# Verify production_variant was called with user's 512, not spec's 256
mock_prod_variant.assert_called_once()
call_kwargs = mock_prod_variant.call_args[1]
self.assertEqual(call_kwargs["volume_size"], 512)

@patch("sagemaker.serve.model_builder.Endpoint.get")
@patch("sagemaker.serve.model_builder.session_helper.production_variant")
@patch("sagemaker.serve.model_builder._retrieve_model_deploy_kwargs")
def test_deploy_uses_spec_volume_size_when_not_passed(
self, mock_deploy_kwargs, mock_prod_variant, mock_endpoint_get
):
"""Test that volume_size from spec is used when customer doesn't pass it."""

mock_deploy_kwargs.return_value = {"volume_size": 256}
mock_prod_variant.return_value = {"VariantName": "AllTraffic"}
mock_endpoint_get.return_value = Mock()

js_config = JumpStartConfig(
model_id="meta-textgenerationneuron-llama-2-7b",
model_version="1.0.0",
)

mock_session = Mock()
mock_session.boto_region_name = "us-west-2"
mock_session.endpoint_in_service_or_not = Mock(return_value=False)
mock_session.endpoint_from_production_variants = Mock()
mock_session.sagemaker_config = {}
mock_session.settings = Mock()
mock_session.settings.include_jumpstart_tags = False
mock_session._append_sagemaker_config_tags = Mock(return_value=[])

mb = ModelBuilder.from_jumpstart_config(
jumpstart_config=js_config,
role_arn="arn:aws:iam::123456789012:role/SageMakerRole",
compute=Compute(instance_type="ml.inf2.xlarge"),
sagemaker_session=mock_session,
)
mb.built_model = Mock()
mb.built_model.model_name = "test-model"
mb.model_server = None
mb.mode = Mode.SAGEMAKER_ENDPOINT

# Deploy WITHOUT passing volume_size — should use spec's 256
mb.deploy(
endpoint_name="test-ep",
instance_type="ml.inf2.xlarge",
initial_instance_count=1,
wait=False,
)

mock_prod_variant.assert_called_once()
call_kwargs = mock_prod_variant.call_args[1]
self.assertEqual(call_kwargs["volume_size"], 256)


if __name__ == "__main__":
unittest.main()
Loading