Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -241,7 +241,6 @@ def simple_chat(message: dict, temperature: float = 0.8, max_length: int = 4096,
|
|
241 |
|
242 |
conversation = [{"role": "user", "content": input_text}]
|
243 |
input_ids = tokenizer(conversation[-1]['content'], return_tensors="pt").to(model.device)
|
244 |
-
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
245 |
|
246 |
generate_kwargs = dict(
|
247 |
max_length=max_length,
|
@@ -250,30 +249,16 @@ def simple_chat(message: dict, temperature: float = 0.8, max_length: int = 4096,
|
|
250 |
top_k=top_k,
|
251 |
temperature=temperature,
|
252 |
repetition_penalty=penalty,
|
253 |
-
eos_token_id=[151329, 151336, 151338]
|
254 |
-
streamer=streamer
|
255 |
)
|
256 |
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
with torch.no_grad():
|
261 |
-
model.generate(input_ids['input_ids'], **generate_kwargs)
|
262 |
-
|
263 |
-
# Run generation in a separate thread
|
264 |
-
thread = Thread(target=generate_text)
|
265 |
-
thread.start()
|
266 |
-
|
267 |
-
# Collect generated text in real-time
|
268 |
-
for new_text in streamer:
|
269 |
-
buffer += new_text
|
270 |
-
|
271 |
-
# Wait for the generation thread to finish
|
272 |
-
thread.join()
|
273 |
|
274 |
# Process to remove any prefix or unwanted prompt
|
275 |
text_original = input_text.strip()
|
276 |
-
results_text =
|
277 |
|
278 |
print(" ")
|
279 |
print("------")
|
@@ -289,7 +274,6 @@ def simple_chat(message: dict, temperature: float = 0.8, max_length: int = 4096,
|
|
289 |
|
290 |
|
291 |
|
292 |
-
|
293 |
# @spaces.GPU()
|
294 |
# def simple_chat(message: dict, temperature: float = 0.8, max_length: int = 4096, top_p: float = 1, top_k: int = 10, penalty: float = 1.0):
|
295 |
# try:
|
|
|
241 |
|
242 |
conversation = [{"role": "user", "content": input_text}]
|
243 |
input_ids = tokenizer(conversation[-1]['content'], return_tensors="pt").to(model.device)
|
|
|
244 |
|
245 |
generate_kwargs = dict(
|
246 |
max_length=max_length,
|
|
|
249 |
top_k=top_k,
|
250 |
temperature=temperature,
|
251 |
repetition_penalty=penalty,
|
252 |
+
eos_token_id=[151329, 151336, 151338]
|
|
|
253 |
)
|
254 |
|
255 |
+
with torch.no_grad():
|
256 |
+
generated_ids = model.generate(input_ids['input_ids'], **generate_kwargs)
|
257 |
+
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
258 |
|
259 |
# Process to remove any prefix or unwanted prompt
|
260 |
text_original = input_text.strip()
|
261 |
+
results_text = generated_text[len(text_original):].strip()
|
262 |
|
263 |
print(" ")
|
264 |
print("------")
|
|
|
274 |
|
275 |
|
276 |
|
|
|
277 |
# @spaces.GPU()
|
278 |
# def simple_chat(message: dict, temperature: float = 0.8, max_length: int = 4096, top_p: float = 1, top_k: int = 10, penalty: float = 1.0):
|
279 |
# try:
|