Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import time | |
def render_batch_operations(): | |
"""Render the batch operations page""" | |
st.title("๐ Batch Operations") | |
if "models" not in st.session_state or not st.session_state.models: | |
st.info("No models found. Please create repositories first.") | |
if st.button("Go to Dashboard", use_container_width=True): | |
st.session_state.page = "home" | |
st.experimental_rerun() | |
return | |
# Create a dataframe for model selection | |
models_data = [] | |
for model in st.session_state.models: | |
try: | |
models_data.append({ | |
"Select": False, # Checkbox column | |
"Model Name": model.modelId.split("/")[-1], | |
"Full ID": model.modelId, | |
"Downloads": getattr(model, "downloads", 0), | |
"Likes": getattr(model, "likes", 0), | |
"Private": getattr(model, "private", False), | |
"Tags": ", ".join(getattr(model, "tags", []) or []), | |
}) | |
except Exception as e: | |
st.warning(f"Error processing model {getattr(model, 'modelId', 'unknown')}: {str(e)}") | |
if not models_data: | |
st.error("Failed to process model data.") | |
return | |
# Convert to DataFrame for display | |
df = pd.DataFrame(models_data) | |
st.markdown("### Select Models for Batch Operations") | |
st.markdown("Use the checkboxes to select models you want to operate on.") | |
# Editable dataframe | |
edited_df = st.data_editor( | |
df, | |
column_config={ | |
"Select": st.column_config.CheckboxColumn( | |
"Select", | |
help="Select for batch operations", | |
default=False, | |
), | |
"Full ID": st.column_config.TextColumn( | |
"Repository ID", | |
help="Full repository ID", | |
disabled=True, | |
), | |
"Downloads": st.column_config.NumberColumn( | |
"Downloads", | |
help="Number of downloads", | |
disabled=True, | |
), | |
"Likes": st.column_config.NumberColumn( | |
"Likes", | |
help="Number of likes", | |
disabled=True, | |
), | |
"Private": st.column_config.CheckboxColumn( | |
"Private", | |
help="Repository visibility", | |
disabled=True, | |
), | |
"Tags": st.column_config.TextColumn( | |
"Tags", | |
help="Current tags", | |
disabled=True, | |
), | |
}, | |
hide_index=True, | |
use_container_width=True, | |
) | |
# Get selected models | |
selected_models = edited_df[edited_df["Select"] == True] | |
selected_count = len(selected_models) | |
if selected_count > 0: | |
st.success(f"Selected {selected_count} models for batch operations.") | |
else: | |
st.info("Please select at least one model to perform batch operations.") | |
# Batch operations tabs | |
if selected_count > 0: | |
tab1, tab2, tab3, tab4 = st.tabs(["Update Tags", "Update Visibility", "Add Collaborators", "Delete"]) | |
with tab1: | |
st.subheader("Update Tags") | |
# Get available tags | |
available_tags = st.session_state.client.get_model_tags() | |
# Tags selection | |
selected_tags = st.multiselect( | |
"Select tags to add to all selected models", | |
options=available_tags, | |
help="These tags will be added to all selected models" | |
) | |
tags_action = st.radio( | |
"Tag Operation", | |
["Add tags (keep existing)", "Replace tags (remove existing)"], | |
index=0 | |
) | |
if st.button("Apply Tags", use_container_width=True, type="primary"): | |
if not selected_tags: | |
st.warning("Please select at least one tag to add.") | |
else: | |
with st.spinner(f"Updating tags for {selected_count} models..."): | |
# Track success and failures | |
successes = 0 | |
failures = [] | |
# Process each selected model | |
for idx, row in selected_models.iterrows(): | |
try: | |
repo_id = row["Full ID"] | |
model_info = st.session_state.client.get_model_info(repo_id) | |
if model_info: | |
# Get current model card content | |
try: | |
model_card_url = f"https://huggingface.co/{repo_id}/raw/main/README.md" | |
response = st.session_state.client.api._get_paginated(model_card_url) | |
if response.status_code != 200: | |
failures.append((repo_id, "Failed to fetch model card")) | |
continue | |
model_card_content = response.text | |
# Update tags in the model card | |
import re | |
yaml_match = re.search(r"---\s+(.*?)\s+---", model_card_content, re.DOTALL) | |
if yaml_match: | |
yaml_content = yaml_match.group(1) | |
tags_match = re.search(r"tags:\s*((?:- .*?\n)+)", yaml_content, re.DOTALL) | |
if tags_match and tags_action == "Add tags (keep existing)": | |
# Extract existing tags | |
existing_tags = [ | |
line.strip("- \n") | |
for line in tags_match.group(1).split("\n") | |
if line.strip().startswith("-") | |
] | |
# Combine existing and new tags | |
all_tags = list(set(existing_tags + selected_tags)) | |
# Replace tags section | |
new_yaml = yaml_content.replace( | |
tags_match.group(0), | |
f"tags:\n" + "\n".join([f"- {tag}" for tag in all_tags]) + "\n", | |
) | |
# Update the model card | |
new_content = model_card_content.replace( | |
yaml_match.group(0), f"---\n{new_yaml}---" | |
) | |
elif tags_match and tags_action == "Replace tags (remove existing)": | |
# Replace tags section | |
new_yaml = yaml_content.replace( | |
tags_match.group(0), | |
f"tags:\n" + "\n".join([f"- {tag}" for tag in selected_tags]) + "\n", | |
) | |
# Update the model card | |
new_content = model_card_content.replace( | |
yaml_match.group(0), f"---\n{new_yaml}---" | |
) | |
elif tags_action == "Add tags (keep existing)" or tags_action == "Replace tags (remove existing)": | |
# Add tags section if it doesn't exist | |
new_yaml = yaml_content + f"\ntags:\n" + "\n".join([f"- {tag}" for tag in selected_tags]) + "\n" | |
# Update the model card | |
new_content = model_card_content.replace( | |
yaml_match.group(0), f"---\n{new_yaml}---" | |
) | |
else: | |
failures.append((repo_id, "Failed to update tags in model card")) | |
continue | |
else: | |
# Add YAML frontmatter with tags | |
tags_yaml = "---\ntags:\n" + "\n".join([f"- {tag}" for tag in selected_tags]) + "\n---\n\n" | |
new_content = tags_yaml + model_card_content | |
# Update the model card | |
success, _ = st.session_state.client.update_model_card(repo_id, new_content) | |
if success: | |
successes += 1 | |
else: | |
failures.append((repo_id, "Failed to update model card")) | |
except Exception as e: | |
failures.append((repo_id, str(e))) | |
else: | |
failures.append((repo_id, "Failed to fetch model info")) | |
except Exception as e: | |
failures.append((row["Full ID"], str(e))) | |
# Show results | |
if successes > 0: | |
st.success(f"Successfully updated tags for {successes} models") | |
if failures: | |
st.error(f"Failed to update {len(failures)} models") | |
for repo_id, error in failures: | |
st.warning(f"Failed to update {repo_id}: {error}") | |
# Refresh models after batch operation | |
st.session_state.models = st.session_state.client.get_user_models() | |
st.info("Model list refreshed. You may need to wait a few minutes for all changes to propagate.") | |
with tab2: | |
st.subheader("Update Visibility") | |
visibility = st.radio( | |
"Set visibility for selected models", | |
["Public", "Private"], | |
index=0, | |
help="Change the visibility of all selected models" | |
) | |
if st.button("Update Visibility", use_container_width=True, type="primary"): | |
with st.spinner(f"Updating visibility for {selected_count} models..."): | |
st.warning("This feature requires Hugging Face Pro or Enterprise subscription.") | |
st.info("In the actual implementation, this would update the models' visibility settings.") | |
# This is a placeholder for the actual implementation | |
time.sleep(2) | |
st.success(f"Successfully updated visibility for {selected_count} models") | |
with tab3: | |
st.subheader("Add Collaborators") | |
collaborators = st.text_area( | |
"Enter usernames of collaborators (one per line)", | |
help="These users will be added as collaborators to all selected models" | |
) | |
role = st.selectbox( | |
"Collaborator role", | |
["read", "write", "admin"], | |
index=0 | |
) | |
if st.button("Add Collaborators", use_container_width=True, type="primary"): | |
if not collaborators.strip(): | |
st.warning("Please enter at least one collaborator username.") | |
else: | |
with st.spinner(f"Adding collaborators to {selected_count} models..."): | |
# This is a placeholder for the actual implementation | |
collaborator_list = [c.strip() for c in collaborators.split("\n") if c.strip()] | |
st.info(f"Adding {len(collaborator_list)} collaborators with '{role}' role to {selected_count} models.") | |
st.warning("This feature requires Hugging Face Pro or Enterprise subscription.") | |
time.sleep(2) | |
st.success(f"Successfully added collaborators to {selected_count} models") | |
with tab4: | |
st.subheader("โ ๏ธ Delete Models") | |
st.warning( | |
"This operation is irreversible. All selected models will be permanently deleted." | |
) | |
# Confirmation | |
confirmation = st.text_input( | |
"Type 'DELETE' to confirm deletion of all selected models", | |
key="batch_delete_confirm" | |
) | |
if st.button("Delete Selected Models", use_container_width=True, type="primary"): | |
if confirmation != "DELETE": | |
st.error("Please type 'DELETE' to confirm.") | |
else: | |
with st.spinner(f"Deleting {selected_count} models..."): | |
# Track success and failures | |
successes = 0 | |
failures = [] | |
# Process each selected model | |
for idx, row in selected_models.iterrows(): | |
try: | |
repo_id = row["Full ID"] | |
# Delete the repository | |
success, message = st.session_state.client.delete_model_repository(repo_id) | |
if success: | |
successes += 1 | |
else: | |
failures.append((repo_id, message)) | |
except Exception as e: | |
failures.append((row["Full ID"], str(e))) | |
# Show results | |
if successes > 0: | |
st.success(f"Successfully deleted {successes} models") | |
if failures: | |
st.error(f"Failed to delete {len(failures)} models") | |
for repo_id, error in failures: | |
st.warning(f"Failed to delete {repo_id}: {error}") | |
# Refresh models after batch operation | |
st.session_state.models = st.session_state.client.get_user_models() | |