Ali2206 commited on
Commit
fb2ccc1
·
verified ·
1 Parent(s): 44280bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +151 -78
app.py CHANGED
@@ -12,6 +12,7 @@ import re
12
  import psutil
13
  import subprocess
14
  from datetime import datetime
 
15
 
16
  # Persistent directory setup
17
  persistent_dir = "/data/hf_cache"
@@ -44,8 +45,10 @@ MEDICAL_KEYWORDS = {
44
  'allergies', 'summary', 'impression', 'findings', 'recommendations',
45
  'conclusion', 'history', 'examination', 'progress', 'discharge'
46
  }
47
- CHUNK_SIZE = 10000 # Increased chunk size for better context
48
- MAX_TOKENS = 12000 # Maximum tokens for analysis
 
 
49
 
50
  def sanitize_utf8(text: str) -> str:
51
  """Ensure text is UTF-8 clean."""
@@ -56,14 +59,21 @@ def file_hash(path: str) -> str:
56
  with open(path, "rb") as f:
57
  return hashlib.md5(f.read()).hexdigest()
58
 
59
- def extract_all_pages(file_path: str) -> Tuple[str, int]:
 
 
 
 
 
60
  """
61
- Extract all pages from PDF with smart prioritization of medical sections.
62
- Returns (extracted_text, total_pages)
63
  """
64
  try:
65
  text_chunks = []
66
  total_pages = 0
 
 
67
  with pdfplumber.open(file_path) as pdf:
68
  total_pages = len(pdf.pages)
69
 
@@ -71,18 +81,22 @@ def extract_all_pages(file_path: str) -> Tuple[str, int]:
71
  page_text = page.extract_text() or ""
72
  lower_text = page_text.lower()
73
 
74
- # Include all pages but mark sections with medical keywords
75
  if any(re.search(rf'\b{kw}\b', lower_text) for kw in MEDICAL_KEYWORDS):
76
- text_chunks.append(f"=== MEDICAL SECTION (Page {i+1}) ===\n{page_text.strip()}")
 
 
77
  else:
78
- text_chunks.append(f"=== Page {i+1} ===\n{page_text.strip()}")
 
 
79
 
80
- return "\n\n".join(text_chunks), total_pages
81
  except Exception as e:
82
- return f"PDF processing error: {str(e)}", 0
83
 
84
  def convert_file_to_json(file_path: str, file_type: str) -> str:
85
- """Convert file to JSON format with caching, processing all content."""
86
  try:
87
  h = file_hash(file_path)
88
  cache_path = os.path.join(file_cache_dir, f"{h}.json")
@@ -92,11 +106,12 @@ def convert_file_to_json(file_path: str, file_type: str) -> str:
92
  return f.read()
93
 
94
  if file_type == "pdf":
95
- text, total_pages = extract_all_pages(file_path)
96
  result = json.dumps({
97
  "filename": os.path.basename(file_path),
98
  "content": text,
99
  "total_pages": total_pages,
 
100
  "status": "complete"
101
  })
102
  elif file_type == "csv":
@@ -106,15 +121,22 @@ def convert_file_to_json(file_path: str, file_type: str) -> str:
106
  skip_blank_lines=False, on_bad_lines="skip", chunksize=1000):
107
  chunks.append(chunk.fillna("").astype(str).values.tolist())
108
  content = [item for sublist in chunks for item in sublist]
109
- result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
 
 
 
 
110
  elif file_type in ["xls", "xlsx"]:
111
  try:
112
- # Read Excel in chunks if possible
113
  df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
114
  except Exception:
115
  df = pd.read_excel(file_path, engine="xlrd", header=None, dtype=str)
116
  content = df.fillna("").astype(str).values.tolist()
117
- result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
 
 
 
 
118
  else:
119
  result = json.dumps({"error": f"Unsupported file type: {file_type}"})
120
 
@@ -204,6 +226,40 @@ def format_final_report(analysis_results: List[str], filename: str) -> str:
204
 
205
  return "\n".join(report)
206
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  def init_agent():
208
  """Initialize the TxAgent with proper configuration."""
209
  print("🔁 Initializing model...")
@@ -229,72 +285,74 @@ def init_agent():
229
  print("✅ Agent Ready")
230
  return agent
231
 
232
- def analyze_large_document(content: str, filename: str, agent: TxAgent) -> str:
233
- """Analyze large documents by processing in logical sections."""
234
- # Split content into logical sections
235
- sections = re.split(r"(=== MEDICAL SECTION|=== Page \d+ ===)", content)
236
- sections = [s.strip() for s in sections if s.strip()]
237
-
238
  analysis_results = []
