Ali2206 commited on
Commit
c9b3ae0
Β·
verified Β·
1 Parent(s): b6e9667

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +131 -113
app.py CHANGED
@@ -1,27 +1,28 @@
1
  import sys
2
  import os
 
3
  import pdfplumber
4
  import json
5
  import gradio as gr
6
  from typing import List
7
  from concurrent.futures import ThreadPoolExecutor, as_completed
8
  import hashlib
 
9
  import re
10
  import psutil
11
  import subprocess
12
- from collections import defaultdict
13
- from vllm import LLM, SamplingParams
14
 
15
  # Persistent directory
16
- persistent_dir = os.getenv("HF_HOME", "/data/hf_cache")
17
  os.makedirs(persistent_dir, exist_ok=True)
18
 
19
  model_cache_dir = os.path.join(persistent_dir, "txagent_models")
 
20
  file_cache_dir = os.path.join(persistent_dir, "cache")
21
  report_dir = os.path.join(persistent_dir, "reports")
22
  vllm_cache_dir = os.path.join(persistent_dir, "vllm_cache")
23
 
24
- for directory in [model_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
25
  os.makedirs(directory, exist_ok=True)
26
 
27
  os.environ["HF_HOME"] = model_cache_dir
@@ -29,13 +30,15 @@ os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
29
  os.environ["VLLM_CACHE_DIR"] = vllm_cache_dir
30
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
31
  os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
32
- os.environ["VLLM_NO_TORCH_COMPILE"] = "1"
33
 
34
  current_dir = os.path.dirname(os.path.abspath(__file__))
35
  src_path = os.path.abspath(os.path.join(current_dir, "src"))
36
  sys.path.insert(0, src_path)
37
 
38
- from txagent.txagent import TxAgent, clean_response # MODIFIED: Import clean_response
 
 
 
39
 
40
  def sanitize_utf8(text: str) -> str:
41
  return text.encode("utf-8", "ignore").decode("utf-8")
@@ -44,14 +47,15 @@ def file_hash(path: str) -> str:
44
  with open(path, "rb") as f:
45
  return hashlib.md5(f.read()).hexdigest()
46
 
47
- def extract_all_pages(file_path: str) -> str:
48
  try:
49
  text_chunks = []
50
  with pdfplumber.open(file_path) as pdf:
51
- for page in pdf.pages:
52
  page_text = page.extract_text() or ""
53
- text_chunks.append(page_text.strip())
54
- return "\n".join(text_chunks)
 
55
  except Exception as e:
56
  return f"PDF processing error: {str(e)}"
57
 
@@ -62,9 +66,22 @@ def convert_file_to_json(file_path: str, file_type: str) -> str:
62
  if os.path.exists(cache_path):
63
  with open(cache_path, "r", encoding="utf-8") as f:
64
  return f.read()
 
65
  if file_type == "pdf":
66
- text = extract_all_pages(file_path)
67
  result = json.dumps({"filename": os.path.basename(file_path), "content": text, "status": "initial"})
 
 
 
 
 
 
 
 
 
 
 
 
68
  else:
69
  result = json.dumps({"error": f"Unsupported file type: {file_type}"})
70
  with open(cache_path, "w", encoding="utf-8") as f:
@@ -88,61 +105,47 @@ def log_system_usage(tag=""):
88
  except Exception as e:
89
  print(f"[{tag}] GPU/CPU monitor failed: {e}")
90
 
91
- def normalize_text(text: str) -> str:
92
- return re.sub(r"\s+", " ", text.lower().strip())
93
-
94
- def consolidate_findings(responses: List[str]) -> str:
95
- findings = defaultdict(set)
96
- headings = ["Missed Diagnoses", "Medication Conflicts", "Incomplete Assessments", "Urgent Follow-up"]
97
-
98
- for response in responses:
99
- if not response:
100
- continue
101
- current_heading = None
102
- for line in response.split("\n"):
103
- line = line.strip()
104
- if not line:
105
- continue
106
- if line.lower().startswith(tuple(h.lower() + ":" for h in headings)):
107
- current_heading = next(h for h in headings if line.lower().startswith(h.lower() + ":"))
108
- elif current_heading and line.startswith("-"):
109
- findings[current_heading].add(normalize_text(line))
110
-
111
- output = []
112
- for heading in headings:
113
- if findings[heading]:
114
- output.append(f"**{heading}**:")
115
- original_lines = {normalize_text(r): r for r in sum([r.split("\n") for r in responses], []) if r.startswith("-")}
116
- output.extend(sorted(original_lines.get(n, "- " + n) for n in findings[heading]))
117
- return "\n".join(output).strip() if output else "No oversights identified."
118
 
119
  def init_agent():
120
  print("πŸ” Initializing model...")
121
  log_system_usage("Before Load")
122
- model = LLM(
123
- model="mims-harvard/TxAgent-T1-Llama-3.1-8B",
124
- max_model_len=4096, # MODIFIED: Enforce low VRAM
125
- enforce_eager=True,
126
- enable_chunked_prefill=True,
127
- max_num_batched_tokens=8192,
128
- gpu_memory_utilization=0.5, # MODIFIED: Limit VRAM
 
 
 
 
 
 
 
129
  )
 
130
  log_system_usage("After Load")
131
- print("βœ… Model Ready")
132
- return model
133
 
134
- def create_ui(model):
135
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
136
  gr.Markdown("<h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1>")
137
  chatbot = gr.Chatbot(label="Analysis", height=600, type="messages")
138
- file_upload = gr.File(file_types=[".pdf"], file_count="multiple")
139
  msg_input = gr.Textbox(placeholder="Ask about potential oversights...", show_label=False)
140
  send_btn = gr.Button("Analyze", variant="primary")
141
- download_output = gr.File(label="Download Report")
142
 
143
  def analyze(message: str, history: List[dict], files: List):
144
  history.append({"role": "user", "content": message})
145
- history.append({"role": "assistant", "content": "πŸ”„ Analyzing..."})
146
  yield history, None
147
 
148
  extracted = ""
@@ -151,82 +154,97 @@ def create_ui(model):
151
  with ThreadPoolExecutor(max_workers=6) as executor:
152
  futures = [executor.submit(convert_file_to_json, f.name, f.name.split(".")[-1].lower()) for f in files]
153
  results = [sanitize_utf8(f.result()) for f in as_completed(futures)]
154
- extracted = "\n".join([json.loads(r).get("content", "") for r in results if "content" in json.loads(r)])
155
  file_hash_value = file_hash(files[0].name) if files else ""
156
 
157
- chunk_size = 800 # MODIFIED: Enforce correct size
 
158
  chunks = [extracted[i:i + chunk_size] for i in range(0, len(extracted), chunk_size)]
159
- chunk_responses = []
160
- batch_size = 4 # MODIFIED: Lower for VRAM
161
- total_chunks = len(chunks)
162
 
163
- prompt_template = """
164
- Strictly output oversights under these exact headings, one point per line, starting with "-". No other text, reasoning, or tools.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
- **Missed Diagnoses**:
167
- **Medication Conflicts**:
168
- **Incomplete Assessments**:
169
- **Urgent Follow-up**:
170
 
171
- Records:
172
- {chunk}
173
- """ # MODIFIED: Stronger instructions
 
 
174
 
175
- sampling_params = SamplingParams(
176
- temperature=0.3, # MODIFIED: Improve output quality
177
- max_tokens=64, # MODIFIED: Allow full responses
178
- seed=100,
179
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
- try:
182
- findings = defaultdict(list) # MODIFIED: Track per batch
183
- for i in range(0, len(chunks), batch_size):
184
- batch = chunks[i:i + batch_size]
185
- prompts = [prompt_template.format(chunk=chunk) for chunk in batch]
186
- log_system_usage(f"Batch {i//batch_size + 1}")
187
- outputs = model.generate(prompts, sampling_params, use_tqdm=True) # MODIFIED: Stream progress
188
- batch_responses = []
189
- with ThreadPoolExecutor(max_workers=4) as executor:
190
- futures = [executor.submit(clean_response, output.outputs[0].text) for output in outputs]
191
- batch_responses.extend(f.result() for f in as_completed(futures))
192
-
193
- processed = min(i + len(batch), total_chunks)
194
- batch_output = []
195
- for response in batch_responses:
196
- if response:
197
- chunk_responses.append(response)
198
- current_heading = None
199
- for line in response.split("\n"):
200
- line = line.strip()
201
- if line.lower().startswith(tuple(h.lower() + ":" for h in ["missed diagnoses", "medication conflicts", "incomplete assessments", "urgent follow-up"])):
202
- current_heading = line[:-1]
203
- if current_heading not in batch_output:
204
- batch_output.append(current_heading + ":")
205
- elif current_heading and line.startswith("-"):
206
- findings[current_heading].append(line)
207
- batch_output.append(line)
208
-
209
- # MODIFIED: Stream partial results
210
- if batch_output:
211
- history[-1]["content"] = "\n".join(batch_output) + f"\n\nπŸ”„ Processing chunk {processed}/{total_chunks}..."
212
- else:
213
- history[-1]["content"] = f"πŸ”„ Processing chunk {processed}/{total_chunks}..."
214
- yield history, None
215
 
216
- # MODIFIED: Final consolidation
217
- final_response = consolidate_findings(chunk_responses)
218
- history[-1]["content"] = final_response
219
- yield history, None
 
220
 
 
221
  report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt") if file_hash_value else None
222
- if report_path and final_response != "No oversights identified.":
223
  with open(report_path, "w", encoding="utf-8") as f:
224
- f.write(final_response)
225
  yield history, report_path if report_path and os.path.exists(report_path) else None
226
 
227
  except Exception as e:
228
  print("🚨 ERROR:", e)
229
- history[-1]["content"] = f"❌ Error: {str(e)}"
230
  yield history, None
231
 
232
  send_btn.click(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output])
@@ -235,8 +253,8 @@ def create_ui(model):
235
 
236
  if __name__ == "__main__":
237
  print("πŸš€ Launching app...")
238
- model = init_agent()
239
- demo = create_ui(model)
240
  demo.queue(api_open=False).launch(
241
  server_name="0.0.0.0",
242
  server_port=7860,
 
1
  import sys
2
  import os
3
+ import pandas as pd
4
  import pdfplumber
5
  import json
6
  import gradio as gr
7
  from typing import List
8
  from concurrent.futures import ThreadPoolExecutor, as_completed
9
  import hashlib
10
+ import shutil
11
  import re
12
  import psutil
13
  import subprocess
 
 
14
 
15
  # Persistent directory
16
+ persistent_dir = "/data/hf_cache"
17
  os.makedirs(persistent_dir, exist_ok=True)
18
 
19
  model_cache_dir = os.path.join(persistent_dir, "txagent_models")
20
+ tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
21
  file_cache_dir = os.path.join(persistent_dir, "cache")
22
  report_dir = os.path.join(persistent_dir, "reports")
23
  vllm_cache_dir = os.path.join(persistent_dir, "vllm_cache")
24
 
25
+ for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
26
  os.makedirs(directory, exist_ok=True)
27
 
28
  os.environ["HF_HOME"] = model_cache_dir
 
30
  os.environ["VLLM_CACHE_DIR"] = vllm_cache_dir
31
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
32
  os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
 
33
 
34
  current_dir = os.path.dirname(os.path.abspath(__file__))
35
  src_path = os.path.abspath(os.path.join(current_dir, "src"))
36
  sys.path.insert(0, src_path)
37
 
38
+ from txagent.txagent import TxAgent
39
+
40
+ MEDICAL_KEYWORDS = {'diagnosis', 'assessment', 'plan', 'results', 'medications',
41
+ 'allergies', 'summary', 'impression', 'findings', 'recommendations'}
42
 
43
  def sanitize_utf8(text: str) -> str:
44
  return text.encode("utf-8", "ignore").decode("utf-8")
 
47
  with open(path, "rb") as f:
48
  return hashlib.md5(f.read()).hexdigest()
49
 
50
+ def extract_priority_pages(file_path: str) -> str:
51
  try:
52
  text_chunks = []
53
  with pdfplumber.open(file_path) as pdf:
54
+ for i, page in enumerate(pdf.pages):
55
  page_text = page.extract_text() or ""
56
+ if i < 3 or any(re.search(rf'\b{kw}\b', page_text.lower()) for kw in MEDICAL_KEYWORDS):
57
+ text_chunks.append(f"=== Page {i+1} ===\n{page_text.strip()}")
58
+ return "\n\n".join(text_chunks)
59
  except Exception as e:
60
  return f"PDF processing error: {str(e)}"
61
 
 
66
  if os.path.exists(cache_path):
67
  with open(cache_path, "r", encoding="utf-8") as f:
68
  return f.read()
69
+
70
  if file_type == "pdf":
71
+ text = extract_priority_pages(file_path)
72
  result = json.dumps({"filename": os.path.basename(file_path), "content": text, "status": "initial"})
73
+ elif file_type == "csv":
74
+ df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str,
75
+ skip_blank_lines=False, on_bad_lines="skip")
76
+ content = df.fillna("").astype(str).values.tolist()
77
+ result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
78
+ elif file_type in ["xls", "xlsx"]:
79
+ try:
80
+ df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
81
+ except Exception:
82
+ df = pd.read_excel(file_path, engine="xlrd", header=None, dtype=str)
83
+ content = df.fillna("").astype(str).values.tolist()
84
+ result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
85
  else:
86
  result = json.dumps({"error": f"Unsupported file type: {file_type}"})
87
  with open(cache_path, "w", encoding="utf-8") as f:
 
105
  except Exception as e:
106
  print(f"[{tag}] GPU/CPU monitor failed: {e}")
107
 
108
+ def clean_response(text: str) -> str:
109
+ text = sanitize_utf8(text)
110
+ text = re.sub(r"\[TOOL_CALLS\].*", "", text, flags=re.DOTALL)
111
+ text = re.sub(r"\n{3,}", "\n\n", text).strip()
112
+ return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  def init_agent():
115
  print("πŸ” Initializing model...")
116
  log_system_usage("Before Load")
117
+ default_tool_path = os.path.abspath("data/new_tool.json")
118
+ target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
119
+ if not os.path.exists(target_tool_path):
120
+ shutil.copy(default_tool_path, target_tool_path)
121
+
122
+ agent = TxAgent(
123
+ model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
124
+ rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
125
+ tool_files_dict={"new_tool": target_tool_path},
126
+ force_finish=True,
127
+ enable_checker=True,
128
+ step_rag_num=4,
129
+ seed=100,
130
+ additional_default_tools=[],
131
  )
132
+ agent.init_model()
133
  log_system_usage("After Load")
134
+ print("βœ… Agent Ready")
135
+ return agent
136
 
137
+ def create_ui(agent):
138
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
139
  gr.Markdown("<h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1>")
140
  chatbot = gr.Chatbot(label="Analysis", height=600, type="messages")
141
+ file_upload = gr.File(file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple")
142
  msg_input = gr.Textbox(placeholder="Ask about potential oversights...", show_label=False)
143
  send_btn = gr.Button("Analyze", variant="primary")
144
+ download_output = gr.File(label="Download Full Report")
145
 
146
  def analyze(message: str, history: List[dict], files: List):
147
  history.append({"role": "user", "content": message})
148
+ history.append({"role": "assistant", "content": "⏳ Analyzing records for potential oversights..."})
149
  yield history, None
150
 
151
  extracted = ""
 
154
  with ThreadPoolExecutor(max_workers=6) as executor:
155
  futures = [executor.submit(convert_file_to_json, f.name, f.name.split(".")[-1].lower()) for f in files]
156
  results = [sanitize_utf8(f.result()) for f in as_completed(futures)]
157
+ extracted = "\n".join(results)
158
  file_hash_value = file_hash(files[0].name) if files else ""
159
 
160
+ # Split extracted text into chunks of ~6,000 characters
161
+ chunk_size = 6000
162
  chunks = [extracted[i:i + chunk_size] for i in range(0, len(extracted), chunk_size)]
163
+ combined_response = ""
 
 
164
 
165
+ prompt_template = f"""
166
+ Analyze the medical records for clinical oversights. Provide a concise, evidence-based summary under these headings:
167
+ 1. **Missed Diagnoses**:
168
+ - Identify inconsistencies in history, symptoms, or tests.
169
+ - Consider psychiatric, neurological, infectious, autoimmune, genetic conditions, family history, trauma, and developmental factors.
170
+ 2. **Medication Conflicts**:
171
+ - Check for contraindications, interactions, or unjustified off-label use.
172
+ - Assess if medications worsen diagnoses or cause adverse effects.
173
+ 3. **Incomplete Assessments**:
174
+ - Note missing or superficial cognitive, psychiatric, social, or family assessments.
175
+ - Highlight gaps in medical history, substance use, or lab/imaging documentation.
176
+ 4. **Urgent Follow-up**:
177
+ - Flag abnormal lab results, imaging, behaviors, or legal history needing immediate reassessment or referral.
178
+ Medical Records (Chunk {0} of {1}):
179
+ {{chunk}}
180
+ Begin analysis:
181
+ """
182
 
183
+ try:
184
+ if history and history[-1]["content"].startswith("⏳"):
185
+ history.pop()
 
186
 
187
+ # Process each chunk and stream results in real-time
188
+ for chunk_idx, chunk in enumerate(chunks, 1):
189
+ # Update UI with progress
190
+ history.append({"role": "assistant", "content": f"πŸ”„ Processing Chunk {chunk_idx} of {len(chunks)}..."})
191
+ yield history, None
192
 
193
+ prompt = prompt_template.format(chunk_idx, len(chunks), chunk=chunk)
194
+ chunk_response = ""
195
+ for chunk_output in agent.run_gradio_chat(
196
+ message=prompt,
197
+ history=[],
198
+ temperature=0.2,
199
+ max_new_tokens=1024,
200
+ max_token=4096,
201
+ call_agent=False,
202
+ conversation=[],
203
+ ):
204
+ if chunk_output is None:
205
+ continue
206
+ if isinstance(chunk_output, list):
207
+ for m in chunk_output:
208
+ if hasattr(m, 'content') and m.content:
209
+ cleaned = clean_response(m.content)
210
+ if cleaned:
211
+ chunk_response += cleaned + "\n"
212
+ # Update UI with partial response
213
+ if history[-1]["content"].startswith("πŸ”„"):
214
+ history[-1] = {"role": "assistant", "content": f"--- Analysis for Chunk {chunk_idx} ---\n{chunk_response.strip()}"}
215
+ else:
216
+ history[-1]["content"] = f"--- Analysis for Chunk {chunk_idx} ---\n{chunk_response.strip()}"
217
+ yield history, None
218
+ elif isinstance(chunk_output, str) and chunk_output.strip():
219
+ cleaned = clean_response(chunk_output)
220
+ if cleaned:
221
+ chunk_response += cleaned + "\n"
222
+ # Update UI with partial response
223
+ if history[-1]["content"].startswith("πŸ”„"):
224
+ history[-1] = {"role": "assistant", "content": f"--- Analysis for Chunk {chunk_idx} ---\n{chunk_response.strip()}"}
225
+ else:
226
+ history[-1]["content"] = f"--- Analysis for Chunk {chunk_idx} ---\n{chunk_response.strip()}"
227
+ yield history, None
228
 
229
+ # Append completed chunk response to combined response
230
+ combined_response += f"--- Analysis for Chunk {chunk_idx} ---\n{chunk_response}\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
+ # Finalize UI with complete response
233
+ if combined_response:
234
+ history[-1]["content"] = combined_response.strip()
235
+ else:
236
+ history.append({"role": "assistant", "content": "No oversights identified."})
237
 
238
+ # Generate report file
239
  report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt") if file_hash_value else None
240
+ if report_path:
241
  with open(report_path, "w", encoding="utf-8") as f:
242
+ f.write(combined_response)
243
  yield history, report_path if report_path and os.path.exists(report_path) else None
244
 
245
  except Exception as e:
246
  print("🚨 ERROR:", e)
247
+ history.append({"role": "assistant", "content": f"❌ Error occurred: {str(e)}"})
248
  yield history, None
249
 
250
  send_btn.click(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output])
 
253
 
254
  if __name__ == "__main__":
255
  print("πŸš€ Launching app...")
256
+ agent = init_agent()
257
+ demo = create_ui(agent)
258
  demo.queue(api_open=False).launch(
259
  server_name="0.0.0.0",
260
  server_port=7860,