import gradio as gr import pandas as pd import matplotlib.pyplot as plt import seaborn as sns import plotly.graph_objs as go import plotly.io as pio from io import StringIO import base64 # Read the data from the file def parse_data(file_content): lines = file_content.splitlines() model_data = [] current_model = None for line in lines: line = line.strip() if line.startswith('hf (pretrained='): current_model = line.split('pretrained=')[1].split(',')[0] elif line and current_model: if not line.startswith('-') and '|' in line: # Parse table row parts = [p.strip() for p in line.split('|')] if len(parts) >= 2: # Ensure the correct number of columns try: task_name = parts[0] value = float(parts[1]) # Extract the numeric value model_data.append([ current_model, task_name, # Task name value ]) except ValueError: print(f"Skipping row due to invalid value: {parts}") if not model_data: print("No valid data found in the file.") return pd.DataFrame(model_data, columns=['Model', 'Task', 'Value']) # Calculate average performance def calculate_averages(data): if data.empty: print("No data available to calculate averages.") return pd.DataFrame(columns=['Model', 'Average Performance']) return data.groupby('Model')['Value'].mean().reset_index().rename(columns={'Value': 'Average Performance'}) def create_bar_chart(df, category): """Create a horizontal bar chart for the specified category.""" sorted_df = df[['Model', category]].sort_values(by=category, ascending=True) fig = go.Figure(go.Bar( x=sorted_df[category], y=sorted_df['Model'], orientation='h', marker=dict(color=sorted_df[category], colorscale='Viridis'), hoverinfo='x+y', text=sorted_df[category], textposition='auto' )) fig.update_layout( margin=dict(l=20, r=20, t=20, b=20), title=f"Leaderboard for {category} Scores" ) return fig def generate_visualizations(data, averages): sns.set(style='whitegrid') if averages.empty: print("No averages to visualize.") return None, None, None, None, None, None averages = averages.sort_values(by='Average Performance') # Matplotlib average performance plot plt.figure(figsize=(12, 8)) sns.barplot(data=averages, x='Average Performance', y='Model', palette='viridis') plt.title('Average Performance of Models', fontsize=16) plt.xlabel('Average Performance', fontsize=12) plt.ylabel('Model', fontsize=12) plt.tight_layout() # Save the plot to a buffer buffer_avg = StringIO() plt.savefig(buffer_avg, format='png') buffer_avg.seek(0) image_avg = base64.b64encode(buffer_avg.read()).decode('utf-8') plt.close() # Line plot for task performance by model sorted_models = averages['Model'].tolist() data['Model'] = pd.Categorical(data['Model'], categories=sorted_models, ordered=True) data = data.sort_values(by=['Model', 'Task']) if data.empty: print("No data available for line plot.") return image_avg, None, None, None, None, None plt.figure(figsize=(14, 10)) sns.lineplot(data=data, x='Task', y='Value', hue='Model', marker='o') plt.title('Task Performance by Model', fontsize=16) plt.xlabel('Task', fontsize=12) plt.ylabel('Performance', fontsize=12) plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', title='Model') plt.xticks(rotation=45) plt.tight_layout() # Save the line plot to a buffer buffer_line = StringIO() plt.savefig(buffer_line, format='png') buffer_line.seek(0) image_line = base64.b64encode(buffer_line.read()).decode('utf-8') plt.close() # Heatmap of task performance pivot_table = data.pivot_table(index='Task', columns='Model', values='Value') plt.figure(figsize=(12, 10)) sns.heatmap(pivot_table, annot=True, fmt=".2f", cmap="coolwarm", cbar=True) plt.title('Task Performance Heatmap', fontsize=16) plt.xlabel('Model', fontsize=12) plt.ylabel('Task', fontsize=12) plt.tight_layout() # Save the heatmap to a buffer buffer_heatmap = StringIO() plt.savefig(buffer_heatmap, format='png') buffer_heatmap.seek(0) image_heatmap = base64.b64encode(buffer_heatmap.read()).decode('utf-8') plt.close() # Boxplot of performance distribution per model plt.figure(figsize=(12, 8)) sns.boxplot(data=data, x='Model', y='Value', palette='Set2') plt.title('Performance Distribution per Model', fontsize=16) plt.xlabel('Model', fontsize=12) plt.ylabel('Performance', fontsize=12) plt.xticks(rotation=45) plt.tight_layout() # Save the boxplot to a buffer buffer_boxplot = StringIO() plt.savefig(buffer_boxplot, format='png') buffer_boxplot.seek(0) image_boxplot = base64.b64encode(buffer_boxplot.read()).decode('utf-8') plt.close() # Create plotly bar charts fig1 = create_bar_chart(averages, 'Average Performance') plotly_avg = pio.to_html(fig1, full_html=False) plotly_tasks = {} # Assuming you have tasks in the dataframe and want to display it tasks = data['Task'].unique() for task in tasks: task_data = data[data['Task'] == task] fig2 = create_bar_chart(task_data, 'Value') fig2.update_layout(title=f"Leaderboard for {task} Scores") plotly_tasks[task] = pio.to_html(fig2, full_html=False) return image_avg, image_line, image_heatmap, image_boxplot, plotly_avg, plotly_tasks def process_and_visualize(file_content): data = parse_data(file_content) averages = calculate_averages(data) image_avg, image_line, image_heatmap, image_boxplot, plotly_avg, plotly_tasks = generate_visualizations(data, averages) output_text = f"Average Performance per Model:\n{averages.sort_values(by='Average Performance').to_string()}" return output_text, image_avg, image_line, image_heatmap, image_boxplot, plotly_avg, plotly_tasks, plotly_tasks if __name__ == "__main__": task_names = ['tinyArc', 'tinyHellaswag', 'tinyMMLU', 'tinyTruthfulQA', 'tinyTruthfulQA_mc1', 'tinyWinogrande'] iface = gr.Interface( fn=process_and_visualize, inputs=gr.Textbox(lines=10, label="Paste your data here"), outputs=[ gr.Textbox(label="Average Performance per Model"), gr.Image(label="Matplotlib Average Performance Chart"), gr.Image(label="Matplotlib Task Performance Line Chart"), gr.Image(label="Matplotlib Task Performance Heatmap"), gr.Image(label="Matplotlib Performance Distribution Boxplot"), gr.HTML(label="Plotly Average Performance Chart"), gr.TabbedInterface( [gr.HTML(label=f"Plotly {task} Chart") for task in task_names], label="Task Charts" ), ], title="LLM Benchmark Visualizer", description="Upload your LLM benchmark data and visualize the results." ) iface.launch(share=True)