Spaces:
Running
Running
import json | |
import gradio as gr | |
import pandas as pd | |
import plotly.express as px | |
import pyarrow.parquet as pq | |
import os | |
import requests | |
from io import BytesIO | |
import math | |
# Define pipeline tags (keeping the same ones from the provided code) | |
PIPELINE_TAGS = [ | |
'text-generation', | |
'text-to-image', | |
'text-classification', | |
'text2text-generation', | |
'audio-to-audio', | |
'feature-extraction', | |
'image-classification', | |
'translation', | |
'reinforcement-learning', | |
'fill-mask', | |
'text-to-speech', | |
'automatic-speech-recognition', | |
'image-text-to-text', | |
'token-classification', | |
'sentence-similarity', | |
'question-answering', | |
'image-feature-extraction', | |
'summarization', | |
'zero-shot-image-classification', | |
'object-detection', | |
'image-segmentation', | |
'image-to-image', | |
'image-to-text', | |
'audio-classification', | |
'visual-question-answering', | |
'text-to-video', | |
'zero-shot-classification', | |
'depth-estimation', | |
'text-ranking', | |
'image-to-video', | |
'multiple-choice', | |
'unconditional-image-generation', | |
'video-classification', | |
'text-to-audio', | |
'time-series-forecasting', | |
'any-to-any', | |
'video-text-to-text', | |
'table-question-answering', | |
] | |
# Model size categories in GB | |
MODEL_SIZE_RANGES = { | |
"Small (<1GB)": (0, 1), | |
"Medium (1-5GB)": (1, 5), | |
"Large (5-20GB)": (5, 20), | |
"X-Large (20-50GB)": (20, 50), | |
"XX-Large (>50GB)": (50, float('inf')) | |
} | |
# Filter functions for tags - keeping the same from provided code | |
def is_audio_speech(repo_dct): | |
res = (repo_dct.get("pipeline_tag", None) and "audio" in repo_dct.get("pipeline_tag", "").lower()) or \ | |
(repo_dct.get("pipeline_tag", None) and "speech" in repo_dct.get("pipeline_tag", "").lower()) or \ | |
(repo_dct.get("tags", None) and any("audio" in tag.lower() for tag in repo_dct.get("tags", []))) or \ | |
(repo_dct.get("tags", None) and any("speech" in tag.lower() for tag in repo_dct.get("tags", []))) | |
return res | |
def is_music(repo_dct): | |
res = (repo_dct.get("tags", None) and any("music" in tag.lower() for tag in repo_dct.get("tags", []))) | |
return res | |
def is_robotics(repo_dct): | |
res = (repo_dct.get("tags", None) and any("robot" in tag.lower() for tag in repo_dct.get("tags", []))) | |
return res | |
def is_biomed(repo_dct): | |
res = (repo_dct.get("tags", None) and any("bio" in tag.lower() for tag in repo_dct.get("tags", []))) or \ | |
(repo_dct.get("tags", None) and any("medic" in tag.lower() for tag in repo_dct.get("tags", []))) | |
return res | |
def is_timeseries(repo_dct): | |
res = (repo_dct.get("tags", None) and any("series" in tag.lower() for tag in repo_dct.get("tags", []))) | |
return res | |
def is_science(repo_dct): | |
res = (repo_dct.get("tags", None) and any("science" in tag.lower() and not "bigscience" in tag for tag in repo_dct.get("tags", []))) | |
return res | |
def is_video(repo_dct): | |
res = (repo_dct.get("tags", None) and any("video" in tag.lower() for tag in repo_dct.get("tags", []))) | |
return res | |
def is_image(repo_dct): | |
res = (repo_dct.get("tags", None) and any("image" in tag.lower() for tag in repo_dct.get("tags", []))) | |
return res | |
def is_text(repo_dct): | |
res = (repo_dct.get("tags", None) and any("text" in tag.lower() for tag in repo_dct.get("tags", []))) | |
return res | |
# Add model size filter function | |
def is_in_size_range(repo_dct, size_range): | |
if size_range is None: | |
return True | |
min_size, max_size = MODEL_SIZE_RANGES[size_range] | |
# Get model size in GB from safetensors total (if available) | |
if repo_dct.get("safetensors") and repo_dct["safetensors"].get("total"): | |
# Convert bytes to GB | |
size_gb = repo_dct["safetensors"]["total"] / (1024 * 1024 * 1024) | |
return min_size <= size_gb < max_size | |
return False | |
TAG_FILTER_FUNCS = { | |
"Audio & Speech": is_audio_speech, | |
"Time series": is_timeseries, | |
"Robotics": is_robotics, | |
"Music": is_music, | |
"Video": is_video, | |
"Images": is_image, | |
"Text": is_text, | |
"Biomedical": is_biomed, | |
"Sciences": is_science, | |
} | |
def make_org_stats(count_by, org_stats, top_k=20, filter_func=None, size_range=None): | |
assert count_by in ["likes", "downloads"] | |
# Apply both filter_func and size_range if provided | |
def combined_filter(dct): | |
passes_tag_filter = filter_func(dct) if filter_func else True | |
passes_size_filter = is_in_size_range(dct, size_range) if size_range else True | |
return passes_tag_filter and passes_size_filter | |
# Sort organizations by total count | |
sorted_stats = sorted( | |
[( | |
org_id, | |
sum(model[count_by] for model in models if combined_filter(model)) | |
) for org_id, models in org_stats.items()], | |
key=lambda x: x[1], | |
reverse=True, | |
) | |
# Top organizations + Others category | |
res = sorted_stats[:top_k] + [("Others...", sum(st for auth, st in sorted_stats[top_k:]))] | |
total_st = sum(st for o, st in res) | |
# Prepare data for treemap | |
res_plot_df = [] | |
for org, st in res: | |
if org == "Others...": | |
res_plot_df += [("Others...", "other", st * 100 / total_st if total_st > 0 else 0)] | |
else: | |
for model in org_stats[org]: | |
if combined_filter(model): | |
res_plot_df += [(org, model["id"], model[count_by] * 100 / total_st if total_st > 0 else 0)] | |
return ([(o, 100 * st / total_st if total_st > 0 else 0) for o, st in res if st > 0], res_plot_df) | |
def make_figure(count_by, org_stats, tag_filter=None, pipeline_filter=None, size_range=None): | |
assert count_by in ["downloads", "likes"] | |
# Determine which filter function to use | |
filter_func = None | |
if tag_filter: | |
filter_func = TAG_FILTER_FUNCS[tag_filter] | |
elif pipeline_filter: | |
filter_func = lambda dct: dct.get("pipeline_tag", None) and dct.get("pipeline_tag", "") == pipeline_filter | |
else: | |
filter_func = lambda dct: True | |
# Generate stats with filters | |
_, res_plot_df = make_org_stats(count_by, org_stats, top_k=25, filter_func=filter_func, size_range=size_range) | |
# Create DataFrame for Plotly | |
df = pd.DataFrame( | |
dict( | |
organizations=[o for o, _, _ in res_plot_df], | |
model=[r for _, r, _ in res_plot_df], | |
stats=[s for _, _, s in res_plot_df], | |
) | |
) | |
df["models"] = "models" # Root node | |
# Create treemap | |
fig = px.treemap(df, path=["models", 'organizations', 'model'], values='stats', | |
title=f"HuggingFace Models - {count_by.capitalize()} by Organization") | |
fig.update_layout( | |
margin=dict(t=50, l=25, r=25, b=25) | |
) | |
return fig | |
def download_and_process_models(): | |
"""Download and process the models data from HuggingFace dataset""" | |
try: | |
# Create a cache directory | |
if not os.path.exists('data'): | |
os.makedirs('data') | |
# Check if we have cached data | |
if os.path.exists('data/processed_models.json'): | |
print("Loading from cache...") | |
with open('data/processed_models.json', 'r') as f: | |
return json.load(f) | |
# URL to the models.parquet file | |
url = "https://huggingface.co/datasets/cfahlgren1/hub-stats/resolve/main/models.parquet" | |
print(f"Downloading models data from {url}...") | |
response = requests.get(url) | |
if response.status_code != 200: | |
raise Exception(f"Failed to download data: HTTP {response.status_code}") | |
# Read the parquet file | |
table = pq.read_table(BytesIO(response.content)) | |
df = table.to_pandas() | |
print(f"Downloaded {len(df)} models") | |
# Process the dataframe into the organization structure we need | |
org_stats = {} | |
for _, row in df.iterrows(): | |
model_id = row['id'] | |
# Extract the organization part of the model ID | |
if '/' in model_id: | |
org_id = model_id.split('/')[0] | |
else: | |
org_id = "unaffiliated" | |
# Create model entry with needed fields | |
model_entry = { | |
"id": model_id, | |
"downloads": row.get('downloads', 0), | |
"likes": row.get('likes', 0), | |
"pipeline_tag": row.get('pipeline_tag'), | |
"tags": row.get('tags', []), | |
} | |
# Add safetensors information if available | |
if 'safetensors' in row and row['safetensors']: | |
if isinstance(row['safetensors'], dict) and 'total' in row['safetensors']: | |
model_entry["safetensors"] = {"total": row['safetensors']['total']} | |
elif isinstance(row['safetensors'], str): | |
# Try to parse JSON string | |
try: | |
safetensors = json.loads(row['safetensors']) | |
if isinstance(safetensors, dict) and 'total' in safetensors: | |
model_entry["safetensors"] = {"total": safetensors['total']} | |
except: | |
pass | |
# Add to organization stats | |
if org_id not in org_stats: | |
org_stats[org_id] = [] | |
org_stats[org_id].append(model_entry) | |
# Cache the processed data | |
with open('data/processed_models.json', 'w') as f: | |
json.dump(org_stats, f) | |
return org_stats | |
except Exception as e: | |
print(f"Error downloading or processing data: {e}") | |
# Return sample data for testing if real data unavailable | |
return create_sample_data() | |
def create_sample_data(): | |
"""Create sample data for testing when real data is unavailable""" | |
print("Creating sample data for testing...") | |
sample_orgs = ['openai', 'meta', 'google', 'microsoft', 'anthropic', 'stability', 'huggingface'] | |
org_stats = {} | |
for org in sample_orgs: | |
org_stats[org] = [] | |
num_models = 5 # Each org has 5 sample models | |
for i in range(num_models): | |
model_id = f"{org}/model-{i+1}" | |
# Random pipeline tag | |
pipeline_idx = i % len(PIPELINE_TAGS) | |
pipeline_tag = PIPELINE_TAGS[pipeline_idx] | |
# Random tags | |
tags = [pipeline_tag, "sample-data"] | |
# Random downloads and likes | |
downloads = int(1000 * (10 ** (org_stats.keys().index(org) % 3))) # Different magnitudes | |
likes = int(downloads * 0.05) # 5% like rate | |
# Random model size in bytes (from 100MB to 100GB) | |
model_size = (10**8) * (10 ** (i % 3)) # Different magnitudes | |
org_stats[org].append({ | |
"id": model_id, | |
"downloads": downloads, | |
"likes": likes, | |
"pipeline_tag": pipeline_tag, | |
"tags": tags, | |
"safetensors": {"total": model_size} | |
}) | |
return org_stats | |
# Create Gradio interface | |
with gr.Blocks() as demo: | |
models_data = gr.State(value=None) # To store loaded data | |
with gr.Row(): | |
gr.Markdown(""" | |
## HuggingFace Models TreeMap | |
This app shows how different organizations contribute to the HuggingFace ecosystem with their models. | |
Use the filters to explore models by different metrics, tags, pipelines, and model sizes. | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
count_by_dropdown = gr.Dropdown( | |
label="Metric", | |
choices=["downloads", "likes"], | |
value="downloads" | |
) | |
filter_choice_radio = gr.Radio( | |
label="Filter by", | |
choices=["None", "Tag Filter", "Pipeline Filter"], | |
value="None" | |
) | |
tag_filter_dropdown = gr.Dropdown( | |
label="Select Tag", | |
choices=list(TAG_FILTER_FUNCS.keys()), | |
value=None, | |
visible=False | |
) | |
pipeline_filter_dropdown = gr.Dropdown( | |
label="Select Pipeline Tag", | |
choices=PIPELINE_TAGS, | |
value=None, | |
visible=False | |
) | |
size_filter_dropdown = gr.Dropdown( | |
label="Model Size Filter", | |
choices=["None"] + list(MODEL_SIZE_RANGES.keys()), | |
value="None" | |
) | |
generate_plot_button = gr.Button("Generate Plot") | |
with gr.Column(scale=3): | |
plot_output = gr.Plot() | |
def generate_plot_on_click(count_by, filter_choice, tag_filter, pipeline_filter, size_filter, data): | |
print(f"Generating plot with: Metric={count_by}, Filter={filter_choice}, Tag={tag_filter}, Pipeline={pipeline_filter}, Size={size_filter}") | |
if data is None: | |
print("Error: Data not loaded yet.") | |
return None | |
selected_tag_filter = None | |
selected_pipeline_filter = None | |
selected_size_filter = None | |
if filter_choice == "Tag Filter": | |
selected_tag_filter = tag_filter | |
elif filter_choice == "Pipeline Filter": | |
selected_pipeline_filter = pipeline_filter | |
if size_filter != "None": | |
selected_size_filter = size_filter | |
fig = make_figure( | |
count_by=count_by, | |
org_stats=data, | |
tag_filter=selected_tag_filter, | |
pipeline_filter=selected_pipeline_filter, | |
size_range=selected_size_filter | |
) | |
return fig | |
def update_filter_visibility(filter_choice): | |
if filter_choice == "Tag Filter": | |
return gr.update(visible=True), gr.update(visible=False) | |
elif filter_choice == "Pipeline Filter": | |
return gr.update(visible=False), gr.update(visible=True) | |
else: # "None" | |
return gr.update(visible=False), gr.update(visible=False) | |
filter_choice_radio.change( | |
fn=update_filter_visibility, | |
inputs=[filter_choice_radio], | |
outputs=[tag_filter_dropdown, pipeline_filter_dropdown] | |
) | |
# Load data once at startup | |
demo.load( | |
fn=download_and_process_models, | |
inputs=[], | |
outputs=[models_data] | |
) | |
# Button click event to generate plot | |
generate_plot_button.click( | |
fn=generate_plot_on_click, | |
inputs=[ | |
count_by_dropdown, | |
filter_choice_radio, | |
tag_filter_dropdown, | |
pipeline_filter_dropdown, | |
size_filter_dropdown, | |
models_data | |
], | |
outputs=[plot_output] | |
) | |
if __name__ == "__main__": | |
demo.launch() |