Ali2206 commited on
Commit
9c0d5a4
·
verified ·
1 Parent(s): 9b25f67

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -120
app.py CHANGED
@@ -4,51 +4,53 @@ import pandas as pd
4
  import pdfplumber
5
  import json
6
  import gradio as gr
7
- from typing import List, Optional
8
  from concurrent.futures import ThreadPoolExecutor, as_completed
9
  import hashlib
10
  import shutil
11
- import time
12
- from functools import lru_cache
13
 
14
- # Environment and path setup
15
- current_dir = os.path.dirname(os.path.abspath(__file__))
16
- src_path = os.path.abspath(os.path.join(current_dir, "src"))
17
- print(f"Adding to path: {src_path}")
18
- sys.path.insert(0, src_path)
19
-
20
- # Configure cache directories
21
- base_dir = "/data"
22
- model_cache_dir = os.path.join(base_dir, "txagent_models")
23
- tool_cache_dir = os.path.join(base_dir, "tool_cache")
24
- file_cache_dir = os.path.join(base_dir, "cache")
25
 
 
 
26
  os.makedirs(model_cache_dir, exist_ok=True)
27
- os.makedirs(tool_cache_dir, exist_ok=True)
28
- os.makedirs(file_cache_dir, exist_ok=True)
29
-
30
  os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
31
  os.environ["HF_HOME"] = model_cache_dir
32
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
33
- os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
34
 
35
  from txagent.txagent import TxAgent
36
 
37
  def sanitize_utf8(text: str) -> str:
38
  return text.encode("utf-8", "ignore").decode("utf-8")
39
 
40
- def file_hash(path: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  with open(path, "rb") as f:
42
  return hashlib.md5(f.read()).hexdigest()
43
 
44
- @lru_cache(maxsize=100)
45
- def get_cached_response(prompt: str, file_hash: str) -> Optional[str]:
46
- return None
47
-
48
  def convert_file_to_json(file_path: str, file_type: str) -> str:
49
  try:
 
 
50
  h = file_hash(file_path)
51
- cache_path = os.path.join(file_cache_dir, f"{h}.json")
52
 
53
  if os.path.exists(cache_path):
54
  return open(cache_path, "r", encoding="utf-8").read()
@@ -64,8 +66,7 @@ def convert_file_to_json(file_path: str, file_type: str) -> str:
64
  with pdfplumber.open(file_path) as pdf:
65
  text = "\n".join([page.extract_text() or "" for page in pdf.pages])
66
  result = json.dumps({"filename": os.path.basename(file_path), "content": text.strip()})
67
- with open(cache_path, "w", encoding="utf-8") as f:
68
- f.write(result)
69
  return result
70
  else:
71
  return json.dumps({"error": f"Unsupported file type: {file_type}"})
@@ -76,49 +77,11 @@ def convert_file_to_json(file_path: str, file_type: str) -> str:
76
  df = df.fillna("")
77
  content = df.astype(str).values.tolist()
78
  result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
79
- with open(cache_path, "w", encoding="utf-8") as f:
80
- f.write(result)
81
  return result
82
  except Exception as e:
83
  return json.dumps({"error": f"Error reading {os.path.basename(file_path)}: {str(e)}"})
84
 
85
- def convert_files_to_json_parallel(uploaded_files: list) -> str:
86
- extracted_text = []
87
- with ThreadPoolExecutor(max_workers=4) as executor:
88
- futures = []
89
- for file in uploaded_files:
90
- if not hasattr(file, 'name'):
91
- continue
92
- path = file.name
93
- ext = path.split(".")[-1].lower()
94
- futures.append(executor.submit(convert_file_to_json, path, ext))
95
-
96
- for future in as_completed(futures):
97
- extracted_text.append(sanitize_utf8(future.result()))
98
- return "\n".join(extracted_text)
99
-
100
- def init_agent():
101
- default_tool_path = os.path.abspath("data/new_tool.json")
102
- target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
103
- if not os.path.exists(target_tool_path):
104
- shutil.copy(default_tool_path, target_tool_path)
105
-
106
- model_name = "mims-harvard/TxAgent-T1-Llama-3.1-8B"
107
- rag_model_name = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B"
108
-
109
- agent = TxAgent(
110
- model_name=model_name,
111
- rag_model_name=rag_model_name,
112
- tool_files_dict={"new_tool": target_tool_path},
113
- force_finish=True,
114
- enable_checker=True,
115
- step_rag_num=8,
116
- seed=100,
117
- additional_default_tools=[]
118
- )
119
- agent.init_model()
120
- return agent
121
-
122
  def create_ui(agent: TxAgent):
123
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
124
  gr.Markdown("<h1 style='text-align: center;'>📋 CPS: Clinical Patient Support System</h1>")
@@ -134,61 +97,60 @@ def create_ui(agent: TxAgent):
134
  conversation_state = gr.State([])
135
 
136
  def handle_chat(message: str, history: list, conversation: list, uploaded_files: list, progress=gr.Progress()):
137
- start_time = time.time()
138
  try:
139
  history.append({"role": "user", "content": message})
140
  history.append({"role": "assistant", "content": "⏳ Processing your request..."})
141
  yield history
142
 
143
- file_process_time = time.time()
144
  extracted_text = ""
145
  if uploaded_files and isinstance(uploaded_files, list):
146
- extracted_text = convert_files_to_json_parallel(uploaded_files)
147
- print(f"File processing took: {time.time() - file_process_time:.2f}s")
 
 
 
 
 
148
 
149
  context = (
150
- "You are an expert clinical AI assistant. Review this patient's history, "
151
- "medications, and notes, and ONLY provide a final answer summarizing "
152
- "what the doctor might have missed."
153
  )
154
  chunked_prompt = f"{context}\n\n--- Patient Record ---\n{extracted_text}\n\n[Final Analysis]"
155
 
156
- model_start = time.time()
157
  generator = agent.run_gradio_chat(
158
  message=chunked_prompt,
159
  history=[],
160
  temperature=0.3,
161
- max_new_tokens=768,
162
- max_token=4096,
163
  call_agent=False,
164
  conversation=conversation,
165
  uploaded_files=uploaded_files,
166
- max_round=10
167
  )
168
 
169
- final_response = []
170
  for update in generator:
171
  if not update:
172
  continue
173
- if isinstance(update, str):
174
- final_response.append(update)
175
- elif isinstance(update, list):
176
- final_response.extend(msg.content for msg in update if hasattr(msg, 'content'))
177
-
178
- if len(final_response) % 3 == 0:
179
- history[-1] = {"role": "assistant", "content": "".join(final_response).strip()}
180
- yield history
181
-
182
- history[-1] = {"role": "assistant", "content": "".join(final_response).strip() or "❌ No response."}
183
- print(f"Model processing took: {time.time() - model_start:.2f}s")
 
184
  yield history
185
 
186
  except Exception as chat_error:
187
  print(f"Chat handling error: {chat_error}")
188
  history[-1] = {"role": "assistant", "content": "❌ An error occurred while processing your request."}
189
  yield history
190
- finally:
191
- print(f"Total request time: {time.time() - start_time:.2f}s")
192
 
193
  inputs = [message_input, chatbot, conversation_state, file_upload]
194
  send_button.click(fn=handle_chat, inputs=inputs, outputs=chatbot)
@@ -201,32 +163,3 @@ def create_ui(agent: TxAgent):
201
  ], inputs=message_input)
