Akjava's picture
Update app.py
d2b26cf verified
raw
history blame
2.25 kB
import spaces
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from transformers import TextIteratorStreamer
from threading import Thread
import gradio as gr
text_generator = None
#model_id = "AXCXEPT/phi-4-deepseek-R1K-RL-EZO"
model_id = "AXCXEPT/phi-4-open-R1-Distill-EZOv1"
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
device = "auto" # torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cuda"
dtype = torch.bfloat16
if not huggingface_token:
pass
print("no HUGGINGFACE_TOKEN if you need set secret ")
#raise ValueError("HUGGINGFACE_TOKEN environment variable is not set")
tokenizer = AutoTokenizer.from_pretrained(model_id, token=huggingface_token)
#print(tokenizer.special_tokens_map)
# ็‰นๆฎŠใƒˆใƒผใ‚ฏใƒณIDใ‚’็ขบ่ช
#print(tokenizer.eos_token_id)
#print(tokenizer.encode("<|im_end|>", add_special_tokens=False))
#print(model_id,device,dtype)
histories = []
model = AutoModelForCausalLM.from_pretrained(
model_id, token=huggingface_token ,torch_dtype=dtype,device_map=device
)
model.to(device)
def generate_text(messages):
question = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
question = tokenizer(question, return_tensors="pt").to(device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
generation_kwargs = dict(question, streamer=streamer, max_new_tokens=1000)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
generated_output = ""
thread.start()
for new_text in streamer:
generated_output += new_text.replace("<|im_end|>","")#just replace
yield generated_output
# SDK version is very important in README.md
@spaces.GPU(duration=120)
def call_generate_text(message, history):
messages = history+[{"role":"user","content":message}]
try:
for text in generate_text(messages):
yield text
except RuntimeError as e:
print(f"An unexpected error occurred: {e}")
yield ""
demo = gr.ChatInterface(call_generate_text,type="messages")
if __name__ == "__main__":
demo.queue()
demo.launch()