tejash300 commited on
Commit
dd43ec8
·
verified ·
1 Parent(s): 4e897df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +231 -76
app.py CHANGED
@@ -28,11 +28,11 @@ import hashlib # For caching file results
28
  # For asynchronous blocking calls
29
  from starlette.concurrency import run_in_threadpool
30
 
31
- # Import gensim for topic modeling
32
  import gensim
33
  from gensim import corpora, models
34
 
35
- # Import spacy stop words
36
  from spacy.lang.en.stop_words import STOP_WORDS
37
 
38
  # Global cache for analysis results based on file hash
@@ -43,19 +43,19 @@ try:
43
  from google.colab import drive
44
  drive.mount('/content/drive')
45
  except Exception:
46
- pass # Skip drive mount if not in Google Colab
47
 
48
- # Ensure required directories exist
49
  os.makedirs("static", exist_ok=True)
50
  os.makedirs("temp", exist_ok=True)
51
 
52
- # Ensure GPU usage
53
  device = "cuda" if torch.cuda.is_available() else "cpu"
54
 
55
- # Initialize FastAPI
56
  app = FastAPI(title="Legal Document and Video Analyzer")
57
 
58
- # Add CORS middleware
59
  app.add_middleware(
60
  CORSMiddleware,
61
  allow_origins=["*"],
@@ -64,31 +64,31 @@ app.add_middleware(
64
  allow_headers=["*"],
65
  )
66
 
67
- # In-memory storage for document text and chat history
68
  document_storage = {}
69
  chat_history = []
70
 
71
- # Function to store document context by task ID
72
  def store_document_context(task_id, text):
73
  document_storage[task_id] = text
74
  return True
75
 
76
- # Function to load document context by task ID
77
  def load_document_context(task_id):
78
  return document_storage.get(task_id, "")
79
 
80
- # Utility to compute MD5 hash from file content
81
  def compute_md5(content: bytes) -> str:
82
  return hashlib.md5(content).hexdigest()
83
 
84
  #############################
85
- # Fine-tuning on CUAD QA #
86
  #############################
87
 
88
  def fine_tune_cuad_model():
 
 
 
 
89
  from datasets import load_dataset
90
- import numpy as np
91
- from transformers import Trainer, TrainingArguments, AutoModelForQuestionAnswering
92
 
93
  print("✅ Loading CUAD dataset for fine tuning...")
94
  dataset = load_dataset("theatticusproject/cuad-qa", trust_remote_code=True)
@@ -144,7 +144,6 @@ def fine_tune_cuad_model():
144
  tokenized_end_index = len(input_ids) - 1
145
  while tokenized_end_index >= 0 and sequence_ids[tokenized_end_index] != 1:
146
  tokenized_end_index -= 1
147
- # Safety check: if indices are not found, default to cls_index
148
  if tokenized_start_index >= len(offsets) or tokenized_end_index < 0:
149
  tokenized_examples["start_positions"].append(cls_index)
150
  tokenized_examples["end_positions"].append(cls_index)
@@ -152,19 +151,16 @@ def fine_tune_cuad_model():
152
  tokenized_examples["start_positions"].append(cls_index)
153
  tokenized_examples["end_positions"].append(cls_index)
154
  else:
155
- # Move tokenized_start_index to the first token after start_char
156
  while tokenized_start_index < len(offsets) and offsets[tokenized_start_index][0] <= start_char:
157
  tokenized_start_index += 1
158
  safe_start = tokenized_start_index - 1 if tokenized_start_index > 0 else cls_index
159
  tokenized_examples["start_positions"].append(safe_start)
160
- # Move tokenized_end_index backwards to the last token before end_char
161
  while tokenized_end_index >= 0 and offsets[tokenized_end_index][1] >= end_char:
162
  tokenized_end_index -= 1
163
  safe_end = tokenized_end_index + 1 if tokenized_end_index < len(offsets) - 1 else cls_index
164
  tokenized_examples["end_positions"].append(safe_end)
165
  return tokenized_examples
166
 
167
- print("✅ Tokenizing dataset...")
168
  train_dataset = train_dataset.map(prepare_train_features, batched=True, remove_columns=train_dataset.column_names)
169
  val_dataset = val_dataset.map(prepare_train_features, batched=True, remove_columns=val_dataset.column_names)
170
  train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "start_positions", "end_positions"])
@@ -205,57 +201,74 @@ def fine_tune_cuad_model():
205
  #############################
206
 
207
  try:
 
208
  try:
209
  nlp = spacy.load("en_core_web_sm")
210
  except Exception:
211
  spacy.cli.download("en_core_web_sm")
212
  nlp = spacy.load("en_core_web_sm")
213
- print("✅ Loading NLP models...")
214
 
215
- # Update summarizer to use facebook/bart-large-cnn for summarization
216
  summarizer = pipeline(
217
  "summarization",
218
  model="facebook/bart-large-cnn",
219
  tokenizer="facebook/bart-large-cnn",
220
- device=0 if torch.cuda.is_available() else -1
221
  )
