import streamlit as st import re import json import time def model_documentation_generator(model_info): """Generate comprehensive model documentation based on metadata""" if not model_info: st.error("Model information not found") return st.subheader("🔄 Automated Model Documentation Generator") st.markdown("This tool generates a comprehensive model card based on model metadata and your input.") # Extract existing model card content if available model_card_content = "" yaml_content = "" markdown_content = "" try: repo_id = model_info.modelId model_card_url = f"https://huggingface.co/{repo_id}/raw/main/README.md" response = st.session_state.client.api._get_paginated(model_card_url) if response.status_code == 200: model_card_content = response.text # Extract YAML frontmatter yaml_match = re.search(r"---\s+(.*?)\s+---", model_card_content, re.DOTALL) if yaml_match: yaml_content = yaml_match.group(1) # Extract markdown content (everything after frontmatter) markdown_match = re.search(r"---\s+.*?\s+---\s*(.*)", model_card_content, re.DOTALL) if markdown_match: markdown_content = markdown_match.group(1).strip() except Exception as e: st.warning(f"Couldn't load model card: {str(e)}") # Form for model metadata input with st.form("model_doc_form"): st.markdown("### Model Metadata") # Basic Information st.markdown("#### Basic Information") col1, col2 = st.columns(2) with col1: # Extract model name from repo ID model_name = model_info.modelId.split("/")[-1] model_title = st.text_input("Model Title", value=model_name.replace("-", " ").title()) with col2: # Model type selection model_type_options = [ "Text Classification", "Token Classification", "Question Answering", "Summarization", "Translation", "Text Generation", "Image Classification", "Object Detection", "Other" ] # Try to determine model type from tags default_type_index = 0 tags = getattr(model_info, "tags", []) for i, option in enumerate(model_type_options): option_key = option.lower().replace(" ", "-") if option_key in tags or option_key.replace("-", "_") in tags: default_type_index = i break model_type = st.selectbox( "Model Type", model_type_options, index=default_type_index ) # Model description description = st.text_area( "Model Description", value=getattr(model_info, "description", "") or "", height=100, help="A brief overview of what the model does" ) # Technical Information st.markdown("#### Technical Information") col1, col2 = st.columns(2) with col1: # Model Architecture architecture_options = [ "BERT", "GPT-2", "T5", "RoBERTa", "DeBERTa", "DistilBERT", "BART", "ResNet", "YOLO", "Other" ] architecture = st.selectbox("Model Architecture", architecture_options) # Framework framework_options = ["PyTorch", "TensorFlow", "JAX", "Other"] framework = st.selectbox("Framework", framework_options) with col2: # Model size model_size = st.text_input("Model Size (e.g., 110M parameters)") # Language language_options = ["English", "French", "German", "Spanish", "Chinese", "Japanese", "Multilingual", "Other"] language = st.selectbox("Language", language_options) # Training Information st.markdown("#### Training Information") col1, col2 = st.columns(2) with col1: # Training Dataset training_data = st.text_input("Training Dataset(s)") # Training compute training_compute = st.text_input("Training Infrastructure (e.g., TPU v3-8, 4x A100)") with col2: # Evaluation Dataset eval_data = st.text_input("Evaluation Dataset(s)") # Training time training_time = st.text_input("Training Time (e.g., 3 days, 12 hours)") # Performance Metrics st.markdown("#### Performance Metrics") metrics_data = st.text_area( "Performance Metrics (one per line, e.g., 'Accuracy: 0.92')", height=100, help="Key metrics and their values" ) # Limitations st.markdown("#### Limitations and Biases") limitations = st.text_area( "Known Limitations and Biases", height=100, help="Document any known limitations, biases, or ethical considerations" ) # Usage Information st.markdown("#### Usage Information") use_cases = st.text_area( "Intended Use Cases", height=100, help="Describe how the model should be used" ) code_example = st.text_area( "Code Example", height=150, value=f""" ```python from transformers import AutoTokenizer, AutoModel tokenizer = AutoTokenizer.from_pretrained("{model_info.modelId}") model = AutoModel.from_pretrained("{model_info.modelId}") inputs = tokenizer("Hello, world!", return_tensors="pt") outputs = model(**inputs) ``` """, help="Provide a simple code example showing how to use the model" ) # License and Citation st.markdown("#### License and Citation") license_options = ["MIT", "Apache-2.0", "GPL-3.0", "CC-BY-SA-4.0", "CC-BY-4.0", "Proprietary", "Other"] license_type = st.selectbox("License", license_options) citation = st.text_area( "Citation Information", height=100, help="Provide citation information if applicable" ) # Tags st.markdown("#### Tags") # Get available tags available_tags = st.session_state.client.get_model_tags() # Extract existing tags existing_tags = [] if yaml_content: tags_match = re.search(r"tags:\s*((?:- .*?\n)+)", yaml_content, re.DOTALL) if tags_match: existing_tags = [ line.strip("- \n") for line in tags_match.group(1).split("\n") if line.strip().startswith("-") ] selected_tags = st.multiselect( "Select tags for your model", options=available_tags, default=existing_tags, help="Tags help others discover your model" ) # Advanced options with st.expander("Advanced Options"): keep_existing_content = st.checkbox( "Keep existing custom content", value=True, help="If checked, we'll try to preserve custom sections from your existing model card" ) additional_sections = st.text_area( "Additional Custom Sections (in Markdown)", height=200, help="Add any additional custom sections in Markdown format" ) # Submit button submitted = st.form_submit_button("Generate Model Card", use_container_width=True) if submitted: with st.spinner("Generating comprehensive model card..."): try: # Parse performance metrics metrics_list = [] for line in metrics_data.split("\n"): line = line.strip() if line: metrics_list.append(line) # Generate YAML frontmatter yaml_frontmatter = f"""tags: {chr(10).join(['- ' + tag for tag in selected_tags])} license: {license_type}""" if language and language != "Other": yaml_frontmatter += f"\nlanguage: {language.lower()}" if model_type and model_type != "Other": yaml_frontmatter += f"\npipeline_tag: {model_type.lower().replace(' ', '-')}" # Generate markdown content md_content = f"""# {model_title} {description} ## Model Description This model is a {architecture}-based model for {model_type} tasks. It was developed using {framework} and consists of {model_size if model_size else "multiple"} parameters. """ # Training section if training_data or eval_data or training_compute or training_time: md_content += "## Training and Evaluation Data\n\n" if training_data: md_content += f"The model was trained on {training_data}. " if training_compute: md_content += f"Training was performed using {training_compute}. " if training_time: md_content += f"The total training time was approximately {training_time}." md_content += "\n\n" if eval_data: md_content += f"Evaluation was performed on {eval_data}.\n\n" # Performance metrics if metrics_list: md_content += "## Model Performance\n\n" md_content += "The model achieves the following performance metrics:\n\n" for metric in metrics_list: md_content += f"- {metric}\n" md_content += "\n" # Limitations if limitations: md_content += "## Limitations and Biases\n\n" md_content += f"{limitations}\n\n" # Usage if use_cases: md_content += "## Intended Uses & Limitations\n\n" md_content += f"{use_cases}\n\n" # Code example if code_example: md_content += "## How to Use\n\n" md_content += "Here's an example of how to use this model:\n\n" md_content += f"{code_example}\n\n" # Citation if citation: md_content += "## Citation\n\n" md_content += f"{citation}\n\n" # Keep existing custom content if requested if keep_existing_content and markdown_content: # Try to extract sections we haven't covered existing_sections = re.findall(r"^## (.+?)\n\n(.*?)(?=^## |\Z)", markdown_content, re.MULTILINE | re.DOTALL) standard_sections = ["Model Description", "Training and Evaluation Data", "Model Performance", "Limitations and Biases", "Intended Uses & Limitations", "How to Use", "Citation"] for section_title, section_content in existing_sections: if section_title.strip() not in standard_sections: md_content += f"## {section_title}\n\n{section_content}\n\n" # Add additional custom sections if additional_sections: md_content += f"\n{additional_sections}\n" # Combine everything into the final model card final_model_card = f"---\n{yaml_frontmatter}\n---\n\n{md_content.strip()}" # Display the generated model card st.markdown("### Generated Model Card") st.code(final_model_card, language="markdown") # Option to update the model card if st.button("Update Model Card", use_container_width=True, type="primary"): with st.spinner("Updating model card..."): try: # Update the model card success, _ = st.session_state.client.update_model_card( model_info.modelId, final_model_card ) if success: st.success("Model card updated successfully!") time.sleep(1) # Give API time to update st.rerun() else: st.error("Failed to update model card") except Exception as e: st.error(f"Error updating model card: {str(e)}") except Exception as e: st.error(f"Error generating model card: {str(e)}")