Ali2206 commited on
Commit
dae38a2
·
verified ·
1 Parent(s): 5eac763

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +232 -7
app.py CHANGED
@@ -1,17 +1,242 @@
1
- # app.py (Gradio UI)
2
- import os
3
  import sys
 
 
 
 
4
  import gradio as gr
5
- from multiprocessing import freeze_support
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- from ui.ui_core import create_ui
8
- from backend.agent_instance import init_agent
9
 
10
  if __name__ == "__main__":
11
- freeze_support()
 
12
  agent = init_agent()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  demo = create_ui(agent)
14
- demo.queue().launch(
15
  server_name="0.0.0.0",
16
  server_port=7860,
17
  show_error=True,
 
 
 
1
  import sys
2
+ import os
3
+ import pandas as pd
4
+ import pdfplumber
5
+ import json
6
  import gradio as gr
7
+ from typing import List, Optional
8
+ from concurrent.futures import ThreadPoolExecutor, as_completed
9
+ import hashlib
10
+ import shutil
11
+ import time
12
+ from functools import lru_cache
13
+
14
+ # Environment and path setup
15
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src")))
16
+
17
+ # Configure cache directories
18
+ base_dir = "/data"
19
+ model_cache_dir = os.path.join(base_dir, "txagent_models")
20
+ tool_cache_dir = os.path.join(base_dir, "tool_cache")
21
+ file_cache_dir = os.path.join(base_dir, "cache")
22
+
23
+ os.makedirs(model_cache_dir, exist_ok=True)
24
+ os.makedirs(tool_cache_dir, exist_ok=True)
25
+ os.makedirs(file_cache_dir, exist_ok=True)
26
+
27
+ os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
28
+ os.environ["HF_HOME"] = model_cache_dir
29
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
30
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
31
+
32
+ from txagent.txagent import TxAgent
33
+
34
+ # Utility functions
35
+ def sanitize_utf8(text: str) -> str:
36
+ return text.encode("utf-8", "ignore").decode("utf-8")
37
+
38
+ def file_hash(path: str) -> str:
39
+ with open(path, "rb") as f:
40
+ return hashlib.md5(f.read()).hexdigest()
41
+
42
+ @lru_cache(maxsize=100)
43
+ def get_cached_response(prompt: str, file_hash: str) -> Optional[str]:
44
+ """Cache for frequent queries"""
45
+ return None # Implement actual cache lookup if needed
46
+
47
+ def convert_file_to_json(file_path: str, file_type: str) -> str:
48
+ try:
49
+ h = file_hash(file_path)
50
+ cache_path = os.path.join(file_cache_dir, f"{h}.json")
51
+
52
+ if os.path.exists(cache_path):
53
+ return open(cache_path, "r", encoding="utf-8").read()
54
+
55
+ if file_type == "csv":
56
+ df = pd.read_csv(file_path, encoding_errors="replace", header=None,
57
+ dtype=str, skip_blank_lines=False, on_bad_lines="skip")
58
+ elif file_type in ["xls", "xlsx"]:
59
+ try:
60
+ df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
61
+ except:
62
+ df = pd.read_excel(file_path, engine="xlrd", header=None, dtype=str)
63
+ elif file_type == "pdf":
64
+ with pdfplumber.open(file_path) as pdf:
65
+ text = "\n".join([page.extract_text() or "" for page in pdf.pages])
66
+ result = json.dumps({"filename": os.path.basename(file_path), "content": text.strip()})
67
+ with open(cache_path, "w", encoding="utf-8") as f:
68
+ f.write(result)
69
+ return result
70
+ else:
71
+ return json.dumps({"error": f"Unsupported file type: {file_type}"})
72
+
73
+ if df is None or df.empty:
74
+ return json.dumps({"warning": f"No data extracted from: {file_path}"})
75
+
76
+ df = df.fillna("")
77
+ content = df.astype(str).values.tolist()
78
+ result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
79
+ with open(cache_path, "w", encoding="utf-8") as f:
80
+ f.write(result)
81
+ return result
82
+ except Exception as e:
83
+ return json.dumps({"error": f"Error reading {os.path.basename(file_path)}: {str(e)}"})
84
+
85
+ def convert_files_to_json_parallel(uploaded_files: list) -> str:
86
+ """Process files in parallel using ThreadPool"""
87
+ extracted_text = []
88
+ with ThreadPoolExecutor(max_workers=4) as executor:
89
+ futures = []
90
+ for file in uploaded_files:
91
+ if not hasattr(file, 'name'):
92
+ continue
93
+ path = file.name
94
+ ext = path.split(".")[-1].lower()
95
+ futures.append(executor.submit(convert_file_to_json, path, ext))
96
+
97
+ for future in as_completed(futures):
98
+ extracted_text.append(sanitize_utf8(future.result()))
99
+ return "\n".join(extracted_text)
100
+
101
+ def init_agent():
102
+ """Initialize the TxAgent with optimized settings"""
103
+ # Copy default tool file if needed
104
+ default_tool_path = os.path.abspath("data/new_tool.json")
105
+ target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
106
+ if not os.path.exists(target_tool_path):
107
+ shutil.copy(default_tool_path, target_tool_path)
108
+
109
+ model_name = "mims-harvard/TxAgent-T1-Llama-3.1-8B"
110
+ rag_model_name = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B"
111
+
112
+ agent = TxAgent(
113
+ model_name=model_name,
114
+ rag_model_name=rag_model_name,
115
+ tool_files_dict={"new_tool": target_tool_path},
116
+ force_finish=True,
117
+ enable_checker=True,
118
+ step_rag_num=8, # Reduced from 10
119
+ seed=100,
120
+ additional_default_tools=[],
121
+ torch_dtype="auto",
122
+ device_map="auto",
123
+ load_in_4bit=False,
124
+ load_in_8bit=False
125
+ )
126
+ agent.init_model()
127
+ return agent
128
+
129
+ def create_ui(agent: TxAgent):
130
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
131
+ gr.Markdown("<h1 style='text-align: center;'>📋 CPS: Clinical Patient Support System</h1>")
132
+
133
+ chatbot = gr.Chatbot(label="CPS Assistant", height=600, type="messages")
134
+ file_upload = gr.File(
135
+ label="Upload Medical File",
136
+ file_types=[".pdf", ".txt", ".docx", ".jpg", ".png", ".csv", ".xls", ".xlsx"],
137
+ file_count="multiple"
138
+ )
139
+ message_input = gr.Textbox(placeholder="Ask a biomedical question or just upload the files...", show_label=False)
140
+ send_button = gr.Button("Send", variant="primary")
141
+ conversation_state = gr.State([])
142
+
143
+ def handle_chat(message: str, history: list, conversation: list, uploaded_files: list, progress=gr.Progress()):
144
+ start_time = time.time()
145
+ try:
146
+ history.append({"role": "user", "content": message})
147
+ history.append({"role": "assistant", "content": "⏳ Processing your request..."})
148
+ yield history
149
+
150
+ # File processing with timing
151
+ file_process_time = time.time()
152
+ extracted_text = ""
153
+ if uploaded_files and isinstance(uploaded_files, list):
154
+ extracted_text = convert_files_to_json_parallel(uploaded_files)
155
+ print(f"File processing took: {time.time() - file_process_time:.2f}s")
156
+
157
+ context = (
158
+ "You are an expert clinical AI assistant. Review this patient's history, "
159
+ "medications, and notes, and ONLY provide a final answer summarizing "
160
+ "what the doctor might have missed."
161
+ )
162
+ chunked_prompt = f"{context}\n\n--- Patient Record ---\n{extracted_text}\n\n[Final Analysis]"
163
+
164
+ # Model processing with timing
165
+ model_start = time.time()
166
+ generator = agent.run_gradio_chat(
167
+ message=chunked_prompt,
168
+ history=[],
169
+ temperature=0.3,
170
+ max_new_tokens=768, # Reduced from 1024
171
+ max_token=4096, # Reduced from 8192
172
+ call_agent=False,
173
+ conversation=conversation,
174
+ uploaded_files=uploaded_files,
175
+ max_round=10 # Reduced from 30
176
+ )
177
+
178
+ final_response = []
179
+ for update in generator:
180
+ if not update:
181
+ continue
182
+ if isinstance(update, str):
183
+ final_response.append(update)
184
+ elif isinstance(update, list):
185
+ final_response.extend(msg.content for msg in update if hasattr(msg, 'content'))
186
+
187
+ # Yield intermediate results periodically
188
+ if len(final_response) % 3 == 0: # More frequent updates
189
+ history[-1] = {"role": "assistant", "content": "".join(final_response).strip()}
190
+ yield history
191
+
192
+ history[-1] = {"role": "assistant", "content": "".join(final_response).strip() or "❌ No response."}
193
+ print(f"Model processing took: {time.time() - model_start:.2f}s")
194
+ yield history
195
+
196
+ except Exception as chat_error:
197
+ print(f"Chat handling error: {chat_error}")
198
+ history[-1] = {"role": "assistant", "content": "❌ An error occurred while processing your request."}
199
+ yield history
200
+ finally:
201
+ print(f"Total request time: {time.time() - start_time:.2f}s")
202
+
203
+ inputs = [message_input, chatbot, conversation_state, file_upload]
204
+ send_button.click(fn=handle_chat, inputs=inputs, outputs=chatbot)
205
+ message_input.submit(fn=handle_chat, inputs=inputs, outputs=chatbot)
206
+
207
+ gr.Examples([
208
+ ["Upload your medical form and ask what the doctor might've missed."],
209
+ ["This patient was treated with antibiotics for UTI. What else should we check?"],
210
+ ["Is there anything abnormal in the attached blood work report?"]
211
+ ], inputs=message_input)
212
 
213
+ return demo
 
214
 
215
  if __name__ == "__main__":
216
+ # Initialize agent and warm it up
217
+ print("Initializing agent...")
218
  agent = init_agent()
219
+
220
+ # Warm-up call
221
+ print("Performing warm-up call...")
222
+ try:
223
+ warm_up = agent.run_gradio_chat(
224
+ message="Warm up",
225
+ history=[],
226
+ temperature=0.1,
227
+ max_new_tokens=10,
228
+ max_token=100,
229
+ call_agent=False
230
+ )
231
+ for _ in warm_up:
232
+ pass
233
+ except:
234
+ pass
235
+
236
+ # Launch Gradio interface
237
+ print("Launching interface...")
238
  demo = create_ui(agent)
239
+ demo.queue(concurrency_count=3).launch(
240
  server_name="0.0.0.0",
241
  server_port=7860,
242
  show_error=True,