Ali2206 commited on
Commit
9ef8abc
·
verified ·
1 Parent(s): ecb8e1d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -108
app.py CHANGED
@@ -1,36 +1,30 @@
1
- # Optimized app.py with lazy loading and preloading thread, fixed chatbot format and startup error handling
2
-
3
  import os
4
  import gradio as gr
5
- from typing import List
6
  import hashlib
7
  import time
8
  import json
9
- import re
10
  from concurrent.futures import ThreadPoolExecutor, as_completed
11
- from threading import Thread
12
  import pandas as pd
13
  import pdfplumber
14
 
15
- # Optimized environment setup
16
  os.environ.update({
17
  "HF_HOME": "/data/hf_cache",
18
- "VLLM_CACHE_DIR": "/data/vllm_cache",
19
- "TOKENIZERS_PARALLELISM": "false",
20
- "CUDA_LAUNCH_BLOCKING": "1"
21
  })
22
 
23
- # Create cache directories if they don't exist
24
  os.makedirs("/data/hf_cache", exist_ok=True)
25
- os.makedirs("/data/tool_cache", exist_ok=True)
26
  os.makedirs("/data/file_cache", exist_ok=True)
27
  os.makedirs("/data/reports", exist_ok=True)
28
- os.makedirs("/data/vllm_cache", exist_ok=True)
29
 
30
- # Lazy loading of heavy dependencies
31
- def lazy_load_agent():
32
- from txagent.txagent import TxAgent
33
 
 
 
34
  agent = TxAgent(
35
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
36
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
@@ -38,159 +32,161 @@ def lazy_load_agent():
38
  force_finish=True,
39
  enable_checker=True,
40
  step_rag_num=8,
41
- seed=100,
42
- additional_default_tools=[],
43
  )
44
  agent.init_model()
45
- return agent
46
-
47
- # Pre-load the agent in a separate thread
48
- agent = None
49
- def preload_agent():
50
- global agent
51
- agent = lazy_load_agent()
52
 
53
- Thread(target=preload_agent).start()
54
-
55
- # File processing functions
56
  def file_hash(path: str) -> str:
57
  with open(path, "rb") as f:
58
  return hashlib.md5(f.read()).hexdigest()
59
 
60
- def extract_priority_pages(file_path: str, max_pages: int = 10) -> str:
61
  try:
62
  with pdfplumber.open(file_path) as pdf:
63
- return "\n\n".join(
64
- f"=== Page {i+1} ===\n{(page.extract_text() or '').strip()}"
65
  for i, page in enumerate(pdf.pages[:max_pages])
66
  )
67
  except Exception as e:
68
- return f"PDF processing error: {str(e)}"
69
 
70
  def process_file(file_path: str, file_type: str) -> str:
71
  try:
72
- h = file_hash(file_path)
73
- cache_path = f"/data/file_cache/{h}.json"
74
-
75
  if os.path.exists(cache_path):
76
- with open(cache_path, "r", encoding="utf-8") as f:
77
  return f.read()
78
-
79
  if file_type == "pdf":
80
- content = extract_priority_pages(file_path)
81
- result = json.dumps({"filename": os.path.basename(file_path), "content": content})
82
  elif file_type == "csv":
83
- df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str)
84
- result = json.dumps({"filename": os.path.basename(file_path), "rows": df.fillna("").values.tolist()})
85
  elif file_type in ["xls", "xlsx"]:
86
- df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
87
- result = json.dumps({"filename": os.path.basename(file_path), "rows": df.fillna("").values.tolist()})
88
  else:
89
- return json.dumps({"error": f"Unsupported file type: {file_type}"})
90
 
91
- with open(cache_path, "w", encoding="utf-8") as f:
 
92
  f.write(result)
93
  return result
94
-
95
  except Exception as e:
96
  return json.dumps({"error": str(e)})
97
 
98
  def format_response(response: str) -> str:
99
  response = response.replace("[TOOL_CALLS]", "").strip()
100
- if "Based on the medical records provided" in response:
101
- parts = response.split("Based on the medical records provided")
102
- response = "Based on the medical records provided" + parts[-1]
103
-
104
- replacements = {
105
- "1. **Missed Diagnoses**:": "### 🔍 Missed Diagnoses",
106
- "2. **Medication Conflicts**:": "\n### 💊 Medication Conflicts",
107
- "3. **Incomplete Assessments**:": "\n### 📋 Incomplete Assessments",
108
- "4. **Abnormal Results Needing Follow-up**:": "\n### ⚠️ Abnormal Results Needing Follow-up",
109
- "Overall, the patient's medical records": "\n### 📝 Overall Assessment"
110
  }
111
-
112
- for old, new in replacements.items():
113
- response = response.replace(old, new)
114
-
115
  return response
116
 
117
- def analyze_files(message: str, history: List, files: List):
 
 
 
 
 
 
 
118
  try:
119
- while agent is None:
120
- time.sleep(0.1)
121
-
122
- history.append([message, None])
123
- yield history, None
124
-
125
  extracted_data = ""
