Tonic commited on
Commit
ece5856
·
unverified ·
1 Parent(s): 0ae53cb

add model inference code

Browse files
Files changed (1) hide show
  1. tasks/text.py +54 -8
tasks/text.py CHANGED
@@ -3,6 +3,9 @@ from datetime import datetime
3
  from datasets import load_dataset
4
  from sklearn.metrics import accuracy_score
5
  import random
 
 
 
6
 
7
  from .utils.evaluation import TextEvaluationRequest
8
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
@@ -51,14 +54,57 @@ async def evaluate_text(request: TextEvaluationRequest):
51
  tracker.start()
52
  tracker.start_task("inference")
53
 
54
- #--------------------------------------------------------------------------------------------
55
- # YOUR MODEL INFERENCE CODE HERE
56
- # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
57
- #--------------------------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- # Make random predictions (placeholder for actual model inference)
60
- true_labels = test_dataset["label"]
61
- predictions = [random.randint(0, 7) for _ in range(len(true_labels))]
62
 
63
  #--------------------------------------------------------------------------------------------
64
  # YOUR MODEL INFERENCE STOPS HERE
@@ -69,7 +115,7 @@ async def evaluate_text(request: TextEvaluationRequest):
69
  emissions_data = tracker.stop_task()
70
 
71
  # Calculate accuracy
72
- accuracy = accuracy_score(true_labels, predictions)
73
 
74
  # Prepare results dictionary
75
  results = {
 
3
  from datasets import load_dataset
4
  from sklearn.metrics import accuracy_score
5
  import random
6
+ import torch
7
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
8
+ from torch.utils.data import Dataset, DataLoader
9
 
10
  from .utils.evaluation import TextEvaluationRequest
11
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
 
54
  tracker.start()
55
  tracker.start_task("inference")
56
 
57
+ # Load the model and tokenizer
58
+ model_name = "Tonic/climate-guard-toxic-agent"
59
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
60
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
61
+
62
+ class TextDataset(Dataset):
63
+ def __init__(self, texts, labels, tokenizer, max_len=128):
64
+ self.texts = texts
65
+ self.labels = labels
66
+ self.tokenizer = tokenizer
67
+ self.max_len = max_len
68
+
69
+ def __len__(self):
70
+ return len(self.texts)
71
+
72
+ def __getitem__(self, idx):
73
+ text = self.texts[idx]
74
+ label = self.labels[idx]
75
+ encodings = self.tokenizer(
76
+ text,
77
+ max_length=self.max_len,
78
+ padding='max_length',
79
+ truncation=True,
80
+ return_tensors="pt"
81
+ )
82
+ return {
83
+ 'input_ids': encodings['input_ids'].squeeze(0),
84
+ 'attention_mask': encodings['attention_mask'].squeeze(0),
85
+ 'labels': torch.tensor(label, dtype=torch.long)
86
+ }
87
+
88
+ # Create dataset and dataloader
89
+ test_dataset = TextDataset(texts, labels, tokenizer)
90
+ test_loader = DataLoader(test_dataset, batch_size=16)
91
+
92
+ # Model inference
93
+ model.eval()
94
+ predictions = []
95
+ ground_truth = []
96
+ DEVICE = 'cpu'
97
+ with torch.no_grad():
98
+ for batch in test_loader:
99
+ input_ids = batch['input_ids'].to(DEVICE)
100
+ attention_mask = batch['attention_mask'].to(DEVICE)
101
+ labels = batch['labels'].to(DEVICE)
102
+
103
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
104
+ _, predicted = torch.max(outputs.logits, 1)
105
 
106
+ predictions.extend(predicted.cpu().numpy())
107
+ ground_truth.extend(labels.cpu().numpy())
 
108
 
109
  #--------------------------------------------------------------------------------------------
110
  # YOUR MODEL INFERENCE STOPS HERE
 
115
  emissions_data = tracker.stop_task()
116
 
117
  # Calculate accuracy
118
+ accuracy = accuracy_score(ground_truth, predictions)
119
 
120
  # Prepare results dictionary
121
  results = {