Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,074 Bytes
7c7f18b 8b23e9c 7c7f18b 8b23e9c 0aa17c3 8b23e9c 0aa17c3 7c7f18b 8b23e9c 7c7f18b 0aa17c3 7c7f18b 0aa17c3 8b23e9c f427a25 7c7f18b 0aa17c3 7c7f18b 0aa17c3 8b23e9c 7c7f18b 8b23e9c 7c7f18b 8b23e9c 7c7f18b 0aa17c3 7c7f18b 8b23e9c 0aa17c3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
from threading import Thread
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
MAX_NEW_TOKENS = 8192
MODEL_NAME = "Azure99/Blossom-V6.1-8B"
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME, torch_dtype=torch.bfloat16, device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
def get_input_ids(inst, history):
conversation = []
for user, assistant in history:
conversation.extend(
[
{"role": "user", "content": user},
{"role": "assistant", "content": assistant},
]
)
conversation.append({"role": "user", "content": inst})
return tokenizer.apply_chat_template(conversation, return_tensors="pt").to(
model.device
)
@spaces.GPU(duration=60)
def chat(inst, history, temperature, top_p, repetition_penalty):
streamer = TextIteratorStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=True
)
input_ids = get_input_ids(inst, history)
generation_kwargs = dict(
input_ids=input_ids,
streamer=streamer,
do_sample=True,
max_new_tokens=MAX_NEW_TOKENS,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
Thread(target=model.generate, kwargs=generation_kwargs).start()
outputs = ""
for new_text in streamer:
outputs += new_text
yield outputs
additional_inputs = [
gr.Slider(
label="Temperature",
value=0.5,
minimum=0.0,
maximum=1.0,
step=0.05,
interactive=True,
info="Controls randomness in choosing words.",
),
gr.Slider(
label="Top-P",
value=0.85,
minimum=0.0,
maximum=1.0,
step=0.05,
interactive=True,
info="Picks words until their combined probability is at least top_p.",
),
gr.Slider(
label="Repetition penalty",
value=1.05,
minimum=1.0,
maximum=1.2,
step=0.01,
interactive=True,
info="Repetition Penalty: Controls how much repetition is penalized.",
),
]
gr.ChatInterface(
chat,
chatbot=gr.Chatbot(
show_label=False, height=500, show_copy_button=True, render_markdown=True
),
textbox=gr.Textbox(placeholder="", container=False, scale=7),
title="Blossom-V6.1-8B Demo",
description="Hello, I am Blossom, an open source conversational large language model.🌠"
'<a href="https://github.com/Azure99/BlossomLM">GitHub</a>',
theme="soft",
examples=[
["Hello"],
["What is MBTI"],
["用Python实现二分查找"],
["为switch写一篇小红书种草文案,带上emoji"],
],
cache_examples=False,
additional_inputs=additional_inputs,
additional_inputs_accordion=gr.Accordion(label="Config", open=True),
clear_btn="🗑️Clear",
undo_btn="↩️Undo",
retry_btn="🔄Retry",
submit_btn="➡️Submit",
).queue().launch()
|