Ali2206 commited on
Commit
f2a9805
·
verified ·
1 Parent(s): a046927

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -83
app.py CHANGED
@@ -1,29 +1,17 @@
1
- # Optimized app.py for A100 GPU (safe parallel batching + no stuck + max performance)
2
-
3
  import sys
4
  import os
5
  import json
6
  import shutil
7
  import re
8
- import time
9
  import gc
10
- import threading
11
- from concurrent.futures import ThreadPoolExecutor, as_completed
12
- from typing import List, Tuple, Dict, Union
13
  from datetime import datetime
 
14
  import pandas as pd
15
  import gradio as gr
 
16
 
17
- # Constants
18
- MAX_MODEL_TOKENS = 131072
19
- MAX_NEW_TOKENS = 4096
20
- MAX_CHUNK_TOKENS = 8192
21
- PROMPT_OVERHEAD = 300
22
- BATCH_SIZE = 2 # Safer for vLLM
23
- MAX_PARALLEL_JOBS = 2 # Max threads launched in parallel
24
- SLEEP_BETWEEN_JOBS = 0.5 # Seconds
25
-
26
- # Paths
27
  persistent_dir = "/data/hf_cache"
28
  model_cache_dir = os.path.join(persistent_dir, "txagent_models")
29
  tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
@@ -42,14 +30,21 @@ sys.path.insert(0, src_path)
42
 
43
  from txagent.txagent import TxAgent
44
 
45
- # Utility functions
 
 
 
 
 
 
 
 
46
  def estimate_tokens(text: str) -> int:
47
  return len(text) // 4 + 1
48
 
49
  def clean_response(text: str) -> str:
50
  text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL)
51
  text = re.sub(r"\n{3,}", "\n\n", text)
52
- text = re.sub(r"[^\n#\-\*\w\s\.,:\(\)]+", "", text)
53
  return text.strip()
54
 
55
  def extract_text_from_excel(path: str) -> str:
@@ -84,7 +79,7 @@ def split_text(text: str, max_tokens=MAX_CHUNK_TOKENS) -> List[str]:
84
  chunks.append("\n".join(current))
85
  return chunks
86
 
87
- def batch_chunks(chunks: List[str], batch_size: int = 2) -> List[List[str]]:
88
  return [chunks[i:i+batch_size] for i in range(0, len(chunks), batch_size)]
89
 
90
  def build_prompt(chunk: str) -> str:
@@ -106,48 +101,42 @@ def init_agent() -> TxAgent:
106
  agent.init_model()
107
  return agent
108
 
109
- def process_single_batch(agent, batch: List[str]) -> str:
110
- prompts = [build_prompt(chunk) for chunk in batch]
111
- joined_prompt = "\n\n".join(prompts)
112
- response = ""
113
- try:
114
- for r in agent.run_gradio_chat(
115
- message=joined_prompt,
116
- history=[],
117
- temperature=0.0,
118
- max_new_tokens=MAX_NEW_TOKENS,
119
- max_token=MAX_MODEL_TOKENS,
120
- call_agent=False,
121
- conversation=[]
122
- ):
123
- if isinstance(r, str):
124
- response += r
125
- elif isinstance(r, list):
126
- for m in r:
127
- if hasattr(m, "content"):
128
- response += m.content
129
- elif hasattr(r, "content"):
130
- response += r.content
131
- return clean_response(response)
132
- except Exception as e:
133
- return f"❌ Error: {str(e)}"
134
-
135
- def analyze_batches_parallel(agent, batches: List[List[str]]) -> List[str]:
136
  results = []
137
- with ThreadPoolExecutor(max_workers=MAX_PARALLEL_JOBS) as executor:
138
- futures = []
139
- for batch in batches:
140
- futures.append(executor.submit(process_single_batch, agent, batch))
141
- time.sleep(SLEEP_BETWEEN_JOBS)
142
- for future in as_completed(futures):
143
- results.append(future.result())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  torch.cuda.empty_cache()
145
  gc.collect()
146
  return results
147
 
148
  def generate_final_summary(agent, combined: str) -> str:
149
- final_prompt = f"""Provide a structured medical report based on the following summaries:\n\n{combined}\n\nRespond in detailed medical bullet points."""
150
- full_report = ""
151
  for r in agent.run_gradio_chat(
152
  message=final_prompt,
153
  history=[],
@@ -158,14 +147,14 @@ def generate_final_summary(agent, combined: str) -> str:
158
  conversation=[]
159
  ):
