Ali2206 commited on
Commit
c5da27e
Β·
verified Β·
1 Parent(s): 26faa43

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +157 -116
app.py CHANGED
@@ -8,9 +8,10 @@ import hashlib
8
  import shutil
9
  import re
10
  from datetime import datetime
 
11
  from concurrent.futures import ThreadPoolExecutor, as_completed
12
 
13
- # Setup directories
14
  persistent_dir = "/data/hf_cache"
15
  os.makedirs(persistent_dir, exist_ok=True)
16
 
@@ -19,13 +20,16 @@ tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
19
  file_cache_dir = os.path.join(persistent_dir, "cache")
20
  report_dir = os.path.join(persistent_dir, "reports")
21
 
22
- for d in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]:
23
- os.makedirs(d, exist_ok=True)
24
 
25
  os.environ["HF_HOME"] = model_cache_dir
26
  os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
27
 
28
- sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "src")))
 
 
 
29
  from txagent.txagent import TxAgent
30
 
31
  MAX_MODEL_TOKENS = 32768
@@ -34,6 +38,10 @@ MAX_NEW_TOKENS = 2048
34
  PROMPT_OVERHEAD = 500
35
 
36
  def clean_response(text: str) -> str:
 
 
 
 
37
  text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL)
38
  text = re.sub(r"\n{3,}", "\n\n", text)
39
  text = re.sub(r"[^\n#\-\*\w\s\.,:\(\)]+", "", text)
@@ -44,28 +52,35 @@ def estimate_tokens(text: str) -> int:
44
 
45
  def extract_text_from_excel(file_path: str) -> str:
46
  all_text = []
47
- xls = pd.ExcelFile(file_path)
48
- for sheet_name in xls.sheet_names:
49
- df = xls.parse(sheet_name).astype(str).fillna("")
50
- rows = df.apply(lambda row: " | ".join(row), axis=1)
51
- sheet_text = [f"[{sheet_name}] {line}" for line in rows]
52
- all_text.extend(sheet_text)
 
 
 
 
53
  return "\n".join(all_text)
54
 
55
  def split_text_into_chunks(text: str, max_tokens: int = MAX_CHUNK_TOKENS) -> List[str]:
56
- effective_max = max_tokens - PROMPT_OVERHEAD
57
- lines, chunks, curr_chunk, curr_tokens = text.split("\n"), [], [], 0
 
 
 
58
  for line in lines:
59
- t = estimate_tokens(line)
60
- if curr_tokens + t > effective_max:
61
- if curr_chunk:
62
- chunks.append("\n".join(curr_chunk))
63
- curr_chunk, curr_tokens = [line], t
64
  else:
65
- curr_chunk.append(line)
66
- curr_tokens += t
67
- if curr_chunk:
68
- chunks.append("\n".join(curr_chunk))
69
  return chunks
70
 
71
  def build_prompt_from_text(chunk: str) -> str:
