tejash300 commited on
Commit
08be412
·
verified ·
1 Parent(s): 9e88a99

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -259
app.py CHANGED
@@ -12,7 +12,7 @@ import matplotlib.pyplot as plt
12
  import numpy as np
13
  import json
14
  import tempfile
15
- from fastapi import FastAPI, UploadFile, File, HTTPException, Form
16
  from fastapi.responses import FileResponse, JSONResponse
17
  from fastapi.middleware.cors import CORSMiddleware
18
  from transformers import pipeline, AutoModelForQuestionAnswering, AutoTokenizer
@@ -22,6 +22,17 @@ from threading import Thread
22
  import time
23
  import uuid
24
  import subprocess # For running ffmpeg commands
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  # Ensure compatibility with Google Colab
27
  try:
@@ -49,7 +60,7 @@ app.add_middleware(
49
  allow_headers=["*"],
50
  )
51
 
52
- # Initialize document storage and chat history
53
  document_storage = {}
54
  chat_history = []
55
 
@@ -62,16 +73,15 @@ def store_document_context(task_id, text):
62
  def load_document_context(task_id):
63
  return document_storage.get(task_id, "")
64
 
 
 
 
 
65
  #############################
66
  # Fine-tuning on CUAD QA #
67
  #############################
68
 
69
  def fine_tune_cuad_model():
70
- """
71
- Fine tunes a QA model on the CUAD dataset for clause extraction.
72
- For testing, we use only 50 training examples (and 10 for validation)
73
- and restrict training to 1 step with evaluation disabled.
74
- """
75
  from datasets import load_dataset
76
  import numpy as np
77
  from transformers import Trainer, TrainingArguments, AutoModelForQuestionAnswering
@@ -80,10 +90,8 @@ def fine_tune_cuad_model():
80
  dataset = load_dataset("theatticusproject/cuad-qa", trust_remote_code=True)
81
 
82
  if "train" in dataset:
83
- # Use only 50 examples for training
84
  train_dataset = dataset["train"].select(range(50))
85
  if "validation" in dataset:
86
- # Use 10 examples for validation
87
  val_dataset = dataset["validation"].select(range(10))
88
  else:
89
  split = train_dataset.train_test_split(test_size=0.2)
@@ -93,7 +101,6 @@ def fine_tune_cuad_model():
93
  raise ValueError("CUAD dataset does not have a train split")
94
 
95
  print("✅ Preparing training features...")
96
-
97
  tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
98
  model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2")
99
 
@@ -145,15 +152,13 @@ def fine_tune_cuad_model():
145
  print("✅ Tokenizing dataset...")
146
  train_dataset = train_dataset.map(prepare_train_features, batched=True, remove_columns=train_dataset.column_names)
147
  val_dataset = val_dataset.map(prepare_train_features, batched=True, remove_columns=val_dataset.column_names)
148
-
149
  train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "start_positions", "end_positions"])
150
  val_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "start_positions", "end_positions"])
151
 
152
- # Set max_steps to 1 for very fast testing and disable evaluation
153
  training_args = TrainingArguments(
154
  output_dir="./fine_tuned_legal_qa",
155
- max_steps=1, # Only one training step
156
- evaluation_strategy="no", # Disable evaluation during training
157
  learning_rate=2e-5,
158
  per_device_train_batch_size=4,
159
  per_device_eval_batch_size=4,
@@ -162,7 +167,7 @@ def fine_tune_cuad_model():
162
  logging_steps=1,
163
  save_steps=1,
164
  load_best_model_at_end=False,
165
- report_to=[] # Disable wandb logging
166
  )
167
 
168
  print("✅ Starting fine tuning on CUAD QA dataset...")
@@ -174,13 +179,10 @@ def fine_tune_cuad_model():
174
  eval_dataset=val_dataset,
175
  tokenizer=tokenizer,
176
  )
177
-
178
  trainer.train()
179
  print("✅ Fine tuning completed. Saving model...")
180
-
181
  model.save_pretrained("./fine_tuned_legal_qa")
182
  tokenizer.save_pretrained("./fine_tuned_legal_qa")
183
-
184
  return tokenizer, model
185
 
186
  #############################
@@ -194,8 +196,6 @@ try:
194
  spacy.cli.download("en_core_web_sm")
195
  nlp = spacy.load("en_core_web_sm")
196
  print("✅ Loading NLP models...")
197
-
198
- # Use the slow PegasusTokenizer explicitly
199
  from transformers import PegasusTokenizer
200
  summarizer = pipeline(
201
  "summarization",
@@ -203,25 +203,27 @@ try:
203
  tokenizer=PegasusTokenizer.from_pretrained("nsi319/legal-pegasus", use_fast=False),
204
  device=0 if torch.cuda.is_available() else -1
205
  )
206
-
 
 
 
207
  embedding_model = SentenceTransformer("all-mpnet-base-v2", device=device)
208
  ner_model = pipeline("ner", model="dslim/bert-base-NER", device=0 if torch.cuda.is_available() else -1)
209
  speech_to_text = pipeline("automatic-speech-recognition", model="openai/whisper-medium", chunk_length_s=30,
210
  device_map="auto" if torch.cuda.is_available() else "cpu")
211
-
212
  if os.path.exists("fine_tuned_legal_qa"):
213
  print("✅ Loading fine-tuned CUAD QA model from fine_tuned_legal_qa...")
214
  cuad_tokenizer = AutoTokenizer.from_pretrained("fine_tuned_legal_qa")
215
  from transformers import AutoModelForQuestionAnswering
216
  cuad_model = AutoModelForQuestionAnswering.from_pretrained("fine_tuned_legal_qa")
217
  cuad_model.to(device)
 
 
218
  else:
219
  print("⚠️ Fine-tuned QA model not found. Starting fine tuning on CUAD QA dataset. This may take a while...")
220
  cuad_tokenizer, cuad_model = fine_tune_cuad_model()
221
  cuad_model.to(device)
222
-
223
  print("✅ All models loaded successfully")
224
-
225
  except Exception as e:
226
  print(f"⚠️ Error loading models: {str(e)}")
227
  raise RuntimeError(f"Error loading models: {str(e)}")
@@ -229,8 +231,10 @@ except Exception as e:
229
  from transformers import pipeline
230
  qa_model = pipeline("question-answering", model="deepset/roberta-base-squad2")
231
 
 
 
 
232
  def legal_chatbot(user_input, context):
233
- """Uses a real NLP model for legal Q&A."""
234
  global chat_history
235
  chat_history.append({"role": "user", "content": user_input})
236
  response = qa_model(question=user_input, context=context)["answer"]
@@ -238,7 +242,6 @@ def legal_chatbot(user_input, context):
238
  return response
239
 
240
  def extract_text_from_pdf(pdf_file):
241
- """Extracts text from a PDF file using pdfplumber."""
242
  try:
243
  with pdfplumber.open(pdf_file) as pdf:
244
  text = "\n".join([page.extract_text() or "" for page in pdf.pages])
@@ -246,8 +249,7 @@ def extract_text_from_pdf(pdf_file):
246
  except Exception as e:
247
  raise HTTPException(status_code=400, detail=f"PDF extraction failed: {str(e)}")
248
 
249
- def process_video_to_text(video_file_path):
250
- """Extracts audio from video using ffmpeg and converts to text."""
251
  try:
252
  print(f"Processing video file at {video_file_path}")
253
  temp_audio_path = os.path.join("temp", "extracted_audio.wav")
@@ -256,9 +258,11 @@ def process_video_to_text(video_file_path):
256
  "-acodec", "pcm_s16le", "-ar", "44100", "-ac", "2",
257
  temp_audio_path, "-y"
258
  ]
259
- subprocess.run(cmd, check=True)
 
260
  print(f"Audio extracted to {temp_audio_path}")
261
- result = speech_to_text(temp_audio_path)
 
262
  transcript = result["text"]
263
  print(f"Transcription completed: {len(transcript)} characters")
264
  if os.path.exists(temp_audio_path):
@@ -268,11 +272,10 @@ def process_video_to_text(video_file_path):
268
  print(f"Error in video processing: {str(e)}")
269
  raise HTTPException(status_code=400, detail=f"Video processing failed: {str(e)}")
270
 
271
- def process_audio_to_text(audio_file_path):
272
- """Processes an audio file and converts it to text."""
273
  try:
274
  print(f"Processing audio file at {audio_file_path}")
275
- result = speech_to_text(audio_file_path)
276
  transcript = result["text"]
277
  print(f"Transcription completed: {len(transcript)} characters")
278
  return transcript
@@ -281,7 +284,6 @@ def process_audio_to_text(audio_file_path):
281
  raise HTTPException(status_code=400, detail=f"Audio processing failed: {str(e)}")
282
 
283
  def extract_named_entities(text):
284
- """Extracts named entities from legal text."""
285
  max_length = 10000
286
  entities = []
287
  for i in range(0, len(text), max_length):
@@ -290,120 +292,48 @@ def extract_named_entities(text):
290
  entities.extend([{"entity": ent.text, "label": ent.label_} for ent in doc.ents])
291
  return entities
292
 
293
- def analyze_risk(text):
294
- """Analyzes legal risk in the document using keyword-based analysis."""
295
- risk_keywords = {
296
- "Liability": ["liability", "responsible", "responsibility", "legal obligation"],
297
- "Termination": ["termination", "breach", "contract end", "default"],
298
- "Indemnification": ["indemnification", "indemnify", "hold harmless", "compensate", "compensation"],
299
- "Payment Risk": ["payment", "terms", "reimbursement", "fee", "schedule", "invoice", "money"],
300
- "Insurance": ["insurance", "coverage", "policy", "claims"],
301
- }
302
- risk_scores = {category: 0 for category in risk_keywords}
303
- lower_text = text.lower()
304
- for category, keywords in risk_keywords.items():
305
- for keyword in keywords:
306
- risk_scores[category] += lower_text.count(keyword.lower())
307
- return risk_scores
308
-
309
- def extract_context_for_risk_terms(text, risk_keywords, window=1):
310
- """
311
- Extracts and summarizes the context around risk terms.
312
- """
313
- doc = nlp(text)
314
- sentences = list(doc.sents)
315
- risk_contexts = {category: [] for category in risk_keywords}
316
- for i, sent in enumerate(sentences):
317
- sent_text_lower = sent.text.lower()
318
- for category, details in risk_keywords.items():
319
- for keyword in details["keywords"]:
320
- if keyword.lower() in sent_text_lower:
321
- start_idx = max(0, i - window)
322
- end_idx = min(len(sentences), i + window + 1)
323
- context_chunk = " ".join([s.text for s in sentences[start_idx:end_idx]])
324
- risk_contexts[category].append(context_chunk)
325
- summarized_contexts = {}
326
- for category, contexts in risk_contexts.items():
327
- if contexts:
328
- combined_context = " ".join(contexts)
329
- try:
330
- summary_result = summarizer(combined_context, max_length=100, min_length=30, do_sample=False)
331
- summary = summary_result[0]['summary_text']
332
- except Exception as e:
333
- summary = "Context summarization failed."
334
- summarized_contexts[category] = summary
335
- else:
336
- summarized_contexts[category] = "No contextual details found."
337
- return summarized_contexts
338
-
339
- def get_detailed_risk_info(text):
340
- """
341
- Returns detailed risk information by merging risk scores with descriptive details
342
- and contextual summaries from the document.
343
- """
344
- risk_details = {
345
- "Liability": {
346
- "description": "Liability refers to the legal responsibility for losses or damages.",
347
- "common_concerns": "Broad liability clauses may expose parties to unforeseen risks.",
348
- "recommendations": "Review and negotiate clear limits on liability.",
349
- "example": "E.g., 'The party shall be liable for direct damages due to negligence.'"
350
- },
351
- "Termination": {
352
- "description": "Termination involves conditions under which a contract can be ended.",
353
- "common_concerns": "Unilateral termination rights or ambiguous conditions can be risky.",
354
- "recommendations": "Ensure termination clauses are balanced and include notice periods.",
355
- "example": "E.g., 'Either party may terminate the agreement with 30 days notice.'"
356
- },
357
- "Indemnification": {
358
- "description": "Indemnification requires one party to compensate for losses incurred by the other.",
359
- "common_concerns": "Overly broad indemnification can shift significant risk.",
360
- "recommendations": "Negotiate clear limits and carve-outs where necessary.",
361
- "example": "E.g., 'The seller shall indemnify the buyer against claims from product defects.'"
362
- },
363
- "Payment Risk": {
364
- "description": "Payment risk pertains to terms regarding fees, schedules, and reimbursements.",
365
- "common_concerns": "Vague payment terms or hidden charges increase risk.",
366
- "recommendations": "Clarify payment conditions and include penalties for delays.",
367
- "example": "E.g., 'Payments must be made within 30 days, with a 2% late fee thereafter.'"
368
- },
369
- "Insurance": {
370
- "description": "Insurance risk covers the adequacy and scope of required coverage.",
371
- "common_concerns": "Insufficient insurance can leave parties exposed in unexpected events.",
372
- "recommendations": "Review insurance requirements to ensure they meet the risk profile.",
373
- "example": "E.g., 'The contractor must maintain liability insurance with at least $1M coverage.'"
374
- }
375
- }
376
- risk_scores = analyze_risk(text)
377
- risk_keywords_context = {
378
- "Liability": {"keywords": ["liability", "responsible", "responsibility", "legal obligation"]},
379
- "Termination": {"keywords": ["termination", "breach", "contract end", "default"]},
380
- "Indemnification": {"keywords": ["indemnification", "indemnify", "hold harmless", "compensate", "compensation"]},
381
- "Payment Risk": {"keywords": ["payment", "terms", "reimbursement", "fee", "schedule", "invoice", "money"]},
382
- "Insurance": {"keywords": ["insurance", "coverage", "policy", "claims"]}
383
- }
384
- risk_contexts = extract_context_for_risk_terms(text, risk_keywords_context, window=1)
385
- detailed_info = {}
386
- for risk_term, score in risk_scores.items():
387
- if score > 0:
388
- info = risk_details.get(risk_term, {"description": "No details available."})
389
- detailed_info[risk_term] = {
390
- "score": score,
391
- "description": info.get("description", ""),
392
- "common_concerns": info.get("common_concerns", ""),
393
- "recommendations": info.get("recommendations", ""),
394
- "example": info.get("example", ""),
395
- "context_summary": risk_contexts.get(risk_term, "No context available.")
396
- }
397
- return detailed_info
398
 
