File size: 10,350 Bytes
7b04d4e
 
 
 
49a323c
7b04d4e
75c2b7c
aca1712
d2c67f3
 
aca1712
d2c67f3
 
63920cb
a5f647b
771e08a
a5f647b
771e08a
d2c67f3
771e08a
 
aca1712
a5f647b
7e37c85
d2c67f3
 
 
 
771e08a
d2c67f3
519704e
d2c67f3
 
 
771e08a
 
 
 
 
 
 
d2c67f3
771e08a
 
 
 
 
 
 
 
 
 
 
 
 
d2c67f3
160a45b
d2c67f3
 
 
 
 
 
 
 
a01cb2c
d2c67f3
aca1712
d2c67f3
 
 
 
 
 
 
 
 
 
 
519704e
d2c67f3
92928c5
519704e
d2c67f3
 
 
519704e
 
 
 
 
 
 
 
a5f647b
d2c67f3
 
519704e
 
 
 
d2c67f3
5f3406b
519704e
 
 
 
92928c5
a5f647b
519704e
 
d2c67f3
 
 
 
 
 
 
 
 
7870fce
519704e
 
d2c67f3
46e12d1
d2c67f3
 
 
 
519704e
 
 
a5f647b
aca1712
d2c67f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a5f647b
d2c67f3
9bf83e0
d2c67f3
 
519704e
 
f6cffbc
519704e
d2c67f3
aca1712
d2c67f3
 
7870fce
d2c67f3
 
 
 
a5f647b
f6cffbc
519704e
 
 
f6cffbc
d2c67f3
 
f6cffbc
d2c67f3
 
 
 
f6cffbc
d2c67f3
 
 
 
 
 
 
 
 
7870fce
d2c67f3
 
7870fce
d2c67f3
 
 
 
f6cffbc
46e12d1
a5f647b
771e08a
d2c67f3
a5f647b
1cddd79
 
d2c67f3
 
7b04d4e
1cddd79
d2c67f3
 
 
 
1cddd79
b4f3ea6
b6ce847
d2c67f3
27eab0f
 
 
 
d2c67f3
 
33fd6ad
7e37c85
b4f3ea6
 
 
1cddd79
7b04d4e
bda20be
d2c67f3
 
 
 
 
 
 
 
 
 
bda20be
 
1cddd79
 
771e08a
 
d2c67f3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
import gradio as gr
import cv2
import numpy as np
from groq import Groq
from PIL import Image as PILImage
import io
import base64
import torch
import warnings
from typing import Tuple, List, Dict, Optional

# Suppress the CUDA autocast warning
warnings.filterwarnings('ignore', category=FutureWarning)

