File size: 3,871 Bytes
7bd4255
438c90e
 
 
7bd4255
 
438c90e
 
 
 
 
 
7bd4255
 
 
438c90e
 
 
7bd4255
 
 
 
 
 
 
 
438c90e
7bd4255
 
 
 
438c90e
 
7bd4255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
438c90e
 
7bd4255
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import gradio as gr
import requests
import json
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset
import plotly.io as pio
import plotly.graph_objects as go
import plotly.express as px
import pandas as pd
from sklearn.metrics import confusion_matrix

def load_model(endpoint: str):
    tokenizer = AutoTokenizer.from_pretrained(endpoint)
    model = AutoModelForSequenceClassification.from_pretrained(endpoint)
    return tokenizer, model


def test_model(tokenizer, model, test_data: list, label_map: dict):
    results = []
    for text, true_label in test_data:
        inputs = tokenizer(text, return_tensors="pt",
                           truncation=True, padding=True)
        outputs = model(**inputs)
        pred_label = label_map[int(outputs.logits.argmax(dim=-1))]
        results.append((text, true_label, pred_label))
    return results

def generate_label_map(dataset):
  num_labels = len(dataset.features["label"].names)
  label_map = {i: label for i, label in enumerate(dataset.features["label"].names)}
  return label_map


def generate_report_card(results, label_map):
    true_labels = [r[1] for r in results]
    pred_labels = [r[2] for r in results]

    cm = confusion_matrix(true_labels, pred_labels,
                          labels=list(label_map.values()))

    fig = go.Figure(
        data=go.Heatmap(
            z=cm,
            x=list(label_map.values()),
            y=list(label_map.values()),
            colorscale='Viridis',
            colorbar=dict(title='Number of Samples')
        ),
        layout=go.Layout(
            title='Confusion Matrix',
            xaxis=dict(title='Predicted Labels'),
            yaxis=dict(title='True Labels', autorange='reversed')
        )
    )

    fig.update_layout(height=600, width=800)

    # return fig in new window
    # fig.show() # uncomment this line to show the plot in a new window

    # Convert the Plotly figure to an HTML string < i was trying this bc i couldn't get Plot() to work before
    # plot_html = pio.to_html(fig, full_html=True, include_plotlyjs=True, config={
    #                        "displayModeBar": False, "responsive": True})
    #return plot_html
    return fig

def app(model_endpoint: str, dataset_name: str, config_name: str, dataset_split: str, num_samples: int):
    tokenizer, model = load_model(model_endpoint)

    # Load the dataset
    num_samples = int(num_samples)  # Add this line to cast num_samples to an integer
    dataset = load_dataset(
        dataset_name, config_name, split=f"{dataset_split}[:{num_samples}]")
    test_data = [(item["sentence"], dataset.features["label"].names[item["label"]])
                 for item in dataset]
    
    label_map = generate_label_map(dataset)

    results = test_model(tokenizer, model, test_data, label_map)
    report_card = generate_report_card(results, label_map)

    return report_card

interface = gr.Interface(
    fn=app,
    inputs=[
        gr.inputs.Textbox(lines=1, label="Model Endpoint",
                          placeholder="ex: distilbert-base-uncased-finetuned-sst-2-english"),
        gr.inputs.Textbox(lines=1, label="Dataset Name",
                          placeholder="ex: glue"),
        gr.inputs.Textbox(lines=1, label="Config Name",
                          placeholder="ex: sst2"),
        gr.inputs.Dropdown(
            choices=["train", "validation", "test"], label="Dataset Split"),
        gr.inputs.Number(default=100, label="Number of Samples"),
    ],
    # outputs=gr.outputs.Plotly(),
    # outputs=gr.outputs.HTML(),
    outputs=gr.Plot(),
    title="Fairness and Bias Testing",
    description="Enter a model endpoint and dataset to test for fairness and bias.",
)

# Define the label map globally
label_map = {0: "negative", 1: "positive"}

if __name__ == "__main__":
    interface.launch()