Update ui/ui_core.py
Browse files- ui/ui_core.py +31 -29
ui/ui_core.py
CHANGED
|
@@ -5,10 +5,8 @@ import pdfplumber
|
|
| 5 |
import json
|
| 6 |
import gradio as gr
|
| 7 |
from typing import List
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
from PIL import Image
|
| 11 |
-
import torch
|
| 12 |
|
| 13 |
# ✅ Fix: Add src to Python path
|
| 14 |
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src")))
|
|
@@ -36,24 +34,20 @@ def clean_final_response(text: str) -> str:
|
|
| 36 |
)
|
| 37 |
return "".join(panels)
|
| 38 |
|
| 39 |
-
def
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
image = Image.open(image_path).convert("RGB")
|
| 44 |
-
encoding = processor(images=image, return_tensors="pt")
|
| 45 |
-
with torch.no_grad():
|
| 46 |
-
outputs = model(**encoding)
|
| 47 |
-
|
| 48 |
-
logits = outputs.logits
|
| 49 |
-
predicted_class = logits.argmax(-1)
|
| 50 |
-
tokens = processor.tokenizer.convert_ids_to_tokens(encoding['input_ids'][0])
|
| 51 |
-
|
| 52 |
-
text = " ".join([tokens[i] for i in range(len(tokens)) if predicted_class[0][i] != -100])
|
| 53 |
-
return json.dumps({"filename": os.path.basename(image_path), "content": text})
|
| 54 |
|
| 55 |
def convert_file_to_json(file_path: str, file_type: str) -> str:
|
| 56 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
if file_type == "csv":
|
| 58 |
df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str, skip_blank_lines=False, on_bad_lines="skip")
|
| 59 |
elif file_type in ["xls", "xlsx"]:
|
|
@@ -62,7 +56,11 @@ def convert_file_to_json(file_path: str, file_type: str) -> str:
|
|
| 62 |
except:
|
| 63 |
df = pd.read_excel(file_path, engine="xlrd", header=None, dtype=str)
|
| 64 |
elif file_type == "pdf":
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
else:
|
| 67 |
return json.dumps({"error": f"Unsupported file type: {file_type}"})
|
| 68 |
|
|
@@ -71,7 +69,9 @@ def convert_file_to_json(file_path: str, file_type: str) -> str:
|
|
| 71 |
|
| 72 |
df = df.fillna("")
|
| 73 |
content = df.astype(str).values.tolist()
|
| 74 |
-
|
|
|
|
|
|
|
| 75 |
except Exception as e:
|
| 76 |
return json.dumps({"error": f"Error reading {os.path.basename(file_path)}: {str(e)}"})
|
| 77 |
|
|
@@ -133,13 +133,11 @@ def create_ui(agent: TxAgent):
|
|
| 133 |
|
| 134 |
chunks = chunk_text(extracted_text.strip())
|
| 135 |
|
| 136 |
-
|
| 137 |
-
for i, chunk in enumerate(chunks):
|
| 138 |
chunked_prompt = (
|
| 139 |
f"{context}\n\n--- Uploaded File Content (Chunk {i+1}/{len(chunks)}) ---\n\n{chunk}\n\n"
|
| 140 |
f"--- End of Chunk ---\n\nNow begin your analysis:"
|
| 141 |
)
|
| 142 |
-
|
| 143 |
generator = agent.run_gradio_chat(
|
| 144 |
message=chunked_prompt,
|
| 145 |
history=[],
|
|
@@ -151,18 +149,22 @@ def create_ui(agent: TxAgent):
|
|
| 151 |
uploaded_files=uploaded_files,
|
| 152 |
max_round=30
|
| 153 |
)
|
| 154 |
-
|
| 155 |
-
chunk_response = ""
|
| 156 |
for update in generator:
|
| 157 |
if isinstance(update, str):
|
| 158 |
-
|
| 159 |
elif isinstance(update, list):
|
| 160 |
for msg in update:
|
| 161 |
if hasattr(msg, 'content'):
|
| 162 |
-
|
|
|
|
| 163 |
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
| 165 |
|
|
|
|
| 166 |
full_response = clean_final_response(full_response.strip())
|
| 167 |
history[-1] = (message, full_response)
|
| 168 |
yield history
|
|
|
|
| 5 |
import json
|
| 6 |
import gradio as gr
|
| 7 |
from typing import List
|
| 8 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 9 |
+
import hashlib
|
|
|
|
|
|
|
| 10 |
|
| 11 |
# ✅ Fix: Add src to Python path
|
| 12 |
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src")))
|
|
|
|
| 34 |
)
|
| 35 |
return "".join(panels)
|
| 36 |
|
| 37 |
+
def file_hash(path):
|
| 38 |
+
with open(path, "rb") as f:
|
| 39 |
+
return hashlib.md5(f.read()).hexdigest()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
def convert_file_to_json(file_path: str, file_type: str) -> str:
|
| 42 |
try:
|
| 43 |
+
cache_dir = os.path.join("cache")
|
| 44 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 45 |
+
h = file_hash(file_path)
|
| 46 |
+
cache_path = os.path.join(cache_dir, f"{h}.json")
|
| 47 |
+
|
| 48 |
+
if os.path.exists(cache_path):
|
| 49 |
+
return open(cache_path, "r", encoding="utf-8").read()
|
| 50 |
+
|
| 51 |
if file_type == "csv":
|
| 52 |
df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str, skip_blank_lines=False, on_bad_lines="skip")
|
| 53 |
elif file_type in ["xls", "xlsx"]:
|
|
|
|
| 56 |
except:
|
| 57 |
df = pd.read_excel(file_path, engine="xlrd", header=None, dtype=str)
|
| 58 |
elif file_type == "pdf":
|
| 59 |
+
with pdfplumber.open(file_path) as pdf:
|
| 60 |
+
text = "\n".join([page.extract_text() or "" for page in pdf.pages])
|
| 61 |
+
result = json.dumps({"filename": os.path.basename(file_path), "content": text.strip()})
|
| 62 |
+
open(cache_path, "w", encoding="utf-8").write(result)
|
| 63 |
+
return result
|
| 64 |
else:
|
| 65 |
return json.dumps({"error": f"Unsupported file type: {file_type}"})
|
| 66 |
|
|
|
|
| 69 |
|
| 70 |
df = df.fillna("")
|
| 71 |
content = df.astype(str).values.tolist()
|
| 72 |
+
result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
|
| 73 |
+
open(cache_path, "w", encoding="utf-8").write(result)
|
| 74 |
+
return result
|
| 75 |
except Exception as e:
|
| 76 |
return json.dumps({"error": f"Error reading {os.path.basename(file_path)}: {str(e)}"})
|
| 77 |
|
|
|
|
| 133 |
|
| 134 |
chunks = chunk_text(extracted_text.strip())
|
| 135 |
|
| 136 |
+
def process_chunk(i, chunk):
|
|
|
|
| 137 |
chunked_prompt = (
|
| 138 |
f"{context}\n\n--- Uploaded File Content (Chunk {i+1}/{len(chunks)}) ---\n\n{chunk}\n\n"
|
| 139 |
f"--- End of Chunk ---\n\nNow begin your analysis:"
|
| 140 |
)
|
|
|
|
| 141 |
generator = agent.run_gradio_chat(
|
| 142 |
message=chunked_prompt,
|
| 143 |
history=[],
|
|
|
|
| 149 |
uploaded_files=uploaded_files,
|
| 150 |
max_round=30
|
| 151 |
)
|
| 152 |
+
result = ""
|
|
|
|
| 153 |
for update in generator:
|
| 154 |
if isinstance(update, str):
|
| 155 |
+
result += update
|
| 156 |
elif isinstance(update, list):
|
| 157 |
for msg in update:
|
| 158 |
if hasattr(msg, 'content'):
|
| 159 |
+
result += msg.content
|
| 160 |
+
return result
|
| 161 |
|
| 162 |
+
# ⏱ Parallel Execution for Speed
|
| 163 |
+
with ThreadPoolExecutor(max_workers=min(8, len(chunks))) as executor:
|
| 164 |
+
futures = [executor.submit(process_chunk, i, chunk) for i, chunk in enumerate(chunks)]
|
| 165 |
+
results = [f.result() for f in as_completed(futures)]
|
| 166 |
|
| 167 |
+
full_response = "\n\n".join(results)
|
| 168 |
full_response = clean_final_response(full_response.strip())
|
| 169 |
history[-1] = (message, full_response)
|
| 170 |
yield history
|