tosanoob commited on
Commit
e48425a
·
1 Parent(s): ac36f56

fix: change chat-app cache logic

Browse files
Files changed (2) hide show
  1. .streamlit/config.toml +2 -5
  2. pages/chat_app.py +96 -56
.streamlit/config.toml CHANGED
@@ -3,11 +3,10 @@ primaryColor = "#3498db"
3
  backgroundColor = "#f9f9f9"
4
  secondaryBackgroundColor = "#ffffff"
5
  textColor = "#333333"
6
- font = "sans serif"
7
 
8
  [server]
9
  enableCORS = false
10
- enableXsrfProtection = true
11
  runOnSave = true
12
  maxUploadSize = 200
13
  headless = true
@@ -21,6 +20,4 @@ serverPort = 8501
21
  fastReruns = true
22
 
23
  [client]
24
- showErrorDetails = false
25
- caching = true
26
- displayEnabled = true
 
3
  backgroundColor = "#f9f9f9"
4
  secondaryBackgroundColor = "#ffffff"
5
  textColor = "#333333"
 
6
 
7
  [server]
8
  enableCORS = false
9
+ enableXsrfProtection = false
10
  runOnSave = true
11
  maxUploadSize = 200
12
  headless = true
 
20
  fastReruns = true
21
 
22
  [client]
23
+ showErrorDetails = false
 
 
pages/chat_app.py CHANGED
@@ -14,16 +14,11 @@ from datetime import datetime
14
  # --- 1. INITIAL CONFIGURATION & STATE INITIALIZATION ---
15
  load_dotenv()
16
 
17
- # Initialize session state for this page
18
- if "chat_app_initialized" not in st.session_state:
19
- st.session_state.chat_app_initialized = True
20
-
21
  # Set page config consistent with other pages
22
  st.set_page_config(
23
  page_title="AI Financial Dashboard",
24
  page_icon="📊",
25
- layout="wide",
26
- initial_sidebar_state="expanded"
27
  )
28
 
29
  # Configure Gemini API
@@ -82,10 +77,32 @@ def find_and_process_stock(query: str):
82
  elif len(found_data) > 1: return {"status": "MULTIPLE_STOCKS_FOUND", "data": found_data[:5]}
83
  else: return {"status": "NO_STOCKS_FOUND"}
84
  def get_smart_time_series(symbol: str, time_period: str):
 
 
 
 
 
 
 
 
 
 
85
  logic_map = {'intraday': {'interval': '15min', 'outputsize': 120}, '1_week': {'interval': '1h', 'outputsize': 40}, '1_month': {'interval': '1day', 'outputsize': 22}, '6_months': {'interval': '1day', 'outputsize': 120}, '1_year': {'interval': '1week', 'outputsize': 52}}
86
  params = logic_map.get(time_period)
87
  if not params: return {"error": f"Time period '{time_period}' is not valid."}
88
- return st.session_state.td_api.get_time_series(symbol=symbol, **params)
 
 
 
 
 
 
 
 
 
 
 
 
89
  def find_conversion_path_bfs(start, end):
90
  if start not in FOREX_GRAPH or end not in FOREX_GRAPH: return None
91
  q = deque([(start, [start])]); visited = {start}
