Spaces:
Sleeping
Sleeping
revert to reference code
Browse files- tasks/text.py +15 -19
tasks/text.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
|
2 |
from fastapi import APIRouter
|
3 |
from datetime import datetime
|
4 |
import time
|
@@ -14,7 +13,7 @@ from huggingface_hub import login
|
|
14 |
from dotenv import load_dotenv
|
15 |
|
16 |
from .utils.evaluation import TextEvaluationRequest
|
17 |
-
from .utils.emissions import tracker, clean_emissions_data, get_space_info
|
18 |
|
19 |
# Load environment variables
|
20 |
load_dotenv()
|
@@ -29,7 +28,7 @@ os.environ["TORCH_COMPILE_DISABLE"] = "1"
|
|
29 |
|
30 |
router = APIRouter()
|
31 |
|
32 |
-
DESCRIPTION = "Climate
|
33 |
ROUTE = "/text"
|
34 |
|
35 |
class TextClassifier:
|
@@ -43,13 +42,15 @@ class TextClassifier:
|
|
43 |
# Load config
|
44 |
self.config = AutoConfig.from_pretrained(
|
45 |
model_name,
|
|
|
|
|
46 |
trust_remote_code=True
|
47 |
)
|
48 |
|
49 |
# Initialize tokenizer
|
50 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
51 |
model_name,
|
52 |
-
model_max_length=
|
53 |
padding_side='right',
|
54 |
truncation_side='right',
|
55 |
trust_remote_code=True
|
@@ -60,15 +61,11 @@ class TextClassifier:
|
|
60 |
model_name,
|
61 |
config=self.config,
|
62 |
trust_remote_code=True,
|
63 |
-
torch_dtype=torch.float32
|
64 |
-
device_map="auto",
|
65 |
-
low_cpu_mem_usage=True
|
66 |
)
|
67 |
|
68 |
-
#
|
69 |
-
|
70 |
-
self.model = self.model.cpu()
|
71 |
-
|
72 |
self.model.eval()
|
73 |
print("Model initialized successfully")
|
74 |
break
|
@@ -84,12 +81,12 @@ class TextClassifier:
|
|
84 |
try:
|
85 |
print(f"Processing batch {batch_idx} with {len(batch)} items")
|
86 |
|
87 |
-
# Tokenize
|
88 |
inputs = self.tokenizer(
|
89 |
batch,
|
90 |
return_tensors="pt",
|
91 |
truncation=True,
|
92 |
-
max_length=512,
|
93 |
padding=True
|
94 |
)
|
95 |
|
@@ -129,14 +126,14 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
129 |
"2_not_human": 2,
|
130 |
"3_not_bad": 3,
|
131 |
"4_solutions_harmful_unnecessary": 4,
|
132 |
-
"
|
133 |
"6_proponents_biased": 6,
|
134 |
"7_fossil_fuels_needed": 7
|
135 |
}
|
136 |
|
137 |
try:
|
138 |
# Load and prepare the dataset
|
139 |
-
dataset = load_dataset(
|
140 |
|
141 |
# Convert string labels to integers
|
142 |
dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
|
@@ -154,7 +151,7 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
154 |
classifier = TextClassifier()
|
155 |
|
156 |
# Prepare batches
|
157 |
-
batch_size =
|
158 |
quotes = test_dataset["quote"]
|
159 |
num_batches = len(quotes) // batch_size + (1 if len(quotes) % batch_size != 0 else 0)
|
160 |
batches = [
|
@@ -166,7 +163,7 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
166 |
batch_results = [[] for _ in range(num_batches)]
|
167 |
|
168 |
# Process batches in parallel
|
169 |
-
max_workers = min(os.cpu_count(),
|
170 |
print(f"Processing with {max_workers} workers")
|
171 |
|
172 |
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
@@ -222,5 +219,4 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
222 |
|
223 |
except Exception as e:
|
224 |
print(f"Error in evaluate_text: {str(e)}")
|
225 |
-
raise Exception(f"Failed to process request: {str(e)}")
|
226 |
-
|
|
|
|
|
1 |
from fastapi import APIRouter
|
2 |
from datetime import datetime
|
3 |
import time
|
|
|
13 |
from dotenv import load_dotenv
|
14 |
|
15 |
from .utils.evaluation import TextEvaluationRequest
|
16 |
+
from .utils.emissions import tracker, clean_emissions_data, get_space_info
|
17 |
|
18 |
# Load environment variables
|
19 |
load_dotenv()
|
|
|
28 |
|
29 |
router = APIRouter()
|
30 |
|
31 |
+
DESCRIPTION = "ModernBERT Climate Claims Classifier"
|
32 |
ROUTE = "/text"
|
33 |
|
34 |
class TextClassifier:
|
|
|
42 |
# Load config
|
43 |
self.config = AutoConfig.from_pretrained(
|
44 |
model_name,
|
45 |
+
num_labels=8,
|
46 |
+
problem_type="single_label_classification",
|
47 |
trust_remote_code=True
|
48 |
)
|
49 |
|
50 |
# Initialize tokenizer
|
51 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
52 |
model_name,
|
53 |
+
model_max_length=8192,
|
54 |
padding_side='right',
|
55 |
truncation_side='right',
|
56 |
trust_remote_code=True
|
|
|
61 |
model_name,
|
62 |
config=self.config,
|
63 |
trust_remote_code=True,
|
64 |
+
torch_dtype=torch.float32
|
|
|
|
|
65 |
)
|
66 |
|
67 |
+
# Move model to appropriate device
|
68 |
+
self.model = self.model.to(self.device)
|
|
|
|
|
69 |
self.model.eval()
|
70 |
print("Model initialized successfully")
|
71 |
break
|
|
|
81 |
try:
|
82 |
print(f"Processing batch {batch_idx} with {len(batch)} items")
|
83 |
|
84 |
+
# Tokenize
|
85 |
inputs = self.tokenizer(
|
86 |
batch,
|
87 |
return_tensors="pt",
|
88 |
truncation=True,
|
89 |
+
max_length=512,
|
90 |
padding=True
|
91 |
)
|
92 |
|
|
|
126 |
"2_not_human": 2,
|
127 |
"3_not_bad": 3,
|
128 |
"4_solutions_harmful_unnecessary": 4,
|
129 |
+
"5_science_is_unreliable": 5,
|
130 |
"6_proponents_biased": 6,
|
131 |
"7_fossil_fuels_needed": 7
|
132 |
}
|
133 |
|
134 |
try:
|
135 |
# Load and prepare the dataset
|
136 |
+
dataset = load_dataset(request.dataset_name)
|
137 |
|
138 |
# Convert string labels to integers
|
139 |
dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
|
|
|
151 |
classifier = TextClassifier()
|
152 |
|
153 |
# Prepare batches
|
154 |
+
batch_size = 16 # Reduced batch size
|
155 |
quotes = test_dataset["quote"]
|
156 |
num_batches = len(quotes) // batch_size + (1 if len(quotes) % batch_size != 0 else 0)
|
157 |
batches = [
|
|
|
163 |
batch_results = [[] for _ in range(num_batches)]
|
164 |
|
165 |
# Process batches in parallel
|
166 |
+
max_workers = min(os.cpu_count(), 2) # Reduced workers
|
167 |
print(f"Processing with {max_workers} workers")
|
168 |
|
169 |
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
|
219 |
|
220 |
except Exception as e:
|
221 |
print(f"Error in evaluate_text: {str(e)}")
|
222 |
+
raise Exception(f"Failed to process request: {str(e)}")
|
|