Ali2206 commited on
Commit
78b3332
Β·
verified Β·
1 Parent(s): ea2488a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +402 -133
app.py CHANGED
@@ -1,155 +1,424 @@
1
- # ───────────────────────────────────────────────────────── app.py ─────────
2
- import os, sys, json, re, gc, time, hashlib, logging, shutil, subprocess
3
- from typing import List, Any
 
 
 
 
4
  from concurrent.futures import ThreadPoolExecutor, as_completed
5
-
6
- import torch, gradio as gr, psutil
 
 
 
 
 
 
7
  from diskcache import Cache
 
 
 
 
 
 
 
 
8
 
9
- # ---------- CONFIG ----------
10
- MODEL_NAME = "mims-harvard/TxAgent-T1-Llama-3.1-8B"
11
- RAG_MODEL = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B"
12
- PROMPT_MAX = 512
13
- GPU_UTIL = 0.90 # leave a little head‑room
14
-
15
- PERSIST = "/data/hf_cache"
16
- MODEL_CACHE = os.path.join(PERSIST, "txagent_models")
17
- TOOL_CACHE = os.path.join(PERSIST, "tool_cache")
18
- FILE_CACHE = os.path.join(PERSIST, "preprocessed")
19
- REPORT_DIR = os.path.join(PERSIST, "reports")
20
- for d in (MODEL_CACHE, TOOL_CACHE, FILE_CACHE, REPORT_DIR):
21
- os.makedirs(d, exist_ok=True)
22
-
23
- os.environ.update(
24
- HF_HOME = MODEL_CACHE,
25
- TRANSFORMERS_CACHE = MODEL_CACHE,
26
- VLLM_CACHE_DIR = os.path.join(PERSIST, "vllm_cache"),
27
- TOKENIZERS_PARALLELISM = "false",
28
- )
29
-
30
- ROOT = os.path.dirname(os.path.abspath(__file__))
31
- sys.path.insert(0, os.path.join(ROOT, "src"))
32
-
33
- from txagent.txagent import TxAgent # noqa: E402
34
-
35
- logging.basicConfig(
36
- level = logging.INFO,
37
- format="%(asctime)s %(levelname)s %(name)s β€” %(message)s")
38
- log = logging.getLogger("app")
39
-
40
- cache = Cache(FILE_CACHE, size_limit=20 * 1024**3) # 20Β GB
41
-
42
-
43
- # ---------- GPUΒ /Β CPUΒ helpers ----------
44
- def _gpu_ok() -> bool:
45
- return torch.cuda.is_available() and torch.cuda.device_count() > 0
46
-
47
- def _sys(tag=""):
48
- cpu = psutil.cpu_percent()
49
- ram = psutil.virtual_memory()
50
- log.info("[%s] CPU %.1f%% β€” RAM %.1fΒ /Β %.1fΒ GB",
51
- tag, cpu, ram.used/1e9, ram.total/1e9)
52
-
53
- # ---------- AGENT LOADER ----------
54
- def _init_vllm() -> TxAgent:
55
- from vllm import LLM # local import avoids import‑time CUDA checks
56
- agent = TxAgent(
57
- model_name = MODEL_NAME,
58
- rag_model_name = RAG_MODEL,
59
- step_rag_num = 4,
60
- force_finish = True,
61
- enable_checker = False,
62
- seed = 42,
63
- )
64
- # monkey‑patch TxAgent.load_models to use enforced kwargs
65
- def _load():
66
- agent.model = LLM(
67
- model = MODEL_NAME,
68
- dtype = "half",
69
- gpu_memory_utilization = GPU_UTIL,
70
- enforce_eager = True, # avoids CUDAGraph crashes
71
- )
72
- agent.load_models = _load # type: ignore
73
- agent.init_model()
74
- return agent
75
 
 
 
 
76
 
77
- def _init_cpu_pipe():
78
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
79
- tok = AutoTokenizer.from_pretrained("NousResearch/Nous-Hermes-2-Mistral-7B-DPO")
80
- mdl = AutoModelForCausalLM.from_pretrained(
81
- "NousResearch/Nous-Hermes-2-Mistral-7B-DPO",
82
- torch_dtype = (torch.float16 if _gpu_ok() else torch.float32),
83
- device_map = ("auto" if _gpu_ok() else None),
84
- )
85
- return pipeline("text-generation", model=mdl, tokenizer=tok,
86
- max_new_tokens=PROMPT_MAX, device=0 if _gpu_ok() else -1)
87
 
