Spaces:
Build error
Build error
import gradio as gr | |
from transformers import ( | |
AutoTokenizer, | |
AutoModel, | |
AutoModelForSeq2SeqLM, | |
AutoModelForCausalLM | |
) | |
tokenizer = AutoTokenizer.from_pretrained("microsoft/GODEL-v1_1-base-seq2seq") | |
model = AutoModelForSeq2SeqLM.from_pretrained("microsoft/GODEL-v1_1-base-seq2seq") | |
preset_examples = [ | |
('Instruction: given a dialog context, you need to response empathically.', | |
'', 'Does money buy happiness?', 'Chitchat'), | |
] | |
def generate(instruction, knowledge, dialog, top_p, min_length, max_length): | |
if knowledge != '': | |
knowledge = '[KNOWLEDGE] ' + knowledge | |
dialog = ' EOS '.join(dialog) | |
query = f"{instruction} [CONTEXT] {dialog} {knowledge}" | |
input_ids = tokenizer(f"{query}", return_tensors="pt").input_ids | |
outputs = model.generate(input_ids, min_length=int( | |
min_length), max_length=int(max_length), top_p=top_p, do_sample=True) | |
output = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
print(query) | |
print(output) | |
return output | |
def api_call_generation(instruction, knowledge, query, top_p, min_length, max_length): | |
dialog = [ | |
query | |
] | |
response = generate(instruction, knowledge, dialog, | |
top_p, min_length, max_length) | |
return response | |
def change_example(choice): | |
choice_idx = int(choice.split()[-1]) - 1 | |
instruction, knowledge, query, instruction_type = preset_examples[choice_idx] | |
return [gr.update(lines=1, visible=True, value=instruction), gr.update(visible=True, value=knowledge), gr.update(lines=1, visible=True, value=query), gr.update(visible=True, value=instruction_type)] | |
def change_textbox(choice): | |
if choice == "Chitchat": | |
return gr.update(lines=1, visible=True, value="Instruction: given a dialog context, you need to response empathically.") | |
elif choice == "Grounded Response Generation": | |
return gr.update(lines=1, visible=True, value="Instruction: given a dialog context and related knowledge, you need to response safely based on the knowledge.") | |
else: | |
return gr.update(lines=1, visible=True, value="Instruction: given a dialog context and related knowledge, you need to answer the question based on the knowledge.") | |
with gr.Blocks() as demo: | |
gr.Markdown("# The broken God") | |
gr.Markdown('''All hail Mekhane. Reject flesh. Embrace metal''') | |
dropdown = gr.Dropdown( | |
[f"Example {i+1}" for i in range(1)], label='Examples') | |
radio = gr.Radio( | |
["Conversational Question Answering", "Chitchat", "Grounded Response Generation"], label="Instruction Type", value='Conversational Question Answering' | |
) | |
instruction = gr.Textbox(lines=1, interactive=True, label="Instruction", | |
value="Instruction: given a dialog context and related knowledge, you need to answer the question based on the knowledge.") | |
radio.change(fn=change_textbox, inputs=radio, outputs=instruction) | |
knowledge = gr.Textbox(lines=6, label="Knowledge") | |
query = gr.Textbox(lines=1, label="User Query") | |
dropdown.change(change_example, dropdown, [instruction, knowledge, query, radio]) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
response = gr.Textbox(label="Response", lines=2) | |
with gr.Column(scale=1): | |
top_p = gr.Slider(0, 1, value=0.9, label='top_p') | |
min_length = gr.Number(8, label='min_length') | |
max_length = gr.Number( | |
64, label='max_length (should be larger than min_length)') | |
greet_btn = gr.Button("Generate") | |
greet_btn.click(fn=api_call_generation, inputs=[ | |
instruction, knowledge, query, top_p, min_length, max_length], outputs=response) | |
demo.launch() |