222
- # Commenting out FP16 conversion to avoid potential issues
223
- # if device == "cuda":
224
- # try:
225
- # summarizer.model.half()
226
- # except Exception as e:
227
- # print("FP16 conversion failed:", e)
228
 
 
 
 
 
 
 
 
 
229
  embedding_model = SentenceTransformer("all-mpnet-base-v2", device=device)
230
- ner_model = pipeline("ner", model="dslim/bert-base-NER", device=0 if torch.cuda.is_available() else -1)
 
 
 
 
231
  speech_to_text = pipeline("automatic-speech-recognition", model="openai/whisper-medium", chunk_length_s=30,
232
- device_map="auto" if torch.cuda.is_available() else "cpu")
 
 
233
  if os.path.exists("fine_tuned_legal_qa"):
234
  print("✅ Loading fine-tuned CUAD QA model from fine_tuned_legal_qa...")
235
  cuad_tokenizer = AutoTokenizer.from_pretrained("fine_tuned_legal_qa")
236
  from transformers import AutoModelForQuestionAnswering
237
  cuad_model = AutoModelForQuestionAnswering.from_pretrained("fine_tuned_legal_qa")
238
  cuad_model.to(device)
239
- # Commenting out FP16 conversion for cuad_model as well
240
- # if device == "cuda":
241
- # cuad_model.half()
242
  else:
243
- print("⚠️ Fine-tuned QA model not found. Starting fine tuning on CUAD QA dataset. This may take a while...")
244
  cuad_tokenizer, cuad_model = fine_tune_cuad_model()
245
  cuad_model.to(device)
246
- print("✅ All models loaded successfully")
 
 
 
 
 
 
 
 
247
  except Exception as e:
248
  print(f"⚠️ Error loading models: {str(e)}")
249
  raise RuntimeError(f"Error loading models: {str(e)}")
250
 
251
- from transformers import pipeline
252
- qa_model = pipeline("question-answering", model="deepset/roberta-base-squad2")
253
- sentiment_pipeline = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english", device=0 if torch.cuda.is_available() else -1)
254
 
255
  def legal_chatbot(user_input, context):
256
  global chat_history
257
  chat_history.append({"role": "user", "content": user_input})
258
- response = qa_model(question=user_input, context=context)["answer"]
 
 
 
259
  chat_history.append({"role": "assistant", "content": response})
260
  return response
261
 
@@ -268,6 +281,9 @@ def extract_text_from_pdf(pdf_file):
268
  raise HTTPException(status_code=400, detail=f"PDF extraction failed: {str(e)}")
269
 
270
  async def process_video_to_text(video_file_path):
 
 
 
271
  try:
272
  print(f"Processing video file at {video_file_path}")
273
  temp_audio_path = os.path.join("temp", "extracted_audio.wav")
@@ -289,6 +305,9 @@ async def process_video_to_text(video_file_path):
289
  raise HTTPException(status_code=400, detail=f"Video processing failed: {str(e)}")
290
 
291
  async def process_audio_to_text(audio_file_path):
 
 
 
292
  try:
293
  print(f"Processing audio file at {audio_file_path}")
294
  result = await run_in_threadpool(speech_to_text, audio_file_path)
@@ -300,6 +319,9 @@ async def process_audio_to_text(audio_file_path):
300
  raise HTTPException(status_code=400, detail=f"Audio processing failed: {str(e)}")
301
 
302
  def extract_named_entities(text):
 
 
 
303
  max_length = 10000
304
  entities = []
305
  for i in range(0, len(text), max_length):
@@ -308,9 +330,9 @@ def extract_named_entities(text):
308
  entities.extend([{"entity": ent.text, "label": ent.label_} for ent in doc.ents])
309
  return entities
310
 
311
- # -----------------------------
312
- # Enhanced Risk Analysis Functions
313
- # -----------------------------
314
 
315
  def analyze_sentiment(text):
316
  sentences = [sent.text for sent in nlp(text).sents]
@@ -337,11 +359,9 @@ def get_enhanced_context_info(text):
337
  enhanced["topics"] = analyze_topics(text, num_topics=5)
338
  return enhanced
339
 
340
- # New function to create a detailed, dynamic explanation for each topic
341
  def explain_topics(topics):
342
  explanation = {}
343
  for topic_idx, topic_str in topics:
344
- # Split topic string into individual weighted terms
345
  parts = topic_str.split('+')
346
  terms = []
347
  for part in parts:
@@ -353,22 +373,23 @@ def explain_topics(topics):
353
  weight = float(weight_str)
354
  except:
355
  weight = 0.0
356
- # Filter out common stop words
357
  if word.lower() not in STOP_WORDS and len(word) > 1:
358
  terms.append((weight, word))
359
  terms.sort(key=lambda x: -x[0])
360
- # Create a plain language label based on dominant words
361
  if terms:
362
- if any("liability" in word.lower() for weight, word in terms):
363
  label = "Liability & Penalty Risk"
364
- elif any("termination" in word.lower() for weight, word in terms):
365
  label = "Termination & Refund Risk"
