MrPio commited on
Commit
9c9f43c
·
verified ·
1 Parent(s): dba8453

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -13
app.py CHANGED
@@ -1,6 +1,9 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoModelForSequenceClassification, DebertaV2Tokenizer
 
 
 
4
 
5
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
6
  CLASSES = {
@@ -9,21 +12,24 @@ CLASSES = {
9
  'no': 2,
10
  }
11
  tokenizer = DebertaV2Tokenizer.from_pretrained('cross-encoder/nli-deberta-v3-base', do_lower_case=True)
12
- model = AutoModelForSequenceClassification.from_pretrained('MrPio/TheSeagullStory-nli-deberta-v3-base')
13
- model.eval()
 
 
 
14
  story = open('story.txt').read().replace("\n\n", "\n").replace("\n", " ").strip()
15
 
16
- if torch.cuda.is_available():
17
- model.half()
18
-
19
  def ask(question):
20
- with torch.no_grad():
21
- input = tokenizer(story, question, truncation=True, padding=True,return_tensors="pt")
22
- input = {key: value.to(device) for key, value in input.items()}
23
- output=model(**input)
24
- prediction = torch.softmax(output.logits, 1).squeeze()
25
- print(prediction)
26
- return {c: round(prediction[i].item(), 3) for c, i in CLASSES.items()}
 
 
 
27
 
28
 
29
  gradio = gr.Interface(
 
1
  import gradio as gr
2
  import torch
3
+ import tensorflow as tf
4
+ from transformers import AutoModelForSequenceClassification, DebertaV2Tokenizer,TFAutoModelForSequenceClassification
5
+
6
+ USE_TENSORFLOW=True
7
 
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
  CLASSES = {
 
12
  'no': 2,
13
  }
14
  tokenizer = DebertaV2Tokenizer.from_pretrained('cross-encoder/nli-deberta-v3-base', do_lower_case=True)
15
+ model = TFAutoModelForSequenceClassification.from_pretrained('MrPio/TheSeagullStory-nli-deberta-v3-base',dtype=tf.float16) if USE_TENSORFLOW else AutoModelForSequenceClassification.from_pretrained('MrPio/TheSeagullStory-nli-deberta-v3-base')
16
+ if not USE_TENSORFLOW:
17
+ model.eval()
18
+ if torch.cuda.is_available():
19
+ model.half()
20
  story = open('story.txt').read().replace("\n\n", "\n").replace("\n", " ").strip()
21
 
 
 
 
22
  def ask(question):
23
+ input = tokenizer(story, question, truncation=True, padding=True,return_tensors='tf' if USE_TENSORFLOW else 'pt')
24
+ if not USE_TENSORFLOW:
25
+ input = {key: value.to(device) for key, value in input.items()}
26
+ output=model(**input)
27
+ prediction = torch.softmax(output.logits, 1).squeeze()
28
+ return {c: round(prediction[i].item(), 3) for c, i in CLASSES.items()}
29
+ else:
30
+ output=model(input, training=False)
31
+ prediction = tf.nn.softmax(output.logits, axis=-1).numpy().squeeze()
32
+ return {c: round(prediction[i], 3) for c, i in CLASSES.items()}
33
 
34
 
35
  gradio = gr.Interface(