Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ jobs:

- name: Install dependencies
run: |
python -m pip install -e ".[tests,tracking-client,graphviz]"
python -m pip install -e ".[tests,tracking-client,tracking-server,graphviz]"

- name: Run tests
run: |
Expand Down
12 changes: 6 additions & 6 deletions burr/tracking/server/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ async def _load_project_annotations(self, project_id: str):
annotations_path = self._get_annotation_path(project_id)
annotations = []
if os.path.exists(annotations_path):
async with aiofiles.open(annotations_path) as f:
async with aiofiles.open(annotations_path, encoding="utf-8") as f:
for line in await f.readlines():
annotations.append(AnnotationOut.parse_raw(line))
return annotations
Expand Down Expand Up @@ -348,7 +348,7 @@ async def create_annotation(
**annotation.dict(),
)
annotations_path = self._get_annotation_path(project_id)
async with aiofiles.open(annotations_path, "a") as f:
async with aiofiles.open(annotations_path, "a", encoding="utf-8") as f:
await f.write(annotation_out.json() + "\n")
return annotation_out

Expand Down Expand Up @@ -381,7 +381,7 @@ async def update_annotation(
detail=f"Annotation: {annotation_id} from project: {project_id} not found",
)
annotations_path = self._get_annotation_path(project_id)
async with aiofiles.open(annotations_path, "w") as f:
async with aiofiles.open(annotations_path, "w", encoding="utf-8") as f:
for a in all_annotations:
await f.write(a.json() + "\n")
return annotation_out
Expand All @@ -407,7 +407,7 @@ async def get_annotations(
if not os.path.exists(annotation_path):
return []
annotations = []
async with aiofiles.open(annotation_path) as f:
async with aiofiles.open(annotation_path, encoding="utf-8") as f:
for line in await f.readlines():
parsed = AnnotationOut.parse_raw(line)
if (
Expand Down Expand Up @@ -521,7 +521,7 @@ async def get_application_logs(
status_code=404,
detail=f"Graph file for app: {app_id} from project: {project_id} not found",
)
async with aiofiles.open(graph_file) as f:
async with aiofiles.open(graph_file, encoding="utf-8") as f:
str_graph = await f.read()
collections.defaultdict(list)
if os.path.exists(log_file):
Expand All @@ -530,7 +530,7 @@ async def get_application_logs(
steps = Step.from_logs(lines)
children = []
if os.path.exists(children_file):
async with aiofiles.open(children_file) as f:
async with aiofiles.open(children_file, encoding="utf-8") as f:
str_children = await f.readlines()
children = [
ChildApplicationModel.parse_obj(json.loads(item)) for item in str_children
Expand Down
104 changes: 104 additions & 0 deletions tests/tracking/test_local_tracking_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
# specific language governing permissions and limitations
# under the License.

import asyncio
import json
import os
import uuid
from datetime import datetime
from typing import Literal, Optional, Tuple

import pytest
Expand All @@ -37,6 +39,7 @@
ChildApplicationModel,
EndEntryModel,
EndSpanModel,
PointerModel,
)
from burr.visibility import TracerFactory

Expand Down Expand Up @@ -309,6 +312,107 @@ def test_fork_children_have_correct_partition_key(tmpdir):
assert child.event_type == "fork"


def test_local_backend_reads_utf8_annotations_graph_and_children(tmpdir, monkeypatch):
# Server backend imports require the tracking-server extra (aiofiles); skip
# cleanly where it is not installed so this module still collects with only
# tracking-client present.
pytest.importorskip("aiofiles")
# Guard against missed call sites in a host-independent way: a round-trip
# assertion alone would pass on a UTF-8 default host even if a text open
# dropped encoding="utf-8". So wrap the backend's aiofiles.open and assert
# every text-mode (non-binary) open is explicitly UTF-8. Binary opens
# (log/metadata, mode "rb") are intentionally exempt.
import burr.tracking.server.backend as backend_module
from burr.tracking.server.backend import LocalBackend
from burr.tracking.server.schema import (
AnnotationCreate,
AnnotationDataPointer,
AnnotationObservation,
)

real_aiofiles_open = backend_module.aiofiles.open

def _utf8_guarded_open(file, mode="r", *args, **kwargs):
if "b" not in mode:
assert (
kwargs.get("encoding") == "utf-8"
), f"text-mode open of {file} (mode={mode!r}) must pass encoding='utf-8'"
return real_aiofiles_open(file, mode, *args, **kwargs)

monkeypatch.setattr(backend_module.aiofiles, "open", _utf8_guarded_open)

project_name = "test_local_backend_utf8"
app_id = "app-unicode"
partition_key = "partici\u00f3n-ni\u00f1a"
step_name = "an\u00e1lisis caf\u00e9 \u65e5\u672c\u8a9e"
tag = "ni\u00f1o"
note = "acci\u00f3n termin\u00f3 con \u00e9xito \u4f60\u597d"
entrypoint = "inicio-caf\u00e9-\u4e16\u754c"
child_app_id = "hijo-ni\u00f1o-\u6f22\u5b57"
child_partition_key = "clave-ni\u00f1a"
log_dir = os.path.join(tmpdir, "tracking")
project_dir = os.path.join(log_dir, project_name)
app_dir = os.path.join(project_dir, app_id)
os.makedirs(app_dir)

backend = LocalBackend(path=log_dir)

annotation = AnnotationCreate(
span_id="span-1",
step_name=step_name,
tags=["revision", tag],
observations=[
AnnotationObservation(
data_fields={"note": note},
thumbs_up_thumbs_down=True,
data_pointers=[
AnnotationDataPointer(
type="state_field",
field_name="resultado_final",
span_id=None,
)
],
)
],
)

created_annotation = asyncio.run(
backend.create_annotation(annotation, project_name, partition_key, app_id, 1)
)
annotations = asyncio.run(
backend.get_annotations(project_name, partition_key, app_id, step_sequence_id=1)
)

assert created_annotation.step_name == step_name
assert annotations[0].tags == ["revision", tag]
assert annotations[0].observations[0].data_fields["note"] == note

application = ApplicationModel(entrypoint=entrypoint, actions=[], transitions=[])
child = ChildApplicationModel(
child=PointerModel(app_id=child_app_id, sequence_id=2, partition_key=child_partition_key),
event_time=datetime.now(),
event_type="fork",
sequence_id=1,
)

with open(
os.path.join(app_dir, LocalTrackingClient.GRAPH_FILENAME), "w", encoding="utf-8"
) as f:
f.write(application.model_dump_json())
with open(os.path.join(app_dir, LocalTrackingClient.LOG_FILENAME), "w", encoding="utf-8"):
pass
with open(
os.path.join(app_dir, LocalTrackingClient.CHILDREN_FILENAME), "w", encoding="utf-8"
) as f:
f.write(child.model_dump_json() + "\n")

logs = asyncio.run(backend.get_application_logs(None, project_name, app_id, partition_key))

assert logs.application.entrypoint == entrypoint
assert logs.children[0].child.app_id == child_app_id
assert logs.children[0].child.partition_key == child_partition_key


def test_multi_fork_tracking_client(tmpdir):
"""This is more of an end-to-end test. We shoudl probably break it out
into smaller tests but the local tracking client being used as a persister is
Expand Down
Loading