399
  def analyze_contract_clauses(text):
400
- """Analyzes contract clauses using the fine-tuned CUAD QA model."""
401
  max_length = 512
402
  step = 256
403
  clauses_detected = []
404
  try:
405
  clause_types = list(cuad_model.config.id2label.values())
406
- except Exception as e:
407
  clause_types = [
408
  "Obligations of Seller", "Governing Law", "Termination", "Indemnification",
409
  "Confidentiality", "Insurance", "Non-Compete", "Change of Control",
@@ -426,50 +356,52 @@ def analyze_contract_clauses(text):
426
  aggregated_clauses[clause_type] = clause
427
  return list(aggregated_clauses.values())
428
 
 
 
 
 
429
  @app.post("/analyze_legal_document")
430
  async def analyze_legal_document(file: UploadFile = File(...)):
431
- """Analyzes a legal document for clause detection and compliance risks."""
432
  try:
433
- print(f"Processing file: {file.filename}")
434
  content = await file.read()
435
- text = extract_text_from_pdf(io.BytesIO(content))
 
 
 
 
436
  if not text:
437
  return {"status": "error", "message": "No valid text found in the document."}
438
  summary_text = text[:4096] if len(text) > 4096 else text
439
  summary = summarizer(summary_text, max_length=200, min_length=50, do_sample=False)[0]['summary_text'] if len(text) > 100 else "Document too short for meaningful summarization."
440
- print("Extracting named entities...")
441
  entities = extract_named_entities(text)
442
- print("Analyzing risk...")
443
- risk_scores = analyze_risk(text)
444
- detailed_risk = get_detailed_risk_info(text)
445
- print("Analyzing contract clauses...")
446
  clauses = analyze_contract_clauses(text)
447
  generated_task_id = str(uuid.uuid4())
448
  store_document_context(generated_task_id, text)
449
- return {
450
  "status": "success",
451
  "task_id": generated_task_id,
452
  "summary": summary,
453
  "named_entities": entities,
454
- "risk_scores": risk_scores,
455
- "detailed_risk": detailed_risk,
456
  "clauses_detected": clauses
457
  }
 
 
458
  except Exception as e:
459
- print(f"Error processing document: {str(e)}")
460
  return {"status": "error", "message": str(e)}
461
 
462
  @app.post("/analyze_legal_video")
463
- async def analyze_legal_video(file: UploadFile = File(...)):
464
- """Analyzes a legal video by transcribing audio and analyzing the transcript."""
465
  try:
466
- print(f"Processing video file: {file.filename}")
467
  content = await file.read()
 
 
 
468
  with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_file:
469
  temp_file.write(content)
470
  temp_file_path = temp_file.name
471
- print(f"Temporary file saved at: {temp_file_path}")
472
- text = process_video_to_text(temp_file_path)
473
  if os.path.exists(temp_file_path):
474
  os.remove(temp_file_path)
475
  if not text:
@@ -479,41 +411,37 @@ async def analyze_legal_video(file: UploadFile = File(...)):
479
  f.write(text)
480
  summary_text = text[:4096] if len(text) > 4096 else text
481
  summary = summarizer(summary_text, max_length=200, min_length=50, do_sample=False)[0]['summary_text'] if len(text) > 100 else "Transcript too short for meaningful summarization."
482
- print("Extracting named entities from transcript...")
483
  entities = extract_named_entities(text)
484
- print("Analyzing risk from transcript...")
485
- risk_scores = analyze_risk(text)
486
- detailed_risk = get_detailed_risk_info(text)
487
- print("Analyzing legal clauses from transcript...")
488
  clauses = analyze_contract_clauses(text)
489
  generated_task_id = str(uuid.uuid4())
490
  store_document_context(generated_task_id, text)
491
- return {
492
  "status": "success",
493
  "task_id": generated_task_id,
494
  "transcript": text,
495
  "transcript_path": transcript_path,
496
  "summary": summary,
497
  "named_entities": entities,
498
- "risk_scores": risk_scores,
499
- "detailed_risk": detailed_risk,
500
  "clauses_detected": clauses
501
  }
 
 
502
  except Exception as e:
503
- print(f"Error processing video: {str(e)}")
504
  return {"status": "error", "message": str(e)}
505
 
506
  @app.post("/analyze_legal_audio")
507
- async def analyze_legal_audio(file: UploadFile = File(...)):
508
- """Analyzes legal audio by transcribing and analyzing the transcript."""
509
  try:
510
- print(f"Processing audio file: {file.filename}")
511
  content = await file.read()
 
 
 
512
  with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_file:
513
  temp_file.write(content)
514
  temp_file_path = temp_file.name
515
- print(f"Temporary file saved at: {temp_file_path}")
516
- text = process_audio_to_text(temp_file_path)
517
  if os.path.exists(temp_file_path):
518
  os.remove(temp_file_path)
519
  if not text:
@@ -523,33 +451,28 @@ async def analyze_legal_audio(file: UploadFile = File(...)):
523
  f.write(text)
524
  summary_text = text[:4096] if len(text) > 4096 else text
525
  summary = summarizer(summary_text, max_length=200, min_length=50, do_sample=False)[0]['summary_text'] if len(text) > 100 else "Transcript too short for meaningful summarization."
526
- print("Extracting named entities from transcript...")
527
  entities = extract_named_entities(text)
528
- print("Analyzing risk from transcript...")
529
- risk_scores = analyze_risk(text)
530
- detailed_risk = get_detailed_risk_info(text)
531
- print("Analyzing legal clauses from transcript...")
532
  clauses = analyze_contract_clauses(text)
533
  generated_task_id = str(uuid.uuid4())
534
  store_document_context(generated_task_id, text)
535
- return {
536
  "status": "success",
537
  "task_id": generated_task_id,
538
  "transcript": text,
539
  "transcript_path": transcript_path,
540
  "summary": summary,
541
  "named_entities": entities,
542
- "risk_scores": risk_scores,
543
- "detailed_risk": detailed_risk,
544
  "clauses_detected": clauses
545
  }
 
 
546
  except Exception as e:
547
- print(f"Error processing audio: {str(e)}")
548
  return {"status": "error", "message": str(e)}
549
 
550
  @app.get("/transcript/{transcript_id}")
551
  async def get_transcript(transcript_id: str):
552
- """Retrieves a previously generated transcript."""
553
  transcript_path = os.path.join("static", f"transcript_{transcript_id}.txt")
554
  if os.path.exists(transcript_path):
555
  return FileResponse(transcript_path)
@@ -558,7 +481,6 @@ async def get_transcript(transcript_id: str):
558
 
559
  @app.post("/legal_chatbot")
560
  async def legal_chatbot_api(query: str = Form(...), task_id: str = Form(...)):
561
- """Handles legal Q&A using chat history and document context."""
562
  document_context = load_document_context(task_id)
563
  if not document_context:
564
  return {"response": "⚠️ No relevant document found for this task ID."}
@@ -576,7 +498,6 @@ async def health_check():
576
  }
