ThiSecur commited on
Commit
f6cf3b1
Β·
verified Β·
1 Parent(s): f7201fd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -0
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import matplotlib.pyplot as plt
4
+ from transformers import pipeline
5
+ import os
6
+
7
+ # Initialize the vision agent (same as your original)
8
+ agent = pipeline("image-classification", model="google/vit-base-patch16-224")
9
+
10
+ class GradioVisionAnalyzer:
11
+ def __init__(self):
12
+ self.min_confidence = 0.1 # Default confidence threshold
13
+
14
+ def analyze_image(self, image, confidence_threshold):
15
+ """Gradio-compatible analysis function"""
16
+ self.min_confidence = confidence_threshold/100 # Convert slider % to decimal
17
+
18
+ try:
19
+ results = agent(image)
20
+ filtered_results = [r for r in results if r['score'] >= self.min_confidence]
21
+
22
+ if not filtered_results:
23
+ return None, "No confident identifications found (adjust confidence threshold)"
24
+
25
+ # Create visualization
26
+ fig = self.create_visualization(image, filtered_results)
27
+ return fig, self.format_results(filtered_results)
28
+
29
+ except Exception as e:
30
+ return None, f"Error: {str(e)}"
31
+
32
+ def create_visualization(self, img, results):
33
+ """Adapted matplotlib visualization for Gradio"""
34
+ plt.figure(figsize=(10, 5))
35
+
36
+ # Show original image
37
+ plt.subplot(1, 2, 1)
38
+ plt.imshow(img)
39
+ plt.axis('off')
40
+ plt.title("Uploaded Image")
41
+
42
+ # Show results as bar chart
43
+ plt.subplot(1, 2, 2)
44
+ labels = [r['label'] for r in results]
45
+ scores = [r['score'] for r in results]
46
+ colors = plt.cm.viridis([s/max(scores) for s in scores])
47
+
48
+ bars = plt.barh(labels, scores, color=colors)
49
+ plt.xlabel('Confidence Score')
50
+ plt.title(f'Results (Threshold: {self.min_confidence:.0%})')
51
+ plt.xlim(0, 1)
52
+
53
+ for bar in bars:
54
+ width = bar.get_width()
55
+ plt.text(min(width + 0.01, 0.99),
56
+ bar.get_y() + bar.get_height()/2,
57
+ f'{width:.0%}',
58
+ va='center',
59
+ ha='left')
60
+
61
+ plt.tight_layout()
62
+ return plt.gcf() # Return the figure object
63
+
64
+ def format_results(self, results):
65
+ """Format results for text output"""
66
+ output = f"Minimum Confidence: {self.min_confidence:.0%}\n\n"
67
+ for i, result in enumerate(results, 1):
68
+ output += f"{i}. {result['label']} ({result['score']:.0%} confidence)\n"
69
+ return output
70
+
71
+ # Initialize analyzer
72
+ analyzer = GradioVisionAnalyzer()
73
+
74
+ # Create Gradio interface
75
+ with gr.Blocks(title="AI Vision Agent for Security Compliance") as demo:
76
+ gr.Markdown("""
77
+ ## πŸ›‘οΈ AI Security Compliance Assistant
78
+ Upload images to detect policy violations (unattended devices, clean-desk issues, etc.)
79
+ """)
80
+
81
+ with gr.Row():
82
+ with gr.Column():
83
+ image_input = gr.Image(type="pil", label="Upload Security Image")
84
+ confidence_slider = gr.Slider(0, 100, value=10, label="Confidence Threshold (%)")
85
+ analyze_btn = gr.Button("Analyze Image")
86
+
87
+ with gr.Column():
88
+ plot_output = gr.Plot(label="Detection Results")
89
+ text_output = gr.Textbox(label="Detailed Findings", interactive=False)
90
+
91
+ # Example images for quick testing
92
+ gr.Examples(
93
+ examples=["example1.jpg", "example2.jpg"],
94
+ inputs=image_input,
95
+ label="Try sample images"
96
+ )
97
+
98
+ analyze_btn.click(
99
+ fn=analyzer.analyze_image,
100
+ inputs=[image_input, confidence_slider],
101
+ outputs=[plot_output, text_output]
102
+ )
103
+
104
+ # For Hugging Face Spaces
105
+ demo.launch()