AminFaraji commited on
Commit
df78453
·
verified ·
1 Parent(s): 24f89fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -0
app.py CHANGED
@@ -18,6 +18,25 @@ generation_config.pad_token_id = tokenizer.eos_token_id
18
  generation_config.eos_token_id = tokenizer.eos_token_id
19
  generation_config
20
  stop_tokens = [["Human", ":"], ["AI", ":"]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  stopping_criteria = StoppingCriteriaList(
22
  [StopGenerationCriteria(stop_tokens, tokenizer, model.device)]
23
  )
 
18
  generation_config.eos_token_id = tokenizer.eos_token_id
19
  generation_config
20
  stop_tokens = [["Human", ":"], ["AI", ":"]]
21
+
22
+ class StopGenerationCriteria(StoppingCriteria):
23
+ def __init__(
24
+ self, tokens: List[List[str]], tokenizer: AutoTokenizer, device: torch.device
25
+ ):
26
+ stop_token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens]
27
+ self.stop_token_ids = [
28
+ torch.tensor(x, dtype=torch.long, device=device) for x in stop_token_ids
29
+ ]
30
+
31
+ def __call__(
32
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
33
+ ) -> bool:
34
+ for stop_ids in self.stop_token_ids:
35
+ if torch.eq(input_ids[0][-len(stop_ids) :], stop_ids).all():
36
+ return True
37
+ return False
38
+
39
+
40
  stopping_criteria = StoppingCriteriaList(
41
  [StopGenerationCriteria(stop_tokens, tokenizer, model.device)]
42
  )