Ali2206 commited on
Commit
828effe
·
verified ·
1 Parent(s): a71a831

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -99
app.py CHANGED
@@ -5,23 +5,37 @@ import pdfplumber
5
  import json
6
  import gradio as gr
7
  from typing import List, Dict, Optional, Generator
8
- from concurrent.futures import ThreadPoolExecutor, as_completed
9
  import hashlib
10
  import shutil
11
  import re
12
  import psutil
13
- import subprocess
14
  import logging
15
  import torch
16
  import gc
17
  from diskcache import Cache
18
  import time
19
  from transformers import AutoTokenizer
 
 
 
 
 
 
20
 
21
  # Configure logging
22
  logging.basicConfig(level=logging.INFO)
23
  logger = logging.getLogger(__name__)
24
 
 
 
 
 
 
 
 
 
25
  # Persistent directory
26
  persistent_dir = "/data/hf_cache"
27
  os.makedirs(persistent_dir, exist_ok=True)
@@ -61,13 +75,17 @@ def file_hash(path: str) -> str:
61
  return hashlib.md5(f.read()).hexdigest()
62
 
63
  def extract_all_pages(file_path: str, progress_callback=None) -> str:
 
 
 
 
64
  try:
65
  with pdfplumber.open(file_path) as pdf:
66
  total_pages = len(pdf.pages)
67
  if total_pages == 0:
68
  return ""
69
 
70
- batch_size = 10
71
  batches = [(i, min(i + batch_size, total_pages)) for i in range(0, total_pages, batch_size)]
72
  text_chunks = [""] * total_pages
73
  processed_pages = 0
@@ -77,11 +95,11 @@ def extract_all_pages(file_path: str, progress_callback=None) -> str:
77
  with pdfplumber.open(file_path) as pdf:
78
  for page in pdf.pages[start:end]:
79
  page_num = start + pdf.pages.index(page)
80
- page_text = page.extract_text() or ""
81
  results.append((page_num, f"=== Page {page_num + 1} ===\n{page_text.strip()}"))
82
  return results
83
 
84
- with ThreadPoolExecutor(max_workers=6) as executor:
85
  futures = [executor.submit(extract_batch, start, end) for start, end in batches]
86
  for future in as_completed(futures):
87
  for page_num, text in future.result():
@@ -90,62 +108,54 @@ def extract_all_pages(file_path: str, progress_callback=None) -> str:
90
  if progress_callback:
91
  progress_callback(min(processed_pages, total_pages), total_pages)
92
 
93
- return "\n\n".join(filter(None, text_chunks))
 
 
94
  except Exception as e:
95
  logger.error("PDF processing error: %s", e)
96
  return f"PDF processing error: {str(e)}"
97
 
98
  def excel_to_json(file_path: str) -> List[Dict]:
99
- """Convert Excel file to JSON with optimized processing"""
 
 
 
100
  try:
101
- # First try with openpyxl (faster for xlsx)
102
- try:
103
- df = pd.read_excel(file_path, engine='openpyxl', header=None, dtype=str)
104
- except Exception:
105
- # Fall back to xlrd if needed
106
- df = pd.read_excel(file_path, engine='xlrd', header=None, dtype=str)
107
-
108
- # Convert to list of lists with null handling
109
  content = df.where(pd.notnull(df), "").astype(str).values.tolist()
110
-
111
- return [{
112
  "filename": os.path.basename(file_path),
113
  "rows": content,
114
  "type": "excel"
115
  }]
 
 
116
  except Exception as e:
117
  logger.error(f"Error processing Excel file: {e}")
118
  return [{"error": f"Error processing Excel file: {str(e)}"}]
119
 
120
  def csv_to_json(file_path: str) -> List[Dict]:
121
- """Convert CSV file to JSON with optimized processing"""
 
 
 
122
  try:
123
- # Read CSV in chunks if large
124
- chunks = []
125
- for chunk in pd.read_csv(
126
- file_path,
127
- header=None,
128
- dtype=str,
129
- encoding_errors='replace',
130
- on_bad_lines='skip',
131
- chunksize=10000
132
- ):
133
- chunks.append(chunk)
134
-
135
- df = pd.concat(chunks) if chunks else pd.DataFrame()
136
  content = df.where(pd.notnull(df), "").astype(str).values.tolist()
137
-
138
- return [{
139
  "filename": os.path.basename(file_path),
140
  "rows": content,
141
  "type": "csv"
142
  }]
 
 
143
  except Exception as e:
144
  logger.error(f"Error processing CSV file: {e}")
145
  return [{"error": f"Error processing CSV file: {str(e)}"}]
