Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions apps/api/app/agents/rag/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,22 @@ def _mmr_filter(chunks: list[dict], threshold: float = MMR_SIMILARITY_THRESHOLD)
# DB search functions
# ---------------------------------------------------------------------------

def _uuid_or_none(value: UUID | str | None) -> UUID | None:
if value is None:
return None
return value if isinstance(value, UUID) else UUID(str(value))


async def _vector_search(
query_vec: list[float],
notebook_id: str | None,
db: AsyncSession,
top_k: int,
global_search: bool,
user_id: UUID | None,
*,
exclude_notebook_id: UUID | None = None,
source_id: UUID | None = None,
) -> list[dict]:
"""Vector-only search. Returns top candidates with updated_at for recency scoring."""
candidate_k = min(200, max(RERANK_CANDIDATE_K * 2, top_k * 4))
Expand Down Expand Up @@ -133,6 +142,11 @@ async def _vector_search(
Chunk.notebook_id == UUID(notebook_id) if notebook_id else text("1=1")
)

if exclude_notebook_id is not None:
stmt = stmt.where(Chunk.notebook_id != exclude_notebook_id)
if source_id is not None:
stmt = stmt.where(Chunk.source_id == source_id)

result = await db.execute(stmt)
return [
{
Expand Down Expand Up @@ -160,6 +174,9 @@ async def _fts_search(
top_k: int,
global_search: bool,
user_id: UUID | None,
*,
exclude_notebook_id: UUID | None = None,
source_id: UUID | None = None,
) -> list[dict]:
"""
PostgreSQL FTS search via plainto_tsquery (handles Chinese via 'simple' config).
Expand Down Expand Up @@ -192,6 +209,11 @@ async def _fts_search(
Chunk.notebook_id == UUID(notebook_id) if notebook_id else text("1=1")
)

if exclude_notebook_id is not None:
stmt = stmt.where(Chunk.notebook_id != exclude_notebook_id)
if source_id is not None:
stmt = stmt.where(Chunk.source_id == source_id)

result = await db.execute(stmt)
rows = result.all()
if not rows:
Expand Down Expand Up @@ -412,6 +434,8 @@ async def retrieve_chunks(
user_id: UUID | None = None,
history: list[dict] | None = None,
_precomputed_variants: list[str] | None = None,
exclude_notebook_id: UUID | str | None = None,
source_id: UUID | str | None = None,
) -> list[dict]:
"""
Full RAG retrieval pipeline:
Expand All @@ -427,6 +451,8 @@ async def retrieve_chunks(

- global_search=False: restrict to chunks in `notebook_id`
- global_search=True: search across ALL notebooks owned by `user_id`
- exclude_notebook_id: with global_search, drop chunks in this notebook
- source_id: only chunks from this source
"""
import asyncio

Expand All @@ -439,6 +465,8 @@ async def retrieve_chunks(
top_k=top_k,
global_search=global_search,
user_id=user_id,
exclude_notebook_id=_uuid_or_none(exclude_notebook_id),
source_id=_uuid_or_none(source_id),
)

# ── Step 1 & 2: Variant generation + primary embed IN PARALLEL ───────
Expand Down
9 changes: 9 additions & 0 deletions apps/api/app/agents/writing/composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,11 +338,20 @@ async def compose_answer(
user_memories: list[dict] | None = None,
notebook_summary: dict | None = None,
db: "AsyncSession | None" = None,
*,
extra_graph_context: str | None = None,
) -> tuple[str, list[dict]]:
"""Non-streaming: return (answer_text, citations)."""
from app.providers.llm import chat

context, citations = _build_context(chunks)
eg = (extra_graph_context or "").strip()
if eg:
context = (
f"## 结构化知识关联(图谱)\n{eg}\n\n---\n\n{context}"
if context
else f"## 结构化知识关联(图谱)\n{eg}"
)
messages = await _build_messages(query, context, history, user_memories, notebook_summary, db)
answer = await chat(messages)
return answer, citations
Expand Down
78 changes: 43 additions & 35 deletions apps/api/app/domains/ai/routers/knowledge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

from uuid import UUID

from fastapi import APIRouter
from fastapi import APIRouter, HTTPException
from sqlalchemy import select

from app.dependencies import CurrentUser, DbDep
from app.domains.ai.schemas import CrossNotebookChunk, CrossNotebookOut
from app.models import Chunk, Notebook, Source, NotebookSummary
from app.models import Chunk, Notebook, NotebookSummary
from app.schemas.response import ApiResponse, success

router = APIRouter()
Expand All @@ -18,8 +18,17 @@
response_model=ApiResponse[CrossNotebookOut],
)
async def get_related_knowledge(notebook_id: UUID, current_user: CurrentUser, db: DbDep):
"""Find related content in other notebooks based on the current notebook's summary."""
from app.providers.embedding import embed_query
"""Find related content in other notebooks based on the current notebook's hybrid RAG search."""
from app.agents.rag.retrieval import retrieve_chunks

own = await db.execute(
select(Notebook.id).where(
Notebook.id == notebook_id,
Notebook.user_id == current_user.id,
)
)
if own.scalar_one_or_none() is None:
raise HTTPException(status_code=404, detail="Notebook not found")

summary_result = await db.execute(
select(NotebookSummary.summary_md).where(NotebookSummary.notebook_id == notebook_id)
Expand All @@ -28,43 +37,42 @@ async def get_related_knowledge(notebook_id: UUID, current_user: CurrentUser, db
if not summary_md or len(summary_md.strip()) < 20:
return success(CrossNotebookOut(chunks=[]))

query_vec = await embed_query(summary_md[:300])
q = summary_md[:500].strip()
raw = await retrieve_chunks(
q,
None,
db,
top_k=5,
global_search=True,
user_id=current_user.id,
exclude_notebook_id=notebook_id,
_precomputed_variants=[q],
)
if not raw:
return success(CrossNotebookOut(chunks=[]))

stmt = (
select(
Chunk.id,
Chunk.content,
Chunk.source_id,
Chunk.notebook_id,
Source.title.label("source_title"),
Notebook.title.label("notebook_title"),
(1 - Chunk.embedding.cosine_distance(query_vec)).label("score"),
)
.outerjoin(Source, Chunk.source_id == Source.id)
ids = [UUID(c["chunk_id"]) for c in raw]
meta_result = await db.execute(
select(Chunk.id, Chunk.notebook_id, Notebook.title.label("notebook_title"))
.join(Notebook, Chunk.notebook_id == Notebook.id)
.where(
Notebook.user_id == current_user.id,
Chunk.notebook_id != notebook_id,
((Source.status == "indexed") | (Chunk.source_type == "note")),
)
.order_by(Chunk.embedding.cosine_distance(query_vec))
.limit(10)
.where(Chunk.id.in_(ids))
)

result = await db.execute(stmt)
rows = result.all()
meta = {
str(r.id): (str(r.notebook_id), r.notebook_title or "未命名笔记本")
for r in meta_result.all()
}

chunks = [
CrossNotebookChunk(
notebook_title=row.notebook_title or "未命名笔记本",
source_title=row.source_title or "📝 笔记",
excerpt=row.content[:300],
score=round(float(row.score), 3),
chunk_id=str(row.id),
notebook_id=str(row.notebook_id),
notebook_title=meta.get(c["chunk_id"], (None, "未命名笔记本"))[1],
source_title=c.get("source_title") or "📝 笔记",
excerpt=(c.get("excerpt") or c.get("content") or "")[:300],
score=round(float(c.get("score") or 0), 3),
chunk_id=str(c.get("chunk_id", "")),
notebook_id=meta.get(c["chunk_id"], ("",))[0] or "",
)
for row in rows
if float(row.score) >= 0.35
][:5]
for c in raw
if c.get("chunk_id") in meta
]

return success(CrossNotebookOut(chunks=chunks))
33 changes: 26 additions & 7 deletions apps/api/app/domains/ai/routers/suggestions.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,11 @@ async def get_context_greeting(notebook_id: UUID, current_user: CurrentUser, db:
@router.get("/sources/{source_id}/suggestions", response_model=ApiResponse[SourceSuggestionsOut])
async def get_source_suggestions(source_id: UUID, current_user: CurrentUser, db: DbDep):
"""Generate suggested questions for a newly indexed source."""
source_result = await db.execute(select(Source).where(Source.id == source_id))
source_result = await db.execute(
select(Source)
.join(Notebook, Source.notebook_id == Notebook.id)
.where(Source.id == source_id, Notebook.user_id == current_user.id)
)
source = source_result.scalar_one_or_none()
if source is None or source.status != "indexed":
return success(SourceSuggestionsOut(summary=None, questions=[]))
Expand All @@ -175,13 +179,28 @@ async def get_source_suggestions(source_id: UUID, current_user: CurrentUser, db:
questions=source.metadata_["suggestions"],
))

chunks_result = await db.execute(
select(Chunk.content)
.where(Chunk.source_id == source_id)
.order_by(Chunk.chunk_index)
.limit(3)
from app.agents.rag.retrieval import retrieve_chunks

q = f"{source.title or ''}\n{source.summary or ''}".strip()[:800] or "资料要点"
chunk_dicts = await retrieve_chunks(
q,
str(source.notebook_id),
db,
top_k=5,
user_id=current_user.id,
source_id=source.id,
_precomputed_variants=[q],
)
context = "\n".join(row[0][:500] for row in chunks_result.all())
if chunk_dicts:
context = "\n".join((c.get("content") or "")[:500] for c in chunk_dicts)
else:
chunks_result = await db.execute(
select(Chunk.content)
.where(Chunk.source_id == source_id)
.order_by(Chunk.chunk_index)
.limit(3)
)
context = "\n".join(row[0][:500] for row in chunks_result.all())

client = get_utility_client()
try:
Expand Down
56 changes: 26 additions & 30 deletions apps/api/app/domains/ai/routers/writing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
from uuid import UUID

from fastapi import APIRouter
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from sqlalchemy import select

Expand All @@ -14,9 +14,9 @@
WritingContextOut,
WritingContextRequest,
)
from app.models import Chunk, Source
from app.models import Notebook
from app.schemas.response import ApiResponse, success
from app.providers.llm import get_client, get_utility_model, get_utility_client
from app.providers.llm import get_utility_model, get_utility_client

router = APIRouter()

Expand Down Expand Up @@ -64,43 +64,39 @@ async def generate():
@router.post("/ai/writing-context", response_model=ApiResponse[WritingContextOut])
async def get_writing_context(body: WritingContextRequest, current_user: CurrentUser, db: DbDep):
"""Return top-3 related knowledge chunks based on what the user is currently writing."""
from app.providers.embedding import embed_query
from app.agents.rag.retrieval import retrieve_chunks

text = body.text_around_cursor[:500]
if len(text.strip()) < 20:
return success(WritingContextOut(chunks=[]))

query_vec = await embed_query(text)

stmt = (
select(
Chunk.id,
Chunk.content,
Chunk.source_id,
Source.title.label("source_title"),
(1 - Chunk.embedding.cosine_distance(query_vec)).label("score"),
)
.join(Source, Chunk.source_id == Source.id)
.where(
Source.status == "indexed",
Chunk.notebook_id == UUID(body.notebook_id),
nb_row = await db.execute(
select(Notebook.id).where(
Notebook.id == UUID(body.notebook_id),
Notebook.user_id == current_user.id,
)
.order_by(Chunk.embedding.cosine_distance(query_vec))
.limit(12)
)

result = await db.execute(stmt)
rows = result.all()
if nb_row.scalar_one_or_none() is None:
raise HTTPException(status_code=404, detail="Notebook not found")

q = text.strip()
rows = await retrieve_chunks(
q,
body.notebook_id,
db,
top_k=3,
user_id=current_user.id,
_precomputed_variants=[q],
)

chunks = [
WritingContextChunk(
source_title=row.source_title or "未知来源",
excerpt=row.content[:300],
score=round(float(row.score), 3),
chunk_id=str(row.id),
source_title=r.get("source_title") or "未知来源",
excerpt=(r.get("excerpt") or r.get("content") or "")[:300],
score=round(float(r.get("score") or 0), 3),
chunk_id=str(r.get("chunk_id", "")),
)
for row in rows
if float(row.score) >= 0.35
][:3]
for r in rows
]

return success(WritingContextOut(chunks=chunks))
27 changes: 25 additions & 2 deletions apps/api/app/services/conversation_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,17 +167,40 @@ async def send_message(self, conversation_id: UUID, content: str) -> Message:

history = await self._load_history(conversation_id)

from app.agents.rag.graph_retrieval import graph_augmented_context
from app.agents.rag.retrieval import retrieve_chunks
from app.agents.writing.composer import compose_answer
from app.agents.memory import get_user_memories, get_notebook_summary

user_memories = await get_user_memories(self.user_id, self.db)
notebook_summary = await get_notebook_summary(conv.notebook_id, self.db)
chunks = await retrieve_chunks(content, str(conv.notebook_id), self.db)
if conv.notebook_id:
chunks, graph_ctx = await asyncio.gather(
retrieve_chunks(
content,
str(conv.notebook_id),
self.db,
user_id=self.user_id,
),
graph_augmented_context(content, str(conv.notebook_id), self.db),
)
else:
chunks = await retrieve_chunks(
content,
None,
self.db,
global_search=True,
user_id=self.user_id,
)
graph_ctx = ""
answer, citations = await compose_answer(
content, chunks, history,
content,
chunks,
history,
user_memories=user_memories,
notebook_summary=notebook_summary,
db=self.db,
extra_graph_context=graph_ctx or None,
)

assistant_msg = Message(
Expand Down
Loading
Loading