126
  if files:
127
- with ThreadPoolExecutor(max_workers=4) as executor:
128
- futures = [executor.submit(process_file, f.name, f.name.split(".")[-1].lower())
129
- for f in files if hasattr(f, 'name')]
130
  extracted_data = "\n".join(f.result() for f in as_completed(futures))
131
-
132
  prompt = f"""Review these medical records:
133
  {extracted_data[:10000]}
134
 
135
- Identify:
136
- 1. Potential missed diagnoses
137
- 2. Medication conflicts
138
  3. Incomplete assessments
139
  4. Abnormal results needing follow-up
140
 
141
  Analysis:"""
142
-
143
  response = ""
144
  for chunk in agent.run_gradio_chat(
145
  message=prompt,
146
  history=[],
147
  temperature=0.2,
148
- max_new_tokens=800,
149
- max_token=3000
150
  ):
151
  if isinstance(chunk, str):
152
  response += chunk
153
  elif isinstance(chunk, list):
154
  response += "".join(getattr(c, 'content', '') for c in chunk)
155
-
156
- formatted = format_response(response)
157
- if formatted.strip():
158
- history[-1][1] = formatted
159
- yield history, None
160
-
161
- final_output = format_response(response) or "No clear oversights identified."
162
- history[-1][1] = final_output
163
  yield history, None
164
-
165
  except Exception as e:
166
- history[-1][1] = f"❌ Error: {str(e)}"
167
  yield history, None
168
 
169
- # UI definition
170
- with gr.Blocks(title="Clinical Oversight Assistant") as demo:
171
- gr.Markdown("""
172
- <div style='text-align: center;'>
173
- <h1>🩺 Clinical Oversight Assistant</h1>
174
- <p>Upload medical records to analyze for potential oversights in patient care</p>
175
- </div>
176
- """)
177
-
 
 
 
 
 
 
178
  with gr.Row():
179
  with gr.Column(scale=1):
180
- file_upload = gr.File(label="Upload Medical Records", file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple")
181
- query = gr.Textbox(label="Your Query", placeholder="Ask about potential oversights...", lines=3)
 
 
 
 
 
 
 
182
  submit = gr.Button("Analyze", variant="primary")
183
- gr.Examples([
184
- ["What potential diagnoses might have been missed?"],
185
- ["Are there any medication conflicts I should be aware of?"],
186
- ["What assessments appear incomplete in these records?"]
187
- ], inputs=query)
188
-
189
  with gr.Column(scale=2):
190
- chatbot = gr.Chatbot(label="Analysis Results", height=600, type="messages")
191
-
192
- submit.click(analyze_files, inputs=[query, chatbot, file_upload], outputs=[chatbot, gr.File(visible=False)])
193
- query.submit(analyze_files, inputs=[query, chatbot, file_upload], outputs=[chatbot, gr.File(visible=False)])
 
 
 
 
 
 
 
 
 
 
 
194
 
195
  if __name__ == "__main__":
196
- demo.queue().launch(server_name="0.0.0.0", server_port=7860, show_error=True)
 
 
 
 
 
1
+ import sys
 
2
  import os
3
  import gradio as gr
 
4
  import hashlib
5
  import time
6
  import json
 
7
  from concurrent.futures import ThreadPoolExecutor, as_completed
 
8
  import pandas as pd
9
  import pdfplumber
10
 
11
+ # Set up environment
12
  os.environ.update({
13
  "HF_HOME": "/data/hf_cache",
14
+ "TOKENIZERS_PARALLELISM": "false"
 
 
15
  })
16
 
17
+ # Create cache directories
18
  os.makedirs("/data/hf_cache", exist_ok=True)
 
19
  os.makedirs("/data/file_cache", exist_ok=True)
20
  os.makedirs("/data/reports", exist_ok=True)
 
21
 
22
+ # Import TxAgent after setting up environment
23
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "src")))
24
+ from txagent.txagent import TxAgent
25
 
26
+ # Initialize agent with error handling
27
+ try:
28
  agent = TxAgent(
29
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
30
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
 
32
  force_finish=True,
33
  enable_checker=True,
34
  step_rag_num=8,
35
+ seed=100
 
36
  )
37
  agent.init_model()
38
+ except Exception as e:
39
+ print(f"Failed to initialize agent: {str(e)}")
40
+ agent = None
 
 
 
 
41
 
 
 
 
42
  def file_hash(path: str) -> str:
43
  with open(path, "rb") as f:
44
  return hashlib.md5(f.read()).hexdigest()
45
 
46
+ def extract_text_from_pdf(file_path: str, max_pages: int = 10) -> str:
47
  try:
48
  with pdfplumber.open(file_path) as pdf:
49
+ return "\n".join(
50
+ f"Page {i+1}:\n{(page.extract_text() or '').strip()}\n"
51
  for i, page in enumerate(pdf.pages[:max_pages])
52
  )
53
  except Exception as e:
54
+ return f"PDF error: {str(e)}"
55
 
56
  def process_file(file_path: str, file_type: str) -> str:
