Skip to content

Commit 06a9eb0

Browse files
committed
refactor(embed_stream): move to manually maintained files, fix magic numbers
- Move embed_stream() from auto-generated base_client.py to client.py (.fernignore) - Move StreamedEmbedding and extraction logic to manually_maintained/streaming_embed.py - Replace magic batch_size=10 with embed_stream_batch_size=96 from config.py (API max) - Remove overengineered StreamingEmbedParser and ijson dependency - Remove MEMORY_OPTIMIZATION_PROPOSAL.md - Revert base_client.py and v2/client.py to Fern baseline - 9 unit tests, all Fern-safe
1 parent 101d3db commit 06a9eb0

9 files changed

Lines changed: 233 additions & 1058 deletions

File tree

MEMORY_OPTIMIZATION_PROPOSAL.md

Lines changed: 0 additions & 145 deletions
This file was deleted.

src/cohere/base_client.py

Lines changed: 0 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,110 +1128,6 @@ def embed(
11281128
)
11291129
return _response.data
11301130

1131-
def embed_stream(
1132-
self,
1133-
*,
1134-
texts: typing.Optional[typing.Sequence[str]] = OMIT,
1135-
model: typing.Optional[str] = OMIT,
1136-
input_type: typing.Optional[EmbedInputType] = OMIT,
1137-
embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT,
1138-
truncate: typing.Optional[EmbedRequestTruncate] = OMIT,
1139-
batch_size: int = 10,
1140-
request_options: typing.Optional[RequestOptions] = None,
1141-
) -> typing.Iterator[typing.Any]: # Returns Iterator[StreamedEmbedding]
1142-
"""
1143-
Memory-efficient streaming version of embed that yields embeddings one at a time.
1144-
1145-
This method processes texts in batches and yields individual embeddings as they are
1146-
parsed from the response, without loading all embeddings into memory at once.
1147-
Ideal for processing large datasets where memory usage is a concern.
1148-
1149-
Parameters
1150-
----------
1151-
texts : typing.Optional[typing.Sequence[str]]
1152-
An array of strings for the model to embed. Will be processed in batches.
1153-
1154-
model : typing.Optional[str]
1155-
ID of one of the available [Embedding models](https://docs.cohere.com/docs/cohere-embed).
1156-
1157-
input_type : typing.Optional[EmbedInputType]
1158-
Specifies the type of input passed to the model.
1159-
1160-
embedding_types : typing.Optional[typing.Sequence[EmbeddingType]]
1161-
Specifies the types of embeddings you want to get back.
1162-
1163-
truncate : typing.Optional[EmbedRequestTruncate]
1164-
One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length.
1165-
1166-
batch_size : int
1167-
Number of texts to process in each batch. Default is 10.
1168-
Lower values use less memory but may be slower overall.
1169-
1170-
request_options : typing.Optional[RequestOptions]
1171-
Request-specific configuration.
1172-
1173-
Yields
1174-
------
1175-
StreamedEmbedding
1176-
Individual embeddings as they are parsed from the response.
1177-
1178-
Examples
1179-
--------
1180-
from cohere import Client
1181-
1182-
client = Client(
1183-
client_name="YOUR_CLIENT_NAME",
1184-
token="YOUR_TOKEN",
1185-
)
1186-
1187-
# Process embeddings one at a time without loading all into memory
1188-
for embedding in client.embed_stream(
1189-
texts=["hello", "goodbye", "how are you"],
1190-
model="embed-v4.0",
1191-
batch_size=2
1192-
):
1193-
print(f"Embedding {embedding.index}: {embedding.embedding[:5]}...")
1194-
# Process/save embedding immediately
1195-
"""
1196-
# Validate inputs
1197-
if texts is None or texts is OMIT:
1198-
return
1199-
if batch_size < 1:
1200-
raise ValueError("batch_size must be at least 1")
1201-
1202-
from .streaming_utils import StreamingEmbedParser
1203-
1204-
# Process texts in batches
1205-
texts_list = list(texts)
1206-
if not texts_list:
1207-
return
1208-
1209-
# Track text index separately from embedding index (for multiple embedding types)
1210-
global_text_index = 0
1211-
1212-
for batch_start in range(0, len(texts_list), batch_size):
1213-
batch_end = min(batch_start + batch_size, len(texts_list))
1214-
batch_texts = texts_list[batch_start:batch_end]
1215-
1216-
# Get response for this batch
1217-
response = self._raw_client.embed(
1218-
texts=batch_texts,
1219-
model=model,
1220-
input_type=input_type,
1221-
embedding_types=embedding_types,
1222-
truncate=truncate,
1223-
request_options=request_options,
1224-
)
1225-
1226-
# Parse embeddings from response incrementally
1227-
parser = StreamingEmbedParser(response._response, batch_texts)
1228-
for embedding in parser.iter_embeddings():
1229-
# The parser tracks text index per embedding type
1230-
# Adjust text reference to use batch_texts mapping
1231-
text_index_in_batch = batch_texts.index(embedding.text) if embedding.text in batch_texts else 0
1232-
embedding.index = batch_start + text_index_in_batch
1233-
yield embedding
1234-
12351131
def rerank(
12361132
self,
12371133
*,

src/cohere/client.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from . import EmbedResponse, EmbedInputType, EmbeddingType, EmbedRequestTruncate
1414
from .base_client import BaseCohere, AsyncBaseCohere, OMIT
15-
from .config import embed_batch_size
15+
from .config import embed_batch_size, embed_stream_batch_size
1616
from .core import RequestOptions
1717
from .environment import ClientEnvironment
1818
from .manually_maintained.cache import CacheMixin
@@ -223,6 +223,61 @@ def embed(
223223

224224
return merge_embed_responses(responses)
225225

226+
def embed_stream(
227+
self,
228+
*,
229+
texts: typing.Sequence[str],
230+
model: typing.Optional[str] = OMIT,
231+
input_type: typing.Optional[EmbedInputType] = OMIT,
232+
embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT,
233+
truncate: typing.Optional[EmbedRequestTruncate] = OMIT,
234+
batch_size: int = embed_stream_batch_size,
235+
request_options: typing.Optional[RequestOptions] = None,
236+
) -> typing.Iterator[typing.Any]:
237+
"""
238+
Memory-efficient embed that yields embeddings one batch at a time.
239+
240+
Processes texts in batches and yields individual StreamedEmbedding objects
241+
as they come back, so you can write to a vector store incrementally without
242+
holding all embeddings in memory.
243+
244+
Args:
245+
texts: Texts to embed.
246+
model: Embedding model ID.
247+
input_type: Input type (search_document, search_query, etc.).
248+
embedding_types: Types of embeddings to return (float, int8, etc.).
249+
truncate: How to handle inputs longer than the max token length.
250+
batch_size: Texts per API call. Defaults to 96 (API max).
251+
request_options: Request-specific configuration.
252+
253+
Yields:
254+
StreamedEmbedding with index, embedding, embedding_type, and text.
255+
"""
256+
from .manually_maintained.streaming_embed import extract_embeddings_from_response
257+
258+
if not texts:
259+
return
260+
if batch_size < 1:
261+
raise ValueError("batch_size must be at least 1")
262+
263+
texts_list = list(texts)
264+
265+
for batch_start in range(0, len(texts_list), batch_size):
266+
batch_texts = texts_list[batch_start : batch_start + batch_size]
267+
268+
response = BaseCohere.embed(
269+
self,
270+
texts=batch_texts,
271+
model=model,
272+
input_type=input_type,
273+
embedding_types=embedding_types,
274+
truncate=truncate,
275+
request_options=request_options,
276+
)
277+
278+
response_data = response.dict() if hasattr(response, "dict") else response.__dict__
279+
yield from extract_embeddings_from_response(response_data, batch_texts, batch_start)
280+
226281
"""
227282
The following methods have been moved or deprecated in cohere==5.0.0. Please update your usage.
228283
Issues may be filed in https://github.com/cohere-ai/cohere-python/issues.

src/cohere/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
embed_batch_size = 96
2+
embed_stream_batch_size = 96 # Max texts per API request (API limit)

0 commit comments

Comments
 (0)