577
 
578
  def setup_ngrok():
579
- """Sets up ngrok tunnel for Google Colab."""
580
  try:
581
  auth_token = os.environ.get("NGROK_AUTH_TOKEN")
582
  if auth_token:
@@ -603,65 +524,59 @@ def setup_ngrok():
603
  print(f"⚠️ Ngrok setup error: {e}")
604
  return None
605
 
606
- from fastapi.responses import FileResponse
 
 
607
 
608
  @app.get("/download_risk_chart")
609
- async def download_risk_chart():
610
- """Generate and return a risk assessment chart as an image file."""
611
  try:
612
- os.makedirs("static", exist_ok=True)
613
- risk_scores = {
614
- "Liability": 11,
615
- "Termination": 12,
616
- "Indemnification": 10,
617
- "Payment Risk": 41,
618
- "Insurance": 71
619
- }
620
  plt.figure(figsize=(8, 5))
621
- plt.bar(risk_scores.keys(), risk_scores.values(), color='red')
622
- plt.xlabel("Risk Categories")
623
  plt.ylabel("Risk Score")
624
- plt.title("Legal Risk Assessment")
625
- plt.xticks(rotation=30)
626
- risk_chart_path = "static/risk_chart.png"
627
  plt.savefig(risk_chart_path)
628
  plt.close()
629
- return FileResponse(risk_chart_path, media_type="image/png", filename="risk_chart.png")
630
  except Exception as e:
631
  raise HTTPException(status_code=500, detail=f"Error generating risk chart: {str(e)}")
632
 
633
  @app.get("/download_risk_pie_chart")
634
- async def download_risk_pie_chart():
635
  try:
636
- risk_scores = {
637
- "Liability": 11,
638
- "Termination": 12,
639
- "Indemnification": 10,
640
- "Payment Risk": 41,
641
- "Insurance": 71
642
- }
 
 
643
  plt.figure(figsize=(6, 6))
644
- plt.pie(risk_scores.values(), labels=risk_scores.keys(), autopct='%1.1f%%', startangle=90)
645
- plt.title("Legal Risk Distribution")
646
- pie_chart_path = "static/risk_pie_chart.png"
647
  plt.savefig(pie_chart_path)
648
  plt.close()
649
- return FileResponse(pie_chart_path, media_type="image/png", filename="risk_pie_chart.png")
650
  except Exception as e:
651
  raise HTTPException(status_code=500, detail=f"Error generating pie chart: {str(e)}")
652
 
653
  @app.get("/download_risk_radar_chart")
654
- async def download_risk_radar_chart():
655
  try:
656
- risk_scores = {
657
- "Liability": 11,
658
- "Termination": 12,
659
- "Indemnification": 10,
660
- "Payment Risk": 41,
661
- "Insurance": 71
662
- }
663
- categories = list(risk_scores.keys())
664
- values = list(risk_scores.values())
665
  categories += categories[:1]
666
  values += values[:1]
667
  angles = np.linspace(0, 2 * np.pi, len(categories), endpoint=False).tolist()
@@ -669,66 +584,61 @@ async def download_risk_radar_chart():
669
  fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True))
670
  ax.plot(angles, values, 'o-', linewidth=2)
671
  ax.fill(angles, values, alpha=0.25)
672
- ax.set_thetagrids(np.degrees(angles[:-1]), categories)
673
- ax.set_title("Legal Risk Radar Chart", y=1.1)
674
- radar_chart_path = "static/risk_radar_chart.png"
675
  plt.savefig(radar_chart_path)
676
  plt.close()
677
- return FileResponse(radar_chart_path, media_type="image/png", filename="risk_radar_chart.png")
678
  except Exception as e:
679
  raise HTTPException(status_code=500, detail=f"Error generating radar chart: {str(e)}")
680
 
681
  @app.get("/download_risk_trend_chart")
682
- async def download_risk_trend_chart():
683
  try:
684
- dates = ["2025-01-01", "2025-02-01", "2025-03-01", "2025-04-01"]
685
- risk_history = {
686
- "Liability": [10, 12, 11, 13],
687
- "Termination": [12, 15, 14, 13],
688
- "Indemnification": [9, 10, 11, 10],
689
- "Payment Risk": [40, 42, 41, 43],
690
- "Insurance": [70, 69, 71, 72]
691
- }
 
 
 
692
  plt.figure(figsize=(10, 6))
693
- for category, scores in risk_history.items():
694
- plt.plot(dates, scores, marker='o', label=category)
695
- plt.xlabel("Date")
696
  plt.ylabel("Risk Score")
697
- plt.title("Historical Legal Risk Trends")
698
  plt.xticks(rotation=45)
699
- plt.legend()
700
- trend_chart_path = "static/risk_trend_chart.png"
701
  plt.savefig(trend_chart_path, bbox_inches="tight")
702
  plt.close()
703
- return FileResponse(trend_chart_path, media_type="image/png", filename="risk_trend_chart.png")
704
  except Exception as e:
705
  raise HTTPException(status_code=500, detail=f"Error generating trend chart: {str(e)}")
706
 
707
- import pandas as pd
708
- import plotly.express as px
709
- from fastapi.responses import HTMLResponse
710
-
711
  @app.get("/interactive_risk_chart", response_class=HTMLResponse)
712
- async def interactive_risk_chart():
713
  try:
714
- risk_scores = {
715
- "Liability": 11,
716
- "Termination": 12,
717
- "Indemnification": 10,
718
- "Payment Risk": 41,
719
- "Insurance": 71
720
- }
721
  df = pd.DataFrame({
722
- "Risk Category": list(risk_scores.keys()),
723
- "Risk Score": list(risk_scores.values())
724
  })
725
- fig = px.bar(df, x="Risk Category", y="Risk Score", title="Interactive Legal Risk Assessment")
726
  return fig.to_html()
727
  except Exception as e:
728
  raise HTTPException(status_code=500, detail=f"Error generating interactive chart: {str(e)}")
729
 
730
  def run():
731
- """Starts the FastAPI server."""
732
  print("Starting FastAPI server...")
733
  uvicorn.run(app, host="0.0.0.0", port=8500, timeout_keep_alive=600)
734
 
@@ -739,3 +649,4 @@ if __name__ == "__main__":
739
  else:
740
  print("\n⚠️ Ngrok setup failed. API will only be available locally.\n")
741
  run()
 
 
12
  import numpy as np
13
  import json
14
  import tempfile
15
+ from fastapi import FastAPI, UploadFile, File, HTTPException, Form, BackgroundTasks
16
  from fastapi.responses import FileResponse, JSONResponse
17
  from fastapi.middleware.cors import CORSMiddleware
18
  from transformers import pipeline, AutoModelForQuestionAnswering, AutoTokenizer
 
22
  import time
23
  import uuid
24
  import subprocess # For running ffmpeg commands
25
+ import hashlib # For caching file results
26
+
27
+ # For asynchronous blocking calls
28
+ from starlette.concurrency import run_in_threadpool
29
+
30
+ # Import gensim for topic modeling
31
+ import gensim
32
+ from gensim import corpora, models
33
+
34
+ # Global cache for analysis results based on file hash
35
+ analysis_cache = {}
36
 
37
  # Ensure compatibility with Google Colab
38
  try:
 
60
  allow_headers=["*"],
61
  )
62
 
63
+ # In-memory storage for document text and chat history
64
  document_storage = {}
65
  chat_history = []
66
 
 
73
  def load_document_context(task_id):
74
  return document_storage.get(task_id, "")
75
 
76
+ # Utility to compute MD5 hash from file content
77
+ def compute_md5(content: bytes) -> str:
78
+ return hashlib.md5(content).hexdigest()
79
+
80
  #############################
81
  # Fine-tuning on CUAD QA #
82
  #############################
83
 
84
  def fine_tune_cuad_model():
 
 
 
 
 
85
  from datasets import load_dataset
86
  import numpy as np
87
  from transformers import Trainer, TrainingArguments, AutoModelForQuestionAnswering
 
90
  dataset = load_dataset("theatticusproject/cuad-qa", trust_remote_code=True)
91
 
92
  if "train" in dataset:
 
93
  train_dataset = dataset["train"].select(range(50))
94
  if "validation" in dataset:
 
95
  val_dataset = dataset["validation"].select(range(10))
96
  else:
97
  split = train_dataset.train_test_split(test_size=0.2)
 
101
  raise ValueError("CUAD dataset does not have a train split")
102
 
103
  print("✅ Preparing training features...")
 
104
  tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
105
  model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2")
106
 
 
152
  print("✅ Tokenizing dataset...")
153
  train_dataset = train_dataset.map(prepare_train_features, batched=True, remove_columns=train_dataset.column_names)
154
  val_dataset = val_dataset.map(prepare_train_features, batched=True, remove_columns=val_dataset.column_names)
 
155
  train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "start_positions", "end_positions"])
156
  val_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "start_positions", "end_positions"])
157
 
 
158
  training_args = TrainingArguments(
159
  output_dir="./fine_tuned_legal_qa",
160
+ max_steps=1,
161
+ evaluation_strategy="no",
162
  learning_rate=2e-5,
163
  per_device_train_batch_size=4,
164
  per_device_eval_batch_size=4,
 
167
  logging_steps=1,
168
  save_steps=1,
169
  load_best_model_at_end=False,
170
+ report_to=[]
171
  )
172
 
173
  print("✅ Starting fine tuning on CUAD QA dataset...")
 
179
  eval_dataset=val_dataset,
180
  tokenizer=tokenizer,
181
  )
 
182
  trainer.train()
183
  print("✅ Fine tuning completed. Saving model...")
 
184
  model.save_pretrained("./fine_tuned_legal_qa")
185
  tokenizer.save_pretrained("./fine_tuned_legal_qa")
 
186
  return tokenizer, model
187
 
188
  #############################
 
196
  spacy.cli.download("en_core_web_sm")
197
  nlp = spacy.load("en_core_web_sm")
