Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -124,27 +124,7 @@ Current conversation:
|
|
124 |
Human: Who is Dwight K Schrute?
|
125 |
AI:
|
126 |
""".strip()
|
127 |
-
|
128 |
-
input_ids = input_ids.to(model.device)
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
class StopGenerationCriteria(StoppingCriteria):
|
133 |
-
def __init__(
|
134 |
-
self, tokens: List[List[str]], tokenizer: AutoTokenizer, device: torch.device
|
135 |
-
):
|
136 |
-
stop_token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens]
|
137 |
-
self.stop_token_ids = [
|
138 |
-
torch.tensor(x, dtype=torch.long, device=device) for x in stop_token_ids
|
139 |
-
]
|
140 |
-
|
141 |
-
def __call__(
|
142 |
-
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
143 |
-
) -> bool:
|
144 |
-
for stop_ids in self.stop_token_ids:
|
145 |
-
if torch.eq(input_ids[0][-len(stop_ids) :], stop_ids).all():
|
146 |
-
return True
|
147 |
-
return False
|
148 |
|
149 |
|
150 |
stop_tokens = [["Human", ":"], ["AI", ":"]]
|
|
|
124 |
Human: Who is Dwight K Schrute?
|
125 |
AI:
|
126 |
""".strip()
|
127 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
|
129 |
|
130 |
stop_tokens = [["Human", ":"], ["AI", ":"]]
|