elo4 commited on
Commit
68bb257
·
verified ·
1 Parent(s): 291d162

Add TinyBERT demo

Browse files
Files changed (1) hide show
  1. app.py +23 -2
app.py CHANGED
@@ -9,6 +9,7 @@ from huggingface_hub import hf_hub_download
9
  import torch
10
  import pickle
11
  import numpy as np
 
12
 
13
  # Load models and tokenizers
14
  models = {
@@ -20,6 +21,10 @@ models = {
20
  "BERT Multilingual (NLP Town)": {
21
  "tokenizer": AutoTokenizer.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment"),
22
  "model": AutoModelForSequenceClassification.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment"),
 
 
 
 
23
  }
24
  }
25
 
@@ -68,12 +73,24 @@ def predict_with_bert_multilingual(text):
68
  predictions = logits.argmax(axis=-1).cpu().numpy()
69
  return int(predictions[0] + 1)
70
 
 
 
 
 
 
 
 
 
 
 
 
71
  # Unified function for sentiment analysis and statistics
72
  def analyze_sentiment_and_statistics(text):
73
  results = {
74
  "DistilBERT": predict_with_distilbert(text),
75
  "Logistic Regression": predict_with_logistic_regression(text),
76
  "BERT Multilingual (NLP Town)": predict_with_bert_multilingual(text),
 
77
  }
78
 
79
  # Calculate statistics
@@ -133,7 +150,8 @@ with gr.Blocks(css=".gradio-container { max-width: 900px; margin: auto; padding:
133
  with gr.Column():
134
  distilbert_output = gr.Textbox(label="Predicted Sentiment (DistilBERT)", interactive=False)
135
  log_reg_output = gr.Textbox(label="Predicted Sentiment (Logistic Regression)", interactive=False)
136
- bert_output = gr.Textbox(label="Predicted Sentiment (BERT Multilingual)", interactive=False)
 
137
 
138
  with gr.Column():
139
  statistics_output = gr.Textbox(label="Statistics (Lowest, Highest, Average)", interactive=False)
@@ -145,14 +163,17 @@ with gr.Blocks(css=".gradio-container { max-width: 900px; margin: auto; padding:
145
  f"{results['DistilBERT']}",
146
  f"{results['Logistic Regression']}",
147
  f"{results['BERT Multilingual (NLP Town)']}",
 
148
  f"Statistics:\n{statistics['Lowest Score']}\n{statistics['Highest Score']}\nAverage Score: {statistics['Average Score']}"
149
  )
150
 
151
  analyze_button.click(
152
  process_input_and_analyze,
153
  inputs=[text_input],
154
- outputs=[distilbert_output, log_reg_output, bert_output, statistics_output]
155
  )
156
 
 
 
157
  # Launch the app
158
  demo.launch()
 
9
  import torch
10
  import pickle
11
  import numpy as np
12
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
13
 
14
  # Load models and tokenizers
15
  models = {
 
21
  "BERT Multilingual (NLP Town)": {
22
  "tokenizer": AutoTokenizer.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment"),
23
  "model": AutoModelForSequenceClassification.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment"),
24
+ },
25
+ "TinyBERT": {
26
+ "tokenizer": AutoTokenizer.from_pretrained("elo4/TinyBERT-sentiment-model"),
27
+ "model": AutoModelForSequenceClassification.from_pretrained("elo4/TinyBERT-sentiment-model"),
28
  }
29
  }
30
 
 
73
  predictions = logits.argmax(axis=-1).cpu().numpy()
74
  return int(predictions[0] + 1)
75
 
76
+ def predict_with_tinybert(text):
77
+ tokenizer = models["TinyBERT"]["tokenizer"]
78
+ model = models["TinyBERT"]["model"]
79
+ encodings = tokenizer([text], padding=True, truncation=True, max_length=128, return_tensors="pt").to(device)
80
+ with torch.no_grad():
81
+ outputs = model(**encodings)
82
+ logits = outputs.logits
83
+ predictions = logits.argmax(axis=-1).cpu().numpy()
84
+ return int(predictions[0])
85
+
86
+
87
  # Unified function for sentiment analysis and statistics
88
  def analyze_sentiment_and_statistics(text):
89
  results = {
90
  "DistilBERT": predict_with_distilbert(text),
91
  "Logistic Regression": predict_with_logistic_regression(text),
92
  "BERT Multilingual (NLP Town)": predict_with_bert_multilingual(text),
93
+ "TinyBERT": predict_with_tinybert(text),
94
  }
95
 
96
  # Calculate statistics
 
150
  with gr.Column():
151
  distilbert_output = gr.Textbox(label="Predicted Sentiment (DistilBERT)", interactive=False)
152
  log_reg_output = gr.Textbox(label="Predicted Sentiment (Logistic Regression)", interactive=False)
153
+ bert_output = gr.Textbox(label="Predicted Sentiment (BERT Multilingual)", interactive=False),
154
+ tinybert_output = gr.Textbox(label="Predicted Sentiment (TinyBERT)", interactive=False)
155
 
156
  with gr.Column():
157
  statistics_output = gr.Textbox(label="Statistics (Lowest, Highest, Average)", interactive=False)
 
163
  f"{results['DistilBERT']}",
164
  f"{results['Logistic Regression']}",
165
  f"{results['BERT Multilingual (NLP Town)']}",
166
+ f"{results['TinyBERT']}",
167
  f"Statistics:\n{statistics['Lowest Score']}\n{statistics['Highest Score']}\nAverage Score: {statistics['Average Score']}"
168
  )
169
 
170
  analyze_button.click(
171
  process_input_and_analyze,
172
  inputs=[text_input],
173
+ outputs=[distilbert_output, log_reg_output, bert_output, tinybert_output, statistics_output]
174
  )
175
 
176
+
177
+
178
  # Launch the app
179
  demo.launch()