Akjava's picture
Update app.py
3ab88cc verified
raw
history blame
2.3 kB
import spaces
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from transformers import TextIteratorStreamer
from threading import Thread
import gradio as gr
text_generator = None
model_id = "AXCXEPT/phi-4-deepseek-R1K-RL-EZO"
#model_id = "AXCXEPT/phi-4-open-R1-Distill-EZOv1"#not well work with my old code
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
huggingface_token = None
device = "auto" # torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cuda"
dtype = torch.bfloat16
if not huggingface_token:
pass
print("no HUGGINGFACE_TOKEN if you need set secret ")
#raise ValueError("HUGGINGFACE_TOKEN environment variable is not set")
tokenizer = AutoTokenizer.from_pretrained(model_id, token=huggingface_token)
#print(tokenizer.special_tokens_map)
# ็‰นๆฎŠใƒˆใƒผใ‚ฏใƒณIDใ‚’็ขบ่ช
#print(tokenizer.eos_token_id)
#print(tokenizer.encode("<|im_end|>", add_special_tokens=False))
#print(model_id,device,dtype)
histories = []
model = AutoModelForCausalLM.from_pretrained(
model_id, token=huggingface_token ,torch_dtype=dtype,device_map=device
)
model.to(device)
def generate_text(messages):
question = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
question = tokenizer(question, return_tensors="pt").to(device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
generation_kwargs = dict(question, streamer=streamer, max_new_tokens=1000)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
generated_output = ""
thread.start()
for new_text in streamer:
generated_output += new_text.replace("<|im_end|>","")#just replace
yield generated_output
# SDK version is very important in README.md
@spaces.GPU(duration=120)
def call_generate_text(message, history):
messages = history+[{"role":"user","content":message}]
try:
for text in generate_text(messages):
yield text
except RuntimeError as e:
print(f"An unexpected error occurred: {e}")
yield ""
demo = gr.ChatInterface(call_generate_text,type="messages")
if __name__ == "__main__":
demo.queue()
demo.launch()