tejash300 commited on
Commit
4e897df
·
verified ·
1 Parent(s): d5c52ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -11
app.py CHANGED
@@ -406,17 +406,32 @@ def analyze_contract_clauses(text):
406
  "Assignment", "Warranty", "Limitation of Liability", "Arbitration",
407
  "IP Rights", "Force Majeure", "Revenue/Profit Sharing", "Audit Rights"
408
  ]
409
- chunks = [text[i:i+max_length] for i in range(0, len(text), step) if i+step < len(text)]
410
- for chunk in chunks:
411
- # Move each tensor individually to the device
412
- tokenized_inputs = cuad_tokenizer(chunk, return_tensors="pt", truncation=True, max_length=512)
413
- inputs = {k: v.to(device) for k, v in tokenized_inputs.items()}
414
- with torch.no_grad():
415
- outputs = cuad_model(**inputs)
416
- predictions = torch.sigmoid(outputs.start_logits).cpu().numpy()[0]
417
- for idx, confidence in enumerate(predictions):
418
- if confidence > 0.5 and idx < len(clause_types):
419
- clauses_detected.append({"type": clause_types[idx], "confidence": float(confidence)})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  aggregated_clauses = {}
421
  for clause in clauses_detected:
422
  clause_type = clause["type"]
 
406
  "Assignment", "Warranty", "Limitation of Liability", "Arbitration",
407
  "IP Rights", "Force Majeure", "Revenue/Profit Sharing", "Audit Rights"
408
  ]
409
+ # Process text in chunks of 'max_length' with a step size 'step'
410
+ for i in range(0, len(text), step):
411
+ chunk = text[i:i+max_length]
412
+ if not chunk.strip():
413
+ continue # Skip empty chunks
414
+ try:
415
+ tokenized_inputs = cuad_tokenizer(chunk, return_tensors="pt", truncation=True, max_length=512)
416
+ inputs = {k: v.to(device) for k, v in tokenized_inputs.items()}
417
+ # Check that token IDs are within vocabulary bounds
418
+ max_token = inputs["input_ids"].max().item()
419
+ if max_token >= cuad_model.config.vocab_size:
420
+ print(f"Skipping chunk due to invalid token id: {max_token}")
421
+ continue
422
+ with torch.no_grad():
423
+ outputs = cuad_model(**inputs)
424
+ # Optional: verify shape consistency
425
+ if outputs.start_logits.shape[1] != inputs["input_ids"].shape[1]:
426
+ print("Mismatch in logits shape, skipping chunk")
427
+ continue
428
+ predictions = torch.sigmoid(outputs.start_logits).cpu().numpy()[0]
429
+ for idx, confidence in enumerate(predictions):
430
+ if confidence > 0.5 and idx < len(clause_types):
431
+ clauses_detected.append({"type": clause_types[idx], "confidence": float(confidence)})
432
+ except Exception as e:
433
+ print(f"Error processing chunk: {e}")
434
+ continue
435
  aggregated_clauses = {}
436
  for clause in clauses_detected:
437
  clause_type = clause["type"]