366
- elif any("compliance" in word.lower() for weight, word in terms):
367
  label = "Compliance & Regulatory Risk"
368
  else:
369
  label = "General Risk Language"
370
  else:
371
  label = "General Risk Language"
 
372
  explanation_text = (
373
  f"Topic {topic_idx} ({label}) is characterized by dominant terms: " +
374
  ", ".join([f"'{word}' ({weight:.3f})" for weight, word in terms[:5]])
@@ -393,10 +414,42 @@ def analyze_risk_enhanced(text):
393
  "topics_explanation": topics_explanation
394
  }
395
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
  def analyze_contract_clauses(text):
397
- max_length = 512
398
- step = 256
399
- clauses_detected = []
 
 
 
 
400
  try:
401
  clause_types = list(cuad_model.config.id2label.values())
402
  except Exception:
@@ -406,60 +459,105 @@ def analyze_contract_clauses(text):
406
  "Assignment", "Warranty", "Limitation of Liability", "Arbitration",
407
  "IP Rights", "Force Majeure", "Revenue/Profit Sharing", "Audit Rights"
408
  ]
409
- # Process text in chunks of 'max_length' with a step size 'step'
410
- for i in range(0, len(text), step):
411
- chunk = text[i:i+max_length]
412
- if not chunk.strip():
413
- continue # Skip empty chunks
 
 
414
  try:
 
415
  tokenized_inputs = cuad_tokenizer(chunk, return_tensors="pt", truncation=True, max_length=512)
416
  inputs = {k: v.to(device) for k, v in tokenized_inputs.items()}
417
- # Check that token IDs are within vocabulary bounds
418
- max_token = inputs["input_ids"].max().item()
419
- if max_token >= cuad_model.config.vocab_size:
420
- print(f"Skipping chunk due to invalid token id: {max_token}")
421
  continue
 
422
  with torch.no_grad():
423
  outputs = cuad_model(**inputs)
424
- # Optional: verify shape consistency
 
 
 
 
425
  if outputs.start_logits.shape[1] != inputs["input_ids"].shape[1]:
426
- print("Mismatch in logits shape, skipping chunk")
427
  continue
 
 
428
  predictions = torch.sigmoid(outputs.start_logits).cpu().numpy()[0]
429
  for idx, confidence in enumerate(predictions):
430
  if confidence > 0.5 and idx < len(clause_types):
431
- clauses_detected.append({"type": clause_types[idx], "confidence": float(confidence)})
 
 
 
 
432
  except Exception as e:
433
  print(f"Error processing chunk: {e}")
 
 
 
434
  continue
 
 
435
  aggregated_clauses = {}
436
  for clause in clauses_detected:
437
- clause_type = clause["type"]
438
- if clause_type not in aggregated_clauses or clause["confidence"] > aggregated_clauses[clause_type]["confidence"]:
439
- aggregated_clauses[clause_type] = clause
 
440
  return list(aggregated_clauses.values())
441
 
442
- # -----------------------------
443
- # Endpoints
444
- # -----------------------------
445
 
446
  @app.post("/analyze_legal_document")
447
  async def analyze_legal_document(file: UploadFile = File(...)):
 
 
 
 
448
  try:
449
  content = await file.read()
450
  file_hash = compute_md5(content)
 
 
451
  if file_hash in analysis_cache:
452
  return analysis_cache[file_hash]
 
 
453
  text = await run_in_threadpool(extract_text_from_pdf, io.BytesIO(content))
454
  if not text:
455
  return {"status": "error", "message": "No valid text found in the document."}
 
 
456
  summary_text = text[:4096] if len(text) > 4096 else text
457
- summary = summarizer(summary_text, max_length=200, min_length=50, do_sample=False)[0]['summary_text'] if len(text) > 100 else "Document too short for meaningful summarization."
 
 
 
 
 
 
 
 
 
458
  entities = extract_named_entities(text)
 
 
459
  risk_analysis = analyze_risk_enhanced(text)
 
 
460
  clauses = analyze_contract_clauses(text)
 
 
461
  generated_task_id = str(uuid.uuid4())
462
  store_document_context(generated_task_id, text)
 
463
  result = {
464
  "status": "success",
465
  "task_id": generated_task_id,
@@ -468,36 +566,65 @@ async def analyze_legal_document(file: UploadFile = File(...)):
468
  "risk_analysis": risk_analysis,
469
  "clauses_detected": clauses
470
  }
 
 
471
  analysis_cache[file_hash] = result
472
  return result
 
473
  except Exception as e:
474
  return {"status": "error", "message": str(e)}
475
 
476
  @app.post("/analyze_legal_video")
477
  async def analyze_legal_video(file: UploadFile = File(...), background_tasks: BackgroundTasks = None):
 
 
 
478
  try:
479
  content = await file.read()
480
  file_hash = compute_md5(content)
481
  if file_hash in analysis_cache:
482
  return analysis_cache[file_hash]
 
 
483
  with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_file:
484
  temp_file.write(content)
485
  temp_file_path = temp_file.name
 
 
486
  text = await process_video_to_text(temp_file_path)
 
 
