File size: 4,499 Bytes
d189df6 13d0b4d 722177d d189df6 d91e4b3 d189df6 de1935d d189df6 0ac0a2e d189df6 921a7bd d189df6 d91e4b3 d189df6 e98882b eb71667 d189df6 cf2c79e d189df6 0ac0a2e d189df6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import os
import json
import subprocess
from threading import Thread
import torch
import spaces
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
MODEL_ID = "infly/OpenCoder-8B-Instruct"
CHAT_TEMPLATE = "ChatML"
MODEL_NAME = MODEL_ID.split("/")[-1]
CONTEXT_LENGTH = 1300
#EMOJI = os.environ.get("EMOJI")
DESCRIPTION = "Infly OpenCoder-8B-Instruct"
@spaces.GPU()
def predict(message, history, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p):
# Format history with a given chat template
if CHAT_TEMPLATE == "ChatML":
stop_tokens = ["<|endoftext|>", "<|im_end|>", "<|end_of_text|>", "<|eot_id|>", "assistant"]
instruction = '<|im_start|>system\n' + system_prompt + '\n<|im_end|>\n'
for human, assistant in history:
instruction += '<|im_start|>user\n' + human + '\n<|im_end|>\n<|im_start|>assistant\n' + assistant
instruction += '\n<|im_start|>user\n' + message + '\n<|im_end|>\n<|im_start|>assistant\n'
elif CHAT_TEMPLATE == "Mistral Instruct":
stop_tokens = ["</s>", "[INST]", "[INST] ", "<s>", "[/INST]", "[/INST] "]
instruction = '<s>[INST] ' + system_prompt
for human, assistant in history:
instruction += human + ' [/INST] ' + assistant + '</s>[INST]'
instruction += ' ' + message + ' [/INST]'
else:
raise Exception("Incorrect chat template, select 'ChatML' or 'Mistral Instruct'")
print(instruction)
streamer = TextIteratorStreamer(tokenizer, timeout=90.0, skip_prompt=True, skip_special_tokens=True)
enc = tokenizer([instruction], return_tensors="pt", padding=True, truncation=True, max_length=CONTEXT_LENGTH)
input_ids, attention_mask = enc.input_ids, enc.attention_mask
if input_ids.shape[1] > CONTEXT_LENGTH:
input_ids = input_ids[:, -CONTEXT_LENGTH:]
generate_kwargs = dict(
{"input_ids": input_ids.to(device), "attention_mask": attention_mask.to(device)},
streamer=streamer,
do_sample=True,
temperature=temperature,
max_new_tokens=max_new_tokens,
top_k=top_k,
repetition_penalty=repetition_penalty,
top_p=top_p
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for new_token in streamer:
outputs.append(new_token)
if new_token in stop_tokens:
break
yield "".join(outputs)
def handle_retry(history, retry_data: gr.RetryData):
new_history = history[:retry_data.index]
previous_prompt = history[retry_data.index]['content']
yield from respond(previous_prompt, new_history)
# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
trust_remote_code=True
)
css = """
.message-row {
justify-content: space-evenly !important;
}
.message-bubble-border {
border-radius: 6px !important;
}
.message-buttons-bot, .message-buttons-user {
right: 10px !important;
left: auto !important;
bottom: 2px !important;
}
.dark.message-bubble-border {
border-color: #15172c !important;
}
.dark.user {
background: #10132c !important;
}
.dark.assistant.dark, .dark.pending.dark {
background: #020417 !important;
}
"""
# Create Gradio interface
gr.ChatInterface(
predict,
title="Infly" + MODEL_NAME,
description=DESCRIPTION,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False),
additional_inputs=[
gr.Textbox("Perform the task to the best of your ability.", label="System prompt"),
gr.Slider(0, 1, 0.8, label="Temperature"),
gr.Slider(128, 4096, 512, label="Max new tokens"),
gr.Slider(1, 80, 40, label="Top K sampling"),
gr.Slider(0, 2, 1.1, label="Repetition penalty"),
gr.Slider(0, 1, 0.95, label="Top P sampling"),
],
theme = gr.themes.Ocean(
secondary_hue="emerald",
),
css=css,
#retry_btn="Retry",
#undo_btn="Undo",
#clear_btn="Clear",
#submit_btn="Send",
chatbot=gr.Chatbot(
scale=1,
show_copy_button=True
),
chatbot.retry(handle_retry, chatbot, [chatbot])
).queue().launch() |