Ali2206 commited on
Commit
8ce9243
·
verified ·
1 Parent(s): 9569e68

Update src/txagent/txagent.py

Browse files
Files changed (1) hide show
  1. src/txagent/txagent.py +32 -19
src/txagent/txagent.py CHANGED
@@ -80,11 +80,15 @@ class TxAgent:
80
  if model_name:
81
  self.model_name = model_name
82
 
83
- self.model = LLM(model=self.model_name, dtype="float16")
84
- self.chat_template = Template(self.model.get_tokenizer().chat_template)
85
- self.tokenizer = self.model.get_tokenizer()
86
- logger.info("Model %s loaded successfully", self.model_name)
87
- return f"Model {self.model_name} loaded successfully."
 
 
 
 
88
 
89
  def load_tooluniverse(self):
90
  self.tooluniverse = ToolUniverse(tool_files=self.tool_files_dict)
@@ -416,17 +420,23 @@ class TxAgent:
416
  prompt += output_begin_string
417
 
418
  if check_token_status and max_token:
419
- num_input_tokens = len(self.tokenizer.encode(prompt, return_tensors="pt")[0])
420
  if num_input_tokens > max_token:
421
- torch.cuda.empty_cache()
422
- gc.collect()
423
- logger.info("Token overflow: %d > %d", num_input_tokens, max_token)
 
 
424
  return None, True
425
- logger.debug("Input tokens: %d", num_input_tokens)
426
 
427
- output = model.generate(prompt, sampling_params=sampling_params)
428
- output = output[0].outputs[0].text
429
- logger.debug("Inference output: %s", output[:100])
 
 
 
 
 
430
  torch.cuda.empty_cache()
431
  if check_token_status:
432
  return output, False
@@ -442,17 +452,20 @@ Patient Record Excerpt:
442
  """
443
  conversation = self.set_system_prompt([], prompt.format(chunk=message))
444
  conversation.append({"role": "user", "content": message})
445
- output = self.llm_infer(
446
  messages=conversation,
447
  temperature=temperature,
448
  max_new_tokens=max_new_tokens,
449
  max_token=max_token,
450
  tools=[] # No tools
451
  )
452
- if '[FinalAnswer]' in output:
 
 
 
453
  output = output.split('[FinalAnswer]')[-1].strip()
454
- logger.debug("Quick summary output: %s", output[:100])
455
- return output
456
 
457
  def run_background_report(self, message: str, history: list, temperature: float,
458
  max_new_tokens: int, max_token: int, call_agent: bool,
@@ -539,10 +552,10 @@ Patient Record Excerpt:
539
  f.write(combined_response)
540
  logger.info("Detailed report saved to %s", report_path)
541
  except Exception as e:
542
- logger.error("Failed to save report: %s", e)
543
 
544
  except Exception as e:
545
- logger.error("Background report error: %s", e)
546
  combined_response += f"Error: {e}\n"
547
  with open(report_path, "w", encoding="utf-8") as f:
548
  f.write(combined_response)
 
80
  if model_name:
81
  self.model_name = model_name
82
 
83
+ try:
84
+ self.model = LLM(model=self.model_name, dtype="float16", max_model_len=131072)
85
+ self.chat_template = Template(self.model.get_tokenizer().chat_template)
86
+ self.tokenizer = self.model.get_tokenizer()
87
+ logger.info("Model %s loaded successfully", self.model_name)
88
+ return f"Model {self.model_name} loaded successfully."
89
+ except Exception as e:
90
+ logger.error(f"Model loading error: {e}")
91
+ raise
92
 
93
  def load_tooluniverse(self):
94
  self.tooluniverse = ToolUniverse(tool_files=self.tool_files_dict)
 
420
  prompt += output_begin_string
421
 
422
  if check_token_status and max_token:
423
+ num_input_tokens = len(self.tokenizer.encode(prompt, add_special_tokens=False))
424
  if num_input_tokens > max_token:
425
+ logger.warning(f"Input tokens ({num_input_tokens}) exceed max_token ({max_token}). Truncating.")
426
+ prompt_tokens = self.tokenizer.encode(prompt, add_special_tokens=False)[:max_token]
427
+ prompt = self.tokenizer.decode(prompt_tokens)
428
+ if num_input_tokens > 131072:
429
+ logger.error(f"Input tokens ({num_input_tokens}) exceed model limit (131072).")
430
  return None, True
 
431
 
432
+ try:
433
+ output = model.generate(prompt, sampling_params=sampling_params)
434
+ output = output[0].outputs[0].text
435
+ logger.debug("Inference output: %s", output[:100])
436
+ except Exception as e:
437
+ logger.error(f"Inference error: {e}")
438
+ return None, True
439
+
440
  torch.cuda.empty_cache()
441
  if check_token_status:
442
  return output, False
 
452
  """
453
  conversation = self.set_system_prompt([], prompt.format(chunk=message))
454
  conversation.append({"role": "user", "content": message})
455
+ output, token_overflow = self.llm_infer(
456
  messages=conversation,
457
  temperature=temperature,
458
  max_new_tokens=max_new_tokens,
459
  max_token=max_token,
460
  tools=[] # No tools
461
  )
462
+ if token_overflow:
463
+ logger.error("Token overflow in quick summary")
464
+ return "Error: Input too large for quick summary."
465
+ if output and '[FinalAnswer]' in output:
466
  output = output.split('[FinalAnswer]')[-1].strip()
467
+ logger.debug("Quick summary output: %s", output[:100] if output else "None")
468
+ return output or "No missed diagnoses identified"
469
 
470
  def run_background_report(self, message: str, history: list, temperature: float,
471
  max_new_tokens: int, max_token: int, call_agent: bool,
 
552
  f.write(combined_response)
553
  logger.info("Detailed report saved to %s", report_path)
554
  except Exception as e:
555
+ logger.error(f"Failed to save report: {e}")
556
 
557
  except Exception as e:
558
+ logger.error(f"Background report error: {e}")
559
  combined_response += f"Error: {e}\n"
560
  with open(report_path, "w", encoding="utf-8") as f:
561
  f.write(combined_response)