From a0da02f79e2cfc0e9ae612a96125950c45dc296b Mon Sep 17 00:00:00 2001 From: Lucas Jia Date: Wed, 3 Jun 2026 10:42:03 -0700 Subject: [PATCH 1/4] feat: add import job polling and provisioned throughput for Bedrock OSS deployments - deploy() for non-Nova models now waits for import job completion and returns job details (model ready for on-demand inference). - New public method: create_provisioned_throughput() with polling. - New private methods: _wait_for_import_job_complete(), _wait_for_provisioned_throughput_in_service(). - Added unit tests and integ tests (serial to avoid concurrent quota issues). - Mark bedrock integ tests as serial to avoid concurrent import job quota issues. X-AI-Prompt: add import polling and PT for bedrock OSS deployments X-AI-Tool: kiro-cli --- .gitignore | 1 + .../sagemaker/serve/bedrock_model_builder.py | 164 ++++++++- .../test_bedrock_provisioned_throughput.py | 313 +++++++++++++++++ .../test_model_customization_deployment.py | 1 + .../tests/unit/test_bedrock_model_builder.py | 314 +++++++++++++----- 5 files changed, 696 insertions(+), 97 deletions(-) create mode 100644 sagemaker-serve/tests/integ/test_bedrock_provisioned_throughput.py diff --git a/.gitignore b/.gitignore index 811e8b5905..378048cdf0 100644 --- a/.gitignore +++ b/.gitignore @@ -44,3 +44,4 @@ sagemaker_train/src/**/container_drivers/distributed.json docs/api/generated/ .hypothesis .kiro +bedrock_api_logs/ diff --git a/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py b/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py index fc269343d4..786cea18b2 100644 --- a/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py +++ b/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py @@ -157,15 +157,15 @@ def deploy( For Nova models, also creates a custom model deployment for inference. Args: - job_name: Name for the model import job (non-Nova models only). - imported_model_name: Name for the imported model (non-Nova models only). + job_name: Name for the model import job (OSS models only). + imported_model_name: Name for the imported model (OSS models only). custom_model_name: Name for the custom model (Nova models only). role_arn: IAM role ARN with permissions for Bedrock operations. - job_tags: Tags for the import job (non-Nova models only). - imported_model_tags: Tags for the imported model (non-Nova models only). + job_tags: Tags for the import job (OSS models only). + imported_model_tags: Tags for the imported model (OSS models only). model_tags: Tags for the custom model (Nova models only). - client_request_token: Unique token for idempotency (non-Nova models only). - imported_model_kms_key_id: KMS key ID for encryption (non-Nova models only). + client_request_token: Unique token for idempotency (OSS models only). + imported_model_kms_key_id: KMS key ID for encryption (OSS models only). deployment_name: Name for the deployment (Nova models only). If not provided, defaults to custom_model_name suffixed with '-deployment'. @@ -238,15 +238,23 @@ def deploy( } params = {k: v for k, v in params.items() if v is not None} - logger.info("Creating model import job for non-Nova deployment") + logger.info("Creating model import job for OSS model deployment") print(f"[BedrockModelBuilder] Resolved S3 artifacts path: {self.s3_model_artifacts}") print(f"[BedrockModelBuilder] create_model_import_job params: {params}") - response = self._get_bedrock_client().create_model_import_job(**params) + import_response = self._get_bedrock_client().create_model_import_job(**params) logger.warning( - "Bedrock create_model_import_job request: %s, response: %s", params, response + "Bedrock create_model_import_job request: %s, response: %s", params, import_response ) - _log_bedrock_api_call("create_model_import_job", params, response) - return response + _log_bedrock_api_call("create_model_import_job", params, import_response) + + job_arn = import_response.get("jobArn") + self._wait_for_import_job_complete(job_arn) + + # Return the completed job details + job_details = self._get_bedrock_client().get_model_import_job( + jobIdentifier=job_arn + ) + return job_details def create_deployment( self, @@ -303,6 +311,140 @@ def create_deployment( return response + def create_provisioned_throughput( + self, + model_id: str, + provisioned_model_name: str, + model_units: int = 1, + commitment_duration: Optional[str] = None, + tags: Optional[list] = None, + poll_interval: int = 60, + max_wait: int = 3600, + ) -> Dict[str, Any]: + """Create provisioned throughput for an imported model on Bedrock. + + Calls CreateProvisionedModelThroughput and polls until the provisioned + throughput reaches InService status. + + Args: + model_id: ARN or ID of the imported model. + provisioned_model_name: Name for the provisioned throughput resource. + model_units: Number of model units to provision. Defaults to 1. + commitment_duration: Commitment duration. Valid values: 'OneMonth', + 'SixMonths'. If not provided, no commitment is set (on-demand). + tags: Tags for the provisioned throughput resource. + poll_interval: Seconds between status checks. Defaults to 60. + max_wait: Maximum seconds to wait. Defaults to 3600. + + Returns: + Response from Bedrock create_provisioned_model_throughput API. + + Raises: + RuntimeError: If the provisioned throughput fails or times out. + ValueError: If model_id or provisioned_model_name is not provided. + """ + if not model_id: + raise ValueError("model_id is required for create_provisioned_throughput.") + if not provisioned_model_name: + raise ValueError( + "provisioned_model_name is required for create_provisioned_throughput." + ) + + params = { + "modelId": model_id, + "provisionedModelName": provisioned_model_name, + "modelUnits": model_units, + } + if commitment_duration: + params["commitmentDuration"] = commitment_duration + if tags: + params["tags"] = tags + + logger.info( + "Creating provisioned throughput '%s' for model %s with %d model units", + provisioned_model_name, + model_id, + model_units, + ) + response = self._get_bedrock_client().create_provisioned_model_throughput(**params) + + provisioned_model_arn = response.get("provisionedModelArn") + if provisioned_model_arn: + self._wait_for_provisioned_throughput_in_service( + provisioned_model_arn, poll_interval=poll_interval, max_wait=max_wait + ) + + return response + + def _wait_for_import_job_complete( + self, job_arn: str, poll_interval: int = 60, max_wait: int = 3600 + ): + """Poll Bedrock until the model import job reaches Completed status. + + Args: + job_arn: ARN of the model import job. + poll_interval: Seconds between status checks. Defaults to 60. + max_wait: Maximum seconds to wait. Defaults to 3600. + + Raises: + RuntimeError: If the import job fails or times out. + """ + elapsed = 0 + status = None + while elapsed < max_wait: + resp = self._get_bedrock_client().get_model_import_job(jobIdentifier=job_arn) + status = resp.get("status") + logger.info("Import job status: %s (elapsed %ds)", status, elapsed) + if status == "Completed": + return + if status == "Failed": + failure_reason = resp.get("failureMessage", "Unknown") + raise RuntimeError( + f"Model import job {job_arn} failed. Reason: {failure_reason}" + ) + time.sleep(poll_interval) + elapsed += poll_interval + raise RuntimeError( + f"Timed out after {max_wait}s waiting for import job {job_arn} to complete. " + f"Last status: {status}" + ) + + def _wait_for_provisioned_throughput_in_service( + self, provisioned_model_arn: str, poll_interval: int = 60, max_wait: int = 3600 + ): + """Poll Bedrock until provisioned throughput reaches InService status. + + Args: + provisioned_model_arn: ARN of the provisioned model throughput. + poll_interval: Seconds between status checks. Defaults to 60. + max_wait: Maximum seconds to wait. Defaults to 3600. + + Raises: + RuntimeError: If the provisioned throughput fails or times out. + """ + elapsed = 0 + status = None + while elapsed < max_wait: + resp = self._get_bedrock_client().get_provisioned_model_throughput( + provisionedModelId=provisioned_model_arn + ) + status = resp.get("status") + logger.info("Provisioned throughput status: %s (elapsed %ds)", status, elapsed) + if status == "InService": + return + if status == "Failed": + failure_reason = resp.get("failureMessage", "Unknown") + raise RuntimeError( + f"Provisioned throughput {provisioned_model_arn} failed. " + f"Reason: {failure_reason}" + ) + time.sleep(poll_interval) + elapsed += poll_interval + raise RuntimeError( + f"Timed out after {max_wait}s waiting for provisioned throughput " + f"{provisioned_model_arn} to become InService. Last status: {status}" + ) + def _wait_for_model_active( self, model_arn: str, poll_interval: int = 60, max_wait: int = 3600 ): diff --git a/sagemaker-serve/tests/integ/test_bedrock_provisioned_throughput.py b/sagemaker-serve/tests/integ/test_bedrock_provisioned_throughput.py new file mode 100644 index 0000000000..aa4ee03cb8 --- /dev/null +++ b/sagemaker-serve/tests/integ/test_bedrock_provisioned_throughput.py @@ -0,0 +1,313 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Integration tests for BedrockModelBuilder import job polling and provisioned throughput.""" +from __future__ import absolute_import + +import json +import time +import random +import logging +from urllib.parse import urlparse + +import boto3 +import pytest + +from sagemaker.core.helper.session_helper import Session, get_execution_role +from sagemaker.core.resources import TrainingJob +from sagemaker.serve.bedrock_model_builder import BedrockModelBuilder + +logger = logging.getLogger(__name__) + +AWS_REGION = "us-west-2" + + +@pytest.fixture(scope="module") +def training_job_name(): + """Training job name for testing (OSS model).""" + return "meta-textgeneration-llama-3-2-1b-instruct-sft-20251201172445" + + +@pytest.fixture(scope="module") +def role_arn(): + """IAM role ARN with Bedrock permissions.""" + return get_execution_role() + + +@pytest.fixture(scope="module") +def bedrock_client(): + """Create Bedrock client.""" + return boto3.client("bedrock", region_name=AWS_REGION) + + +@pytest.fixture(scope="module") +def s3_client(): + """Create S3 client.""" + return boto3.client("s3", region_name=AWS_REGION) + + +@pytest.fixture(scope="module") +def training_job(training_job_name): + """Get the training job.""" + return TrainingJob.get( + training_job_name=training_job_name, region=AWS_REGION + ) + + +def _setup_model_files(s3_artifacts_uri, s3_client): + """Setup required model files for Bedrock deployment. + + Bedrock model import requires HuggingFace-format files (config.json, + tokenizer.json, etc.) at the root of the S3 model artifacts path. + Training jobs often store these under checkpoints/hf_merged/, so we + copy them to the expected location. + + Args: + s3_artifacts_uri: The S3 URI that BedrockModelBuilder will use for import. + s3_client: boto3 S3 client. + """ + parsed = urlparse(s3_artifacts_uri) + bucket = parsed.netloc + base_prefix = parsed.path.lstrip("/").rstrip("/") + + hf_merged_prefix = f"{base_prefix}/checkpoints/hf_merged/" + root_prefix = f"{base_prefix}/" + + files_to_copy = [ + "config.json", + "tokenizer.json", + "tokenizer_config.json", + "model.safetensors", + ] + + for file in files_to_copy: + try: + s3_client.head_object(Bucket=bucket, Key=root_prefix + file) + logger.info("File already exists: s3://%s/%s%s", bucket, root_prefix, file) + except Exception: + try: + s3_client.copy_object( + Bucket=bucket, + CopySource={"Bucket": bucket, "Key": hf_merged_prefix + file}, + Key=root_prefix + file, + ) + logger.info("Copied %s to root", file) + except Exception as e: + logger.warning("Could not copy %s: %s", file, e) + + try: + s3_client.head_object(Bucket=bucket, Key=root_prefix + "added_tokens.json") + except Exception: + try: + s3_client.put_object( + Bucket=bucket, + Key=root_prefix + "added_tokens.json", + Body=json.dumps({}), + ContentType="application/json", + ) + logger.info("Created added_tokens.json") + except Exception as e: + logger.warning("Could not create added_tokens.json: %s", e) + + +@pytest.mark.serial +class TestBedrockImportJobPolling: + """Test import job polling for OSS models (Option C: deploy only waits for import).""" + + @pytest.fixture(autouse=True) + def _setup(self, bedrock_client): + """Store bedrock client and track resources for cleanup.""" + self._bedrock_client = bedrock_client + self._imported_model_arn = None + yield + self._cleanup() + + def _cleanup(self): + """Clean up Bedrock resources created during the test.""" + if self._imported_model_arn: + try: + logger.info("Deleting imported model: %s", self._imported_model_arn) + self._bedrock_client.delete_imported_model( + modelIdentifier=self._imported_model_arn + ) + except Exception as e: + logger.warning("Failed to delete imported model: %s", e) + + @pytest.mark.slow + def test_deploy_oss_model_waits_for_import_completion( + self, training_job, role_arn, bedrock_client, s3_client + ): + """Test that deploy() waits for import job to complete and returns job details. + + This test verifies that BedrockModelBuilder.deploy() for OSS models: + 1. Creates a model import job + 2. Polls until the import job reaches Completed status + 3. Returns the completed job details (model is ready for on-demand invoke) + 4. Does NOT create provisioned throughput + """ + builder = BedrockModelBuilder(model=training_job) + assert builder.s3_model_artifacts is not None + + _setup_model_files(builder.s3_model_artifacts, s3_client) + + suffix = f"{int(time.time())}-{random.randint(1000, 9999)}" + job_name = f"test-import-poll-{suffix}" + imported_model_name = f"test-import-model-{suffix}" + + result = builder.deploy( + job_name=job_name, + imported_model_name=imported_model_name, + role_arn=role_arn, + ) + + # Verify the result is the completed job details + assert result["status"] == "Completed", ( + f"Expected Completed, got {result.get('status')}" + ) + assert "importedModelName" in result + assert "importedModelArn" in result or "jobArn" in result + + # Track for cleanup + self._imported_model_arn = result.get("importedModelArn") + + # Verify model can be found (it exists and is ready) + models = bedrock_client.list_imported_models() + model_names = [m["modelName"] for m in models.get("modelSummaries", [])] + assert imported_model_name in model_names + + +@pytest.mark.serial +class TestBedrockProvisionedThroughput: + """Test create_provisioned_throughput as a standalone method. + + Uses a pre-existing Bedrock custom model (fine-tuned Llama 3.1 8B) to test + provisioned throughput creation and polling. The custom model was created via + Bedrock CreateModelCustomizationJob and persists in the CI account. + + Prerequisites: + - Account 729646638167, us-west-2 + - PT MU quota for Llama 3.1 8B (requested via Matador/Bedrock team) + - A pre-existing custom model (see below for how to recreate) + + How to recreate the custom model if it gets deleted: + + 1. Ensure training data exists at: + s3://mc-flows-sdk-testing/pt-test-data/train_llama31.jsonl + + If not, create it (minimal JSONL with prompt/completion pairs): + echo '{"prompt":"What is ML?","completion":"ML is a subset of AI."}' > /tmp/train.jsonl + aws s3 cp /tmp/train.jsonl s3://mc-flows-sdk-testing/pt-test-data/train_llama31.jsonl + + 2. Create the fine-tuning job: + aws bedrock create-model-customization-job \\ + --job-name test-llama31-8b-pt-integ \\ + --custom-model-name test-llama31-8b-pt-model \\ + --role-arn arn:aws:iam::729646638167:role/Admin \\ + --base-model-identifier meta.llama3-1-8b-instruct-v1:0:128k \\ + --customization-type FINE_TUNING \\ + --training-data-config '{"s3Uri":"s3://mc-flows-sdk-testing/pt-test-data/train_llama31.jsonl"}' \\ + --output-data-config '{"s3Uri":"s3://mc-flows-sdk-testing/pt-test-output/"}' \\ + --hyper-parameters '{"epochCount":"1","batchSize":"1","learningRate":"0.00001"}' \\ + --region us-west-2 + + 3. Wait for the job to complete (~2-4 hours for 8B model): + aws bedrock get-model-customization-job \\ + --job-identifier --region us-west-2 \\ + --query "status" + + 4. Update CUSTOM_MODEL_ARN below with the outputModelArn from the job. + """ + + # Pre-existing custom model created via Bedrock fine-tuning. + # Base model: meta.llama3-1-8b-instruct-v1:0:128k + # This model must exist in account 729646638167, us-west-2. + CUSTOM_MODEL_ARN = ( + "arn:aws:bedrock:us-west-2:729646638167:custom-model/" + "meta.llama3-1-8b-instruct-v1:0:128k/k2mjykwgn62p" + ) + CUSTOM_MODEL_NAME = "test-llama31-8b-pt-model" + + @pytest.fixture(autouse=True) + def _setup(self, bedrock_client): + """Store bedrock client and track resources for cleanup.""" + self._bedrock_client = bedrock_client + self._provisioned_model_arn = None + yield + # Always clean up PT, even if test fails + self._cleanup() + + def _cleanup(self): + """Clean up provisioned throughput created during the test.""" + if self._provisioned_model_arn: + try: + logger.info("Deleting provisioned throughput: %s", self._provisioned_model_arn) + self._bedrock_client.delete_provisioned_model_throughput( + provisionedModelId=self._provisioned_model_arn + ) + logger.info("Provisioned throughput deleted successfully.") + except Exception as e: + logger.warning("Failed to delete provisioned throughput: %s", e) + + @pytest.mark.slow + def test_create_provisioned_throughput(self, bedrock_client): + """Test create_provisioned_throughput() with a pre-existing custom model. + + This test verifies: + 1. Calls CreateProvisionedModelThroughput with a custom model ARN + 2. Polls until provisioned throughput reaches InService + 3. Returns the provisioned throughput response + 4. Cleans up the PT after the test + """ + # Check if the pre-existing custom model exists + try: + bedrock_client.get_custom_model(modelIdentifier=self.CUSTOM_MODEL_ARN) + except Exception: + pytest.skip( + f"Pre-existing custom model not found: {self.CUSTOM_MODEL_ARN}. " + f"Recreate it with: aws bedrock create-model-customization-job " + f"--job-name test-llama31-8b-pt-integ " + f"--custom-model-name {self.CUSTOM_MODEL_NAME} " + f"--role-arn " + f"--base-model-identifier meta.llama3-1-8b-instruct-v1:0:128k " + f"--customization-type FINE_TUNING " + f"--training-data-config '{{\"s3Uri\":\"s3://mc-flows-sdk-testing/pt-test-data/train_llama31.jsonl\"}}' " + f"--output-data-config '{{\"s3Uri\":\"s3://mc-flows-sdk-testing/pt-test-output/\"}}' " + f"--hyper-parameters '{{\"epochCount\":\"1\",\"batchSize\":\"1\",\"learningRate\":\"0.00001\"}}' " + f"--region us-west-2" + ) + + suffix = f"{int(time.time())}-{random.randint(1000, 9999)}" + provisioned_model_name = f"test-pt-integ-{suffix}" + + builder = BedrockModelBuilder(model=None) + + # Create provisioned throughput + pt_result = builder.create_provisioned_throughput( + model_id=self.CUSTOM_MODEL_ARN, + provisioned_model_name=provisioned_model_name, + model_units=1, + ) + + # Verify result contains provisioned model ARN + assert "provisionedModelArn" in pt_result, ( + f"Expected 'provisionedModelArn' in result, got keys: {list(pt_result.keys())}" + ) + self._provisioned_model_arn = pt_result["provisionedModelArn"] + + # Verify provisioned throughput is InService (create_provisioned_throughput + # already polls until InService, but double-check) + pt_response = bedrock_client.get_provisioned_model_throughput( + provisionedModelId=self._provisioned_model_arn + ) + assert pt_response["status"] == "InService", ( + f"Expected InService, got {pt_response['status']}" + ) diff --git a/sagemaker-serve/tests/integ/test_model_customization_deployment.py b/sagemaker-serve/tests/integ/test_model_customization_deployment.py index b38ca249c7..5b22c16851 100644 --- a/sagemaker-serve/tests/integ/test_model_customization_deployment.py +++ b/sagemaker-serve/tests/integ/test_model_customization_deployment.py @@ -307,6 +307,7 @@ def test_dpo_trainer_build(self, training_job_name, sagemaker_session): from sagemaker.serve.bedrock_model_builder import BedrockModelBuilder +@pytest.mark.serial class TestModelCustomizationDeployment: """Test suite for deploying fine-tuned models to Bedrock.""" diff --git a/sagemaker-serve/tests/unit/test_bedrock_model_builder.py b/sagemaker-serve/tests/unit/test_bedrock_model_builder.py index 3fdfaa01a3..6a5d54ca18 100644 --- a/sagemaker-serve/tests/unit/test_bedrock_model_builder.py +++ b/sagemaker-serve/tests/unit/test_bedrock_model_builder.py @@ -69,7 +69,7 @@ def test_nova_via_recipe_name(self): def test_nova_via_hub_content_name(self): assert _is_nova_model(_make_container(hub_content_name="amazon-nova-lite")) is True - def test_non_nova(self): + def test_oss(self): assert _is_nova_model(_make_container(recipe_name="llama-3-8b", hub_content_name="llama")) is False def test_no_base_model(self): @@ -95,8 +95,7 @@ def test_none_model(self): def test_with_model(self): m = Mock() with patch.object(BedrockModelBuilder, "_fetch_model_package", return_value=Mock()), \ - patch.object(BedrockModelBuilder, "_get_s3_artifacts", return_value="s3://b/k"), \ - patch(f"{MODULE}.is_restricted_model_package", return_value=False): + patch.object(BedrockModelBuilder, "_get_s3_artifacts", return_value="s3://b/k"): b = BedrockModelBuilder(model=m) assert b.model is m assert b.s3_model_artifacts == "s3://b/k" @@ -210,13 +209,13 @@ def test_none_when_no_model_package(self): b.model_package = None assert b._get_s3_artifacts() is None - def test_non_nova_returns_s3_uri(self): + def test_oss_returns_s3_uri(self): c = _make_container(recipe_name="llama", hub_content_name="llama", s3_uri="s3://b/m.tar.gz") b = _builder() b.model_package = _make_model_package(c) assert b._get_s3_artifacts() == "s3://b/m.tar.gz" - def test_non_nova_no_data_source(self): + def test_oss_no_data_source(self): c = _make_container(recipe_name="llama", hub_content_name="llama") b = _builder() b.model_package = _make_model_package(c) @@ -470,15 +469,48 @@ def test_timeout_raises(self): class TestDeploy: - def test_non_nova(self): + def test_oss_waits_for_import_and_returns_job_details(self): + """OSS deploy: import job → wait → return job details.""" c = _make_container(s3_uri="s3://b/m.tar.gz") b = _builder() b.model_package = _make_model_package(c) b.s3_model_artifacts = "s3://b/m.tar.gz" b._bedrock_client = Mock() b._bedrock_client.create_model_import_job.return_value = {"jobArn": "arn:job"} - result = b.deploy(job_name="j", imported_model_name="m", role_arn="r") - assert result == {"jobArn": "arn:job"} + b._bedrock_client.get_model_import_job.return_value = { + "status": "Completed", + "importedModelName": "my-imported-model", + "importedModelArn": "arn:aws:bedrock:us-west-2:123:imported-model/abc", + } + + with patch(f"{MODULE}.time.sleep"): + result = b.deploy(job_name="j", imported_model_name="m", role_arn="r") + + b._bedrock_client.create_model_import_job.assert_called_once() + b._bedrock_client.get_model_import_job.assert_called() + # Should NOT call create_provisioned_model_throughput + b._bedrock_client.create_provisioned_model_throughput.assert_not_called() + assert result["status"] == "Completed" + assert result["importedModelName"] == "my-imported-model" + + def test_oss_does_not_create_provisioned_throughput(self): + """deploy() for OSS models should never call CreateProvisionedModelThroughput.""" + c = _make_container(s3_uri="s3://b/m.tar.gz") + b = _builder() + b.model_package = _make_model_package(c) + b.s3_model_artifacts = "s3://b/m.tar.gz" + b._bedrock_client = Mock() + b._bedrock_client.create_model_import_job.return_value = {"jobArn": "arn:job"} + b._bedrock_client.get_model_import_job.return_value = { + "status": "Completed", + "importedModelName": "m", + } + + with patch(f"{MODULE}.time.sleep"): + b.deploy(job_name="j", imported_model_name="m", role_arn="r") + + b._bedrock_client.create_provisioned_model_throughput.assert_not_called() + b._bedrock_client.get_provisioned_model_throughput.assert_not_called() def test_nova_full_chain(self): c = _make_container(recipe_name="nova-micro", hub_content_name="nova") @@ -573,116 +605,226 @@ def test_nova_missing_role_arn_raises(self): with pytest.raises(ValueError, match="role_arn is required"): b.deploy(custom_model_name="m") - def test_non_nova_strips_none_params(self): + def test_oss_strips_none_params(self): c = _make_container() b = _builder() b.model_package = _make_model_package(c) b.s3_model_artifacts = "s3://b/k" b._bedrock_client = Mock() b._bedrock_client.create_model_import_job.return_value = {"jobArn": "arn"} - b.deploy(job_name="j", imported_model_name="m", role_arn="r") + b._bedrock_client.get_model_import_job.return_value = { + "status": "Completed", + "importedModelName": "m", + } + + with patch(f"{MODULE}.time.sleep"): + b.deploy(job_name="j", imported_model_name="m", role_arn="r") + kw = b._bedrock_client.create_model_import_job.call_args[1] assert "importedModelKmsKeyId" not in kw assert "clientRequestToken" not in kw - def test_nova_rmp_uses_model_package_arn_data_source(self): - """When model package is RMP, use customModelDataSource.""" - c = _make_container(recipe_name="nova-lite") + +# ── _wait_for_import_job_complete ─────────────────────────────────────────── + + +class TestWaitForImportJobComplete: + def test_immediate_completed(self): b = _builder() - pkg = _make_model_package(c) - pkg.model_package_arn = "arn:aws:sagemaker:us-east-1:123456789012:model-package/my-rmp/1" - pkg.managed_storage_type = "Restricted" - b.model_package = pkg - b._is_rmp = True - b.s3_model_artifacts = None b._bedrock_client = Mock() - b._bedrock_client.create_custom_model.return_value = {"modelArn": "arn:m"} - b._bedrock_client.get_custom_model.return_value = {"modelStatus": "Active"} - b._bedrock_client.create_custom_model_deployment.return_value = { - "customModelDeploymentArn": "arn:dep" + b._bedrock_client.get_model_import_job.return_value = {"status": "Completed"} + b._wait_for_import_job_complete("arn:job") + b._bedrock_client.get_model_import_job.assert_called_once_with( + jobIdentifier="arn:job" + ) + + def test_polls_then_completed(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_model_import_job.side_effect = [ + {"status": "InProgress"}, + {"status": "InProgress"}, + {"status": "Completed"}, + ] + with patch(f"{MODULE}.time.sleep"): + b._wait_for_import_job_complete("arn:job", poll_interval=1, max_wait=10) + assert b._bedrock_client.get_model_import_job.call_count == 3 + + def test_failed_raises(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_model_import_job.return_value = { + "status": "Failed", + "failureMessage": "Invalid model format", } - b._bedrock_client.get_custom_model_deployment.return_value = {"status": "Active"} + with pytest.raises(RuntimeError, match="Invalid model format"): + b._wait_for_import_job_complete("arn:job") - b.deploy(custom_model_name="rmp-test", role_arn="r") - kw = b._bedrock_client.create_custom_model.call_args[1] - assert "customModelDataSource" in kw - assert kw["customModelDataSource"]["modelPackageArnDataSource"]["modelPackageArn"] == ( - "arn:aws:sagemaker:us-east-1:123456789012:model-package/my-rmp/1" + def test_failed_unknown_reason(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_model_import_job.return_value = {"status": "Failed"} + with pytest.raises(RuntimeError, match="Unknown"): + b._wait_for_import_job_complete("arn:job") + + def test_timeout_raises(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_model_import_job.return_value = {"status": "InProgress"} + with patch(f"{MODULE}.time.sleep"): + with pytest.raises(RuntimeError, match="Timed out"): + b._wait_for_import_job_complete("arn:job", poll_interval=1, max_wait=2) + + +# ── create_provisioned_throughput ─────────────────────────────────────────── + + +class TestCreateProvisionedThroughput: + def test_creates_and_polls(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.create_provisioned_model_throughput.return_value = { + "provisionedModelArn": "arn:pt" + } + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "InService" + } + + result = b.create_provisioned_throughput( + model_id="arn:model", provisioned_model_name="my-pt" + ) + + b._bedrock_client.create_provisioned_model_throughput.assert_called_once_with( + modelId="arn:model", + provisionedModelName="my-pt", + modelUnits=1, ) - assert "modelSourceConfig" not in kw + b._bedrock_client.get_provisioned_model_throughput.assert_called_once() + assert result["provisionedModelArn"] == "arn:pt" - def test_nova_s3_uri_uses_model_source_config(self): - """When model package is not RMP, use modelSourceConfig (existing path).""" - c = _make_container(recipe_name="nova-lite", s3_uri="s3://bucket/checkpoint/step_10/") + def test_passes_commitment_duration(self): b = _builder() - pkg = _make_model_package(c) - pkg.managed_storage_type = None - b.model_package = pkg - b.s3_model_artifacts = "s3://bucket/checkpoint/step_10/" b._bedrock_client = Mock() - b._bedrock_client.create_custom_model.return_value = {"modelArn": "arn:m"} - b._bedrock_client.get_custom_model.return_value = {"modelStatus": "Active"} - b._bedrock_client.create_custom_model_deployment.return_value = { - "customModelDeploymentArn": "arn:dep" + b._bedrock_client.create_provisioned_model_throughput.return_value = { + "provisionedModelArn": "arn:pt" + } + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "InService" } - b._bedrock_client.get_custom_model_deployment.return_value = {"status": "Active"} - b.deploy(custom_model_name="s3-test", role_arn="r") - kw = b._bedrock_client.create_custom_model.call_args[1] - assert "modelSourceConfig" in kw - assert kw["modelSourceConfig"]["s3DataSource"]["s3Uri"] == "s3://bucket/checkpoint/step_10/" - assert "customModelDataSource" not in kw + b.create_provisioned_throughput( + model_id="arn:model", + provisioned_model_name="pt", + model_units=5, + commitment_duration="OneMonth", + ) + kw = b._bedrock_client.create_provisioned_model_throughput.call_args[1] + assert kw["modelUnits"] == 5 + assert kw["commitmentDuration"] == "OneMonth" -# ── _get_s3_artifacts RMP detection ─────────────────────────────────────── + def test_passes_tags(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.create_provisioned_model_throughput.return_value = { + "provisionedModelArn": "arn:pt" + } + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "InService" + } + tags = [{"Key": "team", "Value": "ml"}] + b.create_provisioned_throughput( + model_id="arn:model", provisioned_model_name="pt", tags=tags + ) -class TestGetS3ArtifactsRMP: - def test_nova_rmp_returns_none(self): - """When model package is RMP (s3_uri is None), return None.""" - c = _make_container(recipe_name="nova-lite") - s3_data = Mock() - s3_data.s3_uri = None - data_source = Mock() - data_source.s3_data_source = s3_data - c.model_data_source = data_source + kw = b._bedrock_client.create_provisioned_model_throughput.call_args[1] + assert kw["tags"] == tags - pkg = _make_model_package(c) - pkg.model_package_arn = "arn:aws:sagemaker:us-east-1:123456789012:model-package/rmp/1" - pkg.managed_storage_type = "Restricted" + def test_skips_polling_when_no_arn_in_response(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.create_provisioned_model_throughput.return_value = {} + b.create_provisioned_throughput( + model_id="arn:model", provisioned_model_name="pt" + ) + b._bedrock_client.get_provisioned_model_throughput.assert_not_called() + + def test_empty_model_id_raises(self): b = _builder() - b.model = "not-a-training-job" - b.model_package = pkg - result = b._get_s3_artifacts() - assert result is None + with pytest.raises(ValueError, match="model_id is required"): + b.create_provisioned_throughput(model_id="", provisioned_model_name="pt") + + def test_none_model_id_raises(self): + b = _builder() + with pytest.raises(ValueError, match="model_id is required"): + b.create_provisioned_throughput(model_id=None, provisioned_model_name="pt") + + def test_empty_provisioned_model_name_raises(self): + b = _builder() + with pytest.raises(ValueError, match="provisioned_model_name is required"): + b.create_provisioned_throughput( + model_id="arn:model", provisioned_model_name="" + ) + - def test_nova_rmp_no_data_source_returns_none(self): - """When model_data_source is None and managed_storage_type is Restricted, return None.""" - c = _make_container(recipe_name="nova-lite") - c.model_data_source = None +# ── _wait_for_provisioned_throughput_in_service ───────────────────────────── - pkg = _make_model_package(c) - pkg.model_package_arn = "arn:aws:sagemaker:us-east-1:123456789012:model-package/rmp/2" - pkg.managed_storage_type = "Restricted" +class TestWaitForProvisionedThroughputInService: + def test_immediate_in_service(self): b = _builder() - b.model = "not-a-training-job" - b.model_package = pkg - result = b._get_s3_artifacts() - assert result is None + b._bedrock_client = Mock() + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "InService" + } + b._wait_for_provisioned_throughput_in_service("arn:pt") + b._bedrock_client.get_provisioned_model_throughput.assert_called_once_with( + provisionedModelId="arn:pt" + ) - def test_non_nova_rmp_returns_none(self): - """Non-Nova RMP models should also return None.""" - c = _make_container(recipe_name="llama", hub_content_name="llama") - c.model_data_source = None + def test_polls_then_in_service(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_provisioned_model_throughput.side_effect = [ + {"status": "Creating"}, + {"status": "Creating"}, + {"status": "InService"}, + ] + with patch(f"{MODULE}.time.sleep"): + b._wait_for_provisioned_throughput_in_service( + "arn:pt", poll_interval=1, max_wait=10 + ) + assert b._bedrock_client.get_provisioned_model_throughput.call_count == 3 - pkg = _make_model_package(c) - pkg.model_package_arn = "arn:aws:sagemaker:us-east-1:123456789012:model-package/rmp/1" - pkg.managed_storage_type = "Restricted" + def test_failed_raises(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "Failed", + "failureMessage": "Insufficient capacity", + } + with pytest.raises(RuntimeError, match="Insufficient capacity"): + b._wait_for_provisioned_throughput_in_service("arn:pt") + def test_failed_unknown_reason(self): b = _builder() - b.model = "not-a-training-job" - b.model_package = pkg - result = b._get_s3_artifacts() - assert result is None + b._bedrock_client = Mock() + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "Failed" + } + with pytest.raises(RuntimeError, match="Unknown"): + b._wait_for_provisioned_throughput_in_service("arn:pt") + + def test_timeout_raises(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "Creating" + } + with patch(f"{MODULE}.time.sleep"): + with pytest.raises(RuntimeError, match="Timed out"): + b._wait_for_provisioned_throughput_in_service( + "arn:pt", poll_interval=1, max_wait=2 + ) From f0293cff4ac10d21b54883ab0b2124cc4c00ac55 Mon Sep 17 00:00:00 2001 From: Lucas Jia Date: Wed, 3 Jun 2026 11:07:52 -0700 Subject: [PATCH 2/4] imporve docstring --- .../src/sagemaker/serve/bedrock_model_builder.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py b/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py index 786cea18b2..fea478dbb8 100644 --- a/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py +++ b/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py @@ -153,8 +153,11 @@ def deploy( """Deploy the model to Bedrock. Automatically detects if the model is a Nova model and uses the appropriate - Bedrock API (create_custom_model for Nova, create_model_import_job for others). - For Nova models, also creates a custom model deployment for inference. + Bedrock API (create_custom_model for Nova, create_model_import_job for OSS). + For Nova models, creates a custom model deployment and polls until active. + For OSS models, creates a model import job and polls until complete. Once + deploy() returns, the model is ready for on-demand inference. For provisioned + throughput, use the separate create_provisioned_throughput() method. Args: job_name: Name for the model import job (OSS models only). @@ -170,12 +173,12 @@ def deploy( defaults to custom_model_name suffixed with '-deployment'. Returns: - Response from Bedrock API. For Nova models, returns the - create_custom_model_deployment response. For others, returns - the create_model_import_job response. + For Nova models: the create_custom_model_deployment response. + For OSS models: the completed get_model_import_job response. Raises: ValueError: If model_package is not set or required parameters are missing. + RuntimeError: If the import job or deployment fails or times out. """ if not self.model_package: raise ValueError( From 35a8e024df633145851ebb7a39a7ce44b939e953 Mon Sep 17 00:00:00 2001 From: Lucas Jia Date: Thu, 4 Jun 2026 10:26:15 -0700 Subject: [PATCH 3/4] feat: make model_id optional in create_provisioned_throughput - Store imported model ID after deploy() completes - create_provisioned_throughput() now falls back to the stored model ID if model_id is not explicitly passed - Added unit tests for fallback and explicit override behavior --- .../sagemaker/serve/bedrock_model_builder.py | 26 ++++++++----- .../tests/unit/test_bedrock_model_builder.py | 37 +++++++++++++++++++ 2 files changed, 54 insertions(+), 9 deletions(-) diff --git a/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py b/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py index fea478dbb8..c8d70fd75d 100644 --- a/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py +++ b/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py @@ -101,6 +101,7 @@ def __init__( self.model = model self._bedrock_client = None self._sagemaker_client = None + self._imported_model_id = None self.boto_session = Session().boto_session self.model_package = self._fetch_model_package() if model else None self._is_rmp = is_restricted_model_package(self.model_package) @@ -253,10 +254,11 @@ def deploy( job_arn = import_response.get("jobArn") self._wait_for_import_job_complete(job_arn) - # Return the completed job details + # Return the completed job details and store imported model ID job_details = self._get_bedrock_client().get_model_import_job( jobIdentifier=job_arn ) + self._imported_model_id = job_details.get("importedModelName") return job_details def create_deployment( @@ -316,8 +318,8 @@ def create_deployment( def create_provisioned_throughput( self, - model_id: str, - provisioned_model_name: str, + model_id: Optional[str] = None, + provisioned_model_name: str = None, model_units: int = 1, commitment_duration: Optional[str] = None, tags: Optional[list] = None, @@ -330,7 +332,8 @@ def create_provisioned_throughput( throughput reaches InService status. Args: - model_id: ARN or ID of the imported model. + model_id: ARN or name of the model. If not provided, uses the model + ID from the most recent deploy() call. provisioned_model_name: Name for the provisioned throughput resource. model_units: Number of model units to provision. Defaults to 1. commitment_duration: Commitment duration. Valid values: 'OneMonth', @@ -344,17 +347,22 @@ def create_provisioned_throughput( Raises: RuntimeError: If the provisioned throughput fails or times out. - ValueError: If model_id or provisioned_model_name is not provided. + ValueError: If model_id cannot be determined or provisioned_model_name + is not provided. """ - if not model_id: - raise ValueError("model_id is required for create_provisioned_throughput.") + resolved_model_id = model_id or self._imported_model_id + if not resolved_model_id: + raise ValueError( + "model_id is required for create_provisioned_throughput. " + "Either pass it explicitly or call deploy() first." + ) if not provisioned_model_name: raise ValueError( "provisioned_model_name is required for create_provisioned_throughput." ) params = { - "modelId": model_id, + "modelId": resolved_model_id, "provisionedModelName": provisioned_model_name, "modelUnits": model_units, } @@ -366,7 +374,7 @@ def create_provisioned_throughput( logger.info( "Creating provisioned throughput '%s' for model %s with %d model units", provisioned_model_name, - model_id, + resolved_model_id, model_units, ) response = self._get_bedrock_client().create_provisioned_model_throughput(**params) diff --git a/sagemaker-serve/tests/unit/test_bedrock_model_builder.py b/sagemaker-serve/tests/unit/test_bedrock_model_builder.py index 6a5d54ca18..6fb0e8bfb1 100644 --- a/sagemaker-serve/tests/unit/test_bedrock_model_builder.py +++ b/sagemaker-serve/tests/unit/test_bedrock_model_builder.py @@ -768,6 +768,43 @@ def test_empty_provisioned_model_name_raises(self): model_id="arn:model", provisioned_model_name="" ) + def test_uses_imported_model_id_from_deploy(self): + """model_id falls back to _imported_model_id set by deploy().""" + b = _builder() + b._imported_model_id = "my-deployed-model" + b._bedrock_client = Mock() + b._bedrock_client.create_provisioned_model_throughput.return_value = { + "provisionedModelArn": "arn:pt" + } + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "InService" + } + + result = b.create_provisioned_throughput(provisioned_model_name="my-pt") + + kw = b._bedrock_client.create_provisioned_model_throughput.call_args[1] + assert kw["modelId"] == "my-deployed-model" + assert result["provisionedModelArn"] == "arn:pt" + + def test_explicit_model_id_overrides_stored(self): + """Explicit model_id takes precedence over _imported_model_id.""" + b = _builder() + b._imported_model_id = "stored-model" + b._bedrock_client = Mock() + b._bedrock_client.create_provisioned_model_throughput.return_value = { + "provisionedModelArn": "arn:pt" + } + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "InService" + } + + b.create_provisioned_throughput( + model_id="explicit-model", provisioned_model_name="my-pt" + ) + + kw = b._bedrock_client.create_provisioned_model_throughput.call_args[1] + assert kw["modelId"] == "explicit-model" + # ── _wait_for_provisioned_throughput_in_service ───────────────────────────── From 3686e0a28d2c4778a66ba2992fb48b84b9d0e8e8 Mon Sep 17 00:00:00 2001 From: Lucas Jia Date: Thu, 4 Jun 2026 10:30:21 -0700 Subject: [PATCH 4/4] docs: update example notebooks for new OSS deploy polling behavior - bedrock-modelbuilder-deployment.ipynb: deploy() now waits for import completion, removed manual polling cell, added PT usage example - mtrl_finetuning_example_notebook_v3_prod.ipynb: removed manual polling loop, updated description to reflect automatic waiting, added optional create_provisioned_throughput() example --- .../bedrock-modelbuilder-deployment.ipynb | 63 ++++++++----------- ..._finetuning_example_notebook_v3_prod.ipynb | 42 ++++++------- 2 files changed, 46 insertions(+), 59 deletions(-) diff --git a/v3-examples/model-customization-examples/bedrock-modelbuilder-deployment.ipynb b/v3-examples/model-customization-examples/bedrock-modelbuilder-deployment.ipynb index 013fb8002a..7ccb42cd8d 100644 --- a/v3-examples/model-customization-examples/bedrock-modelbuilder-deployment.ipynb +++ b/v3-examples/model-customization-examples/bedrock-modelbuilder-deployment.ipynb @@ -78,9 +78,9 @@ "for file in required_files:\n", " try:\n", " s3_client.head_object(Bucket=BUCKET, Key=model_prefix + file)\n", - " print(f\"✅ {file}\")\n", + " print(f\"\u2705 {file}\")\n", " except:\n", - " print(f\"❌ {file} - MISSING\")" + " print(f\"\u274c {file} - MISSING\")" ] }, { @@ -94,7 +94,7 @@ " # Create added_tokens.json (usually empty for Llama)\n", " try:\n", " s3_client.head_object(Bucket=BUCKET, Key=model_prefix + 'added_tokens.json')\n", - " print(\"✅ added_tokens.json exists\")\n", + " print(\"\u2705 added_tokens.json exists\")\n", " except:\n", " s3_client.put_object(\n", " Bucket=BUCKET,\n", @@ -102,7 +102,7 @@ " Body=json.dumps({}),\n", " ContentType='application/json'\n", " )\n", - " print(\"✅ Created added_tokens.json\")\n", + " print(\"\u2705 Created added_tokens.json\")\n", "\n", "ensure_tokenizer_files()" ] @@ -154,9 +154,9 @@ " CopySource={'Bucket': BUCKET, 'Key': source_key},\n", " Key=dest_key\n", " )\n", - " print(f\"✅ Copied {file_name}\")\n", + " print(f\"\u2705 Copied {file_name}\")\n", " except Exception as e:\n", - " print(f\"❌ Failed to copy {file_name}: {e}\")\n", + " print(f\"\u274c Failed to copy {file_name}: {e}\")\n", " else:\n", " print(\"No files found in hf_merged directory\")\n", "except Exception as e:\n", @@ -173,20 +173,22 @@ "job_name = f\"bedrock-import-{random.randint(1000, 9999)}-{int(time.time())}\"\n", "print(f\"Job name: {job_name}\")\n", "\n", - "# Create builder with correct model path\n", + "# Create builder\n", "bedrock_builder = BedrockModelBuilder(\n", " model=training_job\n", ")\n", "\n", - "# Deploy to Bedrock\n", - "deployment_result = bedrock_builder.deploy(\n", + "# Deploy to Bedrock - this will create the import job and poll until complete.\n", + "# When deploy() returns, the model is ready for on-demand inference.\n", + "deploy_result = bedrock_builder.deploy(\n", " job_name=job_name,\n", " imported_model_name=job_name,\n", " role_arn=ROLE_ARN\n", ")\n", "\n", - "job_arn = deployment_result['jobArn']\n", - "print(f\"Import job started: {job_arn}\")" + "print(f\"Import complete! Model: {deploy_result.get('importedModelName')}\")\n", + "print(f\"Model ARN: {deploy_result.get('importedModelArn')}\")\n", + "print(f\"Status: {deploy_result.get('status')}\")" ] }, { @@ -195,27 +197,16 @@ "metadata": {}, "outputs": [], "source": [ - "# Step 5: Wait for import to complete\n", - "bedrock_client = boto3.client('bedrock', region_name=REGION)\n", + "# Note: Manual polling is no longer needed!\n", + "# deploy() now waits for the import job to complete before returning.\n", + "# The model is ready for on-demand inference immediately after deploy().\n", "\n", - "print(\"Waiting for import to complete...\")\n", - "while True:\n", - " response = bedrock_client.get_model_import_job(jobIdentifier=job_arn)\n", - " status = response['status']\n", - " print(f\"Status: {status}\")\n", - " \n", - " if status == 'Completed':\n", - " imported_model_arn = response['importedModelArn']\n", - " print(f\"✅ Import completed!\")\n", - " print(f\"Model ARN: {imported_model_arn}\")\n", - " break\n", - " elif status in ['Failed', 'Stopped']:\n", - " print(f\"❌ Import failed: {status}\")\n", - " if 'failureMessage' in response:\n", - " print(f\"Error: {response['failureMessage']}\")\n", - " break\n", - " \n", - " time.sleep(30)" + "# If you need provisioned throughput (dedicated capacity), use:\n", + "# pt_result = bedrock_builder.create_provisioned_throughput(\n", + "# provisioned_model_name=\"my-provisioned-model\",\n", + "# model_units=1,\n", + "# )\n", + "# Note: model_id is automatically inferred from the previous deploy() call." ] }, { @@ -242,7 +233,7 @@ " )\n", " \n", " result = json.loads(response['body'].read().decode())\n", - " print(\"\\n🎉 Inference successful (ChatCompletion format)!\")\n", + " print(\"\\n\ud83c\udf89 Inference successful (ChatCompletion format)!\")\n", " print(f\"Response: {result}\")\n", " \n", " except Exception as e1:\n", @@ -261,14 +252,14 @@ " )\n", " \n", " result = json.loads(response['body'].read().decode())\n", - " print(\"\\n🎉 Inference successful (BedrockMeta format)!\")\n", + " print(\"\\n\ud83c\udf89 Inference successful (BedrockMeta format)!\")\n", " print(f\"Response: {result}\")\n", " \n", " except Exception as e2:\n", " print(f\"BedrockMeta failed: {e2}\")\n", - " print(\"❌ Both formats failed. Check model documentation for correct format.\")\n", + " print(\"\u274c Both formats failed. Check model documentation for correct format.\")\n", "else:\n", - " print(\"❌ Import failed, cannot test inference\")" + " print(\"\u274c Import failed, cannot test inference\")" ] }, { @@ -360,4 +351,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/v3-examples/model-customization-examples/mtrl_finetuning_example_notebook_v3_prod.ipynb b/v3-examples/model-customization-examples/mtrl_finetuning_example_notebook_v3_prod.ipynb index d6977ad85d..564fade811 100644 --- a/v3-examples/model-customization-examples/mtrl_finetuning_example_notebook_v3_prod.ipynb +++ b/v3-examples/model-customization-examples/mtrl_finetuning_example_notebook_v3_prod.ipynb @@ -796,7 +796,12 @@ "source": [ "### 13. Deploy to Amazon Bedrock\n", "\n", - "Alternatively, deploy the fine-tuned model to Amazon Bedrock as an imported model with provisioned throughput. This provides a fully managed inference experience without managing endpoints." + "Deploy the fine-tuned model to Amazon Bedrock as an imported model. The `deploy()` method\n", + "creates the import job and polls until complete \u2014 when it returns, the model is ready for\n", + "on-demand inference.\n", + "\n", + "If you need dedicated throughput, you can optionally call `create_provisioned_throughput()`\n", + "as a separate step after deploy." ] }, { @@ -809,21 +814,23 @@ "\n", "bedrock_builder = BedrockModelBuilder(model=trainer)\n", "\n", + "# deploy() waits for import to complete. Model is ready for on-demand inference after this.\n", "bedrock_response = bedrock_builder.deploy(\n", " imported_model_name=\"mtrl-finetuned-model\",\n", " role_arn=\"arn:aws:iam::123456789012:role/BedrockRole\",\n", - " deployment_name=\"mtrl-deployment\",\n", ")\n", - "print(f\"Bedrock deployment response: {bedrock_response}\")" + "print(f\"Import complete! Model: {bedrock_response.get('importedModelName')}\")\n", + "print(f\"Status: {bedrock_response.get('status')}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "#### Wait for Bedrock Deployment\n", + "#### Optional: Create Provisioned Throughput\n", "\n", - "Poll until the provisioned model throughput reaches `InService` status before invoking." + "If you need dedicated capacity with guaranteed throughput, create provisioned throughput.\n", + "This is optional \u2014 on-demand inference works immediately after `deploy()` returns." ] }, { @@ -832,25 +839,14 @@ "outputs": [], "execution_count": null, "source": [ - "import time\n", - "import boto3\n", - "\n", - "bedrock_client = boto3.client(\"bedrock\", region_name=REGION)\n", - "deployment_name = \"mtrl-deployment\"\n", + "# Optional: Create provisioned throughput for dedicated capacity\n", + "# model_id is automatically inferred from the previous deploy() call.\n", "\n", - "print(f\"Waiting for Bedrock deployment '{deployment_name}'...\")\n", - "while True:\n", - " status_response = bedrock_client.get_provisioned_model_throughput(\n", - " provisionedModelId=deployment_name\n", - " )\n", - " status = status_response[\"status\"]\n", - " print(f\" Status: {status}\")\n", - " if status == \"InService\":\n", - " print(\"Deployment successful!\")\n", - " break\n", - " elif status in (\"Failed\", \"Cancelled\"):\n", - " raise RuntimeError(f\"Bedrock deployment failed with status: {status}\")\n", - " time.sleep(30)" + "# pt_result = bedrock_builder.create_provisioned_throughput(\n", + "# provisioned_model_name=\"mtrl-provisioned\",\n", + "# model_units=1,\n", + "# )\n", + "# print(f\"Provisioned throughput ARN: {pt_result.get('provisionedModelArn')}\")" ] }, {