vkovacs commited on
Commit
fb05782
·
1 Parent(s): 158b5a1

sentence split logic added

Browse files
Files changed (1) hide show
  1. app.py +52 -8
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import torch
 
3
  import numpy as np
4
  from transformers import AutoModelForSequenceClassification
5
  from transformers import AutoTokenizer
@@ -16,6 +17,23 @@ HF_TOKEN = os.environ["hf_read"]
16
  SENTIMENT_LABEL_NAMES = {0: "Negative", 1: "No sentiment or Neutral sentiment", 2: "Positive"}
17
  LANGUAGES = ["Czech", "English", "French", "German", "Hungarian", "Polish", "Slovakian"]
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  def build_huggingface_path(language: str):
21
  if language == "Czech" or language == "Slovakian":
@@ -39,22 +57,48 @@ def predict(text, model_id, tokenizer_id):
39
  logits = model(**inputs).logits
40
 
41
  probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
42
- output_pred = {model.config.id2label[i]: probs[i] for i in np.argsort(probs)[::-1]}
43
- output_info = f'<p style="text-align: center; display: block">Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.</p>'
44
- return output_pred, output_info
 
45
 
46
  def predict_wrapper(text, language):
47
  model_id = build_huggingface_path(language)
48
  tokenizer_id = "xlm-roberta-large"
49
- return predict(text, model_id, tokenizer_id)
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  with gr.Blocks() as demo:
53
- gr.Interface(
 
 
 
 
 
 
 
 
 
 
 
54
  fn=predict_wrapper,
55
- inputs=[gr.Textbox(lines=6, label="Input"),
56
- gr.Dropdown(LANGUAGES, label="Language")],
57
- outputs=[gr.Label(num_top_classes=3, label="Output"), gr.Markdown()])
58
 
59
  if __name__ == "__main__":
60
  demo.launch()
 
 
1
  import os
2
  import torch
3
+ import spacy
4
  import numpy as np
5
  from transformers import AutoModelForSequenceClassification
6
  from transformers import AutoTokenizer
 
17
  SENTIMENT_LABEL_NAMES = {0: "Negative", 1: "No sentiment or Neutral sentiment", 2: "Positive"}
18
  LANGUAGES = ["Czech", "English", "French", "German", "Hungarian", "Polish", "Slovakian"]
19
 
20
+ def load_spacy_model(model_name="xx_sent_ud_sm"):
21
+ try:
22
+ model = spacy.load(model_name)
23
+ except OSError:
24
+ spacy.cli.download(model_name)
25
+ model = spacy.load(model_name)
26
+ return model
27
+
28
+ def split_sentences(text, model):
29
+ # disable pipeline components not necessary for splitting
30
+ model.disable_pipes(model.pipe_names) # first disable all the pipes
31
+ model.enable_pipe("senter") # then enable the sentence splitter only
32
+
33
+ doc = model(text)
34
+ sentences = [sent.text for sent in doc.sents]
35
+
36
+ return sentences
37
 
38
  def build_huggingface_path(language: str):
39
  if language == "Czech" or language == "Slovakian":
 
57
  logits = model(**inputs).logits
58
 
59
  probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
60
+ label_pred = model.config.id2label[probs.argmax()]
61
+ probability_pred = f"{100*probs.max()}%"
62
+ return label_pred, probability_pred
63
+
64
 
65
  def predict_wrapper(text, language):
66
  model_id = build_huggingface_path(language)
67
  tokenizer_id = "xlm-roberta-large"
 
68
 
69
+ spacy_model = load_spacy_model()
70
+ sentences = split_sentences(text, spacy_model)
71
+
72
+ results = []
73
+ for sentence in sentences:
74
+ label, probability = predict(sentence, model_id, tokenizer_id)
75
+ results.append({
76
+ "Sentence": sentence,
77
+ "Prediction": label,
78
+ "Probability": probability
79
+ })
80
+
81
+ output_info = f'Prediction made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.'
82
+ return results, output_info
83
 
84
  with gr.Blocks() as demo:
85
+ with gr.Row():
86
+ with gr.Column():
87
+ input_text = gr.Textbox(lines=6, label="Input Text", placeholder="Enter your text here...")
88
+ language_choice = gr.Dropdown(choices=LANGUAGES, label="Language", value="English")
89
+ predict_button = gr.Button("Submit")
90
+
91
+ with gr.Column():
92
+ result_table = gr.Dataframe(headers=["Sentence", "Prediction", "Probability"],
93
+ label="Sentence-level Predictions")
94
+ model_info = gr.Markdown()
95
+
96
+ predict_button.click(
97
  fn=predict_wrapper,
98
+ inputs=[input_text, language_choice],
99
+ outputs=[result_table, model_info]
100
+ )
101
 
102
  if __name__ == "__main__":
103
  demo.launch()
104
+