@@ -88,132 +103,158 @@ Respond in well-structured bullet points with medical reasoning.
88
  """
89
 
90
  def init_agent():
91
- tool_path = os.path.join(tool_cache_dir, "new_tool.json")
92
- if not os.path.exists(tool_path):
93
- shutil.copy(os.path.abspath("data/new_tool.json"), tool_path)
 
94
  agent = TxAgent(
95
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
96
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
97
- tool_files_dict={"new_tool": tool_path},
98
  force_finish=True,
99
  enable_checker=True,
100
  step_rag_num=4,
101
- seed=100
 
102
  )
103
  agent.init_model()
104
  return agent
105
 
106
  def process_final_report(agent, file, chatbot_state: List[Dict[str, str]]) -> Tuple[List[Dict[str, str]], Union[str, None]]:
107
  messages = chatbot_state if chatbot_state else []
 
 
108
  if file is None or not hasattr(file, "name"):
109
- return messages + [{"role": "assistant", "content": "❌ Please upload a valid Excel file."}], None
110
-
111
- messages.append({"role": "user", "content": f"Processing Excel file: {os.path.basename(file.name)}"})
112
- text = extract_text_from_excel(file.name)
113
- chunks = split_text_into_chunks(text)
114
- chunk_responses = [None] * len(chunks)
115
-
116
- def analyze_chunk(i, chunk):
117
- prompt = build_prompt_from_text(chunk)
118
- response = ""
119
- for res in agent.run_gradio_chat(message=prompt, history=[], temperature=0.2, max_new_tokens=MAX_NEW_TOKENS, max_token=MAX_MODEL_TOKENS, call_agent=False, conversation=[]):
120
- if isinstance(res, str):
121
- response += res
122
- elif hasattr(res, "content"):
123
- response += res.content
124
- elif isinstance(res, list):
125
- for r in res:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  if hasattr(r, "content"):
127
- response += r.content
128
- return i, clean_response(response)
129
-
130
- with ThreadPoolExecutor(max_workers=1) as executor:
131
- futures = [executor.submit(analyze_chunk, i, c) for i, c in enumerate(chunks)]
132
- for f in as_completed(futures):
133
- i, result = f.result()
134
- chunk_responses[i] = result
135
-
136
- valid = [r for r in chunk_responses if r and not r.startswith("❌")]
137
- if not valid:
138
- return messages + [{"role": "assistant", "content": "❌ No valid chunk results."}], None
139
-
140
- summary_prompt = f"Summarize this analysis in a final structured report:\n\n" + "\n\n".join(valid)
141
- messages.append({"role": "assistant", "content": "πŸ“Š Generating final report..."})
142
-
143
- final_report = ""
144
- for res in agent.run_gradio_chat(message=summary_prompt, history=[], temperature=0.2, max_new_tokens=MAX_NEW_TOKENS, max_token=MAX_MODEL_TOKENS, call_agent=False, conversation=[]):
145
- if isinstance(res, str):
146
- final_report += res
147
- elif hasattr(res, "content"):
148
- final_report += res.content
149
-
150
- cleaned = clean_response(final_report)
151
- report_path = os.path.join(report_dir, f"report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md")
152
- with open(report_path, 'w') as f:
153
- f.write(f"# 🧠 Final Patient Report\n\n{cleaned}")
154
-
155
- messages.append({"role": "assistant", "content": f"πŸ“Š Final Report:\n\n{cleaned}"})
156
- messages.append({"role": "assistant", "content": f"βœ… Report generated and saved: {os.path.basename(report_path)}"})
157
  return messages, report_path
158
 
159
  def create_ui(agent):
160
  with gr.Blocks(css="""
161
  html, body, .gradio-container {
162
  height: 100vh;
163
- background-color: #111827;
164
- color: #e5e7eb;
 
165
  font-family: 'Inter', sans-serif;
166
- }
167
- .message-avatar {
168
- width: 38px;
169
- height: 38px;
170
- border-radius: 50%;
171
- margin-right: 10px;
172
- }
173
- .chat-message {
174
- display: flex;
175
- align-items: flex-start;
176
- margin-bottom: 1rem;
177
- }
178
- .message-bubble {
179
- background-color: #1f2937;
180
- padding: 12px 16px;
181
- border-radius: 12px;
182
- max-width: 90%;
183
- }
184
- .chat-input {
185
- background-color: #1f2937;
186
- border: 1px solid #374151;
187
- border-radius: 8px;
188
- color: #e5e7eb;
189
- padding: 0.75rem 1rem;
190
  }
191
  .gr-button.primary {
192
- background: #2563eb;
193
- color: white;
194
- border-radius: 8px;
 
195
  font-weight: 600;
196
  }
197
  .gr-button.primary:hover {
198
- background: #1e40af;
 
 
 
 
 
 
 
 
 
 
 
 
199
  }
