ypesk commited on
Commit
4d8b8b9
·
verified ·
1 Parent(s): a71804f

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +19 -12
tasks/text.py CHANGED
@@ -3,6 +3,14 @@ from datetime import datetime
3
  from datasets import load_dataset
4
  from sklearn.metrics import accuracy_score
5
  import random
 
 
 
 
 
 
 
 
6
 
7
  from .utils.evaluation import TextEvaluationRequest
8
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
@@ -55,13 +63,16 @@ async def evaluate_text(request: TextEvaluationRequest):
55
  # YOUR MODEL INFERENCE CODE HERE
56
  # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
57
  #--------------------------------------------------------------------------------------------
58
- class CovidTwitterBertClassifier(nn.Module):
59
-
60
- def __init__(self, n_classes):
 
 
 
61
  super().__init__()
62
- self.n_classes = n_classes
63
  self.bert = BertForPreTraining.from_pretrained('digitalepidemiologylab/covid-twitter-bert-v2')
64
- self.bert.cls.seq_relationship = nn.Linear(1024, n_classes)
65
 
66
  self.sigmoid = nn.Sigmoid()
67
 
@@ -71,11 +82,7 @@ async def evaluate_text(request: TextEvaluationRequest):
71
  logits = outputs[1]
72
 
73
  return logits
74
-
75
- model = CovidTwitterBertClassifier(8)
76
-
77
- model.to(device)
78
- model.load_state_dict(torch.load('ypesk/ct_baseline/CTBert_128_e15_0.692.pth'))
79
  model.eval()
80
 
81
 
@@ -83,7 +90,7 @@ async def evaluate_text(request: TextEvaluationRequest):
83
 
84
  test_texts = [t['quote'] for t in data_test]
85
 
86
- MAX_LEN = 128 #1024 # < m some tweets will be truncated
87
 
88
  tokenized_test = tokenizer(test_texts, max_length=MAX_LEN, padding='max_length', truncation=True)
89
  test_input_ids, test_token_type_ids, test_attention_mask = tokenized_test['input_ids'], tokenized_test['token_type_ids'], tokenized_test['attention_mask']
@@ -92,7 +99,7 @@ async def evaluate_text(request: TextEvaluationRequest):
92
  test_input_ids = torch.tensor(test_input_ids)
93
  test_attention_mask = torch.tensor(test_attention_mask)
94
 
95
- batch_size = 8 #
96
  test_data = TensorDataset(test_input_ids, test_attention_mask, test_token_type_ids)
97
 
98
  test_sampler = SequentialSampler(test_data)
 
3
  from datasets import load_dataset
4
  from sklearn.metrics import accuracy_score
5
  import random
6
+ import numpy as np
7
+ from huggingface_hub import PyTorchModelHubMixin
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
12
+ from transformers import BertForPreTraining, BertModel, AutoTokenizer, BertForSequenceClassification, RobertaForSequenceClassification
13
+
14
 
15
  from .utils.evaluation import TextEvaluationRequest
16
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
 
63
  # YOUR MODEL INFERENCE CODE HERE
64
  # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
65
  #--------------------------------------------------------------------------------------------
66
+ class CovidTwitterBertClassifier(
67
+ nn.Module,
68
+ PyTorchModelHubMixin,
69
+ # optionally, you can add metadata which gets pushed to the model card
70
+ ):
71
+ def __init__(self, num_classes):
72
  super().__init__()
73
+ self.n_classes = num_classes
74
  self.bert = BertForPreTraining.from_pretrained('digitalepidemiologylab/covid-twitter-bert-v2')
75
+ self.bert.cls.seq_relationship = nn.Linear(1024, num_classes)
76
 
77
  self.sigmoid = nn.Sigmoid()
78
 
 
82
  logits = outputs[1]
83
 
84
  return logits
85
+ model = CovidTwitterBertClassifier.from_pretrained("ypesk/ct-baseline")
 
 
 
 
86
  model.eval()
87
 
88
 
 
90
 
91
  test_texts = [t['quote'] for t in data_test]
92
 
93
+ MAX_LEN = 256 #1024 # < m some tweets will be truncated
94
 
95
  tokenized_test = tokenizer(test_texts, max_length=MAX_LEN, padding='max_length', truncation=True)
96
  test_input_ids, test_token_type_ids, test_attention_mask = tokenized_test['input_ids'], tokenized_test['token_type_ids'], tokenized_test['attention_mask']
 
99
  test_input_ids = torch.tensor(test_input_ids)
100
  test_attention_mask = torch.tensor(test_attention_mask)
101
 
102
+ batch_size = 12 #
103
  test_data = TensorDataset(test_input_ids, test_attention_mask, test_token_type_ids)
104
 
105
  test_sampler = SequentialSampler(test_data)