ypesk commited on
Commit
435ab1a
·
verified ·
1 Parent(s): 80cfa93

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +184 -22
tasks/text.py CHANGED
@@ -30,7 +30,13 @@ else:
30
  device = torch.device("cpu")
31
 
32
 
33
- MODEL = "ct" #mlp, ct, modern
 
 
 
 
 
 
34
 
35
  class ConspiracyClassification(
36
  nn.Module,
@@ -65,26 +71,90 @@ class ConspiracyClassification(
65
 
66
  return outputs
67
 
68
- class CovidTwitterBertClassifier(
69
  nn.Module,
70
  PyTorchModelHubMixin,
71
  # optionally, you can add metadata which gets pushed to the model card
72
- ):
73
  def __init__(self, num_classes):
74
  super().__init__()
75
- self.n_classes = num_classes
76
  self.bert = BertForPreTraining.from_pretrained('digitalepidemiologylab/covid-twitter-bert-v2')
77
  self.bert.cls.seq_relationship = nn.Linear(1024, num_classes)
78
-
79
- self.sigmoid = nn.Sigmoid()
80
 
81
- def forward(self, input_ids, token_type_ids, input_mask):
82
  outputs = self.bert(input_ids = input_ids, token_type_ids = token_type_ids, attention_mask = input_mask)
83
-
84
  logits = outputs[1]
85
 
86
- return logits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  @router.post(ROUTE, tags=["Text Task"],
90
  description=DESCRIPTION)
@@ -120,28 +190,20 @@ async def evaluate_text(request: TextEvaluationRequest):
120
  # Split dataset
121
  train_test = dataset["train"]
122
  test_dataset = dataset["test"]
123
-
124
- # Start tracking emissions
125
- tracker.start()
126
- tracker.start_task("inference")
127
 
128
- #--------------------------------------------------------------------------------------------
129
- # YOUR MODEL INFERENCE CODE HERE
130
- # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
131
- #--------------------------------------------------------------------------------------------
132
  if MODEL =="mlp":
133
  model = ConspiracyClassification.from_pretrained("ypesk/frugal-ai-mlp-baseline")
134
  model = model.to(device)
135
  emb_model = SentenceTransformer("paraphrase-MiniLM-L3-v2")
136
  batch_size = 6
137
-
138
  test_texts = torch.Tensor(emb_model.encode([t['quote'] for t in test_dataset]))
139
  test_data = TensorDataset(test_texts)
140
  test_sampler = SequentialSampler(test_data)
141
  test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=batch_size)
142
 
143
  elif MODEL == "ct":
144
- model = CovidTwitterBertClassifier.from_pretrained("ypesk/ct-baseline")
145
  model = model.to(device)
146
  tokenizer = AutoTokenizer.from_pretrained('digitalepidemiologylab/covid-twitter-bert')
147
 
@@ -161,18 +223,118 @@ async def evaluate_text(request: TextEvaluationRequest):
161
 
162
  test_sampler = SequentialSampler(test_data)
163
  test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=batch_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
 
 
 
 
 
165
  model.eval()
166
- predictions = []
167
  for batch in tqdm(test_dataloader):
168
  batch = tuple(t.to(device) for t in batch)
169
  with torch.no_grad():
170
  if MODEL =="mlp":
171
  b_texts = batch[0]
172
  logits = model(b_texts)
173
- elif MODEL == "ct":
 
 
 
174
  b_input_ids, b_input_mask, b_token_type_ids = batch
175
- logits = model(b_input_ids, b_token_type_ids, b_input_mask)
176
 
177
  logits = logits.detach().cpu().numpy()
178
  predictions.extend(logits.argmax(1))
 
30
  device = torch.device("cpu")
31
 
32
 
33
+ if torch.cuda.is_available():
34
+ device = torch.device("cuda")
35
+ else:
36
+ device = torch.device("cpu")
37
+
38
+
39
+ MODEL = "modern-large" #mlp, ct, modern-base, modern-large, gte-base, gte-large
40
 
41
  class ConspiracyClassification(
42
  nn.Module,
 
71
 
72
  return outputs
73
 
74
+ class CTBERT(
75
  nn.Module,
76
  PyTorchModelHubMixin,
77
  # optionally, you can add metadata which gets pushed to the model card
78
+ ):
79
  def __init__(self, num_classes):
80
  super().__init__()
 
81
  self.bert = BertForPreTraining.from_pretrained('digitalepidemiologylab/covid-twitter-bert-v2')
82
  self.bert.cls.seq_relationship = nn.Linear(1024, num_classes)
 
 
83
 
84
+ def forward(self, input_ids, input_mask, token_type_ids):
85
  outputs = self.bert(input_ids = input_ids, token_type_ids = token_type_ids, attention_mask = input_mask)
 
86
  logits = outputs[1]
87
 
88
+ return logits
89
+
90
+ class conspiracyModelBase(
91
+ nn.Module,
92
+ PyTorchModelHubMixin,
93
+ # optionally, you can add metadata which gets pushed to the model card
94
+ ):
95
+ def __init__(self, num_classes):
96
+ super().__init__()
97
+ self.n_classes = num_classes
98
+ self.bert = ModernBertForSequenceClassification.from_pretrained('answerdotai/ModernBERT-base', num_labels=num_classes)
99
+
100
+ def forward(self, input_ids, input_mask):
101
+ outputs = self.bert(input_ids = input_ids, attention_mask = input_mask)
102
+
103
+ return outputs.logits
104
 
105
+ class conspiracyModelLarge(
106
+ nn.Module,
107
+ PyTorchModelHubMixin,
108
+ # optionally, you can add metadata which gets pushed to the model card
109
+ ):
110
+ def __init__(self, num_classes):
111
+ super().__init__()
112
+ self.n_classes = num_classes
113
+ self.bert = ModernBertForSequenceClassification.from_pretrained('answerdotai/ModernBERT-large', num_labels=num_classes)
114
+
115
+ def forward(self, input_ids, input_mask):
116
+ outputs = self.bert(input_ids = input_ids, attention_mask = input_mask)
117
+
118
+ return outputs.logits
119
+
120
+ class gteModelLarge(
121
+ nn.Module,
122
+ PyTorchModelHubMixin,
123
+ # optionally, you can add metadata which gets pushed to the model card
124
+ ):
125
+ def __init__(self, num_classes):
126
+ super().__init__()
127
+ self.n_classes = num_classes
128
+ #self.bert = ModernBertForSequenceClassification.from_pretrained('answerdotai/ModernBERT-large', num_labels=num_classes)
129
+ self.gte = AutoModel.from_pretrained('Alibaba-NLP/gte-large-en-v1.5', trust_remote_code=True)
130
+ #self.cls = nn.Linear(768, num_classes)
131
+ self.cls = nn.Linear(1024, num_classes)
132
+
133
+ def forward(self, input_ids, input_mask, input_type_ids):
134
+ outputs = self.gte(input_ids = input_ids, attention_mask = input_mask, token_type_ids = input_type_ids)
135
+ embeddings = outputs.last_hidden_state[:, 0]
136
+ logits = self.cls(embeddings)
137
+ return logits
138
+
139
+ class gteModel(
140
+ nn.Module,
141
+ PyTorchModelHubMixin,
142
+ # optionally, you can add metadata which gets pushed to the model card
143
+ ):
144
+ def __init__(self, num_classes):
145
+ super().__init__()
146
+ self.n_classes = num_classes
147
+ #self.bert = ModernBertForSequenceClassification.from_pretrained('answerdotai/ModernBERT-large', num_labels=num_classes)
148
+ self.gte = AutoModel.from_pretrained('Alibaba-NLP/gte-base-en-v1.5', trust_remote_code=True)
149
+ self.cls = nn.Linear(768, num_classes)
150
+ #self.cls = nn.Linear(1024, num_classes)
151
+
152
+ def forward(self, input_ids, input_mask, input_type_ids):
153
+ outputs = self.gte(input_ids = input_ids, attention_mask = input_mask, token_type_ids = input_type_ids)
154
+ embeddings = outputs.last_hidden_state[:, 0]
155
+ logits = self.cls(embeddings)
156
+ return logits
157
+
158
 
159
  @router.post(ROUTE, tags=["Text Task"],
160
  description=DESCRIPTION)
 
190
  # Split dataset
191
  train_test = dataset["train"]
192
  test_dataset = dataset["test"]
 
 
 
 
193
 
 
 
 
 
194
  if MODEL =="mlp":
195
  model = ConspiracyClassification.from_pretrained("ypesk/frugal-ai-mlp-baseline")
196
  model = model.to(device)
197
  emb_model = SentenceTransformer("paraphrase-MiniLM-L3-v2")
198
  batch_size = 6
199
+
200
  test_texts = torch.Tensor(emb_model.encode([t['quote'] for t in test_dataset]))
201
  test_data = TensorDataset(test_texts)
202
  test_sampler = SequentialSampler(test_data)
203
  test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=batch_size)
204
 
205
  elif MODEL == "ct":
206
+ model = CTBERT.from_pretrained("ypesk/frugal-ai-ct-bert-baseline")
207
  model = model.to(device)
208
  tokenizer = AutoTokenizer.from_pretrained('digitalepidemiologylab/covid-twitter-bert')
209
 
 
223
 
224
  test_sampler = SequentialSampler(test_data)
225
  test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=batch_size)
