Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -2,6 +2,7 @@ import spaces
|
|
2 |
import os
|
3 |
import torch
|
4 |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
|
|
5 |
import gradio as gr
|
6 |
|
7 |
text_generator = None
|
@@ -36,7 +37,7 @@ if not is_hugging_face:
|
|
36 |
model = AutoModelForCausalLM.from_pretrained(
|
37 |
model_id, token=huggingface_token ,torch_dtype=dtype,device_map=device
|
38 |
)
|
39 |
-
text_generator = pipeline("text-generation", model=model, tokenizer=tokenizer,torch_dtype=dtype,device_map=device ) #pipeline has not to(device)
|
40 |
|
41 |
if next(model.parameters()).is_cuda:
|
42 |
print("The model is on a GPU")
|
@@ -57,9 +58,10 @@ def generate_text(messages):
|
|
57 |
model = AutoModelForCausalLM.from_pretrained(
|
58 |
model_id, token=huggingface_token ,torch_dtype=dtype,device_map=device
|
59 |
)
|
60 |
-
|
|
|
61 |
result = text_generator(messages, max_new_tokens=256, do_sample=True, temperature=0.7)
|
62 |
-
|
63 |
generated_output = ""
|
64 |
for token in result:
|
65 |
generated_output += token["generated_token"]
|
|
|
2 |
import os
|
3 |
import torch
|
4 |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
5 |
+
from transformers import TextStreamer
|
6 |
import gradio as gr
|
7 |
|
8 |
text_generator = None
|
|
|
37 |
model = AutoModelForCausalLM.from_pretrained(
|
38 |
model_id, token=huggingface_token ,torch_dtype=dtype,device_map=device
|
39 |
)
|
40 |
+
text_generator = pipeline("text-generation", model=model, tokenizer=tokenizer,torch_dtype=dtype,device_map=device,stream=True ) #pipeline has not to(device)
|
41 |
|
42 |
if next(model.parameters()).is_cuda:
|
43 |
print("The model is on a GPU")
|
|
|
58 |
model = AutoModelForCausalLM.from_pretrained(
|
59 |
model_id, token=huggingface_token ,torch_dtype=dtype,device_map=device
|
60 |
)
|
61 |
+
streamer = TextStreamer(tokenizer, skip_prompt=True)
|
62 |
+
text_generator = pipeline("text-generation", model=model, tokenizer=tokenizer,torch_dtype=dtype,device_map=device ,streamer=streamer) #pipeline has not to(device)
|
63 |
result = text_generator(messages, max_new_tokens=256, do_sample=True, temperature=0.7)
|
64 |
+
print(result)
|
65 |
generated_output = ""
|
66 |
for token in result:
|
67 |
generated_output += token["generated_token"]
|