diff --git a/tests/unit/vertex_rag/test_rag_constants.py b/tests/unit/vertex_rag/test_rag_constants.py index 65459cec2a..19fa4d3970 100644 --- a/tests/unit/vertex_rag/test_rag_constants.py +++ b/tests/unit/vertex_rag/test_rag_constants.py @@ -82,6 +82,7 @@ TEST_CORPUS_DISPLAY_NAME = "my-corpus-1" TEST_CORPUS_DISCRIPTION = "My first corpus." TEST_RAG_CORPUS_ID = "generate-123" +TEST_RAG_CORPUS_NUMERIC_ID = "1234567890" TEST_API_ENDPOINT = "us-central1-" + aiplatform.constants.base.API_BASE_PATH TEST_RAG_CORPUS_RESOURCE_NAME = f"projects/{TEST_PROJECT_NUMBER}/locations/{TEST_REGION}/ragCorpora/{TEST_RAG_CORPUS_ID}" @@ -244,6 +245,7 @@ TEST_API_ENDPOINT, TEST_PROJECT_NUMBER, TEST_REGION, TEST_RAG_CORPUS_ID ) TEST_RAG_FILE_ID = "generate-456" +TEST_RAG_FILE_NUMERIC_ID = "9876543210" TEST_RAG_FILE_RESOURCE_NAME = ( TEST_RAG_CORPUS_RESOURCE_NAME + f"/ragFiles/{TEST_RAG_FILE_ID}" ) diff --git a/tests/unit/vertex_rag/test_rag_constants_preview.py b/tests/unit/vertex_rag/test_rag_constants_preview.py index 0c0f3c810c..e41950b1a5 100644 --- a/tests/unit/vertex_rag/test_rag_constants_preview.py +++ b/tests/unit/vertex_rag/test_rag_constants_preview.py @@ -91,6 +91,7 @@ TEST_CORPUS_DISPLAY_NAME = "my-corpus-1" TEST_CORPUS_DISCRIPTION = "My first corpus." TEST_RAG_CORPUS_ID = "generate-123" +TEST_RAG_CORPUS_NUMERIC_ID = "1234567890" TEST_API_ENDPOINT = "us-central1-" + aiplatform.constants.base.API_BASE_PATH TEST_RAG_CORPUS_RESOURCE_NAME = f"projects/{TEST_PROJECT_NUMBER}/locations/{TEST_REGION}/ragCorpora/{TEST_RAG_CORPUS_ID}" @@ -489,6 +490,7 @@ TEST_API_ENDPOINT, TEST_PROJECT_NUMBER, TEST_REGION, TEST_RAG_CORPUS_ID ) TEST_RAG_FILE_ID = "generate-456" +TEST_RAG_FILE_NUMERIC_ID = "9876543210" TEST_RAG_FILE_RESOURCE_NAME = ( TEST_RAG_CORPUS_RESOURCE_NAME + f"/ragFiles/{TEST_RAG_FILE_ID}" ) diff --git a/tests/unit/vertex_rag/test_rag_data.py b/tests/unit/vertex_rag/test_rag_data.py index 94f6c35bf9..adc67c70c2 100644 --- a/tests/unit/vertex_rag/test_rag_data.py +++ b/tests/unit/vertex_rag/test_rag_data.py @@ -663,6 +663,26 @@ def test_get_corpus_id_success(self): rag_corpus = rag.get_corpus(test_rag_constants.TEST_RAG_CORPUS_ID) rag_corpus_eq(rag_corpus, test_rag_constants.TEST_RAG_CORPUS) + def test_get_corpus_numeric_id_success(self): + """Bare numeric IDs must pass the regex and be expanded to full resource names.""" + with mock.patch.object( + rag.utils._gapic_utils, "create_rag_data_service_client" + ) as mock_client_factory: + api_client_mock = mock.Mock(spec=VertexRagDataServiceClient) + api_client_mock.parse_rag_corpus_path.side_effect = ( + VertexRagDataServiceClient.parse_rag_corpus_path + ) + api_client_mock.rag_corpus_path.side_effect = ( + VertexRagDataServiceClient.rag_corpus_path + ) + api_client_mock.get_rag_corpus.return_value = ( + test_rag_constants.TEST_GAPIC_RAG_CORPUS + ) + mock_client_factory.return_value = api_client_mock + + rag_corpus = rag.get_corpus(test_rag_constants.TEST_RAG_CORPUS_NUMERIC_ID) + rag_corpus_eq(rag_corpus, test_rag_constants.TEST_RAG_CORPUS) + @pytest.mark.usefixtures("rag_data_client_mock_exception") def test_get_corpus_failure(self): with pytest.raises(RuntimeError) as e: @@ -883,6 +903,35 @@ def test_get_file_id_success(self): ) rag_file_eq(rag_file, test_rag_constants.TEST_RAG_FILE) + def test_get_file_numeric_id_success(self): + """Bare numeric IDs must pass the regex and be expanded to full resource names.""" + with mock.patch.object( + rag.utils._gapic_utils, "create_rag_data_service_client" + ) as mock_client_factory: + api_client_mock = mock.Mock(spec=VertexRagDataServiceClient) + api_client_mock.parse_rag_corpus_path.side_effect = ( + VertexRagDataServiceClient.parse_rag_corpus_path + ) + api_client_mock.parse_rag_file_path.side_effect = ( + VertexRagDataServiceClient.parse_rag_file_path + ) + api_client_mock.rag_corpus_path.side_effect = ( + VertexRagDataServiceClient.rag_corpus_path + ) + api_client_mock.rag_file_path.side_effect = ( + VertexRagDataServiceClient.rag_file_path + ) + api_client_mock.get_rag_file.return_value = ( + test_rag_constants.TEST_GAPIC_RAG_FILE + ) + mock_client_factory.return_value = api_client_mock + + rag_file = rag.get_file( + name=test_rag_constants.TEST_RAG_FILE_NUMERIC_ID, + corpus_name=test_rag_constants.TEST_RAG_CORPUS_NUMERIC_ID, + ) + rag_file_eq(rag_file, test_rag_constants.TEST_RAG_FILE) + @pytest.mark.usefixtures("rag_data_client_mock_exception") def test_get_file_failure(self): with pytest.raises(RuntimeError) as e: diff --git a/tests/unit/vertex_rag/test_rag_data_preview.py b/tests/unit/vertex_rag/test_rag_data_preview.py index b1e7d4c3b0..2acbdb87ff 100644 --- a/tests/unit/vertex_rag/test_rag_data_preview.py +++ b/tests/unit/vertex_rag/test_rag_data_preview.py @@ -1315,6 +1315,28 @@ def test_get_corpus_id_success(self): rag_corpus = rag.get_corpus(test_rag_constants_preview.TEST_RAG_CORPUS_ID) rag_corpus_eq(rag_corpus, test_rag_constants_preview.TEST_RAG_CORPUS) + def test_get_corpus_numeric_id_success(self): + """Bare numeric IDs must pass the regex and be expanded to full resource names.""" + with mock.patch.object( + rag.utils._gapic_utils, "create_rag_data_service_client" + ) as mock_client_factory: + api_client_mock = mock.Mock(spec=VertexRagDataServiceClient) + api_client_mock.parse_rag_corpus_path.side_effect = ( + VertexRagDataServiceClient.parse_rag_corpus_path + ) + api_client_mock.rag_corpus_path.side_effect = ( + VertexRagDataServiceClient.rag_corpus_path + ) + api_client_mock.get_rag_corpus.return_value = ( + test_rag_constants_preview.TEST_GAPIC_RAG_CORPUS + ) + mock_client_factory.return_value = api_client_mock + + rag_corpus = rag.get_corpus( + test_rag_constants_preview.TEST_RAG_CORPUS_NUMERIC_ID + ) + rag_corpus_eq(rag_corpus, test_rag_constants_preview.TEST_RAG_CORPUS) + @pytest.mark.usefixtures("rag_data_client_preview_mock_exception") def test_get_corpus_failure(self): with pytest.raises(RuntimeError) as e: @@ -1454,6 +1476,35 @@ def test_get_file_id_success(self): ) rag_file_eq(rag_file, test_rag_constants_preview.TEST_RAG_FILE) + def test_get_file_numeric_id_success(self): + """Bare numeric IDs must pass the regex and be expanded to full resource names.""" + with mock.patch.object( + rag.utils._gapic_utils, "create_rag_data_service_client" + ) as mock_client_factory: + api_client_mock = mock.Mock(spec=VertexRagDataServiceClient) + api_client_mock.parse_rag_corpus_path.side_effect = ( + VertexRagDataServiceClient.parse_rag_corpus_path + ) + api_client_mock.parse_rag_file_path.side_effect = ( + VertexRagDataServiceClient.parse_rag_file_path + ) + api_client_mock.rag_corpus_path.side_effect = ( + VertexRagDataServiceClient.rag_corpus_path + ) + api_client_mock.rag_file_path.side_effect = ( + VertexRagDataServiceClient.rag_file_path + ) + api_client_mock.get_rag_file.return_value = ( + test_rag_constants_preview.TEST_GAPIC_RAG_FILE + ) + mock_client_factory.return_value = api_client_mock + + rag_file = rag.get_file( + name=test_rag_constants_preview.TEST_RAG_FILE_NUMERIC_ID, + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_NUMERIC_ID, + ) + rag_file_eq(rag_file, test_rag_constants_preview.TEST_RAG_FILE) + @pytest.mark.usefixtures("rag_data_client_preview_mock_exception") def test_get_file_failure(self): with pytest.raises(RuntimeError) as e: diff --git a/vertexai/preview/rag/utils/_gapic_utils.py b/vertexai/preview/rag/utils/_gapic_utils.py index 4150111c60..5bcb1e9a8e 100644 --- a/vertexai/preview/rag/utils/_gapic_utils.py +++ b/vertexai/preview/rag/utils/_gapic_utils.py @@ -78,7 +78,8 @@ ) -_VALID_RESOURCE_NAME_REGEX = "[a-z][a-zA-Z0-9._-]{0,127}" +# Allows numeric resource IDs (e.g. "1234567890") as bare names. +_VALID_RESOURCE_NAME_REGEX = "[a-zA-Z0-9][a-zA-Z0-9._-]{0,127}" _VALID_DOCUMENT_AI_PROCESSOR_NAME_REGEX = ( r"projects/[^/]+/locations/[^/]+/processors/[^/]+(?:/processorVersions/[^/]+)?" ) diff --git a/vertexai/rag/utils/_gapic_utils.py b/vertexai/rag/utils/_gapic_utils.py index 3ee39a7a0f..6535767b71 100644 --- a/vertexai/rag/utils/_gapic_utils.py +++ b/vertexai/rag/utils/_gapic_utils.py @@ -67,7 +67,8 @@ ) -_VALID_RESOURCE_NAME_REGEX = "[a-z][a-zA-Z0-9._-]{0,127}" +# Allows numeric resource IDs (e.g. "1234567890") as bare names. +_VALID_RESOURCE_NAME_REGEX = "[a-zA-Z0-9][a-zA-Z0-9._-]{0,127}" _VALID_DOCUMENT_AI_PROCESSOR_NAME_REGEX = ( r"projects/[^/]+/locations/[^/]+/processors/[^/]+(?:/processorVersions/[^/]+)?" )