|
import json |
|
import gradio as gr |
|
import pandas as pd |
|
import plotly.express as px |
|
import pyarrow.parquet as pq |
|
import os |
|
import requests |
|
from io import BytesIO |
|
import math |
|
|
|
|
|
PIPELINE_TAGS = [ |
|
'text-generation', |
|
'text-to-image', |
|
'text-classification', |
|
'text2text-generation', |
|
'audio-to-audio', |
|
'feature-extraction', |
|
'image-classification', |
|
'translation', |
|
'reinforcement-learning', |
|
'fill-mask', |
|
'text-to-speech', |
|
'automatic-speech-recognition', |
|
'image-text-to-text', |
|
'token-classification', |
|
'sentence-similarity', |
|
'question-answering', |
|
'image-feature-extraction', |
|
'summarization', |
|
'zero-shot-image-classification', |
|
'object-detection', |
|
'image-segmentation', |
|
'image-to-image', |
|
'image-to-text', |
|
'audio-classification', |
|
'visual-question-answering', |
|
'text-to-video', |
|
'zero-shot-classification', |
|
'depth-estimation', |
|
'text-ranking', |
|
'image-to-video', |
|
'multiple-choice', |
|
'unconditional-image-generation', |
|
'video-classification', |
|
'text-to-audio', |
|
'time-series-forecasting', |
|
'any-to-any', |
|
'video-text-to-text', |
|
'table-question-answering', |
|
] |
|
|
|
|
|
MODEL_SIZE_RANGES = { |
|
"Small (<1GB)": (0, 1), |
|
"Medium (1-5GB)": (1, 5), |
|
"Large (5-20GB)": (5, 20), |
|
"X-Large (20-50GB)": (20, 50), |
|
"XX-Large (>50GB)": (50, float('inf')) |
|
} |
|
|
|
|
|
def is_audio_speech(repo_dct): |
|
res = (repo_dct.get("pipeline_tag", None) and "audio" in repo_dct.get("pipeline_tag", "").lower()) or \ |
|
(repo_dct.get("pipeline_tag", None) and "speech" in repo_dct.get("pipeline_tag", "").lower()) or \ |
|
(repo_dct.get("tags", None) and any("audio" in tag.lower() for tag in repo_dct.get("tags", []))) or \ |
|
(repo_dct.get("tags", None) and any("speech" in tag.lower() for tag in repo_dct.get("tags", []))) |
|
return res |
|
|
|
def is_music(repo_dct): |
|
res = (repo_dct.get("tags", None) and any("music" in tag.lower() for tag in repo_dct.get("tags", []))) |
|
return res |
|
|
|
def is_robotics(repo_dct): |
|
res = (repo_dct.get("tags", None) and any("robot" in tag.lower() for tag in repo_dct.get("tags", []))) |
|
return res |
|
|
|
def is_biomed(repo_dct): |
|
res = (repo_dct.get("tags", None) and any("bio" in tag.lower() for tag in repo_dct.get("tags", []))) or \ |
|
(repo_dct.get("tags", None) and any("medic" in tag.lower() for tag in repo_dct.get("tags", []))) |
|
return res |
|
|
|
def is_timeseries(repo_dct): |
|
res = (repo_dct.get("tags", None) and any("series" in tag.lower() for tag in repo_dct.get("tags", []))) |
|
return res |
|
|
|
def is_science(repo_dct): |
|
res = (repo_dct.get("tags", None) and any("science" in tag.lower() and not "bigscience" in tag for tag in repo_dct.get("tags", []))) |
|
return res |
|
|
|
def is_video(repo_dct): |
|
res = (repo_dct.get("tags", None) and any("video" in tag.lower() for tag in repo_dct.get("tags", []))) |
|
return res |
|
|
|
def is_image(repo_dct): |
|
res = (repo_dct.get("tags", None) and any("image" in tag.lower() for tag in repo_dct.get("tags", []))) |
|
return res |
|
|
|
def is_text(repo_dct): |
|
res = (repo_dct.get("tags", None) and any("text" in tag.lower() for tag in repo_dct.get("tags", []))) |
|
return res |
|
|
|
|
|
def is_in_size_range(repo_dct, size_range): |
|
if size_range is None: |
|
return True |
|
|
|
min_size, max_size = MODEL_SIZE_RANGES[size_range] |
|
|
|
|
|
if repo_dct.get("safetensors") and repo_dct["safetensors"].get("total"): |
|
|
|
size_gb = repo_dct["safetensors"]["total"] / (1024 * 1024 * 1024) |
|
return min_size <= size_gb < max_size |
|
|
|
return False |
|
|
|
TAG_FILTER_FUNCS = { |
|
"Audio & Speech": is_audio_speech, |
|
"Time series": is_timeseries, |
|
"Robotics": is_robotics, |
|
"Music": is_music, |
|
"Video": is_video, |
|
"Images": is_image, |
|
"Text": is_text, |
|
"Biomedical": is_biomed, |
|
"Sciences": is_science, |
|
} |
|
|
|
def make_org_stats(count_by, org_stats, top_k=20, filter_func=None, size_range=None): |
|
assert count_by in ["likes", "downloads"] |
|
|
|
|
|
def combined_filter(dct): |
|
passes_tag_filter = filter_func(dct) if filter_func else True |
|
passes_size_filter = is_in_size_range(dct, size_range) if size_range else True |
|
return passes_tag_filter and passes_size_filter |
|
|
|
|
|
sorted_stats = sorted( |
|
[( |
|
org_id, |
|
sum(model[count_by] for model in models if combined_filter(model)) |
|
) for org_id, models in org_stats.items()], |
|
key=lambda x: x[1], |
|
reverse=True, |
|
) |
|
|
|
|
|
res = sorted_stats[:top_k] + [("Others...", sum(st for auth, st in sorted_stats[top_k:]))] |
|
total_st = sum(st for o, st in res) |
|
|
|
|
|
res_plot_df = [] |
|
for org, st in res: |
|
if org == "Others...": |
|
res_plot_df += [("Others...", "other", st * 100 / total_st if total_st > 0 else 0)] |
|
else: |
|
for model in org_stats[org]: |
|
if combined_filter(model): |
|
res_plot_df += [(org, model["id"], model[count_by] * 100 / total_st if total_st > 0 else 0)] |
|
|
|
return ([(o, 100 * st / total_st if total_st > 0 else 0) for o, st in res if st > 0], res_plot_df) |
|
|
|
def make_figure(count_by, org_stats, tag_filter=None, pipeline_filter=None, size_range=None): |
|
assert count_by in ["downloads", "likes"] |
|
|
|
|
|
filter_func = None |
|
if tag_filter: |
|
filter_func = TAG_FILTER_FUNCS[tag_filter] |
|
elif pipeline_filter: |
|
filter_func = lambda dct: dct.get("pipeline_tag", None) and dct.get("pipeline_tag", "") == pipeline_filter |
|
else: |
|
filter_func = lambda dct: True |
|
|
|
|
|
_, res_plot_df = make_org_stats(count_by, org_stats, top_k=25, filter_func=filter_func, size_range=size_range) |
|
|
|
|
|
df = pd.DataFrame( |
|
dict( |
|
organizations=[o for o, _, _ in res_plot_df], |
|
model=[r for _, r, _ in res_plot_df], |
|
stats=[s for _, _, s in res_plot_df], |
|
) |
|
) |
|
|
|
df["models"] = "models" |
|
|
|
|
|
fig = px.treemap(df, path=["models", 'organizations', 'model'], values='stats', |
|
title=f"HuggingFace Models - {count_by.capitalize()} by Organization") |
|
|
|
fig.update_layout( |
|
margin=dict(t=50, l=25, r=25, b=25) |
|
) |
|
|
|
return fig |
|
|
|
def download_and_process_models(): |
|
"""Download and process the models data from HuggingFace dataset""" |
|
try: |
|
|
|
if not os.path.exists('data'): |
|
os.makedirs('data') |
|
|
|
|
|
if os.path.exists('data/processed_models.json'): |
|
print("Loading from cache...") |
|
with open('data/processed_models.json', 'r') as f: |
|
return json.load(f) |
|
|
|
|
|
url = "https://huggingface.co/datasets/cfahlgren1/hub-stats/resolve/main/models.parquet" |
|
|
|
print(f"Downloading models data from {url}...") |
|
response = requests.get(url) |
|
if response.status_code != 200: |
|
raise Exception(f"Failed to download data: HTTP {response.status_code}") |
|
|
|
|
|
table = pq.read_table(BytesIO(response.content)) |
|
df = table.to_pandas() |
|
|
|
print(f"Downloaded {len(df)} models") |
|
|
|
|
|
org_stats = {} |
|
|
|
for _, row in df.iterrows(): |
|
model_id = row['id'] |
|
|
|
|
|
if '/' in model_id: |
|
org_id = model_id.split('/')[0] |
|
else: |
|
org_id = "unaffiliated" |
|
|
|
|
|
model_entry = { |
|
"id": model_id, |
|
"downloads": row.get('downloads', 0), |
|
"likes": row.get('likes', 0), |
|
"pipeline_tag": row.get('pipeline_tag'), |
|
"tags": row.get('tags', []), |
|
} |
|
|
|
|
|
if 'safetensors' in row and row['safetensors']: |
|
if isinstance(row['safetensors'], dict) and 'total' in row['safetensors']: |
|
model_entry["safetensors"] = {"total": row['safetensors']['total']} |
|
elif isinstance(row['safetensors'], str): |
|
|
|
try: |
|
safetensors = json.loads(row['safetensors']) |
|
if isinstance(safetensors, dict) and 'total' in safetensors: |
|
model_entry["safetensors"] = {"total": safetensors['total']} |
|
except: |
|
pass |
|
|
|
|
|
if org_id not in org_stats: |
|
org_stats[org_id] = [] |
|
|
|
org_stats[org_id].append(model_entry) |
|
|
|
|
|
with open('data/processed_models.json', 'w') as f: |
|
json.dump(org_stats, f) |
|
|
|
return org_stats |
|
|
|
except Exception as e: |
|
print(f"Error downloading or processing data: {e}") |
|
|
|
return create_sample_data() |
|
|
|
def create_sample_data(): |
|
"""Create sample data for testing when real data is unavailable""" |
|
print("Creating sample data for testing...") |
|
|
|
sample_orgs = ['openai', 'meta', 'google', 'microsoft', 'anthropic', 'stability', 'huggingface'] |
|
org_stats = {} |
|
|
|
for org in sample_orgs: |
|
org_stats[org] = [] |
|
num_models = 5 |
|
|
|
for i in range(num_models): |
|
model_id = f"{org}/model-{i+1}" |
|
|
|
|
|
pipeline_idx = i % len(PIPELINE_TAGS) |
|
pipeline_tag = PIPELINE_TAGS[pipeline_idx] |
|
|
|
|
|
tags = [pipeline_tag, "sample-data"] |
|
|
|
|
|
downloads = int(1000 * (10 ** (org_stats.keys().index(org) % 3))) |
|
likes = int(downloads * 0.05) |
|
|
|
|
|
model_size = (10**8) * (10 ** (i % 3)) |
|
|
|
org_stats[org].append({ |
|
"id": model_id, |
|
"downloads": downloads, |
|
"likes": likes, |
|
"pipeline_tag": pipeline_tag, |
|
"tags": tags, |
|
"safetensors": {"total": model_size} |
|
}) |
|
|
|
return org_stats |
|
|
|
|
|
with gr.Blocks() as demo: |
|
models_data = gr.State(value=None) |
|
|
|
with gr.Row(): |
|
gr.Markdown(""" |
|
## HuggingFace Models TreeMap |
|
|
|
This app shows how different organizations contribute to the HuggingFace ecosystem with their models. |
|
Use the filters to explore models by different metrics, tags, pipelines, and model sizes. |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
count_by_dropdown = gr.Dropdown( |
|
label="Metric", |
|
choices=["downloads", "likes"], |
|
value="downloads" |
|
) |
|
|
|
filter_choice_radio = gr.Radio( |
|
label="Filter by", |
|
choices=["None", "Tag Filter", "Pipeline Filter"], |
|
value="None" |
|
) |
|
|
|
tag_filter_dropdown = gr.Dropdown( |
|
label="Select Tag", |
|
choices=list(TAG_FILTER_FUNCS.keys()), |
|
value=None, |
|
visible=False |
|
) |
|
|
|
pipeline_filter_dropdown = gr.Dropdown( |
|
label="Select Pipeline Tag", |
|
choices=PIPELINE_TAGS, |
|
value=None, |
|
visible=False |
|
) |
|
|
|
size_filter_dropdown = gr.Dropdown( |
|
label="Model Size Filter", |
|
choices=["None"] + list(MODEL_SIZE_RANGES.keys()), |
|
value="None" |
|
) |
|
|
|
generate_plot_button = gr.Button("Generate Plot") |
|
|
|
with gr.Column(scale=3): |
|
plot_output = gr.Plot() |
|
|
|
def generate_plot_on_click(count_by, filter_choice, tag_filter, pipeline_filter, size_filter, data): |
|
print(f"Generating plot with: Metric={count_by}, Filter={filter_choice}, Tag={tag_filter}, Pipeline={pipeline_filter}, Size={size_filter}") |
|
|
|
if data is None: |
|
print("Error: Data not loaded yet.") |
|
return None |
|
|
|
selected_tag_filter = None |
|
selected_pipeline_filter = None |
|
selected_size_filter = None |
|
|
|
if filter_choice == "Tag Filter": |
|
selected_tag_filter = tag_filter |
|
elif filter_choice == "Pipeline Filter": |
|
selected_pipeline_filter = pipeline_filter |
|
|
|
if size_filter != "None": |
|
selected_size_filter = size_filter |
|
|
|
fig = make_figure( |
|
count_by=count_by, |
|
org_stats=data, |
|
tag_filter=selected_tag_filter, |
|
pipeline_filter=selected_pipeline_filter, |
|
size_range=selected_size_filter |
|
) |
|
return fig |
|
|
|
def update_filter_visibility(filter_choice): |
|
if filter_choice == "Tag Filter": |
|
return gr.update(visible=True), gr.update(visible=False) |
|
elif filter_choice == "Pipeline Filter": |
|
return gr.update(visible=False), gr.update(visible=True) |
|
else: |
|
return gr.update(visible=False), gr.update(visible=False) |
|
|
|
filter_choice_radio.change( |
|
fn=update_filter_visibility, |
|
inputs=[filter_choice_radio], |
|
outputs=[tag_filter_dropdown, pipeline_filter_dropdown] |
|
) |
|
|
|
|
|
demo.load( |
|
fn=download_and_process_models, |
|
inputs=[], |
|
outputs=[models_data] |
|
) |
|
|
|
|
|
generate_plot_button.click( |
|
fn=generate_plot_on_click, |
|
inputs=[ |
|
count_by_dropdown, |
|
filter_choice_radio, |
|
tag_filter_dropdown, |
|
pipeline_filter_dropdown, |
|
size_filter_dropdown, |
|
models_data |
|
], |
|
outputs=[plot_output] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |