Ali2206 commited on
Commit
d2cced3
·
verified ·
1 Parent(s): 50abd96

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -56
app.py CHANGED
@@ -11,17 +11,12 @@ import shutil
11
  import time
12
  from functools import lru_cache
13
 
14
- # Environment and path setup
15
- current_dir = os.path.dirname(__file__)
16
- src_path = os.path.abspath(os.path.join(current_dir, "src"))
17
-
18
- print(">> Adding to path:", src_path)
19
  sys.path.insert(0, src_path)
20
 
21
- # Now import
22
-
23
-
24
- # Configure cache directories
25
  base_dir = "/data"
26
  model_cache_dir = os.path.join(base_dir, "txagent_models")
27
  tool_cache_dir = os.path.join(base_dir, "tool_cache")
@@ -31,14 +26,14 @@ os.makedirs(model_cache_dir, exist_ok=True)
31
  os.makedirs(tool_cache_dir, exist_ok=True)
32
  os.makedirs(file_cache_dir, exist_ok=True)
33
 
34
- os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
35
  os.environ["HF_HOME"] = model_cache_dir
 
36
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
37
  os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
38
 
39
  from txagent.txagent import TxAgent
40
 
41
- # Utility functions
42
  def sanitize_utf8(text: str) -> str:
43
  return text.encode("utf-8", "ignore").decode("utf-8")
44
 
@@ -48,8 +43,7 @@ def file_hash(path: str) -> str:
48
 
49
  @lru_cache(maxsize=100)
50
  def get_cached_response(prompt: str, file_hash: str) -> Optional[str]:
51
- """Cache for frequent queries"""
52
- return None # Implement actual cache lookup if needed
53
 
54
  def convert_file_to_json(file_path: str, file_type: str) -> str:
55
  try:
@@ -90,7 +84,6 @@ def convert_file_to_json(file_path: str, file_type: str) -> str:
90
  return json.dumps({"error": f"Error reading {os.path.basename(file_path)}: {str(e)}"})
91
 
92
  def convert_files_to_json_parallel(uploaded_files: list) -> str:
93
- """Process files in parallel using ThreadPool"""
94
  extracted_text = []
95
  with ThreadPoolExecutor(max_workers=4) as executor:
96
  futures = []
@@ -100,14 +93,12 @@ def convert_files_to_json_parallel(uploaded_files: list) -> str:
100
  path = file.name
101
  ext = path.split(".")[-1].lower()
102
  futures.append(executor.submit(convert_file_to_json, path, ext))
103
-
104
  for future in as_completed(futures):
105
  extracted_text.append(sanitize_utf8(future.result()))
106
  return "\n".join(extracted_text)
107
 
108
  def init_agent():
109
- """Initialize the TxAgent with optimized settings"""
110
- # Copy default tool file if needed
111
  default_tool_path = os.path.abspath("data/new_tool.json")
112
  target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
113
  if not os.path.exists(target_tool_path):
@@ -115,20 +106,16 @@ def init_agent():
115
 
116
  model_name = "mims-harvard/TxAgent-T1-Llama-3.1-8B"
117
  rag_model_name = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B"
118
-
119
  agent = TxAgent(
120
  model_name=model_name,
121
  rag_model_name=rag_model_name,
122
  tool_files_dict={"new_tool": target_tool_path},
123
  force_finish=True,
124
  enable_checker=True,
125
- step_rag_num=8, # Reduced from 10
126
  seed=100,
127
- additional_default_tools=[],
128
- torch_dtype="auto",
129
- device_map="auto",
130
- load_in_4bit=False,
131
- load_in_8bit=False
132
  )
133
  agent.init_model()
134
  return agent
@@ -154,12 +141,9 @@ def create_ui(agent: TxAgent):
154
  history.append({"role": "assistant", "content": "⏳ Processing your request..."})
155
  yield history
156
 
157
- # File processing with timing
158
- file_process_time = time.time()
159
  extracted_text = ""
160
  if uploaded_files and isinstance(uploaded_files, list):
161
  extracted_text = convert_files_to_json_parallel(uploaded_files)
162
- print(f"File processing took: {time.time() - file_process_time:.2f}s")
163
 
164
  context = (
165
  "You are an expert clinical AI assistant. Review this patient's history, "
@@ -168,18 +152,16 @@ def create_ui(agent: TxAgent):
168
  )
169
  chunked_prompt = f"{context}\n\n--- Patient Record ---\n{extracted_text}\n\n[Final Analysis]"
170
 
171
- # Model processing with timing
172
- model_start = time.time()
173
  generator = agent.run_gradio_chat(
174
  message=chunked_prompt,
175
  history=[],
176
  temperature=0.3,
177
- max_new_tokens=768, # Reduced from 1024
178
- max_token=4096, # Reduced from 8192
179
  call_agent=False,
180
  conversation=conversation,
181
  uploaded_files=uploaded_files,
182
- max_round=10 # Reduced from 30
183
  )
184
 
185
  final_response = []
@@ -190,14 +172,12 @@ def create_ui(agent: TxAgent):
190
  final_response.append(update)
191
  elif isinstance(update, list):
192
  final_response.extend(msg.content for msg in update if hasattr(msg, 'content'))
193
-
194
- # Yield intermediate results periodically
195
- if len(final_response) % 3 == 0: # More frequent updates
196
  history[-1] = {"role": "assistant", "content": "".join(final_response).strip()}
197
  yield history
198
 
199
  history[-1] = {"role": "assistant", "content": "".join(final_response).strip() or "❌ No response."}
