tejash300 commited on
Commit
e3b69f0
·
verified ·
1 Parent(s): dd43ec8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -124
app.py CHANGED
@@ -38,24 +38,24 @@ from spacy.lang.en.stop_words import STOP_WORDS
38
  # Global cache for analysis results based on file hash
39
  analysis_cache = {}
40
 
41
- # Ensure compatibility with Google Colab
42
  try:
43
  from google.colab import drive
44
  drive.mount('/content/drive')
45
  except Exception:
46
- pass # Not in Colab
47
 
48
- # Make sure directories exist
49
  os.makedirs("static", exist_ok=True)
50
  os.makedirs("temp", exist_ok=True)
51
 
52
  # Use GPU if available
53
  device = "cuda" if torch.cuda.is_available() else "cpu"
54
 
55
- # FastAPI setup
56
  app = FastAPI(title="Legal Document and Video Analyzer")
57
 
58
- # CORS
59
  app.add_middleware(
60
  CORSMiddleware,
61
  allow_origins=["*"],
@@ -64,7 +64,7 @@ app.add_middleware(
64
  allow_headers=["*"],
65
  )
66
 
67
- # In-memory storage
68
  document_storage = {}
69
  chat_history = []
70
 
@@ -79,14 +79,10 @@ def compute_md5(content: bytes) -> str:
79
  return hashlib.md5(content).hexdigest()
80
 
81
  #############################
82
- # Fine-tuning on CUAD QA #
83
  #############################
84
 
85
  def fine_tune_cuad_model():
86
- """
87
- Minimal stub for fine-tuning the CUAD QA model.
88
- If you have a full fine-tuning script, place it here.
89
- """
90
  from datasets import load_dataset
91
  from transformers import Trainer, TrainingArguments, AutoModelForQuestionAnswering, AutoTokenizer
92
 
@@ -161,6 +157,7 @@ def fine_tune_cuad_model():
161
  tokenized_examples["end_positions"].append(safe_end)
162
  return tokenized_examples
163
 
 
164
  train_dataset = train_dataset.map(prepare_train_features, batched=True, remove_columns=train_dataset.column_names)
165
  val_dataset = val_dataset.map(prepare_train_features, batched=True, remove_columns=val_dataset.column_names)
166
  train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "start_positions", "end_positions"])
@@ -201,7 +198,7 @@ def fine_tune_cuad_model():
201
  #############################
202
 
203
  try:
204
- # Load spacy
205
  try:
206
  nlp = spacy.load("en_core_web_sm")
207
  except Exception:
@@ -209,32 +206,29 @@ try:
209
  nlp = spacy.load("en_core_web_sm")
210
  print("✅ Loaded spaCy model.")
211
 
212
- # Summarizer (GPU)
213
  summarizer = pipeline(
214
  "summarization",
215
  model="facebook/bart-large-cnn",
216
  tokenizer="facebook/bart-large-cnn",
217
  device=0 if device == "cuda" else -1
218
  )
219
-
220
- # QA pipeline (GPU)
221
  qa_model = pipeline(
222
  "question-answering",
223
  model="deepset/roberta-base-squad2",
224
  device=0 if device == "cuda" else -1
225
  )
226
 
227
- # Embeddings (GPU if available)
228
  embedding_model = SentenceTransformer("all-mpnet-base-v2", device=device)
229
 
230
- # Named Entity Recognition (GPU)
231
  ner_model = pipeline("ner", model="dslim/bert-base-NER", device=0 if device == "cuda" else -1)
232
 
233
- # Speech-to-text (GPU if available via device_map="auto")
234
  speech_to_text = pipeline("automatic-speech-recognition", model="openai/whisper-medium", chunk_length_s=30,
235
  device_map="auto" if device == "cuda" else None)
236
 
237
- # Fine-tuned CUAD QA
238
  if os.path.exists("fine_tuned_legal_qa"):
239
  print("✅ Loading fine-tuned CUAD QA model from fine_tuned_legal_qa...")
