|
from huggingface_hub import list_models |
|
from cachetools import cached, TTLCache |
|
from toolz import groupby, valmap |
|
import gradio as gr |
|
from tqdm.auto import tqdm |
|
import pandas as pd |
|
|
|
|
|
@cached(TTLCache(maxsize=10, ttl=60 * 60 * 3)) |
|
def get_all_models(): |
|
models = list(tqdm(iter(list_models(cardData=True)))) |
|
return [model for model in models if model is not None] |
|
|
|
|
|
def has_base_model_info(model): |
|
try: |
|
if card_data := model.cardData: |
|
if base_model := card_data.get("base_model"): |
|
if isinstance(base_model, str): |
|
return True |
|
except AttributeError: |
|
return False |
|
return False |
|
|
|
|
|
grouped_by_has_base_model_info = groupby(has_base_model_info, get_all_models()) |
|
print(valmap(len, grouped_by_has_base_model_info)) |
|
|
|
summary = f"""{len(grouped_by_has_base_model_info.get(True)):,} models have base model info. |
|
{len(grouped_by_has_base_model_info.get(False)):,} models don't have base model info. |
|
Currently {round(len(grouped_by_has_base_model_info.get(True))/len(get_all_models())*100,2)}% of models have base model info.""" |
|
|
|
models_with_base_model_info = grouped_by_has_base_model_info.get(True) |
|
base_models = [ |
|
model.cardData.get("base_model") for model in models_with_base_model_info |
|
] |
|
df = pd.DataFrame( |
|
pd.DataFrame({"base_model": base_models}).value_counts() |
|
).reset_index() |
|
df_with_org = df.copy(deep=True) |
|
|
|
|
|
def parse_org(hub_id): |
|
parts = hub_id.split("/") |
|
return parts[0] if len(parts) == 2 else None |
|
|
|
|
|
df_with_org["org"] = df_with_org["base_model"].apply(parse_org) |
|
df_with_org = df_with_org.dropna(subset=["org"]) |
|
|
|
grouped_by_base_model = groupby( |
|
lambda x: x.cardData.get("base_model"), models_with_base_model_info |
|
) |
|
|
|
all_base_models = df["base_model"].to_list() |
|
|
|
|
|
def return_models_for_base_model(base_model): |
|
models = grouped_by_base_model.get(base_model) |
|
|
|
models = sorted(models, key=lambda x: x.downloads, reverse=True) |
|
results = "" |
|
results += f"## {base_model} children\n\n" |
|
results += f"{base_model} has {len(models)} children\n\n" |
|
for model in models: |
|
url = f"https://huggingface.co/{model.modelId}" |
|
results += ( |
|
f"[{model.modelId}]({url}) | number of downloads {model.downloads}" + "\n\n" |
|
) |
|
return results |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Base model explorer") |
|
gr.Markdown( |
|
"""When sharing models to the Hub it is possible to specify a base model in the model card i.e. that your model is a fine-tuned version of [bert-base-cased](https://huggingface.co/bert-base-cased). |
|
This Space allows you to find children models for a given base model and view the popularity of models for fine-tuning.""" |
|
) |
|
gr.Markdown(summary) |
|
gr.Markdown("### Find all models trained from a base model") |
|
base_model = gr.Dropdown(all_base_models, label="Base Model") |
|
results = gr.Markdown() |
|
base_model.change(return_models_for_base_model, base_model, results) |
|
with gr.Accordion("Base model popularity ranking", open=False): |
|
gr.DataFrame(df.head(50)) |
|
with gr.Accordion("Base model popularity ranking by organization", open=False): |
|
gr.DataFrame( |
|
pd.DataFrame( |
|
df_with_org.groupby("org")["count"] |
|
.sum() |
|
.sort_values(ascending=False) |
|
.head(50) |
|
) |
|
.reset_index() |
|
.sort_values("count", ascending=False) |
|
) |
|
|
|
|
|
demo.launch() |
|
|