Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -18,6 +18,25 @@ generation_config.pad_token_id = tokenizer.eos_token_id
|
|
18 |
generation_config.eos_token_id = tokenizer.eos_token_id
|
19 |
generation_config
|
20 |
stop_tokens = [["Human", ":"], ["AI", ":"]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
stopping_criteria = StoppingCriteriaList(
|
22 |
[StopGenerationCriteria(stop_tokens, tokenizer, model.device)]
|
23 |
)
|
|
|
18 |
generation_config.eos_token_id = tokenizer.eos_token_id
|
19 |
generation_config
|
20 |
stop_tokens = [["Human", ":"], ["AI", ":"]]
|
21 |
+
|
22 |
+
class StopGenerationCriteria(StoppingCriteria):
|
23 |
+
def __init__(
|
24 |
+
self, tokens: List[List[str]], tokenizer: AutoTokenizer, device: torch.device
|
25 |
+
):
|
26 |
+
stop_token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens]
|
27 |
+
self.stop_token_ids = [
|
28 |
+
torch.tensor(x, dtype=torch.long, device=device) for x in stop_token_ids
|
29 |
+
]
|
30 |
+
|
31 |
+
def __call__(
|
32 |
+
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
33 |
+
) -> bool:
|
34 |
+
for stop_ids in self.stop_token_ids:
|
35 |
+
if torch.eq(input_ids[0][-len(stop_ids) :], stop_ids).all():
|
36 |
+
return True
|
37 |
+
return False
|
38 |
+
|
39 |
+
|
40 |
stopping_criteria = StoppingCriteriaList(
|
41 |
[StopGenerationCriteria(stop_tokens, tokenizer, model.device)]
|
42 |
)
|