import streamlit as st import json import pandas as pd import numpy as np import plotly.express as px from io import StringIO import time def model_inference_dashboard(model_info): """Create a dashboard for testing model inference directly in the app""" if not model_info: st.error("Model information not found") return st.subheader("🧠 Model Inference Dashboard") # Get the pipeline type based on model tags or information pipeline_tag = getattr(model_info, "pipeline_tag", None) if not pipeline_tag: # Try to determine from tags tags = getattr(model_info, "tags", []) for tag in tags: if tag in [ "text-classification", "token-classification", "question-answering", "summarization", "translation", "text-generation", "fill-mask", "sentence-similarity", "image-classification", "object-detection", "image-segmentation", "text-to-image", "image-to-text" ]: pipeline_tag = tag break if not pipeline_tag: pipeline_tag = "text-classification" # Default fallback # Display information about the model st.info(f"This dashboard allows you to test your model's inference capabilities. Model pipeline: **{pipeline_tag}**") # Different input options based on pipeline type input_data = None if pipeline_tag in ["text-classification", "token-classification", "fill-mask", "text-generation", "summarization"]: # Text-based input st.markdown("### Text Input") input_text = st.text_area( "Enter text for inference", value="This model is amazing!", height=150 ) # Additional parameters for specific pipelines if pipeline_tag == "text-generation": col1, col2 = st.columns(2) with col1: max_length = st.slider("Max Length", min_value=10, max_value=500, value=100) with col2: temperature = st.slider("Temperature", min_value=0.1, max_value=2.0, value=1.0, step=0.1) input_data = { "text": input_text, "max_length": max_length, "temperature": temperature } elif pipeline_tag == "summarization": max_length = st.slider("Max Summary Length", min_value=10, max_value=200, value=50) input_data = { "text": input_text, "max_length": max_length } else: input_data = {"text": input_text} elif pipeline_tag in ["question-answering"]: st.markdown("### Question & Context") question = st.text_input("Question", value="What is this model about?") context = st.text_area( "Context", value="This model is a transformer-based language model designed for natural language understanding tasks.", height=150 ) input_data = { "question": question, "context": context } elif pipeline_tag in ["translation"]: st.markdown("### Translation") source_lang = st.selectbox("Source Language", ["English", "French", "German", "Spanish", "Chinese"]) target_lang = st.selectbox("Target Language", ["French", "English", "German", "Spanish", "Chinese"]) translation_text = st.text_area("Text to translate", value="Hello, how are you?", height=150) input_data = { "text": translation_text, "source_language": source_lang, "target_language": target_lang } elif pipeline_tag in ["image-classification", "object-detection", "image-segmentation"]: st.markdown("### Image Input") upload_method = st.radio("Select input method", ["Upload Image", "Image URL"]) if upload_method == "Upload Image": uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: st.image(uploaded_file, caption="Uploaded Image", use_column_width=True) input_data = {"image": uploaded_file} else: image_url = st.text_input("Image URL", value="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/distilbert-base-uncased-finetuned-sst-2-english-architecture.png") if image_url: st.image(image_url, caption="Image from URL", use_column_width=True) input_data = {"image_url": image_url} elif pipeline_tag in ["audio-classification", "automatic-speech-recognition"]: st.markdown("### Audio Input") upload_method = st.radio("Select input method", ["Upload Audio", "Audio URL"]) if upload_method == "Upload Audio": uploaded_file = st.file_uploader("Upload an audio file", type=["mp3", "wav", "ogg"]) if uploaded_file is not None: st.audio(uploaded_file) input_data = {"audio": uploaded_file} else: audio_url = st.text_input("Audio URL") if audio_url: st.audio(audio_url) input_data = {"audio_url": audio_url} # Execute inference if st.button("Run Inference", use_container_width=True): if input_data: with st.spinner("Running inference..."): # In a real implementation, this would call the HF Inference API # For demo purposes, simulate a response time.sleep(2) # Generate a sample response based on the pipeline type if pipeline_tag == "text-classification": result = [ {"label": "POSITIVE", "score": 0.9231}, {"label": "NEGATIVE", "score": 0.0769} ] elif pipeline_tag == "token-classification": result = [ {"entity": "B-PER", "word": "This", "score": 0.2, "index": 0, "start": 0, "end": 4}, {"entity": "O", "word": "model", "score": 0.95, "index": 1, "start": 5, "end": 10}, {"entity": "O", "word": "is", "score": 0.99, "index": 2, "start": 11, "end": 13}, {"entity": "B-MISC", "word": "amazing", "score": 0.85, "index": 3, "start": 14, "end": 21} ] elif pipeline_tag == "text-generation": result = { "generated_text": input_data["text"] + " It provides state-of-the-art performance on a wide range of natural language processing tasks, including sentiment analysis, named entity recognition, and question answering. The model was trained on a diverse corpus of text data, allowing it to generate coherent and contextually relevant responses." } elif pipeline_tag == "summarization": result = { "summary_text": "This model provides excellent performance." } elif pipeline_tag == "question-answering": result = { "answer": "a transformer-based language model", "start": 9, "end": 45, "score": 0.953 } elif pipeline_tag == "translation": if input_data["target_language"] == "French": result = {"translation_text": "Bonjour, comment allez-vous?"} elif input_data["target_language"] == "German": result = {"translation_text": "Hallo, wie geht es dir?"} elif input_data["target_language"] == "Spanish": result = {"translation_text": "Hola, ¿cómo estás?"} elif input_data["target_language"] == "Chinese": result = {"translation_text": "你好,你好吗?"} else: result = {"translation_text": "Hello, how are you?"} elif pipeline_tag in ["image-classification"]: result = [ {"label": "diagram", "score": 0.9712}, {"label": "architecture", "score": 0.0231}, {"label": "document", "score": 0.0057} ] elif pipeline_tag in ["object-detection"]: result = [ {"label": "box", "score": 0.9712, "box": {"xmin": 10, "ymin": 20, "xmax": 100, "ymax": 80}}, {"label": "text", "score": 0.8923, "box": {"xmin": 120, "ymin": 30, "xmax": 250, "ymax": 60}} ] else: result = {"result": "Sample response for " + pipeline_tag} # Display the results st.markdown("### Inference Results") # Different visualizations based on the response type if pipeline_tag == "text-classification": # Create a bar chart for classification results result_df = pd.DataFrame(result) fig = px.bar( result_df, x="label", y="score", color="score", color_continuous_scale=px.colors.sequential.Viridis, title="Classification Results" ) st.plotly_chart(fig, use_container_width=True) # Show the raw results st.json(result) elif pipeline_tag == "token-classification": # Display entity highlighting st.markdown("#### Named Entities") # Create HTML with colored spans for entities html = "" input_text = input_data["text"] entities = {} for item in result: if item["entity"].startswith("B-") or item["entity"].startswith("I-"): entity_type = item["entity"][2:] # Remove B- or I- prefix entities[entity_type] = entities.get(entity_type, 0) + 1 # Create a color map for entity types colors = px.colors.qualitative.Plotly[:len(entities)] entity_colors = dict(zip(entities.keys(), colors)) # Create the HTML for item in result: word = item["word"] entity = item["entity"] if entity == "O": html += f"{word} " else: entity_type = entity[2:] if entity.startswith("B-") or entity.startswith("I-") else entity color = entity_colors.get(entity_type, "#CCCCCC") html += f'{word} ' st.markdown(f'