Spaces:
Sleeping
Sleeping
import streamlit as st | |
import re | |
import json | |
import time | |
def model_documentation_generator(model_info): | |
"""Generate comprehensive model documentation based on metadata""" | |
if not model_info: | |
st.error("Model information not found") | |
return | |
st.subheader("π Automated Model Documentation Generator") | |
st.markdown("This tool generates a comprehensive model card based on model metadata and your input.") | |
# Extract existing model card content if available | |
model_card_content = "" | |
yaml_content = "" | |
markdown_content = "" | |
try: | |
repo_id = model_info.modelId | |
model_card_url = f"https://huggingface.co/{repo_id}/raw/main/README.md" | |
response = st.session_state.client.api._get_paginated(model_card_url) | |
if response.status_code == 200: | |
model_card_content = response.text | |
# Extract YAML frontmatter | |
yaml_match = re.search(r"---\s+(.*?)\s+---", model_card_content, re.DOTALL) | |
if yaml_match: | |
yaml_content = yaml_match.group(1) | |
# Extract markdown content (everything after frontmatter) | |
markdown_match = re.search(r"---\s+.*?\s+---\s*(.*)", model_card_content, re.DOTALL) | |
if markdown_match: | |
markdown_content = markdown_match.group(1).strip() | |
except Exception as e: | |
st.warning(f"Couldn't load model card: {str(e)}") | |
# Form for model metadata input | |
with st.form("model_doc_form"): | |
st.markdown("### Model Metadata") | |
# Basic Information | |
st.markdown("#### Basic Information") | |
col1, col2 = st.columns(2) | |
with col1: | |
# Extract model name from repo ID | |
model_name = model_info.modelId.split("/")[-1] | |
model_title = st.text_input("Model Title", value=model_name.replace("-", " ").title()) | |
with col2: | |
# Model type selection | |
model_type_options = [ | |
"Text Classification", | |
"Token Classification", | |
"Question Answering", | |
"Summarization", | |
"Translation", | |
"Text Generation", | |
"Image Classification", | |
"Object Detection", | |
"Other" | |
] | |
# Try to determine model type from tags | |
default_type_index = 0 | |
tags = getattr(model_info, "tags", []) | |
for i, option in enumerate(model_type_options): | |
option_key = option.lower().replace(" ", "-") | |
if option_key in tags or option_key.replace("-", "_") in tags: | |
default_type_index = i | |
break | |
model_type = st.selectbox( | |
"Model Type", | |
model_type_options, | |
index=default_type_index | |
) | |
# Model description | |
description = st.text_area( | |
"Model Description", | |
value=getattr(model_info, "description", "") or "", | |
height=100, | |
help="A brief overview of what the model does" | |
) | |
# Technical Information | |
st.markdown("#### Technical Information") | |
col1, col2 = st.columns(2) | |
with col1: | |
# Model Architecture | |
architecture_options = [ | |
"BERT", "GPT-2", "T5", "RoBERTa", "DeBERTa", "DistilBERT", | |
"BART", "ResNet", "YOLO", "Other" | |
] | |
architecture = st.selectbox("Model Architecture", architecture_options) | |
# Framework | |
framework_options = ["PyTorch", "TensorFlow", "JAX", "Other"] | |
framework = st.selectbox("Framework", framework_options) | |
with col2: | |
# Model size | |
model_size = st.text_input("Model Size (e.g., 110M parameters)") | |
# Language | |
language_options = ["English", "French", "German", "Spanish", "Chinese", "Japanese", "Multilingual", "Other"] | |
language = st.selectbox("Language", language_options) | |
# Training Information | |
st.markdown("#### Training Information") | |
col1, col2 = st.columns(2) | |
with col1: | |
# Training Dataset | |
training_data = st.text_input("Training Dataset(s)") | |
# Training compute | |
training_compute = st.text_input("Training Infrastructure (e.g., TPU v3-8, 4x A100)") | |
with col2: | |
# Evaluation Dataset | |
eval_data = st.text_input("Evaluation Dataset(s)") | |
# Training time | |
training_time = st.text_input("Training Time (e.g., 3 days, 12 hours)") | |
# Performance Metrics | |
st.markdown("#### Performance Metrics") | |
metrics_data = st.text_area( | |
"Performance Metrics (one per line, e.g., 'Accuracy: 0.92')", | |
height=100, | |
help="Key metrics and their values" | |
) | |
# Limitations | |
st.markdown("#### Limitations and Biases") | |
limitations = st.text_area( | |
"Known Limitations and Biases", | |
height=100, | |
help="Document any known limitations, biases, or ethical considerations" | |
) | |
# Usage Information | |
st.markdown("#### Usage Information") | |
use_cases = st.text_area( | |
"Intended Use Cases", | |
height=100, | |
help="Describe how the model should be used" | |
) | |
code_example = st.text_area( | |
"Code Example", | |
height=150, | |
value=f""" | |
```python | |
from transformers import AutoTokenizer, AutoModel | |
tokenizer = AutoTokenizer.from_pretrained("{model_info.modelId}") | |
model = AutoModel.from_pretrained("{model_info.modelId}") | |
inputs = tokenizer("Hello, world!", return_tensors="pt") | |
outputs = model(**inputs) | |
``` | |
""", | |
help="Provide a simple code example showing how to use the model" | |
) | |
# License and Citation | |
st.markdown("#### License and Citation") | |
license_options = ["MIT", "Apache-2.0", "GPL-3.0", "CC-BY-SA-4.0", "CC-BY-4.0", "Proprietary", "Other"] | |
license_type = st.selectbox("License", license_options) | |
citation = st.text_area( | |
"Citation Information", | |
height=100, | |
help="Provide citation information if applicable" | |
) | |
# Tags | |
st.markdown("#### Tags") | |
# Get available tags | |
available_tags = st.session_state.client.get_model_tags() | |
# Extract existing tags | |
existing_tags = [] | |
if yaml_content: | |
tags_match = re.search(r"tags:\s*((?:- .*?\n)+)", yaml_content, re.DOTALL) | |
if tags_match: | |
existing_tags = [ | |
line.strip("- \n") | |
for line in tags_match.group(1).split("\n") | |
if line.strip().startswith("-") | |
] | |
selected_tags = st.multiselect( | |
"Select tags for your model", | |
options=available_tags, | |
default=existing_tags, | |
help="Tags help others discover your model" | |
) | |
# Advanced options | |
with st.expander("Advanced Options"): | |
keep_existing_content = st.checkbox( | |
"Keep existing custom content", | |
value=True, | |
help="If checked, we'll try to preserve custom sections from your existing model card" | |
) | |
additional_sections = st.text_area( | |
"Additional Custom Sections (in Markdown)", | |
height=200, | |
help="Add any additional custom sections in Markdown format" | |
) | |
# Submit button | |
submitted = st.form_submit_button("Generate Model Card", use_container_width=True) | |
if submitted: | |
with st.spinner("Generating comprehensive model card..."): | |
try: | |
# Parse performance metrics | |
metrics_list = [] | |
for line in metrics_data.split("\n"): | |
line = line.strip() | |
if line: | |
metrics_list.append(line) | |
# Generate YAML frontmatter | |
yaml_frontmatter = f"""tags: | |
{chr(10).join(['- ' + tag for tag in selected_tags])} | |
license: {license_type}""" | |
if language and language != "Other": | |
yaml_frontmatter += f"\nlanguage: {language.lower()}" | |
if model_type and model_type != "Other": | |
yaml_frontmatter += f"\npipeline_tag: {model_type.lower().replace(' ', '-')}" | |
# Generate markdown content | |
md_content = f"""# {model_title} | |
{description} | |
## Model Description | |
This model is a {architecture}-based model for {model_type} tasks. It was developed using {framework} and consists of {model_size if model_size else "multiple"} parameters. | |
""" | |
# Training section | |
if training_data or eval_data or training_compute or training_time: | |
md_content += "## Training and Evaluation Data\n\n" | |
if training_data: | |
md_content += f"The model was trained on {training_data}. " | |
if training_compute: | |
md_content += f"Training was performed using {training_compute}. " | |
if training_time: | |
md_content += f"The total training time was approximately {training_time}." | |
md_content += "\n\n" | |
if eval_data: | |
md_content += f"Evaluation was performed on {eval_data}.\n\n" | |
# Performance metrics | |
if metrics_list: | |
md_content += "## Model Performance\n\n" | |
md_content += "The model achieves the following performance metrics:\n\n" | |
for metric in metrics_list: | |
md_content += f"- {metric}\n" | |
md_content += "\n" | |
# Limitations | |
if limitations: | |
md_content += "## Limitations and Biases\n\n" | |
md_content += f"{limitations}\n\n" | |
# Usage | |
if use_cases: | |
md_content += "## Intended Uses & Limitations\n\n" | |
md_content += f"{use_cases}\n\n" | |
# Code example | |
if code_example: | |
md_content += "## How to Use\n\n" | |
md_content += "Here's an example of how to use this model:\n\n" | |
md_content += f"{code_example}\n\n" | |
# Citation | |
if citation: | |
md_content += "## Citation\n\n" | |
md_content += f"{citation}\n\n" | |
# Keep existing custom content if requested | |
if keep_existing_content and markdown_content: | |
# Try to extract sections we haven't covered | |
existing_sections = re.findall(r"^## (.+?)\n\n(.*?)(?=^## |\Z)", markdown_content, re.MULTILINE | re.DOTALL) | |
standard_sections = ["Model Description", "Training and Evaluation Data", "Model Performance", | |
"Limitations and Biases", "Intended Uses & Limitations", "How to Use", "Citation"] | |
for section_title, section_content in existing_sections: | |
if section_title.strip() not in standard_sections: | |
md_content += f"## {section_title}\n\n{section_content}\n\n" | |
# Add additional custom sections | |
if additional_sections: | |
md_content += f"\n{additional_sections}\n" | |
# Combine everything into the final model card | |
final_model_card = f"---\n{yaml_frontmatter}\n---\n\n{md_content.strip()}" | |
# Display the generated model card | |
st.markdown("### Generated Model Card") | |
st.code(final_model_card, language="markdown") | |
# Option to update the model card | |
if st.button("Update Model Card", use_container_width=True, type="primary"): | |
with st.spinner("Updating model card..."): | |
try: | |
# Update the model card | |
success, _ = st.session_state.client.update_model_card( | |
model_info.modelId, final_model_card | |
) | |
if success: | |
st.success("Model card updated successfully!") | |
time.sleep(1) # Give API time to update | |
st.rerun() | |
else: | |
st.error("Failed to update model card") | |
except Exception as e: | |
st.error(f"Error updating model card: {str(e)}") | |
except Exception as e: | |
st.error(f"Error generating model card: {str(e)}") | |