Spaces:
Runtime error
Runtime error
Commit
·
b840e20
1
Parent(s):
2e85755
Fix to allow masked token after 512th token
Browse filesSequences longer than 510 are now truncated around the masked token for xlm-roberta-base, regardless of mask location.
app.py
CHANGED
@@ -31,8 +31,35 @@ xlmr_tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-base', max_length=51
|
|
31 |
xlmr_p = pipeline("fill-mask", model=model, tokenizer=tokenizer)
|
32 |
|
33 |
def xlmr_base_fn(text):
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
pred_dict = {}
|
37 |
for pred in preds:
|
38 |
pred_dict[pred['token_str']] = pred['score']
|
|
|
31 |
xlmr_p = pipeline("fill-mask", model=model, tokenizer=tokenizer)
|
32 |
|
33 |
def xlmr_base_fn(text):
|
34 |
+
# Find our masked token
|
35 |
+
tokens = xlmr_tokenizer.tokenize(text)
|
36 |
+
mask_token_idx = [i for i, x in enumerate(tokens) if xlmr_tokenizer.mask_token in x][0]
|
37 |
+
|
38 |
+
max_len = tokenizer.model_max_length
|
39 |
+
max_len = max_len-2 if max_len % 512 == 0 and max_len < 4096 else 510
|
40 |
+
|
41 |
+
# Smart truncation for long sequences
|
42 |
+
if not len(tokens) < max_len:
|
43 |
+
|
44 |
+
# Find left and right bounds for truncated sequences
|
45 |
+
lbound = max(0, mask_token_idx-(max_len//2))
|
46 |
+
rbound = min(len(tokens), mask_token_idx+(max_len//2))
|
47 |
+
|
48 |
+
# If we hit an edge, expand sequence in the other direction
|
49 |
+
if lbound == 0 and rbound != len(tokens)-1:
|
50 |
+
rbound = min(len(tokens), max_len)
|
51 |
+
elif rbound == len(tokens) and lbound != 0:
|
52 |
+
lbound = max(0, len(tokens)-max_len)
|
53 |
+
|
54 |
+
# Apply truncation and rejoin tokens to form new text
|
55 |
+
truncated_text = ''.join(tokens[lbound:rbound])
|
56 |
+
|
57 |
+
# Handle lowbar from xlmr tokenizer
|
58 |
+
truncated_text = ''.join([x if ord(x) != 9601 else ' ' for x in result])
|
59 |
+
else:
|
60 |
+
truncated_text = text
|
61 |
+
|
62 |
+
preds = xlmr_p(truncated_text)
|
63 |
pred_dict = {}
|
64 |
for pred in preds:
|
65 |
pred_dict[pred['token_str']] = pred['score']
|