File size: 4,865 Bytes
870933d
 
 
5bc5294
db49d02
870933d
6841f61
870933d
 
 
db49d02
 
5bc5294
 
db49d02
9c9f43c
5bc5294
a7a36a7
dba8453
a7a36a7
 
 
 
 
4ffcee4
db49d02
 
 
9c9f43c
 
 
 
a7a36a7
 
870933d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5bc5294
a7a36a7
5bc5294
9c9f43c
 
5bc5294
9c9f43c
 
 
5bc5294
 
 
 
 
 
 
 
870933d
5bc5294
 
 
 
 
870933d
 
 
 
 
 
 
 
 
5bc5294
a7a36a7
 
4ffcee4
a7a36a7
 
 
 
5bc5294
 
 
45172d4
 
 
3a82778
45172d4
a7a36a7
 
 
6841f61
 
 
4ffcee4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
import time
from datetime import datetime
from typing import Any, Sequence

import firebase_admin
import gradio as gr
import pytz
from dotenv import load_dotenv
from firebase_admin import credentials, firestore
import tensorflow as tf
import torch
from gradio import CSVLogger, FlaggingCallback
from gradio.components import Component
from transformers import DebertaV2Tokenizer, TFAutoModelForSequenceClassification, AutoModelForSequenceClassification

USE_TENSORFLOW = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CLASSES = {
    'yes': 0,
    'irrelevant': 1,
    'no': 2,
}
tokenizer = DebertaV2Tokenizer.from_pretrained('cross-encoder/nli-deberta-v3-base', do_lower_case=True)
model = TFAutoModelForSequenceClassification.from_pretrained('MrPio/TheSeagullStory-nli-deberta-v3-base',
                                                             dtype=tf.float16) if USE_TENSORFLOW else AutoModelForSequenceClassification.from_pretrained(
    'MrPio/TheSeagullStory-nli-deberta-v3-base')
if not USE_TENSORFLOW:
    model.eval()
    if torch.cuda.is_available():
        model.half()
story = open('story.txt').read().replace("\n\n", "\n").replace("\n", " ").strip()

load_dotenv()
cred = credentials.Certificate({
    "type": "service_account",
    "project_id": "scheda-dnd",
    "private_key": os.environ.get("PRIVATE_KEY"),
    "private_key_id": "948666ca297742d06eebd6a97f77f750d033c208",
    "client_email": "[email protected]",
    "client_id": "105104335855166557589",
    "auth_uri": "https://accounts.google.com/o/oauth2/auth",
    "token_uri": "https://oauth2.googleapis.com/token",
    "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
    "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/firebase-adminsdk-is4pg%40scheda-dnd.iam.gserviceaccount.com",
    "universe_domain": "googleapis.com",
})
firebase_admin.initialize_app(cred)
db = firestore.client()


def ask(question):
    input = tokenizer(story, question, truncation=True, padding=True, return_tensors='tf' if USE_TENSORFLOW else 'pt')
    if not USE_TENSORFLOW:
        input = {key: value.to(device) for key, value in input.items()}
        output = model(**input)
        prediction = torch.softmax(output.logits, 1).squeeze()
        return {c: round(prediction[i].item(), 3) for c, i in CLASSES.items()}
    else:
        output = model(input, training=False)
        prediction = tf.nn.softmax(output.logits, axis=-1).numpy().squeeze()
        return {c: round(prediction[i], 3) for c, i in CLASSES.items()}


class Flagger(FlaggingCallback):
    def __init__(self):
        self.base_logger = CSVLogger()
        self.flags_collection = db.collection("other_apps/seagull_story/seagull_story_flags")

    def setup(self, components: Sequence[Component], flagging_dir: str):
        self.base_logger.setup(components=components, flagging_dir=flagging_dir)

    def flag(self, flag_data: list[Any], flag_option: str | None = None, username: str | None = None) -> int:
        if len(flag_data[0]) > 3 and 'confidences' in flag_data[1]:
            self.flags_collection.document(str(time.time_ns())).set({
                "question": flag_data[0],
                "prediction": flag_data[1]['label'],
                "confidences": flag_data[1]['confidences'],
                "flag": flag_option,
                "timestamp": datetime.now(pytz.utc),
                "username": username,
            })
        return self.base_logger.flag(flag_data=flag_data, flag_option=flag_option, username=username)


gradio = gr.Interface(
    ask,
    inputs=[gr.Textbox(value="", label="Your question, as an affirmative sentence:")],
    outputs=[gr.Label(label="Answer", num_top_classes=3)],
    title="The Seagull Story",
    flagging_mode='manual',
    flagging_callback=Flagger(),
    flagging_options=['Yes', 'No', 'Irrelevant'],
    description="“ Albert and Dave find themselves on the pier. They go to a nearby restaurant where Albert orders "
                "seagull meat. The waiter promptly serves Albert the meal. After taking a bite, he realizes "
                "something. Albert pulls a gun out of his ruined jacket and shoots himself. ”\n\nWhy did Albert shoot "
                "himself?\n\nCan you unravel the truth behind this epilogue by asking only yes/no questions?\n\nPlease be specific about the time period you have in mind with your question.",
    article='Please refrain from embarrassing DeBERTa with dumb questions.\n\nCheck the repository for more detail: https://github.com/MrPio/The-Seagull-Story',
    examples=['Albert shoot himself for a reason',
              'Dave has a watch on his wrist',
              'Albert and Dave came to the pier on their own']
)

if __name__ == "__main__":
    gradio.launch(share=True)