Ali2206 commited on
Commit
cf5094d
·
verified ·
1 Parent(s): 833a580

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -80
app.py CHANGED
@@ -1,12 +1,10 @@
1
- #import sys, os, json, gradio as gr, pandas as pd, pdfplumber, hashlib, shutil, re, time
 
 
2
  from concurrent.futures import ThreadPoolExecutor, as_completed
3
  from threading import Thread
4
 
5
- # Setup
6
- current_dir = os.path.dirname(os.path.abspath(__file__))
7
- src_path = os.path.join(current_dir, "src")
8
- sys.path.insert(0, src_path)
9
-
10
  base_dir = "/data"
11
  model_cache_dir = os.path.join(base_dir, "txagent_models")
12
  tool_cache_dir = os.path.join(base_dir, "tool_cache")
@@ -17,7 +15,7 @@ vllm_cache_dir = os.path.join(base_dir, "vllm_cache")
17
  for d in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
18
  os.makedirs(d, exist_ok=True)
19
 
20
- # Hugging Face & Transformers cache
21
  os.environ.update({
22
  "HF_HOME": model_cache_dir,
23
  "TRANSFORMERS_CACHE": model_cache_dir,
@@ -26,13 +24,21 @@ os.environ.update({
26
  "CUDA_LAUNCH_BLOCKING": "1"
27
  })
28
 
29
- from txagent.txagent import TxAgent
 
 
30
 
31
- MEDICAL_KEYWORDS = {'diagnosis', 'assessment', 'plan', 'results', 'medications',
32
- 'allergies', 'summary', 'impression', 'findings', 'recommendations'}
 
 
 
 
 
33
 
34
- def sanitize_utf8(text): return text.encode("utf-8", "ignore").decode("utf-8")
35
  def file_hash(path): return hashlib.md5(open(path, "rb").read()).hexdigest()
 
 
36
 
37
  def extract_priority_pages(file_path, max_pages=20):
38
  try:
@@ -42,7 +48,7 @@ def extract_priority_pages(file_path, max_pages=20):
42
  pages.append(f"=== Page {i+1} ===\n{(page.extract_text() or '').strip()}")
43
  for i, page in enumerate(pdf.pages[3:max_pages], start=4):
44
  text = page.extract_text() or ""
45
- if any(re.search(rf'\b{kw}\b', text.lower()) for kw in MEDICAL_KEYWORDS):
46
  pages.append(f"=== Page {i} ===\n{text.strip()}")
47
  return "\n\n".join(pages)
48
  except Exception as e:
@@ -59,43 +65,36 @@ def convert_file_to_json(file_path, file_type):
59
  result = json.dumps({"filename": os.path.basename(file_path), "content": text, "status": "initial"})
60
  Thread(target=full_pdf_processing, args=(file_path, h)).start()
61
  elif file_type == "csv":
62
- df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str, skip_blank_lines=False, on_bad_lines="skip")
63
- result = json.dumps({"filename": os.path.basename(file_path), "rows": df.fillna("").astype(str).values.tolist()})
64
  elif file_type in ["xls", "xlsx"]:
65
- try:
66
- df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
67
- except:
68
- df = pd.read_excel(file_path, engine="xlrd", header=None, dtype=str)
69
- result = json.dumps({"filename": os.path.basename(file_path), "rows": df.fillna("").astype(str).values.tolist()})
70
  else:
71
  return json.dumps({"error": f"Unsupported file type: {file_type}"})
72
 
73
  with open(cache_path, "w", encoding="utf-8") as f: f.write(result)
74
  return result
75
  except Exception as e:
76
- return json.dumps({"error": f"Error processing {os.path.basename(file_path)}: {str(e)}"})
77
 
78
- def full_pdf_processing(file_path, file_hash_value):
79
  try:
80
- cache_path = os.path.join(file_cache_dir, f"{file_hash_value}_full.json")
81
  if os.path.exists(cache_path): return
82
  with pdfplumber.open(file_path) as pdf:
83
- full_text = "\n".join([f"=== Page {i+1} ===\n{(page.extract_text() or '').strip()}" for i, page in enumerate(pdf.pages)])
84
- result = json.dumps({"filename": os.path.basename(file_path), "content": full_text, "status": "complete"})
85
- with open(cache_path, "w", encoding="utf-8") as f: f.write(result)
86
- with open(os.path.join(report_dir, f"{file_hash_value}_report.txt"), "w", encoding="utf-8") as out: out.write(full_text)
87
- except Exception as e:
88
- print("PDF processing error:", e)
89
 
90
  def init_agent():
91
- default_tool_path = os.path.abspath("data/new_tool.json")
92
  target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
93
  if not os.path.exists(target_tool_path):
94
- shutil.copy(default_tool_path, target_tool_path)
95
 
96
  agent = TxAgent(
97
- model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
98
- rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
99
  tool_files_dict={"new_tool": target_tool_path},
100
  force_finish=True,
101
  enable_checker=True,
@@ -105,82 +104,59 @@ def init_agent():
105
  agent.init_model()
106
  return agent
107
 
108
- # Lazy load agent only on first use
109
  agent_container = {"agent": None}
110
  def get_agent():
111
  if agent_container["agent"] is None:
112
  agent_container["agent"] = init_agent()
113
  return agent_container["agent"]
114
 
115
- def create_ui(get_agent_func):
116
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
117
- gr.Markdown("<h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1><h3 style='text-align: center;'>Identify potential oversights in patient care</h3>")
118
-
119
- chatbot = gr.Chatbot(label="Analysis", height=600, type="messages")
120
  file_upload = gr.File(file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple")
121
- msg_input = gr.Textbox(placeholder="Ask about potential oversights...")
122
  send_btn = gr.Button("Analyze", variant="primary")
123
  state = gr.State([])
124
- download_output = gr.File(label="Download Report")
125
 
126
  def analyze(message, history, conversation, files):
127
  try:
128
- extracted_data, file_hash_value = "", ""
129
  if files:
130
- with ThreadPoolExecutor(max_workers=4) as pool:
131
  futures = [pool.submit(convert_file_to_json, f.name, f.name.split(".")[-1].lower()) for f in files]
132
- extracted_data = "\n".join([sanitize_utf8(f.result()) for f in as_completed(futures)])
133
- file_hash_value = file_hash(files[0].name)
134
-
135
- prompt = f"""Review these medical records and identify EXACTLY what might have been missed:
136
- 1. List potential missed diagnoses
137
- 2. Flag any medication conflicts
138
- 3. Note incomplete assessments
139
- 4. Highlight abnormal results needing follow-up
140
 
141
- Medical Records:\n{extracted_data[:15000]}
142
-
143
- ### Potential Oversights:\n"""
 
 
144
 
 
 
145
  final_response = ""
146
- for chunk in get_agent_func().run_gradio_chat(
147
- message=prompt,
148
- history=[],
149
- temperature=0.2,
150
- max_new_tokens=1024,
151
- max_token=4096,
152
- call_agent=False,
153
- conversation=conversation
154
- ):
155
- if isinstance(chunk, str):
156
- final_response += chunk
157
- elif isinstance(chunk, list):
158
- final_response += "".join([c.content for c in chunk if hasattr(c, "content")])
159
-
160
  cleaned = final_response.replace("[TOOL_CALLS]", "").strip()
161
- if not cleaned:
162
- cleaned = "No oversights found. Consider further review."
163
-
164
  updated_history = history + [{"role": "user", "content": message}, {"role": "assistant", "content": cleaned}]
165
-
166
- report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt") if file_hash_value and os.path.exists(os.path.join(report_dir, f"{file_hash_value}_report.txt")) else None
167
- yield updated_history, report_path
168
  except Exception as e:
169
- updated_history = history + [{"role": "user", "content": message}, {"role": "assistant", "content": f"❌ Error: {str(e)}"}]
170
- yield updated_history, None
171
-
172
- send_btn.click(analyze, inputs=[msg_input, chatbot, state, file_upload], outputs=[chatbot, download_output])
173
- msg_input.submit(analyze, inputs=[msg_input, chatbot, state, file_upload], outputs=[chatbot, download_output])
174
 
 
 
175
  return demo
176
 
177
  if __name__ == "__main__":
178
- print("Launching interface...")
179
- ui = create_ui(get_agent)
180
  ui.queue(api_open=False).launch(
181
  server_name="0.0.0.0",
182
  server_port=7860,
183
  show_error=True,
184
  allowed_paths=["/data/reports"],
185
  share=False
186
- )
 
1
+ # Fully optimized app.py for Hugging Face Space with persistent 150GB storage
2
+
3
+ import sys, os, json, gradio as gr, pandas as pd, pdfplumber, hashlib, shutil, re, time
4
  from concurrent.futures import ThreadPoolExecutor, as_completed
5
  from threading import Thread
6
 
7
+ # Use /data for persistent HF storage
 
 
 
 
8
  base_dir = "/data"
9
  model_cache_dir = os.path.join(base_dir, "txagent_models")
10
  tool_cache_dir = os.path.join(base_dir, "tool_cache")
 
15
  for d in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
16
  os.makedirs(d, exist_ok=True)
17
 
18
+ # Set persistent HF + VLLM cache
19
  os.environ.update({
20
  "HF_HOME": model_cache_dir,
21
  "TRANSFORMERS_CACHE": model_cache_dir,
 
24
  "CUDA_LAUNCH_BLOCKING": "1"
25
  })
26
 
27
+ # Force local loading only
28
+ LOCAL_TXAGENT_PATH = os.path.join(model_cache_dir, "mims-harvard", "TxAgent-T1-Llama-3.1-8B")
29
+ LOCAL_RAG_PATH = os.path.join(model_cache_dir, "mims-harvard", "ToolRAG-T1-GTE-Qwen2-1.5B")
30
 
31
+ # Manual download using snapshot_download (only if needed)
32
+ # from huggingface_hub import snapshot_download
33
+ # snapshot_download("mims-harvard/TxAgent-T1-Llama-3.1-8B", local_dir=LOCAL_TXAGENT_PATH, local_dir_use_symlinks=False)
34
+ # snapshot_download("mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", local_dir=LOCAL_RAG_PATH, local_dir_use_symlinks=False)
35
+
36
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "src")))
37
+ from txagent.txagent import TxAgent
38
 
 
39
  def file_hash(path): return hashlib.md5(open(path, "rb").read()).hexdigest()
