File size: 15,120 Bytes
9d2f4f2
 
bbf45d0
 
 
9d2f4f2
1308a67
bbf45d0
9d2f4f2
7262ace
9d2f4f2
 
7262ace
1308a67
 
 
7262ace
9d2f4f2
 
 
 
bbf45d0
9d2f4f2
 
 
 
 
 
 
 
 
 
 
 
 
bbf45d0
 
1308a67
 
 
 
9d2f4f2
1308a67
9d2f4f2
1308a67
 
9d2f4f2
 
 
 
 
 
 
1308a67
 
 
 
 
 
 
9d2f4f2
1308a67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d2f4f2
1308a67
 
 
31bb936
1308a67
 
 
 
31bb936
1308a67
 
bbf45d0
7262ace
 
9d2f4f2
7262ace
9d2f4f2
 
 
7262ace
9d2f4f2
 
 
 
 
 
1308a67
bbf45d0
9d2f4f2
1308a67
 
9d2f4f2
 
1308a67
9d2f4f2
 
 
 
 
1308a67
7262ace
9d2f4f2
 
 
 
1308a67
9d2f4f2
1308a67
 
 
 
 
 
 
 
9d2f4f2
 
1308a67
 
9d2f4f2
7262ace
1308a67
9d2f4f2
7262ace
 
 
 
9d2f4f2
 
7262ace
 
9d2f4f2
7262ace
 
bbf45d0
9d2f4f2
 
bbf45d0
 
9d2f4f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1308a67
9d2f4f2
 
 
 
 
 
 
 
 
 
 
 
 
1308a67
 
 
9d2f4f2
 
 
 
 
7262ace
1308a67
 
9d2f4f2
1308a67
 
 
9d2f4f2
1308a67
9d2f4f2
 
 
7262ace
9d2f4f2
7262ace
1308a67
9d2f4f2
 
 
 
 
 
 
7262ace
9d2f4f2
 
 
 
 
7262ace
9d2f4f2
 
 
 
 
 
 
 
 
 
 
 
1308a67
9d2f4f2
 
1308a67
9d2f4f2
 
1308a67
 
 
9d2f4f2
 
 
 
7262ace
 
9d2f4f2
bbf45d0
1308a67
 
9d2f4f2
 
 
1308a67
9d2f4f2
 
7262ace
9d2f4f2
 
 
1308a67
9d2f4f2
7262ace
1308a67
bbf45d0
1308a67
9d2f4f2
 
7262ace
1308a67
bbf45d0
9d2f4f2
 
 
 
bbf45d0
 
 
1308a67
 
 
9d2f4f2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
# --- START OF FILE app.py ---

import gradio as gr
import pandas as pd
import plotly.express as px
import time
from datasets import load_dataset # Import the datasets library

# --- Constants ---
MODEL_SIZE_RANGES = {
    "Small (<1GB)": (0, 1), "Medium (1-5GB)": (1, 5), "Large (5-20GB)": (5, 20),
    "X-Large (20-50GB)": (20, 50), "XX-Large (>50GB)": (50, float('inf'))
}

# The Hugging Face dataset ID to load.
HF_DATASET_ID = "evijit/orgstats_daily_data"

TAG_FILTER_CHOICES = [
    "Audio & Speech", "Time series", "Robotics", "Music", "Video", "Images",
    "Text", "Biomedical", "Sciences"
]

PIPELINE_TAGS = [
 'text-generation', 'text-to-image', 'text-classification', 'text2text-generation',
 'audio-to-audio', 'feature-extraction', 'image-classification', 'translation',
 'reinforcement-learning', 'fill-mask', 'text-to-speech', 'automatic-speech-recognition',
 'image-text-to-text', 'token-classification', 'sentence-similarity', 'question-answering',
 'image-feature-extraction', 'summarization', 'zero-shot-image-classification',
 'object-detection', 'image-segmentation', 'image-to-image', 'image-to-text',
 'audio-classification', 'visual-question-answering', 'text-to-video',
 'zero-shot-classification', 'depth-estimation', 'text-ranking', 'image-to-video',
 'multiple-choice', 'unconditional-image-generation', 'video-classification',
 'text-to-audio', 'time-series-forecasting', 'any-to-any', 'video-text-to-text',
 'table-question-answering',
]


