ypesk commited on
Commit
8ff37b6
·
verified ·
1 Parent(s): 830a067

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +9 -2
tasks/text.py CHANGED
@@ -24,6 +24,12 @@ DESCRIPTION = "First Baseline"
24
  ROUTE = "/text"
25
 
26
 
 
 
 
 
 
 
27
  MODEL = "mlp" #mlp, ct, modern
28
 
29
  class ConspiracyClassification(
@@ -125,7 +131,7 @@ async def evaluate_text(request: TextEvaluationRequest):
125
  #--------------------------------------------------------------------------------------------
126
  if MODEL =="mlp":
127
  model = ConspiracyClassification.from_pretrained("ypesk/frugal-ai-mlp-baseline")
128
-
129
  emb_model = SentenceTransformer("paraphrase-MiniLM-L3-v2")
130
  batch_size = 6
131
 
@@ -136,6 +142,7 @@ async def evaluate_text(request: TextEvaluationRequest):
136
 
137
  elif MODEL == "ct":
138
  model = CovidTwitterBertClassifier.from_pretrained("ypesk/ct-baseline")
 
139
  tokenizer = AutoTokenizer.from_pretrained('digitalepidemiologylab/covid-twitter-bert')
140
 
141
  test_texts = [t['quote'] for t in test_dataset]
@@ -158,7 +165,7 @@ async def evaluate_text(request: TextEvaluationRequest):
158
  model.eval()
159
  predictions = []
160
  for batch in tqdm(test_dataloader):
161
-
162
  with torch.no_grad():
163
  if MODEL =="mlp":
164
  b_texts = batch[0]
 
24
  ROUTE = "/text"
25
 
26
 
27
+ if torch.cuda.is_available():
28
+ device = torch.device("cuda")
29
+ else:
30
+ device = torch.device("cpu")
31
+
32
+
33
  MODEL = "mlp" #mlp, ct, modern
34
 
35
  class ConspiracyClassification(
 
131
  #--------------------------------------------------------------------------------------------
132
  if MODEL =="mlp":
133
  model = ConspiracyClassification.from_pretrained("ypesk/frugal-ai-mlp-baseline")
134
+ model = model.to(device)
135
  emb_model = SentenceTransformer("paraphrase-MiniLM-L3-v2")
136
  batch_size = 6
137
 
 
142
 
143
  elif MODEL == "ct":
144
  model = CovidTwitterBertClassifier.from_pretrained("ypesk/ct-baseline")
145
+ model = model.to(device)
146
  tokenizer = AutoTokenizer.from_pretrained('digitalepidemiologylab/covid-twitter-bert')
147
 
148
  test_texts = [t['quote'] for t in test_dataset]
 
165
  model.eval()
166
  predictions = []
167
  for batch in tqdm(test_dataloader):
168
+ batch = tuple(t.to(device) for t in batch)
169
  with torch.no_grad():
170
  if MODEL =="mlp":
171
  b_texts = batch[0]