Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -23,21 +23,21 @@ import time
|
|
23 |
import uuid
|
24 |
import subprocess # For running ffmpeg commands
|
25 |
|
26 |
-
#
|
27 |
try:
|
28 |
from google.colab import drive
|
29 |
drive.mount('/content/drive')
|
30 |
-
except:
|
31 |
pass # Skip drive mount if not in Google Colab
|
32 |
|
33 |
-
#
|
34 |
os.makedirs("static", exist_ok=True)
|
35 |
os.makedirs("temp", exist_ok=True)
|
36 |
|
37 |
-
#
|
38 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
39 |
|
40 |
-
#
|
41 |
app = FastAPI(title="Legal Document and Video Analyzer")
|
42 |
|
43 |
# Add CORS middleware
|
@@ -49,19 +49,17 @@ app.add_middleware(
|
|
49 |
allow_headers=["*"],
|
50 |
)
|
51 |
|
52 |
-
#
|
53 |
document_storage = {}
|
54 |
-
chat_history = []
|
55 |
|
56 |
-
#
|
57 |
def store_document_context(task_id, text):
|
58 |
-
"""Store document text for retrieval by chatbot."""
|
59 |
document_storage[task_id] = text
|
60 |
return True
|
61 |
|
62 |
-
#
|
63 |
def load_document_context(task_id):
|
64 |
-
"""Retrieve document text for chatbot context."""
|
65 |
return document_storage.get(task_id, "")
|
66 |
|
67 |
#############################
|
@@ -72,7 +70,7 @@ def fine_tune_cuad_model():
|
|
72 |
"""
|
73 |
Fine tunes a QA model on the CUAD dataset for clause extraction.
|
74 |
For testing, we use only 50 training examples (and 10 for validation)
|
75 |
-
and
|
76 |
"""
|
77 |
from datasets import load_dataset
|
78 |
import numpy as np
|
@@ -151,18 +149,19 @@ def fine_tune_cuad_model():
|
|
151 |
train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "start_positions", "end_positions"])
|
152 |
val_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "start_positions", "end_positions"])
|
153 |
|
154 |
-
#
|
155 |
training_args = TrainingArguments(
|
156 |
output_dir="./fine_tuned_legal_qa",
|
|
|
157 |
evaluation_strategy="steps",
|
158 |
-
eval_steps=
|
159 |
learning_rate=2e-5,
|
160 |
per_device_train_batch_size=4,
|
161 |
per_device_eval_batch_size=4,
|
162 |
-
num_train_epochs=
|
163 |
weight_decay=0.01,
|
164 |
-
logging_steps=
|
165 |
-
save_steps=
|
166 |
load_best_model_at_end=True,
|
167 |
report_to=[] # Disable wandb logging
|
168 |
)
|
@@ -741,4 +740,3 @@ if __name__ == "__main__":
|
|
741 |
else:
|
742 |
print("\n⚠️ Ngrok setup failed. API will only be available locally.\n")
|
743 |
run()
|
744 |
-
|
|
|
23 |
import uuid
|
24 |
import subprocess # For running ffmpeg commands
|
25 |
|
26 |
+
# Ensure compatibility with Google Colab
|
27 |
try:
|
28 |
from google.colab import drive
|
29 |
drive.mount('/content/drive')
|
30 |
+
except Exception:
|
31 |
pass # Skip drive mount if not in Google Colab
|
32 |
|
33 |
+
# Ensure required directories exist
|
34 |
os.makedirs("static", exist_ok=True)
|
35 |
os.makedirs("temp", exist_ok=True)
|
36 |
|
37 |
+
# Ensure GPU usage
|
38 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
39 |
|
40 |
+
# Initialize FastAPI
|
41 |
app = FastAPI(title="Legal Document and Video Analyzer")
|
42 |
|
43 |
# Add CORS middleware
|
|
|
49 |
allow_headers=["*"],
|
50 |
)
|
51 |
|
52 |
+
# Initialize document storage and chat history
|
53 |
document_storage = {}
|
54 |
+
chat_history = []
|
55 |
|
56 |
+
# Function to store document context by task ID
|
57 |
def store_document_context(task_id, text):
|
|
|
58 |
document_storage[task_id] = text
|
59 |
return True
|
60 |
|
61 |
+
# Function to load document context by task ID
|
62 |
def load_document_context(task_id):
|
|
|
63 |
return document_storage.get(task_id, "")
|
64 |
|
65 |
#############################
|
|
|
70 |
"""
|
71 |
Fine tunes a QA model on the CUAD dataset for clause extraction.
|
72 |
For testing, we use only 50 training examples (and 10 for validation)
|
73 |
+
and restrict training to 10 steps.
|
74 |
"""
|
75 |
from datasets import load_dataset
|
76 |
import numpy as np
|
|
|
149 |
train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "start_positions", "end_positions"])
|
150 |
val_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "start_positions", "end_positions"])
|
151 |
|
152 |
+
# Set max_steps to 10 for fast testing.
|
153 |
training_args = TrainingArguments(
|
154 |
output_dir="./fine_tuned_legal_qa",
|
155 |
+
max_steps=10,
|
156 |
evaluation_strategy="steps",
|
157 |
+
eval_steps=5,
|
158 |
learning_rate=2e-5,
|
159 |
per_device_train_batch_size=4,
|
160 |
per_device_eval_batch_size=4,
|
161 |
+
num_train_epochs=1,
|
162 |
weight_decay=0.01,
|
163 |
+
logging_steps=1,
|
164 |
+
save_steps=5,
|
165 |
load_best_model_at_end=True,
|
166 |
report_to=[] # Disable wandb logging
|
167 |
)
|
|
|
740 |
else:
|
741 |
print("\n⚠️ Ngrok setup failed. API will only be available locally.\n")
|
742 |
run()
|
|