226
+
227
+ elif MODEL == "modern-base":
228
+ model = conspiracyModelBase.from_pretrained("ypesk/frugal-ai-modern-base-baseline")
229
+ model = model.to(device)
230
+ tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
231
+
232
+ test_texts = [t['quote'] for t in test_dataset]
233
+
234
+ MAX_LEN = 256 #1024 # < m some tweets will be truncated
235
+
236
+ tokenized_test = tokenizer(test_texts, max_length=MAX_LEN, padding='max_length', truncation=True)
237
+ test_input_ids, test_attention_mask = tokenized_test['input_ids'], tokenized_test['attention_mask']
238
+
239
+ test_input_ids = torch.tensor(test_input_ids)
240
+ test_attention_mask = torch.tensor(test_attention_mask)
241
+
242
+ batch_size = 12 #
243
+ test_data = TensorDataset(test_input_ids, test_attention_mask)
244
+
245
+ test_sampler = SequentialSampler(test_data)
246
+ test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=batch_size)
247
+
248
+ elif MODEL == "modern-large":
249
+ model = conspiracyModelLarge.from_pretrained("ypesk/frugal-ai-modern-large-baseline")
250
+ model = model.to(device)
251
+ tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-large")
252
+
253
+ test_texts = [t['quote'] for t in test_dataset]
254
+
255
+ MAX_LEN = 256 #1024 # < m some tweets will be truncated
256
+
257
+ tokenized_test = tokenizer(test_texts, max_length=MAX_LEN, padding='max_length', truncation=True)
258
+ test_input_ids, test_attention_mask = tokenized_test['input_ids'], tokenized_test['attention_mask']
259
+
260
+ test_input_ids = torch.tensor(test_input_ids)
261
+ test_attention_mask = torch.tensor(test_attention_mask)
262
+
263
+ batch_size = 12 #
264
+ test_data = TensorDataset(test_input_ids, test_attention_mask)
265
+
266
+ test_sampler = SequentialSampler(test_data)
267
+ test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=batch_size)
268
+
269
+ elif MODEL == "gte-base":
270
+ model = gteModel.from_pretrained("ypesk/frugal-ai-gte-base-baseline")
271
+ model = model.to(device)
272
+ tokenizer = AutoTokenizer.from_pretrained('Alibaba-NLP/gte-base-en-v1.5')
273
+
274
+ test_texts = [t['quote'] for t in test_dataset]
275
+
276
+ MAX_LEN = 256 #1024 # < m some tweets will be truncated
277
+
278
+ tokenized_test = tokenizer(test_texts, max_length=MAX_LEN, padding='max_length', truncation=True)
279
+ test_input_ids, test_attention_mask, test_token_type_ids = tokenized_test['input_ids'], tokenized_test['attention_mask'], tokenized_test['token_type_ids']
280
+
281
+ test_input_ids = torch.tensor(test_input_ids)
282
+ test_attention_mask = torch.tensor(test_attention_mask)
283
+ test_token_type_ids = torch.tensor(test_token_type_ids)
284
+
285
+ batch_size = 12 #
286
+ test_data = TensorDataset(test_input_ids, test_attention_mask, test_token_type_ids)
287
+
288
+ test_sampler = SequentialSampler(test_data)
289
+ test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=batch_size)
290
+
291
+ elif MODEL == "gte-large":
292
+ model = gteModel.from_pretrained("ypesk/frugal-ai-gte-large-baseline")
293
+ model = model.to(device)
294
+ tokenizer = AutoTokenizer.from_pretrained('Alibaba-NLP/gte-large-en-v1.5')
295
+
296
+ test_texts = [t['quote'] for t in test_dataset]
297
+
298
+ MAX_LEN = 256 #1024 # < m some tweets will be truncated
299
+
300
+ tokenized_test = tokenizer(test_texts, max_length=MAX_LEN, padding='max_length', truncation=True)
301
+ test_input_ids, test_attention_mask, test_token_type_ids = tokenized_test['input_ids'], tokenized_test['attention_mask'], tokenized_test['token_type_ids']
302
+
303
+ test_input_ids = torch.tensor(test_input_ids)
304
+ test_attention_mask = torch.tensor(test_attention_mask)
305
+ test_token_type_ids = torch.tensor(test_token_type_ids)
306
+
307
+ batch_size = 12 #
308
+ test_data = TensorDataset(test_input_ids, test_attention_mask, test_token_type_ids)
309
+
310
+ test_sampler = SequentialSampler(test_data)
311
+ test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=batch_size)
312
+
313
+
314
+
315
+
316
+ # Start tracking emissions
317
+ tracker.start()
318
+ tracker.start_task("inference")
319
 
320
+ #--------------------------------------------------------------------------------------------
321
+ # YOUR MODEL INFERENCE CODE HERE
322
+ # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
323
+ #--------------------------------------------------------------------------------------------
324
+
325
  model.eval()
 
326
  for batch in tqdm(test_dataloader):
327
  batch = tuple(t.to(device) for t in batch)
328
  with torch.no_grad():
329
  if MODEL =="mlp":
330
  b_texts = batch[0]
331
  logits = model(b_texts)
332
+ elif MODEL == "modern-base" or MODEL=="modern-large":
333
+ b_input_ids, b_input_mask = batch
334
+ logits = model(b_input_ids, b_input_mask)
335
+ elif MODEL == "gte-base" or MODEL=="gte-large" or MODEL=="ct":
336
  b_input_ids, b_input_mask, b_token_type_ids = batch
337
+ logits = model(b_input_ids, b_input_mask, b_token_type_ids)
338
 
339
  logits = logits.detach().cpu().numpy()
340
  predictions.extend(logits.argmax(1))