487
  if os.path.exists(temp_file_path):
488
  os.remove(temp_file_path)
 
489
  if not text:
490
  return {"status": "error", "message": "No speech could be transcribed from the video."}
 
 
491
  transcript_path = os.path.join("static", f"transcript_{int(time.time())}.txt")
492
  with open(transcript_path, "w") as f:
493
  f.write(text)
 
 
494
  summary_text = text[:4096] if len(text) > 4096 else text
495
- summary = summarizer(summary_text, max_length=200, min_length=50, do_sample=False)[0]['summary_text'] if len(text) > 100 else "Transcript too short for meaningful summarization."
 
 
 
 
 
 
 
 
 
496
  entities = extract_named_entities(text)
497
  risk_analysis = analyze_risk_enhanced(text)
498
  clauses = analyze_contract_clauses(text)
 
 
499
  generated_task_id = str(uuid.uuid4())
500
  store_document_context(generated_task_id, text)
 
501
  result = {
502
  "status": "success",
503
  "task_id": generated_task_id,
@@ -515,29 +642,55 @@ async def analyze_legal_video(file: UploadFile = File(...), background_tasks: Ba
515
 
516
  @app.post("/analyze_legal_audio")
517
  async def analyze_legal_audio(file: UploadFile = File(...), background_tasks: BackgroundTasks = None):
 
 
 
518
  try:
519
  content = await file.read()
520
  file_hash = compute_md5(content)
521
  if file_hash in analysis_cache:
522
  return analysis_cache[file_hash]
 
 
523
  with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_file:
524
  temp_file.write(content)
525
  temp_file_path = temp_file.name
 
 
526
  text = await process_audio_to_text(temp_file_path)
 
 
527
  if os.path.exists(temp_file_path):
528
  os.remove(temp_file_path)
 
529
  if not text:
530
  return {"status": "error", "message": "No speech could be transcribed from the audio."}
 
 
531
  transcript_path = os.path.join("static", f"transcript_{int(time.time())}.txt")
532
  with open(transcript_path, "w") as f:
533
  f.write(text)
 
 
534
  summary_text = text[:4096] if len(text) > 4096 else text
535
- summary = summarizer(summary_text, max_length=200, min_length=50, do_sample=False)[0]['summary_text'] if len(text) > 100 else "Transcript too short for meaningful summarization."
 
 
 
 
 
 
 
 
 
536
  entities = extract_named_entities(text)
537
  risk_analysis = analyze_risk_enhanced(text)
538
  clauses = analyze_contract_clauses(text)
 
 
539
  generated_task_id = str(uuid.uuid4())
540
  store_document_context(generated_task_id, text)
 
541
  result = {
542
  "status": "success",
543
  "task_id": generated_task_id,
@@ -563,6 +716,9 @@ async def get_transcript(transcript_id: str):
563
 
564
  @app.post("/legal_chatbot")
565
  async def legal_chatbot_api(query: str = Form(...), task_id: str = Form(...)):
 
 
 
566
  document_context = load_document_context(task_id)
567
  if not document_context:
568
  return {"response": "⚠️ No relevant document found for this task ID."}
@@ -606,10 +762,7 @@ def setup_ngrok():
606
  print(f"⚠️ Ngrok setup error: {e}")
607
  return None
608
 
609
- # ------------------------------
610
- # Clause Visualization Endpoints
611
- # ------------------------------
612
-
613
  @app.get("/download_clause_bar_chart")
614
  async def download_clause_bar_chart(task_id: str):
615
  try:
@@ -673,6 +826,7 @@ async def download_clause_radar_chart(task_id: str):
673
  raise HTTPException(status_code=404, detail="No clauses detected.")
674
  labels = [c["type"] for c in clauses]
675
  values = [c["confidence"] for c in clauses]
 
676
  labels += labels[:1]
677
  values += values[:1]
678
  angles = np.linspace(0, 2 * np.pi, len(labels), endpoint=False).tolist()
@@ -700,3 +854,4 @@ if __name__ == "__main__":
700
  else:
701
  print("\n⚠️ Ngrok setup failed. API will only be available locally.\n")
702
  run()
 
 
28
  # For asynchronous blocking calls
29
  from starlette.concurrency import run_in_threadpool
30
 
31
+ # Gensim for topic modeling
32
  import gensim
33
  from gensim import corpora, models
34
 
35
+ # Spacy stop words
36
  from spacy.lang.en.stop_words import STOP_WORDS
37
 
38
  # Global cache for analysis results based on file hash
 
43
  from google.colab import drive
44
  drive.mount('/content/drive')
45
  except Exception:
46
+ pass # Not in Colab
47
 
48
+ # Make sure directories exist
49
  os.makedirs("static", exist_ok=True)
50
  os.makedirs("temp", exist_ok=True)
51
 
52
+ # Use GPU if available
53
  device = "cuda" if torch.cuda.is_available() else "cpu"
54
 
55
+ # FastAPI setup
56
  app = FastAPI(title="Legal Document and Video Analyzer")
57
 
58
+ # CORS
59
  app.add_middleware(
60
  CORSMiddleware,
61
  allow_origins=["*"],
 
64
  allow_headers=["*"],
65
  )
66
 
67
+ # In-memory storage
68
  document_storage = {}
69
  chat_history = []
70
 
 
71
  def store_document_context(task_id, text):
72
  document_storage[task_id] = text
73
  return True
74
 
 
75
  def load_document_context(task_id):
76
  return document_storage.get(task_id, "")
77
 
 
78
  def compute_md5(content: bytes) -> str:
79
  return hashlib.md5(content).hexdigest()
80
 
81
  #############################
82
+ # Fine-tuning on CUAD QA #
83
  #############################
84
 
85
  def fine_tune_cuad_model():
86
+ """
87
+ Minimal stub for fine-tuning the CUAD QA model.
88
+ If you have a full fine-tuning script, place it here.
89
+ """
90
  from datasets import load_dataset
91
+ from transformers import Trainer, TrainingArguments, AutoModelForQuestionAnswering, AutoTokenizer
 
92
 
93
  print("✅ Loading CUAD dataset for fine tuning...")
94
  dataset = load_dataset("theatticusproject/cuad-qa", trust_remote_code=True)
 
144
  tokenized_end_index = len(input_ids) - 1
145
  while tokenized_end_index >= 0 and sequence_ids[tokenized_end_index] != 1:
146
  tokenized_end_index -= 1
 
147
  if tokenized_start_index >= len(offsets) or tokenized_end_index < 0:
148
  tokenized_examples["start_positions"].append(cls_index)
149
  tokenized_examples["end_positions"].append(cls_index)
 
151
  tokenized_examples["start_positions"].append(cls_index)
152
  tokenized_examples["end_positions"].append(cls_index)
153
  else:
 
154
  while tokenized_start_index < len(offsets) and offsets[tokenized_start_index][0] <= start_char:
155
  tokenized_start_index += 1
156
  safe_start = tokenized_start_index - 1 if tokenized_start_index > 0 else cls_index
157
  tokenized_examples["start_positions"].append(safe_start)
 
158
  while tokenized_end_index >= 0 and offsets[tokenized_end_index][1] >= end_char:
159
  tokenized_end_index -= 1
160
  safe_end = tokenized_end_index + 1 if tokenized_end_index < len(offsets) - 1 else cls_index
161
  tokenized_examples["end_positions"].append(safe_end)
162
  return tokenized_examples
163
 
 
164
  train_dataset = train_dataset.map(prepare_train_features, batched=True, remove_columns=train_dataset.column_names)
165
  val_dataset = val_dataset.map(prepare_train_features, batched=True, remove_columns=val_dataset.column_names)
166
  train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "start_positions", "end_positions"])
 
201
  #############################
202
 
203
  try:
204
+ # Load spacy
205
  try:
206
  nlp = spacy.load("en_core_web_sm")
207
  except Exception:
208
  spacy.cli.download("en_core_web_sm")
209
  nlp = spacy.load("en_core_web_sm")
210
+ print("✅ Loaded spaCy model.")
211
 
212
+ # Summarizer (GPU)
213
  summarizer = pipeline(
214
  "summarization",
215
  model="facebook/bart-large-cnn",
216
  tokenizer="facebook/bart-large-cnn",
217
+ device=0 if device == "cuda" else -1
218
  )
 
 
 
 
 
 
219
 
220
+ # QA pipeline (GPU)
221
+ qa_model = pipeline(
222
+ "question-answering",
223
+ model="deepset/roberta-base-squad2",
224
+ device=0 if device == "cuda" else -1
225
+ )
226
+
227
+ # Embeddings (GPU if available)
228
  embedding_model = SentenceTransformer("all-mpnet-base-v2", device=device)
229
+
230
+ # Named Entity Recognition (GPU)
231
+ ner_model = pipeline("ner", model="dslim/bert-base-NER", device=0 if device == "cuda" else -1)
232
+
233
+ # Speech-to-text (GPU if available via device_map="auto")
234
  speech_to_text = pipeline("automatic-speech-recognition", model="openai/whisper-medium", chunk_length_s=30,
235
+ device_map="auto" if device == "cuda" else None)
236
+
237
+ # Fine-tuned CUAD QA
238
  if os.path.exists("fine_tuned_legal_qa"):
239
  print("✅ Loading fine-tuned CUAD QA model from fine_tuned_legal_qa...")
240
  cuad_tokenizer = AutoTokenizer.from_pretrained("fine_tuned_legal_qa")
241
  from transformers import AutoModelForQuestionAnswering
242
  cuad_model = AutoModelForQuestionAnswering.from_pretrained("fine_tuned_legal_qa")
243
  cuad_model.to(device)
 
 
 
244
  else:
245
+ print("⚠️ Fine-tuned QA model not found. Fine-tuning now (this may be slow).")
246
  cuad_tokenizer, cuad_model = fine_tune_cuad_model()
247
  cuad_model.to(device)
248
+
249
+ # Sentiment (GPU)
250
+ sentiment_pipeline = pipeline(
251
+ "sentiment-analysis",
252
+ model="distilbert-base-uncased-finetuned-sst-2-english",
253
+ device=0 if device == "cuda" else -1
254
+ )
255
+
256
+ print("✅ All models loaded successfully.")
257
  except Exception as e:
258
  print(f"⚠️ Error loading models: {str(e)}")
259
  raise RuntimeError(f"Error loading models: {str(e)}")
260
 
261
+ #############################
262
+ # Helper Functions #
263
+ #############################
264
 
265
  def legal_chatbot(user_input, context):
266
  global chat_history
267
  chat_history.append({"role": "user", "content": user_input})
268
+ try:
269
+ response = qa_model(question=user_input, context=context)["answer"]
270
+ except Exception as e:
271
+ response = f"Error processing query: {e}"
272
  chat_history.append({"role": "assistant", "content": response})
273
  return response
274
 
 
281
  raise HTTPException(status_code=400, detail=f"PDF extraction failed: {str(e)}")
282
 
283
  async def process_video_to_text(video_file_path):
284
+ """
285
+ Extracts audio from video and runs speech-to-text.
286
+ """
287
  try:
288
  print(f"Processing video file at {video_file_path}")
289
  temp_audio_path = os.path.join("temp", "extracted_audio.wav")
 
305
  raise HTTPException(status_code=400, detail=f"Video processing failed: {str(e)}")
306
 
307
  async def process_audio_to_text(audio_file_path):
308
+ """
309
+ Runs speech-to-text on an audio file.
310
+ """
311
  try:
312
  print(f"Processing audio file at {audio_file_path}")
313
  result = await run_in_threadpool(speech_to_text, audio_file_path)
 
319
  raise HTTPException(status_code=400, detail=f"Audio processing failed: {str(e)}")
320
 
321
  def extract_named_entities(text):
322
+ """
323
+ Splits text into manageable chunks, runs spaCy for entity extraction.
324
+ """
325
  max_length = 10000
326
  entities = []
327
  for i in range(0, len(text), max_length):
 
330
  entities.extend([{"entity": ent.text, "label": ent.label_} for ent in doc.ents])
331
  return entities
332
 
333
+ #############################
334
+ # Risk & Topic Analysis #
335
+ #############################
336
 
337
  def analyze_sentiment(text):
338
  sentences = [sent.text for sent in nlp(text).sents]
 
359
  enhanced["topics"] = analyze_topics(text, num_topics=5)
360
  return enhanced
361
 
 
362
  def explain_topics(topics):
363
  explanation = {}
364
  for topic_idx, topic_str in topics:
 
365
  parts = topic_str.split('+')
366
  terms = []
367
  for part in parts:
 
373
  weight = float(weight_str)
374
  except:
375
  weight = 0.0
376
+ # Filter out short words & stop words
377
  if word.lower() not in STOP_WORDS and len(word) > 1:
378
  terms.append((weight, word))
379
  terms.sort(key=lambda x: -x[0])
380
+ # Heuristic labeling
381
  if terms:
382
+ if any("liability" in w.lower() for _, w in terms):
383
  label = "Liability & Penalty Risk"
384
+ elif any("termination" in w.lower() for _, w in terms):
385
  label = "Termination & Refund Risk"
386
+ elif any("compliance" in w.lower() for _, w in terms):
387
  label = "Compliance & Regulatory Risk"
388
  else:
389
  label = "General Risk Language"
390
  else:
391
  label = "General Risk Language"
392
+
393
  explanation_text = (
394
  f"Topic {topic_idx} ({label}) is characterized by dominant terms: " +
395
  ", ".join([f"'{word}' ({weight:.3f})" for weight, word in terms[:5]])
 
414
  "topics_explanation": topics_explanation
415
  }
416
 
417
+ #############################
418
+ # Clause Detection (GPU) #
419
+ #############################
420
+
421
+ def chunk_text_by_tokens(text, tokenizer, max_chunk_len=384, stride=128):
422
+ """
423
+ Convert the entire text into tokens once, then create overlapping chunks
424
+ of up to `max_chunk_len` tokens with overlap `stride`.
425
+ """
426
+ # Encode text once
427
+ encoded = tokenizer(text, add_special_tokens=False)
428
+ input_ids = encoded["input_ids"]
429
+ # We'll create overlapping windows of tokens
430
+ chunks = []
431
+ idx = 0
432
+ while idx < len(input_ids):
433
+ end = idx + max_chunk_len
434
+ sub_ids = input_ids[idx:end]
435
+ # Convert back to text
436
+ chunk_text = tokenizer.decode(sub_ids, skip_special_tokens=True)
437
+ chunks.append(chunk_text)
438
+ if end >= len(input_ids):
439
+ break
440
+ idx = end - stride
441
+ if idx < 0:
442
+ idx = 0
443
+ return chunks
444
+
445
  def analyze_contract_clauses(text):
446
+ """
447
+ Token-based chunking to avoid partial tokens.
448
+ Each chunk is fed into the fine-tuned CUAD model on GPU.
449
+ """
450
+ # We'll break the text into chunks of up to 384 tokens, with a stride of 128
451
+ text_chunks = chunk_text_by_tokens(text, cuad_tokenizer, max_chunk_len=384, stride=128)
452
+
453
  try:
454
  clause_types = list(cuad_model.config.id2label.values())
455
  except Exception:
 
459
  "Assignment", "Warranty", "Limitation of Liability", "Arbitration",
460
  "IP Rights", "Force Majeure", "Revenue/Profit Sharing", "Audit Rights"
461
  ]
