Ali2206 commited on
Commit
3cdcbc4
·
verified ·
1 Parent(s): 6af3907

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -32
app.py CHANGED
@@ -14,36 +14,42 @@ import re
14
  import tempfile
15
  import threading
16
 
17
- # Environment setup
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  current_dir = os.path.dirname(os.path.abspath(__file__))
19
  src_path = os.path.abspath(os.path.join(current_dir, "src"))
20
  sys.path.insert(0, src_path)
21
 
22
- # Cache directories
23
- base_dir = "/data"
24
- os.makedirs(base_dir, exist_ok=True)
25
- model_cache_dir = os.path.join(base_dir, "txagent_models")
26
- tool_cache_dir = os.path.join(base_dir, "tool_cache")
27
- file_cache_dir = os.path.join(base_dir, "cache")
28
- report_dir = "/data/reports"
29
- vllm_cache_dir = os.path.join(base_dir, "vllm_cache")
30
-
31
- os.makedirs(model_cache_dir, exist_ok=True)
32
- os.makedirs(tool_cache_dir, exist_ok=True)
33
- os.makedirs(file_cache_dir, exist_ok=True)
34
- os.makedirs(report_dir, exist_ok=True)
35
- os.makedirs(vllm_cache_dir, exist_ok=True)
36
-
37
- os.environ.update({
38
- "TRANSFORMERS_CACHE": model_cache_dir,
39
- "HF_HOME": model_cache_dir,
40
- "VLLM_CACHE_DIR": vllm_cache_dir,
41
- "TOKENIZERS_PARALLELISM": "false",
42
- "CUDA_LAUNCH_BLOCKING": "1"
43
- })
44
-
45
  from txagent.txagent import TxAgent
