Spaces:
Sleeping
Sleeping
fix: change chat-app cache logic
Browse files- .streamlit/config.toml +2 -5
- pages/chat_app.py +96 -56
.streamlit/config.toml
CHANGED
@@ -3,11 +3,10 @@ primaryColor = "#3498db"
|
|
3 |
backgroundColor = "#f9f9f9"
|
4 |
secondaryBackgroundColor = "#ffffff"
|
5 |
textColor = "#333333"
|
6 |
-
font = "sans serif"
|
7 |
|
8 |
[server]
|
9 |
enableCORS = false
|
10 |
-
enableXsrfProtection =
|
11 |
runOnSave = true
|
12 |
maxUploadSize = 200
|
13 |
headless = true
|
@@ -21,6 +20,4 @@ serverPort = 8501
|
|
21 |
fastReruns = true
|
22 |
|
23 |
[client]
|
24 |
-
showErrorDetails = false
|
25 |
-
caching = true
|
26 |
-
displayEnabled = true
|
|
|
3 |
backgroundColor = "#f9f9f9"
|
4 |
secondaryBackgroundColor = "#ffffff"
|
5 |
textColor = "#333333"
|
|
|
6 |
|
7 |
[server]
|
8 |
enableCORS = false
|
9 |
+
enableXsrfProtection = false
|
10 |
runOnSave = true
|
11 |
maxUploadSize = 200
|
12 |
headless = true
|
|
|
20 |
fastReruns = true
|
21 |
|
22 |
[client]
|
23 |
+
showErrorDetails = false
|
|
|
|
pages/chat_app.py
CHANGED
@@ -14,16 +14,11 @@ from datetime import datetime
|
|
14 |
# --- 1. INITIAL CONFIGURATION & STATE INITIALIZATION ---
|
15 |
load_dotenv()
|
16 |
|
17 |
-
# Initialize session state for this page
|
18 |
-
if "chat_app_initialized" not in st.session_state:
|
19 |
-
st.session_state.chat_app_initialized = True
|
20 |
-
|
21 |
# Set page config consistent with other pages
|
22 |
st.set_page_config(
|
23 |
page_title="AI Financial Dashboard",
|
24 |
page_icon="📊",
|
25 |
-
layout="wide"
|
26 |
-
initial_sidebar_state="expanded"
|
27 |
)
|
28 |
|
29 |
# Configure Gemini API
|
@@ -82,10 +77,32 @@ def find_and_process_stock(query: str):
|
|
82 |
elif len(found_data) > 1: return {"status": "MULTIPLE_STOCKS_FOUND", "data": found_data[:5]}
|
83 |
else: return {"status": "NO_STOCKS_FOUND"}
|
84 |
def get_smart_time_series(symbol: str, time_period: str):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
logic_map = {'intraday': {'interval': '15min', 'outputsize': 120}, '1_week': {'interval': '1h', 'outputsize': 40}, '1_month': {'interval': '1day', 'outputsize': 22}, '6_months': {'interval': '1day', 'outputsize': 120}, '1_year': {'interval': '1week', 'outputsize': 52}}
|
86 |
params = logic_map.get(time_period)
|
87 |
if not params: return {"error": f"Time period '{time_period}' is not valid."}
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
def find_conversion_path_bfs(start, end):
|
90 |
if start not in FOREX_GRAPH or end not in FOREX_GRAPH: return None
|
91 |
q = deque([(start, [start])]); visited = {start}
|
@@ -127,6 +144,8 @@ SYSTEM_INSTRUCTION = """You are the AI brain controlling an Interactive Financia
|
|
127 |
GOLDEN RULES:
|
128 |
1. **UNDERSTAND FIRST, CALL LATER:**
|
129 |
* **Company Name:** When a user enters a company name (e.g., "Vingroup Corporation", "Apple"), your FIRST task is to use the `find_and_process_stock` tool to identify the official stock symbol.
|
|
|
|
|
130 |
* **Country Name:** When a user enters a country name for currency (e.g., "Vietnamese currency"), you must infer the 3-letter currency code ("VND") BEFORE calling the `perform_currency_conversion` tool.
|
131 |
2. **ACT AND NOTIFY:** Your role is to execute commands and report briefly.
|
132 |
* **Found 1 symbol:** "I've found [Company Name] ([Symbol]) and automatically added it to your watchlist and chart."
|
@@ -233,58 +252,79 @@ def render_currency_tab():
|
|
233 |
else: st.error(f"Error: {res.get('error', 'Unknown')}")
|
234 |
|
235 |
# --- 6. MAIN APP LAYOUT & CONTROL FLOW ---
|
236 |
-
|
237 |
-
main_container = st.container()
|
238 |
-
|
239 |
-
with main_container:
|
240 |
-
st.title("📈 AI Financial Dashboard")
|
241 |
|
242 |
-
|
|
|
243 |
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
tab1, tab2, tab3 = st.tabs(tab_names)
|
253 |
-
with tab1: render_watchlist_tab()
|
254 |
-
with tab2: render_timeseries_tab()
|
255 |
-
with tab3: render_currency_tab()
|
256 |
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
263 |
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
|
|
268 |
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
276 |
tool_calls = [part.function_call for part in response.candidates[0].content.parts if part.function_call]
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
# --- 1. INITIAL CONFIGURATION & STATE INITIALIZATION ---
|
15 |
load_dotenv()
|
16 |
|
|
|
|
|
|
|
|
|
17 |
# Set page config consistent with other pages
|
18 |
st.set_page_config(
|
19 |
page_title="AI Financial Dashboard",
|
20 |
page_icon="📊",
|
21 |
+
layout="wide"
|
|
|
22 |
)
|
23 |
|
24 |
# Configure Gemini API
|
|
|
77 |
elif len(found_data) > 1: return {"status": "MULTIPLE_STOCKS_FOUND", "data": found_data[:5]}
|
78 |
else: return {"status": "NO_STOCKS_FOUND"}
|
79 |
def get_smart_time_series(symbol: str, time_period: str):
|
80 |
+
# Kiểm tra nếu symbol chưa có trong watchlist thì thêm vào trước
|
81 |
+
if symbol not in st.session_state.stock_watchlist:
|
82 |
+
# Tìm thông tin cổ phiếu và thêm vào watchlist
|
83 |
+
results = st.session_state.td_api.get_stocks(symbol=symbol)
|
84 |
+
found_data = results.get('data', [])
|
85 |
+
if found_data:
|
86 |
+
stock_info = found_data[0]
|
87 |
+
st.session_state.stock_watchlist[symbol] = stock_info
|
88 |
+
st.session_state.active_tab = 'Time Charts'
|
89 |
+
|
90 |
logic_map = {'intraday': {'interval': '15min', 'outputsize': 120}, '1_week': {'interval': '1h', 'outputsize': 40}, '1_month': {'interval': '1day', 'outputsize': 22}, '6_months': {'interval': '1day', 'outputsize': 120}, '1_year': {'interval': '1week', 'outputsize': 52}}
|
91 |
params = logic_map.get(time_period)
|
92 |
if not params: return {"error": f"Time period '{time_period}' is not valid."}
|
93 |
+
|
94 |
+
result = st.session_state.td_api.get_time_series(symbol=symbol, **params)
|
95 |
+
|
96 |
+
# Nếu kết quả thành công, cập nhật cache
|
97 |
+
if 'values' in result:
|
98 |
+
df = pd.DataFrame(result['values'])
|
99 |
+
df['datetime'] = pd.to_datetime(df['datetime'])
|
100 |
+
df['close'] = pd.to_numeric(df['close'])
|
101 |
+
if symbol not in st.session_state.timeseries_cache:
|
102 |
+
st.session_state.timeseries_cache[symbol] = {}
|
103 |
+
st.session_state.timeseries_cache[symbol][time_period] = df.sort_values('datetime').set_index('datetime')
|
104 |
+
|
105 |
+
return result
|
106 |
def find_conversion_path_bfs(start, end):
|
107 |
if start not in FOREX_GRAPH or end not in FOREX_GRAPH: return None
|
108 |
q = deque([(start, [start])]); visited = {start}
|
|
|
144 |
GOLDEN RULES:
|
145 |
1. **UNDERSTAND FIRST, CALL LATER:**
|
146 |
* **Company Name:** When a user enters a company name (e.g., "Vingroup Corporation", "Apple"), your FIRST task is to use the `find_and_process_stock` tool to identify the official stock symbol.
|
147 |
+
* **Stock Symbol:** When a user directly provides a stock symbol (e.g., "AAPL", "VNM"), use `find_and_process_stock` first to confirm and add it to the watchlist.
|
148 |
+
* **Time Period Request:** When user asks for specific time period (e.g., "last year", "last month"), first make sure the stock symbol is processed with `find_and_process_stock`, then use `get_smart_time_series` with appropriate time_period.
|
149 |
* **Country Name:** When a user enters a country name for currency (e.g., "Vietnamese currency"), you must infer the 3-letter currency code ("VND") BEFORE calling the `perform_currency_conversion` tool.
|
150 |
2. **ACT AND NOTIFY:** Your role is to execute commands and report briefly.
|
151 |
* **Found 1 symbol:** "I've found [Company Name] ([Symbol]) and automatically added it to your watchlist and chart."
|
|
|
252 |
else: st.error(f"Error: {res.get('error', 'Unknown')}")
|
253 |
|
254 |
# --- 6. MAIN APP LAYOUT & CONTROL FLOW ---
|
255 |
+
st.title("📈 AI Financial Dashboard")
|
|
|
|
|
|
|
|
|
256 |
|
257 |
+
# Chia bố cục thành hai cột
|
258 |
+
col1, col2 = st.columns([1, 1])
|
259 |
|
260 |
+
# Cột bên trái cho chat với AI
|
261 |
+
with col1:
|
262 |
+
chat_container = st.container(height=600)
|
263 |
+
with chat_container:
|
264 |
+
for message in st.session_state.chat_history:
|
265 |
+
with st.chat_message(message["role"]):
|
266 |
+
st.markdown(message["parts"])
|
|
|
|
|
|
|
|
|
|
|
267 |
|
268 |
+
# Cột bên phải cho tab biểu đồ và dữ liệu
|
269 |
+
with col2:
|
270 |
+
right_column_container = st.container(height=600)
|
271 |
+
with right_column_container:
|
272 |
+
tab_names = ['Stock Watchlist', 'Time Charts', 'Currency Converter']
|
273 |
+
try: default_index = tab_names.index(st.session_state.active_tab)
|
274 |
+
except ValueError: default_index = 0
|
275 |
+
st.session_state.active_tab = tab_names[default_index]
|
276 |
+
|
277 |
+
tab1, tab2, tab3 = st.tabs(tab_names)
|
278 |
+
with tab1: render_watchlist_tab()
|
279 |
+
with tab2: render_timeseries_tab()
|
280 |
+
with tab3: render_currency_tab()
|
281 |
|
282 |
+
# Input chat nằm dưới cùng
|
283 |
+
user_prompt = st.chat_input("Ask AI to control the dashboard...")
|
284 |
+
if user_prompt:
|
285 |
+
st.session_state.chat_history.append({"role": "user", "parts": user_prompt})
|
286 |
+
st.rerun()
|
287 |
|
288 |
+
# Xử lý câu hỏi của người dùng và hiển thị phản hồi AI
|
289 |
+
if st.session_state.chat_history and st.session_state.chat_history[-1]["role"] == "user":
|
290 |
+
last_user_prompt = st.session_state.chat_history[-1]["parts"]
|
291 |
+
|
292 |
+
with chat_container:
|
293 |
+
with st.chat_message("model"):
|
294 |
+
with st.spinner("🤖 AI executing command..."):
|
295 |
+
response = st.session_state.chat_session.send_message(last_user_prompt)
|
296 |
+
tool_calls = [part.function_call for part in response.candidates[0].content.parts if part.function_call]
|
297 |
+
|
298 |
+
while tool_calls:
|
299 |
+
tool_responses = []
|
300 |
+
for call in tool_calls:
|
301 |
+
func_name = call.name; func_args = {k: v for k, v in call.args.items()}
|
302 |
+
if func_name in AVAILABLE_FUNCTIONS:
|
303 |
+
tool_result = AVAILABLE_FUNCTIONS[func_name](**func_args)
|
304 |
+
tool_responses.append(glm.Part(function_response=glm.FunctionResponse(name=func_name, response={'result': tool_result})))
|
305 |
+
else:
|
306 |
+
tool_responses.append(glm.Part(function_response=glm.FunctionResponse(name=func_name, response={'error': f"Function '{func_name}' not found."})))
|
307 |
+
response = st.session_state.chat_session.send_message(glm.Content(parts=tool_responses))
|
308 |
tool_calls = [part.function_call for part in response.candidates[0].content.parts if part.function_call]
|
309 |
+
|
310 |
+
# Tìm kiếm từ khóa thời gian trong prompt của người dùng
|
311 |
+
old_period = st.session_state.active_timeseries_period
|
312 |
+
if last_user_prompt and "last year" in last_user_prompt.lower():
|
313 |
+
st.session_state.active_timeseries_period = "1_year"
|
314 |
+
elif last_user_prompt and "last 6 months" in last_user_prompt.lower():
|
315 |
+
st.session_state.active_timeseries_period = "6_months"
|
316 |
+
elif last_user_prompt and "last month" in last_user_prompt.lower():
|
317 |
+
st.session_state.active_timeseries_period = "1_month"
|
318 |
+
elif last_user_prompt and "last week" in last_user_prompt.lower():
|
319 |
+
st.session_state.active_timeseries_period = "1_week"
|
320 |
+
|
321 |
+
# Nếu thời gian thay đổi và có cổ phiếu trong watchlist, cập nhật dữ liệu
|
322 |
+
if old_period != st.session_state.active_timeseries_period and st.session_state.stock_watchlist:
|
323 |
+
new_period = st.session_state.active_timeseries_period
|
324 |
+
for symbol in st.session_state.stock_watchlist.keys():
|
325 |
+
if symbol not in st.session_state.timeseries_cache or new_period not in st.session_state.timeseries_cache[symbol]:
|
326 |
+
ts_data = get_smart_time_series(symbol, new_period)
|
327 |
+
st.session_state.active_tab = 'Time Charts'
|
328 |
+
|
329 |
+
st.session_state.chat_history.append({"role": "model", "parts": response.text})
|
330 |
+
st.rerun()
|