bonito / app.py
Nihal Nayak
instruction response pair
4d92f8a
raw
history blame
2.91 kB
import gradio as gr
import spaces
from huggingface_hub import InferenceClient
from transformers import AutoModelForCausalLM, AutoTokenizer
"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
model = AutoModelForCausalLM.from_pretrained("BatsResearch/bonito-v1")
tokenizer = AutoTokenizer.from_pretrained("BatsResearch/bonito-v1")
model.to("cuda")
@spaces.GPU
def respond(
message,
task_type,
max_tokens,
temperature,
top_p,
):
task_type = task_type.lower()
input_text = "<|tasktype|>\n" + task_type.strip()
input_text += "\n<|context|>\n" + message.strip() + "\n<|task|>\n"
input_ids = tokenizer.encode(input_text, return_tensors="pt").to("cuda")
output = model.generate(
input_ids,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
)
pred_start = int(input_ids.shape[-1])
response = tokenizer.decode(output[0][pred_start:], skip_special_tokens=True)
# check if <|pipe|> is in the response
if "<|pipe|>" in response:
pair = response.split("<|pipe|>")
instruction = pair[0].strip().replace("{{context}}", message)
response = pair[1].strip()
else:
# fallback
instruction = pair[0].strip().replace("{{context}}", message)
response = "Unable to generate response. Please regenerate."
return instruction, response
task_types = [
"extractive question answering",
"multiple-choice question answering",
"question generation",
"question answering without choices",
"yes-no question answering",
"coreference resolution",
"paraphrase generation",
"paraphrase identification",
"sentence completion",
"sentiment",
"summarization",
"text generation",
"topic classification",
"word sense disambiguation",
"textual entailment",
"natural language inference",
]
# capitalize for better readability
task_types = [task_type.capitalize() for task_type in task_types]
demo = gr.Interface(
fn=respond,
inputs=[
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
gr.Dropdown(task_types, label="Task type"),
],
outputs=[gr.Textbox(label="Input"), gr.Textbox(label="Output")],
additional_inputs=[
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
title="Zephyr Chatbot",
description="A chatbot that uses the Hugging Face Zephyr model.",
)
if __name__ == "__main__":
demo.launch()