Ali2206 commited on
Commit
e24be23
·
verified ·
1 Parent(s): 9c0d5a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -52
app.py CHANGED
@@ -4,53 +4,51 @@ import pandas as pd
4
  import pdfplumber
5
  import json
6
  import gradio as gr
7
- from typing import List
8
  from concurrent.futures import ThreadPoolExecutor, as_completed
9
  import hashlib
10
  import shutil
 
 
11
 
12
- # Fix: Add src to Python path
13
- sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src")))
 
 
 
 
 
 
 
 
 
14
 
15
- # ✅ Persist model cache to Hugging Face Space's /data directory
16
- model_cache_dir = "/data/txagent_models"
17
  os.makedirs(model_cache_dir, exist_ok=True)
 
 
 
18
  os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
19
  os.environ["HF_HOME"] = model_cache_dir
 
 
20
 
21
  from txagent.txagent import TxAgent
22
 
23
  def sanitize_utf8(text: str) -> str:
24
  return text.encode("utf-8", "ignore").decode("utf-8")
25
 
26
- def clean_final_response(text: str) -> str:
27
- cleaned = text.replace("[TOOL_CALLS]", "").strip()
28
- responses = cleaned.split("[Final Analysis]")
29
-
30
- if len(responses) <= 1:
31
- return f"<div style='padding:1em;border:1px solid #ccc;border-radius:12px;color:#fff;background:#353F54;'><p>{cleaned}</p></div>"
32
-
33
- panels = []
34
- for i, section in enumerate(responses[1:], 1):
35
- final = section.strip()
36
- panels.append(
37
- f"<div style='background:#2B2B2B;color:#E0E0E0;border-radius:12px;margin-bottom:1em;border:1px solid #888;'>"
38
- f"<div style='font-size:1.1em;font-weight:bold;padding:0.75em;background:#3A3A3A;color:#fff;border-radius:12px 12px 0 0;'>🧠 Final Analysis #{i}</div>"
39
- f"<div style='padding:1em;line-height:1.6;'>{final.replace(chr(10), '<br>')}</div>"
40
- f"</div>"
41
- )
42
- return "".join(panels)
43
-
44
- def file_hash(path):
45
  with open(path, "rb") as f:
46
  return hashlib.md5(f.read()).hexdigest()
47
 
 
 
 
 
48
  def convert_file_to_json(file_path: str, file_type: str) -> str:
49
  try:
50
- cache_dir = "/data/cache"
51
- os.makedirs(cache_dir, exist_ok=True)
52
  h = file_hash(file_path)
53
- cache_path = os.path.join(cache_dir, f"{h}.json")
54
 
55
  if os.path.exists(cache_path):
56
  return open(cache_path, "r", encoding="utf-8").read()
@@ -66,7 +64,8 @@ def convert_file_to_json(file_path: str, file_type: str) -> str:
66
  with pdfplumber.open(file_path) as pdf:
67
  text = "\n".join([page.extract_text() or "" for page in pdf.pages])
68
  result = json.dumps({"filename": os.path.basename(file_path), "content": text.strip()})
69
- open(cache_path, "w", encoding="utf-8").write(result)
 
70
  return result
71
  else:
72
  return json.dumps({"error": f"Unsupported file type: {file_type}"})
@@ -77,11 +76,49 @@ def convert_file_to_json(file_path: str, file_type: str) -> str:
77
  df = df.fillna("")
78
  content = df.astype(str).values.tolist()
79
  result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
80
- open(cache_path, "w", encoding="utf-8").write(result)
 
81
  return result
82
  except Exception as e:
83
  return json.dumps({"error": f"Error reading {os.path.basename(file_path)}: {str(e)}"})
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  def create_ui(agent: TxAgent):
86
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
87
  gr.Markdown("<h1 style='text-align: center;'>📋 CPS: Clinical Patient Support System</h1>")
@@ -97,60 +134,66 @@ def create_ui(agent: TxAgent):
97
  conversation_state = gr.State([])
98
 
99
  def handle_chat(message: str, history: list, conversation: list, uploaded_files: list, progress=gr.Progress()):
 
100
  try:
101
  history.append({"role": "user", "content": message})
102
  history.append({"role": "assistant", "content": "⏳ Processing your request..."})
103
  yield history
104
 
 
105
  extracted_text = ""
106
  if uploaded_files and isinstance(uploaded_files, list):
107
- for file in uploaded_files:
108
- if not hasattr(file, 'name'):
109
- continue
110
- path = file.name
111
- ext = path.split(".")[-1].lower()
112
- json_text = convert_file_to_json(path, ext)
113
- extracted_text += sanitize_utf8(json_text) + "\n"
114
 
115
  context = (
116
- "You are an expert clinical AI assistant. Review this patient's history, medications, and notes, and ONLY provide a final answer summarizing what the doctor might have missed."
 
 
117
  )