88
- def init_agent():
89
- _sys("before‑load")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  try:
91
- agent = _init_vllm()
92
- log.info("βœ… vLLM loaded on GPU")
93
- agent.generator = None # mark as vLLM path
 
 
 
 
 
 
 
 
 
 
 
 
94
  except Exception as e:
95
- log.warning("⚠️ vLLM path failed (%s) β†’ falling back to HF pipeline", e)
96
- pipe = _init_cpu_pipe()
97
- agent = TxAgent(dummy=True) # bare object; we'll store pipe on it
98
- agent.generator = pipe
99
- _sys("after‑load")
100
- return agent
101
 
102
- AGENT = init_agent()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
- # ---------- LLM utility ----------
106
- def run_llm(prompt: str) -> str:
107
- """Unified call for either vLLM or HF pipeline"""
108
- if AGENT.generator is None: # vLLM path
109
- out = list(AGENT.run_gradio_chat(prompt, [], 0.2,
110
- PROMPT_MAX, 2048, False, []))[-1]
111
- return out.content if hasattr(out, "content") else str(out)
112
- # HF pipeline path
113
- return AGENT.generator(prompt)[0]["generated_text"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
- # ---------- (dummy)Β IO helpers ----------
117
- def md5(path: str) -> str:
118
- h = hashlib.md5()
119
- with open(path, "rb") as f:
120
- for chunk in iter(lambda: f.read(1 << 20), b""):
121
- h.update(chunk)
122
- return h.hexdigest()
123
 
 
 
 
 
 
 
 
 
124
 
125
- # ---------- GRADIO ----------
126
- def analyze(q, hist, _files):
127
- hist.append({"role": "user", "content": q})
128
- yield hist, None, ""
 
 
 
129
 
130
- # (File‑parsing code omitted here for brevity β€” keep your fast PDF/CSV parts)
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
- answer = run_llm("Summarise missed diagnoses only:\n\n" + q)
133
- hist.append({"role": "assistant", "content": answer})
134
- yield hist, None, answer
 
 
 
 
 
135
 
136
- def ui():
137
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
138
- gr.Markdown("<h1 style='text-align:center'>🩺 Clinical Oversight Assistant</h1>")
139
- chat = gr.Chatbot(height=600, type="messages")
140
- summ = gr.Markdown()
141
- ask = gr.Textbox(placeholder="Ask…", show_label=False)
142
- btn = gr.Button("Analyze", variant="primary")
143
-
144
- btn.click(analyze, [ask, gr.State([]), gr.State([])], [chat, gr.State(None), summ])
145
- ask.submit(analyze, [ask, gr.State([]), gr.State([])], [chat, gr.State(None), summ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  return demo
147
 
148
  if __name__ == "__main__":
149
- ui().queue(api_open=False).launch(
150
- server_name="0.0.0.0",
151
- server_port=7860,
152
- allowed_paths=[REPORT_DIR],
153
- show_error=True,
154
- )
155
- # ──────────────────────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import pandas as pd
4
+ import pdfplumber
5
+ import json
6
+ import gradio as gr
7
+ from typing import List, Tuple, Optional, Generator
8
  from concurrent.futures import ThreadPoolExecutor, as_completed
9
+ import hashlib
10
+ import shutil
11
+ import re
12
+ import psutil
13
+ import subprocess
14
+ import logging
15
+ import torch
16
+ import gc
17
  from diskcache import Cache
18
+ import time
19
+ import pyarrow as pa
20
+ import pyarrow.parquet as pq
21
+ import pyarrow.csv as pc
22
+ import numpy as np
23
+ from functools import partial
24
+ from itertools import islice
25
+ import io
26
 
27
+ # Configure logging
28
+ logging.basicConfig(level=logging.INFO)
29
+ logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ # Persistent directory
32
+ persistent_dir = "/data/hf_cache"
33
+ os.makedirs(persistent_dir, exist_ok=True)
34
 
35
+ model_cache_dir = os.path.join(persistent_dir, "txagent_models")
36
+ tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
37
+ file_cache_dir = os.path.join(persistent_dir, "cache")
38
+ report_dir = os.path.join(persistent_dir, "reports")
39
+ vllm_cache_dir = os.path.join(persistent_dir, "vllm_cache")
 
 
 
 
 
40
 
41
+ for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
42
+ os.makedirs(directory, exist_ok=True)
43
+
44
+ os.environ["HF_HOME"] = model_cache_dir
45
+ os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
46
+ os.environ["VLLM_CACHE_DIR"] = vllm_cache_dir
47
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
48
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
49
+
50
+ current_dir = os.path.dirname(os.path.abspath(__file__))
51
+ src_path = os.path.abspath(os.path.join(current_dir, "src"))
52
+ sys.path.insert(0, src_path)
53
+
54
+ from txagent.txagent import TxAgent
55
+
56
+ # Initialize cache with 10GB limit
57
+ cache = Cache(file_cache_dir, size_limit=10 * 1024**3)
58
+
59
+ def sanitize_utf8(text: str) -> str:
60
+ return text.encode("utf-8", "ignore").decode("utf-8")
61
+
62
+ def file_hash(path: str) -> str:
63
+ with open(path, "rb") as f:
64
+ return hashlib.md5(f.read()).hexdigest()
65
+
66
+ def extract_all_pages(file_path: str, progress_callback=None) -> str:
67
+ try:
68
+ with pdfplumber.open(file_path) as pdf:
69
+ total_pages = len(pdf.pages)
70
+ if total_pages == 0:
71
+ return ""
72
+
73
+ batch_size = 10
74
+ batches = [(i, min(i + batch_size, total_pages)) for i in range(0, total_pages, batch_size)]
75
+ text_chunks = [""] * total_pages
76
+ processed_pages = 0
77
+
78
+ def extract_batch(start: int, end: int) -> List[tuple]:
79
+ results = []
80
+ with pdfplumber.open(file_path) as pdf:
81
+ for page in pdf.pages[start:end]:
82
+ page_num = start + pdf.pages.index(page)
83
+ page_text = page.extract_text() or ""
84
+ results.append((page_num, f"=== Page {page_num + 1} ===\n{page_text.strip()}"))
85
+ return results
86
+
87
+ with ThreadPoolExecutor(max_workers=6) as executor:
88
+ futures = [executor.submit(extract_batch, start, end) for start, end in batches]
89
+ for future in as_completed(futures):
90
+ for page_num, text in future.result():
91
+ text_chunks[page_num] = text
92
+ processed_pages += batch_size
93
+ if progress_callback:
94
+ progress_callback(min(processed_pages, total_pages), total_pages)
95
+
96
+ return "\n\n".join(filter(None, text_chunks))
97
+ except Exception as e:
98
+ logger.error("PDF processing error: %s", e)
99
+ return f"PDF processing error: {str(e)}"
100
+
101
+ def excel_to_ndjson(file_path: str) -> Generator[str, None, None]:
102
+ """Stream Excel file as NDJSON for maximum performance"""
103
  try:
104
+ # Use openpyxl in streaming mode
105
+ with pd.ExcelFile(file_path, engine='openpyxl') as xls:
106
+ for sheet_name in xls.sheet_names:
107
+ for chunk in pd.read_excel(
108
+ xls,
109
+ sheet_name=sheet_name,
110
+ header=None,
111
+ dtype=str,
112
+ chunksize=1000
113
+ ):
114
+ for _, row in chunk.iterrows():
115
+ yield json.dumps({
116
+ "sheet": sheet_name,
117
+ "row": row.fillna("").astype(str).tolist()
118
+ }) + "\n"
119
  except Exception as e:
120
+ logger.error(f"Error streaming Excel: {e}")
121
+ raise
 
 
 
 
122
 
123
+ def csv_to_ndjson(file_path: str) -> Generator[str, None, None]:
124
+ """Stream CSV file as NDJSON for maximum performance"""
125
+ try:
126
+ for chunk in pd.read_csv(
127
+ file_path,
128
+ header=None,
129
+ dtype=str,
130
+ chunksize=1000,
131
+ encoding_errors='replace',
132
+ on_bad_lines='skip'
133
+ ):
134
+ for _, row in chunk.iterrows():
135
+ yield json.dumps({
136
+ "row": row.fillna("").astype(str).tolist()
137
+ }) + "\n"
138
+ except Exception as e:
139
+ logger.error(f"Error streaming CSV: {e}")
140
+ raise
141
+
142
+ def stream_file_to_json(file_path: str, file_type: str) -> Generator[str, None, None]:
143
+ """Stream file content as JSON chunks"""
144
+ try:
145
+ if file_type == "pdf":
146
+ text = extract_all_pages(file_path)
147
+ yield json.dumps({
148
+ "filename": os.path.basename(file_path),
149
+ "content": text,
150
+ "status": "initial"
151
+ })
152
+ elif file_type in ["csv", "xls", "xlsx"]:
153
+ # Stream the file content
154
+ yield json.dumps({
155
+ "filename": os.path.basename(file_path),
156
+ "streaming": True,
157
+ "type": file_type
158
+ })
159
+
160
+ if file_type == "csv":
161
+ stream_gen = csv_to_ndjson(file_path)
162
+ else:
163
+ stream_gen = excel_to_ndjson(file_path)
164
+
165
+ for chunk in stream_gen:
166
+ yield chunk
167
+ else:
168
+ yield json.dumps({"error": f"Unsupported file type: {file_type}"})
169
+ except Exception as e:
170
+ logger.error("Error processing %s: %s", os.path.basename(file_path), e)
171
+ yield json.dumps({"error": f"Error processing {os.path.basename(file_path)}: {str(e)}"})
172
 
173
+ def log_system_usage(tag=""):
174
+ try:
175
+ cpu = psutil.cpu_percent(interval=1)
176
+ mem = psutil.virtual_memory()
177
+ logger.info("[%s] CPU: %.1f%% | RAM: %dMB / %dMB", tag, cpu, mem.used // (1024**2), mem.total // (1024**2))
178
+ result = subprocess.run(
179
+ ["nvidia-smi", "--query-gpu=memory.used,memory.total,utilization.gpu", "--format=csv,nounits,noheader"],
180
+ capture_output=True, text=True
181
+ )
182
+ if result.returncode == 0:
183
+ used, total, util = result.stdout.strip().split(", ")
184
+ logger.info("[%s] GPU: %sMB / %sMB | Utilization: %s%%", tag, used, total, util)
185
+ except Exception as e:
186
+ logger.error("[%s] GPU/CPU monitor failed: %s", tag, e)
187
 
188
+ def clean_response(text: str) -> str:
189
+ text = sanitize_utf8(text)
190
+ text = re.sub(r"\[.*?\]|\bNone\b|To analyze the patient record excerpt.*?medications\.|Since the previous attempts.*?\.|I need to.*?medications\.|Retrieving tools.*?\.", "", text, flags=re.DOTALL)
191
+ diagnoses = []
192
+ lines = text.splitlines()
193
+ in_diagnoses_section = False
194
+ for line in lines:
195
+ line = line.strip()
196
+ if not line:
197
+ continue
198
+ if re.match(r"###\s*Missed Diagnoses", line):
199
+ in_diagnoses_section = True
200
+ continue
201
+ if re.match(r"###\s*(Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line):
202
+ in_diagnoses_section = False
203
+ continue
204
+ if in_diagnoses_section and re.match(r"-\s*.+", line):
205
+ diagnosis = re.sub(r"^\-\s*", "", line).strip()
206
+ if diagnosis and not re.match(r"No issues identified", diagnosis, re.IGNORECASE):
207
+ diagnoses.append(diagnosis)
208
+ text = " ".join(diagnoses)
209
+ text = re.sub(r"\s+", " ", text).strip()
210
+ text = re.sub(r"[^\w\s\.\,\(\)\-]", "", text)
211
+ return text if text else ""
212
 
213
+ def summarize_findings(combined_response: str) -> str:
214
+ chunks = combined_response.split("--- Analysis for Chunk")
215
+ diagnoses = []
216
+ for chunk in chunks:
217
+ chunk = chunk.strip()
218
+ if not chunk or "No oversights identified" in chunk:
219
+ continue
220
+ lines = chunk.splitlines()
221
+ in_diagnoses_section = False
222
+ for line in lines:
223
+ line = line.strip()
224
+ if not line:
225
+ continue
226
+ if re.match(r"###\s*Missed Diagnoses", line):
227
+ in_diagnoses_section = True
228
+ continue
229
+ if re.match(r"###\s*(Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line):
230
+ in_diagnoses_section = False
231
+ continue
232
+ if in_diagnoses_section and re.match(r"-\s*.+", line):
233
+ diagnosis = re.sub(r"^\-\s*", "", line).strip()
234
+ if diagnosis and not re.match(r"No issues identified", diagnosis, re.IGNORECASE):
235
+ diagnoses.append(diagnosis)
236
 
237
+ seen = set()
238
+ unique_diagnoses = [d for d in diagnoses if not (d in seen or seen.add(d))]
239
+
240
+ if not unique_diagnoses:
241
+ return "No missed diagnoses were identified in the provided records."
 
 
242
 
243
+ summary = "Missed diagnoses include " + ", ".join(unique_diagnoses[:-1])
244
+ if len(unique_diagnoses) > 1:
245
+ summary += f", and {unique_diagnoses[-1]}"
246
+ elif len(unique_diagnoses) == 1:
247
+ summary = "Missed diagnoses include " + unique_diagnoses[0]
248
+ summary += ", all of which require urgent clinical review to prevent potential adverse outcomes."
249
+
250
+ return summary.strip()
251
 
252
+ def init_agent():
253
+ logger.info("Initializing model...")
254
+ log_system_usage("Before Load")
255
+ default_tool_path = os.path.abspath("data/new_tool.json")
256
+ target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
257
+ if not os.path.exists(target_tool_path):
258
+ shutil.copy(default_tool_path, target_tool_path)
259
 
260
+ agent = TxAgent(
261
+ model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
262
+ rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
263
+ tool_files_dict={"new_tool": target_tool_path},
264
+ force_finish=True,
265
+ enable_checker=False,
266
+ step_rag_num=4,
267
+ seed=100,
268
+ additional_default_tools=[],
269
+ )
270
+ agent.init_model()
271
+ log_system_usage("After Load")
272
+ logger.info("Agent Ready")
273
+ return agent
274
 
275
+ def batched(iterable, n):
276
+ """Batch data into tuples of length n. The last batch may be shorter."""
277
+ it = iter(iterable)
278
+ while True:
279
+ batch = list(islice(it, n))
280
+ if not batch:
281
+ return
282
+ yield batch
283
 
284
+ def create_ui(agent):
285
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
286
+ gr.Markdown("<h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1>")
287
+ chatbot = gr.Chatbot(label="Detailed Analysis", height=600, type="messages")
288
+ final_summary = gr.Markdown(label="Summary of Missed Diagnoses")
289
+ file_upload = gr.File(file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple")
290
+ msg_input = gr.Textbox(placeholder="Ask about potential oversights...", show_label=False)
291
+ send_btn = gr.Button("Analyze", variant="primary")
292
+ download_output = gr.File(label="Download Full Report")
293
+ progress_bar = gr.Progress()
294
+
295
+ prompt_template = """
296
+ Analyze the patient record excerpt for missed diagnoses only. Provide a concise, evidence-based summary as a single paragraph without headings or bullet points. Include specific clinical findings (e.g., 'elevated blood pressure (160/95) on page 10'), their potential implications (e.g., 'may indicate untreated hypertension'), and a recommendation for urgent review. Do not include other oversight categories like medication conflicts. If no missed diagnoses are found, state 'No missed diagnoses identified' in a single sentence.
297
+ Patient Record Excerpt (Chunk {0} of {1}):
298
+ {chunk}
299
+ """
300
+
301
+ def analyze(message: str, history: List[dict], files: List, progress=gr.Progress()):
302
+ history.append({"role": "user", "content": message})
303
+ yield history, None, ""
304
+
305
+ extracted = []
306
+ file_hash_value = ""
307
+
308
+ if files:
309
+ # Process files in parallel with streaming
310
+ with ThreadPoolExecutor(max_workers=4) as executor:
311
+ futures = []
312
+ for f in files:
313
+ file_type = f.name.split(".")[-1].lower()
314
+ futures.append(executor.submit(
315
+ lambda f: list(stream_file_to_json(f.name, file_type)),
316
+ f
317
+ ))
318
+
319
+ for future in as_completed(futures):
320
+ try:
321
+ extracted.extend(future.result())
322
+ except Exception as e:
323
+ logger.error(f"File processing error: {e}")
324
+ extracted.append(json.dumps({
325
+ "error": f"Error processing file: {str(e)}"
326
+ }))
327
+
328
+ file_hash_value = file_hash(files[0].name) if files else ""
329
+ history.append({"role": "assistant", "content": "βœ… File processing complete"})
330
+ yield history, None, ""
331
+
332
+ # Process chunks in parallel with dynamic batching
333
+ chunk_size = 8000 # Larger chunks reduce overhead
334
+ combined_response = ""
335
+
336
+ try:
337
+ # Convert extracted data to text chunks
338
+ text_content = "\n".join(extracted)
339
+ chunks = [text_content[i:i+chunk_size] for i in range(0, len(text_content), chunk_size)]
340
+
341
+ # Process chunks in parallel batches
342
+ batch_size = 4 # Optimal for most GPUs
343
+ total_chunks = len(chunks)
344
+
345
+ for batch_idx, batch_chunks in enumerate(batched(chunks, batch_size)):
346
+ batch_prompts = [
347
+ prompt_template.format(
348
+ batch_idx * batch_size + i + 1,
349
+ total_chunks,
350
+ chunk=chunk[:6000] # Slightly larger context
351
+ )
352
+ for i, chunk in enumerate(batch_chunks)
353
+ ]
354
+
355
+ progress((batch_idx * batch_size) / total_chunks,
356
+ desc=f"Analyzing batch {batch_idx + 1}/{(total_chunks + batch_size - 1) // batch_size}")
357
+
358
+ # Process batch in parallel
359
+ with ThreadPoolExecutor(max_workers=len(batch_prompts)) as executor:
360
+ future_to_prompt = {
361
+ executor.submit(
362
+ agent.run_gradio_chat,
363
+ prompt, [], 0.2, 512, 2048, False, []
364
+ ): prompt
365
+ for prompt in batch_prompts
366
+ }
367
+
368
+ for future in as_completed(future_to_prompt):
369
+ chunk_response = ""
370
+ for chunk_output in future.result():
371
+ if chunk_output is None:
372
+ continue
373
+ if isinstance(chunk_output, list):
374
+ for m in chunk_output:
375
+ if hasattr(m, 'content') and m.content:
376
+ cleaned = clean_response(m.content)
377
+ if cleaned:
378
+ chunk_response += cleaned + " "
379
+ elif isinstance(chunk_output, str) and chunk_output.strip():
380
+ cleaned = clean_response(chunk_output)
381
+ if cleaned:
382
+ chunk_response += cleaned + " "
383
+
384
+ combined_response += f"--- Analysis for Chunk {batch_idx * batch_size + 1} ---\n{chunk_response.strip()}\n"
385
+ history[-1] = {"role": "assistant", "content": combined_response.strip()}
386
+ yield history, None, ""
387
+
388
+ # Clean up memory
389
+ torch.cuda.empty_cache()
390
+ gc.collect()
391
+
392
+ # Generate final summary
393
+ summary = summarize_findings(combined_response)
394
+ report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt") if file_hash_value else None
395
+ if report_path:
396
+ with open(report_path, "w", encoding="utf-8") as f:
397
+ f.write(combined_response + "\n\n" + summary)
398
+
399
+ yield history, report_path if report_path and os.path.exists(report_path) else None, summary
400
+
401
+ except Exception as e:
402
+ logger.error("Analysis error: %s", e)
403
+ history.append({"role": "assistant", "content": f"❌ Error occurred: {str(e)}"})
404
+ yield history, None, f"Error occurred during analysis: {str(e)}"
405
+
406
+ send_btn.click(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output, final_summary])
407
+ msg_input.submit(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output, final_summary])
408
  return demo
409
 
410
  if __name__ == "__main__":
411
+ try:
412
+ logger.info("Launching app...")
413
+ agent = init_agent()
414
+ demo = create_ui(agent)
415
+ demo.queue(api_open=False).launch(
416
+ server_name="0.0.0.0",
417
+ server_port=7860,
418
+ show_error=True,
419
+ allowed_paths=[report_dir],
420
+ share=False
421
+ )
422
+ finally:
423
+ if torch.distributed.is_initialized():
424
+ torch.distributed.destroy_process_group()