Spaces:
Build error
Build error
import gradio as gr | |
import requests | |
import json | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
from datasets import load_dataset | |
import plotly.io as pio | |
import plotly.graph_objects as go | |
import plotly.express as px | |
import pandas as pd | |
from sklearn.metrics import confusion_matrix | |
def load_model(endpoint: str): | |
tokenizer = AutoTokenizer.from_pretrained(endpoint) | |
model = AutoModelForSequenceClassification.from_pretrained(endpoint) | |
return tokenizer, model | |
def test_model(tokenizer, model, test_data: list, label_map: dict): | |
results = [] | |
for text, true_label in test_data: | |
inputs = tokenizer(text, return_tensors="pt", | |
truncation=True, padding=True) | |
outputs = model(**inputs) | |
pred_label = label_map[int(outputs.logits.argmax(dim=-1))] | |
results.append((text, true_label, pred_label)) | |
return results | |
def generate_label_map(dataset): | |
num_labels = len(dataset.features["label"].names) | |
label_map = {i: label for i, label in enumerate(dataset.features["label"].names)} | |
return label_map | |
def generate_report_card(results, label_map): | |
true_labels = [r[1] for r in results] | |
pred_labels = [r[2] for r in results] | |
cm = confusion_matrix(true_labels, pred_labels, | |
labels=list(label_map.values())) | |
fig = go.Figure( | |
data=go.Heatmap( | |
z=cm, | |
x=list(label_map.values()), | |
y=list(label_map.values()), | |
colorscale='Viridis', | |
colorbar=dict(title='Number of Samples') | |
), | |
layout=go.Layout( | |
title='Confusion Matrix', | |
xaxis=dict(title='Predicted Labels'), | |
yaxis=dict(title='True Labels', autorange='reversed') | |
) | |
) | |
fig.update_layout(height=600, width=800) | |
# return fig in new window | |
# fig.show() # uncomment this line to show the plot in a new window | |
# Convert the Plotly figure to an HTML string < i was trying this bc i couldn't get Plot() to work before | |
# plot_html = pio.to_html(fig, full_html=True, include_plotlyjs=True, config={ | |
# "displayModeBar": False, "responsive": True}) | |
#return plot_html | |
return fig | |
def app(model_endpoint: str, dataset_name: str, config_name: str, dataset_split: str, num_samples: int): | |
tokenizer, model = load_model(model_endpoint) | |
# Load the dataset | |
num_samples = int(num_samples) # Add this line to cast num_samples to an integer | |
dataset = load_dataset( | |
dataset_name, config_name, split=f"{dataset_split}[:{num_samples}]") | |
test_data = [(item["sentence"], dataset.features["label"].names[item["label"]]) | |
for item in dataset] | |
label_map = generate_label_map(dataset) | |
results = test_model(tokenizer, model, test_data, label_map) | |
report_card = generate_report_card(results, label_map) | |
return report_card | |
interface = gr.Interface( | |
fn=app, | |
inputs=[ | |
gr.inputs.Textbox(lines=1, label="Model Endpoint", | |
placeholder="ex: distilbert-base-uncased-finetuned-sst-2-english"), | |
gr.inputs.Textbox(lines=1, label="Dataset Name", | |
placeholder="ex: glue"), | |
gr.inputs.Textbox(lines=1, label="Config Name", | |
placeholder="ex: sst2"), | |
gr.inputs.Dropdown( | |
choices=["train", "validation", "test"], label="Dataset Split"), | |
gr.inputs.Number(default=100, label="Number of Samples"), | |
], | |
# outputs=gr.outputs.Plotly(), | |
# outputs=gr.outputs.HTML(), | |
outputs=gr.Plot(), | |
title="Fairness and Bias Testing", | |
description="Enter a model endpoint and dataset to test for fairness and bias.", | |
) | |
# Define the label map globally | |
label_map = {0: "negative", 1: "positive"} | |
if __name__ == "__main__": | |
interface.launch() | |