TheSeagullStory / app.py
MrPio's picture
Add Flagging
db49d02
raw
history blame
3.37 kB
from typing import Any, Sequence
import gradio as gr
import tensorflow as tf
import torch
from gradio import CSVLogger, FlaggingCallback
from gradio.components import Component
from transformers import DebertaV2Tokenizer, TFAutoModelForSequenceClassification, AutoModelForSequenceClassification
USE_TENSORFLOW = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CLASSES = {
'yes': 0,
'irrelevant': 1,
'no': 2,
}
tokenizer = DebertaV2Tokenizer.from_pretrained('cross-encoder/nli-deberta-v3-base', do_lower_case=True)
model = TFAutoModelForSequenceClassification.from_pretrained('MrPio/TheSeagullStory-nli-deberta-v3-base',
dtype=tf.float16) if USE_TENSORFLOW else AutoModelForSequenceClassification.from_pretrained(
'MrPio/TheSeagullStory-nli-deberta-v3-base')
if not USE_TENSORFLOW:
model.eval()
if torch.cuda.is_available():
model.half()
story = open('story.txt').read().replace("\n\n", "\n").replace("\n", " ").strip()
def ask(question):
input = tokenizer(story, question, truncation=True, padding=True, return_tensors='tf' if USE_TENSORFLOW else 'pt')
if not USE_TENSORFLOW:
input = {key: value.to(device) for key, value in input.items()}
output = model(**input)
prediction = torch.softmax(output.logits, 1).squeeze()
return {c: round(prediction[i].item(), 3) for c, i in CLASSES.items()}
else:
output = model(input, training=False)
prediction = tf.nn.softmax(output.logits, axis=-1).numpy().squeeze()
return {c: round(prediction[i], 3) for c, i in CLASSES.items()}
class Flagger(FlaggingCallback):
def __init__(self):
self.base_logger = CSVLogger()
def setup(self, components: Sequence[Component], flagging_dir: str):
self.base_logger.setup(components=components, flagging_dir=flagging_dir)
def flag(self, flag_data: list[Any], flag_option: str | None = None, username: str | None = None) -> int:
return self.base_logger.flag(flag_data=flag_data, flag_option=flag_option, username=username)
gradio = gr.Interface(
ask,
inputs=[gr.Textbox(value="", label="Your question, as an affirmative sentence:")],
outputs=[gr.Label(label="Answer", num_top_classes=3)],
title="The Seagull Story",
flagging_mode='manual',
flagging_callback=Flagger(),
flagging_options=['Yes', 'No', 'Irrelevant'],
description="β€œ Albert and Dave find themselves on the pier. They go to a nearby restaurant where Albert orders "
"seagull meat. The waiter promptly serves Albert the meal. After taking a bite, he realizes "
"something. Albert pulls a gun out of his ruined jacket and shoots himself. ”\n\nWhy did Albert shoot "
"himself?\n\nCan you unravel the truth behind this epilogue by asking only yes/no questions?\n\nPlease be specific about the time period you have in mind with your question.",
article='Please refrain from embarrassing DeBERTa with dumb questions.\n\nCheck the repository for more detail: https://github.com/MrPio/The-Seagull-Story',
examples=['Albert shoot himself for a reason',
'Dave has a watch on his wrist',
'Albert and Dave came to the pier on their own']
)
if __name__ == "__main__":
gradio.launch(share=True)