57
  try:
58
+ cache_path = f"/data/file_cache/{file_hash(file_path)}.json"
 
 
59
  if os.path.exists(cache_path):
60
+ with open(cache_path, "r") as f:
61
  return f.read()
62
+
63
  if file_type == "pdf":
64
+ content = extract_text_from_pdf(file_path)
 
65
  elif file_type == "csv":
66
+ df = pd.read_csv(file_path, header=None, dtype=str, on_bad_lines="skip")
67
+ content = df.fillna("").to_string()
68
  elif file_type in ["xls", "xlsx"]:
69
+ df = pd.read_excel(file_path, header=None, dtype=str)
70
+ content = df.fillna("").to_string()
71
  else:
72
+ return json.dumps({"error": "Unsupported file type"})
73
 
74
+ result = json.dumps({"filename": os.path.basename(file_path), "content": content})
75
+ with open(cache_path, "w") as f:
76
  f.write(result)
77
  return result
 
78
  except Exception as e:
79
  return json.dumps({"error": str(e)})
80
 
81
  def format_response(response: str) -> str:
82
  response = response.replace("[TOOL_CALLS]", "").strip()
83
+ sections = {
84
+ "1. **Missed Diagnoses**:": "🔍 Missed Diagnoses",
85
+ "2. **Medication Conflicts**:": "💊 Medication Conflicts",
86
+ "3. **Incomplete Assessments**:": "📋 Incomplete Assessments",
87
+ "4. **Abnormal Results Needing Follow-up**:": "⚠️ Abnormal Results"
 
 
 
 
 
88
  }
89
+ for old, new in sections.items():
90
+ response = response.replace(old, f"\n### {new}\n")
 
 
91
  return response
92
 
93
+ def analyze(message: str, history: list, files: list):
94
+ if agent is None:
95
+ yield history + [(message, "Agent initialization failed. Please try again later.")], None
96
+ return
97
+
98
+ history.append((message, None))
99
+ yield history, None
100
+
101
  try:
 
 
 
 
 
 
102
  extracted_data = ""
103
  if files:
104
+ with ThreadPoolExecutor() as executor:
105
+ futures = [executor.submit(process_file, f.name, f.name.split(".")[-1])
106
+ for f in files if hasattr(f, 'name')]
107
  extracted_data = "\n".join(f.result() for f in as_completed(futures))
108
+
109
  prompt = f"""Review these medical records:
110
  {extracted_data[:10000]}
111
 
112
+ Identify potential issues:
113
+ 1. Missed diagnoses
114
+ 2. Medication conflicts
115
  3. Incomplete assessments
116
  4. Abnormal results needing follow-up
117
 
118
  Analysis:"""
119
+
120
  response = ""
121
  for chunk in agent.run_gradio_chat(
122
  message=prompt,
123
  history=[],
124
  temperature=0.2,
125
+ max_new_tokens=800
 
126
  ):
127
  if isinstance(chunk, str):
128
  response += chunk
129
  elif isinstance(chunk, list):
130
  response += "".join(getattr(c, 'content', '') for c in chunk)
131
+
132
+ history[-1] = (message, format_response(response))
133
+ yield history, None
134
+
135
+ history[-1] = (message, format_response(response))
 
 
 
136
  yield history, None
137
+
138
  except Exception as e:
139
+ history[-1] = (message, f"❌ Error: {str(e)}")
140
  yield history, None
141
 
142
+ # Create the interface
143
+ with gr.Blocks(
144
+ title="Clinical Oversight Assistant",
145
+ css="""
146
+ .gradio-container {
147
+ max-width: 1000px;
148
+ margin: auto;
149
+ }
150
+ .chatbot {
151
+ min-height: 500px;
152
+ }
153
+ """
154
+ ) as demo:
155
+ gr.Markdown("# 🩺 Clinical Oversight Assistant")
156
+
157
  with gr.Row():
158
  with gr.Column(scale=1):
159
+ files = gr.File(
160
+ label="Upload Medical Records",
161
+ file_types=[".pdf", ".csv", ".xlsx"],
162
+ file_count="multiple"
163
+ )
164
+ query = gr.Textbox(
165
+ label="Your Query",
166
+ placeholder="Ask about potential oversights..."
167
+ )
168
  submit = gr.Button("Analyze", variant="primary")
169
+
 
 
 
 
 
170
  with gr.Column(scale=2):
171
+ chatbot = gr.Chatbot(
172
+ label="Analysis Results",
173
+ show_copy_button=True
174
+ )
175
+
176
+ submit.click(
177
+ analyze,
178
+ inputs=[query, chatbot, files],
179
+ outputs=[chatbot, gr.File(visible=False)]
180
+ )
181
+ query.submit(
182
+ analyze,
183
+ inputs=[query, chatbot, files],
184
+ outputs=[chatbot, gr.File(visible=False)]
185
+ )
186
 
187
  if __name__ == "__main__":
188
+ demo.launch(
189
+ server_name="0.0.0.0",
190
+ server_port=7860,
191
+ show_error=True
192
+ )