diff --git a/ai/select-algorithm-go/.env.example b/ai/select-algorithm-go/.env.example new file mode 100644 index 0000000..2f283ab --- /dev/null +++ b/ai/select-algorithm-go/.env.example @@ -0,0 +1,26 @@ +# Azure DocumentDB cluster name (find in Azure Portal > DocumentDB > Overview) +MONGO_CLUSTER_NAME=your-cluster-name + +# Azure OpenAI embedding endpoint (find in Azure Portal > Azure OpenAI > Keys and Endpoint) +AZURE_OPENAI_EMBEDDING_ENDPOINT=https://your-resource.openai.azure.com + +# Azure OpenAI embedding model deployment name +AZURE_OPENAI_EMBEDDING_MODEL=text-embedding-3-small + +# Database name (default: Hotels) +AZURE_DOCUMENTDB_DATABASENAME=Hotels + +# Path to pre-computed vectors JSON file (default: ../data/Hotels_Vector.json) +DATA_FILE_WITH_VECTORS=../data/Hotels_Vector.json + +# Embedding dimensions (default: 1536) +EMBEDDING_DIMENSIONS=1536 + +# Field name containing embeddings in the data file +EMBEDDED_FIELD=contentVector + +# Algorithm to test: all, diskann, hnsw, ivf (default: all) +ALGORITHM=all + +# Similarity to test: COS, L2, IP (default: COS) +SIMILARITY=COS diff --git a/ai/select-algorithm-go/.gitignore b/ai/select-algorithm-go/.gitignore new file mode 100644 index 0000000..594492e --- /dev/null +++ b/ai/select-algorithm-go/.gitignore @@ -0,0 +1,19 @@ +# Build output +build/ +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary +*.test + +# Output +*.out + +# Vendor +vendor/ + +# Environment +.env diff --git a/ai/select-algorithm-go/README.md b/ai/select-algorithm-go/README.md new file mode 100644 index 0000000..9a67bda --- /dev/null +++ b/ai/select-algorithm-go/README.md @@ -0,0 +1,158 @@ +# Select algorithm - Go + +Compare DiskANN, HNSW, and IVF vector index algorithms across COS, L2, and IP similarity metrics using Azure DocumentDB. + +## Prerequisites + +- Go 1.22+ +- Azure DocumentDB cluster +- Azure OpenAI resource with `text-embedding-3-small` deployment + +## Setup + +1. Copy `.env.example` to `.env` in this directory and fill in your values. +2. Source the `.env` file into your shell (the app reads `os.Getenv` directly): + +```bash +export $(grep -v '^#' .env | xargs) # Linux / macOS +``` + +```powershell +Get-Content .env | ForEach-Object { if ($_ -match '^\s*([^#][^=]+)=(.*)') { [System.Environment]::SetEnvironmentVariable($Matches[1].Trim(), $Matches[2].Trim()) } } +``` + +3. Install dependencies: + +```bash +go mod download +``` + +## Usage + +### Compare all algorithms (default: COS similarity) + +```bash +go run src/main.go +``` + +Set `ALGORITHM` and `SIMILARITY` env vars in `.env` to control which collections are queried: + +| ALGORITHM | SIMILARITY | Collections queried | +|-----------|------------|---------------------| +| `all` | `COS` | 3 (one per algorithm, COS) | +| `all` | `all` | 9 (all combinations) | +| `diskann` | `COS` | 1 (hotels_diskann_cos) | +| `diskann` | `all` | 3 (diskann × all similarities) | + +## Architecture + +Creates collections per algorithm and similarity combination (3 algorithms × 3 distance metrics): + +| Algorithm | COS | L2 | IP | +|-----------|-----|----|----| +| DiskANN | `hotels_diskann_cos` | `hotels_diskann_l2` | `hotels_diskann_ip` | +| HNSW | `hotels_hnsw_cos` | `hotels_hnsw_l2` | `hotels_hnsw_ip` | +| IVF | `hotels_ivf_cos` | `hotels_ivf_l2` | `hotels_ivf_ip` | + +Each collection gets its own vector index created via `db.RunCommand()` and data inserted via `InsertMany()`. The main script runs `$search` aggregation queries and prints a comparison table. + +## Expected output + +When you run the sample with `ALGORITHM=all` and `SIMILARITY=COS`, the console prints a comparison table similar to the following (exact timings and scores vary per run): + +``` +Vector Algorithm Comparison + Database: hotels + Algorithms: all + Similarity: COS + Collections to query: hotels_diskann_cos, hotels_hnsw_cos, hotels_ivf_cos + Search query: "luxury hotel with ocean view" + +Initializing MongoDB and Azure OpenAI clients... +Loading data from data/hotels.json... +Loaded 10 documents +Generating query embedding... + +Processing collection: hotels_diskann_cos + Creating collection... + Creating vector index (diskann / COS)... + Inserting 10 documents... + Running vector search... +[OK] 3 results, 42ms + +Processing collection: hotels_hnsw_cos + Creating collection... + Creating vector index (hnsw / COS)... + Inserting 10 documents... + Running vector search... +[OK] 3 results, 38ms + +Processing collection: hotels_ivf_cos + Creating collection... + Creating vector index (ivf / COS)... + Inserting 10 documents... + Running vector search... +[OK] 3 results, 35ms + ++-------------------+-----------+--------+----------------------------------+ +| Collection | Algorithm | Latency| Top Result | ++-------------------+-----------+--------+----------------------------------+ +| hotels_diskann_cos| diskann | 42 ms | Oceanfront Resort (score: 0.87) | +| hotels_hnsw_cos | hnsw | 38 ms | Oceanfront Resort (score: 0.87) | +| hotels_ivf_cos | ivf | 35 ms | Oceanfront Resort (score: 0.86) | ++-------------------+-----------+--------+----------------------------------+ + +Done. +``` + +> Results vary based on cluster size, region latency, and data. The table format and column names are stable; the values are not. + +## Notes + +- Uses passwordless authentication with `DefaultAzureCredential` for both Azure OpenAI and DocumentDB +- Environment variables are read directly via `os.Getenv`; source your `.env` file before running +- Algorithm-specific index parameters: + - DiskANN: maxDegree=32, lBuild=50 + - HNSW: m=16, efConstruction=64 + - IVF: numLists=1 +- Algorithm-specific search parameters: + - DiskANN: lSearch=100 + - HNSW: efSearch=80 + - IVF: nProbes=1 + +## Troubleshooting + +### Authentication with DefaultAzureCredential + +This sample uses `DefaultAzureCredential` for passwordless authentication to both Azure OpenAI and DocumentDB. Before running, sign in with the Azure CLI: + +```bash +az login +``` + +`DefaultAzureCredential` tries multiple credential sources in order (environment variables, managed identity, Azure CLI, and others). For local development, the Azure CLI credential is typically used. + +### RBAC role requirements + +Your Azure identity needs the following roles: + +- **Cognitive Services OpenAI User** on the Azure OpenAI resource (for embedding generation). +- A DocumentDB/Cosmos DB data-plane role that permits read and write operations on the target database. Consult your cluster's access control settings. + +### MongoDB OIDC connection errors + +If you see `MONGODB-OIDC` authentication failures: + +- Confirm the `MONGO_CLUSTER_NAME` environment variable is set correctly (cluster name only, not the full URI). +- Verify your Azure identity has been granted access to the DocumentDB cluster. +- Check that the token resource (`https://ossrdbms-aad.database.windows.net`) is correct for your cluster type. +- Ensure network connectivity to the cluster (firewall rules, VNet configuration). + +### Common error codes + +| Error | Cause | Fix | +|-------|-------|-----| +| `failed to create Azure credential` | No valid Azure credential found | Run `az login` or configure a service principal | +| `failed to connect to MongoDB` | Network or auth issue | Check cluster name, firewall rules, and RBAC | +| `failed to generate embedding` | Azure OpenAI call failed | Verify endpoint URL, deployment name, and RBAC role | +| `invalid ALGORITHM` / `invalid SIMILARITY` | Bad env var value | Use one of: `all`, `diskann`, `hnsw`, `ivf` / `all`, `COS`, `L2`, `IP` | diff --git a/ai/select-algorithm-go/go.mod b/ai/select-algorithm-go/go.mod new file mode 100644 index 0000000..3888faa --- /dev/null +++ b/ai/select-algorithm-go/go.mod @@ -0,0 +1,35 @@ +module documentdb-vector-samples + +go 1.22 + +require ( + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0 + github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 + github.com/openai/openai-go/v3 v3.12.0 + go.mongodb.org/mongo-driver v1.17.6 +) + +require ( + github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect + github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 // indirect + github.com/golang-jwt/jwt/v5 v5.3.0 // indirect + github.com/golang/snappy v1.0.0 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/klauspost/compress v1.18.2 // indirect + github.com/kylelemons/godebug v1.1.0 // indirect + github.com/montanaflynn/stats v0.7.1 // indirect + github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect + github.com/tidwall/gjson v1.18.0 // indirect + github.com/tidwall/match v1.2.0 // indirect + github.com/tidwall/pretty v1.2.1 // indirect + github.com/tidwall/sjson v1.2.5 // indirect + github.com/xdg-go/pbkdf2 v1.0.0 // indirect + github.com/xdg-go/scram v1.2.0 // indirect + github.com/xdg-go/stringprep v1.0.4 // indirect + github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect + golang.org/x/crypto v0.46.0 // indirect + golang.org/x/net v0.48.0 // indirect + golang.org/x/sync v0.19.0 // indirect + golang.org/x/sys v0.39.0 // indirect + golang.org/x/text v0.32.0 // indirect +) diff --git a/ai/select-algorithm-go/go.sum b/ai/select-algorithm-go/go.sum new file mode 100644 index 0000000..c301141 --- /dev/null +++ b/ai/select-algorithm-go/go.sum @@ -0,0 +1,96 @@ +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0 h1:JXg2dwJUmPB9JmtVmdEB16APJ7jurfbY5jnfXpJoRMc= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0/go.mod h1:YD5h/ldMsG0XiIw7PdyNhLxaM317eFh5yNLccNfGdyw= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 h1:Hk5QBxZQC1jb2Fwj6mpzme37xbCDdNTxU7O9eb5+LB4= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1/go.mod h1:IYus9qsFobWIc2YVwe/WPjcnyCkPKtnHAqUYeebc8z0= +github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.2 h1:yz1bePFlP5Vws5+8ez6T3HWXPmwOK7Yvq8QxDBD3SKY= +github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.2/go.mod h1:Pa9ZNPuoNu/GztvBSKk9J1cDJW6vk/n0zLtV4mgd8N8= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 h1:9iefClla7iYpfYWdzPCRDozdmndjTm8DXdpCzPajMgA= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2/go.mod h1:XtLgD3ZD34DAaVIIAyG3objl5DynM3CQ/vMcbBNJZGI= +github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1 h1:WJTmL004Abzc5wDB5VtZG2PJk5ndYDgVacGqfirKxjM= +github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1/go.mod h1:tCcJZ0uHAmvjsVYzEFivsRTN00oz5BEsRgQHu5JZ9WE= +github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgvJqCH0sFfrBUTnUJSBrBf7++ypk+twtRs= +github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= +github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs= +github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/keybase/go-keychain v0.0.1 h1:way+bWYa6lDppZoZcgMbYsvC7GxljxrskdNInRtuthU= +github.com/keybase/go-keychain v0.0.1/go.mod h1:PdEILRW3i9D8JcdM+FmY6RwkHGnhHxXwkPPMeUgOK1k= +github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk= +github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/montanaflynn/stats v0.7.1 h1:etflOAAHORrCC44V+aR6Ftzort912ZU+YLiSTuV8eaE= +github.com/montanaflynn/stats v0.7.1/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow= +github.com/openai/openai-go/v3 v3.12.0 h1:NkrImaglFQeDycc/n/fEmpFV8kKr8snl9/8X2x4eHOg= +github.com/openai/openai-go/v3 v3.12.0/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= +github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/match v1.2.0 h1:0pt8FlkOwjN2fPt4bIl4BoNxb98gGHN2ObFEDkrfZnM= +github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= +github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= +github.com/xdg-go/scram v1.2.0 h1:bYKF2AEwG5rqd1BumT4gAnvwU/M9nBp2pTSxeZw7Wvs= +github.com/xdg-go/scram v1.2.0/go.mod h1:3dlrS0iBaWKYVt2ZfA4cj48umJZ+cAEbR6/SjLA88I8= +github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8= +github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= +github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM= +github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +go.mongodb.org/mongo-driver v1.17.6 h1:87JUG1wZfWsr6rIz3ZmpH90rL5tea7O3IHuSwHUpsss= +go.mongodb.org/mongo-driver v1.17.6/go.mod h1:Hy04i7O2kC4RS06ZrhPRqj/u4DTYkFDAAccj+rVKqgQ= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= +golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= +golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= +golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= +golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= +golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/ai/select-algorithm-go/src/main.go b/ai/select-algorithm-go/src/main.go new file mode 100644 index 0000000..e2f6a28 --- /dev/null +++ b/ai/select-algorithm-go/src/main.go @@ -0,0 +1,543 @@ +package main + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log" + "os" + "strconv" + "strings" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/openai/openai-go/v3" + "github.com/openai/openai-go/v3/azure" + "github.com/openai/openai-go/v3/option" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" +) + +type Algorithm string +type Similarity string + +const ( + DiskANN Algorithm = "diskann" + HNSW Algorithm = "hnsw" + IVF Algorithm = "ivf" +) + +const ( + COS Similarity = "COS" + L2 Similarity = "L2" + IP Similarity = "IP" +) + +var ( + AllAlgorithms = []Algorithm{DiskANN, HNSW, IVF} + AllSimilarities = []Similarity{COS, L2, IP} +) + +var AlgorithmLabels = map[Algorithm]string{ + DiskANN: "DiskANN", + HNSW: "HNSW", + IVF: "IVF", +} + +type CollectionTarget struct { + CollectionName string + Algorithm Algorithm + Similarity Similarity +} + +type SearchResult struct { + Document interface{} `bson:"document"` + Score float64 `bson:"score"` +} + +type ComparisonResult struct { + CollectionName string + Algorithm string + Similarity string + SearchResults []SearchResult + LatencyMs int64 +} + +func getEnvOrDefault(key, defaultValue string) string { + if value := os.Getenv(key); value != "" { + return value + } + return defaultValue +} + +func getTargetCollections(algorithmEnv, similarityEnv string) ([]CollectionTarget, error) { + algorithmEnv = strings.ToLower(strings.TrimSpace(algorithmEnv)) + similarityEnv = strings.ToUpper(strings.TrimSpace(similarityEnv)) + + algorithms := []Algorithm{} + if algorithmEnv == "all" { + algorithms = AllAlgorithms + } else { + algorithms = []Algorithm{Algorithm(algorithmEnv)} + } + + similarities := []Similarity{} + if similarityEnv == "all" { + similarities = AllSimilarities + } else { + similarities = []Similarity{Similarity(similarityEnv)} + } + + targets := []CollectionTarget{} + for _, alg := range algorithms { + validAlg := false + for _, validAlgorithm := range AllAlgorithms { + if alg == validAlgorithm { + validAlg = true + break + } + } + if !validAlg { + return nil, fmt.Errorf("invalid ALGORITHM '%s'. Must be one of: all, diskann, hnsw, ivf", alg) + } + + for _, sim := range similarities { + validSim := false + for _, validSimilarity := range AllSimilarities { + if sim == validSimilarity { + validSim = true + break + } + } + if !validSim { + return nil, fmt.Errorf("invalid SIMILARITY '%s'. Must be one of: all, COS, L2, IP", sim) + } + + targets = append(targets, CollectionTarget{ + CollectionName: fmt.Sprintf("hotels_%s_%s", alg, strings.ToLower(string(sim))), + Algorithm: alg, + Similarity: sim, + }) + } + } + + return targets, nil +} + +func getIndexOptions(collectionName, indexName, embeddedField string, dimensions int, algorithm Algorithm, similarity Similarity) bson.D { + cosmosSearchOptions := bson.D{ + {"dimensions", dimensions}, + {"similarity", string(similarity)}, + } + + switch algorithm { + case DiskANN: + cosmosSearchOptions = append(bson.D{{"kind", "vector-diskann"}}, cosmosSearchOptions...) + cosmosSearchOptions = append(cosmosSearchOptions, bson.E{"maxDegree", 32}) + cosmosSearchOptions = append(cosmosSearchOptions, bson.E{"lBuild", 50}) + case HNSW: + cosmosSearchOptions = append(bson.D{{"kind", "vector-hnsw"}}, cosmosSearchOptions...) + cosmosSearchOptions = append(cosmosSearchOptions, bson.E{"m", 16}) + cosmosSearchOptions = append(cosmosSearchOptions, bson.E{"efConstruction", 64}) + case IVF: + cosmosSearchOptions = append(bson.D{{"kind", "vector-ivf"}}, cosmosSearchOptions...) + cosmosSearchOptions = append(cosmosSearchOptions, bson.E{"numLists", 1}) + } + + return bson.D{ + {"createIndexes", collectionName}, + {"indexes", []bson.D{ + { + {"name", indexName}, + {"key", bson.D{{embeddedField, "cosmosSearch"}}}, + {"cosmosSearchOptions", cosmosSearchOptions}, + }, + }}, + } +} + +func getSearchPipeline(queryEmbedding []float64, embeddedField string, k int, algorithm Algorithm) []bson.M { + cosmosSearch := bson.M{ + "vector": queryEmbedding, + "path": embeddedField, + "k": k, + } + + switch algorithm { + case DiskANN: + cosmosSearch["lSearch"] = 100 + case HNSW: + cosmosSearch["efSearch"] = 80 + case IVF: + cosmosSearch["nProbes"] = 1 + } + + return []bson.M{ + {"$search": bson.M{"cosmosSearch": cosmosSearch}}, + {"$project": bson.M{ + "score": bson.M{"$meta": "searchScore"}, + "document": "$$ROOT", + }}, + } +} + +func getClientsPasswordless() (*mongo.Client, openai.Client, error) { + ctx := context.Background() + + clusterName := os.Getenv("MONGO_CLUSTER_NAME") + if clusterName == "" { + return nil, openai.Client{}, fmt.Errorf("MONGO_CLUSTER_NAME environment variable is required") + } + + credential, err := azidentity.NewDefaultAzureCredential(nil) + if err != nil { + return nil, openai.Client{}, fmt.Errorf("failed to create Azure credential: %w", err) + } + + mongoURI := fmt.Sprintf("mongodb+srv://%s.mongocluster.cosmos.azure.com/", clusterName) + + oidcCallback := func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + scope := "https://ossrdbms-aad.database.windows.net/.default" + token, err := credential.GetToken(ctx, policy.TokenRequestOptions{ + Scopes: []string{scope}, + }) + if err != nil { + return nil, fmt.Errorf("failed to get token with scope %s: %w", scope, err) + } + + return &options.OIDCCredential{ + AccessToken: token.Token, + }, nil + } + + clientOptions := options.Client(). + ApplyURI(mongoURI). + SetConnectTimeout(120 * time.Second). + SetServerSelectionTimeout(120 * time.Second). + SetRetryWrites(false). + SetAuth(options.Credential{ + AuthMechanism: "MONGODB-OIDC", + AuthMechanismProperties: map[string]string{ + "TOKEN_RESOURCE": "https://ossrdbms-aad.database.windows.net", + }, + OIDCMachineCallback: oidcCallback, + }) + + mongoClient, err := mongo.Connect(ctx, clientOptions) + if err != nil { + return nil, openai.Client{}, fmt.Errorf("failed to connect to MongoDB: %w", err) + } + + azureOpenAIEndpoint := os.Getenv("AZURE_OPENAI_EMBEDDING_ENDPOINT") + if azureOpenAIEndpoint == "" { + return nil, openai.Client{}, fmt.Errorf("AZURE_OPENAI_EMBEDDING_ENDPOINT environment variable is required") + } + + openAIClient := openai.NewClient( + option.WithBaseURL(fmt.Sprintf("%s/openai/v1", azureOpenAIEndpoint)), + azure.WithTokenCredential(credential)) + + return mongoClient, openAIClient, nil +} + +func readFileReturnJSON(filePath string) ([]map[string]interface{}, error) { + file, err := os.ReadFile(filePath) + if err != nil { + return nil, fmt.Errorf("error reading file '%s': %w", filePath, err) + } + + var data []map[string]interface{} + err = json.Unmarshal(file, &data) + if err != nil { + return nil, fmt.Errorf("error parsing JSON in file '%s': %w", filePath, err) + } + + return data, nil +} + +func insertData(ctx context.Context, collection *mongo.Collection, data []map[string]interface{}, batchSize int) (int, int, error) { + totalDocuments := len(data) + insertedCount := 0 + failedCount := 0 + + for i := 0; i < totalDocuments; i += batchSize { + end := i + batchSize + if end > totalDocuments { + end = totalDocuments + } + + batch := data[i:end] + + documents := make([]interface{}, len(batch)) + for j, doc := range batch { + documents[j] = doc + } + + result, err := collection.InsertMany(ctx, documents, options.InsertMany().SetOrdered(false)) + if err != nil { + var bulkErr mongo.BulkWriteException + if errors.As(err, &bulkErr) { + inserted := len(batch) - len(bulkErr.WriteErrors) + insertedCount += inserted + failedCount += len(bulkErr.WriteErrors) + } else { + failedCount += len(batch) + } + } else { + insertedCount += len(result.InsertedIDs) + } + + if i+batchSize < totalDocuments { + time.Sleep(100 * time.Millisecond) + } + } + + indexColumns := []string{"HotelId", "Category", "Description", "Description_fr"} + for _, col := range indexColumns { + indexModel := mongo.IndexModel{ + Keys: bson.D{{Key: col, Value: 1}}, + } + _, err := collection.Indexes().CreateOne(ctx, indexModel) + if err != nil { + fmt.Printf("Warning: Could not create index on %s: %v\n", col, err) + } + } + + return insertedCount, failedCount, nil +} + +func generateEmbedding(ctx context.Context, client openai.Client, text, modelName string) ([]float64, error) { + resp, err := client.Embeddings.New(ctx, openai.EmbeddingNewParams{ + Input: openai.EmbeddingNewParamsInputUnion{ + OfString: openai.String(text), + }, + Model: modelName, + }) + if err != nil { + return nil, fmt.Errorf("failed to generate embedding: %w", err) + } + + if len(resp.Data) == 0 { + return nil, fmt.Errorf("no embedding data received") + } + + embedding := make([]float64, len(resp.Data[0].Embedding)) + for i, v := range resp.Data[0].Embedding { + embedding[i] = float64(v) + } + + return embedding, nil +} + +func printComparisonTable(results []ComparisonResult) { + fmt.Println("\n╔══════════════════════════════════════════════════════════════════════════════════╗") + fmt.Println("║ Vector Algorithm Comparison Results ║") + fmt.Println("╠══════════════════════════════════════════════════════════════════════════════════╣") + + fmt.Printf("║ %-12s%-14s%-24s%-12s%-14s║\n", "Algorithm", "Similarity", "Top Result", "Score", "Latency(ms)") + fmt.Println("╠══════════════════════════════════════════════════════════════════════════════════╣") + + for _, r := range results { + topName := "N/A" + topScore := "N/A" + + if len(r.SearchResults) > 0 { + topResult := r.SearchResults[0] + doc := topResult.Document.(bson.D) + for _, elem := range doc { + if elem.Key == "HotelName" { + hotelName := fmt.Sprintf("%v", elem.Value) + if len(hotelName) > 22 { + hotelName = hotelName[:22] + } + topName = hotelName + break + } + } + topScore = fmt.Sprintf("%.4f", topResult.Score) + } + + fmt.Printf("║ %-12s%-14s%-24s%-12s%-14s║\n", + r.Algorithm, + r.Similarity, + topName, + topScore, + fmt.Sprintf("%d", r.LatencyMs)) + } + + fmt.Println("╚══════════════════════════════════════════════════════════════════════════════════╝") + + for _, r := range results { + fmt.Printf("\n--- %s / %s (%s) ---\n", r.Algorithm, r.Similarity, r.CollectionName) + if len(r.SearchResults) == 0 { + fmt.Println(" No results.") + continue + } + for i, item := range r.SearchResults { + doc := item.Document.(bson.D) + var hotelName string + for _, elem := range doc { + if elem.Key == "HotelName" { + hotelName = fmt.Sprintf("%v", elem.Value) + break + } + } + fmt.Printf(" %d. %s, Score: %.4f\n", i+1, hotelName, item.Score) + } + fmt.Printf(" Latency: %dms\n", r.LatencyMs) + } +} + +func main() { + // Set environment variables before running, or source your .env file manually: + // export $(grep -v '^#' .env | xargs) # Linux/macOS + // Get-Content .env | ForEach-Object { if ($_ -match '^\s*([^#][^=]+)=(.*)') { [System.Environment]::SetEnvironmentVariable($Matches[1].Trim(), $Matches[2].Trim()) } } # PowerShell + + ctx := context.Background() + + dbName := getEnvOrDefault("AZURE_DOCUMENTDB_DATABASENAME", "Hotels") + embeddedField := getEnvOrDefault("EMBEDDED_FIELD", "DescriptionVector") + embeddingDimensions, err := strconv.Atoi(getEnvOrDefault("EMBEDDING_DIMENSIONS", "1536")) + if err != nil { + log.Fatalf("Invalid value for EMBEDDING_DIMENSIONS: %v", err) + } + dataFile := getEnvOrDefault("DATA_FILE_WITH_VECTORS", "../data/Hotels_Vector.json") + deployment := os.Getenv("AZURE_OPENAI_EMBEDDING_MODEL") + if deployment == "" { + log.Fatal("AZURE_OPENAI_EMBEDDING_MODEL environment variable is required") + } + batchSize, err := strconv.Atoi(getEnvOrDefault("LOAD_SIZE_BATCH", "100")) + if err != nil { + log.Fatalf("Invalid value for LOAD_SIZE_BATCH: %v", err) + } + algorithmEnv := getEnvOrDefault("ALGORITHM", "all") + similarityEnv := getEnvOrDefault("SIMILARITY", "COS") + searchQuery := "quintessential lodging near running trails, eateries, retail" + + targets, err := getTargetCollections(algorithmEnv, similarityEnv) + if err != nil { + log.Fatal(err) + } + + collectionNames := []string{} + for _, t := range targets { + collectionNames = append(collectionNames, t.CollectionName) + } + + fmt.Println("\nVector Algorithm Comparison") + fmt.Printf(" Database: %s\n", dbName) + fmt.Printf(" Algorithms: %s\n", algorithmEnv) + fmt.Printf(" Similarity: %s\n", similarityEnv) + fmt.Printf(" Collections to query: %s\n", strings.Join(collectionNames, ", ")) + fmt.Printf(" Search query: \"%s\"\n\n", searchQuery) + + fmt.Println("Initializing MongoDB and Azure OpenAI clients...") + mongoClient, azureOpenAIClient, err := getClientsPasswordless() + if err != nil { + log.Fatalf("Failed to initialize clients: %v", err) + } + defer mongoClient.Disconnect(context.Background()) + + db := mongoClient.Database(dbName) + + fmt.Printf("Loading data from %s...\n", dataFile) + data, err := readFileReturnJSON(dataFile) + if err != nil { + log.Fatalf("Failed to load data: %v", err) + } + fmt.Printf("Loaded %d documents\n", len(data)) + + fmt.Println("Generating query embedding...") + queryEmbedding, err := generateEmbedding(ctx, azureOpenAIClient, searchQuery, deployment) + if err != nil { + log.Fatalf("Failed to generate embedding: %v", err) + } + if len(queryEmbedding) != embeddingDimensions { + log.Fatalf("Embedding dimension mismatch: expected %d, got %d. Verify EMBEDDING_DIMENSIONS matches your model.", embeddingDimensions, len(queryEmbedding)) + } + fmt.Printf("Query embedding: %d dimensions\n\n", len(queryEmbedding)) + + comparisonResults := []ComparisonResult{} + + for _, target := range targets { + fmt.Printf("\n━━━ %s / %s ━━━\n", AlgorithmLabels[target.Algorithm], target.Similarity) + fmt.Printf("Collection: %s\n", target.CollectionName) + + if err := db.Collection(target.CollectionName).Drop(ctx); err != nil { + log.Printf("Warning: failed to drop collection %s: %v (may not exist)", target.CollectionName, err) + } + + collection := db.Collection(target.CollectionName) + fmt.Printf("Created collection: %s\n", target.CollectionName) + + inserted, failed, err := insertData(ctx, collection, data, batchSize) + if err != nil { + fmt.Printf("Error inserting data: %v\n", err) + continue + } + fmt.Printf("Inserted: %d/%d\n", inserted, len(data)) + if failed > 0 { + fmt.Printf("Failed: %d\n", failed) + } + + indexName := fmt.Sprintf("vectorIndex_%s_%s", target.Algorithm, strings.ToLower(string(target.Similarity))) + indexOptions := getIndexOptions( + target.CollectionName, + indexName, + embeddedField, + embeddingDimensions, + target.Algorithm, + target.Similarity, + ) + + var result bson.M + err = db.RunCommand(ctx, indexOptions).Decode(&result) + if err != nil { + fmt.Printf("Error creating vector index: %v\n", err) + continue + } + fmt.Printf("Created vector index: %s\n", indexName) + + fmt.Println("Executing vector search...") + startTime := time.Now() + + pipeline := getSearchPipeline(queryEmbedding, embeddedField, 5, target.Algorithm) + cursor, err := collection.Aggregate(ctx, pipeline) + if err != nil { + fmt.Printf("Error performing vector search: %v\n", err) + continue + } + + var searchResults []SearchResult + for cursor.Next(ctx) { + var result SearchResult + if err := cursor.Decode(&result); err != nil { + fmt.Printf("Warning: Could not decode result: %v\n", err) + continue + } + searchResults = append(searchResults, result) + } + cursor.Close(ctx) + + latencyMs := time.Since(startTime).Milliseconds() + + comparisonResults = append(comparisonResults, ComparisonResult{ + CollectionName: target.CollectionName, + Algorithm: AlgorithmLabels[target.Algorithm], + Similarity: string(target.Similarity), + SearchResults: searchResults, + LatencyMs: latencyMs, + }) + + fmt.Printf("[OK] %d results, %dms\n", len(searchResults), latencyMs) + } + + if len(comparisonResults) > 0 { + printComparisonTable(comparisonResults) + } + + fmt.Println("\nDone.") +} diff --git a/ai/vector-search-go/src/create_embeddings.go b/ai/vector-search-go/src/create_embeddings.go index 4550a01..8f4700a 100644 --- a/ai/vector-search-go/src/create_embeddings.go +++ b/ai/vector-search-go/src/create_embeddings.go @@ -41,7 +41,7 @@ func CreateEmbeddings(ctx context.Context, texts []string, openAIClient openai.C }) if err != nil { - return nil, fmt.Errorf("error generating embeddings: %v", err) + return nil, fmt.Errorf("error generating embeddings: %w", err) } // Extract embedding vectors from the API response @@ -87,7 +87,7 @@ func ProcessEmbeddingBatch(ctx context.Context, dataBatch []map[string]interface if len(textsToEmbed) > 0 { embeddings, err := CreateEmbeddings(ctx, textsToEmbed, openAIClient, modelName) if err != nil { - return fmt.Errorf("failed to create embeddings: %v", err) + return fmt.Errorf("failed to create embeddings: %w", err) } // Add embeddings back to the original documents @@ -118,7 +118,7 @@ func LoadEmbeddingConfig() *EmbeddingConfig { // Load environment variables from .env file err := godotenv.Load() if err != nil { - log.Printf("Warning: Error loading .env file: %v", err) + log.Printf("Warning: Error loading .env file: %w", err) } batchSize, _ := strconv.Atoi(getEnvOrDefault("EMBEDDING_SIZE_BATCH", "16")) @@ -141,7 +141,8 @@ func LoadEmbeddingConfig() *EmbeddingConfig { // 3. Processes data in batches to generate embeddings // 4. Saves the enhanced data with embeddings func main() { - ctx := context.Background() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + defer cancel() fmt.Println("Starting embedding creation process...") @@ -158,9 +159,9 @@ func main() { // Initialize clients for MongoDB and Azure OpenAI fmt.Println("\nInitializing Azure OpenAI client...") - mongoClient, azureOpenAIClient, err := GetClientsPasswordless() + mongoClient, azureOpenAIClient, err := GetClientsPasswordless(ctx) if err != nil { - log.Fatalf("Failed to initialize clients: %v", err) + log.Fatalf("Failed to initialize clients: %w", err) } defer func() { if mongoClient != nil { @@ -172,7 +173,7 @@ func main() { fmt.Printf("\nReading input data from %s...\n", config.DataWithoutVectors) data, err := ReadFileReturnJSON(config.DataWithoutVectors) if err != nil { - log.Fatalf("Failed to read input file: %v", err) + log.Fatalf("Failed to read input file: %w", err) } fmt.Printf("Loaded %d documents\n", len(data)) @@ -215,7 +216,7 @@ func main() { fmt.Printf("\nSaving enhanced data to %s...\n", config.DataWithVectors) err = WriteFileJSON(data, config.DataWithVectors) if err != nil { - log.Fatalf("Failed to save output file: %v", err) + log.Fatalf("Failed to save output file: %w", err) } fmt.Println("\nEmbedding creation completed successfully!") diff --git a/ai/vector-search-go/src/diskann.go b/ai/vector-search-go/src/diskann.go index 8991f58..85bdccd 100644 --- a/ai/vector-search-go/src/diskann.go +++ b/ai/vector-search-go/src/diskann.go @@ -67,7 +67,7 @@ func CreateDiskANNVectorIndex(ctx context.Context, collection *mongo.Collection, fmt.Println(" • Use HNSW instead: go run src/hnsw.go") fmt.Println(" • Use IVF instead: go run src/ivf.go") } - return fmt.Errorf("error creating DiskANN vector index: %v", err) + return fmt.Errorf("error creating DiskANN vector index: %w", err) } fmt.Println("DiskANN vector index created successfully") @@ -81,7 +81,7 @@ func PerformDiskANNVectorSearch(ctx context.Context, collection *mongo.Collectio // Generate embedding for the query text queryEmbedding, err := GenerateEmbedding(ctx, openAIClient, queryText, modelName) if err != nil { - return nil, fmt.Errorf("error generating embedding: %v", err) + return nil, fmt.Errorf("error generating embedding: %w", err) } // Construct the aggregation pipeline for vector search @@ -115,7 +115,7 @@ func PerformDiskANNVectorSearch(ctx context.Context, collection *mongo.Collectio // Execute the aggregation pipeline cursor, err := collection.Aggregate(ctx, pipeline) if err != nil { - return nil, fmt.Errorf("error performing DiskANN vector search: %v", err) + return nil, fmt.Errorf("error performing DiskANN vector search: %w", err) } defer cursor.Close(ctx) @@ -130,7 +130,7 @@ func PerformDiskANNVectorSearch(ctx context.Context, collection *mongo.Collectio } if err := cursor.Err(); err != nil { - return nil, fmt.Errorf("cursor error: %v", err) + return nil, fmt.Errorf("cursor error: %w", err) } return results, nil @@ -138,15 +138,16 @@ func PerformDiskANNVectorSearch(ctx context.Context, collection *mongo.Collectio // main function demonstrates DiskANN vector search functionality func main() { - ctx := context.Background() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + defer cancel() // Load configuration from environment variables config := LoadConfig() fmt.Println("\nInitializing MongoDB and Azure OpenAI clients...") - mongoClient, azureOpenAIClient, err := GetClientsPasswordless() + mongoClient, azureOpenAIClient, err := GetClientsPasswordless(ctx) if err != nil { - log.Fatalf("Failed to initialize clients: %v", err) + log.Fatalf("Failed to initialize clients: %w", err) } defer mongoClient.Disconnect(ctx) @@ -158,7 +159,7 @@ func main() { fmt.Printf("\nLoading data from %s...\n", config.DataFile) data, err := ReadFileReturnJSON(config.DataFile) if err != nil { - log.Fatalf("Failed to load data: %v", err) + log.Fatalf("Failed to load data: %w", err) } fmt.Printf("Loaded %d documents\n", len(data)) @@ -180,7 +181,7 @@ func main() { // Clear existing data to ensure clean state deleteResult, err := collection.DeleteMany(ctx, bson.M{}) if err != nil { - log.Fatalf("Failed to clear existing data: %v", err) + log.Fatalf("Failed to clear existing data: %w", err) } if deleteResult.DeletedCount > 0 { fmt.Printf("Cleared %d existing documents from collection\n", deleteResult.DeletedCount) @@ -189,7 +190,7 @@ func main() { // Insert the hotel data stats, err := InsertData(ctx, collection, documentsWithEmbeddings, config.BatchSize, nil) if err != nil { - log.Fatalf("Failed to insert data: %v", err) + log.Fatalf("Failed to insert data: %w", err) } if stats.Inserted == 0 { @@ -201,7 +202,7 @@ func main() { // Create DiskANN vector index err = CreateDiskANNVectorIndex(ctx, collection, config.VectorField, config.Dimensions) if err != nil { - log.Fatalf("Failed to create DiskANN vector index: %v", err) + log.Fatalf("Failed to create DiskANN vector index: %w", err) } // Wait briefly for index to be ready @@ -221,7 +222,7 @@ func main() { 5, ) if err != nil { - log.Fatalf("Failed to perform vector search: %v", err) + log.Fatalf("Failed to perform vector search: %w", err) } // Display results diff --git a/ai/vector-search-go/src/hnsw.go b/ai/vector-search-go/src/hnsw.go index ab6977c..c7d57e7 100644 --- a/ai/vector-search-go/src/hnsw.go +++ b/ai/vector-search-go/src/hnsw.go @@ -66,7 +66,7 @@ func CreateHNSWVectorIndex(ctx context.Context, collection *mongo.Collection, ve fmt.Println(" • Use IVF instead: go run src/ivf.go") fmt.Println(" • Use DiskANN instead: go run src/diskann.go") } - return fmt.Errorf("error creating HNSW vector index: %v", err) + return fmt.Errorf("error creating HNSW vector index: %w", err) } fmt.Println("HNSW vector index created successfully") @@ -80,7 +80,7 @@ func PerformHNSWVectorSearch(ctx context.Context, collection *mongo.Collection, // Convert query text to embedding vector queryEmbedding, err := GenerateEmbedding(ctx, openAIClient, queryText, modelName) if err != nil { - return nil, fmt.Errorf("error generating embedding: %v", err) + return nil, fmt.Errorf("error generating embedding: %w", err) } // Build aggregation pipeline for HNSW vector search @@ -114,7 +114,7 @@ func PerformHNSWVectorSearch(ctx context.Context, collection *mongo.Collection, // Execute the search pipeline cursor, err := collection.Aggregate(ctx, pipeline) if err != nil { - return nil, fmt.Errorf("error performing HNSW vector search: %v", err) + return nil, fmt.Errorf("error performing HNSW vector search: %w", err) } defer cursor.Close(ctx) @@ -129,7 +129,7 @@ func PerformHNSWVectorSearch(ctx context.Context, collection *mongo.Collection, } if err := cursor.Err(); err != nil { - return nil, fmt.Errorf("cursor error: %v", err) + return nil, fmt.Errorf("cursor error: %w", err) } return results, nil @@ -139,15 +139,16 @@ func PerformHNSWVectorSearch(ctx context.Context, collection *mongo.Collection, func main() { fmt.Println("Starting HNSW vector search demonstration...") - ctx := context.Background() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + defer cancel() // Load configuration from environment variables config := LoadConfig() fmt.Println("\nInitializing MongoDB and Azure OpenAI clients...") - mongoClient, azureOpenAIClient, err := GetClientsPasswordless() + mongoClient, azureOpenAIClient, err := GetClientsPasswordless(ctx) if err != nil { - log.Fatalf("Failed to initialize clients: %v", err) + log.Fatalf("Failed to initialize clients: %w", err) } defer mongoClient.Disconnect(ctx) @@ -159,7 +160,7 @@ func main() { fmt.Printf("\nLoading data from %s...\n", config.DataFile) data, err := ReadFileReturnJSON(config.DataFile) if err != nil { - log.Fatalf("Failed to load data: %v", err) + log.Fatalf("Failed to load data: %w", err) } fmt.Printf("Loaded %d documents\n", len(data)) @@ -181,7 +182,7 @@ func main() { // Clear any existing data to start fresh deleteResult, err := collection.DeleteMany(ctx, bson.M{}) if err != nil { - log.Fatalf("Failed to clear existing data: %v", err) + log.Fatalf("Failed to clear existing data: %w", err) } if deleteResult.DeletedCount > 0 { fmt.Printf("Cleared %d existing documents from collection\n", deleteResult.DeletedCount) @@ -190,7 +191,7 @@ func main() { // Insert hotel data with embeddings stats, err := InsertData(ctx, collection, documentsWithEmbeddings, config.BatchSize, nil) if err != nil { - log.Fatalf("Failed to insert data: %v", err) + log.Fatalf("Failed to insert data: %w", err) } if stats.Inserted == 0 { @@ -203,7 +204,7 @@ func main() { fmt.Println("\nCreating HNSW vector index...") err = CreateHNSWVectorIndex(ctx, collection, config.VectorField, config.Dimensions) if err != nil { - log.Fatalf("Failed to create HNSW vector index: %v", err) + log.Fatalf("Failed to create HNSW vector index: %w", err) } // Allow time for index to become ready @@ -224,7 +225,7 @@ func main() { 16, // efSearch (not used directly in DocumentDB but kept for API consistency) ) if err != nil { - log.Fatalf("Failed to perform HNSW vector search: %v", err) + log.Fatalf("Failed to perform HNSW vector search: %w", err) } // Display the search results diff --git a/ai/vector-search-go/src/ivf.go b/ai/vector-search-go/src/ivf.go index 2aeddd8..306c28c 100644 --- a/ai/vector-search-go/src/ivf.go +++ b/ai/vector-search-go/src/ivf.go @@ -63,7 +63,7 @@ func CreateIVFVectorIndex(ctx context.Context, collection *mongo.Collection, vec fmt.Println(" • Use HNSW instead: go run src/hnsw.go") fmt.Println(" • Use DiskANN instead: go run src/diskann.go") } - return fmt.Errorf("error creating IVF vector index: %v", err) + return fmt.Errorf("error creating IVF vector index: %w", err) } fmt.Println("IVF vector index created successfully") @@ -77,7 +77,7 @@ func PerformIVFVectorSearch(ctx context.Context, collection *mongo.Collection, o // Generate embedding vector for the search query queryEmbedding, err := GenerateEmbedding(ctx, openaAIClient, queryText, modelName) if err != nil { - return nil, fmt.Errorf("error generating embedding: %v", err) + return nil, fmt.Errorf("error generating embedding: %w", err) } // Construct aggregation pipeline for IVF vector search @@ -111,7 +111,7 @@ func PerformIVFVectorSearch(ctx context.Context, collection *mongo.Collection, o // Execute the aggregation pipeline cursor, err := collection.Aggregate(ctx, pipeline) if err != nil { - return nil, fmt.Errorf("error performing IVF vector search: %v", err) + return nil, fmt.Errorf("error performing IVF vector search: %w", err) } defer cursor.Close(ctx) @@ -126,7 +126,7 @@ func PerformIVFVectorSearch(ctx context.Context, collection *mongo.Collection, o } if err := cursor.Err(); err != nil { - return nil, fmt.Errorf("cursor error: %v", err) + return nil, fmt.Errorf("cursor error: %w", err) } return results, nil @@ -136,15 +136,16 @@ func PerformIVFVectorSearch(ctx context.Context, collection *mongo.Collection, o func main() { fmt.Println("Starting IVF vector search demonstration...") - ctx := context.Background() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + defer cancel() // Load configuration from environment variables config := LoadConfig() fmt.Println("\nInitializing MongoDB and Azure OpenAI clients...") - mongoClient, azureOpenAIClient, err := GetClientsPasswordless() + mongoClient, azureOpenAIClient, err := GetClientsPasswordless(ctx) if err != nil { - log.Fatalf("Failed to initialize clients: %v", err) + log.Fatalf("Failed to initialize clients: %w", err) } defer mongoClient.Disconnect(ctx) @@ -156,7 +157,7 @@ func main() { fmt.Printf("\nLoading data from %s...\n", config.DataFile) data, err := ReadFileReturnJSON(config.DataFile) if err != nil { - log.Fatalf("Failed to load data: %v", err) + log.Fatalf("Failed to load data: %w", err) } fmt.Printf("Loaded %d documents\n", len(data)) @@ -178,7 +179,7 @@ func main() { // Remove any existing data for clean state deleteResult, err := collection.DeleteMany(ctx, bson.M{}) if err != nil { - log.Fatalf("Failed to clear existing data: %v", err) + log.Fatalf("Failed to clear existing data: %w", err) } if deleteResult.DeletedCount > 0 { fmt.Printf("Cleared %d existing documents from collection\n", deleteResult.DeletedCount) @@ -187,7 +188,7 @@ func main() { // Insert hotel data with embeddings stats, err := InsertData(ctx, collection, documentsWithEmbeddings, config.BatchSize, nil) if err != nil { - log.Fatalf("Failed to insert data: %v", err) + log.Fatalf("Failed to insert data: %w", err) } if stats.Inserted == 0 { @@ -200,7 +201,7 @@ func main() { fmt.Println("\nCreating IVF vector index...") err = CreateIVFVectorIndex(ctx, collection, config.VectorField, config.Dimensions) if err != nil { - log.Fatalf("Failed to create IVF vector index: %v", err) + log.Fatalf("Failed to create IVF vector index: %w", err) } // Wait for index to be built and ready @@ -221,7 +222,7 @@ func main() { 1, // numProbes (not used in DocumentDB but kept for API consistency) ) if err != nil { - log.Fatalf("Failed to perform IVF vector search: %v", err) + log.Fatalf("Failed to perform IVF vector search: %w", err) } // Display the search results diff --git a/ai/vector-search-go/src/show_indexes.go b/ai/vector-search-go/src/show_indexes.go index 00e758e..9c33d69 100644 --- a/ai/vector-search-go/src/show_indexes.go +++ b/ai/vector-search-go/src/show_indexes.go @@ -5,6 +5,7 @@ import ( "fmt" "log" "strings" + "time" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" @@ -138,7 +139,7 @@ func showCollectionIndexes(ctx context.Context, collection *mongo.Collection, co var indexes []IndexInfo if err := cursor.All(ctx, &indexes); err != nil { - return fmt.Errorf("error decoding indexes: %v", err) + return fmt.Errorf("error decoding indexes: %w", err) } if len(indexes) == 0 { @@ -172,7 +173,7 @@ func showDatabaseCollectionsAndIndexes(ctx context.Context, database *mongo.Data // Get list of all collections in the database collectionNames, err := database.ListCollectionNames(ctx, bson.M{}) if err != nil { - return fmt.Errorf("error accessing database '%s': %v", databaseName, err) + return fmt.Errorf("error accessing database '%s': %w", databaseName, err) } if len(collectionNames) == 0 { @@ -208,7 +209,8 @@ func showDatabaseCollectionsAndIndexes(ctx context.Context, database *mongo.Data // main function displays vector indexes and collection information func main() { - ctx := context.Background() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + defer cancel() fmt.Println("Vector Index Information Display") fmt.Printf("%s\n", strings.Repeat("=", 50)) @@ -221,9 +223,9 @@ func main() { // Initialize MongoDB client fmt.Println("\nConnecting to MongoDB...") - mongoClient, _, err := GetClientsPasswordless() + mongoClient, _, err := GetClientsPasswordless(ctx) if err != nil { - log.Fatalf("Failed to initialize MongoDB client: %v", err) + log.Fatalf("Failed to initialize MongoDB client: %w", err) } defer mongoClient.Disconnect(ctx) diff --git a/ai/vector-search-go/src/utils.go b/ai/vector-search-go/src/utils.go index cd02a5b..ab9948a 100644 --- a/ai/vector-search-go/src/utils.go +++ b/ai/vector-search-go/src/utils.go @@ -61,7 +61,7 @@ func LoadConfig() *Config { // services instead of .env files. For development/demo purposes only. err := godotenv.Load() if err != nil { - log.Printf("Warning: Error loading .env file: %v", err) + log.Printf("Warning: Error loading .env file: %w", err) } dimensions, _ := strconv.Atoi(getEnvOrDefault("EMBEDDING_DIMENSIONS", "1536")) @@ -88,8 +88,7 @@ func getEnvOrDefault(key, defaultValue string) string { } // GetClients creates MongoDB and Azure OpenAI clients with connection string authentication -func GetClients() (*mongo.Client, openai.Client, error) { - ctx := context.Background() +func GetClients(ctx context.Context) (*mongo.Client, openai.Client, error) { // Get MongoDB connection string mongoConnectionString := os.Getenv("MONGO_CONNECTION_STRING") @@ -109,13 +108,13 @@ func GetClients() (*mongo.Client, openai.Client, error) { mongoClient, err := mongo.Connect(ctx, clientOptions) if err != nil { - return nil, openai.Client{}, fmt.Errorf("failed to connect to MongoDB: %v", err) + return nil, openai.Client{}, fmt.Errorf("failed to connect to MongoDB: %w", err) } // Test the connection err = mongoClient.Ping(ctx, nil) if err != nil { - return nil, openai.Client{}, fmt.Errorf("failed to ping MongoDB: %v", err) + return nil, openai.Client{}, fmt.Errorf("failed to ping MongoDB: %w", err) } // Get Azure OpenAI configuration @@ -135,8 +134,7 @@ func GetClients() (*mongo.Client, openai.Client, error) { } // GetClientsPasswordless creates MongoDB and Azure OpenAI clients with passwordless authentication -func GetClientsPasswordless() (*mongo.Client, openai.Client, error) { - ctx := context.Background() +func GetClientsPasswordless(ctx context.Context) (*mongo.Client, openai.Client, error) { // Get MongoDB cluster name clusterName := os.Getenv("MONGO_CLUSTER_NAME") @@ -147,7 +145,7 @@ func GetClientsPasswordless() (*mongo.Client, openai.Client, error) { // Create Azure credential credential, err := azidentity.NewDefaultAzureCredential(nil) if err != nil { - return nil, openai.Client{}, fmt.Errorf("failed to create Azure credential: %v", err) + return nil, openai.Client{}, fmt.Errorf("failed to create Azure credential: %w", err) } // Attempt OIDC authentication @@ -156,7 +154,7 @@ func GetClientsPasswordless() (*mongo.Client, openai.Client, error) { fmt.Println("Attempting OIDC authentication...") mongoClient, err := connectWithOIDC(ctx, mongoURI, credential) if err != nil { - return nil, openai.Client{}, fmt.Errorf("OIDC authentication failed: %v", err) + return nil, openai.Client{}, fmt.Errorf("OIDC authentication failed: %w", err) } fmt.Println("OIDC authentication successful!") @@ -184,7 +182,7 @@ func connectWithOIDC(ctx context.Context, mongoURI string, credential *azidentit Scopes: []string{scope}, }) if err != nil { - return nil, fmt.Errorf("failed to get token with scope %s: %v", scope, err) + return nil, fmt.Errorf("failed to get token with scope %s: %w", scope, err) } fmt.Printf("Successfully obtained token") @@ -238,13 +236,13 @@ func connectWithConnectionString(ctx context.Context, connectionString string) ( func ReadFileReturnJSON(filePath string) ([]map[string]interface{}, error) { file, err := os.ReadFile(filePath) if err != nil { - return nil, fmt.Errorf("error reading file '%s': %v", filePath, err) + return nil, fmt.Errorf("error reading file '%s': %w", filePath, err) } var data []map[string]interface{} err = json.Unmarshal(file, &data) if err != nil { - return nil, fmt.Errorf("error parsing JSON in file '%s': %v", filePath, err) + return nil, fmt.Errorf("error parsing JSON in file '%s': %w", filePath, err) } return data, nil @@ -254,12 +252,12 @@ func ReadFileReturnJSON(filePath string) ([]map[string]interface{}, error) { func WriteFileJSON(data []map[string]interface{}, filePath string) error { jsonData, err := json.MarshalIndent(data, "", " ") if err != nil { - return fmt.Errorf("error marshalling data to JSON: %v", err) + return fmt.Errorf("error marshalling data to JSON: %w", err) } err = os.WriteFile(filePath, jsonData, 0644) if err != nil { - return fmt.Errorf("error writing to file '%s': %v", filePath, err) + return fmt.Errorf("error writing to file '%s': %w", filePath, err) } fmt.Printf("Data successfully written to '%s'\n", filePath) @@ -346,7 +344,7 @@ func DropVectorIndexes(ctx context.Context, collection *mongo.Collection, vector // Get all indexes for the collection cursor, err := collection.Indexes().List(ctx) if err != nil { - return fmt.Errorf("could not list indexes: %v", err) + return fmt.Errorf("could not list indexes: %w", err) } defer cursor.Close(ctx) @@ -432,7 +430,7 @@ func GenerateEmbedding(ctx context.Context, client openai.Client, text, modelNam Model: modelName, }) if err != nil { - return nil, fmt.Errorf("failed to generate embedding: %v", err) + return nil, fmt.Errorf("failed to generate embedding: %w", err) } if len(resp.Data) == 0 {