462
+
463
+ clauses_detected = []
464
+
465
+ for chunk in text_chunks:
466
+ chunk = chunk.strip()
467
+ if not chunk:
468
+ continue
469
  try:
470
+ # Tokenize the chunk again for the model
471
  tokenized_inputs = cuad_tokenizer(chunk, return_tensors="pt", truncation=True, max_length=512)
472
  inputs = {k: v.to(device) for k, v in tokenized_inputs.items()}
473
+ # Check for invalid token IDs
474
+ if torch.any(inputs["input_ids"] >= cuad_model.config.vocab_size):
475
+ print("Invalid token id found; skipping chunk")
 
476
  continue
477
+
478
  with torch.no_grad():
479
  outputs = cuad_model(**inputs)
480
+ # Force synchronization so that if there's a device error, we catch it here
481
+ if device == "cuda":
482
+ torch.cuda.synchronize()
483
+
484
+ # Shape check
485
  if outputs.start_logits.shape[1] != inputs["input_ids"].shape[1]:
486
+ print("Mismatch in logits shape; skipping chunk")
487
  continue
488
+
489
+ # For demonstration, we just apply a threshold to the start_logits
490
  predictions = torch.sigmoid(outputs.start_logits).cpu().numpy()[0]
491
  for idx, confidence in enumerate(predictions):
