Ali2206 commited on
Commit
5601da1
·
verified ·
1 Parent(s): 5d37db7

Update src/txagent/txagent.py

Browse files
Files changed (1) hide show
  1. src/txagent/txagent.py +21 -1
src/txagent/txagent.py CHANGED
@@ -81,7 +81,13 @@ class TxAgent:
81
  self.model_name = model_name
82
 
83
  try:
84
- self.model = LLM(model=self.model_name, dtype="float16", max_model_len=131072)
 
 
 
 
 
 
85
  self.chat_template = Template(self.model.get_tokenizer().chat_template)
86
  self.tokenizer = self.model.get_tokenizer()
87
  logger.info("Model %s loaded successfully", self.model_name)
@@ -419,6 +425,10 @@ class TxAgent:
419
  if output_begin_string:
420
  prompt += output_begin_string
421
 
 
 
 
 
422
  if check_token_status and max_token:
423
  num_input_tokens = len(self.tokenizer.encode(prompt, add_special_tokens=False))
424
  if num_input_tokens > max_token:
@@ -430,6 +440,7 @@ class TxAgent:
430
  return None, True
431
 
432
  try:
 
433
  output = model.generate(prompt, sampling_params=sampling_params)
434
  output = output[0].outputs[0].text
435
  logger.debug("Inference output: %s", output[:100])
@@ -438,6 +449,7 @@ class TxAgent:
438
  return None, True
439
 
440
  torch.cuda.empty_cache()
 
441
  if check_token_status:
442
  return output, False
443
  return output
@@ -445,6 +457,10 @@ class TxAgent:
445
  def run_quick_summary(self, message: str, temperature: float = 0.1, max_new_tokens: int = 256, max_token: int = 1024):
446
  """Generate a fast, concise summary of potential missed diagnoses without tool calls"""
447
  logger.debug("Starting quick summary for message: %s", message[:100])
 
 
 
 
448
  prompt = """
449
  Analyze the patient record excerpt for missed diagnoses, focusing ONLY on clinical findings such as symptoms, medications, or evaluation results. Provide a concise summary in ONE paragraph without headings or bullet points. ALWAYS treat medications or psychiatric evaluations as potential missed diagnoses, specifying their implications (e.g., 'use of Seroquel may indicate untreated psychosis'). Recommend urgent review for identified findings. Do NOT use external tools or repeat non-clinical data (e.g., name, date of birth). If no clinical findings are present, state 'No missed diagnoses identified' in ONE sentence.
450
  Patient Record Excerpt:
@@ -473,6 +489,10 @@ Patient Record Excerpt:
473
  call_agent_level: int, report_path: str):
474
  """Run detailed report generation in the background and save to file"""
475
  logger.debug("Starting background report for message: %s", message[:100])
 
 
 
 
476
  combined_response = ""
477
  history_copy = history.copy()
478
 
 
81
  self.model_name = model_name
82
 
83
  try:
84
+ torch.cuda.empty_cache()
85
+ self.model = LLM(
86
+ model=self.model_name,
87
+ dtype="float16",
88
+ max_model_len=131072,
89
+ enforce_eager=True # Avoid graph compilation issues
90
+ )
91
  self.chat_template = Template(self.model.get_tokenizer().chat_template)
92
  self.tokenizer = self.model.get_tokenizer()
93
  logger.info("Model %s loaded successfully", self.model_name)
 
425
  if output_begin_string:
426
  prompt += output_begin_string
427
 
428
+ if len(prompt) > 100000: # Early text length check
429
+ logger.error(f"Prompt length ({len(prompt)}) exceeds limit (100000).")
430
+ return None, True
431
+
432
  if check_token_status and max_token:
433
  num_input_tokens = len(self.tokenizer.encode(prompt, add_special_tokens=False))
434
  if num_input_tokens > max_token:
 
440
  return None, True
441
 
442
  try:
443
+ torch.cuda.empty_cache()
444
  output = model.generate(prompt, sampling_params=sampling_params)
445
  output = output[0].outputs[0].text
446
  logger.debug("Inference output: %s", output[:100])
 
449
  return None, True
450
 
451
  torch.cuda.empty_cache()
452
+ gc.collect()
453
  if check_token_status:
454
  return output, False
455
  return output
 
457
  def run_quick_summary(self, message: str, temperature: float = 0.1, max_new_tokens: int = 256, max_token: int = 1024):
458
  """Generate a fast, concise summary of potential missed diagnoses without tool calls"""
459
  logger.debug("Starting quick summary for message: %s", message[:100])
460
+ if len(message) > 50000:
461
+ logger.warning(f"Message length ({len(message)}) exceeds limit (50000). Truncating.")
462
+ message = message[:50000]
463
+
464
  prompt = """
465
  Analyze the patient record excerpt for missed diagnoses, focusing ONLY on clinical findings such as symptoms, medications, or evaluation results. Provide a concise summary in ONE paragraph without headings or bullet points. ALWAYS treat medications or psychiatric evaluations as potential missed diagnoses, specifying their implications (e.g., 'use of Seroquel may indicate untreated psychosis'). Recommend urgent review for identified findings. Do NOT use external tools or repeat non-clinical data (e.g., name, date of birth). If no clinical findings are present, state 'No missed diagnoses identified' in ONE sentence.
466
  Patient Record Excerpt:
 
489
  call_agent_level: int, report_path: str):
490
  """Run detailed report generation in the background and save to file"""
491
  logger.debug("Starting background report for message: %s", message[:100])
492
+ if len(message) > 50000:
493
+ logger.warning(f"Message length ({len(message)}) exceeds limit (50000). Truncating.")
494
+ message = message[:50000]
495
+
496
  combined_response = ""
497
  history_copy = history.copy()
498