Spaces:
Sleeping
Sleeping
import gradio as gr | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
from transformers import pipeline | |
import os | |
# Initialize the vision agent (same as your original) | |
agent = pipeline("image-classification", model="google/vit-base-patch16-224") | |
class GradioVisionAnalyzer: | |
def __init__(self): | |
self.min_confidence = 0.1 # Default confidence threshold | |
def analyze_image(self, image, confidence_threshold): | |
"""Gradio-compatible analysis function""" | |
self.min_confidence = confidence_threshold/100 # Convert slider % to decimal | |
try: | |
results = agent(image) | |
filtered_results = [r for r in results if r['score'] >= self.min_confidence] | |
if not filtered_results: | |
return None, "No confident identifications found (adjust confidence threshold)" | |
# Create visualization | |
fig = self.create_visualization(image, filtered_results) | |
return fig, self.format_results(filtered_results) | |
except Exception as e: | |
return None, f"Error: {str(e)}" | |
def create_visualization(self, img, results): | |
"""Adapted matplotlib visualization for Gradio""" | |
plt.figure(figsize=(10, 5)) | |
# Show original image | |
plt.subplot(1, 2, 1) | |
plt.imshow(img) | |
plt.axis('off') | |
plt.title("Uploaded Image") | |
# Show results as bar chart | |
plt.subplot(1, 2, 2) | |
labels = [r['label'] for r in results] | |
scores = [r['score'] for r in results] | |
colors = plt.cm.viridis([s/max(scores) for s in scores]) | |
bars = plt.barh(labels, scores, color=colors) | |
plt.xlabel('Confidence Score') | |
plt.title(f'Results (Threshold: {self.min_confidence:.0%})') | |
plt.xlim(0, 1) | |
for bar in bars: | |
width = bar.get_width() | |
plt.text(min(width + 0.01, 0.99), | |
bar.get_y() + bar.get_height()/2, | |
f'{width:.0%}', | |
va='center', | |
ha='left') | |
plt.tight_layout() | |
return plt.gcf() # Return the figure object | |
def format_results(self, results): | |
"""Format results for text output""" | |
output = f"Minimum Confidence: {self.min_confidence:.0%}\n\n" | |
for i, result in enumerate(results, 1): | |
output += f"{i}. {result['label']} ({result['score']:.0%} confidence)\n" | |
return output | |
# Initialize analyzer | |
analyzer = GradioVisionAnalyzer() | |
# Create Gradio interface | |
with gr.Blocks(title="AI Vision Agent for Security Compliance") as demo: | |
gr.Markdown(""" | |
## π‘οΈ AI Security Compliance Assistant | |
Upload images to detect policy violations (unattended devices, clean-desk issues, etc.) | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image(type="pil", label="Upload Security Image") | |
confidence_slider = gr.Slider(0, 100, value=10, label="Confidence Threshold (%)") | |
analyze_btn = gr.Button("Analyze Image") | |
with gr.Column(): | |
plot_output = gr.Plot(label="Detection Results") | |
text_output = gr.Textbox(label="Detailed Findings", interactive=False) | |
# Example images for quick testing | |
gr.Examples( | |
examples=["apresentation.png", "image1.png" , "image2.jpg" , "image3.png"], | |
inputs=image_input, | |
label="Try sample images" | |
) | |
analyze_btn.click( | |
fn=analyzer.analyze_image, | |
inputs=[image_input, confidence_slider], | |
outputs=[plot_output, text_output] | |
) | |
# For Hugging Face Spaces | |
demo.launch() |