239
- current_chunk = ""
240
-
241
- for section in sections:
242
- # If adding this section would exceed chunk size, analyze current chunk
243
- if len(current_chunk) + len(section) > CHUNK_SIZE and current_chunk:
244
- analysis_results.append(process_chunk(current_chunk, filename, agent))
245
- current_chunk = section
246
- else:
247
- current_chunk += "\n\n" + section
248
 
249
- # Process the last chunk
250
- if current_chunk:
251
- analysis_results.append(process_chunk(current_chunk, filename, agent))
252
-
253
- return format_final_report(analysis_results, filename)
 
254
 
255
- def process_chunk(chunk: str, filename: str, agent: TxAgent) -> str:
256
- """Process a single chunk of the document."""
257
- prompt = f"""
258
- Analyze this section of medical records for clinical oversights. Focus on:
259
- 1. Critical findings needing immediate attention
260
- 2. Potential missed diagnoses
261
- 3. Medication conflicts
262
- 4. Assessment gaps
263
- 5. Follow-up recommendations
264
 
265
- File: {filename}
266
- Content:
267
- {chunk[:CHUNK_SIZE]}
 
 
 
268
 
269
- Provide concise findings in bullet points under relevant headings.
270
- Focus on factual evidence from the content.
271
  """
272
-
273
- full_response = ""
274
- for output in agent.run_gradio_chat(
275
- message=prompt,
276
- history=[],
277
- temperature=0.1, # Lower temperature for more factual responses
278
- max_new_tokens=1024,
279
- max_token=MAX_TOKENS,
280
- call_agent=False,
281
- conversation=[],
282
- ):
283
- if output is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  continue
285
-
286
- if isinstance(output, list):
287
- for m in output:
288
- if hasattr(m, 'content') and m.content:
289
- cleaned = clean_response(m.content)
290
- if cleaned:
291
- full_response += cleaned + "\n"
292
- elif isinstance(output, str) and output.strip():
293
- cleaned = clean_response(output)
294
- if cleaned:
295
- full_response += cleaned + "\n"
296
 
297
- return full_response
298
 
299
  def create_ui(agent):
300
  """Create the Gradio interface."""
@@ -316,7 +374,7 @@ def create_ui(agent):
316
  label="Analysis Focus"
317
  )
318
  with gr.Row():
319
- send_btn = gr.Button("Analyze Full Document", variant="primary")
320
  clear_btn = gr.Button("Clear")
321
  status = gr.Textbox(label="Status", interactive=False)
322
 
@@ -338,11 +396,12 @@ def create_ui(agent):
338
  yield "", None, "⚠️ Please upload at least one file to analyze."
339
  return
340
 
341
- yield "", None, "⏳ Processing documents..."
342
 
343
  # Process all files completely
344
  file_contents = []
345
  filenames = []
 
346
 
347
  with ThreadPoolExecutor(max_workers=4) as executor:
348
  futures = []
@@ -356,7 +415,14 @@ def create_ui(agent):
356
 
357
  results = []
358
  for future in as_completed(futures):
359
- results.append(sanitize_utf8(future.result()))
 
 
 
 
 
 
 
360
 
361
  file_contents = results
362
 
@@ -367,11 +433,11 @@ def create_ui(agent):
367
  for fc in file_contents
368
  ])
369
 
370
- yield "", None, "🔍 Analyzing content..."
371
 
372
  try:
373
  # Process the complete document
