tejash300 commited on
Commit
57ab60a
·
verified ·
1 Parent(s): e3b69f0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -252
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import os
2
  os.environ["TRANSFORMERS_NO_FAST"] = "1" # Force use of slow tokenizers
3
- os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
4
 
5
  import io
6
  import torch
@@ -14,7 +13,7 @@ import numpy as np
14
  import json
15
  import tempfile
16
  from fastapi import FastAPI, UploadFile, File, HTTPException, Form, BackgroundTasks
17
- from fastapi.responses import FileResponse, JSONResponse, HTMLResponse
18
  from fastapi.middleware.cors import CORSMiddleware
19
  from transformers import pipeline, AutoModelForQuestionAnswering, AutoTokenizer
20
  from sentence_transformers import SentenceTransformer
@@ -28,17 +27,14 @@ import hashlib # For caching file results
28
  # For asynchronous blocking calls
29
  from starlette.concurrency import run_in_threadpool
30
 
31
- # Gensim for topic modeling
32
  import gensim
33
  from gensim import corpora, models
34
 
35
- # Spacy stop words
36
- from spacy.lang.en.stop_words import STOP_WORDS
37
-
38
  # Global cache for analysis results based on file hash
39
  analysis_cache = {}
40
 
41
- # Ensure compatibility with Google Colab (if applicable)
42
  try:
43
  from google.colab import drive
44
  drive.mount('/content/drive')
@@ -49,7 +45,7 @@ except Exception:
49
  os.makedirs("static", exist_ok=True)
50
  os.makedirs("temp", exist_ok=True)
51
 
52
- # Use GPU if available
53
  device = "cuda" if torch.cuda.is_available() else "cpu"
54
 
55
  # Initialize FastAPI