492
  if confidence > 0.5 and idx < len(clause_types):
493
+ clauses_detected.append({
494
+ "type": clause_types[idx],
495
+ "confidence": float(confidence)
496
+ })
497
+
498
  except Exception as e:
499
  print(f"Error processing chunk: {e}")
500
+ # Clear GPU cache if there's an error
501
+ if device == "cuda":
502
+ torch.cuda.empty_cache()
503
  continue
504
+
505
+ # Aggregate clauses by their highest confidence
506
  aggregated_clauses = {}
507
  for clause in clauses_detected:
508
+ ctype = clause["type"]
509
+ if ctype not in aggregated_clauses or clause["confidence"] > aggregated_clauses[ctype]["confidence"]:
510
+ aggregated_clauses[ctype] = clause
511
+
512
  return list(aggregated_clauses.values())
513
 
514
+ #############################
515
+ # Endpoints #
516
+ #############################
517
 
518
  @app.post("/analyze_legal_document")
519
  async def analyze_legal_document(file: UploadFile = File(...)):
520
+ """
521
+ Analyze a legal document (PDF). Extract text, summarize, detect entities,
522
+ do risk analysis, detect clauses, and store context for chat.
523
+ """
524
  try:
525
  content = await file.read()
526
  file_hash = compute_md5(content)
