File size: 8,418 Bytes
7b04d4e
 
 
 
49a323c
7b04d4e
75c2b7c
aca1712
 
63920cb
a5f647b
771e08a
a5f647b
771e08a
 
 
 
aca1712
a5f647b
7e37c85
771e08a
 
519704e
771e08a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a01cb2c
63920cb
a01cb2c
 
63920cb
a01cb2c
7870fce
a01cb2c
aca1712
 
 
 
 
 
 
18cd948
519704e
a5f647b
519704e
 
92928c5
519704e
a5f647b
aca1712
519704e
a5f647b
519704e
 
 
 
 
 
 
 
a5f647b
 
 
519704e
 
 
 
63920cb
5f3406b
519704e
 
 
 
92928c5
a5f647b
519704e
 
7870fce
c498e95
7870fce
 
519704e
 
 
46e12d1
a5f647b
 
519704e
 
 
a5f647b
aca1712
 
 
519704e
bda20be
aca1712
 
46e12d1
a5f647b
7870fce
 
 
 
 
a5f647b
 
 
 
519704e
9bf83e0
519704e
a5f647b
519704e
 
f6cffbc
519704e
f6cffbc
aca1712
a5f647b
 
 
7870fce
a5f647b
 
 
 
 
f6cffbc
519704e
 
 
f6cffbc
 
7870fce
f6cffbc
 
a5f647b
7870fce
 
 
 
 
 
 
f6cffbc
46e12d1
a5f647b
771e08a
a5f647b
1cddd79
 
a5f647b
7b04d4e
1cddd79
b4f3ea6
46e12d1
1cddd79
18cd948
7b04d4e
b4f3ea6
b6ce847
49a323c
27eab0f
 
 
 
9fd1d46
27eab0f
33fd6ad
7e37c85
b4f3ea6
 
 
1cddd79
7b04d4e
bda20be
 
46e12d1
771e08a
a5f647b
bda20be
 
1cddd79
 
771e08a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import gradio as gr
import cv2
import numpy as np
from groq import Groq
from PIL import Image as PILImage
import io
import base64
import torch


