Ali2206 commited on
Commit
9a76893
·
verified ·
1 Parent(s): ac2fc78

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -106
app.py CHANGED
@@ -1,11 +1,8 @@
1
  import sys
2
  import os
3
  import pandas as pd
4
- import json
5
  import gradio as gr
6
- from typing import List, Tuple, Dict, Any, Union
7
- import hashlib
8
- import shutil
9
  import re
10
  from datetime import datetime
11
  from concurrent.futures import ThreadPoolExecutor, as_completed
@@ -16,16 +13,15 @@ os.makedirs(persistent_dir, exist_ok=True)
16
 
17
  model_cache_dir = os.path.join(persistent_dir, "txagent_models")
18
  tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
19
- file_cache_dir = os.path.join(persistent_dir, "cache")
20
  report_dir = os.path.join(persistent_dir, "reports")
21
 
22
- for d in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]:
23
  os.makedirs(d, exist_ok=True)
24
 
25
  os.environ["HF_HOME"] = model_cache_dir
26
  os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
27
 
28
- sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "src")))
29
  from txagent.txagent import TxAgent
30
 
31
  MAX_MODEL_TOKENS = 32768
@@ -39,9 +35,6 @@ def clean_response(text: str) -> str:
39
  text = re.sub(r"[^\n#\-\*\w\s\.,:\(\)]+", "", text)
40
  return text.strip()
41
 
42
- def estimate_tokens(text: str) -> int:
43
- return len(text) // 3.5 + 1
44
-
45
  def extract_text_from_excel(file_path: str) -> str:
46
  all_text = []
47
  xls = pd.ExcelFile(file_path)
@@ -52,47 +45,41 @@ def extract_text_from_excel(file_path: str) -> str:
52
  all_text.extend(sheet_text)
53
  return "\n".join(all_text)
54
 
55
- def split_text_into_chunks(text: str, max_tokens: int = MAX_CHUNK_TOKENS, max_chunks: int = 30) -> List[str]:
56
- effective_max = max_tokens - PROMPT_OVERHEAD
57
- lines, chunks, curr_chunk, curr_tokens = text.split("\n"), [], [], 0
 
 
58
  for line in lines:
59
- t = estimate_tokens(line)
60
- if curr_tokens + t > effective_max:
61
  if curr_chunk:
62
  chunks.append("\n".join(curr_chunk))
63
- if len(chunks) >= max_chunks:
64
- break
65
- curr_chunk, curr_tokens = [line], t
66
  else:
67
  curr_chunk.append(line)
68
- curr_tokens += t
69
- if curr_chunk and len(chunks) < max_chunks:
70
  chunks.append("\n".join(curr_chunk))
71
  return chunks
72
 
73
  def build_prompt_from_text(chunk: str) -> str:
74
- return f"""
75
- ### Unstructured Clinical Records
76
-
77
- Analyze the following clinical notes and provide a detailed, concise summary focusing on:
78
- - Diagnostic Patterns
79
- - Medication Issues
80
- - Missed Opportunities
81
  - Inconsistencies
82
- - Follow-up Recommendations
83
 
84
- ---
85
 
86
- {chunk}
87
-
88
- ---
89
- Respond in well-structured bullet points with medical reasoning.
90
- """
91
 
92
  def init_agent():
93
  tool_path = os.path.join(tool_cache_dir, "new_tool.json")
94
  if not os.path.exists(tool_path):
95
- shutil.copy(os.path.abspath("data/new_tool.json"), tool_path)
 