374
- full_report = analyze_large_document(
375
  combined_content,
376
  combined_filename,
377
  agent
@@ -408,6 +474,13 @@ def create_ui(agent):
408
 
409
  if __name__ == "__main__":
410
  print("🚀 Launching app...")
 
 
 
 
 
 
 
411
  agent = init_agent()
412
  demo = create_ui(agent)
413
  demo.queue(
 
12
  import psutil
13
  import subprocess
14
  from datetime import datetime
15
+ import tiktoken
16
 
17
  # Persistent directory setup
18
  persistent_dir = "/data/hf_cache"
 
45
  'allergies', 'summary', 'impression', 'findings', 'recommendations',
46
  'conclusion', 'history', 'examination', 'progress', 'discharge'
47
  }
48
+ TOKENIZER = "cl100k_base" # Matches Llama 3's tokenizer
49
+ MAX_MODEL_LEN = 8000 # Conservative estimate for model context
50
+ CHUNK_TOKEN_SIZE = MAX_MODEL_LEN // 2 # Target chunk size
51
+ MEDICAL_SECTION_HEADER = "=== MEDICAL SECTION ==="
52
 
53
  def sanitize_utf8(text: str) -> str:
54
  """Ensure text is UTF-8 clean."""
 
59
  with open(path, "rb") as f:
60
  return hashlib.md5(f.read()).hexdigest()
61
 
62
+ def count_tokens(text: str) -> int:
63
+ """Count tokens using the same method as the model"""
64
+ encoding = tiktoken.get_encoding(TOKENIZER)
65
+ return len(encoding.encode(text))
66
+
67
+ def extract_all_pages_with_token_count(file_path: str) -> Tuple[str, int, int]:
68
  """
69
+ Extract all pages from PDF with token counting.
70
+ Returns (extracted_text, total_pages, total_tokens)
71
  """
72
  try:
73
  text_chunks = []
74
  total_pages = 0
75
+ total_tokens = 0
76
+
77
  with pdfplumber.open(file_path) as pdf:
78
  total_pages = len(pdf.pages)
79
 
 
81
  page_text = page.extract_text() or ""
82
  lower_text = page_text.lower()
83
 
84
+ # Mark medical sections
85
  if any(re.search(rf'\b{kw}\b', lower_text) for kw in MEDICAL_KEYWORDS):
86
+ section_header = f"\n{MEDICAL_SECTION_HEADER} (Page {i+1})\n"
87
+ text_chunks.append(section_header + page_text.strip())
88
+ total_tokens += count_tokens(section_header)
89
  else:
90
+ text_chunks.append(f"\n=== Page {i+1} ===\n{page_text.strip()}")
91
+
92
+ total_tokens += count_tokens(page_text)
93
 
94
+ return "\n".join(text_chunks), total_pages, total_tokens
95
  except Exception as e:
96
+ return f"PDF processing error: {str(e)}", 0, 0
97
 
98
  def convert_file_to_json(file_path: str, file_type: str) -> str:
99
+ """Convert file to JSON format with caching and token counting."""
100
  try:
101
  h = file_hash(file_path)
102
  cache_path = os.path.join(file_cache_dir, f"{h}.json")
 
106
  return f.read()
107
 
108
  if file_type == "pdf":
109
+ text, total_pages, total_tokens = extract_all_pages_with_token_count(file_path)
110
  result = json.dumps({
111
  "filename": os.path.basename(file_path),
112
  "content": text,
113
  "total_pages": total_pages,
114
+ "total_tokens": total_tokens,
115
  "status": "complete"
116
  })
117
  elif file_type == "csv":
 
121
  skip_blank_lines=False, on_bad_lines="skip", chunksize=1000):
122
  chunks.append(chunk.fillna("").astype(str).values.tolist())
123
  content = [item for sublist in chunks for item in sublist]
124
+ result = json.dumps({
125
+ "filename": os.path.basename(file_path),
126
+ "rows": content,
127
+ "total_tokens": count_tokens(str(content))
128
+ })
129
  elif file_type in ["xls", "xlsx"]:
130
  try:
 
131
  df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
132
  except Exception:
133
  df = pd.read_excel(file_path, engine="xlrd", header=None, dtype=str)
134
  content = df.fillna("").astype(str).values.tolist()
135
+ result = json.dumps({
136
+ "filename": os.path.basename(file_path),
137
+ "rows": content,
138
+ "total_tokens": count_tokens(str(content))
139
+ })
140
  else:
141
  result = json.dumps({"error": f"Unsupported file type: {file_type}"})
142
 
 
226
 
227
  return "\n".join(report)
228
 
229
+ def split_content_by_tokens(content: str, max_tokens: int = CHUNK_TOKEN_SIZE) -> List[str]:
230
+ """Split content into chunks that fit within token limits"""
231
+ paragraphs = re.split(r"\n\s*\n", content)
232
+ chunks = []
233
+ current_chunk = []
234
+ current_tokens = 0
235
+
236
+ for para in paragraphs:
237
+ para_tokens = count_tokens(para)
238
+ if para_tokens > max_tokens:
239
+ # Handle very long paragraphs by splitting sentences
240
+ sentences = re.split(r'(?<=[.!?])\s+', para)
241
+ for sent in sentences:
242
+ sent_tokens = count_tokens(sent)
243
+ if current_tokens + sent_tokens > max_tokens:
244
+ chunks.append("\n\n".join(current_chunk))
245
+ current_chunk = [sent]
246
+ current_tokens = sent_tokens
247
+ else:
248
+ current_chunk.append(sent)
249
+ current_tokens += sent_tokens
250
+ elif current_tokens + para_tokens > max_tokens:
251
+ chunks.append("\n\n".join(current_chunk))
252
+ current_chunk = [para]
253
+ current_tokens = para_tokens
254
+ else:
255
+ current_chunk.append(para)
256
+ current_tokens += para_tokens
257
+
258
+ if current_chunk:
259
+ chunks.append("\n\n".join(current_chunk))
260
+
261
+ return chunks
262
+
263
  def init_agent():
264
  """Initialize the TxAgent with proper configuration."""
265
  print("🔁 Initializing model...")
 
285
  print("✅ Agent Ready")
286
  return agent
287
 
