Ali2206 commited on
Commit
ff76617
·
verified ·
1 Parent(s): 76162fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -266
app.py CHANGED
@@ -2,23 +2,18 @@ import sys
2
  import os
3
  import pandas as pd
4
  import pdfplumber
5
- import json
6
  import gradio as gr
7
- from typing import List, Dict, Optional, Generator
8
  from concurrent.futures import ThreadPoolExecutor, as_completed
9
  import hashlib
10
  import shutil
11
  import re
12
- import psutil
13
- import subprocess
14
  import logging
15
  import torch
16
  import gc
17
  from diskcache import Cache
18
- import time
19
  from transformers import AutoTokenizer
20
  from functools import lru_cache
21
- import numpy as np
22
  from difflib import SequenceMatcher
23
 
24
  # Configure logging
@@ -32,6 +27,7 @@ MAX_WORKERS = 2
32
  CHUNK_SIZE = 5
33
  MODEL_MAX_TOKENS = 131072
34
  MAX_TEXT_LENGTH = 500000
 
35
 
36
  # Persistent directory setup
37
  persistent_dir = "/data/hf_cache"
@@ -41,17 +37,11 @@ model_cache_dir = os.path.join(persistent_dir, "txagent_models")
41
  tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
42
  file_cache_dir = os.path.join(persistent_dir, "cache")
43
  report_dir = os.path.join(persistent_dir, "reports")
44
- vllm_cache_dir = os.path.join(persistent_dir, "vllm_cache")
45
-
46
- for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
47
- os.makedirs(directory, exist_ok=True)
48
 
49
  os.environ.update({
50
  "HF_HOME": model_cache_dir,
51
- "TRANSFORMERS_CACHE": model_cache_dir,
52
- "VLLM_CACHE_DIR": vllm_cache_dir,
53
  "TOKENIZERS_PARALLELISM": "false",
54
- "CUDA_LAUNCH_BLOCKING": "1"
55
  })
56
 
57
  current_dir = os.path.dirname(os.path.abspath(__file__))
@@ -60,7 +50,7 @@ sys.path.insert(0, src_path)
60
 
61
  from txagent.txagent import TxAgent
62
 
63
- # Initialize cache with 10GB limit
64
  cache = Cache(file_cache_dir, size_limit=10 * 1024**3)
65
 
66
  @lru_cache(maxsize=1)
@@ -99,70 +89,47 @@ def extract_pdf_page(page, tokenizer, max_tokens=MAX_TOKENS) -> List[str]:
99
  current_length += 1
100
  if current_chunk:
101
  chunks.append(tokenizer.decode(current_chunk))
102
- return [f"=== Page {page.page_number} ===\n{c}" for c in chunks]
103
- return [f"=== Page {page.page_number} ===\n{text}"]
104
  except Exception as e:
105
  logger.warning(f"Error extracting page {page.page_number}: {str(e)}")
106
  return []
107
 
108
- def extract_all_pages(file_path: str, progress_callback=None) -> List[str]:
109
  try:
110
  tokenizer = get_tokenizer()
111
  with pdfplumber.open(file_path) as pdf:
112
  total_pages = len(pdf.pages)
113
  if total_pages == 0:
114
- logger.error("PDF has 0 pages - may be corrupted or empty")
115
- return []
116
 
117
  results = []
118
- total_tokens = 0
119
- for chunk_start in range(0, total_pages, CHUNK_SIZE):
120
- chunk_end = min(chunk_start + CHUNK_SIZE, total_pages)
121
-
122
- with pdfplumber.open(file_path) as pdf:
123
- with ThreadPoolExecutor(max_workers=min(CHUNK_SIZE, 2)) as executor:
124
- futures = [executor.submit(extract_pdf_page, pdf.pages[i], tokenizer)
125
- for i in range(chunk_start, chunk_end)]
126
-
127
- for future in as_completed(futures):
128
- page_chunks = future.result()
129
- for chunk in page_chunks:
130
- chunk_tokens = len(tokenizer.encode(chunk, add_special_tokens=False))
131
- if total_tokens + chunk_tokens > MODEL_MAX_TOKENS:
132
- logger.warning("Total tokens exceed model limit. Stopping.")
133
- return results
134
- results.append(chunk)
135
- total_tokens += chunk_tokens
136
-
137
- if progress_callback:
138
- progress_callback(min(chunk_end, total_pages), total_pages)
139
-
140
- del pdf
141
- gc.collect()
142
-
143
- if not results:
144
- logger.error("No content extracted from PDF - may be scanned or encrypted")
145
- return ["PDF appears to be empty or unreadable"]
146
 
