-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
171 lines (146 loc) · 7.42 KB
/
app.py
File metadata and controls
171 lines (146 loc) · 7.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import os
import time
import asyncio
import chainlit as cl
from dotenv import load_dotenv
from typing import Literal, Dict, Optional
from langchain_core.tools import tool
from langgraph.prebuilt import ToolNode
from langgraph.graph.message import MessagesState
from langgraph.graph import END, StateGraph, START
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, AIMessageChunk
from langgraph.store.redis.aio import AsyncRedisStore
from langgraph.checkpoint.redis.aio import AsyncRedisSaver
from src.graphs.reviewer_graph import create_reviewer_graph
load_dotenv()
# Get Redis connection from environment or use default for local development
redis_url = os.environ.get("REDIS_URL", "redis://localhost:6379")
async def run_app():
async with (
AsyncRedisStore.from_conn_string(redis_url) as store,
AsyncRedisSaver.from_conn_string(redis_url) as checkpointer,
):
await store.setup()
await checkpointer.asetup()
print("Building graph...")
graph = create_reviewer_graph(
store=store,
checkpointer=checkpointer,
)
print("Graph Built.")
@cl.oauth_callback
def oauth_callback(
provider_id: str,
token: str,
raw_user_data: Dict[str, str],
default_user: cl.User,
) -> Optional[cl.User]:
"""
OAuth callback handler for authentication.
This function is used to validate users after they authenticate through OAuth.
It performs security checks and returns a user object for authorized users.
Note: This application only uses authentication for session management
and does not collect or store sensitive user data.
"""
# Log the authentication attempt (without sensitive details)
print(f"Authentication attempt from provider: {provider_id}")
# Validate token is present (basic security check)
if not token:
print("Warning: Empty token in OAuth callback")
return None
# Return the default user for authentication
return default_user
# Reset Chat button action handler
@cl.action_callback("reset_chat")
async def reset_chat(action):
"""Function that gets called when the Reset Chat button is pressed"""
await cl.Message(content="Resetting chat... Please wait.", author="system").send()
# Delete the thread from the checkpointer
try:
thread_id = cl.user_session.get("user").identifier
await checkpointer.adelete_thread(thread_id=thread_id)
print("Chat thread has been reset successfully")
except Exception as e:
print(f"Error resetting chat thread: {e}")
@cl.set_starters
async def set_starters():
return [
cl.Starter(
label="Unemployment rates in Ireland",
message="Show me detailed statistics on unemployment rates in Ireland for the last 5 years.",
icon="public/starters/jobs.svg",
),
cl.Starter(
label="Population growth trends",
message="What's the population growth trend in Dublin compared to other cities?",
icon="public/starters/population.svg",
),
cl.Starter(
label="Ireland renewable-energy mix (shares & totals)",
message="Give me the breakdown of renewable energy resources as a percentage share and absolute numbers in recent years in Ireland.",
icon="public/starters/energy.svg",
command="code",
),
cl.Starter(
label="Impact of inflation on spending",
message="How has inflation affected consumer spending in the past year?",
icon="/public/starters/inflation.svg",
)
]
# Define your header settings
@cl.on_chat_start
async def on_chat_start():
# Check if there's an existing conversation state to restore
thread_id = cl.user_session.get("user").identifier
config = {"configurable": {"thread_id": thread_id}}
existing_state = await graph.aget_state(config)
# Only restore messages if there are actual past messages
if existing_state and "messages" in existing_state.values and existing_state.values["messages"]:
# Add the Reset Chat button element in a restoring previous conversation message, if the chat-session is not new
# Create and add Reset Chat button to the header (top-left, aligned with Readme)
reset_chat_element = cl.CustomElement(name="ResetChatButton")
await cl.Message(
content="*Restoring previous conversation...*",
elements=[reset_chat_element],
author="Assistant"
).send()
for past_msg in existing_state.values.get("messages", []):
if isinstance(past_msg, HumanMessage):
msg = cl.Message(content=past_msg.content, type="user_message")
await msg.send()
elif isinstance(past_msg, AIMessage) and past_msg.content:
msg = cl.Message(content=past_msg.content, type="assistant_message")
await msg.send()
@cl.on_message
async def on_message(message: cl.Message):
start = time.time()
msg = cl.Message(content="")
pls_wait_msg = cl.Message(content="*Please wait while I process your request...*", author="assistant_message")
await pls_wait_msg.send()
thread_id = cl.user_session.get("user").identifier
config = {"configurable": {"thread_id": thread_id}}
# Add the Reset Chat button if this is the first message
existing_state = await graph.aget_state(config)
if existing_state and existing_state.values.get("messages", []) == []:
reset_chat_element = cl.CustomElement(name="ResetChatButton")
reset_msg = cl.Message(
content="",
elements=[reset_chat_element],
author="Assistant"
)
await reset_msg.send()
streaming_started = False
async for chunk, metadata in graph.astream(
input={"messages": [HumanMessage(content=message.content, name="user")]},
stream_mode="messages",
config=config
):
if chunk.content and metadata["langgraph_node"] == "reviewer_agent" and isinstance(chunk, AIMessageChunk):
if not streaming_started:
await pls_wait_msg.remove()
await cl.Message(f"*Thought for {round(time.time() - start, 2)} seconds...*").send()
streaming_started = True
await msg.stream_token(chunk.content)
await msg.send()
asyncio.run(run_app())