Update src/txagent/txagent.py
Browse files- src/txagent/txagent.py +32 -19
src/txagent/txagent.py
CHANGED
@@ -80,11 +80,15 @@ class TxAgent:
|
|
80 |
if model_name:
|
81 |
self.model_name = model_name
|
82 |
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
|
|
88 |
|
89 |
def load_tooluniverse(self):
|
90 |
self.tooluniverse = ToolUniverse(tool_files=self.tool_files_dict)
|
@@ -416,17 +420,23 @@ class TxAgent:
|
|
416 |
prompt += output_begin_string
|
417 |
|
418 |
if check_token_status and max_token:
|
419 |
-
num_input_tokens = len(self.tokenizer.encode(prompt,
|
420 |
if num_input_tokens > max_token:
|
421 |
-
|
422 |
-
|
423 |
-
|
|
|
|
|
424 |
return None, True
|
425 |
-
logger.debug("Input tokens: %d", num_input_tokens)
|
426 |
|
427 |
-
|
428 |
-
|
429 |
-
|
|
|
|
|
|
|
|
|
|
|
430 |
torch.cuda.empty_cache()
|
431 |
if check_token_status:
|
432 |
return output, False
|
@@ -442,17 +452,20 @@ Patient Record Excerpt:
|
|
442 |
"""
|
443 |
conversation = self.set_system_prompt([], prompt.format(chunk=message))
|
444 |
conversation.append({"role": "user", "content": message})
|
445 |
-
output = self.llm_infer(
|
446 |
messages=conversation,
|
447 |
temperature=temperature,
|
448 |
max_new_tokens=max_new_tokens,
|
449 |
max_token=max_token,
|
450 |
tools=[] # No tools
|
451 |
)
|
452 |
-
if
|
|
|
|
|
|
|
453 |
output = output.split('[FinalAnswer]')[-1].strip()
|
454 |
-
logger.debug("Quick summary output: %s", output[:100])
|
455 |
-
return output
|
456 |
|
457 |
def run_background_report(self, message: str, history: list, temperature: float,
|
458 |
max_new_tokens: int, max_token: int, call_agent: bool,
|
@@ -539,10 +552,10 @@ Patient Record Excerpt:
|
|
539 |
f.write(combined_response)
|
540 |
logger.info("Detailed report saved to %s", report_path)
|
541 |
except Exception as e:
|
542 |
-
logger.error("Failed to save report:
|
543 |
|
544 |
except Exception as e:
|
545 |
-
logger.error("Background report error:
|
546 |
combined_response += f"Error: {e}\n"
|
547 |
with open(report_path, "w", encoding="utf-8") as f:
|
548 |
f.write(combined_response)
|
|
|
80 |
if model_name:
|
81 |
self.model_name = model_name
|
82 |
|
83 |
+
try:
|
84 |
+
self.model = LLM(model=self.model_name, dtype="float16", max_model_len=131072)
|
85 |
+
self.chat_template = Template(self.model.get_tokenizer().chat_template)
|
86 |
+
self.tokenizer = self.model.get_tokenizer()
|
87 |
+
logger.info("Model %s loaded successfully", self.model_name)
|
88 |
+
return f"Model {self.model_name} loaded successfully."
|
89 |
+
except Exception as e:
|
90 |
+
logger.error(f"Model loading error: {e}")
|
91 |
+
raise
|
92 |
|
93 |
def load_tooluniverse(self):
|
94 |
self.tooluniverse = ToolUniverse(tool_files=self.tool_files_dict)
|
|
|
420 |
prompt += output_begin_string
|
421 |
|
422 |
if check_token_status and max_token:
|
423 |
+
num_input_tokens = len(self.tokenizer.encode(prompt, add_special_tokens=False))
|
424 |
if num_input_tokens > max_token:
|
425 |
+
logger.warning(f"Input tokens ({num_input_tokens}) exceed max_token ({max_token}). Truncating.")
|
426 |
+
prompt_tokens = self.tokenizer.encode(prompt, add_special_tokens=False)[:max_token]
|
427 |
+
prompt = self.tokenizer.decode(prompt_tokens)
|
428 |
+
if num_input_tokens > 131072:
|
429 |
+
logger.error(f"Input tokens ({num_input_tokens}) exceed model limit (131072).")
|
430 |
return None, True
|
|
|
431 |
|
432 |
+
try:
|
433 |
+
output = model.generate(prompt, sampling_params=sampling_params)
|
434 |
+
output = output[0].outputs[0].text
|
435 |
+
logger.debug("Inference output: %s", output[:100])
|
436 |
+
except Exception as e:
|
437 |
+
logger.error(f"Inference error: {e}")
|
438 |
+
return None, True
|
439 |
+
|
440 |
torch.cuda.empty_cache()
|
441 |
if check_token_status:
|
442 |
return output, False
|
|
|
452 |
"""
|
453 |
conversation = self.set_system_prompt([], prompt.format(chunk=message))
|
454 |
conversation.append({"role": "user", "content": message})
|
455 |
+
output, token_overflow = self.llm_infer(
|
456 |
messages=conversation,
|
457 |
temperature=temperature,
|
458 |
max_new_tokens=max_new_tokens,
|
459 |
max_token=max_token,
|
460 |
tools=[] # No tools
|
461 |
)
|
462 |
+
if token_overflow:
|
463 |
+
logger.error("Token overflow in quick summary")
|
464 |
+
return "Error: Input too large for quick summary."
|
465 |
+
if output and '[FinalAnswer]' in output:
|
466 |
output = output.split('[FinalAnswer]')[-1].strip()
|
467 |
+
logger.debug("Quick summary output: %s", output[:100] if output else "None")
|
468 |
+
return output or "No missed diagnoses identified"
|
469 |
|
470 |
def run_background_report(self, message: str, history: list, temperature: float,
|
471 |
max_new_tokens: int, max_token: int, call_agent: bool,
|
|
|
552 |
f.write(combined_response)
|
553 |
logger.info("Detailed report saved to %s", report_path)
|
554 |
except Exception as e:
|
555 |
+
logger.error(f"Failed to save report: {e}")
|
556 |
|
557 |
except Exception as e:
|
558 |
+
logger.error(f"Background report error: {e}")
|
559 |
combined_response += f"Error: {e}\n"
|
560 |
with open(report_path, "w", encoding="utf-8") as f:
|
561 |
f.write(combined_response)
|