527
+
528
+ # Return cached result if we've already processed this file
529
  if file_hash in analysis_cache:
530
  return analysis_cache[file_hash]
531
+
532
+ # Extract text
533
  text = await run_in_threadpool(extract_text_from_pdf, io.BytesIO(content))
534
  if not text:
535
  return {"status": "error", "message": "No valid text found in the document."}
536
+
537
+ # Summarize (handle short documents gracefully)
538
  summary_text = text[:4096] if len(text) > 4096 else text
539
+ try:
540
+ if len(text) > 100:
541
+ summary = summarizer(summary_text, max_length=200, min_length=50, do_sample=False)[0]['summary_text']
542
+ else:
543
+ summary = "Document too short for a meaningful summary."
544
+ except Exception as e:
545
+ summary = "Summarization failed due to an error."
546
+ print(f"Summarization error: {e}")
547
+
548
+ # Extract named entities
549
  entities = extract_named_entities(text)
550
+
551
+ # Analyze risk
552
  risk_analysis = analyze_risk_enhanced(text)
553
+
554
+ # Detect clauses
555
  clauses = analyze_contract_clauses(text)
556
+
557
+ # Store the document context for chatbot
558
  generated_task_id = str(uuid.uuid4())
559
  store_document_context(generated_task_id, text)
560
+
561
  result = {
562
  "status": "success",
563
  "task_id": generated_task_id,
 
566
  "risk_analysis": risk_analysis,
567
  "clauses_detected": clauses
568
  }
569
+
570
+ # Cache it
571
  analysis_cache[file_hash] = result
572
  return result
573
+
574
  except Exception as e:
575
  return {"status": "error", "message": str(e)}
576
 
577
  @app.post("/analyze_legal_video")
578
  async def analyze_legal_video(file: UploadFile = File(...), background_tasks: BackgroundTasks = None):
