ModelHubManager / pages /batch_operations.py
S-Dreamer's picture
Upload 31 files
74dd3f1 verified
import streamlit as st
import pandas as pd
import time
def render_batch_operations():
"""Render the batch operations page"""
st.title("๐Ÿ”„ Batch Operations")
if "models" not in st.session_state or not st.session_state.models:
st.info("No models found. Please create repositories first.")
if st.button("Go to Dashboard", use_container_width=True):
st.session_state.page = "home"
st.experimental_rerun()
return
# Create a dataframe for model selection
models_data = []
for model in st.session_state.models:
try:
models_data.append({
"Select": False, # Checkbox column
"Model Name": model.modelId.split("/")[-1],
"Full ID": model.modelId,
"Downloads": getattr(model, "downloads", 0),
"Likes": getattr(model, "likes", 0),
"Private": getattr(model, "private", False),
"Tags": ", ".join(getattr(model, "tags", []) or []),
})
except Exception as e:
st.warning(f"Error processing model {getattr(model, 'modelId', 'unknown')}: {str(e)}")
if not models_data:
st.error("Failed to process model data.")
return
# Convert to DataFrame for display
df = pd.DataFrame(models_data)
st.markdown("### Select Models for Batch Operations")
st.markdown("Use the checkboxes to select models you want to operate on.")
# Editable dataframe
edited_df = st.data_editor(
df,
column_config={
"Select": st.column_config.CheckboxColumn(
"Select",
help="Select for batch operations",
default=False,
),
"Full ID": st.column_config.TextColumn(
"Repository ID",
help="Full repository ID",
disabled=True,
),
"Downloads": st.column_config.NumberColumn(
"Downloads",
help="Number of downloads",
disabled=True,
),
"Likes": st.column_config.NumberColumn(
"Likes",
help="Number of likes",
disabled=True,
),
"Private": st.column_config.CheckboxColumn(
"Private",
help="Repository visibility",
disabled=True,
),
"Tags": st.column_config.TextColumn(
"Tags",
help="Current tags",
disabled=True,
),
},
hide_index=True,
use_container_width=True,
)
# Get selected models
selected_models = edited_df[edited_df["Select"] == True]
selected_count = len(selected_models)
if selected_count > 0:
st.success(f"Selected {selected_count} models for batch operations.")
else:
st.info("Please select at least one model to perform batch operations.")
# Batch operations tabs
if selected_count > 0:
tab1, tab2, tab3, tab4 = st.tabs(["Update Tags", "Update Visibility", "Add Collaborators", "Delete"])
with tab1:
st.subheader("Update Tags")
# Get available tags
available_tags = st.session_state.client.get_model_tags()
# Tags selection
selected_tags = st.multiselect(
"Select tags to add to all selected models",
options=available_tags,
help="These tags will be added to all selected models"
)
tags_action = st.radio(
"Tag Operation",
["Add tags (keep existing)", "Replace tags (remove existing)"],
index=0
)
if st.button("Apply Tags", use_container_width=True, type="primary"):
if not selected_tags:
st.warning("Please select at least one tag to add.")
else:
with st.spinner(f"Updating tags for {selected_count} models..."):
# Track success and failures
successes = 0
failures = []
# Process each selected model
for idx, row in selected_models.iterrows():
try:
repo_id = row["Full ID"]
model_info = st.session_state.client.get_model_info(repo_id)
if model_info:
# Get current model card content
try:
model_card_url = f"https://huggingface.co/{repo_id}/raw/main/README.md"
response = st.session_state.client.api._get_paginated(model_card_url)
if response.status_code != 200:
failures.append((repo_id, "Failed to fetch model card"))
continue
model_card_content = response.text
# Update tags in the model card
import re
yaml_match = re.search(r"---\s+(.*?)\s+---", model_card_content, re.DOTALL)
if yaml_match:
yaml_content = yaml_match.group(1)
tags_match = re.search(r"tags:\s*((?:- .*?\n)+)", yaml_content, re.DOTALL)
if tags_match and tags_action == "Add tags (keep existing)":
# Extract existing tags
existing_tags = [
line.strip("- \n")
for line in tags_match.group(1).split("\n")
if line.strip().startswith("-")
]
# Combine existing and new tags
all_tags = list(set(existing_tags + selected_tags))
# Replace tags section
new_yaml = yaml_content.replace(
tags_match.group(0),
f"tags:\n" + "\n".join([f"- {tag}" for tag in all_tags]) + "\n",
)
# Update the model card
new_content = model_card_content.replace(
yaml_match.group(0), f"---\n{new_yaml}---"
)
elif tags_match and tags_action == "Replace tags (remove existing)":
# Replace tags section
new_yaml = yaml_content.replace(
tags_match.group(0),
f"tags:\n" + "\n".join([f"- {tag}" for tag in selected_tags]) + "\n",
)
# Update the model card
new_content = model_card_content.replace(
yaml_match.group(0), f"---\n{new_yaml}---"
)
elif tags_action == "Add tags (keep existing)" or tags_action == "Replace tags (remove existing)":
# Add tags section if it doesn't exist
new_yaml = yaml_content + f"\ntags:\n" + "\n".join([f"- {tag}" for tag in selected_tags]) + "\n"
# Update the model card
new_content = model_card_content.replace(
yaml_match.group(0), f"---\n{new_yaml}---"
)
else:
failures.append((repo_id, "Failed to update tags in model card"))
continue
else:
# Add YAML frontmatter with tags
tags_yaml = "---\ntags:\n" + "\n".join([f"- {tag}" for tag in selected_tags]) + "\n---\n\n"
new_content = tags_yaml + model_card_content
# Update the model card
success, _ = st.session_state.client.update_model_card(repo_id, new_content)
if success:
successes += 1
else:
failures.append((repo_id, "Failed to update model card"))
except Exception as e:
failures.append((repo_id, str(e)))
else:
failures.append((repo_id, "Failed to fetch model info"))
except Exception as e:
failures.append((row["Full ID"], str(e)))
# Show results
if successes > 0:
st.success(f"Successfully updated tags for {successes} models")
if failures:
st.error(f"Failed to update {len(failures)} models")
for repo_id, error in failures:
st.warning(f"Failed to update {repo_id}: {error}")
# Refresh models after batch operation
st.session_state.models = st.session_state.client.get_user_models()
st.info("Model list refreshed. You may need to wait a few minutes for all changes to propagate.")
with tab2:
st.subheader("Update Visibility")
visibility = st.radio(
"Set visibility for selected models",
["Public", "Private"],
index=0,
help="Change the visibility of all selected models"
)
if st.button("Update Visibility", use_container_width=True, type="primary"):
with st.spinner(f"Updating visibility for {selected_count} models..."):
st.warning("This feature requires Hugging Face Pro or Enterprise subscription.")
st.info("In the actual implementation, this would update the models' visibility settings.")
# This is a placeholder for the actual implementation
time.sleep(2)
st.success(f"Successfully updated visibility for {selected_count} models")
with tab3:
st.subheader("Add Collaborators")
collaborators = st.text_area(
"Enter usernames of collaborators (one per line)",
help="These users will be added as collaborators to all selected models"
)
role = st.selectbox(
"Collaborator role",
["read", "write", "admin"],
index=0
)
if st.button("Add Collaborators", use_container_width=True, type="primary"):
if not collaborators.strip():
st.warning("Please enter at least one collaborator username.")
else:
with st.spinner(f"Adding collaborators to {selected_count} models..."):
# This is a placeholder for the actual implementation
collaborator_list = [c.strip() for c in collaborators.split("\n") if c.strip()]
st.info(f"Adding {len(collaborator_list)} collaborators with '{role}' role to {selected_count} models.")
st.warning("This feature requires Hugging Face Pro or Enterprise subscription.")
time.sleep(2)
st.success(f"Successfully added collaborators to {selected_count} models")
with tab4:
st.subheader("โš ๏ธ Delete Models")
st.warning(
"This operation is irreversible. All selected models will be permanently deleted."
)
# Confirmation
confirmation = st.text_input(
"Type 'DELETE' to confirm deletion of all selected models",
key="batch_delete_confirm"
)
if st.button("Delete Selected Models", use_container_width=True, type="primary"):
if confirmation != "DELETE":
st.error("Please type 'DELETE' to confirm.")
else:
with st.spinner(f"Deleting {selected_count} models..."):
# Track success and failures
successes = 0
failures = []
# Process each selected model
for idx, row in selected_models.iterrows():
try:
repo_id = row["Full ID"]
# Delete the repository
success, message = st.session_state.client.delete_model_repository(repo_id)
if success:
successes += 1
else:
failures.append((repo_id, message))
except Exception as e:
failures.append((row["Full ID"], str(e)))
# Show results
if successes > 0:
st.success(f"Successfully deleted {successes} models")
if failures:
st.error(f"Failed to delete {len(failures)} models")
for repo_id, error in failures:
st.warning(f"Failed to delete {repo_id}: {error}")
# Refresh models after batch operation
st.session_state.models = st.session_state.client.get_user_models()