File size: 8,793 Bytes
7b04d4e
 
 
 
 
49a323c
7b04d4e
33fd6ad
75c2b7c
9bf83e0
33fd6ad
1cddd79
 
 
 
5f3406b
9bf83e0
5f3406b
b4f3ea6
740f7c7
 
b4f3ea6
 
5f3406b
fc9e0d8
 
 
740f7c7
 
 
 
 
 
 
 
 
 
7b04d4e
1cddd79
 
 
 
740f7c7
30f620c
27eab0f
30f620c
27eab0f
fc9e0d8
 
49a323c
27eab0f
 
9fd1d46
 
b4f3ea6
9fd1d46
27eab0f
f2ae346
33fd6ad
1cddd79
 
f2ae346
1cddd79
5f3406b
 
 
 
 
740f7c7
 
 
 
5f3406b
 
 
f2ae346
 
 
5f3406b
 
f2ae346
 
 
 
1cddd79
 
f2ae346
740f7c7
f2ae346
 
 
1cddd79
 
 
9fd1d46
1cddd79
7b04d4e
740f7c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9bf83e0
 
 
740f7c7
9bf83e0
 
 
 
 
740f7c7
 
 
 
 
 
 
 
 
9bf83e0
740f7c7
9bf83e0
740f7c7
9bf83e0
740f7c7
 
 
 
 
 
 
 
9bf83e0
 
 
1cddd79
 
30f4028
7b04d4e
740f7c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1cddd79
740f7c7
 
9bf83e0
740f7c7
7b04d4e
1cddd79
 
 
 
7e6153d
7b04d4e
1cddd79
b4f3ea6
740f7c7
1cddd79
740f7c7
7b04d4e
b4f3ea6
b6ce847
49a323c
27eab0f
 
 
 
9fd1d46
27eab0f
33fd6ad
b4f3ea6
 
 
 
1cddd79
7b04d4e
b4f3ea6
 
 
 
 
 
 
1cddd79
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
import gradio as gr
import cv2
import numpy as np
from groq import Groq
import time
from PIL import Image as PILImage
import io
import os
import base64
import random

