Spaces:
Running
Running
import json | |
import gradio as gr | |
import pandas as pd | |
import plotly.express as px | |
PIPELINE_TAGS = [ | |
'text-generation', | |
'text-to-image', | |
'text-classification', | |
'text2text-generation', | |
'audio-to-audio', | |
'feature-extraction', | |
'image-classification', | |
'translation', | |
'reinforcement-learning', | |
'fill-mask', | |
'text-to-speech', | |
'automatic-speech-recognition', | |
'image-text-to-text', | |
'token-classification', | |
'sentence-similarity', | |
'question-answering', | |
'image-feature-extraction', | |
'summarization', | |
'zero-shot-image-classification', | |
'object-detection', | |
'image-segmentation', | |
'image-to-image', | |
'image-to-text', | |
'audio-classification', | |
'visual-question-answering', | |
'text-to-video', | |
'zero-shot-classification', | |
'depth-estimation', | |
'text-ranking', | |
'image-to-video', | |
'multiple-choice', | |
'unconditional-image-generation', | |
'video-classification', | |
'text-to-audio', | |
'time-series-forecasting', | |
'any-to-any', | |
'video-text-to-text', | |
'table-question-answering', | |
] | |
def is_audio_speech(repo_dct): | |
res = (repo_dct.get("pipeline_tag", None) and "audio" in repo_dct.get("pipeline_tag", "").lower()) or \ | |
(repo_dct.get("pipeline_tag", None) and "speech" in repo_dct.get("pipeline_tag", "").lower()) or \ | |
(repo_dct.get("tags", None) and any("audio" in tag.lower() for tag in repo_dct.get("tags", []))) or \ | |
(repo_dct.get("tags", None) and any("speech" in tag.lower() for tag in repo_dct.get("tags", []))) | |
return res | |
def is_music(repo_dct): | |
res = (repo_dct.get("tags", None) and any("music" in tag.lower() for tag in repo_dct.get("tags", []))) | |
return res | |
def is_robotics(repo_dct): | |
res = (repo_dct.get("tags", None) and any("robot" in tag.lower() for tag in repo_dct.get("tags", []))) | |
return res | |
def is_biomed(repo_dct): | |
res = (repo_dct.get("tags", None) and any("bio" in tag.lower() for tag in repo_dct.get("tags", []))) or \ | |
(repo_dct.get("tags", None) and any("medic" in tag.lower() for tag in repo_dct.get("tags", []))) | |
return res | |
def is_timeseries(repo_dct): | |
res = (repo_dct.get("tags", None) and any("series" in tag.lower() for tag in repo_dct.get("tags", []))) | |
return res | |
def is_science(repo_dct): | |
res = (repo_dct.get("tags", None) and any("science" in tag.lower() and not "bigscience" in tag for tag in repo_dct.get("tags", []))) | |
return res | |
def is_video(repo_dct): | |
res = (repo_dct.get("tags", None) and any("video" in tag.lower() for tag in repo_dct.get("tags", []))) | |
return res | |
def is_image(repo_dct): | |
res = (repo_dct.get("tags", None) and any("image" in tag.lower() for tag in repo_dct.get("tags", []))) | |
return res | |
def is_text(repo_dct): | |
res = (repo_dct.get("tags", None) and any("text" in tag.lower() for tag in repo_dct.get("tags", []))) | |
return res | |
TAG_FILTER_FUNCS = { | |
"Audio & Speech": is_audio_speech, | |
"Time series": is_timeseries, | |
"Robotics": is_robotics, | |
"Music": is_music, | |
"Video": is_video, | |
"Images": is_image, | |
"Text": is_text, | |
"Biomedical": is_biomed, | |
"Sciences": is_science, | |
} | |
def make_org_stats(repo_type, count_by, org_stats, top_k=20, filter_func=None): | |
assert count_by in ["likes", "downloads", "downloads_all"] | |
assert repo_type in ["all", "datasets", "models"] | |
repos = ["datasets", "models"] if repo_type == "all" else [repo_type] | |
if filter_func is None: | |
filter_func = lambda x: True | |
sorted_stats = sorted( | |
[( | |
author, | |
sum(dct[count_by] for dct in author_dct[repo] if filter_func(dct)) | |
) for repo in repos for author, author_dct in org_stats.items()], | |
key=lambda x:x[1], | |
reverse=True, | |
) | |
res = sorted_stats[:top_k] + [("Others...", sum(st for auth, st in sorted_stats[top_k:]))] | |
total_st = sum(st for o, st in res) | |
res_plot_df = [] | |
for org, st in res: | |
if org == "Others...": | |
res_plot_df += [("Others...", "other", st * 100 / total_st)] | |
else: | |
for repo in repos: | |
for dct in org_stats[org][repo]: | |
if filter_func(dct): | |
res_plot_df += [(org, dct["id"], dct[count_by] * 100 / total_st)] | |
return ([(o, 100 * st / total_st) for o, st in res if st > 0], res_plot_df) | |
def make_figure(count_by, repo_type, org_stats, tag_filter=None, pipeline_filter=None): | |
assert count_by in ["downloads", "likes", "downloads_all"] | |
assert repo_type in ["all", "models", "datasets"] | |
assert tag_filter is None or pipeline_filter is None | |
filter_func = None | |
if tag_filter: | |
filter_func = TAG_FILTER_FUNCS[tag_filter] | |
if pipeline_filter: | |
filter_func = lambda dct: dct.get("pipeline_tag", None) and dct.get("pipeline_tag", "") == pipeline_filter | |
_, res_plot_df = make_org_stats(repo_type, count_by, org_stats, top_k=25, filter_func=filter_func) | |
df = pd.DataFrame( | |
dict( | |
organizations=[o for o, _, _ in res_plot_df], | |
repo=[r for _, r, _ in res_plot_df], | |
stats=[s for _, _, s in res_plot_df], | |
) | |
) | |
df[repo_type] = repo_type # in order to have a single root node | |
fig = px.treemap(df, path=[repo_type, 'organizations', 'repo'], values='stats') | |
fig.update_layout( | |
treemapcolorway = ["pink" for _ in range(len(res_plot_df))], | |
margin = dict(t=50, l=25, r=25, b=25) | |
) | |
return fig | |
with gr.Blocks() as demo: | |
org_stats_data = gr.State(value=None) # To store loaded data | |
with gr.Row(): | |
gr.Markdown(""" | |
## Hugging Face Organization Stats | |
This app shows how different organizations are contributing to different aspects of the open AI ecosystem. | |
Use the dropdowns on the left to select repository types, metrics, and optionally tags representing topics or modalities of interest. | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
repo_type_dropdown = gr.Dropdown( | |
label="Repository Type", | |
choices=["all", "models", "datasets"], | |
value="all" | |
) | |
count_by_dropdown = gr.Dropdown( | |
label="Metric", | |
choices=["downloads", "likes", "downloads_all"], | |
value="downloads" | |
) | |
filter_choice_radio = gr.Radio( | |
label="Filter by", | |
choices=["None", "Tag Filter", "Pipeline Filter"], | |
value="None" | |
) | |
tag_filter_dropdown = gr.Dropdown( | |
label="Select Tag", | |
choices=list(TAG_FILTER_FUNCS.keys()), | |
value=None, | |
visible=False | |
) | |
pipeline_filter_dropdown = gr.Dropdown( | |
label="Select Pipeline Tag", | |
choices=PIPELINE_TAGS, | |
value=None, | |
visible=False | |
) | |
generate_plot_button = gr.Button("Generate Plot") | |
with gr.Column(scale=3): | |
plot_output = gr.Plot() | |
def generate_plot_on_click(repo_type, count_by, filter_choice, tag_filter, pipeline_filter, data): | |
# Print the current state of the input variables | |
print(f"Generating plot with the following inputs:") | |
print(f" Repository Type: {repo_type}") | |
print(f" Metric (Count By): {count_by}") | |
print(f" Filter Choice: {filter_choice}") | |
if filter_choice == "Tag Filter": | |
print(f" Tag Filter: {tag_filter}") | |
elif filter_choice == "Pipeline Filter": | |
print(f" Pipeline Filter: {pipeline_filter}") | |
if data is None: | |
print("Error: Data not loaded yet.") | |
return None | |
selected_tag_filter = None | |
selected_pipeline_filter = None | |
if filter_choice == "Tag Filter": | |
selected_tag_filter = tag_filter | |
elif filter_choice == "Pipeline Filter": | |
selected_pipeline_filter = pipeline_filter | |
fig = make_figure( | |
count_by=count_by, | |
repo_type=repo_type, | |
org_stats=data, | |
tag_filter=selected_tag_filter, | |
pipeline_filter=selected_pipeline_filter | |
) | |
return fig | |
def update_filter_visibility(filter_choice): | |
if filter_choice == "Tag Filter": | |
return gr.update(visible=True), gr.update(visible=False) | |
elif filter_choice == "Pipeline Filter": | |
return gr.update(visible=False), gr.update(visible=True) | |
else: # "None" | |
return gr.update(visible=False), gr.update(visible=False) | |
filter_choice_radio.change( | |
fn=update_filter_visibility, | |
inputs=[filter_choice_radio], | |
outputs=[tag_filter_dropdown, pipeline_filter_dropdown] | |
) | |
# Load data once at startup | |
def load_org_data(): | |
print("Loading organization statistics data...") | |
loaded_org_stats = json.load(open("org_to_artifacts_2l_stats.json")) | |
print("Data loaded successfully.") | |
return loaded_org_stats | |
demo.load( | |
fn=load_org_data, | |
inputs=[], # No inputs needed to just load data | |
outputs=[org_stats_data] # Only output to the state | |
) | |
# Button click event to generate plot | |
generate_plot_button.click( | |
fn=generate_plot_on_click, | |
inputs=[ | |
repo_type_dropdown, | |
count_by_dropdown, | |
filter_choice_radio, | |
tag_filter_dropdown, | |
pipeline_filter_dropdown, | |
org_stats_data | |
], | |
outputs=[plot_output] | |
) | |
if __name__ == "__main__": | |
# org_stats = json.load(open("org_to_artifacts_2l_stats.json")) # Data loading handled by demo.load | |
demo.launch() |