Spaces:
Running
Running
File size: 4,865 Bytes
870933d 5bc5294 db49d02 870933d 6841f61 870933d db49d02 5bc5294 db49d02 9c9f43c 5bc5294 a7a36a7 dba8453 a7a36a7 4ffcee4 db49d02 9c9f43c a7a36a7 870933d 5bc5294 a7a36a7 5bc5294 9c9f43c 5bc5294 9c9f43c 5bc5294 870933d 5bc5294 870933d 5bc5294 a7a36a7 4ffcee4 a7a36a7 5bc5294 45172d4 3a82778 45172d4 a7a36a7 6841f61 4ffcee4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import os
import time
from datetime import datetime
from typing import Any, Sequence
import firebase_admin
import gradio as gr
import pytz
from dotenv import load_dotenv
from firebase_admin import credentials, firestore
import tensorflow as tf
import torch
from gradio import CSVLogger, FlaggingCallback
from gradio.components import Component
from transformers import DebertaV2Tokenizer, TFAutoModelForSequenceClassification, AutoModelForSequenceClassification
USE_TENSORFLOW = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CLASSES = {
'yes': 0,
'irrelevant': 1,
'no': 2,
}
tokenizer = DebertaV2Tokenizer.from_pretrained('cross-encoder/nli-deberta-v3-base', do_lower_case=True)
model = TFAutoModelForSequenceClassification.from_pretrained('MrPio/TheSeagullStory-nli-deberta-v3-base',
dtype=tf.float16) if USE_TENSORFLOW else AutoModelForSequenceClassification.from_pretrained(
'MrPio/TheSeagullStory-nli-deberta-v3-base')
if not USE_TENSORFLOW:
model.eval()
if torch.cuda.is_available():
model.half()
story = open('story.txt').read().replace("\n\n", "\n").replace("\n", " ").strip()
load_dotenv()
cred = credentials.Certificate({
"type": "service_account",
"project_id": "scheda-dnd",
"private_key": os.environ.get("PRIVATE_KEY"),
"private_key_id": "948666ca297742d06eebd6a97f77f750d033c208",
"client_email": "[email protected]",
"client_id": "105104335855166557589",
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token",
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/firebase-adminsdk-is4pg%40scheda-dnd.iam.gserviceaccount.com",
"universe_domain": "googleapis.com",
})
firebase_admin.initialize_app(cred)
db = firestore.client()
def ask(question):
input = tokenizer(story, question, truncation=True, padding=True, return_tensors='tf' if USE_TENSORFLOW else 'pt')
if not USE_TENSORFLOW:
input = {key: value.to(device) for key, value in input.items()}
output = model(**input)
prediction = torch.softmax(output.logits, 1).squeeze()
return {c: round(prediction[i].item(), 3) for c, i in CLASSES.items()}
else:
output = model(input, training=False)
prediction = tf.nn.softmax(output.logits, axis=-1).numpy().squeeze()
return {c: round(prediction[i], 3) for c, i in CLASSES.items()}
class Flagger(FlaggingCallback):
def __init__(self):
self.base_logger = CSVLogger()
self.flags_collection = db.collection("other_apps/seagull_story/seagull_story_flags")
def setup(self, components: Sequence[Component], flagging_dir: str):
self.base_logger.setup(components=components, flagging_dir=flagging_dir)
def flag(self, flag_data: list[Any], flag_option: str | None = None, username: str | None = None) -> int:
if len(flag_data[0]) > 3 and 'confidences' in flag_data[1]:
self.flags_collection.document(str(time.time_ns())).set({
"question": flag_data[0],
"prediction": flag_data[1]['label'],
"confidences": flag_data[1]['confidences'],
"flag": flag_option,
"timestamp": datetime.now(pytz.utc),
"username": username,
})
return self.base_logger.flag(flag_data=flag_data, flag_option=flag_option, username=username)
gradio = gr.Interface(
ask,
inputs=[gr.Textbox(value="", label="Your question, as an affirmative sentence:")],
outputs=[gr.Label(label="Answer", num_top_classes=3)],
title="The Seagull Story",
flagging_mode='manual',
flagging_callback=Flagger(),
flagging_options=['Yes', 'No', 'Irrelevant'],
description="“ Albert and Dave find themselves on the pier. They go to a nearby restaurant where Albert orders "
"seagull meat. The waiter promptly serves Albert the meal. After taking a bite, he realizes "
"something. Albert pulls a gun out of his ruined jacket and shoots himself. ”\n\nWhy did Albert shoot "
"himself?\n\nCan you unravel the truth behind this epilogue by asking only yes/no questions?\n\nPlease be specific about the time period you have in mind with your question.",
article='Please refrain from embarrassing DeBERTa with dumb questions.\n\nCheck the repository for more detail: https://github.com/MrPio/The-Seagull-Story',
examples=['Albert shoot himself for a reason',
'Dave has a watch on his wrist',
'Albert and Dave came to the pier on their own']
)
if __name__ == "__main__":
gradio.launch(share=True)
|