CultriX's picture
Update app.py
144fe6c verified
raw
history blame
7.44 kB
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objs as go
import plotly.io as pio
from io import StringIO
import base64
# Read the data from the file
def parse_data(file_content):
lines = file_content.splitlines()
model_data = []
current_model = None
for line in lines:
line = line.strip()
if line.startswith('hf (pretrained='):
current_model = line.split('pretrained=')[1].split(',')[0]
elif line and current_model:
if not line.startswith('-') and '|' in line:
# Parse table row
parts = [p.strip() for p in line.split('|')]
if len(parts) >= 2: # Ensure the correct number of columns
try:
task_name = parts[0]
value = float(parts[1]) # Extract the numeric value
model_data.append([
current_model,
task_name, # Task name
value
])
except ValueError:
print(f"Skipping row due to invalid value: {parts}")
if not model_data:
print("No valid data found in the file.")
return pd.DataFrame(model_data, columns=['Model', 'Task', 'Value'])
# Calculate average performance
def calculate_averages(data):
if data.empty:
print("No data available to calculate averages.")
return pd.DataFrame(columns=['Model', 'Average Performance'])
return data.groupby('Model')['Value'].mean().reset_index().rename(columns={'Value': 'Average Performance'})
def create_bar_chart(df, category):
"""Create a horizontal bar chart for the specified category."""
sorted_df = df[['Model', category]].sort_values(by=category, ascending=True)
fig = go.Figure(go.Bar(
x=sorted_df[category],
y=sorted_df['Model'],
orientation='h',
marker=dict(color=sorted_df[category], colorscale='Viridis'),
hoverinfo='x+y',
text=sorted_df[category],
textposition='auto'
))
fig.update_layout(
margin=dict(l=20, r=20, t=20, b=20),
title=f"Leaderboard for {category} Scores"
)
return fig
def generate_visualizations(data, averages):
sns.set(style='whitegrid')
if averages.empty:
print("No averages to visualize.")
return None, None, None, None, None, None
averages = averages.sort_values(by='Average Performance')
# Matplotlib average performance plot
plt.figure(figsize=(12, 8))
sns.barplot(data=averages, x='Average Performance', y='Model', palette='viridis')
plt.title('Average Performance of Models', fontsize=16)
plt.xlabel('Average Performance', fontsize=12)
plt.ylabel('Model', fontsize=12)
plt.tight_layout()
# Save the plot to a buffer
buffer_avg = StringIO()
plt.savefig(buffer_avg, format='png')
buffer_avg.seek(0)
image_avg = base64.b64encode(buffer_avg.read()).decode('utf-8')
plt.close()
# Line plot for task performance by model
sorted_models = averages['Model'].tolist()
data['Model'] = pd.Categorical(data['Model'], categories=sorted_models, ordered=True)
data = data.sort_values(by=['Model', 'Task'])
if data.empty:
print("No data available for line plot.")
return image_avg, None, None, None, None, None
plt.figure(figsize=(14, 10))
sns.lineplot(data=data, x='Task', y='Value', hue='Model', marker='o')
plt.title('Task Performance by Model', fontsize=16)
plt.xlabel('Task', fontsize=12)
plt.ylabel('Performance', fontsize=12)
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', title='Model')
plt.xticks(rotation=45)
plt.tight_layout()
# Save the line plot to a buffer
buffer_line = StringIO()
plt.savefig(buffer_line, format='png')
buffer_line.seek(0)
image_line = base64.b64encode(buffer_line.read()).decode('utf-8')
plt.close()
# Heatmap of task performance
pivot_table = data.pivot_table(index='Task', columns='Model', values='Value')
plt.figure(figsize=(12, 10))
sns.heatmap(pivot_table, annot=True, fmt=".2f", cmap="coolwarm", cbar=True)
plt.title('Task Performance Heatmap', fontsize=16)
plt.xlabel('Model', fontsize=12)
plt.ylabel('Task', fontsize=12)
plt.tight_layout()
# Save the heatmap to a buffer
buffer_heatmap = StringIO()
plt.savefig(buffer_heatmap, format='png')
buffer_heatmap.seek(0)
image_heatmap = base64.b64encode(buffer_heatmap.read()).decode('utf-8')
plt.close()
# Boxplot of performance distribution per model
plt.figure(figsize=(12, 8))
sns.boxplot(data=data, x='Model', y='Value', palette='Set2')
plt.title('Performance Distribution per Model', fontsize=16)
plt.xlabel('Model', fontsize=12)
plt.ylabel('Performance', fontsize=12)
plt.xticks(rotation=45)
plt.tight_layout()
# Save the boxplot to a buffer
buffer_boxplot = StringIO()
plt.savefig(buffer_boxplot, format='png')
buffer_boxplot.seek(0)
image_boxplot = base64.b64encode(buffer_boxplot.read()).decode('utf-8')
plt.close()
# Create plotly bar charts
fig1 = create_bar_chart(averages, 'Average Performance')
plotly_avg = pio.to_html(fig1, full_html=False)
plotly_tasks = {}
# Assuming you have tasks in the dataframe and want to display it
tasks = data['Task'].unique()
for task in tasks:
task_data = data[data['Task'] == task]
fig2 = create_bar_chart(task_data, 'Value')
fig2.update_layout(title=f"Leaderboard for {task} Scores")
plotly_tasks[task] = pio.to_html(fig2, full_html=False)
return image_avg, image_line, image_heatmap, image_boxplot, plotly_avg, plotly_tasks
def process_and_visualize(file_content):
data = parse_data(file_content)
averages = calculate_averages(data)
image_avg, image_line, image_heatmap, image_boxplot, plotly_avg, plotly_tasks = generate_visualizations(data, averages)
output_text = f"Average Performance per Model:\n{averages.sort_values(by='Average Performance').to_string()}"
return output_text, image_avg, image_line, image_heatmap, image_boxplot, plotly_avg, plotly_tasks, plotly_tasks
if __name__ == "__main__":
task_names = ['tinyArc', 'tinyHellaswag', 'tinyMMLU', 'tinyTruthfulQA', 'tinyTruthfulQA_mc1', 'tinyWinogrande']
iface = gr.Interface(
fn=process_and_visualize,
inputs=gr.Textbox(lines=10, label="Paste your data here"),
outputs=[
gr.Textbox(label="Average Performance per Model"),
gr.Image(label="Matplotlib Average Performance Chart"),
gr.Image(label="Matplotlib Task Performance Line Chart"),
gr.Image(label="Matplotlib Task Performance Heatmap"),
gr.Image(label="Matplotlib Performance Distribution Boxplot"),
gr.HTML(label="Plotly Average Performance Chart"),
gr.TabbedInterface(
[gr.HTML(label=f"Plotly {task} Chart") for task in task_names],
label="Task Charts"
),
],
title="LLM Benchmark Visualizer",
description="Upload your LLM benchmark data and visualize the results."
)
iface.launch(share=True)