TheSeagullStory / app.py
MrPio's picture
Add Flagging
870933d
import os
import time
from datetime import datetime
from typing import Any, Sequence
import firebase_admin
import gradio as gr
import pytz
from dotenv import load_dotenv
from firebase_admin import credentials, firestore
import tensorflow as tf
import torch
from gradio import CSVLogger, FlaggingCallback
from gradio.components import Component
from transformers import DebertaV2Tokenizer, TFAutoModelForSequenceClassification, AutoModelForSequenceClassification
USE_TENSORFLOW = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CLASSES = {
'yes': 0,
'irrelevant': 1,
'no': 2,
}
tokenizer = DebertaV2Tokenizer.from_pretrained('cross-encoder/nli-deberta-v3-base', do_lower_case=True)
model = TFAutoModelForSequenceClassification.from_pretrained('MrPio/TheSeagullStory-nli-deberta-v3-base',
dtype=tf.float16) if USE_TENSORFLOW else AutoModelForSequenceClassification.from_pretrained(
'MrPio/TheSeagullStory-nli-deberta-v3-base')
if not USE_TENSORFLOW:
model.eval()
if torch.cuda.is_available():
model.half()
story = open('story.txt').read().replace("\n\n", "\n").replace("\n", " ").strip()
load_dotenv()
cred = credentials.Certificate({
"type": "service_account",
"project_id": "scheda-dnd",
"private_key": os.environ.get("PRIVATE_KEY"),
"private_key_id": "948666ca297742d06eebd6a97f77f750d033c208",
"client_email": "[email protected]",
"client_id": "105104335855166557589",
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token",
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/firebase-adminsdk-is4pg%40scheda-dnd.iam.gserviceaccount.com",
"universe_domain": "googleapis.com",
})
firebase_admin.initialize_app(cred)
db = firestore.client()
def ask(question):
input = tokenizer(story, question, truncation=True, padding=True, return_tensors='tf' if USE_TENSORFLOW else 'pt')
if not USE_TENSORFLOW:
input = {key: value.to(device) for key, value in input.items()}
output = model(**input)
prediction = torch.softmax(output.logits, 1).squeeze()
return {c: round(prediction[i].item(), 3) for c, i in CLASSES.items()}
else:
output = model(input, training=False)
prediction = tf.nn.softmax(output.logits, axis=-1).numpy().squeeze()
return {c: round(prediction[i], 3) for c, i in CLASSES.items()}
class Flagger(FlaggingCallback):
def __init__(self):
self.base_logger = CSVLogger()
self.flags_collection = db.collection("other_apps/seagull_story/seagull_story_flags")
def setup(self, components: Sequence[Component], flagging_dir: str):
self.base_logger.setup(components=components, flagging_dir=flagging_dir)
def flag(self, flag_data: list[Any], flag_option: str | None = None, username: str | None = None) -> int:
if len(flag_data[0]) > 3 and 'confidences' in flag_data[1]:
self.flags_collection.document(str(time.time_ns())).set({
"question": flag_data[0],
"prediction": flag_data[1]['label'],
"confidences": flag_data[1]['confidences'],
"flag": flag_option,
"timestamp": datetime.now(pytz.utc),
"username": username,
})
return self.base_logger.flag(flag_data=flag_data, flag_option=flag_option, username=username)
gradio = gr.Interface(
ask,
inputs=[gr.Textbox(value="", label="Your question, as an affirmative sentence:")],
outputs=[gr.Label(label="Answer", num_top_classes=3)],
title="The Seagull Story",
flagging_mode='manual',
flagging_callback=Flagger(),
flagging_options=['Yes', 'No', 'Irrelevant'],
description="β€œ Albert and Dave find themselves on the pier. They go to a nearby restaurant where Albert orders "
"seagull meat. The waiter promptly serves Albert the meal. After taking a bite, he realizes "
"something. Albert pulls a gun out of his ruined jacket and shoots himself. ”\n\nWhy did Albert shoot "
"himself?\n\nCan you unravel the truth behind this epilogue by asking only yes/no questions?\n\nPlease be specific about the time period you have in mind with your question.",
article='Please refrain from embarrassing DeBERTa with dumb questions.\n\nCheck the repository for more detail: https://github.com/MrPio/The-Seagull-Story',
examples=['Albert shoot himself for a reason',
'Dave has a watch on his wrist',
'Albert and Dave came to the pier on their own']
)
if __name__ == "__main__":
gradio.launch(share=True)