202
 
203
  return demo
204
-
205
- if __name__ == "__main__":
206
- print("Initializing agent...")
207
- agent = init_agent()
208
-
209
- print("Performing warm-up call...")
210
- try:
211
- warm_up = agent.run_gradio_chat(
212
- message="Warm up",
213
- history=[],
214
- temperature=0.1,
215
- max_new_tokens=10,
216
- max_token=100,
217
- call_agent=False,
218
- conversation=[]
219
- )
220
- for _ in warm_up:
221
- pass
222
- except:
223
- pass
224
-
225
- print("Launching interface...")
226
- demo = create_ui(agent)
227
- demo.queue().launch(
228
- server_name="0.0.0.0",
229
- server_port=7860,
230
- show_error=True,
231
- share=True
232
- )
 
4
  import pdfplumber
5
  import json
6
  import gradio as gr
7
+ from typing import List
8
  from concurrent.futures import ThreadPoolExecutor, as_completed
9
  import hashlib
10
  import shutil
 
 
11
 
12
+ # Fix: Add src to Python path
13
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src")))
 
 
 
 
 
 
 
 
 
14
 
15
+ # ✅ Persist model cache to Hugging Face Space's /data directory
16
+ model_cache_dir = "/data/txagent_models"
17
  os.makedirs(model_cache_dir, exist_ok=True)
 
 
 
18
  os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
19
  os.environ["HF_HOME"] = model_cache_dir
 
 
20
 
21
  from txagent.txagent import TxAgent
22
 
23
  def sanitize_utf8(text: str) -> str:
24
  return text.encode("utf-8", "ignore").decode("utf-8")
25
 
