MrPio commited on
Commit
dba8453
·
verified ·
1 Parent(s): 4ffcee4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -3
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  import torch
3
  from transformers import AutoModelForSequenceClassification, DebertaV2Tokenizer
4
 
 
5
  CLASSES = {
6
  'yes': 0,
7
  'irrelevant': 1,
@@ -12,11 +13,17 @@ model = AutoModelForSequenceClassification.from_pretrained('MrPio/TheSeagullStor
12
  model.eval()
13
  story = open('story.txt').read().replace("\n\n", "\n").replace("\n", " ").strip()
14
 
 
 
15
 
16
  def ask(question):
17
- inputs = tokenizer(story, question, truncation=True, padding=True)
18
- prediction = torch.softmax(model(**inputs).logits, 1).squeeze()
19
- return {c: round(prediction[i].item(), 3) for c, i in CLASSES}
 
 
 
 
20
 
21
 
22
  gradio = gr.Interface(
 
2
  import torch
3
  from transformers import AutoModelForSequenceClassification, DebertaV2Tokenizer
4
 
5
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
6
  CLASSES = {
7
  'yes': 0,
8
  'irrelevant': 1,
 
13
  model.eval()
14
  story = open('story.txt').read().replace("\n\n", "\n").replace("\n", " ").strip()
15
 
16
+ if torch.cuda.is_available():
17
+ model.half()
18
 
19
  def ask(question):
20
+ with torch.no_grad():
21
+ input = tokenizer(story, question, truncation=True, padding=True,return_tensors="pt")
22
+ input = {key: value.to(device) for key, value in input.items()}
23
+ output=model(**input)
24
+ prediction = torch.softmax(output.logits, 1).squeeze()
25
+ print(prediction)
26
+ return {c: round(prediction[i].item(), 3) for c, i in CLASSES.items()}
27
 
28
 
29
  gradio = gr.Interface(