import gradio as gr from PIL import Image import matplotlib.pyplot as plt from transformers import pipeline import os # Initialize the vision agent (same as your original) agent = pipeline("image-classification", model="google/vit-base-patch16-224") class GradioVisionAnalyzer: def __init__(self): self.min_confidence = 0.1 # Default confidence threshold def analyze_image(self, image, confidence_threshold): """Gradio-compatible analysis function""" self.min_confidence = confidence_threshold/100 # Convert slider % to decimal try: results = agent(image) filtered_results = [r for r in results if r['score'] >= self.min_confidence] if not filtered_results: return None, "No confident identifications found (adjust confidence threshold)" # Create visualization fig = self.create_visualization(image, filtered_results) return fig, self.format_results(filtered_results) except Exception as e: return None, f"Error: {str(e)}" def create_visualization(self, img, results): """Adapted matplotlib visualization for Gradio""" plt.figure(figsize=(10, 5)) # Show original image plt.subplot(1, 2, 1) plt.imshow(img) plt.axis('off') plt.title("Uploaded Image") # Show results as bar chart plt.subplot(1, 2, 2) labels = [r['label'] for r in results] scores = [r['score'] for r in results] colors = plt.cm.viridis([s/max(scores) for s in scores]) bars = plt.barh(labels, scores, color=colors) plt.xlabel('Confidence Score') plt.title(f'Results (Threshold: {self.min_confidence:.0%})') plt.xlim(0, 1) for bar in bars: width = bar.get_width() plt.text(min(width + 0.01, 0.99), bar.get_y() + bar.get_height()/2, f'{width:.0%}', va='center', ha='left') plt.tight_layout() return plt.gcf() # Return the figure object def format_results(self, results): """Format results for text output""" output = f"Minimum Confidence: {self.min_confidence:.0%}\n\n" for i, result in enumerate(results, 1): output += f"{i}. {result['label']} ({result['score']:.0%} confidence)\n" return output # Initialize analyzer analyzer = GradioVisionAnalyzer() # Create Gradio interface with gr.Blocks(title="AI Vision Agent for Security Compliance") as demo: gr.Markdown(""" ## 🛡️ AI Security Compliance Assistant Upload images to detect policy violations (unattended devices, clean-desk issues, etc.) """) with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Upload Security Image") confidence_slider = gr.Slider(0, 100, value=10, label="Confidence Threshold (%)") analyze_btn = gr.Button("Analyze Image") with gr.Column(): plot_output = gr.Plot(label="Detection Results") text_output = gr.Textbox(label="Detailed Findings", interactive=False) # Example images for quick testing gr.Examples( examples=["apresentation.png", "image1.png" , "image2.jpg" , "image3.png"], inputs=image_input, label="Try sample images" ) analyze_btn.click( fn=analyzer.analyze_image, inputs=[image_input, confidence_slider], outputs=[plot_output, text_output] ) # For Hugging Face Spaces demo.launch()