26
+ def clean_final_response(text: str) -> str:
27
+ cleaned = text.replace("[TOOL_CALLS]", "").strip()
28
+ responses = cleaned.split("[Final Analysis]")
29
+
30
+ if len(responses) <= 1:
31
+ return f"<div style='padding:1em;border:1px solid #ccc;border-radius:12px;color:#fff;background:#353F54;'><p>{cleaned}</p></div>"
32
+
33
+ panels = []
34
+ for i, section in enumerate(responses[1:], 1):
35
+ final = section.strip()
36
+ panels.append(
37
+ f"<div style='background:#2B2B2B;color:#E0E0E0;border-radius:12px;margin-bottom:1em;border:1px solid #888;'>"
38
+ f"<div style='font-size:1.1em;font-weight:bold;padding:0.75em;background:#3A3A3A;color:#fff;border-radius:12px 12px 0 0;'>🧠 Final Analysis #{i}</div>"
39
+ f"<div style='padding:1em;line-height:1.6;'>{final.replace(chr(10), '<br>')}</div>"
40
+ f"</div>"
41
+ )
42
+ return "".join(panels)
43
+
44
+ def file_hash(path):
45
  with open(path, "rb") as f:
46
  return hashlib.md5(f.read()).hexdigest()
47
 
 
 
 
 
48
  def convert_file_to_json(file_path: str, file_type: str) -> str:
49
  try:
50
+ cache_dir = "/data/cache"
51
+ os.makedirs(cache_dir, exist_ok=True)
52
  h = file_hash(file_path)
53
+ cache_path = os.path.join(cache_dir, f"{h}.json")
54
 
55
  if os.path.exists(cache_path):
56
  return open(cache_path, "r", encoding="utf-8").read()
 
66
  with pdfplumber.open(file_path) as pdf:
67
  text = "\n".join([page.extract_text() or "" for page in pdf.pages])
68
  result = json.dumps({"filename": os.path.basename(file_path), "content": text.strip()})
69
+ open(cache_path, "w", encoding="utf-8").write(result)
 
70
  return result
71
  else:
72
  return json.dumps({"error": f"Unsupported file type: {file_type}"})
 
77
  df = df.fillna("")
78
  content = df.astype(str).values.tolist()
79
  result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
80
+ open(cache_path, "w", encoding="utf-8").write(result)
 
81
  return result
82
  except Exception as e:
83
  return json.dumps({"error": f"Error reading {os.path.basename(file_path)}: {str(e)}"})
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  def create_ui(agent: TxAgent):
86
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
87
  gr.Markdown("<h1 style='text-align: center;'>📋 CPS: Clinical Patient Support System</h1>")
 
97
  conversation_state = gr.State([])
98
 
99
  def handle_chat(message: str, history: list, conversation: list, uploaded_files: list, progress=gr.Progress()):
 
100
  try:
101
  history.append({"role": "user", "content": message})
102
  history.append({"role": "assistant", "content": "⏳ Processing your request..."})
103
  yield history
104
 
 
105
  extracted_text = ""
106
  if uploaded_files and isinstance(uploaded_files, list):
107
+ for file in uploaded_files:
108
+ if not hasattr(file, 'name'):
109
+ continue
110
+ path = file.name
111
+ ext = path.split(".")[-1].lower()
112
+ json_text = convert_file_to_json(path, ext)
113
+ extracted_text += sanitize_utf8(json_text) + "\n"
114
 
115
  context = (
116
+ "You are an expert clinical AI assistant. Review this patient's history, medications, and notes, and ONLY provide a final answer summarizing what the doctor might have missed."
 
 
117
  )
118
  chunked_prompt = f"{context}\n\n--- Patient Record ---\n{extracted_text}\n\n[Final Analysis]"
119
 
 
120
  generator = agent.run_gradio_chat(
121
  message=chunked_prompt,
122
  history=[],
123
  temperature=0.3,
124
+ max_new_tokens=1024,
125
+ max_token=8192,
126
  call_agent=False,
127
  conversation=conversation,
128
  uploaded_files=uploaded_files,
129
+ max_round=30
130
  )
131
 
132
+ final_response = ""
133
  for update in generator:
134
  if not update:
135
  continue
136
+ if isinstance(update, list):
137
+ for msg in update:
138
+ if hasattr(msg, "content"):
139
+ final_response += msg.content
140
+ elif isinstance(update, str):
141
+ final_response += update
142
+
143
+ history[-1] = {"role": "assistant", "content": final_response.strip()}
144
+ yield history
145
+
146
+ cleaned = final_response.strip().replace("[TOOL_CALLS]", "").strip()
147
+ history[-1] = {"role": "assistant", "content": cleaned or "❌ No response."}
148
  yield history
149
 
150
  except Exception as chat_error:
151
  print(f"Chat handling error: {chat_error}")
152
  history[-1] = {"role": "assistant", "content": "❌ An error occurred while processing your request."}
153
  yield history
 
 
154
 
155
  inputs = [message_input, chatbot, conversation_state, file_upload]
156
  send_button.click(fn=handle_chat, inputs=inputs, outputs=chatbot)
 
163
  ], inputs=message_input)
164
 
165
  return demo