Ali2206 commited on
Commit
1fa1ea5
·
verified ·
1 Parent(s): cbc2a44

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -104
app.py CHANGED
@@ -1,21 +1,10 @@
1
- #import sys
2
- import os
3
- import pandas as pd
4
- import pdfplumber
5
- import json
6
- import gradio as gr
7
- from typing import List
8
  from concurrent.futures import ThreadPoolExecutor, as_completed
9
- import hashlib
10
- import shutil
11
- import time
12
  from threading import Thread
13
- import re
14
- import tempfile
15
 
16
- # Setup paths
17
  current_dir = os.path.dirname(os.path.abspath(__file__))
18
- src_path = os.path.abspath(os.path.join(current_dir, "src"))
19
  sys.path.insert(0, src_path)
20
 
21
  base_dir = "/data"
@@ -28,9 +17,10 @@ vllm_cache_dir = os.path.join(base_dir, "vllm_cache")
28
  for d in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
29
  os.makedirs(d, exist_ok=True)
30
 
 
31
  os.environ.update({
32
- "TRANSFORMERS_CACHE": model_cache_dir,
33
  "HF_HOME": model_cache_dir,
 
34
  "VLLM_CACHE_DIR": vllm_cache_dir,
35
  "TOKENIZERS_PARALLELISM": "false",
36
  "CUDA_LAUNCH_BLOCKING": "1"
@@ -38,38 +28,31 @@ os.environ.update({
38
 
39
  from txagent.txagent import TxAgent
40
 
41
- MEDICAL_KEYWORDS = {
42
- 'diagnosis', 'assessment', 'plan', 'results', 'medications',
43
- 'allergies', 'summary', 'impression', 'findings', 'recommendations'
44
- }
45
-
46
- def sanitize_utf8(text: str) -> str:
47
- return text.encode("utf-8", "ignore").decode("utf-8")
48
 
49
- def file_hash(path: str) -> str:
50
- with open(path, "rb") as f:
51
- return hashlib.md5(f.read()).hexdigest()
52
 
53
- def extract_priority_pages(file_path: str, max_pages: int = 20) -> str:
54
  try:
55
- text_chunks = []
56
  with pdfplumber.open(file_path) as pdf:
 
57
  for i, page in enumerate(pdf.pages[:3]):
58
- text_chunks.append(f"=== Page {i+1} ===\n{(page.extract_text() or '').strip()}")
59
  for i, page in enumerate(pdf.pages[3:max_pages], start=4):
60
- page_text = page.extract_text() or ""
61
- if any(re.search(rf'\b{kw}\b', page_text.lower()) for kw in MEDICAL_KEYWORDS):
62
- text_chunks.append(f"=== Page {i} ===\n{page_text.strip()}")
63
- return "\n\n".join(text_chunks)
64
  except Exception as e:
65
  return f"PDF processing error: {str(e)}"
66
 
67
- def convert_file_to_json(file_path: str, file_type: str) -> str:
68
  try:
69
  h = file_hash(file_path)
70
  cache_path = os.path.join(file_cache_dir, f"{h}.json")
71
- if os.path.exists(cache_path):
72
- return open(cache_path, "r", encoding="utf-8").read()
73
 
74
  if file_type == "pdf":
75
  text = extract_priority_pages(file_path)
@@ -77,39 +60,32 @@ def convert_file_to_json(file_path: str, file_type: str) -> str:
77
  Thread(target=full_pdf_processing, args=(file_path, h)).start()
78
  elif file_type == "csv":
79
  df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str, skip_blank_lines=False, on_bad_lines="skip")
80
- content = df.fillna("").astype(str).values.tolist()
81
- result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
82
  elif file_type in ["xls", "xlsx"]:
83
  try:
84
  df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
85
  except:
86
  df = pd.read_excel(file_path, engine="xlrd", header=None, dtype=str)
87
- content = df.fillna("").astype(str).values.tolist()
88
- result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
89
  else:
90
  return json.dumps({"error": f"Unsupported file type: {file_type}"})
91
 
92
- with open(cache_path, "w", encoding="utf-8") as f:
93
- f.write(result)
94
  return result
95
-
96
  except Exception as e:
97
  return json.dumps({"error": f"Error processing {os.path.basename(file_path)}: {str(e)}"})
98
 
99
- def full_pdf_processing(file_path: str, file_hash: str):
100
  try:
101
- cache_path = os.path.join(file_cache_dir, f"{file_hash}_full.json")
102
- if os.path.exists(cache_path):
103
- return
104
  with pdfplumber.open(file_path) as pdf:
105
  full_text = "\n".join([f"=== Page {i+1} ===\n{(page.extract_text() or '').strip()}" for i, page in enumerate(pdf.pages)])
106
  result = json.dumps({"filename": os.path.basename(file_path), "content": full_text, "status": "complete"})
107
- with open(cache_path, "w", encoding="utf-8") as f:
108
- f.write(result)
109
- with open(os.path.join(report_dir, f"{file_hash}_report.txt"), "w", encoding="utf-8") as out:
110
- out.write(full_text)
111
  except Exception as e:
112
- print(f"Background processing failed: {str(e)}")
113
 
114
  def init_agent():
115
  default_tool_path = os.path.abspath("data/new_tool.json")
@@ -124,36 +100,37 @@ def init_agent():
124
  force_finish=True,
125
  enable_checker=True,
126
  step_rag_num=8,
127
- seed=100,
128
- additional_default_tools=[],
129
  )
130
  agent.init_model()
131
  return agent
132
 
133
- def create_ui(agent: TxAgent):
 
 
 
 
 
 
 
134
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
135
- gr.Markdown("""
136
- <h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1>
137
- <h3 style='text-align: center;'>Identify potential oversights in patient care</h3>
138
- """)
139
 
140
  chatbot = gr.Chatbot(label="Analysis", height=600, type="messages")
141
- file_upload = gr.File(label="Upload Medical Records", file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple")
142
- msg_input = gr.Textbox(placeholder="Ask about potential oversights...", show_label=False)
143
  send_btn = gr.Button("Analyze", variant="primary")
144
- conversation_state = gr.State([])
145
- download_output = gr.File(label="Download Full Report")
146
 
147
- def analyze_potential_oversights(message: str, history: list, conversation: list, files: list):
148
  try:
149
- extracted_data = ""
150
- file_hash_value = ""
151
-
152
- if files and isinstance(files, list):
153
- with ThreadPoolExecutor(max_workers=4) as executor:
154
- futures = [executor.submit(convert_file_to_json, f.name, f.name.split(".")[-1].lower()) for f in files if hasattr(f, 'name')]
155
  extracted_data = "\n".join([sanitize_utf8(f.result()) for f in as_completed(futures)])
156
- file_hash_value = file_hash(files[0].name) if files else ""
157
 
158
  prompt = f"""Review these medical records and identify EXACTLY what might have been missed:
159
  1. List potential missed diagnoses
@@ -165,8 +142,8 @@ Medical Records:\n{extracted_data[:15000]}
165
 
166
  ### Potential Oversights:\n"""
167
 
168
- final_output = ""
169
- for chunk in agent.run_gradio_chat(
170
  message=prompt,
171
  history=[],
172
  temperature=0.2,
@@ -176,52 +153,31 @@ Medical Records:\n{extracted_data[:15000]}
176
  conversation=conversation
177
  ):
178
  if isinstance(chunk, str):
179
- final_output += chunk
180
  elif isinstance(chunk, list):
181
- final_output += "".join([c.content for c in chunk if hasattr(c, 'content')])
182
 
183
- cleaned = final_output.replace("[TOOL_CALLS]", "").strip()
184
  if not cleaned:
185
- cleaned = "No clear oversights identified. Recommend comprehensive review."
186
-
187
- updated_history = history + [
188
- {"role": "user", "content": message},
189
- {"role": "assistant", "content": cleaned}
190
- ]
191
 
192
- report_path = None
193
- if file_hash_value:
194
- possible_report = os.path.join(report_dir, f"{file_hash_value}_report.txt")
195
- if os.path.exists(possible_report):
196
- report_path = possible_report
197
 
 
198
  yield updated_history, report_path
199
-
200
  except Exception as e:
201
- updated_history = history + [{"role": "user", "content": message},
202
- {"role": "assistant", "content": f"❌ Analysis failed: {str(e)}"}]
203
  yield updated_history, None
204
 
205
- inputs = [msg_input, chatbot, conversation_state, file_upload]
206
- outputs = [chatbot, download_output]
207
- send_btn.click(analyze_potential_oversights, inputs=inputs, outputs=outputs)
208
- msg_input.submit(analyze_potential_oversights, inputs=inputs, outputs=outputs)
209
-
210
- gr.Examples([
211
- ["What might have been missed in this patient's treatment?"],
212
- ["Are there any medication conflicts in these records?"],
213
- ["What abnormal results require follow-up?"]
214
- ], inputs=msg_input)
215
 
216
  return demo
217
 
218
  if __name__ == "__main__":
219
- print("Initializing medical analysis agent...")
220
- agent = init_agent()
221
-
222
  print("Launching interface...")
223
- demo = create_ui(agent)
224
- demo.queue(api_open=False).launch(
225
  server_name="0.0.0.0",
226
  server_port=7860,
227
  show_error=True,
 
1
+ import sys, os, json, gradio as gr, pandas as pd, pdfplumber, hashlib, shutil, re, time
 
 
 
 
 
 
2
  from concurrent.futures import ThreadPoolExecutor, as_completed
 
 
 
3
  from threading import Thread
 
 
4
 
5
+ # Setup
6
  current_dir = os.path.dirname(os.path.abspath(__file__))
7
+ src_path = os.path.join(current_dir, "src")
8
  sys.path.insert(0, src_path)
9
 
10
  base_dir = "/data"
 
17
  for d in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
18
  os.makedirs(d, exist_ok=True)
19
 
20
+ # Hugging Face & Transformers cache
21
  os.environ.update({
 
22
  "HF_HOME": model_cache_dir,
23
+ "TRANSFORMERS_CACHE": model_cache_dir,
24
  "VLLM_CACHE_DIR": vllm_cache_dir,
25
  "TOKENIZERS_PARALLELISM": "false",
26
  "CUDA_LAUNCH_BLOCKING": "1"
 
28
 
29
  from txagent.txagent import TxAgent
30
 
31
+ MEDICAL_KEYWORDS = {'diagnosis', 'assessment', 'plan', 'results', 'medications',
32
+ 'allergies', 'summary', 'impression', 'findings', 'recommendations'}
 
 
 
 
 
33
 
34
+ def sanitize_utf8(text): return text.encode("utf-8", "ignore").decode("utf-8")
35
+ def file_hash(path): return hashlib.md5(open(path, "rb").read()).hexdigest()
 
36
 
37
+ def extract_priority_pages(file_path, max_pages=20):
38
  try:
 
39
  with pdfplumber.open(file_path) as pdf:
40
+ pages = []
41
  for i, page in enumerate(pdf.pages[:3]):
42
+ pages.append(f"=== Page {i+1} ===\n{(page.extract_text() or '').strip()}")
43
  for i, page in enumerate(pdf.pages[3:max_pages], start=4):
44
+ text = page.extract_text() or ""
45
+ if any(re.search(rf'\b{kw}\b', text.lower()) for kw in MEDICAL_KEYWORDS):
46
+ pages.append(f"=== Page {i} ===\n{text.strip()}")
47
+ return "\n\n".join(pages)
48
  except Exception as e:
49
  return f"PDF processing error: {str(e)}"
50
 
51
+ def convert_file_to_json(file_path, file_type):
52
  try:
53
  h = file_hash(file_path)
54
  cache_path = os.path.join(file_cache_dir, f"{h}.json")
55
+ if os.path.exists(cache_path): return open(cache_path, "r", encoding="utf-8").read()
 
56
 
57
  if file_type == "pdf":
58
  text = extract_priority_pages(file_path)
 
60
  Thread(target=full_pdf_processing, args=(file_path, h)).start()
61
  elif file_type == "csv":
62
  df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str, skip_blank_lines=False, on_bad_lines="skip")
63
+ result = json.dumps({"filename": os.path.basename(file_path), "rows": df.fillna("").astype(str).values.tolist()})
 
64
  elif file_type in ["xls", "xlsx"]:
65
  try:
66
  df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
67
  except:
68
  df = pd.read_excel(file_path, engine="xlrd", header=None, dtype=str)
69
+ result = json.dumps({"filename": os.path.basename(file_path), "rows": df.fillna("").astype(str).values.tolist()})
 
70
  else:
71
  return json.dumps({"error": f"Unsupported file type: {file_type}"})
72
 
73
+ with open(cache_path, "w", encoding="utf-8") as f: f.write(result)
 
74
  return result
 
75
  except Exception as e:
76
  return json.dumps({"error": f"Error processing {os.path.basename(file_path)}: {str(e)}"})
77
 
78
+ def full_pdf_processing(file_path, file_hash_value):
79
  try:
80
+ cache_path = os.path.join(file_cache_dir, f"{file_hash_value}_full.json")
81
+ if os.path.exists(cache_path): return
 
82
  with pdfplumber.open(file_path) as pdf:
83
  full_text = "\n".join([f"=== Page {i+1} ===\n{(page.extract_text() or '').strip()}" for i, page in enumerate(pdf.pages)])
84
  result = json.dumps({"filename": os.path.basename(file_path), "content": full_text, "status": "complete"})
85
+ with open(cache_path, "w", encoding="utf-8") as f: f.write(result)
86
+ with open(os.path.join(report_dir, f"{file_hash_value}_report.txt"), "w", encoding="utf-8") as out: out.write(full_text)
 
 
87
  except Exception as e:
88
+ print("PDF processing error:", e)
89
 
90
  def init_agent():
91
  default_tool_path = os.path.abspath("data/new_tool.json")
 
100
  force_finish=True,
101
  enable_checker=True,
102
  step_rag_num=8,
103
+ seed=100
 
104
  )
105
  agent.init_model()
106
  return agent
107
 
108
+ # Lazy load agent only on first use
109
+ agent_container = {"agent": None}
110
+ def get_agent():
111
+ if agent_container["agent"] is None:
112
+ agent_container["agent"] = init_agent()
113
+ return agent_container["agent"]
114
+
115
+ def create_ui(get_agent_func):
116
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
117
+ gr.Markdown("<h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1><h3 style='text-align: center;'>Identify potential oversights in patient care</h3>")
 
 
 
118
 
119
  chatbot = gr.Chatbot(label="Analysis", height=600, type="messages")
120
+ file_upload = gr.File(file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple")
121
+ msg_input = gr.Textbox(placeholder="Ask about potential oversights...")
122
  send_btn = gr.Button("Analyze", variant="primary")
123
+ state = gr.State([])
124
+ download_output = gr.File(label="Download Report")
125
 
126
+ def analyze(message, history, conversation, files):
127
  try:
128
+ extracted_data, file_hash_value = "", ""
129
+ if files:
130
+ with ThreadPoolExecutor(max_workers=4) as pool:
131
+ futures = [pool.submit(convert_file_to_json, f.name, f.name.split(".")[-1].lower()) for f in files]
 
 
132
  extracted_data = "\n".join([sanitize_utf8(f.result()) for f in as_completed(futures)])
133
+ file_hash_value = file_hash(files[0].name)
134
 
135
  prompt = f"""Review these medical records and identify EXACTLY what might have been missed:
136
  1. List potential missed diagnoses
 
142
 
143
  ### Potential Oversights:\n"""
144
 
145
+ final_response = ""
146
+ for chunk in get_agent_func().run_gradio_chat(
147
  message=prompt,
148
  history=[],
149
  temperature=0.2,
 
153
  conversation=conversation
154
  ):
155
  if isinstance(chunk, str):
156
+ final_response += chunk
157
  elif isinstance(chunk, list):
158
+ final_response += "".join([c.content for c in chunk if hasattr(c, "content")])
159
 
160
+ cleaned = final_response.replace("[TOOL_CALLS]", "").strip()
161
  if not cleaned:
162
+ cleaned = "No oversights found. Consider further review."
 
 
 
 
 
163
 
164
+ updated_history = history + [{"role": "user", "content": message}, {"role": "assistant", "content": cleaned}]
 
 
 
 
165
 
166
+ report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt") if file_hash_value and os.path.exists(os.path.join(report_dir, f"{file_hash_value}_report.txt")) else None
167
  yield updated_history, report_path
 
168
  except Exception as e:
169
+ updated_history = history + [{"role": "user", "content": message}, {"role": "assistant", "content": f"❌ Error: {str(e)}"}]
 
170
  yield updated_history, None
171
 
172
+ send_btn.click(analyze, inputs=[msg_input, chatbot, state, file_upload], outputs=[chatbot, download_output])
173
+ msg_input.submit(analyze, inputs=[msg_input, chatbot, state, file_upload], outputs=[chatbot, download_output])
 
 
 
 
 
 
 
 
174
 
175
  return demo
176
 
177
  if __name__ == "__main__":
 
 
 
178
  print("Launching interface...")
179
+ ui = create_ui(get_agent)
180
+ ui.queue(api_open=False).launch(
181
  server_name="0.0.0.0",
182
  server_port=7860,
183
  show_error=True,