198
  print("✅ Loading NLP models...")
 
 
199
  from transformers import PegasusTokenizer
200
  summarizer = pipeline(
201
  "summarization",
 
203
  tokenizer=PegasusTokenizer.from_pretrained("nsi319/legal-pegasus", use_fast=False),
204
  device=0 if torch.cuda.is_available() else -1
205
  )
206
+ # Optionally convert summarizer model to FP16 for faster inference on GPU
207
+ if device == "cuda":
208
+ summarizer.model.half()
209
+
210
  embedding_model = SentenceTransformer("all-mpnet-base-v2", device=device)
211
  ner_model = pipeline("ner", model="dslim/bert-base-NER", device=0 if torch.cuda.is_available() else -1)
212
  speech_to_text = pipeline("automatic-speech-recognition", model="openai/whisper-medium", chunk_length_s=30,
213
  device_map="auto" if torch.cuda.is_available() else "cpu")
 
214
  if os.path.exists("fine_tuned_legal_qa"):
215
  print("✅ Loading fine-tuned CUAD QA model from fine_tuned_legal_qa...")
216
  cuad_tokenizer = AutoTokenizer.from_pretrained("fine_tuned_legal_qa")
217
  from transformers import AutoModelForQuestionAnswering
218
  cuad_model = AutoModelForQuestionAnswering.from_pretrained("fine_tuned_legal_qa")
219
  cuad_model.to(device)
220
+ if device == "cuda":
221
+ cuad_model.half()
222
  else:
223
  print("⚠️ Fine-tuned QA model not found. Starting fine tuning on CUAD QA dataset. This may take a while...")
224
  cuad_tokenizer, cuad_model = fine_tune_cuad_model()
225
  cuad_model.to(device)
 
226
  print("✅ All models loaded successfully")
 
227
  except Exception as e:
228
  print(f"⚠️ Error loading models: {str(e)}")
229
  raise RuntimeError(f"Error loading models: {str(e)}")
 
231
  from transformers import pipeline
232
  qa_model = pipeline("question-answering", model="deepset/roberta-base-squad2")
233
 
234
+ # Initialize sentiment-analysis pipeline
235
+ sentiment_pipeline = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english", device=0 if torch.cuda.is_available() else -1)
236
+
237
  def legal_chatbot(user_input, context):
 
238
  global chat_history
239
  chat_history.append({"role": "user", "content": user_input})
240
  response = qa_model(question=user_input, context=context)["answer"]
 
242
  return response
243
 
244
  def extract_text_from_pdf(pdf_file):
 
245
  try:
246
  with pdfplumber.open(pdf_file) as pdf:
247
  text = "\n".join([page.extract_text() or "" for page in pdf.pages])
 
249
  except Exception as e:
250
  raise HTTPException(status_code=400, detail=f"PDF extraction failed: {str(e)}")
251
 
252
+ async def process_video_to_text(video_file_path):
 
253
  try:
254
  print(f"Processing video file at {video_file_path}")
255
  temp_audio_path = os.path.join("temp", "extracted_audio.wav")
 
258
  "-acodec", "pcm_s16le", "-ar", "44100", "-ac", "2",
259
  temp_audio_path, "-y"
260
  ]
261
+ # Run ffmpeg in a separate thread
262
+ await run_in_threadpool(subprocess.run, cmd, check=True)
263
  print(f"Audio extracted to {temp_audio_path}")
264
+ # Run speech-to-text in threadpool
265
+ result = await run_in_threadpool(speech_to_text, temp_audio_path)
266
  transcript = result["text"]
267
  print(f"Transcription completed: {len(transcript)} characters")
268
  if os.path.exists(temp_audio_path):
 
272
  print(f"Error in video processing: {str(e)}")
273
  raise HTTPException(status_code=400, detail=f"Video processing failed: {str(e)}")
274
 
275
+ async def process_audio_to_text(audio_file_path):
 
276
  try:
277
  print(f"Processing audio file at {audio_file_path}")
278
+ result = await run_in_threadpool(speech_to_text, audio_file_path)
279
  transcript = result["text"]
280
  print(f"Transcription completed: {len(transcript)} characters")
281
  return transcript
 
284
  raise HTTPException(status_code=400, detail=f"Audio processing failed: {str(e)}")
285
 
286
  def extract_named_entities(text):
 
287
  max_length = 10000
288
  entities = []
289
  for i in range(0, len(text), max_length):
 
292
  entities.extend([{"entity": ent.text, "label": ent.label_} for ent in doc.ents])
293
  return entities
294
 