579
+ """
580
+ Analyze a legal video: transcribe, summarize, detect entities, risk analysis, etc.
581
+ """
582
  try:
583
  content = await file.read()
584
  file_hash = compute_md5(content)
585
  if file_hash in analysis_cache:
586
  return analysis_cache[file_hash]
587
+
588
+ # Save video temporarily
589
  with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_file:
590
  temp_file.write(content)
591
  temp_file_path = temp_file.name
592
+
593
+ # Transcribe
594
  text = await process_video_to_text(temp_file_path)
595
+
596
+ # Cleanup
597
  if os.path.exists(temp_file_path):
598
  os.remove(temp_file_path)
599
+
600
  if not text:
601
  return {"status": "error", "message": "No speech could be transcribed from the video."}
602
+
603
+ # Save transcript
604
  transcript_path = os.path.join("static", f"transcript_{int(time.time())}.txt")
605
  with open(transcript_path, "w") as f:
606
  f.write(text)
607
+
608
+ # Summarize
609
  summary_text = text[:4096] if len(text) > 4096 else text
610
+ try:
611
+ if len(text) > 100:
612
+ summary = summarizer(summary_text, max_length=200, min_length=50, do_sample=False)[0]['summary_text']
613
+ else:
614
+ summary = "Transcript too short for meaningful summarization."
615
+ except Exception as e:
616
+ summary = "Summarization failed due to an error."
617
+ print(f"Summarization error: {e}")
618
+
619
+ # Entities, risk, clauses
620
  entities = extract_named_entities(text)
621
  risk_analysis = analyze_risk_enhanced(text)
622
  clauses = analyze_contract_clauses(text)
623
+
624
+ # Store context
625
  generated_task_id = str(uuid.uuid4())
626
  store_document_context(generated_task_id, text)
627
+
628
  result = {
629
  "status": "success",
630
  "task_id": generated_task_id,
 
642
 
643
  @app.post("/analyze_legal_audio")
644
  async def analyze_legal_audio(file: UploadFile = File(...), background_tasks: BackgroundTasks = None):
645
+ """
646
+ Analyze an audio file: transcribe, summarize, detect entities, risk analysis, etc.
647
+ """
648
  try:
649
  content = await file.read()
650
  file_hash = compute_md5(content)
651
  if file_hash in analysis_cache:
652
  return analysis_cache[file_hash]
653
+
654
+ # Save audio temporarily
655
  with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_file:
656
  temp_file.write(content)
657
  temp_file_path = temp_file.name
658
+
659
+ # Transcribe
660
  text = await process_audio_to_text(temp_file_path)
661
+
662
+ # Cleanup
663
  if os.path.exists(temp_file_path):
664
  os.remove(temp_file_path)
665
+
666
  if not text:
667
  return {"status": "error", "message": "No speech could be transcribed from the audio."}
668
+
669
+ # Save transcript
670
  transcript_path = os.path.join("static", f"transcript_{int(time.time())}.txt")
671
  with open(transcript_path, "w") as f:
672
  f.write(text)
673
+
674
+ # Summarize
675
  summary_text = text[:4096] if len(text) > 4096 else text
676
+ try:
677
+ if len(text) > 100:
678
+ summary = summarizer(summary_text, max_length=200, min_length=50, do_sample=False)[0]['summary_text']
679
+ else:
680
+ summary = "Transcript too short for meaningful summarization."
681
+ except Exception as e:
682
+ summary = "Summarization failed due to an error."
683
+ print(f"Summarization error: {e}")
684
+
685
+ # Entities, risk, clauses
686
  entities = extract_named_entities(text)
687
  risk_analysis = analyze_risk_enhanced(text)
688
  clauses = analyze_contract_clauses(text)
689
+
690
+ # Store context
691
  generated_task_id = str(uuid.uuid4())
692
  store_document_context(generated_task_id, text)
693
+
694
  result = {
695
  "status": "success",
696
  "task_id": generated_task_id,
 
716
 
717
  @app.post("/legal_chatbot")
718
  async def legal_chatbot_api(query: str = Form(...), task_id: str = Form(...)):
719
+ """
720
+ Simple QA pipeline on the stored document context.
721
+ """
722
  document_context = load_document_context(task_id)
723
  if not document_context:
724
  return {"response": "⚠️ No relevant document found for this task ID."}
 
762
  print(f"⚠️ Ngrok setup error: {e}")
763
  return None
764
 
765
+ # Visualization endpoints
 
 
 
766
  @app.get("/download_clause_bar_chart")
767
  async def download_clause_bar_chart(task_id: str):
768
  try:
 
826
  raise HTTPException(status_code=404, detail="No clauses detected.")
827
  labels = [c["type"] for c in clauses]
828
  values = [c["confidence"] for c in clauses]
829
+ # close the loop for radar
830
  labels += labels[:1]
831
  values += values[:1]
832
  angles = np.linspace(0, 2 * np.pi, len(labels), endpoint=False).tolist()
 
854
  else:
855
  print("\n⚠️ Ngrok setup failed. API will only be available locally.\n")
856
  run()
857
+