Ali2206 commited on
Commit
dc9cc58
Β·
verified Β·
1 Parent(s): 34915cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +184 -104
app.py CHANGED
@@ -2,11 +2,12 @@ import sys
2
  import os
3
  import pandas as pd
4
  import gradio as gr
5
- from typing import List, Tuple, Dict, Any, Union, Generator
6
  import shutil
7
  import re
8
  from datetime import datetime
9
  import time
 
10
  import asyncio
11
  import logging
12
  from concurrent.futures import ThreadPoolExecutor, as_completed
@@ -23,7 +24,7 @@ report_dir = os.path.join(persistent_dir, "reports")
23
  for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]:
24
  os.makedirs(directory, exist_ok=True)
25
 
26
- os.environ["HF_HOME"] = model_cache_dir # Using HF_HOME instead of TRANSFORMERS_CACHE
27
 
28
  current_dir = os.path.dirname(os.path.abspath(__file__))
29
  src_path = os.path.abspath(os.path.join(current_dir, "src"))
@@ -31,15 +32,22 @@ sys.path.insert(0, src_path)
31
 
32
  from txagent.txagent import TxAgent
33
 
34
- # Updated token limits as specified
35
  MAX_MODEL_TOKENS = 131072 # TxAgent's max token limit
36
  MAX_CHUNK_TOKENS = 32768 # Larger chunks to reduce number of chunks
37
  MAX_NEW_TOKENS = 512 # Optimized for fast generation
38
  PROMPT_OVERHEAD = 500 # Estimated tokens for prompt template
39
  MAX_CONCURRENT = 8 # High concurrency for A100 80GB
40
 
 
 
 
 
 
 
 
41
  # Setup logging
42
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
43
  logger = logging.getLogger(__name__)
44
 
45
  def clean_response(text: str) -> str:
@@ -53,9 +61,13 @@ def clean_response(text: str) -> str:
53
  return text.strip()
54
 
55
  def estimate_tokens(text: str) -> int:
56
- return len(text) // 3.5 + 1 # More conservative estimate
 
 
 
57
 
58
  def extract_text_from_excel(file_path: str) -> str:
 
59
  all_text = []
60
  try:
61
  xls = pd.ExcelFile(file_path)
@@ -70,12 +82,12 @@ def extract_text_from_excel(file_path: str) -> str:
70
  raise ValueError(f"Failed to process Excel file: {str(e)}")
71
  return "\n".join(all_text)
72
 
73
- def split_text_into_chunks(text: str) -> List[str]:
74
- """Split text into chunks respecting MAX_CHUNK_TOKENS and PROMPT_OVERHEAD"""
75
- effective_max = MAX_CHUNK_TOKENS - PROMPT_OVERHEAD
76
- if effective_max <= 0:
77
- raise ValueError("Effective max tokens must be positive")
78
-
79
  lines = text.split("\n")
80
  chunks = []
81
  current_chunk = []
@@ -83,7 +95,7 @@ def split_text_into_chunks(text: str) -> List[str]:
83
 
84
  for line in lines:
85
  line_tokens = estimate_tokens(line)
86
- if current_tokens + line_tokens > effective_max:
87
  if current_chunk:
88
  chunks.append("\n".join(current_chunk))
89
  current_chunk = [line]
@@ -94,11 +106,12 @@ def split_text_into_chunks(text: str) -> List[str]:
94
 
95
  if current_chunk:
96
  chunks.append("\n".join(current_chunk))
97
-
98
  logger.info(f"Split text into {len(chunks)} chunks")
99
  return chunks
100
 
101
  def build_prompt_from_text(chunk: str) -> str:
 
102
  return f"""
103
  ### Unstructured Clinical Records
104
 
@@ -119,6 +132,7 @@ Please analyze the above and provide concise responses (max {MAX_NEW_TOKENS} tok
119
  """
120
 
121
  def init_agent():
 
122
  default_tool_path = os.path.abspath("data/new_tool.json")
123
  target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
124
 
@@ -138,17 +152,19 @@ def init_agent():
138
  agent.init_model()
139
  return agent
140
 
141
- async def process_chunk(agent: TxAgent, chunk: str, chunk_idx: int) -> Tuple[int, str]:
142
- """Process a single chunk with error handling"""
 
 
 
 
 
 
 
 
 
 
143
  try:
144
- prompt = build_prompt_from_text(chunk)
145
- prompt_tokens = estimate_tokens(prompt)
146
-
147
- if prompt_tokens > MAX_MODEL_TOKENS:
148
- logger.warning(f"Chunk {chunk_idx} prompt too long ({prompt_tokens} tokens)")
149
- return chunk_idx, ""
150
-
151
- response = ""
152
  for result in agent.run_gradio_chat(
153
  message=prompt,
154
  history=[],
@@ -166,95 +182,143 @@ async def process_chunk(agent: TxAgent, chunk: str, chunk_idx: int) -> Tuple[int
166
  for r in result:
167
  if hasattr(r, "content"):
168
  response += r.content
169
-
170
- return chunk_idx, clean_response(response)
171
-
172
  except Exception as e:
173
- logger.error(f"Error processing chunk {chunk_idx}: {str(e)}")
174
- return chunk_idx, ""
 
 
 
175
 
176
- async def process_file(agent: TxAgent, file_path: str) -> Generator[Tuple[List[Dict[str, str]], Union[str, None]], None, None]:
177
- """Process the entire file and yield progress updates"""
178
- messages = []
179
  report_path = None
180
-
 
 
 
 
181
  try:
182
- # Initial messages
183
- messages.append({"role": "user", "content": f"Processing file: {os.path.basename(file_path)}"})
184
- messages.append({"role": "assistant", "content": "⏳ Extracting data from Excel..."})
185
- yield messages, None
186
-
187
- # Extract and chunk text
188
  start_time = time.time()
189
- text = extract_text_from_excel(file_path)
190
- chunks = split_text_into_chunks(text)
191
- messages.append({"role": "assistant", "content": f"βœ… Extracted {len(chunks)} chunks in {time.time()-start_time:.1f}s"})
192
- yield messages, None
193
-
194
- # Process chunks in parallel
195
  chunk_responses = [None] * len(chunks)
196
- with ThreadPoolExecutor(max_workers=MAX_CONCURRENT) as executor:
197
- futures = []
198
- for idx, chunk in enumerate(chunks):
199
- future = executor.submit(
200
- lambda c, i: asyncio.run(process_chunk(agent, c, i)),
201
- chunk, idx
202
- )
203
- futures.append(future)
204
- messages.append({"role": "assistant", "content": f"πŸ” Processing chunk {idx+1}/{len(chunks)}..."})
205
- yield messages, None
206
-
207
- for future in as_completed(futures):
208
- idx, response = future.result()
209
- chunk_responses[idx] = response
210
- messages.append({"role": "assistant", "content": f"βœ… Chunk {idx+1} processed"})
211
- yield messages, None
212
-
213
- # Combine and summarize
214
- combined = "\n\n".join([r for r in chunk_responses if r])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  messages.append({"role": "assistant", "content": "πŸ“Š Generating final report..."})
216
- yield messages, None
217
-
218
- final_response = ""
219
- for result in agent.run_gradio_chat(
220
- message=f"Summarize these clinical findings:\n\n{combined}",
221
- history=[],
222
- temperature=0.2,
223
- max_new_tokens=MAX_NEW_TOKENS*2, # Allow more tokens for summary
224
- max_token=MAX_MODEL_TOKENS,
225
- call_agent=False,
226
- conversation=[],
227
- ):
228
- if isinstance(result, str):
229
- final_response += result
230
- elif hasattr(result, "content"):
231
- final_response += result.content
232
- elif isinstance(result, list):
233
- for r in result:
234
- if hasattr(r, "content"):
235
- final_response += r.content
236
-
237
- messages[-1]["content"] = f"πŸ“Š Generating final report...\n\n{clean_response(final_response)}"
238
- yield messages, None
239
-
240
- # Save report
241
- final_report = f"# Final Clinical Report\n\n{clean_response(final_response)}"
 
 
242
  timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
243
  report_path = os.path.join(report_dir, f"report_{timestamp}.md")
244
 
245
  with open(report_path, 'w') as f:
246
  f.write(final_report)
247
-
248
- messages.append({"role": "assistant", "content": f"βœ… Report saved: report_{timestamp}.md"})
249
- yield messages, report_path
250
-
 
 
251
  except Exception as e:
 
252
  logger.error(f"Processing failed: {str(e)}")
253
- messages.append({"role": "assistant", "content": f"❌ Error: {str(e)}"})
254
- yield messages, None
255
 
256
- def create_ui(agent: TxAgent):
257
- """Create the Gradio interface"""
258
  with gr.Blocks(title="Clinical Analysis", css=".gradio-container {max-width: 900px}") as demo:
259
  gr.Markdown("## πŸ₯ Clinical Data Analysis (TxAgent)")
260
 
@@ -278,15 +342,29 @@ def create_ui(agent: TxAgent):
278
  )
279
  report_output = gr.File(
280
  label="Download Report",
281
- visible=False
 
282
  )
283
-
 
 
 
 
 
 
 
 
 
 
 
 
284
  analyze_btn.click(
285
- fn=lambda file: process_file(agent, file.name) if file else ([{"role": "assistant", "content": "❌ Please upload a file"}], None),
286
- inputs=[file_input],
287
- outputs=[chatbot, report_output]
 
288
  )
289
-
290
  return demo
291
 
292
  if __name__ == "__main__":
@@ -298,7 +376,9 @@ if __name__ == "__main__":
298
  server_port=7860,
299
  show_error=True,
300
  allowed_paths=[report_dir],
301
- share=False
 
 
302
  )