200
  """) as demo:
201
- gr.Markdown("""<h2 style='color:#60a5fa'>🩺 Patient History AI Assistant</h2><p>Upload a clinical Excel file and receive a structured diagnostic summary.</p>""")
 
 
 
 
202
  with gr.Row():
203
  with gr.Column(scale=3):
204
- chatbot = gr.Chatbot(
205
- label="Clinical Assistant",
206
- height=700,
207
- type="messages",
208
- avatar_images=[
209
- "https://ui-avatars.com/api/?name=AI&background=2563eb&color=fff&size=128",
210
- "https://ui-avatars.com/api/?name=You&background=374151&color=fff&size=128"
211
- ]
212
- )
213
  with gr.Column(scale=1):
214
- with gr.Row():
215
- file_upload = gr.File(label="", file_types=[".xlsx"], elem_id="upload-btn")
216
- analyze_btn = gr.Button("🧠 Analyze", variant="primary")
217
  report_output = gr.File(label="Download Report", visible=False, interactive=False)
218
 
219
  chatbot_state = gr.State(value=[])
 
8
  import shutil
9
  import re
10
  from datetime import datetime
11
+ import time
12
  from concurrent.futures import ThreadPoolExecutor, as_completed
13
 
14
+ # Configuration and setup
15
  persistent_dir = "/data/hf_cache"
16
  os.makedirs(persistent_dir, exist_ok=True)
17
 
 
20
  file_cache_dir = os.path.join(persistent_dir, "cache")
21
  report_dir = os.path.join(persistent_dir, "reports")
22
 
23
+ for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]:
24
+ os.makedirs(directory, exist_ok=True)
25
 
26
  os.environ["HF_HOME"] = model_cache_dir
27
  os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
28
 
29
+ current_dir = os.path.dirname(os.path.abspath(__file__))
30
+ src_path = os.path.abspath(os.path.join(current_dir, "src"))
31
+ sys.path.insert(0, src_path)
32
+
33
  from txagent.txagent import TxAgent
34
 
35
  MAX_MODEL_TOKENS = 32768
 
38
  PROMPT_OVERHEAD = 500
39
 
40
  def clean_response(text: str) -> str:
41
+ try:
42
+ text = text.encode('utf-8', 'surrogatepass').decode('utf-8')
43
+ except UnicodeError:
44
+ text = text.encode('utf-8', 'replace').decode('utf-8')
45
  text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL)
46
  text = re.sub(r"\n{3,}", "\n\n", text)
47
  text = re.sub(r"[^\n#\-\*\w\s\.,:\(\)]+", "", text)
 
52
 
53
  def extract_text_from_excel(file_path: str) -> str:
54
  all_text = []
55
+ try:
56
+ xls = pd.ExcelFile(file_path)
57
+ for sheet_name in xls.sheet_names:
58
+ df = xls.parse(sheet_name)
59
+ df = df.astype(str).fillna("")
60
+ rows = df.apply(lambda row: " | ".join(row), axis=1)
61
+ sheet_text = [f"[{sheet_name}] {line}" for line in rows]
62
+ all_text.extend(sheet_text)
63
+ except Exception as e:
64
+ raise ValueError(f"Failed to extract text from Excel file: {str(e)}")
65
  return "\n".join(all_text)
66
 
67
  def split_text_into_chunks(text: str, max_tokens: int = MAX_CHUNK_TOKENS) -> List[str]:
68
+ effective_max_tokens = max_tokens - PROMPT_OVERHEAD
69
+ if effective_max_tokens <= 0:
70
+ raise ValueError("Effective max tokens must be positive.")
71
+ lines = text.split("\n")
72
+ chunks, current_chunk, current_tokens = [], [], 0
73
  for line in lines:
74
+ line_tokens = estimate_tokens(line)
75
+ if current_tokens + line_tokens > effective_max_tokens:
76
+ if current_chunk:
77
+ chunks.append("\n".join(current_chunk))
78
+ current_chunk, current_tokens = [line], line_tokens
79
  else:
80
+ current_chunk.append(line)
81
+ current_tokens += line_tokens
82
+ if current_chunk:
83
+ chunks.append("\n".join(current_chunk))
84
  return chunks
85
 
86
  def build_prompt_from_text(chunk: str) -> str:
 
103
  """
104
 
105
  def init_agent():
106
+ default_tool_path = os.path.abspath("data/new_tool.json")
107
+ target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
108
+ if not os.path.exists(target_tool_path):
109
+ shutil.copy(default_tool_path, target_tool_path)
110
  agent = TxAgent(
111
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
112
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
113
+ tool_files_dict={"new_tool": target_tool_path},
114
  force_finish=True,
115
  enable_checker=True,
116
  step_rag_num=4,
117
+ seed=100,
118
+ additional_default_tools=[]
119
  )
120
  agent.init_model()
121
  return agent
122
 
123
  def process_final_report(agent, file, chatbot_state: List[Dict[str, str]]) -> Tuple[List[Dict[str, str]], Union[str, None]]:
124
  messages = chatbot_state if chatbot_state else []
125
+ report_path = None
126
+
127
  if file is None or not hasattr(file, "name"):