@@ -68,13 +64,16 @@ app.add_middleware(
68
  document_storage = {}
69
  chat_history = []
70
 
 
71
  def store_document_context(task_id, text):
72
  document_storage[task_id] = text
73
  return True
74
 
 
75
  def load_document_context(task_id):
76
  return document_storage.get(task_id, "")
77
 
 
78
  def compute_md5(content: bytes) -> str:
79
  return hashlib.md5(content).hexdigest()
80
 
@@ -84,6 +83,7 @@ def compute_md5(content: bytes) -> str:
84
 
85
  def fine_tune_cuad_model():
86
  from datasets import load_dataset
 
87
  from transformers import Trainer, TrainingArguments, AutoModelForQuestionAnswering, AutoTokenizer
88
 
89
  print("✅ Loading CUAD dataset for fine tuning...")
@@ -121,10 +121,7 @@ def fine_tune_cuad_model():
121
  tokenized_examples["end_positions"] = []
122
  for i, offsets in enumerate(offset_mapping):
123
  input_ids = tokenized_examples["input_ids"][i]
124
- try:
125
- cls_index = input_ids.index(tokenizer.cls_token_id)
126
- except ValueError:
127
- cls_index = 0
128
  sequence_ids = tokenized_examples.sequence_ids(i)
129
  sample_index = sample_mapping[i]
130
  answers = examples["answers"][sample_index]
@@ -135,26 +132,21 @@ def fine_tune_cuad_model():
135
  start_char = answers["answer_start"][0]
136
  end_char = start_char + len(answers["text"][0])
137
  tokenized_start_index = 0
138
- while tokenized_start_index < len(sequence_ids) and sequence_ids[tokenized_start_index] != 1:
139
  tokenized_start_index += 1
140
  tokenized_end_index = len(input_ids) - 1
141
- while tokenized_end_index >= 0 and sequence_ids[tokenized_end_index] != 1:
142
  tokenized_end_index -= 1
143
- if tokenized_start_index >= len(offsets) or tokenized_end_index < 0:
144
- tokenized_examples["start_positions"].append(cls_index)
145
- tokenized_examples["end_positions"].append(cls_index)
146
- elif not (offsets[tokenized_start_index][0] <= start_char and offsets[tokenized_end_index][1] >= end_char):
147
  tokenized_examples["start_positions"].append(cls_index)
148
  tokenized_examples["end_positions"].append(cls_index)
149
  else:
150
  while tokenized_start_index < len(offsets) and offsets[tokenized_start_index][0] <= start_char:
151
  tokenized_start_index += 1
152
- safe_start = tokenized_start_index - 1 if tokenized_start_index > 0 else cls_index
153
- tokenized_examples["start_positions"].append(safe_start)
154
- while tokenized_end_index >= 0 and offsets[tokenized_end_index][1] >= end_char:
155
  tokenized_end_index -= 1
156
- safe_end = tokenized_end_index + 1 if tokenized_end_index < len(offsets) - 1 else cls_index
157
- tokenized_examples["end_positions"].append(safe_end)
158
  return tokenized_examples
159
 
160
  print("✅ Tokenizing dataset...")
@@ -198,70 +190,53 @@ def fine_tune_cuad_model():
198
  #############################
199
 
200
  try:
201
- # Load spaCy model
202
  try:
203
  nlp = spacy.load("en_core_web_sm")
204
  except Exception:
205
  spacy.cli.download("en_core_web_sm")
206
  nlp = spacy.load("en_core_web_sm")
207
- print("✅ Loaded spaCy model.")
208
-
209
- # Create summarizer and QA pipelines on GPU
210
  summarizer = pipeline(
211
  "summarization",
212
- model="facebook/bart-large-cnn",
213
- tokenizer="facebook/bart-large-cnn",
214
- device=0 if device == "cuda" else -1
215
- )
216
- qa_model = pipeline(
217
- "question-answering",
218
- model="deepset/roberta-base-squad2",
219
- device=0 if device == "cuda" else -1
220
  )
 
 
 
221
 
222
- # Use GPU for sentence embeddings if available
223
  embedding_model = SentenceTransformer("all-mpnet-base-v2", device=device)
224
-
225
- ner_model = pipeline("ner", model="dslim/bert-base-NER", device=0 if device == "cuda" else -1)
226
-
227
- # Speech-to-text pipeline on GPU (if available)
228
  speech_to_text = pipeline("automatic-speech-recognition", model="openai/whisper-medium", chunk_length_s=30,
229
- device_map="auto" if device == "cuda" else None)
230
-
231
- # Load or fine-tune the CUAD QA model and move to GPU
232
  if os.path.exists("fine_tuned_legal_qa"):
233
  print("✅ Loading fine-tuned CUAD QA model from fine_tuned_legal_qa...")
234
  cuad_tokenizer = AutoTokenizer.from_pretrained("fine_tuned_legal_qa")
235
  from transformers import AutoModelForQuestionAnswering
236
  cuad_model = AutoModelForQuestionAnswering.from_pretrained("fine_tuned_legal_qa")
237
  cuad_model.to(device)
 
 
238
  else:
239
- print("⚠️ Fine-tuned QA model not found. Fine-tuning now (this may take a while)...")
240
  cuad_tokenizer, cuad_model = fine_tune_cuad_model()
241
  cuad_model.to(device)
242
-
243
- sentiment_pipeline = pipeline(
244
- "sentiment-analysis",
245
- model="distilbert-base-uncased-finetuned-sst-2-english",
246
- device=0 if device == "cuda" else -1
247
- )
248
-
249
- print("✅ All models loaded successfully.")
250
  except Exception as e:
251
- print(f"⚠️ Error loading models: {str(e)}")
252
  raise RuntimeError(f"Error loading models: {str(e)}")
253
 
254
- #############################
255
- # Helper Functions #
256
- #############################
 
257
 
258
  def legal_chatbot(user_input, context):
259
  global chat_history
260
  chat_history.append({"role": "user", "content": user_input})
261
- try:
262
- response = qa_model(question=user_input, context=context)["answer"]
263
- except Exception as e:
264
- response = f"Error processing query: {e}"
265
  chat_history.append({"role": "assistant", "content": response})
266
  return response
267
 
@@ -314,9 +289,9 @@ def extract_named_entities(text):
314
  entities.extend([{"entity": ent.text, "label": ent.label_} for ent in doc.ents])
315
  return entities
316
 
317
- #############################
318
- # Risk & Topic Analysis #
319
- #############################
320
 
321
  def analyze_sentiment(text):
322
  sentences = [sent.text for sent in nlp(text).sents]
@@ -343,82 +318,20 @@ def get_enhanced_context_info(text):
343
  enhanced["topics"] = analyze_topics(text, num_topics=5)
344
  return enhanced
345
 
346
- def explain_topics(topics):
347
- explanation = {}
348
- for topic_idx, topic_str in topics:
349
- parts = topic_str.split('+')
350
- terms = []
351
- for part in parts:
352
- part = part.strip()
353
- if '*' in part:
354
- weight_str, word = part.split('*', 1)
355
- word = word.strip().strip('\"').strip('\'')
356
- try:
357
- weight = float(weight_str)
358
- except:
359
- weight = 0.0
360
- if word.lower() not in STOP_WORDS and len(word) > 1:
361
- terms.append((weight, word))
362
- terms.sort(key=lambda x: -x[0])
363
- if terms:
364
- if any("liability" in w.lower() for _, w in terms):
365
- label = "Liability & Penalty Risk"
366
- elif any("termination" in w.lower() for _, w in terms):
367
- label = "Termination & Refund Risk"
368
- elif any("compliance" in w.lower() for _, w in terms):
369
- label = "Compliance & Regulatory Risk"
370
- else:
371
- label = "General Risk Language"
372
- else:
373
- label = "General Risk Language"
374
-
375
- explanation_text = (
376
- f"Topic {topic_idx} ({label}) is characterized by dominant terms: " +
377
- ", ".join([f"'{word}' ({weight:.3f})" for weight, word in terms[:5]])
378
- )
379
- explanation[topic_idx] = {
380
- "label": label,
381
- "explanation": explanation_text,
382
- "terms": terms
383
- }
384
- return explanation
385
-
386
  def analyze_risk_enhanced(text):
387
  enhanced = get_enhanced_context_info(text)
388
  avg_sentiment = enhanced["average_sentiment"]
389
  risk_score = abs(avg_sentiment) if avg_sentiment < 0 else 0
390
- topics_raw = enhanced["topics"]
391
- topics_explanation = explain_topics(topics_raw)
392
- return {
393
- "risk_score": risk_score,
394
- "average_sentiment": avg_sentiment,
395
- "topics": topics_raw,
396
- "topics_explanation": topics_explanation
397
- }
398
 
399
- #############################
400
- # Clause Detection (GPU) #
401
- #############################
402
-
403
- def chunk_text_by_tokens(text, tokenizer, max_chunk_len=384, stride=128):
404
- encoded = tokenizer(text, add_special_tokens=False)
405
- input_ids = encoded["input_ids"]
406
- chunks = []
407
- idx = 0
408
- while idx < len(input_ids):
409
- end = idx + max_chunk_len
410
- sub_ids = input_ids[idx:end]
411
- chunk_text = tokenizer.decode(sub_ids, skip_special_tokens=True)
412
- chunks.append(chunk_text)
413
- if end >= len(input_ids):
414
- break
415
- idx = end - stride
416
- if idx < 0:
417
- idx = 0
418
- return chunks
419
 
420
  def analyze_contract_clauses(text):
421
- text_chunks = chunk_text_by_tokens(text, cuad_tokenizer, max_chunk_len=384, stride=128)
 
 
422
  try:
423
  clause_types = list(cuad_model.config.id2label.values())
424
  except Exception:
@@ -428,50 +341,26 @@ def analyze_contract_clauses(text):
428
  "Assignment", "Warranty", "Limitation of Liability", "Arbitration",
429
  "IP Rights", "Force Majeure", "Revenue/Profit Sharing", "Audit Rights"
430
  ]