96
  agent = TxAgent(
97
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
98
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
@@ -105,68 +92,62 @@ def init_agent():
105
  agent.init_model()
106
  return agent
107
 
108
- def process_final_report(agent, file, chatbot_state: List[Tuple[str, str]]) -> Tuple[List[Tuple[str, str]], Union[str, None, str]]:
109
- messages = chatbot_state if chatbot_state else []
110
- if file is None or not hasattr(file, "name"):
111
- return messages + [("assistant", "❌ Please upload a valid Excel file.")], None, ""
 
 
112
 
113
  messages.append(("user", f"Processing Excel file: {os.path.basename(file.name)}"))
114
- text = extract_text_from_excel(file.name)
115
- chunks = split_text_into_chunks(text)
116
- chunk_responses = [None] * len(chunks)
117
-
118
- def analyze_chunk(i, chunk):
119
- prompt = build_prompt_from_text(chunk)
120
- response = ""
121
- for res in agent.run_gradio_chat(
122
- message=prompt, history=[], temperature=0.2,
123
- max_new_tokens=MAX_NEW_TOKENS, max_token=MAX_MODEL_TOKENS,
124
- call_agent=False, conversation=[]
125
- ):
126
- if isinstance(res, str):
127
- response += res
128
- elif hasattr(res, "content"):
129
- response += res.content
130
- elif isinstance(res, list):
131
- for r in res:
132
- if hasattr(r, "content"):
133
- response += r.content
134
- return i, clean_response(response)
135
-
136
- with ThreadPoolExecutor(max_workers=1) as executor:
137
- futures = [executor.submit(analyze_chunk, i, c) for i, c in enumerate(chunks)]
138
- for f in as_completed(futures):
139
- i, result = f.result()
140
- chunk_responses[i] = result
141
-
142
- valid = [r for r in chunk_responses if r and not r.startswith("❌")]
143
- if not valid:
144
- return messages + [("assistant", "❌ No valid chunk results.")], None, ""
145
-
146
- summary_prompt = f"Summarize this analysis in a final structured report:\n\n" + "\n\n".join(valid)
147
- messages.append(("assistant", "📊 Generating final report..."))
148
-
149
- final_report = ""
150
- for res in agent.run_gradio_chat(
151
- message=summary_prompt, history=[], temperature=0.2,
152
- max_new_tokens=MAX_NEW_TOKENS, max_token=MAX_MODEL_TOKENS,
153
- call_agent=False, conversation=[]
154
- ):
155
- if isinstance(res, str):
156
- final_report += res
157
- elif hasattr(res, "content"):
158
- final_report += res.content
159
-
160
- cleaned = clean_response(final_report)
161
- report_path = os.path.join(report_dir, f"report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md")
162
- with open(report_path, 'w') as f:
163
- f.write(f"# 🧠 Final Patient Report\n\n{cleaned}")
164
-
165
- # Add the report content to the chat messages
166
- messages.append(("assistant", f"✅ Report generated and saved: {os.path.basename(report_path)}"))
167
- messages.append(("assistant", f"## Final Report\n\n{cleaned}"))
168
 
169
- return messages, report_path, cleaned
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
  def create_ui(agent):
172
  with gr.Blocks(css="""
@@ -211,24 +192,25 @@ def create_ui(agent):
211
  margin-top: 10px;
212
  border: 1px solid #2c3344;
213
  }
 
 
 
214
  """) as demo:
215
- gr.Markdown("""# 🧠 Clinical Reasoning Assistant
216
  Upload clinical Excel records below and click **Analyze** to generate a medical summary.
217
  """)
218
- chatbot = gr.Chatbot(label="Chatbot", elem_classes="chatbot", type="tuples")
 
219
  file_upload = gr.File(label="Upload Excel File", file_types=[".xlsx"])
220
  analyze_btn = gr.Button("Analyze")
221
  report_output = gr.File(label="Download Report", visible=False)
222
- chatbot_state = gr.State(value=[])
223
-
224
- def update_ui(file, current_state):
225
- messages, report_path, final_text = process_final_report(agent, file, current_state)
226
- return messages, gr.update(visible=report_path is not None, value=report_path), messages
227
-
228
  analyze_btn.click(
229
- fn=update_ui,
230
- inputs=[file_upload, chatbot_state],
231
- outputs=[chatbot, report_output, chatbot_state]
 
232
  )
233
 
234
  return demo
@@ -237,7 +219,12 @@ if __name__ == "__main__":
237
  try:
238
  agent = init_agent()
239
  demo = create_ui(agent)
240
- demo.launch(server_name="0.0.0.0", server_port=7860, allowed_paths=["/data/hf_cache/reports"], share=False)
 
 
 
 
 
241
  except Exception as e:
242
  print(f"Error: {str(e)}")
243
  sys.exit(1)
 
1
  import sys
2
  import os
3
  import pandas as pd
 
4
  import gradio as gr
5
+ from typing import List, Tuple
 
 
6
  import re
7
  from datetime import datetime
8
  from concurrent.futures import ThreadPoolExecutor, as_completed
 
13
 
14
  model_cache_dir = os.path.join(persistent_dir, "txagent_models")
15
  tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
 
16
  report_dir = os.path.join(persistent_dir, "reports")
17
 
18
+ for d in [model_cache_dir, tool_cache_dir, report_dir]:
19
  os.makedirs(d, exist_ok=True)
20
 
21
  os.environ["HF_HOME"] = model_cache_dir
22
  os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
23
 
24
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "src"))
25
  from txagent.txagent import TxAgent
26
 
27
  MAX_MODEL_TOKENS = 32768
 
35
  text = re.sub(r"[^\n#\-\*\w\s\.,:\(\)]+", "", text)
36
  return text.strip()
37
 
 
 
 
38
  def extract_text_from_excel(file_path: str) -> str:
39
  all_text = []
40
  xls = pd.ExcelFile(file_path)
 
45
  all_text.extend(sheet_text)
46
  return "\n".join(all_text)
47
 
48
+ def split_text_into_chunks(text: str) -> List[str]:
49
+ effective_max = MAX_CHUNK_TOKENS - PROMPT_OVERHEAD
50
+ lines, chunks, curr_chunk = text.split("\n"), [], []
51
+ curr_tokens = sum(len(line.split()) for line in curr_chunk)
52
+
53
  for line in lines:
54
+ line_tokens = len(line.split())
55
+ if curr_tokens + line_tokens > effective_max:
56
  if curr_chunk:
57
  chunks.append("\n".join(curr_chunk))
