diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder.py b/sagemaker-serve/src/sagemaker/serve/model_builder.py index 6f60d325ad..b68572fd2a 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder.py @@ -2774,7 +2774,7 @@ def _deploy_core_endpoint(self, **kwargs): self._deserializer = deserializer data_capture_config = kwargs.get("data_capture_config", None) - volume_size = kwargs.get("volume_size", None) + volume_size = kwargs.get("volume_size", getattr(self, "volume_size", None)) inference_recommendation_id = kwargs.get("inference_recommendation_id", None) explainer_config = kwargs.get("explainer_config", None) endpoint_logging = kwargs.get("endpoint_logging", False) @@ -3620,6 +3620,7 @@ def from_jumpstart_config( "container_startup_health_check_timeout" ) mb_instance.inference_ami_version = deploy_kwargs.get("inference_ami_version") + mb_instance.volume_size = deploy_kwargs.get("volume_size") # Apply network isolation from JumpStart model spec if not set by user via network param if not mb_instance._enable_network_isolation and deploy_kwargs.get( diff --git a/sagemaker-serve/tests/integ/test_jumpstart_network_isolation.py b/sagemaker-serve/tests/integ/test_jumpstart_deploy_parity.py similarity index 55% rename from sagemaker-serve/tests/integ/test_jumpstart_network_isolation.py rename to sagemaker-serve/tests/integ/test_jumpstart_deploy_parity.py index 0dd2f214d2..ab714ad96b 100644 --- a/sagemaker-serve/tests/integ/test_jumpstart_network_isolation.py +++ b/sagemaker-serve/tests/integ/test_jumpstart_deploy_parity.py @@ -65,3 +65,53 @@ def test_jumpstart_build_enables_network_isolation(): finally: core_model.delete() logger.info("Model deleted.") + + +VOLUME_SIZE_MODEL_ID = "meta-textgenerationneuron-llama-2-7b" +VOLUME_SIZE_INSTANCE_TYPE = "ml.inf2.xlarge" + + +@pytest.mark.slow_test +def test_jumpstart_build_sets_volume_size(): + """Integration test verifying volume_size from model specs is propagated. + + JumpStart model specs define inference_volume_size for models that need + large EBS volumes for model weights. This test validates that ModelBuilder + propagates volume_size through both from_jumpstart_config() and build() paths, + matching v2 behavior where VolumeSizeInGB appears in CreateEndpointConfig. + """ + logger.info("Starting JumpStart volume_size integration test...") + + # Test from_jumpstart_config path + compute = Compute(instance_type=VOLUME_SIZE_INSTANCE_TYPE) + jumpstart_config = JumpStartConfig(model_id=VOLUME_SIZE_MODEL_ID, accept_eula=True) + model_builder = ModelBuilder.from_jumpstart_config( + jumpstart_config=jumpstart_config, compute=compute + ) + + assert getattr(model_builder, "volume_size", None) is not None, ( + f"ModelBuilder.volume_size should be set after from_jumpstart_config() " + f"for model {VOLUME_SIZE_MODEL_ID} on {VOLUME_SIZE_INSTANCE_TYPE}, got None" + ) + logger.info(f"from_jumpstart_config set volume_size={model_builder.volume_size}") + + # Test build path (also sets volume_size via _build_for_jumpstart) + unique_id = str(uuid.uuid4())[:8] + core_model = model_builder.build(model_name=f"js-volsize-test-{unique_id}") + logger.info(f"Model created: {core_model.model_name}") + + try: + assert getattr(model_builder, "volume_size", None) is not None, ( + f"ModelBuilder.volume_size should persist after build() " + f"for model {VOLUME_SIZE_MODEL_ID}, got None" + ) + assert model_builder.volume_size >= 256, ( + f"volume_size should be >= 256, " + f"got {model_builder.volume_size}" + ) + logger.info( + f"✅ volume_size={model_builder.volume_size} correctly set" + ) + finally: + core_model.delete() + logger.info("Model deleted.") diff --git a/sagemaker-serve/tests/unit/test_model_builder_coverage_boost.py b/sagemaker-serve/tests/unit/test_model_builder_coverage_boost.py index fafd132547..023a2494d1 100644 --- a/sagemaker-serve/tests/unit/test_model_builder_coverage_boost.py +++ b/sagemaker-serve/tests/unit/test_model_builder_coverage_boost.py @@ -12,6 +12,9 @@ from sagemaker.serve.mode.function_pointers import Mode from sagemaker.serve.utils.types import ModelServer from sagemaker.core.training.configs import Compute, Networking +from sagemaker.core.jumpstart.configs import JumpStartConfig +from sagemaker.core.inference_config import AsyncInferenceConfig +from botocore.exceptions import ClientError class TestModelBuilderInit(unittest.TestCase): @@ -189,7 +192,6 @@ class TestBuildDefaultAsyncInferenceConfig(unittest.TestCase): def test_build_default_async_config(self): """Test building default async inference config.""" - from sagemaker.core.inference_config import AsyncInferenceConfig mb = ModelBuilder(model=Mock()) mb.model_name = "test-model" @@ -256,7 +258,6 @@ def test_does_ic_exist_true(self): def test_does_ic_exist_false(self): """Test IC doesn't exist.""" - from botocore.exceptions import ClientError mb = ModelBuilder(model=Mock()) mb.sagemaker_session = Mock() @@ -366,7 +367,6 @@ class TestFromJumpStartConfig(unittest.TestCase): def test_from_jumpstart_config_basic(self): """Test creating ModelBuilder from JumpStart config.""" - from sagemaker.core.jumpstart.configs import JumpStartConfig js_config = JumpStartConfig( model_id="test-model", @@ -384,8 +384,6 @@ def test_from_jumpstart_config_basic(self): @patch("sagemaker.serve.model_builder._retrieve_model_deploy_kwargs") def test_from_jumpstart_config_applies_network_isolation(self, mock_deploy_kwargs): """Test that enable_network_isolation from deploy kwargs is applied.""" - from sagemaker.core.jumpstart.configs import JumpStartConfig - from sagemaker.core.training.configs import Compute mock_deploy_kwargs.return_value = { "model_data_download_timeout": 600, @@ -409,6 +407,132 @@ def test_from_jumpstart_config_applies_network_isolation(self, mock_deploy_kwarg self.assertTrue(mb._enable_network_isolation) + @patch("sagemaker.serve.model_builder._retrieve_model_deploy_kwargs") + def test_from_jumpstart_config_applies_volume_size(self, mock_deploy_kwargs): + """Test that volume_size from deploy kwargs is applied.""" + + mock_deploy_kwargs.return_value = { + "model_data_download_timeout": 600, + "volume_size": 256, + } + + js_config = JumpStartConfig( + model_id="meta-textgenerationneuron-llama-2-7b", + model_version="1.0.0" + ) + + mock_session = Mock() + mock_session.boto_region_name = "us-west-2" + + mb = ModelBuilder.from_jumpstart_config( + jumpstart_config=js_config, + role_arn="arn:aws:iam::123456789012:role/SageMakerRole", + compute=Compute(instance_type="ml.inf2.xlarge"), + sagemaker_session=mock_session, + ) + + self.assertEqual(mb.volume_size, 256) + + @patch("sagemaker.serve.model_builder.Endpoint.get") + @patch("sagemaker.serve.model_builder.session_helper.production_variant") + @patch("sagemaker.serve.model_builder._retrieve_model_deploy_kwargs") + def test_deploy_passes_volume_size_to_production_variant( + self, mock_deploy_kwargs, mock_prod_variant, mock_endpoint_get + ): + """Test that volume_size kwarg passed to deploy() reaches production_variant.""" + + mock_deploy_kwargs.return_value = {"volume_size": 256} + mock_prod_variant.return_value = {"VariantName": "AllTraffic"} + mock_endpoint_get.return_value = Mock() + + js_config = JumpStartConfig( + model_id="meta-textgenerationneuron-llama-2-7b", + model_version="1.0.0", + ) + + mock_session = Mock() + mock_session.boto_region_name = "us-west-2" + mock_session.endpoint_in_service_or_not = Mock(return_value=False) + mock_session.endpoint_from_production_variants = Mock() + mock_session.sagemaker_config = {} + mock_session.settings = Mock() + mock_session.settings.include_jumpstart_tags = False + mock_session._append_sagemaker_config_tags = Mock(return_value=[]) + + mb = ModelBuilder.from_jumpstart_config( + jumpstart_config=js_config, + role_arn="arn:aws:iam::123456789012:role/SageMakerRole", + compute=Compute(instance_type="ml.inf2.xlarge"), + sagemaker_session=mock_session, + ) + mb.built_model = Mock() + mb.built_model.model_name = "test-model" + mb.model_server = None + mb.mode = Mode.SAGEMAKER_ENDPOINT + + # Deploy with explicit volume_size=512 overriding spec's 256 + mb.deploy( + endpoint_name="test-ep", + instance_type="ml.inf2.xlarge", + initial_instance_count=1, + volume_size=512, + wait=False, + ) + + # Verify production_variant was called with user's 512, not spec's 256 + mock_prod_variant.assert_called_once() + call_kwargs = mock_prod_variant.call_args[1] + self.assertEqual(call_kwargs["volume_size"], 512) + + @patch("sagemaker.serve.model_builder.Endpoint.get") + @patch("sagemaker.serve.model_builder.session_helper.production_variant") + @patch("sagemaker.serve.model_builder._retrieve_model_deploy_kwargs") + def test_deploy_uses_spec_volume_size_when_not_passed( + self, mock_deploy_kwargs, mock_prod_variant, mock_endpoint_get + ): + """Test that volume_size from spec is used when customer doesn't pass it.""" + + mock_deploy_kwargs.return_value = {"volume_size": 256} + mock_prod_variant.return_value = {"VariantName": "AllTraffic"} + mock_endpoint_get.return_value = Mock() + + js_config = JumpStartConfig( + model_id="meta-textgenerationneuron-llama-2-7b", + model_version="1.0.0", + ) + + mock_session = Mock() + mock_session.boto_region_name = "us-west-2" + mock_session.endpoint_in_service_or_not = Mock(return_value=False) + mock_session.endpoint_from_production_variants = Mock() + mock_session.sagemaker_config = {} + mock_session.settings = Mock() + mock_session.settings.include_jumpstart_tags = False + mock_session._append_sagemaker_config_tags = Mock(return_value=[]) + + mb = ModelBuilder.from_jumpstart_config( + jumpstart_config=js_config, + role_arn="arn:aws:iam::123456789012:role/SageMakerRole", + compute=Compute(instance_type="ml.inf2.xlarge"), + sagemaker_session=mock_session, + ) + mb.built_model = Mock() + mb.built_model.model_name = "test-model" + mb.model_server = None + mb.mode = Mode.SAGEMAKER_ENDPOINT + + # Deploy WITHOUT passing volume_size — should use spec's 256 + mb.deploy( + endpoint_name="test-ep", + instance_type="ml.inf2.xlarge", + initial_instance_count=1, + wait=False, + ) + + mock_prod_variant.assert_called_once() + call_kwargs = mock_prod_variant.call_args[1] + self.assertEqual(call_kwargs["volume_size"], 256) + if __name__ == "__main__": unittest.main()