|
|
|
import gradio as gr |
|
import requests |
|
from huggingface_hub import HfApi |
|
from huggingface_hub.errors import RepositoryNotFoundError |
|
import pandas as pd |
|
import plotly.express as px |
|
from gradio_huggingfacehub_search import HuggingfaceHubSearch |
|
from collections import defaultdict |
|
import numpy as np |
|
|
|
HF_API = HfApi() |
|
|
|
|
|
def apply_power_scaling(sizes, exponent=0.2): |
|
"""Apply custom power scaling to the sizes.""" |
|
"""skip over if size is none, but make sure to fill it as 0""" |
|
return [size**exponent if size is not None else 0 for size in sizes] |
|
|
|
|
|
def count_chunks(sizes): |
|
"""Count the number of chunks, which are 64KB each in size - which are bytes""" |
|
"""always round up to the nearest chunk""" |
|
return [int(np.ceil(size / 64_000)) if size is not None else 0 for size in sizes] |
|
|
|
|
|
def build_hierarchy(siblings): |
|
"""Builds a hierarchical structure from the list of RepoSibling objects.""" |
|
hierarchy = defaultdict(dict) |
|
|
|
for sibling in siblings: |
|
path_parts = sibling.rfilename.split("/") |
|
size = sibling.lfs.size if sibling.lfs else sibling.size |
|
|
|
current_level = hierarchy |
|
for part in path_parts[:-1]: |
|
current_level = current_level.setdefault(part, {}) |
|
current_level[path_parts[-1]] = size |
|
|
|
return hierarchy |
|
|
|
|
|
def calculate_directory_sizes(hierarchy): |
|
"""Recursively calculates the size of each directory as the sum of its contents.""" |
|
total_size = 0 |
|
|
|
for key, value in hierarchy.items(): |
|
if isinstance(value, dict): |
|
dir_size = calculate_directory_sizes(value) |
|
hierarchy[key] = { |
|
"__size__": dir_size, |
|
**value, |
|
} |
|
total_size += dir_size |
|
else: |
|
total_size += value |
|
|
|
return total_size |
|
|
|
|
|
def flatten_hierarchy_with_directory_sizes(hierarchy, root_name="Repository"): |
|
"""Flatten a nested dictionary into Plotly-compatible treemap data with a defined root node.""" |
|
labels = [] |
|
parents = [] |
|
sizes = [] |
|
|
|
|
|
def process_level(current_hierarchy, current_parent): |
|
for key, value in current_hierarchy.items(): |
|
if isinstance(value, dict) and "__size__" in value: |
|
dir_size = value.pop("__size__") |
|
labels.append(key) |
|
parents.append(current_parent) |
|
sizes.append(dir_size) |
|
process_level(value, key) |
|
else: |
|
labels.append(key) |
|
parents.append(current_parent) |
|
sizes.append(value) |
|
|
|
|
|
total_size = calculate_directory_sizes(hierarchy) |
|
labels.append(root_name) |
|
parents.append("") |
|
sizes.append(total_size) |
|
|
|
|
|
process_level(hierarchy, root_name) |
|
|
|
return labels, parents, sizes |
|
|
|
|
|
def visualize_repo_treemap(r_info): |
|
"""Visualizes the repository as a treemap with directory sizes and human-readable tooltips.""" |
|
siblings = r_info.siblings |
|
hierarchy = build_hierarchy(siblings) |
|
|
|
|
|
calculate_directory_sizes(hierarchy) |
|
|
|
|
|
labels, parents, sizes = flatten_hierarchy_with_directory_sizes(hierarchy) |
|
|
|
|
|
scaled_sizes = apply_power_scaling(sizes) |
|
|
|
|
|
formatted_sizes = [ |
|
( |
|
format_repo_size(size) if size is not None else None |
|
) |
|
for size in sizes |
|
] |
|
|
|
chunks = count_chunks(sizes) |
|
|
|
|
|
fig = px.treemap( |
|
names=labels, |
|
parents=parents, |
|
values=scaled_sizes, |
|
title="Repo by Chunks", |
|
custom_data=[formatted_sizes, chunks], |
|
) |
|
|
|
|
|
fig.update_layout( |
|
title={ |
|
"text": "Repo File Size Treemap<br><span style='font-size:14px;'>Hover over each directory or file to see the size of the file and its number of chunks</span>", |
|
"x": 0.5, |
|
"xanchor": "center", |
|
} |
|
) |
|
|
|
|
|
fig.update_traces( |
|
hovertemplate=( |
|
"<b>%{label}</b><br>" |
|
"Size: %{customdata[0]}<br>" |
|
"# of Chunks: %{customdata[1]}" |
|
) |
|
) |
|
fig.update_traces(root_color="lightgrey") |
|
fig.update_layout(margin=dict(t=50, l=25, r=25, b=25)) |
|
|
|
return fig |
|
|
|
|
|
def format_repo_size(r_size: int) -> str: |
|
units = {0: "B", 1: "KB", 2: "MB", 3: "GB", 4: "TB", 5: "PB"} |
|
order = 0 |
|
while r_size >= 1024 and order < len(units) - 1: |
|
r_size /= 1024 |
|
order += 1 |
|
return f"{r_size:.2f} {units[order]}" |
|
|
|
|
|
def repo_files(r_type: str, r_id: str) -> dict: |
|
r_info = HF_API.repo_info(repo_id=r_id, repo_type=r_type, files_metadata=True) |
|
fig = visualize_repo_treemap(r_info) |
|
files = {} |
|
for sibling in r_info.siblings: |
|
ext = sibling.rfilename.split(".")[-1] |
|
if ext in files: |
|
files[ext]["size"] += sibling.size |
|
files[ext]["count"] += 1 |
|
else: |
|
files[ext] = {} |
|
files[ext]["size"] = sibling.size |
|
files[ext]["count"] = 1 |
|
return files, fig |
|
|
|
|
|
def repo_size(r_type, r_id): |
|
try: |
|
r_refs = HF_API.list_repo_refs(repo_id=r_id, repo_type=r_type) |
|
except RepositoryNotFoundError: |
|
gr.Warning(f"Repository is gated, branch information for {r_id} not available.") |
|
return {} |
|
repo_sizes = {} |
|
for branch in r_refs.branches: |
|
try: |
|
response = requests.get( |
|
f"https://huggingface.co/api/{r_type}s/{r_id}/treesize/{branch.name}", |
|
timeout=1000, |
|
) |
|
response = response.json() |
|
except Exception: |
|
response = {} |
|
if response.get("error") and ( |
|
"restricted" in response.get("error") or "gated" in response.get("error") |
|
): |
|
gr.Warning(f"Branch information for {r_id} not available.") |
|
return {} |
|
size = response.get("size") |
|
if size is not None: |
|
repo_sizes[branch.name] = size |
|
return repo_sizes |
|
|
|
|
|
def get_repo_info(r_type, r_id): |
|
try: |
|
repo_sizes = repo_size(r_type, r_id) |
|
repo_files_info, treemap_fig = repo_files(r_type, r_id) |
|
except RepositoryNotFoundError: |
|
gr.Warning( |
|
"Repository not found. Make sure you've entered a valid repo ID and type that corresponds to the repository." |
|
) |
|
return ( |
|
gr.Row(visible=False), |
|
gr.Dataframe(visible=False), |
|
gr.Plot(visible=False), |
|
gr.Row(visible=False), |
|
gr.Dataframe(visible=False), |
|
) |
|
|
|
rf_sizes_df = ( |
|
pd.DataFrame(repo_files_info) |
|
.T.reset_index(names="ext") |
|
.sort_values(by="size", ascending=False) |
|
) |
|
|
|
if not repo_sizes: |
|
r_sizes_component = gr.Dataframe(visible=False) |
|
b_block = gr.Row(visible=False) |
|
else: |
|
r_sizes_df = pd.DataFrame(repo_sizes, index=["size"]).T.reset_index( |
|
names="branch" |
|
) |
|
r_sizes_df["formatted_size"] = r_sizes_df["size"].apply(format_repo_size) |
|
r_sizes_df.columns = ["Branch", "bytes", "Size"] |
|
r_sizes_component = gr.Dataframe( |
|
value=r_sizes_df[["Branch", "Size"]], visible=True |
|
) |
|
b_block = gr.Row(visible=True) |
|
|
|
rf_sizes_df["formatted_size"] = rf_sizes_df["size"].apply(format_repo_size) |
|
rf_sizes_df.columns = ["Extension", "bytes", "Count", "Size"] |
|
rf_sizes_plot = px.pie( |
|
rf_sizes_df, |
|
values="bytes", |
|
names="Extension", |
|
hover_data=["Size"], |
|
title=f"File Distribution in {r_id}", |
|
hole=0.3, |
|
) |
|
return ( |
|
gr.Row(visible=True), |
|
gr.Dataframe( |
|
value=rf_sizes_df[["Extension", "Count", "Size"]], |
|
visible=True, |
|
), |
|
|
|
gr.Plot(treemap_fig, visible=True), |
|
b_block, |
|
r_sizes_component, |
|
) |
|
|
|
|
|
with gr.Blocks(theme="ocean") as demo: |
|
gr.Markdown("# Repository Information") |
|
gr.Markdown( |
|
"Search for a model or dataset repository using the autocomplete below, select the repository type, and get back information about the repository's files and branches." |
|
) |
|
with gr.Blocks(): |
|
|
|
repo_id = HuggingfaceHubSearch( |
|
label="Hub Repository Search (enter user, organization, or repository name to start searching)", |
|
placeholder="Search for model or dataset repositories on Huggingface", |
|
search_type=["model", "dataset"], |
|
) |
|
repo_type = gr.Radio( |
|
choices=["model", "dataset"], |
|
label="Repository Type", |
|
value="model", |
|
) |
|
search_button = gr.Button(value="Search") |
|
with gr.Blocks(): |
|
with gr.Row(visible=False) as results_block: |
|
with gr.Column(): |
|
gr.Markdown("## File Information") |
|
file_info_plot = gr.Plot(visible=False) |
|
with gr.Row(): |
|
file_info = gr.Dataframe(visible=False) |
|
|
|
with gr.Row(visible=False) as branch_block: |
|
with gr.Column(): |
|
gr.Markdown("## Branch Sizes") |
|
branch_sizes = gr.Dataframe(visible=False) |
|
|
|
search_button.click( |
|
get_repo_info, |
|
inputs=[repo_type, repo_id], |
|
outputs=[results_block, file_info, file_info_plot, branch_block, branch_sizes], |
|
) |
|
|
|
demo.launch() |
|
|