From 6401b07a1978ea7b57bac6985be6fdc7fc9229f1 Mon Sep 17 00:00:00 2001 From: "Dina Berry (She/her)" Date: Fri, 24 Apr 2026 07:07:19 -0700 Subject: [PATCH 1/4] Add Java select-algorithm sample Adds a Java sample demonstrating how to choose and configure vector search algorithms (IVF, HNSW, DiskANN) in Azure DocumentDB. Uses DefaultAzureCredential for passwordless auth and Azure OpenAI for embeddings. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- ai/select-algorithm-java/.gitignore | 12 ++ ai/select-algorithm-java/README.md | 107 ++++++++++ ai/select-algorithm-java/pom.xml | 78 +++++++ .../selectalgorithm/SelectAlgorithm.java | 196 +++++++++++++++++ .../documentdb/selectalgorithm/Utils.java | 199 ++++++++++++++++++ 5 files changed, 592 insertions(+) create mode 100644 ai/select-algorithm-java/.gitignore create mode 100644 ai/select-algorithm-java/README.md create mode 100644 ai/select-algorithm-java/pom.xml create mode 100644 ai/select-algorithm-java/src/main/java/com/azure/documentdb/selectalgorithm/SelectAlgorithm.java create mode 100644 ai/select-algorithm-java/src/main/java/com/azure/documentdb/selectalgorithm/Utils.java diff --git a/ai/select-algorithm-java/.gitignore b/ai/select-algorithm-java/.gitignore new file mode 100644 index 0000000..bebdf5e --- /dev/null +++ b/ai/select-algorithm-java/.gitignore @@ -0,0 +1,12 @@ +# Build output +target/ + +# IDE +.idea/ +*.iml +.classpath +.project +.settings/ + +# Environment +.env diff --git a/ai/select-algorithm-java/README.md b/ai/select-algorithm-java/README.md new file mode 100644 index 0000000..7c4d2c4 --- /dev/null +++ b/ai/select-algorithm-java/README.md @@ -0,0 +1,107 @@ +# DocumentDB Vector Index Algorithm Comparison (Java) + +This sample compares DocumentDB vector index algorithms (DiskANN, HNSW, IVF) across similarity functions (COS, L2, IP) to help you choose the best configuration for your use case. + +## Overview + +Vector indexes improve search performance by organizing vectors for efficient similarity searches. This sample: + +- Creates collections per algorithm/similarity combination +- Configures algorithm-specific index parameters +- Measures query latency for each configuration +- Displays a comparison table to guide your selection + +## Prerequisites + +- [Java 21 or higher](https://learn.microsoft.com/java/openjdk/download) +- [Maven 3.6 or higher](https://maven.apache.org/download.cgi) +- [Azure CLI](https://learn.microsoft.com/cli/azure/install-azure-cli) +- Azure subscription with: + - Azure DocumentDB (MongoDB vCore) cluster + - Azure OpenAI with text-embedding-3-small model + - Managed identity configured for passwordless auth + +## Setup + +1. Copy `.env.example` to `.env`: + ```bash + cp .env.example .env + ``` + +2. Update `.env` with your Azure resource values. The sample uses the + [dotenv-java](https://github.com/cdimascio/dotenv-java) library to load + variables from `.env` at startup, falling back to system environment + variables when the file is absent. + +3. Sign in to Azure for passwordless authentication: + ```bash + az login + ``` + +4. Compile the project: + ```bash + mvn clean compile + ``` + +> **Note:** This sample does not include a Maven Wrapper (`mvnw`). Install +> Maven 3.6+ from and ensure `mvn` is +> on your PATH. + +## Usage + +Run the comparison for specific or all algorithms and similarity functions: + +```bash +# Compare all algorithms with cosine similarity +mvn exec:java -Dexec.mainClass="com.azure.documentdb.selectalgorithm.SelectAlgorithm" + +# Compare only DiskANN with all similarity functions +ALGORITHM=diskann SIMILARITY=all mvn exec:java -Dexec.mainClass="com.azure.documentdb.selectalgorithm.SelectAlgorithm" + +# Compare HNSW with L2 (Euclidean) distance +ALGORITHM=hnsw SIMILARITY=L2 mvn exec:java -Dexec.mainClass="com.azure.documentdb.selectalgorithm.SelectAlgorithm" + +# Compare all algorithms and all similarity functions +ALGORITHM=all SIMILARITY=all mvn exec:java -Dexec.mainClass="com.azure.documentdb.selectalgorithm.SelectAlgorithm" +``` + +### Environment Variables + +- `ALGORITHM`: Which algorithm(s) to test + - `all` (default): Test DiskANN, HNSW, and IVF + - `diskann`: Test only DiskANN + - `hnsw`: Test only HNSW + - `ivf`: Test only IVF + +- `SIMILARITY`: Which similarity function(s) to test + - `COS` (default): Cosine similarity + - `L2`: Euclidean distance + - `IP`: Inner product + - `all`: Test all similarity functions + +## Algorithm Characteristics + +### DiskANN +- Disk-based for large datasets +- Good balance of speed and accuracy +- Parameters: maxDegree=32, lBuild=50, lSearch=100 + +### HNSW +- Memory-based hierarchical graph +- Excellent for real-time applications +- Parameters: m=16, efConstruction=64, efSearch=80 + +### IVF +- Cluster-based partitioning +- Fast search via centroids +- Parameters: numLists=1, nProbes=1 + +## Output + +The sample prints a comparison table showing latency per query for each algorithm/similarity combination, helping you make an informed choice. + +## Further Resources + +- [Azure DocumentDB Documentation](https://learn.microsoft.com/azure/documentdb/) +- [Vector Search in DocumentDB](https://learn.microsoft.com/azure/documentdb/vector-search) +- [MongoDB Java Driver Documentation](https://mongodb.github.io/mongo-java-driver/) diff --git a/ai/select-algorithm-java/pom.xml b/ai/select-algorithm-java/pom.xml new file mode 100644 index 0000000..b396af3 --- /dev/null +++ b/ai/select-algorithm-java/pom.xml @@ -0,0 +1,78 @@ + + 4.0.0 + + com.azure.documentdb.samples + select-algorithm-java + 1.0-SNAPSHOT + Azure DocumentDB Vector Algorithm Comparison + Compare DocumentDB vector index algorithms (DiskANN, HNSW, IVF) using Java SDK + + + 21 + 21 + 21 + UTF-8 + + + + + + com.azure + azure-sdk-bom + 1.2.29 + pom + import + + + + + + + org.mongodb + mongodb-driver-sync + 5.6.2 + + + com.azure + azure-identity + + + com.azure + azure-ai-openai + + + com.fasterxml.jackson.core + jackson-databind + 2.18.2 + + + io.github.cdimascio + dotenv-java + 3.0.2 + + + org.slf4j + slf4j-simple + 2.0.17 + runtime + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.13.0 + + 21 + + -Xlint:all + + + + + + diff --git a/ai/select-algorithm-java/src/main/java/com/azure/documentdb/selectalgorithm/SelectAlgorithm.java b/ai/select-algorithm-java/src/main/java/com/azure/documentdb/selectalgorithm/SelectAlgorithm.java new file mode 100644 index 0000000..abfa2e3 --- /dev/null +++ b/ai/select-algorithm-java/src/main/java/com/azure/documentdb/selectalgorithm/SelectAlgorithm.java @@ -0,0 +1,196 @@ +package com.azure.documentdb.selectalgorithm; + +import com.azure.ai.openai.OpenAIClient; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.MongoDatabase; +import com.mongodb.client.model.Indexes; +import org.bson.Document; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class SelectAlgorithm { + private static final String SAMPLE_QUERY = "quintessential lodging near running trails, eateries, retail"; + private static final String DATABASE_NAME = "Hotels"; + private static final int NUM_QUERIES = 5; + + public static void main(String[] args) { + Utils.loadEnv(); + new SelectAlgorithm().run(); + System.exit(0); + } + + public void run() { + try (var mongoClient = Utils.createMongoClient()) { + var openAIClient = Utils.createOpenAIClient(); + + var algorithmParam = Utils.getEnv("ALGORITHM", "all").toLowerCase(); + var similarityParam = Utils.getEnv("SIMILARITY", "COS").toUpperCase(); + + var algorithms = getAlgorithms(algorithmParam); + var similarities = getSimilarities(similarityParam); + + System.out.println("Testing algorithms: " + algorithms); + System.out.println("Testing similarity functions: " + similarities); + System.out.println(); + + var results = new ArrayList>(); + var database = mongoClient.getDatabase(DATABASE_NAME); + + for (var algorithm : algorithms) { + for (var similarity : similarities) { + var result = testConfiguration(database, openAIClient, algorithm, similarity); + results.add(result); + } + } + + Utils.printComparisonTable(results); + + } catch (Exception e) { + System.err.println("Error: " + e.getMessage()); + e.printStackTrace(); + } + } + + private List getAlgorithms(String param) { + if ("all".equals(param)) { + return List.of("diskann", "hnsw", "ivf"); + } + return List.of(param); + } + + private List getSimilarities(String param) { + if ("all".equalsIgnoreCase(param)) { + return List.of("COS", "L2", "IP"); + } + return List.of(param); + } + + private Map testConfiguration(MongoDatabase database, OpenAIClient openAIClient, + String algorithm, String similarity) { + System.out.println("Testing " + algorithm.toUpperCase() + " with " + similarity + " similarity..."); + + var collectionName = "hotels_" + algorithm.toLowerCase() + "_" + similarity.toLowerCase(); + var vectorIndexName = "vectorIndex_" + algorithm.toLowerCase() + "_" + similarity.toLowerCase(); + + try { + var collection = database.getCollection(collectionName, Document.class); + collection.drop(); + database.createCollection(collectionName); + System.out.println(" Created collection: " + collectionName); + + var hotelData = Utils.loadHotelData(); + insertDataInBatches(collection, hotelData); + + createStandardIndexes(collection); + createVectorIndex(database, collectionName, vectorIndexName, algorithm, similarity); + + var queryEmbedding = Utils.createEmbedding(openAIClient, SAMPLE_QUERY); + var avgLatency = measureSearchLatency(collection, queryEmbedding, algorithm); + + System.out.println(" Average latency: " + String.format("%.2f", avgLatency) + " ms"); + System.out.println(); + + var result = new HashMap(); + result.put("algorithm", algorithm.toUpperCase()); + result.put("similarity", similarity); + result.put("latency", avgLatency); + return result; + + } catch (Exception e) { + System.err.println(" Error testing " + algorithm + " with " + similarity + ": " + e.getMessage()); + var result = new HashMap(); + result.put("algorithm", algorithm.toUpperCase()); + result.put("similarity", similarity); + result.put("latency", -1.0); + return result; + } + } + + private void insertDataInBatches(MongoCollection collection, List> hotelData) { + var batchSizeStr = Utils.getEnv("LOAD_SIZE_BATCH"); + var batchSize = batchSizeStr != null ? Integer.parseInt(batchSizeStr) : 100; + var batches = Utils.partitionList(hotelData, batchSize); + + System.out.println(" Loading data in batches of " + batchSize + "..."); + + for (int i = 0; i < batches.size(); i++) { + var batch = batches.get(i); + var documents = batch.stream() + .map(Document::new) + .toList(); + + collection.insertMany(documents); + if ((i + 1) % 10 == 0 || (i + 1) == batches.size()) { + System.out.println(" Loaded " + ((i + 1) * batchSize) + " documents"); + } + } + } + + private void createStandardIndexes(MongoCollection collection) { + collection.createIndex(Indexes.ascending("HotelId")); + collection.createIndex(Indexes.ascending("Category")); + collection.createIndex(Indexes.ascending("Description")); + collection.createIndex(Indexes.ascending("Description_fr")); + } + + private void createVectorIndex(MongoDatabase database, String collectionName, String indexName, + String algorithm, String similarity) { + var embeddedField = Utils.getEnv("EMBEDDED_FIELD"); + var cosmosSearchOptions = Utils.createVectorIndexOptions(algorithm, similarity); + + var indexDefinition = new Document() + .append("createIndexes", collectionName) + .append("indexes", List.of( + new Document() + .append("name", indexName) + .append("key", new Document(embeddedField, "cosmosSearch")) + .append("cosmosSearchOptions", cosmosSearchOptions) + )); + + database.runCommand(indexDefinition); + System.out.println(" Created vector index: " + indexName); + } + + private double measureSearchLatency(MongoCollection collection, List queryEmbedding, + String algorithm) { + var embeddedField = Utils.getEnv("EMBEDDED_FIELD"); + var searchOptions = Utils.createSearchOptions(algorithm); + + var totalLatency = 0.0; + + for (int i = 0; i < NUM_QUERIES; i++) { + var cosmosSearch = new Document() + .append("vector", queryEmbedding) + .append("path", embeddedField) + .append("k", 5); + + if (!searchOptions.isEmpty()) { + cosmosSearch.putAll(searchOptions); + } + + var searchStage = new Document("$search", new Document() + .append("cosmosSearch", cosmosSearch) + ); + + var projectStage = new Document("$project", new Document() + .append("score", new Document("$meta", "searchScore")) + .append("HotelName", 1) + ); + + var pipeline = List.of(searchStage, projectStage); + + var startTime = System.nanoTime(); + var results = collection.aggregate(pipeline); + results.first(); + var endTime = System.nanoTime(); + + totalLatency += (endTime - startTime) / 1_000_000.0; + } + + return totalLatency / NUM_QUERIES; + } +} diff --git a/ai/select-algorithm-java/src/main/java/com/azure/documentdb/selectalgorithm/Utils.java b/ai/select-algorithm-java/src/main/java/com/azure/documentdb/selectalgorithm/Utils.java new file mode 100644 index 0000000..61fe025 --- /dev/null +++ b/ai/select-algorithm-java/src/main/java/com/azure/documentdb/selectalgorithm/Utils.java @@ -0,0 +1,199 @@ +package com.azure.documentdb.selectalgorithm; + +import com.azure.ai.openai.OpenAIClient; +import com.azure.ai.openai.OpenAIClientBuilder; +import com.azure.ai.openai.models.EmbeddingsOptions; +import com.azure.core.http.policy.ExponentialBackoffOptions; +import com.azure.core.http.policy.RetryOptions; +import com.azure.identity.DefaultAzureCredentialBuilder; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.mongodb.ConnectionString; +import com.mongodb.MongoClientSettings; +import com.mongodb.MongoCredential; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoClients; +import io.github.cdimascio.dotenv.Dotenv; +import org.bson.Document; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +public class Utils { + private static Dotenv dotenv; + private static final ObjectMapper objectMapper = new ObjectMapper(); + + public static void loadEnv() { + try { + dotenv = Dotenv.configure() + .ignoreIfMissing() + .load(); + } catch (Exception e) { + System.err.println("Warning: Could not load .env file, using system environment variables"); + } + } + + public static String getEnv(String key) { + if (dotenv != null) { + String value = dotenv.get(key); + if (value != null) return value; + } + return System.getenv(key); + } + + public static String getEnv(String key, String defaultValue) { + String value = getEnv(key); + return value != null ? value : defaultValue; + } + + public static MongoClient createMongoClient() { + var clusterName = getEnv("MONGO_CLUSTER_NAME"); + var managedIdentityPrincipalId = getEnv("AZURE_MANAGED_IDENTITY_PRINCIPAL_ID"); + var azureCredential = new DefaultAzureCredentialBuilder().build(); + + MongoCredential.OidcCallback callback = (MongoCredential.OidcCallbackContext context) -> { + var token = azureCredential.getToken( + new com.azure.core.credential.TokenRequestContext() + .addScopes("https://ossrdbms-aad.database.windows.net/.default") + ).block(); + + if (token == null) { + throw new RuntimeException("Failed to obtain Azure AD token"); + } + + return new MongoCredential.OidcCallbackResult(token.getToken()); + }; + + var credential = MongoCredential.createOidcCredential(null) + .withMechanismProperty("OIDC_CALLBACK", callback); + + var connectionString = new ConnectionString( + String.format("mongodb+srv://%s@%s.mongocluster.cosmos.azure.com/?authMechanism=MONGODB-OIDC&tls=true&retrywrites=false&maxIdleTimeMS=120000", + managedIdentityPrincipalId, clusterName) + ); + + var settings = MongoClientSettings.builder() + .applyConnectionString(connectionString) + .credential(credential) + .retryWrites(true) + .retryReads(true) + .build(); + + return MongoClients.create(settings); + } + + public static OpenAIClient createOpenAIClient() { + var endpoint = getEnv("AZURE_OPENAI_EMBEDDING_ENDPOINT"); + var credential = new DefaultAzureCredentialBuilder().build(); + + return new OpenAIClientBuilder() + .endpoint(endpoint) + .credential(credential) + .retryOptions(new RetryOptions( + new ExponentialBackoffOptions() + .setMaxRetries(3) + .setBaseDelay(Duration.ofSeconds(1)) + .setMaxDelay(Duration.ofSeconds(30)) + )) + .buildClient(); + } + + public static List> loadHotelData() throws IOException { + var dataFile = getEnv("DATA_FILE_WITH_VECTORS"); + var filePath = Path.of(dataFile); + + System.out.println("Reading JSON file from " + filePath.toAbsolutePath()); + var jsonContent = Files.readString(filePath); + + return objectMapper.readValue(jsonContent, new TypeReference>>() {}); + } + + public static List createEmbedding(OpenAIClient openAIClient, String text) { + var model = getEnv("AZURE_OPENAI_EMBEDDING_MODEL"); + var options = new EmbeddingsOptions(List.of(text)); + + var response = openAIClient.getEmbeddings(model, options); + return response.getData().get(0).getEmbedding().stream() + .map(Float::doubleValue) + .toList(); + } + + public static Document createVectorIndexOptions(String algorithm, String similarity) { + var embeddedField = getEnv("EMBEDDED_FIELD"); + var dimensionsStr = getEnv("EMBEDDING_DIMENSIONS"); + var dimensions = dimensionsStr != null ? Integer.parseInt(dimensionsStr) : 1536; + + var options = new Document() + .append("kind", getVectorKind(algorithm)) + .append("dimensions", dimensions) + .append("similarity", similarity); + + switch (algorithm.toLowerCase()) { + case "diskann": + options.append("maxDegree", 32) + .append("lBuild", 50); + break; + case "hnsw": + options.append("m", 16) + .append("efConstruction", 64); + break; + case "ivf": + options.append("numLists", 1); + break; + } + + return options; + } + + public static Document createSearchOptions(String algorithm) { + var options = new Document(); + + switch (algorithm.toLowerCase()) { + case "diskann": + options.append("lSearch", 100); + break; + case "hnsw": + options.append("efSearch", 80); + break; + case "ivf": + options.append("nProbes", 1); + break; + } + + return options; + } + + private static String getVectorKind(String algorithm) { + return "vector-" + algorithm.toLowerCase(); + } + + public static List> partitionList(List list, int batchSize) { + var partitions = new ArrayList>(); + for (int i = 0; i < list.size(); i += batchSize) { + partitions.add(list.subList(i, Math.min(i + batchSize, list.size()))); + } + return partitions; + } + + public static void printComparisonTable(List> results) { + System.out.println("\n" + "=".repeat(80)); + System.out.println("Vector Index Algorithm Comparison Results"); + System.out.println("=".repeat(80)); + System.out.printf("%-15s %-15s %-20s%n", "Algorithm", "Similarity", "Avg Latency (ms)"); + System.out.println("-".repeat(80)); + + for (var result : results) { + System.out.printf("%-15s %-15s %-20.2f%n", + result.get("algorithm"), + result.get("similarity"), + result.get("latency")); + } + + System.out.println("=".repeat(80)); + } +} From 6db5f29e819e7c3afaf21132c26d9f57b745c283 Mon Sep 17 00:00:00 2001 From: "Dina Berry (She/her)" Date: Mon, 27 Apr 2026 11:50:25 -0700 Subject: [PATCH 2/4] Fix single-query search and align timeout - Replace 5-iteration latency averaging with single query execution (matches all other language samples) - Return all k=5 results instead of just first() - Print search results with scores (matches TS/Python/Go/.NET output) - Add connectTimeoutMS=120000 to connection string (matches other samples) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../selectalgorithm/SelectAlgorithm.java | 72 ++++++++++--------- .../documentdb/selectalgorithm/Utils.java | 2 +- 2 files changed, 41 insertions(+), 33 deletions(-) diff --git a/ai/select-algorithm-java/src/main/java/com/azure/documentdb/selectalgorithm/SelectAlgorithm.java b/ai/select-algorithm-java/src/main/java/com/azure/documentdb/selectalgorithm/SelectAlgorithm.java index abfa2e3..5783e43 100644 --- a/ai/select-algorithm-java/src/main/java/com/azure/documentdb/selectalgorithm/SelectAlgorithm.java +++ b/ai/select-algorithm-java/src/main/java/com/azure/documentdb/selectalgorithm/SelectAlgorithm.java @@ -15,7 +15,7 @@ public class SelectAlgorithm { private static final String SAMPLE_QUERY = "quintessential lodging near running trails, eateries, retail"; private static final String DATABASE_NAME = "Hotels"; - private static final int NUM_QUERIES = 5; + public static void main(String[] args) { Utils.loadEnv(); @@ -89,15 +89,19 @@ private Map testConfiguration(MongoDatabase database, OpenAIClie createVectorIndex(database, collectionName, vectorIndexName, algorithm, similarity); var queryEmbedding = Utils.createEmbedding(openAIClient, SAMPLE_QUERY); - var avgLatency = measureSearchLatency(collection, queryEmbedding, algorithm); + var searchResult = executeVectorSearch(collection, queryEmbedding, algorithm); - System.out.println(" Average latency: " + String.format("%.2f", avgLatency) + " ms"); + System.out.println(" Latency: " + String.format("%.2f", searchResult.latencyMs) + " ms"); + System.out.println(" Results: " + searchResult.results.size()); + for (var doc : searchResult.results) { + System.out.println(" - " + doc.getString("HotelName") + " (score: " + String.format("%.4f", doc.getDouble("score")) + ")"); + } System.out.println(); var result = new HashMap(); result.put("algorithm", algorithm.toUpperCase()); result.put("similarity", similarity); - result.put("latency", avgLatency); + result.put("latency", searchResult.latencyMs); return result; } catch (Exception e) { @@ -155,42 +159,46 @@ private void createVectorIndex(MongoDatabase database, String collectionName, St System.out.println(" Created vector index: " + indexName); } - private double measureSearchLatency(MongoCollection collection, List queryEmbedding, - String algorithm) { - var embeddedField = Utils.getEnv("EMBEDDED_FIELD"); - var searchOptions = Utils.createSearchOptions(algorithm); + private static class SearchResult { + double latencyMs; + List results; - var totalLatency = 0.0; + SearchResult(double latencyMs, List results) { + this.latencyMs = latencyMs; + this.results = results; + } + } - for (int i = 0; i < NUM_QUERIES; i++) { - var cosmosSearch = new Document() - .append("vector", queryEmbedding) - .append("path", embeddedField) - .append("k", 5); + private SearchResult executeVectorSearch(MongoCollection collection, List queryEmbedding, + String algorithm) { + var embeddedField = Utils.getEnv("EMBEDDED_FIELD"); + var searchOptions = Utils.createSearchOptions(algorithm); - if (!searchOptions.isEmpty()) { - cosmosSearch.putAll(searchOptions); - } + var cosmosSearch = new Document() + .append("vector", queryEmbedding) + .append("path", embeddedField) + .append("k", 5); - var searchStage = new Document("$search", new Document() - .append("cosmosSearch", cosmosSearch) - ); + if (!searchOptions.isEmpty()) { + cosmosSearch.putAll(searchOptions); + } - var projectStage = new Document("$project", new Document() - .append("score", new Document("$meta", "searchScore")) - .append("HotelName", 1) - ); + var searchStage = new Document("$search", new Document() + .append("cosmosSearch", cosmosSearch) + ); - var pipeline = List.of(searchStage, projectStage); + var projectStage = new Document("$project", new Document() + .append("score", new Document("$meta", "searchScore")) + .append("HotelName", 1) + ); - var startTime = System.nanoTime(); - var results = collection.aggregate(pipeline); - results.first(); - var endTime = System.nanoTime(); + var pipeline = List.of(searchStage, projectStage); - totalLatency += (endTime - startTime) / 1_000_000.0; - } + var startTime = System.nanoTime(); + var results = collection.aggregate(pipeline).into(new java.util.ArrayList<>()); + var endTime = System.nanoTime(); - return totalLatency / NUM_QUERIES; + var latencyMs = (endTime - startTime) / 1_000_000.0; + return new SearchResult(latencyMs, results); } } diff --git a/ai/select-algorithm-java/src/main/java/com/azure/documentdb/selectalgorithm/Utils.java b/ai/select-algorithm-java/src/main/java/com/azure/documentdb/selectalgorithm/Utils.java index 61fe025..6f4a0bd 100644 --- a/ai/select-algorithm-java/src/main/java/com/azure/documentdb/selectalgorithm/Utils.java +++ b/ai/select-algorithm-java/src/main/java/com/azure/documentdb/selectalgorithm/Utils.java @@ -73,7 +73,7 @@ public static MongoClient createMongoClient() { .withMechanismProperty("OIDC_CALLBACK", callback); var connectionString = new ConnectionString( - String.format("mongodb+srv://%s@%s.mongocluster.cosmos.azure.com/?authMechanism=MONGODB-OIDC&tls=true&retrywrites=false&maxIdleTimeMS=120000", + String.format("mongodb+srv://%s@%s.mongocluster.cosmos.azure.com/?authMechanism=MONGODB-OIDC&tls=true&retrywrites=false&maxIdleTimeMS=120000&connectTimeoutMS=120000", managedIdentityPrincipalId, clusterName) ); From a50a655f581dbd87443fd1034631fd9501736ca0 Mon Sep 17 00:00:00 2001 From: "Dina Berry (She/her)" Date: Mon, 27 Apr 2026 12:38:37 -0700 Subject: [PATCH 3/4] Add .env.example, dimension validation, env validation - Create .env.example with all required environment variables - Add embedding dimension validation after OpenAI API call - Add startup validation for required environment variables Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../selectalgorithm/SelectAlgorithm.java | 18 ++++++++++++++++++ .../documentdb/selectalgorithm/Utils.java | 11 ++++++++++- 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/ai/select-algorithm-java/src/main/java/com/azure/documentdb/selectalgorithm/SelectAlgorithm.java b/ai/select-algorithm-java/src/main/java/com/azure/documentdb/selectalgorithm/SelectAlgorithm.java index 5783e43..f6cc222 100644 --- a/ai/select-algorithm-java/src/main/java/com/azure/documentdb/selectalgorithm/SelectAlgorithm.java +++ b/ai/select-algorithm-java/src/main/java/com/azure/documentdb/selectalgorithm/SelectAlgorithm.java @@ -23,7 +23,25 @@ public static void main(String[] args) { System.exit(0); } + private void validateEnvironment() { + List missing = new ArrayList<>(); + String[] required = {"MONGO_CLUSTER_NAME", "AZURE_MANAGED_IDENTITY_PRINCIPAL_ID", + "AZURE_OPENAI_EMBEDDING_ENDPOINT", "DATA_FILE_WITH_VECTORS"}; + for (String var : required) { + String value = Utils.getEnv(var); + if (value == null || value.isBlank()) { + missing.add(var); + } + } + if (!missing.isEmpty()) { + throw new IllegalStateException( + "Missing required environment variables: " + String.join(", ", missing) + + "\nSee .env.example for required values. Copy to .env and fill in your Azure resource details."); + } + } + public void run() { + validateEnvironment(); try (var mongoClient = Utils.createMongoClient()) { var openAIClient = Utils.createOpenAIClient(); diff --git a/ai/select-algorithm-java/src/main/java/com/azure/documentdb/selectalgorithm/Utils.java b/ai/select-algorithm-java/src/main/java/com/azure/documentdb/selectalgorithm/Utils.java index 6f4a0bd..fb499c1 100644 --- a/ai/select-algorithm-java/src/main/java/com/azure/documentdb/selectalgorithm/Utils.java +++ b/ai/select-algorithm-java/src/main/java/com/azure/documentdb/selectalgorithm/Utils.java @@ -118,9 +118,18 @@ public static List createEmbedding(OpenAIClient openAIClient, String tex var options = new EmbeddingsOptions(List.of(text)); var response = openAIClient.getEmbeddings(model, options); - return response.getData().get(0).getEmbedding().stream() + var embedding = response.getData().get(0).getEmbedding().stream() .map(Float::doubleValue) .toList(); + + var expectedDimensions = Integer.parseInt(getEnv("EMBEDDING_DIMENSIONS", "1536")); + if (embedding.size() != expectedDimensions) { + throw new IllegalStateException(String.format( + "Embedding dimension mismatch: expected %d, got %d. Verify EMBEDDING_DIMENSIONS matches your model.", + expectedDimensions, embedding.size())); + } + + return embedding; } public static Document createVectorIndexOptions(String algorithm, String similarity) { From a3b3265b56eee29280ad5befb60c27bb09d3d17b Mon Sep 17 00:00:00 2001 From: "Dina Berry (She/her)" Date: Mon, 27 Apr 2026 12:54:44 -0700 Subject: [PATCH 4/4] Cache credential, specific exceptions, resource cleanup - Create DefaultAzureCredential once and reuse across methods - Replace catch-all Exception with specific types (MongoException, IOException) - Ensure MongoClient is properly closed after execution Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../selectalgorithm/SelectAlgorithm.java | 19 +++++++++++-------- .../documentdb/selectalgorithm/Utils.java | 10 +++++----- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/ai/select-algorithm-java/src/main/java/com/azure/documentdb/selectalgorithm/SelectAlgorithm.java b/ai/select-algorithm-java/src/main/java/com/azure/documentdb/selectalgorithm/SelectAlgorithm.java index f6cc222..3ebf38a 100644 --- a/ai/select-algorithm-java/src/main/java/com/azure/documentdb/selectalgorithm/SelectAlgorithm.java +++ b/ai/select-algorithm-java/src/main/java/com/azure/documentdb/selectalgorithm/SelectAlgorithm.java @@ -1,12 +1,14 @@ package com.azure.documentdb.selectalgorithm; import com.azure.ai.openai.OpenAIClient; +import com.mongodb.MongoException; import com.mongodb.client.MongoClient; import com.mongodb.client.MongoCollection; import com.mongodb.client.MongoDatabase; import com.mongodb.client.model.Indexes; import org.bson.Document; +import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -16,10 +18,15 @@ public class SelectAlgorithm { private static final String SAMPLE_QUERY = "quintessential lodging near running trails, eateries, retail"; private static final String DATABASE_NAME = "Hotels"; - public static void main(String[] args) { Utils.loadEnv(); - new SelectAlgorithm().run(); + try { + new SelectAlgorithm().run(); + } catch (Exception e) { + System.err.println("Error: " + e.getMessage()); + e.printStackTrace(); + System.exit(1); + } System.exit(0); } @@ -40,7 +47,7 @@ private void validateEnvironment() { } } - public void run() { + public void run() throws IOException { validateEnvironment(); try (var mongoClient = Utils.createMongoClient()) { var openAIClient = Utils.createOpenAIClient(); @@ -66,10 +73,6 @@ public void run() { } Utils.printComparisonTable(results); - - } catch (Exception e) { - System.err.println("Error: " + e.getMessage()); - e.printStackTrace(); } } @@ -122,7 +125,7 @@ private Map testConfiguration(MongoDatabase database, OpenAIClie result.put("latency", searchResult.latencyMs); return result; - } catch (Exception e) { + } catch (MongoException | IOException e) { System.err.println(" Error testing " + algorithm + " with " + similarity + ": " + e.getMessage()); var result = new HashMap(); result.put("algorithm", algorithm.toUpperCase()); diff --git a/ai/select-algorithm-java/src/main/java/com/azure/documentdb/selectalgorithm/Utils.java b/ai/select-algorithm-java/src/main/java/com/azure/documentdb/selectalgorithm/Utils.java index fb499c1..8a419a8 100644 --- a/ai/select-algorithm-java/src/main/java/com/azure/documentdb/selectalgorithm/Utils.java +++ b/ai/select-algorithm-java/src/main/java/com/azure/documentdb/selectalgorithm/Utils.java @@ -5,6 +5,7 @@ import com.azure.ai.openai.models.EmbeddingsOptions; import com.azure.core.http.policy.ExponentialBackoffOptions; import com.azure.core.http.policy.RetryOptions; +import com.azure.identity.DefaultAzureCredential; import com.azure.identity.DefaultAzureCredentialBuilder; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; @@ -27,13 +28,14 @@ public class Utils { private static Dotenv dotenv; private static final ObjectMapper objectMapper = new ObjectMapper(); + private static final DefaultAzureCredential CREDENTIAL = new DefaultAzureCredentialBuilder().build(); public static void loadEnv() { try { dotenv = Dotenv.configure() .ignoreIfMissing() .load(); - } catch (Exception e) { + } catch (RuntimeException e) { System.err.println("Warning: Could not load .env file, using system environment variables"); } } @@ -54,10 +56,9 @@ public static String getEnv(String key, String defaultValue) { public static MongoClient createMongoClient() { var clusterName = getEnv("MONGO_CLUSTER_NAME"); var managedIdentityPrincipalId = getEnv("AZURE_MANAGED_IDENTITY_PRINCIPAL_ID"); - var azureCredential = new DefaultAzureCredentialBuilder().build(); MongoCredential.OidcCallback callback = (MongoCredential.OidcCallbackContext context) -> { - var token = azureCredential.getToken( + var token = CREDENTIAL.getToken( new com.azure.core.credential.TokenRequestContext() .addScopes("https://ossrdbms-aad.database.windows.net/.default") ).block(); @@ -89,11 +90,10 @@ public static MongoClient createMongoClient() { public static OpenAIClient createOpenAIClient() { var endpoint = getEnv("AZURE_OPENAI_EMBEDDING_ENDPOINT"); - var credential = new DefaultAzureCredentialBuilder().build(); return new OpenAIClientBuilder() .endpoint(endpoint) - .credential(credential) + .credential(CREDENTIAL) .retryOptions(new RetryOptions( new ExponentialBackoffOptions() .setMaxRetries(3)