class RobustSafetyMonitor:
    def __init__(self):
        """Initialize the robust safety detection tool with configuration."""
        self.client = Groq()
        self.model_name = "llama-3.2-11b-vision-preview"  # Updated to use the correct model
        self.max_image_size = (800, 800)
        self.colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0), (255, 0, 255)]
        
        # Load YOLOv5 model for general object detection
        self.yolo_model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
        
        # Force CPU inference if CUDA is causing issues
        self.yolo_model.cpu()
        self.yolo_model.eval()

    def preprocess_image(self, frame: np.ndarray) -> np.ndarray:
        """Process image for analysis."""
        if frame is None:
            raise ValueError("No image provided")
            
        if len(frame.shape) == 2:
            frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
        elif len(frame.shape) == 3 and frame.shape[2] == 4:
            frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
        
        return self.resize_image(frame)

    def resize_image(self, image: np.ndarray) -> np.ndarray:
        """Resize image while maintaining aspect ratio."""
        height, width = image.shape[:2]
        if height > self.max_image_size[1] or width > self.max_image_size[0]:
            aspect = width / height
            if width > height:
                new_width = self.max_image_size[0]
                new_height = int(new_width / aspect)
            else:
                new_height = self.max_image_size[1]
                new_width = int(new_height * aspect)
            return cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
        return image

    def encode_image(self, frame: np.ndarray) -> str:
        """Convert image to base64 encoding with proper formatting."""
        try:
            frame_pil = PILImage.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
            buffered = io.BytesIO()
            frame_pil.save(buffered, format="JPEG", quality=95)
            img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
            return f"data:image/jpeg;base64,{img_base64}"
        except Exception as e:
            raise ValueError(f"Error encoding image: {str(e)}")

    def detect_objects(self, frame: np.ndarray) -> Tuple[np.ndarray, Dict]:
        """Detect objects using YOLOv5."""
        try:
            with torch.no_grad():
                results = self.yolo_model(frame)
            bbox_data = results.xyxy[0].cpu().numpy()
            labels = results.names
            return bbox_data, labels
        except Exception as e:
            raise ValueError(f"Error detecting objects: {str(e)}")

    def analyze_frame(self, frame: np.ndarray) -> Tuple[List[Dict], str]:
        """Perform safety analysis on the frame using Llama Vision."""
        if frame is None:
            return [], "No frame received"
    
        try:
            frame = self.preprocess_image(frame)
            image_base64 = self.encode_image(frame)

            completion = self.client.chat.completions.create(
                model=self.model_name,
                messages=[
                    {
                        "role": "user",
                        "content": [
                            {
                                "type": "text",
                                "text": """Analyze this workplace image and identify any potential safety risks. 
                                List each risk on a new line starting with 'Risk:'.
                                Format: Risk: [Object/Area] - [Description of hazard]"""
                            },
                            {
                                "type": "image_url",
                                "image_url": {
                                    "url": image_base64
                                }
                            }
                        ]
                    }
                ],
                temperature=0.7,
                max_tokens=1024,
                stream=False
            )
            
            # Get the response content safely
            try:
                response = completion.choices[0].message.content
            except AttributeError:
                response = str(completion.choices[0].message)
                
            safety_issues = self.parse_safety_analysis(response)
            return safety_issues, response

        except Exception as e:
            print(f"Analysis error: {str(e)}")
            return [], f"Analysis Error: {str(e)}"

    def draw_bounding_boxes(self, image: np.ndarray, bboxes: np.ndarray, 
                          labels: Dict, safety_issues: List[Dict]) -> np.ndarray:
        """Draw bounding boxes around objects based on safety issues."""
        image_copy = image.copy()
        font = cv2.FONT_HERSHEY_SIMPLEX
        font_scale = 0.5
        thickness = 2
        
        for idx, bbox in enumerate(bboxes):
            try:
                x1, y1, x2, y2, conf, class_id = bbox
                label = labels[int(class_id)]
                color = self.colors[idx % len(self.colors)]
                
                # Convert coordinates to integers
                x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
                
                # Draw bounding box
                cv2.rectangle(image_copy, (x1, y1), (x2, y2), color, thickness)
                
                # Check if object is associated with any safety issues
                risk_found = False
                for safety_issue in safety_issues:
                    if safety_issue.get('object', '').lower() in label.lower():
                        label_text = f"Risk: {safety_issue.get('description', '')}"
                        y_pos = max(y1 - 10, 20)
                        cv2.putText(image_copy, label_text, (x1, y_pos), font, 
                                  font_scale, (0, 0, 255), thickness)
                        risk_found = True
                        break
                        
                if not risk_found:
                    label_text = f"{label} {conf:.2f}"
                    y_pos = max(y1 - 10, 20)
                    cv2.putText(image_copy, label_text, (x1, y_pos), font, 
                              font_scale, color, thickness)
            except Exception as e:
                print(f"Error drawing box: {str(e)}")
                continue

        return image_copy

    def process_frame(self, frame: np.ndarray) -> Tuple[Optional[np.ndarray], str]:
        """Main processing pipeline for safety analysis."""
        if frame is None:
            return None, "No image provided"
    
        try:
            # Detect objects
            bbox_data, labels = self.detect_objects(frame)
            
            # Get safety analysis
            safety_issues, analysis = self.analyze_frame(frame)
            
            # Draw annotations
            annotated_frame = self.draw_bounding_boxes(frame, bbox_data, labels, safety_issues)
            
            return annotated_frame, analysis
    
        except Exception as e:
            print(f"Processing error: {str(e)}")
            return None, f"Error processing image: {str(e)}"
    
    def parse_safety_analysis(self, analysis: str) -> List[Dict]:
        """Parse the safety analysis text into structured data."""
        safety_issues = []
        
        if not isinstance(analysis, str):
            return safety_issues
            
        for line in analysis.split('\n'):
            if "risk:" in line.lower():
                try:
                    # Extract object and description
                    parts = line.lower().split('risk:', 1)[1].strip()
                    if '-' in parts:
                        obj, desc = parts.split('-', 1)
                    else:
                        obj, desc = parts, parts
                        
                    safety_issues.append({
                        "object": obj.strip(),
                        "description": desc.strip()
                    })
                except Exception as e:
                    print(f"Error parsing line: {line}, Error: {str(e)}")
                    continue
                    
        return safety_issues


def create_monitor_interface():
    """Create the Gradio interface for the safety monitoring system."""
    monitor = RobustSafetyMonitor()
    
    with gr.Blocks() as demo:
        gr.Markdown("# Workplace Safety Analysis System")
        gr.Markdown("Powered by Groq LLaVA Vision and YOLOv5")
        
        with gr.Row():
            input_image = gr.Image(label="Upload Workplace Image", type="numpy")
            output_image = gr.Image(label="Safety Analysis Visualization")
        
        analysis_text = gr.Textbox(label="Detailed Safety Analysis", lines=5)
        
        def analyze_image(image):
            if image is None:
                return None, "Please upload an image"
            try:
                processed_frame, analysis = monitor.process_frame(image)
                return processed_frame, analysis
            except Exception as e:
                print(f"Analysis error: {str(e)}")
                return None, f"Error analyzing image: {str(e)}"
            
        input_image.upload(
            fn=analyze_image,
            inputs=input_image,
            outputs=[output_image, analysis_text]
        )

        gr.Markdown("""
        ## Instructions
        1. Upload a workplace image for safety analysis
        2. View detected hazards and their locations in the visualization
        3. Read the detailed safety analysis below the images
        
        ## Features
        - Real-time object detection
        - AI-powered safety risk analysis
        - Visual risk highlighting
        - Detailed safety recommendations
        """)

    return demo

if __name__ == "__main__":
    demo = create_monitor_interface()
    demo.launch(share=True)