ModelHubManager / components /documentation_generator.py
S-Dreamer's picture
Upload 31 files
74dd3f1 verified
import streamlit as st
import re
import json
import time
def model_documentation_generator(model_info):
"""Generate comprehensive model documentation based on metadata"""
if not model_info:
st.error("Model information not found")
return
st.subheader("πŸ”„ Automated Model Documentation Generator")
st.markdown("This tool generates a comprehensive model card based on model metadata and your input.")
# Extract existing model card content if available
model_card_content = ""
yaml_content = ""
markdown_content = ""
try:
repo_id = model_info.modelId
model_card_url = f"https://huggingface.co/{repo_id}/raw/main/README.md"
response = st.session_state.client.api._get_paginated(model_card_url)
if response.status_code == 200:
model_card_content = response.text
# Extract YAML frontmatter
yaml_match = re.search(r"---\s+(.*?)\s+---", model_card_content, re.DOTALL)
if yaml_match:
yaml_content = yaml_match.group(1)
# Extract markdown content (everything after frontmatter)
markdown_match = re.search(r"---\s+.*?\s+---\s*(.*)", model_card_content, re.DOTALL)
if markdown_match:
markdown_content = markdown_match.group(1).strip()
except Exception as e:
st.warning(f"Couldn't load model card: {str(e)}")
# Form for model metadata input
with st.form("model_doc_form"):
st.markdown("### Model Metadata")
# Basic Information
st.markdown("#### Basic Information")
col1, col2 = st.columns(2)
with col1:
# Extract model name from repo ID
model_name = model_info.modelId.split("/")[-1]
model_title = st.text_input("Model Title", value=model_name.replace("-", " ").title())
with col2:
# Model type selection
model_type_options = [
"Text Classification",
"Token Classification",
"Question Answering",
"Summarization",
"Translation",
"Text Generation",
"Image Classification",
"Object Detection",
"Other"
]
# Try to determine model type from tags
default_type_index = 0
tags = getattr(model_info, "tags", [])
for i, option in enumerate(model_type_options):
option_key = option.lower().replace(" ", "-")
if option_key in tags or option_key.replace("-", "_") in tags:
default_type_index = i
break
model_type = st.selectbox(
"Model Type",
model_type_options,
index=default_type_index
)
# Model description
description = st.text_area(
"Model Description",
value=getattr(model_info, "description", "") or "",
height=100,
help="A brief overview of what the model does"
)
# Technical Information
st.markdown("#### Technical Information")
col1, col2 = st.columns(2)
with col1:
# Model Architecture
architecture_options = [
"BERT", "GPT-2", "T5", "RoBERTa", "DeBERTa", "DistilBERT",
"BART", "ResNet", "YOLO", "Other"
]
architecture = st.selectbox("Model Architecture", architecture_options)
# Framework
framework_options = ["PyTorch", "TensorFlow", "JAX", "Other"]
framework = st.selectbox("Framework", framework_options)
with col2:
# Model size
model_size = st.text_input("Model Size (e.g., 110M parameters)")
# Language
language_options = ["English", "French", "German", "Spanish", "Chinese", "Japanese", "Multilingual", "Other"]
language = st.selectbox("Language", language_options)
# Training Information
st.markdown("#### Training Information")
col1, col2 = st.columns(2)
with col1:
# Training Dataset
training_data = st.text_input("Training Dataset(s)")
# Training compute
training_compute = st.text_input("Training Infrastructure (e.g., TPU v3-8, 4x A100)")
with col2:
# Evaluation Dataset
eval_data = st.text_input("Evaluation Dataset(s)")
# Training time
training_time = st.text_input("Training Time (e.g., 3 days, 12 hours)")
# Performance Metrics
st.markdown("#### Performance Metrics")
metrics_data = st.text_area(
"Performance Metrics (one per line, e.g., 'Accuracy: 0.92')",
height=100,
help="Key metrics and their values"
)
# Limitations
st.markdown("#### Limitations and Biases")
limitations = st.text_area(
"Known Limitations and Biases",
height=100,
help="Document any known limitations, biases, or ethical considerations"
)
# Usage Information
st.markdown("#### Usage Information")
use_cases = st.text_area(
"Intended Use Cases",
height=100,
help="Describe how the model should be used"
)
code_example = st.text_area(
"Code Example",
height=150,
value=f"""
```python
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("{model_info.modelId}")
model = AutoModel.from_pretrained("{model_info.modelId}")
inputs = tokenizer("Hello, world!", return_tensors="pt")
outputs = model(**inputs)
```
""",
help="Provide a simple code example showing how to use the model"
)
# License and Citation
st.markdown("#### License and Citation")
license_options = ["MIT", "Apache-2.0", "GPL-3.0", "CC-BY-SA-4.0", "CC-BY-4.0", "Proprietary", "Other"]
license_type = st.selectbox("License", license_options)
citation = st.text_area(
"Citation Information",
height=100,
help="Provide citation information if applicable"
)
# Tags
st.markdown("#### Tags")
# Get available tags
available_tags = st.session_state.client.get_model_tags()
# Extract existing tags
existing_tags = []
if yaml_content:
tags_match = re.search(r"tags:\s*((?:- .*?\n)+)", yaml_content, re.DOTALL)
if tags_match:
existing_tags = [
line.strip("- \n")
for line in tags_match.group(1).split("\n")
if line.strip().startswith("-")
]
selected_tags = st.multiselect(
"Select tags for your model",
options=available_tags,
default=existing_tags,
help="Tags help others discover your model"
)
# Advanced options
with st.expander("Advanced Options"):
keep_existing_content = st.checkbox(
"Keep existing custom content",
value=True,
help="If checked, we'll try to preserve custom sections from your existing model card"
)
additional_sections = st.text_area(
"Additional Custom Sections (in Markdown)",
height=200,
help="Add any additional custom sections in Markdown format"
)
# Submit button
submitted = st.form_submit_button("Generate Model Card", use_container_width=True)
if submitted:
with st.spinner("Generating comprehensive model card..."):
try:
# Parse performance metrics
metrics_list = []
for line in metrics_data.split("\n"):
line = line.strip()
if line:
metrics_list.append(line)
# Generate YAML frontmatter
yaml_frontmatter = f"""tags:
{chr(10).join(['- ' + tag for tag in selected_tags])}
license: {license_type}"""
if language and language != "Other":
yaml_frontmatter += f"\nlanguage: {language.lower()}"
if model_type and model_type != "Other":
yaml_frontmatter += f"\npipeline_tag: {model_type.lower().replace(' ', '-')}"
# Generate markdown content
md_content = f"""# {model_title}
{description}
## Model Description
This model is a {architecture}-based model for {model_type} tasks. It was developed using {framework} and consists of {model_size if model_size else "multiple"} parameters.
"""
# Training section
if training_data or eval_data or training_compute or training_time:
md_content += "## Training and Evaluation Data\n\n"
if training_data:
md_content += f"The model was trained on {training_data}. "
if training_compute:
md_content += f"Training was performed using {training_compute}. "
if training_time:
md_content += f"The total training time was approximately {training_time}."
md_content += "\n\n"
if eval_data:
md_content += f"Evaluation was performed on {eval_data}.\n\n"
# Performance metrics
if metrics_list:
md_content += "## Model Performance\n\n"
md_content += "The model achieves the following performance metrics:\n\n"
for metric in metrics_list:
md_content += f"- {metric}\n"
md_content += "\n"
# Limitations
if limitations:
md_content += "## Limitations and Biases\n\n"
md_content += f"{limitations}\n\n"
# Usage
if use_cases:
md_content += "## Intended Uses & Limitations\n\n"
md_content += f"{use_cases}\n\n"
# Code example
if code_example:
md_content += "## How to Use\n\n"
md_content += "Here's an example of how to use this model:\n\n"
md_content += f"{code_example}\n\n"
# Citation
if citation:
md_content += "## Citation\n\n"
md_content += f"{citation}\n\n"
# Keep existing custom content if requested
if keep_existing_content and markdown_content:
# Try to extract sections we haven't covered
existing_sections = re.findall(r"^## (.+?)\n\n(.*?)(?=^## |\Z)", markdown_content, re.MULTILINE | re.DOTALL)
standard_sections = ["Model Description", "Training and Evaluation Data", "Model Performance",
"Limitations and Biases", "Intended Uses & Limitations", "How to Use", "Citation"]
for section_title, section_content in existing_sections:
if section_title.strip() not in standard_sections:
md_content += f"## {section_title}\n\n{section_content}\n\n"
# Add additional custom sections
if additional_sections:
md_content += f"\n{additional_sections}\n"
# Combine everything into the final model card
final_model_card = f"---\n{yaml_frontmatter}\n---\n\n{md_content.strip()}"
# Display the generated model card
st.markdown("### Generated Model Card")
st.code(final_model_card, language="markdown")
# Option to update the model card
if st.button("Update Model Card", use_container_width=True, type="primary"):
with st.spinner("Updating model card..."):
try:
# Update the model card
success, _ = st.session_state.client.update_model_card(
model_info.modelId, final_model_card
)
if success:
st.success("Model card updated successfully!")
time.sleep(1) # Give API time to update
st.rerun()
else:
st.error("Failed to update model card")
except Exception as e:
st.error(f"Error updating model card: {str(e)}")
except Exception as e:
st.error(f"Error generating model card: {str(e)}")