Update src/txagent/txagent.py
Browse files- src/txagent/txagent.py +21 -1
src/txagent/txagent.py
CHANGED
@@ -81,7 +81,13 @@ class TxAgent:
|
|
81 |
self.model_name = model_name
|
82 |
|
83 |
try:
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
self.chat_template = Template(self.model.get_tokenizer().chat_template)
|
86 |
self.tokenizer = self.model.get_tokenizer()
|
87 |
logger.info("Model %s loaded successfully", self.model_name)
|
@@ -419,6 +425,10 @@ class TxAgent:
|
|
419 |
if output_begin_string:
|
420 |
prompt += output_begin_string
|
421 |
|
|
|
|
|
|
|
|
|
422 |
if check_token_status and max_token:
|
423 |
num_input_tokens = len(self.tokenizer.encode(prompt, add_special_tokens=False))
|
424 |
if num_input_tokens > max_token:
|
@@ -430,6 +440,7 @@ class TxAgent:
|
|
430 |
return None, True
|
431 |
|
432 |
try:
|
|
|
433 |
output = model.generate(prompt, sampling_params=sampling_params)
|
434 |
output = output[0].outputs[0].text
|
435 |
logger.debug("Inference output: %s", output[:100])
|
@@ -438,6 +449,7 @@ class TxAgent:
|
|
438 |
return None, True
|
439 |
|
440 |
torch.cuda.empty_cache()
|
|
|
441 |
if check_token_status:
|
442 |
return output, False
|
443 |
return output
|
@@ -445,6 +457,10 @@ class TxAgent:
|
|
445 |
def run_quick_summary(self, message: str, temperature: float = 0.1, max_new_tokens: int = 256, max_token: int = 1024):
|
446 |
"""Generate a fast, concise summary of potential missed diagnoses without tool calls"""
|
447 |
logger.debug("Starting quick summary for message: %s", message[:100])
|
|
|
|
|
|
|
|
|
448 |
prompt = """
|
449 |
Analyze the patient record excerpt for missed diagnoses, focusing ONLY on clinical findings such as symptoms, medications, or evaluation results. Provide a concise summary in ONE paragraph without headings or bullet points. ALWAYS treat medications or psychiatric evaluations as potential missed diagnoses, specifying their implications (e.g., 'use of Seroquel may indicate untreated psychosis'). Recommend urgent review for identified findings. Do NOT use external tools or repeat non-clinical data (e.g., name, date of birth). If no clinical findings are present, state 'No missed diagnoses identified' in ONE sentence.
|
450 |
Patient Record Excerpt:
|
@@ -473,6 +489,10 @@ Patient Record Excerpt:
|
|
473 |
call_agent_level: int, report_path: str):
|
474 |
"""Run detailed report generation in the background and save to file"""
|
475 |
logger.debug("Starting background report for message: %s", message[:100])
|
|
|
|
|
|
|
|
|
476 |
combined_response = ""
|
477 |
history_copy = history.copy()
|
478 |
|
|
|
81 |
self.model_name = model_name
|
82 |
|
83 |
try:
|
84 |
+
torch.cuda.empty_cache()
|
85 |
+
self.model = LLM(
|
86 |
+
model=self.model_name,
|
87 |
+
dtype="float16",
|
88 |
+
max_model_len=131072,
|
89 |
+
enforce_eager=True # Avoid graph compilation issues
|
90 |
+
)
|
91 |
self.chat_template = Template(self.model.get_tokenizer().chat_template)
|
92 |
self.tokenizer = self.model.get_tokenizer()
|
93 |
logger.info("Model %s loaded successfully", self.model_name)
|
|
|
425 |
if output_begin_string:
|
426 |
prompt += output_begin_string
|
427 |
|
428 |
+
if len(prompt) > 100000: # Early text length check
|
429 |
+
logger.error(f"Prompt length ({len(prompt)}) exceeds limit (100000).")
|
430 |
+
return None, True
|
431 |
+
|
432 |
if check_token_status and max_token:
|
433 |
num_input_tokens = len(self.tokenizer.encode(prompt, add_special_tokens=False))
|
434 |
if num_input_tokens > max_token:
|
|
|
440 |
return None, True
|
441 |
|
442 |
try:
|
443 |
+
torch.cuda.empty_cache()
|
444 |
output = model.generate(prompt, sampling_params=sampling_params)
|
445 |
output = output[0].outputs[0].text
|
446 |
logger.debug("Inference output: %s", output[:100])
|
|
|
449 |
return None, True
|
450 |
|
451 |
torch.cuda.empty_cache()
|
452 |
+
gc.collect()
|
453 |
if check_token_status:
|
454 |
return output, False
|
455 |
return output
|
|
|
457 |
def run_quick_summary(self, message: str, temperature: float = 0.1, max_new_tokens: int = 256, max_token: int = 1024):
|
458 |
"""Generate a fast, concise summary of potential missed diagnoses without tool calls"""
|
459 |
logger.debug("Starting quick summary for message: %s", message[:100])
|
460 |
+
if len(message) > 50000:
|
461 |
+
logger.warning(f"Message length ({len(message)}) exceeds limit (50000). Truncating.")
|
462 |
+
message = message[:50000]
|
463 |
+
|
464 |
prompt = """
|
465 |
Analyze the patient record excerpt for missed diagnoses, focusing ONLY on clinical findings such as symptoms, medications, or evaluation results. Provide a concise summary in ONE paragraph without headings or bullet points. ALWAYS treat medications or psychiatric evaluations as potential missed diagnoses, specifying their implications (e.g., 'use of Seroquel may indicate untreated psychosis'). Recommend urgent review for identified findings. Do NOT use external tools or repeat non-clinical data (e.g., name, date of birth). If no clinical findings are present, state 'No missed diagnoses identified' in ONE sentence.
|
466 |
Patient Record Excerpt:
|
|
|
489 |
call_agent_level: int, report_path: str):
|
490 |
"""Run detailed report generation in the background and save to file"""
|
491 |
logger.debug("Starting background report for message: %s", message[:100])
|
492 |
+
if len(message) > 50000:
|
493 |
+
logger.warning(f"Message length ({len(message)}) exceeds limit (50000). Truncating.")
|
494 |
+
message = message[:50000]
|
495 |
+
|
496 |
combined_response = ""
|
497 |
history_copy = history.copy()
|
498 |
|