Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -406,17 +406,32 @@ def analyze_contract_clauses(text):
|
|
406 |
"Assignment", "Warranty", "Limitation of Liability", "Arbitration",
|
407 |
"IP Rights", "Force Majeure", "Revenue/Profit Sharing", "Audit Rights"
|
408 |
]
|
409 |
-
|
410 |
-
for
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
420 |
aggregated_clauses = {}
|
421 |
for clause in clauses_detected:
|
422 |
clause_type = clause["type"]
|
|
|
406 |
"Assignment", "Warranty", "Limitation of Liability", "Arbitration",
|
407 |
"IP Rights", "Force Majeure", "Revenue/Profit Sharing", "Audit Rights"
|
408 |
]
|
409 |
+
# Process text in chunks of 'max_length' with a step size 'step'
|
410 |
+
for i in range(0, len(text), step):
|
411 |
+
chunk = text[i:i+max_length]
|
412 |
+
if not chunk.strip():
|
413 |
+
continue # Skip empty chunks
|
414 |
+
try:
|
415 |
+
tokenized_inputs = cuad_tokenizer(chunk, return_tensors="pt", truncation=True, max_length=512)
|
416 |
+
inputs = {k: v.to(device) for k, v in tokenized_inputs.items()}
|
417 |
+
# Check that token IDs are within vocabulary bounds
|
418 |
+
max_token = inputs["input_ids"].max().item()
|
419 |
+
if max_token >= cuad_model.config.vocab_size:
|
420 |
+
print(f"Skipping chunk due to invalid token id: {max_token}")
|
421 |
+
continue
|
422 |
+
with torch.no_grad():
|
423 |
+
outputs = cuad_model(**inputs)
|
424 |
+
# Optional: verify shape consistency
|
425 |
+
if outputs.start_logits.shape[1] != inputs["input_ids"].shape[1]:
|
426 |
+
print("Mismatch in logits shape, skipping chunk")
|
427 |
+
continue
|
428 |
+
predictions = torch.sigmoid(outputs.start_logits).cpu().numpy()[0]
|
429 |
+
for idx, confidence in enumerate(predictions):
|
430 |
+
if confidence > 0.5 and idx < len(clause_types):
|
431 |
+
clauses_detected.append({"type": clause_types[idx], "confidence": float(confidence)})
|
432 |
+
except Exception as e:
|
433 |
+
print(f"Error processing chunk: {e}")
|
434 |
+
continue
|
435 |
aggregated_clauses = {}
|
436 |
for clause in clauses_detected:
|
437 |
clause_type = clause["type"]
|