Spaces:
Running
Running
Add Flagging
Browse files- .gitignore +2 -1
- app.py +34 -0
- requirements.txt +5 -2
.gitignore
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
.venv/
|
2 |
.idea/
|
3 |
-
.gradio/
|
|
|
|
1 |
.venv/
|
2 |
.idea/
|
3 |
+
.gradio/
|
4 |
+
.env
|
app.py
CHANGED
@@ -1,6 +1,13 @@
|
|
|
|
|
|
|
|
1 |
from typing import Any, Sequence
|
2 |
|
|
|
3 |
import gradio as gr
|
|
|
|
|
|
|
4 |
import tensorflow as tf
|
5 |
import torch
|
6 |
from gradio import CSVLogger, FlaggingCallback
|
@@ -25,6 +32,23 @@ if not USE_TENSORFLOW:
|
|
25 |
model.half()
|
26 |
story = open('story.txt').read().replace("\n\n", "\n").replace("\n", " ").strip()
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
def ask(question):
|
30 |
input = tokenizer(story, question, truncation=True, padding=True, return_tensors='tf' if USE_TENSORFLOW else 'pt')
|
@@ -42,11 +66,21 @@ def ask(question):
|
|
42 |
class Flagger(FlaggingCallback):
|
43 |
def __init__(self):
|
44 |
self.base_logger = CSVLogger()
|
|
|
45 |
|
46 |
def setup(self, components: Sequence[Component], flagging_dir: str):
|
47 |
self.base_logger.setup(components=components, flagging_dir=flagging_dir)
|
48 |
|
49 |
def flag(self, flag_data: list[Any], flag_option: str | None = None, username: str | None = None) -> int:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
return self.base_logger.flag(flag_data=flag_data, flag_option=flag_option, username=username)
|
51 |
|
52 |
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
from datetime import datetime
|
4 |
from typing import Any, Sequence
|
5 |
|
6 |
+
import firebase_admin
|
7 |
import gradio as gr
|
8 |
+
import pytz
|
9 |
+
from dotenv import load_dotenv
|
10 |
+
from firebase_admin import credentials, firestore
|
11 |
import tensorflow as tf
|
12 |
import torch
|
13 |
from gradio import CSVLogger, FlaggingCallback
|
|
|
32 |
model.half()
|
33 |
story = open('story.txt').read().replace("\n\n", "\n").replace("\n", " ").strip()
|
34 |
|
35 |
+
load_dotenv()
|
36 |
+
cred = credentials.Certificate({
|
37 |
+
"type": "service_account",
|
38 |
+
"project_id": "scheda-dnd",
|
39 |
+
"private_key": os.environ.get("PRIVATE_KEY"),
|
40 |
+
"private_key_id": "948666ca297742d06eebd6a97f77f750d033c208",
|
41 |
+
"client_email": "[email protected]",
|
42 |
+
"client_id": "105104335855166557589",
|
43 |
+
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
44 |
+
"token_uri": "https://oauth2.googleapis.com/token",
|
45 |
+
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
|
46 |
+
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/firebase-adminsdk-is4pg%40scheda-dnd.iam.gserviceaccount.com",
|
47 |
+
"universe_domain": "googleapis.com",
|
48 |
+
})
|
49 |
+
firebase_admin.initialize_app(cred)
|
50 |
+
db = firestore.client()
|
51 |
+
|
52 |
|
53 |
def ask(question):
|
54 |
input = tokenizer(story, question, truncation=True, padding=True, return_tensors='tf' if USE_TENSORFLOW else 'pt')
|
|
|
66 |
class Flagger(FlaggingCallback):
|
67 |
def __init__(self):
|
68 |
self.base_logger = CSVLogger()
|
69 |
+
self.flags_collection = db.collection("other_apps/seagull_story/seagull_story_flags")
|
70 |
|
71 |
def setup(self, components: Sequence[Component], flagging_dir: str):
|
72 |
self.base_logger.setup(components=components, flagging_dir=flagging_dir)
|
73 |
|
74 |
def flag(self, flag_data: list[Any], flag_option: str | None = None, username: str | None = None) -> int:
|
75 |
+
if len(flag_data[0]) > 3 and 'confidences' in flag_data[1]:
|
76 |
+
self.flags_collection.document(str(time.time_ns())).set({
|
77 |
+
"question": flag_data[0],
|
78 |
+
"prediction": flag_data[1]['label'],
|
79 |
+
"confidences": flag_data[1]['confidences'],
|
80 |
+
"flag": flag_option,
|
81 |
+
"timestamp": datetime.now(pytz.utc),
|
82 |
+
"username": username,
|
83 |
+
})
|
84 |
return self.base_logger.flag(flag_data=flag_data, flag_option=flag_option, username=username)
|
85 |
|
86 |
|
requirements.txt
CHANGED
@@ -2,7 +2,10 @@ huggingface_hub==0.25.2
|
|
2 |
transformers
|
3 |
tokenizers
|
4 |
torch
|
5 |
-
gradio
|
6 |
sentencepiece
|
7 |
tensorflow
|
8 |
-
tf-keras
|
|
|
|
|
|
|
|
2 |
transformers
|
3 |
tokenizers
|
4 |
torch
|
5 |
+
gradio~=5.9.1
|
6 |
sentencepiece
|
7 |
tensorflow
|
8 |
+
tf-keras
|
9 |
+
firebase-admin~=6.6.0
|
10 |
+
pytz~=2022.6
|
11 |
+
python-dotenv~=0.21.0
|