118
  chunked_prompt = f"{context}\n\n--- Patient Record ---\n{extracted_text}\n\n[Final Analysis]"
119
 
 
120
  generator = agent.run_gradio_chat(
121
  message=chunked_prompt,
122
  history=[],
123
  temperature=0.3,
124
- max_new_tokens=1024,
125
- max_token=8192,
126
  call_agent=False,
127
  conversation=conversation,
128
  uploaded_files=uploaded_files,
129
- max_round=30
130
  )
131
 
132
- final_response = ""
133
  for update in generator:
134
  if not update:
135
  continue
136
- if isinstance(update, list):
 
 
137
  for msg in update:
138
- if hasattr(msg, "content"):
139
- final_response += msg.content
140
- elif isinstance(update, str):
141
- final_response += update
142
-
143
- history[-1] = {"role": "assistant", "content": final_response.strip()}
144
- yield history
145
-
146
- cleaned = final_response.strip().replace("[TOOL_CALLS]", "").strip()
147
- history[-1] = {"role": "assistant", "content": cleaned or "❌ No response."}
 
 
148
  yield history
149
 
150
  except Exception as chat_error:
151
  print(f"Chat handling error: {chat_error}")
152
  history[-1] = {"role": "assistant", "content": "❌ An error occurred while processing your request."}
153
  yield history
 
 
154
 
155
  inputs = [message_input, chatbot, conversation_state, file_upload]
156
  send_button.click(fn=handle_chat, inputs=inputs, outputs=chatbot)
@@ -163,3 +206,32 @@ def create_ui(agent: TxAgent):
163
  ], inputs=message_input)
164
 
165
  return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import pdfplumber
5
  import json
6
  import gradio as gr
7
+ from typing import List, Optional
8
  from concurrent.futures import ThreadPoolExecutor, as_completed
9
  import hashlib
10
  import shutil
11
+ import time
12
+ from functools import lru_cache
13
 
14
+ # Environment and path setup
15
+ current_dir = os.path.dirname(os.path.abspath(__file__))
16
+ src_path = os.path.abspath(os.path.join(current_dir, "src"))
17
+ print(f"Adding to path: {src_path}")
18
+ sys.path.insert(0, src_path)
19
+
20
+ # Configure cache directories
21
+ base_dir = "/data"
22
+ model_cache_dir = os.path.join(base_dir, "txagent_models")
23
+ tool_cache_dir = os.path.join(base_dir, "tool_cache")
24
+ file_cache_dir = os.path.join(base_dir, "cache")
25
 
 
 
26
  os.makedirs(model_cache_dir, exist_ok=True)
27
+ os.makedirs(tool_cache_dir, exist_ok=True)
28
+ os.makedirs(file_cache_dir, exist_ok=True)
29
+
30
  os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
31
  os.environ["HF_HOME"] = model_cache_dir
32
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
33
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
34
 
35
  from txagent.txagent import TxAgent
36
 
37
  def sanitize_utf8(text: str) -> str:
38
  return text.encode("utf-8", "ignore").decode("utf-8")
39
 