240
  cuad_tokenizer = AutoTokenizer.from_pretrained("fine_tuned_legal_qa")
@@ -242,11 +236,10 @@ try:
242
  cuad_model = AutoModelForQuestionAnswering.from_pretrained("fine_tuned_legal_qa")
243
  cuad_model.to(device)
244
  else:
245
- print("⚠️ Fine-tuned QA model not found. Fine-tuning now (this may be slow).")
246
  cuad_tokenizer, cuad_model = fine_tune_cuad_model()
247
  cuad_model.to(device)
248
 
249
- # Sentiment (GPU)
250
  sentiment_pipeline = pipeline(
251
  "sentiment-analysis",
252
  model="distilbert-base-uncased-finetuned-sst-2-english",
@@ -281,9 +274,6 @@ def extract_text_from_pdf(pdf_file):
281
  raise HTTPException(status_code=400, detail=f"PDF extraction failed: {str(e)}")
282
 
283
  async def process_video_to_text(video_file_path):
284
- """
285
- Extracts audio from video and runs speech-to-text.
286
- """
287
  try:
288
  print(f"Processing video file at {video_file_path}")
289
  temp_audio_path = os.path.join("temp", "extracted_audio.wav")
@@ -305,9 +295,6 @@ async def process_video_to_text(video_file_path):
305
  raise HTTPException(status_code=400, detail=f"Video processing failed: {str(e)}")
306
 
307
  async def process_audio_to_text(audio_file_path):
308
- """
309
- Runs speech-to-text on an audio file.
310
- """
311
  try:
312
  print(f"Processing audio file at {audio_file_path}")
313
  result = await run_in_threadpool(speech_to_text, audio_file_path)
@@ -319,9 +306,6 @@ async def process_audio_to_text(audio_file_path):
319
  raise HTTPException(status_code=400, detail=f"Audio processing failed: {str(e)}")
320
 
321
  def extract_named_entities(text):
322
- """
323
- Splits text into manageable chunks, runs spaCy for entity extraction.
324
- """
325
  max_length = 10000
326
  entities = []
327
  for i in range(0, len(text), max_length):
@@ -373,11 +357,9 @@ def explain_topics(topics):
373
  weight = float(weight_str)
374
  except:
375
  weight = 0.0
376
- # Filter out short words & stop words
377
  if word.lower() not in STOP_WORDS and len(word) > 1:
378
  terms.append((weight, word))
379
  terms.sort(key=lambda x: -x[0])
380
- # Heuristic labeling
381
  if terms:
382
  if any("liability" in w.lower() for _, w in terms):
383
  label = "Liability & Penalty Risk"
@@ -419,20 +401,13 @@ def analyze_risk_enhanced(text):
419
  #############################
420
 
421
  def chunk_text_by_tokens(text, tokenizer, max_chunk_len=384, stride=128):
422
- """
423
- Convert the entire text into tokens once, then create overlapping chunks
424
- of up to `max_chunk_len` tokens with overlap `stride`.
425
- """
426
- # Encode text once
427
  encoded = tokenizer(text, add_special_tokens=False)
428
  input_ids = encoded["input_ids"]
429
- # We'll create overlapping windows of tokens
430
  chunks = []
431
  idx = 0
432
  while idx < len(input_ids):
433
  end = idx + max_chunk_len
434
  sub_ids = input_ids[idx:end]
435
- # Convert back to text
436
  chunk_text = tokenizer.decode(sub_ids, skip_special_tokens=True)
437
  chunks.append(chunk_text)
438
  if end >= len(input_ids):
@@ -443,13 +418,7 @@ def chunk_text_by_tokens(text, tokenizer, max_chunk_len=384, stride=128):
443
  return chunks
444
 
445
  def analyze_contract_clauses(text):
446
- """
447
- Token-based chunking to avoid partial tokens.
448
- Each chunk is fed into the fine-tuned CUAD model on GPU.
449
- """
450
- # We'll break the text into chunks of up to 384 tokens, with a stride of 128
451
  text_chunks = chunk_text_by_tokens(text, cuad_tokenizer, max_chunk_len=384, stride=128)
452
-
453
  try:
454
  clause_types = list(cuad_model.config.id2label.values())
455
  except Exception:
@@ -459,7 +428,6 @@ def analyze_contract_clauses(text):
459
  "Assignment", "Warranty", "Limitation of Liability", "Arbitration",
460
  "IP Rights", "Force Majeure", "Revenue/Profit Sharing", "Audit Rights"
461
  ]
