File size: 7,137 Bytes
7b04d4e
 
 
 
 
49a323c
7b04d4e
33fd6ad
75c2b7c
18cd948
33fd6ad
1cddd79
 
 
 
5f3406b
9bf83e0
5f3406b
18cd948
 
5f3406b
18cd948
 
 
 
 
 
 
 
 
 
 
 
 
b122109
1cddd79
18cd948
 
 
30f620c
27eab0f
30f620c
27eab0f
fc9e0d8
 
49a323c
27eab0f
18cd948
27eab0f
9fd1d46
18cd948
 
 
27eab0f
f2ae346
33fd6ad
1cddd79
 
f2ae346
1cddd79
5f3406b
 
 
 
 
18cd948
 
 
5f3406b
 
 
f2ae346
 
 
5f3406b
 
18cd948
 
 
 
1cddd79
 
18cd948
 
 
 
 
1cddd79
18cd948
1cddd79
18cd948
 
740f7c7
9bf83e0
 
 
18cd948
9bf83e0
 
18cd948
9bf83e0
 
 
18cd948
 
 
 
 
 
 
 
 
 
 
 
 
 
9bf83e0
 
 
1cddd79
 
30f4028
18cd948
e43f38f
18cd948
740f7c7
18cd948
e43f38f
 
 
 
18cd948
e43f38f
 
 
18cd948
 
 
 
 
1ae9e2e
18cd948
 
1ae9e2e
18cd948
7b04d4e
18cd948
1cddd79
 
 
7e6153d
7b04d4e
1cddd79
b4f3ea6
18cd948
1cddd79
18cd948
7b04d4e
b4f3ea6
b6ce847
49a323c
27eab0f
 
 
 
9fd1d46
27eab0f
33fd6ad
b4f3ea6
 
 
 
1cddd79
7b04d4e
1cddd79
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import gradio as gr
import cv2
import numpy as np
from groq import Groq
import time
from PIL import Image as PILImage
import io
import os
import base64
import random

def create_monitor_interface():
    api_key = os.getenv("GROQ_API_KEY")
    
    class SafetyMonitor:
        def __init__(self):
            self.client = Groq()
            self.model_name = "llama-3.2-90b-vision-preview"
            self.max_image_size = (640, 640)  # Increased size for better visibility
            self.colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]
            
        def resize_image(self, image):
            height, width = image.shape[:2]
            aspect = width / height
            
            if width > height:
                new_width = min(self.max_image_size[0], width)
                new_height = int(new_width / aspect)
            else:
                new_height = min(self.max_image_size[1], height)
                new_width = int(new_height * aspect)
                
            return cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)

        def analyze_frame(self, frame: np.ndarray) -> str:
            if frame is None:
                return "No frame received"
                
            # Convert and resize image
            if len(frame.shape) == 2:
                frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
            elif len(frame.shape) == 3 and frame.shape[2] == 4:
                frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
            
            frame = self.resize_image(frame)
            frame_pil = PILImage.fromarray(frame)
            
            # Convert to base64 with minimal quality
            buffered = io.BytesIO()
            frame_pil.save(buffered, 
                         format="JPEG", 
                         quality=30, 
                         optimize=True)
            img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
            image_url = f"data:image/jpeg;base64,{img_base64}"
            
            try:
                completion = self.client.chat.completions.create(
                    model=self.model_name,
                    messages=[
                        {
                            "role": "user",
                            "content": [
                                {
                                    "type": "text",
                                    "text": """Analyze this workplace image and describe each safety concern in this format:
                                    - <location>Description</location>
                                    Use one line per issue, starting with a dash and location in tags."""
                                },
                                {
                                    "type": "image_url",
                                    "image_url": {
                                        "url": image_url
                                    }
                                }
                            ]
                        },
                        {
                            "role": "assistant",
                            "content": ""
                        }
                    ],
                    temperature=0.1,
                    max_tokens=150,
                    top_p=1,
                    stream=False,
                    stop=None
                )
                return completion.choices[0].message.content
            except Exception as e:
                print(f"Detailed error: {str(e)}")
                return f"Analysis Error: {str(e)}"

        def draw_observations(self, image, observations):
            height, width = image.shape[:2]
            font = cv2.FONT_HERSHEY_SIMPLEX
            font_scale = 0.5
            thickness = 2
            
            # Generate random positions for each observation
            for idx, obs in enumerate(observations):
                color = self.colors[idx % len(self.colors)]
                
                # Generate random box position
                box_width = width // 3
                box_height = height // 3
                x = random.randint(0, width - box_width)
                y = random.randint(0, height - box_height)
                
                # Draw rectangle
                cv2.rectangle(image, (x, y), (x + box_width, y + box_height), color, 2)
                
                # Add label with background
                label = obs[:40] + "..." if len(obs) > 40 else obs
                label_size = cv2.getTextSize(label, font, font_scale, thickness)[0]
                cv2.rectangle(image, (x, y - 20), (x + label_size[0], y), color, -1)
                cv2.putText(image, label, (x, y - 5), font, font_scale, (255, 255, 255), thickness)
            
            return image

        def process_frame(self, frame: np.ndarray) -> tuple[np.ndarray, str]:
            if frame is None:
                return None, "No image provided"
                
            analysis = self.analyze_frame(frame)
            display_frame = self.resize_image(frame.copy())
            
            # Parse observations from the analysis
            observations = []
            for line in analysis.split('\n'):
                line = line.strip()
                if line.startswith('-'):
                    # Extract text between <location> tags if present
                    if '<location>' in line and '</location>' in line:
                        start = line.find('<location>') + len('<location>')
                        end = line.find('</location>')
                        observation = line[end + len('</location>'):].strip()
                    else:
                        observation = line[1:].strip()  # Remove the dash
                    if observation:
                        observations.append(observation)
            
            # Draw observations on the image
            annotated_frame = self.draw_observations(display_frame, observations)
            
            return annotated_frame, analysis

    # Create the main interface
    monitor = SafetyMonitor()
    
    with gr.Blocks() as demo:
        gr.Markdown("# Safety Analysis System powered by Llama 3.2 90b vision")
        
        with gr.Row():
            input_image = gr.Image(label="Upload Image")
            output_image = gr.Image(label="Annotated Results")
        
        analysis_text = gr.Textbox(label="Detailed Analysis", lines=5)
            
        def analyze_image(image):
            if image is None:
                return None, "No image provided"
            try:
                processed_frame, analysis = monitor.process_frame(image)
                return processed_frame, analysis
            except Exception as e:
                print(f"Processing error: {str(e)}")
                return None, f"Error processing image: {str(e)}"
            
        input_image.change(
            fn=analyze_image,
            inputs=input_image,
            outputs=[output_image, analysis_text]
        )

    return demo

demo = create_monitor_interface()
demo.launch()