sotosbarl commited on
Commit
805f6fe
·
1 Parent(s): 46cd5e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -8
app.py CHANGED
@@ -4,20 +4,20 @@ import pickle
4
  import streamlit as st
5
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
6
 
7
- model_name = "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli"
8
- tokenizer = AutoTokenizer.from_pretrained(model_name)
9
- model = AutoModelForSequenceClassification.from_pretrained(model_name)
 
 
 
10
 
11
 
12
  label_names = ["γάμος", "αλλοδαπός", "φορολογία", "κληρονομικά", "στέγη", "οικογενειακό", "εμπορικό","κλοπή","απάτη"]
13
 
14
 
15
  def classify(text):
16
- input = tokenizer(text, truncation=True, return_tensors="pt")
17
- output = model(input["input_ids"].to(device)) # device = "cuda:0" or "cpu"
18
- prediction = torch.softmax(output["logits"][0], -1).tolist()
19
- prediction = {name: round(float(pred) * 100, 1) for pred, name in zip(prediction, label_names)}
20
- return prediction
21
 
22
 
23
  text = st.text_input('Enter some text:') # Input field for new text
 
4
  import streamlit as st
5
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
6
 
7
+ # model_name = "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli"
8
+ # tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ # model = AutoModelForSequenceClassification.from_pretrained(model_name)
10
+
11
+ classifier = pipeline("zero-shot-classification", model="MoritzLaurer/mDeBERTa-v3-base-mnli-xnli")
12
+
13
 
14
 
15
  label_names = ["γάμος", "αλλοδαπός", "φορολογία", "κληρονομικά", "στέγη", "οικογενειακό", "εμπορικό","κλοπή","απάτη"]
16
 
17
 
18
  def classify(text):
19
+ output = classifier(text, label_names, multi_label=True)
20
+ return output
 
 
 
21
 
22
 
23
  text = st.text_input('Enter some text:') # Input field for new text