462
-
463
  clauses_detected = []
464
 
465
  for chunk in text_chunks:
@@ -467,26 +435,20 @@ def analyze_contract_clauses(text):
467
  if not chunk:
468
  continue
469
  try:
470
- # Tokenize the chunk again for the model
471
  tokenized_inputs = cuad_tokenizer(chunk, return_tensors="pt", truncation=True, max_length=512)
 
472
  inputs = {k: v.to(device) for k, v in tokenized_inputs.items()}
473
- # Check for invalid token IDs
474
  if torch.any(inputs["input_ids"] >= cuad_model.config.vocab_size):
475
  print("Invalid token id found; skipping chunk")
476
  continue
477
-
478
  with torch.no_grad():
479
  outputs = cuad_model(**inputs)
480
- # Force synchronization so that if there's a device error, we catch it here
481
  if device == "cuda":
482
  torch.cuda.synchronize()
483
-
484
- # Shape check
485
  if outputs.start_logits.shape[1] != inputs["input_ids"].shape[1]:
486
  print("Mismatch in logits shape; skipping chunk")
487
  continue
488
-
489
- # For demonstration, we just apply a threshold to the start_logits
490
  predictions = torch.sigmoid(outputs.start_logits).cpu().numpy()[0]
491
  for idx, confidence in enumerate(predictions):
492
  if confidence > 0.5 and idx < len(clause_types):
@@ -494,21 +456,17 @@ def analyze_contract_clauses(text):
494
  "type": clause_types[idx],
495
  "confidence": float(confidence)
496
  })
497
-
498
  except Exception as e:
499
  print(f"Error processing chunk: {e}")
500
- # Clear GPU cache if there's an error
501
  if device == "cuda":
502
  torch.cuda.empty_cache()
503
  continue
504
 
505
- # Aggregate clauses by their highest confidence
506
  aggregated_clauses = {}
507
  for clause in clauses_detected:
508
  ctype = clause["type"]
509
  if ctype not in aggregated_clauses or clause["confidence"] > aggregated_clauses[ctype]["confidence"]:
510
  aggregated_clauses[ctype] = clause
511
-
512
  return list(aggregated_clauses.values())
513
 
514
  #############################
@@ -517,24 +475,14 @@ def analyze_contract_clauses(text):
517
 
518
  @app.post("/analyze_legal_document")
519
  async def analyze_legal_document(file: UploadFile = File(...)):
520
- """
521
- Analyze a legal document (PDF). Extract text, summarize, detect entities,
522
- do risk analysis, detect clauses, and store context for chat.
523
- """
524
  try:
525
  content = await file.read()
526
  file_hash = compute_md5(content)
527
-
528
- # Return cached result if we've already processed this file
529
  if file_hash in analysis_cache:
530
  return analysis_cache[file_hash]
531
-
532
- # Extract text
533
  text = await run_in_threadpool(extract_text_from_pdf, io.BytesIO(content))
534
  if not text:
535
  return {"status": "error", "message": "No valid text found in the document."}
536
-
537
- # Summarize (handle short documents gracefully)
538
  summary_text = text[:4096] if len(text) > 4096 else text
539
  try:
540
  if len(text) > 100:
@@ -544,20 +492,11 @@ async def analyze_legal_document(file: UploadFile = File(...)):
544
  except Exception as e:
545
  summary = "Summarization failed due to an error."
546
  print(f"Summarization error: {e}")
547
-
548
- # Extract named entities
549
  entities = extract_named_entities(text)
550
-
551
- # Analyze risk
552
  risk_analysis = analyze_risk_enhanced(text)
