File size: 765 Bytes
3965925
 
 
8318b40
 
3965925
 
d9a9784
9d2acd8
3965925
9d2acd8
3965925
9d2acd8
 
 
8e19006
ad63f3a
3965925
d9a9784
ad63f3a
3965925
 
ad63f3a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline

tokenizer = AutoTokenizer.from_pretrained(".", local_files_only=True)
model = AutoModelForSeq2SeqLM.from_pretrained(".", local_files_only=True)
pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer)

def predict(task, prompt, context="", auto_cot=False):
    full_prompt = f"[TASK: {task.upper()}] {prompt}"
    if context:
        full_prompt += f" Context: {context}"
    if auto_cot:
        full_prompt += "\nLet's think step by step."
    output = pipe(full_prompt, max_new_tokens=128)[0]["generated_text"]
    return output

demo = gr.Interface(
    fn=predict,
    inputs=["text", "text", "text", "checkbox"],
    outputs="text"
)

demo.launch()