Akjava's picture
Update app.py
43cc94e verified
raw
history blame
2.69 kB
import spaces
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from transformers import TextStreamer
import gradio as gr
text_generator = None
is_hugging_face = True
model_id = "google/gemma-2-9b-it"
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
device = "auto" # torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cuda"
dtype = torch.bfloat16
if not huggingface_token:
pass
print("no HUGGINGFACE_TOKEN if you need set secret ")
#raise ValueError("HUGGINGFACE_TOKEN environment variable is not set")
tokenizer = AutoTokenizer.from_pretrained(model_id, token=huggingface_token)
print(model_id,device,dtype)
histories = []
#model = None
if not is_hugging_face:
model = AutoModelForCausalLM.from_pretrained(
model_id, token=huggingface_token ,torch_dtype=dtype,device_map=device
)
text_generator = pipeline("text-generation", model=model, tokenizer=tokenizer,torch_dtype=dtype,device_map=device,stream=True ) #pipeline has not to(device)
if next(model.parameters()).is_cuda:
print("The model is on a GPU")
else:
print("The model is on a CPU")
#print(f"text_generator.device='{text_generator.device}")
if str(text_generator.device).strip() == 'cuda':
print("The pipeline is using a GPU")
else:
print("The pipeline is using a CPU")
print("initialized")
@spaces.GPU(duration=30)
def generate_text(messages):
if is_hugging_face:#need everytime initialize for ZeroGPU
model = AutoModelForCausalLM.from_pretrained(
model_id, token=huggingface_token ,torch_dtype=dtype,device_map=device
)
streamer = TextStreamer(tokenizer, skip_prompt=True)
text_generator = pipeline("text-generation", model=model, tokenizer=tokenizer,torch_dtype=dtype,device_map=device ,streamer=streamer) #pipeline has not to(device)
result = text_generator(messages, max_new_tokens=256, do_sample=True, temperature=0.7)
print(result)
generated_output = ""
for token in result:
generated_output += token["generated_token"]
yield generated_output
def call_generate_text(message, history):
# history.append({"role": "user", "content": message})
print(message)
print(history)
messages = history+[{"role":"user","content":message}]
try:
for text in generate_text(messages):
yield text
except RuntimeError as e:
print(f"An unexpected error occurred: {e}")
yield ""
demo = gr.ChatInterface(call_generate_text,type="messages")
if __name__ == "__main__":
demo.launch(share=True)