146
 
147
  def process_file(file_path: str, file_type: str) -> List[Dict]:
148
- """Process file based on type and return JSON data"""
149
  try:
150
  if file_type == "pdf":
151
  text = extract_all_pages(file_path)
@@ -165,18 +175,22 @@ def process_file(file_path: str, file_type: str) -> List[Dict]:
165
  logger.error("Error processing %s: %s", os.path.basename(file_path), e)
166
  return [{"error": f"Error processing {os.path.basename(file_path)}: {str(e)}"}]
167
 
168
- def tokenize_and_chunk(text: str, max_tokens: int = 1800) -> List[str]:
169
- """Split text into chunks based on token count"""
170
- tokens = tokenizer.encode(text)
 
 
 
171
  chunks = []
172
  for i in range(0, len(tokens), max_tokens):
173
  chunk_tokens = tokens[i:i + max_tokens]
174
- chunks.append(tokenizer.decode(chunk_tokens))
 
175
  return chunks
176
 
177
  def log_system_usage(tag=""):
178
  try:
179
- cpu = psutil.cpu_percent(interval=1)
180
  mem = psutil.virtual_memory()
181
  logger.info("[%s] CPU: %.1f%% | RAM: %dMB / %dMB", tag, cpu, mem.used // (1024**2), mem.total // (1024**2))
182
  result = subprocess.run(
@@ -261,27 +275,27 @@ def init_agent():
261
  if not os.path.exists(target_tool_path):
262
  shutil.copy(default_tool_path, target_tool_path)
263
 
264
- agent = TxAgent(
265
- model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
266
- rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
267
- tool_files_dict={"new_tool": target_tool_path},
268
- force_finish=True,
269
- enable_checker=False,
270
- step_rag_num=4,
271
- seed=100,
272
- additional_default_tools=[],
 
273
  )
274
- agent.init_model()
275
  log_system_usage("After Load")
276
  logger.info("Agent Ready")
277
- return agent
278
 
279
- def create_ui(agent):
280
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
281
  gr.Markdown("<h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1>")
282
  chatbot = gr.Chatbot(label="Detailed Analysis", height=600, type="messages")
283
  final_summary = gr.Markdown(label="Summary of Missed Diagnoses")
284
- file_upload = gr.File(file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple")
285
  msg_input = gr.Textbox(placeholder="Ask about potential oversights...", show_label=False)
286
  send_btn = gr.Button("Analyze", variant="primary")
287
  download_output = gr.File(label="Download Full Report")
@@ -293,7 +307,10 @@ Patient Record Excerpt (Chunk {0} of {1}):
293
  {chunk}
294
  """
295
 
296
- def analyze(message: str, history: List[dict], files: List, progress=gr.Progress()):
 
 
 
297
  history.append({"role": "user", "content": message})
298
  yield history, None, ""
299
 
@@ -301,8 +318,7 @@ Patient Record Excerpt (Chunk {0} of {1}):
301
  file_hash_value = ""
302
 
303
  if files:
304
- # Process files in parallel
305
- with ThreadPoolExecutor(max_workers=4) as executor:
306
  futures = []
307
  for f in files:
308
  file_type = f.name.split(".")[-1].lower()
@@ -323,14 +339,11 @@ Patient Record Excerpt (Chunk {0} of {1}):
323
  history.append({"role": "assistant", "content": "✅ File processing complete"})
324
  yield history, None, ""
325
 
326
- # Convert extracted data to JSON text
327
  text_content = "\n".join(json.dumps(item) for item in extracted)
328
-
329
- # Tokenize and chunk the content properly
330
  chunks = tokenize_and_chunk(text_content)
331
  combined_response = ""
332
- batch_size = 2 # Reduced batch size to prevent token overflow
333
-
334
  try:
335
  for batch_idx in range(0, len(chunks), batch_size):
336
  batch_chunks = chunks[batch_idx:batch_idx + batch_size]
@@ -338,7 +351,7 @@ Patient Record Excerpt (Chunk {0} of {1}):
338
  prompt_template.format(
339
  batch_idx + i + 1,
340
  len(chunks),
341
- chunk=chunk[:1800] # Conservative chunk size
342
  )
343
  for i, chunk in enumerate(batch_chunks)
344
  ]
@@ -346,63 +359,54 @@ Patient Record Excerpt (Chunk {0} of {1}):
346
  progress((batch_idx) / len(chunks),
347
  desc=f"Analyzing batch {(batch_idx // batch_size) + 1}/{(len(chunks) + batch_size - 1) // batch_size}")
348
 
349
- # Process batch in parallel
350
- with ThreadPoolExecutor(max_workers=len(batch_prompts)) as executor:
351
- future_to_prompt = {
352
- executor.submit(
353
- agent.run_gradio_chat,
354
- prompt, [], 0.2, 512, 2048, False, []
355
- ): prompt
356
- for prompt in batch_prompts
357
- }
358
-
359
- for future in as_completed(future_to_prompt):
360
  chunk_response = ""
361
- for chunk_output in future.result():
362
- if chunk_output is None:
363
- continue
364
- if isinstance(chunk_output, list):
365
- for m in chunk_output:
366
- if hasattr(m, 'content') and m.content:
367
- cleaned = clean_response(m.content)
368
- if cleaned:
369
- chunk_response += cleaned + " "
370
- elif isinstance(chunk_output, str) and chunk_output.strip():
371
- cleaned = clean_response(chunk_output)
372
- if cleaned:
373
- chunk_response += cleaned + " "
374
-
375
- combined_response += f"--- Analysis for Chunk {batch_idx + 1} ---\n{chunk_response.strip()}\n"
376
- history[-1] = {"role": "assistant", "content": combined_response.strip()}
377
- yield history, None, ""
378
-
379
- # Clean up memory
380
- torch.cuda.empty_cache()
381
- gc.collect()
382
-
383
- # Generate final summary
384
  summary = summarize_findings(combined_response)
385
  report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt") if file_hash_value else None
386
  if report_path:
387
  with open(report_path, "w", encoding="utf-8") as f:
388
  f.write(combined_response + "\n\n" + summary)
 
389
 
390
  yield history, report_path if report_path and os.path.exists(report_path) else None, summary
391
 
392
  except Exception as e:
393
  logger.error("Analysis error: %s", e)
394
  history.append({"role": "assistant", "content": f"❌ Error occurred: {str(e)}"})
 
395
  yield history, None, f"Error occurred during analysis: {str(e)}"
396
 
397
- send_btn.click(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output, final_summary])
398
- msg_input.submit(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output, final_summary])
399
  return demo
400
 
401
  if __name__ == "__main__":
402
  try:
403
  logger.info("Launching app...")
404
- agent = init_agent()
405
- demo = create_ui(agent)
406
  demo.queue(api_open=False).launch(
407
  server_name="0.0.0.0",
408
  server_port=7860,
 
5
  import json
6
  import gradio as gr
7
  from typing import List, Dict, Optional, Generator
8
+ from concurrent.futures import ProcessPoolExecutor, as_completed
9
  import hashlib
10
  import shutil
11
  import re
12
  import psutil
13
+ مخصوصimport subprocess
14
  import logging
15
  import torch
16
  import gc
17
  from diskcache import Cache
18
  import time
19
  from transformers import AutoTokenizer
20
+ import pyarrow as pa
21
+ import pyarrow.csv as pc
22
+ import pyarrow.parquet as pq
23
+ from vllm import LLM, SamplingParams
24
+ import asyncio
25
+ import threading
26
 
27
  # Configure logging
28
  logging.basicConfig(level=logging.INFO)
29
  logger = logging.getLogger(__name__)
30
 
31
+ # File handler for response logging
32
+ response_log_file = os.path.join("/data/hf_cache", "response_log.txt")
33
+ response_logger = logging.getLogger("ResponseLogger")
34
+ response_handler = logging.FileHandler(response_log_file, mode="a")
35
+ response_handler.setFormatter(logging.Formatter("%(asctime)s - %(message)s"))
36
+ response_logger.addHandler(response_handler)
37
+ response_logger.setLevel(logging.INFO)
38
+
39
  # Persistent directory
40
  persistent_dir = "/data/hf_cache"
41
  os.makedirs(persistent_dir, exist_ok=True)
 
75
  return hashlib.md5(f.read()).hexdigest()
76
 
77
  def extract_all_pages(file_path: str, progress_callback=None) -> str:
78
+ cache_key = f"pdf_{file_hash(file_path)}"
79
+ if cache_key in cache:
80
+ return cache[cache_key]
81
+
82
  try:
83
  with pdfplumber.open(file_path) as pdf:
84
  total_pages = len(pdf.pages)
85
  if total_pages == 0:
86
  return ""
87
 
88
+ batch_size = 5
89
  batches = [(i, min(i + batch_size, total_pages)) for i in range(0, total_pages, batch_size)]
90
  text_chunks = [""] * total_pages
91
  processed_pages = 0
 
95
  with pdfplumber.open(file_path) as pdf:
96
  for page in pdf.pages[start:end]:
97
  page_num = start + pdf.pages.index(page)
98
+ page_text = page.extract_text_simple() or ""
99
  results.append((page_num, f"=== Page {page_num + 1} ===\n{page_text.strip()}"))
100
  return results
101
 
102
+ with ProcessPoolExecutor(max_workers=4) as executor:
103
  futures = [executor.submit(extract_batch, start, end) for start, end in batches]
104
  for future in as_completed(futures):
105
  for page_num, text in future.result():
 
108
  if progress_callback:
109
  progress_callback(min(processed_pages, total_pages), total_pages)
110
 
111
+ result = "\n\n".join(filter(None, text_chunks))
112
+ cache[cache_key] = result
113
+ return result
114
  except Exception as e:
115
  logger.error("PDF processing error: %s", e)
116
  return f"PDF processing error: {str(e)}"
117
 
118
  def excel_to_json(file_path: str) -> List[Dict]:
119
+ cache_key = f"excel_{file_hash(file_path)}"
120
+ if cache_key in cache:
121
+ return cache[cache_key]
122
+
123
  try:
124
+ table = pq.read_table(file_path)
125
+ df = table.to_pandas(use_threads=True, split_blocks=True)
 
 
 
 
 
 
126
  content = df.where(pd.notnull(df), "").astype(str).values.tolist()
127
+ result = [{
 
128
  "filename": os.path.basename(file_path),
129
  "rows": content,
130
  "type": "excel"
131
  }]
132
+ cache[cache_key] = result
133
+ return result
134
  except Exception as e:
135
  logger.error(f"Error processing Excel file: {e}")
136
  return [{"error": f"Error processing Excel file: {str(e)}"}]
137
 
138
  def csv_to_json(file_path: str) -> List[Dict]:
139
+ cache_key = f"csv_{file_hash(file_path)}"
140
+ if cache_key in cache:
141
+ return cache[cache_key]
142
+
143
  try:
144
+ table = pc.read_csv(file_path, parse_options=pc.ParseOptions(invalid_row_handler=lambda x: "skip"))
145
+ df = table.to_pandas(use_threads=True, split_blocks=True)
 
 
 
 
 
 
 
 
 
 
 
146
  content = df.where(pd.notnull(df), "").astype(str).values.tolist()
147
+ result = [{
 
148
  "filename": os.path.basename(file_path),
149
  "rows": content,
150
  "type": "csv"
151
  }]
152
+ cache[cache_key] = result
153
+ return result
154
  except Exception as e:
155
  logger.error(f"Error processing CSV file: {e}")
156
  return [{"error": f"Error processing CSV file: {str(e)}"}]
157
 
158
  def process_file(file_path: str, file_type: str) -> List[Dict]:
 
159
  try:
160
  if file_type == "pdf":
161
  text = extract_all_pages(file_path)
 
175
  logger.error("Error processing %s: %s", os.path.basename(file_path), e)
176
  return [{"error": f"Error processing {os.path.basename(file_path)}: {str(e)}"}]
177
 
178
+ def tokenize_and_chunk(text: str, max_tokens: int = 800) -> List[str]:
179
+ cache_key = f"tokens_{hashlib.md5(text.encode()).hexdigest()}"
180
+ if cache_key in cache:
181
+ return cache[cache_key]
182
+
183
+ tokens = tokenizer.encode(text, add_special_tokens=False)
184
  chunks = []
185
  for i in range(0, len(tokens), max_tokens):
186
  chunk_tokens = tokens[i:i + max_tokens]
187
+ chunks.append(tokenizer.decode(chunk_tokens, skip_special_tokens=True))
188
+ cache[cache_key] = chunks
189
  return chunks
190
 
191
  def log_system_usage(tag=""):
192
  try:
193
+ cpu = psutil.cpu_percent(interval=0.1)
194
  mem = psutil.virtual_memory()
195
  logger.info("[%s] CPU: %.1f%% | RAM: %dMB / %dMB", tag, cpu, mem.used // (1024**2), mem.total // (1024**2))
196
  result = subprocess.run(
 
275
  if not os.path.exists(target_tool_path):
276
  shutil.copy(default_tool_path, target_tool_path)
277
 
278
+ llm = LLM(
279
+ model="mims-harvard/TxAgent-T1-Llama-3.1-8B",
280
+ gpu_memory_utilization=0.8,
281
+ max_model_len=2048,
282
+ tensor_parallel_size=1,
283
+ )
284
+ sampling_params = SamplingParams(
285
+ temperature=0.2,
286
+ max_tokens=256, # Reduced for faster streaming
287
+ stop=["</s>", "[INST]"],
288
  )
 
289
  log_system_usage("After Load")
290
  logger.info("Agent Ready")
291
+ return llm, sampling_params
292
 
293
+ async def create_ui(llm, sampling_params):
294
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
295
  gr.Markdown("<h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1>")
296
  chatbot = gr.Chatbot(label="Detailed Analysis", height=600, type="messages")
297
  final_summary = gr.Markdown(label="Summary of Missed Diagnoses")
298
+ file_upload = gr.File(file_types=["pdf", "csv", "xls", "xlsx"], file_count="multiple")
299
  msg_input = gr.Textbox(placeholder="Ask about potential oversights...", show_label=False)
300
  send_btn = gr.Button("Analyze", variant="primary")
301
  download_output = gr.File(label="Download Full Report")
 
307
  {chunk}
308
  """
309
 
310
+ def log_response_partial(text: str):
311
+ response_logger.info(text)
312
+
313
+ async def analyze(message: str, history: List[dict], files: List, progress=gr.Progress()):
314
  history.append({"role": "user", "content": message})
315
  yield history, None, ""
316
 
 
318
  file_hash_value = ""
319
 
320
  if files:
321
+ with ProcessPoolExecutor(max_workers=4) as executor:
 
322
  futures = []
323
  for f in files:
324
  file_type = f.name.split(".")[-1].lower()
 
339
  history.append({"role": "assistant", "content": "✅ File processing complete"})
340
  yield history, None, ""
341
 
 
342
  text_content = "\n".join(json.dumps(item) for item in extracted)
 
 
343
  chunks = tokenize_and_chunk(text_content)
344
  combined_response = ""
345
+ batch_size = 1
346
+
347
  try:
348
  for batch_idx in range(0, len(chunks), batch_size):
349
  batch_chunks = chunks[batch_idx:batch_idx + batch_size]
 
351
  prompt_template.format(
352
  batch_idx + i + 1,
353
  len(chunks),
354
+ chunk=chunk[:800]
355
  )
356
  for i, chunk in enumerate(batch_chunks)
357
  ]
 
359
  progress((batch_idx) / len(chunks),
360
  desc=f"Analyzing batch {(batch_idx // batch_size) + 1}/{(len(chunks) + batch_size - 1) // batch_size}")
361
 
362
+ with torch.no_grad():
363
+ for prompt in batch_prompts:
 
 
 
 
 
 
 
 
 
364
  chunk_response = ""
365
+ current_response = ""
366
+ stream = llm.generate([prompt], sampling_params, use_tqdm=False)
367
+ for output in stream:
368
+ for request_output in output:
369
+ new_text = request_output.outputs[0].text[len(current_response):]
370
+ if new_text:
371
+ current_response += new_text
372
+ cleaned = clean_response(current_response)
373
+ if cleaned and cleaned != chunk_response:
374
+ chunk_response = cleaned
375
+ history[-1] = {"role": "assistant", "content": chunk_response}
376
+ threading.Thread(target=log_response_partial, args=(chunk_response,)).start()
377
+ yield history, None, ""
378
+ await asyncio.sleep(0.01) # Prevent UI blocking
379
+
380
+ if chunk_response:
381
+ combined_response += f"--- Analysis for Chunk {batch_idx + 1} ---\n{chunk_response}\n"
382
+
383
+ torch.cuda.empty_cache()
384
+ gc.collect()
385
+
 
 
386
  summary = summarize_findings(combined_response)
387
  report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt") if file_hash_value else None
388
  if report_path:
389
  with open(report_path, "w", encoding="utf-8") as f:
390
  f.write(combined_response + "\n\n" + summary)
391
+ threading.Thread(target=log_response_partial, args=(summary,)).start()
392
 
393
  yield history, report_path if report_path and os.path.exists(report_path) else None, summary
394
 
395
  except Exception as e:
396
  logger.error("Analysis error: %s", e)
397
  history.append({"role": "assistant", "content": f"❌ Error occurred: {str(e)}"})
398
+ threading.Thread(target=log_response_partial, args=(f"Error: {str(e)}",)).start()
399
  yield history, None, f"Error occurred during analysis: {str(e)}"
400
 
401
+ send_btn.click(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output, final_summary], _js="() => {return {streaming: true}}")
402
+ msg_input.submit(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output, final_summary], _js="() => {return {streaming: true}}")
403
  return demo
404
 
405
  if __name__ == "__main__":
406
  try:
407
  logger.info("Launching app...")
408
+ llm, sampling_params = init_agent()
409
+ demo = asyncio.run(create_ui(llm, sampling_params))
410
  demo.queue(api_open=False).launch(
411
  server_name="0.0.0.0",
412
  server_port=7860,