303
  except Exception as e:
304
  logger.error(f"Application failed: {str(e)}")
 
2
  import os
3
  import pandas as pd
4
  import gradio as gr
5
+ from typing import List, Tuple, Dict, Any, Union
6
  import shutil
7
  import re
8
  from datetime import datetime
9
  import time
10
+ from transformers import AutoTokenizer
11
  import asyncio
12
  import logging
13
  from concurrent.futures import ThreadPoolExecutor, as_completed
 
24
  for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]:
25
  os.makedirs(directory, exist_ok=True)
26
 
27
+ os.environ["HF_HOME"] = model_cache_dir # Using HF_HOME as specified
28
 
29
  current_dir = os.path.dirname(os.path.abspath(__file__))
30
  src_path = os.path.abspath(os.path.join(current_dir, "src"))
 
32
 
33
  from txagent.txagent import TxAgent
34
 
35
+ # Constants
36
  MAX_MODEL_TOKENS = 131072 # TxAgent's max token limit
37
  MAX_CHUNK_TOKENS = 32768 # Larger chunks to reduce number of chunks
38
  MAX_NEW_TOKENS = 512 # Optimized for fast generation
39
  PROMPT_OVERHEAD = 500 # Estimated tokens for prompt template
40
  MAX_CONCURRENT = 8 # High concurrency for A100 80GB
41
 
42
+ # Initialize tokenizer for precise token counting
43
+ try:
44
+ tokenizer = AutoTokenizer.from_pretrained("mims-harvard/TxAgent-T1-Llama-3.1-8B")
45
+ except Exception as e:
46
+ print(f"Warning: Could not load tokenizer, falling back to heuristic: {str(e)}")
47
+ tokenizer = None
48
+
49
  # Setup logging
50
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
51
  logger = logging.getLogger(__name__)
52
 
53
  def clean_response(text: str) -> str:
 
61
  return text.strip()
62
 
63
  def estimate_tokens(text: str) -> int:
64
+ """Estimate tokens using tokenizer if available, else fall back to heuristic."""
65
+ if tokenizer:
66
+ return len(tokenizer.encode(text, add_special_tokens=False))
67
+ return len(text) // 3.5 + 1 # Consistent with your heuristic
68
 
69
  def extract_text_from_excel(file_path: str) -> str:
70
+ """Extract text from all sheets in an Excel file."""
71
  all_text = []
72
  try:
73
  xls = pd.ExcelFile(file_path)
 
82
  raise ValueError(f"Failed to process Excel file: {str(e)}")
83
  return "\n".join(all_text)
84
 
85
+ def split_text_into_chunks(text: str, max_tokens: int = MAX_CHUNK_TOKENS) -> List[str]:
86
+ """Split text into chunks respecting MAX_CHUNK_TOKENS and PROMPT_OVERHEAD."""
87
+ effective_max_tokens = max_tokens - PROMPT_OVERHEAD
88
+ if effective_max_tokens <= 0:
89
+ raise ValueError(f"Effective max tokens ({effective_max_tokens}) must be positive.")
90
+
91
  lines = text.split("\n")
92
  chunks = []
93
  current_chunk = []
 
95
 
96
  for line in lines:
97
  line_tokens = estimate_tokens(line)
98
+ if current_tokens + line_tokens > effective_max_tokens:
99
  if current_chunk:
100
  chunks.append("\n".join(current_chunk))
101
  current_chunk = [line]
 
106
 
107
  if current_chunk:
108
  chunks.append("\n".join(current_chunk))
109
+
110
  logger.info(f"Split text into {len(chunks)} chunks")
111
  return chunks
112
 
113
  def build_prompt_from_text(chunk: str) -> str:
114
+ """Build a prompt for analyzing a chunk of clinical data."""
115
  return f"""
116
  ### Unstructured Clinical Records
117
 
 
132
  """
133
 
134
  def init_agent():
135
+ """Initialize the TxAgent with optimized vLLM settings for A100 80GB."""
136
  default_tool_path = os.path.abspath("data/new_tool.json")