46
 
 
 
 
47
  MEDICAL_KEYWORDS = {
48
  'diagnosis', 'assessment', 'plan', 'results', 'medications',
49
  'allergies', 'summary', 'impression', 'findings', 'recommendations'
@@ -60,11 +66,11 @@ def extract_priority_pages(file_path: str, max_pages: int = 20) -> str:
60
  try:
61
  text_chunks = []
62
  with pdfplumber.open(file_path) as pdf:
63
- # Process first three pages
64
  for i, page in enumerate(pdf.pages[:3]):
65
  text = page.extract_text() or ""
66
  text_chunks.append(f"=== Page {i+1} ===\n{text.strip()}")
67
- # Check for keywords on later pages and add if found
68
  for i, page in enumerate(pdf.pages[3:max_pages], start=4):
69
  page_text = page.extract_text() or ""
70
  if any(re.search(rf'\b{kw}\b', page_text.lower()) for kw in MEDICAL_KEYWORDS):
@@ -121,7 +127,9 @@ def full_pdf_processing(file_path: str, file_hash_value: str):
121
  except Exception as e:
122
  print(f"Background processing failed: {str(e)}")
123
 
124
- # Global agent and a lock for safe multi-threaded access
 
 
125
  agent = None
126
  agent_lock = Lock()
127
 
@@ -147,13 +155,16 @@ def load_agent_in_background():
147
  global agent
148
  with agent_lock:
149
  if agent is None:
150
- print("Initializing agent in background...")
151
  agent = init_agent()
152
  print("Agent initialization complete.")
153
 
154
  # Start background agent loading at startup
155
  threading.Thread(target=load_agent_in_background, daemon=True).start()
156
 
 
 
 
157
  def create_ui():
158
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
159
  gr.Markdown("""
@@ -197,7 +208,7 @@ def create_ui():
197
  extracted_data = "\n".join(results)
198
  file_hash_value = file_hash(files[0].name) if hasattr(files[0], 'name') else ""
199
 
200
- # Truncate the extracted data to avoid token overflows
201
  max_extracted_chars = 12000
202
  truncated_data = extracted_data[:max_extracted_chars]
203
 
@@ -236,8 +247,7 @@ Medical Records:
236
  history[-1] = {"role": "assistant", "content": cleaned}
237
  yield history, None
238
  except Exception as agent_error:
239
- history[-1] = {"role": "assistant",
240
- "content": f"❌ Analysis failed during processing: {str(agent_error)}"}
241
  yield history, None
242
  return
243
 
@@ -275,6 +285,6 @@ if __name__ == "__main__":
275
  server_name="0.0.0.0",
276
  server_port=7860,
277
  show_error=True,
278
- allowed_paths=["/data/reports"],
279
  share=False
280
  )
 
14
  import tempfile
15
  import threading
16
 
17
+ # ---------------------------------------------------------------------------------------
18
+ # Setup persistent directories for Hugging Face Spaces
19
+ # ---------------------------------------------------------------------------------------
20
+ # Use a persistent cache directory (adjust the path as needed based on your HF Space settings)
21
+ persistent_dir = "/workspace/hf_cache"
22
+ os.makedirs(persistent_dir, exist_ok=True)
23
+
24
+ model_cache_dir = os.path.join(persistent_dir, "txagent_models")
25
+ tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
26
+ file_cache_dir = os.path.join(persistent_dir, "cache")
27
+ report_dir = os.path.join(persistent_dir, "reports")
28
+ vllm_cache_dir = os.path.join(persistent_dir, "vllm_cache")
29
+
30
+ for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
31
+ os.makedirs(directory, exist_ok=True)
32
+
33
+ # Set environment variables so that model and transformers caches point to persistent storage.
34
+ os.environ["HF_HOME"] = model_cache_dir
35
+ os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
36
+ os.environ["VLLM_CACHE_DIR"] = vllm_cache_dir
37
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
38
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
39
+
40
+ # Append the local source path if needed
41
  current_dir = os.path.dirname(os.path.abspath(__file__))
42
  src_path = os.path.abspath(os.path.join(current_dir, "src"))
43
  sys.path.insert(0, src_path)
44
 
45
+ # ---------------------------------------------------------------------------------------
46
+ # Import the TxAgent from your tool package
47
+ # ---------------------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  from txagent.txagent import TxAgent
49
 
50
+ # ---------------------------------------------------------------------------------------
51
+ # Define constants and helper functions
52
+ # ---------------------------------------------------------------------------------------
53
  MEDICAL_KEYWORDS = {
54
  'diagnosis', 'assessment', 'plan', 'results', 'medications',
55
  'allergies', 'summary', 'impression', 'findings', 'recommendations'
 
66
  try:
67
  text_chunks = []
68
  with pdfplumber.open(file_path) as pdf:
69
+ # Process first three pages always
70
  for i, page in enumerate(pdf.pages[:3]):
71
  text = page.extract_text() or ""
72
  text_chunks.append(f"=== Page {i+1} ===\n{text.strip()}")
73
+ # Process subsequent pages only if they contain key medical keywords
74
  for i, page in enumerate(pdf.pages[3:max_pages], start=4):
75
  page_text = page.extract_text() or ""
76
  if any(re.search(rf'\b{kw}\b', page_text.lower()) for kw in MEDICAL_KEYWORDS):
 
127
  except Exception as e:
128
  print(f"Background processing failed: {str(e)}")
129
 
130
+ # ---------------------------------------------------------------------------------------
131
+ # Global agent variable and thread-safe lock for background model loading
132
+ # ---------------------------------------------------------------------------------------
133
  agent = None
134
  agent_lock = Lock()
135
 
 
155
  global agent
156
  with agent_lock:
157
  if agent is None:
158
+ print("Initializing agent in background (this may take a while)...")
159
  agent = init_agent()
160
  print("Agent initialization complete.")
161
 
162
  # Start background agent loading at startup
163
  threading.Thread(target=load_agent_in_background, daemon=True).start()
164
 
165
+ # ---------------------------------------------------------------------------------------
166
+ # Define the Gradio UI
167
+ # ---------------------------------------------------------------------------------------
168
  def create_ui():
169
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
170
  gr.Markdown("""
 
208
  extracted_data = "\n".join(results)
209
  file_hash_value = file_hash(files[0].name) if hasattr(files[0], 'name') else ""
210
 
211
+ # Truncate extracted data to avoid token overflow
212
  max_extracted_chars = 12000
213
  truncated_data = extracted_data[:max_extracted_chars]
214
 
 
247
  history[-1] = {"role": "assistant", "content": cleaned}
248
  yield history, None
249
  except Exception as agent_error:
250
+ history[-1] = {"role": "assistant", "content": f"❌ Analysis failed during processing: {str(agent_error)}"}
 
251
  yield history, None
252
  return
253
 
 
285
  server_name="0.0.0.0",
286
  server_port=7860,
287
  show_error=True,
288
+ allowed_paths=[report_dir],
289
  share=False
290
  )