import streamlit as st import pandas as pd import time def render_batch_operations(): """Render the batch operations page""" st.title("🔄 Batch Operations") if "models" not in st.session_state or not st.session_state.models: st.info("No models found. Please create repositories first.") if st.button("Go to Dashboard", use_container_width=True): st.session_state.page = "home" st.experimental_rerun() return # Create a dataframe for model selection models_data = [] for model in st.session_state.models: try: models_data.append({ "Select": False, # Checkbox column "Model Name": model.modelId.split("/")[-1], "Full ID": model.modelId, "Downloads": getattr(model, "downloads", 0), "Likes": getattr(model, "likes", 0), "Private": getattr(model, "private", False), "Tags": ", ".join(getattr(model, "tags", []) or []), }) except Exception as e: st.warning(f"Error processing model {getattr(model, 'modelId', 'unknown')}: {str(e)}") if not models_data: st.error("Failed to process model data.") return # Convert to DataFrame for display df = pd.DataFrame(models_data) st.markdown("### Select Models for Batch Operations") st.markdown("Use the checkboxes to select models you want to operate on.") # Editable dataframe edited_df = st.data_editor( df, column_config={ "Select": st.column_config.CheckboxColumn( "Select", help="Select for batch operations", default=False, ), "Full ID": st.column_config.TextColumn( "Repository ID", help="Full repository ID", disabled=True, ), "Downloads": st.column_config.NumberColumn( "Downloads", help="Number of downloads", disabled=True, ), "Likes": st.column_config.NumberColumn( "Likes", help="Number of likes", disabled=True, ), "Private": st.column_config.CheckboxColumn( "Private", help="Repository visibility", disabled=True, ), "Tags": st.column_config.TextColumn( "Tags", help="Current tags", disabled=True, ), }, hide_index=True, use_container_width=True, ) # Get selected models selected_models = edited_df[edited_df["Select"] == True] selected_count = len(selected_models) if selected_count > 0: st.success(f"Selected {selected_count} models for batch operations.") else: st.info("Please select at least one model to perform batch operations.") # Batch operations tabs if selected_count > 0: tab1, tab2, tab3, tab4 = st.tabs(["Update Tags", "Update Visibility", "Add Collaborators", "Delete"]) with tab1: st.subheader("Update Tags") # Get available tags available_tags = st.session_state.client.get_model_tags() # Tags selection selected_tags = st.multiselect( "Select tags to add to all selected models", options=available_tags, help="These tags will be added to all selected models" ) tags_action = st.radio( "Tag Operation", ["Add tags (keep existing)", "Replace tags (remove existing)"], index=0 ) if st.button("Apply Tags", use_container_width=True, type="primary"): if not selected_tags: st.warning("Please select at least one tag to add.") else: with st.spinner(f"Updating tags for {selected_count} models..."): # Track success and failures successes = 0 failures = [] # Process each selected model for idx, row in selected_models.iterrows(): try: repo_id = row["Full ID"] model_info = st.session_state.client.get_model_info(repo_id) if model_info: # Get current model card content try: model_card_url = f"https://huggingface.co/{repo_id}/raw/main/README.md" response = st.session_state.client.api._get_paginated(model_card_url) if response.status_code != 200: failures.append((repo_id, "Failed to fetch model card")) continue model_card_content = response.text # Update tags in the model card import re yaml_match = re.search(r"---\s+(.*?)\s+---", model_card_content, re.DOTALL) if yaml_match: yaml_content = yaml_match.group(1) tags_match = re.search(r"tags:\s*((?:- .*?\n)+)", yaml_content, re.DOTALL) if tags_match and tags_action == "Add tags (keep existing)": # Extract existing tags existing_tags = [ line.strip("- \n") for line in tags_match.group(1).split("\n") if line.strip().startswith("-") ] # Combine existing and new tags all_tags = list(set(existing_tags + selected_tags)) # Replace tags section new_yaml = yaml_content.replace( tags_match.group(0), f"tags:\n" + "\n".join([f"- {tag}" for tag in all_tags]) + "\n", ) # Update the model card new_content = model_card_content.replace( yaml_match.group(0), f"---\n{new_yaml}---" ) elif tags_match and tags_action == "Replace tags (remove existing)": # Replace tags section new_yaml = yaml_content.replace( tags_match.group(0), f"tags:\n" + "\n".join([f"- {tag}" for tag in selected_tags]) + "\n", ) # Update the model card new_content = model_card_content.replace( yaml_match.group(0), f"---\n{new_yaml}---" ) elif tags_action == "Add tags (keep existing)" or tags_action == "Replace tags (remove existing)": # Add tags section if it doesn't exist new_yaml = yaml_content + f"\ntags:\n" + "\n".join([f"- {tag}" for tag in selected_tags]) + "\n" # Update the model card new_content = model_card_content.replace( yaml_match.group(0), f"---\n{new_yaml}---" ) else: failures.append((repo_id, "Failed to update tags in model card")) continue else: # Add YAML frontmatter with tags tags_yaml = "---\ntags:\n" + "\n".join([f"- {tag}" for tag in selected_tags]) + "\n---\n\n" new_content = tags_yaml + model_card_content # Update the model card success, _ = st.session_state.client.update_model_card(repo_id, new_content) if success: successes += 1 else: failures.append((repo_id, "Failed to update model card")) except Exception as e: failures.append((repo_id, str(e))) else: failures.append((repo_id, "Failed to fetch model info")) except Exception as e: failures.append((row["Full ID"], str(e))) # Show results if successes > 0: st.success(f"Successfully updated tags for {successes} models") if failures: st.error(f"Failed to update {len(failures)} models") for repo_id, error in failures: st.warning(f"Failed to update {repo_id}: {error}") # Refresh models after batch operation st.session_state.models = st.session_state.client.get_user_models() st.info("Model list refreshed. You may need to wait a few minutes for all changes to propagate.") with tab2: st.subheader("Update Visibility") visibility = st.radio( "Set visibility for selected models", ["Public", "Private"], index=0, help="Change the visibility of all selected models" ) if st.button("Update Visibility", use_container_width=True, type="primary"): with st.spinner(f"Updating visibility for {selected_count} models..."): st.warning("This feature requires Hugging Face Pro or Enterprise subscription.") st.info("In the actual implementation, this would update the models' visibility settings.") # This is a placeholder for the actual implementation time.sleep(2) st.success(f"Successfully updated visibility for {selected_count} models") with tab3: st.subheader("Add Collaborators") collaborators = st.text_area( "Enter usernames of collaborators (one per line)", help="These users will be added as collaborators to all selected models" ) role = st.selectbox( "Collaborator role", ["read", "write", "admin"], index=0 ) if st.button("Add Collaborators", use_container_width=True, type="primary"): if not collaborators.strip(): st.warning("Please enter at least one collaborator username.") else: with st.spinner(f"Adding collaborators to {selected_count} models..."): # This is a placeholder for the actual implementation collaborator_list = [c.strip() for c in collaborators.split("\n") if c.strip()] st.info(f"Adding {len(collaborator_list)} collaborators with '{role}' role to {selected_count} models.") st.warning("This feature requires Hugging Face Pro or Enterprise subscription.") time.sleep(2) st.success(f"Successfully added collaborators to {selected_count} models") with tab4: st.subheader("⚠️ Delete Models") st.warning( "This operation is irreversible. All selected models will be permanently deleted." ) # Confirmation confirmation = st.text_input( "Type 'DELETE' to confirm deletion of all selected models", key="batch_delete_confirm" ) if st.button("Delete Selected Models", use_container_width=True, type="primary"): if confirmation != "DELETE": st.error("Please type 'DELETE' to confirm.") else: with st.spinner(f"Deleting {selected_count} models..."): # Track success and failures successes = 0 failures = [] # Process each selected model for idx, row in selected_models.iterrows(): try: repo_id = row["Full ID"] # Delete the repository success, message = st.session_state.client.delete_model_repository(repo_id) if success: successes += 1 else: failures.append((repo_id, message)) except Exception as e: failures.append((row["Full ID"], str(e))) # Show results if successes > 0: st.success(f"Successfully deleted {successes} models") if failures: st.error(f"Failed to delete {len(failures)} models") for repo_id, error in failures: st.warning(f"Failed to delete {repo_id}: {error}") # Refresh models after batch operation st.session_state.models = st.session_state.client.get_user_models()