147
- return results
148
  except Exception as e:
149
  logger.error(f"PDF processing error: {e}")
150
  return [f"PDF processing error: {str(e)}"]
151
 
152
  def excel_to_json(file_path: str) -> List[Dict]:
153
- """Enhanced Excel processing with multiple engine support"""
154
- engines = ['openpyxl', 'xlrd', 'odf']
155
- last_error = None
156
-
157
  for engine in engines:
158
  try:
159
  with pd.ExcelFile(file_path, engine=engine) as excel_file:
160
  sheets = excel_file.sheet_names
161
  if not sheets:
162
- return [{"error": "No sheets found in Excel file"}]
163
 
164
  results = []
165
- for sheet_name in sheets:
166
  try:
167
  df = pd.read_excel(
168
  excel_file,
@@ -170,99 +137,70 @@ def excel_to_json(file_path: str) -> List[Dict]:
170
  header=None,
171
  dtype=str,
172
  na_filter=False,
173
- engine=engine
174
  )
175
  if not df.empty:
176
- # Convert all cells to string and clean
177
- df = df.applymap(lambda x: str(x).strip() if pd.notna(x) else "")
178
  results.append({
179
- "filename": f"{os.path.basename(file_path)} - {sheet_name}",
180
- "rows": df.values.tolist(),
181
- "type": "excel",
182
  "sheet": sheet_name,
183
- "dimensions": f"{len(df)} rows x {len(df.columns)} cols"
 
184
  })
185
- except Exception as sheet_error:
186
- logger.warning(f"Error processing sheet {sheet_name}: {sheet_error}")
187
  continue
188
 
189
- if results:
190
- logger.info(f"Successfully processed Excel file with {engine} engine")
191
- return results
192
- except Exception as engine_error:
193
- last_error = engine_error
194
  continue
195
 
196
- return [{"error": f"Failed to process Excel file with all engines. Last error: {str(last_error)}"}]
197
 
198
  def csv_to_json(file_path: str) -> List[Dict]:
199
  try:
200
- chunks = []
201
- for chunk in pd.read_csv(
202
  file_path,
203
  header=None,
204
  dtype=str,
205
  encoding_errors='replace',
206
  on_bad_lines='skip',
207
- chunksize=10000,
208
- na_filter=False
209
- ):
210
- chunks.append(chunk)
211
-
212
- df = pd.concat(chunks) if chunks else pd.DataFrame()
213
  if df.empty:
214
- return [{"error": "CSV file is empty or could not be read"}]
215
 
216
  return [{
217
  "filename": os.path.basename(file_path),
218
  "rows": df.values.tolist(),
219
- "type": "csv",
220
- "dimensions": f"{len(df)} rows x {len(df.columns)} cols"
221
  }]
222
  except Exception as e:
223
  logger.error(f"CSV processing error: {e}")
224
  return [{"error": f"CSV processing error: {str(e)}"}]
225
 
226
- @lru_cache(maxsize=100)
227
  def process_file_cached(file_path: str, file_type: str) -> List[Dict]:
228
- """Enhanced file processing with detailed logging"""
229
  try:
230
- logger.info(f"Processing file: {file_path} (type: {file_type})")
231
 
232
  if file_type == "pdf":
233
  chunks = extract_all_pages(file_path)
234
- if not chunks or (len(chunks) == 1 and "error" in chunks[0]):
235
- return [{"error": chunks[0] if chunks else "PDF appears to be empty"}]
236
  return [{
237
  "filename": os.path.basename(file_path),
238
  "content": chunk,
239
- "status": "initial",
240
- "type": "pdf",
241
- "page": i+1
242
- } for i, chunk in enumerate(chunks)]
243
 
244
  elif file_type in ["xls", "xlsx"]:
245
- result = excel_to_json(file_path)
246
- if "error" in result[0]:
247
- logger.error(f"Excel processing failed: {result[0]['error']}")
248
- else:
249
- logger.info(f"Excel processing successful - found {len(result)} sheets")
250
- return result
251
 
252
  elif file_type == "csv":
253
- result = csv_to_json(file_path)
254
- if "error" in result[0]:
255
- logger.error(f"CSV processing failed: {result[0]['error']}")
256
- else:
257
- logger.info(f"CSV processing successful - found {len(result[0]['rows'])} rows")
258
- return result
259
-
260
- else:
261
- logger.warning(f"Unsupported file type: {file_type}")
262
- return [{"error": f"Unsupported file type: {file_type}"}]
263
 
 
264
  except Exception as e:
265
- logger.error(f"Error processing {file_path}: {str(e)}", exc_info=True)
266
  return [{"error": f"Error processing file: {str(e)}"}]
267
 
268
  def clean_response(text: str) -> str:
@@ -272,49 +210,25 @@ def clean_response(text: str) -> str:
272
  patterns = [
273
  (re.compile(r"\[.*?\]|\bNone\b", re.IGNORECASE), ""),
274
  (re.compile(r"\s+"), " "),
275
- (re.compile(r"[^\w\s\.\,\(\)\-]"), ""),
276
  ]
277
 
278
  for pattern, repl in patterns:
279
  text = pattern.sub(repl, text)
280
 
281
- sentences = text.split(". ")
282
- unique_sentences = []
283
- seen = set()
284
-
285
- for s in sentences:
286
- if not s:
287
- continue
288
- is_unique = True
289
- for seen_s in seen:
290
- if SequenceMatcher(None, s.lower(), seen_s.lower()).ratio() > 0.9:
291
- is_unique = False
292
- break
293
- if is_unique:
294
- unique_sentences.append(s)
295
- seen.add(s)
296
-
297
- text = ". ".join(unique_sentences).strip()
298
- return text if text else "No missed diagnoses identified."
299
 
300
  @lru_cache(maxsize=1)
301
  def init_agent():
302
  logger.info("Initializing model...")
303
 
304
- default_tool_path = os.path.abspath("data/new_tool.json")
305
- target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
306
- if not os.path.exists(target_tool_path):
307
- shutil.copy(default_tool_path, target_tool_path)
308
-
309
  agent = TxAgent(
310
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
311
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
312
- tool_files_dict={"new_tool": target_tool_path},
313
  force_finish=True,
314
  enable_checker=False,
315
  step_rag_num=4,
316
  seed=100,
317
- additional_default_tools=[],
318
  )
319
  agent.init_model()
320
  logger.info("Agent Ready")
@@ -322,8 +236,7 @@ def init_agent():
322
 
323
  def create_ui(agent):
324
  PROMPT_TEMPLATE = """
325
- Analyze the patient record excerpt for missed diagnoses. Provide detailed, evidence-based analysis.
326
- Patient Record Excerpt (Chunk {0} of {1}):
327
  {chunk}
328
  """
329
 
@@ -332,170 +245,89 @@ Patient Record Excerpt (Chunk {0} of {1}):
332
 
333
  with gr.Row():
334
  with gr.Column(scale=3):
335
- chatbot = gr.Chatbot(label="Analysis Summary", height=600, value=[])
336
  msg_input = gr.Textbox(placeholder="Ask about potential oversights...")
337
  send_btn = gr.Button("Analyze", variant="primary")
