Ali2206 commited on
Commit
1244d40
·
verified ·
1 Parent(s): cbf903d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -7
app.py CHANGED
@@ -4,6 +4,7 @@ import pandas as pd
4
  import gradio as gr
5
  import re
6
  import hashlib
 
7
  from datetime import datetime
8
  from collections import defaultdict
9
  from typing import List, Dict, Tuple
@@ -13,10 +14,38 @@ WORKING_DIR = os.getcwd()
13
  REPORT_DIR = os.path.join(WORKING_DIR, "reports")
14
  os.makedirs(REPORT_DIR, exist_ok=True)
15
 
 
 
 
 
 
 
 
 
 
 
16
  class PatientHistoryAnalyzer:
17
  def __init__(self):
18
  self.max_token_length = 2000
19
  self.max_text_length = 500
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  def clean_text(self, text: str) -> str:
22
  """Clean and normalize text fields"""
@@ -185,6 +214,30 @@ class PatientHistoryAnalyzer:
185
  "### Recommendations"
186
  ])
187
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  def generate_report(self, analysis_results: List[str]) -> Tuple[str, str]:
189
  """Combine analysis results into final report"""
190
  report = [
@@ -226,14 +279,13 @@ class PatientHistoryAnalyzer:
226
  patient_data = self.process_excel(file_path)
227
  prompts = self.generate_analysis_prompt(patient_data)
228
 
229
- # Simulate LLM responses
230
- simulated_responses = [
231
- "### Summary of Current Status\nPatient shows improvement in blood pressure control...",
232
- "### Historical Patterns\nChronic back pain has been a consistent issue...",
233
- "### Medication Summary\nCurrent regimen includes 4 medications..."
234
- ]
235
 
236
- return self.generate_report(simulated_responses)
237
 
238
  except Exception as e:
239
  return f"Error during analysis: {str(e)}", ""
 
4
  import gradio as gr
5
  import re
6
  import hashlib
7
+ import shutil
8
  from datetime import datetime
9
  from collections import defaultdict
10
  from typing import List, Dict, Tuple
 
14
  REPORT_DIR = os.path.join(WORKING_DIR, "reports")
15
  os.makedirs(REPORT_DIR, exist_ok=True)
16
 
17
+ # Model configuration
18
+ MODEL_CACHE_DIR = os.path.join(WORKING_DIR, "model_cache")
19
+ os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
20
+ os.environ["HF_HOME"] = MODEL_CACHE_DIR
21
+ os.environ["TRANSFORMERS_CACHE"] = MODEL_CACHE_DIR
22
+
23
+ # Import TxAgent after setting up environment
24
+ sys.path.append(os.path.join(WORKING_DIR, "src"))
25
+ from txagent.txagent import TxAgent
26
+
27
  class PatientHistoryAnalyzer:
28
  def __init__(self):
29
  self.max_token_length = 2000
30
  self.max_text_length = 500
31
+ self.agent = self._initialize_agent()
32
+
33
+ def _initialize_agent(self):
34
+ """Initialize the TxAgent with proper configuration"""
35
+ tool_path = os.path.join(WORKING_DIR, "data", "new_tool.json")
36
+ if not os.path.exists(tool_path):
37
+ raise FileNotFoundError(f"Tool file not found at {tool_path}")
38
+
39
+ return TxAgent(
40
+ model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
41
+ rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
42
+ tool_files_dict={"new_tool": tool_path},
43
+ force_finish=True,
44
+ enable_checker=True,
45
+ step_rag_num=4,
46
+ seed=100,
47
+ additional_default_tools=[],
48
+ )
49
 
50
  def clean_text(self, text: str) -> str:
51
  """Clean and normalize text fields"""
 
214
  "### Recommendations"
215
  ])
216
 
217
+ def _call_agent(self, prompt: str) -> str:
218
+ """Call TxAgent with proper error handling"""
219
+ try:
220
+ response = ""
221
+ for result in self.agent.run_gradio_chat(
222
+ message=prompt,
223
+ history=[],
224
+ temperature=0.2,
225
+ max_new_tokens=1024,
226
+ max_token=2048,
227
+ call_agent=False,
228
+ conversation=[],
229
+ ):
230
+ if isinstance(result, list):
231
+ for r in result:
232
+ if hasattr(r, 'content') and r.content:
233
+ response += r.content + "\n"
234
+ elif isinstance(result, str):
235
+ response += result + "\n"
236
+
237
+ return response.strip()
238
+ except Exception as e:
239
+ return f"Error in model response: {str(e)}"
240
+
241
  def generate_report(self, analysis_results: List[str]) -> Tuple[str, str]:
242
  """Combine analysis results into final report"""
243
  report = [
 
279
  patient_data = self.process_excel(file_path)
280
  prompts = self.generate_analysis_prompt(patient_data)
281
 
282
+ # Call TxAgent for each prompt
283
+ analysis_results = []
284
+ for prompt in prompts:
285
+ response = self._call_agent(prompt['content'])
286
+ analysis_results.append(response)
 
287
 
288
+ return self.generate_report(analysis_results)
289
 
290
  except Exception as e:
291
  return f"Error during analysis: {str(e)}", ""