Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions backend/src/vmarker/temp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
[PROTOCOL]: 变更时更新此头部,然后检查 CLAUDE.md
"""

import re
import shutil
import time
import uuid
Expand All @@ -20,6 +21,12 @@

BASE_DIR = Path(gettempdir()) / "vmarker"
DEFAULT_MAX_AGE_HOURS = 24
SESSION_ID_PATTERN = re.compile(r"^[0-9a-f]{12}$")


def is_valid_session_id(session_id: str) -> bool:
"""校验会话 ID,拒绝路径字符和非预期格式"""
return bool(SESSION_ID_PATTERN.fullmatch(session_id))


# =============================================================================
Expand All @@ -39,16 +46,33 @@ class TempSession:
5. 处理完成 / 异常 → 清理
"""

def __init__(self, session_id: str | None = None):
def __init__(self, session_id: str | None = None, create: bool = True):
"""
初始化会话

Args:
session_id: 会话 ID,不提供则自动生成
create: 是否创建会话目录,默认创建
"""
self.session_id = session_id or uuid.uuid4().hex[:12]
if session_id is None:
self.session_id = uuid.uuid4().hex[:12]
else:
if not is_valid_session_id(session_id):
raise ValueError(f"invalid session_id: {session_id!r}")
self.session_id = session_id
self.session_dir = BASE_DIR / self.session_id
self.session_dir.mkdir(parents=True, exist_ok=True)
if create:
self.session_dir.mkdir(parents=True, exist_ok=True)

@classmethod
def from_existing(cls, session_id: str) -> "TempSession | None":
"""打开现有会话,不隐式创建目录"""
if not is_valid_session_id(session_id):
return None
session_dir = BASE_DIR / session_id
if not session_dir.is_dir():
return None
return cls(session_id=session_id, create=False)

def save_upload(self, filename: str, content: bytes) -> Path:
"""
Expand Down Expand Up @@ -196,10 +220,11 @@ def get_session(session_id: str) -> TempSession | None:
Returns:
TempSession 实例,如果不存在返回 None
"""
session = TempSession(session_id)
return session if session.is_valid else None
return TempSession.from_existing(session_id)


def session_exists(session_id: str) -> bool:
"""检查会话是否存在"""
return (BASE_DIR / session_id).exists()
if not is_valid_session_id(session_id):
return False
return (BASE_DIR / session_id).is_dir()
53 changes: 53 additions & 0 deletions backend/tests/test_temp_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""
[INPUT]: 依赖 pytest, vmarker.temp_manager
[OUTPUT]: temp_manager 模块测试
[POS]: tests/ 的临时会话管理测试
[PROTOCOL]: 变更时更新此头部,然后检查 CLAUDE.md
"""

from vmarker import temp_manager


def test_get_session_missing_returns_none_without_creating_directory(tmp_path, monkeypatch):
"""不存在的 session 不应被隐式创建"""
monkeypatch.setattr(temp_manager, "BASE_DIR", tmp_path)

session_id = "missing-session"

session = temp_manager.get_session(session_id)

assert session is None
assert not (tmp_path / session_id).exists()


def test_get_session_existing_returns_session(tmp_path, monkeypatch):
"""已存在的 session 应正常返回"""
monkeypatch.setattr(temp_manager, "BASE_DIR", tmp_path)

created = temp_manager.TempSession("abcdef123456")

session = temp_manager.get_session(created.session_id)

assert session is not None
assert session.session_id == created.session_id
assert session.session_dir == created.session_dir


def test_get_session_rejects_path_traversal_session_id(tmp_path, monkeypatch):
"""非法 session_id 不应访问 BASE_DIR 之外的路径"""
monkeypatch.setattr(temp_manager, "BASE_DIR", tmp_path)

session = temp_manager.get_session("../../../etc")

assert session is None
assert list(tmp_path.iterdir()) == []


def test_temp_session_rejects_invalid_session_id():
"""显式传入非法 session_id 时应拒绝"""
try:
temp_manager.TempSession("../../../etc", create=False)
except ValueError as exc:
assert "invalid session_id" in str(exc)
else:
raise AssertionError("TempSession should reject invalid session_id")