Ali2206 commited on
Commit
4b4b32b
·
verified ·
1 Parent(s): 5caebdc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -62
app.py CHANGED
@@ -1,14 +1,16 @@
1
  import sys
2
  import os
3
  import pandas as pd
4
- import json
5
  import gradio as gr
6
  from typing import List, Tuple, Dict, Any, Union
7
- import hashlib
8
  import shutil
9
  import re
10
  from datetime import datetime
11
  import time
 
 
 
 
12
 
13
  # Configuration and setup
14
  persistent_dir = "/data/hf_cache"
@@ -32,10 +34,22 @@ sys.path.insert(0, src_path)
32
  from txagent.txagent import TxAgent
33
 
34
  # Constants
35
- MAX_MODEL_TOKENS = 32768 # Model's maximum sequence length
36
- MAX_CHUNK_TOKENS = 8192 # Chunk size aligned with max_num_batched_tokens
37
- MAX_NEW_TOKENS = 2048 # Maximum tokens for generation
38
- PROMPT_OVERHEAD = 500 # Estimated tokens for prompt template overhead
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  def clean_response(text: str) -> str:
41
  try:
@@ -48,8 +62,10 @@ def clean_response(text: str) -> str:
48
  return text.strip()
49
 
50
  def estimate_tokens(text: str) -> int:
51
- """Estimate the number of tokens based on character length."""
52
- return len(text) // 3.5 + 1 # Add 1 to avoid zero estimates
 
 
53
 
54
  def extract_text_from_excel(file_path: str) -> str:
55
  """Extract text from all sheets in an Excel file."""
@@ -67,10 +83,7 @@ def extract_text_from_excel(file_path: str) -> str:
67
  return "\n".join(all_text)
68
 
69
  def split_text_into_chunks(text: str, max_tokens: int = MAX_CHUNK_TOKENS) -> List[str]:
70
- """
71
- Split text into chunks, ensuring each chunk is within token limits,
72
- accounting for prompt overhead.
73
- """
74
  effective_max_tokens = max_tokens - PROMPT_OVERHEAD
75
  if effective_max_tokens <= 0:
76
  raise ValueError(f"Effective max tokens ({effective_max_tokens}) must be positive.")
@@ -83,7 +96,7 @@ def split_text_into_chunks(text: str, max_tokens: int = MAX_CHUNK_TOKENS) -> Lis
83
  for line in lines:
84
  line_tokens = estimate_tokens(line)
85
  if current_tokens + line_tokens > effective_max_tokens:
86
- if current_chunk: # Save the current chunk if it's not empty
87
  chunks.append("\n".join(current_chunk))
88
  current_chunk = [line]
89
  current_tokens = line_tokens
@@ -118,7 +131,7 @@ Please analyze the above and provide:
118
  """
119
 
120
  def init_agent():
121
- """Initialize the TxAgent with model and tool configurations."""
122
  default_tool_path = os.path.abspath("data/new_tool.json")
123
  target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
124
 
@@ -138,8 +151,47 @@ def init_agent():
138
  agent.init_model()
139
  return agent
140
 
141
- def process_final_report(agent, file, chatbot_state: List[Dict[str, str]]) -> Tuple[List[Dict[str, str]], Union[str, None]]:
142
- """Process the Excel file and generate a final report."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  messages = chatbot_state if chatbot_state else []
144
  report_path = None
145
 
@@ -152,57 +204,43 @@ def process_final_report(agent, file, chatbot_state: List[Dict[str, str]]) -> Tu
152
  messages.append({"role": "assistant", "content": "⏳ Extracting and analyzing data..."})
153
 
154
  # Extract text and split into chunks
 
155
  extracted_text = extract_text_from_excel(file.name)
156
  chunks = split_text_into_chunks(extracted_text, max_tokens=MAX_CHUNK_TOKENS)
157
- chunk_responses = []
158
-
159
- # Process each chunk
160
- for i, chunk in enumerate(chunks):
161
- messages.append({"role": "assistant", "content": f"🔍 Analyzing chunk {i+1}/{len(chunks)}..."})
162
-
163
- prompt = build_prompt_from_text(chunk)
164
- prompt_tokens = estimate_tokens(prompt)
165
- if prompt_tokens > MAX_MODEL_TOKENS:
166
- messages.append({"role": "assistant", "content": f" Chunk {i+1} prompt too long ({prompt_tokens} tokens). Skipping..."})
167
- continue
168
-
169
- response = ""
170
- try:
171
- for result in agent.run_gradio_chat(
172
- message=prompt,
173
- history=[],
174
- temperature=0.2,
175
- max_new_tokens=MAX_NEW_TOKENS,
176
- max_token=MAX_MODEL_TOKENS,
177
- call_agent=False,
178
- conversation=[],
179
- ):
180
- if isinstance(result, str):
181
- response += result
182
- elif hasattr(result, "content"):
183
- response += result.content
184
- elif isinstance(result, list):
185
- for r in result:
186
- if hasattr(r, "content"):
187
- response += r.content
188
- except Exception as e:
189
- messages.append({"role": "assistant", "content": f"❌ Error analyzing chunk {i+1}: {str(e)}"})
190
- continue
191
-
192
- chunk_responses.append(clean_response(response))
193
- messages.append({"role": "assistant", "content": f"✅ Chunk {i+1} analysis complete"})
194
-
195
  if not chunk_responses:
196
  messages.append({"role": "assistant", "content": "❌ No valid chunk responses to summarize."})
197
  return messages, report_path
198
 
199
- # Summarize chunk responses incrementally to avoid token limit
200
  summary = ""
201
  current_summary_tokens = 0
202
  for i, response in enumerate(chunk_responses):
203
  response_tokens = estimate_tokens(response)
204
  if current_summary_tokens + response_tokens > MAX_MODEL_TOKENS - PROMPT_OVERHEAD - MAX_NEW_TOKENS:
205
- # Summarize current summary
206
  summary_prompt = f"Summarize the following analysis:\n\n{summary}\n\nProvide a concise summary."
207
  summary_response = ""
208
  try:
@@ -270,13 +308,15 @@ def process_final_report(agent, file, chatbot_state: List[Dict[str, str]]) -> Tu
270
  f.write(final_report)
271
 
272
  messages.append({"role": "assistant", "content": f"✅ Report generated and saved: report_{timestamp}.md"})
 
273
 
274
  except Exception as e:
275
  messages.append({"role": "assistant", "content": f"❌ Error processing file: {str(e)}"})
 
276
 
277
  return messages, report_path
278
 
279
- def create_ui(agent):
280
  """Create the Gradio UI for the patient history analysis tool."""
281
  with gr.Blocks(title="Patient History Chat", css=".gradio-container {max-width: 900px !important}") as demo:
282
  gr.Markdown("## 🏥 Patient History Analysis Tool")
@@ -312,10 +352,15 @@ def create_ui(agent):
312
  # State to maintain chatbot messages
313
  chatbot_state = gr.State(value=[])
314
 
315
- def update_ui(file, current_state):
316
- messages, report_path = process_final_report(agent, file, current_state)
317
- report_update = gr.update(visible=report_path is not None, value=report_path)
318
- return messages, report_update, messages
 
 
 
 
 
319
 