def create_monitor_interface():
    api_key = os.getenv("GROQ_API_KEY")
    
    class SafetyMonitor:
        def __init__(self):
            self.client = Groq()
            self.model_name = "llama-3.2-90b-vision-preview"
            self.max_image_size = (800, 800)
            self.colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0), (255, 0, 255)]
            self.last_analysis_time = 0
            self.analysis_interval = 2
            self.last_observations = []
            
        def resize_image(self, image):
            height, width = image.shape[:2]
            
            if height > self.max_image_size[1] or width > self.max_image_size[0]:
                aspect = width / height
                if width > height:
                    new_width = self.max_image_size[0]
                    new_height = int(new_width / aspect)
                else:
                    new_height = self.max_image_size[1]
                    new_width = int(new_height * aspect)
                return cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
            return image

        def analyze_frame(self, frame: np.ndarray) -> str:
            if frame is None:
                return "No frame received"
                
            # Convert image
            if len(frame.shape) == 2:
                frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
            elif len(frame.shape) == 3 and frame.shape[2] == 4:
                frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
            
            frame = self.resize_image(frame)
            frame_pil = PILImage.fromarray(frame)
            
            buffered = io.BytesIO()
            frame_pil.save(buffered, 
                         format="JPEG", 
                         quality=85,
                         optimize=True)
            img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
            image_url = f"data:image/jpeg;base64,{img_base64}"
            
            try:
                completion = self.client.chat.completions.create(
                    model=self.model_name,
                    messages=[
                        {
                            "role": "user",
                            "content": [
                                {
                                    "type": "text",
                                    "text": """Analyze this image for safety hazards. For each issue, describe:
                                    1. The location (top-left, center, bottom-right, etc.)
                                    2. The specific safety concern
                                    Format: - <location>position:description</location>"""
                                },
                                {
                                    "type": "image_url",
                                    "image_url": {
                                        "url": image_url
                                    }
                                }
                            ]
                        },
                        {
                            "role": "assistant",
                            "content": ""
                        }
                    ],
                    temperature=0.1,
                    max_tokens=200,
                    top_p=1,
                    stream=False,
                    stop=None
                )
                return completion.choices[0].message.content
            except Exception as e:
                print(f"Detailed error: {str(e)}")
                return f"Analysis Error: {str(e)}"

        def get_region_coordinates(self, position: str, image_shape: tuple) -> tuple:
            height, width = image_shape[:2]
            regions = {
                'top-left': (0, 0, width//3, height//3),
                'top': (width//3, 0, 2*width//3, height//3),
                'top-right': (2*width//3, 0, width, height//3),
                'left': (0, height//3, width//3, 2*height//3),
                'center': (width//3, height//3, 2*width//3, 2*height//3),
                'right': (2*width//3, height//3, width, 2*height//3),
                'bottom-left': (0, 2*height//3, width//3, height),
                'bottom': (width//3, 2*height//3, 2*width//3, height),
                'bottom-right': (2*width//3, 2*height//3, width, height)
            }
            
            for region_name, coords in regions.items():
                if region_name in position.lower():
                    return coords
            
            return regions['center']

        def draw_observations(self, image, observations):
            height, width = image.shape[:2]
            font = cv2.FONT_HERSHEY_SIMPLEX
            font_scale = 0.6
            thickness = 2
            
            for idx, obs in enumerate(observations):
                color = self.colors[idx % len(self.colors)]
                
                parts = obs.split(':')
                if len(parts) >= 2:
                    position = parts[0]
                    description = ':'.join(parts[1:])
                else:
                    position = 'center'
                    description = obs
                
                x1, y1, x2, y2 = self.get_region_coordinates(position, image.shape)
                
                cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
                
                label = description[:50] + "..." if len(description) > 50 else description
                label_size = cv2.getTextSize(label, font, font_scale, thickness)[0]
                
                label_x = max(0, min(x1, width - label_size[0]))
                label_y = max(20, y1 - 5)
                
                cv2.rectangle(image, (label_x, label_y - 20), 
                            (label_x + label_size[0], label_y), color, -1)
                cv2.putText(image, label, (label_x, label_y - 5), 
                          font, font_scale, (255, 255, 255), thickness)
            
            return image

        def process_frame(self, frame: np.ndarray) -> tuple[np.ndarray, str]:
            if frame is None:
                return None, "No image provided"
            
            current_time = time.time()
            
            if current_time - self.last_analysis_time >= self.analysis_interval:
                analysis = self.analyze_frame(frame)
                self.last_analysis_time = current_time
                
                observations = []
                for line in analysis.split('\n'):
                    line = line.strip()
                    if line.startswith('-'):
                        if '<location>' in line and '</location>' in line:
                            start = line.find('<location>') + len('<location>')
                            end = line.find('</location>')
                            observation = line[start:end].strip()
                            if observation:
                                observations.append(observation)
                
                self.last_observations = observations
            
            display_frame = frame.copy()
            annotated_frame = self.draw_observations(display_frame, self.last_observations)
            
            return annotated_frame, '\n'.join([f"- {obs}" for obs in self.last_observations])

    # Create the main interface
    monitor = SafetyMonitor()
    
    with gr.Blocks() as demo:
        gr.Markdown("# Safety Analysis System powered by Llama 3.2 90b vision")
        
        with gr.Row():
            input_image = gr.Image(label="Upload Image")
            output_image = gr.Image(label="Analysis")
        
        analysis_text = gr.Textbox(label="Safety Concerns", lines=5)
            
        def analyze_image(image):
            if image is None:
                return None, "No image provided"
            try:
                processed_frame, analysis = monitor.process_frame(image)
                return processed_frame, analysis
            except Exception as e:
                print(f"Processing error: {str(e)}")
                return None, f"Error processing image: {str(e)}"
            
        input_image.change(
            fn=analyze_image,
            inputs=input_image,
            outputs=[output_image, analysis_text]
        )

        gr.Markdown("""
        ## Instructions:
        1. Upload an image to analyze safety concerns
        2. View annotated results and detailed analysis
        3. Each box highlights a potential safety issue
        """)

    return demo

demo = create_monitor_interface()
demo.launch()