File size: 3,075 Bytes
7c7f18b
8b23e9c
7c7f18b
 
 
 
8b23e9c
0aa17c3
 
8b23e9c
0aa17c3
 
 
7c7f18b
8b23e9c
 
7c7f18b
 
 
0aa17c3
 
 
 
 
 
7c7f18b
0aa17c3
 
 
8b23e9c
 
0aa17c3
7c7f18b
0aa17c3
 
 
7c7f18b
0aa17c3
 
 
 
 
 
 
 
 
8b23e9c
7c7f18b
8b23e9c
7c7f18b
 
 
 
8b23e9c
 
7c7f18b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0aa17c3
7c7f18b
8b23e9c
0aa17c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from threading import Thread

import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

MAX_NEW_TOKENS = 8192
MODEL_NAME = "Azure99/Blossom-V6.1-8B"

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, torch_dtype=torch.bfloat16, device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)


def get_input_ids(inst, history):
    conversation = []
    for user, assistant in history:
        conversation.extend(
            [
                {"role": "user", "content": user},
                {"role": "assistant", "content": assistant},
            ]
        )
    conversation.append({"role": "user", "content": inst})
    return tokenizer.apply_chat_template(conversation, return_tensors="pt").to(
        model.device
    )


@spaces.GPU(duration=120)
def chat(inst, history, temperature, top_p, repetition_penalty):
    streamer = TextIteratorStreamer(
        tokenizer, skip_prompt=True, skip_special_tokens=True
    )
    input_ids = get_input_ids(inst, history)
    generation_kwargs = dict(
        input_ids=input_ids,
        streamer=streamer,
        do_sample=True,
        max_new_tokens=MAX_NEW_TOKENS,
        temperature=temperature,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
    )

    Thread(target=model.generate, kwargs=generation_kwargs).start()

    outputs = ""
    for new_text in streamer:
        outputs += new_text
        yield outputs


additional_inputs = [
    gr.Slider(
        label="Temperature",
        value=0.5,
        minimum=0.0,
        maximum=1.0,
        step=0.05,
        interactive=True,
        info="Controls randomness in choosing words.",
    ),
    gr.Slider(
        label="Top-P",
        value=0.85,
        minimum=0.0,
        maximum=1.0,
        step=0.05,
        interactive=True,
        info="Picks words until their combined probability is at least top_p.",
    ),
    gr.Slider(
        label="Repetition penalty",
        value=1.05,
        minimum=1.0,
        maximum=1.2,
        step=0.01,
        interactive=True,
        info="Repetition Penalty: Controls how much repetition is penalized.",
    ),
]

gr.ChatInterface(
    chat,
    chatbot=gr.Chatbot(
        show_label=False, height=500, show_copy_button=True, render_markdown=True
    ),
    textbox=gr.Textbox(placeholder="", container=False, scale=7),
    title="Blossom-V6.1-8B Demo",
    description="Hello, I am Blossom, an open source conversational large language model.🌠"
    '<a href="https://github.com/Azure99/BlossomLM">GitHub</a>',
    theme="soft",
    examples=[
        ["Hello"],
        ["What is MBTI"],
        ["用Python实现二分查找"],
        ["为switch写一篇小红书种草文案,带上emoji"],
    ],
    cache_examples=False,
    additional_inputs=additional_inputs,
    additional_inputs_accordion=gr.Accordion(label="Config", open=True),
    clear_btn="🗑️Clear",
    undo_btn="↩️Undo",
    retry_btn="🔄Retry",
    submit_btn="➡️Submit",
).queue().launch()