295
+ # -----------------------------
296
+ # Enhanced Risk Analysis Functions
297
+ # -----------------------------
298
+
299
+ def analyze_sentiment(text):
300
+ sentences = [sent.text for sent in nlp(text).sents]
301
+ if not sentences:
302
+ return 0
303
+ results = sentiment_pipeline(sentences, batch_size=16)
304
+ scores = [res["score"] if res["label"] == "POSITIVE" else -res["score"] for res in results]
305
+ avg_sentiment = sum(scores) / len(scores) if scores else 0
306
+ return avg_sentiment
307
+
308
+ def analyze_topics(text, num_topics=3):
309
+ tokens = gensim.utils.simple_preprocess(text, deacc=True)
310
+ if not tokens:
311
+ return []
312
+ dictionary = corpora.Dictionary([tokens])
313
+ corpus = [dictionary.doc2bow(tokens)]
314
+ lda_model = models.LdaModel(corpus, num_topics=num_topics, id2word=dictionary, passes=10)
315
+ topics = lda_model.print_topics(num_topics=num_topics)
316
+ return topics
317
+
318
+ def get_enhanced_context_info(text):
319
+ enhanced = {}
320
+ enhanced["average_sentiment"] = analyze_sentiment(text)
321
+ enhanced["topics"] = analyze_topics(text, num_topics=5)
322
+ return enhanced
323
+
324
+ def analyze_risk_enhanced(text):
325
+ enhanced = get_enhanced_context_info(text)
326
+ avg_sentiment = enhanced["average_sentiment"]
327
+ risk_score = abs(avg_sentiment) if avg_sentiment < 0 else 0
328
+ return {"risk_score": risk_score, "average_sentiment": avg_sentiment, "topics": enhanced["topics"]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
 
330
  def analyze_contract_clauses(text):
 
331
  max_length = 512
332
  step = 256
333
  clauses_detected = []
334
  try:
335
  clause_types = list(cuad_model.config.id2label.values())
336
+ except Exception:
337
  clause_types = [
338
  "Obligations of Seller", "Governing Law", "Termination", "Indemnification",
339
  "Confidentiality", "Insurance", "Non-Compete", "Change of Control",
 
356
  aggregated_clauses[clause_type] = clause
357
  return list(aggregated_clauses.values())
358
 
359
+ # -----------------------------
360
+ # Endpoints
361
+ # -----------------------------
362
+
363
  @app.post("/analyze_legal_document")
364
  async def analyze_legal_document(file: UploadFile = File(...)):
 
365
  try:
 
366
  content = await file.read()
367
+ file_hash = compute_md5(content)
368
+ # Return cached result if available
369
+ if file_hash in analysis_cache:
370
+ return analysis_cache[file_hash]
371
+ text = await run_in_threadpool(extract_text_from_pdf, io.BytesIO(content))
372
  if not text:
373
  return {"status": "error", "message": "No valid text found in the document."}
374
  summary_text = text[:4096] if len(text) > 4096 else text
375
  summary = summarizer(summary_text, max_length=200, min_length=50, do_sample=False)[0]['summary_text'] if len(text) > 100 else "Document too short for meaningful summarization."
 
376
  entities = extract_named_entities(text)
377
+ risk_analysis = analyze_risk_enhanced(text)
 
 
 
378
  clauses = analyze_contract_clauses(text)
379
  generated_task_id = str(uuid.uuid4())
380
  store_document_context(generated_task_id, text)
381
+ result = {
382
  "status": "success",
383
  "task_id": generated_task_id,
384
  "summary": summary,
385
  "named_entities": entities,
386
+ "risk_analysis": risk_analysis,
 
387
  "clauses_detected": clauses
388
  }
389
+ analysis_cache[file_hash] = result
390
+ return result
391
  except Exception as e:
 
392
  return {"status": "error", "message": str(e)}
393
 
394
  @app.post("/analyze_legal_video")
395
+ async def analyze_legal_video(file: UploadFile = File(...), background_tasks: BackgroundTasks = None):
 
396
  try:
 
397
  content = await file.read()
398
+ file_hash = compute_md5(content)
399
+ if file_hash in analysis_cache:
400
+ return analysis_cache[file_hash]
401
  with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_file:
402
  temp_file.write(content)
403
  temp_file_path = temp_file.name
404
+ text = await process_video_to_text(temp_file_path)
 
405
  if os.path.exists(temp_file_path):
406
  os.remove(temp_file_path)
407
  if not text:
 
411
  f.write(text)
412
  summary_text = text[:4096] if len(text) > 4096 else text
413
  summary = summarizer(summary_text, max_length=200, min_length=50, do_sample=False)[0]['summary_text'] if len(text) > 100 else "Transcript too short for meaningful summarization."
 
414
  entities = extract_named_entities(text)
415
+ risk_analysis = analyze_risk_enhanced(text)
 
 
 
416
  clauses = analyze_contract_clauses(text)
417
  generated_task_id = str(uuid.uuid4())
418
  store_document_context(generated_task_id, text)
419
+ result = {
420
  "status": "success",
421
  "task_id": generated_task_id,
422
  "transcript": text,
423
  "transcript_path": transcript_path,
424
  "summary": summary,
425
  "named_entities": entities,
426
+ "risk_analysis": risk_analysis,
 
427
  "clauses_detected": clauses
428
  }
429
+ analysis_cache[file_hash] = result
430
+ return result
431
  except Exception as e:
 
432
  return {"status": "error", "message": str(e)}
433
 
434
  @app.post("/analyze_legal_audio")
435
+ async def analyze_legal_audio(file: UploadFile = File(...), background_tasks: BackgroundTasks = None):
 
436
  try:
 
437
  content = await file.read()
438
+ file_hash = compute_md5(content)
439
+ if file_hash in analysis_cache:
440
+ return analysis_cache[file_hash]
441
  with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_file:
442
  temp_file.write(content)
443
  temp_file_path = temp_file.name
444
+ text = await process_audio_to_text(temp_file_path)
 
445
  if os.path.exists(temp_file_path):
446
  os.remove(temp_file_path)
447
  if not text:
 
451
  f.write(text)
452
  summary_text = text[:4096] if len(text) > 4096 else text
453
  summary = summarizer(summary_text, max_length=200, min_length=50, do_sample=False)[0]['summary_text'] if len(text) > 100 else "Transcript too short for meaningful summarization."
 
454
  entities = extract_named_entities(text)
455
+ risk_analysis = analyze_risk_enhanced(text)
 
 
 
456
  clauses = analyze_contract_clauses(text)
457
  generated_task_id = str(uuid.uuid4())
458
  store_document_context(generated_task_id, text)
459
+ result = {
460
  "status": "success",
461
  "task_id": generated_task_id,
462
  "transcript": text,
463
  "transcript_path": transcript_path,
464
  "summary": summary,
465
  "named_entities": entities,
466
+ "risk_analysis": risk_analysis,
 
467
  "clauses_detected": clauses
468
  }
469
+ analysis_cache[file_hash] = result
470
+ return result
471
  except Exception as e:
 
472
  return {"status": "error", "message": str(e)}
473
 
474
  @app.get("/transcript/{transcript_id}")
475
  async def get_transcript(transcript_id: str):
 
476
  transcript_path = os.path.join("static", f"transcript_{transcript_id}.txt")
477
  if os.path.exists(transcript_path):
478
  return FileResponse(transcript_path)
 
481
 
482
  @app.post("/legal_chatbot")
483
  async def legal_chatbot_api(query: str = Form(...), task_id: str = Form(...)):
 
484
  document_context = load_document_context(task_id)
485
  if not document_context:
486
  return {"response": "⚠️ No relevant document found for this task ID."}
 
498
  }
499
 
500
  def setup_ngrok():
 
501
  try:
502
  auth_token = os.environ.get("NGROK_AUTH_TOKEN")
503
  if auth_token:
 
524
  print(f"⚠️ Ngrok setup error: {e}")
525
  return None
526
 
527
+ # ------------------------------
528
+ # Dynamic Visualization Endpoints
529
+ # ------------------------------
530
 
531
  @app.get("/download_risk_chart")
532
+ async def download_risk_chart(task_id: str):
 
533
  try:
534
+ text = load_document_context(task_id)
535
+ if not text:
536
+ raise HTTPException(status_code=404, detail="Document context not found")
537
+ risk_analysis = analyze_risk_enhanced(text)
 
 
 
 
538
  plt.figure(figsize=(8, 5))
539
+ plt.bar(["Risk Score"], [risk_analysis["risk_score"]], color='red')
 
540
  plt.ylabel("Risk Score")
541
+ plt.title("Legal Risk Assessment (Enhanced)")
542
+ risk_chart_path = os.path.join("static", f"risk_chart_{task_id}.png")
 
543
  plt.savefig(risk_chart_path)
544
  plt.close()
545
+ return FileResponse(risk_chart_path, media_type="image/png", filename=f"risk_chart_{task_id}.png")
546
  except Exception as e:
547
  raise HTTPException(status_code=500, detail=f"Error generating risk chart: {str(e)}")
548
 
549
  @app.get("/download_risk_pie_chart")
550
+ async def download_risk_pie_chart(task_id: str):
551
  try:
552
+ text = load_document_context(task_id)
553
+ if not text:
554
+ raise HTTPException(status_code=404, detail="Document context not found")
555
+ risk_analysis = analyze_risk_enhanced(text)
556
+ labels = ["Risk", "No Risk"]
557
+ # Ensure the values are within [0,1]
558
+ risk_value = risk_analysis["risk_score"]
559
+ risk_value = min(max(risk_value, 0), 1)
560
+ values = [risk_value, 1 - risk_value]
561
  plt.figure(figsize=(6, 6))
562
+ plt.pie(values, labels=labels, autopct='%1.1f%%', startangle=90)
563
+ plt.title("Legal Risk Distribution (Enhanced)")
564
+ pie_chart_path = os.path.join("static", f"risk_pie_chart_{task_id}.png")
565
  plt.savefig(pie_chart_path)
566
  plt.close()
567
+ return FileResponse(pie_chart_path, media_type="image/png", filename=f"risk_pie_chart_{task_id}.png")
568
  except Exception as e:
569
  raise HTTPException(status_code=500, detail=f"Error generating pie chart: {str(e)}")
570
 
571
  @app.get("/download_risk_radar_chart")
572
+ async def download_risk_radar_chart(task_id: str):
573
  try:
574
+ text = load_document_context(task_id)
575
+ if not text:
576
+ raise HTTPException(status_code=404, detail="Document context not found")
577
+ risk_analysis = analyze_risk_enhanced(text)
578
+ categories = ["Average Sentiment", "Risk Score"]
579
+ values = [risk_analysis["average_sentiment"], risk_analysis["risk_score"]]
 
 
 
580
  categories += categories[:1]
581
  values += values[:1]
582
  angles = np.linspace(0, 2 * np.pi, len(categories), endpoint=False).tolist()
 
584
  fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True))
