Update app.py
Browse files
app.py
CHANGED
@@ -4,6 +4,7 @@ import pandas as pd
|
|
4 |
import gradio as gr
|
5 |
import re
|
6 |
import hashlib
|
|
|
7 |
from datetime import datetime
|
8 |
from collections import defaultdict
|
9 |
from typing import List, Dict, Tuple
|
@@ -13,10 +14,38 @@ WORKING_DIR = os.getcwd()
|
|
13 |
REPORT_DIR = os.path.join(WORKING_DIR, "reports")
|
14 |
os.makedirs(REPORT_DIR, exist_ok=True)
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
class PatientHistoryAnalyzer:
|
17 |
def __init__(self):
|
18 |
self.max_token_length = 2000
|
19 |
self.max_text_length = 500
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
def clean_text(self, text: str) -> str:
|
22 |
"""Clean and normalize text fields"""
|
@@ -185,6 +214,30 @@ class PatientHistoryAnalyzer:
|
|
185 |
"### Recommendations"
|
186 |
])
|
187 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
def generate_report(self, analysis_results: List[str]) -> Tuple[str, str]:
|
189 |
"""Combine analysis results into final report"""
|
190 |
report = [
|
@@ -226,14 +279,13 @@ class PatientHistoryAnalyzer:
|
|
226 |
patient_data = self.process_excel(file_path)
|
227 |
prompts = self.generate_analysis_prompt(patient_data)
|
228 |
|
229 |
-
#
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
]
|
235 |
|
236 |
-
return self.generate_report(
|
237 |
|
238 |
except Exception as e:
|
239 |
return f"Error during analysis: {str(e)}", ""
|
|
|
4 |
import gradio as gr
|
5 |
import re
|
6 |
import hashlib
|
7 |
+
import shutil
|
8 |
from datetime import datetime
|
9 |
from collections import defaultdict
|
10 |
from typing import List, Dict, Tuple
|
|
|
14 |
REPORT_DIR = os.path.join(WORKING_DIR, "reports")
|
15 |
os.makedirs(REPORT_DIR, exist_ok=True)
|
16 |
|
17 |
+
# Model configuration
|
18 |
+
MODEL_CACHE_DIR = os.path.join(WORKING_DIR, "model_cache")
|
19 |
+
os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
|
20 |
+
os.environ["HF_HOME"] = MODEL_CACHE_DIR
|
21 |
+
os.environ["TRANSFORMERS_CACHE"] = MODEL_CACHE_DIR
|
22 |
+
|
23 |
+
# Import TxAgent after setting up environment
|
24 |
+
sys.path.append(os.path.join(WORKING_DIR, "src"))
|
25 |
+
from txagent.txagent import TxAgent
|
26 |
+
|
27 |
class PatientHistoryAnalyzer:
|
28 |
def __init__(self):
|
29 |
self.max_token_length = 2000
|
30 |
self.max_text_length = 500
|
31 |
+
self.agent = self._initialize_agent()
|
32 |
+
|
33 |
+
def _initialize_agent(self):
|
34 |
+
"""Initialize the TxAgent with proper configuration"""
|
35 |
+
tool_path = os.path.join(WORKING_DIR, "data", "new_tool.json")
|
36 |
+
if not os.path.exists(tool_path):
|
37 |
+
raise FileNotFoundError(f"Tool file not found at {tool_path}")
|
38 |
+
|
39 |
+
return TxAgent(
|
40 |
+
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
|
41 |
+
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
|
42 |
+
tool_files_dict={"new_tool": tool_path},
|
43 |
+
force_finish=True,
|
44 |
+
enable_checker=True,
|
45 |
+
step_rag_num=4,
|
46 |
+
seed=100,
|
47 |
+
additional_default_tools=[],
|
48 |
+
)
|
49 |
|
50 |
def clean_text(self, text: str) -> str:
|
51 |
"""Clean and normalize text fields"""
|
|
|
214 |
"### Recommendations"
|
215 |
])
|
216 |
|
217 |
+
def _call_agent(self, prompt: str) -> str:
|
218 |
+
"""Call TxAgent with proper error handling"""
|
219 |
+
try:
|
220 |
+
response = ""
|
221 |
+
for result in self.agent.run_gradio_chat(
|
222 |
+
message=prompt,
|
223 |
+
history=[],
|
224 |
+
temperature=0.2,
|
225 |
+
max_new_tokens=1024,
|
226 |
+
max_token=2048,
|
227 |
+
call_agent=False,
|
228 |
+
conversation=[],
|
229 |
+
):
|
230 |
+
if isinstance(result, list):
|
231 |
+
for r in result:
|
232 |
+
if hasattr(r, 'content') and r.content:
|
233 |
+
response += r.content + "\n"
|
234 |
+
elif isinstance(result, str):
|
235 |
+
response += result + "\n"
|
236 |
+
|
237 |
+
return response.strip()
|
238 |
+
except Exception as e:
|
239 |
+
return f"Error in model response: {str(e)}"
|
240 |
+
|
241 |
def generate_report(self, analysis_results: List[str]) -> Tuple[str, str]:
|
242 |
"""Combine analysis results into final report"""
|
243 |
report = [
|
|
|
279 |
patient_data = self.process_excel(file_path)
|
280 |
prompts = self.generate_analysis_prompt(patient_data)
|
281 |
|
282 |
+
# Call TxAgent for each prompt
|
283 |
+
analysis_results = []
|
284 |
+
for prompt in prompts:
|
285 |
+
response = self._call_agent(prompt['content'])
|
286 |
+
analysis_results.append(response)
|
|
|
287 |
|
288 |
+
return self.generate_report(analysis_results)
|
289 |
|
290 |
except Exception as e:
|
291 |
return f"Error during analysis: {str(e)}", ""
|