58
+ curr_chunk, curr_tokens = [line], line_tokens
 
 
59
  else:
60
  curr_chunk.append(line)
61
+ curr_tokens += line_tokens
62
+ if curr_chunk:
63
  chunks.append("\n".join(curr_chunk))
64
  return chunks
65
 
66
  def build_prompt_from_text(chunk: str) -> str:
67
+ return f"""Analyze these clinical notes and provide:
68
+ - Diagnostic patterns
69
+ - Medication issues
70
+ - Missed opportunities
 
 
 
71
  - Inconsistencies
72
+ - Follow-up recommendations
73
 
74
+ Respond with clear bullet points:
75
 
76
+ {chunk}"""
 
 
 
 
77
 
78
  def init_agent():
79
  tool_path = os.path.join(tool_cache_dir, "new_tool.json")
80
  if not os.path.exists(tool_path):
81
+ import shutil
82
+ shutil.copy("data/new_tool.json", tool_path)
83
  agent = TxAgent(
84
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
85
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
 
92
  agent.init_model()
93
  return agent
94
 
95
+ def process_final_report(agent, file, chatbot_state: List[Tuple[str, str]]) -> Tuple[List[Tuple[str, str]], str]:
96
+ messages = chatbot_state.copy() if chatbot_state else []
97
+
98
+ if file is None:
99
+ messages.append(("assistant", "❌ Please upload a valid Excel file."))
100
+ return messages, None
101
 
102
  messages.append(("user", f"Processing Excel file: {os.path.basename(file.name)}"))
103
+ yield messages, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
+ try:
106
+ text = extract_text_from_excel(file.name)
107
+ chunks = split_text_into_chunks(text)
108
+
109
+ messages.append(("assistant", "🔍 Analyzing clinical data..."))
110
+ yield messages, None
111
+
112
+ full_report = []
113
+ for i, chunk in enumerate(chunks, 1):
114
+ prompt = build_prompt_from_text(chunk)
115
+ response = ""
116
+
117
+ for res in agent.run_gradio_chat(
118
+ message=prompt, history=[], temperature=0.2,
119
+ max_new_tokens=MAX_NEW_TOKENS, max_token=MAX_MODEL_TOKENS,
120
+ call_agent=False, conversation=[]
121
+ ):
122
+ if isinstance(res, str):
123
+ response += res
124
+ elif hasattr(res, "content"):
125
+ response += res.content
126
+
127
+ cleaned = clean_response(response)
128
+ full_report.append(cleaned)
129
+
130
+ # Update progress in chat
131
+ progress_msg = f"✅ Analyzed section {i}/{len(chunks)}"
132
+ if len(messages) > 2 and "Analyzed section" in messages[-1][1]:
133
+ messages[-1] = ("assistant", progress_msg)
134
+ else:
135
+ messages.append(("assistant", progress_msg))
136
+ yield messages, None
137
+
138
+ # Generate final report
139
+ final_report = "## 🧠 Final Clinical Report\n\n" + "\n\n".join(full_report)
140
+ report_path = os.path.join(report_dir, f"report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md")
141
+ with open(report_path, 'w') as f:
142
+ f.write(final_report)
143
+
144
+ messages.append(("assistant", f"✅ Report generated and saved: {os.path.basename(report_path)}"))
145
+ messages.append(("assistant", final_report))
146
+ yield messages, report_path
147
+
148
+ except Exception as e:
149
+ messages.append(("assistant", f"❌ Error: {str(e)}"))
150
+ yield messages, None
151
 
152
  def create_ui(agent):
153
  with gr.Blocks(css="""
 
192
  margin-top: 10px;
193
  border: 1px solid #2c3344;
194
  }
195
+ .bullet-points {
196
+ margin-left: 20px;
197
+ }
198
  """) as demo:
199
+ gr.Markdown("""# Clinical Reasoning Assistant
200
  Upload clinical Excel records below and click **Analyze** to generate a medical summary.
201
  """)
202
+
203
+ chatbot = gr.Chatbot(label="Chatbot", elem_classes="chatbot")
204
  file_upload = gr.File(label="Upload Excel File", file_types=[".xlsx"])
205
  analyze_btn = gr.Button("Analyze")
206
  report_output = gr.File(label="Download Report", visible=False)
207
+ chatbot_state = gr.State([])
208
+
 
 
 
 
209
  analyze_btn.click(
210
+ fn=process_final_report,
211
+ inputs=[file_upload, chatbot_state, gr.State(agent)],
212
+ outputs=[chatbot, report_output],
213
+ show_progress="hidden"
214
  )
215
 
216
  return demo
 
219
  try:
220
  agent = init_agent()
221
  demo = create_ui(agent)
222
+ demo.launch(
223
+ server_name="0.0.0.0",
224
+ server_port=7860,
225
+ allowed_paths=["/data/hf_cache/reports"],
226
+ share=False
227
+ )
228
  except Exception as e:
229
  print(f"Error: {str(e)}")
230
  sys.exit(1)