AminFaraji commited on
Commit
078dec5
·
verified ·
1 Parent(s): 7d2deb9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -21
app.py CHANGED
@@ -124,27 +124,7 @@ Current conversation:
124
  Human: Who is Dwight K Schrute?
125
  AI:
126
  """.strip()
127
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids
128
- input_ids = input_ids.to(model.device)
129
-
130
-
131
-
132
- class StopGenerationCriteria(StoppingCriteria):
133
- def __init__(
134
- self, tokens: List[List[str]], tokenizer: AutoTokenizer, device: torch.device
135
- ):
136
- stop_token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens]
137
- self.stop_token_ids = [
138
- torch.tensor(x, dtype=torch.long, device=device) for x in stop_token_ids
139
- ]
140
-
141
- def __call__(
142
- self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
143
- ) -> bool:
144
- for stop_ids in self.stop_token_ids:
145
- if torch.eq(input_ids[0][-len(stop_ids) :], stop_ids).all():
146
- return True
147
- return False
148
 
149
 
150
  stop_tokens = [["Human", ":"], ["AI", ":"]]
 
124
  Human: Who is Dwight K Schrute?
125
  AI:
126
  """.strip()
127
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
 
130
  stop_tokens = [["Human", ":"], ["AI", ":"]]