File size: 3,646 Bytes
9c078e7
4a8131a
9c078e7
56f924f
9c078e7
 
 
 
56f924f
 
 
 
0e24f0d
b1a9b5c
9c078e7
aecc737
0e24f0d
aecc737
 
 
9c078e7
99d2247
 
 
7c5f508
56f924f
 
 
 
 
 
 
 
 
 
 
 
e1ae004
99d2247
 
 
e1ae004
99d2247
9c078e7
99d2247
 
 
 
 
 
 
 
 
 
 
 
9c078e7
0e24f0d
9c078e7
 
 
0e24f0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99d2247
 
0e24f0d
 
 
 
aecc737
0e24f0d
 
 
 
aecc737
 
 
 
 
 
 
 
 
9c078e7
0e24f0d
 
9c078e7
 
 
 
0e24f0d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import gradio as gr
import spaces
from huggingface_hub import InferenceClient
from transformers import AutoModelForCausalLM, AutoTokenizer

"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
# client = InferenceClient("BatsResearch/bonito-v1")
model = AutoModelForCausalLM.from_pretrained("BatsResearch/bonito-v1")
tokenizer = AutoTokenizer.from_pretrained("BatsResearch/bonito-v1")
model.to("cuda")

@spaces.GPU
def respond(
    message,
    task_type,
    max_tokens,
    temperature,
    top_p,
):
    task_type = task_type.lower()
    input_text = "<|tasktype|>\n" + task_type.strip()
    input_text += "\n<|context|>\n" + message.strip() + "\n<|task|>\n"

    input_ids = tokenizer.encode(input_text, return_tensors="pt").to("cuda")
    output = model.generate(
        input_ids,
        max_new_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        do_sample=True,
    )

    response = tokenizer.decode(output[0], skip_special_tokens=True)

    # response = client.text_generation(input_text, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p)

    return response
    # messages = []
    # messages.append({"role": "user", "content": message})

    # response = ""


    # for message in client.text_generation(
    #     messages,
    #     max_tokens=max_tokens,
    #     stream=True,
    #     temperature=temperature,
    #     top_p=top_p,
    # ):
    #     token = message.choices[0].delta.content

    #     response += token
    #     yield response


"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
# demo = gr.ChatInterface(
#     respond,
#     additional_inputs=[
#         gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
#         gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
#         gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
#         gr.Slider(
#             minimum=0.1,
#             maximum=1.0,
#             value=0.95,
#             step=0.05,
#             label="Top-p (nucleus sampling)",
#         ),
#     ],
# )
task_types = [
    "extractive question answering",
    "multiple-choice question answering",
    "question generation",
    "question answering without choices",
    "yes-no question answering",
    "coreference resolution",
    "paraphrase generation",
    "paraphrase identification",
    "sentence completion",
    "sentiment",
    "summarization",
    "text generation",
    "topic classification",
    "word sense disambiguation",
    "textual entailment",
    "natural language inference",
]
# capitalize for better readability
task_types = [task_type.capitalize() for task_type in task_types]

demo = gr.Interface(
    fn=respond,
    inputs=[
        gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
        gr.Dropdown(task_types, label="Task type"),
    ],
    outputs=gr.Textbox(label="Response"),
    additional_inputs=[
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
    title="Zephyr Chatbot",
    description="A chatbot that uses the Hugging Face Zephyr model.",
)


if __name__ == "__main__":
    demo.launch()