def load_models_data():
    """
    Loads the pre-processed models data using the HF datasets library.
    """
    overall_start_time = time.time()
    print(f"Attempting to load dataset from Hugging Face Hub: {HF_DATASET_ID}")

    # These are the columns expected to be in the pre-processed dataset.
    expected_cols = [
        'id', 'downloads', 'downloadsAllTime', 'likes', 'pipeline_tag', 'tags', 'params',
        'size_category', 'organization', 'has_audio', 'has_speech', 'has_music',
        'has_robot', 'has_bio', 'has_med', 'has_series', 'has_video', 'has_image',
        'has_text', 'has_science', 'is_audio_speech', 'is_biomed',
        'data_download_timestamp'
    ]

    try:
        # Load the dataset using the datasets library
        # It will be cached locally after the first run.
        dataset_dict = load_dataset(HF_DATASET_ID)
        
        if not dataset_dict:
             raise ValueError(f"Dataset '{HF_DATASET_ID}' loaded but appears empty.")

        # Get the name of the first split (e.g., 'train')
        split_name = list(dataset_dict.keys())[0]
        print(f"Using dataset split: '{split_name}'. Converting to Pandas.")

        # Convert the dataset object to a Pandas DataFrame
        df = dataset_dict[split_name].to_pandas()

        elapsed = time.time() - overall_start_time
        
        # Validate that the loaded data has the columns we expect.
        missing_cols = [col for col in expected_cols if col not in df.columns]
        if missing_cols:
            raise ValueError(f"Loaded dataset is missing expected columns: {missing_cols}.")

        # --- Diagnostic for 'has_robot' after loading ---
        if 'has_robot' in df.columns:
            robot_count = df['has_robot'].sum()
            print(f"DIAGNOSTIC (Dataset Load): 'has_robot' column found. Number of True values: {robot_count}")
        else:
            print("DIAGNOSTIC (Dataset Load): 'has_robot' column NOT FOUND.")
        # --- End Diagnostic ---

        msg = f"Successfully loaded dataset from HF Hub in {elapsed:.2f}s. Shape: {df.shape}"
        print(msg)
        return df, True, msg
        
    except Exception as e:
        err_msg = f"Failed to load dataset from Hugging Face Hub. Error: {e}"
        print(err_msg)
        return pd.DataFrame(), False, err_msg


def make_treemap_data(df, count_by, top_k=25, tag_filter=None, pipeline_filter=None, size_filter=None, skip_orgs=None):
    if df is None or df.empty: return pd.DataFrame()
    filtered_df = df.copy()
    col_map = { "Audio & Speech": "is_audio_speech", "Music": "has_music", "Robotics": "has_robot",
                "Biomedical": "is_biomed", "Time series": "has_series", "Sciences": "has_science",
                "Video": "has_video", "Images": "has_image", "Text": "has_text"}
    
    if tag_filter and tag_filter in col_map:
        target_col = col_map[tag_filter]
        if target_col in filtered_df.columns:
            filtered_df = filtered_df[filtered_df[target_col]]
        else:
            print(f"Warning: Tag filter column '{col_map[tag_filter]}' not found in DataFrame.")

    if pipeline_filter:
        if "pipeline_tag" in filtered_df.columns:
            # Ensure the comparison works even if pipeline_tag has NaNs or mixed types
            filtered_df = filtered_df[filtered_df["pipeline_tag"].astype(str) == pipeline_filter]
        else:
            print(f"Warning: 'pipeline_tag' column not found for filtering.")

    if size_filter and size_filter != "None" and size_filter in MODEL_SIZE_RANGES.keys():
        if 'size_category' in filtered_df.columns:
            filtered_df = filtered_df[filtered_df['size_category'] == size_filter]
        else:
            print("Warning: 'size_category' column not found for filtering.")

    if skip_orgs and len(skip_orgs) > 0:
        if "organization" in filtered_df.columns:
            filtered_df = filtered_df[~filtered_df["organization"].isin(skip_orgs)]
        else:
            print("Warning: 'organization' column not found for filtering.")

    if filtered_df.empty: return pd.DataFrame()

    # Ensure the metric column is numeric and handle potential missing values
    if count_by not in filtered_df.columns:
         print(f"Warning: Metric column '{count_by}' not found. Using 0.")
         filtered_df[count_by] = 0.0
    filtered_df[count_by] = pd.to_numeric(filtered_df[count_by], errors="coerce").fillna(0.0)
    
    # Group and get top organizations
    org_totals = filtered_df.groupby("organization")[count_by].sum().nlargest(top_k, keep='first')
    top_orgs_list = org_totals.index.tolist()
    
    # Prepare data for treemap
    treemap_data = filtered_df[filtered_df["organization"].isin(top_orgs_list)][["id", "organization", count_by]].copy()
    treemap_data["root"] = "models"
    # Ensure numeric again for the final slice
    treemap_data[count_by] = pd.to_numeric(treemap_data[count_by], errors="coerce").fillna(0.0) 
    return treemap_data

