tejash300 commited on
Commit
74d93ff
·
verified ·
1 Parent(s): 5916467

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -14
app.py CHANGED
@@ -45,7 +45,7 @@ except Exception:
45
  os.makedirs("static", exist_ok=True)
46
  os.makedirs("temp", exist_ok=True)
47
 
48
- # Ensure GPU usage
49
  device = "cuda" if torch.cuda.is_available() else "cpu"
50
 
51
  # Initialize FastAPI
@@ -64,16 +64,13 @@ app.add_middleware(
64
  document_storage = {}
65
  chat_history = []
66
 
67
- # Function to store document context by task ID
68
  def store_document_context(task_id, text):
69
  document_storage[task_id] = text
70
  return True
71
 
72
- # Function to load document context by task ID
73
  def load_document_context(task_id):
74
  return document_storage.get(task_id, "")
75
 
76
- # Utility to compute MD5 hash from file content
77
  def compute_md5(content: bytes) -> str:
78
  return hashlib.md5(content).hexdigest()
79
 
@@ -196,17 +193,14 @@ try:
196
  spacy.cli.download("en_core_web_sm")
197
  nlp = spacy.load("en_core_web_sm")
198
  print("✅ Loading NLP models...")
199
- # Use Facebook's bart-large-cnn for summarization
200
  summarizer = pipeline(
201
  "summarization",
202
- model="facebook/bart-large-cnn",
203
- tokenizer="facebook/bart-large-cnn",
204
  device=0 if torch.cuda.is_available() else -1
205
  )
206
- # Removed FP16 conversion for summarizer to avoid CUDA errors
207
- # if device == "cuda":
208
- # summarizer.model.half()
209
-
210
  embedding_model = SentenceTransformer("all-mpnet-base-v2", device=device)
211
  ner_model = pipeline("ner", model="dslim/bert-base-NER", device=0 if torch.cuda.is_available() else -1)
212
  speech_to_text = pipeline("automatic-speech-recognition", model="openai/whisper-medium", chunk_length_s=30,
@@ -373,7 +367,10 @@ async def analyze_legal_document(file: UploadFile = File(...)):
373
  if not text:
374
  return {"status": "error", "message": "No valid text found in the document."}
375
  summary_text = text[:4096] if len(text) > 4096 else text
376
- summary = summarizer(summary_text, max_length=200, min_length=50, do_sample=False)[0]['summary_text'] if len(text) > 100 else "Document too short for meaningful summarization."
 
 
 
377
  entities = extract_named_entities(text)
378
  risk_analysis = analyze_risk_enhanced(text)
379
  clauses = analyze_contract_clauses(text)
@@ -411,7 +408,10 @@ async def analyze_legal_video(file: UploadFile = File(...), background_tasks: Ba
411
  with open(transcript_path, "w") as f:
412
  f.write(text)
413
  summary_text = text[:4096] if len(text) > 4096 else text
414
- summary = summarizer(summary_text, max_length=200, min_length=50, do_sample=False)[0]['summary_text'] if len(text) > 100 else "Transcript too short for meaningful summarization."
 
 
 
415
  entities = extract_named_entities(text)
416
  risk_analysis = analyze_risk_enhanced(text)
417
  clauses = analyze_contract_clauses(text)
@@ -451,7 +451,10 @@ async def analyze_legal_audio(file: UploadFile = File(...), background_tasks: Ba
451
  with open(transcript_path, "w") as f:
452
  f.write(text)
453
  summary_text = text[:4096] if len(text) > 4096 else text
454
- summary = summarizer(summary_text, max_length=200, min_length=50, do_sample=False)[0]['summary_text'] if len(text) > 100 else "Transcript too short for meaningful summarization."
 
 
 
455
  entities = extract_named_entities(text)
456
  risk_analysis = analyze_risk_enhanced(text)
457
  clauses = analyze_contract_clauses(text)
 
45
  os.makedirs("static", exist_ok=True)
46
  os.makedirs("temp", exist_ok=True)
47
 
48
+ # Set device to GPU if available
49
  device = "cuda" if torch.cuda.is_available() else "cpu"
50
 
51
  # Initialize FastAPI
 
64
  document_storage = {}
65
  chat_history = []
66
 
 
67
  def store_document_context(task_id, text):
68
  document_storage[task_id] = text
69
  return True
70
 
 
71
  def load_document_context(task_id):
72
  return document_storage.get(task_id, "")
73
 
 
74
  def compute_md5(content: bytes) -> str:
75
  return hashlib.md5(content).hexdigest()
76
 
 
193
  spacy.cli.download("en_core_web_sm")
194
  nlp = spacy.load("en_core_web_sm")
195
  print("✅ Loading NLP models...")
196
+ # Use T5-base for summarization and run it on GPU (device=0)
197
  summarizer = pipeline(
198
  "summarization",
199
+ model="t5-base",
200
+ tokenizer="t5-base",
201
  device=0 if torch.cuda.is_available() else -1
202
  )
203
+ # Do NOT convert the summarizer model to FP16 to reduce risk of CUDA errors
 
 
 
204
  embedding_model = SentenceTransformer("all-mpnet-base-v2", device=device)
205
  ner_model = pipeline("ner", model="dslim/bert-base-NER", device=0 if torch.cuda.is_available() else -1)
206
  speech_to_text = pipeline("automatic-speech-recognition", model="openai/whisper-medium", chunk_length_s=30,
 
367
  if not text:
368
  return {"status": "error", "message": "No valid text found in the document."}
369
  summary_text = text[:4096] if len(text) > 4096 else text
370
+ summary_result = summarizer(summary_text, max_length=200, min_length=50, do_sample=False)
371
+ summary = summary_result[0].get("summary_text", "")
372
+ if not summary:
373
+ summary = "Summary not generated. Please check the input text."
374
  entities = extract_named_entities(text)
375
  risk_analysis = analyze_risk_enhanced(text)
376
  clauses = analyze_contract_clauses(text)
 
408
  with open(transcript_path, "w") as f:
409
  f.write(text)
410
  summary_text = text[:4096] if len(text) > 4096 else text
411
+ summary_result = summarizer(summary_text, max_length=200, min_length=50, do_sample=False)
412
+ summary = summary_result[0].get("summary_text", "")
413
+ if not summary:
414
+ summary = "Summary not generated. Please check the input transcript."
415
  entities = extract_named_entities(text)
416
  risk_analysis = analyze_risk_enhanced(text)
417
  clauses = analyze_contract_clauses(text)
 
451
  with open(transcript_path, "w") as f:
452
  f.write(text)
453
  summary_text = text[:4096] if len(text) > 4096 else text
454
+ summary_result = summarizer(summary_text, max_length=200, min_length=50, do_sample=False)
455
+ summary = summary_result[0].get("summary_text", "")
456
+ if not summary:
457
+ summary = "Summary not generated. Please check the input transcript."
458
  entities = extract_named_entities(text)
459
  risk_analysis = analyze_risk_enhanced(text)
460
  clauses = analyze_contract_clauses(text)