137
  target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
138
 
 
152
  agent.init_model()
153
  return agent
154
 
155
+ async def process_chunk(agent, chunk: str, chunk_index: int, total_chunks: int) -> Tuple[int, str, str]:
156
+ """Process a single chunk and return index, response, and status message."""
157
+ logger.info(f"Processing chunk {chunk_index+1}/{total_chunks}")
158
+ prompt = build_prompt_from_text(chunk)
159
+ prompt_tokens = estimate_tokens(prompt)
160
+
161
+ if prompt_tokens > MAX_MODEL_TOKENS:
162
+ error_msg = f"❌ Chunk {chunk_index+1} prompt too long ({prompt_tokens} tokens). Skipping..."
163
+ logger.warning(error_msg)
164
+ return chunk_index, "", error_msg
165
+
166
+ response = ""
167
  try:
 
 
 
 
 
 
 
 
168
  for result in agent.run_gradio_chat(
169
  message=prompt,
170
  history=[],
 
182
  for r in result:
183
  if hasattr(r, "content"):
184
  response += r.content
185
+ status = f"βœ… Chunk {chunk_index+1} analysis complete"
186
+ logger.info(status)
 
187
  except Exception as e:
188
+ status = f"❌ Error analyzing chunk {chunk_index+1}: {str(e)}"
189
+ logger.error(status)
190
+ response = ""
191
+
192
+ return chunk_index, clean_response(response), status
193
 
194
+ async def process_final_report(agent, file, chatbot_state: List[Dict[str, str]]) -> Tuple[List[Dict[str, str]], Union[str, None]]:
195
+ """Process the Excel file and generate a final report."""
196
+ messages = chatbot_state if chatbot_state else []
197
  report_path = None
198
+
199
+ if file is None or not hasattr(file, "name"):
200
+ messages.append({"role": "assistant", "content": "❌ Please upload a valid Excel file before analyzing."})
201
+ return messages, report_path
202
+
203
  try:
204
+ messages.append({"role": "user", "content": f"Processing Excel file: {os.path.basename(file.name)}"})
205
+ messages.append({"role": "assistant", "content": "⏳ Extracting and analyzing data..."})
206
+
207
+ # Extract text and split into chunks
 
 
208
  start_time = time.time()
209
+ extracted_text = extract_text_from_excel(file.name)
210
+ chunks = split_text_into_chunks(extracted_text, max_tokens=MAX_CHUNK_TOKENS)
211
+ logger.info(f"Extracted text and split into {len(chunks)} chunks in {time.time() - start_time:.2f} seconds")
212
+
 
 
213
  chunk_responses = [None] * len(chunks)
214
+ batch_size = MAX_CONCURRENT
215
+
216
+ # Process chunks in batches
217
+ for batch_start in range(0, len(chunks), batch_size):
218
+ batch_chunks = chunks[batch_start:batch_start + batch_size]
219
+ batch_indices = list(range(batch_start, min(batch_start + batch_size, len(chunks))))
220
+ logger.info(f"Processing batch {batch_start//batch_size + 1}/{(len(chunks) + batch_size - 1)//batch_size}")
221
+
222
+ with ThreadPoolExecutor(max_workers=MAX_CONCURRENT) as executor:
223
+ futures = [
224
+ executor.submit(lambda c, i: asyncio.run(process_chunk(agent, c, i, len(chunks))), chunk, i)
225
+ for i, chunk in zip(batch_indices, batch_chunks)
226
+ ]
227
+ for future in as_completed(futures):
228
+ chunk_index, response, status = future.result()
229
+ chunk_responses[chunk_index] = response
230
+ messages.append({"role": "assistant", "content": status})
231
+
232
+ # Filter out empty responses
233
+ chunk_responses = [r for r in chunk_responses if r]
234
+ if not chunk_responses:
235
+ messages.append({"role": "assistant", "content": "❌ No valid chunk responses to summarize."})
236
+ return messages, report_path
237
+
238
+ # Summarize chunk responses incrementally
239
+ summary = ""
240
+ current_summary_tokens = 0
241
+ for i, response in enumerate(chunk_responses):
242
+ response_tokens = estimate_tokens(response)
243
+ if current_summary_tokens + response_tokens > MAX_MODEL_TOKENS - PROMPT_OVERHEAD - MAX_NEW_TOKENS:
244
+ summary_prompt = f"Summarize the following analysis:\n\n{summary}\n\nProvide a concise summary."
245
+ summary_response = ""
246
+ try:
247
+ for result in agent.run_gradio_chat(
248
+ message=summary_prompt,
249
+ history=[],
250
+ temperature=0.2,
251
+ max_new_tokens=MAX_NEW_TOKENS,
252
+ max_token=MAX_MODEL_TOKENS,
253
+ call_agent=False,
254
+ conversation=[],
255
+ ):
256
+ if isinstance(result, str):
257
+ summary_response += result
258
+ elif hasattr(result, "content"):
259
+ summary_response += result.content
260
+ elif isinstance(result, list):
261
+ for r in result:
262
+ if hasattr(r, "content"):
263
+ summary_response += r.content
264
+ summary = clean_response(summary_response)
265
+ current_summary_tokens = estimate_tokens(summary)
266
+ except Exception as e:
267
+ messages.append({"role": "assistant", "content": f"❌ Error summarizing intermediate results: {str(e)}"})
268
+ return messages, report_path
269
+
270
+ summary += f"\n\n### Chunk {i+1} Analysis\n{response}"
271
+ current_summary_tokens += response_tokens
272
+
273
+ # Final summarization
274
+ final_prompt = f"Summarize the key findings from the following analyses:\n\n{summary}"
275
  messages.append({"role": "assistant", "content": "πŸ“Š Generating final report..."})
276
+
277
+ final_report_text = ""
278
+ try:
279
+ for result in agent.run_gradio_chat(
280
+ message=final_prompt,
281
+ history=[],
282
+ temperature=0.2,
283
+ max_new_tokens=MAX_NEW_TOKENS * 2, # Allow more tokens for summary, as in your code
284
+ max_token=MAX_MODEL_TOKENS,
285
+ call_agent=False,
286
+ conversation=[],
287
+ ):
288
+ if isinstance(result, str):
289
+ final_report_text += result
290
+ elif hasattr(result, "content"):
291
+ final_report_text += result.content
292
+ elif isinstance(result, list):
293
+ for r in result:
294
+ if hasattr(r, "content"):
295
+ final_report_text += r.content
296
+ except Exception as e:
297
+ messages.append({"role": "assistant", "content": f"❌ Error generating final report: {str(e)}"})
298
+ return messages, report_path
299
+
300
+ final_report = f"# Final Clinical Report\n\n{clean_response(final_report_text)}"
301
+ messages[-1]["content"] = f"πŸ“Š Final Report:\n\n{clean_response(final_report_text)}"
302
+
303
+ # Save the report
304
  timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
305
  report_path = os.path.join(report_dir, f"report_{timestamp}.md")
306
 
307
  with open(report_path, 'w') as f:
308
  f.write(final_report)
309
+
310
+ messages.append({"role": "assistant", "content": f"βœ… Report generated and saved: report_{timestamp}.md"})
311
+ logger.info(f"Total processing time: {time.time() - start_time:.2f} seconds")
312
+
313
+ return messages, report_path
314
+
315
  except Exception as e:
316
+ messages.append({"role": "assistant", "content": f"❌ Error processing file: {str(e)}"})
317
  logger.error(f"Processing failed: {str(e)}")
318
+ return messages, report_path
 
319
 
320
+ def create_ui(agent):
321
+ """Create the Gradio interface."""
322
  with gr.Blocks(title="Clinical Analysis", css=".gradio-container {max-width: 900px}") as demo:
323
  gr.Markdown("## πŸ₯ Clinical Data Analysis (TxAgent)")
324
 
 
342
  )
343
  report_output = gr.File(
344
  label="Download Report",
345
+ visible=False,
346
+ interactive=False
347
  )
348
+
349
+ # State to maintain chatbot messages
350
+ chatbot_state = gr.State(value=[])
351
+
352
+ async def update_ui(file, current_state):
353
+ if file is None or not hasattr(file, "name"):
354
+ messages = current_state if current_state else []
355
+ messages.append({"role": "assistant", "content": "❌ Please upload a valid Excel file before analyzing."})
356
+ return messages, None
357
+ messages, report_path = await process_final_report(agent, file, current_state)
358
+ report_update = gr.update(visible=report_path is not None, value=report_path)
359
+ return messages, report_update
360
+
361
  analyze_btn.click(
362
+ fn=update_ui,
363
+ inputs=[file_input, chatbot_state],
364
+ outputs=[chatbot, report_output],
365
+ api_name="analyze"
366
  )
367
+
368
  return demo
369
 
370
  if __name__ == "__main__":
 
376
  server_port=7860,
377
  show_error=True,
378
  allowed_paths=[report_dir],
379
+ share=False,
380
+ inline=False,
381
+ max_threads=40
382
  )
383
  except Exception as e:
384
  logger.error(f"Application failed: {str(e)}")