338
- file_upload = gr.File(file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple")
339
 
340
  with gr.Column(scale=1):
341
- final_summary = gr.Markdown(label="Missed Diagnoses Summary")
342
- download_output = gr.File(label="Download Detailed Report")
343
- progress_bar = gr.Progress()
344
 
345
- def analyze(message: str, history: List[List[str]], files: List, progress=gr.Progress()):
346
- """Enhanced analysis with detailed file processing feedback"""
347
  try:
348
- if history is None:
349
- history = []
350
-
351
- history.append([message, None])
352
- yield history, None, ""
353
-
354
  if not files:
355
- history[-1][1] = "Please upload a file to analyze"
356
- yield history, None, "No files uploaded"
357
- return
358
-
359
- extracted = []
360
- file_hash_value = ""
361
 
362
- for f in files:
363
- file_type = f.name.split(".")[-1].lower()
364
- logger.info(f"Processing file: {f.name} (type: {file_type})")
365
-
366
- cache_key = f"{file_hash(f.name)}_{file_type}"
367
- if cache_key in cache:
368
- cached_data = cache[cache_key]
369
- if isinstance(cached_data, list) and len(cached_data) > 0:
370
- extracted.extend(cached_data)
371
- history[-1][1] = f"✅ Using cached data for {os.path.basename(f.name)}"
372
- yield history, None, ""
373
- continue
374
-
375
- try:
376
- result = process_file_cached(f.name, file_type)
377
- if "error" in result[0]:
378
- history[-1][1] = f"❌ Error processing {os.path.basename(f.name)}: {result[0]['error']}"
379
- yield history, None, result[0]['error']
380
- return
381
-
382
- cache[cache_key] = result
383
- extracted.extend(result)
384
- history[-1][1] = f"✅ Processed {os.path.basename(f.name)}"
385
- yield history, None, ""
386
- except Exception as e:
387
- logger.error(f"File processing error: {e}", exc_info=True)
388
- history[-1][1] = f"❌ Critical error processing {os.path.basename(f.name)}"
389
- yield history, None, str(e)
390
- return
391
 
392
- file_hash_value = file_hash(files[0].name) if files else ""
 
 
 
393
 
394
- # Debug extracted content
395
- logger.info(f"Extracted content summary:")
396
- for item in extracted:
397
- if "content" in item:
398
- logger.info(f"- {item['filename']}: {len(item['content'])} chars")
399
- elif "rows" in item:
400
- logger.info(f"- {item['filename']}: {len(item['rows'])} rows")
401
-
402
- if not extracted:
403
- history[-1][1] = "❌ No valid content extracted from files"
404
- yield history, None, "No valid content extracted"
405
- return
406
-
407
  chunks = []
408
- for item in extracted:
409
  if "content" in item:
410
  chunks.append(item["content"])
411
  elif "rows" in item:
412
- # Convert Excel/CSV rows to text
413
- rows_text = "\n".join([", ".join(map(str, row)) for row in item["rows"]])
414
- chunks.append(f"=== {item['filename']} ===\n{rows_text}")
415
 
416
  if not chunks:
417
- history[-1][1] = "No processable content found in files"
418
- yield history, None, "No processable content found"
419
- return
420
-
421
- combined_response = ""
422
- report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt") if file_hash_value else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
 
424
- try:
425
- for batch_idx in range(0, len(chunks), BATCH_SIZE):
426
- batch_chunks = chunks[batch_idx:batch_idx + BATCH_SIZE]
427
-
428
- progress(batch_idx / len(chunks),
429
- desc=f"Processing batch {(batch_idx // BATCH_SIZE) + 1}/{(len(chunks) + BATCH_SIZE - 1) // BATCH_SIZE}")
430
-
431
- with ThreadPoolExecutor(max_workers=min(BATCH_SIZE, MAX_WORKERS)) as executor:
432
- futures = {
433
- executor.submit(
434
- agent.run_quick_summary,
435
- chunk, 0.2, 256, 1024
436
- ): idx
437
- for idx, chunk in enumerate(batch_chunks)
438
- }
439
-
440
- for future in as_completed(futures):
441
- chunk_idx = futures[future]
442
- try:
443
- response = clean_response(future.result())
444
- if response:
445
- combined_response += f"\n--- Analysis for Chunk {batch_idx + chunk_idx + 1} ---\n{response}\n"
446
- history[-1][1] = combined_response.strip()
447
- yield history, None, ""
448
- except Exception as e:
449
- logger.error(f"Chunk processing error: {e}")
450
- history[-1][1] = f"Error processing chunk: {str(e)}"
451
- yield history, None, ""
452
- finally:
453
- del future
454
- torch.cuda.empty_cache()
455
- gc.collect()
456
-
457
- summary = "Analysis complete. " + ("Download full report below." if report_path and os.path.exists(report_path) else "")
458
- history.append(["Analysis completed", None])
459
- history[-1][1] = summary
460
- yield history, report_path, summary
461
-
462
- except Exception as e:
463
- logger.error(f"Analysis error: {e}")
464
- history.append(["Analysis failed", None])
465
- history[-1][1] = f"❌ Error occurred: {str(e)}"
466
- yield history, None, f"Error occurred: {str(e)}"
467
- finally:
468
- torch.cuda.empty_cache()
469
- gc.collect()
470
-
471
  except Exception as e:
472
- logger.error(f"Unexpected error in analysis: {e}")
473
- history.append(["System error", None])
474
- history[-1][1] = f"❌ System error occurred: {str(e)}"
475
- yield history, None, f"System error: {str(e)}"
 
476
 
477
  send_btn.click(
478
- analyze,
479
- inputs=[msg_input, gr.State([]), file_upload],
480
- outputs=[chatbot, download_output, final_summary]
481
  )
482
  msg_input.submit(
483
- analyze,
484
- inputs=[msg_input, gr.State([]), file_upload],
485
- outputs=[chatbot, download_output, final_summary]
486
  )
487
 
488
  return demo
489
 
490
  if __name__ == "__main__":
491
  try:
492
- logger.info("Launching app...")
493
  agent = init_agent()
494
  demo = create_ui(agent)
495
- demo.queue().launch(
496
  server_name="0.0.0.0",
497
  server_port=7860,
498
- show_error=True
499
  )
500
  except Exception as e:
501
  logger.error(f"Fatal error: {e}")
 
2
  import os
3
  import pandas as pd
4
  import pdfplumber
 
5
  import gradio as gr
6
+ from typing import List, Dict
7
  from concurrent.futures import ThreadPoolExecutor, as_completed
8
  import hashlib
9
  import shutil
10
  import re
 
 
11
  import logging
12
  import torch
13
  import gc
14
  from diskcache import Cache
 
15
  from transformers import AutoTokenizer
16
  from functools import lru_cache
 
17
  from difflib import SequenceMatcher
18
 
19
  # Configure logging
 
27
  CHUNK_SIZE = 5
28
  MODEL_MAX_TOKENS = 131072
29
  MAX_TEXT_LENGTH = 500000
30
+ MAX_ROWS_TO_PROCESS = 1000 # Limit for Excel/CSV rows
31
 
32
  # Persistent directory setup
33
  persistent_dir = "/data/hf_cache"
 
37
  tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
38
  file_cache_dir = os.path.join(persistent_dir, "cache")
39
  report_dir = os.path.join(persistent_dir, "reports")
40
+ os.makedirs(report_dir, exist_ok=True)
 
 
 
41
 
42
  os.environ.update({
43
  "HF_HOME": model_cache_dir,
 
 
44
  "TOKENIZERS_PARALLELISM": "false",
 
45
  })
46
 
47
  current_dir = os.path.dirname(os.path.abspath(__file__))
 
50
 
51
  from txagent.txagent import TxAgent
52
 
53
+ # Initialize cache
54
  cache = Cache(file_cache_dir, size_limit=10 * 1024**3)
55
 
56
  @lru_cache(maxsize=1)
 
89
  current_length += 1
90
  if current_chunk:
91
  chunks.append(tokenizer.decode(current_chunk))
92
+ return chunks
93
+ return [text]
94
  except Exception as e:
95
  logger.warning(f"Error extracting page {page.page_number}: {str(e)}")
96
  return []
97
 
98
+ def extract_all_pages(file_path: str) -> List[str]:
99
  try:
100
  tokenizer = get_tokenizer()
101
  with pdfplumber.open(file_path) as pdf:
102
  total_pages = len(pdf.pages)
103
  if total_pages == 0:
104
+ return ["PDF appears to be empty"]
 
105
 
106
  results = []
107
+ for i in range(0, min(total_pages, 50)): # Limit to first 50 pages
108
+ try:
109
+ page = pdf.pages[i]
110
+ chunks = extract_pdf_page(page, tokenizer)
111
+ for chunk in chunks:
112
+ results.append(f"=== Page {i+1} ===\n{chunk}")
113
+ except Exception as e:
114
+ logger.warning(f"Error processing page {i+1}: {str(e)}")
115
+ continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
+ return results if results else ["Could not extract text from PDF"]
118
  except Exception as e:
119
  logger.error(f"PDF processing error: {e}")
120
  return [f"PDF processing error: {str(e)}"]
121
 
122
  def excel_to_json(file_path: str) -> List[Dict]:
123
+ engines = ['openpyxl', 'xlrd']
 
 
 
124
  for engine in engines:
125
  try:
126
  with pd.ExcelFile(file_path, engine=engine) as excel_file:
127
  sheets = excel_file.sheet_names
128
  if not sheets:
129
+ return [{"error": "No sheets found"}]
130
 
131
  results = []
132
+ for sheet_name in sheets[:3]: # Limit to first 3 sheets
133
  try:
134
  df = pd.read_excel(
135
  excel_file,
 
137
  header=None,
138
  dtype=str,
139
  na_filter=False,
140
+ nrows=MAX_ROWS_TO_PROCESS # Limit rows
141
  )
142
  if not df.empty:
143
+ rows = df.head(MAX_ROWS_TO_PROCESS).values.tolist()
 
144
  results.append({
145
+ "filename": os.path.basename(file_path),
 
 
146
  "sheet": sheet_name,
147
+ "rows": rows,
148
+ "type": "excel"
149
  })
150
+ except Exception as e:
151
+ logger.warning(f"Error processing sheet {sheet_name}: {str(e)}")
152
  continue
153
 
154
+ return results if results else [{"error": "No readable data found"}]
155
+ except Exception as e:
156
+ logger.warning(f"Excel engine {engine} failed: {str(e)}")
 
 
157
  continue
158
 
159
+ return [{"error": "Could not process Excel file with any engine"}]
160
 
161
  def csv_to_json(file_path: str) -> List[Dict]:
162
  try:
163
+ df = pd.read_csv(
 
164
  file_path,
165
  header=None,
166
  dtype=str,
167
  encoding_errors='replace',
168
  on_bad_lines='skip',
169
+ nrows=MAX_ROWS_TO_PROCESS # Limit rows
170
+ )
 
 
 
 
171
  if df.empty:
172
+ return [{"error": "CSV file is empty"}]
173
 
174
  return [{
175
  "filename": os.path.basename(file_path),
176
  "rows": df.values.tolist(),
177
+ "type": "csv"
 
178
  }]
179
  except Exception as e:
180
  logger.error(f"CSV processing error: {e}")
181
  return [{"error": f"CSV processing error: {str(e)}"}]
182
 
 
183
  def process_file_cached(file_path: str, file_type: str) -> List[Dict]:
 
184
  try:
185
+ logger.info(f"Processing {file_type} file: {os.path.basename(file_path)}")
186
 
187
  if file_type == "pdf":
188
  chunks = extract_all_pages(file_path)
 
 
189
  return [{
190
  "filename": os.path.basename(file_path),
191
  "content": chunk,
192
+ "type": "pdf"
193
+ } for chunk in chunks]
 
 
194
 
195
  elif file_type in ["xls", "xlsx"]:
196
+ return excel_to_json(file_path)
 
 
 
 
 
197
 
198
  elif file_type == "csv":
199
+ return csv_to_json(file_path)
 
 
 
 
 
 
 
 
 
200
 
201
+ return [{"error": f"Unsupported file type: {file_type}"}]
202
  except Exception as e:
203
+ logger.error(f"Error processing file: {e}")
204
  return [{"error": f"Error processing file: {str(e)}"}]
205
 
206
  def clean_response(text: str) -> str:
 
210
  patterns = [
211
  (re.compile(r"\[.*?\]|\bNone\b", re.IGNORECASE), ""),
212
  (re.compile(r"\s+"), " "),
 
213
  ]
214
 
215
  for pattern, repl in patterns:
216
  text = pattern.sub(repl, text)
217
 
218
+ return text.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
  @lru_cache(maxsize=1)
221
  def init_agent():
222
  logger.info("Initializing model...")
223
 
 
 
 
 
 
224
  agent = TxAgent(
225
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
226
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
227
+ tool_files_dict={"new_tool": os.path.join(tool_cache_dir, "new_tool.json")},
228
  force_finish=True,
229
  enable_checker=False,
230
  step_rag_num=4,
231
  seed=100,
 
232
  )
233
  agent.init_model()
234
  logger.info("Agent Ready")
 
236
 
237
  def create_ui(agent):
238
  PROMPT_TEMPLATE = """
239
+ Analyze this patient record excerpt for missed diagnoses (limit response to 500 tokens):
 
240
  {chunk}
241
  """
242
 
 
245
 
246
  with gr.Row():
247
  with gr.Column(scale=3):
248
+ chatbot = gr.Chatbot(label="Analysis", height=500)
249
  msg_input = gr.Textbox(placeholder="Ask about potential oversights...")
250
  send_btn = gr.Button("Analyze", variant="primary")
251
+ file_upload = gr.File(file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="single")
252
 
253
  with gr.Column(scale=1):
254
+ final_summary = gr.Markdown("## Summary")
255
+ status = gr.Textbox(label="Status", interactive=False)
 
256
 
257
+ def analyze(message: str, history: List[List[str]], files: List):
 
258
  try:
 
 
 
 
 
 
259
  if not files:
260
+ return history, "Please upload a file first", "No file uploaded"
 
 
 
 
 
261
 
262
+ file = files[0]
263
+ file_type = file.name.split(".")[-1].lower()
264
+ history.append([message, None])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
 
266
+ # Process file
267
+ processed = process_file_cached(file.name, file_type)
268
+ if "error" in processed[0]:
269
+ return history, processed[0]["error"], "File processing failed"
270
 
271
+ # Prepare chunks
 
 
 
 
 
 
 
 
 
 
 
 
272
  chunks = []
273
+ for item in processed:
274
  if "content" in item:
275
  chunks.append(item["content"])
276
  elif "rows" in item:
277
+ rows_text = "\n".join([", ".join(map(str, row)) for row in item["rows"][:100]]) # Limit rows
278
+ chunks.append(f"=== {item.get('sheet', 'Data')} ===\n{rows_text}")
 
279
 
280
  if not chunks:
281
+ return history, "No processable content found", "Content extraction failed"
282
+
283
+ # Process chunks
284
+ responses = []
285
+ for i, chunk in enumerate(chunks[:5]): # Limit to 5 chunks
286
+ try:
287
+ prompt = PROMPT_TEMPLATE.format(chunk=chunk[:5000]) # Limit chunk size
288
+ response = agent.run_quick_summary(prompt, 0.2, 256, 500) # Limit tokens
289
+ cleaned = clean_response(response)
290
+ if cleaned:
291
+ responses.append(f"Analysis {i+1}:\n{cleaned}")
292
+ except Exception as e:
293
+ logger.warning(f"Error processing chunk {i+1}: {str(e)}")
294
+ continue
295
+
296
+ if not responses:
297
+ return history, "No valid analysis generated", "Analysis failed"
298
+
299
+ summary = "\n\n".join(responses)
300
+ history[-1][1] = summary
301
+ return history, "Analysis completed", "Success"
302
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
  except Exception as e:
304
+ logger.error(f"Analysis error: {e}")
305
+ return history, f"Error: {str(e)}", "Failed"
306
+ finally:
307
+ torch.cuda.empty_cache()
308
+ gc.collect()
309
 
310
  send_btn.click(
311
+ analyze,
312
+ inputs=[msg_input, chatbot, file_upload],
313
+ outputs=[chatbot, final_summary, status]
314
  )
315
  msg_input.submit(
316
+ analyze,
317
+ inputs=[msg_input, chatbot, file_upload],
318
+ outputs=[chatbot, final_summary, status]
319
  )
320
 
321
  return demo
322
 
323
  if __name__ == "__main__":
324
  try:
 
325
  agent = init_agent()
326
  demo = create_ui(agent)
327
+ demo.launch(
328
  server_name="0.0.0.0",
329
  server_port=7860,
330
+ share=False
331
  )
332
  except Exception as e:
333
  logger.error(f"Fatal error: {e}")