431
- clauses_detected = []
432
-
433
- for chunk in text_chunks:
434
- chunk = chunk.strip()
435
- if not chunk:
436
- continue
437
- try:
438
- tokenized_inputs = cuad_tokenizer(chunk, return_tensors="pt", truncation=True, max_length=512)
439
- # Move to GPU and clamp token IDs to ensure they are within valid range
440
- inputs = {k: v.to(device) for k, v in tokenized_inputs.items()}
441
- inputs["input_ids"] = torch.clamp(inputs["input_ids"], max=cuad_model.config.vocab_size - 1)
442
- if torch.any(inputs["input_ids"] >= cuad_model.config.vocab_size):
443
- print("Invalid token id found; skipping chunk")
444
- continue
445
- with torch.no_grad():
446
- outputs = cuad_model(**inputs)
447
- if device == "cuda":
448
- torch.cuda.synchronize()
449
- if outputs.start_logits.shape[1] != inputs["input_ids"].shape[1]:
450
- print("Mismatch in logits shape; skipping chunk")
451
- continue
452
- predictions = torch.sigmoid(outputs.start_logits).cpu().numpy()[0]
453
- for idx, confidence in enumerate(predictions):
454
- if confidence > 0.5 and idx < len(clause_types):
455
- clauses_detected.append({
456
- "type": clause_types[idx],
457
- "confidence": float(confidence)
458
- })
459
- except Exception as e:
460
- print(f"Error processing chunk: {e}")
461
- if device == "cuda":
462
- torch.cuda.empty_cache()
463
- continue
464
-
465
  aggregated_clauses = {}
466
  for clause in clauses_detected:
467
- ctype = clause["type"]
468
- if ctype not in aggregated_clauses or clause["confidence"] > aggregated_clauses[ctype]["confidence"]:
469
- aggregated_clauses[ctype] = clause
470
  return list(aggregated_clauses.values())
471
 
472
- #############################
473
- # Endpoints #
474
- #############################
475
 
476
  @app.post("/analyze_legal_document")
477
  async def analyze_legal_document(file: UploadFile = File(...)):
@@ -484,14 +373,7 @@ async def analyze_legal_document(file: UploadFile = File(...)):
484
  if not text:
485
  return {"status": "error", "message": "No valid text found in the document."}
486
  summary_text = text[:4096] if len(text) > 4096 else text
487
- try:
488
- if len(text) > 100:
489
- summary = summarizer(summary_text, max_length=200, min_length=50, do_sample=False)[0]['summary_text']
490
- else:
491
- summary = "Document too short for a meaningful summary."
492
- except Exception as e:
493
- summary = "Summarization failed due to an error."
494
- print(f"Summarization error: {e}")
495
  entities = extract_named_entities(text)
496
  risk_analysis = analyze_risk_enhanced(text)
497
  clauses = analyze_contract_clauses(text)
