tejash300 commited on
Commit
64af888
·
verified ·
1 Parent(s): c575db1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -18
app.py CHANGED
@@ -23,21 +23,21 @@ import time
23
  import uuid
24
  import subprocess # For running ffmpeg commands
25
 
26
- # Ensure compatibility with Google Colab
27
  try:
28
  from google.colab import drive
29
  drive.mount('/content/drive')
30
- except:
31
  pass # Skip drive mount if not in Google Colab
32
 
33
- # Ensure required directories exist
34
  os.makedirs("static", exist_ok=True)
35
  os.makedirs("temp", exist_ok=True)
36
 
37
- # Ensure GPU usage
38
  device = "cuda" if torch.cuda.is_available() else "cpu"
39
 
40
- # Initialize FastAPI
41
  app = FastAPI(title="Legal Document and Video Analyzer")
42
 
43
  # Add CORS middleware
@@ -49,19 +49,17 @@ app.add_middleware(
49
  allow_headers=["*"],
50
  )
51
 
52
- # Initialize document storage
53
  document_storage = {}
54
- chat_history = [] # Global chat history
55
 
56
- # Function to store document context by task ID
57
  def store_document_context(task_id, text):
58
- """Store document text for retrieval by chatbot."""
59
  document_storage[task_id] = text
60
  return True
61
 
62
- # Function to load document context by task ID
63
  def load_document_context(task_id):
64
- """Retrieve document text for chatbot context."""
65
  return document_storage.get(task_id, "")
66
 
67
  #############################
@@ -72,7 +70,7 @@ def fine_tune_cuad_model():
72
  """
73
  Fine tunes a QA model on the CUAD dataset for clause extraction.
74
  For testing, we use only 50 training examples (and 10 for validation)
75
- and set training arguments for very fast, minimal training.
76
  """
77
  from datasets import load_dataset
78
  import numpy as np
@@ -151,18 +149,19 @@ def fine_tune_cuad_model():
151
  train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "start_positions", "end_positions"])
152
  val_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "start_positions", "end_positions"])
153
 
154
- # Adjust training arguments for fast testing
155
  training_args = TrainingArguments(
156
  output_dir="./fine_tuned_legal_qa",
 
157
  evaluation_strategy="steps",
158
- eval_steps=10,
159
  learning_rate=2e-5,
160
  per_device_train_batch_size=4,
161
  per_device_eval_batch_size=4,
162
- num_train_epochs=0.1, # Very short training for testing purposes
163
  weight_decay=0.01,
164
- logging_steps=5,
165
- save_steps=10,
166
  load_best_model_at_end=True,
167
  report_to=[] # Disable wandb logging
168
  )
@@ -741,4 +740,3 @@ if __name__ == "__main__":
741
  else:
742
  print("\n⚠️ Ngrok setup failed. API will only be available locally.\n")
743
  run()
744
-
 
23
  import uuid
24
  import subprocess # For running ffmpeg commands
25
 
26
+ # Ensure compatibility with Google Colab
27
  try:
28
  from google.colab import drive
29
  drive.mount('/content/drive')
30
+ except Exception:
31
  pass # Skip drive mount if not in Google Colab
32
 
33
+ # Ensure required directories exist
34
  os.makedirs("static", exist_ok=True)
35
  os.makedirs("temp", exist_ok=True)
36
 
37
+ # Ensure GPU usage
38
  device = "cuda" if torch.cuda.is_available() else "cpu"
39
 
40
+ # Initialize FastAPI
41
  app = FastAPI(title="Legal Document and Video Analyzer")
42
 
43
  # Add CORS middleware
 
49
  allow_headers=["*"],
50
  )
51
 
52
+ # Initialize document storage and chat history
53
  document_storage = {}
54
+ chat_history = []
55
 
56
+ # Function to store document context by task ID
57
  def store_document_context(task_id, text):
 
58
  document_storage[task_id] = text
59
  return True
60
 
61
+ # Function to load document context by task ID
62
  def load_document_context(task_id):
 
63
  return document_storage.get(task_id, "")
64
 
65
  #############################
 
70
  """
71
  Fine tunes a QA model on the CUAD dataset for clause extraction.
72
  For testing, we use only 50 training examples (and 10 for validation)
73
+ and restrict training to 10 steps.
74
  """
75
  from datasets import load_dataset
76
  import numpy as np
 
149
  train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "start_positions", "end_positions"])
150
  val_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "start_positions", "end_positions"])
151
 
152
+ # Set max_steps to 10 for fast testing.
153
  training_args = TrainingArguments(
154
  output_dir="./fine_tuned_legal_qa",
155
+ max_steps=10,
156
  evaluation_strategy="steps",
157
+ eval_steps=5,
158
  learning_rate=2e-5,
159
  per_device_train_batch_size=4,
160
  per_device_eval_batch_size=4,
161
+ num_train_epochs=1,
162
  weight_decay=0.01,
163
+ logging_steps=1,
164
+ save_steps=5,
165
  load_best_model_at_end=True,
166
  report_to=[] # Disable wandb logging
167
  )
 
740
  else:
741
  print("\n⚠️ Ngrok setup failed. API will only be available locally.\n")
742
  run()