Spaces:
Sleeping
Sleeping
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import torch | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import gradio as gr | |
import re | |
import os | |
# Model definitions | |
models = { | |
"NLPTown (1–5 Stars)": { | |
"id": "nlptown/bert-base-multilingual-uncased-sentiment", | |
"scoring": lambda probs: np.dot(probs, np.arange(1, 6)), | |
"label": "Avg Star Rating (1-5)" | |
}, | |
"SST-2 (Positive/Negative)": { | |
"id": "distilbert-base-uncased-finetuned-sst-2-english", | |
"scoring": lambda probs: probs[1], | |
"label": "Prob(Positive)" | |
}, | |
"BERTweet (Pos/Neu/Neg)": { | |
"id": "finiteautomata/bertweet-base-sentiment-analysis", | |
"scoring": lambda probs: probs[2], | |
"label": "Prob(Positive)" | |
} | |
} | |
# Load models | |
loaded_models = {} | |
for name, config in models.items(): | |
print(f"Loading {name}...") | |
tokenizer = AutoTokenizer.from_pretrained(config["id"], use_fast=False if "bertweet" in config["id"] else True) | |
model = AutoModelForSequenceClassification.from_pretrained(config["id"]) | |
loaded_models[name] = (tokenizer, model) | |
def gender_swap(text): | |
swap_dict = { | |
"he": "she", "she": "he", | |
"him": "her", "her": "him", | |
"his": "her", "hers": "his", | |
"He": "She", "She": "He", | |
"Him": "Her", "Her": "Him", | |
"His": "Her", "Hers": "His" | |
} | |
pattern = re.compile(r'\b(' + '|'.join(re.escape(k) for k in swap_dict.keys()) + r')\b') | |
return pattern.sub(lambda x: swap_dict[x.group()], text) | |
def ensure_directory_exists(path): | |
if not os.path.exists(path): | |
os.makedirs(path) | |
def sanitize_filename(name): | |
return re.sub(r'\W+', '_', name) | |
def plot_scores(name, score_orig, score_swap, ylabel): | |
plt.figure(figsize=(4, 4)) | |
plt.bar(["Original", "Swapped"], [score_orig, score_swap], color=["skyblue", "lightpink"]) | |
plt.ylabel(ylabel) | |
plt.title(f"{name} Bias Comparison") | |
plt.tight_layout() | |
safe_name = sanitize_filename(name) | |
save_directory = './plots' | |
ensure_directory_exists(save_directory) | |
image_path = os.path.join(save_directory, f"{safe_name}_bias_plot.png") | |
plt.savefig(image_path) | |
plt.close() | |
return image_path | |
def analyze_sentiment(text): | |
results = {} | |
swapped_text = gender_swap(text) | |
for name, (tokenizer, model) in loaded_models.items(): | |
inputs = tokenizer(text, return_tensors="pt", truncation=True) | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
probs = torch.nn.functional.softmax(logits, dim=1).squeeze().numpy() | |
score_orig = models[name]["scoring"](probs) | |
inputs_swapped = tokenizer(swapped_text, return_tensors="pt", truncation=True) | |
with torch.no_grad(): | |
logits_swapped = model(**inputs_swapped).logits | |
probs_swapped = torch.nn.functional.softmax(logits_swapped, dim=1).squeeze().numpy() | |
score_swap = models[name]["scoring"](probs_swapped) | |
image_path = plot_scores(name, score_orig, score_swap, models[name]["label"]) | |
results[name] = { | |
"Original": round(score_orig, 4), | |
"Gender-Swapped": round(score_swap, 4), | |
"Bias (abs diff)": round(abs(score_orig - score_swap), 4), | |
"Plot": image_path | |
} | |
return results | |
def display_results(text): | |
results = analyze_sentiment(text) | |
model_outputs = [] | |
plot_outputs = [] | |
for model_name, result in results.items(): | |
model_output = f"### **{model_name}**\n" | |
model_output += f"*Original Score*: {result['Original']:.4f}\n" | |
model_output += f"\n*Gender-Swapped Score*: {result['Gender-Swapped']:.4f}\n" | |
model_output += f"\n*Bias (abs diff)*: {result['Bias (abs diff)']:.4f}\n" | |
model_outputs.append(model_output) | |
plot_outputs.append(result['Plot']) | |
return "\n\n".join(model_outputs), plot_outputs | |
with gr.Blocks() as demo: | |
gr.Markdown("# Sentiment & Gender Bias Analysis", elem_id="header") | |
gr.Markdown("Enter a sentence below to analyze sentiment and gender bias across different models.", elem_id="subheader") | |
inp = gr.Textbox(label="Enter a sentence", placeholder="Enter sentence here...", elem_id="input_box") | |
out_text = gr.Markdown(elem_id="output_text") | |
out_plots = gr.Gallery(label="Model-wise Bias Plots", columns=3, height="auto", elem_id="output_plots") | |
btn = gr.Button("Analyze", elem_id="analyze_button") | |
btn.click(fn=display_results, inputs=inp, outputs=[out_text, out_plots]) | |
demo.launch() | |