@@ -529,14 +411,7 @@ async def analyze_legal_video(file: UploadFile = File(...), background_tasks: Ba
529
  with open(transcript_path, "w") as f:
530
  f.write(text)
531
  summary_text = text[:4096] if len(text) > 4096 else text
532
- try:
533
- if len(text) > 100:
534
- summary = summarizer(summary_text, max_length=200, min_length=50, do_sample=False)[0]['summary_text']
535
- else:
536
- summary = "Transcript too short for meaningful summarization."
537
- except Exception as e:
538
- summary = "Summarization failed due to an error."
539
- print(f"Summarization error: {e}")
540
  entities = extract_named_entities(text)
541
  risk_analysis = analyze_risk_enhanced(text)
542
  clauses = analyze_contract_clauses(text)
@@ -576,14 +451,7 @@ async def analyze_legal_audio(file: UploadFile = File(...), background_tasks: Ba
576
  with open(transcript_path, "w") as f:
577
  f.write(text)
578
  summary_text = text[:4096] if len(text) > 4096 else text
579
- try:
580
- if len(text) > 100:
581
- summary = summarizer(summary_text, max_length=200, min_length=50, do_sample=False)[0]['summary_text']
582
- else:
583
- summary = "Transcript too short for meaningful summarization."
584
- except Exception as e:
585
- summary = "Summarization failed due to an error."
586
- print(f"Summarization error: {e}")
587
  entities = extract_named_entities(text)
588
  risk_analysis = analyze_risk_enhanced(text)
589
  clauses = analyze_contract_clauses(text)
@@ -616,7 +484,7 @@ async def get_transcript(transcript_id: str):
616
  async def legal_chatbot_api(query: str = Form(...), task_id: str = Form(...)):
617
  document_context = load_document_context(task_id)
618
  if not document_context:
619
- return {"response": "⚠️ No relevant document found for this task ID."}
620
  response = legal_chatbot(query, document_context)
621
  return {"response": response, "chat_history": chat_history[-5:]}
622
 
@@ -646,95 +514,129 @@ def setup_ngrok():
646
  try:
647
  tunnels = ngrok.get_tunnels()
648
  if not tunnels:
649
- print("⚠️ Ngrok tunnel closed. Reconnecting...")
650
  ngrok_tunnel = ngrok.connect(8500, "http")
651
  print(f"✅ Reconnected. New URL: {ngrok_tunnel.public_url}")
652
  except Exception as e:
653
- print(f"⚠️ Ngrok error: {e}")
654
  Thread(target=keep_alive, daemon=True).start()
655
  return public_url
656
  except Exception as e:
657
- print(f"⚠️ Ngrok setup error: {e}")
658
  return None
659
 
660
- @app.get("/download_clause_bar_chart")
661
- async def download_clause_bar_chart(task_id: str):
 
 
 
 
662
  try:
663
  text = load_document_context(task_id)
664
  if not text:
665
  raise HTTPException(status_code=404, detail="Document context not found")
666
- clauses = analyze_contract_clauses(text)
667
- if not clauses:
668
- raise HTTPException(status_code=404, detail="No clauses detected.")
669
- clause_types = [c["type"] for c in clauses]
670
- confidences = [c["confidence"] for c in clauses]
671
- plt.figure(figsize=(10, 6))
672
- plt.bar(clause_types, confidences, color='blue')
673
- plt.xlabel("Clause Type")
674
- plt.ylabel("Confidence Score")
675
- plt.title("Extracted Legal Clause Confidence Scores")
676
- plt.xticks(rotation=45, ha="right")
677
- plt.tight_layout()
678
- bar_chart_path = os.path.join("static", f"clause_bar_chart_{task_id}.png")
679
- plt.savefig(bar_chart_path)
680
  plt.close()
681
- return FileResponse(bar_chart_path, media_type="image/png", filename=f"clause_bar_chart_{task_id}.png")
682
  except Exception as e:
683
- raise HTTPException(status_code=500, detail=f"Error generating clause bar chart: {str(e)}")
684
 
685
- @app.get("/download_clause_donut_chart")
686
- async def download_clause_donut_chart(task_id: str):
687
  try:
688
  text = load_document_context(task_id)
689
  if not text:
690
  raise HTTPException(status_code=404, detail="Document context not found")
691
- clauses = analyze_contract_clauses(text)
692
- if not clauses:
693
- raise HTTPException(status_code=404, detail="No clauses detected.")
694
- from collections import Counter
695
- clause_counter = Counter([c["type"] for c in clauses])
696
- labels = list(clause_counter.keys())
697
- sizes = list(clause_counter.values())
698
  plt.figure(figsize=(6, 6))
699
- wedges, texts, autotexts = plt.pie(sizes, labels=labels, autopct='%1.1f%%', startangle=90)
700
- centre_circle = plt.Circle((0, 0), 0.70, fc='white')
701
- fig = plt.gcf()
702
- fig.gca().add_artist(centre_circle)
703
- plt.title("Clause Type Distribution")
704
- plt.tight_layout()
705
- donut_chart_path = os.path.join("static", f"clause_donut_chart_{task_id}.png")
706
- plt.savefig(donut_chart_path)
707
  plt.close()
708
- return FileResponse(donut_chart_path, media_type="image/png", filename=f"clause_donut_chart_{task_id}.png")
709
  except Exception as e:
710
- raise HTTPException(status_code=500, detail=f"Error generating clause donut chart: {str(e)}")
711
 
712
- @app.get("/download_clause_radar_chart")
713
- async def download_clause_radar_chart(task_id: str):
714
  try:
715
  text = load_document_context(task_id)
716
  if not text:
717
  raise HTTPException(status_code=404, detail="Document context not found")
718
- clauses = analyze_contract_clauses(text)
719
- if not clauses:
720
- raise HTTPException(status_code=404, detail="No clauses detected.")
721
- labels = [c["type"] for c in clauses]
722
- values = [c["confidence"] for c in clauses]
723
- labels += labels[:1]
724
  values += values[:1]
725
- angles = np.linspace(0, 2 * np.pi, len(labels), endpoint=False).tolist()
726
  angles += angles[:1]
727
  fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True))
