File size: 10,080 Bytes
7b04d4e
 
 
 
 
49a323c
7b04d4e
33fd6ad
75c2b7c
33fd6ad
1cddd79
 
 
 
5f3406b
9bf83e0
5f3406b
bda20be
 
18cd948
b122109
1cddd79
18cd948
bda20be
46e12d1
 
33fd6ad
1cddd79
 
f2ae346
1cddd79
5f3406b
 
 
 
 
46e12d1
 
 
 
 
bda20be
46e12d1
 
bda20be
46e12d1
bda20be
46e12d1
 
 
 
 
 
 
 
 
 
 
 
 
bda20be
46e12d1
5f3406b
 
 
f2ae346
 
 
5f3406b
 
1cddd79
 
bda20be
46f4ca8
 
1cddd79
18cd948
1cddd79
bda20be
18cd948
740f7c7
46e12d1
 
 
 
 
 
bda20be
46e12d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bda20be
46e12d1
 
 
 
bda20be
46e12d1
 
 
 
 
 
 
 
 
bda20be
9bf83e0
46e12d1
9bf83e0
 
18cd948
9bf83e0
46f4ca8
46e12d1
46f4ca8
 
18cd948
46e12d1
 
 
 
 
 
bd1163f
 
 
 
46e12d1
 
 
 
bd1163f
46e12d1
 
bd1163f
 
 
46e12d1
bd1163f
bda20be
 
 
bd1163f
46e12d1
bd1163f
bda20be
 
9bf83e0
 
 
46e12d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1cddd79
 
 
7e6153d
7b04d4e
1cddd79
b4f3ea6
46e12d1
1cddd79
18cd948
7b04d4e
b4f3ea6
b6ce847
49a323c
27eab0f
 
 
 
9fd1d46
27eab0f
33fd6ad
b4f3ea6
 
 
 
1cddd79
7b04d4e
bda20be
 
46e12d1
 
 
bda20be
 
1cddd79
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
import gradio as gr
import cv2
import numpy as np
from groq import Groq
import time
from PIL import Image as PILImage
import io
import os
import base64

