File size: 711 Bytes
3965925
 
 
 
 
 
 
d9a9784
3965925
 
 
 
 
d9a9784
3965925
d9a9784
3965925
d9a9784
 
3965925
 
d9a9784
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline

tokenizer = AutoTokenizer.from_pretrained("./mtpe-model")
model = AutoModelForSeq2SeqLM.from_pretrained("./mtpe-model")
pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer)

def predict(task, prompt, context="", auto_cot=False):
    input_str = f"[TASK: {task.upper()}] {prompt}"
    if context:
        input_str += f" Context: {context}"
    if auto_cot:
        input_str += "\nLet's think step by step."
    return pipe(input_str, max_new_tokens=128)[0]["generated_text"]

app = gr.Interface(
    fn=predict,
    inputs=["text", "text", "text", "checkbox"],
    outputs="text"
)

app.launch()