MrPio commited on
Commit
870933d
·
1 Parent(s): db49d02

Add Flagging

Browse files
Files changed (3) hide show
  1. .gitignore +2 -1
  2. app.py +34 -0
  3. requirements.txt +5 -2
.gitignore CHANGED
@@ -1,3 +1,4 @@
1
  .venv/
2
  .idea/
3
- .gradio/
 
 
1
  .venv/
2
  .idea/
3
+ .gradio/
4
+ .env
app.py CHANGED
@@ -1,6 +1,13 @@
 
 
 
1
  from typing import Any, Sequence
2
 
 
3
  import gradio as gr
 
 
 
4
  import tensorflow as tf
5
  import torch
6
  from gradio import CSVLogger, FlaggingCallback
@@ -25,6 +32,23 @@ if not USE_TENSORFLOW:
25
  model.half()
26
  story = open('story.txt').read().replace("\n\n", "\n").replace("\n", " ").strip()
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  def ask(question):
30
  input = tokenizer(story, question, truncation=True, padding=True, return_tensors='tf' if USE_TENSORFLOW else 'pt')
@@ -42,11 +66,21 @@ def ask(question):
42
  class Flagger(FlaggingCallback):
43
  def __init__(self):
44
  self.base_logger = CSVLogger()
 
45
 
46
  def setup(self, components: Sequence[Component], flagging_dir: str):
47
  self.base_logger.setup(components=components, flagging_dir=flagging_dir)
48
 
49
  def flag(self, flag_data: list[Any], flag_option: str | None = None, username: str | None = None) -> int:
 
 
 
 
 
 
 
 
 
50
  return self.base_logger.flag(flag_data=flag_data, flag_option=flag_option, username=username)
51
 
52
 
 
1
+ import os
2
+ import time
3
+ from datetime import datetime
4
  from typing import Any, Sequence
5
 
6
+ import firebase_admin
7
  import gradio as gr
8
+ import pytz
9
+ from dotenv import load_dotenv
10
+ from firebase_admin import credentials, firestore
11
  import tensorflow as tf
12
  import torch
13
  from gradio import CSVLogger, FlaggingCallback
 
32
  model.half()
33
  story = open('story.txt').read().replace("\n\n", "\n").replace("\n", " ").strip()
34
 
35
+ load_dotenv()
36
+ cred = credentials.Certificate({
37
+ "type": "service_account",
38
+ "project_id": "scheda-dnd",
39
+ "private_key": os.environ.get("PRIVATE_KEY"),
40
+ "private_key_id": "948666ca297742d06eebd6a97f77f750d033c208",
41
+ "client_email": "[email protected]",
42
+ "client_id": "105104335855166557589",
43
+ "auth_uri": "https://accounts.google.com/o/oauth2/auth",
44
+ "token_uri": "https://oauth2.googleapis.com/token",
45
+ "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
46
+ "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/firebase-adminsdk-is4pg%40scheda-dnd.iam.gserviceaccount.com",
47
+ "universe_domain": "googleapis.com",
48
+ })
49
+ firebase_admin.initialize_app(cred)
50
+ db = firestore.client()
51
+
52
 
53
  def ask(question):
54
  input = tokenizer(story, question, truncation=True, padding=True, return_tensors='tf' if USE_TENSORFLOW else 'pt')
 
66
  class Flagger(FlaggingCallback):
67
  def __init__(self):
68
  self.base_logger = CSVLogger()
69
+ self.flags_collection = db.collection("other_apps/seagull_story/seagull_story_flags")
70
 
71
  def setup(self, components: Sequence[Component], flagging_dir: str):
72
  self.base_logger.setup(components=components, flagging_dir=flagging_dir)
73
 
74
  def flag(self, flag_data: list[Any], flag_option: str | None = None, username: str | None = None) -> int:
75
+ if len(flag_data[0]) > 3 and 'confidences' in flag_data[1]:
76
+ self.flags_collection.document(str(time.time_ns())).set({
77
+ "question": flag_data[0],
78
+ "prediction": flag_data[1]['label'],
79
+ "confidences": flag_data[1]['confidences'],
80
+ "flag": flag_option,
81
+ "timestamp": datetime.now(pytz.utc),
82
+ "username": username,
83
+ })
84
  return self.base_logger.flag(flag_data=flag_data, flag_option=flag_option, username=username)
85
 
86
 
requirements.txt CHANGED
@@ -2,7 +2,10 @@ huggingface_hub==0.25.2
2
  transformers
3
  tokenizers
4
  torch
5
- gradio
6
  sentencepiece
7
  tensorflow
8
- tf-keras
 
 
 
 
2
  transformers
3
  tokenizers
4
  torch
5
+ gradio~=5.9.1
6
  sentencepiece
7
  tensorflow
8
+ tf-keras
9
+ firebase-admin~=6.6.0
10
+ pytz~=2022.6
11
+ python-dotenv~=0.21.0