Spaces:
Sleeping
Sleeping
Update tasks/text.py
Browse files- tasks/text.py +52 -21
tasks/text.py
CHANGED
@@ -2,14 +2,16 @@ from fastapi import APIRouter
|
|
2 |
from datetime import datetime
|
3 |
from datasets import load_dataset
|
4 |
from sklearn.metrics import accuracy_score
|
5 |
-
import
|
|
|
|
|
6 |
|
7 |
from .utils.evaluation import TextEvaluationRequest
|
8 |
from .utils.emissions import tracker, clean_emissions_data, get_space_info
|
9 |
|
10 |
router = APIRouter()
|
11 |
|
12 |
-
DESCRIPTION = "
|
13 |
ROUTE = "/text"
|
14 |
|
15 |
@router.post(ROUTE, tags=["Text Task"],
|
@@ -18,9 +20,8 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
18 |
"""
|
19 |
Evaluate text classification for climate disinformation detection.
|
20 |
|
21 |
-
Current Model:
|
22 |
-
-
|
23 |
-
- Used as a baseline for comparison
|
24 |
"""
|
25 |
# Get space info
|
26 |
username, space_url = get_space_info()
|
@@ -45,32 +46,62 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
45 |
|
46 |
# Split dataset
|
47 |
train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
|
|
|
48 |
test_dataset = train_test["test"]
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
# Start tracking emissions
|
51 |
tracker.start()
|
52 |
tracker.start_task("inference")
|
53 |
|
54 |
-
|
55 |
-
|
56 |
-
# Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
|
57 |
-
#--------------------------------------------------------------------------------------------
|
58 |
-
|
59 |
-
# Make random predictions (placeholder for actual model inference)
|
60 |
-
true_labels = test_dataset["label"]
|
61 |
-
predictions = [random.randint(0, 7) for _ in range(len(true_labels))]
|
62 |
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
66 |
|
67 |
-
|
68 |
# Stop tracking emissions
|
69 |
emissions_data = tracker.stop_task()
|
70 |
-
|
71 |
# Calculate accuracy
|
72 |
accuracy = accuracy_score(true_labels, predictions)
|
73 |
-
|
74 |
# Prepare results dictionary
|
75 |
results = {
|
76 |
"username": username,
|
@@ -89,4 +120,4 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
89 |
}
|
90 |
}
|
91 |
|
92 |
-
return results
|
|
|
2 |
from datetime import datetime
|
3 |
from datasets import load_dataset
|
4 |
from sklearn.metrics import accuracy_score
|
5 |
+
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
|
6 |
+
from transformers import Trainer, TrainingArguments
|
7 |
+
import torch
|
8 |
|
9 |
from .utils.evaluation import TextEvaluationRequest
|
10 |
from .utils.emissions import tracker, clean_emissions_data, get_space_info
|
11 |
|
12 |
router = APIRouter()
|
13 |
|
14 |
+
DESCRIPTION = "DistilBERT Baseline"
|
15 |
ROUTE = "/text"
|
16 |
|
17 |
@router.post(ROUTE, tags=["Text Task"],
|
|
|
20 |
"""
|
21 |
Evaluate text classification for climate disinformation detection.
|
22 |
|
23 |
+
Current Model: DistilBERT
|
24 |
+
- Fine-tunes and evaluates a DistilBERT model on the given dataset
|
|
|
25 |
"""
|
26 |
# Get space info
|
27 |
username, space_url = get_space_info()
|
|
|
46 |
|
47 |
# Split dataset
|
48 |
train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
|
49 |
+
train_dataset = train_test["train"]
|
50 |
test_dataset = train_test["test"]
|
51 |
+
|
52 |
+
# Tokenizer and model
|
53 |
+
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
|
54 |
+
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=8)
|
55 |
+
|
56 |
+
# Tokenize datasets
|
57 |
+
def preprocess(examples):
|
58 |
+
return tokenizer(examples["text"], truncation=True, padding=True, max_length=512)
|
59 |
+
|
60 |
+
train_dataset = train_dataset.map(preprocess, batched=True)
|
61 |
+
test_dataset = test_dataset.map(preprocess, batched=True)
|
62 |
+
|
63 |
+
train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
|
64 |
+
test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
|
65 |
+
|
66 |
+
# Training arguments
|
67 |
+
training_args = TrainingArguments(
|
68 |
+
output_dir="./results",
|
69 |
+
evaluation_strategy="epoch",
|
70 |
+
learning_rate=5e-5,
|
71 |
+
per_device_train_batch_size=16,
|
72 |
+
per_device_eval_batch_size=16,
|
73 |
+
num_train_epochs=3,
|
74 |
+
weight_decay=0.01,
|
75 |
+
logging_dir="./logs",
|
76 |
+
logging_steps=10,
|
77 |
+
)
|
78 |
+
|
79 |
+
trainer = Trainer(
|
80 |
+
model=model,
|
81 |
+
args=training_args,
|
82 |
+
train_dataset=train_dataset,
|
83 |
+
eval_dataset=test_dataset,
|
84 |
+
tokenizer=tokenizer,
|
85 |
+
)
|
86 |
+
|
87 |
# Start tracking emissions
|
88 |
tracker.start()
|
89 |
tracker.start_task("inference")
|
90 |
|
91 |
+
# Train and evaluate the model
|
92 |
+
trainer.train()
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
+
# Perform inference
|
95 |
+
predictions = trainer.predict(test_dataset).predictions
|
96 |
+
predictions = torch.argmax(torch.tensor(predictions), axis=1).tolist()
|
97 |
+
true_labels = test_dataset["label"]
|
98 |
|
|
|
99 |
# Stop tracking emissions
|
100 |
emissions_data = tracker.stop_task()
|
101 |
+
|
102 |
# Calculate accuracy
|
103 |
accuracy = accuracy_score(true_labels, predictions)
|
104 |
+
|
105 |
# Prepare results dictionary
|
106 |
results = {
|
107 |
"username": username,
|
|
|
120 |
}
|
121 |
}
|
122 |
|
123 |
+
return results
|