def create_treemap(treemap_data, count_by, title=None):
    if treemap_data.empty:
        fig = px.treemap(names=["No data matches filters"], parents=[""], values=[1])
        fig.update_layout(title="No data matches the selected filters", margin=dict(t=50, l=25, r=25, b=25))
        return fig
    fig = px.treemap(
        treemap_data, path=["root", "organization", "id"], values=count_by,
        title=title or f"HuggingFace Models - {count_by.capitalize()} by Organization",
        color_discrete_sequence=px.colors.qualitative.Plotly
    )
    fig.update_layout(margin=dict(t=50, l=25, r=25, b=25))
    fig.update_traces(textinfo="label+value+percent root", hovertemplate="<b>%{label}</b><br>%{value:,} " + count_by + "<br>%{percentRoot:.2%} of total<extra></extra>")
    return fig

with gr.Blocks(title="HuggingFace Model Explorer", fill_width=True) as demo:
    models_data_state = gr.State(pd.DataFrame())
    loading_complete_state = gr.State(False)

    with gr.Row(): gr.Markdown("# HuggingFace Models TreeMap Visualization")
    with gr.Row():
        with gr.Column(scale=1):
            count_by_dropdown = gr.Dropdown(label="Metric", choices=[("Downloads (last 30 days)", "downloads"), ("Downloads (All Time)", "downloadsAllTime"), ("Likes", "likes")], value="downloads")
            filter_choice_radio = gr.Radio(label="Filter Type", choices=["None", "Tag Filter", "Pipeline Filter"], value="None")
            tag_filter_dropdown = gr.Dropdown(label="Select Tag", choices=TAG_FILTER_CHOICES, value=None, visible=False)
            pipeline_filter_dropdown = gr.Dropdown(label="Select Pipeline Tag", choices=PIPELINE_TAGS, value=None, visible=False)
            size_filter_dropdown = gr.Dropdown(label="Model Size Filter", choices=["None"] + list(MODEL_SIZE_RANGES.keys()), value="None")
            top_k_slider = gr.Slider(label="Number of Top Organizations", minimum=5, maximum=50, value=25, step=5)
            skip_orgs_textbox = gr.Textbox(label="Organizations to Skip (comma-separated)", value="TheBloke,MaziyarPanahi,unsloth,modularai,Gensyn,bartowski")
            generate_plot_button = gr.Button(value="Generate Plot", variant="primary", interactive=False)

        with gr.Column(scale=3):
            plot_output = gr.Plot()
            status_message_md = gr.Markdown("Initializing...")
            data_info_md = gr.Markdown("")

    def _update_button_interactivity(is_loaded_flag):
        return gr.update(interactive=is_loaded_flag)
    loading_complete_state.change(fn=_update_button_interactivity, inputs=loading_complete_state, outputs=generate_plot_button)

    def _toggle_filters_visibility(choice):
        return gr.update(visible=choice == "Tag Filter"), gr.update(visible=choice == "Pipeline Filter")
    filter_choice_radio.change(fn=_toggle_filters_visibility, inputs=filter_choice_radio, outputs=[tag_filter_dropdown, pipeline_filter_dropdown])

    def ui_load_data_controller(progress=gr.Progress()):
        progress(0, desc=f"Loading dataset '{HF_DATASET_ID}' from Hugging Face Hub...")
        print("ui_load_data_controller called.")
        status_msg_ui = "Loading data..."
        data_info_text = ""
        current_df = pd.DataFrame()
        load_success_flag = False
        data_as_of_date_display = "N/A"
        try:
            # Call the load function that uses the datasets library.
            current_df, load_success_flag, status_msg_from_load = load_models_data()
            if load_success_flag:
                progress(0.9, desc="Processing loaded data...")
                # Get the data timestamp from the loaded file
                if 'data_download_timestamp' in current_df.columns and not current_df.empty and pd.notna(current_df['data_download_timestamp'].iloc[0]):
                    timestamp_from_parquet = pd.to_datetime(current_df['data_download_timestamp'].iloc[0])
                    # Ensure the timestamp is timezone-aware for consistent formatting
                    if timestamp_from_parquet.tzinfo is None:
                        timestamp_from_parquet = timestamp_from_parquet.tz_localize('UTC')
                    data_as_of_date_display = timestamp_from_parquet.strftime('%B %d, %Y, %H:%M:%S %Z')
                else:
                    data_as_of_date_display = "Pre-processed (date unavailable)"
                
                # Create summary text for the UI
                size_dist_lines = []
                if 'size_category' in current_df.columns:
                    for cat in MODEL_SIZE_RANGES.keys():
                        count = (current_df['size_category'] == cat).sum()
                        size_dist_lines.append(f"  - {cat}: {count:,} models")
                else: size_dist_lines.append("  - Size category information not available.")
                size_dist = "\n".join(size_dist_lines)
                
                data_info_text = (f"### Data Information\n"
                                  f"- Overall Status: {status_msg_from_load}\n" 
                                  f"- Total models loaded: {len(current_df):,}\n"
                                  f"- Data as of: {data_as_of_date_display}\n"
                                  f"- Size categories:\n{size_dist}")
                
                status_msg_ui = "Data loaded successfully. Ready to generate plot."
            else: 
                data_info_text = f"### Data Load Failed\n- {status_msg_from_load}"
                status_msg_ui = status_msg_from_load 
        except Exception as e:
            status_msg_ui = f"An unexpected error occurred in ui_load_data_controller: {str(e)}"
            data_info_text = f"### Critical Error\n- {status_msg_ui}"
            print(f"Critical error in ui_load_data_controller: {e}")
            load_success_flag = False 
        return current_df, load_success_flag, data_info_text, status_msg_ui

    def ui_generate_plot_controller(metric_choice, filter_type, tag_choice, pipeline_choice, 
                                   size_choice, k_orgs, skip_orgs_input, df_current_models, progress=gr.Progress()):
        if df_current_models is None or df_current_models.empty:
            empty_fig = create_treemap(pd.DataFrame(), metric_choice, "Error: Model Data Not Loaded")
            error_msg = "Model data is not loaded or is empty. Please wait for data to load."
            gr.Warning(error_msg)
            return empty_fig, error_msg
        
        progress(0.1, desc="Preparing data for visualization...")
        
        tag_to_use = tag_choice if filter_type == "Tag Filter" else None
        pipeline_to_use = pipeline_choice if filter_type == "Pipeline Filter" else None
        size_to_use = size_choice if size_choice != "None" else None
        orgs_to_skip = [org.strip() for org in skip_orgs_input.split(',') if org.strip()] if skip_orgs_input else []
        

        treemap_df = make_treemap_data(df_current_models, metric_choice, k_orgs, tag_to_use, pipeline_to_use, size_to_use, orgs_to_skip)
        
        progress(0.7, desc="Generating Plotly visualization...")

        title_labels = {"downloads": "Downloads (last 30 days)", "downloadsAllTime": "Downloads (All Time)", "likes": "Likes"}
        chart_title = f"HuggingFace Models - {title_labels.get(metric_choice, metric_choice)} by Organization"
        plotly_fig = create_treemap(treemap_df, metric_choice, chart_title)
        
        if treemap_df.empty:
            plot_stats_md = "No data matches the selected filters. Try adjusting your filters."
        else:
            total_items_in_plot = len(treemap_df['id'].unique())
            total_value_in_plot = treemap_df[metric_choice].sum()
            plot_stats_md = (f"## Plot Statistics\n- **Models shown**: {total_items_in_plot:,}\n- **Total {metric_choice}**: {int(total_value_in_plot):,}")
        
        return plotly_fig, plot_stats_md

    # On app load, call the controller to fetch data using the datasets library.
    demo.load(
        fn=ui_load_data_controller,
        inputs=[],
        outputs=[models_data_state, loading_complete_state, data_info_md, status_message_md]
    )

    generate_plot_button.click(
        fn=ui_generate_plot_controller,
        inputs=[count_by_dropdown, filter_choice_radio, tag_filter_dropdown, pipeline_filter_dropdown,
                size_filter_dropdown, top_k_slider, skip_orgs_textbox, models_data_state],
        outputs=[plot_output, status_message_md]
    )

if __name__ == "__main__":
    print(f"Application starting. Data will be loaded from Hugging Face dataset: {HF_DATASET_ID}")
    # Increase the queue size for potentially busy traffic if hosted
    demo.queue().launch()

# --- END OF FILE app.py ---