728
  ax.plot(angles, values, 'o-', linewidth=2)
729
  ax.fill(angles, values, alpha=0.25)
730
- ax.set_thetagrids(np.degrees(angles[:-1]), labels[:-1])
731
- ax.set_title("Legal Clause Radar Chart", y=1.1)
732
- radar_chart_path = os.path.join("static", f"clause_radar_chart_{task_id}.png")
733
  plt.savefig(radar_chart_path)
734
  plt.close()
735
- return FileResponse(radar_chart_path, media_type="image/png", filename=f"clause_radar_chart_{task_id}.png")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
736
  except Exception as e:
737
- raise HTTPException(status_code=500, detail=f"Error generating clause radar chart: {str(e)}")
738
 
739
  def run():
740
  print("Starting FastAPI server...")
@@ -745,5 +647,5 @@ if __name__ == "__main__":
745
  if public_url:
746
  print(f"\n✅ Your API is publicly available at: {public_url}/docs\n")
747
  else:
748
- print("\n⚠️ Ngrok setup failed. API will only be available locally.\n")
749
  run()
 
1
  import os
2
  os.environ["TRANSFORMERS_NO_FAST"] = "1" # Force use of slow tokenizers
 
3
 
4
  import io
5
  import torch
 
13
  import json
14
  import tempfile
15
  from fastapi import FastAPI, UploadFile, File, HTTPException, Form, BackgroundTasks
16
+ from fastapi.responses import FileResponse, JSONResponse, HTMLResponse # Added HTMLResponse
17
  from fastapi.middleware.cors import CORSMiddleware
18
  from transformers import pipeline, AutoModelForQuestionAnswering, AutoTokenizer
19
  from sentence_transformers import SentenceTransformer
 
27
  # For asynchronous blocking calls
28
  from starlette.concurrency import run_in_threadpool
29
 
30
+ # Import gensim for topic modeling
31
  import gensim
32
  from gensim import corpora, models
33
 
 
 
 
34
  # Global cache for analysis results based on file hash
35
  analysis_cache = {}
36
 
37
+ # Ensure compatibility with Google Colab
38
  try:
39
  from google.colab import drive
40
  drive.mount('/content/drive')
 
45
  os.makedirs("static", exist_ok=True)
46
  os.makedirs("temp", exist_ok=True)
47
 
48
+ # Ensure GPU usage
49
  device = "cuda" if torch.cuda.is_available() else "cpu"
50
 
51
  # Initialize FastAPI
 
64
  document_storage = {}
65
  chat_history = []
66
 
67
+ # Function to store document context by task ID
68
  def store_document_context(task_id, text):
69
  document_storage[task_id] = text
70
  return True
71
 
72
+ # Function to load document context by task ID
73
  def load_document_context(task_id):
74
  return document_storage.get(task_id, "")
75
 
76
+ # Utility to compute MD5 hash from file content
77
  def compute_md5(content: bytes) -> str:
78
  return hashlib.md5(content).hexdigest()
79
 
 
83
 
84
  def fine_tune_cuad_model():
85
  from datasets import load_dataset
86
+ import numpy as np
87
  from transformers import Trainer, TrainingArguments, AutoModelForQuestionAnswering, AutoTokenizer
88
 
89
  print("✅ Loading CUAD dataset for fine tuning...")
 
121
  tokenized_examples["end_positions"] = []
122
  for i, offsets in enumerate(offset_mapping):
123
  input_ids = tokenized_examples["input_ids"][i]
124
+ cls_index = input_ids.index(tokenizer.cls_token_id)
 
 
 
125
  sequence_ids = tokenized_examples.sequence_ids(i)
126
  sample_index = sample_mapping[i]
127
  answers = examples["answers"][sample_index]
 
132
  start_char = answers["answer_start"][0]
133
  end_char = start_char + len(answers["text"][0])
134
  tokenized_start_index = 0
135
+ while sequence_ids[tokenized_start_index] != 1:
136
  tokenized_start_index += 1
137
  tokenized_end_index = len(input_ids) - 1
138
+ while sequence_ids[tokenized_end_index] != 1:
139
  tokenized_end_index -= 1
140
+ if not (offsets[tokenized_start_index][0] <= start_char and offsets[tokenized_end_index][1] >= end_char):
 
 
 
141
  tokenized_examples["start_positions"].append(cls_index)
142
  tokenized_examples["end_positions"].append(cls_index)
143
  else:
144
  while tokenized_start_index < len(offsets) and offsets[tokenized_start_index][0] <= start_char:
145
  tokenized_start_index += 1
146
+ tokenized_examples["start_positions"].append(tokenized_start_index - 1)
147
+ while offsets[tokenized_end_index][1] >= end_char:
 
148
  tokenized_end_index -= 1
149
+ tokenized_examples["end_positions"].append(tokenized_end_index + 1)
 
150
  return tokenized_examples
151
 
152
  print("✅ Tokenizing dataset...")
 
190
  #############################
191
 
192
  try:
 
193
  try:
194
  nlp = spacy.load("en_core_web_sm")
195
  except Exception:
196
  spacy.cli.download("en_core_web_sm")
197
  nlp = spacy.load("en_core_web_sm")
198
+ print("✅ Loading NLP models...")
199
+ from transformers import PegasusTokenizer
 
200
  summarizer = pipeline(
201
  "summarization",
202
+ model="nsi319/legal-pegasus",
203
+ tokenizer=PegasusTokenizer.from_pretrained("nsi319/legal-pegasus", use_fast=False),
204
+ device=0 if torch.cuda.is_available() else -1
 
 
 
 
 
205
  )
206
+ # Optionally convert summarizer model to FP16 for faster inference on GPU
207
+ if device == "cuda":
208
+ summarizer.model.half()
209
 
 
210
  embedding_model = SentenceTransformer("all-mpnet-base-v2", device=device)
211
+ ner_model = pipeline("ner", model="dslim/bert-base-NER", device=0 if torch.cuda.is_available() else -1)
 
 
 
212
  speech_to_text = pipeline("automatic-speech-recognition", model="openai/whisper-medium", chunk_length_s=30,
213
+ device_map="auto" if torch.cuda.is_available() else "cpu")
 
 
214
  if os.path.exists("fine_tuned_legal_qa"):
215
  print("✅ Loading fine-tuned CUAD QA model from fine_tuned_legal_qa...")
216
  cuad_tokenizer = AutoTokenizer.from_pretrained("fine_tuned_legal_qa")
217
  from transformers import AutoModelForQuestionAnswering
218
  cuad_model = AutoModelForQuestionAnswering.from_pretrained("fine_tuned_legal_qa")
219
  cuad_model.to(device)
220
+ if device == "cuda":
221
+ cuad_model.half()
222
  else:
223
+ print(" Fine-tuned QA model not found. Starting fine tuning on CUAD QA dataset. This may take a while...")
224
  cuad_tokenizer, cuad_model = fine_tune_cuad_model()
225
  cuad_model.to(device)
226
+ print("✅ All models loaded successfully")
 
 
 
 
 
 
 
227
  except Exception as e:
228
+ print(f" Error loading models: {str(e)}")
229
  raise RuntimeError(f"Error loading models: {str(e)}")
230
 
231
+ from transformers import pipeline
232
+ qa_model = pipeline("question-answering", model="deepset/roberta-base-squad2")
233
+
234
+ sentiment_pipeline = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english", device=0 if torch.cuda.is_available() else -1)
235
 
236
  def legal_chatbot(user_input, context):
237
  global chat_history
238
  chat_history.append({"role": "user", "content": user_input})
239
+ response = qa_model(question=user_input, context=context)["answer"]
 
 
 
240
  chat_history.append({"role": "assistant", "content": response})
241
  return response
242
 
 
289
  entities.extend([{"entity": ent.text, "label": ent.label_} for ent in doc.ents])
290
  return entities
291
 
292
+ # -----------------------------
293
+ # Enhanced Risk Analysis Functions
294
+ # -----------------------------
295
 
296
  def analyze_sentiment(text):
297
  sentences = [sent.text for sent in nlp(text).sents]
 
318
  enhanced["topics"] = analyze_topics(text, num_topics=5)
319
  return enhanced
320
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
  def analyze_risk_enhanced(text):
322
  enhanced = get_enhanced_context_info(text)
323
  avg_sentiment = enhanced["average_sentiment"]
324
  risk_score = abs(avg_sentiment) if avg_sentiment < 0 else 0
325
+ return {"risk_score": risk_score, "average_sentiment": avg_sentiment, "topics": enhanced["topics"]}
 
 
 
 
 
 
 
326
 
327
+ # -----------------------------
328
+ # Clause Detection (Chunk-Based)
329
+ # -----------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
 
331
  def analyze_contract_clauses(text):
332
+ max_length = 512
333
+ step = 256
334
+ clauses_detected = []
335
  try:
336
  clause_types = list(cuad_model.config.id2label.values())
337
  except Exception:
 
341
  "Assignment", "Warranty", "Limitation of Liability", "Arbitration",
342
  "IP Rights", "Force Majeure", "Revenue/Profit Sharing", "Audit Rights"
343
  ]