40
+ def file_hash(path: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  with open(path, "rb") as f:
42
  return hashlib.md5(f.read()).hexdigest()
43
 
44
+ @lru_cache(maxsize=100)
45
+ def get_cached_response(prompt: str, file_hash: str) -> Optional[str]:
46
+ return None
47
+
48
  def convert_file_to_json(file_path: str, file_type: str) -> str:
49
  try:
 
 
50
  h = file_hash(file_path)
51
+ cache_path = os.path.join(file_cache_dir, f"{h}.json")
52
 
53
  if os.path.exists(cache_path):
54
  return open(cache_path, "r", encoding="utf-8").read()
 
64
  with pdfplumber.open(file_path) as pdf:
65
  text = "\n".join([page.extract_text() or "" for page in pdf.pages])
66
  result = json.dumps({"filename": os.path.basename(file_path), "content": text.strip()})
67
+ with open(cache_path, "w", encoding="utf-8") as f:
68
+ f.write(result)
69
  return result
70
  else:
71
  return json.dumps({"error": f"Unsupported file type: {file_type}"})
 
76
  df = df.fillna("")
77
  content = df.astype(str).values.tolist()
78
  result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
79
+ with open(cache_path, "w", encoding="utf-8") as f:
80
+ f.write(result)
81
  return result
82
  except Exception as e:
83
  return json.dumps({"error": f"Error reading {os.path.basename(file_path)}: {str(e)}"})
84
 
85
+ def convert_files_to_json_parallel(uploaded_files: list) -> str:
86
+ extracted_text = []
87
+ with ThreadPoolExecutor(max_workers=4) as executor:
88
+ futures = []
89
+ for file in uploaded_files:
90
+ if not hasattr(file, 'name'):
91
+ continue
92
+ path = file.name
93
+ ext = path.split(".")[-1].lower()
94
+ futures.append(executor.submit(convert_file_to_json, path, ext))
95
+
96
+ for future in as_completed(futures):
97
+ extracted_text.append(sanitize_utf8(future.result()))
98
+ return "\n".join(extracted_text)
99
+
100
+ def init_agent():
101
+ default_tool_path = os.path.abspath("data/new_tool.json")
102
+ target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
103
+ if not os.path.exists(target_tool_path):
104
+ shutil.copy(default_tool_path, target_tool_path)
105
+
106
+ model_name = "mims-harvard/TxAgent-T1-Llama-3.1-8B"
107
+ rag_model_name = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B"
108
+
109
+ agent = TxAgent(
110
+ model_name=model_name,
111
+ rag_model_name=rag_model_name,
112
+ tool_files_dict={"new_tool": target_tool_path},
113
+ force_finish=True,
114
+ enable_checker=True,
115
+ step_rag_num=8,
116
+ seed=100,
117
+ additional_default_tools=[]
118
+ )
119
+ agent.init_model()
120
+ return agent
121
+
122
  def create_ui(agent: TxAgent):
123
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
124
  gr.Markdown("<h1 style='text-align: center;'>📋 CPS: Clinical Patient Support System</h1>")
 
134
  conversation_state = gr.State([])
135
 
136
  def handle_chat(message: str, history: list, conversation: list, uploaded_files: list, progress=gr.Progress()):
137
+ start_time = time.time()
138
  try:
139
  history.append({"role": "user", "content": message})
140
  history.append({"role": "assistant", "content": "⏳ Processing your request..."})
141
  yield history
142
 
143
+ file_process_time = time.time()
144
  extracted_text = ""
145
  if uploaded_files and isinstance(uploaded_files, list):
146
+ extracted_text = convert_files_to_json_parallel(uploaded_files)
147
+ print(f"File processing took: {time.time() - file_process_time:.2f}s")
 
 
 
 
 
148
 
149
  context = (
150
+ "You are an expert clinical AI assistant. Review this patient's history, "
151
+ "medications, and notes, and ONLY provide a final answer summarizing "
152
+ "what the doctor might have missed."
153
  )
154
  chunked_prompt = f"{context}\n\n--- Patient Record ---\n{extracted_text}\n\n[Final Analysis]"
155
 
156
+ model_start = time.time()
157
  generator = agent.run_gradio_chat(
158
  message=chunked_prompt,
159
  history=[],
160
  temperature=0.3,
161
+ max_new_tokens=768,
162
+ max_token=4096,
163
  call_agent=False,
164
  conversation=conversation,
165
  uploaded_files=uploaded_files,
166
+ max_round=10
167
  )
168
 
169
+ final_response = []
170
  for update in generator:
171
  if not update:
172
  continue
173
+ if isinstance(update, str):
174
+ final_response.append(update)
175
+ elif isinstance(update, list):
176
  for msg in update:
177
+ if hasattr(msg, 'content'):
178
+ final_response.append(msg.content)
179
+
180
+ if len(final_response) % 3 == 0:
181
+ content = "".join(final_response).strip().replace("[TOOL_CALLS]", "")
182
+ history[-1] = {"role": "assistant", "content": content or "❌ No response."}
183
+ yield history
184
+
185
+ final_cleaned = "".join(final_response).strip().replace("[TOOL_CALLS]", "")
186
+ history[-1] = {"role": "assistant", "content": final_cleaned or "❌ No response."}
187
+ print("Final model response:\n", final_cleaned)
188
+ print(f"Model processing took: {time.time() - model_start:.2f}s")
189
  yield history
190
 
191
  except Exception as chat_error:
192
  print(f"Chat handling error: {chat_error}")
193
  history[-1] = {"role": "assistant", "content": "❌ An error occurred while processing your request."}
194
  yield history
195
+ finally:
196
+ print(f"Total request time: {time.time() - start_time:.2f}s")
197
 
198
  inputs = [message_input, chatbot, conversation_state, file_upload]
199
  send_button.click(fn=handle_chat, inputs=inputs, outputs=chatbot)
 
206
  ], inputs=message_input)
207
 
208
  return demo
209
+
210
+ if __name__ == "__main__":
211
+ print("Initializing agent...")
212
+ agent = init_agent()
213
+
214
+ print("Performing warm-up call...")
215
+ try:
216
+ warm_up = agent.run_gradio_chat(
217
+ message="Warm up",
218
+ history=[],
219
+ temperature=0.1,
220
+ max_new_tokens=10,
221
+ max_token=100,
222
+ call_agent=False,
223
+ conversation=[]
224
+ )
225
+ for _ in warm_up:
226
+ pass
227
+ except:
228
+ pass
229
+
230
+ print("Launching interface...")
231
+ demo = create_ui(agent)
232
+ demo.queue().launch(
233
+ server_name="0.0.0.0",
234
+ server_port=7860,
235
+ show_error=True,
236
+ share=True
237
+ )