Spaces:
Running
Running
import json | |
import gradio as gr | |
import pandas as pd | |
import plotly.express as px | |
import pyarrow.parquet as pq | |
import os | |
import requests | |
from io import BytesIO | |
import numpy as np | |
# Define pipeline tags from the provided code | |
PIPELINE_TAGS = [ | |
'text-generation', | |
'text-to-image', | |
'text-classification', | |
'text2text-generation', | |
'audio-to-audio', | |
'feature-extraction', | |
'image-classification', | |
'translation', | |
'reinforcement-learning', | |
'fill-mask', | |
'text-to-speech', | |
'automatic-speech-recognition', | |
'image-text-to-text', | |
'token-classification', | |
'sentence-similarity', | |
'question-answering', | |
'image-feature-extraction', | |
'summarization', | |
'zero-shot-image-classification', | |
'object-detection', | |
'image-segmentation', | |
'image-to-image', | |
'image-to-text', | |
'audio-classification', | |
'visual-question-answering', | |
'text-to-video', | |
'zero-shot-classification', | |
'depth-estimation', | |
'text-ranking', | |
'image-to-video', | |
'multiple-choice', | |
'unconditional-image-generation', | |
'video-classification', | |
'text-to-audio', | |
'time-series-forecasting', | |
'any-to-any', | |
'video-text-to-text', | |
'table-question-answering', | |
] | |
# Model size categories in GB | |
MODEL_SIZE_RANGES = { | |
"Small (<1GB)": (0, 1), | |
"Medium (1-5GB)": (1, 5), | |
"Large (5-20GB)": (5, 20), | |
"X-Large (20-50GB)": (20, 50), | |
"XX-Large (>50GB)": (50, float('inf')) | |
} | |
# Filter functions for tags - keeping the same from provided code | |
def is_audio_speech(model_dict): | |
tags = model_dict.get("tags", []) | |
pipeline_tag = model_dict.get("pipeline_tag", "") | |
return (pipeline_tag and ("audio" in pipeline_tag.lower() or "speech" in pipeline_tag.lower())) or \ | |
any("audio" in tag.lower() for tag in tags) or \ | |
any("speech" in tag.lower() for tag in tags) | |
def is_music(model_dict): | |
tags = model_dict.get("tags", []) | |
return any("music" in tag.lower() for tag in tags) | |
def is_robotics(model_dict): | |
tags = model_dict.get("tags", []) | |
return any("robot" in tag.lower() for tag in tags) | |
def is_biomed(model_dict): | |
tags = model_dict.get("tags", []) | |
return any("bio" in tag.lower() for tag in tags) or \ | |
any("medic" in tag.lower() for tag in tags) | |
def is_timeseries(model_dict): | |
tags = model_dict.get("tags", []) | |
return any("series" in tag.lower() for tag in tags) | |
def is_science(model_dict): | |
tags = model_dict.get("tags", []) | |
return any("science" in tag.lower() and "bigscience" not in tag for tag in tags) | |
def is_video(model_dict): | |
tags = model_dict.get("tags", []) | |
return any("video" in tag.lower() for tag in tags) | |
def is_image(model_dict): | |
tags = model_dict.get("tags", []) | |
return any("image" in tag.lower() for tag in tags) | |
def is_text(model_dict): | |
tags = model_dict.get("tags", []) | |
return any("text" in tag.lower() for tag in tags) | |
# Add model size filter function | |
def is_in_size_range(model_dict, size_range): | |
if size_range is None: | |
return True | |
min_size, max_size = MODEL_SIZE_RANGES[size_range] | |
# Get model size in GB from safetensors total (if available) | |
safetensors = model_dict.get("safetensors", None) | |
if safetensors and isinstance(safetensors, dict) and "total" in safetensors: | |
# Convert bytes to GB | |
size_gb = safetensors["total"] / (1024 * 1024 * 1024) | |
return min_size <= size_gb < max_size | |
return False | |
TAG_FILTER_FUNCS = { | |
"Audio & Speech": is_audio_speech, | |
"Time series": is_timeseries, | |
"Robotics": is_robotics, | |
"Music": is_music, | |
"Video": is_video, | |
"Images": is_image, | |
"Text": is_text, | |
"Biomedical": is_biomed, | |
"Sciences": is_science, | |
} | |
def extract_org_from_id(model_id): | |
"""Extract organization name from model ID""" | |
if "/" in model_id: | |
return model_id.split("/")[0] | |
return "unaffiliated" | |
def make_treemap_data(df, count_by, top_k=25, tag_filter=None, pipeline_filter=None, size_filter=None): | |
"""Process DataFrame into treemap format with filters applied""" | |
# Create a copy to avoid modifying the original | |
filtered_df = df.copy() | |
# Apply filters | |
if tag_filter and tag_filter in TAG_FILTER_FUNCS: | |
filter_func = TAG_FILTER_FUNCS[tag_filter] | |
filtered_df = filtered_df[filtered_df.apply(filter_func, axis=1)] | |
if pipeline_filter: | |
filtered_df = filtered_df[filtered_df["pipeline_tag"] == pipeline_filter] | |
if size_filter and size_filter in MODEL_SIZE_RANGES: | |
# Create a function to check if a model is in the size range | |
def check_size(row): | |
return is_in_size_range(row, size_filter) | |
filtered_df = filtered_df[filtered_df.apply(check_size, axis=1)] | |
# Add organization column | |
filtered_df["organization"] = filtered_df["id"].apply(extract_org_from_id) | |
# Aggregate by organization | |
org_totals = filtered_df.groupby("organization")[count_by].sum().reset_index() | |
org_totals = org_totals.sort_values(by=count_by, ascending=False) | |
# Get top organizations | |
top_orgs = org_totals.head(top_k)["organization"].tolist() | |
# Filter to only include models from top organizations | |
filtered_df = filtered_df[filtered_df["organization"].isin(top_orgs)] | |
# Prepare data for treemap | |
treemap_data = filtered_df[["id", "organization", count_by]].copy() | |
# Add a root node | |
treemap_data["root"] = "models" | |
# Ensure numeric values | |
treemap_data[count_by] = pd.to_numeric(treemap_data[count_by], errors="coerce").fillna(0) | |
return treemap_data | |
def create_treemap(treemap_data, count_by, title=None): | |
"""Create a Plotly treemap from the prepared data""" | |
if treemap_data.empty: | |
# Create an empty figure with a message | |
fig = px.treemap( | |
names=["No data matches the selected filters"], | |
values=[1] | |
) | |
fig.update_layout( | |
title="No data matches the selected filters", | |
margin=dict(t=50, l=25, r=25, b=25) | |
) | |
return fig | |
# Create the treemap | |
fig = px.treemap( | |
treemap_data, | |
path=["root", "organization", "id"], | |
values=count_by, | |
title=title or f"HuggingFace Models - {count_by.capitalize()} by Organization" | |
) | |
# Update layout | |
fig.update_layout( | |
margin=dict(t=50, l=25, r=25, b=25) | |
) | |
# Update traces for better readability | |
fig.update_traces( | |
textinfo="label+value+percent root", | |
hovertemplate="<b>%{label}</b><br>%{value:,} " + count_by + "<br>%{percentRoot:.2%} of total<extra></extra>" | |
) | |
return fig | |
def download_with_progress(url, progress=None): | |
"""Download a file with progress tracking""" | |
try: | |
response = requests.get(url, stream=True) | |
total_size = int(response.headers.get('content-length', 0)) | |
block_size = 1024 # 1 Kibibyte | |
data = BytesIO() | |
if total_size == 0: | |
# If content length is unknown, we can't show accurate progress | |
if progress is not None: | |
progress(0, "Starting download...") | |
for chunk in response.iter_content(block_size): | |
data.write(chunk) | |
if progress is not None: | |
progress(0, f"Downloading... (unknown size)") | |
else: | |
downloaded = 0 | |
for chunk in response.iter_content(block_size): | |
downloaded += len(chunk) | |
data.write(chunk) | |
if progress is not None: | |
percent = int(100 * downloaded / total_size) | |
progress(percent / 100, f"Downloading... {percent}% ({downloaded//(1024*1024)}MB/{total_size//(1024*1024)}MB)") | |
return data.getvalue() | |
except Exception as e: | |
print(f"Error in download_with_progress: {e}") | |
raise | |
def update_progress(progress_obj, value, description): | |
"""Safely update progress with error handling""" | |
try: | |
if progress_obj is not None: | |
progress_obj(value, description) | |
except Exception as e: | |
print(f"Error updating progress: {e}") | |
def download_and_process_models(progress=None): | |
"""Download and process the models data from HuggingFace dataset with progress tracking""" | |
try: | |
# Create a cache directory | |
if not os.path.exists('data'): | |
os.makedirs('data') | |
# Check if we have cached data | |
if os.path.exists('data/processed_models.parquet'): | |
update_progress(progress, 1.0, "Loading from cache...") | |
print("Loading models from cache...") | |
df = pd.read_parquet('data/processed_models.parquet') | |
return df | |
# URL to the models.parquet file | |
url = "https://huggingface.co/datasets/cfahlgren1/hub-stats/resolve/main/models.parquet" | |
update_progress(progress, 0.0, "Starting download...") | |
print(f"Downloading models data from {url}...") | |
try: | |
# Download with progress tracking | |
file_content = download_with_progress(url, progress) | |
update_progress(progress, 0.9, "Parsing parquet file...") | |
# Read the parquet file | |
table = pq.read_table(BytesIO(file_content)) | |
df = table.to_pandas() | |
print(f"Downloaded {len(df)} models") | |
update_progress(progress, 0.95, "Processing data...") | |
# Process the safetensors column if it's a string (JSON) | |
if 'safetensors' in df.columns: | |
def parse_safetensors(val): | |
if isinstance(val, str): | |
try: | |
return json.loads(val) | |
except: | |
return None | |
return val | |
df['safetensors'] = df['safetensors'].apply(parse_safetensors) | |
# Process the tags column if needed | |
if 'tags' in df.columns and len(df) > 0 and not isinstance(df['tags'].iloc[0], list): | |
def parse_tags(val): | |
if isinstance(val, str): | |
try: | |
return json.loads(val) | |
except: | |
return [] | |
return val if isinstance(val, list) else [] | |
df['tags'] = df['tags'].apply(parse_tags) | |
# Cache the processed data | |
update_progress(progress, 0.98, "Saving to cache...") | |
df.to_parquet('data/processed_models.parquet') | |
update_progress(progress, 1.0, "Data ready!") | |
return df | |
except Exception as download_error: | |
print(f"Download failed: {download_error}") | |
update_progress(progress, 0.5, "Download failed, generating sample data...") | |
return create_sample_data(progress) | |
except Exception as e: | |
print(f"Error downloading or processing data: {e}") | |
update_progress(progress, 1.0, "Using sample data (error occurred)") | |
# Return sample data for testing if real data unavailable | |
return create_sample_data(progress) | |
def create_sample_data(progress=None): | |
"""Create sample data for testing when real data is unavailable""" | |
print("Creating sample data for testing...") | |
if progress: | |
progress(0.3, "Creating sample data...") | |
# Sample organizations | |
orgs = ['openai', 'meta', 'google', 'microsoft', 'anthropic', 'nvidia', 'huggingface', | |
'deepseek-ai', 'stability-ai', 'mistralai', 'cerebras', 'databricks', 'together', | |
'facebook', 'amazon', 'deepmind', 'cohere', 'nvidia', 'bigscience', 'eleutherai'] | |
# Common model name formats | |
model_name_patterns = [ | |
"model-{size}-{version}", | |
"{prefix}-{size}b", | |
"{prefix}-{size}b-{variant}", | |
"llama-{size}b-{variant}", | |
"gpt-{variant}-{size}b", | |
"{prefix}-instruct-{size}b", | |
"{prefix}-chat-{size}b", | |
"{prefix}-coder-{size}b", | |
"stable-diffusion-{version}", | |
"whisper-{size}", | |
"bert-{size}-{variant}", | |
"roberta-{size}", | |
"t5-{size}", | |
"{prefix}-vision-{size}b" | |
] | |
# Common name parts | |
prefixes = ["falcon", "llama", "mistral", "gpt", "phi", "gemma", "qwen", "yi", "mpt", "bloom"] | |
sizes = ["7", "13", "34", "70", "1", "3", "7b", "13b", "70b", "8b", "2b", "1b", "0.5b", "small", "base", "large", "huge"] | |
variants = ["chat", "instruct", "base", "v1.0", "v2", "beta", "turbo", "fast", "xl", "xxl"] | |
# Generate sample data | |
data = [] | |
total_models = sum(np.random.randint(5, 20) for _ in orgs) | |
models_created = 0 | |
for org_idx, org in enumerate(orgs): | |
# Create 5-20 models per organization | |
num_models = np.random.randint(5, 20) | |
for i in range(num_models): | |
# Create realistic model name | |
pattern = np.random.choice(model_name_patterns) | |
prefix = np.random.choice(prefixes) | |
size = np.random.choice(sizes) | |
version = f"v{np.random.randint(1, 4)}" | |
variant = np.random.choice(variants) | |
model_name = pattern.format( | |
prefix=prefix, | |
size=size, | |
version=version, | |
variant=variant | |
) | |
model_id = f"{org}/{model_name}" | |
# Select a realistic pipeline tag based on name | |
if "diffusion" in model_name or "image" in model_name: | |
pipeline_tag = np.random.choice(["text-to-image", "image-to-image", "image-segmentation"]) | |
elif "whisper" in model_name or "speech" in model_name: | |
pipeline_tag = np.random.choice(["automatic-speech-recognition", "text-to-speech"]) | |
elif "coder" in model_name or "code" in model_name: | |
pipeline_tag = "text-generation" | |
elif "bert" in model_name or "roberta" in model_name: | |
pipeline_tag = np.random.choice(["fill-mask", "text-classification", "token-classification"]) | |
elif "vision" in model_name: | |
pipeline_tag = np.random.choice(["image-classification", "image-to-text", "visual-question-answering"]) | |
else: | |
pipeline_tag = "text-generation" # Most common | |
# Generate realistic tags | |
tags = [pipeline_tag] | |
if "text-generation" in pipeline_tag: | |
tags.extend(["language-model", "text", "gpt", "llm"]) | |
if "instruct" in model_name: | |
tags.append("instruction-following") | |
if "chat" in model_name: | |
tags.append("chat") | |
elif "speech" in pipeline_tag: | |
tags.extend(["audio", "speech", "voice"]) | |
elif "image" in pipeline_tag: | |
tags.extend(["vision", "image", "diffusion"]) | |
# Add language tags | |
if np.random.random() < 0.8: # 80% chance for English | |
tags.append("en") | |
if np.random.random() < 0.3: # 30% chance for multilingual | |
tags.append("multilingual") | |
# Generate downloads and likes (weighted by org position for variety) | |
# Earlier orgs get more downloads to make the visualization interesting | |
popularity_factor = (len(orgs) - org_idx) / len(orgs) # 1.0 to 0.0 | |
base_downloads = 1000 * (10 ** (2 * popularity_factor)) | |
downloads = int(base_downloads * np.random.uniform(0.3, 3.0)) | |
likes = int(downloads * np.random.uniform(0.01, 0.1)) # 1-10% like ratio | |
# Generate model size (in bytes for safetensors total) | |
# Model size should correlate somewhat with the size in the name | |
size_indicator = 1 | |
for s in ["70b", "13b", "7b", "3b", "2b", "1b", "large", "huge", "xl", "xxl"]: | |
if s in model_name.lower(): | |
size_indicator = float(s.replace("b", "")) if s[0].isdigit() else 3 | |
break | |
# Size in GB, then convert to bytes | |
size_gb = np.random.uniform(0.1, 2.0) * size_indicator | |
if size_gb > 50: # Cap at 100GB | |
size_gb = min(size_gb, 100) | |
size_bytes = int(size_gb * 1e9) | |
# Create model entry | |
model = { | |
"id": model_id, | |
"downloads": downloads, | |
"downloadsAllTime": int(downloads * np.random.uniform(1.5, 3.0)), # All-time higher than recent | |
"likes": likes, | |
"pipeline_tag": pipeline_tag, | |
"tags": tags, | |
"safetensors": {"total": size_bytes} | |
} | |
data.append(model) | |
models_created += 1 | |
if progress and i % 5 == 0: | |
progress(0.3 + 0.6 * (models_created / total_models), f"Created {models_created}/{total_models} sample models...") | |
# Convert to DataFrame | |
df = pd.DataFrame(data) | |
if progress: | |
progress(0.95, "Finalizing sample data...") | |
return df | |
# Create Gradio interface | |
with gr.Blocks() as demo: | |
models_data = gr.State() # To store loaded data | |
# Loading screen components | |
with gr.Row(visible=True) as loading_screen: | |
with gr.Column(scale=1): | |
gr.Markdown(""" | |
# HuggingFace Models TreeMap Visualization | |
Loading data... This might take a moment. | |
""") | |
data_loading_progress = gr.Progress() | |
# Main application components (initially hidden) | |
with gr.Row(visible=False) as main_app: | |
gr.Markdown(""" | |
# HuggingFace Models TreeMap Visualization | |
This app shows how different organizations contribute to the HuggingFace ecosystem with their models. | |
Use the filters to explore models by different metrics, tags, pipelines, and model sizes. | |
The treemap visualizes models grouped by organization, with the size of each box representing the selected metric (downloads or likes). | |
""") | |
with gr.Row(visible=False) as control_panel: | |
with gr.Column(scale=1): | |
count_by_dropdown = gr.Dropdown( | |
label="Metric", | |
choices=["downloads", "downloadsAllTime", "likes"], | |
value="downloads", | |
info="Select the metric to determine box sizes" | |
) | |
filter_choice_radio = gr.Radio( | |
label="Filter Type", | |
choices=["None", "Tag Filter", "Pipeline Filter"], | |
value="None", | |
info="Choose how to filter the models" | |
) | |
tag_filter_dropdown = gr.Dropdown( | |
label="Select Tag", | |
choices=list(TAG_FILTER_FUNCS.keys()), | |
value=None, | |
visible=False, | |
info="Filter models by domain/category" | |
) | |
pipeline_filter_dropdown = gr.Dropdown( | |
label="Select Pipeline Tag", | |
choices=PIPELINE_TAGS, | |
value=None, | |
visible=False, | |
info="Filter models by specific pipeline" | |
) | |
size_filter_dropdown = gr.Dropdown( | |
label="Model Size Filter", | |
choices=["None"] + list(MODEL_SIZE_RANGES.keys()), | |
value="None", | |
info="Filter models by their size (in safetensors['total'])" | |
) | |
top_k_slider = gr.Slider( | |
label="Number of Top Organizations", | |
minimum=5, | |
maximum=50, | |
value=25, | |
step=5, | |
info="Number of top organizations to include" | |
) | |
generate_plot_button = gr.Button("Generate Plot", variant="primary") | |
with gr.Column(scale=3): | |
plot_output = gr.Plot() | |
stats_output = gr.Markdown("*Generate a plot to see statistics*") | |
def generate_plot_on_click(count_by, filter_choice, tag_filter, pipeline_filter, size_filter, top_k, data_df): | |
print(f"Generating plot with: Metric={count_by}, Filter={filter_choice}, Tag={tag_filter}, Pipeline={pipeline_filter}, Size={size_filter}, Top K={top_k}") | |
if data_df is None or len(data_df) == 0: | |
return None, "Error: No data available. Please try again." | |
selected_tag_filter = None | |
selected_pipeline_filter = None | |
selected_size_filter = None | |
if filter_choice == "Tag Filter": | |
selected_tag_filter = tag_filter | |
elif filter_choice == "Pipeline Filter": | |
selected_pipeline_filter = pipeline_filter | |
if size_filter != "None": | |
selected_size_filter = size_filter | |
# Process data for treemap | |
treemap_data = make_treemap_data( | |
df=data_df, | |
count_by=count_by, | |
top_k=top_k, | |
tag_filter=selected_tag_filter, | |
pipeline_filter=selected_pipeline_filter, | |
size_filter=selected_size_filter | |
) | |
# Create plot | |
fig = create_treemap( | |
treemap_data=treemap_data, | |
count_by=count_by, | |
title=f"HuggingFace Models - {count_by.capitalize()} by Organization" | |
) | |
# Generate statistics | |
if treemap_data.empty: | |
stats_md = "No data matches the selected filters." | |
else: | |
total_models = len(treemap_data) | |
total_value = treemap_data[count_by].sum() | |
top_5_orgs = treemap_data.groupby("organization")[count_by].sum().sort_values(ascending=False).head(5) | |
stats_md = f""" | |
### Statistics | |
- **Total models shown**: {total_models:,} | |
- **Total {count_by}**: {total_value:,} | |
### Top 5 Organizations | |
| Organization | {count_by.capitalize()} | % of Total | | |
| --- | --- | --- | | |
""" | |
for org, value in top_5_orgs.items(): | |
percentage = (value / total_value) * 100 | |
stats_md += f"| {org} | {value:,} | {percentage:.2f}% |\n" | |
return fig, stats_md | |
def update_filter_visibility(filter_choice): | |
if filter_choice == "Tag Filter": | |
return gr.update(visible=True), gr.update(visible=False) | |
elif filter_choice == "Pipeline Filter": | |
return gr.update(visible=False), gr.update(visible=True) | |
else: # "None" | |
return gr.update(visible=False), gr.update(visible=False) | |
filter_choice_radio.change( | |
fn=update_filter_visibility, | |
inputs=[filter_choice_radio], | |
outputs=[tag_filter_dropdown, pipeline_filter_dropdown] | |
) | |
def load_data_with_progress(progress=gr.Progress()): | |
"""Load data with progress tracking and update UI visibility""" | |
data_df = download_and_process_models(progress) | |
# Return both the data and the visibility updates | |
return data_df, gr.update(visible=False), gr.update(visible=True), gr.update(visible=True) | |
# Load data once at startup with progress bar | |
demo.load( | |
fn=load_data_with_progress, | |
inputs=[], | |
outputs=[models_data, loading_screen, main_app, control_panel] | |
) | |
# Button click event to generate plot | |
generate_plot_button.click( | |
fn=generate_plot_on_click, | |
inputs=[ | |
count_by_dropdown, | |
filter_choice_radio, | |
tag_filter_dropdown, | |
pipeline_filter_dropdown, | |
size_filter_dropdown, | |
top_k_slider, | |
models_data | |
], | |
outputs=[plot_output, stats_output] | |
) | |
if __name__ == "__main__": | |
demo.launch() |