|
import gradio as gr |
|
from PIL import Image |
|
import matplotlib.pyplot as plt |
|
from transformers import pipeline |
|
import os |
|
|
|
|
|
agent = pipeline("image-classification", model="google/vit-base-patch16-224") |
|
|
|
class GradioVisionAnalyzer: |
|
def __init__(self): |
|
self.min_confidence = 0.1 |
|
|
|
def analyze_image(self, image, confidence_threshold): |
|
"""Gradio-compatible analysis function""" |
|
self.min_confidence = confidence_threshold/100 |
|
|
|
try: |
|
results = agent(image) |
|
filtered_results = [r for r in results if r['score'] >= self.min_confidence] |
|
|
|
if not filtered_results: |
|
return None, "No confident identifications found (adjust confidence threshold)" |
|
|
|
|
|
fig = self.create_visualization(image, filtered_results) |
|
return fig, self.format_results(filtered_results) |
|
|
|
except Exception as e: |
|
return None, f"Error: {str(e)}" |
|
|
|
def create_visualization(self, img, results): |
|
"""Adapted matplotlib visualization for Gradio""" |
|
plt.figure(figsize=(10, 5)) |
|
|
|
|
|
plt.subplot(1, 2, 1) |
|
plt.imshow(img) |
|
plt.axis('off') |
|
plt.title("Uploaded Image") |
|
|
|
|
|
plt.subplot(1, 2, 2) |
|
labels = [r['label'] for r in results] |
|
scores = [r['score'] for r in results] |
|
colors = plt.cm.viridis([s/max(scores) for s in scores]) |
|
|
|
bars = plt.barh(labels, scores, color=colors) |
|
plt.xlabel('Confidence Score') |
|
plt.title(f'Results (Threshold: {self.min_confidence:.0%})') |
|
plt.xlim(0, 1) |
|
|
|
for bar in bars: |
|
width = bar.get_width() |
|
plt.text(min(width + 0.01, 0.99), |
|
bar.get_y() + bar.get_height()/2, |
|
f'{width:.0%}', |
|
va='center', |
|
ha='left') |
|
|
|
plt.tight_layout() |
|
return plt.gcf() |
|
|
|
def format_results(self, results): |
|
"""Format results for text output""" |
|
output = f"Minimum Confidence: {self.min_confidence:.0%}\n\n" |
|
for i, result in enumerate(results, 1): |
|
output += f"{i}. {result['label']} ({result['score']:.0%} confidence)\n" |
|
return output |
|
|
|
|
|
analyzer = GradioVisionAnalyzer() |
|
|
|
|
|
with gr.Blocks(title="AI Vision Agent for Security Compliance") as demo: |
|
gr.Markdown(""" |
|
## π‘οΈ AI Security Compliance Assistant |
|
Upload images to detect policy violations (unattended devices, clean-desk issues, etc.) |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
image_input = gr.Image(type="pil", label="Upload Security Image") |
|
confidence_slider = gr.Slider(0, 100, value=10, label="Confidence Threshold (%)") |
|
analyze_btn = gr.Button("Analyze Image") |
|
|
|
with gr.Column(): |
|
plot_output = gr.Plot(label="Detection Results") |
|
text_output = gr.Textbox(label="Detailed Findings", interactive=False) |
|
|
|
|
|
gr.Examples( |
|
examples=["example1.jpg", "example2.jpg"], |
|
inputs=image_input, |
|
label="Try sample images" |
|
) |
|
|
|
analyze_btn.click( |
|
fn=analyzer.analyze_image, |
|
inputs=[image_input, confidence_slider], |
|
outputs=[plot_output, text_output] |
|
) |
|
|
|
|
|
demo.launch() |