class RobustSafetyMonitor:
    def __init__(self):
        """Initialize the robust safety detection tool with configuration."""
        self.client = Groq()
        self.model_name = "llama-3.2-90b-vision-preview"
        self.max_image_size = (800, 800)
        self.colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0), (255, 0, 255)]
        
        # Load YOLOv5 model for general object detection
        self.yolo_model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)

    def preprocess_image(self, frame):
        """Process image for analysis."""
        if len(frame.shape) == 2:
            frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
        elif len(frame.shape) == 3 and frame.shape[2] == 4:
            frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
        
        return self.resize_image(frame)

    def resize_image(self, image):
        """Resize image while maintaining aspect ratio."""
        height, width = image.shape[:2]
        if height > self.max_image_size[1] or width > self.max_image_size[0]:
            aspect = width / height
            if width > height:
                new_width = self.max_image_size[0]
                new_height = int(new_width / aspect)
            else:
                new_height = self.max_image_size[1]
                new_width = int(new_height * aspect)
            return cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
        return image

    def encode_image(self, frame):
        """Convert image to base64 encoding without extra formatting."""
        frame_pil = PILImage.fromarray(frame)
        buffered = io.BytesIO()
        frame_pil.save(buffered, format="JPEG", quality=95)  # Ensure JPEG format
        img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
        return img_base64  # Return only the base64 string

    def detect_objects(self, frame):
        """Detect objects using YOLOv5."""
        results = self.yolo_model(frame)
        # Extract bounding boxes, class labels, and confidence scores
        bbox_data = results.xyxy[0].numpy()  # Bounding box coordinates
        labels = results.names  # Class names
        return bbox_data, labels

    def analyze_frame(self, frame):
        """Perform safety analysis on the frame using Llama Vision 3.2."""
        if frame is None:
            return "No frame received", {}
    
        frame = self.preprocess_image(frame)
        image_base64 = self.encode_image(frame)

        try:
            # Use Llama Vision 3.2 to analyze the context of the image and detect risks
            completion = self.client.chat.completions.create(
                model=self.model_name,
                messages=[
                    {
                        "role": "user",
                        "content": [
                            {
                                "type": "text",
                                "text": """Analyze this workplace image and identify any potential safety risks. 
                                           Consider the positioning of workers, the equipment, materials, and environment. 
                                           Flag risks like improper equipment use, worker proximity to danger zones, unstable materials, and environmental hazards."""
                            },
                            {
                                "type": "image_url",
                                "image_url": {
                                    "url": image_base64  # Corrected: Send only the base64 string
                                }
                            }
                        ]
                    }
                ],
                temperature=0.7,
                max_tokens=1024,
                stream=False
            )
            # Process and parse the response correctly
            response = completion.choices[0].message['content']
            return self.parse_safety_analysis(response), response  # Return parsed analysis and full response

        except Exception as e:
            print(f"Analysis error: {str(e)}")
            return f"Analysis Error: {str(e)}", {}

    def draw_bounding_boxes(self, image, bboxes, labels, safety_issues):
        """Draw bounding boxes around objects based on safety issues flagged by Llama Vision."""
        font = cv2.FONT_HERSHEY_SIMPLEX
        font_scale = 0.5
        thickness = 2
        
        for idx, bbox in enumerate(bboxes):
            x1, y1, x2, y2, conf, class_id = bbox
            label = labels[int(class_id)]
            color = self.colors[idx % len(self.colors)]
            
            # Draw bounding box
            cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), color, thickness)
            
            # Link detected object to potential risks based on Llama Vision analysis
            for safety_issue in safety_issues:
                if safety_issue['object'].lower() in label.lower():
                    label_text = f"Risk: {safety_issue['description']}"
                    cv2.putText(image, label_text, (int(x1), int(y1) - 10), font, font_scale, (0, 0, 255), thickness)
                    break
            else:
                label_text = f"{label} {conf:.2f}"
                cv2.putText(image, label_text, (int(x1), int(y1) - 10), font, font_scale, color, thickness)

        return image

    def process_frame(self, frame):
        """Main processing pipeline for dynamic, robust safety analysis."""
        if frame is None:
            return None, "No image provided"
    
        try:
            # Detect objects dynamically in the image using YOLO
            bbox_data, labels = self.detect_objects(frame)
            frame_with_boxes = self.draw_bounding_boxes(frame, bbox_data, labels, [])

            # Get dynamic safety analysis from Llama Vision 3.2
            safety_issues, analysis = self.analyze_frame(frame)

            # Update the frame with bounding boxes based on safety issues flagged
            annotated_frame = self.draw_bounding_boxes(frame_with_boxes, bbox_data, labels, safety_issues)

            return annotated_frame, analysis
    
        except Exception as e:
            print(f"Processing error: {str(e)}")
            return None, f"Error processing image: {str(e)}"
    
    def parse_safety_analysis(self, analysis):
        """Parse the safety analysis to identify contextual issues and link to objects."""
        safety_issues = []
        for line in analysis.split('\n'):
            if "risk" in line.lower() or "hazard" in line.lower():
                # Extract object involved and description
                parts = line.split(':', 1)
                if len(parts) == 2:
                    safety_issues.append({
                        "object": parts[0].strip(),
                        "description": parts[1].strip()
                    })
        return safety_issues


def create_monitor_interface():
    monitor = RobustSafetyMonitor()
    
    with gr.Blocks() as demo:
        gr.Markdown("# Robust Safety Analysis System powered by Llama Vision 3.2")
        
        with gr.Row():
            input_image = gr.Image(label="Upload Image")
            output_image = gr.Image(label="Safety Analysis")
        
        analysis_text = gr.Textbox(label="Detailed Analysis", lines=5)
            
        def analyze_image(image):
            if image is None:
                return None, "No image provided"
            try:
                processed_frame, analysis = monitor.process_frame(image)
                return processed_frame, analysis
            except Exception as e:
                print(f"Processing error: {str(e)}")
                return None, f"Error processing image: {str(e)}"
            
        input_image.upload(
            fn=analyze_image,
            inputs=input_image,
            outputs=[output_image, analysis_text]
        )

        gr.Markdown("""
        ## Instructions:
        1. Upload any workplace/safety-related image
        2. View identified hazards and their locations
        3. Read detailed analysis of safety concerns based on the image
        """)

    return demo

if __name__ == "__main__":
    demo = create_monitor_interface()
    demo.launch()