import streamlit as st import json import pandas as pd import numpy as np import plotly.express as px from io import StringIO import time def model_inference_dashboard(model_info): """Create a dashboard for testing model inference directly in the app""" if not model_info: st.error("Model information not found") return st.subheader("🧠 Model Inference Dashboard") # Get the pipeline type based on model tags or information pipeline_tag = getattr(model_info, "pipeline_tag", None) if not pipeline_tag: # Try to determine from tags tags = getattr(model_info, "tags", []) for tag in tags: if tag in [ "text-classification", "token-classification", "question-answering", "summarization", "translation", "text-generation", "fill-mask", "sentence-similarity", "image-classification", "object-detection", "image-segmentation", "text-to-image", "image-to-text" ]: pipeline_tag = tag break if not pipeline_tag: pipeline_tag = "text-classification" # Default fallback # Display information about the model st.info(f"This dashboard allows you to test your model's inference capabilities. Model pipeline: **{pipeline_tag}**") # Different input options based on pipeline type input_data = None if pipeline_tag in ["text-classification", "token-classification", "fill-mask", "text-generation", "summarization"]: # Text-based input st.markdown("### Text Input") input_text = st.text_area( "Enter text for inference", value="This model is amazing!", height=150 ) # Additional parameters for specific pipelines if pipeline_tag == "text-generation": col1, col2 = st.columns(2) with col1: max_length = st.slider("Max Length", min_value=10, max_value=500, value=100) with col2: temperature = st.slider("Temperature", min_value=0.1, max_value=2.0, value=1.0, step=0.1) input_data = { "text": input_text, "max_length": max_length, "temperature": temperature } elif pipeline_tag == "summarization": max_length = st.slider("Max Summary Length", min_value=10, max_value=200, value=50) input_data = { "text": input_text, "max_length": max_length } else: input_data = {"text": input_text} elif pipeline_tag in ["question-answering"]: st.markdown("### Question & Context") question = st.text_input("Question", value="What is this model about?") context = st.text_area( "Context", value="This model is a transformer-based language model designed for natural language understanding tasks.", height=150 ) input_data = { "question": question, "context": context } elif pipeline_tag in ["translation"]: st.markdown("### Translation") source_lang = st.selectbox("Source Language", ["English", "French", "German", "Spanish", "Chinese"]) target_lang = st.selectbox("Target Language", ["French", "English", "German", "Spanish", "Chinese"]) translation_text = st.text_area("Text to translate", value="Hello, how are you?", height=150) input_data = { "text": translation_text, "source_language": source_lang, "target_language": target_lang } elif pipeline_tag in ["image-classification", "object-detection", "image-segmentation"]: st.markdown("### Image Input") upload_method = st.radio("Select input method", ["Upload Image", "Image URL"]) if upload_method == "Upload Image": uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: st.image(uploaded_file, caption="Uploaded Image", use_column_width=True) input_data = {"image": uploaded_file} else: image_url = st.text_input("Image URL", value="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/distilbert-base-uncased-finetuned-sst-2-english-architecture.png") if image_url: st.image(image_url, caption="Image from URL", use_column_width=True) input_data = {"image_url": image_url} elif pipeline_tag in ["audio-classification", "automatic-speech-recognition"]: st.markdown("### Audio Input") upload_method = st.radio("Select input method", ["Upload Audio", "Audio URL"]) if upload_method == "Upload Audio": uploaded_file = st.file_uploader("Upload an audio file", type=["mp3", "wav", "ogg"]) if uploaded_file is not None: st.audio(uploaded_file) input_data = {"audio": uploaded_file} else: audio_url = st.text_input("Audio URL") if audio_url: st.audio(audio_url) input_data = {"audio_url": audio_url} # Execute inference if st.button("Run Inference", use_container_width=True): if input_data: with st.spinner("Running inference..."): # In a real implementation, this would call the HF Inference API # For demo purposes, simulate a response time.sleep(2) # Generate a sample response based on the pipeline type if pipeline_tag == "text-classification": result = [ {"label": "POSITIVE", "score": 0.9231}, {"label": "NEGATIVE", "score": 0.0769} ] elif pipeline_tag == "token-classification": result = [ {"entity": "B-PER", "word": "This", "score": 0.2, "index": 0, "start": 0, "end": 4}, {"entity": "O", "word": "model", "score": 0.95, "index": 1, "start": 5, "end": 10}, {"entity": "O", "word": "is", "score": 0.99, "index": 2, "start": 11, "end": 13}, {"entity": "B-MISC", "word": "amazing", "score": 0.85, "index": 3, "start": 14, "end": 21} ] elif pipeline_tag == "text-generation": result = { "generated_text": input_data["text"] + " It provides state-of-the-art performance on a wide range of natural language processing tasks, including sentiment analysis, named entity recognition, and question answering. The model was trained on a diverse corpus of text data, allowing it to generate coherent and contextually relevant responses." } elif pipeline_tag == "summarization": result = { "summary_text": "This model provides excellent performance." } elif pipeline_tag == "question-answering": result = { "answer": "a transformer-based language model", "start": 9, "end": 45, "score": 0.953 } elif pipeline_tag == "translation": if input_data["target_language"] == "French": result = {"translation_text": "Bonjour, comment allez-vous?"} elif input_data["target_language"] == "German": result = {"translation_text": "Hallo, wie geht es dir?"} elif input_data["target_language"] == "Spanish": result = {"translation_text": "Hola, ¿cómo estás?"} elif input_data["target_language"] == "Chinese": result = {"translation_text": "你好,你好吗?"} else: result = {"translation_text": "Hello, how are you?"} elif pipeline_tag in ["image-classification"]: result = [ {"label": "diagram", "score": 0.9712}, {"label": "architecture", "score": 0.0231}, {"label": "document", "score": 0.0057} ] elif pipeline_tag in ["object-detection"]: result = [ {"label": "box", "score": 0.9712, "box": {"xmin": 10, "ymin": 20, "xmax": 100, "ymax": 80}}, {"label": "text", "score": 0.8923, "box": {"xmin": 120, "ymin": 30, "xmax": 250, "ymax": 60}} ] else: result = {"result": "Sample response for " + pipeline_tag} # Display the results st.markdown("### Inference Results") # Different visualizations based on the response type if pipeline_tag == "text-classification": # Create a bar chart for classification results result_df = pd.DataFrame(result) fig = px.bar( result_df, x="label", y="score", color="score", color_continuous_scale=px.colors.sequential.Viridis, title="Classification Results" ) st.plotly_chart(fig, use_container_width=True) # Show the raw results st.json(result) elif pipeline_tag == "token-classification": # Display entity highlighting st.markdown("#### Named Entities") # Create HTML with colored spans for entities html = "" input_text = input_data["text"] entities = {} for item in result: if item["entity"].startswith("B-") or item["entity"].startswith("I-"): entity_type = item["entity"][2:] # Remove B- or I- prefix entities[entity_type] = entities.get(entity_type, 0) + 1 # Create a color map for entity types colors = px.colors.qualitative.Plotly[:len(entities)] entity_colors = dict(zip(entities.keys(), colors)) # Create the HTML for item in result: word = item["word"] entity = item["entity"] if entity == "O": html += f"{word} " else: entity_type = entity[2:] if entity.startswith("B-") or entity.startswith("I-") else entity color = entity_colors.get(entity_type, "#CCCCCC") html += f'{word} ' st.markdown(f'
{html}
', unsafe_allow_html=True) # Display legend st.markdown("#### Entity Legend") legend_html = "".join([ f'{entity}' for entity, color in entity_colors.items() ]) st.markdown(f'
{legend_html}
', unsafe_allow_html=True) # Show the raw results st.json(result) elif pipeline_tag in ["text-generation", "summarization", "translation"]: # Display the generated text response_key = "generated_text" if "generated_text" in result else "summary_text" if "summary_text" in result else "translation_text" st.markdown(f"#### Output Text") st.markdown(f'
{result[response_key]}
', unsafe_allow_html=True) # Text stats st.markdown("#### Text Statistics") input_length = len(input_data["text"]) if "text" in input_data else 0 output_length = len(result[response_key]) col1, col2, col3 = st.columns(3) with col1: st.metric("Input Length", input_length, "characters") with col2: st.metric("Output Length", output_length, "characters") with col3: compression = ((output_length - input_length) / input_length * 100) if input_length > 0 else 0 st.metric("Length Change", f"{compression:.1f}%", f"{output_length - input_length} chars") elif pipeline_tag == "question-answering": # Highlight the answer in the context st.markdown("#### Answer") st.markdown(f'
{result["answer"]}
', unsafe_allow_html=True) # Show the answer in context if "context" in input_data: st.markdown("#### Answer in Context") context = input_data["context"] start = result["start"] end = result["end"] highlighted_context = ( context[:start] + f'{context[start:end]}' + context[end:] ) st.markdown(f'
{highlighted_context}
', unsafe_allow_html=True) # Confidence score st.markdown("#### Confidence") st.progress(result["score"]) st.text(f"Confidence Score: {result['score']:.4f}") elif pipeline_tag == "image-classification": # Create a bar chart for classification results result_df = pd.DataFrame(result) fig = px.bar( result_df, x="score", y="label", orientation='h', color="score", color_continuous_scale=px.colors.sequential.Viridis, title="Image Classification Results" ) fig.update_layout(yaxis={'categoryorder':'total ascending'}) st.plotly_chart(fig, use_container_width=True) # Show the raw results st.json(result) else: # Generic display for other types st.json(result) # Option to save the results st.download_button( label="Download Results", data=json.dumps(result, indent=2), file_name="inference_results.json", mime="application/json" ) else: st.warning("Please provide input data for inference") # API integration options with st.expander("API Integration"): st.markdown("### Use this model in your application") # Python code example st.markdown("#### Python") python_code = f""" ```python import requests API_URL = "https://api-inference.huggingface.co/models/{model_info.modelId}" headers = {{"Authorization": "Bearer YOUR_API_KEY"}} def query(payload): response = requests.post(API_URL, headers=headers, json=payload) return response.json() # Example usage output = query({{ "inputs": "This model is amazing!" }}) print(output) ``` """ st.markdown(python_code) # JavaScript code example st.markdown("#### JavaScript") js_code = f""" ```javascript async function query(data) {{ const response = await fetch( "https://api-inference.huggingface.co/models/{model_info.modelId}", {{ headers: {{ Authorization: "Bearer YOUR_API_KEY" }}, method: "POST", body: JSON.stringify(data), }} ); const result = await response.json(); return result; }} // Example usage query({{"inputs": "This model is amazing!"}}).then((response) => {{ console.log(JSON.stringify(response)); }}); ``` """ st.markdown(js_code)