288
+ def analyze_complete_document(content: str, filename: str, agent: TxAgent) -> str:
289
+ """Analyze complete document with proper chunking and token management"""
290
+ chunks = split_content_by_tokens(content)
 
 
 
291
  analysis_results = []
 
 
 
 
 
 
 
 
 
292
 
293
+ for i, chunk in enumerate(chunks):
294
+ try:
295
+ # Create context-aware prompt
296
+ prompt = f"""
297
+ Analyze this section ({i+1}/{len(chunks)}) of medical records for clinical oversights.
298
+ Focus on factual evidence from the content only.
299
 
300
+ **File:** {filename}
301
+ **Content:**
302
+ {chunk}
 
 
 
 
 
 
303
 
304
+ Provide concise findings under these headings:
305
+ 1. CRITICAL FINDINGS (urgent issues)
306
+ 2. MISSED DIAGNOSES (with supporting evidence)
307
+ 3. MEDICATION ISSUES (specific conflicts)
308
+ 4. ASSESSMENT GAPS (missing evaluations)
309
+ 5. FOLLOW-UP RECOMMENDATIONS (specific actions)
310
 
311
+ Be concise and evidence-based:
 
312
  """
313
+ # Ensure prompt + chunk doesn't exceed model limits
314
+ prompt_tokens = count_tokens(prompt)
315
+ chunk_tokens = count_tokens(chunk)
316
+
317
+ if prompt_tokens + chunk_tokens > MAX_MODEL_LEN - 1024: # Leave room for response
318
+ # Dynamically adjust chunk size
319
+ max_chunk_tokens = MAX_MODEL_LEN - prompt_tokens - 1024
320
+ adjusted_chunk = ""
321
+ tokens_used = 0
322
+ for para in re.split(r"\n\s*\n", chunk):
323
+ para_tokens = count_tokens(para)
324
+ if tokens_used + para_tokens <= max_chunk_tokens:
325
+ adjusted_chunk += "\n\n" + para
326
+ tokens_used += para_tokens
327
+ else:
328
+ break
329
+ chunk = adjusted_chunk.strip()
330
+
331
+ response = ""
332
+ for output in agent.run_gradio_chat(
333
+ message=prompt,
334
+ history=[],
335
+ temperature=0.1,
336
+ max_new_tokens=1024,
337
+ max_token=MAX_MODEL_LEN,
338
+ call_agent=False,
339
+ conversation=[],
340
+ ):
341
+ if output:
342
+ if isinstance(output, list):
343
+ for m in output:
344
+ if hasattr(m, 'content'):
345
+ response += clean_response(m.content)
346
+ elif isinstance(output, str):
347
+ response += clean_response(output)
348
+
349
+ if response:
350
+ analysis_results.append(response)
351
+ except Exception as e:
352
+ print(f"Error processing chunk {i}: {str(e)}")
353
  continue
 
 
 
 
 
 
 
 
 
 
 
354
 
355
+ return format_final_report(analysis_results, filename)
356
 
357
  def create_ui(agent):
358
  """Create the Gradio interface."""
 
374
  label="Analysis Focus"
375
  )
376
  with gr.Row():
377
+ send_btn = gr.Button("Analyze Complete Documents", variant="primary")
378
  clear_btn = gr.Button("Clear")
379
  status = gr.Textbox(label="Status", interactive=False)
380
 
 
396
  yield "", None, "⚠️ Please upload at least one file to analyze."
397
  return
398
 
399
+ yield "", None, "⏳ Processing documents (this may take several minutes for large files)..."
400
 
401
  # Process all files completely
402
  file_contents = []
403
  filenames = []
404
+ total_tokens = 0
405
 
406
  with ThreadPoolExecutor(max_workers=4) as executor:
407
  futures = []
 
415
 
416
  results = []
417
  for future in as_completed(futures):
418
+ result = sanitize_utf8(future.result())
419
+ results.append(result)
420
+ try:
421
+ data = json.loads(result)
422
+ if "total_tokens" in data:
423
+ total_tokens += data["total_tokens"]
424
+ except:
425
+ pass
426
 
427
  file_contents = results
428
 
 
433
  for fc in file_contents
434
  ])
435
 
436
+ yield "", None, f"🔍 Analyzing content ({total_tokens//1000}k tokens)..."
437
 
438
  try:
439
  # Process the complete document
440
+ full_report = analyze_complete_document(
441
  combined_content,
442
  combined_filename,
443
  agent
 
474
 
475
  if __name__ == "__main__":
476
  print("🚀 Launching app...")
477
+ # Install tiktoken if not available
478
+ try:
479
+ import tiktoken
480
+ except ImportError:
481
+ print("Installing tiktoken...")
482
+ subprocess.run([sys.executable, "-m", "pip", "install", "tiktoken"])
483
+
484
  agent = init_agent()
485
  demo = create_ui(agent)
486
  demo.queue(