JenetGhumman commited on
Commit
5d2f9b2
·
verified ·
1 Parent(s): 9685f7b

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +52 -21
tasks/text.py CHANGED
@@ -2,14 +2,16 @@ from fastapi import APIRouter
2
  from datetime import datetime
3
  from datasets import load_dataset
4
  from sklearn.metrics import accuracy_score
5
- import random
 
 
6
 
7
  from .utils.evaluation import TextEvaluationRequest
8
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
9
 
10
  router = APIRouter()
11
 
12
- DESCRIPTION = "Random Baseline"
13
  ROUTE = "/text"
14
 
15
  @router.post(ROUTE, tags=["Text Task"],
@@ -18,9 +20,8 @@ async def evaluate_text(request: TextEvaluationRequest):
18
  """
19
  Evaluate text classification for climate disinformation detection.
20
 
21
- Current Model: Random Baseline
22
- - Makes random predictions from the label space (0-7)
23
- - Used as a baseline for comparison
24
  """
25
  # Get space info
26
  username, space_url = get_space_info()
@@ -45,32 +46,62 @@ async def evaluate_text(request: TextEvaluationRequest):
45
 
46
  # Split dataset
47
  train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
 
48
  test_dataset = train_test["test"]
49
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  # Start tracking emissions
51
  tracker.start()
52
  tracker.start_task("inference")
53
 
54
- #--------------------------------------------------------------------------------------------
55
- # YOUR MODEL INFERENCE CODE HERE
56
- # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
57
- #--------------------------------------------------------------------------------------------
58
-
59
- # Make random predictions (placeholder for actual model inference)
60
- true_labels = test_dataset["label"]
61
- predictions = [random.randint(0, 7) for _ in range(len(true_labels))]
62
 
63
- #--------------------------------------------------------------------------------------------
64
- # YOUR MODEL INFERENCE STOPS HERE
65
- #--------------------------------------------------------------------------------------------
 
66
 
67
-
68
  # Stop tracking emissions
69
  emissions_data = tracker.stop_task()
70
-
71
  # Calculate accuracy
72
  accuracy = accuracy_score(true_labels, predictions)
73
-
74
  # Prepare results dictionary
75
  results = {
76
  "username": username,
@@ -89,4 +120,4 @@ async def evaluate_text(request: TextEvaluationRequest):
89
  }
90
  }
91
 
92
- return results
 
2
  from datetime import datetime
3
  from datasets import load_dataset
4
  from sklearn.metrics import accuracy_score
5
+ from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
6
+ from transformers import Trainer, TrainingArguments
7
+ import torch
8
 
9
  from .utils.evaluation import TextEvaluationRequest
10
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
11
 
12
  router = APIRouter()
13
 
14
+ DESCRIPTION = "DistilBERT Baseline"
15
  ROUTE = "/text"
16
 
17
  @router.post(ROUTE, tags=["Text Task"],
 
20
  """
21
  Evaluate text classification for climate disinformation detection.
22
 
23
+ Current Model: DistilBERT
24
+ - Fine-tunes and evaluates a DistilBERT model on the given dataset
 
25
  """
26
  # Get space info
27
  username, space_url = get_space_info()
 
46
 
47
  # Split dataset
48
  train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
49
+ train_dataset = train_test["train"]
50
  test_dataset = train_test["test"]
51
+
52
+ # Tokenizer and model
53
+ tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
54
+ model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=8)
55
+
56
+ # Tokenize datasets
57
+ def preprocess(examples):
58
+ return tokenizer(examples["text"], truncation=True, padding=True, max_length=512)
59
+
60
+ train_dataset = train_dataset.map(preprocess, batched=True)
61
+ test_dataset = test_dataset.map(preprocess, batched=True)
62
+
63
+ train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
64
+ test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
65
+
66
+ # Training arguments
67
+ training_args = TrainingArguments(
68
+ output_dir="./results",
69
+ evaluation_strategy="epoch",
70
+ learning_rate=5e-5,
71
+ per_device_train_batch_size=16,
72
+ per_device_eval_batch_size=16,
73
+ num_train_epochs=3,
74
+ weight_decay=0.01,
75
+ logging_dir="./logs",
76
+ logging_steps=10,
77
+ )
78
+
79
+ trainer = Trainer(
80
+ model=model,
81
+ args=training_args,
82
+ train_dataset=train_dataset,
83
+ eval_dataset=test_dataset,
84
+ tokenizer=tokenizer,
85
+ )
86
+
87
  # Start tracking emissions
88
  tracker.start()
89
  tracker.start_task("inference")
90
 
91
+ # Train and evaluate the model
92
+ trainer.train()
 
 
 
 
 
 
93
 
94
+ # Perform inference
95
+ predictions = trainer.predict(test_dataset).predictions
96
+ predictions = torch.argmax(torch.tensor(predictions), axis=1).tolist()
97
+ true_labels = test_dataset["label"]
98
 
 
99
  # Stop tracking emissions
100
  emissions_data = tracker.stop_task()
101
+
102
  # Calculate accuracy
103
  accuracy = accuracy_score(true_labels, predictions)
104
+
105
  # Prepare results dictionary
106
  results = {
107
  "username": username,
 
120
  }
121
  }
122
 
123
+ return results