200
- print(f"Model processing took: {time.time() - model_start:.2f}s")
201
  yield history
202
 
203
  except Exception as chat_error:
@@ -220,27 +200,9 @@ def create_ui(agent: TxAgent):
220
  return demo
221
 
222
  if __name__ == "__main__":
223
- # Initialize agent and warm it up
224
  print("Initializing agent...")
225
  agent = init_agent()
226
-
227
- # Warm-up call
228
- print("Performing warm-up call...")
229
- try:
230
- warm_up = agent.run_gradio_chat(
231
- message="Warm up",
232
- history=[],
233
- temperature=0.1,
234
- max_new_tokens=10,
235
- max_token=100,
236
- call_agent=False
237
- )
238
- for _ in warm_up:
239
- pass
240
- except:
241
- pass
242
-
243
- # Launch Gradio interface
244
  print("Launching interface...")
245
  demo = create_ui(agent)
246
  demo.queue(concurrency_count=3).launch(
@@ -248,4 +210,4 @@ if __name__ == "__main__":
248
  server_port=7860,
249
  show_error=True,
250
  share=True
251
- )
 
11
  import time
12
  from functools import lru_cache
13
 
14
+ # Add src to Python path
15
+ src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src"))
16
+ print(f"Adding to path: {src_path}")
 
 
17
  sys.path.insert(0, src_path)
18
 
19
+ # Configure Hugging Face and cache dirs
 
 
 
20
  base_dir = "/data"
21
  model_cache_dir = os.path.join(base_dir, "txagent_models")
22
  tool_cache_dir = os.path.join(base_dir, "tool_cache")
 
26
  os.makedirs(tool_cache_dir, exist_ok=True)
27
  os.makedirs(file_cache_dir, exist_ok=True)
28
 
 
29
  os.environ["HF_HOME"] = model_cache_dir
30
+ os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
31
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
32
  os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
33
 
34
  from txagent.txagent import TxAgent
35
 
36
+ # Utils
37
  def sanitize_utf8(text: str) -> str:
38
  return text.encode("utf-8", "ignore").decode("utf-8")
39
 
 
43
 
44
  @lru_cache(maxsize=100)
45
  def get_cached_response(prompt: str, file_hash: str) -> Optional[str]:
46
+ return None
 
47
 
48
  def convert_file_to_json(file_path: str, file_type: str) -> str:
49
  try:
 
84
  return json.dumps({"error": f"Error reading {os.path.basename(file_path)}: {str(e)}"})
85
 
86
  def convert_files_to_json_parallel(uploaded_files: list) -> str:
 
87
  extracted_text = []
88
  with ThreadPoolExecutor(max_workers=4) as executor:
89
  futures = []
 
93
  path = file.name
94
  ext = path.split(".")[-1].lower()
95
  futures.append(executor.submit(convert_file_to_json, path, ext))
96
+
97
  for future in as_completed(futures):
98
  extracted_text.append(sanitize_utf8(future.result()))
99
  return "\n".join(extracted_text)
100
 
101
  def init_agent():
 
 
102
  default_tool_path = os.path.abspath("data/new_tool.json")
103
  target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
104
  if not os.path.exists(target_tool_path):
 
106
 
107
  model_name = "mims-harvard/TxAgent-T1-Llama-3.1-8B"
108
  rag_model_name = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B"
109
+
110
  agent = TxAgent(
111
  model_name=model_name,
112
  rag_model_name=rag_model_name,
113
  tool_files_dict={"new_tool": target_tool_path},
114
  force_finish=True,
115
  enable_checker=True,
116
+ step_rag_num=8,
117
  seed=100,
118
+ additional_default_tools=[]
 
 
 
 
119
  )
120
  agent.init_model()
121
  return agent
 
141
  history.append({"role": "assistant", "content": "⏳ Processing your request..."})
142
  yield history
143
 
 
 
144
  extracted_text = ""
145
  if uploaded_files and isinstance(uploaded_files, list):
146
  extracted_text = convert_files_to_json_parallel(uploaded_files)
 
147
 
148
  context = (
149
  "You are an expert clinical AI assistant. Review this patient's history, "
 
152
  )
153
  chunked_prompt = f"{context}\n\n--- Patient Record ---\n{extracted_text}\n\n[Final Analysis]"
154
 
 
 
155
  generator = agent.run_gradio_chat(
156
  message=chunked_prompt,
157
  history=[],
158
  temperature=0.3,
159
+ max_new_tokens=768,
160
+ max_token=4096,
161
  call_agent=False,
162
  conversation=conversation,
163
  uploaded_files=uploaded_files,
164
+ max_round=10
165
  )
166
 
167
  final_response = []
 
172
  final_response.append(update)
173
  elif isinstance(update, list):
174
  final_response.extend(msg.content for msg in update if hasattr(msg, 'content'))
175
+
176
+ if len(final_response) % 3 == 0:
 
177
  history[-1] = {"role": "assistant", "content": "".join(final_response).strip()}
178
  yield history
179
 
180
  history[-1] = {"role": "assistant", "content": "".join(final_response).strip() or "❌ No response."}
 
181
  yield history
182
 
183
  except Exception as chat_error:
 
200
  return demo
201
 
202
  if __name__ == "__main__":
 
203
  print("Initializing agent...")
204
  agent = init_agent()
205
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  print("Launching interface...")
207
  demo = create_ui(agent)
208
  demo.queue(concurrency_count=3).launch(
 
210
  server_port=7860,
211
  show_error=True,
212
  share=True
213
+ )