Ali2206 commited on
Commit
2639902
Β·
verified Β·
1 Parent(s): 936692d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -15
app.py CHANGED
@@ -7,9 +7,8 @@ import gradio as gr
7
  # Constants
8
  MAX_MODEL_TOKENS = 131072
9
  MAX_NEW_TOKENS = 4096
10
- MAX_CHUNK_TOKENS = 8192
11
  PROMPT_OVERHEAD = 300
12
- BATCH_SIZE = 10 # Bigger batch for faster processing
13
 
14
  # Paths
15
  persistent_dir = "/data/hf_cache"
@@ -47,14 +46,13 @@ def extract_text_from_excel(path: str) -> str:
47
  try:
48
  df = xls.parse(sheet_name).astype(str).fillna("")
49
  except Exception:
50
- continue # Skip sheet if unreadable
51
 
52
  for idx, row in df.iterrows():
53
- # If the row has at least 2 non-empty values and is not totally empty
54
  non_empty = [cell.strip() for cell in row if cell.strip() != ""]
55
  if len(non_empty) >= 2:
56
  text_line = " | ".join(non_empty)
57
- if len(text_line) > 15: # Ignore very small lines
58
  all_text.append(f"[{sheet_name}] {text_line}")
59
 
60
  return "\n".join(all_text)
@@ -94,13 +92,12 @@ def init_agent() -> TxAgent:
94
  agent.init_model()
95
  return agent
96
 
97
- # Serial analyze (safe for vLLM)
98
- def analyze_serial(agent, batch_chunks: List[List[str]]) -> List[str]:
99
  results = []
100
- for idx, batch in enumerate(batch_chunks):
101
- prompt = "\n\n".join(build_prompt(chunk) for chunk in batch)
102
  if estimate_tokens(prompt) > MAX_MODEL_TOKENS:
103
- results.append(f"❌ Batch {idx+1} too long. Skipped.")
104
  continue
105
  response = ""
106
  try:
@@ -123,7 +120,7 @@ def analyze_serial(agent, batch_chunks: List[List[str]]) -> List[str]:
123
  response += r.content
124
  results.append(clean_response(response))
125
  except Exception as e:
126
- results.append(f"❌ Error in batch {idx+1}: {str(e)}")
127
  gc.collect()
128
  return results
129
 
@@ -158,14 +155,13 @@ def process_report(agent, file, messages: List[Dict[str, str]]) -> Tuple[List[Di
158
  try:
159
  extracted = extract_text_from_excel(file.name)
160
  chunks = split_text(extracted)
161
- batch_chunks = [chunks[i:i+BATCH_SIZE] for i in range(0, len(chunks), BATCH_SIZE)]
162
- messages.append({"role": "assistant", "content": f"πŸ” Split into {len(batch_chunks)} batches. Analyzing..."})
163
 
164
- chunk_results = analyze_serial(agent, batch_chunks)
165
  valid = [res for res in chunk_results if not res.startswith("❌")]
166
 
167
  if not valid:
168
- messages.append({"role": "assistant", "content": "❌ No valid batch outputs."})
169
  return messages, None
170
 
171
  summary = generate_final_summary(agent, "\n\n".join(valid))
 
7
  # Constants
8
  MAX_MODEL_TOKENS = 131072
9
  MAX_NEW_TOKENS = 4096
10
+ MAX_CHUNK_TOKENS = 8192 # IMPORTANT: Split input into 8k tokens chunks
11
  PROMPT_OVERHEAD = 300
 
12
 
13
  # Paths
14
  persistent_dir = "/data/hf_cache"
 
46
  try:
47
  df = xls.parse(sheet_name).astype(str).fillna("")
48
  except Exception:
49
+ continue
50
 
51
  for idx, row in df.iterrows():
 
52
  non_empty = [cell.strip() for cell in row if cell.strip() != ""]
53
  if len(non_empty) >= 2:
54
  text_line = " | ".join(non_empty)
55
+ if len(text_line) > 15:
56
  all_text.append(f"[{sheet_name}] {text_line}")
57
 
58
  return "\n".join(all_text)
 
92
  agent.init_model()
93
  return agent
94
 
95
+ def analyze_serial(agent, chunks: List[str]) -> List[str]:
 
96
  results = []
97
+ for idx, chunk in enumerate(chunks):
98
+ prompt = build_prompt(chunk)
99
  if estimate_tokens(prompt) > MAX_MODEL_TOKENS:
100
+ results.append(f"❌ Chunk {idx+1} too long. Skipped.")
101
  continue
102
  response = ""
103
  try:
 
120
  response += r.content
121
  results.append(clean_response(response))
122
  except Exception as e:
123
+ results.append(f"❌ Error in chunk {idx+1}: {str(e)}")
124
  gc.collect()
125
  return results
126
 
 
155
  try:
156
  extracted = extract_text_from_excel(file.name)
157
  chunks = split_text(extracted)
158
+ messages.append({"role": "assistant", "content": f"πŸ” Split into {len(chunks)} chunks. Analyzing..."})
 
159
 
160
+ chunk_results = analyze_serial(agent, chunks)
161
  valid = [res for res in chunk_results if not res.startswith("❌")]
162
 
163
  if not valid:
164
+ messages.append({"role": "assistant", "content": "❌ No valid chunk outputs."})
165
  return messages, None
166
 
167
  summary = generate_final_summary(agent, "\n\n".join(valid))