553
-
554
- # Detect clauses
555
  clauses = analyze_contract_clauses(text)
556
-
557
- # Store the document context for chatbot
558
  generated_task_id = str(uuid.uuid4())
559
  store_document_context(generated_task_id, text)
560
-
561
  result = {
562
  "status": "success",
563
  "task_id": generated_task_id,
@@ -566,46 +505,29 @@ async def analyze_legal_document(file: UploadFile = File(...)):
566
  "risk_analysis": risk_analysis,
567
  "clauses_detected": clauses
568
  }
569
-
570
- # Cache it
571
  analysis_cache[file_hash] = result
572
  return result
573
-
574
  except Exception as e:
575
  return {"status": "error", "message": str(e)}
576
 
577
  @app.post("/analyze_legal_video")
578
  async def analyze_legal_video(file: UploadFile = File(...), background_tasks: BackgroundTasks = None):
579
- """
580
- Analyze a legal video: transcribe, summarize, detect entities, risk analysis, etc.
581
- """
582
  try:
583
  content = await file.read()
584
  file_hash = compute_md5(content)
585
  if file_hash in analysis_cache:
586
  return analysis_cache[file_hash]
587
-
588
- # Save video temporarily
589
  with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_file:
590
  temp_file.write(content)
591
  temp_file_path = temp_file.name
592
-
593
- # Transcribe
594
  text = await process_video_to_text(temp_file_path)
595
-
596
- # Cleanup
597
  if os.path.exists(temp_file_path):
598
  os.remove(temp_file_path)
599
-
600
  if not text:
601
  return {"status": "error", "message": "No speech could be transcribed from the video."}
602
-
603
- # Save transcript
604
  transcript_path = os.path.join("static", f"transcript_{int(time.time())}.txt")
605
  with open(transcript_path, "w") as f:
606
  f.write(text)
607
-
608
- # Summarize
609
  summary_text = text[:4096] if len(text) > 4096 else text
610
  try:
611
  if len(text) > 100:
@@ -615,16 +537,11 @@ async def analyze_legal_video(file: UploadFile = File(...), background_tasks: Ba
615
  except Exception as e:
616
  summary = "Summarization failed due to an error."
617
  print(f"Summarization error: {e}")
618
-
619
- # Entities, risk, clauses
620
  entities = extract_named_entities(text)
621
  risk_analysis = analyze_risk_enhanced(text)
622
  clauses = analyze_contract_clauses(text)
623
-
624
- # Store context
625
  generated_task_id = str(uuid.uuid4())
626
  store_document_context(generated_task_id, text)
627
-
628
  result = {
629
  "status": "success",
630
  "task_id": generated_task_id,
@@ -642,36 +559,22 @@ async def analyze_legal_video(file: UploadFile = File(...), background_tasks: Ba
642
 
643
  @app.post("/analyze_legal_audio")
644
  async def analyze_legal_audio(file: UploadFile = File(...), background_tasks: BackgroundTasks = None):
645
- """
646
- Analyze an audio file: transcribe, summarize, detect entities, risk analysis, etc.
647
- """
648
  try:
649
  content = await file.read()
650
  file_hash = compute_md5(content)
651
  if file_hash in analysis_cache:
652
  return analysis_cache[file_hash]
653
-
654
- # Save audio temporarily
655
  with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_file:
656
  temp_file.write(content)
657
  temp_file_path = temp_file.name
658
-
659
- # Transcribe
660
  text = await process_audio_to_text(temp_file_path)
661
-
662
- # Cleanup
663
  if os.path.exists(temp_file_path):
664
  os.remove(temp_file_path)
665
-
666
  if not text:
667
  return {"status": "error", "message": "No speech could be transcribed from the audio."}
668
-
669
- # Save transcript
670
  transcript_path = os.path.join("static", f"transcript_{int(time.time())}.txt")
671
  with open(transcript_path, "w") as f:
672
  f.write(text)
673
-
674
- # Summarize
675
  summary_text = text[:4096] if len(text) > 4096 else text
676
  try:
677
  if len(text) > 100:
@@ -681,16 +584,11 @@ async def analyze_legal_audio(file: UploadFile = File(...), background_tasks: Ba
681
  except Exception as e:
682
  summary = "Summarization failed due to an error."
683
  print(f"Summarization error: {e}")
684
-
685
- # Entities, risk, clauses
686
  entities = extract_named_entities(text)
687
  risk_analysis = analyze_risk_enhanced(text)
688
  clauses = analyze_contract_clauses(text)
689
-
690
- # Store context
691
  generated_task_id = str(uuid.uuid4())
692
  store_document_context(generated_task_id, text)
693
-
694
  result = {
695
  "status": "success",
696
  "task_id": generated_task_id,
@@ -716,9 +614,6 @@ async def get_transcript(transcript_id: str):
716
 
717
  @app.post("/legal_chatbot")
718
  async def legal_chatbot_api(query: str = Form(...), task_id: str = Form(...)):
719
- """
720
- Simple QA pipeline on the stored document context.
721
- """
722
  document_context = load_document_context(task_id)
723
  if not document_context:
724
  return {"response": "⚠️ No relevant document found for this task ID."}
@@ -762,7 +657,6 @@ def setup_ngrok():
762
  print(f"⚠️ Ngrok setup error: {e}")
763
  return None
764
 
765
- # Visualization endpoints
766
  @app.get("/download_clause_bar_chart")
767
  async def download_clause_bar_chart(task_id: str):
768
  try:
@@ -826,7 +720,6 @@ async def download_clause_radar_chart(task_id: str):
826
  raise HTTPException(status_code=404, detail="No clauses detected.")
827
  labels = [c["type"] for c in clauses]
828
  values = [c["confidence"] for c in clauses]
829
- # close the loop for radar
830
  labels += labels[:1]
831
  values += values[:1]
832
  angles = np.linspace(0, 2 * np.pi, len(labels), endpoint=False).tolist()
@@ -854,4 +747,3 @@ if __name__ == "__main__":
854
  else:
855
  print("\n⚠️ Ngrok setup failed. API will only be available locally.\n")
856
  run()
857
-
 
38
  # Global cache for analysis results based on file hash
39
  analysis_cache = {}
40
 
41
+ # Ensure compatibility with Google Colab (if applicable)
42
  try:
43
  from google.colab import drive
44
  drive.mount('/content/drive')
45
  except Exception:
46
+ pass # Not running in Colab
47
 
48
+ # Ensure required directories exist
49
  os.makedirs("static", exist_ok=True)
50
  os.makedirs("temp", exist_ok=True)
51
 
52
  # Use GPU if available
53
  device = "cuda" if torch.cuda.is_available() else "cpu"
54
 
55
+ # Initialize FastAPI
56
  app = FastAPI(title="Legal Document and Video Analyzer")
57
 
58
+ # Add CORS middleware
59
  app.add_middleware(
60
  CORSMiddleware,
61
  allow_origins=["*"],
 
64
  allow_headers=["*"],
65
  )
66
 
67
+ # In-memory storage for document text and chat history
68
  document_storage = {}
69
  chat_history = []
70
 
 
79
  return hashlib.md5(content).hexdigest()
80
 
81
  #############################
82
+ # Fine-tuning on CUAD QA #
83
  #############################
84
 
85
  def fine_tune_cuad_model():
 
 
 
 
86
  from datasets import load_dataset
87
  from transformers import Trainer, TrainingArguments, AutoModelForQuestionAnswering, AutoTokenizer
88
 
 
157
  tokenized_examples["end_positions"].append(safe_end)
158
  return tokenized_examples
159
 
160
+ print("✅ Tokenizing dataset...")
161
  train_dataset = train_dataset.map(prepare_train_features, batched=True, remove_columns=train_dataset.column_names)
162
  val_dataset = val_dataset.map(prepare_train_features, batched=True, remove_columns=val_dataset.column_names)
163
  train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "start_positions", "end_positions"])
 
198
  #############################
199
 
200
  try:
201
+ # Load spaCy model
202
  try:
203
  nlp = spacy.load("en_core_web_sm")
204
  except Exception:
 
206
  nlp = spacy.load("en_core_web_sm")
207
  print("✅ Loaded spaCy model.")
208
 
209
+ # Create summarizer and QA pipelines on GPU
210
  summarizer = pipeline(
211
  "summarization",
212
  model="facebook/bart-large-cnn",
213
  tokenizer="facebook/bart-large-cnn",
214
  device=0 if device == "cuda" else -1
215
  )
 
 
216
  qa_model = pipeline(
217
  "question-answering",
218
  model="deepset/roberta-base-squad2",
219
  device=0 if device == "cuda" else -1
220
  )
221
 
222
+ # Use GPU for sentence embeddings if available
223
  embedding_model = SentenceTransformer("all-mpnet-base-v2", device=device)
224
 
 
225
  ner_model = pipeline("ner", model="dslim/bert-base-NER", device=0 if device == "cuda" else -1)
226
 
227
+ # Speech-to-text pipeline on GPU (if available)
228
  speech_to_text = pipeline("automatic-speech-recognition", model="openai/whisper-medium", chunk_length_s=30,
229
  device_map="auto" if device == "cuda" else None)
230
 
231
+ # Load or fine-tune the CUAD QA model and move to GPU
232
  if os.path.exists("fine_tuned_legal_qa"):
233
  print("✅ Loading fine-tuned CUAD QA model from fine_tuned_legal_qa...")
234
  cuad_tokenizer = AutoTokenizer.from_pretrained("fine_tuned_legal_qa")
 
236
  cuad_model = AutoModelForQuestionAnswering.from_pretrained("fine_tuned_legal_qa")
237
  cuad_model.to(device)
238
  else:
239
+ print("⚠️ Fine-tuned QA model not found. Fine-tuning now (this may take a while)...")
240
  cuad_tokenizer, cuad_model = fine_tune_cuad_model()
241
  cuad_model.to(device)
242
 
 
243
  sentiment_pipeline = pipeline(
244
  "sentiment-analysis",
245
  model="distilbert-base-uncased-finetuned-sst-2-english",
 
274
  raise HTTPException(status_code=400, detail=f"PDF extraction failed: {str(e)}")
275
 
276
  async def process_video_to_text(video_file_path):
 
 
 
277
  try:
278
  print(f"Processing video file at {video_file_path}")
279
  temp_audio_path = os.path.join("temp", "extracted_audio.wav")
 
295
  raise HTTPException(status_code=400, detail=f"Video processing failed: {str(e)}")
296
 
297
  async def process_audio_to_text(audio_file_path):
 
 
 
298
  try:
299
  print(f"Processing audio file at {audio_file_path}")
300
  result = await run_in_threadpool(speech_to_text, audio_file_path)
 
306
  raise HTTPException(status_code=400, detail=f"Audio processing failed: {str(e)}")
307
 
308
  def extract_named_entities(text):
 
 
 
309
  max_length = 10000
310
  entities = []
311
  for i in range(0, len(text), max_length):
 
357
  weight = float(weight_str)
358
  except:
359
  weight = 0.0
 
360
  if word.lower() not in STOP_WORDS and len(word) > 1:
361
  terms.append((weight, word))
362
  terms.sort(key=lambda x: -x[0])
 
363
  if terms:
364
  if any("liability" in w.lower() for _, w in terms):
365
  label = "Liability & Penalty Risk"
 
401
  #############################
402
 
403
  def chunk_text_by_tokens(text, tokenizer, max_chunk_len=384, stride=128):
 
 
 
 
 
404
  encoded = tokenizer(text, add_special_tokens=False)
405
  input_ids = encoded["input_ids"]
 
406
  chunks = []
407
  idx = 0
408
  while idx < len(input_ids):
409
  end = idx + max_chunk_len
410
  sub_ids = input_ids[idx:end]
 
411
  chunk_text = tokenizer.decode(sub_ids, skip_special_tokens=True)
412
  chunks.append(chunk_text)
413
  if end >= len(input_ids):
 
418
  return chunks
419
 
420
  def analyze_contract_clauses(text):
 
 
 
 
 
421
  text_chunks = chunk_text_by_tokens(text, cuad_tokenizer, max_chunk_len=384, stride=128)
 
422
  try:
423
  clause_types = list(cuad_model.config.id2label.values())
424
  except Exception:
 
428
  "Assignment", "Warranty", "Limitation of Liability", "Arbitration",
429
  "IP Rights", "Force Majeure", "Revenue/Profit Sharing", "Audit Rights"
430
  ]
 
431
  clauses_detected = []
432
 
433
  for chunk in text_chunks:
 
435
  if not chunk:
436
  continue
437
  try:
 
438
  tokenized_inputs = cuad_tokenizer(chunk, return_tensors="pt", truncation=True, max_length=512)
439
+ # Move to GPU and clamp token IDs to ensure they are within valid range
440
  inputs = {k: v.to(device) for k, v in tokenized_inputs.items()}
441
+ inputs["input_ids"] = torch.clamp(inputs["input_ids"], max=cuad_model.config.vocab_size - 1)
442
  if torch.any(inputs["input_ids"] >= cuad_model.config.vocab_size):
443
  print("Invalid token id found; skipping chunk")
444
  continue
 
445
  with torch.no_grad():
446
  outputs = cuad_model(**inputs)
 
447
  if device == "cuda":
448
  torch.cuda.synchronize()
 
 
449
  if outputs.start_logits.shape[1] != inputs["input_ids"].shape[1]:
450
  print("Mismatch in logits shape; skipping chunk")
451
  continue
 
 
452
  predictions = torch.sigmoid(outputs.start_logits).cpu().numpy()[0]
453
  for idx, confidence in enumerate(predictions):
454
  if confidence > 0.5 and idx < len(clause_types):
 
456
  "type": clause_types[idx],
457
  "confidence": float(confidence)
458
  })
 
459
  except Exception as e:
460
  print(f"Error processing chunk: {e}")
 
461
  if device == "cuda":
462
  torch.cuda.empty_cache()
463
  continue
464
 
 
465
  aggregated_clauses = {}
466
  for clause in clauses_detected:
467
  ctype = clause["type"]
468
  if ctype not in aggregated_clauses or clause["confidence"] > aggregated_clauses[ctype]["confidence"]:
469
  aggregated_clauses[ctype] = clause
 
470
  return list(aggregated_clauses.values())
471
 
472
  #############################
 
475
 
476
  @app.post("/analyze_legal_document")
477
  async def analyze_legal_document(file: UploadFile = File(...)):
 
 
 
 
478
  try:
479
  content = await file.read()
480
  file_hash = compute_md5(content)
 
 
481
  if file_hash in analysis_cache:
482
  return analysis_cache[file_hash]
 
 
483
  text = await run_in_threadpool(extract_text_from_pdf, io.BytesIO(content))
484
  if not text:
485
  return {"status": "error", "message": "No valid text found in the document."}
 
 
486
  summary_text = text[:4096] if len(text) > 4096 else text
487
  try:
488
  if len(text) > 100:
 
492
  except Exception as e:
493
  summary = "Summarization failed due to an error."
494
  print(f"Summarization error: {e}")
 
 
495
  entities = extract_named_entities(text)
 
 
496
  risk_analysis = analyze_risk_enhanced(text)
 
 
497
  clauses = analyze_contract_clauses(text)
 
 
498
  generated_task_id = str(uuid.uuid4())
499
  store_document_context(generated_task_id, text)
 
500
  result = {
501
  "status": "success",
502
  "task_id": generated_task_id,
 
505
  "risk_analysis": risk_analysis,
506
  "clauses_detected": clauses
507
  }
 
 
508
  analysis_cache[file_hash] = result
509
  return result
 
510
  except Exception as e:
511
  return {"status": "error", "message": str(e)}
512
 
513
  @app.post("/analyze_legal_video")
514
  async def analyze_legal_video(file: UploadFile = File(...), background_tasks: BackgroundTasks = None):
 
 
 
515
  try:
516
  content = await file.read()
517
  file_hash = compute_md5(content)
518
  if file_hash in analysis_cache:
519
  return analysis_cache[file_hash]
 
 
520
  with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_file:
521
  temp_file.write(content)
522
  temp_file_path = temp_file.name
 
 
523
  text = await process_video_to_text(temp_file_path)
 
 
524
  if os.path.exists(temp_file_path):
525
  os.remove(temp_file_path)
 
526
  if not text:
527
  return {"status": "error", "message": "No speech could be transcribed from the video."}
 
 
528
  transcript_path = os.path.join("static", f"transcript_{int(time.time())}.txt")
529
  with open(transcript_path, "w") as f:
530
  f.write(text)
 
 
531
  summary_text = text[:4096] if len(text) > 4096 else text
532
  try:
533
  if len(text) > 100:
 
537
  except Exception as e:
538
  summary = "Summarization failed due to an error."
539
  print(f"Summarization error: {e}")
 
 
540
  entities = extract_named_entities(text)
541
  risk_analysis = analyze_risk_enhanced(text)
542
  clauses = analyze_contract_clauses(text)
 
 
543
  generated_task_id = str(uuid.uuid4())
544
  store_document_context(generated_task_id, text)
 
545
  result = {
546
  "status": "success",
547
  "task_id": generated_task_id,
 
559
 
560
  @app.post("/analyze_legal_audio")
561
  async def analyze_legal_audio(file: UploadFile = File(...), background_tasks: BackgroundTasks = None):
 
 
 
562
  try:
563
  content = await file.read()
564
  file_hash = compute_md5(content)
565
  if file_hash in analysis_cache:
566
  return analysis_cache[file_hash]
 
 
567
  with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_file:
568
  temp_file.write(content)
569
  temp_file_path = temp_file.name
 
 
570
  text = await process_audio_to_text(temp_file_path)
 
 
571
  if os.path.exists(temp_file_path):
572
  os.remove(temp_file_path)
 
573
  if not text:
574
  return {"status": "error", "message": "No speech could be transcribed from the audio."}
 
 
575
  transcript_path = os.path.join("static", f"transcript_{int(time.time())}.txt")
576
  with open(transcript_path, "w") as f:
577
  f.write(text)
 
 
578
  summary_text = text[:4096] if len(text) > 4096 else text
579
  try:
580
  if len(text) > 100:
 
584
  except Exception as e:
585
  summary = "Summarization failed due to an error."
586
  print(f"Summarization error: {e}")
 
 
587
  entities = extract_named_entities(text)
588
  risk_analysis = analyze_risk_enhanced(text)
589
  clauses = analyze_contract_clauses(text)
 
 
590
  generated_task_id = str(uuid.uuid4())
591
  store_document_context(generated_task_id, text)
 
592
  result = {
593
  "status": "success",
594
  "task_id": generated_task_id,
 
614
 
615
  @app.post("/legal_chatbot")
616
  async def legal_chatbot_api(query: str = Form(...), task_id: str = Form(...)):
 
 
 
617
  document_context = load_document_context(task_id)
618
  if not document_context:
619
  return {"response": "⚠️ No relevant document found for this task ID."}
 
657
  print(f"⚠️ Ngrok setup error: {e}")
658
  return None
659
 
 
660
  @app.get("/download_clause_bar_chart")
661
  async def download_clause_bar_chart(task_id: str):
662
  try:
 
720
  raise HTTPException(status_code=404, detail="No clauses detected.")
721
  labels = [c["type"] for c in clauses]
722
  values = [c["confidence"] for c in clauses]
 
723
  labels += labels[:1]
724
  values += values[:1]
725
  angles = np.linspace(0, 2 * np.pi, len(labels), endpoint=False).tolist()
 
747
  else:
748
  print("\n⚠️ Ngrok setup failed. API will only be available locally.\n")
749
  run()