jennasparks commited on
Commit
19efa32
·
verified ·
1 Parent(s): 996c86a

Added some efficiency

Browse files
Files changed (1) hide show
  1. tasks/text.py +25 -0
tasks/text.py CHANGED
@@ -18,6 +18,31 @@ router = APIRouter()
18
  DESCRIPTION = "Electra_Base"
19
  ROUTE = "/text"
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  @router.post(ROUTE, tags=["Text Task"],
22
  description=DESCRIPTION)
23
  async def evaluate_text(request: TextEvaluationRequest):
 
18
  DESCRIPTION = "Electra_Base"
19
  ROUTE = "/text"
20
 
21
+ class CustomTFDataset(tf.data.Dataset):
22
+ def __init__(self, texts, labels, tokenizer, max_length=128):
23
+ self.texts = texts
24
+ self.labels = labels
25
+ self.tokenizer = tokenizer
26
+ self.max_length = max_length
27
+
28
+ def __len__(self):
29
+ return len(self.texts)
30
+
31
+ def __iter__(self):
32
+ for text, label in zip(self.texts, self.labels):
33
+ encoding = self.tokenizer(
34
+ text,
35
+ truncation=True,
36
+ padding='max_length',
37
+ max_length=self.max_length,
38
+ return_tensors='tf'
39
+ )
40
+ yield {
41
+ 'input_ids': encoding['input_ids'][0],
42
+ 'attention_mask': encoding['attention_mask'][0],
43
+ 'label': tf.constant(label, dtype=tf.int32)
44
+ }
45
+
46
  @router.post(ROUTE, tags=["Text Task"],
47
  description=DESCRIPTION)
48
  async def evaluate_text(request: TextEvaluationRequest):