File size: 3,934 Bytes
b262a14 b21b992 b262a14 34d79f8 b21b992 ec07b24 b262a14 ec07b24 b262a14 0298010 b262a14 bd0098f dde58dc b262a14 dde58dc 34d79f8 bf14f3d 34d79f8 bd0098f 34d79f8 25a236c bd0098f 34d79f8 d646671 34d79f8 bd0098f 34d79f8 72f9be6 34d79f8 b262a14 bd0098f b262a14 bb97ed8 31d4ada b262a14 bd0098f b262a14 bd0098f b262a14 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import spaces
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
from threading import Thread
import traceback
model_path = 'infly/OpenCoder-8B-Instruct'
# Loading the tokenizer and model from Hugging Face's model hub.
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16)
# using CUDA for an optimal experience
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
# Defining a custom stopping criteria class for the model's text generation.
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_ids = [96539] # IDs of tokens where the generation should stop.
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id: # Checking if the last generated token is a stop token.
return True
return False
system_role= 'system'
user_role = 'user'
assistant_role = "assistant"
sft_start_token = "<|im_start|>"
sft_end_token = "<|im_end|>"
ct_end_token = "<|endoftext|>"
# system_prompt= 'You are a CodeLLM developed by INF.'
# Function to generate model predictions.
@spaces.GPU()
def predict(message, history):
try:
stop = StopOnTokens()
model_messages = []
# print(f'history: {history}')
for i, item in enumerate(history):
model_messages.append({"role": user_role, "content": item[0]})
model_messages.append({"role": assistant_role, "content": item[1]})
model_messages.append({"role": user_role, "content": message})
print(f'model_messages: {model_messages}')
# print(f'model_final_inputs: {tokenizer.apply_chat_template(model_messages, add_generation_prompt=True, tokenize=False)}', flush=True)
model_inputs = tokenizer.apply_chat_template(model_messages, add_generation_prompt=True, return_tensors="pt").to(device)
# model_inputs = tokenizer([messages], return_tensors="pt").to(device)
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=model_inputs,
streamer=streamer,
max_new_tokens=1024,
do_sample=False,
stopping_criteria=StoppingCriteriaList([stop])
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start() # Starting the generation in a separate thread.
partial_message = ""
for new_token in streamer:
partial_message += new_token
if sft_end_token in partial_message: # Breaking the loop if the stop token is generated.
break
yield partial_message
except Exception as e:
print(traceback.format_exc())
css = """
full-height {
height: 100%;
}
"""
prompt_examples = [
'Write a quick sort algorithm in python.',
'Write a colorful greedy snake game using pygame.',
'How to use numpy?'
]
placeholder = """
<div style="opacity: 0.5;">
<img src="https://raw.githubusercontent.com/sail-sg/sailor-llm/main/misc/banner.jpg" style="width:30%;">
<br>Sailor models are designed to understand and generate text across diverse linguistic landscapes of these SEA regions:
<br>🇮🇩Indonesian, 🇹🇭Thai, 🇻🇳Vietnamese, 🇲🇾Malay, and 🇱🇦Lao.
</div>
"""
chatbot = gr.Chatbot(label='OpenCoder', placeholder=placeholder)
with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
gr.ChatInterface(predict, chatbot=chatbot, fill_height=True, examples=prompt_examples, css=css)
demo.launch() # Launching the web interface. |