160
  if isinstance(r, str):
161
- full_report += r
162
  elif isinstance(r, list):
163
  for m in r:
164
  if hasattr(m, "content"):
165
- full_report += m.content
166
  elif hasattr(r, "content"):
167
- full_report += r.content
168
- return clean_response(full_report)
169
 
170
  def process_report(agent, file, messages: List[Dict[str, str]]) -> Tuple[List[Dict[str, str]], Union[str, None]]:
171
  if not file or not hasattr(file, "name"):
@@ -177,9 +166,9 @@ def process_report(agent, file, messages: List[Dict[str, str]]) -> Tuple[List[Di
177
  extracted = extract_text_from_excel(file.name)
178
  chunks = split_text(extracted)
179
  batches = batch_chunks(chunks, batch_size=BATCH_SIZE)
180
- messages.append({"role": "assistant", "content": f"🔍 Split into {len(batches)} batches. Parallel analyzing..."})
181
 
182
- batch_results = analyze_batches_parallel(agent, batches)
183
  valid = [res for res in batch_results if not res.startswith("❌")]
184
 
185
  if not valid:
@@ -200,20 +189,11 @@ def process_report(agent, file, messages: List[Dict[str, str]]) -> Tuple[List[Di
200
  return messages, None
201
 
202
  def create_ui(agent):
203
- with gr.Blocks(css="""
204
- html, body, .gradio-container {
205
- background-color: #0e1621;
206
- color: #e0e0e0;
207
- font-family: 'Inter', sans-serif;
208
- }
209
- h2, h3, h4 { color: #89b4fa; font-weight: 600; }
210
- button.gr-button-primary {
211
- background-color: #007bff !important;
212
- color: white !important;
213
- font-weight: bold;
214
- }
215
- """) as demo:
216
- gr.Markdown("""<h2>📄 CPS: Clinical Patient Support System</h2>""")
217
  with gr.Column():
218
  chatbot = gr.Chatbot(label="CPS Assistant", height=700, type="messages")
219
  upload = gr.File(label="Upload Medical File", file_types=[".xlsx"])
@@ -227,14 +207,10 @@ def create_ui(agent):
227
  return messages, gr.update(visible=bool(report_path), value=report_path), messages
228
 
229
  analyze.click(fn=handle_analysis, inputs=[upload, state], outputs=[chatbot, download, state])
230
-
231
  return demo
232
 
 
233
  if __name__ == "__main__":
234
- try:
235
- agent = init_agent()
236
- ui = create_ui(agent)
237
- ui.launch(server_name="0.0.0.0", server_port=7860, allowed_paths=["/data/hf_cache/reports"], share=False)
238
- except Exception as err:
239
- print(f"Startup failed: {err}")
240
- sys.exit(1)
 
 
 
1
  import sys
2
  import os
3
  import json
4
  import shutil
5
  import re
 
6
  import gc
7
+ import time
 
 
8
  from datetime import datetime
9
+ from typing import List, Tuple, Dict, Union
10
  import pandas as pd
11
  import gradio as gr
12
+ import torch
13
 
14
+ # === Configuration ===
 
 
 
 
 
 
 
 
 
15
  persistent_dir = "/data/hf_cache"
16
  model_cache_dir = os.path.join(persistent_dir, "txagent_models")
17
  tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
 
30
 
31
  from txagent.txagent import TxAgent
32
 
33
+ # === Constants ===
34
+ MAX_MODEL_TOKENS = 131072
35
+ MAX_NEW_TOKENS = 4096
36
+ MAX_CHUNK_TOKENS = 8192
37
+ BATCH_SIZE = 2
38
+ PROMPT_OVERHEAD = 300
39
+ SAFE_SLEEP = 0.5 # seconds between batches
40
+
41
+ # === Utility Functions ===
42
  def estimate_tokens(text: str) -> int:
43
  return len(text) // 4 + 1
44
 
45
  def clean_response(text: str) -> str:
46
  text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL)
47
  text = re.sub(r"\n{3,}", "\n\n", text)
 
48
  return text.strip()
49
 
50
  def extract_text_from_excel(path: str) -> str:
 
79
  chunks.append("\n".join(current))
80
  return chunks
81
 
82
+ def batch_chunks(chunks: List[str], batch_size: int = BATCH_SIZE) -> List[List[str]]:
83
  return [chunks[i:i+batch_size] for i in range(0, len(chunks), batch_size)]
84
 
85
  def build_prompt(chunk: str) -> str:
 
101
  agent.init_model()
102
  return agent
103
 
104
+ # === Main Processing ===
105
+ def analyze_batches(agent, batches: List[List[str]]) -> List[str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  results = []
107
+ for batch in batches:
108
+ prompt = "\n\n".join(build_prompt(chunk) for chunk in batch)
109
+ try:
110
+ batch_response = ""
111
+ for r in agent.run_gradio_chat(
112
+ message=prompt,
113
+ history=[],
114
+ temperature=0.0,
115
+ max_new_tokens=MAX_NEW_TOKENS,
116
+ max_token=MAX_MODEL_TOKENS,
117
+ call_agent=False,
118
+ conversation=[]
119
+ ):
120
+ if isinstance(r, str):
121
+ batch_response += r
122
+ elif isinstance(r, list):
123
+ for m in r:
124
+ if hasattr(m, "content"):
125
+ batch_response += m.content
126
+ elif hasattr(r, "content"):
127
+ batch_response += r.content
128
+ results.append(clean_response(batch_response))
129
+ time.sleep(SAFE_SLEEP)
130
+ except Exception as e:
131
+ results.append(f"❌ Batch failed: {str(e)}")
132
+ time.sleep(SAFE_SLEEP * 2) # longer sleep on error
133
  torch.cuda.empty_cache()
134
  gc.collect()
135
  return results
136
 
137
  def generate_final_summary(agent, combined: str) -> str:
138
+ final_prompt = f"Provide a structured medical report based on the following summaries:\n\n{combined}\n\nRespond in detailed medical bullet points."
139
+ final_response = ""
140
  for r in agent.run_gradio_chat(
141
  message=final_prompt,
142
  history=[],
 
147
  conversation=[]
148
  ):
149
  if isinstance(r, str):
150
+ final_response += r
151
  elif isinstance(r, list):
152
  for m in r:
153
  if hasattr(m, "content"):
154
+ final_response += m.content
155
  elif hasattr(r, "content"):
156
+ final_response += r.content
157
+ return clean_response(final_response)
158
 
159
  def process_report(agent, file, messages: List[Dict[str, str]]) -> Tuple[List[Dict[str, str]], Union[str, None]]:
160
  if not file or not hasattr(file, "name"):
 
166
  extracted = extract_text_from_excel(file.name)
167
  chunks = split_text(extracted)
168
  batches = batch_chunks(chunks, batch_size=BATCH_SIZE)
169
+ messages.append({"role": "assistant", "content": f"🔍 Split into {len(batches)} batches. Analyzing..."})
170
 
171
+ batch_results = analyze_batches(agent, batches)
172
  valid = [res for res in batch_results if not res.startswith("❌")]
173
 
174
  if not valid:
 
189
  return messages, None
190
 
191
  def create_ui(agent):
192
+ with gr.Blocks(css="""html, body, .gradio-container {background: #0e1621; color: #e0e0e0;}""") as demo:
193
+ gr.Markdown("""
194
+ <h2>📄 CPS: Clinical Patient Support System</h2>
195
+ <p>Analyze and summarize unstructured medical files using AI (optimized for A100 GPU).</p>
196
+ """)
 
 
 
 
 
 
 
 
 
197
  with gr.Column():
198
  chatbot = gr.Chatbot(label="CPS Assistant", height=700, type="messages")
199
  upload = gr.File(label="Upload Medical File", file_types=[".xlsx"])
 
207
  return messages, gr.update(visible=bool(report_path), value=report_path), messages
208
 
209
  analyze.click(fn=handle_analysis, inputs=[upload, state], outputs=[chatbot, download, state])
 
210
  return demo
211
 
212
+ # === Main ===
213
  if __name__ == "__main__":
214
+ agent = init_agent()
215
+ ui = create_ui(agent)
216
+ ui.launch(server_name="0.0.0.0", server_port=7860, allowed_paths=["/data/hf_cache/reports"], share=False)