Spaces:
Sleeping
Sleeping
Update tasks/text.py
Browse files- tasks/text.py +19 -12
tasks/text.py
CHANGED
@@ -3,6 +3,14 @@ from datetime import datetime
|
|
3 |
from datasets import load_dataset
|
4 |
from sklearn.metrics import accuracy_score
|
5 |
import random
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
from .utils.evaluation import TextEvaluationRequest
|
8 |
from .utils.emissions import tracker, clean_emissions_data, get_space_info
|
@@ -55,13 +63,16 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
55 |
# YOUR MODEL INFERENCE CODE HERE
|
56 |
# Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
|
57 |
#--------------------------------------------------------------------------------------------
|
58 |
-
class CovidTwitterBertClassifier(
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
61 |
super().__init__()
|
62 |
-
self.n_classes =
|
63 |
self.bert = BertForPreTraining.from_pretrained('digitalepidemiologylab/covid-twitter-bert-v2')
|
64 |
-
self.bert.cls.seq_relationship = nn.Linear(1024,
|
65 |
|
66 |
self.sigmoid = nn.Sigmoid()
|
67 |
|
@@ -71,11 +82,7 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
71 |
logits = outputs[1]
|
72 |
|
73 |
return logits
|
74 |
-
|
75 |
-
model = CovidTwitterBertClassifier(8)
|
76 |
-
|
77 |
-
model.to(device)
|
78 |
-
model.load_state_dict(torch.load('ypesk/ct_baseline/CTBert_128_e15_0.692.pth'))
|
79 |
model.eval()
|
80 |
|
81 |
|
@@ -83,7 +90,7 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
83 |
|
84 |
test_texts = [t['quote'] for t in data_test]
|
85 |
|
86 |
-
MAX_LEN =
|
87 |
|
88 |
tokenized_test = tokenizer(test_texts, max_length=MAX_LEN, padding='max_length', truncation=True)
|
89 |
test_input_ids, test_token_type_ids, test_attention_mask = tokenized_test['input_ids'], tokenized_test['token_type_ids'], tokenized_test['attention_mask']
|
@@ -92,7 +99,7 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
92 |
test_input_ids = torch.tensor(test_input_ids)
|
93 |
test_attention_mask = torch.tensor(test_attention_mask)
|
94 |
|
95 |
-
batch_size =
|
96 |
test_data = TensorDataset(test_input_ids, test_attention_mask, test_token_type_ids)
|
97 |
|
98 |
test_sampler = SequentialSampler(test_data)
|
|
|
3 |
from datasets import load_dataset
|
4 |
from sklearn.metrics import accuracy_score
|
5 |
import random
|
6 |
+
import numpy as np
|
7 |
+
from huggingface_hub import PyTorchModelHubMixin
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
|
12 |
+
from transformers import BertForPreTraining, BertModel, AutoTokenizer, BertForSequenceClassification, RobertaForSequenceClassification
|
13 |
+
|
14 |
|
15 |
from .utils.evaluation import TextEvaluationRequest
|
16 |
from .utils.emissions import tracker, clean_emissions_data, get_space_info
|
|
|
63 |
# YOUR MODEL INFERENCE CODE HERE
|
64 |
# Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
|
65 |
#--------------------------------------------------------------------------------------------
|
66 |
+
class CovidTwitterBertClassifier(
|
67 |
+
nn.Module,
|
68 |
+
PyTorchModelHubMixin,
|
69 |
+
# optionally, you can add metadata which gets pushed to the model card
|
70 |
+
):
|
71 |
+
def __init__(self, num_classes):
|
72 |
super().__init__()
|
73 |
+
self.n_classes = num_classes
|
74 |
self.bert = BertForPreTraining.from_pretrained('digitalepidemiologylab/covid-twitter-bert-v2')
|
75 |
+
self.bert.cls.seq_relationship = nn.Linear(1024, num_classes)
|
76 |
|
77 |
self.sigmoid = nn.Sigmoid()
|
78 |
|
|
|
82 |
logits = outputs[1]
|
83 |
|
84 |
return logits
|
85 |
+
model = CovidTwitterBertClassifier.from_pretrained("ypesk/ct-baseline")
|
|
|
|
|
|
|
|
|
86 |
model.eval()
|
87 |
|
88 |
|
|
|
90 |
|
91 |
test_texts = [t['quote'] for t in data_test]
|
92 |
|
93 |
+
MAX_LEN = 256 #1024 # < m some tweets will be truncated
|
94 |
|
95 |
tokenized_test = tokenizer(test_texts, max_length=MAX_LEN, padding='max_length', truncation=True)
|
96 |
test_input_ids, test_token_type_ids, test_attention_mask = tokenized_test['input_ids'], tokenized_test['token_type_ids'], tokenized_test['attention_mask']
|
|
|
99 |
test_input_ids = torch.tensor(test_input_ids)
|
100 |
test_attention_mask = torch.tensor(test_attention_mask)
|
101 |
|
102 |
+
batch_size = 12 #
|
103 |
test_data = TensorDataset(test_input_ids, test_attention_mask, test_token_type_ids)
|
104 |
|
105 |
test_sampler = SequentialSampler(test_data)
|