File size: 4,234 Bytes
d189df6 13d0b4d 722177d d189df6 de1935d d189df6 921a7bd d189df6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import os
import json
import subprocess
from threading import Thread
import torch
import spaces
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
MODEL_ID = "infly/OpenCoder-8B-Instruct"
CHAT_TEMPLATE = "ChatML"
MODEL_NAME = MODEL_ID.split("/")[-1]
CONTEXT_LENGTH = 1300
#EMOJI = os.environ.get("EMOJI")
#DESCRIPTION = os.environ.get("DESCRIPTION")
@spaces.GPU()
def predict(message, history, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p):
# Format history with a given chat template
if CHAT_TEMPLATE == "ChatML":
stop_tokens = ["<|endoftext|>", "<|im_end|>", "<|end_of_text|>", "<|eot_id|>", "assistant"]
instruction = '<|im_start|>system\n' + system_prompt + '\n<|im_end|>\n'
for human, assistant in history:
instruction += '<|im_start|>user\n' + human + '\n<|im_end|>\n<|im_start|>assistant\n' + assistant
instruction += '\n<|im_start|>user\n' + message + '\n<|im_end|>\n<|im_start|>assistant\n'
elif CHAT_TEMPLATE == "Mistral Instruct":
stop_tokens = ["</s>", "[INST]", "[INST] ", "<s>", "[/INST]", "[/INST] "]
instruction = '<s>[INST] ' + system_prompt
for human, assistant in history:
instruction += human + ' [/INST] ' + assistant + '</s>[INST]'
instruction += ' ' + message + ' [/INST]'
else:
raise Exception("Incorrect chat template, select 'ChatML' or 'Mistral Instruct'")
print(instruction)
streamer = TextIteratorStreamer(tokenizer, timeout=90.0, skip_prompt=True, skip_special_tokens=True)
enc = tokenizer([instruction], return_tensors="pt", padding=True, truncation=True, max_length=CONTEXT_LENGTH)
input_ids, attention_mask = enc.input_ids, enc.attention_mask
if input_ids.shape[1] > CONTEXT_LENGTH:
input_ids = input_ids[:, -CONTEXT_LENGTH:]
generate_kwargs = dict(
{"input_ids": input_ids.to(device), "attention_mask": attention_mask.to(device)},
streamer=streamer,
do_sample=True,
temperature=temperature,
max_new_tokens=max_new_tokens,
top_k=top_k,
repetition_penalty=repetition_penalty,
top_p=top_p
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for new_token in streamer:
outputs.append(new_token)
if new_token in stop_tokens:
break
yield "".join(outputs)
# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
trust_remote_code=True
)
css = """
.message-row {
justify-content: space-evenly !important;
}
.message-bubble-border {
border-radius: 6px !important;
}
.message-buttons-bot, .message-buttons-user {
right: 10px !important;
left: auto !important;
bottom: 2px !important;
}
.dark.message-bubble-border {
border-color: #15172c !important;
}
.dark.user {
background: #10132c !important;
}
.dark.assistant.dark, .dark.pending.dark {
background: #020417 !important;
}
"""
# Create Gradio interface
gr.ChatInterface(
predict,
title=EMOJI + " " + MODEL_NAME,
description=DESCRIPTION,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False),
additional_inputs=[
gr.Textbox("Perform the task to the best of your ability.", label="System prompt"),
gr.Slider(0, 1, 0.8, label="Temperature"),
gr.Slider(128, 4096, 512, label="Max new tokens"),
gr.Slider(1, 80, 40, label="Top K sampling"),
gr.Slider(0, 2, 1.1, label="Repetition penalty"),
gr.Slider(0, 1, 0.95, label="Top P sampling"),
],
theme = gr.themes.Ocean(
secondary_hue="emerald"
),
css=css,
retry_btn="Retry",
undo_btn="Undo",
clear_btn="Clear",
submit_btn="Send",
chatbot=gr.Chatbot(
scale=1,
show_copy_button=True
)
).queue().launch() |