40
+ def sanitize_utf8(text): return text.encode("utf-8", "ignore").decode("utf-8")
41
+ MEDICAL_KEYWORDS = {"diagnosis", "assessment", "plan", "results", "medications", "summary", "findings"}
42
 
43
  def extract_priority_pages(file_path, max_pages=20):
44
  try:
 
48
  pages.append(f"=== Page {i+1} ===\n{(page.extract_text() or '').strip()}")
49
  for i, page in enumerate(pdf.pages[3:max_pages], start=4):
50
  text = page.extract_text() or ""
51
+ if any(re.search(rf'\\b{kw}\\b', text.lower()) for kw in MEDICAL_KEYWORDS):
52
  pages.append(f"=== Page {i} ===\n{text.strip()}")
53
  return "\n\n".join(pages)
54
  except Exception as e:
 
65
  result = json.dumps({"filename": os.path.basename(file_path), "content": text, "status": "initial"})
66
  Thread(target=full_pdf_processing, args=(file_path, h)).start()
67
  elif file_type == "csv":
68
+ df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str)
69
+ result = json.dumps({"filename": os.path.basename(file_path), "rows": df.fillna('').astype(str).values.tolist()})
70
  elif file_type in ["xls", "xlsx"]:
71
+ df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
72
+ result = json.dumps({"filename": os.path.basename(file_path), "rows": df.fillna('').astype(str).values.tolist()})
 
 
 
