0xrushi commited on
Commit
1e0fe77
·
verified ·
1 Parent(s): 9685f7b

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +58 -31
tasks/text.py CHANGED
@@ -2,30 +2,63 @@ from fastapi import APIRouter
2
  from datetime import datetime
3
  from datasets import load_dataset
4
  from sklearn.metrics import accuracy_score
5
- import random
 
 
6
 
7
  from .utils.evaluation import TextEvaluationRequest
8
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
9
 
10
  router = APIRouter()
11
 
12
- DESCRIPTION = "Random Baseline"
13
  ROUTE = "/text"
14
 
15
- @router.post(ROUTE, tags=["Text Task"],
16
- description=DESCRIPTION)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  async def evaluate_text(request: TextEvaluationRequest):
18
  """
19
- Evaluate text classification for climate disinformation detection.
20
-
21
- Current Model: Random Baseline
22
- - Makes random predictions from the label space (0-7)
23
- - Used as a baseline for comparison
24
  """
25
- # Get space info
26
  username, space_url = get_space_info()
27
 
28
- # Define the label mapping
29
  LABEL_MAPPING = {
30
  "0_not_relevant": 0,
31
  "1_not_happening": 1,
@@ -37,41 +70,35 @@ async def evaluate_text(request: TextEvaluationRequest):
37
  "7_fossil_fuels_needed": 7
38
  }
39
 
40
- # Load and prepare the dataset
41
  dataset = load_dataset(request.dataset_name)
42
 
43
- # Convert string labels to integers
44
  dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
45
 
46
- # Split dataset
47
  train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
48
  test_dataset = train_test["test"]
49
 
50
- # Start tracking emissions
51
  tracker.start()
52
  tracker.start_task("inference")
53
 
54
- #--------------------------------------------------------------------------------------------
55
- # YOUR MODEL INFERENCE CODE HERE
56
- # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
57
- #--------------------------------------------------------------------------------------------
58
-
59
- # Make random predictions (placeholder for actual model inference)
60
  true_labels = test_dataset["label"]
61
- predictions = [random.randint(0, 7) for _ in range(len(true_labels))]
62
-
63
- #--------------------------------------------------------------------------------------------
64
- # YOUR MODEL INFERENCE STOPS HERE
65
- #--------------------------------------------------------------------------------------------
66
-
67
 
68
- # Stop tracking emissions
 
 
 
 
 
 
 
 
 
 
 
69
  emissions_data = tracker.stop_task()
70
 
71
- # Calculate accuracy
72
  accuracy = accuracy_score(true_labels, predictions)
73
 
74
- # Prepare results dictionary
75
  results = {
76
  "username": username,
77
  "space_url": space_url,
@@ -89,4 +116,4 @@ async def evaluate_text(request: TextEvaluationRequest):
89
  }
90
  }
91
 
92
- return results
 
2
  from datetime import datetime
3
  from datasets import load_dataset
4
  from sklearn.metrics import accuracy_score
5
+ import torch
6
+ from transformers import AutoTokenizer, RobertaForSequenceClassification
7
+ from torch.utils.data import Dataset, DataLoader
8
 
9
  from .utils.evaluation import TextEvaluationRequest
10
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
11
 
12
  router = APIRouter()
13
 
14
+ DESCRIPTION = "RoBERTa Climate Disinformation Classifier"
15
  ROUTE = "/text"
16
 
17
+ class FrugalDataClass(Dataset):
18
+ def __init__(self, texts, labels, tokenizer, max_len=128):
19
+ self.texts = texts
20
+ self.labels = labels
21
+ self.tokenizer = tokenizer
22
+ self.max_len = max_len
23
+
24
+ def __len__(self):
25
+ return len(self.texts)
26
+
27
+ def __getitem__(self, idx):
28
+ text = str(self.texts[idx])
29
+ label = self.labels[idx]
30
+
31
+ encodings = self.tokenizer(
32
+ text,
33
+ max_length=self.max_len,
34
+ padding='max_length',
35
+ truncation=True,
36
+ return_tensors="pt"
37
+ )
38
+
39
+ return {
40
+ 'input_ids': encodings['input_ids'].flatten(),
41
+ 'attention_mask': encodings['attention_mask'].flatten(),
42
+ 'labels': torch.tensor(label, dtype=torch.long)
43
+ }
44
+
45
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
46
+ tokenizer = AutoTokenizer.from_pretrained("roberta-base")
47
+ model = RobertaForSequenceClassification.from_pretrained(
48
+ "roberta-base",
49
+ num_labels=8
50
+ )
51
+ model.load_state_dict(torch.load('best_roberta_model.pth', map_location=device))
52
+ model.to(device)
53
+ model.eval()
54
+
55
+ @router.post(ROUTE, description=DESCRIPTION)
56
  async def evaluate_text(request: TextEvaluationRequest):
57
  """
58
+ Evaluate text classification for climate disinformation detection using RoBERTa.
 
 
 
 
59
  """
 
60
  username, space_url = get_space_info()
61
 
 
62
  LABEL_MAPPING = {
63
  "0_not_relevant": 0,
64
  "1_not_happening": 1,
 
70
  "7_fossil_fuels_needed": 7
71
  }
72
 
 
73
  dataset = load_dataset(request.dataset_name)
74
 
 
75
  dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
76
 
 
77
  train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
78
  test_dataset = train_test["test"]
79
 
 
80
  tracker.start()
81
  tracker.start_task("inference")
82
 
83
+ test_texts = test_dataset["quote"]
 
 
 
 
 
84
  true_labels = test_dataset["label"]
 
 
 
 
 
 
85
 
86
+ test_dataset = FrugalDataClass(test_texts, true_labels, tokenizer)
87
+ test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)
88
+
89
+ predictions = []
90
+ with torch.no_grad():
91
+ for batch in test_loader:
92
+ input_ids = batch['input_ids'].to(device)
93
+ attention_mask = batch['attention_mask'].to(device)
94
+ outputs = model(input_ids, attention_mask=attention_mask)
95
+ preds = torch.argmax(outputs.logits, dim=1).cpu().numpy()
96
+ predictions.extend(preds)
97
+
98
  emissions_data = tracker.stop_task()
99
 
 
100
  accuracy = accuracy_score(true_labels, predictions)
101
 
 
102
  results = {
103
  "username": username,
104
  "space_url": space_url,
 
116
  }
117
  }
118
 
119
+ return results