@@ -127,6 +144,8 @@ SYSTEM_INSTRUCTION = """You are the AI brain controlling an Interactive Financia
127
  GOLDEN RULES:
128
  1. **UNDERSTAND FIRST, CALL LATER:**
129
  * **Company Name:** When a user enters a company name (e.g., "Vingroup Corporation", "Apple"), your FIRST task is to use the `find_and_process_stock` tool to identify the official stock symbol.
 
 
130
  * **Country Name:** When a user enters a country name for currency (e.g., "Vietnamese currency"), you must infer the 3-letter currency code ("VND") BEFORE calling the `perform_currency_conversion` tool.
131
  2. **ACT AND NOTIFY:** Your role is to execute commands and report briefly.
132
  * **Found 1 symbol:** "I've found [Company Name] ([Symbol]) and automatically added it to your watchlist and chart."
@@ -233,58 +252,79 @@ def render_currency_tab():
233
  else: st.error(f"Error: {res.get('error', 'Unknown')}")
234
 
235
  # --- 6. MAIN APP LAYOUT & CONTROL FLOW ---
236
- # Use container to avoid UI stacking issues
237
- main_container = st.container()
238
-
239
- with main_container:
240
- st.title("📈 AI Financial Dashboard")
241
 
242
- col1, col2 = st.columns([1, 1])
 
243
 
244
- with col2:
245
- right_column_container = st.container(height=600)
246
- with right_column_container:
247
- tab_names = ['Stock Watchlist', 'Time Charts', 'Currency Converter']
248
- try: default_index = tab_names.index(st.session_state.active_tab)
249
- except ValueError: default_index = 0
250
- st.session_state.active_tab = tab_names[default_index]
251
-
252
- tab1, tab2, tab3 = st.tabs(tab_names)
253
- with tab1: render_watchlist_tab()
254
- with tab2: render_timeseries_tab()
255
- with tab3: render_currency_tab()
256
 
257
- with col1:
258
- chat_container = st.container(height=600)
259
- with chat_container:
260
- for message in st.session_state.chat_history:
261
- with st.chat_message(message["role"]):
262
- st.markdown(message["parts"])
 
 
 
 
 
 
 
263
 
264
- user_prompt = st.chat_input("Ask AI to control the dashboard...")
265
- if user_prompt:
266
- st.session_state.chat_history.append({"role": "user", "parts": user_prompt})
267
- st.rerun()
 
268
 
269
- if st.session_state.chat_history and st.session_state.chat_history[-1]["role"] == "user":
270
- last_user_prompt = st.session_state.chat_history[-1]["parts"]
271
-
272
- with chat_container:
273
- with st.chat_message("model"):
274
- with st.spinner("🤖 AI executing command..."):
275
- response = st.session_state.chat_session.send_message(last_user_prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  tool_calls = [part.function_call for part in response.candidates[0].content.parts if part.function_call]
277
- while tool_calls:
278
- tool_responses = []
279
- for call in tool_calls:
280
- func_name = call.name; func_args = {k: v for k, v in call.args.items()}
281
- if func_name in AVAILABLE_FUNCTIONS:
282
- tool_result = AVAILABLE_FUNCTIONS[func_name](**func_args)
283
- tool_responses.append(glm.Part(function_response=glm.FunctionResponse(name=func_name, response={'result': tool_result})))
284
- else:
285
- tool_responses.append(glm.Part(function_response=glm.FunctionResponse(name=func_name, response={'error': f"Function '{func_name}' not found."})))
286
- response = st.session_state.chat_session.send_message(glm.Content(parts=tool_responses))
287
- tool_calls = [part.function_call for part in response.candidates[0].content.parts if part.function_call]
288
-
289
- st.session_state.chat_history.append({"role": "model", "parts": response.text})
290
- st.rerun()
 
 
 
 
 
 
 
 
 
14
  # --- 1. INITIAL CONFIGURATION & STATE INITIALIZATION ---
15
  load_dotenv()
16
 
 
 
 
 
17
  # Set page config consistent with other pages
18
  st.set_page_config(
19
  page_title="AI Financial Dashboard",
20
  page_icon="📊",
21
+ layout="wide"
 
22
  )
23
 
24
  # Configure Gemini API
 
77
  elif len(found_data) > 1: return {"status": "MULTIPLE_STOCKS_FOUND", "data": found_data[:5]}
78
  else: return {"status": "NO_STOCKS_FOUND"}
79
  def get_smart_time_series(symbol: str, time_period: str):
80
+ # Kiểm tra nếu symbol chưa có trong watchlist thì thêm vào trước
81
+ if symbol not in st.session_state.stock_watchlist:
82
+ # Tìm thông tin cổ phiếu và thêm vào watchlist
83
+ results = st.session_state.td_api.get_stocks(symbol=symbol)
84
+ found_data = results.get('data', [])
85
+ if found_data:
86
+ stock_info = found_data[0]
87
+ st.session_state.stock_watchlist[symbol] = stock_info
88
+ st.session_state.active_tab = 'Time Charts'
89
+
90
  logic_map = {'intraday': {'interval': '15min', 'outputsize': 120}, '1_week': {'interval': '1h', 'outputsize': 40}, '1_month': {'interval': '1day', 'outputsize': 22}, '6_months': {'interval': '1day', 'outputsize': 120}, '1_year': {'interval': '1week', 'outputsize': 52}}
91
  params = logic_map.get(time_period)
92
  if not params: return {"error": f"Time period '{time_period}' is not valid."}
93
+
94
+ result = st.session_state.td_api.get_time_series(symbol=symbol, **params)
95
+
96
+ # Nếu kết quả thành công, cập nhật cache
97
+ if 'values' in result:
98
+ df = pd.DataFrame(result['values'])
99
+ df['datetime'] = pd.to_datetime(df['datetime'])
100
+ df['close'] = pd.to_numeric(df['close'])
101
+ if symbol not in st.session_state.timeseries_cache:
102
+ st.session_state.timeseries_cache[symbol] = {}
103
+ st.session_state.timeseries_cache[symbol][time_period] = df.sort_values('datetime').set_index('datetime')
104
+
105
+ return result
106
  def find_conversion_path_bfs(start, end):
107
  if start not in FOREX_GRAPH or end not in FOREX_GRAPH: return None
108
  q = deque([(start, [start])]); visited = {start}
 
144
  GOLDEN RULES:
145
  1. **UNDERSTAND FIRST, CALL LATER:**
146
  * **Company Name:** When a user enters a company name (e.g., "Vingroup Corporation", "Apple"), your FIRST task is to use the `find_and_process_stock` tool to identify the official stock symbol.
147
+ * **Stock Symbol:** When a user directly provides a stock symbol (e.g., "AAPL", "VNM"), use `find_and_process_stock` first to confirm and add it to the watchlist.
148
+ * **Time Period Request:** When user asks for specific time period (e.g., "last year", "last month"), first make sure the stock symbol is processed with `find_and_process_stock`, then use `get_smart_time_series` with appropriate time_period.
149
  * **Country Name:** When a user enters a country name for currency (e.g., "Vietnamese currency"), you must infer the 3-letter currency code ("VND") BEFORE calling the `perform_currency_conversion` tool.
150
  2. **ACT AND NOTIFY:** Your role is to execute commands and report briefly.
151
  * **Found 1 symbol:** "I've found [Company Name] ([Symbol]) and automatically added it to your watchlist and chart."
 
252
  else: st.error(f"Error: {res.get('error', 'Unknown')}")
253
 
254
  # --- 6. MAIN APP LAYOUT & CONTROL FLOW ---
255
+ st.title("📈 AI Financial Dashboard")
 
 
 
 
256
 
257
+ # Chia bố cục thành hai cột
258
+ col1, col2 = st.columns([1, 1])
259
 
260
+ # Cột bên trái cho chat với AI
261
+ with col1:
262
+ chat_container = st.container(height=600)
263
+ with chat_container:
264
+ for message in st.session_state.chat_history:
265
+ with st.chat_message(message["role"]):
266
+ st.markdown(message["parts"])
 
 
 
 
 
267
 
268
+ # Cột bên phải cho tab biểu đồ và dữ liệu
269
+ with col2:
270
+ right_column_container = st.container(height=600)
271
+ with right_column_container:
272
+ tab_names = ['Stock Watchlist', 'Time Charts', 'Currency Converter']
273
+ try: default_index = tab_names.index(st.session_state.active_tab)
274
+ except ValueError: default_index = 0
275
+ st.session_state.active_tab = tab_names[default_index]
276
+
277
+ tab1, tab2, tab3 = st.tabs(tab_names)
278
+ with tab1: render_watchlist_tab()
279
+ with tab2: render_timeseries_tab()
280
+ with tab3: render_currency_tab()
281
 
282
+ # Input chat nằm dưới cùng
283
+ user_prompt = st.chat_input("Ask AI to control the dashboard...")
284
+ if user_prompt:
285
+ st.session_state.chat_history.append({"role": "user", "parts": user_prompt})
286
+ st.rerun()
287
 
288
+ # Xử câu hỏi của người dùng và hiển thị phản hồi AI
289
+ if st.session_state.chat_history and st.session_state.chat_history[-1]["role"] == "user":
290
+ last_user_prompt = st.session_state.chat_history[-1]["parts"]
291
+
292
+ with chat_container:
293
+ with st.chat_message("model"):
294
+ with st.spinner("🤖 AI executing command..."):
295
+ response = st.session_state.chat_session.send_message(last_user_prompt)
296
+ tool_calls = [part.function_call for part in response.candidates[0].content.parts if part.function_call]
297
+
298
+ while tool_calls:
299
+ tool_responses = []
300
+ for call in tool_calls:
301
+ func_name = call.name; func_args = {k: v for k, v in call.args.items()}
302
+ if func_name in AVAILABLE_FUNCTIONS:
303
+ tool_result = AVAILABLE_FUNCTIONS[func_name](**func_args)
304
+ tool_responses.append(glm.Part(function_response=glm.FunctionResponse(name=func_name, response={'result': tool_result})))
305
+ else:
306
+ tool_responses.append(glm.Part(function_response=glm.FunctionResponse(name=func_name, response={'error': f"Function '{func_name}' not found."})))
307
+ response = st.session_state.chat_session.send_message(glm.Content(parts=tool_responses))
308
  tool_calls = [part.function_call for part in response.candidates[0].content.parts if part.function_call]
309
+
310
+ # Tìm kiếm từ khóa thời gian trong prompt của người dùng
311
+ old_period = st.session_state.active_timeseries_period
312
+ if last_user_prompt and "last year" in last_user_prompt.lower():
313
+ st.session_state.active_timeseries_period = "1_year"
314
+ elif last_user_prompt and "last 6 months" in last_user_prompt.lower():
315
+ st.session_state.active_timeseries_period = "6_months"
316
+ elif last_user_prompt and "last month" in last_user_prompt.lower():
317
+ st.session_state.active_timeseries_period = "1_month"
318
+ elif last_user_prompt and "last week" in last_user_prompt.lower():
319
+ st.session_state.active_timeseries_period = "1_week"
320
+
321
+ # Nếu thời gian thay đổi và có cổ phiếu trong watchlist, cập nhật dữ liệu
322
+ if old_period != st.session_state.active_timeseries_period and st.session_state.stock_watchlist:
323
+ new_period = st.session_state.active_timeseries_period
324
+ for symbol in st.session_state.stock_watchlist.keys():
325
+ if symbol not in st.session_state.timeseries_cache or new_period not in st.session_state.timeseries_cache[symbol]:
326
+ ts_data = get_smart_time_series(symbol, new_period)
327
+ st.session_state.active_tab = 'Time Charts'
328
+
329
+ st.session_state.chat_history.append({"role": "model", "parts": response.text})
330
+ st.rerun()