def create_monitor_interface():
    api_key = os.getenv("GROQ_API_KEY")
    
    class SafetyMonitor:
        def __init__(self):
            self.client = Groq()
            self.model_name = "llama-3.2-90b-vision-preview"
            self.max_image_size = (800, 800)
            self.colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0), (255, 0, 255)]

        def analyze_frame(self, frame: np.ndarray) -> str:
            if frame is None:
                return "No frame received"
                
            frame = self.preprocess_image(frame)
            image_url = self.encode_image(frame)
            
            try:
                completion = self.client.chat.completions.create(
                    model=self.model_name,
                    messages=[
                        {
                            "role": "user",
                            "content": [
                                {
                                    "type": "text",
                                    "text": """Analyze this image for safety hazards and issues. For each identified hazard:

1. Specify the exact location in the image where the hazard exists
2. Describe the specific safety concern
3. Note any violations or risks

Format each observation exactly as:
- <location>area:hazard description</location>

Examples of locations: top-left, center, bottom-right, full-area, near-machine, workspace, etc.

Look for ALL types of safety issues including:
- Personal protective equipment (PPE)
- Machine and equipment hazards
- Ergonomic risks
- Environmental hazards
- Fire and electrical safety
- Chemical safety
- Fall protection
- Material handling
- Access/egress issues
- Housekeeping
- Tool safety
- Emergency equipment

Be specific about locations and provide detailed observations."""
                                },
                                {
                                    "type": "image_url",
                                    "image_url": {
                                        "url": image_url
                                    }
                                }
                            ]
                        }
                    ],
                    temperature=0.5,
                    max_tokens=500,
                    stream=False
                )
                return completion.choices[0].message.content
            except Exception as e:
                print(f"Analysis error: {str(e)}")
                return f"Analysis Error: {str(e)}"

        def preprocess_image(self, frame):
            """Prepare image for analysis."""
            if len(frame.shape) == 2:
                frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
            elif len(frame.shape) == 3 and frame.shape[2] == 4:
                frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
            
            return self.resize_image(frame)

        def resize_image(self, image):
            """Resize image while maintaining aspect ratio."""
            height, width = image.shape[:2]
            if height > self.max_image_size[1] or width > self.max_image_size[0]:
                aspect = width / height
                if width > height:
                    new_width = self.max_image_size[0]
                    new_height = int(new_width / aspect)
                else:
                    new_height = self.max_image_size[1]
                    new_width = int(new_height * aspect)
                return cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
            return image

        def encode_image(self, frame):
            """Convert image to base64 encoding."""
            frame_pil = PILImage.fromarray(frame)
            buffered = io.BytesIO()
            frame_pil.save(buffered, format="JPEG", quality=95)
            img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
            return f"data:image/jpeg;base64,{img_base64}"

        def parse_locations(self, observation: str) -> dict:
            """Parse location information from observation."""
            locations = {
                'full': (0, 0, 1, 1),
                'top': (0.2, 0, 0.8, 0.3),
                'bottom': (0.2, 0.7, 0.8, 1),
                'left': (0, 0.2, 0.3, 0.8),
                'right': (0.7, 0.2, 1, 0.8),
                'center': (0.3, 0.3, 0.7, 0.7),
                'top-left': (0, 0, 0.3, 0.3),
                'top-right': (0.7, 0, 1, 0.3),
                'bottom-left': (0, 0.7, 0.3, 1),
                'bottom-right': (0.7, 0.7, 1, 1),
                'workspace': (0.2, 0.2, 0.8, 0.8),
                'near-machine': (0.6, 0.1, 1, 0.9),
                'floor-area': (0, 0.7, 1, 1),
                'equipment': (0.5, 0.1, 1, 0.9)
            }
            
            # Find best matching location
            text = observation.lower()
            best_match = 'center'
            max_match = 0
            
            for loc in locations.keys():
                if loc in text:
                    words = loc.split('-')
                    matches = sum(1 for word in words if word in text)
                    if matches > max_match:
                        max_match = matches
                        best_match = loc
            
            return locations[best_match]

        def draw_observations(self, image, observations):
            """Draw bounding boxes and labels for safety observations."""
            height, width = image.shape[:2]
            font = cv2.FONT_HERSHEY_SIMPLEX
            font_scale = 0.5
            thickness = 2
            padding = 10

            for idx, obs in enumerate(observations):
                color = self.colors[idx % len(self.colors)]
                
                # Get relative coordinates and convert to absolute
                rel_coords = self.parse_locations(obs['location'])
                x1 = int(rel_coords[0] * width)
                y1 = int(rel_coords[1] * height)
                x2 = int(rel_coords[2] * width)
                y2 = int(rel_coords[3] * height)
                
                # Draw rectangle
                cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
                
                # Prepare label
                label = obs['description'][:50]
                if len(obs['description']) > 50:
                    label += "..."
                
                # Calculate text position
                label_size, _ = cv2.getTextSize(label, font, font_scale, thickness)
                text_x = max(0, x1)
                text_y = max(label_size[1] + padding, y1 - padding)
                
                # Draw label background
                cv2.rectangle(image, 
                            (text_x, text_y - label_size[1] - padding),
                            (text_x + label_size[0] + padding, text_y),
                            color, -1)
                
                # Draw label text
                cv2.putText(image, label,
                           (text_x + padding//2, text_y - padding//2),
                           font, font_scale, (255, 255, 255), thickness)
            
            return image

        def process_frame(self, frame: np.ndarray) -> tuple[np.ndarray, str]:
            """Process frame and generate safety analysis with visualizations."""
            if frame is None:
                return None, "No image provided"
            
            # Get analysis
            analysis = self.analyze_frame(frame)
            display_frame = frame.copy()
            
            # Parse observations
            observations = []
            for line in analysis.split('\n'):
                line = line.strip()
                if line.startswith('-') and '<location>' in line and '</location>' in line:
                    start = line.find('<location>') + len('<location>')
                    end = line.find('</location>')
                    location_description = line[start:end].strip()
                    
                    # Split location and description
                    if ':' in location_description:
                        location, description = location_description.split(':', 1)
                        observations.append({
                            'location': location.strip(),
                            'description': description.strip()
                        })
            
            # Draw observations if any were found
            if observations:
                annotated_frame = self.draw_observations(display_frame, observations)
                return annotated_frame, analysis
            
            return display_frame, analysis

    # Create interface
    monitor = SafetyMonitor()
    
    with gr.Blocks() as demo:
        gr.Markdown("# Safety Analysis System powered by Llama 3.2 90b vision")
        
        with gr.Row():
            input_image = gr.Image(label="Upload Image")
            output_image = gr.Image(label="Safety Analysis")
        
        analysis_text = gr.Textbox(label="Detailed Analysis", lines=5)
            
        def analyze_image(image):
            if image is None:
                return None, "No image provided"
            try:
                processed_frame, analysis = monitor.process_frame(image)
                return processed_frame, analysis
            except Exception as e:
                print(f"Processing error: {str(e)}")
                return None, f"Error processing image: {str(e)}"
            
        input_image.change(
            fn=analyze_image,
            inputs=input_image,
            outputs=[output_image, analysis_text]
        )

        gr.Markdown("""
        ## Instructions:
        1. Upload any workplace/safety-related image
        2. View identified hazards and safety concerns
        3. Check detailed analysis for recommendations
        """)

    return demo

demo = create_monitor_interface()
demo.launch()