Added some efficiency
Browse files- tasks/text.py +25 -0
tasks/text.py
CHANGED
@@ -18,6 +18,31 @@ router = APIRouter()
|
|
18 |
DESCRIPTION = "Electra_Base"
|
19 |
ROUTE = "/text"
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
@router.post(ROUTE, tags=["Text Task"],
|
22 |
description=DESCRIPTION)
|
23 |
async def evaluate_text(request: TextEvaluationRequest):
|
|
|
18 |
DESCRIPTION = "Electra_Base"
|
19 |
ROUTE = "/text"
|
20 |
|
21 |
+
class CustomTFDataset(tf.data.Dataset):
|
22 |
+
def __init__(self, texts, labels, tokenizer, max_length=128):
|
23 |
+
self.texts = texts
|
24 |
+
self.labels = labels
|
25 |
+
self.tokenizer = tokenizer
|
26 |
+
self.max_length = max_length
|
27 |
+
|
28 |
+
def __len__(self):
|
29 |
+
return len(self.texts)
|
30 |
+
|
31 |
+
def __iter__(self):
|
32 |
+
for text, label in zip(self.texts, self.labels):
|
33 |
+
encoding = self.tokenizer(
|
34 |
+
text,
|
35 |
+
truncation=True,
|
36 |
+
padding='max_length',
|
37 |
+
max_length=self.max_length,
|
38 |
+
return_tensors='tf'
|
39 |
+
)
|
40 |
+
yield {
|
41 |
+
'input_ids': encoding['input_ids'][0],
|
42 |
+
'attention_mask': encoding['attention_mask'][0],
|
43 |
+
'label': tf.constant(label, dtype=tf.int32)
|
44 |
+
}
|
45 |
+
|
46 |
@router.post(ROUTE, tags=["Text Task"],
|
47 |
description=DESCRIPTION)
|
48 |
async def evaluate_text(request: TextEvaluationRequest):
|