MrPio commited on
Commit
db49d02
·
1 Parent(s): 5bc5294

Add Flagging

Browse files
Files changed (1) hide show
  1. app.py +7 -3
app.py CHANGED
@@ -1,9 +1,11 @@
1
  from typing import Any, Sequence
 
2
  import gradio as gr
 
 
3
  from gradio import CSVLogger, FlaggingCallback
4
  from gradio.components import Component
5
- import torch
6
- import tensorflow as tf
7
 
8
  USE_TENSORFLOW = True
9
 
@@ -14,7 +16,9 @@ CLASSES = {
14
  'no': 2,
15
  }
16
  tokenizer = DebertaV2Tokenizer.from_pretrained('cross-encoder/nli-deberta-v3-base', do_lower_case=True)
17
- model = TFAutoModelForSequenceClassification.from_pretrained('MrPio/TheSeagullStory-nli-deberta-v3-base',dtype=tf.float16) if USE_TENSORFLOW else AutoModelForSequenceClassification.from_pretrained('MrPio/TheSeagullStory-nli-deberta-v3-base')
 
 
18
  if not USE_TENSORFLOW:
19
  model.eval()
20
  if torch.cuda.is_available():
 
1
  from typing import Any, Sequence
2
+
3
  import gradio as gr
4
+ import tensorflow as tf
5
+ import torch
6
  from gradio import CSVLogger, FlaggingCallback
7
  from gradio.components import Component
8
+ from transformers import DebertaV2Tokenizer, TFAutoModelForSequenceClassification, AutoModelForSequenceClassification
 
9
 
10
  USE_TENSORFLOW = True
11
 
 
16
  'no': 2,
17
  }
18
  tokenizer = DebertaV2Tokenizer.from_pretrained('cross-encoder/nli-deberta-v3-base', do_lower_case=True)
19
+ model = TFAutoModelForSequenceClassification.from_pretrained('MrPio/TheSeagullStory-nli-deberta-v3-base',
20
+ dtype=tf.float16) if USE_TENSORFLOW else AutoModelForSequenceClassification.from_pretrained(
21
+ 'MrPio/TheSeagullStory-nli-deberta-v3-base')
22
  if not USE_TENSORFLOW:
23
  model.eval()
24
  if torch.cuda.is_available():