344
+ # Create chunks of the text
345
+ chunks = [text[i:i+max_length] for i in range(0, len(text), step) if i+step < len(text)]
346
+ for chunk in chunks:
347
+ inputs = cuad_tokenizer(chunk, return_tensors="pt", truncation=True, max_length=512).to(device)
348
+ with torch.no_grad():
349
+ outputs = cuad_model(**inputs)
350
+ predictions = torch.sigmoid(outputs.start_logits).cpu().numpy()[0]
351
+ for idx, confidence in enumerate(predictions):
352
+ if confidence > 0.5 and idx < len(clause_types):
353
+ clauses_detected.append({"type": clause_types[idx], "confidence": float(confidence)})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
354
  aggregated_clauses = {}
355
  for clause in clauses_detected:
356
+ clause_type = clause["type"]
357
+ if clause_type not in aggregated_clauses or clause["confidence"] > aggregated_clauses[clause_type]["confidence"]:
358
+ aggregated_clauses[clause_type] = clause
359
  return list(aggregated_clauses.values())
360
 
361
+ # -----------------------------
362
+ # Endpoints
363
+ # -----------------------------
364
 
365
  @app.post("/analyze_legal_document")
366
  async def analyze_legal_document(file: UploadFile = File(...)):
 
373
  if not text:
374
  return {"status": "error", "message": "No valid text found in the document."}
375
  summary_text = text[:4096] if len(text) > 4096 else text
376
+ summary = summarizer(summary_text, max_length=200, min_length=50, do_sample=False)[0]['summary_text'] if len(text) > 100 else "Document too short for meaningful summarization."
 
 
 
 
 
 
 
377
  entities = extract_named_entities(text)
378
  risk_analysis = analyze_risk_enhanced(text)
379
  clauses = analyze_contract_clauses(text)
 
411
  with open(transcript_path, "w") as f:
412
  f.write(text)
413
  summary_text = text[:4096] if len(text) > 4096 else text
414
+ summary = summarizer(summary_text, max_length=200, min_length=50, do_sample=False)[0]['summary_text'] if len(text) > 100 else "Transcript too short for meaningful summarization."
 
 
 
 
 
 
 
415
  entities = extract_named_entities(text)
416
  risk_analysis = analyze_risk_enhanced(text)
417
  clauses = analyze_contract_clauses(text)
 
451
  with open(transcript_path, "w") as f:
452
  f.write(text)
453
  summary_text = text[:4096] if len(text) > 4096 else text
454
+ summary = summarizer(summary_text, max_length=200, min_length=50, do_sample=False)[0]['summary_text'] if len(text) > 100 else "Transcript too short for meaningful summarization."
 
 
 
 
 
 
 
455
  entities = extract_named_entities(text)
456
  risk_analysis = analyze_risk_enhanced(text)
457
  clauses = analyze_contract_clauses(text)
 
484
  async def legal_chatbot_api(query: str = Form(...), task_id: str = Form(...)):
485
  document_context = load_document_context(task_id)
486
  if not document_context:
487
+ return {"response": " No relevant document found for this task ID."}
488
  response = legal_chatbot(query, document_context)
489
  return {"response": response, "chat_history": chat_history[-5:]}
490
 
 
514
  try:
515
  tunnels = ngrok.get_tunnels()
516
  if not tunnels:
517
+ print(" Ngrok tunnel closed. Reconnecting...")
518
  ngrok_tunnel = ngrok.connect(8500, "http")
519
  print(f"✅ Reconnected. New URL: {ngrok_tunnel.public_url}")
520
  except Exception as e:
521
+ print(f" Ngrok error: {e}")
522
  Thread(target=keep_alive, daemon=True).start()
523
  return public_url
524
  except Exception as e:
525
+ print(f" Ngrok setup error: {e}")
526
  return None
527
 
528
+ # ------------------------------
529
+ # Dynamic Visualization Endpoints
530
+ # ------------------------------
531
+
532
+ @app.get("/download_risk_chart")
533
+ async def download_risk_chart(task_id: str):
534
  try:
535
  text = load_document_context(task_id)
536
  if not text:
537
  raise HTTPException(status_code=404, detail="Document context not found")
538
+ risk_analysis = analyze_risk_enhanced(text)
539
+ plt.figure(figsize=(8, 5))
540
+ plt.bar(["Risk Score"], [risk_analysis["risk_score"]], color='red')
541
+ plt.ylabel("Risk Score")
542
+ plt.title("Legal Risk Assessment (Enhanced)")
543
+ risk_chart_path = os.path.join("static", f"risk_chart_{task_id}.png")
544
+ plt.savefig(risk_chart_path)
 
 
 
 
 
 
 
545
  plt.close()
546
+ return FileResponse(risk_chart_path, media_type="image/png", filename=f"risk_chart_{task_id}.png")
547
  except Exception as e:
548
+ raise HTTPException(status_code=500, detail=f"Error generating risk chart: {str(e)}")
549
 
550
+ @app.get("/download_risk_pie_chart")
551
+ async def download_risk_pie_chart(task_id: str):
552
  try:
553
  text = load_document_context(task_id)
554
  if not text:
555
  raise HTTPException(status_code=404, detail="Document context not found")
556
+ risk_analysis = analyze_risk_enhanced(text)
557
+ labels = ["Risk", "No Risk"]
558
+ risk_value = risk_analysis["risk_score"]
559
+ risk_value = min(max(risk_value, 0), 1)
560
+ values = [risk_value, 1 - risk_value]
 
 
561
  plt.figure(figsize=(6, 6))
562
+ plt.pie(values, labels=labels, autopct='%1.1f%%', startangle=90)
563
+ plt.title("Legal Risk Distribution (Enhanced)")
564
+ pie_chart_path = os.path.join("static", f"risk_pie_chart_{task_id}.png")
565
+ plt.savefig(pie_chart_path)
 
 
 
 
566
  plt.close()
567
+ return FileResponse(pie_chart_path, media_type="image/png", filename=f"risk_pie_chart_{task_id}.png")
568
  except Exception as e:
569
+ raise HTTPException(status_code=500, detail=f"Error generating pie chart: {str(e)}")
570
 
571
+ @app.get("/download_risk_radar_chart")
572
+ async def download_risk_radar_chart(task_id: str):
573
  try:
574
  text = load_document_context(task_id)
575
  if not text:
576
  raise HTTPException(status_code=404, detail="Document context not found")
577
+ risk_analysis = analyze_risk_enhanced(text)
578
+ categories = ["Average Sentiment", "Risk Score"]
579
+ values = [risk_analysis["average_sentiment"], risk_analysis["risk_score"]]
580
+ categories += categories[:1]
 
 
581
  values += values[:1]
582
+ angles = np.linspace(0, 2 * np.pi, len(categories), endpoint=False).tolist()
583
  angles += angles[:1]
584
  fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True))
585
  ax.plot(angles, values, 'o-', linewidth=2)
586
  ax.fill(angles, values, alpha=0.25)
587
+ ax.set_thetagrids(np.degrees(angles[:-1]), ["Sentiment", "Risk"])
588
+ ax.set_title("Legal Risk Radar Chart (Enhanced)", y=1.1)
589
+ radar_chart_path = os.path.join("static", f"risk_radar_chart_{task_id}.png")
590
  plt.savefig(radar_chart_path)
591
  plt.close()
592
+ return FileResponse(radar_chart_path, media_type="image/png", filename=f"risk_radar_chart_{task_id}.png")
593
+ except Exception as e:
594
+ raise HTTPException(status_code=500, detail=f"Error generating radar chart: {str(e)}")
595
+
596
+ @app.get("/download_risk_trend_chart")
597
+ async def download_risk_trend_chart(task_id: str):
598
+ try:
599
+ text = load_document_context(task_id)
600
+ if not text:
601
+ raise HTTPException(status_code=404, detail="Document context not found")
602
+ words = text.split()
603
+ segments = np.array_split(words, 4)
604
+ segment_texts = [" ".join(segment) for segment in segments]
605
+ trend_scores = []
606
+ for segment in segment_texts:
607
+ risk = analyze_risk_enhanced(segment)
608
+ trend_scores.append(risk["risk_score"])
609
+ segments_labels = [f"Segment {i+1}" for i in range(len(segment_texts))]
610
+ plt.figure(figsize=(10, 6))
611
+ plt.plot(segments_labels, trend_scores, marker='o')
612
+ plt.xlabel("Document Segments")
613
+ plt.ylabel("Risk Score")
614
+ plt.title("Dynamic Legal Risk Trends (Enhanced)")
615
+ plt.xticks(rotation=45)
616
+ trend_chart_path = os.path.join("static", f"risk_trend_chart_{task_id}.png")
617
+ plt.savefig(trend_chart_path, bbox_inches="tight")
618
+ plt.close()
619
+ return FileResponse(trend_chart_path, media_type="image/png", filename=f"risk_trend_chart_{task_id}.png")
620
+ except Exception as e:
621
+ raise HTTPException(status_code=500, detail=f"Error generating trend chart: {str(e)}")
622
+
623
+ @app.get("/interactive_risk_chart", response_class=HTMLResponse)
624
+ async def interactive_risk_chart(task_id: str):
625
+ try:
626
+ import pandas as pd
627
+ import plotly.express as px
628
+ text = load_document_context(task_id)
629
+ if not text:
630
+ raise HTTPException(status_code=404, detail="Document context not found")
631
+ risk_analysis = analyze_risk_enhanced(text)
632
+ df = pd.DataFrame({
633
+ "Metric": ["Average Sentiment", "Risk Score"],
634
+ "Value": [risk_analysis["average_sentiment"], risk_analysis["risk_score"]]
635
+ })
636
+ fig = px.bar(df, x="Metric", y="Value", title="Interactive Enhanced Legal Risk Assessment")
637
+ return fig.to_html()
638
  except Exception as e:
639
+ raise HTTPException(status_code=500, detail=f"Error generating interactive chart: {str(e)}")
640
 
641
  def run():
642
  print("Starting FastAPI server...")
 
647
  if public_url:
648
  print(f"\n✅ Your API is publicly available at: {public_url}/docs\n")
649
  else:
650
+ print("\n Ngrok setup failed. API will only be available locally.\n")
651
  run()