128
+ messages.append({"role": "assistant", "content": "❌ Please upload a valid Excel file before analyzing."})
129
+ return messages, report_path
130
+
131
+ try:
132
+ messages.append({"role": "user", "content": f"Processing Excel file: {os.path.basename(file.name)}"})
133
+ extracted_text = extract_text_from_excel(file.name)
134
+ chunks = split_text_into_chunks(extracted_text)
135
+ chunk_responses = [None] * len(chunks)
136
+
137
+ def analyze_chunk(index: int, chunk: str) -> Tuple[int, str]:
138
+ prompt = build_prompt_from_text(chunk)
139
+ prompt_tokens = estimate_tokens(prompt)
140
+ if prompt_tokens > MAX_MODEL_TOKENS:
141
+ return index, f"❌ Chunk {index+1} prompt too long. Skipping..."
142
+ response = ""
143
+ try:
144
+ for result in agent.run_gradio_chat(
145
+ message=prompt,
146
+ history=[],
147
+ temperature=0.2,
148
+ max_new_tokens=MAX_NEW_TOKENS,
149
+ max_token=MAX_MODEL_TOKENS,
150
+ call_agent=False,
151
+ conversation=[],
152
+ ):
153
+ if isinstance(result, str):
154
+ response += result
155
+ elif isinstance(result, list):
156
+ for r in result:
157
+ if hasattr(r, "content"):
158
+ response += r.content
159
+ elif hasattr(result, "content"):
160
+ response += result.content
161
+ except Exception as e:
162
+ return index, f"❌ Error analyzing chunk {index+1}: {str(e)}"
163
+ return index, clean_response(response)
164
+
165
+ with ThreadPoolExecutor(max_workers=1) as executor:
166
+ futures = [executor.submit(analyze_chunk, i, chunk) for i, chunk in enumerate(chunks)]
167
+ for future in as_completed(futures):
168
+ i, result = future.result()
169
+ chunk_responses[i] = result
170
+ if result.startswith("❌"):
171
+ messages.append({"role": "assistant", "content": result})
172
+
173
+ valid_responses = [res for res in chunk_responses if not res.startswith("❌")]
174
+ if not valid_responses:
175
+ messages.append({"role": "assistant", "content": "❌ No valid chunk responses to summarize."})
176
+ return messages, report_path
177
+
178
+ summary = "\n\n".join(valid_responses)
179
+ final_prompt = f"Provide a structured, consolidated clinical analysis from these results:\n\n{summary}"
180
+ messages.append({"role": "assistant", "content": "πŸ“Š Generating final report..."})
181
+
182
+ final_report_text = ""
183
+ for result in agent.run_gradio_chat(
184
+ message=final_prompt,
185
+ history=[],
186
+ temperature=0.2,
187
+ max_new_tokens=MAX_NEW_TOKENS,
188
+ max_token=MAX_MODEL_TOKENS,
189
+ call_agent=False,
190
+ conversation=[],
191
+ ):
192
+ if isinstance(result, str):
193
+ final_report_text += result
194
+ elif isinstance(result, list):
195
+ for r in result:
196
  if hasattr(r, "content"):
197
+ final_report_text += r.content
198
+ elif hasattr(result, "content"):
199
+ final_report_text += result.content
200
+
201
+ cleaned = clean_response(final_report_text)
202
+ report_path = os.path.join(report_dir, f"report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md")
203
+ with open(report_path, 'w') as f:
204
+ f.write(f"# 🧠 Final Patient Report\n\n{cleaned}")
205
+
206
+ messages.append({"role": "assistant", "content": f"πŸ“Š Final Report:\n\n{cleaned}"})
207
+ messages.append({"role": "assistant", "content": f"βœ… Report generated and saved: {os.path.basename(report_path)}"})
208
+
209
+ except Exception as e:
210
+ messages.append({"role": "assistant", "content": f"❌ Error processing file: {str(e)}"})
211
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  return messages, report_path
213
 
214
  def create_ui(agent):
215
  with gr.Blocks(css="""
216
  html, body, .gradio-container {
217
  height: 100vh;
218
+ width: 100vw;
219
+ padding: 0;
220
+ margin: 0;
221
  font-family: 'Inter', sans-serif;
222
+ background: #ffffff;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  }
224
  .gr-button.primary {
225
+ background: #1e88e5;
226
+ color: #fff;
227
+ border: none;
228
+ border-radius: 6px;
229
  font-weight: 600;
230
  }
231
  .gr-button.primary:hover {
232
+ background: #1565c0;
233
+ }
234
+ .gr-chatbot {
235
+ border: 1px solid #e0e0e0;
236
+ background: #f9f9f9;
237
+ border-radius: 10px;
238
+ padding: 1rem;
239
+ font-size: 15px;
240
+ }
241
+ .gr-markdown, .gr-file-upload {
242
+ background: #ffffff;
243
+ border-radius: 8px;
244
+ box-shadow: 0 1px 3px rgba(0,0,0,0.08);
245
  }
246
  """) as demo:
247
+ gr.Markdown("""
248
+ <h2 style='color:#1e88e5'>🩺 Patient History AI Assistant</h2>
249
+ <p>Upload a clinical Excel file and receive an advanced diagnostic summary.</p>
250
+ """)
251
+
252
  with gr.Row():
253
  with gr.Column(scale=3):
254
+ chatbot = gr.Chatbot(label="Clinical Assistant", height=700, type="messages")
 
 
 
 
 
 
 
 
255
  with gr.Column(scale=1):
256
+ file_upload = gr.File(label="Upload Excel File", file_types=[".xlsx"])
257
+ analyze_btn = gr.Button("🧠 Analyze", variant="primary")
 
258
  report_output = gr.File(label="Download Report", visible=False, interactive=False)
259
 
260
  chatbot_state = gr.State(value=[])