585
  ax.plot(angles, values, 'o-', linewidth=2)
586
  ax.fill(angles, values, alpha=0.25)
587
+ ax.set_thetagrids(np.degrees(angles[:-1]), ["Sentiment", "Risk"])
588
+ ax.set_title("Legal Risk Radar Chart (Enhanced)", y=1.1)
589
+ radar_chart_path = os.path.join("static", f"risk_radar_chart_{task_id}.png")
590
  plt.savefig(radar_chart_path)
591
  plt.close()
592
+ return FileResponse(radar_chart_path, media_type="image/png", filename=f"risk_radar_chart_{task_id}.png")
593
  except Exception as e:
594
  raise HTTPException(status_code=500, detail=f"Error generating radar chart: {str(e)}")
595
 
596
  @app.get("/download_risk_trend_chart")
597
+ async def download_risk_trend_chart(task_id: str):
598
  try:
599
+ text = load_document_context(task_id)
600
+ if not text:
601
+ raise HTTPException(status_code=404, detail="Document context not found")
602
+ words = text.split()
603
+ segments = np.array_split(words, 4)
604
+ segment_texts = [" ".join(segment) for segment in segments]
605
+ trend_scores = []
606
+ for segment in segment_texts:
607
+ risk = analyze_risk_enhanced(segment)
608
+ trend_scores.append(risk["risk_score"])
609
+ segments_labels = [f"Segment {i+1}" for i in range(len(segment_texts))]
610
  plt.figure(figsize=(10, 6))
611
+ plt.plot(segments_labels, trend_scores, marker='o')
612
+ plt.xlabel("Document Segments")
 
613
  plt.ylabel("Risk Score")
614
+ plt.title("Dynamic Legal Risk Trends (Enhanced)")
615
  plt.xticks(rotation=45)
616
+ trend_chart_path = os.path.join("static", f"risk_trend_chart_{task_id}.png")
 
617
  plt.savefig(trend_chart_path, bbox_inches="tight")
618
  plt.close()
619
+ return FileResponse(trend_chart_path, media_type="image/png", filename=f"risk_trend_chart_{task_id}.png")
620
  except Exception as e:
621
  raise HTTPException(status_code=500, detail=f"Error generating trend chart: {str(e)}")
622
 
 
 
 
 
623
  @app.get("/interactive_risk_chart", response_class=HTMLResponse)
624
+ async def interactive_risk_chart(task_id: str):
625
  try:
626
+ import pandas as pd
627
+ import plotly.express as px
628
+ text = load_document_context(task_id)
629
+ if not text:
630
+ raise HTTPException(status_code=404, detail="Document context not found")
631
+ risk_analysis = analyze_risk_enhanced(text)
 
632
  df = pd.DataFrame({
633
+ "Metric": ["Average Sentiment", "Risk Score"],
634
+ "Value": [risk_analysis["average_sentiment"], risk_analysis["risk_score"]]
635
  })
636
+ fig = px.bar(df, x="Metric", y="Value", title="Interactive Enhanced Legal Risk Assessment")
637
  return fig.to_html()
638
  except Exception as e:
639
  raise HTTPException(status_code=500, detail=f"Error generating interactive chart: {str(e)}")
640
 
641
  def run():
 
642
  print("Starting FastAPI server...")
643
  uvicorn.run(app, host="0.0.0.0", port=8500, timeout_keep_alive=600)
644
 
 
649
  else:
650
  print("\n⚠️ Ngrok setup failed. API will only be available locally.\n")
651
  run()
652
+