320
  analyze_btn.click(
321
  fn=update_ui,
@@ -329,7 +374,7 @@ def create_ui(agent):
329
  if __name__ == "__main__":
330
  try:
331
  agent = init_agent()
332
- demo = create_ui(agent)
333
  demo.launch(
334
  server_name="0.0.0.0",
335
  server_port=7860,
 
1
  import sys
2
  import os
3
  import pandas as pd
 
4
  import gradio as gr
5
  from typing import List, Tuple, Dict, Any, Union
 
6
  import shutil
7
  import re
8
  from datetime import datetime
9
  import time
10
+ from transformers import AutoTokenizer
11
+ import asyncio
12
+ import logging
13
+ from concurrent.futures import ThreadPoolExecutor, as_completed
14
 
15
  # Configuration and setup
16
  persistent_dir = "/data/hf_cache"
 
34
  from txagent.txagent import TxAgent
35
 
36
  # Constants
37
+ MAX_MODEL_TOKENS = 131072 # TxAgent's max token limit
38
+ MAX_CHUNK_TOKENS = 32768 # Larger chunks to reduce number of chunks
39
+ MAX_NEW_TOKENS = 512 # Optimized for fast generation
40
+ PROMPT_OVERHEAD = 500 # Estimated tokens for prompt template
41
+ MAX_CONCURRENT = 8 # High concurrency for A100 80GB
42
+
43
+ # Initialize tokenizer for precise token counting
44
+ try:
45
+ tokenizer = AutoTokenizer.from_pretrained("mims-harvard/TxAgent-T1-Llama-3.1-8B")
46
+ except Exception as e:
47
+ print(f"Warning: Could not load tokenizer, falling back to heuristic: {str(e)}")
48
+ tokenizer = None
49
+
50
+ # Setup logging
51
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
52
+ logger = logging.getLogger(__name__)
53
 
54
  def clean_response(text: str) -> str:
55
  try:
 
62
  return text.strip()
63
 
64
  def estimate_tokens(text: str) -> int:
65
+ """Estimate tokens using tokenizer if available, else fall back to heuristic."""
66
+ if tokenizer:
67
+ return len(tokenizer.encode(text, add_special_tokens=False))
68
+ return len(text) // 3.5 + 1
69
 
70
  def extract_text_from_excel(file_path: str) -> str:
71
  """Extract text from all sheets in an Excel file."""
 
83
  return "\n".join(all_text)
84
 
85
  def split_text_into_chunks(text: str, max_tokens: int = MAX_CHUNK_TOKENS) -> List[str]:
86
+ """Split text into chunks within token limits, accounting for prompt overhead."""
 
 
 
87
  effective_max_tokens = max_tokens - PROMPT_OVERHEAD
88
  if effective_max_tokens <= 0:
89
  raise ValueError(f"Effective max tokens ({effective_max_tokens}) must be positive.")
 
96
  for line in lines:
97
  line_tokens = estimate_tokens(line)
98
  if current_tokens + line_tokens > effective_max_tokens:
99
+ if current_chunk:
100
  chunks.append("\n".join(current_chunk))
101
  current_chunk = [line]
102
  current_tokens = line_tokens
 
131
  """
132
 
133
  def init_agent():
134
+ """Initialize the TxAgent with optimized vLLM settings for A100 80GB."""
135
  default_tool_path = os.path.abspath("data/new_tool.json")
136
  target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
137
 
 
151
  agent.init_model()
152
  return agent
153
 
154
+ async def process_chunk(agent, chunk: str, chunk_index: int, total_chunks: int) -> Tuple[int, str, str]:
155
+ """Process a single chunk and return index, response, and status message."""
156
+ logger.info(f"Processing chunk {chunk_index+1}/{total_chunks}")
157
+ prompt = build_prompt_from_text(chunk)
158
+ prompt_tokens = estimate_tokens(prompt)
159
+
160
+ if prompt_tokens > MAX_MODEL_TOKENS:
161
+ error_msg = f"❌ Chunk {chunk_index+1} prompt too long ({prompt_tokens} tokens). Skipping..."
162
+ logger.warning(error_msg)
163
+ return chunk_index, "", error_msg
164
+
165
+ response = ""
166
+ try:
167
+ for result in agent.run_gradio_chat(
168
+ message=prompt,
169
+ history=[],
170
+ temperature=0.2,
171
+ max_new_tokens=MAX_NEW_TOKENS,
172
+ max_token=MAX_MODEL_TOKENS,
173
+ call_agent=False,
174
+ conversation=[],
175
+ ):
176
+ if isinstance(result, str):
177
+ response += result
178
+ elif hasattr(result, "content"):
179
+ response += result.content
180
+ elif isinstance(result, list):
181
+ for r in result:
182
+ if hasattr(r, "content"):
183
+ response += r.content
184
+ status = f"✅ Chunk {chunk_index+1} analysis complete"
185
+ logger.info(status)
186
+ except Exception as e:
187
+ status = f"❌ Error analyzing chunk {chunk_index+1}: {str(e)}"
188
+ logger.error(status)
189
+ response = ""
190
+
191
+ return chunk_index, clean_response(response), status
192
+
193
+ async def process_final_report(agent, file, chatbot_state: List[Dict[str, str]]) -> Tuple[List[Dict[str, str]], Union[str, None]]:
194
+ """Process the Excel file and generate a final report with asynchronous updates."""
195
  messages = chatbot_state if chatbot_state else []
196
  report_path = None
197
 
 
204
  messages.append({"role": "assistant", "content": "⏳ Extracting and analyzing data..."})
205
 
206
  # Extract text and split into chunks
207
+ start_time = time.time()
208
  extracted_text = extract_text_from_excel(file.name)
209
  chunks = split_text_into_chunks(extracted_text, max_tokens=MAX_CHUNK_TOKENS)
210
+ logger.info(f"Extracted text and split into {len(chunks)} chunks in {time.time() - start_time:.2f} seconds")
211
+
212
+ chunk_responses = [None] * len(chunks)
213
+ batch_size = MAX_CONCURRENT
214
+
215
+ # Process chunks in batches
216
+ for batch_start in range(0, len(chunks), batch_size):
217
+ batch_chunks = chunks[batch_start:batch_start + batch_size]
218
+ batch_indices = list(range(batch_start, min(batch_start + batch_size, len(chunks))))
219
+ logger.info(f"Processing batch {batch_start//batch_size + 1}/{(len(chunks) + batch_size - 1)//batch_size}")
220
+
221
+ with ThreadPoolExecutor(max_workers=MAX_CONCURRENT) as executor:
222
+ futures = [
223
+ executor.submit(lambda c, i: asyncio.run(process_chunk(agent, c, i, len(chunks))), chunk, i)
224
+ for i, chunk in zip(batch_indices, batch_chunks)
225
+ ]
226
+ for future in as_completed(futures):
227
+ chunk_index, response, status = future.result()
228
+ chunk_responses[chunk_index] = response
229
+ messages.append({"role": "assistant", "content": status})
230
+ yield messages, None
231
+
232
+ # Filter out empty responses
233
+ chunk_responses = [r for r in chunk_responses if r]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  if not chunk_responses:
235
  messages.append({"role": "assistant", "content": "❌ No valid chunk responses to summarize."})
236
  return messages, report_path
237
 
238
+ # Summarize chunk responses incrementally
239
  summary = ""
240
  current_summary_tokens = 0
241
  for i, response in enumerate(chunk_responses):
242
  response_tokens = estimate_tokens(response)
243
  if current_summary_tokens + response_tokens > MAX_MODEL_TOKENS - PROMPT_OVERHEAD - MAX_NEW_TOKENS:
 
244
  summary_prompt = f"Summarize the following analysis:\n\n{summary}\n\nProvide a concise summary."
245
  summary_response = ""
246
  try:
 
308
  f.write(final_report)
309
 
310
  messages.append({"role": "assistant", "content": f"✅ Report generated and saved: report_{timestamp}.md"})
311
+ logger.info(f"Total processing time: {time.time() - start_time:.2f} seconds")
312
 
313
  except Exception as e:
314
  messages.append({"role": "assistant", "content": f"❌ Error processing file: {str(e)}"})
315
+ logger.error(f"Processing failed: {str(e)}")
316
 
317
  return messages, report_path
318
 
319
+ async def create_ui(agent):
320
  """Create the Gradio UI for the patient history analysis tool."""
321
  with gr.Blocks(title="Patient History Chat", css=".gradio-container {max-width: 900px !important}") as demo:
322
  gr.Markdown("## 🏥 Patient History Analysis Tool")
 
352
  # State to maintain chatbot messages
353
  chatbot_state = gr.State(value=[])
354
 
355
+ async def update_ui(file, current_state):
356
+ messages = current_state if current_state else []
357
+ report_path = None
358
+ async for new_messages, new_report_path in process_final_report(agent, file, messages):
359
+ messages = new_messages
360
+ report_path = new_report_path
361
+ report_update = gr.update(visible=report_path is not None, value=report_path)
362
+ yield messages, report_update, messages
363
+ yield messages, gr.update(visible=report_path is not None, value=report_path), messages
364
 
365
  analyze_btn.click(
366
  fn=update_ui,
 
374
  if __name__ == "__main__":
375
  try:
376
  agent = init_agent()
377
+ demo = asyncio.run(create_ui(agent))
378
  demo.launch(
379
  server_name="0.0.0.0",
380
  server_port=7860,