73
  else:
74
  return json.dumps({"error": f"Unsupported file type: {file_type}"})
75
 
76
  with open(cache_path, "w", encoding="utf-8") as f: f.write(result)
77
  return result
78
  except Exception as e:
79
+ return json.dumps({"error": str(e)})
80
 
81
+ def full_pdf_processing(file_path, h):
82
  try:
83
+ cache_path = os.path.join(file_cache_dir, f"{h}_full.json")
84
  if os.path.exists(cache_path): return
85
  with pdfplumber.open(file_path) as pdf:
86
+ full_text = "\n".join([f"=== Page {i+1} ===\n{(p.extract_text() or '').strip()}" for i, p in enumerate(pdf.pages)])
87
+ with open(cache_path, "w", encoding="utf-8") as f: f.write(json.dumps({"content": full_text}))
88
+ except: pass
 
 
 
89
 
90
  def init_agent():
 
91
  target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
92
  if not os.path.exists(target_tool_path):
93
+ shutil.copy(os.path.abspath("data/new_tool.json"), target_tool_path)
94
 
95
  agent = TxAgent(
96
+ model_name=LOCAL_TXAGENT_PATH,
97
+ rag_model_name=LOCAL_RAG_PATH,
98
  tool_files_dict={"new_tool": target_tool_path},
99
  force_finish=True,
100
  enable_checker=True,
 
104
  agent.init_model()
105
  return agent
106
 
107
+ # Lazy load
108
  agent_container = {"agent": None}
109
  def get_agent():
110
  if agent_container["agent"] is None:
111
  agent_container["agent"] = init_agent()
112
  return agent_container["agent"]
113
 
114
+ def create_ui():
115
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
116
+ gr.Markdown("""<h1 style='text-align:center;'>🩺 Clinical Oversight Assistant</h1>""")
117
+ chatbot = gr.Chatbot(label="Analysis", height=600)
118
+ msg_input = gr.Textbox(placeholder="Ask a question about the patient...")
119
  file_upload = gr.File(file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple")
 
120
  send_btn = gr.Button("Analyze", variant="primary")
121
  state = gr.State([])
 
122
 
123
  def analyze(message, history, conversation, files):
124
  try:
125
+ extracted, hval = "", ""
126
  if files:
127
+ with ThreadPoolExecutor(max_workers=3) as pool:
128
  futures = [pool.submit(convert_file_to_json, f.name, f.name.split(".")[-1].lower()) for f in files]
129
+ extracted = "\n".join([sanitize_utf8(f.result()) for f in as_completed(futures)])
130
+ hval = file_hash(files[0].name)
 
 
 
 
 
 
131
 
132
+ prompt = f"""Review these medical records and identify exactly what might have been missed:
133
+ 1. Missed diagnoses
134
+ 2. Medication conflicts
135
+ 3. Incomplete assessments
136
+ 4. Abnormal results needing follow-up
137
 
138
+ Medical Records:\n{extracted[:15000]}
139
+ """
140
  final_response = ""
141
+ for chunk in get_agent().run_gradio_chat(prompt, history=[], temperature=0.2, max_new_tokens=1024, max_token=4096, call_agent=False, conversation=conversation):
142
+ if isinstance(chunk, str): final_response += chunk
143
+ elif isinstance(chunk, list): final_response += "".join([c.content for c in chunk if hasattr(c, 'content')])
 
 
 
 
 
 
 
 
 
 
 
144
  cleaned = final_response.replace("[TOOL_CALLS]", "").strip()
 
 
 
145
  updated_history = history + [{"role": "user", "content": message}, {"role": "assistant", "content": cleaned}]
146
+ return updated_history, None
 
 
147
  except Exception as e:
148
+ return history + [{"role": "user", "content": message}, {"role": "assistant", "content": f"❌ Error: {str(e)}"}], None
 
 
 
 
149
 
150
+ send_btn.click(analyze, inputs=[msg_input, chatbot, state, file_upload], outputs=[chatbot, gr.File()])
151
+ msg_input.submit(analyze, inputs=[msg_input, chatbot, state, file_upload], outputs=[chatbot, gr.File()])
152
  return demo
153
 
154
  if __name__ == "__main__":
155
+ ui = create_ui()
 
156
  ui.queue(api_open=False).launch(
157
  server_name="0.0.0.0",
158
  server_port=7860,
159
  show_error=True,
160
  allowed_paths=["/data/reports"],
161
  share=False
162
+ )