Spaces:
Sleeping
Sleeping
switch model loading technique
Browse files- tasks/text.py +56 -32
tasks/text.py
CHANGED
@@ -31,11 +31,6 @@ router = APIRouter()
|
|
31 |
DESCRIPTION = "Climate Guard Toxic Agent Classifier"
|
32 |
ROUTE = "/text"
|
33 |
|
34 |
-
# Custom LayerNorm that ignores bias parameter
|
35 |
-
class CustomLayerNorm(nn.LayerNorm):
|
36 |
-
def __init__(self, normalized_shape, eps=1e-5, **kwargs):
|
37 |
-
super().__init__(normalized_shape, eps=eps)
|
38 |
-
|
39 |
class TextClassifier:
|
40 |
def __init__(self):
|
41 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
@@ -45,33 +40,34 @@ class TextClassifier:
|
|
45 |
for attempt in range(max_retries):
|
46 |
try:
|
47 |
# Load config
|
48 |
-
self.config = AutoConfig.from_pretrained(
|
|
|
|
|
|
|
49 |
|
50 |
# Initialize tokenizer
|
51 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
52 |
model_name,
|
53 |
model_max_length=2048,
|
54 |
padding_side='right',
|
55 |
-
truncation_side='right'
|
|
|
56 |
)
|
57 |
|
58 |
-
#
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
-
|
63 |
-
|
64 |
-
self.model =
|
65 |
-
model_name,
|
66 |
-
config=self.config,
|
67 |
-
ignore_mismatched_sizes=True,
|
68 |
-
low_cpu_mem_usage=True
|
69 |
-
)
|
70 |
-
finally:
|
71 |
-
# Restore original LayerNorm
|
72 |
-
nn.LayerNorm = original_layernorm
|
73 |
|
74 |
-
self.model.to(self.device)
|
75 |
self.model.eval()
|
76 |
print("Model initialized successfully")
|
77 |
break
|
@@ -87,14 +83,17 @@ class TextClassifier:
|
|
87 |
try:
|
88 |
print(f"Processing batch {batch_idx} with {len(batch)} items")
|
89 |
|
90 |
-
# Tokenize
|
91 |
inputs = self.tokenizer(
|
92 |
batch,
|
93 |
return_tensors="pt",
|
94 |
truncation=True,
|
95 |
-
max_length=
|
96 |
-
padding=
|
97 |
-
)
|
|
|
|
|
|
|
98 |
|
99 |
# Get predictions
|
100 |
with torch.no_grad():
|
@@ -108,6 +107,13 @@ class TextClassifier:
|
|
108 |
print(f"Error in batch {batch_idx}: {str(e)}")
|
109 |
return [0] * len(batch), batch_idx
|
110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
@router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION)
|
112 |
async def evaluate_text(request: TextEvaluationRequest):
|
113 |
"""Evaluate text classification for climate disinformation detection."""
|
@@ -128,8 +134,21 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
128 |
}
|
129 |
|
130 |
try:
|
131 |
-
# Load and prepare the dataset
|
132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
# Convert string labels to integers
|
135 |
dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
|
@@ -146,8 +165,8 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
146 |
# Initialize the model once
|
147 |
classifier = TextClassifier()
|
148 |
|
149 |
-
# Prepare batches
|
150 |
-
batch_size =
|
151 |
quotes = test_dataset["quote"]
|
152 |
num_batches = len(quotes) // batch_size + (1 if len(quotes) % batch_size != 0 else 0)
|
153 |
batches = [
|
@@ -158,8 +177,8 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
158 |
# Initialize batch_results
|
159 |
batch_results = [[] for _ in range(num_batches)]
|
160 |
|
161 |
-
# Process batches in parallel
|
162 |
-
max_workers = min(os.cpu_count(),
|
163 |
print(f"Processing with {max_workers} workers")
|
164 |
|
165 |
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
@@ -192,6 +211,11 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
192 |
accuracy = accuracy_score(true_labels, predictions)
|
193 |
print("accuracy:", accuracy)
|
194 |
|
|
|
|
|
|
|
|
|
|
|
195 |
# Prepare results
|
196 |
results = {
|
197 |
"username": username,
|
|
|
31 |
DESCRIPTION = "Climate Guard Toxic Agent Classifier"
|
32 |
ROUTE = "/text"
|
33 |
|
|
|
|
|
|
|
|
|
|
|
34 |
class TextClassifier:
|
35 |
def __init__(self):
|
36 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
40 |
for attempt in range(max_retries):
|
41 |
try:
|
42 |
# Load config
|
43 |
+
self.config = AutoConfig.from_pretrained(
|
44 |
+
model_name,
|
45 |
+
trust_remote_code=True
|
46 |
+
)
|
47 |
|
48 |
# Initialize tokenizer
|
49 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
50 |
model_name,
|
51 |
model_max_length=2048,
|
52 |
padding_side='right',
|
53 |
+
truncation_side='right',
|
54 |
+
trust_remote_code=True
|
55 |
)
|
56 |
|
57 |
+
# Initialize model
|
58 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(
|
59 |
+
model_name,
|
60 |
+
config=self.config,
|
61 |
+
trust_remote_code=True,
|
62 |
+
torch_dtype=torch.float32,
|
63 |
+
device_map="auto",
|
64 |
+
low_cpu_mem_usage=True
|
65 |
+
)
|
66 |
|
67 |
+
# Force model to CPU if CUDA is not available
|
68 |
+
if not torch.cuda.is_available():
|
69 |
+
self.model = self.model.cpu()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
|
|
71 |
self.model.eval()
|
72 |
print("Model initialized successfully")
|
73 |
break
|
|
|
83 |
try:
|
84 |
print(f"Processing batch {batch_idx} with {len(batch)} items")
|
85 |
|
86 |
+
# Tokenize with smaller max length
|
87 |
inputs = self.tokenizer(
|
88 |
batch,
|
89 |
return_tensors="pt",
|
90 |
truncation=True,
|
91 |
+
max_length=512, # Reduced max length
|
92 |
+
padding=True
|
93 |
+
)
|
94 |
+
|
95 |
+
# Move inputs to device
|
96 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
97 |
|
98 |
# Get predictions
|
99 |
with torch.no_grad():
|
|
|
107 |
print(f"Error in batch {batch_idx}: {str(e)}")
|
108 |
return [0] * len(batch), batch_idx
|
109 |
|
110 |
+
def __del__(self):
|
111 |
+
# Clean up CUDA memory
|
112 |
+
if hasattr(self, 'model'):
|
113 |
+
del self.model
|
114 |
+
if torch.cuda.is_available():
|
115 |
+
torch.cuda.empty_cache()
|
116 |
+
|
117 |
@router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION)
|
118 |
async def evaluate_text(request: TextEvaluationRequest):
|
119 |
"""Evaluate text classification for climate disinformation detection."""
|
|
|
134 |
}
|
135 |
|
136 |
try:
|
137 |
+
# Load and prepare the dataset with retry mechanism
|
138 |
+
max_retries = 3
|
139 |
+
for attempt in range(max_retries):
|
140 |
+
try:
|
141 |
+
dataset = load_dataset(
|
142 |
+
"QuotaClimat/frugalaichallenge-text-train",
|
143 |
+
token=HF_TOKEN,
|
144 |
+
trust_remote_code=True
|
145 |
+
)
|
146 |
+
break
|
147 |
+
except Exception as e:
|
148 |
+
if attempt == max_retries - 1:
|
149 |
+
raise Exception(f"Failed to load dataset after {max_retries} attempts: {str(e)}")
|
150 |
+
print(f"Dataset loading attempt {attempt + 1} failed, retrying... Error: {str(e)}")
|
151 |
+
time.sleep(2)
|
152 |
|
153 |
# Convert string labels to integers
|
154 |
dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
|
|
|
165 |
# Initialize the model once
|
166 |
classifier = TextClassifier()
|
167 |
|
168 |
+
# Prepare batches with smaller batch size
|
169 |
+
batch_size = 16 # Reduced batch size
|
170 |
quotes = test_dataset["quote"]
|
171 |
num_batches = len(quotes) // batch_size + (1 if len(quotes) % batch_size != 0 else 0)
|
172 |
batches = [
|
|
|
177 |
# Initialize batch_results
|
178 |
batch_results = [[] for _ in range(num_batches)]
|
179 |
|
180 |
+
# Process batches in parallel with fewer workers
|
181 |
+
max_workers = min(os.cpu_count(), 2) # Reduced number of workers
|
182 |
print(f"Processing with {max_workers} workers")
|
183 |
|
184 |
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
|
211 |
accuracy = accuracy_score(true_labels, predictions)
|
212 |
print("accuracy:", accuracy)
|
213 |
|
214 |
+
# Clean up
|
215 |
+
del classifier
|
216 |
+
if torch.cuda.is_available():
|
217 |
+
torch.cuda.empty_cache()
|
218 |
+
|
219 |
# Prepare results
|
220 |
results = {
|
221 |
"username": username,
|