Update app.py
Browse files
app.py
CHANGED
@@ -3,14 +3,13 @@ import pandas as pd
|
|
3 |
from datetime import datetime
|
4 |
from typing import List, Tuple, Dict, Union
|
5 |
import gradio as gr
|
6 |
-
from concurrent.futures import ThreadPoolExecutor
|
7 |
|
8 |
# Constants
|
9 |
MAX_MODEL_TOKENS = 131072
|
10 |
MAX_NEW_TOKENS = 4096
|
11 |
MAX_CHUNK_TOKENS = 8192
|
12 |
PROMPT_OVERHEAD = 300
|
13 |
-
BATCH_SIZE = 3 #
|
14 |
|
15 |
# Paths
|
16 |
persistent_dir = "/data/hf_cache"
|
@@ -84,14 +83,14 @@ def init_agent() -> TxAgent:
|
|
84 |
agent.init_model()
|
85 |
return agent
|
86 |
|
87 |
-
#
|
88 |
-
def
|
89 |
-
results = [
|
90 |
-
|
91 |
-
def worker(index, batch):
|
92 |
prompt = "\n\n".join(build_prompt(chunk) for chunk in batch)
|
93 |
if estimate_tokens(prompt) > MAX_MODEL_TOKENS:
|
94 |
-
|
|
|
95 |
response = ""
|
96 |
try:
|
97 |
for r in agent.run_gradio_chat(
|
@@ -111,19 +110,9 @@ def analyze_parallel(agent, batch_chunks: List[List[str]], max_workers: int = 3)
|
|
111 |
response += m.content
|
112 |
elif hasattr(r, "content"):
|
113 |
response += r.content
|
114 |
-
|
115 |
except Exception as e:
|
116 |
-
|
117 |
-
|
118 |
-
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
119 |
-
futures = {executor.submit(worker, idx, batch): idx for idx, batch in enumerate(batch_chunks)}
|
120 |
-
for future in futures:
|
121 |
-
idx = futures[future]
|
122 |
-
try:
|
123 |
-
results[idx] = future.result()
|
124 |
-
except Exception as e:
|
125 |
-
results[idx] = f"β Error in batch {idx+1}: {str(e)}"
|
126 |
-
|
127 |
gc.collect()
|
128 |
return results
|
129 |
|
@@ -161,7 +150,7 @@ def process_report(agent, file, messages: List[Dict[str, str]]) -> Tuple[List[Di
|
|
161 |
batch_chunks = [chunks[i:i+BATCH_SIZE] for i in range(0, len(chunks), BATCH_SIZE)]
|
162 |
messages.append({"role": "assistant", "content": f"π Split into {len(batch_chunks)} batches. Analyzing..."})
|
163 |
|
164 |
-
chunk_results =
|
165 |
valid = [res for res in chunk_results if not res.startswith("β")]
|
166 |
|
167 |
if not valid:
|
@@ -181,15 +170,14 @@ def process_report(agent, file, messages: List[Dict[str, str]]) -> Tuple[List[Di
|
|
181 |
messages.append({"role": "assistant", "content": f"β Error: {str(e)}"})
|
182 |
return messages, None
|
183 |
|
184 |
-
|
185 |
def create_ui(agent):
|
186 |
with gr.Blocks(css="""
|
187 |
html, body, .gradio-container {
|
188 |
background-color: #0e1621;
|
189 |
color: #e0e0e0;
|
190 |
font-family: 'Inter', sans-serif;
|
191 |
-
margin: 0;
|
192 |
padding: 0;
|
|
|
193 |
}
|
194 |
h2, h3, h4 {
|
195 |
color: #89b4fa;
|
@@ -215,7 +203,6 @@ def create_ui(agent):
|
|
215 |
}
|
216 |
.gr-chatbot .message {
|
217 |
font-size: 16px;
|
218 |
-
line-height: 1.6;
|
219 |
padding: 12px 16px;
|
220 |
border-radius: 18px;
|
221 |
margin: 8px 0;
|
@@ -242,10 +229,8 @@ def create_ui(agent):
|
|
242 |
<h2>π CPS: Clinical Patient Support System</h2>
|
243 |
<p>CPS Assistant helps you analyze and summarize unstructured medical files using AI.</p>
|
244 |
""")
|
245 |
-
|
246 |
with gr.Column():
|
247 |
chatbot = gr.Chatbot(label="CPS Assistant", height=700, type="messages")
|
248 |
-
|
249 |
upload = gr.File(label="Upload Medical File", file_types=[".xlsx"])
|
250 |
analyze = gr.Button("π§ Analyze", variant="primary")
|
251 |
download = gr.File(label="Download Report", visible=False, interactive=False)
|
@@ -268,4 +253,3 @@ if __name__ == "__main__":
|
|
268 |
except Exception as err:
|
269 |
print(f"Startup failed: {err}")
|
270 |
sys.exit(1)
|
271 |
-
|
|
|
3 |
from datetime import datetime
|
4 |
from typing import List, Tuple, Dict, Union
|
5 |
import gradio as gr
|
|
|
6 |
|
7 |
# Constants
|
8 |
MAX_MODEL_TOKENS = 131072
|
9 |
MAX_NEW_TOKENS = 4096
|
10 |
MAX_CHUNK_TOKENS = 8192
|
11 |
PROMPT_OVERHEAD = 300
|
12 |
+
BATCH_SIZE = 3 # group 3 chunks together
|
13 |
|
14 |
# Paths
|
15 |
persistent_dir = "/data/hf_cache"
|
|
|
83 |
agent.init_model()
|
84 |
return agent
|
85 |
|
86 |
+
# Serial processing (safe for vLLM)
|
87 |
+
def analyze_serial(agent, batch_chunks: List[List[str]]) -> List[str]:
|
88 |
+
results = []
|
89 |
+
for idx, batch in enumerate(batch_chunks):
|
|
|
90 |
prompt = "\n\n".join(build_prompt(chunk) for chunk in batch)
|
91 |
if estimate_tokens(prompt) > MAX_MODEL_TOKENS:
|
92 |
+
results.append(f"β Batch {idx+1} too long. Skipped.")
|
93 |
+
continue
|
94 |
response = ""
|
95 |
try:
|
96 |
for r in agent.run_gradio_chat(
|
|
|
110 |
response += m.content
|
111 |
elif hasattr(r, "content"):
|
112 |
response += r.content
|
113 |
+
results.append(clean_response(response))
|
114 |
except Exception as e:
|
115 |
+
results.append(f"β Error in batch {idx+1}: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
gc.collect()
|
117 |
return results
|
118 |
|
|
|
150 |
batch_chunks = [chunks[i:i+BATCH_SIZE] for i in range(0, len(chunks), BATCH_SIZE)]
|
151 |
messages.append({"role": "assistant", "content": f"π Split into {len(batch_chunks)} batches. Analyzing..."})
|
152 |
|
153 |
+
chunk_results = analyze_serial(agent, batch_chunks)
|
154 |
valid = [res for res in chunk_results if not res.startswith("β")]
|
155 |
|
156 |
if not valid:
|
|
|
170 |
messages.append({"role": "assistant", "content": f"β Error: {str(e)}"})
|
171 |
return messages, None
|
172 |
|
|
|
173 |
def create_ui(agent):
|
174 |
with gr.Blocks(css="""
|
175 |
html, body, .gradio-container {
|
176 |
background-color: #0e1621;
|
177 |
color: #e0e0e0;
|
178 |
font-family: 'Inter', sans-serif;
|
|
|
179 |
padding: 0;
|
180 |
+
margin: 0;
|
181 |
}
|
182 |
h2, h3, h4 {
|
183 |
color: #89b4fa;
|
|
|
203 |
}
|
204 |
.gr-chatbot .message {
|
205 |
font-size: 16px;
|
|
|
206 |
padding: 12px 16px;
|
207 |
border-radius: 18px;
|
208 |
margin: 8px 0;
|
|
|
229 |
<h2>π CPS: Clinical Patient Support System</h2>
|
230 |
<p>CPS Assistant helps you analyze and summarize unstructured medical files using AI.</p>
|
231 |
""")
|
|
|
232 |
with gr.Column():
|
233 |
chatbot = gr.Chatbot(label="CPS Assistant", height=700, type="messages")
|
|
|
234 |
upload = gr.File(label="Upload Medical File", file_types=[".xlsx"])
|
235 |
analyze = gr.Button("π§ Analyze", variant="primary")
|
236 |
download = gr.File(label="Download Report", visible=False, interactive=False)
|
|
|
253 |
except Exception as err:
|
254 |
print(f"Startup failed: {err}")
|
255 |
sys.exit(1)
|
|